Skip to content

spikesorting_curation.py

Curation

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class Curation(SpyglassMixin, dj.Manual):
    definition = """
    # Stores each spike sorting; similar to IntervalList
    curation_id: int # a number corresponding to the index of this curation
    -> SpikeSorting
    ---
    parent_curation_id=-1: int
    curation_labels: blob # a dictionary of labels for the units
    merge_groups: blob # a list of merge groups for the units
    quality_metrics: blob # a list of quality metrics for the units (if available)
    description='': varchar(1000) #optional description for this curated sort
    time_of_creation: int   # in Unix time, to the nearest second
    """

    _nwb_table = AnalysisNwbfile

    @staticmethod
    def insert_curation(
        sorting_key: dict,
        parent_curation_id: int = -1,
        labels=None,
        merge_groups=None,
        metrics=None,
        description="",
    ):
        """Given a SpikeSorting key and the parent_sorting_id (and optional
        arguments) insert an entry into Curation.


        Parameters
        ----------
        sorting_key : dict
            The key for the original SpikeSorting
        parent_curation_id : int, optional
            The id of the parent sorting
        labels : dict or None, optional
        merge_groups : dict or None, optional
        metrics : dict or None, optional
            Computed metrics for sorting
        description : str, optional
            text description of this sort

        Returns
        -------
        curation_key : dict

        """
        if parent_curation_id == -1:
            # check to see if this sorting with a parent of -1 has already been
            # inserted and if so, warn the user
            inserted_curation = (Curation & sorting_key).fetch("KEY")
            if len(inserted_curation) > 0:
                Warning(
                    "Sorting has already been inserted, returning key to previously"
                    "inserted curation"
                )
                return inserted_curation[0]

        if labels is None:
            labels = {}
        if merge_groups is None:
            merge_groups = []
        if metrics is None:
            metrics = {}

        # generate a unique number for this curation
        id = (Curation & sorting_key).fetch("curation_id")
        if len(id) > 0:
            curation_id = max(id) + 1
        else:
            curation_id = 0

        # convert unit_ids in labels to integers for labels from sortingview.
        new_labels = {int(unit_id): labels[unit_id] for unit_id in labels}

        sorting_key["curation_id"] = curation_id
        sorting_key["parent_curation_id"] = parent_curation_id
        sorting_key["description"] = description
        sorting_key["curation_labels"] = new_labels
        sorting_key["merge_groups"] = merge_groups
        sorting_key["quality_metrics"] = metrics
        sorting_key["time_of_creation"] = int(time.time())

        # mike: added skip duplicates
        Curation.insert1(sorting_key, skip_duplicates=True)

        # get the primary key for this curation
        c_key = Curation.fetch("KEY")[0]
        curation_key = {item: sorting_key[item] for item in c_key}

        return curation_key

    @staticmethod
    def get_recording(key: dict):
        """Returns the recording extractor for the recording related to this curation

        Parameters
        ----------
        key : dict
            SpikeSortingRecording key

        Returns
        -------
        recording_extractor : spike interface recording extractor

        """
        recording_path = (SpikeSortingRecording & key).fetch1("recording_path")
        return si.load_extractor(recording_path)

    @staticmethod
    def get_curated_sorting(key: dict):
        """Returns the sorting extractor related to this curation,
        with merges applied.

        Parameters
        ----------
        key : dict
            Curation key

        Returns
        -------
        sorting_extractor: spike interface sorting extractor

        """
        sorting_path = (SpikeSorting & key).fetch1("sorting_path")
        sorting = si.load_extractor(sorting_path)
        merge_groups = (Curation & key).fetch1("merge_groups")
        # TODO: write code to get merged sorting extractor
        if len(merge_groups) != 0:
            return MergedSortingExtractor(
                parent_sorting=sorting, merge_groups=merge_groups
            )
        else:
            return sorting

    @staticmethod
    def save_sorting_nwb(
        key,
        sorting,
        timestamps,
        sort_interval_list_name,
        sort_interval,
        labels=None,
        metrics=None,
        unit_ids=None,
    ):
        """Store a sorting in a new AnalysisNwbfile

        Parameters
        ----------
        key : dict
            key to SpikeSorting table
        sorting : si.Sorting
            sorting
        timestamps : array_like
            Time stamps of the sorted recoridng;
            used to convert the spike timings from index to real time
        sort_interval_list_name : str
            name of sort interval
        sort_interval : list
            interval for start and end of sort
        labels : dict, optional
            curation labels, by default None
        metrics : dict, optional
            quality metrics, by default None
        unit_ids : list, optional
            IDs of units whose spiketrains to save, by default None

        Returns
        -------
        analysis_file_name : str
        units_object_id : str

        """

        sort_interval_valid_times = (
            IntervalList & {"interval_list_name": sort_interval_list_name}
        ).fetch1("valid_times")

        units = dict()
        units_valid_times = dict()
        units_sort_interval = dict()

        if unit_ids is None:
            unit_ids = sorting.get_unit_ids()

        for unit_id in unit_ids:
            spike_times_in_samples = sorting.get_unit_spike_train(
                unit_id=unit_id
            )
            units[unit_id] = timestamps[spike_times_in_samples]
            units_valid_times[unit_id] = sort_interval_valid_times
            units_sort_interval[unit_id] = [sort_interval]

        analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
        object_ids = AnalysisNwbfile().add_units(
            analysis_file_name,
            units,
            units_valid_times,
            units_sort_interval,
            metrics=metrics,
            labels=labels,
        )
        AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name)

        if object_ids == "":
            logger.warn(
                "Sorting contains no units."
                "Created an empty analysis nwb file anyway."
            )
            units_object_id = ""
        else:
            units_object_id = object_ids[0]

        return analysis_file_name, units_object_id

