Skip to content

Evaluation API Reference

Functions for evaluating trained retinal models, including correlation metrics, oracle computations, and variance analysis.

Metrics

metrics

correlation_numpy

correlation_numpy(
    y1: ndarray,
    y2: ndarray,
    axis: None | int | tuple[int, ...] = -1,
    eps: float = 1e-08,
    **kwargs,
) -> ndarray

Compute the correlation between two NumPy arrays along the specified dimension(s).

Source code in openretina/eval/metrics.py
12
13
14
15
16
17
18
19
def correlation_numpy(
    y1: np.ndarray, y2: np.ndarray, axis: None | int | tuple[int, ...] = -1, eps: float = 1e-8, **kwargs
) -> np.ndarray:
    """Compute the correlation between two NumPy arrays along the specified dimension(s)."""
    y1 = (y1 - y1.mean(axis=axis, keepdims=True)) / (y1.std(axis=axis, keepdims=True, ddof=0) + eps)
    y2 = (y2 - y2.mean(axis=axis, keepdims=True)) / (y2.std(axis=axis, keepdims=True, ddof=0) + eps)
    corr = (y1 * y2).mean(axis=axis, **kwargs)
    return corr

MSE_numpy

MSE_numpy(
    y1: ndarray,
    y2: ndarray,
    axis: None | int | tuple[int, ...] = -1,
    **kwargs,
) -> ndarray

Compute the mean squared error between two NumPy arrays along the specified dimension(s).

Source code in openretina/eval/metrics.py
22
23
24
def MSE_numpy(y1: np.ndarray, y2: np.ndarray, axis: None | int | tuple[int, ...] = -1, **kwargs) -> np.ndarray:
    """Compute the mean squared error between two NumPy arrays along the specified dimension(s)."""
    return ((y1 - y2) ** 2).mean(axis=axis, **kwargs)

poisson_loss_numpy

poisson_loss_numpy(
    y_true: ndarray,
    y_pred: ndarray,
    eps: float = 1e-08,
    mean_axis: None | int | tuple[int, ...] = -1,
) -> ndarray

Compute the Poisson loss between two NumPy arrays.

Source code in openretina/eval/metrics.py
27
28
29
30
31
def poisson_loss_numpy(
    y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8, mean_axis: None | int | tuple[int, ...] = -1
) -> np.ndarray:
    """Compute the Poisson loss between two NumPy arrays."""
    return (y_pred - y_true * np.log(y_pred + eps)).mean(axis=mean_axis)

model_predictions

model_predictions(
    loader, model: Module, data_key, device
) -> tuple[ndarray, ndarray]

computes model predictions for a given dataloader and a model Returns: target: ground truth, i.e. neuronal firing rates of the neurons output: responses as predicted by the network

Source code in openretina/eval/metrics.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def model_predictions(loader, model: torch.nn.Module, data_key, device) -> tuple[np.ndarray, np.ndarray]:
    """
    computes model predictions for a given dataloader and a model
    Returns:
        target: ground truth, i.e. neuronal firing rates of the neurons
        output: responses as predicted by the network
    """
    target, output = torch.empty(0), torch.empty(0)
    for *inputs, responses in loader[data_key]:  # tuple unpacking necessary for group assignments when present
        output = torch.cat(
            (output, (model(*tensors_to_device(inputs, device), data_key=data_key).detach().cpu())), dim=0
        )
        target = torch.cat((target, responses.detach().cpu()), dim=0)
    output_np = output.numpy()
    target_np = target.numpy()
    lag = target_np.shape[1] - output_np.shape[1]

    return target_np[:, lag:, ...], output_np

corr_stop

corr_stop(
    model: Module,
    loader,
    avg: bool = True,
    device: str = "cpu",
)

Returns either the average correlation of all neurons or the correlations per neuron. Gets called by early stopping and the model performance evaluation

