Skip to content

Base Dataloader

MovieDataSet

MovieDataSet(
    movies: Float[
        ndarray | Tensor, "n_channels n_frames h w"
    ],
    responses: Float[ndarray | Tensor, "n_frames n_neurons"]
    | dict[str, Any],
    roi_ids: Float[ndarray, " n_neurons"] | None,
    roi_coords: Float[ndarray, "n_neurons 2"] | None,
    group_assignment: Float[ndarray, " n_neurons"] | None,
    split: str
    | Literal["train", "validation", "val", "test"],
    chunk_size: int,
)

Bases: Dataset

A dataset class for handling movie data and corresponding neural responses.

PARAMETER DESCRIPTION
movies

The movie data.

TYPE: Float[ndarray | Tensor, 'n_channels n_frames h w']

responses

TYPE: Float[ndarray, 'n_frames n_neurons'] | dict[str, Float[ndarray, 'n_frames n_neurons']]

roi_ids

A list of ROI IDs.

TYPE: Optional[Float[ndarray, ' n_neurons']]

roi_coords

A list of ROI coordinates.

TYPE: Optional[Float[ndarray, 'n_neurons 2']]

group_assignment

A list of group assignments (cell types).

TYPE: Optional[Float[ndarray, ' n_neurons']]

split
                                    The data split, either "train", "validation", "val", or "test".

TYPE: Literal['train', 'validation', 'val', 'test']

chunk_size

The size of the chunks to split the data into.

TYPE: int

ATTRIBUTE DESCRIPTION
samples

A tuple containing movie data and neural responses.

TYPE: tuple

test_responses_by_trial
                                A dictionary containing test responses by trial (only for test split).

TYPE: Optional[Dict[str, Any]]

roi_ids

A list of region of interest (ROI) IDs.

TYPE: Optional[Float[ndarray, ' n_neurons']]

chunk_size

The size of the chunks to split the data into.

TYPE: int

mean_response

The mean response per neuron.

TYPE: Tensor

group_assignment

A list of group assignments.

TYPE: Optional[Float[ndarray, ' n_neurons']]

roi_coords

A list of ROI coordinates.

TYPE: Optional[Float[ndarray, 'n_neurons 2']]

METHOD DESCRIPTION
__getitem__

Returns a DataPoint object for the given index or slice.

movies

Returns the movie data.

responses

Returns the neural responses.

__len__

Returns the number of chunks of clips and responses used for training.

__str__

Returns a string representation of the dataset.

__repr__

Returns a string representation of the dataset.

Source code in openretina/data_io/base_dataloader.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
    self,
    movies: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
    responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"] | dict[str, Any],
    roi_ids: Float[np.ndarray, " n_neurons"] | None,
    roi_coords: Float[np.ndarray, "n_neurons 2"] | None,
    group_assignment: Float[np.ndarray, " n_neurons"] | None,
    split: str | Literal["train", "validation", "val", "test"],
    chunk_size: int,
):
    self.roi_ids = roi_ids
    self.test_responses_by_trial: torch.Tensor | None = None

    # Test responses can be passed as a dictionary by other constructors,
    # with key "avg" for the averaged responses and "by_trial" for the per-trial responses.
    if split == "test" and isinstance(responses, dict):
        responses_dict = cast(dict[str, Any], responses)
        self.samples: tuple = movies, responses_dict["avg"]
        self.test_responses_by_trial = responses_dict.get("by_trial")
    else:
        self.samples = movies, responses

    self.chunk_size = chunk_size
    # Calculate the mean response per neuron (used for bias init in the model)
    self.mean_response = torch.mean(torch.Tensor(self.samples[1]), dim=0)
    self.group_assignment = group_assignment
    self.roi_coords = roi_coords

MovieSampler

MovieSampler(
    start_indices: list[int],
    split: str
    | Literal["train", "validation", "val", "test"],
    chunk_size: int,
    movie_length: int,
    scene_length: int,
    allow_over_boundaries: bool = False,
)

