Skip to content

core.py

DecodingParameters

Bases: SpyglassMixin, Lookup

Parameters for decoding the animal's mental position and some category of interest

Source code in src/spyglass/decoding/v1/core.py
@schema
class DecodingParameters(SpyglassMixin, dj.Lookup):
    """Parameters for decoding the animal's mental position and some category of interest"""

    definition = """
    decoding_param_name : varchar(80)  # a name for this set of parameters
    ---
    decoding_params : BLOB             # initialization parameters for model
    decoding_kwargs = NULL : BLOB      # additional keyword arguments
    """

    contents = [
        {
            "decoding_param_name": f"contfrag_clusterless_{non_local_detector_version}",
            "decoding_params": ContFragClusterlessClassifier(),
        },
        {
            "decoding_param_name": f"nonlocal_clusterless_{non_local_detector_version}",
            "decoding_params": NonLocalClusterlessDetector(),
        },
        {
            "decoding_param_name": f"contfrag_sorted_{non_local_detector_version}",
            "decoding_params": ContFragSortedSpikesClassifier(),
        },
        {
            "decoding_param_name": f"nonlocal_sorted_{non_local_detector_version}",
            "decoding_params": NonLocalSortedSpikesDetector(),
        },
    ]

    @classmethod
    def insert_default(cls):
        cls.insert(cls.contents, skip_duplicates=True)

    def insert(self, rows, *args, **kwargs):
        for row in rows:
            row["decoding_params"] = convert_classes_to_dict(
                vars(row["decoding_params"])
            )
        super().insert(rows, *args, **kwargs)

    def fetch(self, *args, **kwargs):
        rows = super().fetch(*args, **kwargs)
        if len(rows) > 0 and len(rows[0]) > 1:
            content = []
            for (
                decoding_param_name,
                decoding_params,
                decoding_kwargs,
            ) in rows:
                content.append(
                    (
                        decoding_param_name,
                        restore_classes(decoding_params),
                        decoding_kwargs,
                    )
                )
        else:
            content = rows
        return content

    def fetch1(self, *args, **kwargs):
        row = super().fetch1(*args, **kwargs)
        row["decoding_params"] = restore_classes(row["decoding_params"])
        return row