Source code in openretina/eval/metrics.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def corr_stop(model: torch.nn.Module, loader, avg: bool = True, device: str = "cpu"):
    """
    Returns either the average correlation of all neurons or the correlations per neuron.
        Gets called by early stopping and the model performance evaluation
    """

    n_neurons, correlations_sum = 0, 0
    if not avg:
        all_correlations = np.array([])

    for data_key in loader:
        with eval_state(model):
            target, output = model_predictions(loader, model, data_key, device)

        ret = correlation_numpy(target, output, axis=0)

        if np.any(np.isnan(ret)):
            warnings.warn(f"{np.isnan(ret).mean() * 100}% NaNs ")
        ret[np.isnan(ret)] = 0

        if not avg:
            all_correlations = np.append(all_correlations, ret)
        else:
            n_neurons += output.shape[1]
            correlations_sum += ret.sum()

    corr_ret = correlations_sum / n_neurons if avg else all_correlations
    return corr_ret

corr_stop3d

corr_stop3d(
    model: Module,
    loader,
    avg: bool = True,
    device: str = "cpu",
)

Returns either the average correlation of all neurons or the correlations per neuron. Gets called by early stopping and the model performance evaluation

Source code in openretina/eval/metrics.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def corr_stop3d(model: torch.nn.Module, loader, avg: bool = True, device: str = "cpu"):
    """
    Returns either the average correlation of all neurons or the correlations per neuron.
        Gets called by early stopping and the model performance evaluation
    """

    n_neurons, correlations_sum = 0, 0
    if not avg:
        all_correlations = np.array([])

    for data_key in loader:
        with eval_state(model):
            target, output = model_predictions(loader, model, data_key, device)

        # Correlation over time axis (1)
        ret = correlation_numpy(target, output, axis=1)

        # Average over batches
        ret = ret.mean(axis=0)

        if np.any(np.isnan(ret)):
            warnings.warn(f"{np.isnan(ret).mean() * 100}% NaNs in corr_stop3d")
        ret[np.isnan(ret)] = 0

        if not avg:
            all_correlations = np.append(all_correlations, ret)
        else:
            n_neurons += output.shape[-1]
            correlations_sum += ret.sum()

    corr_ret = correlations_sum / n_neurons if avg else all_correlations
    return corr_ret

explainable_vs_total_var

explainable_vs_total_var(
    repeated_outputs: Float[
        ndarray, "frames repeats neurons"
    ],
    eps: float = 1e-09,
) -> tuple[
    Float[ndarray, " neurons"], Float[ndarray, " neurons"]
]

Adapted from neuralpredictors. Compute the ratio of explainable to total variance per neuron. See Cadena et al., 2019: https://doi.org/10.1371/journal.pcbi.1006897

PARAMETER DESCRIPTION
repeated_outputs

numpy array with shape (images/time, repeats, neurons) containing the responses.

TYPE: array

RETURNS DESCRIPTION
tuple

A tuple containing: - var_ratio (array): Ratio of explainable to total variance per neuron - explainable_var (array): Explainable variance for each neuron

TYPE: tuple[Float[ndarray, ' neurons'], Float[ndarray, ' neurons']]

Source code in openretina/eval/metrics.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def explainable_vs_total_var(
    repeated_outputs: Float[np.ndarray, "frames repeats neurons"], eps: float = 1e-9
) -> tuple[Float[np.ndarray, " neurons"], Float[np.ndarray, " neurons"]]:
    """
    Adapted from neuralpredictors.
    Compute the ratio of explainable to total variance per neuron.
    See Cadena et al., 2019: https://doi.org/10.1371/journal.pcbi.1006897

    Args:
        repeated_outputs (array): numpy array with shape (images/time, repeats, neurons) containing the responses.

    Returns:
        tuple: A tuple containing:
            - var_ratio (array): Ratio of explainable to total variance per neuron
            - explainable_var (array): Explainable variance for each neuron
    """
    total_var = np.var(repeated_outputs, axis=(0, 1), ddof=1)
    repeats_var = np.var(repeated_outputs, axis=1, ddof=1)
    noise_var = np.mean(repeats_var, axis=0)
    # Clip. In some bad cases, noise_var can be larger than total_var.
    explainable_var = np.clip(total_var - noise_var, eps, None)
    var_ratio = explainable_var / (total_var + eps)
    return var_ratio, explainable_var

feve

feve(
    targets: Float[ndarray, "frames repeats neurons"],
    predictions: Float[ndarray, "frames repeats neurons"]
    | Float[ndarray, "frames neurons"],
) -> Float[ndarray, " neurons"]

Adapted from neuralpredictors. Compute the fraction of explainable variance explained per neuron

