Skip to content

Sridhar et al. 2025 Dataset

Marmoset retinal ganglion cell responses to naturalistic movies and spatiotemporal white noise, originally published in Sridhar et al. (2025): Modeling spatial contrast sensitivity in responses of primate retinal ganglion cells to natural movies, bioRxiv.

Dataset: doi.gin.g-node.org/10.12751/g-node.3dfiti

Models trained on this dataset were developed as part of A systematic comparison of predictive models on the retina.

Dataloaders

dataloaders

MarmosetMovieDataset

MarmosetMovieDataset(
    responses: dict,
    dir,
    *data_keys: str,
    indices: list,
    frames: ndarray,
    fixations,
    temporal_dilation: int | tuple[int, ...] = 1,
    hidden_temporal_dilation: int | tuple[int, ...] = 1,
    test_frames: int = TEST_DATA_FRAMES,
    img_dir_name: str = "stimuli",
    frame_file: str = "_img_",
    test: bool = False,
    crop: int | tuple = 0,
    subsample: int = 1,
    num_of_frames: int = 15,
    num_of_hidden_frames: int | tuple = 15,
    num_of_layers: int = 1,
    device: str = "cpu",
    time_chunk_size: Optional[int] = None,
    full_img_w: int = 1000,
    full_img_h: int = 800,
    img_w: int = 800,
    img_h: int = 600,
    padding: int = 200,
    excluded_cells=None,
    locations=None,
)

Bases: Dataset

A PyTorch‐style dataset that delivers time chunks of marmoset movie stimuli together with the corresponding neuronal responses recorded from the retina.

Specifically designed for the Sridhar et al. 2025 dataset. Original dataset is available at: https://doi.gin.g-node.org/10.12751/g-node.3dfiti/ The HuggingFace version of the dataset is available at: https://huggingface.co/datasets/open-retina/nm_marmoset_data

The dataset is designed for temporal models that predict neural activity from short clips of video input. Compared with a vanilla torchvision.datasets.VisionDataset, it does:

  • Stimulus‐aligned eye-fixation cropping of movie frames
Parameters

responses : dict Dictionary with at least two keys: * "train_responses"(n_neurons, T_train, n_trials) array. * "test_responses"(n_neurons, T_test) array.

str or pathlib.Path

Root folder that contains the raw movie frames on disk.

data_keys : list[str] Order of modalities returned by __getitem__. Typical options are "inputs" and "targets". The tuple is forwarded to self.data_point to build the sample object. indices : list[int] Trial indices to draw from train_responses. Ignored for test sets which is averaged over trials. frames : np.ndarray All stimulus frames set loaded in memory, shape == (N_frames, full_img_h, full_img_w). fixations : Sequence[Mapping[str, Any]] Per-frame metadata with gaze center and flip flag (expects keys "img_index", "center_x", "center_y", "flip"). temporal_dilation : int, default 1 Step (in frames) between successive visible inputs fed to the network. hidden_temporal_dilation : int | tuple[int], default 1 Temporal dilation(s) applied between hidden layers. If an int is provided it is broadcast to num_of_layers - 1 layers. test_frames, train_frames : int, optional Offsets that separate the first test_frames from the rest img_dir_name : str, default "stimuli" Subfolder inside dir that contains individual frame files. frame_file : str, default "_img_" Substring common to every frame filename (used by :pyfunc:glob). test : bool, default False If True the dataset serves the held-out test set and skips indices / shuffling logic. crop : int | tuple[int, int, int, int], default 0 Pixels cropped from (top, bottom, left, right) before subsampling. Passing an int applies the same pad on every side. subsample : int, default 1 Spatial down-sample factor. 1 keeps full resolution. num_of_frames : int, default 15 Number of visible frames given to the first network layer. num_of_layers : int, optional Total depth of the model (visible + hidden). Required if hidden_temporal_dilation or num_of_hidden_frames are tuples. device : str, default "cpu" Torch device where tensors are materialised. time_chunk_size : int, optional Total length (in frames) of each sample returned by __getitem__. If provided, the iterator chunks the movie into time_chunk_size − frame_overhead non-overlapping segments. full_img_w, full_img_h : int, default 1000, 800 Dimensions of the uncropped video frames. img_w, img_h : int, default 800, 600 Target spatial resolution after* cropping/subsampling. padding : int, default 200 Extra pixels gathered around each frame

Source code in openretina/data_io/sridhar_2025/dataloaders.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self,
    responses: dict,
    dir,
    *data_keys: str,
    indices: list,
    frames: np.ndarray,
    fixations,
    temporal_dilation: int | tuple[int, ...] = 1,
    hidden_temporal_dilation: int | tuple[int, ...] = 1,
    test_frames: int = TEST_DATA_FRAMES,
    img_dir_name: str = "stimuli",
    frame_file: str = "_img_",
    test: bool = False,
    crop: int | tuple = 0,
    subsample: int = 1,
    num_of_frames: int = 15,
    num_of_hidden_frames: int | tuple = 15,
    num_of_layers: int = 1,
    device: str = "cpu",
    time_chunk_size: Optional[int] = None,
    full_img_w: int = 1000,
    full_img_h: int = 800,
    img_w: int = 800,
    img_h: int = 600,
    padding: int = 200,
    excluded_cells=None,
    locations=None,
):
    self.data_keys: tuple[str, ...] = data_keys
    if set(data_keys) == {"inputs", "targets"}:
        self.data_point = default_image_datapoint

    if isinstance(crop, int):
        crop = (crop, crop, crop, crop)
    self.crop = crop
    self.temporal_dilation = temporal_dilation
    self.num_of_layers = num_of_layers

    self.img_h = img_h
    self.img_w = img_w
    self.full_img_h = full_img_h
    self.full_img_w = full_img_w
    self.padding = padding

    self.test_frames = test_frames
    self.excluded_cells = excluded_cells
    self.locations = locations

    self.num_of_frames = num_of_frames

    if isinstance(hidden_temporal_dilation, str):
        hidden_temporal_dilation = int(hidden_temporal_dilation)

    if isinstance(hidden_temporal_dilation, int):
        hidden_temporal_dilation = (hidden_temporal_dilation,) * (self.num_of_layers - 1)
    if isinstance(num_of_hidden_frames, int):
        num_of_hidden_frames = (num_of_hidden_frames,) * (self.num_of_layers - 1)

    if self.num_of_layers > 1:
        hidden_reach = sum((f - 1) * d for f, d in zip(num_of_hidden_frames, hidden_temporal_dilation, strict=True))
    else:
        hidden_reach = 0

    if num_of_hidden_frames is None:
        self.num_of_hidden_frames: tuple = (self.num_of_frames,)
    else:
        self.num_of_hidden_frames = num_of_hidden_frames

    self.frame_overhead = (num_of_frames - 1) * self.temporal_dilation + hidden_reach

    if time_chunk_size is not None:
        self.time_chunk_size = time_chunk_size + self.frame_overhead

    self.subsample = subsample
    self.device = device
    self.basepath = Path(dir).absolute()
    self.img_dir_name = img_dir_name
    self.frame_file = frame_file
    self.response_dict = responses
    if indices is not None:
        self.train_responses = torch.from_numpy(responses["train_responses"]).float()
    self.test_responses = torch.from_numpy(responses["test_responses"]).float()
    raw_test_by_trial = responses.get("test_responses_by_trial")
    self._test_responses_by_trial: torch.Tensor | None = (
        torch.from_numpy(np.transpose(raw_test_by_trial, (2, 1, 0))).float()
        if raw_test_by_trial is not None
        else None
    )

    self.fixations = fixations
    self.indices = indices
    self.frames = frames

    self.random_indices = np.random.permutation(indices)
    self.n_neurons = self.train_responses.shape[0]
    self.num_of_trials = self.train_responses.shape[2]
    self.num_of_imgs = int(self.train_responses.shape[1])

    if test:
        self.num_of_imgs = self.test_responses.shape[1]

    self.cache: list[Any] = []
    self.last_start_index = -1
    self.last_end_index = -1

    self._test = test
    if self._test:
        self._len = int(
            np.floor((self.num_of_imgs - self.frame_overhead) / (self.time_chunk_size - self.frame_overhead))
        )
    else:
        self._len = int(
            len(self.indices)
            * np.floor((self.num_of_imgs - self.frame_overhead) / (self.time_chunk_size - self.frame_overhead))
        )

