Skip to content

Loss Functions

Loss functions for training retinal models. All losses handle temporal lag between predictions and targets automatically (the model may produce fewer time steps than the target due to temporal convolutions).

Poisson Losses

PoissonLoss3d

PoissonLoss3d(
    bias: float = 1e-16,
    per_neuron: bool = False,
    avg: bool = False,
)

Bases: Module

Poisson negative log-likelihood loss for 3D (batch, time, neurons) predictions.

Handles temporal lag between target and output automatically.

Source code in openretina/modules/losses/poisson.py
10
11
12
13
14
def __init__(self, bias: float = 1e-16, per_neuron: bool = False, avg: bool = False):
    super().__init__()
    self.bias = bias
    self.per_neuron = per_neuron
    self.avg = avg

L1PoissonLoss3d

L1PoissonLoss3d(
    bias: float = 1e-16,
    per_neuron: bool = False,
    avg: bool = False,
    gamma_l1: float = 0.001,
)

Bases: Module

Poisson Loss for 3D data with L1 regularization on the output. Useful for models predicting sparse firing rates.

Source code in openretina/modules/losses/poisson.py
35
36
37
38
39
40
def __init__(self, bias: float = 1e-16, per_neuron: bool = False, avg: bool = False, gamma_l1: float = 0.001):
    super().__init__()
    self.bias = bias
    self.per_neuron = per_neuron
    self.avg = avg
    self.gamma_l1 = gamma_l1

CelltypePoissonLoss3d

CelltypePoissonLoss3d(
    bias: float = 1e-16,
    per_neuron: bool = False,
    avg: bool = False,
)

Bases: Module

Poisson loss with inverse-frequency weighting by cell type, so under-represented types contribute equally.

Source code in openretina/modules/losses/poisson.py
65
66
67
68
69
def __init__(self, bias: float = 1e-16, per_neuron: bool = False, avg: bool = False):
    super().__init__()
    self.bias = bias
    self.per_neuron = per_neuron
    self.avg = avg

Correlation Losses

CorrelationLoss3d

CorrelationLoss3d(
    bias: float = 1e-16,
    per_neuron: bool = False,
    avg: bool = False,
)

Bases: Module

Negative Pearson correlation loss for 3D (batch, time, neurons) data.

Returns negated correlation so minimizing the loss maximizes correlation.

Source code in openretina/modules/losses/correlation.py
11
12
13
14
15
16
17
18
def __init__(self, bias: float = 1e-16, per_neuron: bool = False, avg: bool = False):
    super().__init__()
    self.eps = bias
    self.per_neuron = per_neuron
    self.avg = avg

    # Placeholder to store last-computed per-neuron correlations
    self.register_buffer("_per_neuron_correlations", torch.tensor([]), persistent=False)

CelltypeCorrelationLoss3d

CelltypeCorrelationLoss3d(
    bias: float = 1e-16,
    per_neuron: bool = False,
    avg: bool = False,
)

Bases: Module

Correlation loss with inverse-frequency weighting by cell type.

Source code in openretina/modules/losses/correlation.py
43
44
45
46
47
48
49
def __init__(self, bias: float = 1e-16, per_neuron: bool = False, avg: bool = False):
    super().__init__()
    self.eps = bias
    self.per_neuron = per_neuron
    self.avg = avg
    # Placeholder to store last-computed per-neuron correlations
    self.register_buffer("_per_neuron_correlations", torch.tensor([]), persistent=False)

ScaledCorrelationLoss3d

ScaledCorrelationLoss3d(
    bias=1e-16, scale=30, per_neuron=False, avg=False
)

Bases: Module

Correlation loss computed over non-overlapping temporal windows, then averaged.

The scale parameter sets the window size in frames.

Source code in openretina/modules/losses/correlation.py
117
118
119
120
121
122
123
124
def __init__(self, bias=1e-16, scale=30, per_neuron=False, avg=False):
    super().__init__()
    self.eps = bias
    self.scale = scale
    self.per_neuron = per_neuron
    self.avg = avg
    # Placeholder to store last-computed per-neuron correlations
    self.register_buffer("_per_neuron_correlations", torch.tensor([]), persistent=False)

MSE Loss

MSE3d

MSE3d()

Bases: Module

Mean squared error loss for 3D (batch, time, neurons) predictions. Handles temporal lag between target and output.

Source code in openretina/modules/losses/mse.py
8
9
def __init__(self):
    super().__init__()