PARAMETER DESCRIPTION
targets

Neuron responses (ground truth) over time / different images across repetitions.

TYPE: array - like

predictions

Model predictions to the repeated images, either including or excluding

TYPE: array - like

repetitions. Dimensions

np.array(images/time, num_repeats, num_neurons) or np.array(images/time, num_neurons)

Returns: FEVe (np.array): the fraction of explainable variance explained per neuron

Source code in openretina/eval/metrics.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def feve(
    targets: Float[np.ndarray, "frames repeats neurons"],
    predictions: Float[np.ndarray, "frames repeats neurons"] | Float[np.ndarray, "frames neurons"],
) -> Float[np.ndarray, " neurons"]:
    """
    Adapted from neuralpredictors.
    Compute the fraction of explainable variance explained per neuron

    Args:
        targets (array-like): Neuron responses (ground truth) over time / different images across repetitions.
        predictions (array-like): Model predictions to the repeated images, either including or excluding
        repetitions. Dimensions: np.array(images/time, num_repeats, num_neurons) or np.array(images/time, num_neurons)
    Returns:
        FEVe (np.array): the fraction of explainable variance explained per neuron

    """
    if len(targets.shape) != 3:
        raise ValueError(f"Targets must be 3d 'frames repeats neurons', but {targets.shape=}")

    if predictions.shape[1] != targets.shape[1] and predictions.ndim == 2:
        predictions = np.repeat(predictions[:, np.newaxis, :], targets.shape[1], axis=1)

    if targets.shape != predictions.shape:
        raise ValueError(
            f"Targets and predictions must have the same shape, got {targets.shape} and {predictions.shape}"
        )

    sum_square_res = [(target - prediction) ** 2 for target, prediction in zip(targets, predictions, strict=True)]
    sum_square_res = np.concatenate(sum_square_res, axis=0)

    var_ratio, explainable_var = explainable_vs_total_var(targets)
    # Invert the formula to get the noise variance
    total_var = explainable_var / var_ratio
    noise_var = total_var - explainable_var

    mse = np.mean(sum_square_res, axis=0)  # mean over time and reps
    fev_e = 1 - np.clip(mse - noise_var, 0, None) / explainable_var
    return np.clip(fev_e, 0, None)

crop_responses

crop_responses(
    responses: ndarray, predictions: ndarray
) -> tuple[ndarray, int]

Crop responses to match prediction length, accounting for temporal lag.

PARAMETER DESCRIPTION
responses

Array of responses, last axis is time.

TYPE: ndarray

predictions

Array of predictions, first axis is time.

TYPE: ndarray

RETURNS DESCRIPTION
tuple[ndarray, int]

Tuple of (cropped responses, lag).

Source code in openretina/eval/metrics.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def crop_responses(responses: np.ndarray, predictions: np.ndarray) -> tuple[np.ndarray, int]:
    """Crop responses to match prediction length, accounting for temporal lag.

    Args:
        responses: Array of responses, last axis is time.
        predictions: Array of predictions, first axis is time.

    Returns:
        Tuple of (cropped responses, lag).
    """

    lag = responses.shape[-1] - predictions.shape[0]
    if lag < 0:
        raise ValueError(f"Lag is negative: {lag}")
    return responses[..., lag:], lag

Oracles

oracles

oracle_corr_jackknife

oracle_corr_jackknife(
    repeated_responses: Float[
        ndarray, "frames repeats neurons"
    ],
    cut_first_n_frames: int | None = None,
) -> tuple[
    Float[ndarray, " neurons"],
    Float[ndarray, " frames repeats neurons"],
]

Adapted from neuralpredictors. Compute the oracle correlations per neuron by averaging over repeated responses in a leave one out fashion. Note that oracle_corr_jackknife underestimates the true oracle correlation.

PARAMETER DESCRIPTION
repeated_responses

numpy array with shape (images/time, repeats, neuron responses).

TYPE: array - like

cut_first_n_frames

if provided, indicated how many frames to cut from the repeated responses.

TYPE: int DEFAULT: None

RETURNS DESCRIPTION
tuple

A tuple containing: - oracle_score (array): Oracle correlation for each neuron - oracle (array): Oracle responses for each neuron

TYPE: tuple[Float[ndarray, ' neurons'], Float[ndarray, ' frames repeats neurons']]

