Skip to content

sorted_spikes.py

Pipeline for decoding the animal's mental position and some category of interest from clustered spikes times. See [1] for details.

References

[1] Denovellis, E. L. et al. Hippocampal replay of experience at real-world speeds. eLife 10, e64505 (2021).

SortedSpikesDecodingV1

Bases: SpyglassMixin, Computed

Source code in src/spyglass/decoding/v1/sorted_spikes.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
@schema
class SortedSpikesDecodingV1(SpyglassMixin, dj.Computed):
    definition = """
    -> SortedSpikesDecodingSelection
    ---
    results_path: filepath@analysis # path to the results file
    classifier_path: filepath@analysis # path to the classifier file
    """

    def make(self, key):
        """Populate the decoding model.

        1. Fetches parameters and position data from DecodingParameters and
            PositionGroup tables.
        2. Decomposes instervals into encoding and decoding.
        3. Optionally estimates decoding parameters, otherwise uses the provided
            parameters.
        4. Uses SortedSpikesDetector from non_local_detector package to decode
            the animal's mental position, including initial and discrete state
            transition information.
        5. Optionally includes the discrete transition coefficients.
        6. Saves the results and model to disk in the analysis directory, under
            the nwb file name's folder.
        7. Inserts the results and model paths into SortedSpikesDecodingV1 and
            DecodingOutput tables.
        """
        orig_key = copy.deepcopy(key)

        # Get model parameters
        model_params = (
            DecodingParameters
            & {"decoding_param_name": key["decoding_param_name"]}
        ).fetch1()
        decoding_params, decoding_kwargs = (
            model_params["decoding_params"],
            model_params["decoding_kwargs"],
        )
        decoding_kwargs = decoding_kwargs or {}

        # Get position data
        (
            position_info,
            position_variable_names,
        ) = self.fetch_position_info(key)

        # Get the spike times for the selected units. Don't need to filter by
        # interval since the non_local_detector code will do that

        spike_times = self.fetch_spike_data(key, filter_by_interval=False)

        # Get the encoding and decoding intervals
        encoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["encoding_interval"],
            }
        ).fetch1("valid_times")
        is_training = np.zeros(len(position_info), dtype=bool)
        for interval_start, interval_end in encoding_interval:
            is_training[
                np.logical_and(
                    position_info.index >= interval_start,
                    position_info.index <= interval_end,
                )
            ] = True
        is_training[
            position_info[position_variable_names].isna().values.max(axis=1)
        ] = False

        if "is_training" not in decoding_kwargs:
            decoding_kwargs["is_training"] = is_training

        decoding_interval = (
            IntervalList
            & {
                "nwb_file_name": key["nwb_file_name"],
                "interval_list_name": key["decoding_interval"],
            }
        ).fetch1("valid_times")

        # Decode
        classifier = SortedSpikesDetector(**decoding_params)

        if key["estimate_decoding_params"]:
            # if estimating parameters, then we need to treat times outside
            # decoding interval as missing this means that times outside the
            # decoding interval will not use the spiking data a better approach
            # would be to treat the intervals as multiple sequences (see
            # https://en.wikipedia.org/wiki/Baum%E2%80%93Welch_algorithm#Multiple_sequences)

            is_missing = np.ones(len(position_info), dtype=bool)
            for interval_start, interval_end in decoding_interval:
                is_missing[
                    np.logical_and(
                        position_info.index >= interval_start,
                        position_info.index <= interval_end,
                    )
                ] = False
            if "is_missing" not in decoding_kwargs:
                decoding_kwargs["is_missing"] = is_missing
            results = classifier.estimate_parameters(
                position_time=position_info.index.to_numpy(),
                position=position_info[position_variable_names].to_numpy(),
                spike_times=spike_times,
                time=position_info.index.to_numpy(),
                **decoding_kwargs,
            )
        else:
            VALID_FIT_KWARGS = [
                "is_training",
                "encoding_group_labels",
                "environment_labels",
                "discrete_transition_covariate_data",
            ]

            fit_kwargs = {
                key: value
                for key, value in decoding_kwargs.items()
                if key in VALID_FIT_KWARGS
            }
            classifier.fit(
                position_time=position_info.index.to_numpy(),
                position=position_info[position_variable_names].to_numpy(),
                spike_times=spike_times,
                **fit_kwargs,
            )
            VALID_PREDICT_KWARGS = [
                "is_missing",
                "discrete_transition_covariate_data",
                "return_causal_posterior",
            ]
            predict_kwargs = {
                key: value
                for key, value in decoding_kwargs.items()
                if key in VALID_PREDICT_KWARGS
            }

            # We treat each decoding interval as a separate sequence
            results = []
            for interval_start, interval_end in decoding_interval:
                interval_time = position_info.loc[
                    interval_start:interval_end
                ].index.to_numpy()

                if interval_time.size == 0:
                    logger.warning(
                        f"Interval {interval_start}:{interval_end} is empty"
                    )
                    continue
                results.append(
                    classifier.predict(
                        position_time=interval_time,
                        position=position_info.loc[interval_start:interval_end][
                            position_variable_names
                        ].to_numpy(),
                        spike_times=spike_times,
                        time=interval_time,
                        **predict_kwargs,
                    )
                )
            results = xr.concat(results, dim="intervals")

        # Save discrete transition and initial conditions
        results["initial_conditions"] = xr.DataArray(
            classifier.initial_conditions_,
            name="initial_conditions",
        )
        results["discrete_state_transitions"] = xr.DataArray(
            classifier.discrete_state_transitions_,
            dims=("states", "states"),
            name="discrete_state_transitions",
        )
        if (
            vars(classifier).get("discrete_transition_coefficients_")
            is not None
        ):
            results["discrete_transition_coefficients"] = (
                classifier.discrete_transition_coefficients_
            )

        # Insert results
        # in future use https://github.com/rly/ndx-xarray and analysis nwb file?

        nwb_file_name = key["nwb_file_name"].replace("_.nwb", "")

        # Generate a unique path for the results file
        path_exists = True
        while path_exists:
            results_path = (
                Path(config["SPYGLASS_ANALYSIS_DIR"])
                / nwb_file_name
                / f"{nwb_file_name}_{str(uuid.uuid4())}.nc"
            )
            path_exists = results_path.exists()
        classifier.save_results(
            results,
            results_path,
        )
        key["results_path"] = results_path

        classifier_path = results_path.with_suffix(".pkl")
        classifier.save_model(classifier_path)
        key["classifier_path"] = classifier_path

        self.insert1(key)

        from spyglass.decoding.decoding_merge import DecodingOutput

        DecodingOutput.insert1(orig_key, skip_duplicates=True)

    def fetch_results(self) -> xr.Dataset:
        """Retrieve the decoding results

        Returns
        -------
        xr.Dataset
            The decoding results (posteriors, etc.)
        """
        return SortedSpikesDetector.load_results(self.fetch1("results_path"))

    def fetch_model(self):
        """Retrieve the decoding model"""
        return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

    @classmethod
    def fetch_environments(cls, key):
        """Fetch the environments for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        List[TrackGraph]
            list of track graphs in the trained model
        """
        key = cls.get_fully_defined_key(
            key, required_fields=["decoding_param_name"]
        )

        model_params = (
            DecodingParameters
            & {"decoding_param_name": key["decoding_param_name"]}
        ).fetch1()
        decoding_params, decoding_kwargs = (
            model_params["decoding_params"],
            model_params["decoding_kwargs"],
        )

        if decoding_kwargs is None:
            decoding_kwargs = {}

        (
            position_info,
            position_variable_names,
        ) = SortedSpikesDecodingV1.fetch_position_info(key)
        classifier = SortedSpikesDetector(**decoding_params)

        classifier.initialize_environments(
            position=position_info[position_variable_names].to_numpy(),
            environment_labels=decoding_kwargs.get("environment_labels", None),
        )

        return classifier.environments

    @classmethod
    def fetch_position_info(cls, key):
        """Fetch the position information for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        Tuple[pd.DataFrame, List[str]]
            The position information and the names of the position variables
        """
        key = cls.get_fully_defined_key(
            key,
            required_fields=[
                "position_group_name",
                "nwb_file_name",
                "encoding_interval",
                "decoding_interval",
            ],
        )

        position_group_key = {
            "position_group_name": key["position_group_name"],
            "nwb_file_name": key["nwb_file_name"],
        }
        min_time, max_time = _get_interval_range(key)
        position_info, position_variable_names = (
            PositionGroup & position_group_key
        ).fetch_position_info(min_time=min_time, max_time=max_time)

        return position_info, position_variable_names

    @classmethod
    def fetch_linear_position_info(cls, key):
        """Fetch the position information and project it onto the track graph

        Parameters
        ----------
        key : dict
            The decoding selection key

        Returns
        -------
        pd.DataFrame
            The linearized position information
        """
        key = cls.get_fully_defined_key(
            key,
            required_fields=[
                "position_group_name",
                "nwb_file_name",
                "encoding_interval",
                "decoding_interval",
            ],
        )

        environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

        position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
        position_variable_names = (PositionGroup & key).fetch1(
            "position_variables"
        )
        position = np.asarray(position_df[position_variable_names])

        linear_position_df = get_linearized_position(
            position=position,
            track_graph=environment.track_graph,
            edge_order=environment.edge_order,
            edge_spacing=environment.edge_spacing,
        )
        min_time, max_time = _get_interval_range(key)

        return pd.concat(
            [linear_position_df.set_index(position_df.index), position_df],
            axis=1,
        ).loc[min_time:max_time]

    @classmethod
    def fetch_spike_data(
        cls,
        key,
        filter_by_interval=True,
        time_slice=None,
        return_unit_ids=False,
    ) -> Union[list[np.ndarray], Optional[list[dict]]]:
        """Fetch the spike times for the decoding model

        Parameters
        ----------
        key : dict
            The decoding selection key
        filter_by_interval : bool, optional
            Whether to filter for spike times in the model interval,
            by default True
        time_slice : Slice, optional
            User provided slice of time to restrict spikes to, by default None
        return_unit_ids : bool, optional
            if True, return the unit_ids along with the spike times, by default
            False Unit ids defined as a list of dictionaries with keys
            'spikesorting_merge_id' and 'unit_number'

        Returns
        -------
        list[np.ndarray]
            List of spike times for each unit in the model's spike group
        """
        key = cls.get_fully_defined_key(
            key,
            required_fields=[
                "encoding_interval",
                "decoding_interval",
            ],
        )

        spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data(
            key, return_unit_ids=True
        )
        if not filter_by_interval:
            return spike_times

        if time_slice is None:
            min_time, max_time = _get_interval_range(key)
        else:
            min_time, max_time = time_slice.start, time_slice.stop

        new_spike_times = []
        for elec_spike_times in spike_times:
            is_in_interval = np.logical_and(
                elec_spike_times >= min_time, elec_spike_times <= max_time
            )
            new_spike_times.append(elec_spike_times[is_in_interval])

        if return_unit_ids:
            return new_spike_times, unit_ids
        return new_spike_times

    def spike_times_sorted_by_place_field_peak(self, time_slice=None):
        """Spike times of units sorted by place field peak location

        Parameters
        ----------
        time_slice : Slice, optional
            time range to limit returned spikes to, by default None
        """
        if time_slice is None:
            time_slice = slice(-np.inf, np.inf)

        spike_times = self.fetch_spike_data(self.fetch1())
        classifier = self.fetch_model()

        new_spike_times = {}

        for encoding_model in classifier.encoding_model_:
            place_fields = np.asarray(
                classifier.encoding_model_[encoding_model]["place_fields"]
            )
            neuron_sort_ind = np.argsort(
                np.nanargmax(place_fields, axis=1).squeeze()
            )
            new_spike_times[encoding_model] = [
                spike_times[neuron_ind][
                    np.logical_and(
                        spike_times[neuron_ind] >= time_slice.start,
                        spike_times[neuron_ind] <= time_slice.stop,
                    )
                ]
                for neuron_ind in neuron_sort_ind
            ]
        return new_spike_times

    def get_orientation_col(self, df):
        """Examine columns of a input df and return orientation col name"""
        cols = df.columns
        return "orientation" if "orientation" in cols else "head_orientation"

    def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
        """Get relative decoded position from the animal's actual position

        Parameters
        ----------
        track_graph : TrackGraph, optional
            environment track graph to project position on, by default None
        time_slice : Slice, optional
            time intrerval to restrict to, by default None

        Returns
        -------
        distance_metrics : np.ndarray
            Information about the distance of the animal to the mental position.
        """
        # TODO: store in table

        if time_slice is None:
            time_slice = slice(-np.inf, np.inf)

        classifier = self.fetch_model()
        posterior = (
            self.fetch_results()
            .acausal_posterior.sel(time=time_slice)
            .squeeze()
            .unstack("state_bins")
            .sum("state")
        )

        if track_graph is None:
            track_graph = classifier.environments[0].track_graph

        if track_graph is not None:
            linear_position_info = self.fetch_linear_position_info(
                self.fetch1("KEY")
            ).loc[time_slice]

            orientation_name = self.get_orientation_col(linear_position_info)

            traj_data = analysis.get_trajectory_data(
                posterior=posterior,
                track_graph=track_graph,
                decoder=classifier,
                actual_projected_position=linear_position_info[
                    ["projected_x_position", "projected_y_position"]
                ],
                track_segment_id=linear_position_info["track_segment_id"],
                actual_orientation=linear_position_info[orientation_name],
            )

            return analysis.get_ahead_behind_distance(track_graph, *traj_data)
        else:
            position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
                time_slice
            ]
            map_position = analysis.maximum_a_posteriori_estimate(posterior)

            orientation_name = self.get_orientation_col(position_info)

            position_variable_names = (
                PositionGroup & self.fetch1("KEY")
            ).fetch1("position_variables")

            return analysis.get_ahead_behind_distance2D(
                position_info[position_variable_names].to_numpy(),
                position_info[orientation_name].to_numpy(),
                map_position,
                classifier.environments[0].track_graphDD,
            )

