Skip to content

Base Data Classes

Core data containers used across all datasets.

Data Containers

MoviesTrainTestSplit dataclass

MoviesTrainTestSplit(
    train: Float[
        ndarray, "channels train_time height width"
    ],
    test_dict: dict = (lambda: {})(),
    test: InitVar[
        Float[ndarray, "channels test_time height width"]
        | None
    ] = None,
    stim_id: Optional[str] = None,
    random_sequences: Optional[ndarray] = None,
    norm_mean: Optional[float] = None,
    norm_std: Optional[float] = None,
)

Container for stimulus movies used during training and evaluation.

ATTRIBUTE DESCRIPTION
train

Continuous movie shown during training.

TYPE: Float[ndarray, 'channels train_time height width']

test_dict

Named dictionary of frozen test stimuli. For legacy single-test datasets pass test; it will automatically be wrapped into {"test": test}.

TYPE: dict

test

Convenience field to pass a single frozen movie.

TYPE: dict

stim_id

Optional identifier (e.g. "natural") to keep responses/movies aligned.

TYPE: Optional[str]

random_sequences

Optional clip permutations (Höfling 2024 format).

TYPE: Optional[ndarray]

norm_mean

Normalization statistics applied to both train and test movies.

TYPE: / norm_std

ResponsesTrainTestSplit dataclass

ResponsesTrainTestSplit(
    train: Float[ndarray, "neurons train_time"],
    test_dict: dict = (lambda: {})(),
    test: InitVar[
        Float[ndarray, "neurons test_time"] | None
    ] = None,
    test_by_trial: Float[
        ndarray, "trials neurons test_time"
    ]
    | None = None,
    test_by_trial_dict: dict = (lambda: {})(),
    stim_id: str | None = None,
    session_kwargs: dict[str, Any] = (lambda: {})(),
)

Container for neural responses paired with MoviesTrainTestSplit.

Supports multiple test stimuli via test_dict and per-trial traces via test_by_trial_dict. For single-test datasets you may provide test and optionally test_by_trial; both will be lifted into the matching dictionaries.

get_test_by_trial

get_test_by_trial(
    name: str = "test",
) -> Float[ndarray, "trials neurons test_time"] | None

Return the per-trial responses for a specific stimulus.

PARAMETER DESCRIPTION
name

Key inside test_dict. Default is "test", for the default single test stimulus case.

TYPE: str DEFAULT: 'test'

RETURNS DESCRIPTION
Float[ndarray, 'trials neurons test_time'] | None

Array of shape (trials, neurons, time) if available, otherwise None.

Source code in openretina/data_io/base.py
175
176
177
178
179
180
181
182
183
184
185
186
187
def get_test_by_trial(self, name: str = "test") -> Float[np.ndarray, "trials neurons test_time"] | None:
    """
    Return the per-trial responses for a specific stimulus.

    Args:
        name: Key inside `test_dict`. Default is "test", for the default single test stimulus case.

    Returns:
        Array of shape (trials, neurons, time) if available, otherwise `None`.
    """
    if not self.test_by_trial_dict:
        return None
    return self.test_by_trial_dict.get(name)

DatasetStatistics dataclass

DatasetStatistics(
    unique_train_frames: int,
    unique_val_frames: int,
    unique_train_val_frames: int,
    unique_test_frames: dict[str, int],
    unique_train_transitions: int,
    unique_val_transitions: int,
    unique_test_transitions: dict[str, int],
    n_sessions: int,
)

Statistics about unique frames and transitions across sessions, computed from dataloaders.

ATTRIBUTE DESCRIPTION
unique_train_frames

Number of unique training frames seen across all sessions.

TYPE: int

unique_val_frames

Number of unique validation frames seen across all sessions.

TYPE: int

unique_train_val_frames

Union of unique train and val frames (deduplicated).

TYPE: int

unique_test_frames

Dict mapping test split name to unique frame count.

TYPE: dict[str, int]

unique_train_transitions

Number of unique consecutive-frame transitions in training.

TYPE: int

unique_val_transitions

Number of unique consecutive-frame transitions in validation.

TYPE: int

unique_test_transitions

Dict mapping test split name to unique transition count.

TYPE: dict[str, int]

n_sessions

Total number of sessions.

TYPE: int

empty classmethod

empty() -> DatasetStatistics

Create an empty DatasetStatistics instance (all counts zero).

Source code in openretina/data_io/base.py
218
219
220
221
222
223
224
225
226
227
228
229
230
@classmethod
def empty(cls) -> "DatasetStatistics":
    """Create an empty DatasetStatistics instance (all counts zero)."""
    return cls(
        unique_train_frames=0,
        unique_val_frames=0,
        unique_train_val_frames=0,
        unique_test_frames={},
        unique_train_transitions=0,
        unique_val_transitions=0,
        unique_test_transitions={},
        n_sessions=0,
    )

Helper Functions

normalize_train_test_movies

normalize_train_test_movies(
    train: Float[
        ndarray, "channels train_time height width"
    ],
    test: Float[ndarray, "channels test_time height width"],
) -> tuple[
    Float[ndarray, "channels train_time height width"],
    Float[ndarray, "channels test_time height width"],
    dict[str, float | None],
]

z-score normalization of train and test movies using the mean and standard deviation of the train movie.

Parameters: - train: train movie with shape (channels, time, height, width) - test: test movie with shape (channels, time, height, width)

Returns: - train_video_preproc: normalized train movie - test_video_preproc: normalized test movie - norm_stats: dictionary containing the mean and standard deviation of the train movie

Note: The functions casts the input to torch tensors to calculate the mean and standard deviation of large inputs more efficiently.

