Loss Functions
Loss functions for training retinal models. All losses handle temporal lag between predictions and targets automatically (the model may produce fewer time steps than the target due to temporal convolutions).
Poisson Losses
PoissonLoss3d
PoissonLoss3d(
bias: float = 1e-16,
per_neuron: bool = False,
avg: bool = False,
)
Bases: Module
Poisson negative log-likelihood loss for 3D (batch, time, neurons) predictions.
Handles temporal lag between target and output automatically.
Source code in openretina/modules/losses/poisson.py
10 11 12 13 14 | |
L1PoissonLoss3d
L1PoissonLoss3d(
bias: float = 1e-16,
per_neuron: bool = False,
avg: bool = False,
gamma_l1: float = 0.001,
)
Bases: Module
Poisson Loss for 3D data with L1 regularization on the output. Useful for models predicting sparse firing rates.
Source code in openretina/modules/losses/poisson.py
35 36 37 38 39 40 | |
CelltypePoissonLoss3d
CelltypePoissonLoss3d(
bias: float = 1e-16,
per_neuron: bool = False,
avg: bool = False,
)
Bases: Module
Poisson loss with inverse-frequency weighting by cell type, so under-represented types contribute equally.
Source code in openretina/modules/losses/poisson.py
65 66 67 68 69 | |
Correlation Losses
CorrelationLoss3d
CorrelationLoss3d(
bias: float = 1e-16,
per_neuron: bool = False,
avg: bool = False,
)
Bases: Module
Negative Pearson correlation loss for 3D (batch, time, neurons) data.
Returns negated correlation so minimizing the loss maximizes correlation.
Source code in openretina/modules/losses/correlation.py
11 12 13 14 15 16 17 18 | |
CelltypeCorrelationLoss3d
CelltypeCorrelationLoss3d(
bias: float = 1e-16,
per_neuron: bool = False,
avg: bool = False,
)
Bases: Module
Correlation loss with inverse-frequency weighting by cell type.
Source code in openretina/modules/losses/correlation.py
43 44 45 46 47 48 49 | |
ScaledCorrelationLoss3d
ScaledCorrelationLoss3d(
bias=1e-16, scale=30, per_neuron=False, avg=False
)
Bases: Module
Correlation loss computed over non-overlapping temporal windows, then averaged.
The scale parameter sets the window size in frames.
Source code in openretina/modules/losses/correlation.py
117 118 119 120 121 122 123 124 | |
MSE Loss
MSE3d
MSE3d()
Bases: Module
Mean squared error loss for 3D (batch, time, neurons) predictions. Handles temporal lag between target and output.
Source code in openretina/modules/losses/mse.py
8 9 | |