Skip to content

export.py

ExportMixin

Source code in src/spyglass/utils/mixins/export.py
class ExportMixin:

    _export_cache = defaultdict(set)

    # ------------------------------ Version Info -----------------------------

    @cached_property
    def _spyglass_version(self):
        """Get Spyglass version."""
        from spyglass import __version__ as sg_version

        ret = ".".join(sg_version.split(".")[:3])  # Ditch commit info

        if self._test_mode:
            return ret[:16] if len(ret) > 16 else ret

        if not bool(re_match(r"^\d+\.\d+\.\d+", ret)):  # Major.Minor.Patch
            raise ValueError(
                f"Spyglass version issues. Expected #.#.#, Got {ret}."
                + "Please try running `hatch build` from your spyglass dir."
            )

        return ret

    def compare_versions(
        self, version: str, other: str = None, msg: str = None
    ) -> None:
        """Compare two versions. Raise error if not equal.

        Parameters
        ----------
        version : str
            Version to compare.
        other : str, optional
            Other version to compare. Default None. Use self._spyglass_version.
        msg : str, optional
            Additional error message info. Default None.
        """
        if self._test_mode:
            return

        other = other or self._spyglass_version

        if version_parse(version) != version_parse(other):
            raise RuntimeError(
                f"Found mismatched versions: {version} vs {other}\n{msg}"
            )

    # ------------------------------- Dependency -------------------------------

    @cached_property
    def _export_table(self):
        """Lazy load export selection table."""
        from spyglass.common.common_usage import ExportSelection

        return ExportSelection()

    # ------------------------------ ID Property ------------------------------

    @property
    def export_id(self):
        """ID of export in progress.

        NOTE: User of an env variable to store export_id may not be thread safe.
        Exports must be run in sequence, not parallel.
        """

        return int(environ.get(EXPORT_ENV_VAR, 0))

    @export_id.setter
    def export_id(self, value):
        """Set ID of export using `table.export_id = X` notation."""
        if self.export_id != 0 and self.export_id != value:
            raise RuntimeError("Export already in progress.")
        environ[EXPORT_ENV_VAR] = str(value)
        exit_register(self._export_id_cleanup)  # End export on exit

    @export_id.deleter
    def export_id(self):
        """Delete ID of export using `del table.export_id` notation."""
        self._export_id_cleanup()

    def _export_id_cleanup(self):
        """Cleanup export ID."""
        self._export_cache = dict()
        if environ.get(EXPORT_ENV_VAR):
            del environ[EXPORT_ENV_VAR]
        exit_unregister(self._export_id_cleanup)  # Remove exit hook

    # ------------------------------- Export API -------------------------------

    def _start_export(self, paper_id, analysis_id):
        """Start export process."""
        if self.export_id:
            logger.info(f"Export {self.export_id} in progress. Starting new.")
            self._stop_export(warn=False)

        self.export_id = self._export_table.insert1_return_pk(
            dict(
                paper_id=paper_id,
                analysis_id=analysis_id,
                spyglass_version=self._spyglass_version,
            )
        )

    def _stop_export(self, warn=True):
        """End export process."""
        if not self.export_id and warn:
            logger.warning("Export not in progress.")
        del self.export_id

    # ------------------------------- Log Fetch -------------------------------

    def _called_funcs(self):
        """Get stack trace functions."""
        ignore = {
            "__and__",  # caught by restrict
            "__mul__",  # caught by join
            "_called_funcs",  # run here
            "_log_fetch",  # run here
            "_log_fetch_nwb",  # run here
            "<module>",
            "_exec_file",
            "_pseudo_sync_runner",
            "_run_cell",
            "_run_cmd_line_code",
            "_run_with_log",
            "execfile",
            "init_code",
            "initialize",
            "inner",
            "interact",
            "launch_instance",
            "mainloop",
            "run",
            "run_ast_nodes",
            "run_cell",
            "run_cell_async",
            "run_code",
            "run_line_magic",
            "safe_execfile",
            "start",
            "start_ipython",
        }

        ret = {i.function for i in inspect_stack()} - ignore
        return ret

    def _log_fetch(self, restriction=None, *args, **kwargs):
        """Log fetch for export."""
        if (
            not self.export_id
            or self.database == "common_usage"
            or not FETCH_LOG_FLAG.get()
        ):
            return

        banned = [
            "head",  # Prevents on Table().head() call
            "tail",  # Prevents on Table().tail() call
            "preview",  # Prevents on Table() call
            "_repr_html_",  # Prevents on Table() call in notebook
            "cautious_delete",  # Prevents add on permission check during delete
            # "get_abs_path",  # Assumes that fetch_nwb will catch file/table
            "_check_delete_permission",  # Prevents on Table().delete()
            "delete",  # Prevents on Table().delete()
            "_load_admin",  # Prevents on permission check
        ]  # if called by any in banned, return
        if set(banned) & self._called_funcs():
            return

        restr = restriction or self.restriction or True
        limit = kwargs.get("limit")
        offset = kwargs.get("offset")
        if limit or offset:  # Use result as restr if limit/offset
            restr = self.restrict(restr).fetch(
                log_export=False, as_dict=True, limit=limit, offset=offset
            )

        restr_str = make_condition(self, restr, set())

        if restr_str is True:
            restr_str = "True"  # otherwise stored in table as '1'

        if isinstance(restr_str, str) and "SELECT" in restr_str:
            raise RuntimeError(
                "Export cannot handle subquery restrictions. Please submit a "
                + "bug report on GitHub with the code you ran and this"
                + f"restriction:\n\t{restr_str}"
            )

        if isinstance(restr_str, str) and len(restr_str) > 2048:
            raise RuntimeError(
                "Export cannot handle restrictions > 2048.\n\t"
                + "If required, please open an issue on GitHub.\n\t"
                + f"Restriction: {restr_str}"
            )

        if isinstance(restr_str, str):
            restr_str = bash_escape_sql(restr_str, add_newline=False)

        if restr_str in self._export_cache[self.full_table_name]:
            return
        self._export_cache[self.full_table_name].add(restr_str)

        self._export_table.Table.insert1(
            dict(
                export_id=self.export_id,
                table_name=self.full_table_name,
                restriction=restr_str,
            )
        )
        restr_logline = restr_str.replace("AND", "\n\tAND").replace(
            "OR", "\n\tOR"
        )
        logger.debug(f"\nTable: {self.full_table_name}\nRestr: {restr_logline}")

    def _log_fetch_nwb(self, table, table_attr):
        """Log fetch_nwb for export table."""
        tbl_pk = "analysis_file_name"
        fnames = self.fetch(tbl_pk, log_export=True)
        logger.debug(
            f"Export: fetch_nwb\nTable:{self.full_table_name},\nFiles: {fnames}"
        )
        self._export_table.File.insert(
            [{"export_id": self.export_id, tbl_pk: fname} for fname in fnames],
            skip_duplicates=True,
        )
        fnames_str = "('" + "', ".join(fnames) + "')"  # log AnalysisFile table
        table()._log_fetch(restriction=f"{tbl_pk} in {fnames_str}")

    def _run_join(self, **kwargs):
        """Log join for export.

        Special case to log primary keys of each table in join, avoiding
        long restriction strings.
        """
        table_list = [self]
        other = kwargs.get("other")

        if hasattr(other, "_log_fetch"):  # Check if other has mixin
            table_list.append(other)  # can other._log_fetch
        else:
            logger.warning(f"Cannot export log join for\n{other}")

        joined = self.proj().join(other.proj(), log_export=False)
        for table in table_list:  # log separate for unique pks
            if isinstance(table, type) and issubclass(table, Table):
                table = table()  # adapted from dj.declare.compile_foreign_key
            for r in joined.fetch(*table.primary_key, as_dict=True):
                table._log_fetch(restriction=r)

    def _run_with_log(self, method, *args, log_export=True, **kwargs):
        """Run method, log fetch, and return result.

        Uses FETCH_LOG_FLAG to prevent multiple logs in one user call.
        """
        log_this_call = FETCH_LOG_FLAG.get()  # One log per fetch call

        if log_this_call and not self.database == "common_usage":
            FETCH_LOG_FLAG.set(False)

        try:
            ret = method(*args, **kwargs)
        finally:
            if log_this_call:
                FETCH_LOG_FLAG.set(True)

        if log_export and self.export_id and log_this_call:
            if getattr(method, "__name__", None) == "join":  # special case
                self._run_join(**kwargs)
            else:
                self._log_fetch(restriction=kwargs.get("restriction"))
            logger.debug(f"Export: {self._called_funcs()}")

        return ret

    # -------------------------- Intercept DJ methods --------------------------

    def fetch(self, *args, log_export=True, **kwargs):
        """Log fetch for export."""
        if not self.export_id:
            return super().fetch(*args, **kwargs)
        return self._run_with_log(
            super().fetch, *args, log_export=log_export, **kwargs
        )

    def fetch1(self, *args, log_export=True, **kwargs):
        """Log fetch1 for export."""
        if not self.export_id:
            return super().fetch1(*args, **kwargs)
        return self._run_with_log(
            super().fetch1, *args, log_export=log_export, **kwargs
        )

    def restrict(self, restriction):
        """Log restrict for export."""
        if not self.export_id:
            return super().restrict(restriction)
        log_export = "fetch_nwb" not in self._called_funcs()
        return self._run_with_log(
            super().restrict,
            restriction=AndList([restriction, self.restriction]),
            log_export=log_export,
        )

    def join(self, other, log_export=True, *args, **kwargs):
        """Log join for export.

        Join in dj_helper_func related to fetch_nwb have `log_export=False`
        because these entries are caught on the file cascade in RestrGraph.
        """
        if not self.export_id:
            return super().join(other=other, *args, **kwargs)

        return self._run_with_log(
            super().join, other=other, log_export=log_export, *args, **kwargs
        )

