Skip to content

Custom Layers

Specialized neural network layers used as building blocks in core and readout modules.

Convolution Layers

STSeparableBatchConv3d

Spatio-temporal separable 3D convolution with batch-indexed temporal kernels.

STSeparableBatchConv3d

STSeparableBatchConv3d(
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: int | None = None,
    stride: int | tuple[int, int, int] = 1,
    padding: int | str | tuple[int, ...] = 0,
    num_scans: int = 1,
    bias: bool = True,
    subsampling_factor: int = 3,
)

Bases: Module

Spatio-temporal separable convolution layer for processing 3D data.

This layer applies convolution separately in the spatial and temporal dimensions, which is efficient for spatio-temporal data like video or medical images.

ATTRIBUTE DESCRIPTION
in_channels

Number of input channels.

TYPE: int

out_channels

Number of output channels.

TYPE: int

temporal_kernel_size

Size of the kernel in the temporal dimension.

TYPE: int

spatial_kernel_size

Size of the kernel in the spatial dimensions.

TYPE: int

spatial_kernel_size2

Size of the kernel in the second spatial dimension.

TYPE: int

stride

Stride of the convolution.

TYPE: int

padding

Padding added to all sides of the input.

TYPE: int

num_scans

Number of scans for batch processing.

TYPE: int

bias

If True, adds a learnable bias to the output.

TYPE: bool

Initializes the STSeparableBatchConv3d layer.

PARAMETER DESCRIPTION
in_channels

Number of channels in the input.

TYPE: int

out_channels

Number of channels produced by the convolution.

TYPE: int

log_speed_dict

Dictionary mapping data keys to log speeds.

TYPE: dict

temporal_kernel_size

Size of the temporal kernel.

TYPE: int

spatial_kernel_size

Size of the spatial kernel.

TYPE: int

spatial_kernel_size2

Size of the second spatial dimension of the kernel.

TYPE: int DEFAULT: None

stride

Stride of the convolution. Defaults to 1.

TYPE: int DEFAULT: 1

padding

Zero-padding added to all sides of the input. Defaults to 0.

TYPE: int DEFAULT: 0

num_scans

Number of scans to process in batch. Defaults to 1.

TYPE: int DEFAULT: 1

bias

If True, adds a learnable bias to the output. Defaults to True.

TYPE: bool DEFAULT: True

Source code in openretina/modules/layers/convolutions.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: int | None = None,
    stride: int | tuple[int, int, int] = 1,
    padding: int | str | tuple[int, ...] = 0,
    num_scans: int = 1,
    bias: bool = True,
    subsampling_factor: int = 3,
):
    """
    Initializes the STSeparableBatchConv3d layer.

    Args:
        in_channels (int): Number of channels in the input.
        out_channels (int): Number of channels produced by the convolution.
        log_speed_dict (dict): Dictionary mapping data keys to log speeds.
        temporal_kernel_size (int): Size of the temporal kernel.
        spatial_kernel_size (int): Size of the spatial kernel.
        spatial_kernel_size2 (int, optional): Size of the second spatial dimension of the kernel.
        stride (int, optional): Stride of the convolution. Defaults to 1.
        padding (int, optional): Zero-padding added to all sides of the input. Defaults to 0.
        num_scans (int, optional): Number of scans to process in batch. Defaults to 1.
        bias (bool, optional): If True, adds a learnable bias to the output. Defaults to True.
    """
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.temporal_kernel_size = temporal_kernel_size
    self.spatial_kernel_size = spatial_kernel_size
    self.spatial_kernel_size2 = spatial_kernel_size2 if spatial_kernel_size2 is not None else spatial_kernel_size
    self.stride = stride
    self.padding = padding
    self.num_scans = num_scans
    self.subsampling_factor = subsampling_factor

    # Initialize temporal weights
    self.sin_weights, self.cos_weights = self.temporal_weights(
        temporal_kernel_size, in_channels, out_channels, subsampling_factor=self.subsampling_factor
    )

    # Initialize spatial weights
    self.weight_spatial = nn.Parameter(
        torch.randn(out_channels, in_channels, 1, self.spatial_kernel_size, self.spatial_kernel_size2) * 0.01
    )

    # Initialize bias if required
    self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
    # Initialize default log speed (batch adaptation term)
    self.register_buffer("_log_speed_default", torch.zeros(1))

    # Store log speeds for each data key
    for key, val in log_speed_dict.items():
        setattr(self, key, val)

forward

forward(input_: tuple[Tensor, str] | Tensor) -> Tensor

Forward pass of the STSeparableBatchConv3d layer.

