Skip to content


Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/analysis/v1/
class UnitSelectionParams(SpyglassMixin, dj.Manual):
    definition = """
    unit_filter_params_name: varchar(32)
    include_labels = Null: longblob
    exclude_labels = Null: longblob
    # NOTE: pk reduced from 128 to 32 to avoid long primary key error
    contents = [
            ["noise", "mua"],
            ["noise", "mua"],

    def insert_default(cls):
        """Insert default unit selection parameters"""
        cls.insert(cls.contents, skip_duplicates=True)

insert_default() classmethod

Insert default unit selection parameters

Source code in src/spyglass/spikesorting/analysis/v1/
def insert_default(cls):
    """Insert default unit selection parameters"""
    cls.insert(cls.contents, skip_duplicates=True)


Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/analysis/v1/
class SortedSpikesGroup(SpyglassMixin, dj.Manual):
    definition = """
    -> Session
    -> UnitSelectionParams
    sorted_spikes_group_name: varchar(80)

    class Units(SpyglassMixinPart):
        definition = """
        -> master
        -> SpikeSortingOutput.proj(spikesorting_merge_id='merge_id')

    def create_group(
        group_name: str,
        nwb_file_name: str,
        unit_filter_params_name: str = "all_units",
        keys: list[dict] = [],
        """Create a new group of sorted spikes"""
        group_key = {
            "sorted_spikes_group_name": group_name,
            "nwb_file_name": nwb_file_name,
            "unit_filter_params_name": unit_filter_params_name,
        if self & group_key:
            if test_mode:
            raise ValueError(
                f"Group {nwb_file_name}: {group_name} already exists",
                "please delete the group before creating a new one",

        parts_insert = [{**key, **group_key} for key in keys]

        self.Units.insert(parts_insert, skip_duplicates=True)

    def filter_units(
        labels: list[list[str]],
        include_labels: list[str],
        exclude_labels: list[str],
    ) -> np.ndarray:
        Filter units based on labels

        labels: list of list of strings
            list of labels for each unit
        include_labels: list of strings
            if provided, only units with any of these labels will be included
        exclude_labels: list of strings
            if provided, units with any of these labels will be excluded
        include_labels = np.unique(include_labels)
        exclude_labels = np.unique(exclude_labels)

        if include_labels.size == 0 and exclude_labels.size == 0:
            # if no labels are provided, include all units
            return np.ones(len(labels), dtype=bool)

        include_mask = np.zeros(len(labels), dtype=bool)
        for ind, unit_labels in enumerate(labels):
            if isinstance(unit_labels, str):
                unit_labels = [unit_labels]
            if (
                include_labels.size > 0
                and np.all(~np.isin(unit_labels, include_labels))
            ) or np.any(np.isin(unit_labels, exclude_labels)):
                # if the unit does not have any of the include labels
                # or has any of the exclude labels, skip
            include_mask[ind] = True
        return include_mask

    def fetch_spike_data(
        key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
    ) -> Union[list[np.ndarray], Optional[list[dict]]]:
        """fetch spike times for units in the group

        key : dict
            dictionary containing the group key
        time_slice : list of float, optional
            if provided, filter for spikes occurring in the interval [start, stop], by default None
        return_unit_ids : bool, optional
            if True, return the unit_ids along with the spike times, by default False
            Unit ids defined as a list of dictionaries with keys 'spikesorting_merge_id' and 'unit_number'

        list of np.ndarray
            list of spike times for each unit in the group
        # get merge_ids for SpikeSortingOutput
        merge_ids = (
                & {
                    "nwb_file_name": key["nwb_file_name"],
                    "sorted_spikes_group_name": key["sorted_spikes_group_name"],

        # get the filtering parameters
        include_labels, exclude_labels = (UnitSelectionParams & key).fetch1(
            "include_labels", "exclude_labels"

        # get the spike times for each merge_id
        spike_times = []
        unit_ids = []
        merge_keys = [dict(merge_id=merge_id) for merge_id in merge_ids]
        nwb_file_list, merge_ids = (SpikeSortingOutput & merge_keys).fetch_nwb(
        for nwb_file, merge_id in zip(nwb_file_list, merge_ids):
            nwb_field_name = _get_spike_obj_name(nwb_file, allow_empty=True)
            if nwb_field_name is None:
                # case where no units found or curation removed all units
            sorting_spike_times = nwb_file[nwb_field_name][
            file_unit_ids = [
                {"spikesorting_merge_id": merge_id, "unit_id": unit_id}
                for unit_id in range(len(sorting_spike_times))

            # filter the spike times based on the labels if present
            if "label" in nwb_file[nwb_field_name]:
                group_label_list = nwb_file[nwb_field_name]["label"].to_list()
                include_unit = SortedSpikesGroup.filter_units(
                    group_label_list, include_labels, exclude_labels

                sorting_spike_times = list(
                    compress(sorting_spike_times, include_unit)
                file_unit_ids = list(compress(file_unit_ids, include_unit))

            # filter the spike times based on the time slice if provided
            if time_slice is not None:
                sorting_spike_times = [
                            times >= time_slice.start, times <= time_slice.stop
                    for times in sorting_spike_times

            # append the approved spike times to the list

        if return_unit_ids:
            return spike_times, unit_ids
        return spike_times

    def get_spike_indicator(cls, key: dict, time: np.ndarray) -> np.ndarray:
        """Get spike indicator matrix for the group

        key : dict
            key to identify the group
        time : np.ndarray
            time vector for which to calculate the spike indicator matrix

            spike indicator matrix with shape (len(time), n_units)
        time = np.asarray(time)
        min_time, max_time = time[[0, -1]]
        spike_times = cls.fetch_spike_data(key)
        spike_indicator = np.zeros((len(time), len(spike_times)))

        for ind, times in enumerate(spike_times):
            times = times[np.logical_and(times >= min_time, times <= max_time)]
            spike_indicator[:, ind] = np.bincount(
                np.digitize(times, time[1:-1]),

        if spike_indicator.ndim == 1:
            spike_indicator = spike_indicator[:, np.newaxis]

        return spike_indicator

    def get_firing_rate(
        key: dict,
        time: np.ndarray,
        multiunit: bool = False,
        smoothing_sigma: float = 0.015,
    ) -> np.ndarray:
        """Get time-dependent firing rate for units in the group

        key : dict
            key to identify the group
        time : np.ndarray
            time vector for which to calculate the firing rate
        multiunit : bool, optional
            if True, return the multiunit firing rate for units in the group,
            by default False
        smoothing_sigma : float, optional
            standard deviation of gaussian filter to smooth firing rates in
            seconds, by default 0.015

            time-dependent firing rate with shape (len(time), n_units)
        return firing_rate_from_spike_indicator(
            spike_indicator=cls.get_spike_indicator(key, time),

create_group(group_name, nwb_file_name, unit_filter_params_name='all_units', keys=[])

Create a new group of sorted spikes

Source code in src/spyglass/spikesorting/analysis/v1/
def create_group(
    group_name: str,
    nwb_file_name: str,
    unit_filter_params_name: str = "all_units",
    keys: list[dict] = [],
    """Create a new group of sorted spikes"""
    group_key = {
        "sorted_spikes_group_name": group_name,
        "nwb_file_name": nwb_file_name,
        "unit_filter_params_name": unit_filter_params_name,
    if self & group_key:
        if test_mode:
        raise ValueError(
            f"Group {nwb_file_name}: {group_name} already exists",
            "please delete the group before creating a new one",

    parts_insert = [{**key, **group_key} for key in keys]

    self.Units.insert(parts_insert, skip_duplicates=True)

filter_units(labels, include_labels, exclude_labels) staticmethod

Filter units based on labels

labels: list of list of strings list of labels for each unit include_labels: list of strings if provided, only units with any of these labels will be included exclude_labels: list of strings if provided, units with any of these labels will be excluded

Source code in src/spyglass/spikesorting/analysis/v1/
def filter_units(
    labels: list[list[str]],
    include_labels: list[str],
    exclude_labels: list[str],
) -> np.ndarray:
    Filter units based on labels

    labels: list of list of strings
        list of labels for each unit
    include_labels: list of strings
        if provided, only units with any of these labels will be included
    exclude_labels: list of strings
        if provided, units with any of these labels will be excluded
    include_labels = np.unique(include_labels)
    exclude_labels = np.unique(exclude_labels)

    if include_labels.size == 0 and exclude_labels.size == 0:
        # if no labels are provided, include all units
        return np.ones(len(labels), dtype=bool)

    include_mask = np.zeros(len(labels), dtype=bool)
    for ind, unit_labels in enumerate(labels):
        if isinstance(unit_labels, str):
            unit_labels = [unit_labels]
        if (
            include_labels.size > 0
            and np.all(~np.isin(unit_labels, include_labels))
        ) or np.any(np.isin(unit_labels, exclude_labels)):
            # if the unit does not have any of the include labels
            # or has any of the exclude labels, skip
        include_mask[ind] = True
    return include_mask

fetch_spike_data(key, time_slice=None, return_unit_ids=False) staticmethod

fetch spike times for units in the group


Name Type Description Default
key dict

dictionary containing the group key

time_slice list of float

if provided, filter for spikes occurring in the interval [start, stop], by default None

return_unit_ids bool

if True, return the unit_ids along with the spike times, by default False Unit ids defined as a list of dictionaries with keys 'spikesorting_merge_id' and 'unit_number'



Type Description
list of np.ndarray

list of spike times for each unit in the group

Source code in src/spyglass/spikesorting/analysis/v1/
def fetch_spike_data(
    key: dict, time_slice: list[float] = None, return_unit_ids: bool = False
) -> Union[list[np.ndarray], Optional[list[dict]]]:
    """fetch spike times for units in the group

    key : dict
        dictionary containing the group key
    time_slice : list of float, optional
        if provided, filter for spikes occurring in the interval [start, stop], by default None
    return_unit_ids : bool, optional
        if True, return the unit_ids along with the spike times, by default False
        Unit ids defined as a list of dictionaries with keys 'spikesorting_merge_id' and 'unit_number'

    list of np.ndarray
        list of spike times for each unit in the group
    # get merge_ids for SpikeSortingOutput
    merge_ids = (
            & {
                "nwb_file_name": key["nwb_file_name"],
                "sorted_spikes_group_name": key["sorted_spikes_group_name"],

    # get the filtering parameters
    include_labels, exclude_labels = (UnitSelectionParams & key).fetch1(
        "include_labels", "exclude_labels"

    # get the spike times for each merge_id
    spike_times = []
    unit_ids = []
    merge_keys = [dict(merge_id=merge_id) for merge_id in merge_ids]
    nwb_file_list, merge_ids = (SpikeSortingOutput & merge_keys).fetch_nwb(
    for nwb_file, merge_id in zip(nwb_file_list, merge_ids):
        nwb_field_name = _get_spike_obj_name(nwb_file, allow_empty=True)
        if nwb_field_name is None:
            # case where no units found or curation removed all units
        sorting_spike_times = nwb_file[nwb_field_name][
        file_unit_ids = [
            {"spikesorting_merge_id": merge_id, "unit_id": unit_id}
            for unit_id in range(len(sorting_spike_times))

        # filter the spike times based on the labels if present
        if "label" in nwb_file[nwb_field_name]:
            group_label_list = nwb_file[nwb_field_name]["label"].to_list()
            include_unit = SortedSpikesGroup.filter_units(
                group_label_list, include_labels, exclude_labels

            sorting_spike_times = list(
                compress(sorting_spike_times, include_unit)
            file_unit_ids = list(compress(file_unit_ids, include_unit))

        # filter the spike times based on the time slice if provided
        if time_slice is not None:
            sorting_spike_times = [
                        times >= time_slice.start, times <= time_slice.stop
                for times in sorting_spike_times

        # append the approved spike times to the list

    if return_unit_ids:
        return spike_times, unit_ids
    return spike_times

get_spike_indicator(key, time) classmethod

Get spike indicator matrix for the group


Name Type Description Default
key dict

key to identify the group

time ndarray

time vector for which to calculate the spike indicator matrix



Type Description

spike indicator matrix with shape (len(time), n_units)

Source code in src/spyglass/spikesorting/analysis/v1/
def get_spike_indicator(cls, key: dict, time: np.ndarray) -> np.ndarray:
    """Get spike indicator matrix for the group

    key : dict
        key to identify the group
    time : np.ndarray
        time vector for which to calculate the spike indicator matrix

        spike indicator matrix with shape (len(time), n_units)
    time = np.asarray(time)
    min_time, max_time = time[[0, -1]]
    spike_times = cls.fetch_spike_data(key)
    spike_indicator = np.zeros((len(time), len(spike_times)))

    for ind, times in enumerate(spike_times):
        times = times[np.logical_and(times >= min_time, times <= max_time)]
        spike_indicator[:, ind] = np.bincount(
            np.digitize(times, time[1:-1]),

    if spike_indicator.ndim == 1:
        spike_indicator = spike_indicator[:, np.newaxis]

    return spike_indicator

get_firing_rate(key, time, multiunit=False, smoothing_sigma=0.015) classmethod

Get time-dependent firing rate for units in the group


Name Type Description Default
key dict

key to identify the group

time ndarray

time vector for which to calculate the firing rate

multiunit bool

if True, return the multiunit firing rate for units in the group, by default False

smoothing_sigma float

standard deviation of gaussian filter to smooth firing rates in seconds, by default 0.015



Type Description

time-dependent firing rate with shape (len(time), n_units)

Source code in src/spyglass/spikesorting/analysis/v1/
def get_firing_rate(
    key: dict,
    time: np.ndarray,
    multiunit: bool = False,
    smoothing_sigma: float = 0.015,
) -> np.ndarray:
    """Get time-dependent firing rate for units in the group

    key : dict
        key to identify the group
    time : np.ndarray
        time vector for which to calculate the firing rate
    multiunit : bool, optional
        if True, return the multiunit firing rate for units in the group,
        by default False
    smoothing_sigma : float, optional
        standard deviation of gaussian filter to smooth firing rates in
        seconds, by default 0.015

        time-dependent firing rate with shape (len(time), n_units)
    return firing_rate_from_spike_indicator(
        spike_indicator=cls.get_spike_indicator(key, time),