Skip to content

Stimulus Optimization

Optimizer

optimize_stimulus(stimulus, optimizer_init_fn, objective_object, optimization_stopper, stimulus_regularization_loss=None, stimulus_postprocessor=None)

Optimize a stimulus to maximize a given objective while minimizing a regularizing function. The stimulus is modified in place.

Source code in openretina/insilico/stimulus_optimization/optimizer.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def optimize_stimulus(
    stimulus: Tensor,
    optimizer_init_fn: Callable[[list[torch.Tensor]], torch.optim.Optimizer],
    objective_object,
    optimization_stopper: OptimizationStopper,
    stimulus_regularization_loss: list[StimulusRegularizationLoss] | StimulusRegularizationLoss | None = None,
    stimulus_postprocessor: list[StimulusPostprocessor] | StimulusPostprocessor | None = None,
) -> None:
    """
    Optimize a stimulus to maximize a given objective while minimizing a regularizing function.
    The stimulus is modified in place.
    """
    optimizer = optimizer_init_fn([stimulus])

    for _ in range(optimization_stopper.max_iterations):
        objective = objective_object.forward(stimulus)
        # Maximizing the objective, minimizing the regularization loss
        loss = -objective
        for reg_loss_module in convert_to_list(stimulus_regularization_loss):
            regularization_loss = reg_loss_module.forward(stimulus)
            loss += regularization_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for postprocessor in convert_to_list(stimulus_postprocessor):
            stimulus.data = postprocessor.process(stimulus.data)
        if optimization_stopper.early_stop(float(loss.item())):
            break
    stimulus.detach_()  # Detach the tensor from the computation graph

Objective

ContrastiveNeuronObjective

Bases: AbstractObjective

Objective described in [Most discriminative stimuli for functional cell type clustering] (https://openreview.net/forum?id=9W6KaAcYlr)

Source code in openretina/insilico/stimulus_optimization/objective.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class ContrastiveNeuronObjective(AbstractObjective):
    """Objective described in [Most discriminative stimuli for functional cell type clustering]
    (https://openreview.net/forum?id=9W6KaAcYlr)"""

    def __init__(
        self,
        model,
        on_cluster_idc: list[int],
        off_cluster_idc_list: list[list[int]],
        data_key: str | None,
        response_reducer: ResponseReducer,
        temperature: float = 1.6,
    ):
        super().__init__(model, data_key)
        self._on_cluster_idc = on_cluster_idc
        self._off_cluster_idc_list = off_cluster_idc_list
        self._response_reducer = response_reducer
        self._temperature = temperature

    @staticmethod
    def contrastive_objective(on_score: torch.Tensor, all_scores: torch.Tensor, temperature: float) -> torch.Tensor:
        t = temperature
        obj = (
            (1 / t) * on_score
            - torch.logsumexp((1 / t) * all_scores, dim=0)
            + torch.log(torch.tensor(all_scores.size(0)))
        )
        return obj

    def forward(self, stimulus: torch.Tensor) -> torch.Tensor:
        responses = self.model_forward(stimulus)
        score_per_neuron = self._response_reducer.forward(responses)

        on_score = score_per_neuron[self._on_cluster_idc].mean()
        off_scores = [score_per_neuron[idc].mean() for idc in self._off_cluster_idc_list]
        obj = self.contrastive_objective(
            on_score,
            torch.stack([on_score] + off_scores),
            self._temperature,
        )
        return obj

Regularizer

ChangeNormJointlyClipRangeSeparately

Bases: StimulusPostprocessor

First change the norm and afterward clip the value of x to some specified range

Source code in openretina/insilico/stimulus_optimization/regularizer.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class ChangeNormJointlyClipRangeSeparately(StimulusPostprocessor):
    """First change the norm and afterward clip the value of x to some specified range"""

    def __init__(
        self,
        min_max_values: Iterable[tuple[Optional[float], Optional[float]]],
        norm: float | None,
    ):
        self._norm = norm
        self._min_max_values = list(min_max_values)

    def process(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[1] == len(self._min_max_values), (
            f"Expected {len(self._min_max_values)} channels in dim 1, got {x.shape=}"
        )

        if self._norm is not None:
            # Re-normalize
            x_norm = torch.linalg.vector_norm(x.view(len(x), -1), dim=-1)
            renorm = x * (self._norm / x_norm).view(len(x), *[1] * (x.dim() - 1))
        else:
            renorm = x

        # Clip
        clipped_array = []
        for i, (min_val, max_val) in enumerate(self._min_max_values):
            clipped = renorm[:, i]
            if min_val is not None or max_val is not None:
                clipped = torch.clamp(clipped, min=min_val, max=max_val)
            clipped_array.append(clipped)
        result = torch.stack(clipped_array, dim=1)

        return result

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self._norm=}, {self._min_max_values=})"

RangeRegularizationLoss

Bases: StimulusRegularizationLoss