responses property

responses: Tensor

Return test responses time-major for evaluation.

movies property

movies: Tensor

Build the full test movie once and cache it. Used for evaluation only.

Shape: (channels=1, time, height, width)

get_all_locations

get_all_locations()

Returns the locations of all cells in the dataset. :return: list of locations for each cell in the dataset

Source code in openretina/data_io/sridhar_2025/dataloaders.py
237
238
239
240
241
242
243
244
245
def get_all_locations(self):
    """
    Returns the locations of all cells in the dataset.
    :return: list of locations for each cell in the dataset
    """
    if self.locations is not None:
        return self.locations
    else:
        raise ValueError("Locations are not available in this dataset.")

transform

transform(images)

applies transformations to the image: downsampling, cropping, rescaling, and dimension expansion.

Source code in openretina/data_io/sridhar_2025/dataloaders.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def transform(self, images):
    """
    applies transformations to the image: downsampling, cropping, rescaling, and dimension expansion.
    """
    if len(images.shape) == 3:
        h, w, num_of_imgs = images.shape
        images = images[
            self.crop[0] : h - self.crop[1] : self.subsample,
            self.crop[2] : w - self.crop[3] : self.subsample,
            :,
        ]
        return images

    elif len(images.shape) == 4:
        h, w, num_of_imgs = images.shape[:2]
        images = images[
            self.crop[0][0] : h - self.crop[0][1] : self.subsample,
            self.crop[1][0] : w - self.crop[1][1] : self.subsample,
            :,
        ]
        images = images.permute(0, 3, 1, 2)
    else:
        raise ValueError(
            f"Image shape has to be three dimensional (time as channels) or four dimensional "
            f"(time, with w x h x c). got image shape {images.shape}"
        )
    return images

get_frames

get_frames(
    trial_index: int,
    starting_img_index: int,
    ending_img_index: int,
)

Returns a tensor of frames for the given trial based on the starting and ending image indices. The starting and ending indices index into the list of fixations which correspond to the lines in the fixation file. Each fixations element defines the index of the used frame, the center around which the crop is made, and whether the image should be flipped.

Source code in openretina/data_io/sridhar_2025/dataloaders.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def get_frames(self, trial_index: int, starting_img_index: int, ending_img_index: int):
    """
    Returns a tensor of frames for the given trial based on the starting and ending image indices.
    The starting and ending indices index into the list of fixations which correspond to the lines in the
    fixation file.
    Each fixations element defines the index of the used frame, the center around which the crop is made,
    and whether the image should be flipped.
    """
    cache = []
    if self._test:
        starting_line = starting_img_index
        ending_line = ending_img_index
    else:
        starting_line = (starting_img_index + self.test_frames) + self.num_of_imgs * trial_index
        ending_line = starting_line + self.time_chunk_size

    fixations = self.fixations[int(starting_line) : int(ending_line)]
    frames = torch.zeros((self.time_chunk_size, self.img_h, self.img_w))
    for i, (fixation, index) in enumerate(zip(fixations, range(starting_img_index, ending_img_index))):
        if (index >= self.last_start_index) and (index < self.last_end_index):
            img = self.cache[index - self.last_start_index]
        else:
            img = torch.from_numpy(self.frames[fixation["img_index"]].astype(np.float32))
            img = crop_based_on_fixation(
                img=img,
                x_center=fixation["center_x"],
                y_center=fixation["center_y"],
                flip=fixation["flip"] == 0,
                img_h=self.img_h,
                img_w=self.img_w,
                padding=self.padding,
            )
            img = torch.movedim(img, 0, 1)
        frames[i] = img
        cache.append(img)
    frames = torch.movedim(frames, 0, 2)
    frames = self.transform(frames)
    frames = torch.movedim(frames, 2, 0)
    self.last_start_index = starting_img_index
    self.last_end_index = ending_img_index
    self.cache = cache
    return frames

NoiseDataset

NoiseDataset(
    responses: dict,
    dir,
    *data_keys: list,
    indices: list,
    use_cache: bool = True,
    trial_prefix: str = "trial",
    test: bool = False,
    cache_maxsize: int = 5,
    crop: int | tuple = 20,
    subsample: int = 1,
    num_of_frames: int = 15,
    num_of_layers: int = 1,
    device: str = "cpu",
    time_chunk_size: Optional[int] = None,
    temporal_dilation: int = 1,
    hidden_temporal_dilation: int | str | tuple = 1,
    num_of_hidden_frames: int | tuple | None = 15,
    extra_frame: int = 0,
    locations: Optional[list] = None,
    excluded_cells: Optional[list] = None,
)

