Sorted Spikes Decoding¶
The mechanics of decoding with sorted spikes are largely similar to those of decoding with unsorted spikes. You should familiarize yourself with the clusterless decoding tutorial before proceeding with this one.
The elements we will need to decode with sorted spikes are:
PositionGroup
SortedSpikesGroup
DecodingParameters
encoding_interval
decoding_interval
This time, instead of extracting waveform features, we can proceed directly from the SpikeSortingOutput table to specify which units we want to decode. The rest of the decoding process is the same as before.
from pathlib import Path
import datajoint as dj
dj.config.load(
Path("../dj_local_conf.json").absolute()
) # load config for database connection info
SortedSpikesGroup¶
SortedSpikesGroup
is a child table of SpikeSortingOutput
in the spikesorting pipeline. It allows us to group the spikesorting results from multiple
sources (e.g. multiple tetrode groups or intervals) into a single entry. Here we will group together the spiking of multiple tetrode groups to use for decoding.
This table allows us filter units by their annotation labels from curation (e.g only include units labeled "good", exclude units labeled "noise") by defining parameters from UnitSelectionParams
. When accessing data through SortedSpikesGroup
the table will include only units with at least one label in include_labels
and no labels in exclude_labels
. We can look at those here:
from spyglass.spikesorting.analysis.v1.group import UnitSelectionParams
UnitSelectionParams().insert_default()
# look at the filter set we'll use here
unit_filter_params_name = "default_exclusion"
print(
(
UnitSelectionParams()
& {"unit_filter_params_name": unit_filter_params_name}
).fetch1()
)
# look at full table
UnitSelectionParams()
{'unit_filter_params_name': 'default_exclusion', 'include_labels': [], 'exclude_labels': ['noise', 'mua']}
unit_filter_params_name | include_labels | exclude_labels |
---|---|---|
all_units | =BLOB= | =BLOB= |
default_exclusion | =BLOB= | =BLOB= |
exclude_noise | =BLOB= | =BLOB= |
MS2220180629 | =BLOB= | =BLOB= |
Total: 4
Now we can make our sorted spikes group with this unit selection parameter
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
import spyglass.spikesorting.v1 as sgs
nwb_copy_file_name = "mediumnwb20230802_.nwb"
sorter_keys = {
"nwb_file_name": nwb_copy_file_name,
"sorter": "mountainsort4",
"curation_id": 1,
}
# check the set of sorting we'll use
(
sgs.SpikeSortingSelection & sorter_keys
) * SpikeSortingOutput.CurationV1 & sorter_keys
sorting_id | merge_id | recording_id | sorter | sorter_param_name | nwb_file_name name of the NWB file | interval_list_name descriptive name of this interval list | curation_id |
---|---|---|---|---|---|---|---|
642242ff-5f0e-45a2-bcc1-ca681f37b4a3 | 75286bf3-f876-4550-f235-321f2a7badef | 01c5b8e9-933d-4f1e-9a5d-c494276edb3a | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 0a6611b3-c593-4900-a715-66bb1396940e | 1 |
a4b5a94d-ba41-4634-92d0-1d31c9daa913 | 143dff79-3779-c0d2-46fe-7c5040404219 | a8a1d29d-ffdf-4370-8b3d-909fef57f9d4 | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 3d782852-a56b-4a9d-89ca-be9e1a15c957 | 1 |
874775be-df0f-4850-8f88-59ba1bbead89 | a900c1c8-909d-e583-c377-e98c4f0deebf | 747f4eea-6df3-422b-941e-b5aaad7ec607 | mountainsort4 | franklab_tetrode_hippocampus_30KHz | mediumnwb20230802_.nwb | 9cf9e3cd-7115-4b59-a718-3633725d4738 | 1 |
Total: 3
Finding the merge id's corresponding to an interpretable restriction such as merge_id
or interval_list
can require several join steps with upstream tables. To simplify this process we can use the included helper function SpikeSortingOutput().get_restricted_merge_ids()
to perform the necessary joins and return the matching merge id's
# get the merge_ids for the selected sorting
spikesorting_merge_ids = SpikeSortingOutput().get_restricted_merge_ids(
sorter_keys, restrict_by_artifact=False
)
# create a new sorted spikes group
unit_filter_params_name = "default_exclusion"
SortedSpikesGroup().create_group(
group_name="test_group",
nwb_file_name=nwb_copy_file_name,
keys=[
{"spikesorting_merge_id": merge_id}
for merge_id in spikesorting_merge_ids
],
unit_filter_params_name=unit_filter_params_name,
)
# check the new group
SortedSpikesGroup & {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
}
nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name |
---|---|---|
mediumnwb20230802_.nwb | default_exclusion | test_group |
Total: 1
# look at the sorting within the group we just made
SortedSpikesGroup.Units & {
"nwb_file_name": nwb_copy_file_name,
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": unit_filter_params_name,
}
nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name | spikesorting_merge_id |
---|---|---|---|
mediumnwb20230802_.nwb | default_exclusion | test_group | 143dff79-3779-c0d2-46fe-7c5040404219 |
mediumnwb20230802_.nwb | default_exclusion | test_group | 75286bf3-f876-4550-f235-321f2a7badef |
mediumnwb20230802_.nwb | default_exclusion | test_group | a900c1c8-909d-e583-c377-e98c4f0deebf |
Total: 3
Model parameters¶
As before we can specify the model parameters. The only difference is that we will use the ContFragSortedSpikesClassifier
instead of the ContFragClusterlessClassifier
.
from spyglass.decoding.v1.core import DecodingParameters
from non_local_detector.models import ContFragSortedSpikesClassifier
DecodingParameters.insert1(
{
"decoding_param_name": "contfrag_sorted",
"decoding_params": ContFragSortedSpikesClassifier(),
"decoding_kwargs": dict(),
},
skip_duplicates=True,
)
DecodingParameters()
decoding_param_name a name for this set of parameters | decoding_params initialization parameters for model | decoding_kwargs additional keyword arguments |
---|---|---|
contfrag_clusterless | =BLOB= | =BLOB= |
contfrag_clusterless_0.5.13 | =BLOB= | =BLOB= |
contfrag_clusterless_6track | =BLOB= | =BLOB= |
contfrag_sorted | =BLOB= | =BLOB= |
contfrag_sorted_0.5.13 | =BLOB= | =BLOB= |
j1620210710_contfrag_clusterless_1D | =BLOB= | =BLOB= |
j1620210710_test_contfrag_clusterless | =BLOB= | =BLOB= |
MS2220180629_contfrag_sorted | =BLOB= | =BLOB= |
ms_lineartrack_2023_contfrag_sorted | =BLOB= | =BLOB= |
ms_lineartrack_contfrag_clusterless | =BLOB= | =BLOB= |
ms_lineartrack_contfrag_sorted | =BLOB= | =BLOB= |
ms_wtrack_2023_contfrag_sorted | =BLOB= | =BLOB= |
...
Total: 15
1D Decoding¶
As in the clusterless notebook, we can decode 1D position if we specify the track_graph
, edge_order
, and edge_spacing
parameters in the Environment
class constructor. See the clusterless decoding tutorial for more details.
Decoding¶
Now we can decode the position using the sorted spikes using the SortedSpikesDecodingSelection
table. Here we assume that PositionGroup
has been specified as in the clusterless decoding tutorial.
selection_key = {
"sorted_spikes_group_name": "test_group",
"unit_filter_params_name": "default_exclusion",
"position_group_name": "test_group",
"decoding_param_name": "contfrag_sorted",
"nwb_file_name": "mediumnwb20230802_.nwb",
"encoding_interval": "pos 0 valid times",
"decoding_interval": "test decoding interval",
"estimate_decoding_params": False,
}
from spyglass.decoding import SortedSpikesDecodingSelection
SortedSpikesDecodingSelection.insert1(
selection_key,
skip_duplicates=True,
)
from spyglass.decoding.v1.sorted_spikes import SortedSpikesDecodingV1
SortedSpikesDecodingV1.populate(selection_key)
[12:19:30][WARNING] Spyglass: Upsampled position data, frame indices are invalid. Setting add_frame_ind=False No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Encoding models: 0%| | 0/54 [00:00<?, ?cell/s]
Non-Local Likelihood: 0%| | 0/54 [00:00<?, ?cell/s]
We verify that the results have been inserted into the DecodingOutput
merge table.
from spyglass.decoding.decoding_merge import DecodingOutput
DecodingOutput.SortedSpikesDecodingV1 & selection_key
merge_id | nwb_file_name name of the NWB file | unit_filter_params_name | sorted_spikes_group_name | position_group_name | decoding_param_name a name for this set of parameters | encoding_interval descriptive name of this interval list | decoding_interval descriptive name of this interval list | estimate_decoding_params whether to estimate the decoding parameters |
---|---|---|---|---|---|---|---|---|
42e9e7f9-a6f2-9242-63ce-94228bc72743 | mediumnwb20230802_.nwb | default_exclusion | test_group | test_group | contfrag_sorted | pos 0 valid times | test decoding interval | 0 |
Total: 1
We can load the results as before:
results = (SortedSpikesDecodingV1 & selection_key).fetch_results()
results
<xarray.Dataset> Dimensions: (state_ind: 25752, dim_0: 25752, time: 5001, states: 2, intervals: 1, state_bins: 25752) Coordinates: * state_ind (state_ind) int32 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 * time (time) float64 1.626e+09 ... 1.626e+09 * states (states) object 'Continuous' 'Fragmented' environments (states) object ... encoding_groups (states) int32 ... * state_bins (state_bins) object MultiIndex * state (state_bins) object 'Continuous' ... 'Fragme... * x_position (state_bins) float64 29.02 29.02 ... 258.8 * y_position (state_bins) float64 5.828 7.811 ... 224.0 Dimensions without coordinates: dim_0, intervals Data variables: initial_conditions (dim_0) float64 ... discrete_state_transitions (states, states) float64 ... acausal_posterior (intervals, time, state_bins) float32 ... acausal_state_probabilities (intervals, time, states) float64 ... Attributes: marginal_log_likelihoods: -39514.59