make(key)

Populate the decoding model.

  1. Fetches parameters and position data from DecodingParameters and PositionGroup tables.
  2. Decomposes instervals into encoding and decoding.
  3. Optionally estimates decoding parameters, otherwise uses the provided parameters.
  4. Uses SortedSpikesDetector from non_local_detector package to decode the animal's mental position, including initial and discrete state transition information.
  5. Optionally includes the discrete transition coefficients.
  6. Saves the results and model to disk in the analysis directory, under the nwb file name's folder.
  7. Inserts the results and model paths into SortedSpikesDecodingV1 and DecodingOutput tables.
Source code in src/spyglass/decoding/v1/sorted_spikes.py
def make(self, key):
    """Populate the decoding model.

    1. Fetches parameters and position data from DecodingParameters and
        PositionGroup tables.
    2. Decomposes instervals into encoding and decoding.
    3. Optionally estimates decoding parameters, otherwise uses the provided
        parameters.
    4. Uses SortedSpikesDetector from non_local_detector package to decode
        the animal's mental position, including initial and discrete state
        transition information.
    5. Optionally includes the discrete transition coefficients.
    6. Saves the results and model to disk in the analysis directory, under
        the nwb file name's folder.
    7. Inserts the results and model paths into SortedSpikesDecodingV1 and
        DecodingOutput tables.
    """
    orig_key = copy.deepcopy(key)

    # Get model parameters
    model_params = (
        DecodingParameters
        & {"decoding_param_name": key["decoding_param_name"]}
    ).fetch1()
    decoding_params, decoding_kwargs = (
        model_params["decoding_params"],
        model_params["decoding_kwargs"],
    )
    decoding_kwargs = decoding_kwargs or {}

    # Get position data
    (
        position_info,
        position_variable_names,
    ) = self.fetch_position_info(key)

    # Get the spike times for the selected units. Don't need to filter by
    # interval since the non_local_detector code will do that

    spike_times = self.fetch_spike_data(key, filter_by_interval=False)

    # Get the encoding and decoding intervals
    encoding_interval = (
        IntervalList
        & {
            "nwb_file_name": key["nwb_file_name"],
            "interval_list_name": key["encoding_interval"],
        }
    ).fetch1("valid_times")
    is_training = np.zeros(len(position_info), dtype=bool)
    for interval_start, interval_end in encoding_interval:
        is_training[
            np.logical_and(
                position_info.index >= interval_start,
                position_info.index <= interval_end,
            )
        ] = True
    is_training[
        position_info[position_variable_names].isna().values.max(axis=1)
    ] = False

    if "is_training" not in decoding_kwargs:
        decoding_kwargs["is_training"] = is_training

    decoding_interval = (
        IntervalList
        & {
            "nwb_file_name": key["nwb_file_name"],
            "interval_list_name": key["decoding_interval"],
        }
    ).fetch1("valid_times")

    # Decode
    classifier = SortedSpikesDetector(**decoding_params)

    if key["estimate_decoding_params"]:
        # if estimating parameters, then we need to treat times outside
        # decoding interval as missing this means that times outside the
        # decoding interval will not use the spiking data a better approach
        # would be to treat the intervals as multiple sequences (see
        # https://en.wikipedia.org/wiki/Baum%E2%80%93Welch_algorithm#Multiple_sequences)

        is_missing = np.ones(len(position_info), dtype=bool)
        for interval_start, interval_end in decoding_interval:
            is_missing[
                np.logical_and(
                    position_info.index >= interval_start,
                    position_info.index <= interval_end,
                )
            ] = False
        if "is_missing" not in decoding_kwargs:
            decoding_kwargs["is_missing"] = is_missing
        results = classifier.estimate_parameters(
            position_time=position_info.index.to_numpy(),
            position=position_info[position_variable_names].to_numpy(),
            spike_times=spike_times,
            time=position_info.index.to_numpy(),
            **decoding_kwargs,
        )
    else:
        VALID_FIT_KWARGS = [
            "is_training",
            "encoding_group_labels",
            "environment_labels",
            "discrete_transition_covariate_data",
        ]

        fit_kwargs = {
            key: value
            for key, value in decoding_kwargs.items()
            if key in VALID_FIT_KWARGS
        }
        classifier.fit(
            position_time=position_info.index.to_numpy(),
            position=position_info[position_variable_names].to_numpy(),
            spike_times=spike_times,
            **fit_kwargs,
        )
        VALID_PREDICT_KWARGS = [
            "is_missing",
            "discrete_transition_covariate_data",
            "return_causal_posterior",
        ]
        predict_kwargs = {
            key: value
            for key, value in decoding_kwargs.items()
            if key in VALID_PREDICT_KWARGS
        }

        # We treat each decoding interval as a separate sequence
        results = []
        for interval_start, interval_end in decoding_interval:
            interval_time = position_info.loc[
                interval_start:interval_end
            ].index.to_numpy()

            if interval_time.size == 0:
                logger.warning(
                    f"Interval {interval_start}:{interval_end} is empty"
                )
                continue
            results.append(
                classifier.predict(
                    position_time=interval_time,
                    position=position_info.loc[interval_start:interval_end][
                        position_variable_names
                    ].to_numpy(),
                    spike_times=spike_times,
                    time=interval_time,
                    **predict_kwargs,
                )
            )
        results = xr.concat(results, dim="intervals")

    # Save discrete transition and initial conditions
    results["initial_conditions"] = xr.DataArray(
        classifier.initial_conditions_,
        name="initial_conditions",
    )
    results["discrete_state_transitions"] = xr.DataArray(
        classifier.discrete_state_transitions_,
        dims=("states", "states"),
        name="discrete_state_transitions",
    )
    if (
        vars(classifier).get("discrete_transition_coefficients_")
        is not None
    ):
        results["discrete_transition_coefficients"] = (
            classifier.discrete_transition_coefficients_
        )

    # Insert results
    # in future use https://github.com/rly/ndx-xarray and analysis nwb file?

    nwb_file_name = key["nwb_file_name"].replace("_.nwb", "")

    # Generate a unique path for the results file
    path_exists = True
    while path_exists:
        results_path = (
            Path(config["SPYGLASS_ANALYSIS_DIR"])
            / nwb_file_name
            / f"{nwb_file_name}_{str(uuid.uuid4())}.nc"
        )
        path_exists = results_path.exists()
    classifier.save_results(
        results,
        results_path,
    )
    key["results_path"] = results_path

    classifier_path = results_path.with_suffix(".pkl")
    classifier.save_model(classifier_path)
    key["classifier_path"] = classifier_path

    self.insert1(key)

    from spyglass.decoding.decoding_merge import DecodingOutput

    DecodingOutput.insert1(orig_key, skip_duplicates=True)