PARAMETER DESCRIPTION
input_

Tuple containing the input tensor and the data key.

TYPE: tuple

RETURNS DESCRIPTION
Tensor

torch.Tensor: The output of the convolution.

Source code in openretina/modules/layers/convolutions.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def forward(self, input_: tuple[torch.Tensor, str] | torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the STSeparableBatchConv3d layer.

    Args:
        input_ (tuple): Tuple containing the input tensor and the data key.

    Returns:
        torch.Tensor: The output of the convolution.
    """
    if type(input_) is torch.Tensor:
        x = input_
        data_key: str | None = None
    else:
        x, data_key = input_

    # Compute temporal kernel based on the provided data key
    if data_key is None:
        log_speed = self._log_speed_default
    else:
        log_speed = getattr(self, "_".join(["log_speed", data_key]))
    self.weight_temporal = compute_temporal_kernel(
        log_speed,
        self.sin_weights,
        self.cos_weights,
        self.temporal_kernel_size,
        self.subsampling_factor,
    )

    # Assemble the complete weight tensor for convolution
    # o - output channels, i - input channels, t - temporal kernel size
    # x - empty dimension, h - spatial kernel size, w - second spatial kernel size
    self.weight = torch.einsum("oitxx,oixhw->oithw", self.weight_temporal, self.weight_spatial)

    # Perform the convolution
    self.conv = F.conv3d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding)
    return self.conv

TorchFullConv3D

TorchFullConv3D(
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: Optional[int] = None,
    stride: int = 1,
    padding: int = 0,
    bias: bool = True,
    num_scans=1,
)

Bases: Module

Source code in openretina/modules/layers/convolutions.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: Optional[int] = None,
    stride: int = 1,
    padding: int = 0,
    bias: bool = True,
    num_scans=1,
):
    super().__init__()
    # Store log speeds for each data key
    for key, val in log_speed_dict.items():
        setattr(self, key, val)

    if spatial_kernel_size2 is None:
        spatial_kernel_size2 = spatial_kernel_size

    # Initialize default log speed (batch adaptation term)
    self.register_buffer("_log_speed_default", torch.zeros(1))

    self.conv = nn.Conv3d(
        in_channels,
        out_channels,
        (temporal_kernel_size, spatial_kernel_size, spatial_kernel_size2),
        stride=stride,
        padding=padding,
        bias=bias,
    )

TorchSTSeparableConv3D

TorchSTSeparableConv3D(
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: Optional[int] = None,
    stride: int | tuple[int, int, int] = 1,
    padding: int | tuple[int, int, int] | str = 0,
    bias: bool = True,
    num_scans=1,
)

Bases: Module

Source code in openretina/modules/layers/convolutions.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    log_speed_dict: dict,
    temporal_kernel_size: int,
    spatial_kernel_size: int,
    spatial_kernel_size2: Optional[int] = None,
    stride: int | tuple[int, int, int] = 1,
    padding: int | tuple[int, int, int] | str = 0,
    bias: bool = True,
    num_scans=1,
):
    super().__init__()
    # Store log speeds for each data key
    for key, val in log_speed_dict.items():
        setattr(self, key, val)

    if spatial_kernel_size2 is None:
        spatial_kernel_size2 = spatial_kernel_size

    # Initialize default log speed (batch adaptation term)
    self.register_buffer("_log_speed_default", torch.zeros(1))

    self.space_conv = nn.Conv3d(
        in_channels,
        out_channels,
        (1, spatial_kernel_size, spatial_kernel_size2),
        stride=stride,
        padding=padding,
        bias=bias,
    )
    self.time_conv = nn.Conv3d(
        out_channels, out_channels, (temporal_kernel_size, 1, 1), stride=stride, padding=padding, bias=bias
    )

compute_temporal_kernel

compute_temporal_kernel(
    log_speed,
    sin_weights,
    cos_weights,
    length: int,
    subsampling_factor: int,
) -> Tensor

Computes the temporal kernel for the convolution.

PARAMETER DESCRIPTION
log_speed

Logarithm of the speed factor.

TYPE: Parameter

sin_weights

Sinusoidal weights.

TYPE: Parameter

cos_weights

Cosine weights.

TYPE: Parameter

length

Length of the temporal kernel.

TYPE: int

subsampling_factor

the factor by which to subsample the sin and cos weights

TYPE: int

RETURNS DESCRIPTION
Tensor

torch.Tensor: The temporal kernel.

Source code in openretina/modules/layers/convolutions.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def compute_temporal_kernel(log_speed, sin_weights, cos_weights, length: int, subsampling_factor: int) -> torch.Tensor:
    """
    Computes the temporal kernel for the convolution.

    Args:
        log_speed (torch.nn.Parameter): Logarithm of the speed factor.
        sin_weights (torch.nn.Parameter): Sinusoidal weights.
        cos_weights (torch.nn.Parameter): Cosine weights.
        length (int): Length of the temporal kernel.
        subsampling_factor (int): the factor by which to subsample the sin and cos weights

    Returns:
        torch.Tensor: The temporal kernel.
    """
    stretches = torch.exp(log_speed)
    sines, cosines = STSeparableBatchConv3d.temporal_basis(stretches, length, subsampling_factor)
    weights_temporal = torch.sum(sin_weights[:, :, :, None] * sines[None, None, ...], dim=2) + torch.sum(
        cos_weights[:, :, :, None] * cosines[None, None, ...], dim=2
    )
    return weights_temporal[..., None, None]

Regularizers

Laplace

Laplace(
    padding: int | None = None,
    filter_size: int = 3,
    persistent_buffer: bool = True,
)

Bases: Module

Laplace filter for a stack of data. Utilized as the input weight regularizer.

Laplace filter for a stack of data

Source code in openretina/modules/layers/regularizers.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    padding: int | None = None,
    filter_size: int = 3,
    persistent_buffer: bool = True,
):
    """Laplace filter for a stack of data"""

    super().__init__()
    if filter_size == 3:
        kernel = LAPLACE_3x3
    elif filter_size == 5:
        kernel = LAPLACE_5x5
    elif filter_size == 7:
        kernel = LAPLACE_7x7
    else:
        raise ValueError(f"Unsupported filter size {filter_size}")

    self.register_buffer("filter", torch.from_numpy(kernel), persistent=persistent_buffer)
    self.padding_size = kernel.shape[-1] // 2 if padding is None else padding

LaplaceL2norm

LaplaceL2norm(padding=None)

Bases: Module

Normalized Laplace regularizer for a 2D convolutional layer. returns |laplace(filters)| / |filters|

Source code in openretina/modules/layers/regularizers.py
202
203
204
def __init__(self, padding=None):
    super().__init__()
    self.laplace = Laplace(padding=padding)

FlatLaplaceL23dnorm

FlatLaplaceL23dnorm(padding: int | None = None)

Bases: Module

Normalized Laplace regularizer for the spatial component of a separable 3D convolutional layer. returns |laplace(filters)| / |filters|

Source code in openretina/modules/layers/regularizers.py
153
154
155
def __init__(self, padding: int | None = None):
    super().__init__()
    self.laplace = Laplace(padding=padding)

GaussianLaplaceL2

GaussianLaplaceL2(kernel, padding=None)

Bases: Module

Laplace regularizer, with a Gaussian mask, for a single 2D convolutional layer.

PARAMETER DESCRIPTION
kernel

Size of the convolutional kernel of the filter that is getting regularized

padding

Controls the amount of zero-padding for the convolution operation.

TYPE: int DEFAULT: None

Source code in openretina/modules/layers/regularizers.py
173
174
175
176
177
178
179
180
181
182
183
184
def __init__(self, kernel, padding=None):
    """
    Args:
        kernel: Size of the convolutional kernel of the filter that is getting regularized
        padding (int): Controls the amount of zero-padding for the convolution operation.
    """
    super().__init__()

    self.laplace = Laplace(padding=padding)
    self.kernel = (kernel, kernel) if isinstance(kernel, int) else kernel
    sigma = min(*self.kernel) / 4
    self.gaussian2d = torch.from_numpy(gaussian2d(size=(*self.kernel,), sigma=sigma))

Scaling Layers

Bias3DLayer

Bias3DLayer(channels: int, initial: float = 0.0, **kwargs)

Bases: Module

Source code in openretina/modules/layers/scaling.py
6
7
8
9
def __init__(self, channels: int, initial: float = 0.0, **kwargs):
    super().__init__(**kwargs)

    self.bias = torch.nn.Parameter(torch.empty((1, channels, 1, 1, 1)).fill_(initial))

Scale2DLayer

Scale2DLayer(
    num_channels: int, initial: float = 1.0, **kwargs
)

Bases: Module

Source code in openretina/modules/layers/scaling.py
16
17
18
19
def __init__(self, num_channels: int, initial: float = 1.0, **kwargs):
    super().__init__(**kwargs)

    self.scale = torch.nn.Parameter(torch.empty((1, num_channels, 1, 1)).fill_(initial))

FiLM

FiLM(num_features: int, cond_dim: int)

Bases: Module

FiLM (Feature-wise Linear Modulation) is a neural network module that applies conditional scaling and shifting to input features.

This module takes input features and a conditioning tensor, computes scaling (gamma) and shifting (beta) parameters from the conditioning tensor, and applies these parameters to the input features. The result is a modulated output that can adapt based on the provided conditions.

PARAMETER DESCRIPTION
num_features

The number of features in the input tensor.

TYPE: int

cond_dim

The dimensionality of the conditioning tensor.

TYPE: int

RETURNS DESCRIPTION
Tensor

The modulated output tensor after applying the scaling and shifting.

Source code in openretina/modules/layers/scaling.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def __init__(self, num_features: int, cond_dim: int):
    super(FiLM, self).__init__()
    self.num_features = num_features
    self.cond_dim = cond_dim

    self.fc_gamma = nn.Linear(cond_dim, num_features)
    self.fc_beta = nn.Linear(cond_dim, num_features)

    # To avoid perturbations in early epochs, we set these defaults to match the identity function
    nn.init.normal_(self.fc_gamma.weight, mean=0.0, std=0.01)
    nn.init.constant_(self.fc_gamma.bias, 1.0)

    nn.init.normal_(self.fc_beta.weight, mean=0.0, std=0.01)
    nn.init.constant_(self.fc_beta.bias, 0.0)

GRU Layers

ConvGRUCell

ConvGRUCell(
    input_channels,
    rec_channels,
    input_kern: int,
    rec_kern: int,
    groups: int = 1,
    gamma_rec: int = 0,
    pad_input: bool = True,
    **kwargs,
)

Bases: Module

Convolutional GRU cell from: https://github.com/sinzlab/Sinz2018_NIPS/blob/master/nips2018/architectures/cores.py

Source code in openretina/modules/layers/gru.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    input_channels,
    rec_channels,
    input_kern: int,
    rec_kern: int,
    groups: int = 1,
    gamma_rec: int = 0,
    pad_input: bool = True,
    **kwargs,
):
    super().__init__()

    input_padding = input_kern // 2 if pad_input else 0
    rec_padding = rec_kern // 2

    self.rec_channels = rec_channels
    self._shrinkage = 0 if pad_input else input_kern - 1
    self.groups = groups

    self.gamma_rec = gamma_rec
    self.reset_gate_input = nn.Conv2d(
        input_channels,
        rec_channels,
        input_kern,
        padding=input_padding,
        groups=self.groups,
    )
    self.reset_gate_hidden = nn.Conv2d(
        rec_channels,
        rec_channels,
        rec_kern,
        padding=rec_padding,
        groups=self.groups,
    )

    self.update_gate_input = nn.Conv2d(
        input_channels,
        rec_channels,
        input_kern,
        padding=input_padding,
        groups=self.groups,
    )
    self.update_gate_hidden = nn.Conv2d(
        rec_channels,
        rec_channels,
        rec_kern,
        padding=rec_padding,
        groups=self.groups,
    )

    self.out_gate_input = nn.Conv2d(
        input_channels,
        rec_channels,
        input_kern,
        padding=input_padding,
        groups=self.groups,
    )
    self.out_gate_hidden = nn.Conv2d(
        rec_channels,
        rec_channels,
        rec_kern,
        padding=rec_padding,
        groups=self.groups,
    )

    self.apply(self.init_conv)
    self.register_parameter("_prev_state", None)

GRU_Module

GRU_Module(
    input_channels,
    rec_channels,
    input_kern,
    rec_kern,
    groups: int = 1,
    gamma_rec: int = 0,
    pad_input: bool = True,
    **kwargs,
)

Bases: Module

A GRU module for video data to add between the core and the readout. Receives as input the output of a 3Dcore. Expected dimensions: - (Batch, Channels, Frames, Height, Width) or (Channels, Frames, Height, Width) The input is fed sequentially to a convolutional GRU cell, based on the frames channel. The output has the same dimensions as the input.

Source code in openretina/modules/layers/gru.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def __init__(
    self,
    input_channels,
    rec_channels,
    input_kern,
    rec_kern,
    groups: int = 1,
    gamma_rec: int = 0,
    pad_input: bool = True,
    **kwargs,
):
    """
    A GRU module for video data to add between the core and the readout.
    Receives as input the output of a 3Dcore. Expected dimensions:
        - (Batch, Channels, Frames, Height, Width) or (Channels, Frames, Height, Width)
    The input is fed sequentially to a convolutional GRU cell, based on the frames channel.
    The output has the same dimensions as the input.
    """
    super().__init__()
    self.gru = ConvGRUCell(
        input_channels,
        rec_channels,
        input_kern,
        rec_kern,
        groups=groups,
        gamma_rec=gamma_rec,
        pad_input=pad_input,
    )

forward

forward(input_)

Forward pass definition based on https://github.com/sinzlab/Sinz2018_NIPS/blob/3a99f7a6985ae8dec17a5f2c54f550c2cbf74263/nips2018/architectures/cores.py#L556 Modified to also accept 4 dimensional inputs (assuming no batch dimension is provided).

Source code in openretina/modules/layers/gru.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def forward(self, input_):
    """
    Forward pass definition based on
    https://github.com/sinzlab/Sinz2018_NIPS/blob/3a99f7a6985ae8dec17a5f2c54f550c2cbf74263/nips2018/architectures/cores.py#L556
    Modified to also accept 4 dimensional inputs (assuming no batch dimension is provided).
    """
    x, data_key = input_
    if len(x.shape) not in [4, 5]:
        raise RuntimeError(
            f"Expected 4D (unbatched) or 5D (batched) input to ConvGRUCell, but got input of size: {x.shape}"
        )

    batch = True
    if len(x.shape) == 4:
        batch = False
        x = torch.unsqueeze(x, dim=0)

    states = []
    hidden = None
    frame_pos = 2

    for frame in range(x.shape[frame_pos]):
        slice_channel = [frame if frame_pos == i else slice(None) for i in range(len(x.shape))]
        hidden = self.gru(x[slice_channel], hidden)
        states.append(hidden)
    out = torch.stack(states, frame_pos)
    if not batch:
        out = torch.squeeze(out, dim=0)
    return out

Ensemble

EnsembleModel

EnsembleModel(*members: Module)

Bases: Module

An ensemble model consisting of several individual ensemble members.

ATTRIBUTE DESCRIPTION
*members

PyTorch modules representing the members of the ensemble.

Initializes EnsembleModel.

Source code in openretina/modules/layers/ensemble.py
13
14
15
16
def __init__(self, *members: torch.nn.Module):
    """Initializes EnsembleModel."""
    super().__init__()
    self.members = self._module_container_cls(members)

forward

forward(x: Tensor, *args, **kwargs) -> Tensor

Calculates the forward pass through the ensemble.

The input is passed through all individual members of the ensemble and their outputs are averaged.

PARAMETER DESCRIPTION
x

A tensor representing the input to the ensemble.

TYPE: Tensor

*args

Additional arguments will be passed to all ensemble members.

DEFAULT: ()

**kwargs

Additional keyword arguments will be passed to all ensemble members.

DEFAULT: {}

RETURNS DESCRIPTION
Tensor

A tensor representing the ensemble's output.

Source code in openretina/modules/layers/ensemble.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
    """Calculates the forward pass through the ensemble.

    The input is passed through all individual members of the ensemble and their outputs are averaged.

    Args:
        x: A tensor representing the input to the ensemble.
        *args: Additional arguments will be passed to all ensemble members.
        **kwargs: Additional keyword arguments will be passed to all ensemble members.

    Returns:
        A tensor representing the ensemble's output.
    """
    outputs = [m(x, *args, **kwargs) for m in self.members]
    mean_output = torch.stack(outputs, dim=0).mean(dim=0)
    return mean_output

Reducers

WeightedChannelSumLayer

WeightedChannelSumLayer(
    init_channel_weights: tuple[float, ...],
    trainable: bool = False,
)

Bases: Module

A layer that reduces multi-channel input to single-channel input by computing a weighted sum across the channel dimension using the provided weights. If the input only has a single channel, it will return it unchanged. One use case is of this layer is to convert a multi-color input into a grey-scale input to the model.

Source code in openretina/modules/layers/reducers.py
11
12
13
14
15
def __init__(self, init_channel_weights: tuple[float, ...], trainable: bool = False):
    super().__init__()

    # add the channel weights
    self.channel_weights = nn.Parameter(torch.tensor(init_channel_weights), requires_grad=trainable)

forward

forward(x: Tensor) -> Tensor

If the input is not already single-channel (i.e. greyscale), take a weighted sum over channels .

Source code in openretina/modules/layers/reducers.py
17
18
19
20
21
22
23
24
25
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """If the input is not already single-channel (i.e. greyscale), take a weighted sum over channels ."""

    if x.shape[1] == 1:
        return x

    weighted_input = x * (self.channel_weights.view(1, -1, 1, 1, 1))
    squashed = torch.sum(weighted_input, dim=1, keepdim=True)
    return squashed