Skip to content

Stimulus Optimization

Optimizer

optimize_stimulus

optimize_stimulus(
    stimulus: Tensor,
    optimizer_init_fn: Callable[[list[Tensor]], Optimizer],
    objective_object,
    optimization_stopper: OptimizationStopper,
    stimulus_regularization_loss: list[
        StimulusRegularizationLoss
    ]
    | StimulusRegularizationLoss
    | None = None,
    stimulus_postprocessor: list[StimulusPostprocessor]
    | StimulusPostprocessor
    | None = None,
) -> None

Optimize a stimulus to maximize a given objective while minimizing a regularizing function. The stimulus is modified in place.

Source code in openretina/insilico/stimulus_optimization/optimizer.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def optimize_stimulus(
    stimulus: Tensor,
    optimizer_init_fn: Callable[[list[torch.Tensor]], torch.optim.Optimizer],
    objective_object,
    optimization_stopper: OptimizationStopper,
    stimulus_regularization_loss: list[StimulusRegularizationLoss] | StimulusRegularizationLoss | None = None,
    stimulus_postprocessor: list[StimulusPostprocessor] | StimulusPostprocessor | None = None,
) -> None:
    """
    Optimize a stimulus to maximize a given objective while minimizing a regularizing function.
    The stimulus is modified in place.
    """
    optimizer = optimizer_init_fn([stimulus])

    for _ in range(optimization_stopper.max_iterations):
        objective = objective_object.forward(stimulus)
        # Maximizing the objective, minimizing the regularization loss
        loss = -objective
        for reg_loss_module in convert_to_list(stimulus_regularization_loss):
            regularization_loss = reg_loss_module.forward(stimulus)
            loss += regularization_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for postprocessor in convert_to_list(stimulus_postprocessor):
            stimulus.data = postprocessor.process(stimulus.data)
        if optimization_stopper.early_stop(float(loss.item())):
            break
    stimulus.detach_()  # Detach the tensor from the computation graph

Objective

ResponseReducer

ResponseReducer(axis: int = 0)

Bases: ABC

Abstract base class for reducing a response tensor (e.g. over time) to a scalar objective value.

Source code in openretina/insilico/stimulus_optimization/objective.py
14
15
def __init__(self, axis: int = 0):
    self._axis = axis

SliceMeanReducer

SliceMeanReducer(axis: int, start: int, length: int)

Bases: ResponseReducer

Reduce by averaging over a temporal slice starting at start for length frames.

Source code in openretina/insilico/stimulus_optimization/objective.py
31
32
33
34
35
def __init__(self, axis: int, start: int, length: int):
    """Reduce by averaging over a temporal slice starting at `start` for `length` frames."""
    super().__init__(axis)
    self.start = start
    self.length = length

AbstractObjective

AbstractObjective(model, data_key: str | None)

Bases: ABC

Base class for stimulus optimization objectives. Wraps a model and optional data_key for forward passes.

Source code in openretina/insilico/stimulus_optimization/objective.py
59
60
61
def __init__(self, model, data_key: str | None):
    self._model = model
    self._data_key = data_key

IncreaseObjective

IncreaseObjective(
    model,
    neuron_indices: list[int] | int,
    data_key: str | None,
    response_reducer: ResponseReducer,
)

Bases: AbstractObjective

Objective that maximizes the mean response of specified neurons.

Source code in openretina/insilico/stimulus_optimization/objective.py
79
80
81
82
def __init__(self, model, neuron_indices: list[int] | int, data_key: str | None, response_reducer: ResponseReducer):
    super().__init__(model, data_key)
    self._neuron_indices = [neuron_indices] if isinstance(neuron_indices, int) else neuron_indices
    self._response_reducer = response_reducer

InnerNeuronVisualizationObjective

InnerNeuronVisualizationObjective(
    model,
    data_key: str | None,
    response_reducer: ResponseReducer,
)

Bases: AbstractObjective

Objective for visualizing feature preferences of individual channels in intermediate model layers via forward hooks.

Source code in openretina/insilico/stimulus_optimization/objective.py
121
122
123
124
125
126
def __init__(self, model, data_key: str | None, response_reducer: ResponseReducer):
    super().__init__(model, data_key)
    self.features_dict = self.hook_model(model)
    self._response_reducer = response_reducer
    self.layer_name = ""
    self.channel_id = -1

ContrastiveNeuronObjective

ContrastiveNeuronObjective(
    model,
    on_cluster_idc: list[int],
    off_cluster_idc_list: list[list[int]],
    data_key: str | None,
    response_reducer: ResponseReducer,
    temperature: float = 1.6,
)

Bases: AbstractObjective

Objective described in [Most discriminative stimuli for functional cell type clustering] (https://openreview.net/forum?id=9W6KaAcYlr)

Source code in openretina/insilico/stimulus_optimization/objective.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    model,
    on_cluster_idc: list[int],
    off_cluster_idc_list: list[list[int]],
    data_key: str | None,
    response_reducer: ResponseReducer,
    temperature: float = 1.6,
):
    super().__init__(model, data_key)
    self._on_cluster_idc = on_cluster_idc
    self._off_cluster_idc_list = off_cluster_idc_list
    self._response_reducer = response_reducer
    self._temperature = temperature

Regularizer

StimulusRegularizationLoss

