Skip to content

spikesorting_artifact.py

ArtifactDetectionParameters

Bases: SpyglassMixin, Manual

Source code in src/spyglass/spikesorting/v0/spikesorting_artifact.py
@schema
class ArtifactDetectionParameters(SpyglassMixin, dj.Manual):
    definition = """
    # Parameters for detecting artifact times within a sort group.
    artifact_params_name: varchar(200)
    ---
    artifact_params: blob  # dictionary of parameters
    """

    def insert_default(self):
        """Insert the default artifact parameters with an appropriate parameter dict."""
        artifact_params = {}
        artifact_params["zscore_thresh"] = None  # must be None or >= 0
        artifact_params["amplitude_thresh"] = 3000  # must be None or >= 0
        # all electrodes of sort group
        artifact_params["proportion_above_thresh"] = 1.0
        artifact_params["removal_window_ms"] = 1.0  # in milliseconds
        self.insert1(["default", artifact_params], skip_duplicates=True)

        artifact_params_none = {}
        artifact_params_none["zscore_thresh"] = None
        artifact_params_none["amplitude_thresh"] = None
        self.insert1(["none", artifact_params_none], skip_duplicates=True)

insert_default()

Insert the default artifact parameters with an appropriate parameter dict.

Source code in src/spyglass/spikesorting/v0/spikesorting_artifact.py
def insert_default(self):
    """Insert the default artifact parameters with an appropriate parameter dict."""
    artifact_params = {}
    artifact_params["zscore_thresh"] = None  # must be None or >= 0
    artifact_params["amplitude_thresh"] = 3000  # must be None or >= 0
    # all electrodes of sort group
    artifact_params["proportion_above_thresh"] = 1.0
    artifact_params["removal_window_ms"] = 1.0  # in milliseconds
    self.insert1(["default", artifact_params], skip_duplicates=True)

    artifact_params_none = {}
    artifact_params_none["zscore_thresh"] = None
    artifact_params_none["amplitude_thresh"] = None
    self.insert1(["none", artifact_params_none], skip_duplicates=True)

ArtifactDetection

Bases: SpyglassMixin, Computed

Source code in src/spyglass/spikesorting/v0/spikesorting_artifact.py
@schema
class ArtifactDetection(SpyglassMixin, dj.Computed):
    definition = """
    # Stores artifact times and valid no-artifact times as intervals.
    -> ArtifactDetectionSelection
    ---
    artifact_times: longblob # np array of artifact intervals
    artifact_removed_valid_times: longblob # np array of valid no-artifact intervals
    artifact_removed_interval_list_name: varchar(200) # name of the array of no-artifact valid time intervals
    """

    _parallel_make = True

    def make(self, key):
        """Populate the ArtifactDetection table.

        If custom_artifact_detection is set in selection table, do nothing.

        Fetches...
            - Parameters from ArtifactDetectionParameters
            - Recording from SpikeSortingRecording (loads with spikeinterface)
        Uses module-level function _get_artifact_times to detect artifacts.
        """
        if not (ArtifactDetectionSelection & key).fetch1(
            "custom_artifact_detection"
        ):
            # get the dict of artifact params associated with this artifact_params_name
            artifact_params = (ArtifactDetectionParameters & key).fetch1(
                "artifact_params"
            )

            recording_path = (SpikeSortingRecording & key).fetch1(
                "recording_path"
            )
            recording_name = SpikeSortingRecording._get_recording_name(key)
            recording = si.load_extractor(recording_path)

            job_kwargs = {
                "chunk_duration": "10s",
                "n_jobs": 4,
                "progress_bar": "True",
            }

            artifact_removed_valid_times, artifact_times = _get_artifact_times(
                recording, **artifact_params, **job_kwargs
            )

            key["artifact_times"] = artifact_times
            key["artifact_removed_valid_times"] = artifact_removed_valid_times

            # set up a name for no-artifact times using recording id
            key["artifact_removed_interval_list_name"] = (
                recording_name
                + "_"
                + key["artifact_params_name"]
                + "_artifact_removed_valid_times"
            )

            ArtifactRemovedIntervalList.insert1(key, replace=True)

            # also insert into IntervalList
            tmp_key = {}
            tmp_key["nwb_file_name"] = key["nwb_file_name"]
            tmp_key["interval_list_name"] = key[
                "artifact_removed_interval_list_name"
            ]
            tmp_key["valid_times"] = key["artifact_removed_valid_times"]
            tmp_key["pipeline"] = "spikesorting_artifact_v0"
            IntervalList.insert1(tmp_key, replace=True)

            # insert into computed table
            self.insert1(key)

make(key)

Populate the ArtifactDetection table.

If custom_artifact_detection is set in selection table, do nothing.

Fetches... - Parameters from ArtifactDetectionParameters - Recording from SpikeSortingRecording (loads with spikeinterface) Uses module-level function _get_artifact_times to detect artifacts.

Source code in src/spyglass/spikesorting/v0/spikesorting_artifact.py
def make(self, key):
    """Populate the ArtifactDetection table.

    If custom_artifact_detection is set in selection table, do nothing.

    Fetches...
        - Parameters from ArtifactDetectionParameters
        - Recording from SpikeSortingRecording (loads with spikeinterface)
    Uses module-level function _get_artifact_times to detect artifacts.
    """
    if not (ArtifactDetectionSelection & key).fetch1(
        "custom_artifact_detection"
    ):
        # get the dict of artifact params associated with this artifact_params_name
        artifact_params = (ArtifactDetectionParameters & key).fetch1(
            "artifact_params"
        )

        recording_path = (SpikeSortingRecording & key).fetch1(
            "recording_path"
        )
        recording_name = SpikeSortingRecording._get_recording_name(key)
        recording = si.load_extractor(recording_path)

        job_kwargs = {
            "chunk_duration": "10s",
            "n_jobs": 4,
            "progress_bar": "True",
        }

        artifact_removed_valid_times, artifact_times = _get_artifact_times(
            recording, **artifact_params, **job_kwargs
        )

        key["artifact_times"] = artifact_times
        key["artifact_removed_valid_times"] = artifact_removed_valid_times

        # set up a name for no-artifact times using recording id
        key["artifact_removed_interval_list_name"] = (
            recording_name
            + "_"
            + key["artifact_params_name"]
            + "_artifact_removed_valid_times"
        )

        ArtifactRemovedIntervalList.insert1(key, replace=True)

        # also insert into IntervalList
        tmp_key = {}
        tmp_key["nwb_file_name"] = key["nwb_file_name"]
        tmp_key["interval_list_name"] = key[
            "artifact_removed_interval_list_name"
        ]
        tmp_key["valid_times"] = key["artifact_removed_valid_times"]
        tmp_key["pipeline"] = "spikesorting_artifact_v0"
        IntervalList.insert1(tmp_key, replace=True)

        # insert into computed table
        self.insert1(key)