Skip to content

sorted_spikes.py

Pipeline for decoding the animal's mental position and some category of interest from clustered spikes times. See [1] for details.

References

[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021).

SortedSpikesIndicatorSelection

Bases: SpyglassMixin, Lookup

Bins spike times into regular intervals given by the sampling rate.

Start and stop time of the interval are defined by the interval list.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesIndicatorSelection(SpyglassMixin, dj.Lookup):
    """Bins spike times into regular intervals given by the sampling rate.

    Start and stop time of the interval are defined by the interval list.
    """

    definition = """
    -> CuratedSpikeSorting
    -> IntervalList
    sampling_rate=500 : float
    ---
    """

SortedSpikesIndicator

Bases: SpyglassMixin, Computed

Bins spike times into regular intervals given by the sampling rate.

Useful for GLMs and for decoding.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesIndicator(SpyglassMixin, dj.Computed):
    """Bins spike times into regular intervals given by the sampling rate.

    Useful for GLMs and for decoding.
    """

    definition = """
    -> SortedSpikesIndicatorSelection
    ---
    -> AnalysisNwbfile
    spike_indicator_object_id: varchar(40)
    """

    def make(self, key):
        """Populate the SortedSpikesIndicator table.

        Fetches the spike times from the CuratedSpikeSorting table and bins
        them into regular intervals given by the sampling rate. The spike
        indicator is stored in an AnalysisNwbfile.
        """
        pprint.pprint(key)
        # TODO: intersection of sort interval and interval list
        interval_times = (IntervalList & key).fetch1("valid_times")

        sampling_rate = (SortedSpikesIndicatorSelection & key).fetch(
            "sampling_rate"
        )

        time = get_time_bins_from_interval(interval_times, sampling_rate)

        spikes_nwb = (CuratedSpikeSorting & key).fetch_nwb()
        # restrict to cases with units
        spikes_nwb = [entry for entry in spikes_nwb if "units" in entry]
        spike_times_list = [
            np.asarray(n_trode["units"]["spike_times"])
            for n_trode in spikes_nwb
        ]
        if len(spike_times_list) > 0:  # if units
            spikes = np.concatenate(spike_times_list)

            # Bin spikes into time bins
            spike_indicator = []
            for spike_times in spikes:
                spike_times = spike_times[
                    (spike_times > time[0]) & (spike_times <= time[-1])
                ]
                spike_indicator.append(
                    np.bincount(
                        np.digitize(spike_times, time[1:-1]),
                        minlength=time.shape[0],
                    )
                )

            column_names = np.concatenate(
                [
                    [
                        f'{n_trode["sort_group_id"]:04d}_{unit_number:04d}'
                        for unit_number in n_trode["units"].index
                    ]
                    for n_trode in spikes_nwb
                ]
            )
            spike_indicator = pd.DataFrame(
                np.stack(spike_indicator, axis=1),
                index=pd.Index(time, name="time"),
                columns=column_names,
            )

            # Insert into analysis nwb file
            nwb_analysis_file = AnalysisNwbfile()
            key["analysis_file_name"] = nwb_analysis_file.create(
                key["nwb_file_name"]
            )

            key["spike_indicator_object_id"] = nwb_analysis_file.add_nwb_object(
                analysis_file_name=key["analysis_file_name"],
                nwb_object=spike_indicator.reset_index(),
            )

            nwb_analysis_file.add(
                nwb_file_name=key["nwb_file_name"],
                analysis_file_name=key["analysis_file_name"],
            )

            self.insert1(key)

    def fetch1_dataframe(self) -> pd.DataFrame:
        """Return the first spike indicator as a dataframe."""
        return self.fetch_dataframe()[0]

    def fetch_dataframe(self) -> list[pd.DataFrame]:
        """Return all spike indicators as a list of dataframes."""
        return pd.concat(
            [
                data["spike_indicator"].set_index("time")
                for data in self.fetch_nwb()
            ],
            axis=1,
        )

make(key)

Populate the SortedSpikesIndicator table.

