Skip to content

core.py

get_valid_ephys_position_times_from_interval(interval_list_name, nwb_file_name)

Returns the intersection of valid times across ephys and position data.

Finds the intersection of the valid times for the interval list, the valid times for the ephys data, and the valid times for the position data.

Parameters:

Name Type Description Default
interval_list_name str
required
nwb_file_name str
required

Returns:

Name Type Description
valid_ephys_position_times (ndarray, shape(n_valid_times, 2))
Source code in src/spyglass/decoding/v0/core.py
def get_valid_ephys_position_times_from_interval(
    interval_list_name: str, nwb_file_name: str
) -> np.ndarray:
    """Returns the intersection of valid times across ephys and position data.

    Finds the intersection of the valid times for the interval list, the
    valid times for the ephys data, and the valid times for the position data.

    Parameters
    ----------
    interval_list_name : str
    nwb_file_name : str

    Returns
    -------
    valid_ephys_position_times : np.ndarray, shape (n_valid_times, 2)

    """
    interval_valid_times = (
        IntervalList
        & {
            "nwb_file_name": nwb_file_name,
            "interval_list_name": interval_list_name,
        }
    ).fetch1("valid_times")

    position_interval_names = (
        RawPosition
        & {
            "nwb_file_name": nwb_file_name,
        }
    ).fetch("interval_list_name")
    position_interval_names = position_interval_names[
        np.argsort(
            [
                int(name.strip("pos valid time"))
                for name in position_interval_names
            ]
        )
    ]
    valid_pos_times = [
        (
            IntervalList
            & {
                "nwb_file_name": nwb_file_name,
                "interval_list_name": pos_interval_name,
            }
        ).fetch1("valid_times")
        for pos_interval_name in position_interval_names
    ]

    valid_ephys_times = (
        IntervalList
        & {
            "nwb_file_name": nwb_file_name,
            "interval_list_name": "raw data valid times",
        }
    ).fetch1("valid_times")

    return interval_list_intersect(
        interval_list_intersect(interval_valid_times, valid_ephys_times),
        np.concatenate(valid_pos_times),
    )

get_epoch_interval_names(nwb_file_name)

Find the interval names that are epochs.

Parameters:

Name Type Description Default
nwb_file_name str
required

Returns:

Name Type Description
epoch_names list[str]

List of interval names that are epochs.

Source code in src/spyglass/decoding/v0/core.py
def get_epoch_interval_names(nwb_file_name: str) -> list[str]:
    """Find the interval names that are epochs.

    Parameters
    ----------
    nwb_file_name : str

    Returns
    -------
    epoch_names : list[str]
        List of interval names that are epochs.
    """
    interval_list = pd.DataFrame(
        IntervalList() & {"nwb_file_name": nwb_file_name}
    )

    interval_list = interval_list.loc[
        interval_list.interval_list_name.str.contains(
            r"^(?:\d+)_(?:\w+)$", regex=True, na=False
        )
    ]

    return interval_list.interval_list_name.tolist()

get_valid_ephys_position_times_by_epoch(nwb_file_name)

Get the valid ephys position times for each epoch.

Parameters:

Name Type Description Default
nwb_file_name str
required

Returns:

Name Type Description
valid_ephys_position_times_by_epoch dict[str, ndarray]

Dictionary of epoch names and valid ephys position times.

Source code in src/spyglass/decoding/v0/core.py
def get_valid_ephys_position_times_by_epoch(
    nwb_file_name: str,
) -> dict[str, np.ndarray]:
    """Get the valid ephys position times for each epoch.

    Parameters
    ----------
    nwb_file_name : str

    Returns
    -------
    valid_ephys_position_times_by_epoch : dict[str, np.ndarray]
        Dictionary of epoch names and valid ephys position times.

    """
    return {
        epoch: get_valid_ephys_position_times_from_interval(
            epoch, nwb_file_name
        )
        for epoch in get_epoch_interval_names(nwb_file_name)
    }

convert_valid_times_to_slice(valid_times)

Converts valid times to a list of slices for easy indexing.

Parameters:

Name Type Description Default
valid_times (ndarray, shape(n_valid_times, 2))

Start and end times for each valid time.

required

Returns:

Name Type Description
valid_time_slices list[slice]
Source code in src/spyglass/decoding/v0/core.py
def convert_valid_times_to_slice(valid_times: np.ndarray) -> list[slice]:
    """Converts valid times to a list of slices for easy indexing.

    Parameters
    ----------
    valid_times : np.ndarray, shape (n_valid_times, 2)
        Start and end times for each valid time.

    Returns
    -------
    valid_time_slices : list[slice]

    """
    return [slice(times[0], times[1]) for times in valid_times]

create_model_for_multiple_epochs(epoch_names, env_kwargs)

Creates the observation model, environment, and continuous transition types for multiple epochs for decoding

Parameters:

Name Type Description Default
epoch_names (list[str], length(n_epochs))
required
env_kwargs dict

Environment keyword arguments.

required

Returns:

Name Type Description
observation_models tuple[list[ObservationModel]

Observation model for each epoch.

environments list[Environment]

Environment for each epoch.

continuous_transition_types list[list[object]]]

Continuous transition types for each epoch.

Source code in src/spyglass/decoding/v0/core.py
def create_model_for_multiple_epochs(
    epoch_names: list[str], env_kwargs: dict
) -> tuple[list[ObservationModel], list[Environment], list[list[object]]]:
    """Creates the observation model, environment, and continuous transition
    types for multiple epochs for decoding

    Parameters
    ----------
    epoch_names : list[str], length (n_epochs)
    env_kwargs : dict
        Environment keyword arguments.

    Returns
    -------
    observation_models: tuple[list[ObservationModel]
        Observation model for each epoch.
    environments : list[Environment]
        Environment for each epoch.
    continuous_transition_types : list[list[object]]]
        Continuous transition types for each epoch.

    """
    observation_models = []
    environments = []
    continuous_transition_types = []

    for epoch in epoch_names:
        observation_models.append(ObservationModel(epoch))
        environments.append(Environment(epoch, **env_kwargs))

    for epoch1 in epoch_names:
        continuous_transition_types.append([])
        for epoch2 in epoch_names:
            if epoch1 == epoch2:
                continuous_transition_types[-1].append(
                    RandomWalk(epoch1, use_diffusion=False)
                )
            else:
                continuous_transition_types[-1].append(Uniform(epoch1, epoch2))

    return observation_models, environments, continuous_transition_types