Skip to content

Base Dataloader

MovieDataSet

Bases: Dataset

A dataset class for handling movie data and corresponding neural responses.

Parameters:

Name Type Description Default
movies Float[ndarray | Tensor, 'n_channels n_frames h w']

The movie data.

required
responses Float[ndarray, 'n_frames n_neurons']

The neural responses.

required
roi_ids Optional[Float[ndarray, ' n_neurons']]

A list of ROI IDs.

required
roi_coords Optional[Float[ndarray, 'n_neurons 2']]

A list of ROI coordinates.

required
group_assignment Optional[Float[ndarray, ' n_neurons']]

A list of group assignments (cell types).

required
split Literal['train', 'validation', 'val', 'test']
                                    The data split, either "train", "validation", "val", or "test".
required
chunk_size int

The size of the chunks to split the data into.

required

Attributes:

Name Type Description
samples tuple

A tuple containing movie data and neural responses.

test_responses_by_trial Optional[Dict[str, Any]]
                                A dictionary containing test responses by trial (only for test split).
roi_ids Optional[Float[ndarray, ' n_neurons']]

A list of region of interest (ROI) IDs.

chunk_size int

The size of the chunks to split the data into.

mean_response Tensor

The mean response per neuron.

group_assignment Optional[Float[ndarray, ' n_neurons']]

A list of group assignments.

roi_coords Optional[Float[ndarray, 'n_neurons 2']]

A list of ROI coordinates.

Methods:

Name Description
__getitem__

Returns a DataPoint object for the given index or slice.

movies

Returns the movie data.

responses

Returns the neural responses.

__len__

Returns the number of chunks of clips and responses used for training.

__str__

Returns a string representation of the dataset.

__repr__

Returns a string representation of the dataset.

Source code in openretina/data_io/base_dataloader.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class MovieDataSet(Dataset):
    """
    A dataset class for handling movie data and corresponding neural responses.

    Args:
        movies (Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"]): The movie data.
        responses (Float[np.ndarray, "n_frames n_neurons"]): The neural responses.
        roi_ids (Optional[Float[np.ndarray, " n_neurons"]]): A list of ROI IDs.
        roi_coords (Optional[Float[np.ndarray, "n_neurons 2"]]): A list of ROI coordinates.
        group_assignment (Optional[Float[np.ndarray, " n_neurons"]]): A list of group assignments (cell types).
        split (Literal["train", "validation", "val", "test"]):
                                                    The data split, either "train", "validation", "val", or "test".
        chunk_size (int): The size of the chunks to split the data into.

    Attributes:
        samples (tuple): A tuple containing movie data and neural responses.
        test_responses_by_trial (Optional[Dict[str, Any]]):
                                                A dictionary containing test responses by trial (only for test split).
        roi_ids (Optional[Float[np.ndarray, " n_neurons"]]): A list of region of interest (ROI) IDs.
        chunk_size (int): The size of the chunks to split the data into.
        mean_response (torch.Tensor): The mean response per neuron.
        group_assignment (Optional[Float[np.ndarray, " n_neurons"]]): A list of group assignments.
        roi_coords (Optional[Float[np.ndarray, "n_neurons 2"]]): A list of ROI coordinates.

    Methods:
        __getitem__(idx): Returns a DataPoint object for the given index or slice.
        movies: Returns the movie data.
        responses: Returns the neural responses.
        __len__(): Returns the number of chunks of clips and responses used for training.
        __str__(): Returns a string representation of the dataset.
        __repr__(): Returns a string representation of the dataset.
    """

    def __init__(
        self,
        movies: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
        responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"],
        roi_ids: Float[np.ndarray, " n_neurons"] | None,
        roi_coords: Float[np.ndarray, "n_neurons 2"] | None,
        group_assignment: Float[np.ndarray, " n_neurons"] | None,
        split: str | Literal["train", "validation", "val", "test"],
        chunk_size: int,
    ):
        # Will only be a dictionary for certain types of datasets, i.e. Hoefling 2022
        if split == "test" and isinstance(responses, dict):
            self.samples: tuple = movies, responses["avg"]
            self.test_responses_by_trial = responses["by_trial"]
            self.roi_ids = roi_ids
        else:
            self.samples = movies, responses

        self.chunk_size = chunk_size
        # Calculate the mean response per neuron (used for bias init in the model)
        self.mean_response = torch.mean(torch.Tensor(self.samples[1]), dim=0)
        self.group_assignment = group_assignment
        self.roi_coords = roi_coords

    def __getitem__(self, idx: int | slice) -> DataPoint:
        if isinstance(idx, slice):
            return DataPoint(*[self.samples[0][:, idx, ...], self.samples[1][idx, ...]])
        else:
            return DataPoint(
                *[
                    self.samples[0][:, idx : idx + self.chunk_size, ...],
                    self.samples[1][idx : idx + self.chunk_size, ...],
                ]
            )

    @property
    def movies(self):
        return self.samples[0]

    @property
    def responses(self):
        return self.samples[1]

    def __len__(self) -> int:
        # Returns the number of chunks of clips and responses used for training
        return self.samples[1].shape[0] // self.chunk_size

    def __str__(self) -> str:
        return (
            f"MovieDataSet with {self.samples[1].shape[1]} neuron responses "
            f"to a movie of shape {list(self.samples[0].shape)}."
        )

    def __repr__(self) -> str:
        return str(self)

