Skip to content

Readout Modules

Modules that map core feature representations to individual neuron response predictions.

Base Readout

Readout

Bases: Module

Base readout class for all individual readouts. The MultiReadout will expect its readouts to inherit from this base class.

initialize

initialize(*args: Any, **kwargs: Any) -> None
Source code in openretina/modules/readout/base.py
23
24
def initialize(self, *args: Any, **kwargs: Any) -> None:
    raise NotImplementedError("initialize is not implemented for ", self.__class__.__name__)

regularizer

regularizer(
    reduction: Literal["sum", "mean", None] = "sum",
) -> Tensor
Source code in openretina/modules/readout/base.py
26
27
28
29
30
def regularizer(
    self,
    reduction: Literal["sum", "mean", None] = "sum",
) -> torch.Tensor:
    raise NotImplementedError("regularizer is not implemented for ", self.__class__.__name__)

apply_reduction

apply_reduction(
    x: Tensor,
    reduction: Literal["sum", "mean", None] = "mean",
) -> Tensor

Applies a reduction on the output of the regularizer. Args: x: output of the regularizer reduction: method of reduction for the regularizer. Currently possible are ['mean', 'sum', None].

Returns: reduced value of the regularizer

Source code in openretina/modules/readout/base.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def apply_reduction(self, x: torch.Tensor, reduction: Literal["sum", "mean", None] = "mean") -> torch.Tensor:
    """
    Applies a reduction on the output of the regularizer.
    Args:
        x: output of the regularizer
        reduction: method of reduction for the regularizer. Currently possible are ['mean', 'sum', None].

    Returns: reduced value of the regularizer
    """

    if reduction == "mean":
        return x.mean()
    elif reduction == "sum":
        return x.sum()
    elif reduction is None:
        return x
    else:
        raise ValueError(
            f"Reduction method '{reduction}' is not recognized. Valid values are ['mean', 'sum', None]"
        )

initialize_bias

initialize_bias(
    mean_activity: Optional[
        Float[Tensor, " n_neurons"]
    ] = None,
) -> None

Initialize the biases in readout. Args: mean_activity: Tensor containing the mean activity of neurons.

Returns:

Source code in openretina/modules/readout/base.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def initialize_bias(self, mean_activity: Optional[Float[torch.Tensor, " n_neurons"]] = None) -> None:
    """
    Initialize the biases in readout.
    Args:
        mean_activity: Tensor containing the mean activity of neurons.

    Returns:

    """
    if mean_activity is None:
        warnings.warn("Readout is NOT initialized with mean activity but with 0!")
        self.bias.data.fill_(0)
    else:
        self.bias.data = mean_activity

PointGaussianReadout

Spatial readout using learned Gaussian-sampled grid positions.

PointGaussianReadout

PointGaussianReadout(
    in_shape: tuple[int, int, int, int],
    outdims,
    bias,
    init_mu_range=0.1,
    init_sigma_range=0.15,
    batch_sample=True,
    align_corners=True,
    gauss_type="full",
    grid_mean_predictor=None,
    shared_features=None,
    shared_grid=None,
    init_grid=None,
    source_grid=None,
    mean_activity=None,
    gamma_readout=1.0,
    **kwargs,
)

Bases: Readout

A readout module that samples the output for each neuron at a single spatial location from the core feature map, where the location is drawn from a learned 2D Gaussian distribution for each neuron.

First introduced in Lurz et al., 2021: https://openreview.net/forum?id=Tp7kI90Htd

Key notes
  • Unlike mask-based readouts (GaussianMaskReadout, FactorisedReadout), this readout does NOT produce a spatial mask over the input. Instead, each neuron's response is determined by interpolating the feature map at a point sampled from a Gaussian distribution (parameterized by mean and covariance) for each neuron.

  • This mechanism results in each neuron having a flexible receptive field location within the Gaussian window during training (as the location is sampled from the Gaussian distribution), and a fixed location during inference (as the location is then fixed to the mean of the Gaussian distribution).

  • Instead of a spatial mask, the readout spatial location is a single "point" (x, y) in feature space per neuron per sample: there is no spatial integration or summing across a spatial region as in mask-based readouts.

  • Feature weights are still learned and behave like in the classic FactorisedReadout and GaussianMaskReadout.

PARAMETER DESCRIPTION
in_shape

shape of the input feature map [channels, width, height]

TYPE: (list, tuple)

outdims

number of output units

TYPE: int

bias

adds a bias term

TYPE: bool

init_mu_range

initialises the mean with Uniform([-init_range, init_range]) [expected: positive value <=1]. Default: 0.1

TYPE: float DEFAULT: 0.1

init_sigma_range

The standard deviation of the Gaussian with init_sigma_range when gauss_type is 'isotropic' or 'uncorrelated'. When gauss_type='full' initialize the square root of the covariance matrix with Uniform([-init_sigma_range, init_sigma_range]). Default: 1

TYPE: float DEFAULT: 0.15

batch_sample

if True, samples a position for each image in the batch separately [default: True as it decreases convergence time and performs just as well]

TYPE: bool DEFAULT: True

align_corners

Keyword agrument to gridsample for bilinear interpolation. It changed behavior in PyTorch 1.3. The default of align_corners = True is setting the behavior to pre PyTorch 1.3 functionality for comparability.

TYPE: bool DEFAULT: True

gauss_type

Which Gaussian to use. Options are 'isotropic', 'uncorrelated', or 'full' (default).