insert_curation(sorting_key, parent_curation_id=-1, labels=None, merge_groups=None, metrics=None, description='') staticmethod

Given a SpikeSorting key and the parent_sorting_id (and optional arguments) insert an entry into Curation.

Parameters:

Name Type Description Default
sorting_key dict

The key for the original SpikeSorting

required
parent_curation_id int

The id of the parent sorting

-1
labels dict or None
None
merge_groups dict or None
None
metrics dict or None

Computed metrics for sorting

None
description str

text description of this sort

''

Returns:

Name Type Description
curation_key dict
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def insert_curation(
    sorting_key: dict,
    parent_curation_id: int = -1,
    labels=None,
    merge_groups=None,
    metrics=None,
    description="",
):
    """Given a SpikeSorting key and the parent_sorting_id (and optional
    arguments) insert an entry into Curation.


    Parameters
    ----------
    sorting_key : dict
        The key for the original SpikeSorting
    parent_curation_id : int, optional
        The id of the parent sorting
    labels : dict or None, optional
    merge_groups : dict or None, optional
    metrics : dict or None, optional
        Computed metrics for sorting
    description : str, optional
        text description of this sort

    Returns
    -------
    curation_key : dict

    """
    if parent_curation_id == -1:
        # check to see if this sorting with a parent of -1 has already been
        # inserted and if so, warn the user
        inserted_curation = (Curation & sorting_key).fetch("KEY")
        if len(inserted_curation) > 0:
            Warning(
                "Sorting has already been inserted, returning key to previously"
                "inserted curation"
            )
            return inserted_curation[0]

    if labels is None:
        labels = {}
    if merge_groups is None:
        merge_groups = []
    if metrics is None:
        metrics = {}

    # generate a unique number for this curation
    id = (Curation & sorting_key).fetch("curation_id")
    if len(id) > 0:
        curation_id = max(id) + 1
    else:
        curation_id = 0

    # convert unit_ids in labels to integers for labels from sortingview.
    new_labels = {int(unit_id): labels[unit_id] for unit_id in labels}

    sorting_key["curation_id"] = curation_id
    sorting_key["parent_curation_id"] = parent_curation_id
    sorting_key["description"] = description
    sorting_key["curation_labels"] = new_labels
    sorting_key["merge_groups"] = merge_groups
    sorting_key["quality_metrics"] = metrics
    sorting_key["time_of_creation"] = int(time.time())

    # mike: added skip duplicates
    Curation.insert1(sorting_key, skip_duplicates=True)

    # get the primary key for this curation
    c_key = Curation.fetch("KEY")[0]
    curation_key = {item: sorting_key[item] for item in c_key}

    return curation_key

get_recording(key) staticmethod

Returns the recording extractor for the recording related to this curation

Parameters:

Name Type Description Default
key dict

SpikeSortingRecording key

required

Returns:

Name Type Description
recording_extractor spike interface recording extractor
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def get_recording(key: dict):
    """Returns the recording extractor for the recording related to this curation

    Parameters
    ----------
    key : dict
        SpikeSortingRecording key

    Returns
    -------
    recording_extractor : spike interface recording extractor

    """
    recording_path = (SpikeSortingRecording & key).fetch1("recording_path")
    return si.load_extractor(recording_path)

get_curated_sorting(key) staticmethod

Returns the sorting extractor related to this curation, with merges applied.

Parameters:

Name Type Description Default
key dict

Curation key

required

Returns:

Name Type Description
sorting_extractor spike interface sorting extractor
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def get_curated_sorting(key: dict):
    """Returns the sorting extractor related to this curation,
    with merges applied.

    Parameters
    ----------
    key : dict
        Curation key

    Returns
    -------
    sorting_extractor: spike interface sorting extractor

    """
    sorting_path = (SpikeSorting & key).fetch1("sorting_path")
    sorting = si.load_extractor(sorting_path)
    merge_groups = (Curation & key).fetch1("merge_groups")
    # TODO: write code to get merged sorting extractor
    if len(merge_groups) != 0:
        return MergedSortingExtractor(
            parent_sorting=sorting, merge_groups=merge_groups
        )
    else:
        return sorting

save_sorting_nwb(key, sorting, timestamps, sort_interval_list_name, sort_interval, labels=None, metrics=None, unit_ids=None) staticmethod

Store a sorting in a new AnalysisNwbfile

Parameters:

Name Type Description Default
key dict

key to SpikeSorting table

required
sorting Sorting

sorting

required
timestamps array_like

Time stamps of the sorted recoridng; used to convert the spike timings from index to real time

required
sort_interval_list_name str

name of sort interval

required
sort_interval list

interval for start and end of sort

required
labels dict

curation labels, by default None

None
metrics dict

quality metrics, by default None

None
unit_ids list

IDs of units whose spiketrains to save, by default None

None

Returns:

Name Type Description
analysis_file_name str
units_object_id str
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def save_sorting_nwb(
    key,
    sorting,
    timestamps,
    sort_interval_list_name,
    sort_interval,
    labels=None,
    metrics=None,
    unit_ids=None,
):
    """Store a sorting in a new AnalysisNwbfile

    Parameters
    ----------
    key : dict
        key to SpikeSorting table
    sorting : si.Sorting
        sorting
    timestamps : array_like
        Time stamps of the sorted recoridng;
        used to convert the spike timings from index to real time
    sort_interval_list_name : str
        name of sort interval
    sort_interval : list
        interval for start and end of sort
    labels : dict, optional
        curation labels, by default None
    metrics : dict, optional
        quality metrics, by default None
    unit_ids : list, optional
        IDs of units whose spiketrains to save, by default None

    Returns
    -------
    analysis_file_name : str
    units_object_id : str

    """

    sort_interval_valid_times = (
        IntervalList & {"interval_list_name": sort_interval_list_name}
    ).fetch1("valid_times")

    units = dict()
    units_valid_times = dict()
    units_sort_interval = dict()

    if unit_ids is None:
        unit_ids = sorting.get_unit_ids()

    for unit_id in unit_ids:
        spike_times_in_samples = sorting.get_unit_spike_train(
            unit_id=unit_id
        )
        units[unit_id] = timestamps[spike_times_in_samples]
        units_valid_times[unit_id] = sort_interval_valid_times
        units_sort_interval[unit_id] = [sort_interval]

    analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])
    object_ids = AnalysisNwbfile().add_units(
        analysis_file_name,
        units,
        units_valid_times,
        units_sort_interval,
        metrics=metrics,
        labels=labels,
    )
    AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name)

    if object_ids == "":
        logger.warn(
            "Sorting contains no units."
            "Created an empty analysis nwb file anyway."
        )
        units_object_id = ""
    else:
        units_object_id = object_ids[0]

    return analysis_file_name, units_object_id

Waveforms

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class Waveforms(SpyglassMixin, dj.Computed):
    definition = """
    -> WaveformSelection
    ---
    waveform_extractor_path: varchar(400)
    -> AnalysisNwbfile
    waveforms_object_id: varchar(40)   # Object ID for the waveforms in NWB file
    """

    def make(self, key):
        recording = Curation.get_recording(key)
        if recording.get_num_segments() > 1:
            recording = si.concatenate_recordings([recording])

        sorting = Curation.get_curated_sorting(key)

        logger.info("Extracting waveforms...")
        waveform_params = (WaveformParameters & key).fetch1("waveform_params")
        if "whiten" in waveform_params:
            if waveform_params.pop("whiten"):
                recording = sip.whiten(recording, dtype="float32")

        waveform_extractor_name = self._get_waveform_extractor_name(key)
        key["waveform_extractor_path"] = str(
            Path(waveform_dir) / Path(waveform_extractor_name)
        )
        if os.path.exists(key["waveform_extractor_path"]):
            shutil.rmtree(key["waveform_extractor_path"])
        waveforms = si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=key["waveform_extractor_path"],
            **waveform_params,
        )

        key["analysis_file_name"] = AnalysisNwbfile().create(
            key["nwb_file_name"]
        )
        object_id = AnalysisNwbfile().add_units_waveforms(
            key["analysis_file_name"], waveform_extractor=waveforms
        )
        key["waveforms_object_id"] = object_id
        AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"])

        self.insert1(key)

    def load_waveforms(self, key: dict):
        """Returns a spikeinterface waveform extractor specified by key

        Parameters
        ----------
        key : dict
            Could be an entry in Waveforms, or some other key that uniquely defines
            an entry in Waveforms

        Returns
        -------
        we : spikeinterface.WaveformExtractor
        """
        we_path = (self & key).fetch1("waveform_extractor_path")
        we = si.WaveformExtractor.load_from_folder(we_path)
        return we

    def fetch_nwb(self, key):
        # TODO: implement fetching waveforms from NWB
        return NotImplementedError

    def _get_waveform_extractor_name(self, key):
        waveform_params_name = (WaveformParameters & key).fetch1(
            "waveform_params_name"
        )

        return (
            f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_'
            f'{key["curation_id"]}_{waveform_params_name}_waveforms'
        )

