Models API Reference
Complete neural network architectures for retinal response prediction. All models follow the Core + Readout pattern: a shared feature extraction core paired with per-session readouts.
Loading Pre-trained Models
load_core_readout_from_remote
load_core_readout_from_remote(
model_name: str,
device: str,
cache_directory_path: str | PathLike | None = None,
) -> BaseCoreReadout
Download and load a pre-trained core-readout model by name. Falls back to legacy ExampleCoreReadout format.
Source code in openretina/models/core_readout.py
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 | |
load_core_readout_model
load_core_readout_model(
model_path_or_name: str,
device: str,
cache_directory_path: str | PathLike | None = None,
) -> BaseCoreReadout
Load a core-readout model from a local path or remote name. Tries known remote names first, then local paths.
Source code in openretina/models/core_readout.py
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 | |
BaseCoreReadout
BaseCoreReadout
BaseCoreReadout(
core: Core,
readout: MultiReadoutBase,
learning_rate: float,
loss: Module | None = None,
validation_loss: Module | None = None,
data_info: dict[str, Any] | None = None,
)
Bases: LightningModule
Base module for models combining a shared core and a multi-session readout. All models following the Core Readout pattern should inherit from this class.
This LightningModule encapsulates a model made of a shared core and a flexible multi-session readout, suitable for training across-session architectures. It defines training, validation, and testing steps, provides hooks for optimizer and scheduler configuration, and methods for handling data info and visualization.
Initializes a BaseCoreReadout module.
| PARAMETER | DESCRIPTION |
|---|---|
core
|
The shared feature extraction core network.
TYPE:
|
readout
|
The multi-session readout module mapping core features to neuron outputs per session.
TYPE:
|
learning_rate
|
Learning rate for network training.
TYPE:
|
loss
|
Loss function for training. Defaults to PoissonLoss3d if None.
TYPE:
|
validation_loss
|
Loss used to compute correlation performance metric. Defaults to CorrelationLoss3d (avg=True) if None.
TYPE:
|
data_info
|
Dictionary containing data-specific metadata, such as input_shape, session neuron counts, etc. If None, defaults to empty dict.
TYPE:
|
Source code in openretina/models/core_readout.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
forward
forward(
x: Float[Tensor, "batch channels t h w"],
data_key: str | None = None,
) -> Tensor
Source code in openretina/models/core_readout.py
100 101 102 103 | |
training_step
training_step(
batch: tuple[str, DataPoint], batch_idx: int
) -> Tensor
Source code in openretina/models/core_readout.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
validation_step
validation_step(
batch: tuple[str, DataPoint], batch_idx: int
) -> Tensor
Source code in openretina/models/core_readout.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | |
test_step
test_step(
batch: tuple[str, DataPoint],
batch_idx: int,
dataloader_idx: int = 0,
) -> Tensor
Source code in openretina/models/core_readout.py
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | |
configure_optimizers
configure_optimizers()
Source code in openretina/models/core_readout.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | |
save_weight_visualizations
save_weight_visualizations(
folder_path: str,
file_format: str = "jpg",
state_suffix: str = "",
) -> None
Save weight visualizations for core and readout modules.
| PARAMETER | DESCRIPTION |
|---|---|
folder_path
|
Base directory to save visualizations
TYPE:
|
file_format
|
Image format for saved files
TYPE:
|
state_suffix
|
Optional suffix for state identification
TYPE:
|
Source code in openretina/models/core_readout.py
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | |
compute_readout_input_shape
compute_readout_input_shape(
core_in_shape: tuple[int, int, int, int], core: Core
) -> tuple[int, int, int, int]
Source code in openretina/models/core_readout.py
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | |
stimulus_shape
stimulus_shape(
time_steps: int, num_batches: int = 1
) -> tuple[int, int, int, int, int]
Source code in openretina/models/core_readout.py
244 245 246 | |
update_model_data_info
update_model_data_info(data_info: dict[str, Any]) -> None
To update relevant model attributes when loading a (trained) model and training it with new data only.
Source code in openretina/models/core_readout.py
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 | |
UnifiedCoreReadout
UnifiedCoreReadout
UnifiedCoreReadout(
in_shape: Int[tuple, "channels time height width"],
n_neurons_dict: dict[str, int],
core: DictConfig,
readout: DictConfig,
hidden_channels: tuple[int, ...]
| Iterable[int]
| None = None,
learning_rate: float = 0.001,
loss: Module | DictConfig | None = None,
validation_loss: Module | DictConfig | None = None,
data_info: dict[str, Any] | None = None,
optimizer: DictConfig | None = None,
lr_scheduler: DictConfig | None = None,
)
Bases: BaseCoreReadout
A flexible core-readout model for multi-session neural data, designed for Hydra config workflows.
This class is the recommended entry point for defining core-readout models via config files using Hydra. It allows unified instantiation of arbitrary core and readout modules, specified via DictConfig, enabling rapid experimentation and extensibility. Supports all multi-session settings, custom core/readout combinations, and integration with configuration-driven pipelines (including hyperparameter optimization).
Initializes a UnifiedCoreReadout for multi-session configurable neural modeling via Hydra configs.
| PARAMETER | DESCRIPTION |
|---|---|
in_shape
|
Input shape as (channels, time, height, width) for the core module.
TYPE:
|
hidden_channels
|
List of hidden channels for the core; used in core config initialization.
TYPE:
|
n_neurons_dict
|
Mapping from session/dataset identifier to neuron count for each session.
TYPE:
|
core
|
Hydra config for instantiating the core module (should specify class and params).
TYPE:
|
readout
|
Hydra config for the readout module (specifies type and custom session-aware params).
TYPE:
|
learning_rate
|
Learning rate for model training. Defaults to 0.001.
TYPE:
|
loss
|
Loss function for training. Defaults to PoissonLoss3d if None.
TYPE:
|
validation_loss
|
Loss used to compute correlation performance metric. Defaults to CorrelationLoss3d(avg=True) if None.
TYPE:
|
data_info
|
Additional metadata dictionary, e.g., with input shape and neuron mapping.
TYPE:
|
optimizer
|
Hydra config for optimizer instantiation. If None, defaults to AdamW.
TYPE:
|
lr_scheduler
|
Hydra config for learning rate scheduler. If None, defaults to ReduceLROnPlateau.
TYPE:
|
Source code in openretina/models/core_readout.py
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 | |
configure_optimizers
configure_optimizers()
Configure optimizers and schedulers using Hydra configs.
This method overrides BaseCoreReadout.configure_optimizers() to use configurable optimizers and schedulers via the utility functions.
Source code in openretina/models/core_readout.py
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 | |
ExampleCoreReadout
ExampleCoreReadout
ExampleCoreReadout(
in_shape: Int[tuple, "channels time height width"],
hidden_channels: Iterable[int],
temporal_kernel_sizes: Iterable[int],
spatial_kernel_sizes: Iterable[int],
n_neurons_dict: dict[str, int],
core_gamma_input: float = 0.0,
core_gamma_hidden: float = 0.0,
core_gamma_in_sparse: float = 0.0,
core_gamma_temporal: float = 40.0,
core_input_padding: bool
| str
| int
| tuple[int, int, int] = False,
core_hidden_padding: bool
| str
| int
| tuple[int, int, int] = True,
readout_scale: bool = True,
readout_bias: bool = True,
readout_gaussian_masks: bool = True,
readout_gaussian_mean_scale: float = 6.0,
readout_gaussian_var_scale: float = 4.0,
readout_positive: bool = True,
readout_gamma: float = 0.4,
readout_gamma_masks: float = 0.0,
readout_reg_avg: bool = False,
learning_rate: float = 0.01,
cut_first_n_frames_in_core: int = 30,
dropout_rate: float = 0.0,
maxpool_every_n_layers: Optional[int] = None,
downsample_input_kernel_size: Optional[
tuple[int, int, int]
] = None,
convolution_type: str = "custom_separable",
color_squashing_weights: tuple[float, ...]
| None = None,
data_info: dict[str, Any] | None = None,
)
Bases: BaseCoreReadout
Example implementation of a custom Core-Readout model, using a convolutional core and a Gaussian readout.
This class serves as a guide for constructing custom Core-Readout models without using the unified Hydra
configuration system and the UnifiedCoreReadout class. Use this model as a reference if you wish to instantiate
or design core/readout units directly in code rather than through configuration files. For most workflows,
especially those using Hydra, UnifiedCoreReadout is preferred for maximum flexibility.
N.B., this class is provided as a reference example.
Source code in openretina/models/core_readout.py
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 | |
on_load_checkpoint
on_load_checkpoint(checkpoint) -> None
To support legacy models that use bias_param instead of bias in their readout layers.
Source code in openretina/models/core_readout.py
489 490 491 492 493 494 495 496 497 498 | |
Sub-modules
- Core-Readout Models — Full module reference
- Linear-Nonlinear Models — LNP cascade models
- Sparse Autoencoder — Sparse representation models