TYPE: str DEFAULT: 'full'

grid_mean_predictor

Parameters for a predictor of the mean grid locations. Has to have a form like { 'hidden_layers':0, 'hidden_features':20, 'final_tanh': False, }

TYPE: dict DEFAULT: None

shared_features

Used when the feature vectors are shared (within readout between neurons) or between this readout and other readouts. Has to be a dictionary of the form { 'match_ids': (numpy.array), 'shared_features': torch.nn.Parameter or None } The match_ids are used to match things that should be shared within or across scans. If shared_features is None, this readout will create its own features. If it is set to a feature Parameter of another readout, it will replace the features of this readout. It will be access in increasing order of the sorted unique match_ids. For instance, if match_ids=[2,0,0,1], there should be 3 features in order [0,1,2]. When this readout creates features, it will do so in that order.

TYPE: dict DEFAULT: None

shared_grid

Like shared_features. Use dictionary like { 'match_ids': (numpy.array), 'shared_grid': torch.nn.Parameter or None } See documentation of shared_features for specification.

TYPE: dict DEFAULT: None

source_grid
Source grid for the grid_mean_predictor.
Needs to be of size neurons x grid_mean_predictor[input_dimensions]

TYPE: array DEFAULT: None

init_grid

Initial grid locations for the neurons. Only set if both shared_grid and grid_mean_predictor are None.

TYPE: array DEFAULT: None

Source code in openretina/modules/readout/gaussian.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    outdims,
    bias,
    init_mu_range=0.1,
    init_sigma_range=0.15,
    batch_sample=True,
    align_corners=True,
    gauss_type="full",
    grid_mean_predictor=None,
    shared_features=None,
    shared_grid=None,
    init_grid=None,  # initial grid locations for the neurons
    source_grid=None,
    mean_activity=None,
    gamma_readout=1.0,
    **kwargs,
):
    super().__init__()
    self.gamma_readout = gamma_readout
    self.mean_activity = mean_activity
    # determines whether the Gaussian is isotropic or not
    self.gauss_type = gauss_type

    if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma_range <= 0.0:
        raise ValueError("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive")

    # store statistics about the images and neurons
    self.in_shape = in_shape
    self.outdims = outdims

    # sample a different location per example
    self.batch_sample = batch_sample

    # position grid shape
    self.grid_shape = (1, outdims, 1, 2)

    # the grid can be predicted from another grid
    self._predicted_grid = False
    self._shared_grid = False
    self._original_grid = not self._predicted_grid

    if grid_mean_predictor is None and shared_grid is None:
        self._mu = Parameter(torch.Tensor(*self.grid_shape))  # mean location of gaussian for each neuron
        if init_grid is not None:
            self._mu.data = init_grid
    elif grid_mean_predictor is not None and shared_grid is not None:
        raise ValueError("Shared grid_mean_predictor and shared_grid_mean cannot both be set")
    elif grid_mean_predictor is not None:
        self.init_grid_predictor(source_grid=source_grid, **grid_mean_predictor)
    elif shared_grid is not None:
        self.initialize_shared_grid(**(shared_grid or {}))

    if gauss_type == "full":
        self.sigma_shape = (1, outdims, 2, 2)
    elif gauss_type == "uncorrelated":
        self.sigma_shape = (1, outdims, 1, 2)
    elif gauss_type == "isotropic":
        self.sigma_shape = (1, outdims, 1, 1)
    else:
        raise ValueError(f'gauss_type "{gauss_type}" not known')

    self.init_sigma_range = init_sigma_range
    self.sigma = Parameter(torch.Tensor(*self.sigma_shape))  # standard deviation for gaussian for each neuron

    self.initialize_features(**(shared_features or {}))

    if bias:
        bias = Parameter(torch.Tensor(outdims))
        self.register_parameter("bias", bias)
    else:
        self.register_parameter("bias", None)

    self.init_mu_range = init_mu_range
    self.align_corners = align_corners
    self.initialize(mean_activity)

forward

forward(x, sample=None, shift=None, out_idx=None, **kwargs)

Propagates the input forwards through the readout Args: x: input data sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron or use the mean, mu, of the Gaussian distribution without sampling. if sample is None (default), samples from the N(mu,sigma) during training phase and fixes to the mean, mu, during evaluation phase. if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed shift (bool): shifts the location of the grid (from eye-tracking data) out_idx (bool): index of neurons to be predicted

RETURNS DESCRIPTION
y

neuronal activity