Base class for regularization losses applied to the optimized stimulus. Default returns 0.

RangeRegularizationLoss

RangeRegularizationLoss(
    min_max_values: Iterable[
        tuple[float | None, float | None]
    ],
    max_norm: float | None,
    factor: float = 1.0,
)

Bases: StimulusRegularizationLoss

Penalizes stimulus values outside specified per-channel min/max ranges and optionally constrains total norm.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    min_max_values: Iterable[tuple[float | None, float | None]],
    max_norm: float | None,
    factor: float = 1.0,
):
    self._min_max_values = list(min_max_values)
    self._max_norm = max_norm
    self._factor = factor

forward

forward(stimulus: Tensor) -> Tensor

Penalizes the stimulus if it is outside the range defined by min_max_values.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def forward(self, stimulus: torch.Tensor) -> torch.Tensor:
    """Penalizes the stimulus if it is outside the range defined by min_max_values."""
    loss: torch.Tensor = 0.0  # type: ignore
    for i, (min_val, max_val) in enumerate(self._min_max_values):
        stimulus_i = stimulus[:, i]
        if min_val is not None:
            loss += torch.sum(torch.relu(min_val - stimulus_i))
        if max_val is not None:
            loss += torch.sum(torch.relu(stimulus_i - max_val))

    if self._max_norm is not None:
        # Add a loss such that the norm of the stimulus is lower than max_norm
        norm_penalty = torch.relu(torch.norm(stimulus) - self._max_norm)
        loss += norm_penalty

    loss *= self._factor
    return loss

StimulusPostprocessor

Base class for stimulus clippers.

process

process(x: Tensor) -> Tensor

x.shape: batch x channels x time x n_rows x n_cols

Source code in openretina/insilico/stimulus_optimization/regularizer.py
60
61
62
def process(self, x: torch.Tensor) -> torch.Tensor:
    """x.shape: batch x channels x time x n_rows x n_cols"""
    return x

ChangeNormJointlyClipRangeSeparately

ChangeNormJointlyClipRangeSeparately(
    min_max_values: Iterable[
        tuple[Optional[float], Optional[float]]
    ],
    norm: float | None,
)

Bases: StimulusPostprocessor

First change the norm and afterward clip the value of x to some specified range

Source code in openretina/insilico/stimulus_optimization/regularizer.py
68
69
70
71
72
73
74
def __init__(
    self,
    min_max_values: Iterable[tuple[Optional[float], Optional[float]]],
    norm: float | None,
):
    self._norm = norm
    self._min_max_values = list(min_max_values)

TemporalGaussianLowPassFilterProcessor

TemporalGaussianLowPassFilterProcessor(
    sigma: float, kernel_size: int, device: str = "cpu"
)

Bases: StimulusPostprocessor

Uses a 1d Gaussian filter to convolve the stimulus over the temporal dimension. This acts as a low pass filter.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
107
108
109
110
111
112
113
114
def __init__(
    self,
    sigma: float,
    kernel_size: int,
    device: str = "cpu",
):
    kernel = _gaussian_1d_kernel(sigma, kernel_size)
    self._kernel = kernel.to(device)

process

process(
    x: Float[
        Tensor, "batch_dim channels time height width"
    ],
) -> Tensor

Apply a Gaussian low-pass filter to the stimulus tensor along the temporal dimension.

PARAMETER DESCRIPTION
x

Tensor of shape (batch_dim, channels, time_dim, height, width)

TYPE: Tensor

Returns: Tensor: The filtered stimulus tensor.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def process(self, x: Float[torch.Tensor, "batch_dim channels time height width"]) -> torch.Tensor:
    """
    Apply a Gaussian low-pass filter to the stimulus tensor along the temporal dimension.

    Arguments:
        x (Tensor): Tensor of shape (batch_dim, channels, time_dim, height, width)
    Returns:
        Tensor: The filtered stimulus tensor.
    """
    # Create the Gaussian kernel in the temporal dimension
    kernel = einops.repeat(self._kernel.to(x.device), "s -> c 1 s 1 1", c=x.shape[1])

    # Apply convolution in the temporal dimension (axis 2)
    # We need to ensure that the kernel is convolved only along the time dimension.
    filtered_stimulus = F.conv3d(x, kernel, padding="same", groups=x.shape[1])

    return filtered_stimulus

Optimization Stopper

OptimizationStopper

OptimizationStopper(max_iterations: Optional[int])

Base stopping criterion for stimulus optimization. Stops after max_iterations.

Source code in openretina/insilico/stimulus_optimization/optimization_stopper.py
 8
 9
10
11
12
def __init__(self, max_iterations: Optional[int]):
    if max_iterations is None:
        self.max_iterations = sys.maxsize
    else:
        self.max_iterations = max_iterations

EarlyStopper

EarlyStopper(
    max_iterations: Optional[int] = None,
    patience: int = 1,
    min_delta: float = 0.0,
)

Bases: OptimizationStopper

Stops optimization when loss stops improving for patience consecutive steps.

Source code in openretina/insilico/stimulus_optimization/optimization_stopper.py
21
22
23
24
25
26
def __init__(self, max_iterations: Optional[int] = None, patience: int = 1, min_delta: float = 0.0):
    super().__init__(max_iterations)
    self._patience = patience
    self._min_delta = min_delta
    self._counter = 0
    self._min_loss = float("inf")