MovieSampler

Bases: Sampler

A custom sampler for selecting movie frames for training, validation, or testing.

Parameters:

Name Type Description Default
start_indices list[int]

List of starting indices for the movie sections to select.

required
split Literal['train', 'validation', 'val', 'test']

The type of data split.

required
chunk_size int

The size of each contiguous chunk of frames to select.

required
movie_length int

The total length of the movie.

required
scene_length Optional[int]

The length of each scene, if the movie is divided in any scenes. Defaults to None.

required
allow_over_boundaries bool

Whether to allow selected chunks to go over scene boundaries. Defaults to False.

False

Attributes:

Name Type Description
indices list[int]

The starting indices for the movie sections to sample.

split str

The type of data split.

chunk_size int

The size of each chunk of frames.

movie_length int

The total length of the movie.

scene_length int

The length of each scene, if the movie is made up of scenes.

allow_over_boundaries bool

Whether to allow chunks to go over scene boundaries.

Methods:

Name Description
__iter__

Returns an iterator over the sampled indices.

__len__

Returns the number of starting indices (which will corresponds to the number of sampled clips).

Source code in openretina/data_io/base_dataloader.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class MovieSampler(Sampler):
    """
    A custom sampler for selecting movie frames for training, validation, or testing.

    Args:
        start_indices (list[int]): List of starting indices for the movie sections to select.
        split (Literal["train", "validation", "val", "test"]): The type of data split.
        chunk_size (int): The size of each contiguous chunk of frames to select.
        movie_length (int): The total length of the movie.
        scene_length (Optional[int], optional): The length of each scene, if the movie is divided in any scenes.
                                                Defaults to None.
        allow_over_boundaries (bool, optional): Whether to allow selected chunks to go over scene boundaries.
                                                Defaults to False.

    Attributes:
        indices (list[int]): The starting indices for the movie sections to sample.
        split (str): The type of data split.
        chunk_size (int): The size of each chunk of frames.
        movie_length (int): The total length of the movie.
        scene_length (int): The length of each scene, if the movie is made up of scenes.
        allow_over_boundaries (bool): Whether to allow chunks to go over scene boundaries.

    Methods:
        __iter__(): Returns an iterator over the sampled indices.
        __len__(): Returns the number of starting indices (which will corresponds to the number of sampled clips).
    """

    def __init__(
        self,
        start_indices: list[int],
        split: str | Literal["train", "validation", "val", "test"],
        chunk_size: int,
        movie_length: int,
        scene_length: int,
        allow_over_boundaries: bool = False,
    ):
        super().__init__()
        self.indices = start_indices
        self.split = split
        self.chunk_size = chunk_size
        self.movie_length = movie_length
        self.scene_length = scene_length
        self.allow_over_boundaries = allow_over_boundaries

    def __iter__(self):
        if self.split == "train" and (self.scene_length != self.chunk_size):
            if self.allow_over_boundaries:
                shifts = np.random.randint(0, self.chunk_size, len(self.indices))
                # apply shifts while making sure we do not exceed the movie length
                shifted_indices = np.minimum(self.indices + shifts, self.movie_length - self.chunk_size)
            else:
                shifted_indices = gen_shifts_with_boundaries(
                    np.arange(0, self.movie_length + 1, self.scene_length),
                    self.indices,
                    self.chunk_size,
                )
            # Shuffle the indices
            indices_shuffling = np.random.permutation(len(self.indices))
        else:
            shifted_indices = self.indices
            indices_shuffling = np.arange(len(self.indices))

        return iter(np.array(shifted_indices)[indices_shuffling])

    def __len__(self) -> int:
        return len(self.indices)