Source code in openretina/modules/readout/gaussian.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def forward(self, x, sample=None, shift=None, out_idx=None, **kwargs):
    """
    Propagates the input forwards through the readout
    Args:
        x: input data
        sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma),
                            defined per neuron or use the mean, mu, of the Gaussian distribution without sampling.
                            if sample is None (default), samples from the N(mu,sigma) during training phase and
                            fixes to the mean, mu, during evaluation phase.
                            if sample is True/False, overrides the model_state (i.e training or eval)
                            and does as instructed
        shift (bool): shifts the location of the grid (from eye-tracking data)
        out_idx (bool): index of neurons to be predicted

    Returns:
        y: neuronal activity
    """
    N, c, w, h = x.size()
    c_in, _, w_in, h_in = self.in_shape
    if (c_in, w_in, h_in) != (c, w, h):
        warnings.warn("the specified feature map dimension is not the readout's expected input dimension")
    feat = self.features.view(1, c, self.outdims)
    bias = self.bias
    outdims = self.outdims

    if self.batch_sample:
        # sample the grid_locations separately per image per batch
        grid = self.sample_grid(batch_size=N, sample=sample)  # sample determines sampling from Gaussian
    else:
        # use one sampled grid_locations for all images in the batch
        grid = self.sample_grid(batch_size=1, sample=sample).expand(N, outdims, 1, 2)

    if out_idx is not None:
        if isinstance(out_idx, np.ndarray):
            if out_idx.dtype == bool:
                out_idx = np.where(out_idx)[0]
        feat = feat[:, :, out_idx]
        grid = grid[:, out_idx]
        if bias is not None:
            bias = bias[out_idx]
        outdims = len(out_idx)

    if shift is not None:
        grid = grid + shift[:, None, None, :]

    y = F.grid_sample(x, grid, align_corners=self.align_corners)
    y = (y.squeeze(-1) * feat).sum(1).view(N, outdims)

    if self.bias is not None:
        y = y + bias
    return y

sample_grid

sample_grid(batch_size, sample=None)

Returns the grid locations from the core by sampling from a Gaussian distribution Args: batch_size (int): size of the batch sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma), defined per neuron or use the mean, mu, of the Gaussian distribution without sampling. if sample is None (default), samples from the N(mu,sigma) during training phase and fixes to the mean, mu, during evaluation phase. if sample is True/False, overrides the model_state (i.e training or eval) and does as instructed

Source code in openretina/modules/readout/gaussian.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def sample_grid(self, batch_size, sample=None):
    """
    Returns the grid locations from the core by sampling from a Gaussian distribution
    Args:
        batch_size (int): size of the batch
        sample (bool/None): sample determines whether we draw a sample from Gaussian distribution, N(mu,sigma),
                            defined per neuron or use the mean, mu, of the Gaussian distribution without sampling.
                            if sample is None (default), samples from the N(mu,sigma) during training phase and
                            fixes to the mean, mu, during evaluation phase.
                            if sample is True/False, overrides the model_state (i.e training or eval)
                            and does as instructed
    """

    # at eval time, only self.mu is used so it must belong to [-1,1]
    # sigma/variance is always a positive quantity
    with torch.no_grad():
        self.mu.clamp_(min=-1, max=1)

    grid_shape = (batch_size,) + self.grid_shape[1:]

    sample = self.training if sample is None else sample
    if sample:
        norm = self.mu.new(*grid_shape).normal_()
    else:
        norm = self.mu.new(*grid_shape).zero_()  # for consistency and CUDA capability

    if self.gauss_type != "full":
        return torch.clamp(
            norm * self.sigma + self.mu, min=-1, max=1
        )  # grid locations in feature space sampled randomly around the mean self.mu
    else:
        return torch.clamp(
            torch.einsum("ancd,bnid->bnic", self.sigma, norm) + self.mu,
            min=-1,
            max=1,
        )  # grid locations in feature space sampled randomly around the mean self.mu

regularizer

regularizer(reduction='sum', average=None)
Source code in openretina/modules/readout/gaussian.py
209
210
def regularizer(self, reduction="sum", average=None):
    return self.feature_l1(reduction=reduction) * self.gamma_readout

GaussianMaskReadout

Readout using a factorized Gaussian mask over spatial feature maps.

GaussianMaskReadout

GaussianMaskReadout(
    in_shape: tuple[int, int, int, int],
    outdims: int,
    mean_activity: Float[Tensor, " outdims"] | None = None,
    gaussian_mean_scale: float = 1.0,
    gaussian_var_scale: float = 1.0,
    positive: bool = False,
    scale: bool = False,
    bias: bool = True,
    nonlinearity_function=softplus,
    mask_l1_reg: float = 1.0,
    feature_weights_l1_reg: float = 1.0,
)

Bases: Readout

A readout module that computes each neuron's output as a weighted sum (dot product) across the spatial extent of the core feature map, using a 2D Gaussian mask per neuron. It can be considered as an extension of the classic FactorisedReadout, where the spatial mask is enforced to have a Gaussian shape.

First introduced in Hoefling et al., 2024: https://doi.org/10.7554/eLife.86860

Key notes
  • Unlike point-based Gaussian readouts (see PointGaussianReadout), this class produces a full spatial mask for each neuron, effectively performing spatial integration (weighted by a Gaussian) across the entire input feature map in the spatial dimensions.

  • Each neuron has a single mask_log_var scalar (not per-axis), used as variance for both x and y, so the receptive field is circular (in the normalized grid), axis-aligned, and isotropic.

PARAMETER DESCRIPTION
in_shape

The shape of the input tensor (c, t, w, h).

TYPE: tuple[int, int, int, int]

outdims

The number of output dimensions (usually the number of neurons in the session).

TYPE: int

mean_activity

The mean activity of the neurons, used to initialize the bias. Defaults to None.

TYPE: Float[Tensor, ' outdims'] | None DEFAULT: None

gaussian_mean_scale

The scale factor for the Gaussian mask mean. Defaults to 1e0.

TYPE: float DEFAULT: 1.0

gaussian_var_scale

The scale factor for the Gaussian mask variance. Defaults to 1e0.

TYPE: float DEFAULT: 1.0

positive

Whether the output should be positive. Defaults to False.

TYPE: bool DEFAULT: False

scale

