Skip to content

lfp_artifact_MAD_detection.py

mad_artifact_detector(recording, mad_thresh=6.0, proportion_above_thresh=0.1, removal_window_ms=10.0, sampling_frequency=1000.0, *args, **kwargs)

Detect LFP artifacts using the median absolute deviation method.

Parameters:

Name Type Description Default
recording RecordingExtractor

The recording extractor object

required
mad_thresh float

Threshold on the median absolute deviation scaled LFPs, defaults to 6.0

6.0
proportion_above_thresh float

Proportion of electrodes that need to be above the threshold, defaults to 1.0

0.1
removal_window_ms float

Width of the window in milliseconds to mask out per artifact (window/2 removed on each side of threshold crossing), defaults to 1 ms

10.0
sampling_frequency float

Sampling frequency of the recording extractor, defaults to 1000.0

1000.0

Returns:

Name Type Description
artifact_removed_valid_times ndarray

Intervals of valid times where artifacts were not detected, unit: seconds

artifact_intervals ndarray

Intervals in which artifacts are detected (including removal windows), unit: seconds

Source code in src/spyglass/lfp/v1/lfp_artifact_MAD_detection.py
def mad_artifact_detector(
    recording: None,
    mad_thresh: float = 6.0,
    proportion_above_thresh: float = 0.1,
    removal_window_ms: float = 10.0,
    sampling_frequency: float = 1000.0,
    *args,
    **kwargs,
) -> tuple[np.ndarray, np.ndarray]:
    """Detect LFP artifacts using the median absolute deviation method.

    Parameters
    ----------
    recording : RecordingExtractor
        The recording extractor object
    mad_thresh : float, optional
        Threshold on the median absolute deviation scaled LFPs, defaults to 6.0
    proportion_above_thresh : float, optional
        Proportion of electrodes that need to be above the threshold, defaults
        to 1.0
    removal_window_ms : float, optional
        Width of the window in milliseconds to mask out per artifact
        (window/2 removed on each side of threshold crossing), defaults to 1 ms
    sampling_frequency : float, optional
        Sampling frequency of the recording extractor, defaults to 1000.0

    Returns
    -------
    artifact_removed_valid_times : np.ndarray
        Intervals of valid times where artifacts were not detected,
        unit: seconds
    artifact_intervals : np.ndarray
        Intervals in which artifacts are detected (including removal windows),
        unit: seconds
    """

    timestamps = np.asarray(recording.timestamps)
    lfps = np.asarray(recording.data)

    mad = median_abs_deviation(lfps, axis=0, nan_policy="omit", scale="normal")
    is_artifact = _is_above_proportion_thresh(
        _mad_scale_lfps(lfps, mad), mad_thresh, proportion_above_thresh
    )

    MILLISECONDS_PER_SECOND = 1000.0
    half_removal_window_s = (removal_window_ms / MILLISECONDS_PER_SECOND) * 0.5
    half_removal_window_idx = int(half_removal_window_s * sampling_frequency)
    is_artifact = _extend_array_by_window(is_artifact, half_removal_window_idx)

    artifact_intervals_s = np.array(
        _get_time_intervals_from_bool_array(is_artifact, timestamps)
    )

    valid_times = np.array(
        _get_time_intervals_from_bool_array(~is_artifact, timestamps)
    )

    return valid_times, artifact_intervals_s