Skip to content

artifact.py

ArtifactDetectionParameters

Bases: SpyglassMixin, Lookup

Source code in src/spyglass/spikesorting/v1/artifact.py
@schema
class ArtifactDetectionParameters(SpyglassMixin, dj.Lookup):
    definition = """
    # Parameters for detecting artifacts (non-neural high amplitude events).
    artifact_param_name : varchar(200)
    ---
    artifact_params : blob
    """

    contents = [
        [
            "default",
            {
                "zscore_thresh": None,
                "amplitude_thresh_uV": 3000,
                "proportion_above_thresh": 1.0,
                "removal_window_ms": 1.0,
                "chunk_duration": "10s",
                "n_jobs": 4,
                "progress_bar": "True",
            },
        ],
        [
            "none",
            {
                "zscore_thresh": None,
                "amplitude_thresh_uV": None,
                "chunk_duration": "10s",
                "n_jobs": 4,
                "progress_bar": "True",
            },
        ],
    ]

    @classmethod
    def insert_default(cls):
        """Insert default parameters into ArtifactDetectionParameters."""
        cls.insert(cls.contents, skip_duplicates=True)

insert_default() classmethod

Insert default parameters into ArtifactDetectionParameters.

Source code in src/spyglass/spikesorting/v1/artifact.py
@classmethod
def insert_default(cls):
    """Insert default parameters into ArtifactDetectionParameters."""
    cls.insert(cls.contents, skip_duplicates=True)

ArtifactDetectionSelection

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v1/artifact.py
@schema
class ArtifactDetectionSelection(SpyglassMixin, dj.Manual):
    definition = """
    # Processed recording/artifact detection parameters. See `insert_selection`.
    artifact_id: uuid
    ---
    -> SpikeSortingRecording
    -> ArtifactDetectionParameters
    """

    @classmethod
    def insert_selection(cls, key: dict):
        """Insert a row into ArtifactDetectionSelection.

        Automatically generates a unique artifact ID as the sole primary key.

        Parameters
        ----------
        key : dict
            primary key of SpikeSortingRecording and ArtifactDetectionParameters

        Returns
        -------
        artifact_id : str
            the unique artifact ID serving as primary key for
            ArtifactDetectionSelection
        """
        query = cls & key
        if query:
            logger.warning("Similar row(s) already inserted.")
            return query.fetch(as_dict=True)
        key["artifact_id"] = uuid.uuid4()
        cls.insert1(key, skip_duplicates=True)
        return key

insert_selection(key) classmethod

Insert a row into ArtifactDetectionSelection.

Automatically generates a unique artifact ID as the sole primary key.

Parameters:

Name Type Description Default
key dict

primary key of SpikeSortingRecording and ArtifactDetectionParameters

required

Returns:

Name Type Description
artifact_id str

the unique artifact ID serving as primary key for ArtifactDetectionSelection

Source code in src/spyglass/spikesorting/v1/artifact.py
@classmethod
def insert_selection(cls, key: dict):
    """Insert a row into ArtifactDetectionSelection.

    Automatically generates a unique artifact ID as the sole primary key.

    Parameters
    ----------
    key : dict
        primary key of SpikeSortingRecording and ArtifactDetectionParameters

    Returns
    -------
    artifact_id : str
        the unique artifact ID serving as primary key for
        ArtifactDetectionSelection
    """
    query = cls & key
    if query:
        logger.warning("Similar row(s) already inserted.")
        return query.fetch(as_dict=True)
    key["artifact_id"] = uuid.uuid4()
    cls.insert1(key, skip_duplicates=True)
    return key

ArtifactDetection

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v1/artifact.py
@schema
class ArtifactDetection(SpyglassMixin, dj.Computed):
    definition = """
    # Detected artifacts (e.g. large transients from movement).
    # Intervals are stored in IntervalList with `artifact_id` as `interval_list_name`.
    -> ArtifactDetectionSelection
    """

    def make(self, key):
        """Populate ArtifactDetection with detected artifacts.

        1. Fetches...
            - Artifact parameters from ArtifactDetectionParameters
            - Recording analysis NWB file from SpikeSortingRecording
            - Valid times from IntervalList
        2. Load the recording from the NWB file with spikeinterface
        3. Detect artifacts using module-level `_get_artifact_times`
        4. Insert result into IntervalList with `artifact_id` as
            `interval_list_name`
        """
        # FETCH:
        # - artifact parameters
        # - recording analysis nwb file
        artifact_params, recording_analysis_nwb_file = (
            ArtifactDetectionParameters
            * SpikeSortingRecording
            * ArtifactDetectionSelection
            & key
        ).fetch1("artifact_params", "analysis_file_name")
        sort_interval_valid_times = (
            IntervalList
            & {
                "nwb_file_name": (
                    SpikeSortingRecordingSelection * ArtifactDetectionSelection
                    & key
                ).fetch1("nwb_file_name"),
                "interval_list_name": (
                    SpikeSortingRecordingSelection * ArtifactDetectionSelection
                    & key
                ).fetch1("interval_list_name"),
            }
        ).fetch1("valid_times")

        # DO:
        # - load recording
        recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
            recording_analysis_nwb_file
        )
        recording = se.read_nwb_recording(
            recording_analysis_nwb_file_abs_path, load_time_vector=True
        )

        # - detect artifacts
        artifact_removed_valid_times, _ = _get_artifact_times(
            recording,
            sort_interval_valid_times,
            **artifact_params,
        )

        # INSERT
        # - into IntervalList
        IntervalList.insert1(
            dict(
                nwb_file_name=(
                    SpikeSortingRecordingSelection * ArtifactDetectionSelection
                    & key
                ).fetch1("nwb_file_name"),
                interval_list_name=str(key["artifact_id"]),
                valid_times=artifact_removed_valid_times,
                pipeline="spikesorting_artifact_v1",
            ),
            skip_duplicates=True,
        )
        # - into ArtifactRemovedInterval
        self.insert1(key)

