Stimulus Optimization
TODO: check page. AI generated.
This tutorial will guide you through using OpenRetina's stimulus optimization tools to find visual stimuli that optimally drive retinal neurons. You'll learn how to generate most exciting inputs (MEIs) and most discriminative stimuli (MDS) for understanding neural function.
Overview
Stimulus optimization in OpenRetina allows you to:
- Find optimal stimuli: Generate visual patterns that maximally activate specific neurons
- Create discriminative stimuli: Find patterns that differentiate between cell types
- Understand receptive fields: Reveal the visual features that neurons prefer
- Analyze population responses: Study how groups of neurons respond to optimized inputs
Key Concepts
Most Exciting Input (MEI)
A visual stimulus that maximally activates a target neuron or group of neurons.
Most Discriminative Stimulus (MDS)
A stimulus that maximally activates one group of neurons while minimizing responses in other groups, useful for understanding functional differences between cell types.
Gradient-based Optimization
Uses backpropagation through the neural network to iteratively improve stimuli based on objective functions.
Getting Started
Installation and Imports
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
# OpenRetina imports
from openretina.models import load_core_readout_from_remote
from openretina.insilico import (
optimize_stimulus,
IncreaseObjective,
ContrastiveNeuronObjective,
OptimizationStopper,
MeanReducer,
SliceMeanReducer
)
from openretina.insilico.stimulus_optimization.regularizer import (
ChangeNormJointlyClipRangeSeparately,
TemporalGaussianLowPassFilterProcessor
)
from openretina.utils.plotting import plot_stimulus_composition
Load a Pre-trained Model
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load model
model = load_core_readout_from_remote("hoefling_2024_base_low_res", device)
model.eval() # Set to evaluation mode
print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")
print(f"Number of output neurons: {model.readout.n_neurons}")
Basic Stimulus Optimization
Step 1: Create Initial Stimulus
# Define stimulus shape
time_steps = 40
stimulus_shape = model.stimulus_shape(time_steps=time_steps, num_batches=1)
print(f"Stimulus shape: {stimulus_shape}")
# Initialize random stimulus (starting point for optimization)
torch.manual_seed(42) # For reproducibility
stimulus = torch.randn(stimulus_shape, requires_grad=True, device=device)
print(f"Initial stimulus range: [{stimulus.min():.3f}, {stimulus.max():.3f}]")
Step 2: Define Objective Function
# Choose a neuron to optimize for (e.g., neuron index 50)
target_neuron = 50
# Create response reducer to focus on specific time frames
# This averages responses from frames 10-20 (out of 40 total)
reducer = SliceMeanReducer(axis=0, start=10, length=10)
# Create objective to maximize target neuron's response
objective = IncreaseObjective(
model=model,
neuron_indices=target_neuron,
data_key=None, # For single session models
response_reducer=reducer
)
print(f"Objective: Maximize response of neuron {target_neuron}")
print(f"Using time frames {reducer.start} to {reducer.start + reducer.length - 1}")
Step 3: Set Up Regularization and Post-processing
# Stimulus post-processors to maintain realistic stimulus properties
# 1. Clip stimulus values to expected physiological range
stimulus_clipper = ChangeNormJointlyClipRangeSeparately(
min_max_values=[(-0.6, 6.2), (-0.9, 6.2)], # Range for [UV, Green] channels
norm=30.0 # Joint normalization factor
)
# 2. Apply temporal smoothing to avoid unrealistic high-frequency components
temporal_filter = TemporalGaussianLowPassFilterProcessor(
sigma=0.5, # Gaussian kernel standard deviation
kernel_size=5, # Temporal kernel size
device=device
)
# Combine post-processors
stimulus_postprocessors = [stimulus_clipper, temporal_filter]
# Initialize stimulus with reasonable values
stimulus.data = stimulus_clipper.process(stimulus.data * 0.1)
print(f"Processed initial stimulus range: [{stimulus.min():.3f}, {stimulus.max():.3f}]")
Step 4: Run Optimization
# Define optimizer initialization function
optimizer_init_fn = partial(torch.optim.SGD, lr=100.0)
# Set optimization stopping criteria
stopper = OptimizationStopper(max_iterations=50)
# Run optimization
print("Starting optimization...")
initial_response = objective.forward(stimulus).item()
print(f"Initial objective value: {initial_response:.4f}")
optimize_stimulus(
stimulus=stimulus,
optimizer_init_fn=optimizer_init_fn,
objective_object=objective,
optimization_stopper=stopper,
stimulus_regularization_loss=None, # No additional regularization
stimulus_postprocessor=stimulus_postprocessors
)
# Check final objective value
final_response = objective.forward(stimulus).item()
print(f"Final objective value: {final_response:.4f}")
print(f"Improvement: {final_response - initial_response:.4f}")
Step 5: Visualize Results
def plot_optimized_stimulus(stimulus, title="Optimized Stimulus"):
"""Plot the optimized stimulus with temporal, spatial, and frequency components."""
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Convert to numpy for plotting
stim_np = stimulus[0].detach().cpu().numpy() # Remove batch dimension
# Plot stimulus composition
plot_stimulus_composition(
stimulus=stim_np,
temporal_trace_ax=axes[0, 0],
freq_ax=axes[0, 1],
spatial_ax=axes[1, 0],
highlight_x_list=[(reducer.start, reducer.start + reducer.length - 1)]
)
# Add temporal evolution plot
axes[1, 1].imshow(stim_np[0, :, 8, :].T, aspect='auto', cmap='coolwarm')
axes[1, 1].set_xlabel('Time')
axes[1, 1].set_ylabel('X Position')
axes[1, 1].set_title('Temporal Evolution (UV Channel)')
plt.suptitle(title, fontsize=16)
plt.tight_layout()
plt.show()
# Plot the optimized stimulus
plot_optimized_stimulus(stimulus, f"MEI for Neuron {target_neuron}")
Advanced Optimization: Most Discriminative Stimuli
Understanding Cell Type Classification
# For the Höfling dataset, we can work with functional cell types
# Let's create groups based on ON/OFF responses
# Example cell type groupings (you would get these from your model's metadata)
def get_cell_type_groups(model):
"""Get example cell type groupings for demonstration."""
n_neurons = model.readout.n_neurons
# Create example groupings (in practice, use actual cell type labels)
off_cells = list(range(0, n_neurons//3))
on_off_cells = list(range(n_neurons//3, 2*n_neurons//3))
on_cells = list(range(2*n_neurons//3, n_neurons))
return {
'OFF': off_cells,
'ON-OFF': on_off_cells,
'ON': on_cells
}
cell_groups = get_cell_type_groups(model)
print("Cell type groups:")
for group_name, indices in cell_groups.items():
print(f" {group_name}: {len(indices)} cells")
Create Contrastive Objective
# Select target group and contrast groups
target_group = cell_groups['ON'][:10] # Target: first 10 ON cells
contrast_groups = [
cell_groups['OFF'][:20], # Contrast against OFF cells
cell_groups['ON-OFF'][:15] # Contrast against ON-OFF cells
]
print(f"Target group: {len(target_group)} ON cells")
print(f"Contrast groups: {[len(g) for g in contrast_groups]} cells")
# Create contrastive objective
contrastive_objective = ContrastiveNeuronObjective(
model=model,
on_cluster_idc=target_group,
off_cluster_idc_list=contrast_groups,
data_key=None,
response_reducer=reducer,
temperature=1.6 # Controls sharpness of contrast
)
print("Created contrastive objective for most discriminative stimulus")
Optimize Discriminative Stimulus
# Create new stimulus for MDS optimization
torch.manual_seed(123)
mds_stimulus = torch.randn(stimulus_shape, requires_grad=True, device=device)
mds_stimulus.data = stimulus_clipper.process(mds_stimulus.data * 0.1)
print("Starting MDS optimization...")
initial_contrast = contrastive_objective.forward(mds_stimulus).item()
print(f"Initial contrastive objective: {initial_contrast:.4f}")
# Run optimization with same settings
optimize_stimulus(
stimulus=mds_stimulus,
optimizer_init_fn=optimizer_init_fn,
objective_object=contrastive_objective,
optimization_stopper=OptimizationStopper(max_iterations=50),
stimulus_postprocessor=stimulus_postprocessors
)
final_contrast = contrastive_objective.forward(mds_stimulus).item()
print(f"Final contrastive objective: {final_contrast:.4f}")
print(f"Improvement: {final_contrast - initial_contrast:.4f}")
# Plot the discriminative stimulus
plot_optimized_stimulus(mds_stimulus, "Most Discriminative Stimulus (ON vs OFF/ON-OFF)")
Comparing MEI vs MDS
Side-by-Side Analysis
def compare_stimuli_responses(mei_stim, mds_stim, model, cell_groups):
"""Compare responses of MEI and MDS across different cell types."""
model.eval()
with torch.no_grad():
mei_responses = model(mei_stim)[0].cpu().numpy() # Remove batch dim
mds_responses = model(mds_stim)[0].cpu().numpy()
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
# Plot responses by cell type
for i, (group_name, indices) in enumerate(cell_groups.items()):
if i >= 3: # Limit to 3 groups for visualization
break
# MEI responses
axes[0, i].hist(mei_responses[indices], bins=20, alpha=0.7, label='MEI')
axes[0, i].set_title(f'MEI: {group_name} Cells')
axes[0, i].set_xlabel('Response')
axes[0, i].set_ylabel('Count')
# MDS responses
axes[1, i].hist(mds_responses[indices], bins=20, alpha=0.7,
label='MDS', color='orange')
axes[1, i].set_title(f'MDS: {group_name} Cells')
axes[1, i].set_xlabel('Response')
axes[1, i].set_ylabel('Count')
plt.tight_layout()
plt.show()
# Print summary statistics
print("\nResponse Summary:")
print("=" * 50)
for group_name, indices in cell_groups.items():
mei_mean = np.mean(mei_responses[indices])
mds_mean = np.mean(mds_responses[indices])
print(f"{group_name:>8}: MEI={mei_mean:.3f}, MDS={mds_mean:.3f}")
# Compare the optimized stimuli
compare_stimuli_responses(stimulus, mds_stimulus, model, cell_groups)
Advanced Features
Custom Regularization
from openretina.insilico.stimulus_optimization.regularizer import RangeRegularizationLoss
# Create custom regularization to encourage specific properties
range_regularizer = RangeRegularizationLoss(
min_values=torch.tensor([-1.0, -1.0], device=device),
max_values=torch.tensor([7.0, 7.0], device=device),
penalty_weight=0.1
)
# Use in optimization
optimize_stimulus(
stimulus=stimulus,
optimizer_init_fn=optimizer_init_fn,
objective_object=objective,
optimization_stopper=OptimizationStopper(max_iterations=30),
stimulus_regularization_loss=range_regularizer, # Add regularization
stimulus_postprocessor=stimulus_postprocessors
)
Multi-neuron Optimization
# Optimize for multiple neurons simultaneously
target_neurons = [10, 25, 40, 55] # Multiple neuron indices
multi_objective = IncreaseObjective(
model=model,
neuron_indices=target_neurons, # List of neurons
data_key=None,
response_reducer=reducer
)
print(f"Optimizing for {len(target_neurons)} neurons: {target_neurons}")
# Run optimization
multi_stimulus = torch.randn(stimulus_shape, requires_grad=True, device=device)
multi_stimulus.data = stimulus_clipper.process(multi_stimulus.data * 0.1)
optimize_stimulus(
stimulus=multi_stimulus,
optimizer_init_fn=optimizer_init_fn,
objective_object=multi_objective,
optimization_stopper=OptimizationStopper(max_iterations=50),
stimulus_postprocessor=stimulus_postprocessors
)
plot_optimized_stimulus(multi_stimulus, f"MEI for Neurons {target_neurons}")
Parameter Sensitivity Analysis
def sensitivity_analysis(base_stimulus, objective, param_ranges):
"""Analyze sensitivity to optimization parameters."""
results = {}
for param_name, values in param_ranges.items():
results[param_name] = []
for value in values:
# Create fresh stimulus
test_stimulus = base_stimulus.clone().detach()
test_stimulus.requires_grad = True
if param_name == 'learning_rate':
opt_fn = partial(torch.optim.SGD, lr=value)
else:
opt_fn = optimizer_init_fn
# Run short optimization
optimize_stimulus(
stimulus=test_stimulus,
optimizer_init_fn=opt_fn,
objective_object=objective,
optimization_stopper=OptimizationStopper(max_iterations=20),
stimulus_postprocessor=stimulus_postprocessors
)
final_obj = objective.forward(test_stimulus).item()
results[param_name].append(final_obj)
print(f"{param_name}={value}: final_objective={final_obj:.4f}")
return results
# Test different learning rates
param_ranges = {
'learning_rate': [1.0, 10.0, 50.0, 100.0, 200.0]
}
sensitivity_results = sensitivity_analysis(stimulus, objective, param_ranges)
# Plot results
plt.figure(figsize=(10, 6))
for param_name, values in sensitivity_results.items():
plt.plot(param_ranges[param_name], values, 'o-', label=param_name)
plt.xlabel('Parameter Value')
plt.ylabel('Final Objective Value')
plt.title('Parameter Sensitivity Analysis')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Best Practices
Optimization Tips
- Start with small learning rates: High learning rates can lead to unstable optimization
- Use appropriate regularization: Prevent unrealistic stimuli with proper constraints
- Monitor convergence: Check that objective values are improving
- Use temporal smoothing: Apply low-pass filtering for realistic temporal dynamics
- Clip stimulus ranges: Ensure values stay within physiological bounds
Choosing Parameters
# Recommended parameter ranges for different scenarios
# For detailed, high-resolution optimization
detailed_params = {
'learning_rate': 10.0,
'max_iterations': 100,
'temporal_sigma': 0.3, # Sharper temporal features
'clip_range': [(-1.0, 8.0), (-1.5, 8.0)]
}
# For quick exploration
quick_params = {
'learning_rate': 50.0,
'max_iterations': 20,
'temporal_sigma': 1.0, # Smoother temporal features
'clip_range': [(-0.5, 6.0), (-0.8, 6.0)]
}
# For biological realism
biological_params = {
'learning_rate': 20.0,
'max_iterations': 50,
'temporal_sigma': 0.5,
'clip_range': [(-0.6, 6.2), (-0.9, 6.2)] # Based on Höfling et al. data
}
print("Use these parameter sets based on your optimization goals")
Troubleshooting
def diagnose_optimization(stimulus, objective, n_steps=10):
"""Diagnose optimization issues by tracking progress."""
optimizer = torch.optim.SGD([stimulus], lr=50.0)
objectives = []
for i in range(n_steps):
obj = objective.forward(stimulus)
loss = -obj # Negative because we minimize loss but maximize objective
optimizer.zero_grad()
loss.backward()
# Check gradients
grad_norm = stimulus.grad.norm().item()
optimizer.step()
objectives.append(obj.item())
print(f"Step {i}: objective={obj.item():.4f}, grad_norm={grad_norm:.6f}")
if grad_norm < 1e-6:
print("Warning: Very small gradients detected!")
break
# Plot progress
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(objectives, 'o-')
plt.xlabel('Optimization Step')
plt.ylabel('Objective Value')
plt.title('Optimization Progress')
plt.grid(True, alpha=0.3)
plt.subplot(1, 2, 2)
plt.hist(stimulus.detach().cpu().numpy().flatten(), bins=30)
plt.xlabel('Stimulus Values')
plt.ylabel('Count')
plt.title('Stimulus Value Distribution')
plt.tight_layout()
plt.show()
# Diagnose optimization issues
# diagnose_optimization(stimulus, objective)
Saving and Loading Results
Save Optimized Stimuli
import os
from datetime import datetime
def save_optimization_results(stimulus, objective, metadata, save_dir="optimization_results"):
"""Save optimized stimulus and metadata."""
os.makedirs(save_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save stimulus
stimulus_path = os.path.join(save_dir, f"stimulus_{timestamp}.pt")
torch.save(stimulus.detach().cpu(), stimulus_path)
# Save metadata
final_objective = objective.forward(stimulus).item()
metadata.update({
'final_objective': final_objective,
'timestamp': timestamp,
'stimulus_shape': list(stimulus.shape)
})
metadata_path = os.path.join(save_dir, f"metadata_{timestamp}.pt")
torch.save(metadata, metadata_path)
print(f"Results saved:")
print(f" Stimulus: {stimulus_path}")
print(f" Metadata: {metadata_path}")
return stimulus_path, metadata_path
# Save results
metadata = {
'model_name': 'hoefling_2024_base_low_res',
'target_neuron': target_neuron,
'optimization_type': 'MEI',
'learning_rate': 100.0,
'max_iterations': 50
}
save_optimization_results(stimulus, objective, metadata)
Applications and Use Cases
1. Receptive Field Mapping
Use MEI optimization to understand what visual features each neuron prefers.
2. Cell Type Characterization
Use MDS optimization to find stimuli that best differentiate between functional cell types.
3. Model Validation
Compare optimized stimuli with known biological receptive field properties.
4. Feature Discovery
Discover unexpected visual features that strongly drive neural responses.
5. Stimulus Design
Create optimized stimuli for experimental validation in real retinal recordings.
Next Steps
After mastering stimulus optimization:
- Combine with real data: Use optimized stimuli to validate model predictions
- Explore inner representations: Optimize for intermediate model layers
- Population analysis: Study how populations of neurons respond to optimized stimuli
- Cross-species comparison: Compare optimal stimuli across different retinal models
Additional Resources
- In-silico Experiments Overview
- API Reference: Stimulus Optimization
- Example Notebook: Most Discriminative Stimuli
- Model Zoo: Available pre-trained models for optimization