Skip to content

sorting.py

SpikeSortingSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/sorting.py
@schema
class SpikeSortingSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Processed recording and spike sorting parameters. Use `insert_selection` method to insert rows.
    sorting_id: uuid
    ---
    -> SpikeSortingRecording
    -> SpikeSorterParameters
    -> IntervalList
    """

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into SpikeSortingSelection with an
        automatically generated unique sorting ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

        Returns
        -------
        sorting_id : uuid
            the unique sorting ID serving as primary key for SpikeSorting
        """
        query = cls & key
        if query:
            logger.info("Similar row(s) already inserted.")
            return query.fetch(as_dict=True)
        key["sorting_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into SpikeSortingSelection with an automatically generated unique sorting ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

required

Returns:

Name Type Description
sorting_id uuid

the unique sorting ID serving as primary key for SpikeSorting

Source code in src/spyglass/spikesorting/v1/sorting.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into SpikeSortingSelection with an
    automatically generated unique sorting ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

    Returns
    -------
    sorting_id : uuid
        the unique sorting ID serving as primary key for SpikeSorting
    """
    query = cls & key
    if query:
        logger.info("Similar row(s) already inserted.")
        return query.fetch(as_dict=True)
    key["sorting_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

SpikeSorting

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/sorting.py
@schema
class SpikeSorting(SpyglassMixin, dj.Computed):
    definition = """
    -> SpikeSortingSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40)          # Object ID for the sorting in NWB file
    time_of_sort: int               # in Unix time, to the nearest second
    """

    def make(self, key: dict):
        """Runs spike sorting on the data and parameters specified by the
        SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
        """
        # FETCH:
        # - information about the recording
        # - artifact free intervals
        # - spike sorter and sorter params
        recording_key = (
            SpikeSortingRecording * SpikeSortingSelection & key
        ).fetch1()
        artifact_removed_intervals = (
            IntervalList
            & {
                "nwb_file_name": (SpikeSortingSelection & key).fetch1(
                    "nwb_file_name"
                ),
                "interval_list_name": (SpikeSortingSelection & key).fetch1(
                    "interval_list_name"
                ),
            }
        ).fetch1("valid_times")
        sorter, sorter_params = (
            SpikeSorterParameters * SpikeSortingSelection & key
        ).fetch1("sorter", "sorter_params")

        # DO:
        # - load recording
        # - concatenate artifact removed intervals
        # - run spike sorting
        # - save output to NWB file
        recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
            recording_key["analysis_file_name"]
        )
        recording = se.read_nwb_recording(
            recording_analysis_nwb_file_abs_path, load_time_vector=True
        )

        timestamps = recording.get_times()

        artifact_removed_intervals_ind = _consolidate_intervals(
            artifact_removed_intervals, timestamps
        )

        # if the artifact removed intervals do not span the entire time range
        if (
            (len(artifact_removed_intervals_ind) > 1)
            or (artifact_removed_intervals_ind[0][0] > 0)
            or (artifact_removed_intervals_ind[-1][1] < len(timestamps))
        ):
            # set the artifact intervals to zero
            list_triggers = []
            if artifact_removed_intervals_ind[0][0] > 0:
                list_triggers.append(
                    np.array([0, artifact_removed_intervals_ind[0][0]])
                )
            for interval_ind in range(len(artifact_removed_intervals_ind) - 1):
                list_triggers.append(
                    np.arange(
                        (artifact_removed_intervals_ind[interval_ind][1] + 1),
                        artifact_removed_intervals_ind[interval_ind + 1][0],
                    )
                )
            if artifact_removed_intervals_ind[-1][1] < len(timestamps):
                list_triggers.append(
                    np.array(
                        [
                            artifact_removed_intervals_ind[-1][1],
                            len(timestamps) - 1,
                        ]
                    )
                )

            list_triggers = [list(np.concatenate(list_triggers))]
            recording = sip.remove_artifacts(
                recording=recording,
                list_triggers=list_triggers,
                ms_before=None,
                ms_after=None,
                mode="zeros",
            )

        if sorter == "clusterless_thresholder":
            # need to remove tempdir and whiten from sorter_params
            sorter_params.pop("tempdir", None)
            sorter_params.pop("whiten", None)
            sorter_params.pop("outputs", None)

            # Detect peaks for clusterless decoding
            detected_spikes = detect_peaks(recording, **sorter_params)
            sorting = si.NumpySorting.from_times_labels(
                times_list=detected_spikes["sample_index"],
                labels_list=np.zeros(len(detected_spikes), dtype=np.int),
                sampling_frequency=recording.get_sampling_frequency(),
            )
        else:
            # Specify tempdir (expected by some sorters like mountainsort4)
            sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir)
            sorter_params["tempdir"] = sorter_temp_dir.name
            # if whitening is specified in sorter params, apply whitening separately
            # prior to sorting and turn off "sorter whitening"
            if sorter_params["whiten"]:
                recording = sip.whiten(recording, dtype=np.float64)
                sorter_params["whiten"] = False
            sorting = sis.run_sorter(
                sorter,
                recording,
                output_folder=sorter_temp_dir.name,
                remove_existing_folder=True,
                **sorter_params,
            )
        key["time_of_sort"] = int(time.time())
        sorting = sic.remove_excess_spikes(sorting, recording)
        key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb(
            sorting,
            timestamps,
            artifact_removed_intervals,
            (SpikeSortingSelection & key).fetch1("nwb_file_name"),
        )

        # INSERT
        # - new entry to AnalysisNwbfile
        # - new entry to SpikeSorting
        AnalysisNwbfile().add(
            (SpikeSortingSelection & key).fetch1("nwb_file_name"),
            key["analysis_file_name"],
        )
        self.insert1(key, skip_duplicates=True)

    @classmethod
    def get_sorting(cls, key: dict) -> si.BaseSorting:
        """Get sorting in the analysis NWB file as spikeinterface BaseSorting

        Parameters
        ----------
        key : dict
            primary key of SpikeSorting

        Returns
        -------
        sorting : si.BaseSorting

        """

        recording_id = (
            SpikeSortingRecording * SpikeSortingSelection & key
        ).fetch1("recording_id")
        recording = SpikeSortingRecording.get_recording(
            {"recording_id": recording_id}
        )
        sampling_frequency = recording.get_sampling_frequency()
        analysis_file_name = (cls & key).fetch1("analysis_file_name")
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            analysis_file_abs_path, "r", load_namespaces=True
        ) as io:
            nwbf = io.read()
            units = nwbf.units.to_dataframe()
        units_dict_list = [
            {
                unit_id: np.searchsorted(recording.get_times(), spike_times)
                for unit_id, spike_times in zip(
                    units.index, units["spike_times"]
                )
            }
        ]

        sorting = si.NumpySorting.from_unit_dict(
            units_dict_list, sampling_frequency=sampling_frequency
        )

        return sorting

make(key)

Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table.

Source code in src/spyglass/spikesorting/v1/sorting.py
def make(self, key: dict):
    """Runs spike sorting on the data and parameters specified by the
    SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
    """
    # FETCH:
    # - information about the recording
    # - artifact free intervals
    # - spike sorter and sorter params
    recording_key = (
        SpikeSortingRecording * SpikeSortingSelection & key
    ).fetch1()
    artifact_removed_intervals = (
        IntervalList
        & {
            "nwb_file_name": (SpikeSortingSelection & key).fetch1(
                "nwb_file_name"
            ),
            "interval_list_name": (SpikeSortingSelection & key).fetch1(
                "interval_list_name"
            ),
        }
    ).fetch1("valid_times")
    sorter, sorter_params = (
        SpikeSorterParameters * SpikeSortingSelection & key
    ).fetch1("sorter", "sorter_params")

    # DO:
    # - load recording
    # - concatenate artifact removed intervals
    # - run spike sorting
    # - save output to NWB file
    recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
        recording_key["analysis_file_name"]
    )
    recording = se.read_nwb_recording(
        recording_analysis_nwb_file_abs_path, load_time_vector=True
    )

    timestamps = recording.get_times()

    artifact_removed_intervals_ind = _consolidate_intervals(
        artifact_removed_intervals, timestamps
    )

    # if the artifact removed intervals do not span the entire time range
    if (
        (len(artifact_removed_intervals_ind) > 1)
        or (artifact_removed_intervals_ind[0][0] > 0)
        or (artifact_removed_intervals_ind[-1][1] < len(timestamps))
    ):
        # set the artifact intervals to zero
        list_triggers = []
        if artifact_removed_intervals_ind[0][0] > 0:
            list_triggers.append(
                np.array([0, artifact_removed_intervals_ind[0][0]])
            )
        for interval_ind in range(len(artifact_removed_intervals_ind) - 1):
            list_triggers.append(
                np.arange(
                    (artifact_removed_intervals_ind[interval_ind][1] + 1),
                    artifact_removed_intervals_ind[interval_ind + 1][0],
                )
            )
        if artifact_removed_intervals_ind[-1][1] < len(timestamps):
            list_triggers.append(
                np.array(
                    [
                        artifact_removed_intervals_ind[-1][1],
                        len(timestamps) - 1,
                    ]
                )
            )

        list_triggers = [list(np.concatenate(list_triggers))]
        recording = sip.remove_artifacts(
            recording=recording,
            list_triggers=list_triggers,
            ms_before=None,
            ms_after=None,
            mode="zeros",
        )

    if sorter == "clusterless_thresholder":
        # need to remove tempdir and whiten from sorter_params
        sorter_params.pop("tempdir", None)
        sorter_params.pop("whiten", None)
        sorter_params.pop("outputs", None)

        # Detect peaks for clusterless decoding
        detected_spikes = detect_peaks(recording, **sorter_params)
        sorting = si.NumpySorting.from_times_labels(
            times_list=detected_spikes["sample_index"],
            labels_list=np.zeros(len(detected_spikes), dtype=np.int),
            sampling_frequency=recording.get_sampling_frequency(),
        )
    else:
        # Specify tempdir (expected by some sorters like mountainsort4)
        sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir)
        sorter_params["tempdir"] = sorter_temp_dir.name
        # if whitening is specified in sorter params, apply whitening separately
        # prior to sorting and turn off "sorter whitening"
        if sorter_params["whiten"]:
            recording = sip.whiten(recording, dtype=np.float64)
            sorter_params["whiten"] = False
        sorting = sis.run_sorter(
            sorter,
            recording,
            output_folder=sorter_temp_dir.name,
            remove_existing_folder=True,
            **sorter_params,
        )
    key["time_of_sort"] = int(time.time())
    sorting = sic.remove_excess_spikes(sorting, recording)
    key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb(
        sorting,
        timestamps,
        artifact_removed_intervals,
        (SpikeSortingSelection & key).fetch1("nwb_file_name"),
    )

    # INSERT
    # - new entry to AnalysisNwbfile
    # - new entry to SpikeSorting
    AnalysisNwbfile().add(
        (SpikeSortingSelection & key).fetch1("nwb_file_name"),
        key["analysis_file_name"],
    )
    self.insert1(key, skip_duplicates=True)

get_sorting(key) classmethod

Get sorting in the analysis NWB file as spikeinterface BaseSorting

Parameters:

Name Type Description Default
key dict

primary key of SpikeSorting

required

Returns:

Name Type Description
sorting BaseSorting
Source code in src/spyglass/spikesorting/v1/sorting.py
@classmethod
def get_sorting(cls, key: dict) -> si.BaseSorting:
    """Get sorting in the analysis NWB file as spikeinterface BaseSorting

    Parameters
    ----------
    key : dict
        primary key of SpikeSorting

    Returns
    -------
    sorting : si.BaseSorting

    """

    recording_id = (
        SpikeSortingRecording * SpikeSortingSelection & key
    ).fetch1("recording_id")
    recording = SpikeSortingRecording.get_recording(
        {"recording_id": recording_id}
    )
    sampling_frequency = recording.get_sampling_frequency()
    analysis_file_name = (cls & key).fetch1("analysis_file_name")
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        analysis_file_abs_path, "r", load_namespaces=True
    ) as io:
        nwbf = io.read()
        units = nwbf.units.to_dataframe()
    units_dict_list = [
        {
            unit_id: np.searchsorted(recording.get_times(), spike_times)
            for unit_id, spike_times in zip(
                units.index, units["spike_times"]
            )
        }
    ]

    sorting = si.NumpySorting.from_unit_dict(
        units_dict_list, sampling_frequency=sampling_frequency
    )

    return sorting