Whether to include a scale parameter. Defaults to False.

TYPE: bool DEFAULT: False

bias

Whether to include a bias parameter. Defaults to True.

TYPE: bool DEFAULT: True

nonlinearity_function

torch nonlinearity function , e.g. nn.functional.softplus

DEFAULT: softplus

mask_l1_reg

The regularization strength for the sparsity of the spatial mask. Defaults to 1.0.

TYPE: float DEFAULT: 1.0

feature_weights_l1_reg

The regularization strength for the sparsity of feature weights. Defaults to 1.0.

TYPE: float DEFAULT: 1.0

Source code in openretina/modules/readout/factorized_gaussian.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    outdims: int,
    mean_activity: Float[torch.Tensor, " outdims"] | None = None,
    gaussian_mean_scale: float = 1e0,
    gaussian_var_scale: float = 1e0,
    positive: bool = False,
    scale: bool = False,
    bias: bool = True,
    nonlinearity_function=torch.nn.functional.softplus,
    mask_l1_reg: float = 1.0,
    feature_weights_l1_reg: float = 1.0,
):
    """
    Args:
        in_shape: The shape of the input tensor (c, t, w, h).
        outdims: The number of output dimensions (usually the number of neurons in the session).
        mean_activity: The mean activity of the neurons, used to initialize the bias. Defaults to None.
        gaussian_mean_scale: The scale factor for the Gaussian mask mean. Defaults to 1e0.
        gaussian_var_scale: The scale factor for the Gaussian mask variance. Defaults to 1e0.
        positive: Whether the output should be positive. Defaults to False.
        scale: Whether to include a scale parameter. Defaults to False.
        bias: Whether to include a bias parameter. Defaults to True.
        nonlinearity_function: torch nonlinearity function , e.g. nn.functional.softplus
        mask_l1_reg: The regularization strength for the sparsity of the spatial mask. Defaults to 1.0.
        feature_weights_l1_reg: The regularization strength for the sparsity of feature weights. Defaults to 1.0.
    """
    super().__init__()
    self.in_shape = in_shape
    c, t, w, h = in_shape
    self.outdims = outdims
    self.gaussian_mean_scale = gaussian_mean_scale
    self.gaussian_var_scale = gaussian_var_scale
    self.positive = positive
    self.nonlinearity_function = nonlinearity_function
    self.mask_l1_reg = mask_l1_reg
    self.feature_weights_l1_reg = feature_weights_l1_reg

    """we train on the log var and transform to var in a separate step"""
    self.mask_mean = torch.nn.Parameter(data=torch.zeros(self.outdims, 2), requires_grad=True)
    self.mask_log_var = torch.nn.Parameter(data=torch.zeros(self.outdims), requires_grad=True)

    # Grid is fixed and untrainable, so we register it as a buffer
    self.register_buffer("grid", self.make_mask_grid(outdims, w, h))

    self.features = nn.Parameter(torch.Tensor(1, c, 1, outdims))
    self.features.data.normal_(1.0 / c, 0.01)

    if scale:
        self.scale_param = nn.Parameter(torch.ones(outdims))
        self.scale_param.data.normal_(1.0, 0.01)
    else:
        self.register_buffer("scale_param", torch.ones(outdims))  # Non-trainable

    if bias:
        self.bias = nn.Parameter(torch.zeros(outdims))
    else:
        self.register_buffer("bias", torch.zeros(outdims))  # Non-trainable

    self.initialize(mean_activity)

forward

forward(x: Tensor) -> Tensor
Source code in openretina/modules/readout/factorized_gaussian.py
151
152
153
154
155
156
157
158
159
160
def forward(self, x: torch.Tensor) -> torch.Tensor:
    masks = self.masks
    y = torch.einsum("nctwh,whd->nctd", x, masks)

    if self.positive:
        self.features.data.clamp_(0)
    y = (y * self.features).sum(1)

    y = self.nonlinearity_function(y * self.scale_param + self.bias)
    return y

regularizer

regularizer(
    reduction: Literal["sum", "mean", None] = "sum",
) -> Tensor
Source code in openretina/modules/readout/factorized_gaussian.py
119
120
121
122
123
124
def regularizer(self, reduction: Literal["sum", "mean", None] = "sum") -> torch.Tensor:
    reg = (
        self.mask_l1(average=reduction == "mean") * self.mask_l1_reg
        + self.feature_l1(average=reduction == "mean") * self.feature_weights_l1_reg
    )
    return reg

FactorizedReadout

FactorizedReadout

FactorizedReadout(
    in_shape: tuple[int, int, int, int],
    outdims: int,
    mask_l1_reg: float,
    weights_l1_reg: float,
    laplace_mask_reg: float,
    mask_size: int | tuple[int, int],
    readout_bias: bool = False,
    weights_constraint: Literal["abs", "norm", "absnorm"]
    | None = None,
    mask_constraint: Literal["abs"] | None = None,
    init_mask: Tensor | None = None,
    init_weights: Tensor | None = None,
    init_scales: Sequence[tuple[float, float]]
    | None = None,
    mean_activity: Float[Tensor, " outdims"] | None = None,
)

Bases: Readout

The canonical factorized readout module: each neuron's output is the dot product of a learned 2D spatial mask and feature weights.

