Skip to content

sorted_spikes.py

Pipeline for decoding the animal's mental position and some category of interest from clustered spikes times. See [1] for details.

References

[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021).

SortedSpikesIndicatorSelection

Bases: SpyglassMixin, Lookup

Bins spike times into regular intervals given by the sampling rate.

Start and stop time of the interval are defined by the interval list.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesIndicatorSelection(SpyglassMixin, dj.Lookup):
    """Bins spike times into regular intervals given by the sampling rate.

    Start and stop time of the interval are defined by the interval list.
    """

    definition = """
    -> CuratedSpikeSorting
    -> IntervalList
    sampling_rate=500 : float
    ---
    """

SortedSpikesIndicator

Bases: SpyglassMixin, Computed

Bins spike times into regular intervals given by the sampling rate.

Useful for GLMs and for decoding.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesIndicator(SpyglassMixin, dj.Computed):
    """Bins spike times into regular intervals given by the sampling rate.

    Useful for GLMs and for decoding.
    """

    definition = """
    -> SortedSpikesIndicatorSelection
    ---
    -> AnalysisNwbfile
    spike_indicator_object_id: varchar(40)
    """

    def make(self, key):
        """Populate the SortedSpikesIndicator table.

        Fetches the spike times from the CuratedSpikeSorting table and bins
        them into regular intervals given by the sampling rate. The spike
        indicator is stored in an AnalysisNwbfile.
        """
        pprint.pprint(key)
        # TODO: intersection of sort interval and interval list
        interval_times = (IntervalList & key).fetch1("valid_times")

        sampling_rate = (SortedSpikesIndicatorSelection & key).fetch(
            "sampling_rate"
        )

        time = get_time_bins_from_interval(interval_times, sampling_rate)

        spikes_nwb = (CuratedSpikeSorting & key).fetch_nwb()
        # restrict to cases with units
        spikes_nwb = [entry for entry in spikes_nwb if "units" in entry]
        spike_times_list = [
            np.asarray(n_trode["units"]["spike_times"])
            for n_trode in spikes_nwb
        ]
        if len(spike_times_list) > 0:  # if units
            spikes = np.concatenate(spike_times_list)

            # Bin spikes into time bins
            spike_indicator = []
            for spike_times in spikes:
                spike_times = spike_times[
                    (spike_times > time[0]) & (spike_times <= time[-1])
                ]
                spike_indicator.append(
                    np.bincount(
                        np.digitize(spike_times, time[1:-1]),
                        minlength=time.shape[0],
                    )
                )

            column_names = np.concatenate(
                [
                    [
                        f'{n_trode["sort_group_id"]:04d}_{unit_number:04d}'
                        for unit_number in n_trode["units"].index
                    ]
                    for n_trode in spikes_nwb
                ]
            )
            spike_indicator = pd.DataFrame(
                np.stack(spike_indicator, axis=1),
                index=pd.Index(time, name="time"),
                columns=column_names,
            )

            # Insert into analysis nwb file
            nwb_analysis_file = AnalysisNwbfile()
            key["analysis_file_name"] = nwb_analysis_file.create(
                key["nwb_file_name"]
            )

            key["spike_indicator_object_id"] = nwb_analysis_file.add_nwb_object(
                analysis_file_name=key["analysis_file_name"],
                nwb_object=spike_indicator.reset_index(),
            )

            nwb_analysis_file.add(
                nwb_file_name=key["nwb_file_name"],
                analysis_file_name=key["analysis_file_name"],
            )

            self.insert1(key)

    def fetch1_dataframe(self) -> pd.DataFrame:
        """Return the first spike indicator as a dataframe."""
        return self.fetch_dataframe()[0]

    def fetch_dataframe(self) -> list[pd.DataFrame]:
        """Return all spike indicators as a list of dataframes."""
        return pd.concat(
            [
                data["spike_indicator"].set_index("time")
                for data in self.fetch_nwb()
            ],
            axis=1,
        )

