Skip to content

spikesorting_curation.py

apply_merge_groups_to_sorting(sorting, merge_groups)

Apply merge groups to a sorting extractor.

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def apply_merge_groups_to_sorting(
    sorting: si.BaseSorting, merge_groups: List[List[int]]
):
    """Apply merge groups to a sorting extractor."""
    # return a new sorting where the units are merged according to merge_groups
    # merge_groups is a list of lists of unit_ids.
    # for example: merge_groups = [[1, 2], [5, 8, 4]]]

    return MergedSortingExtractor(
        parent_sorting=sorting, merge_groups=merge_groups
    )

Curation

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class Curation(SpyglassMixin, dj.Manual):
    definition = """
    # Stores each spike sorting; similar to IntervalList
    curation_id: int # a number corresponding to the index of this curation
    -> SpikeSorting
    ---
    parent_curation_id=-1: int
    curation_labels: blob # a dictionary of labels for the units
    merge_groups: blob # a list of merge groups for the units
    quality_metrics: blob # a list of quality metrics for the units (if available)
    description='': varchar(1000) #optional description for this curated sort
    time_of_creation: int   # in Unix time, to the nearest second
    """

    _nwb_table = AnalysisNwbfile

    @staticmethod
    def insert_curation(
        sorting_key: dict,
        parent_curation_id: int = -1,
        labels=None,
        merge_groups=None,
        metrics=None,
        description="",
    ):
        """Given a SpikeSorting key and the parent_sorting_id (and optional
        arguments) insert an entry into Curation.


        Parameters
        ----------
        sorting_key : dict
            The key for the original SpikeSorting
        parent_curation_id : int, optional
            The id of the parent sorting
        labels : dict or None, optional
        merge_groups : dict or None, optional
        metrics : dict or None, optional
            Computed metrics for sorting
        description : str, optional
            text description of this sort

        Returns
        -------
        curation_key : dict

        """
        if parent_curation_id == -1:
            # check to see if this sorting with a parent of -1 has already been
            # inserted and if so, warn the user
            inserted_curation = (Curation & sorting_key).fetch("KEY")
            if len(inserted_curation) > 0:
                Warning(
                    "Sorting has already been inserted, returning key to previously"
                    "inserted curation"
                )
                return inserted_curation[0]

        if labels is None:
            labels = {}
        if merge_groups is None:
            merge_groups = []
        if metrics is None:
            metrics = {}

        # generate a unique number for this curation
        id = (Curation & sorting_key).fetch("curation_id")
        if len(id) > 0:
            curation_id = max(id) + 1
        else:
            curation_id = 0

        # convert unit_ids in labels to integers for labels from sortingview.
        new_labels = {int(unit_id): labels[unit_id] for unit_id in labels}

        sorting_key.update(
            {
                "curation_id": curation_id,
                "parent_curation_id": parent_curation_id,
                "description": description,
                "curation_labels": new_labels,
                "merge_groups": merge_groups,
                "quality_metrics": metrics,
                "time_of_creation": int(time.time()),
            }
        )

        # mike: added skip duplicates
        Curation.insert1(sorting_key, skip_duplicates=True)

        # get the primary key for this curation
        curation_key = {
            item: sorting_key[item] for item in Curation.primary_key
        }

        return curation_key

    @staticmethod
    def get_recording(key: dict):
        """Returns the recording extractor for the recording related to this curation

        Parameters
        ----------
        key : dict
            SpikeSortingRecording key

        Returns
        -------
        recording_extractor : spike interface recording extractor

        """
        recording_path = (SpikeSortingRecording & key).fetch1("recording_path")
        return si.load_extractor(recording_path)

    @staticmethod
    def get_curated_sorting(key: dict):
        """Returns the sorting extractor related to this curation,
        with merges applied.

        Parameters
        ----------
        key : dict
            Curation key

        Returns
        -------
        sorting_extractor: spike interface sorting extractor

        """
        sorting_path = (SpikeSorting & key).fetch1("sorting_path")
        sorting = si.load_extractor(sorting_path)
        merge_groups = (Curation & key).fetch1("merge_groups")
        # TODO: write code to get merged sorting extractor
        if len(merge_groups) != 0:
            return MergedSortingExtractor(
                parent_sorting=sorting, merge_groups=merge_groups
            )
        else:
            return sorting

    @staticmethod
    def save_sorting_nwb(
        key,
        sorting,
        timestamps,
        sort_interval_list_name,
        sort_interval,
        labels=None,
        metrics=None,
        unit_ids=None,
    ):
        """Store a sorting in a new AnalysisNwbfile

        Parameters
        ----------
        key : dict
            key to SpikeSorting table
        sorting : si.Sorting
            sorting
        timestamps : array_like
            Time stamps of the sorted recoridng;
            used to convert the spike timings from index to real time
        sort_interval_list_name : str
            name of sort interval
        sort_interval : list
            interval for start and end of sort
        labels : dict, optional
            curation labels, by default None
        metrics : dict, optional
            quality metrics, by default None
        unit_ids : list, optional
            IDs of units whose spiketrains to save, by default None

        Returns
        -------
        analysis_file_name : str
        units_object_id : str

        """
        analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])

        sort_interval_valid_times = (
            IntervalList & {"interval_list_name": sort_interval_list_name}
        ).fetch1("valid_times")

        units = dict()
        units_valid_times = dict()
        units_sort_interval = dict()

        if unit_ids is None:
            unit_ids = sorting.get_unit_ids()

        for unit_id in unit_ids:
            spike_times_in_samples = sorting.get_unit_spike_train(
                unit_id=unit_id
            )
            units[unit_id] = timestamps[spike_times_in_samples]
            units_valid_times[unit_id] = sort_interval_valid_times
            units_sort_interval[unit_id] = [sort_interval]

        object_ids = AnalysisNwbfile().add_units(
            analysis_file_name,
            units,
            units_valid_times,
            units_sort_interval,
            metrics=metrics,
            labels=labels,
        )
        AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name)

        if object_ids == "":
            logger.warn(
                "Sorting contains no units."
                "Created an empty analysis nwb file anyway."
            )
            units_object_id = ""
        else:
            units_object_id = object_ids[0]

        return analysis_file_name, units_object_id

insert_curation(sorting_key, parent_curation_id=-1, labels=None, merge_groups=None, metrics=None, description='') staticmethod

Given a SpikeSorting key and the parent_sorting_id (and optional arguments) insert an entry into Curation.

Parameters:

Name Type Description Default
sorting_key dict

The key for the original SpikeSorting

required
parent_curation_id int

The id of the parent sorting

-1
labels dict or None
None
merge_groups dict or None
None
metrics dict or None

Computed metrics for sorting

None
description str

text description of this sort

''

Returns:

Name Type Description
curation_key dict
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def insert_curation(
    sorting_key: dict,
    parent_curation_id: int = -1,
    labels=None,
    merge_groups=None,
    metrics=None,
    description="",
):
    """Given a SpikeSorting key and the parent_sorting_id (and optional
    arguments) insert an entry into Curation.


    Parameters
    ----------
    sorting_key : dict
        The key for the original SpikeSorting
    parent_curation_id : int, optional
        The id of the parent sorting
    labels : dict or None, optional
    merge_groups : dict or None, optional
    metrics : dict or None, optional
        Computed metrics for sorting
    description : str, optional
        text description of this sort

    Returns
    -------
    curation_key : dict

    """
    if parent_curation_id == -1:
        # check to see if this sorting with a parent of -1 has already been
        # inserted and if so, warn the user
        inserted_curation = (Curation & sorting_key).fetch("KEY")
        if len(inserted_curation) > 0:
            Warning(
                "Sorting has already been inserted, returning key to previously"
                "inserted curation"
            )
            return inserted_curation[0]

    if labels is None:
        labels = {}
    if merge_groups is None:
        merge_groups = []
    if metrics is None:
        metrics = {}

    # generate a unique number for this curation
    id = (Curation & sorting_key).fetch("curation_id")
    if len(id) > 0:
        curation_id = max(id) + 1
    else:
        curation_id = 0

    # convert unit_ids in labels to integers for labels from sortingview.
    new_labels = {int(unit_id): labels[unit_id] for unit_id in labels}

    sorting_key.update(
        {
            "curation_id": curation_id,
            "parent_curation_id": parent_curation_id,
            "description": description,
            "curation_labels": new_labels,
            "merge_groups": merge_groups,
            "quality_metrics": metrics,
            "time_of_creation": int(time.time()),
        }
    )

    # mike: added skip duplicates
    Curation.insert1(sorting_key, skip_duplicates=True)

    # get the primary key for this curation
    curation_key = {
        item: sorting_key[item] for item in Curation.primary_key
    }

    return curation_key

get_recording(key) staticmethod

Returns the recording extractor for the recording related to this curation

Parameters:

Name Type Description Default
key dict

SpikeSortingRecording key

required

Returns:

Name Type Description
recording_extractor spike interface recording extractor
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def get_recording(key: dict):
    """Returns the recording extractor for the recording related to this curation

    Parameters
    ----------
    key : dict
        SpikeSortingRecording key

    Returns
    -------
    recording_extractor : spike interface recording extractor

    """
    recording_path = (SpikeSortingRecording & key).fetch1("recording_path")
    return si.load_extractor(recording_path)

get_curated_sorting(key) staticmethod

Returns the sorting extractor related to this curation, with merges applied.

Parameters:

Name Type Description Default
key dict

Curation key

required

Returns:

Name Type Description
sorting_extractor spike interface sorting extractor
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def get_curated_sorting(key: dict):
    """Returns the sorting extractor related to this curation,
    with merges applied.

    Parameters
    ----------
    key : dict
        Curation key

    Returns
    -------
    sorting_extractor: spike interface sorting extractor

    """
    sorting_path = (SpikeSorting & key).fetch1("sorting_path")
    sorting = si.load_extractor(sorting_path)
    merge_groups = (Curation & key).fetch1("merge_groups")
    # TODO: write code to get merged sorting extractor
    if len(merge_groups) != 0:
        return MergedSortingExtractor(
            parent_sorting=sorting, merge_groups=merge_groups
        )
    else:
        return sorting

save_sorting_nwb(key, sorting, timestamps, sort_interval_list_name, sort_interval, labels=None, metrics=None, unit_ids=None) staticmethod

Store a sorting in a new AnalysisNwbfile

Parameters:

Name Type Description Default
key dict

key to SpikeSorting table

required
sorting Sorting

sorting

required
timestamps array_like

Time stamps of the sorted recoridng; used to convert the spike timings from index to real time

required
sort_interval_list_name str

name of sort interval

required
sort_interval list

interval for start and end of sort

required
labels dict

curation labels, by default None

None
metrics dict

quality metrics, by default None

None
unit_ids list

IDs of units whose spiketrains to save, by default None

None

Returns:

Name Type Description
analysis_file_name str
units_object_id str
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def save_sorting_nwb(
    key,
    sorting,
    timestamps,
    sort_interval_list_name,
    sort_interval,
    labels=None,
    metrics=None,
    unit_ids=None,
):
    """Store a sorting in a new AnalysisNwbfile

    Parameters
    ----------
    key : dict
        key to SpikeSorting table
    sorting : si.Sorting
        sorting
    timestamps : array_like
        Time stamps of the sorted recoridng;
        used to convert the spike timings from index to real time
    sort_interval_list_name : str
        name of sort interval
    sort_interval : list
        interval for start and end of sort
    labels : dict, optional
        curation labels, by default None
    metrics : dict, optional
        quality metrics, by default None
    unit_ids : list, optional
        IDs of units whose spiketrains to save, by default None

    Returns
    -------
    analysis_file_name : str
    units_object_id : str

    """
    analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"])

    sort_interval_valid_times = (
        IntervalList & {"interval_list_name": sort_interval_list_name}
    ).fetch1("valid_times")

    units = dict()
    units_valid_times = dict()
    units_sort_interval = dict()

    if unit_ids is None:
        unit_ids = sorting.get_unit_ids()

    for unit_id in unit_ids:
        spike_times_in_samples = sorting.get_unit_spike_train(
            unit_id=unit_id
        )
        units[unit_id] = timestamps[spike_times_in_samples]
        units_valid_times[unit_id] = sort_interval_valid_times
        units_sort_interval[unit_id] = [sort_interval]

    object_ids = AnalysisNwbfile().add_units(
        analysis_file_name,
        units,
        units_valid_times,
        units_sort_interval,
        metrics=metrics,
        labels=labels,
    )
    AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name)

    if object_ids == "":
        logger.warn(
            "Sorting contains no units."
            "Created an empty analysis nwb file anyway."
        )
        units_object_id = ""
    else:
        units_object_id = object_ids[0]

    return analysis_file_name, units_object_id

WaveformParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class WaveformParameters(SpyglassMixin, dj.Manual):
    definition = """
    waveform_params_name: varchar(80) # name of waveform extraction parameters
    ---
    waveform_params: blob # a dict of waveform extraction parameters
    """

    def insert_default(self):
        """Inserts default waveform parameters"""
        waveform_params_name = "default_not_whitened"
        waveform_params = {
            "ms_before": 0.5,
            "ms_after": 0.5,
            "max_spikes_per_unit": 5000,
            "n_jobs": 5,
            "total_memory": "5G",
            "whiten": False,
        }
        self.insert1(
            [waveform_params_name, waveform_params], skip_duplicates=True
        )
        waveform_params_name = "default_whitened"
        waveform_params = {
            "ms_before": 0.5,
            "ms_after": 0.5,
            "max_spikes_per_unit": 5000,
            "n_jobs": 5,
            "total_memory": "5G",
            "whiten": True,
        }
        self.insert1(
            [waveform_params_name, waveform_params], skip_duplicates=True
        )

insert_default()