This module implements the general Factorized (a.k.a. "Klindt") Readout, where—for each neuron—the spatial integration is performed via a freeform, unconstrained (but sparse) mask, and the stimulus dimensions (e.g., features, channels, time) are combined by a separate learned vector of feature weights. The spatial mask is independently learned for every neuron without any restriction to a particular functional form, but with sparsity penalties.

First introduced in Klindt et al., 2017: https://doi.org/10.48550/arXiv.1711.02653

Key notes
  • Unlike parametric-masked readouts (see GaussianMaskReadout), this class allows the spatial mask to take any shape, offering maximum expressive power for fitting the spatial receptive field.

  • Typical regularizations include sparsity (L1), Laplace smoothness penalties, and optional constraints (non-negativity, normalization) on the mask or weights.

Initializes the FactorizedReadout module : (2d spatial mask + feature weights) / cell. Args: in_shape: The shape of the input tensor (c, t, w, h). outdims (int): Number of output neurons. mask_l1_reg (float): L1 regularization strength for mask. weights_l1_reg (float): L1 regularization strength for weights. laplace_mask_reg (float): Laplace regularization strength for mask. mask_size (int | Tuple[int, int]): Size of the mask (height, width) or (height). readout_bias (bool, optional): If True, includes bias in readout. Defaults to False. weights_constraint (Optional[str], optional): Constraint for weights. Defaults to None. mask_constraint (Optional[str], optional): Constraint for mask. Defaults to None. init_mask (Optional[torch.Tensor], optional): Initial mask tensor. Defaults to None. init_weights (Optional[torch.Tensor], optional): Initial weights tensor. Defaults to None. init_scales (Optional[Sequence[Tuple[float, float]]], optional): Initialization scales for mask and weights. Defaults to None. mean_activity (Float[torch.Tensor, " outdims"] | None): Mean activity of neurons. Defaults to None. Raises: ValueError: If neither init_mask nor init_scales is provided.

Source code in openretina/modules/readout/factorized.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    outdims: int,
    mask_l1_reg: float,
    weights_l1_reg: float,
    laplace_mask_reg: float,
    mask_size: int | tuple[int, int],
    readout_bias: bool = False,
    weights_constraint: Literal["abs", "norm", "absnorm"] | None = None,
    mask_constraint: Literal["abs"] | None = None,
    init_mask: torch.Tensor | None = None,
    init_weights: torch.Tensor | None = None,
    init_scales: Sequence[tuple[float, float]] | None = None,
    mean_activity: Float[torch.Tensor, " outdims"] | None = None,
):
    """
    Initializes the FactorizedReadout module : (2d spatial mask + feature weights) / cell.
    Args:
        in_shape: The shape of the input tensor (c, t, w, h).
        outdims (int): Number of output neurons.
        mask_l1_reg (float): L1 regularization strength for mask.
        weights_l1_reg (float): L1 regularization strength for weights.
        laplace_mask_reg (float): Laplace regularization strength for mask.
        mask_size (int | Tuple[int, int]): Size of the mask (height, width) or (height).
        readout_bias (bool, optional): If True, includes bias in readout. Defaults to False.
        weights_constraint (Optional[str], optional): Constraint for weights. Defaults to None.
        mask_constraint (Optional[str], optional): Constraint for mask. Defaults to None.
        init_mask (Optional[torch.Tensor], optional): Initial mask tensor. Defaults to None.
        init_weights (Optional[torch.Tensor], optional): Initial weights tensor. Defaults to None.
        init_scales (Optional[Sequence[Tuple[float, float]]], optional): Initialization scales for mask
        and weights. Defaults to None.
        mean_activity (Float[torch.Tensor, " outdims"] | None): Mean activity of neurons. Defaults to None.
    Raises:
        ValueError: If neither init_mask nor init_scales is provided.
    """
    super().__init__()

    self.outdims = outdims
    self.in_shape = in_shape
    channels, _, _, _ = in_shape
    self.reg = [mask_l1_reg, weights_l1_reg, laplace_mask_reg]
    self.readout_bias = readout_bias
    self.weights_constraint = weights_constraint
    self.mask_constraint = mask_constraint
    self._input_weights_regularizer_spatial = FlatLaplaceL23dnorm(padding=0)
    num_neurons = outdims

    if isinstance(mask_size, int):
        num_mask_pixels = mask_size**2
        self.mask_size = (mask_size, mask_size)
    else:
        h, w = mask_size
        num_mask_pixels = h * w
        self.mask_size = mask_size

    if init_mask is not None:
        assert num_neurons == init_mask.shape[0], "Number of neurons in init_mask does not match outdims"
        h, w = self.mask_size
        H, W = init_mask.shape[2], init_mask.shape[3]
        h_offset = (H - h) // 2
        w_offset = (W - w) // 2

        # Crop center region
        cropped = init_mask[:, :, h_offset : h_offset + h, w_offset : w_offset + w]  # shape: (num_neurons, 1, h, w)

        # Reshape to (num_mask_pixels, num_neurons)
        reshaped = cropped.reshape(num_neurons, -1).T  # shape: (h*w, num_neurons)

        # Convert to tensor and register as parameter
        self.mask_weights = nn.Parameter(torch.tensor(reshaped, dtype=torch.float32))
    else:
        if init_scales is None:
            raise ValueError("Either init_mask or init_scales must be provided")
        mean, std = init_scales[0]
        self.mask_weights = nn.Parameter(torch.normal(mean=mean, std=std, size=(num_mask_pixels, num_neurons)))

    if init_weights is not None:
        self.feature_weights = nn.Parameter(init_weights)
    else:
        assert init_scales is not None
        mean, std = init_scales[1]
        self.feature_weights = nn.Parameter(torch.normal(mean=mean, std=std, size=(channels, num_neurons)))

    if readout_bias:
        self.bias = nn.Parameter(torch.zeros(outdims))
    else:
        self.register_buffer("bias", torch.zeros(outdims))  # Non-trainable

    self.initialize(mean_activity)

