Linear-Nonlinear Models
SingleCellSeparatedLNP
SingleCellSeparatedLNP(
in_shape: Int[tuple, "channel time height width"],
rf_location: Optional[Int[tuple, "y x"]] = None,
spat_kernel_size: Int[tuple, "height width"] = (15, 15),
learning_rate: float = 0.001,
rank: int = 1,
smooth_weight_spat: float = 0.0,
smooth_weight_temp: float = 0.0,
sparse_weight: float = 0.0,
smooth_regularizer_spat: str = "LaplaceL2norm",
smooth_regularizer_temp: str = "Laplace1d",
smooth_regularizer: str = "LaplaceL2norm",
laplace_padding=None,
nonlinearity: str = "exp",
normalize_weights: bool = True,
loss=None,
validation_loss=None,
**kwargs,
)
Bases: LightningModule
Single-cell, separable LNP model implemented as a PyTorch LightningModule.
This model implements an LNP-style encoding model where the linear filter is constrained to be space-time separable.
The special feature of this model is that it is "single-cell", meaning that it is only able to predict a single neuron's activity, in a single session, unlike other Core-Readout models.
Spatial filtering is performed with a 3D convolution whose kernel size is (1, H, W), i.e. no temporal mixing in the first stage. Temporal filtering is performed with a second 3D convolution spanning the full input time axis (T, 1, 1). The separable rank controls the number of spatial and temporal components used.
The module crops the input around a receptive-field location so that the spatial kernel operates on a local patch rather than the full image.
Parameters
in_shape:
Tuple (channels, time, height, width) describing the expected stimulus shape
excluding batch dimension.
rf_location:
Optional (y, x) center location (in input pixel coordinates) used when
cropping to a spatial patch of size spat_kernel_size. If None, defaults
to the spatial center of the input.
spat_kernel_size:
(height, width) of the spatial kernel / crop window.
learning_rate:
Learning rate used by optimizer.
rank:
Separable rank specifies the number of spatial and temporal filter pairs that can be learned to predict.
rank=1 corresponds to a single spatial filter and a single temporal filter.
max rank can be which corresponds to a full 3d convolution.
smooth_weight_spat:
Weight for spatial smoothness regularization (applied to space_conv).
smooth_weight_temp:
Weight for temporal smoothness regularization (applied to time_conv).
sparse_weight:
Weight for L1 sparsity penalty on both spatial and temporal kernels.
smooth_regularizer_spat:
Name of the spatial regularizer class in regularizers.__dict__.
Examples in this code path include "LaplaceL2norm" or "GaussianLaplaceL2".
smooth_regularizer_temp:
Name of the temporal regularizer class in regularizers.__dict__.
smooth_regularizer:
Currently stored but not used directly in this implementation (kept for API
compatibility / future use).
laplace_padding:
Passed through to regularizer constructors as padding=.... For GaussianLaplaceL2,
a kernel=... argument is also supplied.
nonlinearity:
Output nonlinearity applied after the temporal stage. If "parametrized_softplus",
uses ParametrizedSoftplus(). Otherwise, uses torch.nn.functional.<nonlinearity>
via F.__dict__[nonlinearity] (e.g. "exp", "softplus", ...).
normalize_weights:
If True, renormalizes spatial and temporal kernels to unit norm at every
forward pass (in-place, under no_grad).
loss:
Training loss. Defaults to PoissonLoss3d() if None.
validation_loss:
Validation metric/loss. Defaults to CorrelationLoss3d(avg=True) if None.
During training/validation, correlation is logged as the negative of this loss.
Input / Output shapes
Forward input x:
Tensor of shape (batch, channels, time, height, width).
Forward output:
Tensor of shape (batch, time, neurons) where neurons=1.
Notes
Cropping: If spat_kernel_size does not match the full input spatial size,
the module crops a patch centered at rf_location before applying convolutions.
Cropping is boundary-safe (it clips at edges).
Regularization: regularizer() = laplace() + sparse_weight * weights_l1() where laplace() applies the configured smoothness penalties to spatial and temporal kernels, and weights_l1() applies L1 to both kernels.
Logged metrics
Training: - regularization_loss_core - train_total_loss - train_loss - train_correlation Validation: - val_loss - val_regularization_loss - val_total_loss - val_correlation
Source code in openretina/models/linear_nonlinear.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | |
crop_input
crop_input(
input_tensor: Float[Tensor, "batch channels t h w"],
)
Crops the input tensor to the size of the receptive field.
Source code in openretina/models/linear_nonlinear.py
185 186 187 188 189 190 191 192 193 194 195 196 197 | |
weights_l1
weights_l1(average: bool = True)
Returns l1 regularization across all weight dimensions
| PARAMETER | DESCRIPTION |
|---|---|
average
|
use mean of weights instead of sum. Defaults to True.
TYPE:
|
Source code in openretina/models/linear_nonlinear.py
266 267 268 269 270 271 272 273 274 275 | |
normalize_kernels
normalize_kernels()
Normalizes the kernels to have unit norm.
Source code in openretina/models/linear_nonlinear.py
277 278 279 280 281 | |