Skip to content

waveform_features.py

WaveformFeaturesParams

Bases: SpyglassMixin, Lookup

Defines types of waveform features computed for a given spike time.

Source code in src/spyglass/decoding/v1/waveform_features.py
@schema
class WaveformFeaturesParams(SpyglassMixin, dj.Lookup):
    """Defines types of waveform features computed for a given spike time."""

    definition = """
    features_param_name : varchar(80) # a name for this set of parameters
    ---
    params : longblob # the parameters for the waveform features
    """
    _default_waveform_feature_params = {
        "amplitude": {
            "peak_sign": "neg",
            "estimate_peak_time": False,
        }
    }
    _default_waveform_extract_params = {
        "ms_before": 0.5,
        "ms_after": 0.5,
        "max_spikes_per_unit": None,
        "n_jobs": 5,
        "chunk_duration": "1000s",
    }
    contents = [
        [
            "amplitude",
            {
                "waveform_features_params": _default_waveform_feature_params,
                "waveform_extraction_params": _default_waveform_extract_params,
            },
        ],
        [
            "amplitude, spike_location",
            {
                "waveform_features_params": {
                    "amplitude": _default_waveform_feature_params["amplitude"],
                    "spike_location": {},
                },
                "waveform_extraction_params": _default_waveform_extract_params,
            },
        ],
    ]

    @classmethod
    def insert_default(cls):
        """Insert default waveform features parameters"""
        cls.insert(cls.contents, skip_duplicates=True)

    @staticmethod
    def check_supported_waveform_features(waveform_features: list[str]) -> bool:
        """Checks whether the requested waveform features types are supported

        Parameters
        ----------
        waveform_features : list
        """
        supported_features = set(WAVEFORM_FEATURE_FUNCTIONS)
        return set(waveform_features).issubset(supported_features)

    @property
    def supported_waveform_features(self) -> list[str]:
        """Returns the list of supported waveform features"""
        return list(WAVEFORM_FEATURE_FUNCTIONS)

insert_default() classmethod

Insert default waveform features parameters

Source code in src/spyglass/decoding/v1/waveform_features.py
@classmethod
def insert_default(cls):
    """Insert default waveform features parameters"""
    cls.insert(cls.contents, skip_duplicates=True)

check_supported_waveform_features(waveform_features) staticmethod

Checks whether the requested waveform features types are supported

Parameters:

Name Type Description Default
waveform_features list
required
Source code in src/spyglass/decoding/v1/waveform_features.py
@staticmethod
def check_supported_waveform_features(waveform_features: list[str]) -> bool:
    """Checks whether the requested waveform features types are supported

    Parameters
    ----------
    waveform_features : list
    """
    supported_features = set(WAVEFORM_FEATURE_FUNCTIONS)
    return set(waveform_features).issubset(supported_features)

supported_waveform_features: list[str] property

Returns the list of supported waveform features

UnitWaveformFeatures

Bases: SpyglassMixin, Computed

For each spike time, compute waveform feature associated with that spike.

Used for clusterless decoding.