forward

forward(x: Tensor, **kwargs: Any) -> Tensor
Source code in openretina/modules/readout/factorized.py
146
147
148
149
150
151
152
153
154
155
156
157
def forward(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    self.apply_constraints()
    B, C, T, H, W = x.shape
    h, w = self.mask_size
    assert H == h and W == w
    x_flat = x.view(B, C, T, -1).permute(0, 2, 1, 3)
    masked = torch.matmul(x_flat, self.mask_weights)
    masked = masked.permute(0, 1, 3, 2)
    out = (masked * self.feature_weights.T.unsqueeze(0).unsqueeze(0)).sum(dim=3)
    if self.readout_bias:
        out = out + self.bias
    return F.softplus(out)

regularizer

regularizer(
    reduction: Optional[Literal["sum", "mean"]] = None,
) -> Tensor
Source code in openretina/modules/readout/factorized.py
159
160
161
162
163
164
165
166
167
def regularizer(self, reduction: Optional[Literal["sum", "mean"]] = None) -> torch.Tensor:
    mask_r = self.reg[0] * torch.mean(torch.sum(torch.abs(self.mask_weights), dim=0))
    wt_r = self.reg[1] * torch.mean(torch.sum(torch.abs(self.feature_weights), dim=0))
    reshaped_masked_weights = self.mask_weights.reshape(-1, 1, 1, self.mask_size[0], self.mask_size[1])
    laplace_mask_r = self.reg[2] * self._input_weights_regularizer_spatial(reshaped_masked_weights, avg=False)
    if reduction == "mean":
        return mask_r + wt_r + laplace_mask_r / 3
    else:
        return mask_r + wt_r + laplace_mask_r

LNPReadout

LNPReadout

LNPReadout(
    in_shape: Int[tuple, "channel time height width"],
    outdims: int,
    mean_activity: Float[Tensor, " outdims"] | None = None,
    smooth_weight: float = 0.0,
    sparse_weight: float = 0.0,
    smooth_regularizer: str = "LaplaceL2norm",
    laplace_padding=None,
    nonlinearity: str = "exp",
    bias: bool = False,
    **kwargs,
)

Bases: Readout

Linear Nonlinear Poisson Readout (LNP) For use as an LNP Model use this readout with a DummyCore that passes the input through.

Source code in openretina/modules/readout/linear_nonlinear_poison.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(
    self,
    in_shape: Int[tuple, "channel time height width"],
    outdims: int,
    mean_activity: Float[torch.Tensor, " outdims"] | None = None,
    smooth_weight: float = 0.0,
    sparse_weight: float = 0.0,
    smooth_regularizer: str = "LaplaceL2norm",
    laplace_padding=None,
    nonlinearity: str = "exp",
    bias: bool = False,
    **kwargs,
):
    super().__init__()
    self.smooth_weight = smooth_weight
    self.sparse_weight = sparse_weight
    self.kernel_size = tuple(in_shape[2:])
    self.in_channels = in_shape[0]
    self.n_neurons = outdims
    self.nonlinearity = torch.exp if nonlinearity == "exp" else F.__dict__[nonlinearity]

    self.inner_product_kernel = nn.Conv3d(
        in_channels=self.in_channels,
        out_channels=self.n_neurons,  # Each neuron gets its own kernel
        kernel_size=(1, *self.kernel_size),  # Not using time
        bias=bias,
        stride=1,
    )

    nn.init.xavier_normal_(self.inner_product_kernel.weight.data)

    regularizer_config = (
        dict(padding=laplace_padding, kernel=self.kernel_size)
        if smooth_regularizer == "GaussianLaplaceL2"
        else dict(padding=laplace_padding)
    )

    self._smooth_reg_fn = regularizers.__dict__[smooth_regularizer](**regularizer_config)
    self.initialize(mean_activity)

forward

forward(
    x: Float[Tensor, "batch channels t h w"],
    data_key=None,
    **kwargs,
)
Source code in openretina/modules/readout/linear_nonlinear_poison.py
57
58
59
60
61
def forward(self, x: Float[torch.Tensor, "batch channels t h w"], data_key=None, **kwargs):
    out = self.inner_product_kernel(x)
    out = self.nonlinearity(out)
    out = rearrange(out, "batch neurons t 1 1 -> batch t neurons")
    return out

regularizer

regularizer(**kwargs)
Source code in openretina/modules/readout/linear_nonlinear_poison.py
78
79
def regularizer(self, **kwargs):
    return self.smooth_weight * self.laplace() + self.sparse_weight * self.weights_l1()

Multi-Session Readouts

Wrappers that manage one readout instance per recording session.

MultiReadoutBase

MultiReadoutBase(
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    base_readout: type[Readout] | None = None,
    mean_activity_dict: dict[str, Float[Tensor, " neurons"]]
    | None = None,
    clone_readout=False,
    readout_reg_avg: bool = False,
    **kwargs,
)

Bases: ModuleDict

Base class for MultiReadouts. It is a dictionary of data keys and readouts to the corresponding datasets.

Adapted from neuralpredictors. Original code at: https://github.com/sinzlab/neuralpredictors/blob/v0.3.0.pre/neuralpredictors/layers/readouts/multi_readout.py

PARAMETER DESCRIPTION
in_shape_dict

dictionary of data_key and the corresponding dataset's shape as an output of the core.

TYPE: dict

n_neurons_dict

dictionary of data_key and the corresponding dataset's number of neurons

TYPE: dict

base_readout

base readout class. If None, self._base_readout must be set manually in the inheriting class's definition.

TYPE: Module DEFAULT: None

mean_activity_dict

dictionary of data_key and the corresponding dataset's mean responses. Used to initialize the readout bias with. If None, the bias is initialized with 0.

TYPE: dict DEFAULT: None

clone_readout

whether to clone the first data_key's readout to all other readouts, only allowing for a scale and offset. This is a rather simple method to enforce parameter-sharing between readouts.

TYPE: bool DEFAULT: False

gamma_readout

regularization strength

TYPE: float

**kwargs

additional keyword arguments to be passed to the base_readout's constructor

DEFAULT: {}

Source code in openretina/modules/readout/multi_readout.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    base_readout: type[Readout] | None = None,
    mean_activity_dict: dict[str, Float[torch.Tensor, " neurons"]] | None = None,
    clone_readout=False,
    readout_reg_avg: bool = False,
    **kwargs,
):
    # The `base_readout` can be overridden only if the static property `_base_readout_cls` is not set
    if self._base_readout_cls is None:
        assert base_readout is not None, (
            "Argument `base_readout` must be provided if the class variable `_base_readout_cls` is not set"
        )
        self._base_readout_cls = base_readout

    self._readout_kwargs = kwargs
    self._in_shape = in_shape
    self.readout_reg_avg = readout_reg_avg
    self.readout_reg_reduction: Literal["mean", "sum"] = "mean" if readout_reg_avg else "sum"
    super().__init__()

    for i, data_key in enumerate(n_neurons_dict):
        mean_activity = mean_activity_dict[data_key] if mean_activity_dict is not None else None

        if i == 0 or clone_readout is False:
            self.add_module(
                data_key,
                self._base_readout_cls(
                    in_shape=in_shape,
                    outdims=n_neurons_dict[data_key],
                    mean_activity=mean_activity,
                    **kwargs,
                ),
            )
            original_readout = data_key
        elif i > 0 and clone_readout is True:
            original_readout_object: Readout = self[original_readout]  # type: ignore
            self.add_module(data_key, ClonedReadout(original_readout_object))

    self.initialize(mean_activity_dict)

