Skip to content

merged_sorting_extractor.py

MergedSortingSegment

Bases: BaseSortingSegment

Source code in src/spyglass/spikesorting/v0/merged_sorting_extractor.py
class MergedSortingSegment(si.BaseSortingSegment):
    def __init__(self):
        """Store all the unit spike trains in RAM."""
        si.BaseSortingSegment.__init__(self)
        # Store all the unit spike trains in RAM
        self._unit_spike_trains: Dict[int, np.array] = {}

    def add_unit(self, unit_id: int, spike_times: np.array):
        """Add a unit spike train."""
        self._unit_spike_trains[unit_id] = spike_times

    def get_unit_spike_train(
        self,
        unit_id,
        start_frame: Union[int, None] = None,
        end_frame: Union[int, None] = None,
    ) -> np.ndarray:
        """Get a unit spike train."""
        spike_times = self._unit_spike_trains[unit_id]
        if start_frame is not None:
            spike_times = spike_times[spike_times >= start_frame]
        if end_frame is not None:
            spike_times = spike_times[spike_times < end_frame]
        return spike_times

__init__()

Store all the unit spike trains in RAM.

Source code in src/spyglass/spikesorting/v0/merged_sorting_extractor.py
def __init__(self):
    """Store all the unit spike trains in RAM."""
    si.BaseSortingSegment.__init__(self)
    # Store all the unit spike trains in RAM
    self._unit_spike_trains: Dict[int, np.array] = {}

add_unit(unit_id, spike_times)

Add a unit spike train.

Source code in src/spyglass/spikesorting/v0/merged_sorting_extractor.py
def add_unit(self, unit_id: int, spike_times: np.array):
    """Add a unit spike train."""
    self._unit_spike_trains[unit_id] = spike_times

get_unit_spike_train(unit_id, start_frame=None, end_frame=None)

Get a unit spike train.

Source code in src/spyglass/spikesorting/v0/merged_sorting_extractor.py
def get_unit_spike_train(
    self,
    unit_id,
    start_frame: Union[int, None] = None,
    end_frame: Union[int, None] = None,
) -> np.ndarray:
    """Get a unit spike train."""
    spike_times = self._unit_spike_trains[unit_id]
    if start_frame is not None:
        spike_times = spike_times[spike_times >= start_frame]
    if end_frame is not None:
        spike_times = spike_times[spike_times < end_frame]
    return spike_times