load_waveforms(key)

Returns a spikeinterface waveform extractor specified by key

Parameters:

Name Type Description Default
key dict

Could be an entry in Waveforms, or some other key that uniquely defines an entry in Waveforms

required

Returns:

Name Type Description
we WaveformExtractor
Source code in src/spyglass/spikesorting/spikesorting_curation.py
def load_waveforms(self, key: dict):
    """Returns a spikeinterface waveform extractor specified by key

    Parameters
    ----------
    key : dict
        Could be an entry in Waveforms, or some other key that uniquely defines
        an entry in Waveforms

    Returns
    -------
    we : spikeinterface.WaveformExtractor
    """
    we_path = (self & key).fetch1("waveform_extractor_path")
    we = si.WaveformExtractor.load_from_folder(we_path)
    return we

MetricParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class MetricParameters(SpyglassMixin, dj.Manual):
    definition = """
    # Parameters for computing quality metrics of sorted units
    metric_params_name: varchar(64)
    ---
    metric_params: blob
    """

    # NOTE: See #630, #664. Excessive key length.

    metric_default_params = {
        "snr": {
            "peak_sign": "neg",
            "random_chunk_kwargs_dict": {
                "num_chunks_per_segment": 20,
                "chunk_size": 10000,
                "seed": 0,
            },
        },
        "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0},
        "nn_isolation": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "nn_noise_overlap": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "peak_channel": {"peak_sign": "neg"},
        "num_spikes": {},
    }
    # Example of peak_offset parameters 'peak_offset': {'peak_sign': 'neg'}
    available_metrics = [
        "snr",
        "isi_violation",
        "nn_isolation",
        "nn_noise_overlap",
        "peak_offset",
        "peak_channel",
        "num_spikes",
    ]

    def get_metric_default_params(self, metric: str):
        "Returns default params for the given metric"
        return self.metric_default_params(metric)

    def insert_default(self):
        self.insert1(
            ["franklab_default3", self.metric_default_params],
            skip_duplicates=True,
        )

    def get_available_metrics(self):
        for metric in _metric_name_to_func:
            if metric in self.available_metrics:
                metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
                metric_string = ("{metric_name} : {metric_doc}").format(
                    metric_name=metric, metric_doc=metric_doc
                )
                logger.info(metric_string + "\n")

    # TODO
    def _validate_metrics_list(self, key):
        """Checks whether a row to be inserted contains only the available metrics"""
        # get available metrics list
        # get metric list from key
        # compare
        return NotImplementedError

get_metric_default_params(metric)

Returns default params for the given metric

Source code in src/spyglass/spikesorting/spikesorting_curation.py
def get_metric_default_params(self, metric: str):
    "Returns default params for the given metric"
    return self.metric_default_params(metric)