NeuronDataSplit

Source code in openretina/data_io/base_dataloader.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
class NeuronDataSplit:
    def __init__(
        self,
        responses: ResponsesTrainTestSplit,
        val_clip_idx: List[int],
        num_clips: int,
        clip_length: int,
        key: Optional[dict] = None,
        **kwargs,
    ):
        """
        Initialize the NeuronData object.
        Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

        Args:
            key (dict): The key information for the neuron data,
                        includes date, exp_num, experimenter, field_id, stim_id.
            responses (ResponsesTrainTestSplit): The train and test responses of neurons.
            val_clip_idx (List[int]): The indices of validation clips.
            num_clips (int): The number of clips.
            clip_length (int): The length of each clip.
            key (dict, optional): Additional key information.
        """
        self.neural_responses = responses
        self.num_neurons = self.neural_responses.n_neurons
        self.key = key
        self.roi_coords = ()
        self.clip_length = clip_length
        self.num_clips = num_clips
        self.val_clip_idx = val_clip_idx

        # Transpose the responses to have the shape (n_timepoints, n_neurons)
        self.responses_train_and_val = self.neural_responses.train.T

        self.responses_train, self.responses_val = self.split_data_train_val()
        self.test_responses_by_trial = np.array([])  # Added for compatibility with Hoefling et al., 2024

    def split_data_train_val(self) -> tuple[np.ndarray, np.ndarray]:
        """
        Compute validation responses and updated train responses stripped from validation clips.
        Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

        Returns:
            Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.
        """
        # Initialise validation responses
        base_movie_sorting = np.arange(self.num_clips)

        validation_mask = np.ones_like(self.responses_train_and_val, dtype=bool)
        responses_val = np.zeros([len(self.val_clip_idx) * self.clip_length, self.num_neurons])

        # Compute validation responses and remove sections from training responses
        for i, ind1 in enumerate(self.val_clip_idx):
            grab_index = base_movie_sorting[ind1]
            responses_val[i * self.clip_length : (i + 1) * self.clip_length, :] = self.responses_train_and_val[
                grab_index * self.clip_length : (grab_index + 1) * self.clip_length,
                :,
            ]
            validation_mask[
                (grab_index * self.clip_length) : (grab_index + 1) * self.clip_length,
                :,
            ] = False

        responses_train = self.responses_train_and_val[validation_mask].reshape(-1, self.num_neurons)

        return responses_train, responses_val

    @property
    def response_dict(self) -> dict:
        """
        Create and return a dictionary of neural responses for train, validation, and test datasets.
        """
        return {
            "train": torch.tensor(self.responses_train, dtype=torch.float),
            "validation": torch.tensor(self.responses_val, dtype=torch.float),
            "test": {
                "avg": self.response_dict_test,
                "by_trial": torch.tensor(self.test_responses_by_trial, dtype=torch.float),
            },
        }

    @property
    def response_dict_test(self) -> dict[str, torch.Tensor]:
        return {name: torch.tensor(responses.T) for name, responses in self.neural_responses.test_dict.items()}

response_dict property

Create and return a dictionary of neural responses for train, validation, and test datasets.

__init__(responses, val_clip_idx, num_clips, clip_length, key=None, **kwargs)

Initialize the NeuronData object. Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

Parameters:

Name Type Description Default
key dict

The key information for the neuron data, includes date, exp_num, experimenter, field_id, stim_id.

None
responses ResponsesTrainTestSplit

The train and test responses of neurons.

required
val_clip_idx List[int]

The indices of validation clips.

required
num_clips int

The number of clips.

required
clip_length int

The length of each clip.

required
key dict

Additional key information.