make(key)

Populate the SortedSpikesIndicator table.

Fetches the spike times from the CuratedSpikeSorting table and bins them into regular intervals given by the sampling rate. The spike indicator is stored in an AnalysisNwbfile.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def make(self, key):
    """Populate the SortedSpikesIndicator table.

    Fetches the spike times from the CuratedSpikeSorting table and bins
    them into regular intervals given by the sampling rate. The spike
    indicator is stored in an AnalysisNwbfile.
    """
    pprint.pprint(key)
    # TODO: intersection of sort interval and interval list
    interval_times = (IntervalList & key).fetch1("valid_times")

    sampling_rate = (SortedSpikesIndicatorSelection & key).fetch(
        "sampling_rate"
    )

    time = get_time_bins_from_interval(interval_times, sampling_rate)

    spikes_nwb = (CuratedSpikeSorting & key).fetch_nwb()
    # restrict to cases with units
    spikes_nwb = [entry for entry in spikes_nwb if "units" in entry]
    spike_times_list = [
        np.asarray(n_trode["units"]["spike_times"])
        for n_trode in spikes_nwb
    ]
    if len(spike_times_list) > 0:  # if units
        spikes = np.concatenate(spike_times_list)

        # Bin spikes into time bins
        spike_indicator = []
        for spike_times in spikes:
            spike_times = spike_times[
                (spike_times > time[0]) & (spike_times <= time[-1])
            ]
            spike_indicator.append(
                np.bincount(
                    np.digitize(spike_times, time[1:-1]),
                    minlength=time.shape[0],
                )
            )

        column_names = np.concatenate(
            [
                [
                    f'{n_trode["sort_group_id"]:04d}_{unit_number:04d}'
                    for unit_number in n_trode["units"].index
                ]
                for n_trode in spikes_nwb
            ]
        )
        spike_indicator = pd.DataFrame(
            np.stack(spike_indicator, axis=1),
            index=pd.Index(time, name="time"),
            columns=column_names,
        )

        # Insert into analysis nwb file
        nwb_analysis_file = AnalysisNwbfile()
        key["analysis_file_name"] = nwb_analysis_file.create(
            key["nwb_file_name"]
        )

        key["spike_indicator_object_id"] = nwb_analysis_file.add_nwb_object(
            analysis_file_name=key["analysis_file_name"],
            nwb_object=spike_indicator.reset_index(),
        )

        nwb_analysis_file.add(
            nwb_file_name=key["nwb_file_name"],
            analysis_file_name=key["analysis_file_name"],
        )

        self.insert1(key)

fetch1_dataframe()

Return the first spike indicator as a dataframe.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch1_dataframe(self) -> pd.DataFrame:
    """Return the first spike indicator as a dataframe."""
    return self.fetch_dataframe()[0]

fetch_dataframe()

Return all spike indicators as a list of dataframes.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch_dataframe(self) -> list[pd.DataFrame]:
    """Return all spike indicators as a list of dataframes."""
    return pd.concat(
        [
            data["spike_indicator"].set_index("time")
            for data in self.fetch_nwb()
        ],
        axis=1,
    )

SortedSpikesClassifierParameters

Bases: SpyglassMixin, Manual

Stores parameters for decoding with sorted spikes

Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesClassifierParameters(SpyglassMixin, dj.Manual):
    """Stores parameters for decoding with sorted spikes"""

    definition = """
    classifier_param_name : varchar(80) # a name for this set of parameters
    ---
    classifier_params :   BLOB    # initialization parameters
    fit_params :          BLOB    # fit parameters
    predict_params :      BLOB    # prediction parameters
    """

    def insert_default(self):
        """Insert default parameters for decoding with sorted spikes"""
        self.insert(
            [
                make_default_decoding_params(),
                make_default_decoding_params(use_gpu=True),
            ],
            skip_duplicates=True,
        )

    def insert1(self, key, **kwargs):
        """Override insert1 to convert classes to dict"""
        super().insert1(convert_classes_to_dict(key), **kwargs)

    def fetch1(self, *args, **kwargs):
        """Override fetch1 to restore classes"""
        return restore_classes(super().fetch1(*args, **kwargs))

