Skip to content

metric_curation.py

MetricCurationSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricCurationSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Spike sorting and parameters for metric curation. Use `insert_selection` to insert a row into this table.
    metric_curation_id: uuid
    ---
    -> CurationV1
    -> WaveformParameters
    -> MetricParameters
    -> MetricCurationParameters
    """

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into MetricCurationSelection with an
        automatically generated unique metric curation ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

        Returns
        -------
        key : dict
            key for the inserted row
        """
        if cls & key:
            logger.warn("This row has already been inserted.")
            return (cls & key).fetch1()
        key["metric_curation_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into MetricCurationSelection with an automatically generated unique metric curation ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

required

Returns:

Name Type Description
key dict

key for the inserted row

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into MetricCurationSelection with an
    automatically generated unique metric curation ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

    Returns
    -------
    key : dict
        key for the inserted row
    """
    if cls & key:
        logger.warn("This row has already been inserted.")
        return (cls & key).fetch1()
    key["metric_curation_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

MetricCuration

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricCuration(SpyglassMixin, dj.Computed):
    definition = """
    # Results of applying curation based on quality metrics. To do additional curation, insert another row in `CurationV1`
    -> MetricCurationSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40) # Object ID for the metrics in NWB file
    """

    def make(self, key):
        # FETCH
        nwb_file_name = (
            SpikeSortingSelection * MetricCurationSelection & key
        ).fetch1("nwb_file_name")

        waveform_params = (
            WaveformParameters * MetricCurationSelection & key
        ).fetch1("waveform_params")
        metric_params = (
            MetricParameters * MetricCurationSelection & key
        ).fetch1("metric_params")
        label_params, merge_params = (
            MetricCurationParameters * MetricCurationSelection & key
        ).fetch1("label_params", "merge_params")
        sorting_id, curation_id = (MetricCurationSelection & key).fetch1(
            "sorting_id", "curation_id"
        )
        # DO
        # load recording and sorting
        recording = CurationV1.get_recording(
            {"sorting_id": sorting_id, "curation_id": curation_id}
        )
        sorting = CurationV1.get_sorting(
            {"sorting_id": sorting_id, "curation_id": curation_id}
        )
        # extract waveforms
        if "whiten" in waveform_params:
            if waveform_params.pop("whiten"):
                recording = sp.whiten(recording, dtype=np.float64)

        waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"])
        os.makedirs(waveforms_dir, exist_ok=True)

        logger.info("Extracting waveforms...")
        waveforms = si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=waveforms_dir,
            overwrite=True,
            **waveform_params,
        )
        # compute metrics
        logger.info("Computing metrics...")
        metrics = {}
        for metric_name, metric_param_dict in metric_params.items():
            metrics[metric_name] = self._compute_metric(
                waveforms, metric_name, **metric_param_dict
            )
        if metrics["nn_isolation"]:
            metrics["nn_isolation"] = {
                unit_id: value[0]
                for unit_id, value in metrics["nn_isolation"].items()
            }

        logger.info("Applying curation...")
        labels = self._compute_labels(metrics, label_params)
        merge_groups = self._compute_merge_groups(metrics, merge_params)

        logger.info("Saving to NWB...")
        (
            key["analysis_file_name"],
            key["object_id"],
        ) = _write_metric_curation_to_nwb(
            nwb_file_name, waveforms, metrics, labels, merge_groups
        )

        # INSERT
        AnalysisNwbfile().add(
            nwb_file_name,
            key["analysis_file_name"],
        )
        self.insert1(key)

    @classmethod
    def get_waveforms(cls):
        return NotImplementedError

    @classmethod
    def get_metrics(cls, key: dict):
        """Returns metrics identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id, metric_param_name, metric_params = (
            cls * MetricCurationSelection * MetricParameters & key
        ).fetch1(
            "analysis_file_name",
            "object_id",
            "metric_param_name",
            "metric_params",
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        return {
            name: dict(zip(units.index, units[name])) for name in metric_params
        }

    @classmethod
    def get_labels(cls, key: dict):
        """Returns curation labels identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id = (cls & key).fetch1(
            "analysis_file_name", "object_id"
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        return dict(zip(units.index, units["curation_label"]))

    @classmethod
    def get_merge_groups(cls, key: dict):
        """Returns merge groups identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id = (cls & key).fetch1(
            "analysis_file_name", "object_id"
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        merge_group_dict = dict(zip(units.index, units["merge_groups"]))

        return _merge_dict_to_list(merge_group_dict)

    @staticmethod
    def _compute_metric(waveform_extractor, metric_name, **metric_params):
        metric_func = _metric_name_to_func[metric_name]

        peak_sign_metrics = ["snr", "peak_offset", "peak_channel"]
        if metric_name in peak_sign_metrics:
            if "peak_sign" not in metric_params:
                raise Exception(
                    f"{peak_sign_metrics} metrics require peak_sign",
                    "to be defined in the metric parameters",
                )
            return metric_func(
                waveform_extractor,
                peak_sign=metric_params.pop("peak_sign"),
                **metric_params,
            )

        return {
            unit_id: metric_func(waveform_extractor, this_unit_id=unit_id)
            for unit_id in waveform_extractor.sorting.get_unit_ids()
        }

    @staticmethod
    def _compute_labels(
        metrics: Dict[str, Dict[str, Union[float, List[float]]]],
        label_params: Dict[str, List[Any]],
    ) -> Dict[str, List[str]]:
        """Computes the labels based on the metric and label parameters.

        Parameters
        ----------
        quality_metrics : dict
            Example: {"snr" : {"1" : 2, "2" : 0.1, "3" : 2.3}}
            This indicates that the values of the "snr" quality metric
            for the units "1", "2", "3" are 2, 0.1, and 2.3, respectively.

        label_params : dict
            Example: {
                        "snr" : [(">", 1, ["good", "mua"]),
                                 ("<", 1, ["noise"])]
                     }
            This indicates that units with values of the "snr" quality metric
            greater than 1 should be given the labels "good" and "mua" and values
            less than 1 should be given the label "noise".

        Returns
        -------
        labels : dict
            Example: {"1" : ["good", "mua"], "2" : ["noise"], "3" : ["good", "mua"]}

        """
        if not label_params:
            return {}

        unit_ids = [
            unit_id for unit_id in metrics[list(metrics.keys())[0]].keys()
        ]
        labels = {unit_id: [] for unit_id in unit_ids}

        for metric in label_params:
            if metric not in metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
                continue

            condition = label_params[metric]
            if not len(condition) == 3:
                raise ValueError(f"Condition {condition} must be of length 3")

            compare = _comparison_to_function[condition[0]]
            for unit_id in unit_ids:
                if compare(
                    metrics[metric][unit_id],
                    condition[1],
                ):
                    labels[unit_id].extend(label_params[metric][2])
        return labels

    @staticmethod
    def _compute_merge_groups(
        metrics: Dict[str, Dict[str, Union[float, List[float]]]],
        merge_params: Dict[str, List[Any]],
    ) -> Dict[str, List[str]]:
        """Identifies units to be merged based on the metrics and merge parameters.

        Parameters
        ---------
        quality_metrics : dict
            Example: {"cosine_similarity" : {
                                             "1" : {"1" : 1.00, "2" : 0.10, "3": 0.95},
                                             "2" : {"1" : 0.10, "2" : 1.00, "3": 0.70},
                                             "3" : {"1" : 0.95, "2" : 0.70, "3": 1.00}
                                            }}
            This shows the pairwise values of the "cosine_similarity" quality metric
            for the units "1", "2", "3" as a nested dict.

        merge_params : dict
            Example: {"cosine_similarity" : [">", 0.9]}
            This indicates that units with values of the "cosine_similarity" quality metric
            greater than 0.9 should be placed in the same merge group.


        Returns
        -------
        merge_groups : dict
            Example: {"1" : ["3"], "2" : [], "3" : ["1"]}

        """

        if not merge_params:
            return []

        unit_ids = list(metrics[list(metrics.keys())[0]].keys())
        merge_groups = {unit_id: [] for unit_id in unit_ids}
        for metric in merge_params:
            if metric not in metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
                continue
            compare = _comparison_to_function[merge_params[metric][0]]
            for unit_id in unit_ids:
                other_unit_ids = [
                    other_unit_id
                    for other_unit_id in unit_ids
                    if other_unit_id != unit_id
                ]
                for other_unit_id in other_unit_ids:
                    if compare(
                        metrics[metric][unit_id][other_unit_id],
                        merge_params[metric][1],
                    ):
                        merge_groups[unit_id].extend(other_unit_id)
        return merge_groups

get_metrics(key) classmethod

Returns metrics identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_metrics(cls, key: dict):
    """Returns metrics identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id, metric_param_name, metric_params = (
        cls * MetricCurationSelection * MetricParameters & key
    ).fetch1(
        "analysis_file_name",
        "object_id",
        "metric_param_name",
        "metric_params",
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    return {
        name: dict(zip(units.index, units[name])) for name in metric_params
    }

get_labels(key) classmethod

Returns curation labels identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_labels(cls, key: dict):
    """Returns curation labels identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id = (cls & key).fetch1(
        "analysis_file_name", "object_id"
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    return dict(zip(units.index, units["curation_label"]))

get_merge_groups(key) classmethod

Returns merge groups identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_merge_groups(cls, key: dict):
    """Returns merge groups identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id = (cls & key).fetch1(
        "analysis_file_name", "object_id"
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    merge_group_dict = dict(zip(units.index, units["merge_groups"]))

    return _merge_dict_to_list(merge_group_dict)