Source code for pyvene.models.layers
from abc import ABCMeta, abstractmethod
import torch
[docs]
class InverseRotateLayer(torch.nn.Module):
"""The inverse of a given `LinearLayer` module."""
[docs]
def __init__(self, lin_layer):
super().__init__()
self.lin_layer = lin_layer
[docs]
def forward(self, x):
output = torch.matmul(x, self.lin_layer.weight.T)
return output
[docs]
class RotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization."""
[docs]
def __init__(self, n, init_orth=True):
super().__init__()
weight = torch.empty(n, n)
# we don't need init if the saved checkpoint has a nice
# starting point already.
# you can also study this if you want, but it is our focus.
if init_orth:
torch.nn.init.orthogonal_(weight)
self.weight = torch.nn.Parameter(weight, requires_grad=True)
[docs]
def forward(self, x):
return torch.matmul(x.to(self.weight.dtype), self.weight)
[docs]
class LowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization."""
[docs]
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
if init_orth:
torch.nn.init.orthogonal_(self.weight)
[docs]
def forward(self, x):
return torch.matmul(x.to(self.weight.dtype), self.weight)
[docs]
class SubspaceLowRankRotateLayer(torch.nn.Module):
"""A linear transformation with orthogonal initialization with subspace."""
[docs]
def __init__(self, n, m, init_orth=True):
super().__init__()
# n > m
self.weight = torch.nn.Parameter(torch.empty(n, m), requires_grad=True)
if init_orth:
torch.nn.init.orthogonal_(self.weight)
[docs]
def forward(self, x, l, r):
return torch.matmul(x.to(self.weight.dtype), self.weight[:, l:r])
[docs]
class AutoencoderLayerBase(torch.nn.Module, metaclass=ABCMeta):
"""An abstract base class that defines an interface of an autoencoder."""
@abstractmethod
def encode(self, x):
...
@abstractmethod
def decode(self, latent):
...
[docs]
class AutoencoderLayer(AutoencoderLayerBase):
"""An autoencoder with a single-layer encoder and single-layer decoder."""
[docs]
def __init__(self, input_dim, latent_dim, **kwargs):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_dim, latent_dim, bias=True),
torch.nn.ReLU())
self.decoder = torch.nn.Sequential(
torch.nn.Linear(latent_dim, input_dim, bias=True))
def encode(self, x):
x = x.to(self.encoder[0].weight.dtype)
x = x - self.decoder[0].bias
latent = self.encoder(x)
return latent
def decode(self, latent):
return self.decoder(latent)
[docs]
def forward(self, base, return_latent=False):
base_type = base.dtype
base = base.to(self.encoder[0].weight.dtype)
latent = self.encode(base)
base_reconstruct = self.decode(latent)
if not return_latent:
return base_reconstruct.to(base_type)
return {'latent': latent, 'output': base_reconstruct}