Skip to content

common_task.py

Task

Bases: SpyglassMixin, Manual

Source code in src/spyglass/common/common_task.py
@schema
class Task(SpyglassMixin, dj.Manual):
    definition = """
     task_name: varchar(80)
     ---
     task_description = NULL: varchar(2000)    # description of this task
     task_type = NULL: varchar(2000)           # type of task
     task_subtype = NULL: varchar(2000)        # subtype of task
     """

    @classmethod
    def insert_from_nwbfile(cls, nwbf: pynwb.NWBFile):
        """Insert tasks from an NWB file.

        Parameters
        ----------
        nwbf : pynwb.NWBFile
            The source NWB file object.
        """
        tasks_mod = nwbf.processing.get("tasks")
        if tasks_mod is None:
            logger.warn(f"No tasks processing module found in {nwbf}\n")
            return
        for task in tasks_mod.data_interfaces.values():
            if cls.check_task_table(task):
                cls.insert_from_task_table(task)

    @classmethod
    def insert_from_task_table(cls, task_table: pynwb.core.DynamicTable):
        """Insert tasks from a pynwb DynamicTable containing task metadata.

        Duplicate tasks will not be added.

        Parameters
        ----------
        task_table : pynwb.core.DynamicTable
            The table representing task metadata.
        """
        taskdf = task_table.to_dataframe()

        task_dicts = taskdf.apply(
            lambda row: dict(
                task_name=row.task_name,
                task_description=row.task_description,
            ),
            axis=1,
        ).tolist()

        cls.insert(task_dicts, skip_duplicates=True)

    @classmethod
    def check_task_table(cls, task_table: pynwb.core.DynamicTable) -> bool:
        """Check format of pynwb DynamicTable containing task metadata.

        The table should be an instance of pynwb.core.DynamicTable and contain
        the columns 'task_name' and 'task_description'.

        Parameters
        ----------
        task_table : pynwb.core.DynamicTable
            The table representing task metadata.

        Returns
        -------
        bool
            Whether the DynamicTable conforms to the expected format for loading
            data into the Task table.
        """
        return (
            isinstance(task_table, pynwb.core.DynamicTable)
            and hasattr(task_table, "task_name")
            and hasattr(task_table, "task_description")
        )

insert_from_nwbfile(nwbf) classmethod

Insert tasks from an NWB file.

Parameters:

Name Type Description Default
nwbf NWBFile

The source NWB file object.

required
Source code in src/spyglass/common/common_task.py
@classmethod
def insert_from_nwbfile(cls, nwbf: pynwb.NWBFile):
    """Insert tasks from an NWB file.

    Parameters
    ----------
    nwbf : pynwb.NWBFile
        The source NWB file object.
    """
    tasks_mod = nwbf.processing.get("tasks")
    if tasks_mod is None:
        logger.warn(f"No tasks processing module found in {nwbf}\n")
        return
    for task in tasks_mod.data_interfaces.values():
        if cls.check_task_table(task):
            cls.insert_from_task_table(task)

insert_from_task_table(task_table) classmethod

Insert tasks from a pynwb DynamicTable containing task metadata.

Duplicate tasks will not be added.

Parameters:

Name Type Description Default
task_table DynamicTable

The table representing task metadata.

required
Source code in src/spyglass/common/common_task.py
@classmethod
def insert_from_task_table(cls, task_table: pynwb.core.DynamicTable):
    """Insert tasks from a pynwb DynamicTable containing task metadata.

    Duplicate tasks will not be added.

    Parameters
    ----------
    task_table : pynwb.core.DynamicTable
        The table representing task metadata.
    """
    taskdf = task_table.to_dataframe()

    task_dicts = taskdf.apply(
        lambda row: dict(
            task_name=row.task_name,
            task_description=row.task_description,
        ),
        axis=1,
    ).tolist()

    cls.insert(task_dicts, skip_duplicates=True)

check_task_table(task_table) classmethod

Check format of pynwb DynamicTable containing task metadata.

The table should be an instance of pynwb.core.DynamicTable and contain the columns 'task_name' and 'task_description'.

Parameters:

Name Type Description Default
task_table DynamicTable

The table representing task metadata.

required

Returns:

Type Description
bool

Whether the DynamicTable conforms to the expected format for loading data into the Task table.

