Skip to content

metric_curation.py

WaveformParameters

Bases: SpyglassMixin, Lookup

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class WaveformParameters(SpyglassMixin, dj.Lookup):
    definition = """
    # Parameters for extracting waveforms from the recording based on the sorting.
    waveform_param_name: varchar(80) # name of waveform extraction parameters
    ---
    waveform_params: blob # a dict of waveform extraction parameters
    """

    contents = [
        [
            "default_not_whitened",
            {
                "ms_before": 0.5,
                "ms_after": 0.5,
                "max_spikes_per_unit": 5000,
                "n_jobs": 5,
                "total_memory": "5G",
                "whiten": False,
            },
        ],
        [
            "default_whitened",
            {
                "ms_before": 0.5,
                "ms_after": 0.5,
                "max_spikes_per_unit": 5000,
                "n_jobs": 5,
                "total_memory": "5G",
                "whiten": True,
            },
        ],
    ]

    @classmethod
    def insert_default(cls):
        """Insert default waveform parameters."""
        cls.insert(cls.contents, skip_duplicates=True)

insert_default() classmethod

Insert default waveform parameters.

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def insert_default(cls):
    """Insert default waveform parameters."""
    cls.insert(cls.contents, skip_duplicates=True)

MetricParameters

Bases: SpyglassMixin, Lookup

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricParameters(SpyglassMixin, dj.Lookup):
    definition = """
    # Parameters for computing quality metrics of sorted units.
    metric_param_name: varchar(200)
    ---
    metric_params: blob
    """
    metric_default_param_name = "franklab_default"
    metric_default_param = {
        "snr": {
            "peak_sign": "neg",
            "random_chunk_kwargs_dict": {
                "num_chunks_per_segment": 20,
                "chunk_size": 10000,
                "seed": 0,
            },
        },
        "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0},
        "nn_isolation": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "nn_noise_overlap": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "peak_channel": {"peak_sign": "neg"},
        "num_spikes": {},
    }
    contents = [[metric_default_param_name, metric_default_param]]

    @classmethod
    def insert_default(cls):
        """Insert default metric parameters."""
        cls.insert(cls.contents, skip_duplicates=True)

    @classmethod
    def show_available_metrics(self):
        """Prints the available metrics and their descriptions."""
        for metric in _metric_name_to_func:
            metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
            logger.info(f"{metric} : {metric_doc}\n")

insert_default() classmethod

Insert default metric parameters.

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def insert_default(cls):
    """Insert default metric parameters."""
    cls.insert(cls.contents, skip_duplicates=True)

show_available_metrics() classmethod

Prints the available metrics and their descriptions.

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def show_available_metrics(self):
    """Prints the available metrics and their descriptions."""
    for metric in _metric_name_to_func:
        metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
        logger.info(f"{metric} : {metric_doc}\n")

MetricCurationParameters

Bases: SpyglassMixin, Lookup

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricCurationParameters(SpyglassMixin, dj.Lookup):
    definition = """
    # Parameters for curating a spike sorting based on the metrics.
    metric_curation_param_name: varchar(200)
    ---
    label_params: blob   # dict of param to label units
    merge_params: blob   # dict of param to merge units
    """

    contents = [
        ["default", {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]}, {}],
        ["none", {}, {}],
    ]

    @classmethod
    def insert_default(cls):
        """Insert default metric curation parameters."""
        cls.insert(cls.contents, skip_duplicates=True)

insert_default() classmethod

Insert default metric curation parameters.

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def insert_default(cls):
    """Insert default metric curation parameters."""
    cls.insert(cls.contents, skip_duplicates=True)

MetricCurationSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricCurationSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Spike sorting and parameters for metric curation. Use `insert_selection` to insert a row into this table.
    metric_curation_id: uuid
    ---
    -> CurationV1
    -> WaveformParameters
    -> MetricParameters
    -> MetricCurationParameters
    """

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into MetricCurationSelection with an
        automatically generated unique metric curation ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

        Returns
        -------
        key : dict
            key for the inserted row
        """
        if cls & key:
            logger.warning("This row has already been inserted.")
            return (cls & key).fetch1()
        key["metric_curation_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into MetricCurationSelection with an automatically generated unique metric curation ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

required

Returns:

Name Type Description
key dict

key for the inserted row

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into MetricCurationSelection with an
    automatically generated unique metric curation ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters

    Returns
    -------
    key : dict
        key for the inserted row
    """
    if cls & key:
        logger.warning("This row has already been inserted.")
        return (cls & key).fetch1()
    key["metric_curation_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

MetricCuration

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@schema
class MetricCuration(SpyglassMixin, dj.Computed):
    definition = """
    # Results of applying curation based on quality metrics. To do additional curation, insert another row in `CurationV1`
    -> MetricCurationSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40) # Object ID for the metrics in NWB file
    """

    _use_transaction, _allow_insert = False, True

    def make(self, key):
        """Populate MetricCuration table.

        1. Fetches...
            - Waveform parameters from WaveformParameters
            - Metric parameters from MetricParameters
            - Label and merge parameters from MetricCurationParameters
            - Sorting ID and curation ID from MetricCurationSelection
        2. Loads the recording and sorting from CurationV1.
        3. Optionally whitens the recording with spikeinterface
        4. Extracts waveforms from the recording based on the sorting.
        5. Optionally computes quality metrics for the units.
        6. Applies curation based on the metrics, computing labels and merge
            groups.
        7. Saves the waveforms, metrics, labels, and merge groups to an
            analysis NWB file and inserts into MetricCuration table.
        """

        AnalysisNwbfile()._creation_times["pre_create_time"] = time()
        # FETCH
        nwb_file_name = (
            SpikeSortingSelection * MetricCurationSelection & key
        ).fetch1("nwb_file_name")

        # TODO: reduce fetch calls on same tables
        waveform_params = (
            WaveformParameters * MetricCurationSelection & key
        ).fetch1("waveform_params")
        metric_params = (
            MetricParameters * MetricCurationSelection & key
        ).fetch1("metric_params")
        label_params, merge_params = (
            MetricCurationParameters * MetricCurationSelection & key
        ).fetch1("label_params", "merge_params")
        sorting_id, curation_id = (MetricCurationSelection & key).fetch1(
            "sorting_id", "curation_id"
        )
        # DO
        # load recording and sorting
        recording = CurationV1.get_recording(
            {"sorting_id": sorting_id, "curation_id": curation_id}
        )
        sorting = CurationV1.get_sorting(
            {"sorting_id": sorting_id, "curation_id": curation_id}
        )
        # extract waveforms
        if "whiten" in waveform_params:
            if waveform_params.pop("whiten"):
                recording = sp.whiten(recording, dtype=np.float64)

        waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"])
        os.makedirs(waveforms_dir, exist_ok=True)

        logger.info("Extracting waveforms...")

        # Extract non-sparse waveforms by default
        waveform_params.setdefault("sparse", False)

        waveforms = si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=waveforms_dir,
            overwrite=True,
            **waveform_params,
        )
        # compute metrics
        logger.info("Computing metrics...")
        metrics = {}
        for metric_name, metric_param_dict in metric_params.items():
            metrics[metric_name] = self._compute_metric(
                waveforms, metric_name, **metric_param_dict
            )
        if metrics["nn_isolation"]:
            metrics["nn_isolation"] = {
                unit_id: value[0]
                for unit_id, value in metrics["nn_isolation"].items()
            }

        logger.info("Applying curation...")
        labels = self._compute_labels(metrics, label_params)
        merge_groups = self._compute_merge_groups(metrics, merge_params)

        logger.info("Saving to NWB...")
        (
            key["analysis_file_name"],
            key["object_id"],
        ) = _write_metric_curation_to_nwb(
            nwb_file_name, waveforms, metrics, labels, merge_groups
        )

        # INSERT
        AnalysisNwbfile().add(
            nwb_file_name,
            key["analysis_file_name"],
        )
        AnalysisNwbfile().log(key, table=self.full_table_name)
        self.insert1(key)

    @classmethod
    def get_waveforms(cls):
        """Returns waveforms identified by metric curation. Not implemented."""
        return NotImplementedError

    @classmethod
    def get_metrics(cls, key: dict):
        """Returns metrics identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id, metric_param_name, metric_params = (
            cls * MetricCurationSelection * MetricParameters & key
        ).fetch1(
            "analysis_file_name",
            "object_id",
            "metric_param_name",
            "metric_params",
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        return {
            name: dict(zip(units.index, units[name])) for name in metric_params
        }

    @classmethod
    def get_labels(cls, key: dict):
        """Returns curation labels identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id = (cls & key).fetch1(
            "analysis_file_name", "object_id"
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        return dict(zip(units.index, units["curation_label"]))

    @classmethod
    def get_merge_groups(cls, key: dict):
        """Returns merge groups identified by metric curation

        Parameters
        ----------
        key : dict
            primary key to MetricCuration
        """
        analysis_file_name, object_id = (cls & key).fetch1(
            "analysis_file_name", "object_id"
        )
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            path=analysis_file_abs_path,
            mode="r",
            load_namespaces=True,
        ) as io:
            nwbf = io.read()
            units = nwbf.objects[object_id].to_dataframe()
        merge_group_dict = dict(zip(units.index, units["merge_groups"]))

        return _merge_dict_to_list(merge_group_dict)

    @staticmethod
    def _compute_metric(waveform_extractor, metric_name, **metric_params):
        metric_func = _metric_name_to_func[metric_name]

        peak_sign_metrics = ["snr", "peak_offset", "peak_channel"]
        if metric_name in peak_sign_metrics:
            if "peak_sign" not in metric_params:
                raise Exception(
                    f"{peak_sign_metrics} metrics require peak_sign",
                    "to be defined in the metric parameters",
                )
            return metric_func(
                waveform_extractor,
                peak_sign=metric_params.pop("peak_sign"),
                **metric_params,
            )

        return {
            unit_id: metric_func(waveform_extractor, this_unit_id=unit_id)
            for unit_id in waveform_extractor.sorting.get_unit_ids()
        }

    @staticmethod
    def _compute_labels(
        metrics: Dict[str, Dict[str, Union[float, List[float]]]],
        label_params: Dict[str, List[Any]],
    ) -> Dict[str, List[str]]:
        """Computes the labels based on the metric and label parameters.

        Parameters
        ----------
        quality_metrics : dict
            Example: {"snr" : {"1" : 2, "2" : 0.1, "3" : 2.3}}
            This indicates that the values of the "snr" quality metric
            for the units "1", "2", "3" are 2, 0.1, and 2.3, respectively.

        label_params : dict
            Example: {
                        "snr" : [(">", 1, ["good", "mua"]),
                                 ("<", 1, ["noise"])]
                     }
            This indicates that units with values of the "snr" quality metric
            greater than 1 should be given the labels "good" and "mua" and values
            less than 1 should be given the label "noise".

        Returns
        -------
        labels : dict
            Example: {"1" : ["good", "mua"], "2" : ["noise"], "3" : ["good", "mua"]}

        """
        if not label_params:
            return {}

        unit_ids = [
            unit_id for unit_id in metrics[list(metrics.keys())[0]].keys()
        ]
        labels = {unit_id: [] for unit_id in unit_ids}

        for metric in label_params:
            if metric not in metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
                continue

            condition = label_params[metric]
            if not len(condition) == 3:
                raise ValueError(f"Condition {condition} must be of length 3")

            compare = _comparison_to_function[condition[0]]
            for unit_id in unit_ids:
                if compare(
                    metrics[metric][unit_id],
                    condition[1],
                ):
                    labels[unit_id].extend(label_params[metric][2])
        return labels

    @staticmethod
    def _compute_merge_groups(
        metrics: Dict[str, Dict[str, Union[float, List[float]]]],
        merge_params: Dict[str, List[Any]],
    ) -> Dict[str, List[str]]:
        """Identifies units to be merged based on the metrics and merge parameters.

        Parameters
        ---------
        quality_metrics : dict
            Example: {"cosine_similarity" : {
                                             "1" : {"1" : 1.00, "2" : 0.10, "3": 0.95},
                                             "2" : {"1" : 0.10, "2" : 1.00, "3": 0.70},
                                             "3" : {"1" : 0.95, "2" : 0.70, "3": 1.00}
                                            }}
            This shows the pairwise values of the "cosine_similarity" quality metric
            for the units "1", "2", "3" as a nested dict.

        merge_params : dict
            Example: {"cosine_similarity" : [">", 0.9]}
            This indicates that units with values of the "cosine_similarity" quality metric
            greater than 0.9 should be placed in the same merge group.


        Returns
        -------
        merge_groups : dict
            Example: {"1" : ["3"], "2" : [], "3" : ["1"]}

        """

        if not merge_params:
            return []

        unit_ids = list(metrics[list(metrics.keys())[0]].keys())
        merge_groups = {unit_id: [] for unit_id in unit_ids}
        for metric in merge_params:
            if metric not in metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
                continue
            compare = _comparison_to_function[merge_params[metric][0]]
            for unit_id in unit_ids:
                other_unit_ids = [
                    other_unit_id
                    for other_unit_id in unit_ids
                    if other_unit_id != unit_id
                ]
                for other_unit_id in other_unit_ids:
                    if compare(
                        metrics[metric][unit_id][other_unit_id],
                        merge_params[metric][1],
                    ):
                        merge_groups[unit_id].extend(other_unit_id)
        return merge_groups

make(key)

Populate MetricCuration table.

  1. Fetches...
    • Waveform parameters from WaveformParameters
    • Metric parameters from MetricParameters
    • Label and merge parameters from MetricCurationParameters
    • Sorting ID and curation ID from MetricCurationSelection
  2. Loads the recording and sorting from CurationV1.
  3. Optionally whitens the recording with spikeinterface
  4. Extracts waveforms from the recording based on the sorting.
  5. Optionally computes quality metrics for the units.
  6. Applies curation based on the metrics, computing labels and merge groups.
  7. Saves the waveforms, metrics, labels, and merge groups to an analysis NWB file and inserts into MetricCuration table.
Source code in src/spyglass/spikesorting/v1/metric_curation.py
def make(self, key):
    """Populate MetricCuration table.

    1. Fetches...
        - Waveform parameters from WaveformParameters
        - Metric parameters from MetricParameters
        - Label and merge parameters from MetricCurationParameters
        - Sorting ID and curation ID from MetricCurationSelection
    2. Loads the recording and sorting from CurationV1.
    3. Optionally whitens the recording with spikeinterface
    4. Extracts waveforms from the recording based on the sorting.
    5. Optionally computes quality metrics for the units.
    6. Applies curation based on the metrics, computing labels and merge
        groups.
    7. Saves the waveforms, metrics, labels, and merge groups to an
        analysis NWB file and inserts into MetricCuration table.
    """

    AnalysisNwbfile()._creation_times["pre_create_time"] = time()
    # FETCH
    nwb_file_name = (
        SpikeSortingSelection * MetricCurationSelection & key
    ).fetch1("nwb_file_name")

    # TODO: reduce fetch calls on same tables
    waveform_params = (
        WaveformParameters * MetricCurationSelection & key
    ).fetch1("waveform_params")
    metric_params = (
        MetricParameters * MetricCurationSelection & key
    ).fetch1("metric_params")
    label_params, merge_params = (
        MetricCurationParameters * MetricCurationSelection & key
    ).fetch1("label_params", "merge_params")
    sorting_id, curation_id = (MetricCurationSelection & key).fetch1(
        "sorting_id", "curation_id"
    )
    # DO
    # load recording and sorting
    recording = CurationV1.get_recording(
        {"sorting_id": sorting_id, "curation_id": curation_id}
    )
    sorting = CurationV1.get_sorting(
        {"sorting_id": sorting_id, "curation_id": curation_id}
    )
    # extract waveforms
    if "whiten" in waveform_params:
        if waveform_params.pop("whiten"):
            recording = sp.whiten(recording, dtype=np.float64)

    waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"])
    os.makedirs(waveforms_dir, exist_ok=True)

    logger.info("Extracting waveforms...")

    # Extract non-sparse waveforms by default
    waveform_params.setdefault("sparse", False)

    waveforms = si.extract_waveforms(
        recording=recording,
        sorting=sorting,
        folder=waveforms_dir,
        overwrite=True,
        **waveform_params,
    )
    # compute metrics
    logger.info("Computing metrics...")
    metrics = {}
    for metric_name, metric_param_dict in metric_params.items():
        metrics[metric_name] = self._compute_metric(
            waveforms, metric_name, **metric_param_dict
        )
    if metrics["nn_isolation"]:
        metrics["nn_isolation"] = {
            unit_id: value[0]
            for unit_id, value in metrics["nn_isolation"].items()
        }

    logger.info("Applying curation...")
    labels = self._compute_labels(metrics, label_params)
    merge_groups = self._compute_merge_groups(metrics, merge_params)

    logger.info("Saving to NWB...")
    (
        key["analysis_file_name"],
        key["object_id"],
    ) = _write_metric_curation_to_nwb(
        nwb_file_name, waveforms, metrics, labels, merge_groups
    )

    # INSERT
    AnalysisNwbfile().add(
        nwb_file_name,
        key["analysis_file_name"],
    )
    AnalysisNwbfile().log(key, table=self.full_table_name)
    self.insert1(key)

get_waveforms() classmethod

Returns waveforms identified by metric curation. Not implemented.

Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_waveforms(cls):
    """Returns waveforms identified by metric curation. Not implemented."""
    return NotImplementedError

get_metrics(key) classmethod

Returns metrics identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_metrics(cls, key: dict):
    """Returns metrics identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id, metric_param_name, metric_params = (
        cls * MetricCurationSelection * MetricParameters & key
    ).fetch1(
        "analysis_file_name",
        "object_id",
        "metric_param_name",
        "metric_params",
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    return {
        name: dict(zip(units.index, units[name])) for name in metric_params
    }

get_labels(key) classmethod

Returns curation labels identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_labels(cls, key: dict):
    """Returns curation labels identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id = (cls & key).fetch1(
        "analysis_file_name", "object_id"
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    return dict(zip(units.index, units["curation_label"]))

get_merge_groups(key) classmethod

Returns merge groups identified by metric curation

Parameters:

Name Type Description Default
key dict

primary key to MetricCuration

required
Source code in src/spyglass/spikesorting/v1/metric_curation.py
@classmethod
def get_merge_groups(cls, key: dict):
    """Returns merge groups identified by metric curation

    Parameters
    ----------
    key : dict
        primary key to MetricCuration
    """
    analysis_file_name, object_id = (cls & key).fetch1(
        "analysis_file_name", "object_id"
    )
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        path=analysis_file_abs_path,
        mode="r",
        load_namespaces=True,
    ) as io:
        nwbf = io.read()
        units = nwbf.objects[object_id].to_dataframe()
    merge_group_dict = dict(zip(units.index, units["merge_groups"]))

    return _merge_dict_to_list(merge_group_dict)