Source code in src/spyglass/decoding/v1/waveform_features.py
@schema
class UnitWaveformFeatures(SpyglassMixin, dj.Computed):
    """For each spike time, compute waveform feature associated with that spike.

    Used for clusterless decoding.
    """

    definition = """
    -> UnitWaveformFeaturesSelection
    ---
    -> AnalysisNwbfile
    object_id: varchar(40) # the NWB object that stores the waveforms
    """

    _parallel_make = True

    def make(self, key):
        """Populate UnitWaveformFeatures table."""
        AnalysisNwbfile()._creation_times["pre_create_time"] = time()
        # get the list of feature parameters
        params = (WaveformFeaturesParams & key).fetch1("params")

        # check that the feature type is supported
        if not WaveformFeaturesParams.check_supported_waveform_features(
            params["waveform_features_params"]
        ):
            raise NotImplementedError(
                f"Features {set(params['waveform_features_params'])} are "
                + "not supported"
            )

        merge_key = {"merge_id": key["spikesorting_merge_id"]}
        waveform_extractor = self._fetch_waveform(
            merge_key, params["waveform_extraction_params"]
        )

        source_key = SpikeSortingOutput().merge_get_parent(merge_key).fetch1()
        # v0 pipeline
        if "sorter" in source_key and "nwb_file_name" in source_key:
            sorter = source_key["sorter"]
            nwb_file_name = source_key["nwb_file_name"]
            analysis_nwb_key = "units"
        # v1 pipeline
        else:
            sorting_id = (SpikeSortingOutput.CurationV1 & merge_key).fetch1(
                "sorting_id"
            )
            sorter, nwb_file_name = (
                SpikeSortingSelection & {"sorting_id": sorting_id}
            ).fetch1("sorter", "nwb_file_name")
            analysis_nwb_key = "object_id"

        waveform_features = {}

        for feature, feature_params in params[
            "waveform_features_params"
        ].items():
            waveform_features[feature] = self._compute_waveform_features(
                waveform_extractor,
                feature,
                feature_params,
                sorter,
            )

        nwb = SpikeSortingOutput().fetch_nwb(merge_key)[0]
        spike_times = (
            nwb[analysis_nwb_key]["spike_times"]
            if analysis_nwb_key in nwb
            else pd.DataFrame()
        )

        (
            key["analysis_file_name"],
            key["object_id"],
        ) = _write_waveform_features_to_nwb(
            nwb_file_name,
            waveform_extractor,
            spike_times,
            waveform_features,
        )

        AnalysisNwbfile().add(
            nwb_file_name,
            key["analysis_file_name"],
        )
        AnalysisNwbfile().log(key, table=self.full_table_name)

        self.insert1(key)

    @staticmethod
    def _fetch_waveform(
        merge_key: dict, waveform_extraction_params: dict
    ) -> si.WaveformExtractor:
        # get the recording from the parent table
        recording = SpikeSortingOutput().get_recording(merge_key)
        if recording.get_num_segments() > 1:
            recording = si.concatenate_recordings([recording])
        # get the sorting from the parent table
        sorting = SpikeSortingOutput().get_sorting(merge_key)

        waveforms_temp_dir = temp_dir + "/" + str(merge_key["merge_id"])
        os.makedirs(waveforms_temp_dir, exist_ok=True)

        return si.extract_waveforms(
            recording=recording,
            sorting=sorting,
            folder=waveforms_temp_dir,
            overwrite=True,
            **waveform_extraction_params,
        )

    @staticmethod
    def _compute_waveform_features(
        waveform_extractor: si.WaveformExtractor,
        feature: str,
        feature_params: dict,
        sorter: str,
    ) -> dict:
        feature_func = WAVEFORM_FEATURE_FUNCTIONS[feature]
        if sorter == "clusterless_thresholder" and feature == "amplitude":
            feature_params["estimate_peak_time"] = False

        return {
            unit_id: feature_func(waveform_extractor, unit_id, **feature_params)
            for unit_id in waveform_extractor.sorting.get_unit_ids()
        }

    def fetch_data(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
        """Fetches the spike times and features for each unit.

        Returns
        -------
        spike_times : list of np.ndarray
            List of spike times for each unit
        features : list of np.ndarray
            List of features for each unit

        """
        return tuple(
            zip(
                *list(
                    chain(
                        *[self._convert_data(data) for data in self.fetch_nwb()]
                    )
                )
            )
        )

    @staticmethod
    def _convert_data(nwb_data) -> list[tuple[np.ndarray, np.ndarray]]:
        feature_df = nwb_data["object_id"]

        feature_columns = [
            column for column in feature_df.columns if column != "spike_times"
        ]

        return [
            (
                unit.spike_times,
                np.concatenate(unit[feature_columns].to_numpy(), axis=1),
            )
            for _, unit in feature_df.iterrows()
        ]

make(key)

Populate UnitWaveformFeatures table.

Source code in src/spyglass/decoding/v1/waveform_features.py
def make(self, key):
    """Populate UnitWaveformFeatures table."""
    AnalysisNwbfile()._creation_times["pre_create_time"] = time()
    # get the list of feature parameters
    params = (WaveformFeaturesParams & key).fetch1("params")

    # check that the feature type is supported
    if not WaveformFeaturesParams.check_supported_waveform_features(
        params["waveform_features_params"]
    ):
        raise NotImplementedError(
            f"Features {set(params['waveform_features_params'])} are "
            + "not supported"
        )

    merge_key = {"merge_id": key["spikesorting_merge_id"]}
    waveform_extractor = self._fetch_waveform(
        merge_key, params["waveform_extraction_params"]
    )

    source_key = SpikeSortingOutput().merge_get_parent(merge_key).fetch1()
    # v0 pipeline
    if "sorter" in source_key and "nwb_file_name" in source_key:
        sorter = source_key["sorter"]
        nwb_file_name = source_key["nwb_file_name"]
        analysis_nwb_key = "units"
    # v1 pipeline
    else:
        sorting_id = (SpikeSortingOutput.CurationV1 & merge_key).fetch1(
            "sorting_id"
        )
        sorter, nwb_file_name = (
            SpikeSortingSelection & {"sorting_id": sorting_id}
        ).fetch1("sorter", "nwb_file_name")
        analysis_nwb_key = "object_id"

    waveform_features = {}

    for feature, feature_params in params[
        "waveform_features_params"
    ].items():
        waveform_features[feature] = self._compute_waveform_features(
            waveform_extractor,
            feature,
            feature_params,
            sorter,
        )

    nwb = SpikeSortingOutput().fetch_nwb(merge_key)[0]
    spike_times = (
        nwb[analysis_nwb_key]["spike_times"]
        if analysis_nwb_key in nwb
        else pd.DataFrame()
    )

    (
        key["analysis_file_name"],
        key["object_id"],
    ) = _write_waveform_features_to_nwb(
        nwb_file_name,
        waveform_extractor,
        spike_times,
        waveform_features,
    )

    AnalysisNwbfile().add(
        nwb_file_name,
        key["analysis_file_name"],
    )
    AnalysisNwbfile().log(key, table=self.full_table_name)

    self.insert1(key)

fetch_data()

Fetches the spike times and features for each unit.

Returns:

Name Type Description
spike_times list of np.ndarray

List of spike times for each unit

features list of np.ndarray

List of features for each unit

Source code in src/spyglass/decoding/v1/waveform_features.py
def fetch_data(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
    """Fetches the spike times and features for each unit.

    Returns
    -------
    spike_times : list of np.ndarray
        List of spike times for each unit
    features : list of np.ndarray
        List of features for each unit

    """
    return tuple(
        zip(
            *list(
                chain(
                    *[self._convert_data(data) for data in self.fetch_nwb()]
                )
            )
        )
    )