AutomaticCuration

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class AutomaticCuration(SpyglassMixin, dj.Computed):
    definition = """
    -> AutomaticCurationSelection
    ---
    auto_curation_key: blob # the key to the curation inserted by make
    """

    def make(self, key):
        metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path")
        with open(metrics_path) as f:
            quality_metrics = json.load(f)

        # get the curation information and the curated sorting
        parent_curation = (Curation & key).fetch(as_dict=True)[0]
        parent_merge_groups = parent_curation["merge_groups"]
        parent_labels = parent_curation["curation_labels"]
        parent_curation_id = parent_curation["curation_id"]
        parent_sorting = Curation.get_curated_sorting(key)

        merge_params = (AutomaticCurationParameters & key).fetch1(
            "merge_params"
        )
        merge_groups, units_merged = self.get_merge_groups(
            parent_sorting, parent_merge_groups, quality_metrics, merge_params
        )

        label_params = (AutomaticCurationParameters & key).fetch1(
            "label_params"
        )
        labels = self.get_labels(
            parent_sorting, parent_labels, quality_metrics, label_params
        )

        # keep the quality metrics only if no merging occurred.
        metrics = quality_metrics if not units_merged else None

        # insert this sorting into the CuratedSpikeSorting Table
        # first remove keys that aren't part of the Sorting (the primary key of curation)
        c_key = (SpikeSorting & key).fetch("KEY")[0]
        curation_key = {item: key[item] for item in key if item in c_key}
        key["auto_curation_key"] = Curation.insert_curation(
            curation_key,
            parent_curation_id=parent_curation_id,
            labels=labels,
            merge_groups=merge_groups,
            metrics=metrics,
            description="auto curated",
        )

        self.insert1(key)

    @staticmethod
    def get_merge_groups(
        sorting, parent_merge_groups, quality_metrics, merge_params
    ):
        """Identifies units to be merged based on the quality_metrics and
        merge parameters and returns an updated list of merges for the curation.

        Parameters
        ---------
        sorting : spikeinterface.sorting
        parent_merge_groups : list
            Information about previous merges
        quality_metrics : list
        merge_params : dict

        Returns
        -------
        merge_groups : list of lists
        merge_occurred : bool

        """

        # overview:
        # 1. Use quality metrics to determine merge groups for units
        # 2. Combine merge groups with current merge groups to produce union of merges

        if not merge_params:
            return parent_merge_groups, False
        else:
            # TODO: use the metrics to identify clusters that should be merged
            # new_merges should then reflect those merges and the line below should be deleted.
            new_merges = []
            # append these merges to the parent merge_groups
            for new_merge in new_merges:
                # check to see if the first cluster listed is in a current merge group
                for previous_merge in parent_merge_groups:
                    if new_merge[0] == previous_merge[0]:
                        # add the additional units in new_merge to the identified merge group.
                        previous_merge.extend(new_merge[1:])
                        previous_merge.sort()
                        break
                else:
                    # append this merge group to the list if no previous merge
                    parent_merge_groups.append(new_merge)
            return parent_merge_groups.sort(), True

    @staticmethod
    def get_labels(sorting, parent_labels, quality_metrics, label_params):
        """Returns a dictionary of labels using quality_metrics and label
        parameters.

        Parameters
        ---------
        sorting : spikeinterface.sorting
        parent_labels : list
            Information about previous merges
        quality_metrics : list
        label_params : dict

        Returns
        -------
        parent_labels : list

        """
        # overview:
        # 1. Use quality metrics to determine labels for units
        # 2. Append labels to current labels, checking for inconsistencies
        if not label_params:
            return parent_labels
        else:
            for metric in label_params:
                if metric not in quality_metrics:
                    Warning(f"{metric} not found in quality metrics; skipping")
                else:
                    compare = _comparison_to_function[label_params[metric][0]]

                    for unit_id in quality_metrics[metric].keys():
                        # compare the quality metric to the threshold with the specified operator
                        # note that label_params[metric] is a three element list with a comparison operator as a string,
                        # the threshold value, and a list of labels to be applied if the comparison is true
                        if compare(
                            quality_metrics[metric][unit_id],
                            label_params[metric][1],
                        ):
                            if unit_id not in parent_labels:
                                parent_labels[unit_id] = label_params[metric][2]
                            # check if the label is already there, and if not, add it
                            elif (
                                label_params[metric][2]
                                not in parent_labels[unit_id]
                            ):
                                parent_labels[unit_id].extend(
                                    label_params[metric][2]
                                )
            return parent_labels

get_merge_groups(sorting, parent_merge_groups, quality_metrics, merge_params) staticmethod

Identifies units to be merged based on the quality_metrics and merge parameters and returns an updated list of merges for the curation.

Parameters:

Name Type Description Default
sorting sorting
required
parent_merge_groups list

Information about previous merges

required
quality_metrics list
required
merge_params dict
required

Returns:

Name Type Description
merge_groups list of lists
merge_occurred bool
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def get_merge_groups(
    sorting, parent_merge_groups, quality_metrics, merge_params
):
    """Identifies units to be merged based on the quality_metrics and
    merge parameters and returns an updated list of merges for the curation.

    Parameters
    ---------
    sorting : spikeinterface.sorting
    parent_merge_groups : list
        Information about previous merges
    quality_metrics : list
    merge_params : dict

    Returns
    -------
    merge_groups : list of lists
    merge_occurred : bool

    """

    # overview:
    # 1. Use quality metrics to determine merge groups for units
    # 2. Combine merge groups with current merge groups to produce union of merges

    if not merge_params:
        return parent_merge_groups, False
    else:
        # TODO: use the metrics to identify clusters that should be merged
        # new_merges should then reflect those merges and the line below should be deleted.
        new_merges = []
        # append these merges to the parent merge_groups
        for new_merge in new_merges:
            # check to see if the first cluster listed is in a current merge group
            for previous_merge in parent_merge_groups:
                if new_merge[0] == previous_merge[0]:
                    # add the additional units in new_merge to the identified merge group.
                    previous_merge.extend(new_merge[1:])
                    previous_merge.sort()
                    break
            else:
                # append this merge group to the list if no previous merge
                parent_merge_groups.append(new_merge)
        return parent_merge_groups.sort(), True

get_labels(sorting, parent_labels, quality_metrics, label_params) staticmethod

Returns a dictionary of labels using quality_metrics and label parameters.

Parameters:

Name Type Description Default
sorting sorting
required
parent_labels list

Information about previous merges

required
quality_metrics list
required
label_params dict
required

Returns:

Name Type Description
parent_labels list
Source code in src/spyglass/spikesorting/spikesorting_curation.py
@staticmethod
def get_labels(sorting, parent_labels, quality_metrics, label_params):
    """Returns a dictionary of labels using quality_metrics and label
    parameters.

    Parameters
    ---------
    sorting : spikeinterface.sorting
    parent_labels : list
        Information about previous merges
    quality_metrics : list
    label_params : dict

    Returns
    -------
    parent_labels : list

    """
    # overview:
    # 1. Use quality metrics to determine labels for units
    # 2. Append labels to current labels, checking for inconsistencies
    if not label_params:
        return parent_labels
    else:
        for metric in label_params:
            if metric not in quality_metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
            else:
                compare = _comparison_to_function[label_params[metric][0]]

                for unit_id in quality_metrics[metric].keys():
                    # compare the quality metric to the threshold with the specified operator
                    # note that label_params[metric] is a three element list with a comparison operator as a string,
                    # the threshold value, and a list of labels to be applied if the comparison is true
                    if compare(
                        quality_metrics[metric][unit_id],
                        label_params[metric][1],
                    ):
                        if unit_id not in parent_labels:
                            parent_labels[unit_id] = label_params[metric][2]
                        # check if the label is already there, and if not, add it
                        elif (
                            label_params[metric][2]
                            not in parent_labels[unit_id]
                        ):
                            parent_labels[unit_id].extend(
                                label_params[metric][2]
                            )
        return parent_labels