Bases: Dataset

Dataset for the following (example) file structure: ├── data │   ├── non-repeating stimuli [directory with as many files as trials] | |-- trial_000 | | |-- all_images.npy | |-- trial 001 | | |-- all_images.npy | ... | |-- trial 247 | | |-- all_images.npy │   ├── repeating stimuli [directory with 1 file for test] |-- all_images.npy │   ├── responses [directory with as many files as retinas]

PARAMETER DESCRIPTION
responses

Dictionary containing train set responses under the key 'train_responses' and test responses under the key 'test_responses'. Expected train set response shape: cells x num_of_images x num_of_trials. Expected test set response shape: cells x num_of_images (i.e. test trials averaged before).

TYPE: dict

dir

Path to directory where images are stored. Expected format for image files is: f'{dir}/{trial_prefix}_{int representing trial number}.zfill(3)/all_images.npy'. Expected shape of numpy array in all_images.npy is: height x width x num_of_images in trial.

data_keys

List of keys to be used for the datapoints, expected ['inputs', 'targets'].

TYPE: list DEFAULT: ()

indices

Indices of the trials selected for the given dataset.

TYPE: list

transforms

List of transformations that are supposed to be performed on images.

use_cache

Whether to use caching when loading image data.

TYPE: bool DEFAULT: True

trial_prefix

Prefix of trial file, followed by '_{trial number}'.

TYPE: str DEFAULT: 'trial'

test

Whether the data we are loading is test data.

TYPE: bool DEFAULT: False

cache_maxsize

Maximum number of trials that can be in the cache at a given point. Cache is NOT implemented as LRU. The last cached item is always kicked out first. This is because the trials are always iterated through in the same order.

TYPE: int DEFAULT: 5

crop

How much to crop the images - top, bottom, left, right.

TYPE: int | tuple DEFAULT: 20

subsample

Whether/how much images should be subsampled.

TYPE: int DEFAULT: 1

num_of_frames

Indicates how many frames should be used to make one prediction.

TYPE: int DEFAULT: 15

num_of_layers

Number of expected convolutional layers, used to calculate the shrink in dimensions or padding.

TYPE: int DEFAULT: 1

device

Device to use.

TYPE: str DEFAULT: 'cpu'

time_chunk_size

Indicates how many predictions should be made at once by the model. The 'inputs' in datapoints are padded accordingly in the temporal dimension with respect to num_of_frames and num_of_layers. Only valid if single_prediction is false.

TYPE: Optional[int] DEFAULT: None

Source code in openretina/data_io/sridhar_2025/dataloaders.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
def __init__(
    self,
    responses: dict,
    dir,
    *data_keys: list,
    indices: list,
    use_cache: bool = True,
    trial_prefix: str = "trial",
    test: bool = False,
    cache_maxsize: int = 5,
    crop: int | tuple = 20,
    subsample: int = 1,
    num_of_frames: int = 15,
    num_of_layers: int = 1,
    device: str = "cpu",
    time_chunk_size: Optional[int] = None,
    temporal_dilation: int = 1,
    hidden_temporal_dilation: int | str | tuple = 1,
    num_of_hidden_frames: int | tuple | None = 15,
    extra_frame: int = 0,
    locations: Optional[list] = None,
    excluded_cells: Optional[list] = None,
):
    """
    Dataset for the following (example) file structure:
     ├── data
     │   ├── non-repeating stimuli [directory with as many files as trials]
               |     |-- trial_000
               |     |     |-- all_images.npy
               |     |-- trial 001
               |     |     |-- all_images.npy
               |     ...
               |     |-- trial 247
               |     |     |-- all_images.npy
     │   ├── repeating stimuli [directory with 1 file for test]
                     |-- all_images.npy
     │   ├── responses [directory with as many files as retinas]

    Args:
        responses: Dictionary containing train set responses under the key 'train_responses'
            and test responses under the key 'test_responses'.
            Expected train set response shape: cells x num_of_images x num_of_trials.
            Expected test set response shape: cells x num_of_images (i.e. test trials averaged before).
        dir: Path to directory where images are stored.
            Expected format for image files is:
            f'{dir}/{trial_prefix}_{int representing trial number}.zfill(3)/all_images.npy'.
            Expected shape of numpy array in all_images.npy is: height x width x num_of_images in trial.
        data_keys: List of keys to be used for the datapoints, expected ['inputs', 'targets'].
        indices: Indices of the trials selected for the given dataset.
        transforms: List of transformations that are supposed to be performed on images.
        use_cache: Whether to use caching when loading image data.
        trial_prefix: Prefix of trial file, followed by '_{trial number}'.
        test: Whether the data we are loading is test data.
        cache_maxsize: Maximum number of trials that can be in the cache at a given point.
            Cache is NOT implemented as LRU. The last cached item is always kicked out first.
            This is because the trials are always iterated through in the same order.
        crop: How much to crop the images - top, bottom, left, right.
        subsample: Whether/how much images should be subsampled.
        num_of_frames: Indicates how many frames should be used to make one prediction.
        num_of_layers: Number of expected convolutional layers,
            used to calculate the shrink in dimensions or padding.
        device: Device to use.
        time_chunk_size: Indicates how many predictions should be made at once by the model.
            The 'inputs' in datapoints are padded accordingly in the temporal dimension with respect
            to num_of_frames and num_of_layers. Only valid if single_prediction is false.
    """

    self.use_cache = use_cache
    self.data_keys = data_keys
    if set(data_keys) == {"inputs", "targets"}:
        self.data_point = default_image_datapoint
    if self.use_cache:
        self.cache_maxsize = cache_maxsize
    if isinstance(crop, int):
        crop = (crop, crop, crop, crop)
    self.crop = crop
    self.extra_frame = extra_frame
    self.num_of_layers = num_of_layers
    self.temporal_dilation = temporal_dilation

    self.num_of_frames = num_of_frames

    if isinstance(hidden_temporal_dilation, str):
        hidden_temporal_dilation = int(hidden_temporal_dilation)

    if isinstance(hidden_temporal_dilation, int):
        hidden_temporal_dilation = (hidden_temporal_dilation,) * (self.num_of_layers - 1)
    if isinstance(num_of_hidden_frames, int):
        num_of_hidden_frames = (num_of_hidden_frames,) * (self.num_of_layers - 1)

    if num_of_hidden_frames is None:
        self.num_of_hidden_frames: tuple = (num_of_frames,)
    else:
        self.num_of_hidden_frames = num_of_hidden_frames

    hidden_reach = sum((f - 1) * d for f, d in zip(self.num_of_hidden_frames, hidden_temporal_dilation))

    self.frame_overhead = (self.num_of_frames - 1) * self.temporal_dilation + hidden_reach

    if time_chunk_size is not None:
        self.time_chunk_size = time_chunk_size + self.frame_overhead
    self.subsample = subsample
    self.device = device
    self.trial_prefix = trial_prefix
    self.data_keys = data_keys
    self.basepath = dir
    if indices is not None:
        self.train_responses = torch.from_numpy(responses["train_responses"]).float()

    self.test_responses = torch.from_numpy(responses["test_responses"]).float()
    raw_test_by_trial = responses.get("test_responses_by_trial")
    self._test_responses_by_trial: torch.Tensor | None = (
        torch.from_numpy(np.transpose(raw_test_by_trial, (2, 0, 1))).float()
        if raw_test_by_trial is not None
        else None
    )
    self.indices = indices
    self.random_indices = np.random.permutation(indices)
    self.n_neurons = self.train_responses.shape[0]
    self.num_of_trials = self.train_responses.shape[2]
    self.num_of_imgs = int(self.train_responses.shape[1])
    self.locations = locations
    self.excluded_cells = excluded_cells

    # Checks if trials are saved in evenly sized files
    if test:
        self.num_of_imgs = self.test_responses.shape[1]
    self._test = test
    if self._test:
        self._len = (
            int(np.floor((self.num_of_imgs - self.frame_overhead) / (self.time_chunk_size - self.frame_overhead)))
            - self.extra_frame
        )

    else:
        self._len = (
            int(
                len(self.indices)
                * np.floor((self.num_of_imgs - self.frame_overhead) / (self.time_chunk_size - self.frame_overhead))
            )
            - self.extra_frame
        )

    self._cache: dict[Any, Any] = {data_key: {} for data_key in data_keys}