Source code in src/spyglass/common/common_task.py
@classmethod
def check_task_table(cls, task_table: pynwb.core.DynamicTable) -> bool:
    """Check format of pynwb DynamicTable containing task metadata.

    The table should be an instance of pynwb.core.DynamicTable and contain
    the columns 'task_name' and 'task_description'.

    Parameters
    ----------
    task_table : pynwb.core.DynamicTable
        The table representing task metadata.

    Returns
    -------
    bool
        Whether the DynamicTable conforms to the expected format for loading
        data into the Task table.
    """
    return (
        isinstance(task_table, pynwb.core.DynamicTable)
        and hasattr(task_table, "task_name")
        and hasattr(task_table, "task_description")
    )

TaskEpoch

Bases: SpyglassMixin, Imported

Source code in src/spyglass/common/common_task.py
@schema
class TaskEpoch(SpyglassMixin, dj.Imported):
    # Tasks, session and time intervals
    definition = """
     -> Session
     epoch: int  # the session epoch for this task and apparatus(1 based)
     ---
     -> Task
     -> [nullable] CameraDevice
     -> IntervalList
     task_environment = NULL: varchar(200)  # the environment the animal was in
     camera_names : blob # list of keys corresponding to entry in CameraDevice
     """

    def make(self, key):
        """Populate TaskEpoch from the processing module in the NWB file."""
        nwb_file_name = key["nwb_file_name"]
        nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
        nwbf = get_nwb_file(nwb_file_abspath)
        config = get_config(nwb_file_abspath, calling_table=self.camel_name)
        camera_names = dict()

        # the tasks refer to the camera_id which is unique for the NWB file but
        # not for CameraDevice schema, so we need to look up the right camera
        # map camera ID (in camera name) to camera_name

        for device in nwbf.devices.values():
            if isinstance(device, ndx_franklab_novela.CameraDevice):
                # get the camera ID
                camera_id = int(str.split(device.name)[1])
                camera_names[camera_id] = device.camera_name
        if device_list := config.get("CameraDevice"):
            for device in device_list:
                camera_names.update(
                    {
                        name: id
                        for name, id in zip(
                            device.get("camera_name"),
                            device.get("camera_id", -1),
                        )
                    }
                )

        # find the task modules and for each one, add the task to the Task
        # schema if it isn't there and then add an entry for each epoch

        tasks_mod = nwbf.processing.get("tasks")
        config_tasks = config.get("Tasks")
        if tasks_mod is None and config_tasks is None:
            logger.warn(
                f"No tasks processing module found in {nwbf} or config\n"
            )
            return

        task_inserts = []
        for task in tasks_mod.data_interfaces.values():
            if self.check_task_table(task):
                # check if the task is in the Task table and if not, add it
                Task.insert_from_task_table(task)
                key["task_name"] = task.task_name[0]

                # get the CameraDevice used for this task (primary key is
                # camera name so we need to map from ID to name)

                camera_ids = task.camera_id[0]
                valid_camera_ids = [
                    camera_id
                    for camera_id in camera_ids
                    if camera_id in camera_names.keys()
                ]
                if valid_camera_ids:
                    key["camera_names"] = [
                        {"camera_name": camera_names[camera_id]}
                        for camera_id in valid_camera_ids
                    ]
                else:
                    logger.warn(
                        f"No camera device found with ID {camera_ids} in NWB "
                        + f"file {nwbf}\n"
                    )
                # Add task environment
                if hasattr(task, "task_environment"):
                    key["task_environment"] = task.task_environment[0]

                # get the interval list for this task, which corresponds to the
                # matching epoch for the raw data. Users should define more
                # restrictive intervals as required for analyses

                session_intervals = (
                    IntervalList() & {"nwb_file_name": nwb_file_name}
                ).fetch("interval_list_name")
                for epoch in task.task_epochs[0]:
                    # TODO in beans file, task_epochs[0] is 1x2 dset of ints,
                    # so epoch would be an int
                    key["epoch"] = epoch
                    target_interval = self.get_epoch_interval_name(
                        epoch, session_intervals
                    )
                    if target_interval is None:
                        logger.warn("Skipping epoch.")
                        continue
                    key["interval_list_name"] = target_interval
                    task_inserts.append(key.copy())

        # Add tasks from config
        for task in config_tasks:
            new_key = {
                **key,
                "task_name": task.get("task_name"),
                "task_environment": task.get("task_environment", None),
            }
            # add cameras
            camera_ids = task.get("camera_id", [])
            valid_camera_ids = [
                camera_id
                for camera_id in camera_ids
                if camera_id in camera_names.keys()
            ]
            if valid_camera_ids:
                new_key["camera_names"] = [
                    {"camera_name": camera_names[camera_id]}
                    for camera_id in valid_camera_ids
                ]
            session_intervals = (
                IntervalList() & {"nwb_file_name": nwb_file_name}
            ).fetch("interval_list_name")
            for epoch in task.get("task_epochs", []):
                new_key["epoch"] = epoch
                target_interval = self.get_epoch_interval_name(
                    epoch, session_intervals
                )
                if target_interval is None:
                    logger.warn("Skipping epoch.")
                    continue
                new_key["interval_list_name"] = target_interval
                task_inserts.append(key.copy())

        self.insert(task_inserts, allow_direct_insert=True)

    @classmethod
    def get_epoch_interval_name(cls, epoch, session_intervals):
        """Get the interval name for a given epoch based on matching number"""
        target_interval = str(epoch).zfill(2)
        possible_targets = [
            interval
            for interval in session_intervals
            if target_interval in interval
        ]
        if not possible_targets:
            logger.warn(
                f"Interval not found for epoch {epoch} in {nwb_file_name}."
            )
        elif len(possible_targets) > 1:
            logger.warn(
                f"Multiple intervals found for epoch {epoch} in {nwb_file_name}. "
                + f"matches are {possible_targets}."
            )
        else:
            return possible_targets[0]

    @classmethod
    def update_entries(cls, restrict=True):
        """Update entries in the TaskEpoch table based on a restriction."""
        existing_entries = (cls & restrict).fetch("KEY")
        for row in existing_entries:
            if (cls & row).fetch1("camera_names"):
                continue
            row["camera_names"] = [
                {"camera_name": (cls & row).fetch1("camera_name")}
            ]
            cls.update1(row=row)

    @classmethod
    def check_task_table(cls, task_table: pynwb.core.DynamicTable) -> bool:
        """Check format of pynwb DynamicTable containing task metadata.

        The table should be an instance of pynwb.core.DynamicTable and contain
        the columns 'task_name', 'task_description', 'camera_id', 'and
        'task_epochs'.

        Parameters
        ----------
        task_table : pynwb.core.DynamicTable
            The table representing task metadata.

        Returns
        -------
        bool
            Whether the DynamicTable conforms to the expected format for
            loading data into the TaskEpoch table.
        """

        # TODO this could be more strict and check data types, but really it
        # should be schematized
        return (
            Task.check_task_table(task_table)
            and hasattr(task_table, "camera_id")
            and hasattr(task_table, "task_epochs")
        )