Fetches the spike times from the CuratedSpikeSorting table and bins them into regular intervals given by the sampling rate. The spike indicator is stored in an AnalysisNwbfile.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def make(self, key):
    """Populate the SortedSpikesIndicator table.

    Fetches the spike times from the CuratedSpikeSorting table and bins
    them into regular intervals given by the sampling rate. The spike
    indicator is stored in an AnalysisNwbfile.
    """
    pprint.pprint(key)
    # TODO: intersection of sort interval and interval list
    interval_times = (IntervalList & key).fetch1("valid_times")

    sampling_rate = (SortedSpikesIndicatorSelection & key).fetch(
        "sampling_rate"
    )

    time = get_time_bins_from_interval(interval_times, sampling_rate)

    spikes_nwb = (CuratedSpikeSorting & key).fetch_nwb()
    # restrict to cases with units
    spikes_nwb = [entry for entry in spikes_nwb if "units" in entry]
    spike_times_list = [
        np.asarray(n_trode["units"]["spike_times"])
        for n_trode in spikes_nwb
    ]
    if len(spike_times_list) > 0:  # if units
        spikes = np.concatenate(spike_times_list)

        # Bin spikes into time bins
        spike_indicator = []
        for spike_times in spikes:
            spike_times = spike_times[
                (spike_times > time[0]) & (spike_times <= time[-1])
            ]
            spike_indicator.append(
                np.bincount(
                    np.digitize(spike_times, time[1:-1]),
                    minlength=time.shape[0],
                )
            )

        column_names = np.concatenate(
            [
                [
                    f'{n_trode["sort_group_id"]:04d}_{unit_number:04d}'
                    for unit_number in n_trode["units"].index
                ]
                for n_trode in spikes_nwb
            ]
        )
        spike_indicator = pd.DataFrame(
            np.stack(spike_indicator, axis=1),
            index=pd.Index(time, name="time"),
            columns=column_names,
        )

        # Insert into analysis nwb file
        nwb_analysis_file = AnalysisNwbfile()
        key["analysis_file_name"] = nwb_analysis_file.create(
            key["nwb_file_name"]
        )

        key["spike_indicator_object_id"] = nwb_analysis_file.add_nwb_object(
            analysis_file_name=key["analysis_file_name"],
            nwb_object=spike_indicator.reset_index(),
        )

        nwb_analysis_file.add(
            nwb_file_name=key["nwb_file_name"],
            analysis_file_name=key["analysis_file_name"],
        )

        self.insert1(key)

fetch1_dataframe()

Return the first spike indicator as a dataframe.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch1_dataframe(self) -> pd.DataFrame:
    """Return the first spike indicator as a dataframe."""
    return self.fetch_dataframe()[0]

fetch_dataframe()

Return all spike indicators as a list of dataframes.

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch_dataframe(self) -> list[pd.DataFrame]:
    """Return all spike indicators as a list of dataframes."""
    return pd.concat(
        [
            data["spike_indicator"].set_index("time")
            for data in self.fetch_nwb()
        ],
        axis=1,
    )

SortedSpikesClassifierParameters

Bases: SpyglassMixin, Manual

Stores parameters for decoding with sorted spikes

Parameters:

Name Type Description Default
classifier_param_name

A name for this set of parameters

required
classifier_params

Initialization parameters, including ... environments: list observation_models continuous_transition_types discrete_transition_type: DiagonalDiscrete initial_conditions_type: UniformInitialConditions infer_track_interior: bool clusterless_algorithm: str, optional clusterless_algorithm_params: dict, optional sorted_spikes_algorithm: str, optional sorted_spikes_algorithm_params: dict, optional For more information, see replay_trajectory_classification documentation

required
fit_params
required
predict_params

Prediction parameters, including ... is_compute_acausal: bool use_gpu: bool state_names: List[str]

