Skip to content

Vector Field Analysis

Tools for analyzing model responses via principal component analysis and vector field visualization.

compute_lsta_library

compute_lsta_library(
    model: Module,
    movies: ndarray,
    session_id: str,
    cell_id: int,
    batch_size: int = 64,
    integration_window: tuple[int, int] = (5, 15),
    device: str = "cuda",
) -> tuple[ndarray, ndarray]

Computes the Local Spike-Triggered Average (LSTA) library and response library for a given model, set of movies, and cell_id.

For each batch of input movies, this function: - Runs the model to obtain outputs for all cells and time points. - Selects the output for a specific cell over a specified integration window (time range). - Sums the selected outputs and computes the gradient of this sum with respect to the input movies. - The resulting gradients (LSTA maps) are averaged over the integration window for each movie. - Collects both the LSTA maps and the raw model outputs for all movies.

Parameters

model (torch.nn.Module): The neural network model to evaluate.
movies (np.ndarray or torch.Tensor): Array of input movie stimuli with shape (num_samples, channels, frames,
  height, width).
session_id (str): Identifier for the session/data key used by the model.
cell_id (int): Index of the cell for which to compute LSTA.
batch_size (int, optional): Number of samples per batch. Default is 64.
integration_window (tuple, optional): Tuple (start, end) specifying the time window (frame indices) over which
  to sum outputs. Default is (5, 10).
device (str, optional): Device to run computations on ('cuda' or 'cpu'). Default is 'cuda'.

Returns

lsta_library (np.ndarray): Array of LSTA maps averaged over the integration window, shape
 (num_samples, channels, height, width).
response_library (np.ndarray): Array of model outputs for all batches, shape (num_samples, frames, num_cells).

Raises

IndexError: If cell_id is out of bounds for the model output.

Notes

- The LSTA map for each movie is computed as the gradient of the summed output for the specified cell and time
  window,
    with respect to the input movie frames.
- The returned lsta_library is averaged over the integration window
  (i.e., mean gradient across selected frames).
- The response_library contains the raw model outputs for all movies, all frames, and all cells.
- Default integration_window is not always optimal;
  adjust based on model architecture and expected response timing.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def compute_lsta_library(
    model: torch.nn.Module,
    movies: np.ndarray,
    session_id: str,
    cell_id: int,
    batch_size: int = 64,
    integration_window: tuple[int, int] = (5, 15),
    device: str = "cuda",
) -> tuple[np.ndarray, np.ndarray]:
    """
    Computes the Local Spike-Triggered Average (LSTA) library and response library for a given model,
      set of movies, and cell_id.

    For each batch of input movies, this function:
        - Runs the model to obtain outputs for all cells and time points.
        - Selects the output for a specific cell over a specified integration window (time range).
        - Sums the selected outputs and computes the gradient of this sum with respect to the input movies.
        - The resulting gradients (LSTA maps) are averaged over the integration window for each movie.
        - Collects both the LSTA maps and the raw model outputs for all movies.

    Parameters
    ----------
        model (torch.nn.Module): The neural network model to evaluate.
        movies (np.ndarray or torch.Tensor): Array of input movie stimuli with shape (num_samples, channels, frames,
          height, width).
        session_id (str): Identifier for the session/data key used by the model.
        cell_id (int): Index of the cell for which to compute LSTA.
        batch_size (int, optional): Number of samples per batch. Default is 64.
        integration_window (tuple, optional): Tuple (start, end) specifying the time window (frame indices) over which
          to sum outputs. Default is (5, 10).
        device (str, optional): Device to run computations on ('cuda' or 'cpu'). Default is 'cuda'.

    Returns
    -------
        lsta_library (np.ndarray): Array of LSTA maps averaged over the integration window, shape
         (num_samples, channels, height, width).
        response_library (np.ndarray): Array of model outputs for all batches, shape (num_samples, frames, num_cells).

    Raises
    ------
        IndexError: If cell_id is out of bounds for the model output.

    Notes
    -----
        - The LSTA map for each movie is computed as the gradient of the summed output for the specified cell and time
          window,
            with respect to the input movie frames.
        - The returned lsta_library is averaged over the integration window
          (i.e., mean gradient across selected frames).
        - The response_library contains the raw model outputs for all movies, all frames, and all cells.
        - Default integration_window is not always optimal;
          adjust based on model architecture and expected response timing.
    """
    model.eval()
    all_lstas = []
    all_outputs = []

    for i in range(0, len(movies), batch_size):
        batch_movies = torch.tensor(movies[i : i + batch_size], dtype=torch.float32, device=device, requires_grad=True)

        outputs = model(batch_movies, data_key=session_id)
        num_cells = outputs.shape[-1]
        if not (0 <= cell_id < num_cells):
            raise IndexError(f"cell_id {cell_id} is out of bounds (number of cells: {num_cells})")

        chosen_cell_outputs = outputs[:, integration_window[0] : integration_window[1], cell_id].sum()
        chosen_cell_outputs.backward()

        assert batch_movies.grad is not None
        batch_lstas = batch_movies.grad.detach()
        all_lstas.append(batch_lstas)
        all_outputs.append(outputs.detach())

        # Clear gradients
        del batch_movies
        torch.cuda.empty_cache()

    lstas = torch.cat(all_lstas, dim=0)
    lsta_library = lstas.mean(dim=2)  # Average over the integration window (frames)
    response_library = torch.cat(all_outputs, dim=0)
    return lsta_library.cpu().numpy(), response_library.cpu().numpy()

