Skip to content

position_dlc_model.py

DLCModelInput

Bases: SpyglassMixin, Manual

Table to hold model path if model is being input from local disk instead of Spyglass

Source code in src/spyglass/position/v1/position_dlc_model.py
@schema
class DLCModelInput(SpyglassMixin, dj.Manual):
    """Table to hold model path if model is being input
    from local disk instead of Spyglass
    """

    definition = """
    dlc_model_name : varchar(64)  # Different than dlc_model_name in DLCModelSource... not great
    -> DLCProject
    ---
    project_path         : varchar(255) # Path to project directory
    """

    def insert1(self, key, **kwargs):
        # expects key from DLCProject with config_path
        project_path = Path(key["config_path"]).parent
        assert project_path.exists(), "project path does not exist"
        key["dlc_model_name"] = f'{project_path.name.split("model")[0]}model'
        key["project_path"] = project_path.as_posix()
        del key["config_path"]
        super().insert1(key, **kwargs)
        DLCModelSource.insert_entry(
            dlc_model_name=key["dlc_model_name"],
            project_name=key["project_name"],
            source="FromImport",
            key=key,
            skip_duplicates=True,
        )

DLCModelSource

Bases: SpyglassMixin, Manual

Table to determine whether model originates from upstream DLCModelTraining table, or from local directory

Source code in src/spyglass/position/v1/position_dlc_model.py
@schema
class DLCModelSource(SpyglassMixin, dj.Manual):
    """Table to determine whether model originates from
    upstream DLCModelTraining table, or from local directory
    """

    definition = """
    -> DLCProject
    dlc_model_name : varchar(64)    # User-friendly model name
    ---
    source         : enum ('FromUpstream', 'FromImport')
    """

    class FromImport(SpyglassMixin, dj.Part):
        definition = """
        -> DLCModelSource
        -> DLCModelInput
        ---
        project_path : varchar(255)
        """

    class FromUpstream(SpyglassMixin, dj.Part):
        definition = """
        -> DLCModelSource
        -> DLCModelTraining
        ---
        project_path : varchar(255)
        """

    @classmethod
    @accepts(None, None, ("FromUpstream", "FromImport"), None)
    def insert_entry(
        cls,
        dlc_model_name: str,
        project_name: str,
        source: str = "FromUpstream",
        key: dict = None,
        **kwargs,
    ):
        cls.insert1(
            {
                "dlc_model_name": dlc_model_name,
                "project_name": project_name,
                "source": source,
            },
            **kwargs,
        )
        part_table = getattr(cls, source)
        table_query = dj.FreeTable(
            dj.conn(), full_table_name=part_table.parents()[-1]
        ) & {"project_name": project_name}
        project_path = table_query.fetch1("project_path")
        part_table.insert1(
            {
                "dlc_model_name": dlc_model_name,
                "project_name": project_name,
                "project_path": project_path,
                **key,
            },
            **kwargs,
        )

DLCModelEvaluation

Bases: SpyglassMixin, Computed