Inserts default waveform parameters

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def insert_default(self):
    """Inserts default waveform parameters"""
    waveform_params_name = "default_not_whitened"
    waveform_params = {
        "ms_before": 0.5,
        "ms_after": 0.5,
        "max_spikes_per_unit": 5000,
        "n_jobs": 5,
        "total_memory": "5G",
        "whiten": False,
    }
    self.insert1(
        [waveform_params_name, waveform_params], skip_duplicates=True
    )
    waveform_params_name = "default_whitened"
    waveform_params = {
        "ms_before": 0.5,
        "ms_after": 0.5,
        "max_spikes_per_unit": 5000,
        "n_jobs": 5,
        "total_memory": "5G",
        "whiten": True,
    }
    self.insert1(
        [waveform_params_name, waveform_params], skip_duplicates=True
    )

Waveforms

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class Waveforms(SpyglassMixin, dj.Computed):
    definition = """
    -> WaveformSelection
    ---
    waveform_extractor_path: varchar(400)
    -> AnalysisNwbfile
    waveforms_object_id: varchar(40)   # Object ID for the waveforms in NWB file
    """

    def make(self, key):
        """Populate Waveforms table with waveform extraction results

        1. Fetches ...
            - Recording and sorting from Curation table
            - Parameters from WaveformParameters table
        2. Uses spikeinterface to extract waveforms
        3. Generates an analysis NWB file with the waveforms
        4. Inserts the key into Waveforms table
        """
        key["analysis_file_name"] = AnalysisNwbfile().create(  # logged
            key["nwb_file_name"]
        )
        recording = Curation.get_recording(key)
        if recording.get_num_segments() > 1:
            recording = si.concatenate_recordings([recording])

        sorting = Curation.get_curated_sorting(key)

        logger.info("Extracting waveforms...")
        waveform_params = (WaveformParameters & key).fetch1("waveform_params")
        if "whiten" in waveform_params:
            if waveform_params.pop("whiten"):
                recording = sip.whiten(recording, dtype="float32")

        waveform_extractor_name = self._get_waveform_extractor_name(key)
        key["waveform_extractor_path"] = str(
            Path(waveforms_dir) / Path(waveform_extractor_name)
        )
        if os.path.exists(key["waveform_extractor_path"]):
            shutil.rmtree(key["waveform_extractor_path"])
        waveforms = si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=key["waveform_extractor_path"],
            **waveform_params,
        )

        object_id = AnalysisNwbfile().add_units_waveforms(
            key["analysis_file_name"], waveform_extractor=waveforms
        )
        key["waveforms_object_id"] = object_id
        AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"])

        AnalysisNwbfile().log(key, table=self.full_table_name)
        self.insert1(key)

    def load_waveforms(self, key: dict):
        """Returns a spikeinterface waveform extractor specified by key

        Parameters
        ----------
        key : dict
            Could be an entry in Waveforms, or some other key that uniquely defines
            an entry in Waveforms

        Returns
        -------
        we : spikeinterface.WaveformExtractor
        """
        we_path = (self & key).fetch1("waveform_extractor_path")
        we = si.WaveformExtractor.load_from_folder(we_path)
        return we

    def fetch_nwb(self, key):
        """Fetches the NWB file path for the waveforms. NOT YET IMPLEMENTED."""
        # TODO: implement fetching waveforms from NWB
        return NotImplementedError

    def _get_waveform_extractor_name(self, key):
        waveform_params_name = (WaveformParameters & key).fetch1(
            "waveform_params_name"
        )

        return (
            f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_'
            f'{key["curation_id"]}_{waveform_params_name}_waveforms'
        )

make(key)

Populate Waveforms table with waveform extraction results

  1. Fetches ...
    • Recording and sorting from Curation table
    • Parameters from WaveformParameters table
  2. Uses spikeinterface to extract waveforms
  3. Generates an analysis NWB file with the waveforms
  4. Inserts the key into Waveforms table
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def make(self, key):
    """Populate Waveforms table with waveform extraction results

    1. Fetches ...
        - Recording and sorting from Curation table
        - Parameters from WaveformParameters table
    2. Uses spikeinterface to extract waveforms
    3. Generates an analysis NWB file with the waveforms
    4. Inserts the key into Waveforms table
    """
    key["analysis_file_name"] = AnalysisNwbfile().create(  # logged
        key["nwb_file_name"]
    )
    recording = Curation.get_recording(key)
    if recording.get_num_segments() > 1:
        recording = si.concatenate_recordings([recording])

    sorting = Curation.get_curated_sorting(key)

    logger.info("Extracting waveforms...")
    waveform_params = (WaveformParameters & key).fetch1("waveform_params")
    if "whiten" in waveform_params:
        if waveform_params.pop("whiten"):
            recording = sip.whiten(recording, dtype="float32")

    waveform_extractor_name = self._get_waveform_extractor_name(key)
    key["waveform_extractor_path"] = str(
        Path(waveforms_dir) / Path(waveform_extractor_name)
    )
    if os.path.exists(key["waveform_extractor_path"]):
        shutil.rmtree(key["waveform_extractor_path"])
    waveforms = si.extract_waveforms(
        recording=recording,
        sorting=sorting,
        folder=key["waveform_extractor_path"],
        **waveform_params,
    )

    object_id = AnalysisNwbfile().add_units_waveforms(
        key["analysis_file_name"], waveform_extractor=waveforms
    )
    key["waveforms_object_id"] = object_id
    AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"])

    AnalysisNwbfile().log(key, table=self.full_table_name)
    self.insert1(key)

load_waveforms(key)

Returns a spikeinterface waveform extractor specified by key

Parameters:

Name Type Description Default
key dict

Could be an entry in Waveforms, or some other key that uniquely defines an entry in Waveforms

required

Returns:

Name Type Description
we WaveformExtractor
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def load_waveforms(self, key: dict):
    """Returns a spikeinterface waveform extractor specified by key

    Parameters
    ----------
    key : dict
        Could be an entry in Waveforms, or some other key that uniquely defines
        an entry in Waveforms

    Returns
    -------
    we : spikeinterface.WaveformExtractor
    """
    we_path = (self & key).fetch1("waveform_extractor_path")
    we = si.WaveformExtractor.load_from_folder(we_path)
    return we

fetch_nwb(key)

Fetches the NWB file path for the waveforms. NOT YET IMPLEMENTED.

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def fetch_nwb(self, key):
    """Fetches the NWB file path for the waveforms. NOT YET IMPLEMENTED."""
    # TODO: implement fetching waveforms from NWB
    return NotImplementedError

MetricParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class MetricParameters(SpyglassMixin, dj.Manual):
    definition = """
    # Parameters for computing quality metrics of sorted units
    metric_params_name: varchar(64)
    ---
    metric_params: blob
    """

    # NOTE: See #630, #664. Excessive key length.

    metric_default_params = {
        "snr": {
            "peak_sign": "neg",
            "random_chunk_kwargs_dict": {
                "num_chunks_per_segment": 20,
                "chunk_size": 10000,
                "seed": 0,
            },
        },
        "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0},
        "nn_isolation": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "nn_noise_overlap": {
            "max_spikes": 1000,
            "min_spikes": 10,
            "n_neighbors": 5,
            "n_components": 7,
            "radius_um": 100,
            "seed": 0,
        },
        "peak_channel": {"peak_sign": "neg"},
        "num_spikes": {},
    }
    # Example of peak_offset parameters 'peak_offset': {'peak_sign': 'neg'}
    available_metrics = [
        "snr",
        "isi_violation",
        "nn_isolation",
        "nn_noise_overlap",
        "peak_offset",
        "peak_channel",
        "num_spikes",
    ]

    def get_metric_default_params(self, metric: str):
        "Returns default params for the given metric"
        return self.metric_default_params(metric)

    def insert_default(self) -> None:
        """Inserts default metric parameters"""
        self.insert1(
            ["franklab_default3", self.metric_default_params],
            skip_duplicates=True,
        )

    def get_available_metrics(self):
        """Log available metrics and their descriptions"""
        for metric in _metric_name_to_func:
            if metric in self.available_metrics:
                metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
                metric_string = ("{metric_name} : {metric_doc}").format(
                    metric_name=metric, metric_doc=metric_doc
                )
                logger.info(metric_string + "\n")

    # TODO
    def _validate_metrics_list(self, key):
        """Checks whether a row to be inserted contains only available metrics"""
        # get available metrics list
        # get metric list from key
        # compare
        return NotImplementedError

get_metric_default_params(metric)

Returns default params for the given metric

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def get_metric_default_params(self, metric: str):
    "Returns default params for the given metric"
    return self.metric_default_params(metric)

insert_default()

Inserts default metric parameters

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def insert_default(self) -> None:
    """Inserts default metric parameters"""
    self.insert1(
        ["franklab_default3", self.metric_default_params],
        skip_duplicates=True,
    )

get_available_metrics()

Log available metrics and their descriptions

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def get_available_metrics(self):
    """Log available metrics and their descriptions"""
    for metric in _metric_name_to_func:
        if metric in self.available_metrics:
            metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0]
            metric_string = ("{metric_name} : {metric_doc}").format(
                metric_name=metric, metric_doc=metric_doc
            )
            logger.info(metric_string + "\n")

MetricSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class MetricSelection(SpyglassMixin, dj.Manual):
    definition = """
    -> Waveforms
    -> MetricParameters
    """

    def insert1(self, key, **kwargs):
        """Overriding insert1 to add warnings for peak_offset and peak_channel"""
        waveform_params = (WaveformParameters & key).fetch1("waveform_params")
        metric_params = (MetricParameters & key).fetch1("metric_params")
        if "peak_offset" in metric_params:
            if waveform_params["whiten"]:
                warnings.warn(
                    "Calculating 'peak_offset' metric on "
                    "whitened waveforms may result in slight "
                    "discrepancies"
                )
        if "peak_channel" in metric_params:
            if waveform_params["whiten"]:
                Warning(
                    "Calculating 'peak_channel' metric on "
                    "whitened waveforms may result in slight "
                    "discrepancies"
                )
        super().insert1(key, **kwargs)

insert1(key, **kwargs)

Overriding insert1 to add warnings for peak_offset and peak_channel

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def insert1(self, key, **kwargs):
    """Overriding insert1 to add warnings for peak_offset and peak_channel"""
    waveform_params = (WaveformParameters & key).fetch1("waveform_params")
    metric_params = (MetricParameters & key).fetch1("metric_params")
    if "peak_offset" in metric_params:
        if waveform_params["whiten"]:
            warnings.warn(
                "Calculating 'peak_offset' metric on "
                "whitened waveforms may result in slight "
                "discrepancies"
            )
    if "peak_channel" in metric_params:
        if waveform_params["whiten"]:
            Warning(
                "Calculating 'peak_channel' metric on "
                "whitened waveforms may result in slight "
                "discrepancies"
            )
    super().insert1(key, **kwargs)

QualityMetrics

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class QualityMetrics(SpyglassMixin, dj.Computed):
    definition = """
    -> MetricSelection
    ---
    quality_metrics_path: varchar(500)
    -> AnalysisNwbfile
    object_id: varchar(40) # Object ID for the metrics in NWB file
    """

    def make(self, key):
        """Populate QualityMetrics table with quality metric results.

        1. Fetches ...
            - Waveform extractor from Waveforms table
            - Parameters from MetricParameters table
        2. Computes metrics, including SNR, ISI violation, NN isolation,
            NN noise overlap, peak offset, peak channel, and number of spikes.
        3. Generates an analysis NWB file with the metrics.
        4. Inserts the key into QualityMetrics table
        """
        analysis_file_name = AnalysisNwbfile().create(  # logged
            key["nwb_file_name"]
        )
        waveform_extractor = Waveforms().load_waveforms(key)
        key["analysis_file_name"] = (
            analysis_file_name  # add to key here to prevent fetch errors
        )
        qm = {}
        params = (MetricParameters & key).fetch1("metric_params")
        for metric_name, metric_params in params.items():
            metric = self._compute_metric(
                waveform_extractor, metric_name, **metric_params
            )
            qm[metric_name] = metric
        qm_name = self._get_quality_metrics_name(key)
        key["quality_metrics_path"] = str(
            Path(waveforms_dir) / Path(qm_name + ".json")
        )
        # save metrics dict as json
        logger.info(f"Computed all metrics: {qm}")
        self._dump_to_json(qm, key["quality_metrics_path"])

        key["object_id"] = AnalysisNwbfile().add_units_metrics(
            key["analysis_file_name"], metrics=qm
        )
        AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"])
        AnalysisNwbfile().log(key, table=self.full_table_name)

        self.insert1(key)

    def _get_quality_metrics_name(self, key):
        wf_name = Waveforms()._get_waveform_extractor_name(key)
        qm_name = wf_name + "_qm"
        return qm_name

    def _compute_metric(self, waveform_extractor, metric_name, **metric_params):
        metric_func = _metric_name_to_func[metric_name]

        peak_sign_metrics = ["snr", "peak_offset", "peak_channel"]
        if metric_name == "isi_violation":
            return metric_func(waveform_extractor, **metric_params)
        elif metric_name in peak_sign_metrics:
            if "peak_sign" not in metric_params:
                raise Exception(
                    f"{peak_sign_metrics} metrics require peak_sign",
                    "to be defined in the metric parameters",
                )
            return metric_func(
                waveform_extractor,
                peak_sign=metric_params.pop("peak_sign"),
                **metric_params,
            )

        metric = {}
        num_spikes = sq.compute_num_spikes(waveform_extractor)

        is_nn_iso = metric_name == "nn_isolation"
        is_nn_overlap = metric_name == "nn_noise_overlap"
        min_spikes = metric_params.get("min_spikes", 10)

        for unit_id in waveform_extractor.sorting.get_unit_ids():
            # checks to avoid bug in spikeinterface 0.98.2
            if num_spikes[unit_id] < min_spikes:
                if is_nn_iso:
                    metric[str(unit_id)] = (np.nan, np.nan)
                elif is_nn_overlap:
                    metric[str(unit_id)] = np.nan

            else:
                metric[str(unit_id)] = metric_func(
                    waveform_extractor,
                    this_unit_id=int(unit_id),
                    **metric_params,
                )
            # nn_isolation returns tuple with isolation and unit number.
            # We only want isolation.
            if is_nn_iso:
                metric[str(unit_id)] = metric[str(unit_id)][0]
        return metric

    def _dump_to_json(self, qm_dict, save_path):
        new_qm = {}
        for key, value in qm_dict.items():
            m = {}
            for unit_id, metric_val in value.items():
                m[str(unit_id)] = np.float64(metric_val)
            new_qm[str(key)] = m
        with open(save_path, "w", encoding="utf-8") as f:
            json.dump(new_qm, f, ensure_ascii=False, indent=4)