forward

forward(
    *args, data_key: str | None = None, **kwargs
) -> Tensor
Source code in openretina/modules/readout/multi_readout.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def forward(self, *args, data_key: str | None = None, **kwargs) -> torch.Tensor:
    if data_key is None:
        warnings.warn(
            "No data key provided, returning concatenated responses from all readouts",
            stacklevel=2,
            category=UserWarning,
        )
        readout_responses = []
        for readout_key in self.readout_keys():
            resp = self[readout_key](*args, **kwargs)
            readout_responses.append(resp)
        response = torch.cat(readout_responses, dim=-1)
    else:
        response = self[data_key](*args, **kwargs)
    return response

add_sessions

add_sessions(
    n_neurons_dict: dict[str, int],
    mean_activity_dict: dict[str, Float[Tensor, " neurons"]]
    | None = None,
) -> None

Wrapper method to add new sessions to the readout wrapper. Can be called to add new sessions to an existing readout wrapper. Individual readouts should override this method to add additional checks.

Source code in openretina/modules/readout/multi_readout.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def add_sessions(
    self,
    n_neurons_dict: dict[str, int],
    mean_activity_dict: dict[str, Float[torch.Tensor, " neurons"]] | None = None,
) -> None:
    """Wrapper method to add new sessions to the readout wrapper.
    Can be called to add new sessions to an existing readout wrapper.
    Individual readouts should override this method to add additional checks.
    """
    self._add_sessions(n_neurons_dict, mean_activity_dict)

regularizer

regularizer(
    data_key: str | None = None,
    reduction: Literal["sum", "mean"] | None = None,
)
Source code in openretina/modules/readout/multi_readout.py
151
152
153
154
155
156
157
158
def regularizer(self, data_key: str | None = None, reduction: Literal["sum", "mean"] | None = None):
    if reduction is None:
        reduction = self.readout_reg_reduction
    if data_key is None and len(self) == 1:
        data_key = list(self.keys())[0]
    elif data_key is None:
        raise ValueError("data_key is required when there are multiple sessions")
    return self[data_key].regularizer(reduction=reduction)

MultiGaussianMaskReadout

MultiGaussianMaskReadout(
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    scale: bool,
    bias: bool,
    gaussian_mean_scale: float,
    gaussian_var_scale: float,
    positive: bool,
    mask_l1_reg: float = 1.0,
    feature_weights_l1_reg: float = 1.0,
    readout_reg_avg: bool = False,
    mean_activity_dict: dict[str, Float[Tensor, " neurons"]]
    | None = None,
)

Bases: MultiReadoutBase

Multiple Sessions version of the GaussianMaskReadout factorised gaussian readout.

