Clusterless Decoding¶
Overview¶
Developer Note: if you may make a PR in the future, be sure to copy this
notebook, and use the gitignore
prefix temp
to avoid future conflicts.
This is one notebook in a multi-part series on Spyglass.
- To set up your Spyglass environment and database, see the Setup notebook
- This tutorial assumes you've already extracted waveforms, as well as loaded position data. If 1D decoding, this data should also be linearized.
Clusterless decoding can be performed on either 1D or 2D data. We will start with 2D data.
Elements of Clusterless Decoding¶
- Position Data: This is the data that we want to decode. It can be 1D or 2D.
- Spike Waveform Features: These are the features that we will use to decode the position data.
- Decoding Model Parameters: This is how we define the model that we will use to decode the position data.
Grouping Data¶
An important concept will be groups. Groups are tables that allow use to specify collections of data. We will use groups in two situations here:
- Because we want to decode from more than one tetrode (or probe), so we will create a group that contains all of the tetrodes that we want to decode from.
- Similarly, we will create a group for the position data that we want to decode, so that we can decode from position data from multiple sessions.
Grouping Waveform Features¶
Let's start with grouping the Waveform Features. We will first inspect the waveform features that we have extracted to figure out the primary keys of the data that we want to decode from. We need to use the tables SpikeSortingSelection
and SpikeSortingOutput
to figure out the merge_id
associated with nwb_file_name
to get the waveform features associated with the NWB file of interest.
from pathlib import Path
import datajoint as dj
dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs
from spyglass.decoding.v1.waveform_features import (
UnitWaveformFeaturesSelection,
UnitWaveformFeatures,
)
nwb_copy_file_name = "mediumnwb20230802_.nwb"
sorter_keys = {
"nwb_file_name": nwb_copy_file_name,
"sorter": "clusterless_thresholder",
"sorter_param_name": "default_clusterless",
}
feature_key = {"features_param_name": "amplitude"}
(
UnitWaveformFeaturesSelection.proj(merge_id="spikesorting_merge_id")
* SpikeSortingOutput.CurationV1
* sgs.SpikeSortingSelection
) & SpikeSortingOutput().get_restricted_merge_ids(
sorter_keys, sources=["v1"], as_dict=True
)
merge_id | features_param_name a name for this set of parameters | sorting_id | curation_id | recording_id | sorter | sorter_param_name | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list |
---|---|---|---|---|---|---|---|---|
0233e49a-b849-7eab-7434-9c298eea87b8 | amplitude | 85cb4efd-5dd9-4637-8c47-50927da56ecb | 0 | d6ec337b-f131-47fa-8d04-f152459539ab | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | d4d3d806-13dc-42b9-a149-267fa170aa8f |
07239cea-7578-5409-692c-18c9d26b4d36 | amplitude | 17abb5a3-cc9a-4a7f-8fbf-ae3bcffad239 | 0 | 9b34c86e-f2d0-4c6c-a7b8-302ef30b0fff | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 24608f0d-ffca-4f56-8dd3-a274b7248b63 |
08be9775-370d-6492-0b4e-a5db4ce7a128 | amplitude | 2056130f-b8c9-46d1-9c27-4287d237f63f | 0 | e9ea1b3c-6e7b-4960-a593-0dd6d5ab0990 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | c96e245d-efef-4ab6-b549-683270857dbb |
11819f33-11d5-f0f8-2590-ce3d60b76f3a | amplitude | 71add870-7efe-4e64-b5fc-079c7b6d4a8a | 0 | 8f4b5933-7f9d-4ca1-a262-9a7978630101 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 9d5a025a-2b46-47b3-94f4-70d58db68e60 |
1c2ea289-2e7f-dcda-0464-ce97d3d6a392 | amplitude | 46b8a445-1513-44ce-8a14-d1c9dec80d74 | 0 | 0d247564-2302-4ace-9157-c3891eceaf2c | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 56cbb21e-8fe8-4f4a-b2b0-537ad6039543 |
20f24092-d191-0c58-55c8-d43d453f9fd4 | amplitude | aec60cb7-017c-42ed-91be-0fb2a5f75948 | 0 | 747f4eea-6df3-422b-941e-b5aaad7ec607 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 65009b63-5830-45b5-9954-cd5341aa8cef |
2598b48e-49a0-3389-dd15-0230e8d326e4 | amplitude | e26863d0-7a77-455c-b687-0af1bd626486 | 0 | 34ea9dd3-b728-4bd3-872c-7a4e37fb2ac9 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | e4daaf56-e40d-41d3-8523-097237d98bbd |
483055a5-9775-27b7-856e-01543bd920aa | amplitude | 9af6681f-2e37-496e-823e-7acbdd436a27 | 0 | 73c9e01c-b37c-41a2-8571-0df13c32bf76 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 3da02b84-1a7f-4f2a-81bf-2e92c4d88e96 |
50ae3f7e-65a8-5fc2-5304-ab534b90fa46 | amplitude | 2483d0c7-4cfe-4d6f-8dd6-2e13a8289d94 | 0 | 03cc7709-66e7-47ac-a3bd-63add028d9f8 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 8cfc1ccb-8de3-4eee-9e18-f8b8f5c45821 |
50b29d01-2d74-e37e-2842-ad56d833c5f9 | amplitude | 1dcecaac-8e0d-4d18-8296-cdb50eef9506 | 0 | d8a8c564-13c7-4fab-9a33-1eac416869da | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 96678676-89dd-42e4-89f6-ce56c618ce83 |
5e756e76-68be-21b7-7764-cb78d9aa4ef8 | amplitude | 552176ab-d870-41c4-8621-07e71f6e9a19 | 0 | fa4faf43-e747-43ca-b8a5-53a02d7938ec | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 07036486-e9f5-4dba-8662-7fb5ff2a6711 |
67f156e1-5da7-9c89-03b1-cc2dba88dacd | amplitude | 8f45b210-c8f9-4a27-96c2-9b85f16b3451 | 0 | 30895f0f-1eec-481d-b763-edae7667ef00 | clusterless_thresholder | default_clusterless | mediumnwb20230802_.nwb | 22fb2b64-fc3c-44af-a8c1-dacc9010beab |
...
Total: 26
from spyglass.decoding.v1.waveform_features import UnitWaveformFeaturesSelection
# find the merge ids that correspond to the sorter key restrictions
merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
sorter_keys, sources=["v1"], as_dict=True
)
# find the previously populated waveform selection keys that correspond to these sorts
waveform_selection_keys = (
UnitWaveformFeaturesSelection().proj(merge_id="spikesorting_merge_id")
& merge_ids
& feature_key
).fetch(as_dict=True)
for key in waveform_selection_keys:
key["spikesorting_merge_id"] = key.pop("merge_id")
UnitWaveformFeaturesSelection & waveform_selection_keys
spikesorting_merge_id | features_param_name a name for this set of parameters |
---|---|
0233e49a-b849-7eab-7434-9c298eea87b8 | amplitude |
07239cea-7578-5409-692c-18c9d26b4d36 | amplitude |
08be9775-370d-6492-0b4e-a5db4ce7a128 | amplitude |
11819f33-11d5-f0f8-2590-ce3d60b76f3a | amplitude |
1c2ea289-2e7f-dcda-0464-ce97d3d6a392 | amplitude |
20f24092-d191-0c58-55c8-d43d453f9fd4 | amplitude |
2598b48e-49a0-3389-dd15-0230e8d326e4 | amplitude |
483055a5-9775-27b7-856e-01543bd920aa | amplitude |
50ae3f7e-65a8-5fc2-5304-ab534b90fa46 | amplitude |
50b29d01-2d74-e37e-2842-ad56d833c5f9 | amplitude |
5e756e76-68be-21b7-7764-cb78d9aa4ef8 | amplitude |
67f156e1-5da7-9c89-03b1-cc2dba88dacd | amplitude |
...
Total: 26
We will create a group called test_group
that contains all of the tetrodes that we want to decode from. We will use the create_group
function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.
from spyglass.decoding.v1.clusterless import UnitWaveformFeaturesGroup
UnitWaveformFeaturesGroup().create_group(
nwb_file_name=nwb_copy_file_name,
group_name="test_group",
keys=waveform_selection_keys,
)
UnitWaveformFeaturesGroup & {"waveform_features_group_name": "test_group"}
nwb_file_name name of the NWB file | waveform_features_group_name |
---|---|
mediumnwb20230802_.nwb | test_group |
Total: 1
We can see that we successfully associated "test_group" with the tetrodes that we want to decode from by using the get_group
function.
UnitWaveformFeaturesGroup.UnitFeatures & {
"nwb_file_name": nwb_copy_file_name,
"waveform_features_group_name": "test_group",
}
nwb_file_name name of the NWB file | waveform_features_group_name | spikesorting_merge_id | features_param_name a name for this set of parameters |
---|---|---|---|
mediumnwb20230802_.nwb | test_group | 0751a1e1-a406-7f87-ae6f-ce4ffc60621c | amplitude |
mediumnwb20230802_.nwb | test_group | 485a4ddf-332d-35b5-3ad4-0561736c1844 | amplitude |
mediumnwb20230802_.nwb | test_group | 4a712103-c223-864f-82e0-6c23de79cc14 | amplitude |
mediumnwb20230802_.nwb | test_group | 4a72c253-b3ca-8c13-e615-736a7ebff35c | amplitude |
mediumnwb20230802_.nwb | test_group | 5c53bd33-d57c-fbba-e0fb-55e0bcb85d03 | amplitude |
mediumnwb20230802_.nwb | test_group | 614d796c-0b95-6364-aaa0-b6cb1e7bbb83 | amplitude |
mediumnwb20230802_.nwb | test_group | 6acb99b8-6a0c-eb83-1141-5f603c5895e0 | amplitude |
mediumnwb20230802_.nwb | test_group | 6d039a63-17ad-0b78-4b1e-f02d5f3dbbc5 | amplitude |
mediumnwb20230802_.nwb | test_group | 74e10781-1228-4075-0870-af224024ffdc | amplitude |
mediumnwb20230802_.nwb | test_group | 7e3fa66e-727e-1541-819a-b01309bb30ae | amplitude |
mediumnwb20230802_.nwb | test_group | 86897349-ff68-ac72-02eb-739dd88936e6 | amplitude |
mediumnwb20230802_.nwb | test_group | 8bbddc0f-d6ae-6260-9400-f884a6e25ae8 | amplitude |
...
Total: 23
Grouping Position Data¶
We will now create a group called 02_r1
that contains all of the position data that we want to decode from. As before, we will use the create_group
function to create this group. This function takes two arguments: the name of the group, and the keys of the tables that we want to include in the group.
We use the the PositionOutput
table to figure out the merge_id
associated with nwb_file_name
to get the position data associated with the NWB file of interest. In this case, we only have one position to insert, but we could insert multiple positions if we wanted to decode from multiple sessions.
Note that we can use the upsample_rate
parameter to define the rate to which position data will be upsampled to to for decoding in Hz. This is useful if we want to decode at a finer time scale than the position data sampling frequency. In practice, a value of 500Hz is used in many analyses. Skipping or providing a null value for this parameter will default to using the position sampling rate.
You will also want to specify the name of the position variables if they are different from the default names. The default names are position_x
and position_y
.
from spyglass.position import PositionOutput
import spyglass.position as sgp
sgp.v1.TrodesPosParams.insert1(
{
"trodes_pos_params_name": "default_decoding",
"params": {
"max_LED_separation": 9.0,
"max_plausible_speed": 300.0,
"position_smoothing_duration": 0.125,
"speed_smoothing_std_dev": 0.100,
"orient_smoothing_std_dev": 0.001,
"led1_is_front": 1,
"is_upsampled": 1,
"upsampling_sampling_rate": 250,
"upsampling_interpolation_method": "linear",
},
},
skip_duplicates=True,
)
trodes_s_key = {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 0 valid times",
"trodes_pos_params_name": "default_decoding",
}
sgp.v1.TrodesPosSelection.insert1(
trodes_s_key,
skip_duplicates=True,
)
sgp.v1.TrodesPosV1.populate(trodes_s_key)
PositionOutput.TrodesPosV1 & trodes_s_key
/Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/pynwb/ecephys.py:90: UserWarning: ElectricalSeries 'e-series': The second dimension of data does not match the length of electrodes. Your data may be transposed. warnings.warn("%s '%s': The second dimension of data does not match the length of electrodes. " /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/pynwb/base.py:193: UserWarning: TimeSeries 'analog': Length of data does not match length of timestamps. Your data may be transposed. Time should be on the 0th dimension warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. " [10:24:13][INFO] Spyglass: Writing new NWB file mediumnwb20230802_FUSH604NQA.nwb
Computing position for: {'nwb_file_name': 'mediumnwb20230802_.nwb', 'interval_list_name': 'pos 0 valid times', 'trodes_pos_params_name': 'default_decoding'}
[2024-01-29 10:24:13,819][WARNING]: Skipped checksum for file with hash: f05fc782-7d7e-7835-5aef-bc4f5837358b, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/raw/mediumnwb20230802_.nwb [2024-01-29 10:24:13,821][WARNING]: Skipped checksum for file with hash: f05fc782-7d7e-7835-5aef-bc4f5837358b, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/raw/mediumnwb20230802_.nwb /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/pynwb/ecephys.py:90: UserWarning: ElectricalSeries 'e-series': The second dimension of data does not match the length of electrodes. Your data may be transposed. warnings.warn("%s '%s': The second dimension of data does not match the length of electrodes. " /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/pynwb/base.py:193: UserWarning: TimeSeries 'analog': Length of data does not match length of timestamps. Your data may be transposed. Time should be on the 0th dimension warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. " [2024-01-29 10:24:13,996][WARNING]: Skipped checksum for file with hash: f05fc782-7d7e-7835-5aef-bc4f5837358b, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/raw/mediumnwb20230802_.nwb [2024-01-29 10:24:13,998][WARNING]: Skipped checksum for file with hash: f05fc782-7d7e-7835-5aef-bc4f5837358b, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/raw/mediumnwb20230802_.nwb [10:24:14][INFO] Spyglass: No video frame index found. Assuming all camera frames are present.
merge_id | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | trodes_pos_params_name name for this set of parameters |
---|---|---|---|
6dfae23d-6034-e483-06e7-28ab4c29282f | mediumnwb20230802_.nwb | pos 0 valid times | default_decoding |
Total: 1
from spyglass.decoding.v1.core import PositionGroup
position_merge_ids = (
PositionOutput.TrodesPosV1
& {
"nwb_file_name": nwb_copy_file_name,
"interval_list_name": "pos 0 valid times",
"trodes_pos_params_name": "default_decoding",
}
).fetch("merge_id")
PositionGroup().create_group(
nwb_file_name=nwb_copy_file_name,
group_name="test_group",
keys=[{"pos_merge_id": merge_id} for merge_id in position_merge_ids],
upsample_rate=500,
)
PositionGroup & {
"nwb_file_name": nwb_copy_file_name,
"position_group_name": "test_group",
}
nwb_file_name name of the NWB file | position_group_name | position_variables list of position variables to decode |
---|---|---|
mediumnwb20230802_.nwb | test_group | =BLOB= |
Total: 1
(
PositionGroup
& {"nwb_file_name": nwb_copy_file_name, "position_group_name": "test_group"}
).fetch1("position_variables")
['position_x', 'position_y']
PositionGroup.Position & {
"nwb_file_name": nwb_copy_file_name,
"position_group_name": "test_group",
}
nwb_file_name name of the NWB file | position_group_name | pos_merge_id |
---|---|---|
mediumnwb20230802_.nwb | test_group | 6dfae23d-6034-e483-06e7-28ab4c29282f |
Total: 1
Decoding Model Parameters¶
We will use the non_local_detector
package to decode the data. This package is highly flexible and allows several different types of models to be used. In this case, we will use the ContFragClusterlessClassifier
to decode the data. This has two discrete states: Continuous and Fragmented, which correspond to different types of movement models. To read more about this model, see:
Denovellis, E.L., Gillespie, A.K., Coulter, M.E., Sosa, M., Chung, J.E., Eden, U.T., and Frank, L.M. (2021). Hippocampal replay of experience at real-world speeds. eLife 10, e64505. 10.7554/eLife.64505.
Let's first look at the model and the default parameters:
from non_local_detector.models import ContFragClusterlessClassifier
ContFragClusterlessClassifier()
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 6.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use... environments=(Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),), infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 6.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use... environments=(Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),), infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])
You can change these parameters like so:
from non_local_detector.models import ContFragClusterlessClassifier
ContFragClusterlessClassifier(
clusterless_algorithm_params={
"block_size": 10000,
"position_std": 12.0,
"waveform_std": 24.0,
},
)
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 12.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, us... environments=(Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),), infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 12.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, us... environments=(Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0),), infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])
This is how to insert the model parameters into the database:
from spyglass.decoding.v1.core import DecodingParameters
DecodingParameters.insert1(
{
"decoding_param_name": "contfrag_clusterless",
"decoding_params": ContFragClusterlessClassifier(),
"decoding_kwargs": dict(),
},
skip_duplicates=True,
)
DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}
decoding_param_name a name for this set of parameters | decoding_params initialization parameters for model | decoding_kwargs additional keyword arguments |
---|---|---|
contfrag_clusterless | =BLOB= | =BLOB= |
Total: 1
We can retrieve these parameters and rebuild the model like so:
model_params = (
DecodingParameters & {"decoding_param_name": "contfrag_clusterless"}
).fetch1()
ContFragClusterlessClassifier(**model_params["decoding_params"])
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 6.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use... environments=[Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0)], infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ContFragClusterlessClassifier(clusterless_algorithm='clusterless_kde', clusterless_algorithm_params={'block_size': 10000, 'position_std': 6.0, 'waveform_std': 24.0}, continuous_initial_conditions_types=[UniformInitialConditions(), UniformInitialConditions()], continuous_transition_types=[[RandomWalk(environment_name='', movement_var=6.0, movement_mean=0.0, use... environments=[Environment(environment_name='', place_bin_size=2.0, track_graph=None, edge_order=None, edge_spacing=None, is_track_interior=None, position_range=None, infer_track_interior=True, fill_holes=False, dilate=False, bin_count_threshold=0)], infer_track_interior=True, no_spike_rate=1e-10, observation_models=[ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False), ObservationModel(environment_name='', encoding_group=0, is_local=False, is_no_spike=False)], sampling_frequency=500.0, state_names=['Continuous', 'Fragmented'])
1D Decoding¶
If you want to do 1D decoding, you will need to specify the track_graph
, edge_order
, and edge_spacing
in the environments
parameter. You can read more about these parameters in the linearization notebook. You can retrieve these parameters from the TrackGraph
table if you have stored them there. These will then go into the environments
parameter of the ContFragClusterlessClassifier
model.
from non_local_detector.environment import Environment
?Environment
Init signature: Environment( environment_name: str = '', place_bin_size: Union[float, Tuple[float]] = 2.0, track_graph: Optional[networkx.classes.graph.Graph] = None, edge_order: Optional[tuple] = None, edge_spacing: Optional[tuple] = None, is_track_interior: Optional[numpy.ndarray] = None, position_range: Optional[numpy.ndarray] = None, infer_track_interior: bool = True, fill_holes: bool = False, dilate: bool = False, bin_count_threshold: int = 0, ) -> None Docstring: Represent the spatial environment with a discrete grid. Parameters ---------- environment_name : str, optional place_bin_size : float, optional Approximate size of the position bins. track_graph : networkx.Graph, optional Graph representing the 1D spatial topology edge_order : tuple of 2-tuples, optional The order of the edges in 1D space edge_spacing : None or int or tuples of len n_edges-1, optional Any gapes between the edges in 1D space is_track_interior : np.ndarray or None, optional If given, this will be used to define the valid areas of the track. Must be of type boolean. position_range : sequence, optional A sequence of `n_position_dims`, each an optional (lower, upper) tuple giving the outer bin edges for position. An entry of None in the sequence results in the minimum and maximum values being used for the corresponding dimension. The default, None, is equivalent to passing a tuple of `n_position_dims` None values. infer_track_interior : bool, optional If True, then use the given positions to figure out the valid track areas. fill_holes : bool, optional Fill holes when inferring the track dilate : bool, optional Inflate the available track area with binary dilation bin_count_threshold : int, optional Greater than this number of samples should be in the bin for it to be considered on the track. File: ~/miniconda3/envs/spyglass/lib/python3.9/site-packages/non_local_detector/environment.py Type: type Subclasses:
Decoding¶
Now that we have grouped the data and defined the model parameters, we have finally set up the elements in tables that we need to decode the data. We now need to use the ClusterlessDecodingSelection
to fully specify all the parameters and data that we want.
This has:
waveform_features_group_name
: the name of the group that contains the waveform features that we want to decode fromposition_group_name
: the name of the group that contains the position data that we want to decode fromdecoding_param_name
: the name of the decoding parameters that we want to usenwb_file_name
: the name of the NWB file that we want to decode fromencoding_interval
: the interval of time that we want to train the initial model ondecoding_interval
: the interval of time that we want to decode fromestimate_decoding_params
: whether or not we want to estimate the decoding parameters
The first three parameters should be familiar to you.
Decoding and Encoding Intervals¶
The encoding_interval
is the interval of time that we want to train the initial model on. The decoding_interval
is the interval of time that we want to decode from. These two intervals can be the same, but they do not have to be. For example, we may want to train the model on a long interval of time, but only decode from a short interval of time. This is useful if we want to decode from a short interval of time that is not representative of the entire session. In this case, we will train the model on a longer interval of time that is representative of the entire session.
These keys come from the IntervalList
table. We can see that the IntervalList
table contains the nwb_file_name
and interval_name
that we need to specify the encoding_interval
and decoding_interval
. We will specify a short decoding interval called test decoding interval
and use that to decode from.
Estimating Decoding Parameters¶
The last parameter is estimate_decoding_params
. This is a boolean that specifies whether or not we want to estimate the decoding parameters. If this is True
, then we will estimate the initial conditions and discrete transition matrix from the data.
NOTE: If estimating parameters, then we need to treat times outside decoding interval as missing. this means that times outside the decoding interval will not use the spiking data and only the state transition matrix and previous time step will be used. This may or may not be desired depending on the length of this missing interval.
from spyglass.decoding.v1.clusterless import ClusterlessDecodingSelection
ClusterlessDecodingSelection()
nwb_file_name name of the NWB file | waveform_features_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
---|---|---|---|---|---|---|
Total: 0
from spyglass.common import IntervalList
IntervalList & {"nwb_file_name": nwb_copy_file_name}
nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | valid_times numpy array with start/end times for each interval | pipeline type of interval list (e.g. 'position', 'spikesorting_recording_v1') |
---|---|---|---|
mediumnwb20230802_.nwb | 02_r1 | =BLOB= | |
mediumnwb20230802_.nwb | 04f3ecb4-a18c-4ffb-85d8-2f5f62d4d6d4 | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 0e848c38-9105-4ea4-b6ba-dbdd5b46a088 | =BLOB= | spikesorting_artifact_v1 |
mediumnwb20230802_.nwb | 0f91197e-bebb-4dc6-ad41-5bf89c3eed28 | =BLOB= | spikesorting_artifact_v1 |
mediumnwb20230802_.nwb | 15c8a3e8-5ce9-4654-891e-6ee4109d6f1a | =BLOB= | spikesorting_artifact_v1 |
mediumnwb20230802_.nwb | 1d2b5966-415a-4c65-955a-0e422d8b5b00 | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 1e3f3707-613e-4a44-93f1-c7e5484112cd | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 2402805a-04f9-4a88-9ccf-071376c8de19 | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 24107d8c-ce26-4c77-8f6a-bf6955d8a3c7 | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 257c077b-8f3b-4abb-a631-6b8084d6a1ea | =BLOB= | spikesorting_recording_v1 |
mediumnwb20230802_.nwb | 2b93bcd0-7b05-457c-8aab-c41ef543ecf2 | =BLOB= | spikesorting_artifact_v1 |
mediumnwb20230802_.nwb | 2b9fbf14-74a0-4294-a805-26702340aac9 | =BLOB= | spikesorting_artifact_v1 |
...
Total: 52
decoding_interval_valid_times = [
[1625935714.6359036, 1625935714.6359036 + 15.0]
]
IntervalList.insert1(
{
"nwb_file_name": "mediumnwb20230802_.nwb",
"interval_list_name": "test decoding interval",
"valid_times": decoding_interval_valid_times,
},
skip_duplicates=True,
)
Once we have figured out the keys that we need, we can insert the ClusterlessDecodingSelection
into the database.
selection_key = {
"waveform_features_group_name": "test_group",
"position_group_name": "test_group",
"decoding_param_name": "contfrag_clusterless",
"nwb_file_name": nwb_copy_file_name,
"encoding_interval": "pos 0 valid times",
"decoding_interval": "test decoding interval",
"estimate_decoding_params": False,
}
ClusterlessDecodingSelection.insert1(
selection_key,
skip_duplicates=True,
)
ClusterlessDecodingSelection & selection_key
nwb_file_name name of the NWB file | waveform_features_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
---|---|---|---|---|---|---|
mediumnwb20230802_.nwb | test_group | test_group | contfrag_clusterless | pos 0 valid times | test decoding interval | 0 |
Total: 1
ClusterlessDecodingSelection()
nwb_file_name name of the NWB file | waveform_features_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
---|---|---|---|---|---|---|
mediumnwb20230802_.nwb | test_group | test_group | contfrag_clusterless | pos 0 valid times | test decoding interval | 0 |
Total: 1
To run decoding, we simply populate the ClusterlessDecodingOutput
table. This will run the decoding and insert the results into the database. We can then retrieve the results from the database.
from spyglass.decoding.v1.clusterless import ClusterlessDecodingV1
ClusterlessDecodingV1.populate(selection_key)
[10:24:17][WARNING] Spyglass: Upsampled position data, frame indices are invalid. Setting add_frame_ind=False [2024-01-29 10:24:17,234][WARNING]: Skipped checksum for file with hash: 0cd40383-03e0-44ec-5dac-36c66063796a, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_FUSH604NQA.nwb [2024-01-29 10:24:17,409][WARNING]: Skipped checksum for file with hash: a7c9b1d9-d1a2-7f40-9127-206e83a87006, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_NQEPSMKPK0.nwb [2024-01-29 10:24:17,411][WARNING]: Skipped checksum for file with hash: ec7faa5b-3847-6649-1a93-74ebd50dcfb9, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_F02UG5Z5FR.nwb [2024-01-29 10:24:17,413][WARNING]: Skipped checksum for file with hash: 8e964932-96ab-e1c9-2133-edce8eacab5f, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_OTV91MLKDT.nwb [2024-01-29 10:24:17,415][WARNING]: Skipped checksum for file with hash: 895bac7b-bfd6-b4f2-b2ad-460362aaafa8, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_TSPNTCGNN1.nwb [2024-01-29 10:24:17,417][WARNING]: Skipped checksum for file with hash: 58713583-cf49-4527-7707-105f9c9ee477, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_QSK70WFDJH.nwb [2024-01-29 10:24:17,419][WARNING]: Skipped checksum for file with hash: a64829f8-ab12-fecc-eda9-a22b90b20d43, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_DO45HKXYTB.nwb [2024-01-29 10:24:17,420][WARNING]: Skipped checksum for file with hash: 3a580271-9126-8e57-048e-a7bbb3f917b9, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_KFIYRJ4HFO.nwb [2024-01-29 10:24:17,423][WARNING]: Skipped checksum for file with hash: 13cf8ad9-023c-c9b7-05c3-eaa3330304f2, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_0YIM5K3H47.nwb [2024-01-29 10:24:17,425][WARNING]: Skipped checksum for file with hash: 7ce8a640-0a25-4866-6d5a-aa2c65f0aca5, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_CTLEGE2TWZ.nwb [2024-01-29 10:24:17,427][WARNING]: Skipped checksum for file with hash: aa657f4f-f409-d444-8b32-31d37abe0797, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_7EN0N1U4U1.nwb [2024-01-29 10:24:17,429][WARNING]: Skipped checksum for file with hash: f3b4bd22-1439-e6d2-4e15-aa3650143fdf, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_DHKWBWWAMC.nwb [2024-01-29 10:24:17,430][WARNING]: Skipped checksum for file with hash: 68eac0b2-e5be-e0c5-9eae-cd8dbe6676a8, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_PEN0D79Q0B.nwb [2024-01-29 10:24:17,432][WARNING]: Skipped checksum for file with hash: c8b95099-2cb3-df0b-5ab1-7a5e120a8e2f, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_WP7SIXDJ2A.nwb [2024-01-29 10:24:17,434][WARNING]: Skipped checksum for file with hash: 8fae8089-f683-5f0a-4e59-c71d6ee14f38, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_B82OS6W1QA.nwb [2024-01-29 10:24:17,437][WARNING]: Skipped checksum for file with hash: dd9d0f51-6445-b368-32bd-b1f142bf6ed3, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_XO17FQLN6T.nwb [2024-01-29 10:24:17,439][WARNING]: Skipped checksum for file with hash: 4e2cf5f5-ff7c-1a2b-db85-2d1c4f036fbd, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_OCFI0GFLZ9.nwb [2024-01-29 10:24:17,441][WARNING]: Skipped checksum for file with hash: 8691c252-0bd1-122b-8cf3-b89c4d0fdee0, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_60M9VSZX0W.nwb [2024-01-29 10:24:17,443][WARNING]: Skipped checksum for file with hash: 57b89835-8edb-e91d-0798-09d22fb4fbc9, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_Z5HJ68LHYW.nwb [2024-01-29 10:24:17,445][WARNING]: Skipped checksum for file with hash: 54401121-4426-86c9-72f7-e056bc16e99d, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_U5U5JVGY4F.nwb [2024-01-29 10:24:17,447][WARNING]: Skipped checksum for file with hash: 0ff21e84-2214-6911-2575-a9c92a541407, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_0D5Z0NSIP8.nwb [2024-01-29 10:24:17,449][WARNING]: Skipped checksum for file with hash: 0949b006-5309-93c8-fd8b-1308e8130869, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_EYV2NARUKU.nwb [2024-01-29 10:24:17,451][WARNING]: Skipped checksum for file with hash: b4b31e50-dfa2-0d02-514a-525782a81255, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_T4XBCIW44T.nwb [2024-01-29 10:24:17,453][WARNING]: Skipped checksum for file with hash: c18a9ac4-06bc-4249-2bad-439d4f618421, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_UD55CR8LZK.nwb
Encoding models: 0%| | 0/23 [00:00<?, ?electrode/s]
Non-Local Likelihood: 0%| | 0/23 [00:00<?, ?electrode/s]
/Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/non_local_detector/models/base.py:780: FutureWarning: the `pandas.MultiIndex` object(s) passed as 'state_bins' coordinate(s) or data variable(s) will no longer be implicitly promoted and wrapped into multiple indexed coordinates in the future (i.e., one coordinate for each multi-index level + one dimension coordinate). If you want to keep this behavior, you need to first wrap it explicitly using `mindex_coords = xarray.Coordinates.from_pandas_multiindex(mindex_obj, 'dim')` and pass it as coordinates, e.g., `xarray.Dataset(coords=mindex_coords)`, `dataset.assign_coords(mindex_coords)` or `dataarray.assign_coords(mindex_coords)`. results = xr.Dataset( /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/xarray/namedarray/core.py:487: UserWarning: Duplicate dimension names present: dimensions {'states'} appear more than once in dims=('states', 'states'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``. warnings.warn(
We can now see it as an entry in the DecodingOutput
table.
from spyglass.decoding.decoding_merge import DecodingOutput
DecodingOutput.ClusterlessDecodingV1 & selection_key
merge_id | nwb_file_name name of the NWB file | waveform_features_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
---|---|---|---|---|---|---|---|
b63395dd-402e-270a-a8d1-7aabaf83d452 | mediumnwb20230802_.nwb | test_group | test_group | contfrag_clusterless | pos 0 valid times | test decoding interval | 0 |
Total: 1
We can load the results of the decoding:
decoding_results = (ClusterlessDecodingV1 & selection_key).fetch_results()
decoding_results
[2024-01-29 10:26:49,467][WARNING]: Skipped checksum for file with hash: 10c77056-5508-ace0-bd84-5a4d7497f7a9, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_c17fbf1c-67bd-4d5e-b179-83f501713b9c.nc /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/xarray/namedarray/core.py:487: UserWarning: Duplicate dimension names present: dimensions {'states'} appear more than once in dims=('states', 'states'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``. warnings.warn( /Users/edeno/miniconda3/envs/spyglass/lib/python3.9/site-packages/xarray/namedarray/core.py:487: UserWarning: Duplicate dimension names present: dimensions {'states'} appear more than once in dims=('states', 'states'). We do not yet support duplicate dimension names, but we do allow initial construction of the object. We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``. warnings.warn(
<xarray.Dataset> Dimensions: (state_ind: 26668, dim_0: 26668, time: 3750, states: 2, intervals: 1, state_bins: 26668) Coordinates: * state_ind (state_ind) int32 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 * time (time) float64 1.626e+09 ... 1.626e+09 * states (states) object 'Continuous' 'Fragmented' environments (states) object ... encoding_groups (states) int32 ... * state_bins (state_bins) object MultiIndex * state (state_bins) object 'Continuous' ... 'Fragme... * x_position (state_bins) float64 29.02 29.02 ... 262.7 * y_position (state_bins) float64 0.5211 2.516 ... 224.0 Dimensions without coordinates: dim_0, intervals Data variables: initial_conditions (dim_0) float64 ... discrete_state_transitions (states, states) float64 ... acausal_posterior (intervals, time, state_bins) float32 ... acausal_state_probabilities (intervals, time, states) float64 ... Attributes: marginal_log_likelihoods: -159596.89
Finally, if we deleted the results, we can use the cleanup
function to delete the results from the file system:
DecodingOutput().cleanup()
[10:26:49][INFO] Spyglass: Cleaning up decoding outputs [2024-01-29 10:26:49,699][WARNING]: Skipped checksum for file with hash: 10c77056-5508-ace0-bd84-5a4d7497f7a9, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_c17fbf1c-67bd-4d5e-b179-83f501713b9c.nc [10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_259e8498-1eb6-4e84-a0f6-7575c4ab9b87.nc [10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_8404dc49-081c-48c7-b448-34767512e8ed.nc [2024-01-29 10:26:49,774][WARNING]: Skipped checksum for file with hash: c4a577dc-6f11-bd71-cc5f-e131c6eaa39f, and path: /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_c17fbf1c-67bd-4d5e-b179-83f501713b9c.pkl [10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_259e8498-1eb6-4e84-a0f6-7575c4ab9b87.pkl [10:26:49][INFO] Spyglass: Removing /Users/edeno/Documents/GitHub/spyglass/DATA/analysis/mediumnwb20230802/mediumnwb20230802_8404dc49-081c-48c7-b448-34767512e8ed.pkl
Visualization of decoding output.¶
The output of decoding can be challenging to visualize with static graphs, especially if the decoding is performed on 2D data.
We can interactively visualize the output of decoding using the figurl package. This package allows to create a visualization of the decoding output that can be viewed in a web browser. This is useful for exploring the decoding output over time and sharing the results with others.
NOTE: You will need a kachery cloud instance to use this feature. If you are a member of the Frank lab, you should have access to the Frank lab kachery cloud instance. If you are not a member of the Frank lab, you can create your own kachery cloud instance by following the instructions here.
For each user, you will need to run kachery-cloud-init
in the terminal and follow the instructions to associate your computer with your GitHub user on the kachery-cloud network.
# from non_local_detector.visualization import (
# create_interactive_2D_decoding_figurl,
# )
# (
# position_info,
# position_variable_names,
# ) = ClusterlessDecodingV1.fetch_position_info(selection_key)
# results_time = decoding_results.acausal_posterior.isel(intervals=0).time.values
# position_info = position_info.loc[results_time[0] : results_time[-1]]
# env = ClusterlessDecodingV1.fetch_environments(selection_key)[0]
# spike_times, _ = ClusterlessDecodingV1.fetch_spike_data(selection_key)
# create_interactive_2D_decoding_figurl(
# position_time=position_info.index.to_numpy(),
# position=position_info[position_variable_names],
# env=env,
# results=decoding_results,
# posterior=decoding_results.acausal_posterior.isel(intervals=0)
# .unstack("state_bins")
# .sum("state"),
# spike_times=spike_times,
# head_dir=position_info["orientation"],
# speed=position_info["speed"],
# )
GPUs¶
We can use GPUs for decoding which will result in a significant speedup. This is achieved using the jax package.
Ensuring jax can find a GPU¶
Assuming you've set up a GPU, we can use jax.devices()
to make sure the decoding code can see the GPU. If a GPU is available, it will be listed.
In the following instance, we do not have a GPU:
import jax
jax.devices()
[CpuDevice(id=0)]
Selecting a GPU¶
If you do have multiple GPUs, you can use the jax
package to set the device (GPU) that you want to use. For example, if you want to use the second GPU, you can use the following code (uncomment first):
# device_id = 2
# device = jax.devices()[device_id]
# jax.config.update("jax_default_device", device)
# device
Monitoring GPU Usage¶
You can see which GPUs are occupied (if you have multiple GPUs) by running the command nvidia-smi
in
a terminal (or !nvidia-smi
in a notebook). Pick a GPU with low memory usage.
We can monitor GPU use with the terminal command watch -n 0.1 nvidia-smi
, will
update nvidia-smi
every 100 ms. This won't work in a notebook, as it won't
display the updates.
Other ways to monitor GPU usage are:
- A jupyter widget by nvidia to monitor GPU usage in the notebook
- A terminal program like nvidia-smi with more information about which GPUs are being utilized and by whom.