make(key)

Populate QualityMetrics table with quality metric results.

  1. Fetches ...
    • Waveform extractor from Waveforms table
    • Parameters from MetricParameters table
  2. Computes metrics, including SNR, ISI violation, NN isolation, NN noise overlap, peak offset, peak channel, and number of spikes.
  3. Generates an analysis NWB file with the metrics.
  4. Inserts the key into QualityMetrics table
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def make(self, key):
    """Populate QualityMetrics table with quality metric results.

    1. Fetches ...
        - Waveform extractor from Waveforms table
        - Parameters from MetricParameters table
    2. Computes metrics, including SNR, ISI violation, NN isolation,
        NN noise overlap, peak offset, peak channel, and number of spikes.
    3. Generates an analysis NWB file with the metrics.
    4. Inserts the key into QualityMetrics table
    """
    analysis_file_name = AnalysisNwbfile().create(  # logged
        key["nwb_file_name"]
    )
    waveform_extractor = Waveforms().load_waveforms(key)
    key["analysis_file_name"] = (
        analysis_file_name  # add to key here to prevent fetch errors
    )
    qm = {}
    params = (MetricParameters & key).fetch1("metric_params")
    for metric_name, metric_params in params.items():
        metric = self._compute_metric(
            waveform_extractor, metric_name, **metric_params
        )
        qm[metric_name] = metric
    qm_name = self._get_quality_metrics_name(key)
    key["quality_metrics_path"] = str(
        Path(waveforms_dir) / Path(qm_name + ".json")
    )
    # save metrics dict as json
    logger.info(f"Computed all metrics: {qm}")
    self._dump_to_json(qm, key["quality_metrics_path"])

    key["object_id"] = AnalysisNwbfile().add_units_metrics(
        key["analysis_file_name"], metrics=qm
    )
    AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"])
    AnalysisNwbfile().log(key, table=self.full_table_name)

    self.insert1(key)

AutomaticCurationParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class AutomaticCurationParameters(SpyglassMixin, dj.Manual):
    definition = """
    auto_curation_params_name: varchar(36)   # name of this parameter set
    ---
    merge_params: blob   # dictionary of params to merge units
    label_params: blob   # dictionary params to label units
    """

    # NOTE: No existing entries impacted by this change

    def insert1(self, key, **kwargs):
        """Overriding insert1 to validats label_params and merge_params"""
        # validate the labels and then insert
        # TODO: add validation for merge_params
        for metric in key["label_params"]:
            if metric not in _metric_name_to_func:
                raise Exception(f"{metric} not in list of available metrics")
            comparison_list = key["label_params"][metric]
            if comparison_list[0] not in _comparison_to_function:
                raise Exception(
                    f'{metric}: "{comparison_list[0]}" '
                    f"not in list of available comparisons"
                )
            if not isinstance(comparison_list[1], (int, float)):
                raise Exception(
                    f"{metric}: {comparison_list[1]} is of type "
                    f"{type(comparison_list[1])} and not a number"
                )
            for label in comparison_list[2]:
                if label not in valid_labels:
                    raise Exception(
                        f'{metric}: "{label}" '
                        f"not in list of valid labels: {valid_labels}"
                    )
        super().insert1(key, **kwargs)

    def insert_default(self):
        """Inserts default automatic curation parameters"""
        # label_params parsing: Each key is the name of a metric,
        # the contents are a three value list with the comparison, a value,
        # and a list of labels to apply if the comparison is true
        default_params = {
            "auto_curation_params_name": "default",
            "merge_params": {},
            "label_params": {
                "nn_noise_overlap": [">", 0.1, ["noise", "reject"]]
            },
        }
        self.insert1(default_params, skip_duplicates=True)

        # Second default parameter set for not applying any labels,
        # or merges, but adding metrics
        no_label_params = {
            "auto_curation_params_name": "none",
            "merge_params": {},
            "label_params": {},
        }
        self.insert1(no_label_params, skip_duplicates=True)

insert1(key, **kwargs)

Overriding insert1 to validats label_params and merge_params

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def insert1(self, key, **kwargs):
    """Overriding insert1 to validats label_params and merge_params"""
    # validate the labels and then insert
    # TODO: add validation for merge_params
    for metric in key["label_params"]:
        if metric not in _metric_name_to_func:
            raise Exception(f"{metric} not in list of available metrics")
        comparison_list = key["label_params"][metric]
        if comparison_list[0] not in _comparison_to_function:
            raise Exception(
                f'{metric}: "{comparison_list[0]}" '
                f"not in list of available comparisons"
            )
        if not isinstance(comparison_list[1], (int, float)):
            raise Exception(
                f"{metric}: {comparison_list[1]} is of type "
                f"{type(comparison_list[1])} and not a number"
            )
        for label in comparison_list[2]:
            if label not in valid_labels:
                raise Exception(
                    f'{metric}: "{label}" '
                    f"not in list of valid labels: {valid_labels}"
                )
    super().insert1(key, **kwargs)

insert_default()

Inserts default automatic curation parameters

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def insert_default(self):
    """Inserts default automatic curation parameters"""
    # label_params parsing: Each key is the name of a metric,
    # the contents are a three value list with the comparison, a value,
    # and a list of labels to apply if the comparison is true
    default_params = {
        "auto_curation_params_name": "default",
        "merge_params": {},
        "label_params": {
            "nn_noise_overlap": [">", 0.1, ["noise", "reject"]]
        },
    }
    self.insert1(default_params, skip_duplicates=True)

    # Second default parameter set for not applying any labels,
    # or merges, but adding metrics
    no_label_params = {
        "auto_curation_params_name": "none",
        "merge_params": {},
        "label_params": {},
    }
    self.insert1(no_label_params, skip_duplicates=True)