Source code in openretina/eval/oracles.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def oracle_corr_jackknife(
    repeated_responses: Float[np.ndarray, "frames repeats neurons"],
    cut_first_n_frames: int | None = None,
) -> tuple[Float[np.ndarray, " neurons"], Float[np.ndarray, " frames repeats neurons"]]:
    """
    Adapted from neuralpredictors.
    Compute the oracle correlations per neuron by averaging over repeated responses in a leave one out fashion.
    Note that oracle_corr_jackknife underestimates the true oracle correlation.

    Args:
        repeated_responses (array-like): numpy array with shape (images/time, repeats, neuron responses).
        cut_first_n_frames (int): if provided, indicated how many frames to cut from the repeated responses.

    Returns:
        tuple: A tuple containing:
            - oracle_score (array): Oracle correlation for each neuron
            - oracle (array): Oracle responses for each neuron
    """
    if len(repeated_responses.shape) != 3:
        raise ValueError(f"Expected repeated responses to be 3d, but {repeated_responses.shape=}")

    repeated_responses = repeated_responses[cut_first_n_frames:, :, :]
    loo_oracles = []
    for response_t in repeated_responses:
        num_repeats = response_t.shape[0]
        # Compute the oracle by averaging over all repeats except the current one
        # (add all, subtract current, divide by num_repeats - 1)
        oracle_t = (response_t.sum(axis=0, keepdims=True) - response_t) / (num_repeats - 1)
        oracle_t = np.nan_to_num(oracle_t)
        loo_oracles.append(oracle_t)
    oracle = np.stack(loo_oracles)

    oracle_score = correlation_numpy(
        rearrange(repeated_responses, "t r n -> (t r) n"),
        rearrange(oracle, "t r n -> (t r) n", t=repeated_responses.shape[0]),
        axis=0,
    )

    return oracle_score, oracle

global_mean_oracle

global_mean_oracle(
    responses: Float[ndarray, "frames repeats neurons"]
    | Float[ndarray, "frames neurons"],
    cut_first_n_frames: int | None = None,
) -> Float[ndarray, " neurons"]

Compute the oracle correlation between each neuron's response and the global mean response.

The global mean oracle correlation represents how well each neuron's activity can be predicted by the average response across all neurons at each time point.

PARAMETER DESCRIPTION
responses

Neural responses array. Can be either: - 3D array of shape (frames, repeats, neurons) - 2D array of shape (frames, neurons) which will be treated as single repeat

TYPE: ndarray

return_oracle

If True, returns both correlation values and oracle responses. Defaults to False.

TYPE: bool

RETURNS DESCRIPTION
Float[ndarray, ' neurons']
  • 1D array of shape (neurons,) containing correlation values for each neuron
Note

The function automatically handles single-repeat data by adding a singleton dimension.

Source code in openretina/eval/oracles.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def global_mean_oracle(
    responses: Float[np.ndarray, "frames repeats neurons"] | Float[np.ndarray, "frames neurons"],
    cut_first_n_frames: int | None = None,
) -> Float[np.ndarray, " neurons"]:
    """
    Compute the oracle correlation between each neuron's response and the global mean response.

    The global mean oracle correlation represents how well each neuron's activity can be predicted
    by the average response across all neurons at each time point.

    Args:
        responses (np.ndarray): Neural responses array. Can be either:
            - 3D array of shape (frames, repeats, neurons)
            - 2D array of shape (frames, neurons) which will be treated as single repeat
        return_oracle (bool, optional): If True, returns both correlation values and oracle responses.
            Defaults to False.

    Returns:
        - 1D array of shape (neurons,) containing correlation values for each neuron

    Note:
        The function automatically handles single-repeat data by adding a singleton dimension.
    """
    if responses.ndim == 2:
        responses = responses[:, None, :]
    responses = responses[cut_first_n_frames:, :, :]

    global_mean_response = repeat(responses.mean(axis=2, keepdims=True), "t r _ -> t r n", n=responses.shape[2])

    oracle_mean_corr = correlation_numpy(
        rearrange(responses, "t r n -> (t r) n"),
        rearrange(global_mean_response, "t r n -> (t r) n", t=responses.shape[0]),
        axis=0,
    )

    return oracle_mean_corr