Source code in openretina/insilico/stimulus_optimization/regularizer.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RangeRegularizationLoss(StimulusRegularizationLoss):
    def __init__(
        self,
        min_max_values: Iterable[tuple[float | None, float | None]],
        max_norm: float | None,
        factor: float = 1.0,
    ):
        self._min_max_values = list(min_max_values)
        self._max_norm = max_norm
        self._factor = factor

    def forward(self, stimulus: torch.Tensor) -> torch.Tensor:
        """Penalizes the stimulus if it is outside the range defined by min_max_values."""
        loss: torch.Tensor = 0.0  # type: ignore
        for i, (min_val, max_val) in enumerate(self._min_max_values):
            stimulus_i = stimulus[:, i]
            if min_val is not None:
                loss += torch.sum(torch.relu(min_val - stimulus_i))
            if max_val is not None:
                loss += torch.sum(torch.relu(stimulus_i - max_val))

        if self._max_norm is not None:
            # Add a loss such that the norm of the stimulus is lower than max_norm
            norm_penalty = torch.relu(torch.norm(stimulus) - self._max_norm)
            loss += norm_penalty

        loss *= self._factor
        return loss

forward(stimulus)

Penalizes the stimulus if it is outside the range defined by min_max_values.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def forward(self, stimulus: torch.Tensor) -> torch.Tensor:
    """Penalizes the stimulus if it is outside the range defined by min_max_values."""
    loss: torch.Tensor = 0.0  # type: ignore
    for i, (min_val, max_val) in enumerate(self._min_max_values):
        stimulus_i = stimulus[:, i]
        if min_val is not None:
            loss += torch.sum(torch.relu(min_val - stimulus_i))
        if max_val is not None:
            loss += torch.sum(torch.relu(stimulus_i - max_val))

    if self._max_norm is not None:
        # Add a loss such that the norm of the stimulus is lower than max_norm
        norm_penalty = torch.relu(torch.norm(stimulus) - self._max_norm)
        loss += norm_penalty

    loss *= self._factor
    return loss

StimulusPostprocessor

Base class for stimulus clippers.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
53
54
55
56
57
58
class StimulusPostprocessor:
    """Base class for stimulus clippers."""

    def process(self, x: torch.Tensor) -> torch.Tensor:
        """x.shape: batch x channels x time x n_rows x n_cols"""
        return x

process(x)

x.shape: batch x channels x time x n_rows x n_cols

Source code in openretina/insilico/stimulus_optimization/regularizer.py
56
57
58
def process(self, x: torch.Tensor) -> torch.Tensor:
    """x.shape: batch x channels x time x n_rows x n_cols"""
    return x

TemporalGaussianLowPassFilterProcessor

Bases: StimulusPostprocessor

Uses a 1d Gaussian filter to convolve the stimulus over the temporal dimension. This acts as a low pass filter.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class TemporalGaussianLowPassFilterProcessor(StimulusPostprocessor):
    """Uses a 1d Gaussian filter to convolve the stimulus over the temporal dimension.
    This acts as a low pass filter."""

    def __init__(
        self,
        sigma: float,
        kernel_size: int,
        device: str = "cpu",
    ):
        kernel = _gaussian_1d_kernel(sigma, kernel_size)
        self._kernel = kernel.to(device)

    def process(self, x: Float[torch.Tensor, "batch_dim channels time height width"]) -> torch.Tensor:
        """
        Apply a Gaussian low-pass filter to the stimulus tensor along the temporal dimension.

        Arguments:
            x (Tensor): Tensor of shape (batch_dim, channels, time_dim, height, width)
        Returns:
            Tensor: The filtered stimulus tensor.
        """
        # Create the Gaussian kernel in the temporal dimension
        kernel = einops.repeat(self._kernel.to(x.device), "s -> c 1 s 1 1", c=x.shape[1])

        # Apply convolution in the temporal dimension (axis 2)
        # We need to ensure that the kernel is convolved only along the time dimension.
        filtered_stimulus = F.conv3d(x, kernel, padding="same", groups=x.shape[1])

        return filtered_stimulus

process(x)

Apply a Gaussian low-pass filter to the stimulus tensor along the temporal dimension.

Parameters:

Name Type Description Default
x Tensor

Tensor of shape (batch_dim, channels, time_dim, height, width)

required

Returns: Tensor: The filtered stimulus tensor.

Source code in openretina/insilico/stimulus_optimization/regularizer.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def process(self, x: Float[torch.Tensor, "batch_dim channels time height width"]) -> torch.Tensor:
    """
    Apply a Gaussian low-pass filter to the stimulus tensor along the temporal dimension.

    Arguments:
        x (Tensor): Tensor of shape (batch_dim, channels, time_dim, height, width)
    Returns:
        Tensor: The filtered stimulus tensor.
    """
    # Create the Gaussian kernel in the temporal dimension
    kernel = einops.repeat(self._kernel.to(x.device), "s -> c 1 s 1 1", c=x.shape[1])

    # Apply convolution in the temporal dimension (axis 2)
    # We need to ensure that the kernel is convolved only along the time dimension.
    filtered_stimulus = F.conv3d(x, kernel, padding="same", groups=x.shape[1])

    return filtered_stimulus

Optimization Stopper