make(key)

Populate ArtifactDetection with detected artifacts.

  1. Fetches...
    • Artifact parameters from ArtifactDetectionParameters
    • Recording analysis NWB file from SpikeSortingRecording
    • Valid times from IntervalList
  2. Load the recording from the NWB file with spikeinterface
  3. Detect artifacts using module-level _get_artifact_times
  4. Insert result into IntervalList with artifact_id as interval_list_name
Source code in src/spyglass/spikesorting/v1/artifact.py
def make(self, key):
    """Populate ArtifactDetection with detected artifacts.

    1. Fetches...
        - Artifact parameters from ArtifactDetectionParameters
        - Recording analysis NWB file from SpikeSortingRecording
        - Valid times from IntervalList
    2. Load the recording from the NWB file with spikeinterface
    3. Detect artifacts using module-level `_get_artifact_times`
    4. Insert result into IntervalList with `artifact_id` as
        `interval_list_name`
    """
    # FETCH:
    # - artifact parameters
    # - recording analysis nwb file
    artifact_params, recording_analysis_nwb_file = (
        ArtifactDetectionParameters
        * SpikeSortingRecording
        * ArtifactDetectionSelection
        & key
    ).fetch1("artifact_params", "analysis_file_name")
    sort_interval_valid_times = (
        IntervalList
        & {
            "nwb_file_name": (
                SpikeSortingRecordingSelection * ArtifactDetectionSelection
                & key
            ).fetch1("nwb_file_name"),
            "interval_list_name": (
                SpikeSortingRecordingSelection * ArtifactDetectionSelection
                & key
            ).fetch1("interval_list_name"),
        }
    ).fetch1("valid_times")

    # DO:
    # - load recording
    recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(
        recording_analysis_nwb_file
    )
    recording = se.read_nwb_recording(
        recording_analysis_nwb_file_abs_path, load_time_vector=True
    )

    # - detect artifacts
    artifact_removed_valid_times, _ = _get_artifact_times(
        recording,
        sort_interval_valid_times,
        **artifact_params,
    )

    # INSERT
    # - into IntervalList
    IntervalList.insert1(
        dict(
            nwb_file_name=(
                SpikeSortingRecordingSelection * ArtifactDetectionSelection
                & key
            ).fetch1("nwb_file_name"),
            interval_list_name=str(key["artifact_id"]),
            valid_times=artifact_removed_valid_times,
            pipeline="spikesorting_artifact_v1",
        ),
        skip_duplicates=True,
    )
    # - into ArtifactRemovedInterval
    self.insert1(key)

merge_intervals(intervals)

Takes a list of intervals each of which is [start_time, stop_time] and takes union over intervals that are intersecting

Parameters:

Name Type Description Default
intervals _type_

description

required

Returns:

Type Description
_type_

description

Source code in src/spyglass/spikesorting/v1/artifact.py
def merge_intervals(intervals):
    """Takes a list of intervals each of which is [start_time, stop_time]
    and takes union over intervals that are intersecting

    Parameters
    ----------
    intervals : _type_
        _description_

    Returns
    -------
    _type_
        _description_
    """
    # TODO: Migrate to common_interval.py

    if len(intervals) == 0:
        return []

    # Sort the intervals based on their start times
    intervals.sort(key=lambda x: x[0])

    merged = [intervals[0]]

    for i in range(1, len(intervals)):
        current_start, current_stop = intervals[i]
        last_merged_start, last_merged_stop = merged[-1]

        if current_start <= last_merged_stop:
            # Overlapping intervals, merge them
            merged[-1] = [
                last_merged_start,
                max(last_merged_stop, current_stop),
            ]
        else:
            # Non-overlapping intervals, add the current one to the list
            merged.append([current_start, current_stop])

    return np.asarray(merged)