Skip to content

dj_decoder_conversion.py

Converts decoder classes into dictionaries and dictionaries into classes so that datajoint can store them in tables.

restore_classes(params)

Converts a dictionary of parameters into a dictionary of classes since datajoint cannot handle classes

Source code in src/spyglass/decoding/v0/dj_decoder_conversion.py
def restore_classes(params: dict) -> dict:
    """Converts a dictionary of parameters into a dictionary of classes since datajoint cannot handle classes"""
    continuous_state_transition_types = {
        "RandomWalk": RandomWalk,
        "RandomWalkDirection1": RandomWalkDirection1,
        "RandomWalkDirection2": RandomWalkDirection2,
        "Uniform": Uniform,
        "Identity": Identity,
    }

    discrete_state_transition_types = {
        "DiagonalDiscrete": DiagonalDiscrete,
        "UniformDiscrete": UniformDiscrete,
        "RandomDiscrete": RandomDiscrete,
        "UserDefinedDiscrete": UserDefinedDiscrete,
    }

    initial_conditions_types = {
        "UniformInitialConditions": UniformInitialConditions,
        "UniformOneEnvironmentInitialConditions": UniformOneEnvironmentInitialConditions,
    }

    params["classifier_params"]["continuous_transition_types"] = [
        [
            _convert_dict_to_class(st, continuous_state_transition_types)
            for st in sts
        ]
        for sts in params["classifier_params"]["continuous_transition_types"]
    ]
    params["classifier_params"]["environments"] = [
        _convert_env_dict(env_params)
        for env_params in params["classifier_params"]["environments"]
    ]
    params["classifier_params"]["discrete_transition_type"] = (
        _convert_dict_to_class(
            params["classifier_params"]["discrete_transition_type"],
            discrete_state_transition_types,
        )
    )
    params["classifier_params"]["initial_conditions_type"] = (
        _convert_dict_to_class(
            params["classifier_params"]["initial_conditions_type"],
            initial_conditions_types,
        )
    )

    if params["classifier_params"].get("observation_models"):
        params["classifier_params"]["observation_models"] = [
            ObservationModel(obs)
            for obs in params["classifier_params"]["observation_models"]
        ]

    return params

convert_classes_to_dict(key)

Converts the classifier parameters into a dictionary so that datajoint can store it.

Source code in src/spyglass/decoding/v0/dj_decoder_conversion.py
def convert_classes_to_dict(key: dict) -> dict:
    """Converts the classifier parameters into a dictionary so that datajoint can store it."""
    try:
        key["classifier_params"]["environments"] = [
            _convert_environment_to_dict(env)
            for env in key["classifier_params"]["environments"]
        ]
    except TypeError:
        key["classifier_params"]["environments"] = [
            _convert_environment_to_dict(
                key["classifier_params"]["environments"]
            )
        ]
    key["classifier_params"]["continuous_transition_types"] = (
        _convert_transitions_to_dict(
            key["classifier_params"]["continuous_transition_types"]
        )
    )
    key["classifier_params"]["discrete_transition_type"] = _to_dict(
        key["classifier_params"]["discrete_transition_type"]
    )
    key["classifier_params"]["initial_conditions_type"] = _to_dict(
        key["classifier_params"]["initial_conditions_type"]
    )

    if key["classifier_params"]["observation_models"] is not None:
        key["classifier_params"]["observation_models"] = [
            vars(obs) for obs in key["classifier_params"]["observation_models"]
        ]

    try:
        key["classifier_params"]["clusterless_algorithm_params"] = (
            _convert_algorithm_params(
                key["classifier_params"]["clusterless_algorithm_params"]
            )
        )
    except KeyError:
        pass

    return key