Skip to content

Base Dataloader

MovieDataSet

MovieDataSet(
    movies: Float[
        ndarray | Tensor, "n_channels n_frames h w"
    ],
    responses: Float[ndarray | Tensor, "n_frames n_neurons"]
    | dict[str, Any],
    roi_ids: Float[ndarray, " n_neurons"] | None,
    roi_coords: Float[ndarray, "n_neurons 2"] | None,
    group_assignment: Float[ndarray, " n_neurons"] | None,
    split: str
    | Literal["train", "validation", "val", "test"],
    chunk_size: int,
)

Bases: Dataset

A dataset class for handling movie data and corresponding neural responses.

PARAMETER DESCRIPTION
movies

The movie data.

TYPE: Float[ndarray | Tensor, 'n_channels n_frames h w']

responses

TYPE: Float[ndarray, 'n_frames n_neurons'] | dict[str, Float[ndarray, 'n_frames n_neurons']]

roi_ids

A list of ROI IDs.

TYPE: Optional[Float[ndarray, ' n_neurons']]

roi_coords

A list of ROI coordinates.

TYPE: Optional[Float[ndarray, 'n_neurons 2']]

group_assignment

A list of group assignments (cell types).

TYPE: Optional[Float[ndarray, ' n_neurons']]

split
                                    The data split, either "train", "validation", "val", or "test".

TYPE: Literal['train', 'validation', 'val', 'test']

chunk_size

The size of the chunks to split the data into.

TYPE: int

ATTRIBUTE DESCRIPTION
samples

A tuple containing movie data and neural responses.

TYPE: tuple

test_responses_by_trial
                                A dictionary containing test responses by trial (only for test split).

TYPE: Optional[Dict[str, Any]]

roi_ids

A list of region of interest (ROI) IDs.

TYPE: Optional[Float[ndarray, ' n_neurons']]

chunk_size

The size of the chunks to split the data into.

TYPE: int

mean_response

The mean response per neuron.

TYPE: Tensor

group_assignment

A list of group assignments.

TYPE: Optional[Float[ndarray, ' n_neurons']]

roi_coords

A list of ROI coordinates.

TYPE: Optional[Float[ndarray, 'n_neurons 2']]

METHOD DESCRIPTION
__getitem__

Returns a DataPoint object for the given index or slice.

movies

Returns the movie data.

responses

Returns the neural responses.

__len__

Returns the number of chunks of clips and responses used for training.

__str__

Returns a string representation of the dataset.

__repr__

Returns a string representation of the dataset.

Source code in openretina/data_io/base_dataloader.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
    self,
    movies: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
    responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"] | dict[str, Any],
    roi_ids: Float[np.ndarray, " n_neurons"] | None,
    roi_coords: Float[np.ndarray, "n_neurons 2"] | None,
    group_assignment: Float[np.ndarray, " n_neurons"] | None,
    split: str | Literal["train", "validation", "val", "test"],
    chunk_size: int,
):
    self.roi_ids = roi_ids
    self.test_responses_by_trial: torch.Tensor | None = None

    # Test responses can be passed as a dictionary by other constructors,
    # with key "avg" for the averaged responses and "by_trial" for the per-trial responses.
    if split == "test" and isinstance(responses, dict):
        responses_dict = cast(dict[str, Any], responses)
        self.samples: tuple = movies, responses_dict["avg"]
        self.test_responses_by_trial = responses_dict.get("by_trial")
    else:
        self.samples = movies, responses

    self.chunk_size = chunk_size
    # Calculate the mean response per neuron (used for bias init in the model)
    self.mean_response = torch.mean(torch.Tensor(self.samples[1]), dim=0)
    self.group_assignment = group_assignment
    self.roi_coords = roi_coords

MovieSampler

MovieSampler(
    start_indices: list[int],
    split: str
    | Literal["train", "validation", "val", "test"],
    chunk_size: int,
    movie_length: int,
    scene_length: int,
    allow_over_boundaries: bool = False,
)

Bases: Sampler

A custom sampler for selecting movie frames for training, validation, or testing.

PARAMETER DESCRIPTION
start_indices

List of starting indices for the movie sections to select.

TYPE: list[int]

split

The type of data split.

TYPE: Literal['train', 'validation', 'val', 'test']

chunk_size

The size of each contiguous chunk of frames to select.

TYPE: int