None
Source code in openretina/data_io/base_dataloader.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def __init__(
    self,
    responses: ResponsesTrainTestSplit,
    val_clip_idx: List[int],
    num_clips: int,
    clip_length: int,
    key: Optional[dict] = None,
    **kwargs,
):
    """
    Initialize the NeuronData object.
    Boilerplate class to compute and store neuron data train/test/validation splits before feeding into a dataloader

    Args:
        key (dict): The key information for the neuron data,
                    includes date, exp_num, experimenter, field_id, stim_id.
        responses (ResponsesTrainTestSplit): The train and test responses of neurons.
        val_clip_idx (List[int]): The indices of validation clips.
        num_clips (int): The number of clips.
        clip_length (int): The length of each clip.
        key (dict, optional): Additional key information.
    """
    self.neural_responses = responses
    self.num_neurons = self.neural_responses.n_neurons
    self.key = key
    self.roi_coords = ()
    self.clip_length = clip_length
    self.num_clips = num_clips
    self.val_clip_idx = val_clip_idx

    # Transpose the responses to have the shape (n_timepoints, n_neurons)
    self.responses_train_and_val = self.neural_responses.train.T

    self.responses_train, self.responses_val = self.split_data_train_val()
    self.test_responses_by_trial = np.array([])  # Added for compatibility with Hoefling et al., 2024

split_data_train_val()

Compute validation responses and updated train responses stripped from validation clips. Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

Returns:

Type Description
tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.

