Skip to content

Core Modules

Feature extraction modules that process spatio-temporal visual input into learned representations.

SimpleCoreWrapper

The primary convolutional core used in most models. Stacks spatio-temporal separable Conv3D layers with configurable regularization (Laplace, group sparsity, temporal smoothness).

SimpleCoreWrapper

SimpleCoreWrapper(
    channels: tuple[int, ...],
    temporal_kernel_sizes: tuple[int, ...],
    spatial_kernel_sizes: tuple[int, ...],
    gamma_input: float,
    gamma_temporal: float,
    gamma_in_sparse: float,
    gamma_hidden: float,
    dropout_rate: float = 0.0,
    cut_first_n_frames: int = 30,
    maxpool_every_n_layers: int | None = None,
    downsample_input_kernel_size: tuple[int, int, int]
    | None = None,
    input_padding: bool
    | int
    | str
    | tuple[int, int, int] = False,
    hidden_padding: bool
    | int
    | str
    | tuple[int, int, int] = True,
    color_squashing_weights: tuple[float, ...]
    | None = None,
    convolution_type: str = "custom_separable",
    n_neurons_dict: dict[str, int] | None = None,
)

Bases: Core

Source code in openretina/modules/core/base_core.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def __init__(
    self,
    channels: tuple[int, ...],
    temporal_kernel_sizes: tuple[int, ...],
    spatial_kernel_sizes: tuple[int, ...],
    gamma_input: float,
    gamma_temporal: float,
    gamma_in_sparse: float,
    gamma_hidden: float,
    dropout_rate: float = 0.0,
    cut_first_n_frames: int = 30,
    maxpool_every_n_layers: int | None = None,
    downsample_input_kernel_size: tuple[int, int, int] | None = None,
    input_padding: bool | int | str | tuple[int, int, int] = False,
    hidden_padding: bool | int | str | tuple[int, int, int] = True,
    color_squashing_weights: tuple[float, ...] | None = None,
    convolution_type: str = "custom_separable",
    n_neurons_dict: dict[str, int] | None = None,  # for compatibility
):
    # Input validation
    if len(channels) < 2:
        raise ValueError(f"At least two channels required (input and output channel), {channels=}")
    if len(temporal_kernel_sizes) != len(channels) - 1:
        raise ValueError(
            f"{len(channels) - 1} layers, but only {len(temporal_kernel_sizes)} "
            f"temporal kernel sizes. {channels=} {temporal_kernel_sizes=}"
        )
    if len(temporal_kernel_sizes) != len(spatial_kernel_sizes):
        raise ValueError(
            f"Temporal and spatial kernel sizes must have the same length."
            f"{temporal_kernel_sizes=} {spatial_kernel_sizes=}"
        )
    if color_squashing_weights is not None and channels[0] != 1:
        raise ValueError(
            f"Number of input channels (set to {channels[0]}) must be 1 when squashing multi-channel (color)\
                  to single-channel (greyscale) input."
        )

    super().__init__()
    self.convolution_type = convolution_type
    self.gamma_input = gamma_input
    self.gamma_temporal = gamma_temporal
    self.gamma_in_sparse = gamma_in_sparse
    self.gamma_hidden = gamma_hidden
    self._cut_first_n_frames = cut_first_n_frames
    self._downsample_input_kernel_size = (
        list(downsample_input_kernel_size) if downsample_input_kernel_size is not None else None
    )
    self.color_squashing_weights = color_squashing_weights
    if self._cut_first_n_frames and not input_padding:
        warnings.warn(
            (
                "Cutting frames from the core output can lead to unexpected results if the input is not padded."
                f"{self._cut_first_n_frames=}, {input_padding=}. Double check the core output shape."
            ),
            UserWarning,
            stacklevel=2,
        )

    self._input_weights_regularizer_spatial = FlatLaplaceL23dnorm(padding=0)
    self._input_weights_regularizer_temporal = Laplace1d(padding=0, persistent_buffer=False)

    self.features = torch.nn.Sequential()
    self.color_squashing_layer = (
        WeightedChannelSumLayer(self.color_squashing_weights) if self.color_squashing_weights is not None else None
    )

    for layer_id, (num_in_channels, num_out_channels) in enumerate(zip(channels[:-1], channels[1:], strict=True)):
        layer: dict[str, torch.nn.Module] = OrderedDict()
        padding_to_use = input_padding if layer_id == 0 else hidden_padding
        # explicitly check against bools as the type can also be an int or a tuple
        if padding_to_use is True:
            padding: str | int | tuple[int, int, int] = "same"
        elif padding_to_use is False:
            padding = 0
        else:
            padding = padding_to_use

        conv_class = get_conv_class(self.convolution_type)
        layer["conv"] = conv_class(
            num_in_channels,
            num_out_channels,
            log_speed_dict={},
            temporal_kernel_size=temporal_kernel_sizes[layer_id],
            spatial_kernel_size=spatial_kernel_sizes[layer_id],
            bias=False,
            padding=padding,
        )

        layer["norm"] = torch.nn.BatchNorm3d(num_out_channels, momentum=0.1, affine=True)
        layer["bias"] = Bias3DLayer(num_out_channels)
        layer["nonlin"] = torch.nn.ELU()
        if dropout_rate > 0.0:
            layer["dropout"] = torch.nn.Dropout3d(p=dropout_rate)
        if maxpool_every_n_layers is not None and (layer_id % maxpool_every_n_layers) == 0:
            layer["pool"] = torch.nn.MaxPool3d((1, 2, 2))
        self.features.add_module(f"layer{layer_id}", torch.nn.Sequential(layer))  # type: ignore