get_pc_from_pca

get_pc_from_pca(
    model,
    channel: int,
    lsta_library: ndarray,
    plot: bool = False,
) -> tuple[ndarray, ndarray, ndarray]

Computes the first two principal components (PC1 and PC2) from a PCA analysis on a selected channel of the input data. Parameters


model : object Model object containing data information, specifically the input shape in model.data_info["input_shape"]. channel : int Index of the channel to select from lsta_library for PCA analysis. lsta_library : np.ndarray Input data array of shape (samples, channels, height, width). plot : bool, optional If True, plots the first two principal components as images using matplotlib. Returns


PC1 : np.ndarray The first principal component as a flattened array. PC2 : np.ndarray The second principal component as a flattened array. explained_variance : np.ndarray Array containing the explained variance ratio for the first two principal components. Notes


  • The function reshapes the selected channel data to (samples, height * width) before applying PCA.
  • If plot is True, displays the principal components as images with color mapping.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def get_pc_from_pca(
    model, channel: int, lsta_library: np.ndarray, plot: bool = False
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Computes the first two principal components (PC1 and PC2) from a PCA analysis on a selected channel of the
      input data.
    Parameters
    ----------
    model : object
        Model object containing data information, specifically the input shape in `model.data_info["input_shape"]`.
    channel : int
        Index of the channel to select from `lsta_library` for PCA analysis.
    lsta_library : np.ndarray
        Input data array of shape (samples, channels, height, width).
    plot : bool, optional
        If True, plots the first two principal components as images using matplotlib.
    Returns
    -------
    PC1 : np.ndarray
        The first principal component as a flattened array.
    PC2 : np.ndarray
        The second principal component as a flattened array.
    explained_variance : np.ndarray
        Array containing the explained variance ratio for the first two principal components.
    Notes
    -----
    - The function reshapes the selected channel data to (samples, height * width) before applying PCA.
    - If `plot` is True, displays the principal components as images with color mapping.
    """

    # Select channel and reshape
    lsta_reshaped = lsta_library[:, channel, :, :].reshape(lsta_library.shape[0], -1)

    pca = PCA(n_components=2)
    pca.fit(lsta_reshaped)

    explained_variance = pca.explained_variance_ratio_
    PC1, PC2 = pca.components_

    if plot:
        PC_max = max(np.abs(PC1).max(), np.abs(PC2).max())
        plt.figure(figsize=(10, 5))
        for component in range(2):
            plt.subplot(1, 2, component + 1)
            plt.imshow(
                pca.components_[component].reshape(model.data_info["input_shape"][1:3]),
                cmap="bwr",
                vmin=-PC_max,
                vmax=PC_max,
            )
            plt.title(f"PCA {component} ({explained_variance[component]:.2f} e.v.)")
            plt.axis("off")

    return PC1, PC2, explained_variance

get_images_coordinate

get_images_coordinate(
    images: ndarray,
    PC1: ndarray,
    PC2: ndarray,
    plot: bool = False,
) -> ndarray

Projects a set of images onto two principal component vectors and optionally plots their coordinates. Parameters


images (np.ndarray): Array of images with shape (n_samples, height, width).
PC1 (np.ndarray): First principal component vector with shape (height * width,).
PC2 (np.ndarray): Second principal component vector with shape (height * width,).
plot (bool, optional): If True, plots the projected coordinates. Default is False.

Returns

np.ndarray: Array of shape (n_samples, 2) containing the coordinates of each image projected onto PC1 and PC2.

Note

The function reshapes each image to a 1D vector before projection.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def get_images_coordinate(images: np.ndarray, PC1: np.ndarray, PC2: np.ndarray, plot: bool = False) -> np.ndarray:
    """
    Projects a set of images onto two principal component vectors and optionally plots their coordinates.
    Parameters
    ----------
        images (np.ndarray): Array of images with shape (n_samples, height, width).
        PC1 (np.ndarray): First principal component vector with shape (height * width,).
        PC2 (np.ndarray): Second principal component vector with shape (height * width,).
        plot (bool, optional): If True, plots the projected coordinates. Default is False.
    Returns
    -------
        np.ndarray: Array of shape (n_samples, 2) containing the coordinates of each image projected onto PC1 and PC2.
    Note
    -----
        The function reshapes each image to a 1D vector before projection.
    """
    flatten_images = images.reshape(images.shape[0], -1)
    # Vectorized dot product: (N, features) @ (2, features).T -> (N, 2)
    PC_stack = np.stack([PC1, PC2], axis=0)  # Shape: (2, features)
    images_coordinate = flatten_images @ PC_stack.T  # Shape: (N, 2)

    if plot:
        pt_x = images_coordinate[:, 0]
        pt_y = images_coordinate[:, 1]
        plt.figure()
        plt.scatter(pt_x, pt_y)

    return images_coordinate

plot_untreated_vectorfield

plot_untreated_vectorfield(
    lsta_library: ndarray,
    channel: int,
    PC1: ndarray,
    PC2: ndarray,
    images_coordinate: ndarray,
) -> Figure

Plots a vector field visualization using principal components from an LSTA library. This function extracts the specified channel from the LSTA library, projects each LSTA onto two principal components (PC1 and PC2), and visualizes the resulting vector field at given image coordinates using matplotlib's quiver plot. Additionally, it displays the PC1 and PC2 components as inset images. This function is primarily for visualization in notebooks. Returns figure for saving or further customization. Parameters


lsta_library : np.ndarray A 4D numpy array containing the LSTA library data with shape (n_samples, n_channels, x_size, y_size). channel : int The index of the channel to extract from the LSTA library for analysis. PC1 : np.ndarray The first principal component vector used for projection. PC2 : np.ndarray The second principal component vector used for projection. images_coordinate : np.ndarray A 2D numpy array of shape (n_samples, 2) containing the (x, y) coordinates for each LSTA sample. Returns


plt.Figure The matplotlib Figure object containing the vector field plot with PC1 and PC2 inset images. Call plt.show() to display, or fig.savefig() to save. Notes


  • The function uses matplotlib's quiver plot to visualize the vector field.
  • PC1 and PC2 are displayed as inset images for reference.
  • The axes are turned off for a cleaner visualization.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def plot_untreated_vectorfield(
    lsta_library: np.ndarray, channel: int, PC1: np.ndarray, PC2: np.ndarray, images_coordinate: np.ndarray
) -> plt.Figure:
    """
    Plots a vector field visualization using principal components from an LSTA library.
    This function extracts the specified channel from the LSTA library, projects each LSTA onto two principal components
      (PC1 and PC2),
    and visualizes the resulting vector field at given image coordinates using matplotlib's quiver plot. Additionally,
      it displays
    the PC1 and PC2 components as inset images.
    This function is primarily for visualization in notebooks.
    Returns figure for saving or further customization.
    Parameters
    ----------
    lsta_library : np.ndarray
        A 4D numpy array containing the LSTA library data with shape (n_samples, n_channels, x_size, y_size).
    channel : int
        The index of the channel to extract from the LSTA library for analysis.
    PC1 : np.ndarray
        The first principal component vector used for projection.
    PC2 : np.ndarray
        The second principal component vector used for projection.
    images_coordinate : np.ndarray
        A 2D numpy array of shape (n_samples, 2) containing the (x, y) coordinates for each LSTA sample.
    Returns
    -------
    plt.Figure
        The matplotlib Figure object containing the vector field plot with PC1 and PC2 inset images. Call plt.show()
         to display,
        or fig.savefig() to save.
    Notes
    -----
    - The function uses matplotlib's quiver plot to visualize the vector field.
    - PC1 and PC2 are displayed as inset images for reference.
    - The axes are turned off for a cleaner visualization.
    """
    lsta_library = lsta_library[:, channel, :, :]
    arrowheads = np.array([[np.dot(PC1, lsta.flatten()), np.dot(PC2, lsta.flatten())] for lsta in lsta_library])
    fig, ax = plt.subplots(figsize=(20, 15))
    window_size = int(max(images_coordinate[:, 0].max(), images_coordinate[:, 1].max()) * 1.1)
    ax.quiver(
        images_coordinate[: len(lsta_library), 0],
        images_coordinate[: len(lsta_library), 1],
        arrowheads[:, 0],
        arrowheads[:, 1],
        width=0.002,
        scale_units="xy",
        angles="xy",
        scale=arrowheads.max(),
        alpha=0.5,
    )
    ax.set_xlim((-window_size, window_size))
    ax.set_ylim((-window_size, window_size))
    ax.axis("off")

    x_size = lsta_library.shape[-2]
    y_size = lsta_library.shape[-1]
    plot_pc_insets(fig, PC1, PC2, x_size, y_size)
    return plt.gcf()

plot_clean_vectorfield

plot_clean_vectorfield(
    lsta_library: ndarray,
    channel: int,
    PC1: ndarray,
    PC2: ndarray,
    images: list[Any] | ndarray,
    images_coordinate: ndarray,
    explained_variance: ndarray,
    x_bins: int = 31,
    y_bins: int = 31,
    responses: ndarray | None = None,
) -> Figure

Plots a cleaned vector field representation of binned image and LSTA data projected onto principal components. This function bins images and their corresponding LSTA (Local Spike-Triggered Average) responses based on spatial coordinates, projects the binned data onto two principal components (PC1 and PC2), and visualizes the resulting vector field using quiver plots. Insets showing the PC1 and PC2 components are also added to the figure. This function is primarily for visualization in notebooks. Returns figure for saving or further customization. Insets showing the PC1 and PC2 components are also added to the figure for reference. Parameters


lsta_library : np.ndarray Array of LSTA responses with shape (n_samples, n_channels, x_size, y_size). channel : int Index of the channel to select from lsta_library. PC1 : np.ndarray First principal component vector for projection (flattened). PC2 : np.ndarray Second principal component vector for projection (flattened). images : np.ndarray Array of images corresponding to LSTA responses, shape (n_samples, x_size, y_size). images_coordinate : np.ndarray Array of spatial coordinates for each image, shape (n_samples, 2). explained_variance : np.ndarray Array containing explained variance for each principal component. x_bins : int, optional Number of bins along the x-axis for spatial binning (default is 31). y_bins : int, optional Number of bins along the y-axis for spatial binning (default is 31). responses : np.ndarray, optional Array of response values for each image, shape (n_samples,). If provided, will overlay response magnitudes as colored markers at each location. Returns


fig : matplotlib.figure.Figure The matplotlib figure object containing the vector field plot and PC insets. Call plt.show() to display, or fig.savefig() to save. Raises


ValueError If no images are found in the coordinate bins (e.g., due to bin size or coordinate range). Notes


  • This function is primarily intended for visualization in Jupyter notebooks.
  • The vector field arrows represent the projection of binned images and LSTA responses onto the first two principal components.
  • Insets display the spatial structure of PC1 and PC2 for interpretability.
  • If responses are provided, they will be averaged within bins and displayed as colored markers.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
def plot_clean_vectorfield(
    lsta_library: np.ndarray,
    channel: int,
    PC1: np.ndarray,
    PC2: np.ndarray,
    images: list[Any] | np.ndarray,
    images_coordinate: np.ndarray,
    explained_variance: np.ndarray,
    x_bins: int = 31,
    y_bins: int = 31,
    responses: np.ndarray | None = None,
) -> plt.Figure:
    """
    Plots a cleaned vector field representation of binned image and LSTA data projected onto principal components.
    This function bins images and their corresponding LSTA (Local Spike-Triggered Average) responses based on spatial
    coordinates,
    projects the binned data onto two principal components (PC1 and PC2),
    and visualizes the resulting vector field using quiver plots.
    Insets showing the PC1 and PC2 components are also added to the figure.
        This function is primarily for visualization in notebooks.
    Returns figure for saving or further customization.
    Insets showing the PC1 and PC2 components are also added to the figure for reference.
    Parameters
    ----------
    lsta_library : np.ndarray
        Array of LSTA responses with shape (n_samples, n_channels, x_size, y_size).
    channel : int
        Index of the channel to select from lsta_library.
    PC1 : np.ndarray
        First principal component vector for projection (flattened).
    PC2 : np.ndarray
        Second principal component vector for projection (flattened).
    images : np.ndarray
        Array of images corresponding to LSTA responses, shape (n_samples, x_size, y_size).
    images_coordinate : np.ndarray
        Array of spatial coordinates for each image, shape (n_samples, 2).
    explained_variance : np.ndarray
        Array containing explained variance for each principal component.
    x_bins : int, optional
        Number of bins along the x-axis for spatial binning (default is 31).
    y_bins : int, optional
        Number of bins along the y-axis for spatial binning (default is 31).
    responses : np.ndarray, optional
        Array of response values for each image, shape (n_samples,). If provided, will overlay
        response magnitudes as colored markers at each location.
    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object containing the vector field plot and PC insets. Call plt.show() to display,
        or fig.savefig() to save.
    Raises
    ------
    ValueError
        If no images are found in the coordinate bins (e.g., due to bin size or coordinate range).
    Notes
    -----
    - This function is primarily intended for visualization in Jupyter notebooks.
    - The vector field arrows represent the projection of binned images and LSTA responses onto the first two
    principal components.
    - Insets display the spatial structure of PC1 and PC2 for interpretability.
    - If responses are provided, they will be averaged within bins and displayed as colored markers.
    """
    lsta_library = lsta_library[:, channel, :, :]
    x_size = lsta_library.shape[-2]
    y_size = lsta_library.shape[-1]

    # Bin edges for PC1 and PC2 coordinates
    x_edges = np.linspace(images_coordinate[:, 0].min(), images_coordinate[:, 0].max(), x_bins + 1)
    y_edges = np.linspace(images_coordinate[:, 1].min(), images_coordinate[:, 1].max(), y_bins + 1)

    # Digitize coordinates to bins
    x_bin_idx = np.digitize(images_coordinate[:, 0], x_edges) - 1
    y_bin_idx = np.digitize(images_coordinate[:, 1], y_edges) - 1

    # Mask for valid bins
    valid_mask = (x_bin_idx >= 0) & (x_bin_idx < x_bins) & (y_bin_idx >= 0) & (y_bin_idx < y_bins)

    # Prepare lists for binned images and lstas
    binned_imgs_list = []
    binned_lstas_list = []
    bin_coords_list = []
    binned_responses_list = []

    # For each bin, average images and lstas assigned to it
    for xi in range(x_bins):
        for yi in range(y_bins):
            bin_mask = valid_mask & (x_bin_idx == xi) & (y_bin_idx == yi)
            if np.any(bin_mask):
                binned_imgs_list.append(images[bin_mask].mean(axis=0))
                binned_lstas_list.append(lsta_library[bin_mask].mean(axis=0))
                # Use bin center as coordinate
                bin_coords_list.append([0.5 * (x_edges[xi] + x_edges[xi + 1]), 0.5 * (y_edges[yi] + y_edges[yi + 1])])
                # Average responses within bin if provided
                if responses is not None:
                    binned_responses_list.append(responses[bin_mask].mean())

    binned_imgs = np.array(binned_imgs_list)
    binned_lstas = np.array(binned_lstas_list)
    images_coordinate = np.array(bin_coords_list)
    if responses is not None:
        binned_responses = np.array(binned_responses_list)
    # Check if we have any binned data

    if len(binned_imgs) == 0:
        raise ValueError("No images found in coordinate bins. Try adjusting bin size or coordinate range.")

    flatten_binned_imgs = binned_imgs.reshape(binned_imgs.shape[0], -1)
    flatten_binned_lstas = binned_lstas.reshape(binned_lstas.shape[0], -1)

    binned_arrowtails = np.array([[np.dot(PC1, img), np.dot(PC2, img)] for img in flatten_binned_imgs])
    binned_arrowheads = np.array([[np.dot(PC1, lsta), np.dot(PC2, lsta)] for lsta in flatten_binned_lstas])

    fig, ax = plt.subplots(figsize=(20, 20))

    # Calculate plot limits
    xlim = max(np.abs(binned_arrowtails[:, 0]).max(), np.abs(images_coordinate[:, 0]).max()) * 1.1
    ylim = max(np.abs(binned_arrowtails[:, 1]).max(), np.abs(images_coordinate[:, 1]).max()) * 1.1
    plot_limit = max(xlim, ylim)

    # Overlay response magnitudes as density plot if provided
    if responses is not None:
        # Create a grid for interpolation
        grid_resolution = 100
        x_interval = np.linspace(-plot_limit, plot_limit, grid_resolution)
        y_interval = np.linspace(-plot_limit, plot_limit, grid_resolution)
        xi_grid, yi_grid = np.meshgrid(x_interval, y_interval)

        # Interpolate the response values onto the grid
        zi = griddata(binned_arrowtails, binned_responses, (xi_grid, yi_grid), method="linear", fill_value=np.nan)

        # Create the density plot using pcolormesh
        density = ax.pcolormesh(x_interval, y_interval, zi, cmap="viridis", alpha=0.4, shading="gouraud", zorder=0)

        # Add colorbar
        cbar = plt.colorbar(density, ax=ax)
        cbar.set_label("Response magnitude", size=14)

    ax.quiver(
        binned_arrowtails[:, 0],
        binned_arrowtails[:, 1],
        binned_arrowheads[:, 0],
        binned_arrowheads[:, 1],
        color="black",
        width=0.002,
        scale_units="xy",
        angles="xy",
        scale=binned_arrowheads.max(),
        zorder=2,
    )

    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)

    # Add arrowheads to axes using matplotlib arrow function
    ax.arrow(
        -plot_limit * 0.75,
        0,
        1.5 * plot_limit,
        0,
        head_width=plot_limit * 0.02,
        head_length=plot_limit * 0.02,
        fc="k",
        ec="k",
        linewidth=1,
    )
    ax.arrow(
        0,
        -plot_limit * 0.75,
        0,
        1.5 * plot_limit,
        head_width=plot_limit * 0.02,
        head_length=plot_limit * 0.02,
        fc="k",
        ec="k",
        linewidth=1,
    )
    ax.set_xticks([])
    ax.set_yticks([])

    ax.set_xlim((-plot_limit, plot_limit))
    ax.set_ylim((-plot_limit, plot_limit))

    plot_pc_insets(fig, PC1, PC2, x_size, y_size, explained_variance)
    return fig

Helper Functions

load_and_preprocess_images

load_and_preprocess_images(
    image_dir: str,
    target_h: int,
    target_w: int,
    n_channels: int,
) -> ndarray

Loads PNG images from a directory, downsamples, center-crops, and repeats channels as needed. Parameters


image_dir (str): Directory containing PNG images.
target_h (int): Target height for cropping.
target_w (int): Target width for cropping.
n_channels (int): Number of channels to repeat.

Returns

np.ndarray: Array of shape (num_images, n_channels, target_h, target_w).

Raises

ValueError: If no PNG images are found in the directory.

Notes

- Images are downsampled to fit within target dimensions while maintaining aspect ratio.
- Center-cropping is applied after downsampling.
- Single-channel images are repeated across channels if n_channels > 1.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def load_and_preprocess_images(image_dir: str, target_h: int, target_w: int, n_channels: int) -> np.ndarray:
    """
    Loads PNG images from a directory, downsamples, center-crops, and repeats channels as needed.
    Parameters
    ----------
        image_dir (str): Directory containing PNG images.
        target_h (int): Target height for cropping.
        target_w (int): Target width for cropping.
        n_channels (int): Number of channels to repeat.
    Returns
    -------
        np.ndarray: Array of shape (num_images, n_channels, target_h, target_w).
    Raises
    -------
        ValueError: If no PNG images are found in the directory.
    Notes
    -------
        - Images are downsampled to fit within target dimensions while maintaining aspect ratio.
        - Center-cropping is applied after downsampling.
        - Single-channel images are repeated across channels if n_channels > 1.
    """
    image_files = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(".png")])
    images = np.array([np.array(Image.open(os.path.join(image_dir, f))) for f in image_files])
    # Downsample and crop using array operations
    downsample_factors = np.minimum(images.shape[1] / target_h, images.shape[2] / target_w).astype(int)
    # Ensure downsample_factors is at least 1
    downsample_factors[downsample_factors < 1] = 1

    # Downsample
    ds_images = np.array([img[::factor, ::factor] for img, factor in zip(images, downsample_factors)])

    # Center crop
    h, w = ds_images.shape[1:3]
    start_h = (h - target_h) // 2
    start_w = (w - target_w) // 2
    cropped_images = ds_images[:, start_h : start_h + target_h, start_w : start_w + target_w]

    compressed_images = cropped_images.astype(np.float32)
    # Add channel dimension
    compressed_images = compressed_images[:, np.newaxis]
    # Repeat channels if needed
    if n_channels > 1:
        compressed_images = np.repeat(compressed_images, n_channels, axis=1)
    return compressed_images

prepare_movies_dataset

prepare_movies_dataset(
    model: BaseCoreReadout,
    session_id: str,
    n_image_frames: int = 16,
    normalize_movies: bool = True,
    image_library: ndarray | None = None,
    image_dir: str | None = None,
    device: str = "cuda",
) -> tuple[ndarray, int]

Prepares a dataset of movie stimuli for input into a neural model. This function delegates image loading, preprocessing, normalization, and temporal padding to helper functions. Parameters


model: Neural model object with `data_info` attribute.
session_id (str): Identifier for the session.
n_image_frames (int, optional): Number of frames per movie for each image.
normalize_movies (bool, optional): Whether to normalize the movies.
image_library (np.ndarray, optional): Preprocessed image library.
image_dir (str, optional): Directory containing image files (.png).
device (str, optional): Device for torch tensors.

Returns

movies (np.ndarray): Array of shape (num_images, n_channels, n_frames, target_h, target_w).
n_empty_frames (int): Number of initial empty frames for temporal padding.

Raises

ValueError: If both `image_library` and `image_dir` are provided.
Source code in openretina/insilico/vector_field_analysis/vector_field_analysis.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def prepare_movies_dataset(
    model: BaseCoreReadout,
    session_id: str,
    n_image_frames: int = 16,
    normalize_movies: bool = True,
    image_library: np.ndarray | None = None,
    image_dir: str | None = None,
    device: str = "cuda",
) -> tuple[np.ndarray, int]:
    """
    Prepares a dataset of movie stimuli for input into a neural model.
    This function delegates image loading, preprocessing, normalization, and temporal padding to helper functions.
    Parameters
    ----------
        model: Neural model object with `data_info` attribute.
        session_id (str): Identifier for the session.
        n_image_frames (int, optional): Number of frames per movie for each image.
        normalize_movies (bool, optional): Whether to normalize the movies.
        image_library (np.ndarray, optional): Preprocessed image library.
        image_dir (str, optional): Directory containing image files (.png).
        device (str, optional): Device for torch tensors.
    Returns
    -------
        movies (np.ndarray): Array of shape (num_images, n_channels, n_frames, target_h, target_w).
        n_empty_frames (int): Number of initial empty frames for temporal padding.
    Raises
    -------
        ValueError: If both `image_library` and `image_dir` are provided.
    """
    n_channels = model.data_info["input_shape"][0]
    target_h, target_w = model.data_info["input_shape"][1:3]

    if image_library is not None and image_dir is not None:
        raise ValueError("Provide either image_library or image_dir, not both.")
    if image_dir is not None:
        LOGGER.info(f"Loading images from {image_dir}...")
        compressed_images = load_and_preprocess_images(image_dir, target_h, target_w, n_channels)
    elif image_library is not None:
        LOGGER.info("Using provided image library...")
        compressed_images = image_library
    else:
        raise ValueError("Provide either image_library or image_dir.")

    # number of grey frames = size of equivalent temporal filter of the full model + 10 for border effects
    n_empty_frames = get_model_temporal_padding(model, n_channels, target_h, target_w, device) + 10
    movies = np.repeat(compressed_images[:, :, np.newaxis, :, :], n_empty_frames + n_image_frames, axis=2)

    if normalize_movies:
        movies = normalize_movies_array(movies, model, session_id, n_channels)

    # Set initial empty frames to mean grey
    movies[:, :, :n_empty_frames, :, :] = movies.mean()
    return movies, n_empty_frames