Source code in openretina/data_io/base_dataloader.py
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
def split_data_train_val(self) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute validation responses and updated train responses stripped from validation clips.
    Can deal with unsorted validation clip indices, and parallels the way movie validation clips are handled.

    Returns:
        Tuple[np.ndarray, np.ndarray]: The updated train and validation responses.
    """
    # Initialise validation responses
    base_movie_sorting = np.arange(self.num_clips)

    validation_mask = np.ones_like(self.responses_train_and_val, dtype=bool)
    responses_val = np.zeros([len(self.val_clip_idx) * self.clip_length, self.num_neurons])

    # Compute validation responses and remove sections from training responses
    for i, ind1 in enumerate(self.val_clip_idx):
        grab_index = base_movie_sorting[ind1]
        responses_val[i * self.clip_length : (i + 1) * self.clip_length, :] = self.responses_train_and_val[
            grab_index * self.clip_length : (grab_index + 1) * self.clip_length,
            :,
        ]
        validation_mask[
            (grab_index * self.clip_length) : (grab_index + 1) * self.clip_length,
            :,
        ] = False

    responses_train = self.responses_train_and_val[validation_mask].reshape(-1, self.num_neurons)

    return responses_train, responses_val

gen_shifts_with_boundaries(clip_bounds, start_indices, clip_chunk_size=50)

Generate shifted indices based on clip bounds and start indices. Assumes that the original start indices are already within the clip bounds. If they are not, it changes the overflowing indexes to respect the closest bound.

Parameters:

Name Type Description Default
clip_bounds list

A list of clip bounds.

required
start_indices list

A list of start indices.

required
clip_chunk_size int

The size of each clip chunk. Defaults to 50.

50

Returns:

Name Type Description
list list[int]

A list of shifted indices.

Source code in openretina/data_io/base_dataloader.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def gen_shifts_with_boundaries(
    clip_bounds: list[int] | np.ndarray, start_indices: list[int] | np.ndarray, clip_chunk_size: int = 50
) -> list[int]:
    """
    Generate shifted indices based on clip bounds and start indices.
    Assumes that the original start indices are already within the clip bounds.
    If they are not, it changes the overflowing indexes to respect the closest bound.

    Args:
        clip_bounds (list): A list of clip bounds.
        start_indices (list): A list of start indices.
        clip_chunk_size (int, optional): The size of each clip chunk. Defaults to 50.

    Returns:
        list: A list of shifted indices.

    """

    def get_next_bound(value, bounds):
        insertion_index = bisect.bisect_right(bounds, value)
        return bounds[min(insertion_index, len(bounds) - 1)]

    shifted_indices = []
    shifts = np.random.randint(0, clip_chunk_size // 2, len(start_indices))

    for i, start_idx in enumerate(start_indices):
        next_bound = get_next_bound(start_idx, clip_bounds)
        if start_idx + shifts[i] + clip_chunk_size < next_bound:
            shifted_indices.append(start_idx + shifts[i])
        elif start_idx + clip_chunk_size > next_bound:
            shifted_indices.append(next_bound - clip_chunk_size)
        else:
            shifted_indices.append(start_idx)

    # Ensure we do not exceed the movie length when allowing over boundaries
    if shifted_indices[-1] + clip_chunk_size > clip_bounds[-1]:
        shifted_indices[-1] = clip_bounds[-1] - clip_chunk_size
    return shifted_indices

get_movie_dataloader(movie, responses, *, split, scene_length, chunk_size, batch_size, start_indices=None, roi_ids=None, roi_coords=None, group_assignment=None, drop_last=True, allow_over_boundaries=True, **kwargs)

Create a DataLoader for processing movie data and associated responses. This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

Parameters:

Name Type Description Default
movie Float[ndarray | Tensor, 'n_channels n_frames h w']

The movie data represented as a multi-dimensional array or tensor.

required
responses Float[ndarray, 'n_frames n_neurons']

The responses corresponding to the frames of the movie.

required
split str | Literal['train', 'validation', 'val', 'test']

The dataset split to use (train, validation, or test).

required
scene_length int

The length of the scene to be processed.

required
chunk_size int

The size of each chunk to be extracted from the movie.

required
batch_size int

The number of samples per batch.

required
start_indices list[int] | None

The starting indices for each chunk. If None, will be computed.

None
roi_ids Float[ndarray, ' n_neurons'] | None

The region of interest IDs. If None, will not be used.

None
roi_coords Float[ndarray, 'n_neurons 2'] | None

The coordinates of the regions of interest. If None, will not be used.

None
group_assignment Float[ndarray, ' n_neurons'] | None

The group assignments (cell types) for the neurons. If None, will not be used.

None
drop_last bool

Whether to drop the last incomplete batch. Defaults to True.

True
allow_over_boundaries bool

Whether to allow chunks that exceed the scene boundaries. Defaults to True.

True
**kwargs

Additional keyword arguments for the DataLoader.

{}

Returns:

Name Type Description
DataLoader DataLoader

A DataLoader instance configured with the specified dataset and sampler.

Raises:

Type Description
ValueError

If allow_over_boundaries is False and chunk_size exceeds scene_length during training.

Source code in openretina/data_io/base_dataloader.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def get_movie_dataloader(
    movie: Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"],
    responses: Float[np.ndarray | torch.Tensor, "n_frames n_neurons"],
    *,
    split: str | Literal["train", "validation", "val", "test"],
    scene_length: int,
    chunk_size: int,
    batch_size: int,
    start_indices: list[int] | None = None,
    roi_ids: Float[np.ndarray, " n_neurons"] | None = None,
    roi_coords: Float[np.ndarray, "n_neurons 2"] | None = None,
    group_assignment: Float[np.ndarray, " n_neurons"] | None = None,
    drop_last: bool = True,
    allow_over_boundaries: bool = True,
    **kwargs,
) -> DataLoader:
    """
    Create a DataLoader for processing movie data and associated responses.
    This function prepares the dataset and sampler for training or evaluation based on the specified parameters.

    Args:
        movie (Float[np.ndarray | torch.Tensor, "n_channels n_frames h w"]):
            The movie data represented as a multi-dimensional array or tensor.
        responses (Float[np.ndarray, "n_frames n_neurons"]):
            The responses corresponding to the frames of the movie.
        split (str | Literal["train", "validation", "val", "test"]):
            The dataset split to use (train, validation, or test).
        scene_length (int):
            The length of the scene to be processed.
        chunk_size (int):
            The size of each chunk to be extracted from the movie.
        batch_size (int):
            The number of samples per batch.
        start_indices (list[int] | None, optional):
            The starting indices for each chunk. If None, will be computed.
        roi_ids (Float[np.ndarray, " n_neurons"] | None, optional):
            The region of interest IDs. If None, will not be used.
        roi_coords (Float[np.ndarray, "n_neurons 2"] | None, optional):
            The coordinates of the regions of interest. If None, will not be used.
        group_assignment (Float[np.ndarray, " n_neurons"] | None, optional):
            The group assignments (cell types) for the neurons. If None, will not be used.
        drop_last (bool, optional):
            Whether to drop the last incomplete batch. Defaults to True.
        allow_over_boundaries (bool, optional):
            Whether to allow chunks that exceed the scene boundaries. Defaults to True.
        **kwargs:
            Additional keyword arguments for the DataLoader.

    Returns:
        DataLoader:
            A DataLoader instance configured with the specified dataset and sampler.

    Raises:
        ValueError:
            If `allow_over_boundaries` is False and `chunk_size` exceeds `scene_length` during training.
    """
    if isinstance(responses, torch.Tensor) and bool(torch.isnan(responses).any()):
        print("Nans in responses, skipping this dataloader")
        return  # type: ignore

    if not allow_over_boundaries and split == "train" and chunk_size > scene_length:
        raise ValueError("Clip chunk size must be smaller than scene length to not exceed clip bounds during training.")

    if start_indices is None:
        start_indices = handle_missing_start_indices(movie.shape[1], chunk_size, scene_length, split)
    dataset = MovieDataSet(movie, responses, roi_ids, roi_coords, group_assignment, split, chunk_size)
    sampler = MovieSampler(
        start_indices,
        split,
        chunk_size,
        movie_length=movie.shape[1],
        scene_length=scene_length,
        allow_over_boundaries=allow_over_boundaries,
    )

    return DataLoader(
        dataset, sampler=sampler, batch_size=batch_size, drop_last=split == "train" and drop_last, **kwargs
    )

handle_missing_start_indices(movie_length, chunk_size, scene_length, split)

Handle missing start indices for different splits of the dataset.

Parameters: movies (np.ndarray or torch.Tensor): The movies data, as an array. chunk_size (int or None): The size of each chunk for training split. Required if split is "train". scene_length (int or None): The length of each scene. Required if split is "validation" or "val". split (str): The type of split, one of "train", "validation", "val", or "test".

Returns: dict or list: The generated or provided start indices for each movie.

Raises: AssertionError: If chunk_size is not provided for training split when start_indices is None. AssertionError: If scene_length is not provided for validation split when start_indices is None. NotImplementedError: If start_indices is None and split is not one of "train", "validation", "val", or "test".

Source code in openretina/data_io/base_dataloader.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
def handle_missing_start_indices(
    movie_length: int, chunk_size: int | None, scene_length: int | None, split: str
) -> list[int]:
    """
    Handle missing start indices for different splits of the dataset.

    Parameters:
    movies (np.ndarray or torch.Tensor): The movies data, as an array.
    chunk_size (int or None): The size of each chunk for training split. Required if split is "train".
    scene_length (int or None): The length of each scene. Required if split is "validation" or "val".
    split (str): The type of split, one of "train", "validation", "val", or "test".

    Returns:
    dict or list: The generated or provided start indices for each movie.

    Raises:
    AssertionError: If chunk_size is not provided for training split when start_indices is None.
    AssertionError: If scene_length is not provided for validation split when start_indices is None.
    NotImplementedError: If start_indices is None and split is not one of "train", "validation", "val", or "test".
    """

    if split == "train":
        assert chunk_size is not None, "Chunk size or start indices must be provided for training."
        interval = chunk_size
    elif split in {"validation", "val"}:
        assert scene_length is not None, "Scene length or start indices must be provided for validation."
        interval = scene_length
    elif split == "test":
        interval = movie_length
    else:
        raise NotImplementedError("Start indices could not be recovered.")

    return np.arange(0, movie_length, interval).tolist()  # type: ignore

multiple_movies_dataloaders(neuron_data_dictionary, movies_dictionary, train_chunk_size=50, batch_size=32, seed=42, clip_length=100, num_val_clips=10, val_clip_indices=None, allow_over_boundaries=True)

Create multiple dataloaders for training, validation, and testing from given neuron and movie data. This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session. It does not make assumptions about the movies in different sessions to be the same, the same length, composed of the same clips or in the same order.

Parameters:

Name Type Description Default
neuron_data_dictionary dict[str, ResponsesTrainTestSplit]

A dictionary containing neuron response data split for training and testing.

required
movies_dictionary dict[str, MoviesTrainTestSplit]

A dictionary containing movie data split for training and testing.

required
train_chunk_size int

The size of the chunks for training data. Defaults to 50.

50
batch_size int

The number of samples per batch. Defaults to 32.

32
seed int

The random seed for reproducibility. Defaults to 42.

42
clip_length int

The length of each clip. Defaults to 100.

100
num_val_clips int

The number of validation clips to draw. Defaults to 10.

10
val_clip_indices list[int]

The indices of validation clips to use. If provided, num_val_clips is ignored. Defaults to None.

None
allow_over_boundaries bool

Whether to allow selected chunks to go over scene boundaries.

True

Returns:

Name Type Description
dict dict[str, dict[str, DataLoader]]

A dictionary containing dataloaders for training, validation, and testing for each session.

Raises:

Type Description
AssertionError

If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.

Source code in openretina/data_io/base_dataloader.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def multiple_movies_dataloaders(
    neuron_data_dictionary: dict[str, ResponsesTrainTestSplit],
    movies_dictionary: dict[str, MoviesTrainTestSplit],
    train_chunk_size: int = 50,
    batch_size: int = 32,
    seed: int = 42,
    clip_length: int = 100,
    num_val_clips: int = 10,
    val_clip_indices: list[int] | None = None,
    allow_over_boundaries: bool = True,
) -> dict[str, dict[str, DataLoader]]:
    """
    Create multiple dataloaders for training, validation, and testing from given neuron and movie data.
    This function ensures that the neuron data and movie data are aligned and generates dataloaders for each session.
    It does not make assumptions about the movies in different sessions to be the same, the same length, composed
    of the same clips or in the same order.

    Args:
        neuron_data_dictionary (dict[str, ResponsesTrainTestSplit]):
            A dictionary containing neuron response data split for training and testing.
        movies_dictionary (dict[str, MoviesTrainTestSplit]):
            A dictionary containing movie data split for training and testing.
        train_chunk_size (int, optional):
            The size of the chunks for training data. Defaults to 50.
        batch_size (int, optional):
            The number of samples per batch. Defaults to 32.
        seed (int, optional):
            The random seed for reproducibility. Defaults to 42.
        clip_length (int, optional):
            The length of each clip. Defaults to 100.
        num_val_clips (int, optional):
            The number of validation clips to draw. Defaults to 10.
        val_clip_indices (list[int], optional): The indices of validation clips to use. If provided, num_val_clips is
                                                ignored. Defaults to None.
        allow_over_boundaries (bool, optional):  Whether to allow selected chunks to go over scene boundaries.

    Returns:
        dict:
            A dictionary containing dataloaders for training, validation, and testing for each session.

    Raises:
        AssertionError:
            If the keys of neuron_data_dictionary and movies_dictionary do not match exactly.
    """
    assert set(neuron_data_dictionary.keys()) == set(movies_dictionary.keys()), (
        "The keys of neuron_data_dictionary and movies_dictionary should match exactly."
    )

    # Initialise dataloaders
    dataloaders: dict[str, Any] = collections.defaultdict(dict)

    for session_key, session_data in tqdm(neuron_data_dictionary.items(), desc="Creating movie dataloaders"):
        # Extract all data related to the movies first
        num_clips = movies_dictionary[session_key].train.shape[1] // clip_length

        if val_clip_indices is not None:
            val_clip_idx = val_clip_indices
        else:
            # Draw validation clips based on the random seed
            rnd = np.random.RandomState(seed)
            val_clip_idx = list(rnd.choice(num_clips, num_val_clips, replace=False))

        movie_train_subset, movie_val, movie_test_dict = generate_movie_splits(
            movies_dictionary[session_key].train,
            movies_dictionary[session_key].test_dict,
            val_clip_idc=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Extract all splits from neural data
        neuron_data = NeuronDataSplit(
            responses=session_data,
            val_clip_idx=val_clip_idx,
            num_clips=num_clips,
            clip_length=clip_length,
        )

        # Create dataloaders for each fold
        for fold, movie, chunk_size in [
            ("train", movie_train_subset, train_chunk_size),
            ("validation", movie_val, clip_length),
        ]:
            dataloaders[fold][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict[fold],
                split=fold,
                chunk_size=chunk_size,
                batch_size=batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )
        # test movies
        for name, movie in movie_test_dict.items():
            dataloaders[name][session_key] = get_movie_dataloader(
                movie=movie,
                responses=neuron_data.response_dict_test[name],
                split="test",
                chunk_size=movie.shape[1],
                batch_size=batch_size,
                scene_length=clip_length,
                allow_over_boundaries=allow_over_boundaries,
            )

    return dataloaders