forward

forward(input_: Tensor) -> Tensor
Source code in openretina/modules/core/base_core.py
161
162
163
164
165
166
167
168
169
170
171
def forward(self, input_: torch.Tensor) -> torch.Tensor:
    if self.color_squashing_layer is not None:
        input_ = self.color_squashing_layer(input_)

    if self._downsample_input_kernel_size is not None:
        input_ = torch.nn.functional.avg_pool3d(input_, kernel_size=self._downsample_input_kernel_size)  # type: ignore

    res = self.features(input_)
    # To keep compatibility with hoefling model scores
    res_cut = res[:, :, self._cut_first_n_frames :, :, :]
    return res_cut

regularizer

regularizer() -> Tensor
Source code in openretina/modules/core/base_core.py
219
220
221
222
223
224
225
226
227
228
229
230
def regularizer(self) -> torch.Tensor:
    res: torch.Tensor = 0.0  # type: ignore
    for weight, reg_fn in [
        (self.gamma_input, self.spatial_laplace),
        (self.gamma_hidden, self.group_sparsity),
        (self.gamma_temporal, self.temporal_smoothness),
        (self.gamma_in_sparse, self.group_sparsity_0),
    ]:
        # lazy calculation of regularization functions
        if weight != 0.0:
            res += weight * reg_fn()
    return res

save_weight_visualizations

save_weight_visualizations(
    folder_path: str,
    file_format: str = "jpg",
    state_suffix: str = "",
) -> None
Source code in openretina/modules/core/base_core.py
239
240
241
242
243
def save_weight_visualizations(self, folder_path: str, file_format: str = "jpg", state_suffix: str = "") -> None:
    for i, layer in enumerate(self.features):
        output_dir = os.path.join(folder_path, f"weights_layer_{i}")
        os.makedirs(output_dir, exist_ok=True)
        layer.conv.save_weight_visualizations(output_dir, file_format, state_suffix)  # type: ignore

DummyCore

DummyCore

DummyCore(cut_first_n_frames: int | None = None, **kwargs)

Bases: Core

A dummy core that does nothing. Used for readout only models, like the LNP model.

Source code in openretina/modules/core/base_core.py
251
252
253
def __init__(self, cut_first_n_frames: int | None = None, **kwargs):
    super().__init__()
    self._cut_first_n_frames = cut_first_n_frames

ConvGRUCore

A recurrent core using convolutional GRU cells for temporal processing.

ConvGRUCore

ConvGRUCore(
    channels: tuple[int, ...],
    n_neurons_dict: dict[str, int] | None = None,
    temporal_kernel_sizes=(21,),
    spatial_kernel_sizes=(14,),
    gamma_hidden: float = 0.0,
    gamma_input: float = 0.0,
    gamma_in_sparse: float = 0.0,
    gamma_temporal: float = 0.0,
    final_nonlinearity: bool = True,
    bias: bool = True,
    input_padding: bool = False,
    hidden_padding: bool = True,
    batch_norm: bool = True,
    batch_norm_scale: bool = True,
    batch_norm_momentum: float = 0.1,
    laplace_padding: int | None = 0,
    batch_adaptation: bool = False,
    use_avg_reg: bool = False,
    nonlinearity: str = "ELU",
    conv_type: str = "custom_separable",
    use_gru: bool = False,
    use_projections: bool = False,
    gru_kwargs: dict[str, int | float] | None = None,
    **kwargs,
)