fetch_results()

Retrieve the decoding results

Returns:

Type Description
Dataset

The decoding results (posteriors, etc.)

Source code in src/spyglass/decoding/v1/sorted_spikes.py
def fetch_results(self) -> xr.Dataset:
    """Retrieve the decoding results

    Returns
    -------
    xr.Dataset
        The decoding results (posteriors, etc.)
    """
    return SortedSpikesDetector.load_results(self.fetch1("results_path"))

fetch_model()

Retrieve the decoding model

Source code in src/spyglass/decoding/v1/sorted_spikes.py
def fetch_model(self):
    """Retrieve the decoding model"""
    return SortedSpikesDetector.load_model(self.fetch1("classifier_path"))

fetch_environments(key) classmethod

Fetch the environments for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
List[TrackGraph]

list of track graphs in the trained model

Source code in src/spyglass/decoding/v1/sorted_spikes.py
@classmethod
def fetch_environments(cls, key):
    """Fetch the environments for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    List[TrackGraph]
        list of track graphs in the trained model
    """
    key = cls.get_fully_defined_key(
        key, required_fields=["decoding_param_name"]
    )

    model_params = (
        DecodingParameters
        & {"decoding_param_name": key["decoding_param_name"]}
    ).fetch1()
    decoding_params, decoding_kwargs = (
        model_params["decoding_params"],
        model_params["decoding_kwargs"],
    )

    if decoding_kwargs is None:
        decoding_kwargs = {}

    (
        position_info,
        position_variable_names,
    ) = SortedSpikesDecodingV1.fetch_position_info(key)
    classifier = SortedSpikesDetector(**decoding_params)

    classifier.initialize_environments(
        position=position_info[position_variable_names].to_numpy(),
        environment_labels=decoding_kwargs.get("environment_labels", None),
    )

    return classifier.environments