compare_versions(version, other=None, msg=None)

Compare two versions. Raise error if not equal.

Parameters:

Name Type Description Default
version str

Version to compare.

required
other str

Other version to compare. Default None. Use self._spyglass_version.

None
msg str

Additional error message info. Default None.

None
Source code in src/spyglass/utils/mixins/export.py
def compare_versions(
    self, version: str, other: str = None, msg: str = None
) -> None:
    """Compare two versions. Raise error if not equal.

    Parameters
    ----------
    version : str
        Version to compare.
    other : str, optional
        Other version to compare. Default None. Use self._spyglass_version.
    msg : str, optional
        Additional error message info. Default None.
    """
    if self._test_mode:
        return

    other = other or self._spyglass_version

    if version_parse(version) != version_parse(other):
        raise RuntimeError(
            f"Found mismatched versions: {version} vs {other}\n{msg}"
        )

export_id deletable property writable

ID of export in progress.

NOTE: User of an env variable to store export_id may not be thread safe. Exports must be run in sequence, not parallel.

fetch(*args, log_export=True, **kwargs)

Log fetch for export.

Source code in src/spyglass/utils/mixins/export.py
def fetch(self, *args, log_export=True, **kwargs):
    """Log fetch for export."""
    if not self.export_id:
        return super().fetch(*args, **kwargs)
    return self._run_with_log(
        super().fetch, *args, log_export=log_export, **kwargs
    )

