Sorted Spikes Decoding¶
The mechanics of decoding with sorted spikes are largely similar to those of decoding with unsorted spikes. You should familiarize yourself with the clusterless decoding tutorial before proceeding with this one.
The elements we will need to decode with sorted spikes are:
PositionGroupSortedSpikesGroupDecodingParametersencoding_intervaldecoding_interval
This time, instead of extracting waveform features, we can proceed directly from the SpikeSortingOutput table to specify which units we want to decode. The rest of the decoding process is the same as before.
from pathlib import Path
import datajoint as dj
dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info
SortedSpikesGroup¶
SortedSpikesGroup is a child table of SpikeSortingOutput in the spikesorting pipeline. It allows us to group the spikesorting results from multiple
sources (e.g. multiple tetrode groups or intervals) into a single entry. Here we will group together the spiking of multiple tetrode groups to use for decoding.
This table allows us filter units by their annotation labels from curation (e.g only include units labeled "good", exclude units labeled "noise") by defining parameters from UnitSelectionParams. When accessing data through SortedSpikesGroup the table will include only units with at least one label in include_labels and no labels in exclude_labels. We can look at those here:
from spyglass.spikesorting.analysis.v1.group import UnitSelectionParams
UnitSelectionParams().insert_default()
# look at the filter set we'll use here
unit_filter_params_name = "default_exclusion"
print(
(
UnitSelectionParams()
& {"unit_filter_params_name": unit_filter_params_name}
).fetch1()
)
# look at full table
UnitSelectionParams()
{'unit_filter_params_name': 'default_exclusion', 'include_labels': [], 'exclude_labels': ['noise', 'mua']}
| unit_filter_params_name | include_labels | exclude_labels |
|---|---|---|
| all_units | =BLOB= | =BLOB= |
| default_exclusion | =BLOB= | =BLOB= |
| exclude_noise | =BLOB= | =BLOB= |
| MS2220180629 | =BLOB= | =BLOB= |
Total: 4
Now we can make our sorted spikes group with this unit selection parameter
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs
nwb_copy_file_name = "mediumnwb20230802_.nwb"
sorter_keys = {
"nwb_file_name": nwb_copy_file_name,
"sorter": "mountainsort4",
"curation_id": 1,
}
# check the set of sorting we'll use
(
sgs.SpikeSortingSelection & sorter_keys
) * SpikeSortingOutput.CurationV1 & sorter_keys
| sorting_id | merge_id | recording_id | sorter | sorter_param_name | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | curation_id |
|---|---|---|---|---|---|---|---|
| 642242ff-5f0e-45a2-bcc1-ca681f37b4a3 | 75286bf3-f876-4550-f235-321f2a7badef | 01c5b8e9-933d-4f1e-9a5d-c494276edb3a | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 0a6611b3-c593-4900-a715-66bb1396940e | 1 |
| a4b5a94d-ba41-4634-92d0-1d31c9daa913 | 143dff79-3779-c0d2-46fe-7c5040404219 | a8a1d29d-ffdf-4370-8b3d-909fef57f9d4 | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 3d782852-a56b-4a9d-89ca-be9e1a15c957 | 1 |
| 874775be-df0f-4850-8f88-59ba1bbead89 | a900c1c8-909d-e583-c377-e98c4f0deebf | 747f4eea-6df3-422b-941e-b5aaad7ec607 | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 9cf9e3cd-7115-4b59-a718-3633725d4738 | 1 |
Total: 3
Finding the merge id's corresponding to an interpretable restriction such as merge_id or interval_list can require several join steps with upstream tables. To simplify this process we can use the included helper function SpikeSortingOutput().get_restricted_merge_ids() to perform the necessary joins and return the matching merge id's
# get the merge_ids for the selected sorting
spikesorting_merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
sorter_keys, restrict_by_artifact=False
)
# create a new sorted spikes group
unit_filter_params_name = "default_exclusion"
SortedSpikesGroup().create_group(
group_name="test_group",
nwb_file_name=nwb_copy_file_name,
keys=[
{"spikesorting_merge_id": merge_id}
for merge_id in spikesorting_merge_ids
],
unit_filter_params_name=unit_filter_params_name,
)
# check the new group
SortedSpikesGroup & {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
}
| nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name |
|---|---|---|
| mediumnwb20230802_.nwb | default_exclusion | test_group |
Total: 1
# look at the sorting within the group we just made
SortedSpikesGroup.Units & {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": unit_filter_params_name,
}
| nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name | spikesorting_merge_id |
|---|---|---|---|
| mediumnwb20230802_.nwb | default_exclusion | test_group | 143dff79-3779-c0d2-46fe-7c5040404219 |
| mediumnwb20230802_.nwb | default_exclusion | test_group | 75286bf3-f876-4550-f235-321f2a7badef |
| mediumnwb20230802_.nwb | default_exclusion | test_group | a900c1c8-909d-e583-c377-e98c4f0deebf |
Total: 3
Model parameters¶
As before we can specify the model parameters. The only difference is that we will use the ContFragSortedSpikesClassifier instead of the ContFragClusterlessClassifier.
from spyglass.decoding.v1.core import DecodingParameters
from non_local_detector.models import ContFragSortedSpikesClassifier
DecodingParameters.insert1(
{
"decoding_param_name": "contfrag_sorted",
"decoding_params": ContFragSortedSpikesClassifier(),
"decoding_kwargs": dict(),
},
skip_duplicates=True,
)
DecodingParameters()
| decoding_param_name a name for this set of parameters | decoding_params initialization parameters for model | decoding_kwargs additional keyword arguments |
|---|---|---|
| contfrag_clusterless | =BLOB= | =BLOB= |
| contfrag_clusterless_0.5.13 | =BLOB= | =BLOB= |
| contfrag_clusterless_6track | =BLOB= | =BLOB= |
| contfrag_sorted | =BLOB= | =BLOB= |
| contfrag_sorted_0.5.13 | =BLOB= | =BLOB= |
| j1620210710_contfrag_clusterless_1D | =BLOB= | =BLOB= |
| j1620210710_test_contfrag_clusterless | =BLOB= | =BLOB= |
| MS2220180629_contfrag_sorted | =BLOB= | =BLOB= |
| ms_lineartrack_2023_contfrag_sorted | =BLOB= | =BLOB= |
| ms_lineartrack_contfrag_clusterless | =BLOB= | =BLOB= |
| ms_lineartrack_contfrag_sorted | =BLOB= | =BLOB= |
| ms_wtrack_2023_contfrag_sorted | =BLOB= | =BLOB= |
...
Total: 15
1D Decoding¶
As in the clusterless notebook, we can decode 1D position if we specify the track_graph, edge_order, and edge_spacing parameters in the Environment class constructor. See the clusterless decoding tutorial for more details.
Decoding¶
Now we can decode the position using the sorted spikes using the SortedSpikesDecodingSelection table. Here we assume that PositionGroup has been specified as in the clusterless decoding tutorial.
selection_key = {
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": "default_exclusion",
"position_group_name": "test_group",
"decoding_param_name": "contfrag_sorted",
"nwb_file_name": "mediumnwb20230802_.nwb",
"encoding_interval": "pos 0 valid times",
"decoding_interval": "test decoding interval",
"estimate_decoding_params": False,
}
from spyglass.decoding import SortedSpikesDecodingSelection
SortedSpikesDecodingSelection.insert1(
selection_key,
skip_duplicates=True,
)
from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1
SortedSpikesDecodingV1.populate(selection_key)
[12:19:30][WARNING] Spyglass: Upsampled position data, frame indices are invalid. Setting add_frame_ind=False No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Encoding models: 0%| | 0/54 [00:00<?, ?cell/s]
Non-Local Likelihood: 0%| | 0/54 [00:00<?, ?cell/s]
We verify that the results have been inserted into the DecodingOutput merge table.
from spyglass.decoding.decoding_merge import DecodingOutput
DecodingOutput.SortedSpikesDecodingV1 & selection_key
| merge_id | nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
|---|---|---|---|---|---|---|---|---|
| 42e9e7f9-a6f2-9242-63ce-94228bc72743 | mediumnwb20230802_.nwb | default_exclusion | test_group | test_group | contfrag_sorted | pos 0 valid times | test decoding interval | 0 |
Total: 1
We can load the results as before:
results = (SortedSpikesDecodingV1 & selection_key).fetch_results()
results
<xarray.Dataset>
Dimensions: (state_ind: 25752, dim_0: 25752, time: 5001,
states: 2, intervals: 1, state_bins: 25752)
Coordinates:
* state_ind (state_ind) int32 0 0 0 0 0 0 0 ... 1 1 1 1 1 1
* time (time) float64 1.626e+09 ... 1.626e+09
* states (states) object 'Continuous' 'Fragmented'
environments (states) object ...
encoding_groups (states) int32 ...
* state_bins (state_bins) object MultiIndex
* state (state_bins) object 'Continuous' ... 'Fragme...
* x_position (state_bins) float64 29.02 29.02 ... 258.8
* y_position (state_bins) float64 5.828 7.811 ... 224.0
Dimensions without coordinates: dim_0, intervals
Data variables:
initial_conditions (dim_0) float64 ...
discrete_state_transitions (states, states) float64 ...
acausal_posterior (intervals, time, state_bins) float32 ...
acausal_state_probabilities (intervals, time, states) float64 ...
Attributes:
marginal_log_likelihoods: -39514.59