fetch_position_info(key) classmethod

Fetch the position information for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
Tuple[DataFrame, List[str]]

The position information and the names of the position variables

Source code in src/spyglass/decoding/v1/sorted_spikes.py
@classmethod
def fetch_position_info(cls, key):
    """Fetch the position information for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    Tuple[pd.DataFrame, List[str]]
        The position information and the names of the position variables
    """
    key = cls.get_fully_defined_key(
        key,
        required_fields=[
            "position_group_name",
            "nwb_file_name",
            "encoding_interval",
            "decoding_interval",
        ],
    )

    position_group_key = {
        "position_group_name": key["position_group_name"],
        "nwb_file_name": key["nwb_file_name"],
    }
    min_time, max_time = _get_interval_range(key)
    position_info, position_variable_names = (
        PositionGroup & position_group_key
    ).fetch_position_info(min_time=min_time, max_time=max_time)

    return position_info, position_variable_names

fetch_linear_position_info(key) classmethod

Fetch the position information and project it onto the track graph

Parameters:

Name Type Description Default
key dict

The decoding selection key

required

Returns:

Type Description
DataFrame

The linearized position information

Source code in src/spyglass/decoding/v1/sorted_spikes.py
@classmethod
def fetch_linear_position_info(cls, key):
    """Fetch the position information and project it onto the track graph

    Parameters
    ----------
    key : dict
        The decoding selection key

    Returns
    -------
    pd.DataFrame
        The linearized position information
    """
    key = cls.get_fully_defined_key(
        key,
        required_fields=[
            "position_group_name",
            "nwb_file_name",
            "encoding_interval",
            "decoding_interval",
        ],
    )

    environment = SortedSpikesDecodingV1.fetch_environments(key)[0]

    position_df = SortedSpikesDecodingV1.fetch_position_info(key)[0]
    position_variable_names = (PositionGroup & key).fetch1(
        "position_variables"
    )
    position = np.asarray(position_df[position_variable_names])

    linear_position_df = get_linearized_position(
        position=position,
        track_graph=environment.track_graph,
        edge_order=environment.edge_order,
        edge_spacing=environment.edge_spacing,
    )
    min_time, max_time = _get_interval_range(key)

    return pd.concat(
        [linear_position_df.set_index(position_df.index), position_df],
        axis=1,
    ).loc[min_time:max_time]