AutomaticCuration

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class AutomaticCuration(SpyglassMixin, dj.Computed):
    definition = """
    -> AutomaticCurationSelection
    ---
    auto_curation_key: blob # the key to the curation inserted by make
    """

    def make(self, key):
        """Populate AutomaticCuration table with automatic curation results.

        1. Fetches ...
            - Quality metrics from QualityMetrics table
            - Parameters from AutomaticCurationParameters table
            - Parent curation/sorting from Curation table
        2. Curates the sorting based on provided merge and label parameters
        3. Inserts IDs into  AutomaticCuration and Curation tables
        """
        metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path")
        with open(metrics_path) as f:
            quality_metrics = json.load(f)

        # get the curation information and the curated sorting
        parent_curation = (Curation & key).fetch(as_dict=True)[0]
        parent_merge_groups = parent_curation["merge_groups"]
        parent_labels = parent_curation["curation_labels"]
        parent_curation_id = parent_curation["curation_id"]
        parent_sorting = Curation.get_curated_sorting(key)

        merge_params = (AutomaticCurationParameters & key).fetch1(
            "merge_params"
        )
        merge_groups, units_merged = self.get_merge_groups(
            parent_sorting, parent_merge_groups, quality_metrics, merge_params
        )

        label_params = (AutomaticCurationParameters & key).fetch1(
            "label_params"
        )
        labels = self.get_labels(
            parent_sorting, parent_labels, quality_metrics, label_params
        )

        # keep the quality metrics only if no merging occurred.
        metrics = quality_metrics if not units_merged else None

        # insert this sorting into the CuratedSpikeSorting Table
        # first remove keys that aren't part of the Sorting (the primary key of curation)
        c_key = (SpikeSorting & key).fetch("KEY")[0]
        curation_key = {item: key[item] for item in key if item in c_key}
        key["auto_curation_key"] = Curation.insert_curation(
            curation_key,
            parent_curation_id=parent_curation_id,
            labels=labels,
            merge_groups=merge_groups,
            metrics=metrics,
            description="auto curated",
        )

        self.insert1(key)

    @staticmethod
    def get_merge_groups(
        sorting, parent_merge_groups, quality_metrics, merge_params
    ):
        """Identifies units to be merged based on the quality_metrics and
        merge parameters and returns an updated list of merges for the curation.

        Parameters
        ---------
        sorting : spikeinterface.sorting
        parent_merge_groups : list
            Information about previous merges
        quality_metrics : list
        merge_params : dict

        Returns
        -------
        merge_groups : list of lists
        merge_occurred : bool

        """

        # overview:
        # 1. Use quality metrics to determine merge groups for units
        # 2. Combine merge groups with current merge groups to produce union of merges

        if not merge_params:
            return parent_merge_groups, False
        else:
            # TODO: use the metrics to identify clusters that should be merged
            # new_merges should then reflect those merges and the line below should be deleted.
            new_merges = []
            # append these merges to the parent merge_groups
            for new_merge in new_merges:
                # check to see if the first cluster listed is in a current merge group
                for previous_merge in parent_merge_groups:
                    if new_merge[0] == previous_merge[0]:
                        # add the additional units in new_merge to the identified merge group.
                        previous_merge.extend(new_merge[1:])
                        previous_merge.sort()
                        break
                else:
                    # append this merge group to the list if no previous merge
                    parent_merge_groups.append(new_merge)
            return parent_merge_groups.sort(), True

    @staticmethod
    def get_labels(sorting, parent_labels, quality_metrics, label_params):
        """Returns a dictionary of labels using quality_metrics and label
        parameters.

        Parameters
        ---------
        sorting : spikeinterface.sorting
        parent_labels : list
            Information about previous merges
        quality_metrics : list
        label_params : dict

        Returns
        -------
        parent_labels : list

        """
        # overview:
        # 1. Use quality metrics to determine labels for units
        # 2. Append labels to current labels, checking for inconsistencies
        if not label_params:
            return parent_labels
        else:
            for metric in label_params:
                if metric not in quality_metrics:
                    Warning(f"{metric} not found in quality metrics; skipping")
                else:
                    compare = _comparison_to_function[label_params[metric][0]]

                    for unit_id in quality_metrics[metric].keys():
                        # compare the quality metric to the threshold with the specified operator
                        # note that label_params[metric] is a three element list with a comparison operator as a string,
                        # the threshold value, and a list of labels to be applied if the comparison is true
                        if compare(
                            quality_metrics[metric][unit_id],
                            label_params[metric][1],
                        ):
                            if unit_id not in parent_labels:
                                parent_labels[unit_id] = label_params[metric][2]
                            # check if the label is already there, and if not, add it
                            elif (
                                label_params[metric][2]
                                not in parent_labels[unit_id]
                            ):
                                parent_labels[unit_id].extend(
                                    label_params[metric][2]
                                )
            return parent_labels

make(key)

Populate AutomaticCuration table with automatic curation results.

  1. Fetches ...
    • Quality metrics from QualityMetrics table
    • Parameters from AutomaticCurationParameters table
    • Parent curation/sorting from Curation table
  2. Curates the sorting based on provided merge and label parameters
  3. Inserts IDs into AutomaticCuration and Curation tables
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def make(self, key):
    """Populate AutomaticCuration table with automatic curation results.

    1. Fetches ...
        - Quality metrics from QualityMetrics table
        - Parameters from AutomaticCurationParameters table
        - Parent curation/sorting from Curation table
    2. Curates the sorting based on provided merge and label parameters
    3. Inserts IDs into  AutomaticCuration and Curation tables
    """
    metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path")
    with open(metrics_path) as f:
        quality_metrics = json.load(f)

    # get the curation information and the curated sorting
    parent_curation = (Curation & key).fetch(as_dict=True)[0]
    parent_merge_groups = parent_curation["merge_groups"]
    parent_labels = parent_curation["curation_labels"]
    parent_curation_id = parent_curation["curation_id"]
    parent_sorting = Curation.get_curated_sorting(key)

    merge_params = (AutomaticCurationParameters & key).fetch1(
        "merge_params"
    )
    merge_groups, units_merged = self.get_merge_groups(
        parent_sorting, parent_merge_groups, quality_metrics, merge_params
    )

    label_params = (AutomaticCurationParameters & key).fetch1(
        "label_params"
    )
    labels = self.get_labels(
        parent_sorting, parent_labels, quality_metrics, label_params
    )

    # keep the quality metrics only if no merging occurred.
    metrics = quality_metrics if not units_merged else None

    # insert this sorting into the CuratedSpikeSorting Table
    # first remove keys that aren't part of the Sorting (the primary key of curation)
    c_key = (SpikeSorting & key).fetch("KEY")[0]
    curation_key = {item: key[item] for item in key if item in c_key}
    key["auto_curation_key"] = Curation.insert_curation(
        curation_key,
        parent_curation_id=parent_curation_id,
        labels=labels,
        merge_groups=merge_groups,
        metrics=metrics,
        description="auto curated",
    )

    self.insert1(key)

get_merge_groups(sorting, parent_merge_groups, quality_metrics, merge_params) staticmethod

Identifies units to be merged based on the quality_metrics and merge parameters and returns an updated list of merges for the curation.

Parameters:

Name Type Description Default
sorting sorting
required
parent_merge_groups list

Information about previous merges

required
quality_metrics list
required
merge_params dict
required

Returns:

Name Type Description
merge_groups list of lists
merge_occurred bool
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def get_merge_groups(
    sorting, parent_merge_groups, quality_metrics, merge_params
):
    """Identifies units to be merged based on the quality_metrics and
    merge parameters and returns an updated list of merges for the curation.

    Parameters
    ---------
    sorting : spikeinterface.sorting
    parent_merge_groups : list
        Information about previous merges
    quality_metrics : list
    merge_params : dict

    Returns
    -------
    merge_groups : list of lists
    merge_occurred : bool

    """

    # overview:
    # 1. Use quality metrics to determine merge groups for units
    # 2. Combine merge groups with current merge groups to produce union of merges

    if not merge_params:
        return parent_merge_groups, False
    else:
        # TODO: use the metrics to identify clusters that should be merged
        # new_merges should then reflect those merges and the line below should be deleted.
        new_merges = []
        # append these merges to the parent merge_groups
        for new_merge in new_merges:
            # check to see if the first cluster listed is in a current merge group
            for previous_merge in parent_merge_groups:
                if new_merge[0] == previous_merge[0]:
                    # add the additional units in new_merge to the identified merge group.
                    previous_merge.extend(new_merge[1:])
                    previous_merge.sort()
                    break
            else:
                # append this merge group to the list if no previous merge
                parent_merge_groups.append(new_merge)
        return parent_merge_groups.sort(), True

get_labels(sorting, parent_labels, quality_metrics, label_params) staticmethod

Returns a dictionary of labels using quality_metrics and label parameters.

Parameters:

Name Type Description Default
sorting sorting
required
parent_labels list

Information about previous merges

required
quality_metrics list
required
label_params dict
required

Returns:

Name Type Description
parent_labels list
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@staticmethod
def get_labels(sorting, parent_labels, quality_metrics, label_params):
    """Returns a dictionary of labels using quality_metrics and label
    parameters.

    Parameters
    ---------
    sorting : spikeinterface.sorting
    parent_labels : list
        Information about previous merges
    quality_metrics : list
    label_params : dict

    Returns
    -------
    parent_labels : list

    """
    # overview:
    # 1. Use quality metrics to determine labels for units
    # 2. Append labels to current labels, checking for inconsistencies
    if not label_params:
        return parent_labels
    else:
        for metric in label_params:
            if metric not in quality_metrics:
                Warning(f"{metric} not found in quality metrics; skipping")
            else:
                compare = _comparison_to_function[label_params[metric][0]]

                for unit_id in quality_metrics[metric].keys():
                    # compare the quality metric to the threshold with the specified operator
                    # note that label_params[metric] is a three element list with a comparison operator as a string,
                    # the threshold value, and a list of labels to be applied if the comparison is true
                    if compare(
                        quality_metrics[metric][unit_id],
                        label_params[metric][1],
                    ):
                        if unit_id not in parent_labels:
                            parent_labels[unit_id] = label_params[metric][2]
                        # check if the label is already there, and if not, add it
                        elif (
                            label_params[metric][2]
                            not in parent_labels[unit_id]
                        ):
                            parent_labels[unit_id].extend(
                                label_params[metric][2]
                            )
        return parent_labels

CuratedSpikeSorting

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@schema
class CuratedSpikeSorting(SpyglassMixin, dj.Computed):
    definition = """
    -> CuratedSpikeSortingSelection
    ---
    -> AnalysisNwbfile
    units_object_id: varchar(40)
    """

    class Unit(SpyglassMixin, dj.Part):
        definition = """
        # Table for holding sorted units
        -> CuratedSpikeSorting
        unit_id: int   # ID for each unit
        ---
        label='': varchar(200)   # optional set of labels for each unit
        nn_noise_overlap=-1: float   # noise overlap metric for each unit
        nn_isolation=-1: float   # isolation score metric for each unit
        isi_violation=-1: float   # ISI violation score for each unit
        snr=0: float            # SNR for each unit
        firing_rate=-1: float   # firing rate
        num_spikes=-1: int   # total number of spikes
        peak_channel=null: int # channel of maximum amplitude for each unit
        """

    def make(self, key):
        """Populate CuratedSpikeSorting table with curated sorting results.

        1. Fetches metrics and sorting from the Curation table
        2. Saves the sorting in an analysis NWB file
        3. Inserts key into CuratedSpikeSorting table and units into part table.
        """
        AnalysisNwbfile()._creation_times["pre_create_time"] = time.time()
        unit_labels_to_remove = ["reject"]
        # check that the Curation has metrics
        metrics = (Curation & key).fetch1("quality_metrics")
        if metrics == {}:
            logger.warning(
                f"Metrics for Curation {key} should normally be calculated "
                + "before insertion here"
            )

        sorting = Curation.get_curated_sorting(key)
        unit_ids = sorting.get_unit_ids()

        # Get the labels for the units, add only those units that do not have
        # 'reject' or 'noise' labels
        unit_labels = (Curation & key).fetch1("curation_labels")
        accepted_units = []
        for unit_id in unit_ids:
            if unit_id in unit_labels:
                if (
                    len(set(unit_labels_to_remove) & set(unit_labels[unit_id]))
                    == 0
                ):
                    accepted_units.append(unit_id)
            else:
                accepted_units.append(unit_id)

        # get the labels for the accepted units
        labels = {}
        for unit_id in accepted_units:
            if unit_id in unit_labels:
                labels[unit_id] = ",".join(unit_labels[unit_id])

        # convert unit_ids in metrics to integers, including only accepted units.
        #  TODO: convert to int this somewhere else
        final_metrics = {}
        for metric in metrics:
            final_metrics[metric] = {
                int(unit_id): metrics[metric][unit_id]
                for unit_id in metrics[metric]
                if int(unit_id) in accepted_units
            }

        logger.info(f"Found {len(accepted_units)} accepted units")

        # get the sorting and save it in the NWB file
        sorting = Curation.get_curated_sorting(key)
        recording = Curation.get_recording(key)

        # get the sort_interval and sorting interval list
        sort_interval = (SortInterval & key).fetch1("sort_interval")
        sort_interval_list_name = (SpikeSorting & key).fetch1(
            "artifact_removed_interval_list_name"
        )

        timestamps = SpikeSortingRecording._get_recording_timestamps(recording)

        (
            key["analysis_file_name"],
            key["units_object_id"],
        ) = Curation().save_sorting_nwb(
            key,
            sorting,
            timestamps,
            sort_interval_list_name,
            sort_interval,
            metrics=final_metrics,
            unit_ids=accepted_units,
            labels=labels,
        )

        AnalysisNwbfile().log(key, table=self.full_table_name)
        self.insert1(key)

        # now add the units
        # Remove the non primary key entries.
        del key["units_object_id"]
        del key["analysis_file_name"]

        metric_fields = self.metrics_fields()
        for unit_id in accepted_units:
            key["unit_id"] = unit_id
            if unit_id in labels:
                key["label"] = labels[unit_id]
            for field in metric_fields:
                if field in final_metrics:
                    key[field] = final_metrics[field][unit_id]
                else:
                    Warning(
                        f"No metric named {field} in computed unit quality "
                        + "metrics; skipping"
                    )
            CuratedSpikeSorting.Unit.insert1(key)

    def metrics_fields(self):
        """Returns a list of the metrics that are currently in the Units table."""
        unit_info = self.Unit().fetch(limit=1, format="frame")
        unit_fields = [column for column in unit_info.columns]
        unit_fields.remove("label")
        return unit_fields

    @classmethod
    def get_recording(cls, key):
        """Returns the recording related to this curation. Useful for operations downstream of merge table"""
        # expand the key
        recording_key = (cls & key).fetch1("KEY")
        return SpikeSortingRecording()._get_filtered_recording(recording_key)

    @classmethod
    def get_sorting(cls, key):
        """Returns the sorting related to this curation. Useful for operations downstream of merge table"""
        # expand the key
        sorting_key = (cls & key).fetch1("KEY")
        return Curation.get_curated_sorting(sorting_key)

    @classmethod
    def get_sort_group_info(cls, key):
        """Returns the sort group information for the curation
        (e.g. brain region, electrode placement, etc.)

        Parameters
        ----------
        key : dict
            restriction on CuratedSpikeSorting table

        Returns
        -------
        sort_group_info : Table
            Table with information about the sort groups
        """
        table = cls & key

        electrode_restrict_list = []
        for entry in table:
            # Just take one electrode entry per sort group
            electrode_restrict_list.extend(
                ((SortGroup.SortGroupElectrode() & entry) * Electrode).fetch(
                    limit=1
                )
            )
        # Run joins with the tables with info and return
        sort_group_info = (
            (Electrode & electrode_restrict_list)
            * table
            * SortGroup.SortGroupElectrode()
        ) * BrainRegion()
        return sort_group_info