make(key)

Populate TaskEpoch from the processing module in the NWB file.

Source code in src/spyglass/common/common_task.py
def make(self, key):
    """Populate TaskEpoch from the processing module in the NWB file."""
    nwb_file_name = key["nwb_file_name"]
    nwb_file_abspath = Nwbfile().get_abs_path(nwb_file_name)
    nwbf = get_nwb_file(nwb_file_abspath)
    config = get_config(nwb_file_abspath, calling_table=self.camel_name)
    camera_names = dict()

    # the tasks refer to the camera_id which is unique for the NWB file but
    # not for CameraDevice schema, so we need to look up the right camera
    # map camera ID (in camera name) to camera_name

    for device in nwbf.devices.values():
        if isinstance(device, ndx_franklab_novela.CameraDevice):
            # get the camera ID
            camera_id = int(str.split(device.name)[1])
            camera_names[camera_id] = device.camera_name
    if device_list := config.get("CameraDevice"):
        for device in device_list:
            camera_names.update(
                {
                    name: id
                    for name, id in zip(
                        device.get("camera_name"),
                        device.get("camera_id", -1),
                    )
                }
            )

    # find the task modules and for each one, add the task to the Task
    # schema if it isn't there and then add an entry for each epoch

    tasks_mod = nwbf.processing.get("tasks")
    config_tasks = config.get("Tasks")
    if tasks_mod is None and config_tasks is None:
        logger.warn(
            f"No tasks processing module found in {nwbf} or config\n"
        )
        return

    task_inserts = []
    for task in tasks_mod.data_interfaces.values():
        if self.check_task_table(task):
            # check if the task is in the Task table and if not, add it
            Task.insert_from_task_table(task)
            key["task_name"] = task.task_name[0]

            # get the CameraDevice used for this task (primary key is
            # camera name so we need to map from ID to name)

            camera_ids = task.camera_id[0]
            valid_camera_ids = [
                camera_id
                for camera_id in camera_ids
                if camera_id in camera_names.keys()
            ]
            if valid_camera_ids:
                key["camera_names"] = [
                    {"camera_name": camera_names[camera_id]}
                    for camera_id in valid_camera_ids
                ]
            else:
                logger.warn(
                    f"No camera device found with ID {camera_ids} in NWB "
                    + f"file {nwbf}\n"
                )
            # Add task environment
            if hasattr(task, "task_environment"):
                key["task_environment"] = task.task_environment[0]

            # get the interval list for this task, which corresponds to the
            # matching epoch for the raw data. Users should define more
            # restrictive intervals as required for analyses

            session_intervals = (
                IntervalList() & {"nwb_file_name": nwb_file_name}
            ).fetch("interval_list_name")
            for epoch in task.task_epochs[0]:
                # TODO in beans file, task_epochs[0] is 1x2 dset of ints,
                # so epoch would be an int
                key["epoch"] = epoch
                target_interval = self.get_epoch_interval_name(
                    epoch, session_intervals
                )
                if target_interval is None:
                    logger.warn("Skipping epoch.")
                    continue
                key["interval_list_name"] = target_interval
                task_inserts.append(key.copy())

    # Add tasks from config
    for task in config_tasks:
        new_key = {
            **key,
            "task_name": task.get("task_name"),
            "task_environment": task.get("task_environment", None),
        }
        # add cameras
        camera_ids = task.get("camera_id", [])
        valid_camera_ids = [
            camera_id
            for camera_id in camera_ids
            if camera_id in camera_names.keys()
        ]
        if valid_camera_ids:
            new_key["camera_names"] = [
                {"camera_name": camera_names[camera_id]}
                for camera_id in valid_camera_ids
            ]
        session_intervals = (
            IntervalList() & {"nwb_file_name": nwb_file_name}
        ).fetch("interval_list_name")
        for epoch in task.get("task_epochs", []):
            new_key["epoch"] = epoch
            target_interval = self.get_epoch_interval_name(
                epoch, session_intervals
            )
            if target_interval is None:
                logger.warn("Skipping epoch.")
                continue
            new_key["interval_list_name"] = target_interval
            task_inserts.append(key.copy())

    self.insert(task_inserts, allow_direct_insert=True)

