Skip to content

sorting.py

SpikeSorterParameters

Bases: SpyglassMixin, Lookup

Source code in src/spyglass/spikesorting/v1/sorting.py
@schema
class SpikeSorterParameters(SpyglassMixin, dj.Lookup):
    definition = """
    # Spike sorting algorithm and associated parameters.
    sorter: varchar(200)
    sorter_param_name: varchar(200)
    ---
    sorter_params: blob
    """
    contents = [
        [
            "mountainsort4",
            "franklab_tetrode_hippocampus_30KHz",
            {
                "detect_sign": -1,
                "adjacency_radius": 100,
                "freq_min": 600,
                "freq_max": 6000,
                "filter": False,
                "whiten": True,
                "num_workers": 1,
                "clip_size": 40,
                "detect_threshold": 3,
                "detect_interval": 10,
            },
        ],
        [
            "mountainsort4",
            "franklab_probe_ctx_30KHz",
            {
                "detect_sign": -1,
                "adjacency_radius": 100,
                "freq_min": 300,
                "freq_max": 6000,
                "filter": False,
                "whiten": True,
                "num_workers": 1,
                "clip_size": 40,
                "detect_threshold": 3,
                "detect_interval": 10,
            },
        ],
        [
            "clusterless_thresholder",
            "default_clusterless",
            {
                "detect_threshold": 100.0,  # uV
                # Locally exclusive means one unit per spike detected
                "method": "locally_exclusive",
                "peak_sign": "neg",
                "exclude_sweep_ms": 0.1,
                "local_radius_um": 100,
                # noise levels needs to be 1.0 so the units are in uV and not MAD
                "noise_levels": np.asarray([1.0]),
                "random_chunk_kwargs": {},
                # output needs to be set to sorting for the rest of the pipeline
                "outputs": "sorting",
            },
        ],
    ]
    contents.extend(
        [
            [sorter, "default", sis.get_default_sorter_params(sorter)]
            for sorter in sis.available_sorters()
        ]
    )

    @classmethod
    def insert_default(cls):
        """Insert default sorter parameters into SpikeSorterParameters table."""
        cls.insert(cls.contents, skip_duplicates=True)

insert_default() classmethod

Insert default sorter parameters into SpikeSorterParameters table.

Source code in src/spyglass/spikesorting/v1/sorting.py
@classmethod
def insert_default(cls):
    """Insert default sorter parameters into SpikeSorterParameters table."""
    cls.insert(cls.contents, skip_duplicates=True)

SpikeSortingSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/sorting.py
@schema
class SpikeSortingSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Processed recording and spike sorting parameters. See `insert_selection`.
    sorting_id: uuid
    ---
    -> SpikeSortingRecording
    -> SpikeSorterParameters
    -> IntervalList
    """

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into SpikeSortingSelection with an
        automatically generated unique sorting ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

        Returns
        -------
        sorting_id : uuid
            the unique sorting ID serving as primary key for SpikeSorting
        """
        query = cls & key
        if query:
            logger.info("Similar row(s) already inserted.")
            return query.fetch(as_dict=True)
        key["sorting_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into SpikeSortingSelection with an automatically generated unique sorting ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

required

Returns:

Name Type Description
sorting_id uuid

the unique sorting ID serving as primary key for SpikeSorting

Source code in src/spyglass/spikesorting/v1/sorting.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into SpikeSortingSelection with an
    automatically generated unique sorting ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables

    Returns
    -------
    sorting_id : uuid
        the unique sorting ID serving as primary key for SpikeSorting
    """
    query = cls & key
    if query:
        logger.info("Similar row(s) already inserted.")
        return query.fetch(as_dict=True)
    key["sorting_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

SpikeSorting

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/sorting.py
@schema
class SpikeSorting(SpyglassMixin, dj.Computed):
    definition = """
    -> SpikeSortingSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40)          # Object ID for the sorting in NWB file
    time_of_sort: int               # in Unix time, to the nearest second
    """

    _use_transaction, _allow_insert = False, True

    def make(self, key: dict):
        """Runs spike sorting on the data and parameters specified by the
        SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
        """
        # FETCH:
        # - information about the recording
        # - artifact free intervals
        # - spike sorter and sorter params
        AnalysisNwbfile()._creation_times["pre_create_time"] = time.time()

        recording_key = (
            SpikeSortingRecording * SpikeSortingSelection & key
        ).fetch1()
        artifact_removed_intervals = (
            IntervalList
            & {
                "nwb_file_name": (SpikeSortingSelection & key).fetch1(
                    "nwb_file_name"
                ),
                "interval_list_name": (SpikeSortingSelection & key).fetch1(
                    "interval_list_name"
                ),
            }
        ).fetch1("valid_times")
        sorter, sorter_params = (
            SpikeSorterParameters * SpikeSortingSelection & key
        ).fetch1("sorter", "sorter_params")
        recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
            recording_key["analysis_file_name"]
        )

        # DO:
        # - load recording
        # - concatenate artifact removed intervals
        # - run spike sorting
        # - save output to NWB file
        recording = se.read_nwb_recording(
            recording_analysis_nwb_file_abs_path, load_time_vector=True
        )

        timestamps = recording.get_times()

        artifact_removed_intervals_ind = _consolidate_intervals(
            artifact_removed_intervals, timestamps
        )

        # if the artifact removed intervals do not span the entire time range
        if (
            (len(artifact_removed_intervals_ind) > 1)
            or (artifact_removed_intervals_ind[0][0] > 0)
            or (artifact_removed_intervals_ind[-1][1] < len(timestamps))
        ):
            # set the artifact intervals to zero
            list_triggers = []
            if artifact_removed_intervals_ind[0][0] > 0:
                list_triggers.append(
                    np.arange(0, artifact_removed_intervals_ind[0][0])
                )
            for interval_ind in range(len(artifact_removed_intervals_ind) - 1):
                list_triggers.append(
                    np.arange(
                        (artifact_removed_intervals_ind[interval_ind][1] + 1),
                        artifact_removed_intervals_ind[interval_ind + 1][0],
                    )
                )
            if artifact_removed_intervals_ind[-1][1] < len(timestamps):
                list_triggers.append(
                    np.arange(
                        artifact_removed_intervals_ind[-1][1],
                        len(timestamps) - 1,
                    )
                )

            list_triggers = [list(np.concatenate(list_triggers))]
            recording = sip.remove_artifacts(
                recording=recording,
                list_triggers=list_triggers,
                ms_before=None,
                ms_after=None,
                mode="zeros",
            )

        if sorter == "clusterless_thresholder":
            # need to remove tempdir and whiten from sorter_params
            sorter_params.pop("tempdir", None)
            sorter_params.pop("whiten", None)
            sorter_params.pop("outputs", None)
            if "local_radius_um" in sorter_params:
                sorter_params["radius_um"] = sorter_params.pop(
                    "local_radius_um"
                )  # correct existing parameter sets for spikeinterface>=0.99.1

            # Detect peaks for clusterless decoding
            detected_spikes = detect_peaks(recording, **sorter_params)
            sorting = si.NumpySorting.from_times_labels(
                times_list=detected_spikes["sample_index"],
                labels_list=np.zeros(len(detected_spikes), dtype=np.int32),
                sampling_frequency=recording.get_sampling_frequency(),
            )
        else:
            # Specify tempdir (expected by some sorters like mountainsort4)
            sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir)
            sorter_params["tempdir"] = sorter_temp_dir.name
            os.chmod(sorter_params["tempdir"], 0o777)

            if sorter == "mountainsort5":
                _ = sorter_params.pop("tempdir", None)

            # if whitening is specified in sorter params, apply whitening separately
            # prior to sorting and turn off "sorter whitening"
            if sorter_params.get("whiten", False):
                recording = sip.whiten(recording, dtype=np.float64)
                sorter_params["whiten"] = False

            common_sorter_items = {
                "sorter_name": sorter,
                "recording": recording,
                "output_folder": sorter_temp_dir.name,
                "remove_existing_folder": True,
            }

            if sorter.lower() in ["kilosort2_5", "kilosort3", "ironclust"]:
                sorter_params = {
                    k: v
                    for k, v in sorter_params.items()
                    if k
                    not in ["tempdir", "mp_context", "max_threads_per_process"]
                }
                sorting = sis.run_sorter(
                    **common_sorter_items,
                    singularity_image=True,
                    **sorter_params,
                )
            else:
                sorting = sis.run_sorter(
                    **common_sorter_items,
                    **sorter_params,
                )
        key["time_of_sort"] = int(time.time())
        sorting = sic.remove_excess_spikes(sorting, recording)
        key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb(
            sorting,
            timestamps,
            artifact_removed_intervals,
            (SpikeSortingSelection & key).fetch1("nwb_file_name"),
        )

        # INSERT
        # - new entry to AnalysisNwbfile
        # - new entry to SpikeSorting
        AnalysisNwbfile().add(
            (SpikeSortingSelection & key).fetch1("nwb_file_name"),
            key["analysis_file_name"],
        )
        AnalysisNwbfile().log(key, table=self.full_table_name)
        self.insert1(key, skip_duplicates=True)

    @classmethod
    def get_sorting(cls, key: dict) -> si.BaseSorting:
        """Get sorting in the analysis NWB file as spikeinterface BaseSorting

        Parameters
        ----------
        key : dict
            primary key of SpikeSorting

        Returns
        -------
        sorting : si.BaseSorting

        """

        recording_id = (
            SpikeSortingRecording * SpikeSortingSelection & key
        ).fetch1("recording_id")
        recording = SpikeSortingRecording.get_recording(
            {"recording_id": recording_id}
        )
        sampling_frequency = recording.get_sampling_frequency()
        analysis_file_name = (cls & key).fetch1("analysis_file_name")
        analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
            analysis_file_name
        )
        with pynwb.NWBHDF5IO(
            analysis_file_abs_path, "r", load_namespaces=True
        ) as io:
            nwbf = io.read()
            units = nwbf.units.to_dataframe()
        units_dict_list = [
            {
                unit_id: np.searchsorted(recording.get_times(), spike_times)
                for unit_id, spike_times in zip(
                    units.index, units["spike_times"]
                )
            }
        ]

        sorting = si.NumpySorting.from_unit_dict(
            units_dict_list, sampling_frequency=sampling_frequency
        )

        return sorting

make(key)

Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table.

Source code in src/spyglass/spikesorting/v1/sorting.py
def make(self, key: dict):
    """Runs spike sorting on the data and parameters specified by the
    SpikeSortingSelection table and inserts a new entry to SpikeSorting table.
    """
    # FETCH:
    # - information about the recording
    # - artifact free intervals
    # - spike sorter and sorter params
    AnalysisNwbfile()._creation_times["pre_create_time"] = time.time()

    recording_key = (
        SpikeSortingRecording * SpikeSortingSelection & key
    ).fetch1()
    artifact_removed_intervals = (
        IntervalList
        & {
            "nwb_file_name": (SpikeSortingSelection & key).fetch1(
                "nwb_file_name"
            ),
            "interval_list_name": (SpikeSortingSelection & key).fetch1(
                "interval_list_name"
            ),
        }
    ).fetch1("valid_times")
    sorter, sorter_params = (
        SpikeSorterParameters * SpikeSortingSelection & key
    ).fetch1("sorter", "sorter_params")
    recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
        recording_key["analysis_file_name"]
    )

    # DO:
    # - load recording
    # - concatenate artifact removed intervals
    # - run spike sorting
    # - save output to NWB file
    recording = se.read_nwb_recording(
        recording_analysis_nwb_file_abs_path, load_time_vector=True
    )

    timestamps = recording.get_times()

    artifact_removed_intervals_ind = _consolidate_intervals(
        artifact_removed_intervals, timestamps
    )

    # if the artifact removed intervals do not span the entire time range
    if (
        (len(artifact_removed_intervals_ind) > 1)
        or (artifact_removed_intervals_ind[0][0] > 0)
        or (artifact_removed_intervals_ind[-1][1] < len(timestamps))
    ):
        # set the artifact intervals to zero
        list_triggers = []
        if artifact_removed_intervals_ind[0][0] > 0:
            list_triggers.append(
                np.arange(0, artifact_removed_intervals_ind[0][0])
            )
        for interval_ind in range(len(artifact_removed_intervals_ind) - 1):
            list_triggers.append(
                np.arange(
                    (artifact_removed_intervals_ind[interval_ind][1] + 1),
                    artifact_removed_intervals_ind[interval_ind + 1][0],
                )
            )
        if artifact_removed_intervals_ind[-1][1] < len(timestamps):
            list_triggers.append(
                np.arange(
                    artifact_removed_intervals_ind[-1][1],
                    len(timestamps) - 1,
                )
            )

        list_triggers = [list(np.concatenate(list_triggers))]
        recording = sip.remove_artifacts(
            recording=recording,
            list_triggers=list_triggers,
            ms_before=None,
            ms_after=None,
            mode="zeros",
        )

    if sorter == "clusterless_thresholder":
        # need to remove tempdir and whiten from sorter_params
        sorter_params.pop("tempdir", None)
        sorter_params.pop("whiten", None)
        sorter_params.pop("outputs", None)
        if "local_radius_um" in sorter_params:
            sorter_params["radius_um"] = sorter_params.pop(
                "local_radius_um"
            )  # correct existing parameter sets for spikeinterface>=0.99.1

        # Detect peaks for clusterless decoding
        detected_spikes = detect_peaks(recording, **sorter_params)
        sorting = si.NumpySorting.from_times_labels(
            times_list=detected_spikes["sample_index"],
            labels_list=np.zeros(len(detected_spikes), dtype=np.int32),
            sampling_frequency=recording.get_sampling_frequency(),
        )
    else:
        # Specify tempdir (expected by some sorters like mountainsort4)
        sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir)
        sorter_params["tempdir"] = sorter_temp_dir.name
        os.chmod(sorter_params["tempdir"], 0o777)

        if sorter == "mountainsort5":
            _ = sorter_params.pop("tempdir", None)

        # if whitening is specified in sorter params, apply whitening separately
        # prior to sorting and turn off "sorter whitening"
        if sorter_params.get("whiten", False):
            recording = sip.whiten(recording, dtype=np.float64)
            sorter_params["whiten"] = False

        common_sorter_items = {
            "sorter_name": sorter,
            "recording": recording,
            "output_folder": sorter_temp_dir.name,
            "remove_existing_folder": True,
        }

        if sorter.lower() in ["kilosort2_5", "kilosort3", "ironclust"]:
            sorter_params = {
                k: v
                for k, v in sorter_params.items()
                if k
                not in ["tempdir", "mp_context", "max_threads_per_process"]
            }
            sorting = sis.run_sorter(
                **common_sorter_items,
                singularity_image=True,
                **sorter_params,
            )
        else:
            sorting = sis.run_sorter(
                **common_sorter_items,
                **sorter_params,
            )
    key["time_of_sort"] = int(time.time())
    sorting = sic.remove_excess_spikes(sorting, recording)
    key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb(
        sorting,
        timestamps,
        artifact_removed_intervals,
        (SpikeSortingSelection & key).fetch1("nwb_file_name"),
    )

    # INSERT
    # - new entry to AnalysisNwbfile
    # - new entry to SpikeSorting
    AnalysisNwbfile().add(
        (SpikeSortingSelection & key).fetch1("nwb_file_name"),
        key["analysis_file_name"],
    )
    AnalysisNwbfile().log(key, table=self.full_table_name)
    self.insert1(key, skip_duplicates=True)

get_sorting(key) classmethod

Get sorting in the analysis NWB file as spikeinterface BaseSorting

Parameters:

Name Type Description Default
key dict

primary key of SpikeSorting

required

Returns:

Name Type Description
sorting BaseSorting
Source code in src/spyglass/spikesorting/v1/sorting.py
@classmethod
def get_sorting(cls, key: dict) -> si.BaseSorting:
    """Get sorting in the analysis NWB file as spikeinterface BaseSorting

    Parameters
    ----------
    key : dict
        primary key of SpikeSorting

    Returns
    -------
    sorting : si.BaseSorting

    """

    recording_id = (
        SpikeSortingRecording * SpikeSortingSelection & key
    ).fetch1("recording_id")
    recording = SpikeSortingRecording.get_recording(
        {"recording_id": recording_id}
    )
    sampling_frequency = recording.get_sampling_frequency()
    analysis_file_name = (cls & key).fetch1("analysis_file_name")
    analysis_file_abs_path = AnalysisNwbfile.get_abs_path(
        analysis_file_name
    )
    with pynwb.NWBHDF5IO(
        analysis_file_abs_path, "r", load_namespaces=True
    ) as io:
        nwbf = io.read()
        units = nwbf.units.to_dataframe()
    units_dict_list = [
        {
            unit_id: np.searchsorted(recording.get_times(), spike_times)
            for unit_id, spike_times in zip(
                units.index, units["spike_times"]
            )
        }
    ]

    sorting = si.NumpySorting.from_unit_dict(
        units_dict_list, sampling_frequency=sampling_frequency
    )

    return sorting