fetch_spike_data(key, filter_by_interval=True, time_slice=None, return_unit_ids=False) classmethod

Fetch the spike times for the decoding model

Parameters:

Name Type Description Default
key dict

The decoding selection key

required
filter_by_interval bool

Whether to filter for spike times in the model interval, by default True

True
time_slice Slice

User provided slice of time to restrict spikes to, by default None

None
return_unit_ids bool

if True, return the unit_ids along with the spike times, by default False Unit ids defined as a list of dictionaries with keys 'spikesorting_merge_id' and 'unit_number'

False

Returns:

Type Description
list[ndarray]

List of spike times for each unit in the model's spike group

Source code in src/spyglass/decoding/v1/sorted_spikes.py
@classmethod
def fetch_spike_data(
    cls,
    key,
    filter_by_interval=True,
    time_slice=None,
    return_unit_ids=False,
) -> Union[list[np.ndarray], Optional[list[dict]]]:
    """Fetch the spike times for the decoding model

    Parameters
    ----------
    key : dict
        The decoding selection key
    filter_by_interval : bool, optional
        Whether to filter for spike times in the model interval,
        by default True
    time_slice : Slice, optional
        User provided slice of time to restrict spikes to, by default None
    return_unit_ids : bool, optional
        if True, return the unit_ids along with the spike times, by default
        False Unit ids defined as a list of dictionaries with keys
        'spikesorting_merge_id' and 'unit_number'

    Returns
    -------
    list[np.ndarray]
        List of spike times for each unit in the model's spike group
    """
    key = cls.get_fully_defined_key(
        key,
        required_fields=[
            "encoding_interval",
            "decoding_interval",
        ],
    )

    spike_times, unit_ids = SortedSpikesGroup.fetch_spike_data(
        key, return_unit_ids=True
    )
    if not filter_by_interval:
        return spike_times

    if time_slice is None:
        min_time, max_time = _get_interval_range(key)
    else:
        min_time, max_time = time_slice.start, time_slice.stop

    new_spike_times = []
    for elec_spike_times in spike_times:
        is_in_interval = np.logical_and(
            elec_spike_times >= min_time, elec_spike_times <= max_time
        )
        new_spike_times.append(elec_spike_times[is_in_interval])

    if return_unit_ids:
        return new_spike_times, unit_ids
    return new_spike_times