Bases: Sampler

A custom sampler for selecting movie frames for training, validation, or testing.

PARAMETER DESCRIPTION
start_indices

List of starting indices for the movie sections to select.

TYPE: list[int]

split

The type of data split.

TYPE: Literal['train', 'validation', 'val', 'test']

chunk_size

The size of each contiguous chunk of frames to select.

TYPE: int

movie_length

The total length of the movie.

TYPE: int

scene_length

The length of each scene, if the movie is divided in any scenes. Defaults to None.

TYPE: Optional[int]

allow_over_boundaries

Whether to allow selected chunks to go over scene boundaries. Defaults to False.

TYPE: bool DEFAULT: False

ATTRIBUTE DESCRIPTION
indices

The starting indices for the movie sections to sample.

TYPE: list[int]

split

The type of data split.

TYPE: str

chunk_size

The size of each chunk of frames.

TYPE: int

movie_length

The total length of the movie.

TYPE: int

scene_length

The length of each scene, if the movie is made up of scenes.

TYPE: int

allow_over_boundaries

Whether to allow chunks to go over scene boundaries.

TYPE: bool

METHOD DESCRIPTION
__iter__

Returns an iterator over the sampled indices.

__len__

Returns the number of starting indices (which will corresponds to the number of sampled clips).

Source code in openretina/data_io/base_dataloader.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def __init__(
    self,
    start_indices: list[int],
    split: str | Literal["train", "validation", "val", "test"],
    chunk_size: int,
    movie_length: int,
    scene_length: int,
    allow_over_boundaries: bool = False,
):
    super().__init__()
    self.indices = start_indices
    self.split = split
    self.chunk_size = chunk_size
    self.movie_length = movie_length
    self.scene_length = scene_length
    self.allow_over_boundaries = allow_over_boundaries

NeuronDataSplit

NeuronDataSplit(
    responses: ResponsesTrainTestSplit,
    val_clip_idx: List[int],
    num_clips: int,
    clip_length: int,
    key: Optional[dict] = None,
    **kwargs,
)

Preprocesses ResponsesTrainTestSplit objects before feeding them to dataloaders.

Responsibilities
  • Remove validation clips from the training responses while storing them separately.
  • Expose torch tensors for train/val/test splits via response_dict.
  • Surface averaged and per-trial test responses per each test stimulus name so that downstream MovieDataSet instances can provide dataset.test_responses_by_trial.

Initialize the NeuronData object. Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

PARAMETER DESCRIPTION
key

The key information for the neuron data, includes date, exp_num, experimenter, field_id, stim_id.

TYPE: dict DEFAULT: None

responses

The train and test responses of neurons.

TYPE: ResponsesTrainTestSplit

val_clip_idx

The indices of validation clips.

TYPE: List[int]

num_clips

The number of clips.

TYPE: int

clip_length

The length of each clip.

TYPE: int

key

Additional key information.

TYPE: dict DEFAULT: None

Source code in openretina/data_io/base_dataloader.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def __init__(
    self,
    responses: ResponsesTrainTestSplit,
    val_clip_idx: List[int],
    num_clips: int,
    clip_length: int,
    key: Optional[dict] = None,
    **kwargs,
):
    """
    Initialize the NeuronData object.
    Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

    Args:
        key (dict): The key information for the neuron data,
                    includes date, exp_num, experimenter, field_id, stim_id.
        responses (ResponsesTrainTestSplit): The train and test responses of neurons.
        val_clip_idx (List[int]): The indices of validation clips.
        num_clips (int): The number of clips.
        clip_length (int): The length of each clip.
        key (dict, optional): Additional key information.
    """
    self.neural_responses = responses
    self.num_neurons = self.neural_responses.n_neurons
    self.key = key
    self.roi_coords = ()
    self.clip_length = clip_length
    self.num_clips = num_clips
    self.val_clip_idx = val_clip_idx

    # Transpose the responses to have the shape (n_timepoints, n_neurons)
    self.responses_train_and_val = self.neural_responses.train.T

    self.responses_train, self.responses_val = self.split_data_train_val()
    self.test_responses_by_trial: dict[str, np.ndarray] = (
        {name: np.asarray(by_trial) for name, by_trial in self.neural_responses.test_by_trial_dict.items()}
        if self.neural_responses.test_by_trial_dict is not None
        else {}
    )