CuratedSpikeSorting

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class CuratedSpikeSorting(SpyglassMixin, dj.Computed):
    definition = """
    -> CuratedSpikeSortingSelection
    ---
    -> AnalysisNwbfile
    units_object_id: varchar(40)
    """

    class Unit(SpyglassMixin, dj.Part):
        definition = """
        # Table for holding sorted units
        -> CuratedSpikeSorting
        unit_id: int   # ID for each unit
        ---
        label='': varchar(200)   # optional set of labels for each unit
        nn_noise_overlap=-1: float   # noise overlap metric for each unit
        nn_isolation=-1: float   # isolation score metric for each unit
        isi_violation=-1: float   # ISI violation score for each unit
        snr=0: float            # SNR for each unit
        firing_rate=-1: float   # firing rate
        num_spikes=-1: int   # total number of spikes
        peak_channel=null: int # channel of maximum amplitude for each unit
        """

    def make(self, key):
        unit_labels_to_remove = ["reject"]
        # check that the Curation has metrics
        metrics = (Curation & key).fetch1("quality_metrics")
        if metrics == {}:
            Warning(
                f"Metrics for Curation {key} should normally be calculated before insertion here"
            )

        sorting = Curation.get_curated_sorting(key)
        unit_ids = sorting.get_unit_ids()
        # Get the labels for the units, add only those units that do not have 'reject' or 'noise' labels
        unit_labels = (Curation & key).fetch1("curation_labels")
        accepted_units = []
        for unit_id in unit_ids:
            if unit_id in unit_labels:
                if (
                    len(set(unit_labels_to_remove) & set(unit_labels[unit_id]))
                    == 0
                ):
                    accepted_units.append(unit_id)
            else:
                accepted_units.append(unit_id)

        # get the labels for the accepted units
        labels = {}
        for unit_id in accepted_units:
            if unit_id in unit_labels:
                labels[unit_id] = ",".join(unit_labels[unit_id])

        # convert unit_ids in metrics to integers, including only accepted units.
        #  TODO: convert to int this somewhere else
        final_metrics = {}
        for metric in metrics:
            final_metrics[metric] = {
                int(unit_id): metrics[metric][unit_id]
                for unit_id in metrics[metric]
                if int(unit_id) in accepted_units
            }

        logger.info(f"Found {len(accepted_units)} accepted units")

        # get the sorting and save it in the NWB file
        sorting = Curation.get_curated_sorting(key)
        recording = Curation.get_recording(key)

        # get the sort_interval and sorting interval list
        sort_interval = (SortInterval & key).fetch1("sort_interval")
        sort_interval_list_name = (SpikeSorting & key).fetch1(
            "artifact_removed_interval_list_name"
        )

        timestamps = SpikeSortingRecording._get_recording_timestamps(recording)

        (
            key["analysis_file_name"],
            key["units_object_id"],
        ) = Curation().save_sorting_nwb(
            key,
            sorting,
            timestamps,
            sort_interval_list_name,
            sort_interval,
            metrics=final_metrics,
            unit_ids=accepted_units,
            labels=labels,
        )
        self.insert1(key)

        # now add the units
        # Remove the non primary key entries.
        del key["units_object_id"]
        del key["analysis_file_name"]

        metric_fields = self.metrics_fields()
        for unit_id in accepted_units:
            key["unit_id"] = unit_id
            if unit_id in labels:
                key["label"] = labels[unit_id]
            for field in metric_fields:
                if field in final_metrics:
                    key[field] = final_metrics[field][unit_id]
                else:
                    Warning(
                        f"No metric named {field} in computed unit quality metrics; skipping"
                    )
            CuratedSpikeSorting.Unit.insert1(key)

    def metrics_fields(self):
        """Returns a list of the metrics that are currently in the Units table."""
        unit_info = self.Unit().fetch(limit=1, format="frame")
        unit_fields = [column for column in unit_info.columns]
        unit_fields.remove("label")
        return unit_fields

    @classmethod
    def get_recording(cls, key):
        """Returns the recording related to this curation. Useful for operations downstream of merge table"""
        # expand the key
        recording_key = (cls & key).fetch1("KEY")
        return SpikeSortingRecording()._get_filtered_recording(recording_key)

    @classmethod
    def get_sorting(cls, key):
        """Returns the sorting related to this curation. Useful for operations downstream of merge table"""
        # expand the key
        sorting_key = (cls & key).fetch1("KEY")
        return Curation.get_curated_sorting(sorting_key)

metrics_fields()

Returns a list of the metrics that are currently in the Units table.

Source code in src/spyglass/spikesorting/spikesorting_curation.py
def metrics_fields(self):
    """Returns a list of the metrics that are currently in the Units table."""
    unit_info = self.Unit().fetch(limit=1, format="frame")
    unit_fields = [column for column in unit_info.columns]
    unit_fields.remove("label")
    return unit_fields

get_recording(key) classmethod

Returns the recording related to this curation. Useful for operations downstream of merge table

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@classmethod
def get_recording(cls, key):
    """Returns the recording related to this curation. Useful for operations downstream of merge table"""
    # expand the key
    recording_key = (cls & key).fetch1("KEY")
    return SpikeSortingRecording()._get_filtered_recording(recording_key)

get_sorting(key) classmethod

Returns the sorting related to this curation. Useful for operations downstream of merge table

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@classmethod
def get_sorting(cls, key):
    """Returns the sorting related to this curation. Useful for operations downstream of merge table"""
    # expand the key
    sorting_key = (cls & key).fetch1("KEY")
    return Curation.get_curated_sorting(sorting_key)

UnitInclusionParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/spikesorting_curation.py
@schema
class UnitInclusionParameters(SpyglassMixin, dj.Manual):
    definition = """
    unit_inclusion_param_name: varchar(80) # the name of the list of thresholds for unit inclusion
    ---
    inclusion_param_dict: blob # the dictionary of inclusion / exclusion parameters
    """

    def insert1(self, key, **kwargs):
        # check to see that the dictionary fits the specifications
        # The inclusion parameter dict has the following form:
        # param_dict['metric_name'] = (operator, value)
        #    where operator is '<', '>', <=', '>=', or '==' and value is the comparison (float) value to be used ()
        # param_dict['exclude_labels'] = [list of labels to exclude]
        pdict = key["inclusion_param_dict"]
        metrics_list = CuratedSpikeSorting().metrics_fields()

        for k in pdict:
            if k not in metrics_list and k != "exclude_labels":
                raise Exception(
                    f"key {k} is not a valid element of the inclusion_param_dict"
                )
            if k in metrics_list:
                if pdict[k][0] not in _comparison_to_function:
                    raise Exception(
                        f"operator {pdict[k][0]} for metric {k} is not in the valid operators list: {_comparison_to_function.keys()}"
                    )
            if k == "exclude_labels":
                for label in pdict[k]:
                    if label not in valid_labels:
                        raise Exception(
                            f"exclude label {label} is not in the valid_labels list: {valid_labels}"
                        )
        super().insert1(key, **kwargs)

    def get_included_units(
        self, curated_sorting_key, unit_inclusion_param_name
    ):
        """Given a reference to a set of curated sorting units and the name of
        a unit inclusion parameter list, returns unit key

        Parameters
        ----------
        curated_sorting_key : dict
            key to select a set of curated sorting
        unit_inclusion_param_name : str
            name of a unit inclusion parameter entry

        Returns
        -------
        dict
            key to select all of the included units
        """
        curated_sortings = (CuratedSpikeSorting() & curated_sorting_key).fetch()
        inc_param_dict = (
            UnitInclusionParameters
            & {"unit_inclusion_param_name": unit_inclusion_param_name}
        ).fetch1("inclusion_param_dict")
        units = (CuratedSpikeSorting().Unit() & curated_sortings).fetch()
        units_key = (CuratedSpikeSorting().Unit() & curated_sortings).fetch(
            "KEY"
        )
        # get the list of labels to exclude if there is one
        if "exclude_labels" in inc_param_dict:
            exclude_labels = inc_param_dict["exclude_labels"]
            del inc_param_dict["exclude_labels"]
        else:
            exclude_labels = []

        # create a list of the units to kepp.
        keep = np.asarray([True] * len(units))
        for metric in inc_param_dict:
            # for all units, go through each metric, compare it to the value
            # specified, and update the list to be kept
            keep = np.logical_and(
                keep,
                _comparison_to_function[inc_param_dict[metric][0]](
                    units[metric], inc_param_dict[metric][1]
                ),
            )

        # now exclude by label if it is specified
        if len(exclude_labels):
            for unit_ind in np.ravel(np.argwhere(keep)):
                labels = units[unit_ind]["label"].split(",")
                for label in labels:
                    if label in exclude_labels:
                        keep[unit_ind] = False
                        break

        # return units that passed all of the tests
        # TODO: Make this more efficient
        return {i: units_key[i] for i in np.ravel(np.argwhere(keep))}

get_included_units(curated_sorting_key, unit_inclusion_param_name)

Given a reference to a set of curated sorting units and the name of a unit inclusion parameter list, returns unit key

Parameters:

Name Type Description Default
curated_sorting_key dict

key to select a set of curated sorting

required
unit_inclusion_param_name str

name of a unit inclusion parameter entry

required

Returns:

Type Description
dict

key to select all of the included units

Source code in src/spyglass/spikesorting/spikesorting_curation.py
def get_included_units(
    self, curated_sorting_key, unit_inclusion_param_name
):
    """Given a reference to a set of curated sorting units and the name of
    a unit inclusion parameter list, returns unit key

    Parameters
    ----------
    curated_sorting_key : dict
        key to select a set of curated sorting
    unit_inclusion_param_name : str
        name of a unit inclusion parameter entry

    Returns
    -------
    dict
        key to select all of the included units
    """
    curated_sortings = (CuratedSpikeSorting() & curated_sorting_key).fetch()
    inc_param_dict = (
        UnitInclusionParameters
        & {"unit_inclusion_param_name": unit_inclusion_param_name}
    ).fetch1("inclusion_param_dict")
    units = (CuratedSpikeSorting().Unit() & curated_sortings).fetch()
    units_key = (CuratedSpikeSorting().Unit() & curated_sortings).fetch(
        "KEY"
    )
    # get the list of labels to exclude if there is one
    if "exclude_labels" in inc_param_dict:
        exclude_labels = inc_param_dict["exclude_labels"]
        del inc_param_dict["exclude_labels"]
    else:
        exclude_labels = []

    # create a list of the units to kepp.
    keep = np.asarray([True] * len(units))
    for metric in inc_param_dict:
        # for all units, go through each metric, compare it to the value
        # specified, and update the list to be kept
        keep = np.logical_and(
            keep,
            _comparison_to_function[inc_param_dict[metric][0]](
                units[metric], inc_param_dict[metric][1]
            ),
        )

    # now exclude by label if it is specified
    if len(exclude_labels):
        for unit_ind in np.ravel(np.argwhere(keep)):
            labels = units[unit_ind]["label"].split(",")
            for label in labels:
                if label in exclude_labels:
                    keep[unit_ind] = False
                    break

    # return units that passed all of the tests
    # TODO: Make this more efficient
    return {i: units_key[i] for i in np.ravel(np.argwhere(keep))}