required
Source code in src/spyglass/decoding/v0/sorted_spikes.py
@schema
class SortedSpikesClassifierParameters(SpyglassMixin, dj.Manual):
    """Stores parameters for decoding with sorted spikes

    Parameters
    ----------
    classifier_param_name: str
        A name for this set of parameters
    classifier_params: dict
        Initialization parameters, including ...
            environments: list
            observation_models
            continuous_transition_types
            discrete_transition_type: DiagonalDiscrete
            initial_conditions_type: UniformInitialConditions
            infer_track_interior: bool
            clusterless_algorithm: str, optional
            clusterless_algorithm_params: dict, optional
            sorted_spikes_algorithm: str, optional
            sorted_spikes_algorithm_params: dict, optional
        For more information, see replay_trajectory_classification documentation
    fit_params: dict, optional
    predict_params: dict, optional
        Prediction parameters, including ...
            is_compute_acausal: bool
            use_gpu: bool
            state_names: List[str]
    """

    definition = """
    classifier_param_name : varchar(80) # a name for this set of parameters
    ---
    classifier_params :   BLOB    # initialization parameters
    fit_params :          BLOB    # fit parameters
    predict_params :      BLOB    # prediction parameters
    """

    def insert_default(self):
        """Insert default parameters for decoding with sorted spikes"""
        self.insert(
            [
                make_default_decoding_params(),
                make_default_decoding_params(use_gpu=True),
            ],
            skip_duplicates=True,
        )

    def insert1(self, key, **kwargs):
        """Override insert1 to convert classes to dict"""
        super().insert1(convert_classes_to_dict(key), **kwargs)

    def fetch1(self, *args, **kwargs):
        """Override fetch1 to restore classes"""
        return restore_classes(super().fetch1(*args, **kwargs))

insert_default()

Insert default parameters for decoding with sorted spikes

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def insert_default(self):
    """Insert default parameters for decoding with sorted spikes"""
    self.insert(
        [
            make_default_decoding_params(),
            make_default_decoding_params(use_gpu=True),
        ],
        skip_duplicates=True,
    )

insert1(key, **kwargs)

Override insert1 to convert classes to dict

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def insert1(self, key, **kwargs):
    """Override insert1 to convert classes to dict"""
    super().insert1(convert_classes_to_dict(key), **kwargs)

fetch1(*args, **kwargs)

Override fetch1 to restore classes

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def fetch1(self, *args, **kwargs):
    """Override fetch1 to restore classes"""
    return restore_classes(super().fetch1(*args, **kwargs))

get_spike_indicator(key, time_range, sampling_rate=500.0)

Returns a dataframe with the spike indicator for each unit

Parameters:

Name Type Description Default
key dict
required
time_range tuple[float, float]

Start and end time of the spike indicator

required
sampling_rate float
500.0

Returns:

Name Type Description
spike_indicator (DataFrame, shape(n_time, n_units))

A dataframe with the spike indicator for each unit

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_spike_indicator(
    key: dict, time_range: tuple[float, float], sampling_rate: float = 500.0
) -> pd.DataFrame:
    """Returns a dataframe with the spike indicator for each unit

    Parameters
    ----------
    key : dict
    time_range : tuple[float, float]
        Start and end time of the spike indicator
    sampling_rate : float, optional

    Returns
    -------
    spike_indicator : pd.DataFrame, shape (n_time, n_units)
        A dataframe with the spike indicator for each unit
    """
    start_time, end_time = time_range
    n_samples = int(np.ceil((end_time - start_time) * sampling_rate)) + 1
    time = np.linspace(start_time, end_time, n_samples)

    spike_indicator = dict()
    spikes_nwb_table = CuratedSpikeSorting() & key

    for n_trode in spikes_nwb_table.fetch_nwb():
        try:
            for unit_id, unit_spike_times in n_trode["units"][
                "spike_times"
            ].items():
                unit_spike_times = unit_spike_times[
                    (unit_spike_times > time[0])
                    & (unit_spike_times <= time[-1])
                ]
                unit_name = f'{n_trode["sort_group_id"]:04d}_{unit_id:04d}'
                spike_indicator[unit_name] = np.bincount(
                    np.digitize(unit_spike_times, time[1:-1]),
                    minlength=time.shape[0],
                )
        except KeyError:
            pass

    return pd.DataFrame(
        spike_indicator,
        index=pd.Index(time, name="time"),
    )

get_decoding_data_for_epoch(nwb_file_name, interval_list_name, position_info_param_name='default', additional_spike_keys={})

Collects the data needed for decoding

Parameters:

Name Type Description Default
nwb_file_name str
required
interval_list_name str
required
position_info_param_name str
'default'
additional_spike_keys dict
{}

Returns:

Name Type Description
position_info (DataFrame, shape(n_time, n_position_features))
spikes (DataFrame, shape(n_time, n_units))
valid_slices list[slice]
Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_decoding_data_for_epoch(
    nwb_file_name: str,
    interval_list_name: str,
    position_info_param_name: str = "default",
    additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice]]:
    """Collects the data needed for decoding

    Parameters
    ----------
    nwb_file_name : str
    interval_list_name : str
    position_info_param_name : str, optional
    additional_spike_keys : dict, optional

    Returns
    -------
    position_info : pd.DataFrame, shape (n_time, n_position_features)
    spikes : pd.DataFrame, shape (n_time, n_units)
    valid_slices : list[slice]

    """

    valid_slices = convert_valid_times_to_slice(
        get_valid_ephys_position_times_by_epoch(nwb_file_name)[
            interval_list_name
        ]
    )

    # position interval
    nwb_dict = dict(nwb_file_name=nwb_file_name)
    pos_interval_dict = dict(
        nwb_dict,
        interval_list_name=convert_epoch_interval_name_to_position_interval_name(
            {
                **nwb_dict,
                "interval_list_name": interval_list_name,
            }
        ),
    )

    position_info = (
        IntervalPositionInfo()
        & {
            **pos_interval_dict,
            "position_info_param_name": position_info_param_name,
        }
    ).fetch1_dataframe()

    # spikes
    valid_times = np.asarray(
        [(times.start, times.stop) for times in valid_slices]
    )

    curated_spikes_key = {
        "nwb_file_name": nwb_file_name,
        **additional_spike_keys,
    }
    spikes = get_spike_indicator(
        curated_spikes_key,
        (valid_times.min(), valid_times.max()),
        sampling_rate=500,
    )
    spikes = pd.concat([spikes.loc[times] for times in valid_slices])

    new_time = spikes.index.to_numpy()
    new_index = pd.Index(
        np.unique(np.concatenate((position_info.index, new_time))),
        name="time",
    )
    position_info = (
        position_info.reindex(index=new_index)
        .interpolate(method="linear")
        .reindex(index=new_time)
    )

    return position_info, spikes, valid_slices

get_data_for_multiple_epochs(nwb_file_name, epoch_names, position_info_param_name='decoding', additional_spike_keys={})

Collects the data needed for decoding for multiple epochs

Parameters:

Name Type Description Default
nwb_file_name str
required
epoch_names list
required
position_info_param_name str
'decoding'
additional_spike_keys dict
{}

Returns:

Name Type Description
position_info (DataFrame, shape(n_time, n_position_features))
spikes (DataFrame, shape(n_time, n_units))
valid_slices list[slice]
environment_labels (ndarray, shape(n_time))

The environment label for each time point

sort_group_ids (ndarray, shape(n_units))

The sort group of each unit

Source code in src/spyglass/decoding/v0/sorted_spikes.py
def get_data_for_multiple_epochs(
    nwb_file_name: str,
    epoch_names: list,
    position_info_param_name: str = "decoding",
    additional_spike_keys: dict = {},
) -> tuple[pd.DataFrame, pd.DataFrame, list[slice], np.ndarray, np.ndarray]:
    """Collects the data needed for decoding for multiple epochs

    Parameters
    ----------
    nwb_file_name : str
    epoch_names : list
    position_info_param_name : str, optional
    additional_spike_keys : dict, optional

    Returns
    -------
    position_info : pd.DataFrame, shape (n_time, n_position_features)
    spikes : pd.DataFrame, shape (n_time, n_units)
    valid_slices : list[slice]
    environment_labels : np.ndarray, shape (n_time,)
        The environment label for each time point
    sort_group_ids : np.ndarray, shape (n_units,)
        The sort group of each unit
    """
    data = []
    environment_labels = []

    for epoch in epoch_names:
        logger.info(epoch)
        data.append(
            get_decoding_data_for_epoch(
                nwb_file_name,
                epoch,
                position_info_param_name=position_info_param_name,
                additional_spike_keys=additional_spike_keys,
            )
        )
        n_time = data[-1][0].shape[0]
        environment_labels.append([epoch] * n_time)

    environment_labels = np.concatenate(environment_labels, axis=0)
    position_info, spikes, valid_slices = list(zip(*data))
    position_info = pd.concat(position_info, axis=0)
    spikes = pd.concat(spikes, axis=0)
    valid_slices = {
        epoch: valid_slice
        for epoch, valid_slice in zip(epoch_names, valid_slices)
    }

    assert position_info.shape[0] == spikes.shape[0]

    sort_group_ids = np.asarray(
        [int(col.split("_")[0]) for col in spikes.columns]
    )

    return (
        position_info,
        spikes,
        valid_slices,
        environment_labels,
        sort_group_ids,
    )