Skip to content

Sparse Autoencoder Models

SparsityMSELoss

Source code in openretina/models/sparse_autoencoder.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class SparsityMSELoss:
    def __init__(self, sparsity_factor: float):
        self.sparsity_factor = sparsity_factor

    @staticmethod
    def mse_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
        """Mean over all examples, sum over hidden neurons"""
        mse_full = nn.functional.mse_loss(x, x_hat, reduction="none")
        mse = mse_full.sum(dim=-1).mean()
        return mse

    @staticmethod
    def sparsity_loss(z: torch.Tensor) -> torch.Tensor:
        # The anthropic paper just sums over all neurons
        # Make sure the interpolation factor is small enough to not have a dominating sparsity loss
        return z.abs().sum()

    def forward(
        self, x: torch.Tensor, z: torch.Tensor, x_hat: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mse_loss = self.mse_loss(x, x_hat)
        sparsity_loss = self.sparsity_loss(z)
        total_loss = mse_loss + self.sparsity_factor * sparsity_loss
        return total_loss, mse_loss, sparsity_loss

mse_loss(x, x_hat) staticmethod

Mean over all examples, sum over hidden neurons

Source code in openretina/models/sparse_autoencoder.py
13
14
15
16
17
18
@staticmethod
def mse_loss(x: torch.Tensor, x_hat: torch.Tensor) -> torch.Tensor:
    """Mean over all examples, sum over hidden neurons"""
    mse_full = nn.functional.mse_loss(x, x_hat, reduction="none")
    mse = mse_full.sum(dim=-1).mean()
    return mse