Skip to content

common_dio.py

DIOEvents

Bases: SpyglassMixin, Imported

Source code in src/spyglass/common/common_dio.py
@schema
class DIOEvents(SpyglassMixin, dj.Imported):
    definition = """
    -> Session
    dio_event_name: varchar(80)   # the name assigned to this DIO event
    ---
    dio_object_id: varchar(40)    # the object id of the data in the NWB file
    -> IntervalList               # the list of intervals for this object
    """

    _nwb_table = Nwbfile

    def make(self, key):
        nwb_file_name = key["nwb_file_name"]
        nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name)
        nwbf = get_nwb_file(nwb_file_abspath)

        behav_events = get_data_interface(
            nwbf, "behavioral_events", pynwb.behavior.BehavioralEvents
        )
        if behav_events is None:
            logger.warn(
                "No conforming behavioral events data interface found in "
                + f"{nwb_file_name}\n"
            )
            return

        # Times for these events correspond to the valid times for the raw data
        key["interval_list_name"] = (
            Raw() & {"nwb_file_name": nwb_file_name}
        ).fetch1("interval_list_name")
        for event_series in behav_events.time_series.values():
            key["dio_event_name"] = event_series.name
            key["dio_object_id"] = event_series.object_id
            self.insert1(key, skip_duplicates=True)

    def plot_all_dio_events(self, return_fig=False):
        """Plot all DIO events in the session.

        Examples
        --------
        > restr1 = {'nwb_file_name': 'arthur20220314_.nwb'}
        > restr2 = {'nwb_file_name': 'arthur20220316_.nwb'}
        > (DIOEvents & restr1).plot_all_dio_events()
        > (DIOEvents & [restr1, restr2]).plot_all_dio_events()

        """
        behavioral_events = self.fetch_nwb()
        nwb_file_names = np.unique(
            [event["nwb_file_name"] for event in behavioral_events]
        )
        epoch_valid_times = (
            pd.DataFrame(
                IntervalList()
                & [
                    {"nwb_file_name": nwb_file_name}
                    for nwb_file_name in nwb_file_names
                ]
            )
            .set_index("interval_list_name")
            .filter(regex=r"^[0-9]", axis=0)
            .valid_times
        )

        n_events = len(behavioral_events)

        _, axes = plt.subplots(
            n_events,
            1,
            figsize=(15, n_events * 0.3),
            dpi=100,
            sharex=True,
            constrained_layout=True,
        )

        for ind, (ax, event) in enumerate(zip(axes.flat, behavioral_events)):
            for epoch_name, epoch in epoch_valid_times.items():
                start_time, stop_time = epoch.squeeze()
                ax.axvspan(start_time, stop_time, alpha=0.5)
                if ind == 0:
                    ax.text(
                        start_time + (stop_time - start_time) / 2,
                        1.001,
                        epoch_name,
                        ha="center",
                        va="bottom",
                    )
            ax.step(
                np.asarray(event["dio"].timestamps),
                np.asarray(event["dio"].data),
                where="post",
                color="black",
            )
            ax.set_ylabel(
                event["dio_event_name"], rotation=0, ha="right", va="center"
            )
            ax.set_yticks([])
        ax.set_xlabel("Time")

        if len(nwb_file_names) == 1:
            plt.suptitle(f"DIO events in {nwb_file_names[0]}")
        else:
            plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}")

        if return_fig:
            return plt.gcf()

plot_all_dio_events(return_fig=False)

Plot all DIO events in the session.

Examples:

restr1 = {'nwb_file_name': 'arthur20220314_.nwb'} restr2 = {'nwb_file_name': 'arthur20220316_.nwb'} (DIOEvents & restr1).plot_all_dio_events() (DIOEvents & [restr1, restr2]).plot_all_dio_events()

Source code in src/spyglass/common/common_dio.py
def plot_all_dio_events(self, return_fig=False):
    """Plot all DIO events in the session.

    Examples
    --------
    > restr1 = {'nwb_file_name': 'arthur20220314_.nwb'}
    > restr2 = {'nwb_file_name': 'arthur20220316_.nwb'}
    > (DIOEvents & restr1).plot_all_dio_events()
    > (DIOEvents & [restr1, restr2]).plot_all_dio_events()

    """
    behavioral_events = self.fetch_nwb()
    nwb_file_names = np.unique(
        [event["nwb_file_name"] for event in behavioral_events]
    )
    epoch_valid_times = (
        pd.DataFrame(
            IntervalList()
            & [
                {"nwb_file_name": nwb_file_name}
                for nwb_file_name in nwb_file_names
            ]
        )
        .set_index("interval_list_name")
        .filter(regex=r"^[0-9]", axis=0)
        .valid_times
    )

    n_events = len(behavioral_events)

    _, axes = plt.subplots(
        n_events,
        1,
        figsize=(15, n_events * 0.3),
        dpi=100,
        sharex=True,
        constrained_layout=True,
    )

    for ind, (ax, event) in enumerate(zip(axes.flat, behavioral_events)):
        for epoch_name, epoch in epoch_valid_times.items():
            start_time, stop_time = epoch.squeeze()
            ax.axvspan(start_time, stop_time, alpha=0.5)
            if ind == 0:
                ax.text(
                    start_time + (stop_time - start_time) / 2,
                    1.001,
                    epoch_name,
                    ha="center",
                    va="bottom",
                )
        ax.step(
            np.asarray(event["dio"].timestamps),
            np.asarray(event["dio"].data),
            where="post",
            color="black",
        )
        ax.set_ylabel(
            event["dio_event_name"], rotation=0, ha="right", va="center"
        )
        ax.set_yticks([])
    ax.set_xlabel("Time")

    if len(nwb_file_names) == 1:
        plt.suptitle(f"DIO events in {nwb_file_names[0]}")
    else:
        plt.suptitle(f"DIO events in {', '.join(nwb_file_names)}")

    if return_fig:
        return plt.gcf()