fetch1(*args, log_export=True, **kwargs)

Log fetch1 for export.

Source code in src/spyglass/utils/mixins/export.py
def fetch1(self, *args, log_export=True, **kwargs):
    """Log fetch1 for export."""
    if not self.export_id:
        return super().fetch1(*args, **kwargs)
    return self._run_with_log(
        super().fetch1, *args, log_export=log_export, **kwargs
    )

restrict(restriction)

Log restrict for export.

Source code in src/spyglass/utils/mixins/export.py
def restrict(self, restriction):
    """Log restrict for export."""
    if not self.export_id:
        return super().restrict(restriction)
    log_export = "fetch_nwb" not in self._called_funcs()
    return self._run_with_log(
        super().restrict,
        restriction=AndList([restriction, self.restriction]),
        log_export=log_export,
    )

join(other, log_export=True, *args, **kwargs)

Log join for export.

Join in dj_helper_func related to fetch_nwb have log_export=False because these entries are caught on the file cascade in RestrGraph.

Source code in src/spyglass/utils/mixins/export.py
def join(self, other, log_export=True, *args, **kwargs):
    """Log join for export.

    Join in dj_helper_func related to fetch_nwb have `log_export=False`
    because these entries are caught on the file cascade in RestrGraph.
    """
    if not self.export_id:
        return super().join(other=other, *args, **kwargs)

    return self._run_with_log(
        super().join, other=other, log_export=log_export, *args, **kwargs
    )