make(key)

Populate CuratedSpikeSorting table with curated sorting results.

  1. Fetches metrics and sorting from the Curation table
  2. Saves the sorting in an analysis NWB file
  3. Inserts key into CuratedSpikeSorting table and units into part table.
Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def make(self, key):
    """Populate CuratedSpikeSorting table with curated sorting results.

    1. Fetches metrics and sorting from the Curation table
    2. Saves the sorting in an analysis NWB file
    3. Inserts key into CuratedSpikeSorting table and units into part table.
    """
    AnalysisNwbfile()._creation_times["pre_create_time"] = time.time()
    unit_labels_to_remove = ["reject"]
    # check that the Curation has metrics
    metrics = (Curation & key).fetch1("quality_metrics")
    if metrics == {}:
        logger.warning(
            f"Metrics for Curation {key} should normally be calculated "
            + "before insertion here"
        )

    sorting = Curation.get_curated_sorting(key)
    unit_ids = sorting.get_unit_ids()

    # Get the labels for the units, add only those units that do not have
    # 'reject' or 'noise' labels
    unit_labels = (Curation & key).fetch1("curation_labels")
    accepted_units = []
    for unit_id in unit_ids:
        if unit_id in unit_labels:
            if (
                len(set(unit_labels_to_remove) & set(unit_labels[unit_id]))
                == 0
            ):
                accepted_units.append(unit_id)
        else:
            accepted_units.append(unit_id)

    # get the labels for the accepted units
    labels = {}
    for unit_id in accepted_units:
        if unit_id in unit_labels:
            labels[unit_id] = ",".join(unit_labels[unit_id])

    # convert unit_ids in metrics to integers, including only accepted units.
    #  TODO: convert to int this somewhere else
    final_metrics = {}
    for metric in metrics:
        final_metrics[metric] = {
            int(unit_id): metrics[metric][unit_id]
            for unit_id in metrics[metric]
            if int(unit_id) in accepted_units
        }

    logger.info(f"Found {len(accepted_units)} accepted units")

    # get the sorting and save it in the NWB file
    sorting = Curation.get_curated_sorting(key)
    recording = Curation.get_recording(key)

    # get the sort_interval and sorting interval list
    sort_interval = (SortInterval & key).fetch1("sort_interval")
    sort_interval_list_name = (SpikeSorting & key).fetch1(
        "artifact_removed_interval_list_name"
    )

    timestamps = SpikeSortingRecording._get_recording_timestamps(recording)

    (
        key["analysis_file_name"],
        key["units_object_id"],
    ) = Curation().save_sorting_nwb(
        key,
        sorting,
        timestamps,
        sort_interval_list_name,
        sort_interval,
        metrics=final_metrics,
        unit_ids=accepted_units,
        labels=labels,
    )

    AnalysisNwbfile().log(key, table=self.full_table_name)
    self.insert1(key)

    # now add the units
    # Remove the non primary key entries.
    del key["units_object_id"]
    del key["analysis_file_name"]

    metric_fields = self.metrics_fields()
    for unit_id in accepted_units:
        key["unit_id"] = unit_id
        if unit_id in labels:
            key["label"] = labels[unit_id]
        for field in metric_fields:
            if field in final_metrics:
                key[field] = final_metrics[field][unit_id]
            else:
                Warning(
                    f"No metric named {field} in computed unit quality "
                    + "metrics; skipping"
                )
        CuratedSpikeSorting.Unit.insert1(key)

metrics_fields()

Returns a list of the metrics that are currently in the Units table.

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
def metrics_fields(self):
    """Returns a list of the metrics that are currently in the Units table."""
    unit_info = self.Unit().fetch(limit=1, format="frame")
    unit_fields = [column for column in unit_info.columns]
    unit_fields.remove("label")
    return unit_fields

get_recording(key) classmethod

Returns the recording related to this curation. Useful for operations downstream of merge table

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@classmethod
def get_recording(cls, key):
    """Returns the recording related to this curation. Useful for operations downstream of merge table"""
    # expand the key
    recording_key = (cls & key).fetch1("KEY")
    return SpikeSortingRecording()._get_filtered_recording(recording_key)

get_sorting(key) classmethod

Returns the sorting related to this curation. Useful for operations downstream of merge table

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@classmethod
def get_sorting(cls, key):
    """Returns the sorting related to this curation. Useful for operations downstream of merge table"""
    # expand the key
    sorting_key = (cls & key).fetch1("KEY")
    return Curation.get_curated_sorting(sorting_key)

get_sort_group_info(key) classmethod

Returns the sort group information for the curation (e.g. brain region, electrode placement, etc.)

Parameters:

Name Type Description Default
key dict

restriction on CuratedSpikeSorting table

required

Returns:

Name Type Description
sort_group_info Table

Table with information about the sort groups

Source code in src/spyglass/spikesorting/v0/spikesorting_curation.py
@classmethod
def get_sort_group_info(cls, key):
    """Returns the sort group information for the curation
    (e.g. brain region, electrode placement, etc.)

    Parameters
    ----------
    key : dict
        restriction on CuratedSpikeSorting table

    Returns
    -------
    sort_group_info : Table
        Table with information about the sort groups
    """
    table = cls & key

    electrode_restrict_list = []
    for entry in table:
        # Just take one electrode entry per sort group
        electrode_restrict_list.extend(
            ((SortGroup.SortGroupElectrode() & entry) * Electrode).fetch(
                limit=1
            )
        )
    # Run joins with the tables with info and return
    sort_group_info = (
        (Electrode & electrode_restrict_list)
        * table
        * SortGroup.SortGroupElectrode()
    ) * BrainRegion()
    return sort_group_info