Skip to content

merge.py

SpikeSortingOutput

Bases: _Merge, SpyglassMixin

Source code in src/spyglass/spikesorting/merge.py
@schema
class SpikeSortingOutput(_Merge, SpyglassMixin):
    definition = """
    # Output of spike sorting pipelines.
    merge_id: uuid
    ---
    source: varchar(32)
    """

    class CurationV1(SpyglassMixin, dj.Part):  # noqa: F811
        definition = """
        -> master
        ---
        -> CurationV1
        """

    class ImportedSpikeSorting(SpyglassMixin, dj.Part):  # noqa: F811
        definition = """
        -> master
        ---
        -> ImportedSpikeSorting
        """

    class CuratedSpikeSorting(SpyglassMixin, dj.Part):  # noqa: F811
        definition = """
        -> master
        ---
        -> CuratedSpikeSorting
        """

    @classmethod
    def get_recording(cls, key):
        """get the recording associated with a spike sorting output"""
        source_table = source_class_dict[
            to_camel_case(cls.merge_get_parent(key).table_name)
        ]
        query = source_table & cls.merge_get_part(key)
        return query.get_recording(query.fetch("KEY"))

    @classmethod
    def get_sorting(cls, key):
        """get the sorting associated with a spike sorting output"""
        source_table = source_class_dict[
            to_camel_case(cls.merge_get_parent(key).table_name)
        ]
        query = source_table & cls.merge_get_part(key)
        return query.get_sorting(query.fetch("KEY"))

    @classmethod
    def get_spike_times(cls, key):
        spike_times = []
        for nwb_file in cls.fetch_nwb(key):
            # V1 uses 'object_id', V0 uses 'units'
            file_loc = "object_id" if "object_id" in nwb_file else "units"
            spike_times.extend(nwb_file[file_loc]["spike_times"].to_list())
        return spike_times

    @classmethod
    def get_spike_indicator(cls, key, time):
        time = np.asarray(time)
        min_time, max_time = time[[0, -1]]
        spike_times = cls.get_spike_times(key)
        spike_indicator = np.zeros((len(time), len(spike_times)))

        for ind, times in enumerate(spike_times):
            times = times[
                np.logical_and(spike_times >= min_time, spike_times <= max_time)
            ]
            spike_indicator[:, ind] = np.bincount(
                np.digitize(times, time[1:-1]),
                minlength=time.shape[0],
            )

        return spike_indicator

    @classmethod
    def get_firing_rate(cls, key, time, multiunit=False):
        spike_indicator = cls.get_spike_indicator(key, time)
        if spike_indicator.ndim == 1:
            spike_indicator = spike_indicator[:, np.newaxis]

        sampling_frequency = 1 / np.median(np.diff(time))

        if multiunit:
            spike_indicator = spike_indicator.sum(axis=1, keepdims=True)
        return np.stack(
            [
                get_multiunit_population_firing_rate(
                    indicator[:, np.newaxis], sampling_frequency
                )
                for indicator in spike_indicator.T
            ],
            axis=1,
        )

get_recording(key) classmethod

get the recording associated with a spike sorting output

Source code in src/spyglass/spikesorting/merge.py
@classmethod
def get_recording(cls, key):
    """get the recording associated with a spike sorting output"""
    source_table = source_class_dict[
        to_camel_case(cls.merge_get_parent(key).table_name)
    ]
    query = source_table & cls.merge_get_part(key)
    return query.get_recording(query.fetch("KEY"))

get_sorting(key) classmethod

get the sorting associated with a spike sorting output

Source code in src/spyglass/spikesorting/merge.py
@classmethod
def get_sorting(cls, key):
    """get the sorting associated with a spike sorting output"""
    source_table = source_class_dict[
        to_camel_case(cls.merge_get_parent(key).table_name)
    ]
    query = source_table & cls.merge_get_part(key)
    return query.get_sorting(query.fetch("KEY"))