Source code in openretina/data_io/base.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def normalize_train_test_movies(
    train: Float[np.ndarray, "channels train_time height width"],
    test: Float[np.ndarray, "channels test_time height width"],
) -> tuple[
    Float[np.ndarray, "channels train_time height width"],
    Float[np.ndarray, "channels test_time height width"],
    dict[str, float | None],
]:
    """
    z-score normalization of train and test movies using the mean and standard deviation of the train movie.

    Parameters:
    - train: train movie with shape (channels, time, height, width)
    - test: test movie with shape (channels, time, height, width)

    Returns:
    - train_video_preproc: normalized train movie
    - test_video_preproc: normalized test movie
    - norm_stats: dictionary containing the mean and standard deviation of the train movie

    Note: The functions casts the input to torch tensors to calculate the mean and standard deviation of large
    inputs more efficiently.
    """
    train_tensor = torch.tensor(train, dtype=torch.float32)
    test_tensor = torch.tensor(test, dtype=torch.float32)
    train_mean = train_tensor.mean()
    train_std = train_tensor.std()
    train_video_preproc = (train_tensor - train_mean) / train_std
    test_video = (test_tensor - train_mean) / train_std
    return (
        train_video_preproc.cpu().detach().numpy(),
        test_video.cpu().detach().numpy(),
        {"norm_mean": train_mean.item(), "norm_std": train_std.item()},
    )

compute_data_info

compute_data_info(
    neuron_data_dictionary: dict[
        str, ResponsesTrainTestSplit
    ],
    movies_dictionary: dict[str, MoviesTrainTestSplit]
    | MoviesTrainTestSplit,
    partial_data_info: dict[str, Any] | None = None,
) -> dict[str, Any]

Computes information related to the data used to train a model, including the number of neurons, the shape of the movies, and the normalization statistics. This information should be fed to and saved with the models.

Parameters: - neuron_data_dictionary: dictionary of responses for each session - movies_dictionary: dictionary of movies for each session - partial_data_info: dictionary of partial data info from the config, to be merged with the computed data info

Returns: - data_info: dictionary containing various data info useful for downstream tasks, including the number of neurons, the shape of the movies, the movie normalization statistics, and any extra session kwargs related to the data, including partial data information passed in the training config.

Source code in openretina/data_io/base.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def compute_data_info(
    neuron_data_dictionary: dict[str, ResponsesTrainTestSplit],
    movies_dictionary: dict[str, MoviesTrainTestSplit] | MoviesTrainTestSplit,
    partial_data_info: dict[str, Any] | None = None,
) -> dict[str, Any]:
    """
    Computes information related to the data used to train a model, including the number of neurons, the shape of the
    movies, and the normalization statistics. This information should be fed to and saved with the models.

    Parameters:
    - neuron_data_dictionary: dictionary of responses for each session
    - movies_dictionary: dictionary of movies for each session
    - partial_data_info: dictionary of partial data info from the config, to be merged with the computed data info

    Returns:
    - data_info: dictionary containing various data info useful for downstream tasks, including the number of neurons,
    the shape of the movies, the movie normalization statistics, and any extra session kwargs related to the data,
    including partial data information passed in the training config.
    """
    n_neurons_dict = get_n_neurons_per_session(neuron_data_dictionary)

    # Compute mean activity for each session from training responses
    mean_activity_dict = {}
    for session_name, responses in neuron_data_dictionary.items():
        # responses.train has shape (n_neurons, n_timepoints)
        # Compute mean across time dimension
        mean_activity = torch.tensor(responses.train.mean(axis=1), dtype=torch.float32)
        mean_activity_dict[session_name] = mean_activity

    if isinstance(movies_dictionary, MoviesTrainTestSplit):
        stim_mean = movies_dictionary.norm_mean
        stim_std = movies_dictionary.norm_std
        input_shape = (
            movies_dictionary.train.shape[0],
            *movies_dictionary.train.shape[2:],
        )
    else:
        norm_means = [movie.norm_mean for movie in movies_dictionary.values() if movie.norm_mean is not None]
        norm_stds = [movie.norm_std for movie in movies_dictionary.values() if movie.norm_std is not None]

        if len(norm_means) > 0:
            if not np.allclose(norm_means, norm_means[0], atol=1, rtol=0):
                raise ValueError(f"Normalization means are not consistent across stimuli: {norm_means}")
            stim_mean = norm_means[0]
        else:
            stim_mean = 0.0
            warnings.warn(f"No stimulus mean set, setting {stim_mean=}")
        if len(norm_stds) > 0:
            if not np.allclose(norm_stds, norm_stds[0], atol=1, rtol=0):
                raise ValueError(f"Normalization stds are not consistent across stimuli: {norm_stds}")
            stim_std = norm_stds[0]
        else:
            stim_std = 1.0
            warnings.warn(f"No stimulus stds set, setting {stim_std=}")

        # Do the same for the input shape
        input_shapes = [(movie.train.shape[0], *movie.train.shape[2:]) for movie in movies_dictionary.values()]
        if any(shape != input_shapes[0] for shape in input_shapes):
            raise ValueError(f"Input shapes are not consistent across stimuli: {input_shapes}")

        input_shape = input_shapes[0]

    sessions_kwargs = {
        session_name: responses.session_kwargs for session_name, responses in neuron_data_dictionary.items()
    }

    return {
        "n_neurons_dict": n_neurons_dict,
        "mean_activity_dict": mean_activity_dict,
        "input_shape": input_shape,
        "sessions_kwargs": sessions_kwargs,
        "stim_mean": stim_mean,
        "stim_std": stim_std,
        **(partial_data_info or {}),
    }