transform_image

transform_image(images)

applies transformations to the image: downsampling, cropping, rescaling, and dimension expansion.

Source code in openretina/data_io/sridhar_2025/dataloaders.py
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
def transform_image(self, images):
    """
    applies transformations to the image: downsampling, cropping, rescaling, and dimension expansion.
    """

    if len(images.shape) == 3:
        h, w, num_of_imgs = images.shape
        images = images[
            self.crop[0] : h - self.crop[1] : self.subsample,
            self.crop[2] : w - self.crop[3] : self.subsample,
            :,
        ]
        return images

    elif len(images.shape) == 4:
        h, w, num_of_imgs = images.shape[:2]
        images = images[
            self.crop[0][0] : h - self.crop[0][1] : self.subsample,
            self.crop[1][0] : w - self.crop[1][1] : self.subsample,
            :,
        ]
        images = images.permute(0, 3, 1, 2)
    else:
        raise ValueError(
            f"Image shape has to be three dimensional (time as channels) or four dimensional "
            f"(time, with w x h x c). got image shape {images.shape}"
        )
    return images

frame_movie_loader

frame_movie_loader(
    files,
    fixation_files: dict[int, str],
    big_crops: dict[int, int | tuple[int, int, int, int]],
    basepath,
    batch_size: int = 16,
    seed=None,
    train_frac: float = 0.8,
    subsample: int = 1,
    crop: int | tuple[int, int, int, int] = 0,
    num_of_trials_to_use: int | None = None,
    start_using_trial: int = 0,
    num_of_frames=None,
    temporal_dilation: int | tuple[int, ...] = 1,
    hidden_temporal_dilation: int | tuple[int, ...] = 1,
    cell_index=None,
    retina_index=None,
    device: str = "cpu",
    time_chunk_size=None,
    num_of_layers=None,
    excluded_cells=None,
    frame_file="_img_",
    img_dir_name="stimuli",
    full_img_w: int = 1400,
    full_img_h: int = 1200,
    img_h=None,
    img_w=None,
    final_training: bool = False,
    padding: int = 200,
    retina_specific_crops: bool = True,
    stimulus_seed: int = 0,
    hard_coded=None,
    flip_imgs: bool = False,
    shuffle=None,
    sta_dir="stas",
    get_locations: bool = True,
    **kwargs,
)

Build train/validation/test PyTorch dataloaders for movie–response experiments using fixation-aligned image crops and trial-wise splits.

This function wires together the I/O helpers (process_fixations, load_responses, load_frames), performs a trial-wise split of the training set, and instantiates three dataloaders (train, validation, test) per retina. Internally, each dataloader wraps a temporal dataset (e.g. :class:MarmosetMovieDataset) configured with your spatiotemporal parameters (frame counts, dilations, chunk size, cropping, etc.).

Parameters