Source code in src/spyglass/position/v1/position_dlc_model.py
@schema
class DLCModelEvaluation(SpyglassMixin, dj.Computed):
    definition = """
    -> DLCModel
    ---
    train_iterations   : int   # Training iterations
    train_error=null   : float # Train error (px)
    test_error=null    : float # Test error (px)
    p_cutoff=null      : float # p-cutoff used
    train_error_p=null : float # Train error with p-cutoff
    test_error_p=null  : float # Test error with p-cutoff
    """

    def make(self, key):
        """.populate() method will launch evaluation for each unique entry in Model."""
        import csv

        from deeplabcut import evaluate_network
        from deeplabcut.utils.auxiliaryfunctions import get_evaluation_folder

        dlc_config, project_path, model_prefix, shuffle, trainingsetindex = (
            DLCModel & key
        ).fetch1(
            "config_template",
            "project_path",
            "model_prefix",
            "shuffle",
            "trainingsetindex",
        )

        yml_path, _ = dlc_reader.read_yaml(project_path)

        evaluate_network(
            yml_path,
            Shuffles=[shuffle],  # this needs to be a list
            trainingsetindex=trainingsetindex,
            comparisonbodyparts="all",
        )

        eval_folder = get_evaluation_folder(
            trainFraction=dlc_config["TrainingFraction"][trainingsetindex],
            shuffle=shuffle,
            cfg=dlc_config,
            modelprefix=model_prefix,
        )
        eval_path = project_path / eval_folder
        assert (
            eval_path.exists()
        ), f"Couldn't find evaluation folder:\n{eval_path}"

        eval_csvs = list(eval_path.glob("*csv"))
        max_modified_time = 0
        for eval_csv in eval_csvs:
            modified_time = os.path.getmtime(eval_csv)
            if modified_time > max_modified_time:
                eval_csv_latest = eval_csv
        with open(eval_csv_latest, newline="") as f:
            results = list(csv.DictReader(f, delimiter=","))[0]
        # in testing, test_error_p returned empty string
        self.insert1(
            dict(
                key,
                train_iterations=results["Training iterations:"],
                train_error=results[" Train error(px)"],
                test_error=results[" Test error(px)"],
                p_cutoff=results["p-cutoff used"],
                train_error_p=results["Train error with p-cutoff"],
                test_error_p=results["Test error with p-cutoff"],
            )
        )

make(key)

.populate() method will launch evaluation for each unique entry in Model.

Source code in src/spyglass/position/v1/position_dlc_model.py
def make(self, key):
    """.populate() method will launch evaluation for each unique entry in Model."""
    import csv

    from deeplabcut import evaluate_network
    from deeplabcut.utils.auxiliaryfunctions import get_evaluation_folder

    dlc_config, project_path, model_prefix, shuffle, trainingsetindex = (
        DLCModel & key
    ).fetch1(
        "config_template",
        "project_path",
        "model_prefix",
        "shuffle",
        "trainingsetindex",
    )

    yml_path, _ = dlc_reader.read_yaml(project_path)

    evaluate_network(
        yml_path,
        Shuffles=[shuffle],  # this needs to be a list
        trainingsetindex=trainingsetindex,
        comparisonbodyparts="all",
    )

    eval_folder = get_evaluation_folder(
        trainFraction=dlc_config["TrainingFraction"][trainingsetindex],
        shuffle=shuffle,
        cfg=dlc_config,
        modelprefix=model_prefix,
    )
    eval_path = project_path / eval_folder
    assert (
        eval_path.exists()
    ), f"Couldn't find evaluation folder:\n{eval_path}"

    eval_csvs = list(eval_path.glob("*csv"))
    max_modified_time = 0
    for eval_csv in eval_csvs:
        modified_time = os.path.getmtime(eval_csv)
        if modified_time > max_modified_time:
            eval_csv_latest = eval_csv
    with open(eval_csv_latest, newline="") as f:
        results = list(csv.DictReader(f, delimiter=","))[0]
    # in testing, test_error_p returned empty string
    self.insert1(
        dict(
            key,
            train_iterations=results["Training iterations:"],
            train_error=results[" Train error(px)"],
            test_error=results[" Test error(px)"],
            p_cutoff=results["p-cutoff used"],
            train_error_p=results["Train error with p-cutoff"],
            test_error_p=results["Test error with p-cutoff"],
        )
    )

str_to_bool(value)

Return whether the provided string represents true. Otherwise false.

Source code in src/spyglass/position/v1/position_dlc_model.py
def str_to_bool(value) -> bool:
    """Return whether the provided string represents true. Otherwise false."""
    # Due to distutils equivalent depreciation in 3.10
    # Adopted from github.com/PostHog/posthog/blob/master/posthog/utils.py
    if not value:
        return False
    return str(value).lower() in ("y", "yes", "t", "true", "on", "1")