Bases: Core3d, Module

Source code in openretina/modules/core/gru_core.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def __init__(
    self,
    channels: tuple[int, ...],
    n_neurons_dict: dict[str, int] | None = None,
    temporal_kernel_sizes=(21,),
    spatial_kernel_sizes=(14,),
    gamma_hidden: float = 0.0,
    gamma_input: float = 0.0,
    gamma_in_sparse: float = 0.0,
    gamma_temporal: float = 0.0,
    final_nonlinearity: bool = True,
    bias: bool = True,
    input_padding: bool = False,
    hidden_padding: bool = True,
    batch_norm: bool = True,
    batch_norm_scale: bool = True,
    batch_norm_momentum: float = 0.1,
    laplace_padding: int | None = 0,
    batch_adaptation: bool = False,
    use_avg_reg: bool = False,
    nonlinearity: str = "ELU",
    conv_type: str = "custom_separable",
    use_gru: bool = False,
    use_projections: bool = False,
    gru_kwargs: dict[str, int | float] | None = None,
    **kwargs,
):
    super().__init__()
    # Set regularizers
    self._input_weights_regularizer_spatial = FlatLaplaceL23dnorm(padding=laplace_padding)
    self._input_weights_regularizer_temporal = TimeLaplaceL23dnorm(padding=laplace_padding)

    # Get convolution class
    self.conv_class = get_conv_class(conv_type)

    if n_neurons_dict is None:
        n_neurons_dict = {}
        if batch_adaptation:
            raise ValueError(
                "If batch_adaptation is True, n_neurons_dict must be provided to "
                "learn the adaptation terms per session."
            )

    self.gamma_input = gamma_input
    self.gamma_in_sparse = gamma_in_sparse
    self.gamma_hidden = gamma_hidden
    self.gamma_temporal = gamma_temporal
    self.input_channels = channels[0]
    self.hidden_channels = channels[1:]
    self.layers = len(self.hidden_channels)
    self.use_avg_reg = use_avg_reg

    if not isinstance(temporal_kernel_sizes, (list, tuple)):
        temporal_kernel_sizes = [temporal_kernel_sizes] * self.layers
    if not isinstance(spatial_kernel_sizes, (list, tuple)):
        spatial_kernel_sizes = [spatial_kernel_sizes] * self.layers

    self.features = nn.Sequential()

    # Log speed dictionary
    log_speed_dict = self.generate_log_speed_dict(n_neurons_dict, batch_adaptation) if batch_adaptation else {}

    # Padding logic
    self.input_pad, self.hidden_pad = self.calculate_padding(input_padding, hidden_padding, spatial_kernel_sizes)

    # Initialize layers, including projection if applicable
    self.initialize_layers(
        self.input_channels,
        self.hidden_channels,
        temporal_kernel_sizes,
        spatial_kernel_sizes,
        log_speed_dict,
        batch_norm,
        batch_norm_momentum,
        bias,
        batch_norm_scale,
        final_nonlinearity,
        self.input_pad,
        self.hidden_pad,
        nonlinearity,
        use_projections,
    )

    self.apply(self.init_conv)

    # GRU integration
    if use_gru:
        print("Using GRU")
        self.features.add_module("gru", GRU_Module(input_channels=self.hidden_channels[-1], **gru_kwargs))  # type: ignore

forward

forward(input_, data_key=None)
Source code in openretina/modules/core/gru_core.py
113
114
115
116
117
118
119
120
121
122
123
124
def forward(self, input_, data_key=None):
    ret = []
    do_skip = False
    for layer_num, feat in enumerate(self.features):
        input_ = feat(
            (
                torch.cat(ret[-min(self.skip, layer_num) :], dim=1) if do_skip else input_,
                data_key,
            )
        )

    return input_

regularizer

regularizer()
Source code in openretina/modules/core/gru_core.py
269
270
271
272
273
274
275
276
277
278
def regularizer(self):
    if self.conv_class == STSeparableBatchConv3d:
        return (
            self.group_sparsity() * self.gamma_hidden
            + self.gamma_input * self.spatial_laplace()
            + self.gamma_temporal * self.temporal_smoothness()
            + self.group_sparsity0() * self.gamma_in_sparse
        )
    else:
        return 0