files : Mapping[int, str] or similar Specification used by load_responses to locate neural response files for each retina. Keys are retina indices; values are paths or file descriptors understood by load_responses. The loaded arrays are expected to have shape (n_neurons, frames_per_trial, n_trials). fixation_files : Mapping[int, str] Mapping from retina index to a fixation file path (relative to basepath). TODO: only the first entry is actually read and used for all retinas in this function, assuming a shared fixation stream. big_crops : Mapping[int, int | tuple[int, int, int, int]] Retina-specific crop specifications (top, bottom, left, right) used if retina_specific_crops=True. basepath : str or pathlib.Path Root directory that contains the response files and the stimulus frames. batch_size : int, default 16 Batch size used for all splits. seed : int, optional Random seed forwarded to the trial-wise split helper. train_frac : float, default 0.8 Fraction of training trials assigned to the train split by get_trial_wise_validation_split (unless hard_coded dictates otherwise). subsample : int, default 1 Spatial down-sampling factor applied to frames inside the dataset. crop : int or tuple[int, int, int, int], default 0 Global crop (top, bottom, left, right) applied if retina_specific_crops=False. If an int is given, the same value is used on all sides. num_of_trials_to_use : int, optional Cap on the number of trials to use starting from start_using_trial. By default, uses all available trials (train + validation) for the selected retina. start_using_trial : int, default 0 Offset (0-based) for selecting a contiguous block of trials to use. num_of_frames : int, optional Number of visible frames per sample (passed to the underlying dataset). If None, the dataset default is used. num_of_hidden_frames : int or tuple[int], optional Hidden-layer look-back window(s); broadcast/forwarded to the dataset. temporal_dilation : int, default 1 First layer temporal dilation; forwarded to the dataset. hidden_temporal_dilation : int or tuple[int], default 1 Hidden-layer temporal dilations; forwarded to the dataset. cell_index : int, optional Used to train single cell models, selects a single neuron, forwarded to load_responses. retina_index : int, optional If given, build dataloaders for only this retina. Otherwise, build loaders for all keys in fixation_files. device : {"cpu", "cuda", ...}, default "cpu" Torch device on which tensors will be created. time_chunk_size : int, optional Length (in frames) of each temporal chunk returned by the dataset. num_of_layers : int, optional Model depth (used to broadcast hidden-frame/dilation parameters in the dataset). excluded_cells : array-like of int, optional Indices of neurons to drop prior to training; forwarded to load_responses. frame_file : str, default "img" Substring pattern used by load_frames to find frame files. img_dir_name : str, default "stimuli" Subdirectory of basepath that contains individual stimulus frames. full_img_w, full_img_h : int, default 1400, 1200, which are the dimensions of non-subsampled non-cropped images. Spatial dimensions of the raw full-frame stimuli on disk. For subsampled is 350, 300. img_h, img_w : int, optional Target spatial dimensions after fixation cropping. If either is None, both default to full_img_* - 3 * padding. final_training : bool, default False Passed through to the split helper; when True you may configure the helper to collapse validation into training (implementation-dependent). padding : int, default 200 Extra margin captured around the gaze center when extracting crops. retina_specific_crops : bool, default True If True, override crop with big_crops[retina_index] for each retina. stimulus_seed : int, default 0 Random seed forwarded to load_responses for stimulus/order reproducibility. hard_coded : Mapping[str, Sequence[int]] or None Optional hard-coded split specification consumed by get_trial_wise_validation_split. flip_imgs : bool, default False If True, pass a flag to process_fixations to mirror fixations (and thus crops) across the vertical axis. shuffle : bool or None, optional If None, training dataloader shuffles and validation/test do not. If a boolean is given, that value is used for both train and validation; test remains False.

Returns

dataloaders : dict[str, dict[int, torch.utils.data.DataLoader]] A nested dictionary with three top-level keys: "train", "validation", and "test". Each maps retina indices to a dataloader configured for that split.

* Train/validation loaders iterate over **training responses** for the
  selected trials.
* The test loader iterates over the **held-out test responses**; its
  dataset ignores trial indices but uses the same temporal chunking.
Notes
  • Expected response shapes: load_responses should return, for each retina, a dict with keys "train_responses" (n_neurons × T_train × n_trials) and "test_responses" (n_neurons × T_test).
  • TODO: Fixations. Right now, only the first path from fixation_files is opened and parsed via process_fixations and then reused for all retinas.
  • Trial selection window. After the split, trials are restricted to the contiguous window [start_using_trial, start_using_trial + num_of_trials_to_use) (clipped to the number of available trials).
  • Printing side effects. The function prints the selected training and validation trial IDs to stdout.