insert_default()

Insert default parameters for decoding with sorted spikes

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def insert_default(self):
    """Insert default parameters for decoding with sorted spikes"""
    self.insert(
        [
            make_default_decoding_params(),
            make_default_decoding_params(use_gpu=True),
        ],
        skip_duplicates=True,
    )

insert1(key, **kwargs)

Override insert1 to convert classes to dict

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def insert1(self, key, **kwargs):
    """Override insert1 to convert classes to dict"""
    super().insert1(convert_classes_to_dict(key), **kwargs)

fetch1(*args, **kwargs)

Override fetch1 to restore classes

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch1(self, *args, **kwargs):
    """Override fetch1 to restore classes"""
    return restore_classes(super().fetch1(*args, **kwargs))

get_spike_indicator(key, time_range, sampling_rate=500.0)

Returns a dataframe with the spike indicator for each unit

Parameters:

Name Type Description Default
key dict
required
time_range tuple[float, float]

Start and end time of the spike indicator

required
sampling_rate float
500.0

Returns:

Name Type Description
spike_indicator (DataFrame, shape(n_time, n_units))

A dataframe with the spike indicator for each unit

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_spike_indicator(
    key: dict, time_range: tuple[float, float], sampling_rate: float = 500.0
) -> pd.DataFrame:
    """Returns a dataframe with the spike indicator for each unit

    Parameters
    ----------
    key : dict
    time_range : tuple[float, float]
        Start and end time of the spike indicator
    sampling_rate : float, optional

    Returns
    -------
    spike_indicator : pd.DataFrame, shape (n_time, n_units)
        A dataframe with the spike indicator for each unit
    """
    start_time, end_time = time_range
    n_samples = int(np.ceil((end_time - start_time) * sampling_rate)) + 1
    time = np.linspace(start_time, end_time, n_samples)

    spike_indicator = dict()
    spikes_nwb_table = CuratedSpikeSorting() & key

    for n_trode in spikes_nwb_table.fetch_nwb():
        try:
            for unit_id, unit_spike_times in n_trode["units"][
                "spike_times"
            ].items():
                unit_spike_times = unit_spike_times[
                    (unit_spike_times > time[0])
                    & (unit_spike_times <= time[-1])
                ]
                unit_name = f'{n_trode["sort_group_id"]:04d}_{unit_id:04d}'
                spike_indicator[unit_name] = np.bincount(
                    np.digitize(unit_spike_times, time[1:-1]),
                    minlength=time.shape[0],
                )
        except KeyError:
            pass

    return pd.DataFrame(
        spike_indicator,
        index=pd.Index(time, name="time"),
    )

get_decoding_data_for_epoch(nwb_file_name, interval_list_name, position_info_param_name='default', additional_spike_keys={})

Collects the data needed for decoding

Parameters:

Name Type Description Default
nwb_file_name str
required
interval_list_name str
required
position_info_param_name str
'default'
additional_spike_keys dict
{}

Returns:

Name Type Description
position_info (DataFrame, shape(n_time, n_position_features))
spikes (DataFrame, shape(n_time, n_units))
valid_slices list[slice]
Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_decoding_data_for_epoch(
    nwb_file_name: str,
    interval_list_name: str,
    position_info_param_name: str = "default",
    additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice]]:
    """Collects the data needed for decoding

    Parameters
    ----------
    nwb_file_name : str
    interval_list_name : str
    position_info_param_name : str, optional
    additional_spike_keys : dict, optional

    Returns
    -------
    position_info : pd.DataFrame, shape (n_time, n_position_features)
    spikes : pd.DataFrame, shape (n_time, n_units)
    valid_slices : list[slice]

    """

    valid_slices = convert_valid_times_to_slice(
        get_valid_ephys_position_times_by_epoch(nwb_file_name)[
            interval_list_name
        ]
    )

    # position interval
    nwb_dict = dict(nwb_file_name=nwb_file_name)
    pos_interval_dict = dict(
        nwb_dict,
        interval_list_name=convert_epoch_interval_name_to_position_interval_name(
            {
                **nwb_dict,
                "interval_list_name": interval_list_name,
            }
        ),
    )

    position_info = (
        IntervalPositionInfo()
        & {
            **pos_interval_dict,
            "position_info_param_name": position_info_param_name,
        }
    ).fetch1_dataframe()

    # spikes
    valid_times = np.asarray(
        [(times.start, times.stop) for times in valid_slices]
    )

    curated_spikes_key = {
        "nwb_file_name": nwb_file_name,
        **additional_spike_keys,
    }
    spikes = get_spike_indicator(
        curated_spikes_key,
        (valid_times.min(), valid_times.max()),
        sampling_rate=500,
    )
    spikes = pd.concat([spikes.loc[times] for times in valid_slices])

    new_time = spikes.index.to_numpy()
    new_index = pd.Index(
        np.unique(np.concatenate((position_info.index, new_time))),
        name="time",
    )
    position_info = (
        position_info.reindex(index=new_index)
        .interpolate(method="linear")
        .reindex(index=new_time)
    )

    return position_info, spikes, valid_slices

get_data_for_multiple_epochs(nwb_file_name, epoch_names, position_info_param_name='decoding', additional_spike_keys={})

Collects the data needed for decoding for multiple epochs

Parameters:

Name Type Description Default
nwb_file_name str
required
epoch_names list
required
position_info_param_name str
'decoding'
additional_spike_keys dict
{}

Returns:

Name Type Description
position_info (DataFrame, shape(n_time, n_position_features))
spikes (DataFrame, shape(n_time, n_units))
valid_slices list[slice]
environment_labels (ndarray, shape(n_time))

The environment label for each time point

sort_group_ids (ndarray, shape(n_units))

The sort group of each unit

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_data_for_multiple_epochs(
    nwb_file_name: str,
    epoch_names: list,
    position_info_param_name: str = "decoding",
    additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice], np.ndarray, np.ndarray]:
    """Collects the data needed for decoding for multiple epochs

    Parameters
    ----------
    nwb_file_name : str
    epoch_names : list
    position_info_param_name : str, optional
    additional_spike_keys : dict, optional

    Returns
    -------
    position_info : pd.DataFrame, shape (n_time, n_position_features)
    spikes : pd.DataFrame, shape (n_time, n_units)
    valid_slices : list[slice]
    environment_labels : np.ndarray, shape (n_time,)
        The environment label for each time point
    sort_group_ids : np.ndarray, shape (n_units,)
        The sort group of each unit
    """
    data = []
    environment_labels = []

    for epoch in epoch_names:
        logger.info(epoch)
        data.append(
            get_decoding_data_for_epoch(
                nwb_file_name,
                epoch,
                position_info_param_name=position_info_param_name,
                additional_spike_keys=additional_spike_keys,
            )
        )
        n_time = data[-1][0].shape[0]
        environment_labels.append([epoch] * n_time)

    environment_labels = np.concatenate(environment_labels, axis=0)
    position_info, spikes, valid_slices = list(zip(*data))
    position_info = pd.concat(position_info, axis=0)
    spikes = pd.concat(spikes, axis=0)
    valid_slices = {
        epoch: valid_slice
        for epoch, valid_slice in zip(epoch_names, valid_slices)
    }

    assert position_info.shape[0] == spikes.shape[0]

    sort_group_ids = np.asarray(
        [int(col.split("_")[0]) for col in spikes.columns]
    )

    return (
        position_info,
        spikes,
        valid_slices,
        environment_labels,
        sort_group_ids,
    )