Converts decoder classes into dictionaries and dictionaries into classes
so that datajoint can store them in tables.
restore_classes(params)
Converts a dictionary of parameters into a dictionary of classes since datajoint cannot handle classes
Source code in src/spyglass/decoding/v0/dj_decoder_conversion.py
| def restore_classes(params: dict) -> dict:
"""Converts a dictionary of parameters into a dictionary of classes since datajoint cannot handle classes"""
continuous_state_transition_types = {
"RandomWalk": RandomWalk,
"RandomWalkDirection1": RandomWalkDirection1,
"RandomWalkDirection2": RandomWalkDirection2,
"Uniform": Uniform,
"Identity": Identity,
}
discrete_state_transition_types = {
"DiagonalDiscrete": DiagonalDiscrete,
"UniformDiscrete": UniformDiscrete,
"RandomDiscrete": RandomDiscrete,
"UserDefinedDiscrete": UserDefinedDiscrete,
}
initial_conditions_types = {
"UniformInitialConditions": UniformInitialConditions,
"UniformOneEnvironmentInitialConditions": UniformOneEnvironmentInitialConditions,
}
params["classifier_params"]["continuous_transition_types"] = [
[
_convert_dict_to_class(st, continuous_state_transition_types)
for st in sts
]
for sts in params["classifier_params"]["continuous_transition_types"]
]
params["classifier_params"]["environments"] = [
_convert_env_dict(env_params)
for env_params in params["classifier_params"]["environments"]
]
params["classifier_params"]["discrete_transition_type"] = (
_convert_dict_to_class(
params["classifier_params"]["discrete_transition_type"],
discrete_state_transition_types,
)
)
params["classifier_params"]["initial_conditions_type"] = (
_convert_dict_to_class(
params["classifier_params"]["initial_conditions_type"],
initial_conditions_types,
)
)
if params["classifier_params"].get("observation_models"):
params["classifier_params"]["observation_models"] = [
ObservationModel(obs)
for obs in params["classifier_params"]["observation_models"]
]
return params
|
convert_classes_to_dict(key)
Converts the classifier parameters into a dictionary so that datajoint can store it.
Source code in src/spyglass/decoding/v0/dj_decoder_conversion.py
| def convert_classes_to_dict(key: dict) -> dict:
"""Converts the classifier parameters into a dictionary so that datajoint can store it."""
try:
key["classifier_params"]["environments"] = [
_convert_environment_to_dict(env)
for env in key["classifier_params"]["environments"]
]
except TypeError:
key["classifier_params"]["environments"] = [
_convert_environment_to_dict(
key["classifier_params"]["environments"]
)
]
key["classifier_params"]["continuous_transition_types"] = (
_convert_transitions_to_dict(
key["classifier_params"]["continuous_transition_types"]
)
)
key["classifier_params"]["discrete_transition_type"] = _to_dict(
key["classifier_params"]["discrete_transition_type"]
)
key["classifier_params"]["initial_conditions_type"] = _to_dict(
key["classifier_params"]["initial_conditions_type"]
)
if key["classifier_params"]["observation_models"] is not None:
key["classifier_params"]["observation_models"] = [
vars(obs) for obs in key["classifier_params"]["observation_models"]
]
try:
key["classifier_params"]["clusterless_algorithm_params"] = (
_convert_algorithm_params(
key["classifier_params"]["clusterless_algorithm_params"]
)
)
except KeyError:
pass
return key
|