Source code in openretina/data_io/sridhar_2025/dataloaders.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def frame_movie_loader(
    files,
    fixation_files: dict[int, str],
    big_crops: dict[int, int | tuple[int, int, int, int]],
    basepath,
    batch_size: int = 16,
    seed=None,
    train_frac: float = 0.8,
    subsample: int = 1,
    crop: int | tuple[int, int, int, int] = 0,
    num_of_trials_to_use: int | None = None,
    start_using_trial: int = 0,
    num_of_frames=None,
    temporal_dilation: int | tuple[int, ...] = 1,
    hidden_temporal_dilation: int | tuple[int, ...] = 1,
    cell_index=None,
    retina_index=None,
    device: str = "cpu",
    time_chunk_size=None,
    num_of_layers=None,
    excluded_cells=None,
    frame_file="_img_",
    img_dir_name="stimuli",
    full_img_w: int = 1400,
    full_img_h: int = 1200,
    img_h=None,
    img_w=None,
    final_training: bool = False,
    padding: int = 200,
    retina_specific_crops: bool = True,
    stimulus_seed: int = 0,
    hard_coded=None,
    flip_imgs: bool = False,
    shuffle=None,
    sta_dir="stas",
    get_locations: bool = True,
    **kwargs,
):
    """
    Build train/validation/test PyTorch dataloaders for *movie–response* experiments
    using fixation-aligned image crops and trial-wise splits.

    This function wires together the I/O helpers (``process_fixations``,
    ``load_responses``, ``load_frames``), performs a **trial-wise** split of the
    training set, and instantiates three dataloaders (train, validation, test)
    per retina. Internally, each dataloader wraps a temporal dataset
    (e.g. :class:`MarmosetMovieDataset`) configured with your spatiotemporal
    parameters (frame counts, dilations, chunk size, cropping, etc.).

    Parameters
    ----------
    files : Mapping[int, str] or similar
        Specification used by ``load_responses`` to locate neural response files
        for each retina. Keys are retina indices; values are paths or file
        descriptors understood by ``load_responses``. The loaded arrays are
        expected to have shape ``(n_neurons, frames_per_trial, n_trials)``.
    fixation_files : Mapping[int, str]
        Mapping from retina index to a fixation file path (relative to
        ``basepath``). TODO: only the first entry is actually read and used
        for all retinas in this function, assuming a shared fixation stream.
    big_crops : Mapping[int, int | tuple[int, int, int, int]]
        Retina-specific crop specifications *(top, bottom, left, right)* used
        if ``retina_specific_crops=True``.
    basepath : str or pathlib.Path
        Root directory that contains the response files and the stimulus frames.
    batch_size : int, default 16
        Batch size used for all splits.
    seed : int, optional
        Random seed forwarded to the trial-wise split helper.
    train_frac : float, default 0.8
        Fraction of training trials assigned to the train split by
        ``get_trial_wise_validation_split`` (unless ``hard_coded`` dictates
        otherwise).
    subsample : int, default 1
        Spatial down-sampling factor applied to frames inside the dataset.
    crop : int or tuple[int, int, int, int], default 0
        Global crop *(top, bottom, left, right)* applied if
        ``retina_specific_crops=False``. If an ``int`` is given, the same value
        is used on all sides.
    num_of_trials_to_use : int, optional
        Cap on the number of trials to use starting from
        ``start_using_trial``. By default, uses all available trials (train +
        validation) for the selected retina.
    start_using_trial : int, default 0
        Offset (0-based) for selecting a contiguous block of trials to use.
    num_of_frames : int, optional
        Number of **visible** frames per sample (passed to the underlying
        dataset). If ``None``, the dataset default is used.
    num_of_hidden_frames : int or tuple[int], optional
        Hidden-layer look-back window(s); broadcast/forwarded to the dataset.
    temporal_dilation : int, default 1
        First layer temporal dilation; forwarded to the dataset.
    hidden_temporal_dilation : int or tuple[int], default 1
        Hidden-layer temporal dilations; forwarded to the dataset.
    cell_index : int, optional
        Used to train single cell models, selects a single neuron, forwarded to ``load_responses``.
    retina_index : int, optional
        If given, build dataloaders for only this retina. Otherwise, build
        loaders for all keys in ``fixation_files``.
    device : {"cpu", "cuda", ...}, default "cpu"
        Torch device on which tensors will be created.
    time_chunk_size : int, optional
        Length (in frames) of each temporal chunk returned by the dataset.
    num_of_layers : int, optional
        Model depth (used to broadcast hidden-frame/dilation parameters in the
        dataset).
    excluded_cells : array-like of int, optional
        Indices of neurons to drop prior to training; forwarded to
        ``load_responses``.
    frame_file : str, default "_img_"
        Substring pattern used by ``load_frames`` to find frame files.
    img_dir_name : str, default "stimuli"
        Subdirectory of ``basepath`` that contains individual stimulus frames.
    full_img_w, full_img_h : int, default 1400, 1200, which are the dimensions of non-subsampled non-cropped images.
        Spatial dimensions of the *raw* full-frame stimuli on disk. For subsampled is 350, 300.
    img_h, img_w : int, optional
        Target spatial dimensions after fixation cropping. If either is
        ``None``, both default to ``full_img_* - 3 * padding``.
    final_training : bool, default False
        Passed through to the split helper; when ``True`` you may configure the
        helper to collapse validation into training (implementation-dependent).
    padding : int, default 200
        Extra margin captured around the gaze center when extracting crops.
    retina_specific_crops : bool, default True
        If ``True``, override ``crop`` with ``big_crops[retina_index]`` for each
        retina.
    stimulus_seed : int, default 0
        Random seed forwarded to ``load_responses`` for stimulus/order
        reproducibility.
    hard_coded : Mapping[str, Sequence[int]] or None
        Optional hard-coded split specification consumed by
        ``get_trial_wise_validation_split``.
    flip_imgs : bool, default False
        If ``True``, pass a flag to ``process_fixations`` to mirror fixations
        (and thus crops) across the vertical axis.
    shuffle : bool or None, optional
        If ``None``, training dataloader shuffles and validation/test do not.
        If a boolean is given, that value is used for **both** train and
        validation; test remains ``False``.

    Returns
    -------
    dataloaders : dict[str, dict[int, torch.utils.data.DataLoader]]
        A nested dictionary with three top-level keys: ``"train"``,
        ``"validation"``, and ``"test"``. Each maps retina indices to a
        dataloader configured for that split.

        * Train/validation loaders iterate over **training responses** for the
          selected trials.
        * The test loader iterates over the **held-out test responses**; its
          dataset ignores trial indices but uses the same temporal chunking.

    Notes
    -----
    - **Expected response shapes:** ``load_responses`` should return, for each
      retina, a dict with keys ``"train_responses"`` (``n_neurons × T_train ×
      n_trials``) and ``"test_responses"`` (``n_neurons × T_test``).
    - TODO: **Fixations.** Right now, only the first path from ``fixation_files`` is opened and
      parsed via ``process_fixations`` and then reused for all retinas.
    - **Trial selection window.** After the split, trials are restricted to the
      contiguous window
      ``[start_using_trial, start_using_trial + num_of_trials_to_use)`` (clipped
      to the number of available trials).
    - **Printing side effects.** The function prints the selected training and
      validation trial IDs to stdout."""

    basepath = get_local_file_path(str(basepath))

    dataloaders: dict[str, dict] = {"train": {}, "validation": {}, "test": {}}
    if retina_index is None:
        retina_indices = list(fixation_files.keys())
    else:
        retina_indices = [retina_index]

    with open(f"{basepath}/{fixation_files[retina_indices[0]]}", "r") as file:
        fixation_file = file.readlines()
        fixations = process_fixations(fixation_file, flip_imgs=flip_imgs)

    responses = load_responses(
        basepath, files=files, stimulus_seed=stimulus_seed, excluded_cells=excluded_cells, cell_index=cell_index
    )
    frames = load_frames(
        img_dir_name=os.path.join(basepath, img_dir_name),
        frame_file=frame_file,
        full_img_h=full_img_h,
        full_img_w=full_img_w,
    )

    if img_h is None:
        img_h = full_img_h - 3 * padding
        img_w = full_img_w - 3 * padding

    for retina_index in retina_indices:
        train_responses, test_responses = (
            responses[retina_index]["train_responses"],
            responses[retina_index]["test_responses"],
        )

        all_train_ids, all_validation_ids = get_trial_wise_validation_split(
            train_responses=train_responses,
            train_frac=train_frac,
            seed=seed,
            final_training=final_training,
            hard_coded=hard_coded,
        )

        train_ids, valid_ids = filter_trials(
            train_responses=train_responses,
            all_train_ids=all_train_ids,
            all_validation_ids=all_validation_ids,
            hard_coded=hard_coded,
            num_of_trials_to_use=num_of_trials_to_use,
            starting_trial=start_using_trial,
        )

        # print(f"Trial separation for {retina_index}")
        # print("training trials: ", len(train_ids), train_ids)
        # print("validation trials: ", len(valid_ids), valid_ids, "\n")

        if retina_specific_crops:
            crop = big_crops[retina_index]
        locations = None
        if get_locations:
            assert sta_dir is not None
            locations = get_locations_from_stas(
                sta_dir=os.path.join(basepath, sta_dir),
                retina_index=retina_index,
                cells=[cell_index]
                if cell_index is not None
                else [
                    x
                    for x in range(0, train_responses.shape[0] + len(excluded_cells[retina_index]))
                    if x not in excluded_cells[retina_index]
                ],
                crop=crop,
                flip_sta=True,
            )
        if isinstance(num_of_frames, int):
            num_of_frames = [num_of_frames]
        train_loader = get_dataloader(
            {
                "train_responses": train_responses,
                "test_responses": test_responses,
                "test_responses_by_trial": responses[retina_index].get("test_responses_by_trial"),
            },
            fixations=fixations,
            path=basepath,
            indices=train_ids,
            test=False,
            batch_size=batch_size,
            num_of_frames=num_of_frames[0],
            device=device,
            crop=crop,
            shuffle=True if shuffle is None else shuffle,
            subsample=subsample,
            time_chunk_size=time_chunk_size,
            num_of_layers=num_of_layers,
            frames=frames,
            num_of_hidden_frames=num_of_frames[1:] if len(num_of_frames) > 1 else None,
            padding=padding,
            full_img_h=full_img_h,
            full_img_w=full_img_w,
            img_h=img_h,
            img_w=img_w,
            temporal_dilation=temporal_dilation,
            hidden_temporal_dilation=hidden_temporal_dilation,
            excluded_cells=excluded_cells,
            locations=locations,
        )

        valid_loader = get_dataloader(
            {
                "train_responses": train_responses,
                "test_responses": test_responses,
                "test_responses_by_trial": responses[retina_index].get("test_responses_by_trial"),
            },
            path=basepath,
            indices=valid_ids,
            test=False,
            batch_size=batch_size,
            fixations=fixations,
            num_of_frames=num_of_frames[0],
            device=device,
            crop=crop,
            shuffle=False if shuffle is None else shuffle,
            subsample=subsample,
            time_chunk_size=time_chunk_size,
            num_of_layers=num_of_layers,
            frames=frames,
            num_of_hidden_frames=num_of_frames[1:] if len(num_of_frames) > 1 else None,
            padding=padding,
            full_img_h=full_img_h,
            full_img_w=full_img_w,
            img_h=img_h,
            img_w=img_w,
            temporal_dilation=temporal_dilation,
            hidden_temporal_dilation=hidden_temporal_dilation,
            excluded_cells=excluded_cells,
            locations=locations,
        )

        test_loader = get_dataloader(
            {
                "train_responses": train_responses,
                "test_responses": test_responses,
                "test_responses_by_trial": responses[retina_index].get("test_responses_by_trial"),
            },
            fixations=fixations,
            path=basepath,
            indices=train_ids,
            test=True,
            batch_size=batch_size,
            num_of_frames=num_of_frames[0],
            device=device,
            crop=crop,
            shuffle=False,
            subsample=subsample,
            time_chunk_size=time_chunk_size,
            num_of_layers=num_of_layers,
            frames=frames,
            num_of_hidden_frames=num_of_frames[1:] if len(num_of_frames) > 1 else None,
            padding=padding,
            full_img_h=full_img_h,
            full_img_w=full_img_w,
            img_h=img_h,
            img_w=img_w,
            temporal_dilation=temporal_dilation,
            hidden_temporal_dilation=hidden_temporal_dilation,
            excluded_cells=excluded_cells,
            locations=locations,
        )

        dataloaders["train"][retina_index] = train_loader
        dataloaders["validation"][retina_index] = valid_loader
        dataloaders["test"][retina_index] = test_loader

    return dataloaders

