def __init__(
self,
channels: tuple[int, ...],
n_neurons_dict: dict[str, int] | None = None,
temporal_kernel_sizes=(21,),
spatial_kernel_sizes=(14,),
gamma_hidden: float = 0.0,
gamma_input: float = 0.0,
gamma_in_sparse: float = 0.0,
gamma_temporal: float = 0.0,
final_nonlinearity: bool = True,
bias: bool = True,
input_padding: bool = False,
hidden_padding: bool = True,
batch_norm: bool = True,
batch_norm_scale: bool = True,
batch_norm_momentum: float = 0.1,
laplace_padding: int | None = 0,
batch_adaptation: bool = False,
use_avg_reg: bool = False,
nonlinearity: str = "ELU",
conv_type: str = "custom_separable",
use_gru: bool = False,
use_projections: bool = False,
gru_kwargs: dict[str, int | float] | None = None,
**kwargs,
):
super().__init__()
# Set regularizers
self._input_weights_regularizer_spatial = FlatLaplaceL23dnorm(padding=laplace_padding)
self._input_weights_regularizer_temporal = TimeLaplaceL23dnorm(padding=laplace_padding)
# Get convolution class
self.conv_class = get_conv_class(conv_type)
if n_neurons_dict is None:
n_neurons_dict = {}
if batch_adaptation:
raise ValueError(
"If batch_adaptation is True, n_neurons_dict must be provided to "
"learn the adaptation terms per session."
)
self.gamma_input = gamma_input
self.gamma_in_sparse = gamma_in_sparse
self.gamma_hidden = gamma_hidden
self.gamma_temporal = gamma_temporal
self.input_channels = channels[0]
self.hidden_channels = channels[1:]
self.layers = len(self.hidden_channels)
self.use_avg_reg = use_avg_reg
if not isinstance(temporal_kernel_sizes, (list, tuple)):
temporal_kernel_sizes = [temporal_kernel_sizes] * self.layers
if not isinstance(spatial_kernel_sizes, (list, tuple)):
spatial_kernel_sizes = [spatial_kernel_sizes] * self.layers
self.features = nn.Sequential()
# Log speed dictionary
log_speed_dict = self.generate_log_speed_dict(n_neurons_dict, batch_adaptation) if batch_adaptation else {}
# Padding logic
self.input_pad, self.hidden_pad = self.calculate_padding(input_padding, hidden_padding, spatial_kernel_sizes)
# Initialize layers, including projection if applicable
self.initialize_layers(
self.input_channels,
self.hidden_channels,
temporal_kernel_sizes,
spatial_kernel_sizes,
log_speed_dict,
batch_norm,
batch_norm_momentum,
bias,
batch_norm_scale,
final_nonlinearity,
self.input_pad,
self.hidden_pad,
nonlinearity,
use_projections,
)
self.apply(self.init_conv)
# GRU integration
if use_gru:
print("Using GRU")
self.features.add_module("gru", GRU_Module(input_channels=self.hidden_channels[-1], **gru_kwargs)) # type: ignore