Source code in openretina/modules/readout/multi_readout.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    scale: bool,
    bias: bool,
    gaussian_mean_scale: float,
    gaussian_var_scale: float,
    positive: bool,
    mask_l1_reg: float = 1.0,
    feature_weights_l1_reg: float = 1.0,
    readout_reg_avg: bool = False,
    mean_activity_dict: dict[str, Float[torch.Tensor, " neurons"]] | None = None,
):
    super().__init__(
        in_shape=in_shape,
        n_neurons_dict=n_neurons_dict,
        mean_activity_dict=mean_activity_dict,
        scale=scale,
        bias=bias,
        gaussian_mean_scale=gaussian_mean_scale,
        gaussian_var_scale=gaussian_var_scale,
        positive=positive,
        mask_l1_reg=mask_l1_reg,
        feature_weights_l1_reg=feature_weights_l1_reg,
        readout_reg_avg=readout_reg_avg,
    )

MultiFactorizedReadout

MultiFactorizedReadout(
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    mask_l1_reg: float,
    weights_l1_reg: float,
    laplace_mask_reg: float,
    readout_bias: bool = False,
    weights_constraint: Literal["abs", "norm", "absnorm"]
    | None = None,
    mask_constraint: Literal["abs"] | None = None,
    init_mask: Optional[Tensor] = None,
    init_weights: Optional[Tensor] = None,
    init_scales: Optional[Iterable[Iterable[float]]] = None,
    readout_reg_avg: bool = False,
    mean_activity_dict: dict[str, Float[Tensor, " neurons"]]
    | None = None,
)

Bases: MultiReadoutBase

Multiple Sessions version of the classic factorized readout.

Source code in openretina/modules/readout/multi_readout.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    mask_l1_reg: float,
    weights_l1_reg: float,
    laplace_mask_reg: float,
    readout_bias: bool = False,
    weights_constraint: Literal["abs", "norm", "absnorm"] | None = None,
    mask_constraint: Literal["abs"] | None = None,
    init_mask: Optional[torch.Tensor] = None,
    init_weights: Optional[torch.Tensor] = None,
    init_scales: Optional[Iterable[Iterable[float]]] = None,
    readout_reg_avg: bool = False,
    mean_activity_dict: dict[str, Float[torch.Tensor, " neurons"]] | None = None,
):
    mask_size = in_shape[2:]
    super().__init__(
        in_shape=in_shape,
        n_neurons_dict=n_neurons_dict,
        mask_size=mask_size,
        mask_l1_reg=mask_l1_reg,
        weights_l1_reg=weights_l1_reg,
        laplace_mask_reg=laplace_mask_reg,
        readout_bias=readout_bias,
        weights_constraint=weights_constraint,
        mask_constraint=mask_constraint,
        init_mask=init_mask,
        init_weights=init_weights,
        init_scales=init_scales,
        readout_reg_avg=readout_reg_avg,
        mean_activity_dict=mean_activity_dict,
    )

MultiSampledGaussianReadout

MultiSampledGaussianReadout(
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    bias: bool,
    init_mu_range: float,
    init_sigma_range: float,
    batch_sample: bool = True,
    align_corners: bool = True,
    gauss_type: Literal["full", "iso"] = "full",
    grid_mean_predictor=None,
    shared_features=None,
    shared_grid=None,
    init_grid=None,
    gamma: float = 1.0,
    reg_avg: bool = False,
    nonlinearity_function: Callable[
        [Tensor], Tensor
    ] = softplus,
    mean_activity_dict: dict[str, Float[Tensor, " neurons"]]
    | None = None,
)

Bases: MultiReadoutBase

Multiple Sessions version of the sampled point gaussian readout.

Source code in openretina/modules/readout/multi_readout.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    bias: bool,
    init_mu_range: float,
    init_sigma_range: float,
    batch_sample: bool = True,
    align_corners: bool = True,
    gauss_type: Literal["full", "iso"] = "full",
    grid_mean_predictor=None,
    shared_features=None,
    shared_grid=None,
    init_grid=None,
    gamma: float = 1.0,
    reg_avg: bool = False,
    nonlinearity_function: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.softplus,
    mean_activity_dict: dict[str, Float[torch.Tensor, " neurons"]] | None = None,
):
    super().__init__(
        in_shape=in_shape,
        n_neurons_dict=n_neurons_dict,
        bias=bias,
        init_mu_range=init_mu_range,
        init_sigma_range=init_sigma_range,
        batch_sample=batch_sample,
        align_corners=align_corners,
        gauss_type=gauss_type,
        grid_mean_predictor=grid_mean_predictor,
        shared_features=shared_features,
        shared_grid=shared_grid,
        init_grid=init_grid,
        gamma_readout=gamma,
        readout_reg_avg=reg_avg,
        mean_activity_dict=mean_activity_dict,
    )

    self.nonlinearity = nonlinearity_function

MultipleLNPReadout

MultipleLNPReadout(
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    **kwargs,
)

Bases: MultiReadoutBase

Multiple Linear Nonlinear Poisson Readout (LNP) For use as an LNP Model use this readout with a DummyCore that passes the input through.

Source code in openretina/modules/readout/multi_readout.py
336
337
338
339
340
341
342
343
344
345
346
def __init__(
    self,
    in_shape: tuple[int, int, int, int],
    n_neurons_dict: dict[str, int],
    **kwargs,
):
    super().__init__(
        in_shape=in_shape,
        n_neurons_dict=n_neurons_dict,
        **kwargs,
    )