Stimuli

stimuli

load_frames

load_frames(
    img_dir_name: str | PathLike,
    frame_file: str,
    full_img_w: int,
    full_img_h: int,
) -> Float[ndarray, "frames width height"]

loads all stimulus frames of the movie into memory

Source code in openretina/data_io/sridhar_2025/stimuli.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def load_frames(
    img_dir_name: str | os.PathLike, frame_file: str, full_img_w: int, full_img_h: int
) -> Float[np.ndarray, "frames width height"]:
    """
    loads all stimulus frames of the movie into memory
    """
    img_dir_name = get_local_file_path(str(img_dir_name))
    print("Loading all frames from:", img_dir_name, "into memory")
    images = os.listdir(img_dir_name)
    images = [frame for frame in images if frame_file in frame]
    all_frames = np.zeros((len(images), full_img_w, full_img_h), dtype=np.float16)
    i = 0
    for img_file in tqdm(sorted(images)):
        img = np.load(f"{img_dir_name}/{img_file}")

        all_frames[i] = img / 255
        i += 1
    return all_frames

build_placeholder_movies

build_placeholder_movies(
    session_ids,
    *,
    channels: int,
    height: int,
    width: int,
    time_bins: int = 1,
    test_time_bins: int | None = None,
    stim_id_prefix: str = "sridhar_2025",
    norm_mean: float = 0.0,
    norm_std: float = 1.0,
) -> dict[str, MoviesTrainTestSplit]

Create lightweight MoviesTrainTestSplit placeholders that encode spatial dimensions of the stimuli. Will not be used directly by the dataloader and model, but rather for data_info computation.

Note: For accurate frame counts, use build_movies_from_responses instead, which reads the response files to determine actual training/test frame counts.

Source code in openretina/data_io/sridhar_2025/stimuli.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def build_placeholder_movies(
    session_ids,
    *,
    channels: int,
    height: int,
    width: int,
    time_bins: int = 1,
    test_time_bins: int | None = None,
    stim_id_prefix: str = "sridhar_2025",
    norm_mean: float = 0.0,
    norm_std: float = 1.0,
) -> dict[str, MoviesTrainTestSplit]:
    """
    Create lightweight MoviesTrainTestSplit placeholders that encode spatial dimensions of the stimuli.
    Will not be used directly by the dataloader and model, but rather for `data_info` computation.

    Note: For accurate frame counts, use `build_movies_from_responses` instead, which reads the
    response files to determine actual training/test frame counts.
    """
    if test_time_bins is None:
        test_time_bins = time_bins

    movies = {}
    for session_id in session_ids:
        train = np.zeros((channels, time_bins, height, width), dtype=np.float32)
        test = np.zeros((channels, test_time_bins, height, width), dtype=np.float32)
        movies[session_id] = MoviesTrainTestSplit(
            train=train,
            test=test,
            stim_id=f"{stim_id_prefix}_{session_id}",
            norm_mean=norm_mean,
            norm_std=norm_std,
        )
    return movies

build_movies_from_responses

build_movies_from_responses(
    base_path: str | PathLike,
    response_files: dict[str, str],
    *,
    channels: int,
    height: int,
    width: int,
    stim_id_prefix: str = "sridhar_2025",
    stimulus_seed: int = 0,
    norm_mean: float = 0.0,
    norm_std: float = 1.0,
) -> dict[str, MoviesTrainTestSplit]

Create MoviesTrainTestSplit placeholders with accurate frame counts by reading response files.

This function loads the response pickle files to determine the actual number of training and test frames per session, accounting for stimulus_seed filtering. The resulting placeholders have correct time dimensions for accurate dataset statistics reporting.