movie_length

The total length of the movie.

TYPE: int

scene_length

The length of each scene, if the movie is divided in any scenes. Defaults to None.

TYPE: Optional[int]

allow_over_boundaries

Whether to allow selected chunks to go over scene boundaries. Defaults to False.

TYPE: bool DEFAULT: False

ATTRIBUTE DESCRIPTION
indices

The starting indices for the movie sections to sample.

TYPE: list[int]

split

The type of data split.

TYPE: str

chunk_size

The size of each chunk of frames.

TYPE: int

movie_length

The total length of the movie.

TYPE: int

scene_length

The length of each scene, if the movie is made up of scenes.

TYPE: int

allow_over_boundaries

Whether to allow chunks to go over scene boundaries.

TYPE: bool

METHOD DESCRIPTION
__iter__

Returns an iterator over the sampled indices.

__len__

Returns the number of starting indices (which will corresponds to the number of sampled clips).

Source code in openretina/data_io/base_dataloader.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def __init__(
    self,
    start_indices: list[int],
    split: str | Literal["train", "validation", "val", "test"],
    chunk_size: int,
    movie_length: int,
    scene_length: int,
    allow_over_boundaries: bool = False,
):
    super().__init__()
    self.indices = start_indices
    self.split = split
    self.chunk_size = chunk_size
    self.movie_length = movie_length
    self.scene_length = scene_length
    self.allow_over_boundaries = allow_over_boundaries

NeuronDataSplit

NeuronDataSplit(
    responses: ResponsesTrainTestSplit,
    val_clip_idx: List[int],
    num_clips: int,
    clip_length: int,
    key: Optional[dict] = None,
    **kwargs,
)

Preprocesses ResponsesTrainTestSplit objects before feeding them to dataloaders.

Responsibilities
  • Remove validation clips from the training responses while storing them separately.
  • Expose torch tensors for train/val/test splits via response_dict.
  • Surface averaged and per-trial test responses per each test stimulus name so that downstream MovieDataSet instances can provide dataset.test_responses_by_trial.

Initialize the NeuronData object. Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

PARAMETER DESCRIPTION
key

The key information for the neuron data, includes date, exp_num, experimenter, field_id, stim_id.

TYPE: dict DEFAULT: None

responses

The train and test responses of neurons.

TYPE: ResponsesTrainTestSplit

val_clip_idx

The indices of validation clips.

TYPE: List[int]

num_clips

The number of clips.

TYPE: int

clip_length

The length of each clip.

TYPE: int

key

Additional key information.

TYPE: dict DEFAULT: None

Source code in openretina/data_io/base_dataloader.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def __init__(
    self,
    responses: ResponsesTrainTestSplit,
    val_clip_idx: List[int],
    num_clips: int,
    clip_length: int,
    key: Optional[dict] = None,
    **kwargs,
):
    """
    Initialize the NeuronData object.
    Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

    Args:
        key (dict): The key information for the neuron data,
                    includes date, exp_num, experimenter, field_id, stim_id.
        responses (ResponsesTrainTestSplit): The train and test responses of neurons.
        val_clip_idx (List[int]): The indices of validation clips.
        num_clips (int): The number of clips.
        clip_length (int): The length of each clip.
        key (dict, optional): Additional key information.
    """
    self.neural_responses = responses
    self.num_neurons = self.neural_responses.n_neurons
    self.key = key
    self.roi_coords = ()
    self.clip_length = clip_length
    self.num_clips = num_clips
    self.val_clip_idx = val_clip_idx

    # Transpose the responses to have the shape (n_timepoints, n_neurons)
    self.responses_train_and_val = self.neural_responses.train.T

    self.responses_train, self.responses_val = self.split_data_train_val()
    self.test_responses_by_trial: dict[str, np.ndarray] = (
        {name: np.asarray(by_trial) for name, by_trial in self.neural_responses.test_by_trial_dict.items()}
        if self.neural_responses.test_by_trial_dict is not None
        else {}
    )

response_dict property

response_dict: dict

Create and return a dictionary of neural responses for train, validation, and test datasets.

Structure

{ "train": Tensor[T_train, neurons], "validation": Tensor[T_val, neurons], "test": { stimulus_name: { "avg": Tensor[T_test, neurons], "by_trial": Optional[Tensor[trials, T_test, neurons]] }, ... } }

