@schema
class MetricCuration(SpyglassMixin, dj.Computed):
definition = """
# Results of applying curation based on quality metrics. To do additional curation, insert another row in `CurationV1`
-> MetricCurationSelection
---
-> AnalysisNwbfile
object_id: varchar(40) # Object ID for the metrics in NWB file
"""
def make(self, key):
# FETCH
nwb_file_name = (
SpikeSortingSelection * MetricCurationSelection & key
).fetch1("nwb_file_name")
waveform_params = (
WaveformParameters * MetricCurationSelection & key
).fetch1("waveform_params")
metric_params = (
MetricParameters * MetricCurationSelection & key
).fetch1("metric_params")
label_params, merge_params = (
MetricCurationParameters * MetricCurationSelection & key
).fetch1("label_params", "merge_params")
sorting_id, curation_id = (MetricCurationSelection & key).fetch1(
"sorting_id", "curation_id"
)
# DO
# load recording and sorting
recording = CurationV1.get_recording(
{"sorting_id": sorting_id, "curation_id": curation_id}
)
sorting = CurationV1.get_sorting(
{"sorting_id": sorting_id, "curation_id": curation_id}
)
# extract waveforms
if "whiten" in waveform_params:
if waveform_params.pop("whiten"):
recording = sp.whiten(recording, dtype=np.float64)
waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"])
os.makedirs(waveforms_dir, exist_ok=True)
logger.info("Extracting waveforms...")
waveforms = si.extract_waveforms(
recording=recording,
sorting=sorting,
folder=waveforms_dir,
overwrite=True,
**waveform_params,
)
# compute metrics
logger.info("Computing metrics...")
metrics = {}
for metric_name, metric_param_dict in metric_params.items():
metrics[metric_name] = self._compute_metric(
waveforms, metric_name, **metric_param_dict
)
if metrics["nn_isolation"]:
metrics["nn_isolation"] = {
unit_id: value[0]
for unit_id, value in metrics["nn_isolation"].items()
}
logger.info("Applying curation...")
labels = self._compute_labels(metrics, label_params)
merge_groups = self._compute_merge_groups(metrics, merge_params)
logger.info("Saving to NWB...")
(
key["analysis_file_name"],
key["object_id"],
) = _write_metric_curation_to_nwb(
nwb_file_name, waveforms, metrics, labels, merge_groups
)
# INSERT
AnalysisNwbfile().add(
nwb_file_name,
key["analysis_file_name"],
)
self.insert1(key)
@classmethod
def get_waveforms(cls):
return NotImplementedError
@classmethod
def get_metrics(cls, key: dict):
"""Returns metrics identified by metric curation
Parameters
----------
key : dict
primary key to MetricCuration
"""
analysis_file_name, object_id, metric_param_name, metric_params = (
cls * MetricCurationSelection * MetricParameters & key
).fetch1(
"analysis_file_name",
"object_id",
"metric_param_name",
"metric_params",
)
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
analysis_file_name
)
with pynwb.NWBHDF5IO(
path=analysis_file_abs_path,
mode="r",
load_namespaces=True,
) as io:
nwbf = io.read()
units = nwbf.objects[object_id].to_dataframe()
return {
name: dict(zip(units.index, units[name])) for name in metric_params
}
@classmethod
def get_labels(cls, key: dict):
"""Returns curation labels identified by metric curation
Parameters
----------
key : dict
primary key to MetricCuration
"""
analysis_file_name, object_id = (cls & key).fetch1(
"analysis_file_name", "object_id"
)
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
analysis_file_name
)
with pynwb.NWBHDF5IO(
path=analysis_file_abs_path,
mode="r",
load_namespaces=True,
) as io:
nwbf = io.read()
units = nwbf.objects[object_id].to_dataframe()
return dict(zip(units.index, units["curation_label"]))
@classmethod
def get_merge_groups(cls, key: dict):
"""Returns merge groups identified by metric curation
Parameters
----------
key : dict
primary key to MetricCuration
"""
analysis_file_name, object_id = (cls & key).fetch1(
"analysis_file_name", "object_id"
)
analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
analysis_file_name
)
with pynwb.NWBHDF5IO(
path=analysis_file_abs_path,
mode="r",
load_namespaces=True,
) as io:
nwbf = io.read()
units = nwbf.objects[object_id].to_dataframe()
merge_group_dict = dict(zip(units.index, units["merge_groups"]))
return _merge_dict_to_list(merge_group_dict)
@staticmethod
def _compute_metric(waveform_extractor, metric_name, **metric_params):
metric_func = _metric_name_to_func[metric_name]
peak_sign_metrics = ["snr", "peak_offset", "peak_channel"]
if metric_name in peak_sign_metrics:
if "peak_sign" not in metric_params:
raise Exception(
f"{peak_sign_metrics} metrics require peak_sign",
"to be defined in the metric parameters",
)
return metric_func(
waveform_extractor,
peak_sign=metric_params.pop("peak_sign"),
**metric_params,
)
return {
unit_id: metric_func(waveform_extractor, this_unit_id=unit_id)
for unit_id in waveform_extractor.sorting.get_unit_ids()
}
@staticmethod
def _compute_labels(
metrics: Dict[str, Dict[str, Union[float, List[float]]]],
label_params: Dict[str, List[Any]],
) -> Dict[str, List[str]]:
"""Computes the labels based on the metric and label parameters.
Parameters
----------
quality_metrics : dict
Example: {"snr" : {"1" : 2, "2" : 0.1, "3" : 2.3}}
This indicates that the values of the "snr" quality metric
for the units "1", "2", "3" are 2, 0.1, and 2.3, respectively.
label_params : dict
Example: {
"snr" : [(">", 1, ["good", "mua"]),
("<", 1, ["noise"])]
}
This indicates that units with values of the "snr" quality metric
greater than 1 should be given the labels "good" and "mua" and values
less than 1 should be given the label "noise".
Returns
-------
labels : dict
Example: {"1" : ["good", "mua"], "2" : ["noise"], "3" : ["good", "mua"]}
"""
if not label_params:
return {}
unit_ids = [
unit_id for unit_id in metrics[list(metrics.keys())[0]].keys()
]
labels = {unit_id: [] for unit_id in unit_ids}
for metric in label_params:
if metric not in metrics:
Warning(f"{metric} not found in quality metrics; skipping")
continue
condition = label_params[metric]
if not len(condition) == 3:
raise ValueError(f"Condition {condition} must be of length 3")
compare = _comparison_to_function[condition[0]]
for unit_id in unit_ids:
if compare(
metrics[metric][unit_id],
condition[1],
):
labels[unit_id].extend(label_params[metric][2])
return labels
@staticmethod
def _compute_merge_groups(
metrics: Dict[str, Dict[str, Union[float, List[float]]]],
merge_params: Dict[str, List[Any]],
) -> Dict[str, List[str]]:
"""Identifies units to be merged based on the metrics and merge parameters.
Parameters
---------
quality_metrics : dict
Example: {"cosine_similarity" : {
"1" : {"1" : 1.00, "2" : 0.10, "3": 0.95},
"2" : {"1" : 0.10, "2" : 1.00, "3": 0.70},
"3" : {"1" : 0.95, "2" : 0.70, "3": 1.00}
}}
This shows the pairwise values of the "cosine_similarity" quality metric
for the units "1", "2", "3" as a nested dict.
merge_params : dict
Example: {"cosine_similarity" : [">", 0.9]}
This indicates that units with values of the "cosine_similarity" quality metric
greater than 0.9 should be placed in the same merge group.
Returns
-------
merge_groups : dict
Example: {"1" : ["3"], "2" : [], "3" : ["1"]}
"""
if not merge_params:
return []
unit_ids = list(metrics[list(metrics.keys())[0]].keys())
merge_groups = {unit_id: [] for unit_id in unit_ids}
for metric in merge_params:
if metric not in metrics:
Warning(f"{metric} not found in quality metrics; skipping")
continue
compare = _comparison_to_function[merge_params[metric][0]]
for unit_id in unit_ids:
other_unit_ids = [
other_unit_id
for other_unit_id in unit_ids
if other_unit_id != unit_id
]
for other_unit_id in other_unit_ids:
if compare(
metrics[metric][unit_id][other_unit_id],
merge_params[metric][1],
):
merge_groups[unit_id].extend(other_unit_id)
return merge_groups