Each session gets a unique stim_id because different sessions see different visual content due to different fixation patterns and trial assignments.

PARAMETER DESCRIPTION
base_path

Base directory containing the response files.

TYPE: str | PathLike

response_files

Dictionary mapping session IDs to response pickle file paths (relative to base_path).

TYPE: dict[str, str]

channels

Number of channels in the stimulus.

TYPE: int

height

Height of the stimulus frames.

TYPE: int

width

Width of the stimulus frames.

TYPE: int

stim_id_prefix

Prefix for the stimulus ID (each session will have "{prefix}_{session_id}").

TYPE: str DEFAULT: 'sridhar_2025'

stimulus_seed

Random seed used for trial selection (affects frame counts for some sessions).

TYPE: int DEFAULT: 0

norm_mean

Normalization mean for the stimulus.

TYPE: float DEFAULT: 0.0

norm_std

Normalization std for the stimulus.

TYPE: float DEFAULT: 1.0

RETURNS DESCRIPTION
dict[str, MoviesTrainTestSplit]

Dictionary mapping session IDs to MoviesTrainTestSplit placeholders with accurate frame counts.

Source code in openretina/data_io/sridhar_2025/stimuli.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def build_movies_from_responses(
    base_path: str | os.PathLike,
    response_files: dict[str, str],
    *,
    channels: int,
    height: int,
    width: int,
    stim_id_prefix: str = "sridhar_2025",
    stimulus_seed: int = 0,
    norm_mean: float = 0.0,
    norm_std: float = 1.0,
) -> dict[str, MoviesTrainTestSplit]:
    """
    Create MoviesTrainTestSplit placeholders with accurate frame counts by reading response files.

    This function loads the response pickle files to determine the actual number of training and
    test frames per session, accounting for stimulus_seed filtering. The resulting placeholders
    have correct time dimensions for accurate dataset statistics reporting.

    Each session gets a unique stim_id because different sessions see different visual content
    due to different fixation patterns and trial assignments.

    Args:
        base_path: Base directory containing the response files.
        response_files: Dictionary mapping session IDs to response pickle file paths (relative to base_path).
        channels: Number of channels in the stimulus.
        height: Height of the stimulus frames.
        width: Width of the stimulus frames.
        stim_id_prefix: Prefix for the stimulus ID (each session will have "{prefix}_{session_id}").
        stimulus_seed: Random seed used for trial selection (affects frame counts for some sessions).
        norm_mean: Normalization mean for the stimulus.
        norm_std: Normalization std for the stimulus.

    Returns:
        Dictionary mapping session IDs to MoviesTrainTestSplit placeholders with accurate frame counts.
    """

    base_path = get_local_file_path(str(base_path))
    movies = {}

    for session_id, response_file in response_files.items():
        with open(os.path.join(base_path, response_file), "rb") as f:
            data = pickle.load(f)

        train_responses = data["train_responses"]
        test_responses = data["test_responses"]

        # Apply seed filtering logic (same as in responses.py load_responses)
        if "seeds" in data:
            seed_info = data["seeds"]
            if stimulus_seed in seed_info:
                trials = data["trial_separation"][stimulus_seed]
                train_responses = train_responses[:, :, trials]
            # Note: test responses are not filtered by seed in the dataloader

        # Get frame counts
        _, frames_per_trial, n_trials = train_responses.shape
        train_time_bins = frames_per_trial * n_trials
        test_time_bins = test_responses.shape[1]

        # Create placeholder arrays with correct dimensions
        train = np.zeros((channels, train_time_bins, height, width), dtype=np.float32)
        test = np.zeros((channels, test_time_bins, height, width), dtype=np.float32)

        movies[session_id] = MoviesTrainTestSplit(
            train=train,
            test=test,
            stim_id=f"{stim_id_prefix}_{session_id}",
            norm_mean=norm_mean,
            norm_std=norm_std,
        )

    return movies

Responses

responses

response_splits_from_pickles

response_splits_from_pickles(
    base_path: str | PathLike,
    files: dict[str, str],
    stimulus_seed: int = 0,
    excluded_cells: Optional[dict[Any, list[int]]] = None,
    cell_index: Optional[int] = None,
) -> dict[str, ResponsesTrainTestSplit]

Convert Sridhar pickled responses into ResponsesTrainTestSplit objects compatible with the unified pipeline. The output will not be used directly by the dataloader and model, but rather for data_info computation.

Source code in openretina/data_io/sridhar_2025/responses.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def response_splits_from_pickles(
    base_path: str | os.PathLike,
    files: dict[str, str],
    stimulus_seed: int = 0,
    excluded_cells: Optional[dict[Any, list[int]]] = None,
    cell_index: Optional[int] = None,
) -> dict[str, ResponsesTrainTestSplit]:
    """
    Convert Sridhar pickled responses into ``ResponsesTrainTestSplit`` objects compatible with the unified pipeline.
    The output will not be used directly by the dataloader and model, but rather for `data_info` computation.
    """
    raw_responses = load_responses(
        base_path,
        files=files,
        stimulus_seed=stimulus_seed,
        excluded_cells=excluded_cells,
        cell_index=cell_index,
    )

    splits: dict[str, ResponsesTrainTestSplit] = {}
    for session_id, tensors in raw_responses.items():
        train_responses = np.asarray(tensors["train_responses"], dtype=np.float32)
        test_responses = np.asarray(tensors["test_responses"], dtype=np.float32)
        test_by_trial = tensors.get("test_responses_by_trial")

        n_neurons, frames_per_trial, n_trials = train_responses.shape
        train_matrix = train_responses.reshape(n_neurons, frames_per_trial * n_trials)

        test_by_trial_formatted = None
        if test_by_trial is not None and test_by_trial.ndim == 3:
            # Expecting shape (neurons, time, trials); re-order to (trials, neurons, time)
            test_by_trial_formatted = np.transpose(test_by_trial, (2, 0, 1))
        if test_responses.ndim == 3:
            test_responses = np.mean(test_responses, axis=-1)

        splits[session_id] = ResponsesTrainTestSplit(
            train=train_matrix,
            test=test_responses,
            test_by_trial=test_by_trial_formatted,
            stim_id=f"sridhar_{session_id}",
            session_kwargs={
                "stimulus_seed": stimulus_seed,
                "frames_per_trial": frames_per_trial,
                "num_trials": n_trials,
            },
        )
    return splits