get_epoch_interval_name(epoch, session_intervals) classmethod

Get the interval name for a given epoch based on matching number

Source code in src/spyglass/common/common_task.py
@classmethod
def get_epoch_interval_name(cls, epoch, session_intervals):
    """Get the interval name for a given epoch based on matching number"""
    target_interval = str(epoch).zfill(2)
    possible_targets = [
        interval
        for interval in session_intervals
        if target_interval in interval
    ]
    if not possible_targets:
        logger.warn(
            f"Interval not found for epoch {epoch} in {nwb_file_name}."
        )
    elif len(possible_targets) > 1:
        logger.warn(
            f"Multiple intervals found for epoch {epoch} in {nwb_file_name}. "
            + f"matches are {possible_targets}."
        )
    else:
        return possible_targets[0]

update_entries(restrict=True) classmethod

Update entries in the TaskEpoch table based on a restriction.

Source code in src/spyglass/common/common_task.py
@classmethod
def update_entries(cls, restrict=True):
    """Update entries in the TaskEpoch table based on a restriction."""
    existing_entries = (cls & restrict).fetch("KEY")
    for row in existing_entries:
        if (cls & row).fetch1("camera_names"):
            continue
        row["camera_names"] = [
            {"camera_name": (cls & row).fetch1("camera_name")}
        ]
        cls.update1(row=row)

check_task_table(task_table) classmethod

Check format of pynwb DynamicTable containing task metadata.

The table should be an instance of pynwb.core.DynamicTable and contain the columns 'task_name', 'task_description', 'camera_id', 'and 'task_epochs'.

Parameters:

Name Type Description Default
task_table DynamicTable

The table representing task metadata.

required

Returns:

Type Description
bool

Whether the DynamicTable conforms to the expected format for loading data into the TaskEpoch table.

Source code in src/spyglass/common/common_task.py
@classmethod
def check_task_table(cls, task_table: pynwb.core.DynamicTable) -> bool:
    """Check format of pynwb DynamicTable containing task metadata.

    The table should be an instance of pynwb.core.DynamicTable and contain
    the columns 'task_name', 'task_description', 'camera_id', 'and
    'task_epochs'.

    Parameters
    ----------
    task_table : pynwb.core.DynamicTable
        The table representing task metadata.

    Returns
    -------
    bool
        Whether the DynamicTable conforms to the expected format for
        loading data into the TaskEpoch table.
    """

    # TODO this could be more strict and check data types, but really it
    # should be schematized
    return (
        Task.check_task_table(task_table)
        and hasattr(task_table, "camera_id")
        and hasattr(task_table, "task_epochs")
    )