spike_times_sorted_by_place_field_peak(time_slice=None)

Spike times of units sorted by place field peak location

Parameters:

Name Type Description Default
time_slice Slice

time range to limit returned spikes to, by default None

None
Source code in src/spyglass/decoding/v1/sorted_spikes.py
def spike_times_sorted_by_place_field_peak(self, time_slice=None):
    """Spike times of units sorted by place field peak location

    Parameters
    ----------
    time_slice : Slice, optional
        time range to limit returned spikes to, by default None
    """
    if time_slice is None:
        time_slice = slice(-np.inf, np.inf)

    spike_times = self.fetch_spike_data(self.fetch1())
    classifier = self.fetch_model()

    new_spike_times = {}

    for encoding_model in classifier.encoding_model_:
        place_fields = np.asarray(
            classifier.encoding_model_[encoding_model]["place_fields"]
        )
        neuron_sort_ind = np.argsort(
            np.nanargmax(place_fields, axis=1).squeeze()
        )
        new_spike_times[encoding_model] = [
            spike_times[neuron_ind][
                np.logical_and(
                    spike_times[neuron_ind] >= time_slice.start,
                    spike_times[neuron_ind] <= time_slice.stop,
                )
            ]
            for neuron_ind in neuron_sort_ind
        ]
    return new_spike_times