response_dict property

response_dict: dict

Create and return a dictionary of neural responses for train, validation, and test datasets.

Structure

{ "train": Tensor[T_train, neurons], "validation": Tensor[T_val, neurons], "test": { stimulus_name: { "avg": Tensor[T_test, neurons], "by_trial": Optional[Tensor[trials, T_test, neurons]] }, ... } }

response_dict_test property

response_dict_test: dict[str, dict[str, Tensor | None]]

Torch representation of the averaged and per-trial test responses keyed by stimulus name.

RETURNS DESCRIPTION
dict[str, dict[str, Tensor | None]]

{ stimulus_name: {"avg": Tensor[t_test, neurons], "by_trial": Optional[Tensor[trials, t_test, neurons]]}

dict[str, dict[str, Tensor | None]]

}

split_data_train_val

split_data_train_val() -> tuple[ndarray, ndarray]

Compute validation responses and updated train responses stripped from validation clips. Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

RETURNS DESCRIPTION
tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.

Source code in openretina/data_io/base_dataloader.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def split_data_train_val(self) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute validation responses and updated train responses stripped from validation clips.
    Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.
    """
    # Initialise validation responses
    base_movie_sorting = np.arange(self.num_clips)

    validation_mask = np.ones_like(self.responses_train_and_val, dtype=bool)
    responses_val = np.zeros([len(self.val_clip_idx) * self.clip_length, self.num_neurons])

    # Compute validation responses and remove sections from training responses
    for i, ind1 in enumerate(self.val_clip_idx):
        grab_index = base_movie_sorting[ind1]
        responses_val[i * self.clip_length : (i + 1) * self.clip_length, :] = self.responses_train_and_val[
            grab_index * self.clip_length : (grab_index + 1) * self.clip_length,
            :,
        ]
        validation_mask[
            (grab_index * self.clip_length) : (grab_index + 1) * self.clip_length,
            :,
        ] = False

    responses_train = self.responses_train_and_val[validation_mask].reshape(-1, self.num_neurons)

    return responses_train, responses_val

generate_movie_splits

generate_movie_splits(
    movie_train,
    movie_test: dict[str, ndarray],
    val_clip_idc: list[int],
    num_clips: int,
    clip_length: int,
) -> tuple[Tensor, Tensor, dict[str, Tensor]]

Split training movies into train/validation subsets and convert test movies to tensors.

Source code in openretina/data_io/base_dataloader.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def generate_movie_splits(
    movie_train,
    movie_test: dict[str, np.ndarray],
    val_clip_idc: list[int],
    num_clips: int,
    clip_length: int,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
    """Split training movies into train/validation subsets and convert test movies to tensors."""
    movie_train = torch.tensor(movie_train, dtype=torch.float)
    movie_test_dict = {n: torch.tensor(movie, dtype=torch.float) for n, movie in movie_test.items()}

    channels, _, px_y, px_x = movie_train.shape

    # Prepare validation movie data
    movie_val = torch.zeros((channels, len(val_clip_idc) * clip_length, px_y, px_x), dtype=torch.float)
    for i, idx in enumerate(val_clip_idc):
        movie_val[:, i * clip_length : (i + 1) * clip_length, ...] = movie_train[
            :, idx * clip_length : (idx + 1) * clip_length, ...
        ]

    # Create a boolean mask to indicate which clips are not part of the validation set
    mask = np.ones(num_clips, dtype=bool)
    mask[val_clip_idc] = False
    train_clip_idx = np.arange(num_clips)[mask]

    movie_train_subset = torch.cat(
        [movie_train[:, i * clip_length : (i + 1) * clip_length] for i in train_clip_idx],
        dim=1,
    )

    return movie_train_subset, movie_val, movie_test_dict

gen_shifts_with_boundaries

gen_shifts_with_boundaries(
    clip_bounds: list[int] | ndarray,
    start_indices: list[int] | ndarray,
    clip_chunk_size: int = 50,
) -> list[int]

Generate shifted indices based on clip bounds and start indices. Assumes that the original start indices are already within the clip bounds. If they are not, it changes the overflowing indexes to respect the closest bound.

PARAMETER DESCRIPTION
clip_bounds

A list of clip bounds.

TYPE: list

start_indices

A list of start indices.

TYPE: list

clip_chunk_size

The size of each clip chunk. Defaults to 50.

TYPE: int DEFAULT: 50

RETURNS DESCRIPTION
list

A list of shifted indices.

TYPE: list[int]

Source code in openretina/data_io/base_dataloader.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def gen_shifts_with_boundaries(
    clip_bounds: list[int] | np.ndarray, start_indices: list[int] | np.ndarray, clip_chunk_size: int = 50
) -> list[int]:
    """
    Generate shifted indices based on clip bounds and start indices.
    Assumes that the original start indices are already within the clip bounds.
    If they are not, it changes the overflowing indexes to respect the closest bound.

    Args:
        clip_bounds (list): A list of clip bounds.
        start_indices (list): A list of start indices.
        clip_chunk_size (int, optional): The size of each clip chunk. Defaults to 50.

    Returns:
        list: A list of shifted indices.

    """

    def get_next_bound(value, bounds):
        insertion_index = bisect.bisect_right(bounds, value)
        return bounds[min(insertion_index, len(bounds) - 1)]

    shifted_indices = []
    shifts = np.random.randint(0, clip_chunk_size // 2, len(start_indices))

    for i, start_idx in enumerate(start_indices):
        next_bound = get_next_bound(start_idx, clip_bounds)
        if start_idx + shifts[i] + clip_chunk_size < next_bound:
            shifted_indices.append(start_idx + shifts[i])
        elif start_idx + clip_chunk_size > next_bound:
            shifted_indices.append(next_bound - clip_chunk_size)
        else:
            shifted_indices.append(start_idx)

    # Ensure we do not exceed the movie length when allowing over boundaries
    if shifted_indices[-1] + clip_chunk_size > clip_bounds[-1]:
        shifted_indices[-1] = clip_bounds[-1] - clip_chunk_size
    return shifted_indices

handle_missing_start_indices

handle_missing_start_indices(
    movie_length: int,
    chunk_size: int | None,
    scene_length: int | None,
    split: str,
) -> list[int]

Handle missing start indices for different splits of the dataset.

Parameters: movies (np.ndarray or torch.Tensor): The movies data, as an array. chunk_size (int or None): The size of each chunk for training split. Required if split is "train". scene_length (int or None): The length of each scene. Required if split is "validation" or "val". split (str): The type of split, one of "train", "validation", "val", or "test".

Returns: dict or list: The generated or provided start indices for each movie.

Raises: AssertionError: If chunk_size is not provided for training split when start_indices is None. AssertionError: If scene_length is not provided for validation split when start_indices is None. NotImplementedError: If start_indices is None and split is not one of "train", "validation", "val", or "test".

Source code in openretina/data_io/base_dataloader.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def handle_missing_start_indices(
    movie_length: int, chunk_size: int | None, scene_length: int | None, split: str
) -> list[int]:
    """
    Handle missing start indices for different splits of the dataset.

    Parameters:
    movies (np.ndarray or torch.Tensor): The movies data, as an array.
    chunk_size (int or None): The size of each chunk for training split. Required if split is "train".
    scene_length (int or None): The length of each scene. Required if split is "validation" or "val".
    split (str): The type of split, one of "train", "validation", "val", or "test".

    Returns:
    dict or list: The generated or provided start indices for each movie.

    Raises:
    AssertionError: If chunk_size is not provided for training split when start_indices is None.
    AssertionError: If scene_length is not provided for validation split when start_indices is None.
    NotImplementedError: If start_indices is None and split is not one of "train", "validation", "val", or "test".
    """

    if split == "train":
        assert chunk_size is not None, "Chunk size or start indices must be provided for training."
        interval = chunk_size
    elif split in {"validation", "val"}:
        assert scene_length is not None, "Scene length or start indices must be provided for validation."
        interval = scene_length
    elif split == "test":
        interval = movie_length
    else:
        raise NotImplementedError("Start indices could not be recovered.")

    return np.arange(0, movie_length, interval).tolist()  # type: ignore

get_movie_dataloader

get_movie_dataloader(
    movie: Float[
        ndarray | Tensor, "n_channels n_frames h w"
    ],
    responses: Float[ndarray | Tensor, "n_frames n_neurons"]
    | dict[str, Any],
    *,
    split: str
    | Literal["train", "validation", "val", "test"],
    scene_length: int,
    chunk_size: int,
    batch_size: int,
    start_indices: list[int] | None = None,
    roi_ids: Float[ndarray, " n_neurons"] | None = None,
    roi_coords: Float[ndarray, "n_neurons 2"] | None = None,
    group_assignment: Float[ndarray, " n_neurons"]
    | None = None,
    drop_last: bool = True,
    allow_over_boundaries: bool = True,
    **kwargs,
) -> DataLoader

Create a DataLoader for processing movie data and associated responses. This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

PARAMETER DESCRIPTION
movie

The movie data represented as a multi-dimensional array or tensor.

TYPE: Float[ndarray | Tensor, 'n_channels n_frames h w']

responses

The responses corresponding to the frames of the movie.

TYPE: Float[ndarray, 'n_frames n_neurons']

split

The dataset split to use (train, validation, or test).

TYPE: str | Literal['train', 'validation', 'val', 'test']

scene_length

The length of the scene to be processed.

TYPE: int

chunk_size

The size of each chunk to be extracted from the movie.

TYPE: int

batch_size

The number of samples per batch.

TYPE: int

start_indices

The starting indices for each chunk. If None, will be computed.

TYPE: list[int] | None DEFAULT: None

roi_ids

The region of interest IDs. If None, will not be used.

TYPE: Float[ndarray, ' n_neurons'] | None DEFAULT: None

roi_coords

The coordinates of the regions of interest. If None, will not be used.

TYPE: Float[ndarray, 'n_neurons 2'] | None DEFAULT: None

group_assignment

The group assignments (cell types) for the neurons. If None, will not be used.

TYPE: Float[ndarray, ' n_neurons'] | None DEFAULT: None

drop_last

Whether to drop the last incomplete batch. Defaults to True.

TYPE: bool DEFAULT: True

allow_over_boundaries

Whether to allow chunks that exceed the scene boundaries. Defaults to True.

TYPE: bool DEFAULT: True

**kwargs

Additional keyword arguments for the DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

A DataLoader instance configured with the specified dataset and sampler.

TYPE: DataLoader

RAISES DESCRIPTION
ValueError

If allow_over_boundaries is False and chunk_size exceeds scene_length during training.

Source code in openretina/data_io/base_dataloader.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def get_movie_dataloader(
    movie: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
    responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"] | dict[str, Any],
    *,
    split: str | Literal["train", "validation", "val", "test"],
    scene_length: int,
    chunk_size: int,
    batch_size: int,
    start_indices: list[int] | None = None,
    roi_ids: Float[np.ndarray, " n_neurons"] | None = None,
    roi_coords: Float[np.ndarray, "n_neurons 2"] | None = None,
    group_assignment: Float[np.ndarray, " n_neurons"] | None = None,
    drop_last: bool = True,
    allow_over_boundaries: bool = True,
    **kwargs,
) -> DataLoader:
    """
    Create a DataLoader for processing movie data and associated responses.
    This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

    Args:
        movie (Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"]):
            The movie data represented as a multi-dimensional array or tensor.
        responses (Float[np.ndarray, "n_frames n_neurons"]):
            The responses corresponding to the frames of the movie.
        split (str | Literal["train", "validation", "val", "test"]):
            The dataset split to use (train, validation, or test).
        scene_length (int):
            The length of the scene to be processed.
        chunk_size (int):
            The size of each chunk to be extracted from the movie.
        batch_size (int):
            The number of samples per batch.
        start_indices (list[int] | None, optional):
            The starting indices for each chunk. If None, will be computed.
        roi_ids (Float[np.ndarray, " n_neurons"] | None, optional):
            The region of interest IDs. If None, will not be used.
        roi_coords (Float[np.ndarray, "n_neurons 2"] | None, optional):
            The coordinates of the regions of interest. If None, will not be used.
        group_assignment (Float[np.ndarray, " n_neurons"] | None, optional):
            The group assignments (cell types) for the neurons. If None, will not be used.
        drop_last (bool, optional):
            Whether to drop the last incomplete batch. Defaults to True.
        allow_over_boundaries (bool, optional):
            Whether to allow chunks that exceed the scene boundaries. Defaults to True.
        **kwargs:
            Additional keyword arguments for the DataLoader.

    Returns:
        DataLoader:
            A DataLoader instance configured with the specified dataset and sampler.

    Raises:
        ValueError:
            If `allow_over_boundaries` is False and `chunk_size` exceeds `scene_length` during training.
    """
    if isinstance(responses, torch.Tensor) and bool(torch.isnan(responses).any()):
        print("Nans in responses, skipping this dataloader")
        return  # type: ignore

    if not allow_over_boundaries and split == "train" and chunk_size > scene_length:
        raise ValueError("Clip chunk size must be smaller than scene length to not exceed clip bounds during training.")

    if start_indices is None:
        start_indices = handle_missing_start_indices(movie.shape[1], chunk_size, scene_length, split)
    dataset = MovieDataSet(movie, responses, roi_ids, roi_coords, group_assignment, split, chunk_size)
    sampler = MovieSampler(
        start_indices,
        split,
        chunk_size,
        movie_length=movie.shape[1],
        scene_length=scene_length,
        allow_over_boundaries=allow_over_boundaries,
    )

    return DataLoader(
        dataset, sampler=sampler, batch_size=batch_size, drop_last=split == "train" and drop_last, **kwargs
    )

multiple_movies_dataloaders

multiple_movies_dataloaders(
    neuron_data_dictionary: dict[
        str, ResponsesTrainTestSplit
    ],
    movies_dictionary: dict[str, MoviesTrainTestSplit],
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
    clip_length: int = 100,
    num_val_clips: int = 10,
    val_clip_indices: list[int] | None = None,
    allow_over_boundaries: bool = True,
) -> dict[str, dict[str, DataLoader]]

Create multiple dataloaders for training, validation, and testing from given neuron and movie data. This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session. It does not make assumptions about the movies in different sessions to be the same, the same length, composed of the same clips or in the same order.

PARAMETER DESCRIPTION
neuron_data_dictionary

A dictionary containing neuron response data split for training and testing.

TYPE: dict[str, ResponsesTrainTestSplit]

movies_dictionary

A dictionary containing movie data split for training and testing.

TYPE: dict[str, MoviesTrainTestSplit]

train_chunk_size

The size of the chunks for training data. Defaults to 50.

TYPE: int DEFAULT: 50

batch_size

The number of samples per batch. Defaults to 32.

TYPE: int DEFAULT: 32

seed

The random seed for reproducibility. Defaults to 42.

TYPE: int DEFAULT: 42

clip_length

The length of each clip. Defaults to 100.

TYPE: int DEFAULT: 100

num_val_clips

The number of validation clips to draw. Defaults to 10.

TYPE: int DEFAULT: 10

val_clip_indices

The indices of validation clips to use. If provided, num_val_clips is ignored. Defaults to None.

TYPE: list[int] DEFAULT: None

allow_over_boundaries

Whether to allow selected chunks to go over scene boundaries.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

A dictionary containing dataloaders for training, validation, and testing for each session.

TYPE: dict[str, dict[str, DataLoader]]

RAISES DESCRIPTION
AssertionError

If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.

Source code in openretina/data_io/base_dataloader.py
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def multiple_movies_dataloaders(
    neuron_data_dictionary: dict[str, ResponsesTrainTestSplit],
    movies_dictionary: dict[str, MoviesTrainTestSplit],
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
    clip_length: int = 100,
    num_val_clips: int = 10,
    val_clip_indices: list[int] | None = None,
    allow_over_boundaries: bool = True,
) -> dict[str, dict[str, DataLoader]]:
    """
    Create multiple dataloaders for training, validation, and testing from given neuron and movie data.
    This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session.
    It does not make assumptions about the movies in different sessions to be the same, the same length, composed
    of the same clips or in the same order.

    Args:
        neuron_data_dictionary (dict[str, ResponsesTrainTestSplit]):
            A dictionary containing neuron response data split for training and testing.
        movies_dictionary (dict[str, MoviesTrainTestSplit]):
            A dictionary containing movie data split for training and testing.
        train_chunk_size (int, optional):
            The size of the chunks for training data. Defaults to 50.
        batch_size (int, optional):
            The number of samples per batch. Defaults to 32.
        seed (int, optional):
            The random seed for reproducibility. Defaults to 42.
        clip_length (int, optional):
            The length of each clip. Defaults to 100.
        num_val_clips (int, optional):
            The number of validation clips to draw. Defaults to 10.
        val_clip_indices (list[int], optional): The indices of validation clips to use. If provided, num_val_clips is
                                                ignored. Defaults to None.
        allow_over_boundaries (bool, optional):  Whether to allow selected chunks to go over scene boundaries.

    Returns:
        dict:
            A dictionary containing dataloaders for training, validation, and testing for each session.

    Raises:
        AssertionError:
            If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.
    """
    assert set(neuron_data_dictionary.keys()) == set(movies_dictionary.keys()), (
        "The keys of neuron_data_dictionary and movies_dictionary should match exactly."
    )

    # Initialise dataloaders
    dataloaders: dict[str, Any] = collections.defaultdict(dict)

    for session_key, session_data in tqdm(neuron_data_dictionary.items(), desc="Creating movie dataloaders"):
        # Extract all data related to the movies first
        num_clips = movies_dictionary[session_key].train.shape[1] // clip_length

        if val_clip_indices is not None:
            val_clip_idx = val_clip_indices
        else:
            # Draw validation clips based on the random seed
            rnd = np.random.RandomState(seed)
            val_clip_idx = list(rnd.choice(num_clips, num_val_clips, replace=False))

        movie_train_subset, movie_val, movie_test_dict = generate_movie_splits(
            movies_dictionary[session_key].train,
            movies_dictionary[session_key].test_dict,
            val_clip_idc=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Extract all splits from neural data
        neuron_data = NeuronDataSplit(
            responses=session_data,
            val_clip_idx=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Create dataloaders for each fold
        for fold, movie, chunk_size in [
            ("train", movie_train_subset, train_chunk_size),
            ("validation", movie_val, clip_length),
        ]:
            dataloaders[fold][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict[fold],
                split=fold,
                chunk_size=chunk_size,
                batch_size=batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )
        # test movies
        for name, movie in movie_test_dict.items():
            dataloaders[name][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict_test[name],
                split="test",
                chunk_size=movie.shape[1],
                batch_size=batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )

    return dataloaders