response_dict_test property

response_dict_test: dict[str, dict[str, Tensor | None]]

Torch representation of the averaged and per-trial test responses keyed by stimulus name.

RETURNS DESCRIPTION
dict[str, dict[str, Tensor | None]]

{ stimulus_name: {"avg": Tensor[t_test, neurons], "by_trial": Optional[Tensor[trials, t_test, neurons]]}

dict[str, dict[str, Tensor | None]]

}

split_data_train_val

split_data_train_val() -> tuple[ndarray, ndarray]

Compute validation responses and updated train responses stripped from validation clips. Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

RETURNS DESCRIPTION
tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.

Source code in openretina/data_io/base_dataloader.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def split_data_train_val(self) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute validation responses and updated train responses stripped from validation clips.
    Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.
    """
    # Initialise validation responses
    base_movie_sorting = np.arange(self.num_clips)

    validation_mask = np.ones_like(self.responses_train_and_val, dtype=bool)
    responses_val = np.zeros([len(self.val_clip_idx) * self.clip_length, self.num_neurons])

    # Compute validation responses and remove sections from training responses
    for i, ind1 in enumerate(self.val_clip_idx):
        grab_index = base_movie_sorting[ind1]
        responses_val[i * self.clip_length : (i + 1) * self.clip_length, :] = self.responses_train_and_val[
            grab_index * self.clip_length : (grab_index + 1) * self.clip_length,
            :,
        ]
        validation_mask[
            (grab_index * self.clip_length) : (grab_index + 1) * self.clip_length,
            :,
        ] = False

    responses_train = self.responses_train_and_val[validation_mask].reshape(-1, self.num_neurons)

    return responses_train, responses_val

generate_movie_splits

generate_movie_splits(
    movie_train,
    movie_test: dict[str, ndarray],
    val_clip_idc: list[int],
    num_clips: int,
    clip_length: int,
) -> tuple[Tensor, Tensor, dict[str, Tensor]]

Split training movies into train/validation subsets and convert test movies to tensors.

Source code in openretina/data_io/base_dataloader.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def generate_movie_splits(
    movie_train,
    movie_test: dict[str, np.ndarray],
    val_clip_idc: list[int],
    num_clips: int,
    clip_length: int,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor]]:
    """Split training movies into train/validation subsets and convert test movies to tensors."""
    movie_train = torch.tensor(movie_train, dtype=torch.float)
    movie_test_dict = {n: torch.tensor(movie, dtype=torch.float) for n, movie in movie_test.items()}

    channels, _, px_y, px_x = movie_train.shape

    # Prepare validation movie data
    movie_val = torch.zeros((channels, len(val_clip_idc) * clip_length, px_y, px_x), dtype=torch.float)
    for i, idx in enumerate(val_clip_idc):
        movie_val[:, i * clip_length : (i + 1) * clip_length, ...] = movie_train[
            :, idx * clip_length : (idx + 1) * clip_length, ...
        ]

    # Create a boolean mask to indicate which clips are not part of the validation set
    mask = np.ones(num_clips, dtype=bool)
    mask[val_clip_idc] = False
    train_clip_idx = np.arange(num_clips)[mask]

    movie_train_subset = torch.cat(
        [movie_train[:, i * clip_length : (i + 1) * clip_length] for i in train_clip_idx],
        dim=1,
    )

    return movie_train_subset, movie_val, movie_test_dict

gen_shifts_with_boundaries

gen_shifts_with_boundaries(
    clip_bounds: list[int] | ndarray,
    start_indices: list[int] | ndarray,
    clip_chunk_size: int = 50,
) -> list[int]

Generate shifted indices based on clip bounds and start indices. Assumes that the original start indices are already within the clip bounds. If they are not, it changes the overflowing indexes to respect the closest bound.

PARAMETER DESCRIPTION
clip_bounds

A list of clip bounds.

TYPE: list

start_indices

A list of start indices.

TYPE: list

clip_chunk_size

The size of each clip chunk. Defaults to 50.

TYPE: int DEFAULT: 50

RETURNS DESCRIPTION
list

A list of shifted indices.

TYPE: list[int]

Source code in openretina/data_io/base_dataloader.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def gen_shifts_with_boundaries(
    clip_bounds: list[int] | np.ndarray, start_indices: list[int] | np.ndarray, clip_chunk_size: int = 50
) -> list[int]:
    """
    Generate shifted indices based on clip bounds and start indices.
    Assumes that the original start indices are already within the clip bounds.
    If they are not, it changes the overflowing indexes to respect the closest bound.

    Args:
        clip_bounds (list): A list of clip bounds.
        start_indices (list): A list of start indices.
        clip_chunk_size (int, optional): The size of each clip chunk. Defaults to 50.

    Returns:
        list: A list of shifted indices.

    """

    def get_next_bound(value, bounds):
        insertion_index = bisect.bisect_right(bounds, value)
        return bounds[min(insertion_index, len(bounds) - 1)]

    shifted_indices = []
    shifts = np.random.randint(0, clip_chunk_size // 2, len(start_indices))

    for i, start_idx in enumerate(start_indices):
        next_bound = get_next_bound(start_idx, clip_bounds)
        if start_idx + shifts[i] + clip_chunk_size < next_bound:
            shifted_indices.append(start_idx + shifts[i])
        elif start_idx + clip_chunk_size > next_bound:
            shifted_indices.append(next_bound - clip_chunk_size)
        else:
            shifted_indices.append(start_idx)

    # Ensure we do not exceed the movie length when allowing over boundaries
    if shifted_indices[-1] + clip_chunk_size > clip_bounds[-1]:
        shifted_indices[-1] = clip_bounds[-1] - clip_chunk_size
    return shifted_indices

get_movie_dataloader

get_movie_dataloader(
    movie: Float[
        ndarray | Tensor, "n_channels n_frames h w"
    ],
    responses: Float[ndarray | Tensor, "n_frames n_neurons"]
    | dict[str, Any],
    *,
    split: str
    | Literal["train", "validation", "val", "test"],
    scene_length: int,
    chunk_size: int,
    batch_size: int,
    start_indices: list[int] | None = None,
    roi_ids: Float[ndarray, " n_neurons"] | None = None,
    roi_coords: Float[ndarray, "n_neurons 2"] | None = None,
    group_assignment: Float[ndarray, " n_neurons"]
    | None = None,
    drop_last: bool = True,
    allow_over_boundaries: bool = True,
    **kwargs,
) -> DataLoader

Create a DataLoader for processing movie data and associated responses. This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

PARAMETER DESCRIPTION
movie

The movie data represented as a multi-dimensional array or tensor.

TYPE: Float[ndarray | Tensor, 'n_channels n_frames h w']

responses

The responses corresponding to the frames of the movie.

TYPE: Float[ndarray, 'n_frames n_neurons']

split

The dataset split to use (train, validation, or test).

TYPE: str | Literal['train', 'validation', 'val', 'test']

scene_length

The length of the scene to be processed.

TYPE: int

chunk_size

The size of each chunk to be extracted from the movie.

TYPE: int

batch_size

The number of samples per batch.

TYPE: int

start_indices

The starting indices for each chunk. If None, will be computed.

TYPE: list[int] | None DEFAULT: None

roi_ids

The region of interest IDs. If None, will not be used.

TYPE: Float[ndarray, ' n_neurons'] | None DEFAULT: None

roi_coords

The coordinates of the regions of interest. If None, will not be used.

TYPE: Float[ndarray, 'n_neurons 2'] | None DEFAULT: None

group_assignment

The group assignments (cell types) for the neurons. If None, will not be used.

TYPE: Float[ndarray, ' n_neurons'] | None DEFAULT: None

drop_last

Whether to drop the last incomplete batch. Defaults to True.

TYPE: bool DEFAULT: True

allow_over_boundaries

Whether to allow chunks that exceed the scene boundaries. Defaults to True.

TYPE: bool DEFAULT: True

**kwargs

Additional keyword arguments for the DataLoader.

DEFAULT: {}

RETURNS DESCRIPTION
DataLoader

A DataLoader instance configured with the specified dataset and sampler.

TYPE: DataLoader

RAISES DESCRIPTION
ValueError

If allow_over_boundaries is False and chunk_size exceeds scene_length during training.

Source code in openretina/data_io/base_dataloader.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def get_movie_dataloader(
    movie: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
    responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"] | dict[str, Any],
    *,
    split: str | Literal["train", "validation", "val", "test"],
    scene_length: int,
    chunk_size: int,
    batch_size: int,
    start_indices: list[int] | None = None,
    roi_ids: Float[np.ndarray, " n_neurons"] | None = None,
    roi_coords: Float[np.ndarray, "n_neurons 2"] | None = None,
    group_assignment: Float[np.ndarray, " n_neurons"] | None = None,
    drop_last: bool = True,
    allow_over_boundaries: bool = True,
    **kwargs,
) -> DataLoader:
    """
    Create a DataLoader for processing movie data and associated responses.
    This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

    Args:
        movie (Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"]):
            The movie data represented as a multi-dimensional array or tensor.
        responses (Float[np.ndarray, "n_frames n_neurons"]):
            The responses corresponding to the frames of the movie.
        split (str | Literal["train", "validation", "val", "test"]):
            The dataset split to use (train, validation, or test).
        scene_length (int):
            The length of the scene to be processed.
        chunk_size (int):
            The size of each chunk to be extracted from the movie.
        batch_size (int):
            The number of samples per batch.
        start_indices (list[int] | None, optional):
            The starting indices for each chunk. If None, will be computed.
        roi_ids (Float[np.ndarray, " n_neurons"] | None, optional):
            The region of interest IDs. If None, will not be used.
        roi_coords (Float[np.ndarray, "n_neurons 2"] | None, optional):
            The coordinates of the regions of interest. If None, will not be used.
        group_assignment (Float[np.ndarray, " n_neurons"] | None, optional):
            The group assignments (cell types) for the neurons. If None, will not be used.
        drop_last (bool, optional):
            Whether to drop the last incomplete batch. Defaults to True.
        allow_over_boundaries (bool, optional):
            Whether to allow chunks that exceed the scene boundaries. Defaults to True.
        **kwargs:
            Additional keyword arguments for the DataLoader.

    Returns:
        DataLoader:
            A DataLoader instance configured with the specified dataset and sampler.

    Raises:
        ValueError:
            If `allow_over_boundaries` is False and `chunk_size` exceeds `scene_length` during training.
    """
    if (isinstance(responses, torch.Tensor) and bool(torch.isnan(responses).any())) or (
        isinstance(responses, np.ndarray) and bool(np.isnan(responses).any())
    ):
        log.warning("Nans in responses, skipping this dataloader")
        return  # type: ignore

    if not allow_over_boundaries and split == "train" and chunk_size > scene_length:
        raise ValueError("Clip chunk size must be smaller than scene length to not exceed clip bounds during training.")

    if start_indices is None:
        if split == "train":
            interval = chunk_size
        elif split in {"validation", "val"}:
            interval = scene_length
        elif split == "test":
            interval = movie.shape[1] if allow_over_boundaries else scene_length
        else:
            raise NotImplementedError("Start indices could not be recovered.")
        start_indices = np.arange(0, movie.shape[1], interval).tolist()  # type: ignore

    dataset = MovieDataSet(movie, responses, roi_ids, roi_coords, group_assignment, split, chunk_size)
    sampler = MovieSampler(
        start_indices,
        split,
        chunk_size,
        movie_length=movie.shape[1],
        scene_length=scene_length,
        allow_over_boundaries=allow_over_boundaries,
    )

    return DataLoader(
        dataset, sampler=sampler, batch_size=batch_size, drop_last=split == "train" and drop_last, **kwargs
    )

multiple_movies_dataloaders

multiple_movies_dataloaders(
    neuron_data_dictionary: dict[
        str, ResponsesTrainTestSplit
    ],
    movies_dictionary: dict[str, MoviesTrainTestSplit],
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
    clip_length: int = 100,
    num_val_clips: int = 10,
    val_clip_indices: list[int] | None = None,
    allow_over_boundaries: bool = True,
) -> dict[str, dict[str, DataLoader]]

Create multiple dataloaders for training, validation, and testing from given neuron and movie data. This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session. It does not make assumptions about the movies in different sessions to be the same, the same length, composed of the same clips or in the same order.

PARAMETER DESCRIPTION
neuron_data_dictionary

A dictionary containing neuron response data split for training and testing.

TYPE: dict[str, ResponsesTrainTestSplit]

movies_dictionary

A dictionary containing movie data split for training and testing.

TYPE: dict[str, MoviesTrainTestSplit]

train_chunk_size

The size of the chunks for training data. Defaults to 50.

TYPE: int DEFAULT: 50

batch_size

The number of samples per batch. Defaults to 32.

TYPE: int DEFAULT: 32

seed

The random seed for reproducibility. Defaults to 42.

TYPE: int DEFAULT: 42

clip_length

The length of each clip. Defaults to 100.

TYPE: int DEFAULT: 100

num_val_clips

The number of validation clips to draw. Defaults to 10.

TYPE: int DEFAULT: 10

val_clip_indices

The indices of validation clips to use. If provided, num_val_clips is ignored. Defaults to None.

TYPE: list[int] DEFAULT: None

allow_over_boundaries

Whether to allow selected chunks to go over scene boundaries.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

A dictionary containing dataloaders for training, validation, and testing for each session.

TYPE: dict[str, dict[str, DataLoader]]

RAISES DESCRIPTION
AssertionError

If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.

Source code in openretina/data_io/base_dataloader.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
def multiple_movies_dataloaders(
    neuron_data_dictionary: dict[str, ResponsesTrainTestSplit],
    movies_dictionary: dict[str, MoviesTrainTestSplit],
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
    clip_length: int = 100,
    num_val_clips: int = 10,
    val_clip_indices: list[int] | None = None,
    allow_over_boundaries: bool = True,
) -> dict[str, dict[str, DataLoader]]:
    """
    Create multiple dataloaders for training, validation, and testing from given neuron and movie data.
    This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session.
    It does not make assumptions about the movies in different sessions to be the same, the same length, composed
    of the same clips or in the same order.

    Args:
        neuron_data_dictionary (dict[str, ResponsesTrainTestSplit]):
            A dictionary containing neuron response data split for training and testing.
        movies_dictionary (dict[str, MoviesTrainTestSplit]):
            A dictionary containing movie data split for training and testing.
        train_chunk_size (int, optional):
            The size of the chunks for training data. Defaults to 50.
        batch_size (int, optional):
            The number of samples per batch. Defaults to 32.
        seed (int, optional):
            The random seed for reproducibility. Defaults to 42.
        clip_length (int, optional):
            The length of each clip. Defaults to 100.
        num_val_clips (int, optional):
            The number of validation clips to draw. Defaults to 10.
        val_clip_indices (list[int], optional): The indices of validation clips to use. If provided, num_val_clips is
                                                ignored. Defaults to None.
        allow_over_boundaries (bool, optional):  Whether to allow selected chunks to go over scene boundaries.

    Returns:
        dict:
            A dictionary containing dataloaders for training, validation, and testing for each session.

    Raises:
        AssertionError:
            If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.
    """
    assert set(neuron_data_dictionary.keys()) == set(movies_dictionary.keys()), (
        "The keys of neuron_data_dictionary and movies_dictionary should match exactly."
    )

    # Initialise dataloaders
    dataloaders: dict[str, Any] = collections.defaultdict(dict)

    for session_key, session_data in tqdm(neuron_data_dictionary.items(), desc="Creating movie dataloaders"):
        # Extract all data related to the movies first
        num_clips = movies_dictionary[session_key].train.shape[1] // clip_length

        if val_clip_indices is not None:
            val_clip_idx = val_clip_indices
        else:
            # Draw validation clips based on the random seed
            rnd = np.random.RandomState(seed)
            val_clip_idx = list(rnd.choice(num_clips, num_val_clips, replace=False))

        movie_train_subset, movie_val, movie_test_dict = generate_movie_splits(
            movies_dictionary[session_key].train,
            movies_dictionary[session_key].test_dict,
            val_clip_idc=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Extract all splits from neural data
        neuron_data = NeuronDataSplit(
            responses=session_data,
            val_clip_idx=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Create dataloaders for each fold
        for fold, movie, chunk_size in [
            ("train", movie_train_subset, train_chunk_size),
            ("validation", movie_val, clip_length),
        ]:
            dataloaders[fold][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict[fold],
                split=fold,
                chunk_size=chunk_size,
                batch_size=batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )
        # test movies
        for name, movie in movie_test_dict.items():
            if allow_over_boundaries:
                test_chunk_size = movie.shape[1]
            else:
                test_chunk_size = clip_length

            test_batch_size = _compute_test_batch_size(batch_size, train_chunk_size, test_chunk_size)
            dataloaders[name][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict_test[name],
                split="test",
                chunk_size=test_chunk_size,
                batch_size=test_batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )

    return dataloaders