get_orientation_col(df)

Examine columns of a input df and return orientation col name

Source code in src/spyglass/decoding/v1/sorted_spikes.py
def get_orientation_col(self, df):
    """Examine columns of a input df and return orientation col name"""
    cols = df.columns
    return "orientation" if "orientation" in cols else "head_orientation"

get_ahead_behind_distance(track_graph=None, time_slice=None)

Get relative decoded position from the animal's actual position

Parameters:

Name Type Description Default
track_graph TrackGraph

environment track graph to project position on, by default None

None
time_slice Slice

time intrerval to restrict to, by default None

None

Returns:

Name Type Description
distance_metrics ndarray

Information about the distance of the animal to the mental position.

Source code in src/spyglass/decoding/v1/sorted_spikes.py
def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
    """Get relative decoded position from the animal's actual position

    Parameters
    ----------
    track_graph : TrackGraph, optional
        environment track graph to project position on, by default None
    time_slice : Slice, optional
        time intrerval to restrict to, by default None

    Returns
    -------
    distance_metrics : np.ndarray
        Information about the distance of the animal to the mental position.
    """
    # TODO: store in table

    if time_slice is None:
        time_slice = slice(-np.inf, np.inf)

    classifier = self.fetch_model()
    posterior = (
        self.fetch_results()
        .acausal_posterior.sel(time=time_slice)
        .squeeze()
        .unstack("state_bins")
        .sum("state")
    )

    if track_graph is None:
        track_graph = classifier.environments[0].track_graph

    if track_graph is not None:
        linear_position_info = self.fetch_linear_position_info(
            self.fetch1("KEY")
        ).loc[time_slice]

        orientation_name = self.get_orientation_col(linear_position_info)

        traj_data = analysis.get_trajectory_data(
            posterior=posterior,
            track_graph=track_graph,
            decoder=classifier,
            actual_projected_position=linear_position_info[
                ["projected_x_position", "projected_y_position"]
            ],
            track_segment_id=linear_position_info["track_segment_id"],
            actual_orientation=linear_position_info[orientation_name],
        )

        return analysis.get_ahead_behind_distance(track_graph, *traj_data)
    else:
        position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
            time_slice
        ]
        map_position = analysis.maximum_a_posteriori_estimate(posterior)

        orientation_name = self.get_orientation_col(position_info)

        position_variable_names = (
            PositionGroup & self.fetch1("KEY")
        ).fetch1("position_variables")

        return analysis.get_ahead_behind_distance2D(
            position_info[position_variable_names].to_numpy(),
            position_info[orientation_name].to_numpy(),
            map_position,
            classifier.environments[0].track_graphDD,
        )