Utilities API Reference
The utils package provides helper functions and utilities used throughout OpenRetina for file handling, visualization, model management, and data processing.
Overview
The utils module contains:
- File Utilities: Functions for downloading, caching, and file management
- Plotting: Visualization tools for stimuli, responses, and model components
- Model Utilities: Helper functions for model training and evaluation
- Data Handling: Tools for working with HDF5 files and data formats
- Miscellaneous: General utility functions for reproducibility and debugging
File Utilities
File Management
Unzips a file and removes the zip archive.
- If the ZIP contains only one file, extracts directly to the same directory.
- If the ZIP contains a single top-level folder named after the ZIP, extracts without duplicating the folder.
- If it contains multiple files, extracts into a folder named after the ZIP.
Cache Management
Plotting and Visualization
Stimulus Visualization
Video Processing
HDF5 Data Handling
File Operations
Loads a dataset from an HDF5 file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path
|
str
|
Path to the HDF5 file. |
required |
dataset_path
|
str
|
Path to the dataset within the HDF5 file. |
required |
Returns:
Name | Type | Description |
---|---|---|
data |
ndarray
|
Data of the loaded dataset. |
Data Structure Exploration
Converts an HDF5 file to a folder structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
file_path
|
str
|
The path to the HDF5 file. |
required |
output_dir
|
str
|
The directory where the folder structure will be created. |
required |
Model Utilities
Training Helpers
Video Analysis
Frequency Analysis
Scale spatial and temporal components such that spatial component is in the range [-1, 1] Args: space_time_kernel: shape=[time, y_shape, x.shape] scaling_factor: optional scaling factor Returns:
Filtering
Miscellaneous Utilities
Reproducibility
Function that controls randomness. NumPy and random modules must be imported.
seed : Integer
A non-negative integer that defines the random state. Default is None
.
seed_torch : Boolean
If True
sets the random seed for pytorch tensors, so pytorch module
must be imported. Default is True
.
Returns: seed : Integer corresponding to the random state.
Output Capture
Usage Examples
File Management
from openretina.utils.file_utils import get_local_file_path, get_cache_directory
# Download and cache a remote file
remote_url = "https://example.com/data/file.h5"
local_path = get_local_file_path(remote_url)
print(f"File cached at: {local_path}")
# Check cache directory
cache_dir = get_cache_directory()
print(f"Cache directory: {cache_dir}")
# Download with custom cache location
custom_cache = "/path/to/custom/cache"
local_path = get_local_file_path(remote_url, custom_cache)
Stimulus Visualization
import torch
import matplotlib.pyplot as plt
from openretina.utils.plotting import plot_stimulus_composition, save_stimulus_to_mp4_video
# Create sample stimulus (channels=2, time=50, height=16, width=18)
stimulus = torch.randn(2, 50, 16, 18)
# Plot stimulus composition
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
plot_stimulus_composition(
stimulus=stimulus.numpy(),
temporal_trace_ax=axes[0, 0],
freq_ax=axes[0, 1],
spatial_ax=axes[1, 0],
highlight_x_list=[(10, 20), (30, 40)] # Highlight specific time ranges
)
plt.show()
# Save stimulus as video
save_stimulus_to_mp4_video(
stimulus=stimulus,
filepath="stimulus_video.mp4",
fps=30,
start_at_frame=0
)
HDF5 Data Handling
from openretina.utils.h5_handling import load_h5_into_dict, print_h5_structure
# Explore HDF5 file structure
print_h5_structure("data/responses.h5")
# Load specific datasets
data_dict = load_h5_into_dict("data/responses.h5")
print("Available datasets:")
for key, value in data_dict.items():
if hasattr(value, 'shape'):
print(f" {key}: {value.shape}")
else:
print(f" {key}: {type(value)}")
# Load with specific group
responses_data = load_dataset_from_h5(
"data/responses.h5",
dataset_key="session1/responses",
start_idx=0,
end_idx=1000
)
Model Training Utilities
import torch
from openretina.utils.model_utils import eval_state
from openretina.utils.misc import set_seed
# Set reproducible seed
set_seed(42, seed_torch=True)
# Use eval_state context manager
model = torch.nn.Linear(10, 1)
model.train() # Model in training mode
with eval_state(model):
# Model temporarily in eval mode
print(f"In context: training={model.training}")
output = model(torch.randn(5, 10))
# Model back to original training mode
print(f"After context: training={model.training}")
Video Analysis
import numpy as np
from openretina.utils.video_analysis import calculate_fft, weighted_main_frequency
# Analyze temporal kernel frequency content
temporal_kernel = np.random.randn(50) # 50 frame temporal kernel
sampling_freq = 30.0 # 30 Hz
# Calculate FFT
frequencies, fft_magnitude = calculate_fft(
temporal_kernel=temporal_kernel,
sampling_frequency=sampling_freq,
lowpass_cutoff=10.0
)
# Find dominant frequency
main_freq = weighted_main_frequency(frequencies, fft_magnitude)
print(f"Dominant frequency: {main_freq:.2f} Hz")
# Plot frequency spectrum
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(temporal_kernel)
plt.title('Temporal Kernel')
plt.xlabel('Frame')
plt.subplot(1, 2, 2)
plt.plot(frequencies, fft_magnitude)
plt.axvline(main_freq, color='red', linestyle='--', label=f'Main freq: {main_freq:.2f} Hz')
plt.title('Frequency Spectrum')
plt.xlabel('Frequency (Hz)')
plt.legend()
plt.tight_layout()
plt.show()
Advanced Stimulus Plotting
from openretina.utils.plotting import plot_stimulus_composition
import numpy as np
# Create more complex stimulus with patterns
time_steps, height, width = 100, 32, 32
stimulus = np.zeros((2, time_steps, height, width))
# Create moving grating in UV channel
for t in range(time_steps):
x = np.linspace(0, 4*np.pi, width)
y = np.linspace(0, 4*np.pi, height)
X, Y = np.meshgrid(x, y)
# Moving sine wave
phase = t * 0.2
stimulus[0, t] = np.sin(X + phase) * np.cos(Y)
# Different pattern in Green channel
stimulus[1, t] = np.cos(X - phase) * np.sin(Y + phase)
# Plot with custom highlighting
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
highlight_periods = [(20, 30), (50, 60), (80, 90)]
plot_stimulus_composition(
stimulus=stimulus,
temporal_trace_ax=axes[0, 0],
freq_ax=axes[0, 1],
spatial_ax=axes[1, 0],
highlight_x_list=highlight_periods
)
# Add custom analysis in the fourth subplot
mean_intensity = np.mean(stimulus, axis=(2, 3)) # Average over spatial dimensions
axes[1, 1].plot(mean_intensity[0], label='UV Channel', alpha=0.7)
axes[1, 1].plot(mean_intensity[1], label='Green Channel', alpha=0.7)
axes[1, 1].set_xlabel('Time (frames)')
axes[1, 1].set_ylabel('Mean Intensity')
axes[1, 1].set_title('Temporal Evolution')
axes[1, 1].legend()
# Highlight the same periods
for start, end in highlight_periods:
axes[1, 1].axvspan(start, end, alpha=0.2, color='red')
plt.suptitle('Comprehensive Stimulus Analysis', fontsize=16)
plt.tight_layout()
plt.show()
Batch File Processing
import os
from pathlib import Path
from openretina.utils.file_utils import get_local_file_path
def process_multiple_files(file_urls, output_dir):
"""Download and process multiple files."""
os.makedirs(output_dir, exist_ok=True)
processed_files = []
for i, url in enumerate(file_urls):
print(f"Processing file {i+1}/{len(file_urls)}: {url}")
try:
# Download file
local_path = get_local_file_path(url)
# Process based on file type
if local_path.suffix == '.h5':
# Process HDF5 file
data = load_h5_into_dict(local_path)
print(f" Loaded {len(data)} datasets")
elif local_path.suffix in ['.mp4', '.avi']:
# Process video file
print(f" Video file: {local_path}")
# Copy to output directory
output_path = Path(output_dir) / f"processed_{i}_{local_path.name}"
import shutil
shutil.copy2(local_path, output_path)
processed_files.append(output_path)
except Exception as e:
print(f" Error processing {url}: {e}")
return processed_files
# Example usage
urls = [
"https://example.com/data1.h5",
"https://example.com/data2.h5",
"https://example.com/video1.mp4"
]
# processed_files = process_multiple_files(urls, "output_data/")
Custom Output Capture
from openretina.utils.capture_output import CaptureOutputAndWarnings
import warnings
# Capture both stdout and warnings
with CaptureOutputAndWarnings() as captured:
print("This will be captured")
warnings.warn("This warning will be captured")
print("More output")
print("Captured stdout:")
print(captured.stdout)
print("\nCaptured warnings:")
print(captured.warnings)
Configuration and Constants
# Access package constants
from openretina.utils.constants import *
# File utilities use these constants
from openretina.utils.file_utils import HUGGINGFACE_REPO_ID, GIN_BASE_URL
print(f"Default HuggingFace repo: {HUGGINGFACE_REPO_ID}")
print(f"GIN repository base URL: {GIN_BASE_URL}")
Performance Tips
- File Caching: Downloaded files are automatically cached to avoid repeated downloads
- HDF5 Loading: Use
start_idx
andend_idx
parameters to load only needed data portions - Video Rendering: Adjust FPS and quality settings for video output based on needs
- Memory Management: Large visualizations can consume significant memory; consider reducing resolution
Troubleshooting
Common Issues
Download failures:
# Check internet connection and URL accessibility
import requests
response = requests.head(url)
print(f"Status: {response.status_code}")
# Use custom cache directory if permissions issues
custom_cache = "/tmp/openretina_cache"
local_path = get_local_file_path(url, custom_cache)
Visualization memory issues:
# Reduce stimulus size for plotting
stimulus_subset = stimulus[:, ::2, ::2, ::2] # Downsample
plot_stimulus_composition(stimulus_subset, ...)
# Or plot only specific time ranges
start_frame, end_frame = 20, 50
plot_stimulus_composition(stimulus[:, start_frame:end_frame], ...)
HDF5 file access errors:
# Check file exists and is readable
import os
assert os.path.exists(file_path), f"File not found: {file_path}"
assert os.access(file_path, os.R_OK), f"Cannot read file: {file_path}"
# Check file integrity
import h5py
try:
with h5py.File(file_path, 'r') as f:
print("File opened successfully")
except Exception as e:
print(f"File error: {e}")
Configuration
Utilities can be configured through environment variables:
# Set custom cache directory
export OPENRETINA_CACHE_DIR="/path/to/cache"
# Set download timeout
export OPENRETINA_DOWNLOAD_TIMEOUT="60"
# Set matplotlib backend for headless environments
export MPLBACKEND="Agg"
See Also
- Installation Guide: Setting up the environment
- Models API: Using utilities with models
- Data I/O API: Data handling utilities
- Plotting Examples: Visualization examples