Source code for pyvene.models.interventions

import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence, Union, List, Any

from .layers import RotateLayer, LowRankRotateLayer, SubspaceLowRankRotateLayer, AutoencoderLayer
from .basic_utils import sigmoid_boundary
from .intervention_utils import _can_use_fast, _do_intervention_by_swap

from dataclasses import dataclass
from transformers.activations import ACT2FN
from transformers.utils import ModelOutput


[docs] @dataclass class InterventionOutput(ModelOutput): """ Output of the IntervenableModel, including original outputs, intervened outputs, and collected activations. """ output: Optional[Any] = None latent: Optional[Any] = None
[docs] class Intervention(torch.nn.Module): """Intervention the original representations."""
[docs] def __init__(self, **kwargs): super().__init__() self.trainable = False self.is_source_constant = False self.keep_last_dim = kwargs["keep_last_dim"] if "keep_last_dim" in kwargs else False self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False self.subspace_partition = ( kwargs["subspace_partition"] if "subspace_partition" in kwargs else None ) # we turn the partition into list indices if self.subspace_partition is not None: expanded_subspace_partition = [] for subspace in self.subspace_partition: if len(subspace) == 2 and isinstance(subspace[0], int): expanded_subspace_partition.append([i for i in range(subspace[0],subspace[1])]) else: # it could be discrete indices. expanded_subspace_partition.append(subspace) self.subspace_partition = expanded_subspace_partition if "embed_dim" in kwargs and kwargs["embed_dim"] is not None: self.register_buffer('embed_dim', torch.tensor(kwargs["embed_dim"])) self.register_buffer('interchange_dim', torch.tensor(kwargs["embed_dim"])) else: self.embed_dim = None self.interchange_dim = None if "source_representation" in kwargs and kwargs["source_representation"] is not None: self.is_source_constant = True self.register_buffer('source_representation', kwargs["source_representation"]) else: if "hidden_source_representation" in kwargs and \ kwargs["hidden_source_representation"] is not None: self.is_source_constant = True else: self.source_representation = None
def set_source_representation(self, source_representation): self.is_source_constant = True self.register_buffer('source_representation', source_representation) def set_interchange_dim(self, interchange_dim): if not isinstance(interchange_dim, torch.Tensor): # Convert integer or list into torch.Tensor. self.interchange_dim = torch.tensor(interchange_dim) else: self.interchange_dim = interchange_dim
[docs] @abstractmethod def forward(self, base, source, subspaces=None): pass
[docs] class LocalistRepresentationIntervention(torch.nn.Module): """Localist representation."""
[docs] def __init__(self, **kwargs): super().__init__() self.is_repr_distributed = False
[docs] class DistributedRepresentationIntervention(torch.nn.Module): """Distributed representation."""
[docs] def __init__(self, **kwargs): super().__init__() self.is_repr_distributed = True
[docs] class TrainableIntervention(Intervention): """Intervention the original representations."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.trainable = True self.is_source_constant = False
def tie_weight(self, linked_intervention): pass
[docs] class ConstantSourceIntervention(Intervention): """Constant source."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.is_source_constant = True
[docs] class SourcelessIntervention(Intervention): """No source."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.is_source_constant = True
[docs] class BasisAgnosticIntervention(Intervention): """Intervention that will modify its basis in a uncontrolled manner."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.basis_agnostic = True
[docs] class SharedWeightsTrainableIntervention(TrainableIntervention): """Intervention the original representations."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.shared_weights = True
[docs] class ZeroIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention): """Zero-out activations."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source=None, subspaces=None, **kwargs): return _do_intervention_by_swap( base, torch.zeros_like(base), "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"ZeroIntervention()"
[docs] class CollectIntervention(ConstantSourceIntervention): """Collect activations."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source=None, subspaces=None, **kwargs): return _do_intervention_by_swap( base, source, "collect", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"CollectIntervention()"
[docs] class SkipIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Skip the current intervening layer's computation in the hook function."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source, subspaces=None, **kwargs): # source here is the base example input to the hook return _do_intervention_by_swap( base, source, "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"SkipIntervention()"
[docs] class VanillaIntervention(Intervention, LocalistRepresentationIntervention): """Intervention the original representations."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source, subspaces=None, **kwargs): return _do_intervention_by_swap( base, source if self.source_representation is None else self.source_representation, "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"VanillaIntervention()"
[docs] class AdditionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Intervention the original representations with activation addition."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source, subspaces=None, **kwargs): return _do_intervention_by_swap( base, source if self.source_representation is None else self.source_representation, "add", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"AdditionIntervention()"
[docs] class SubtractionIntervention(BasisAgnosticIntervention, LocalistRepresentationIntervention): """Intervention the original representations with activation subtraction."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward(self, base, source, subspaces=None, **kwargs): return _do_intervention_by_swap( base, source if self.source_representation is None else self.source_representation, "subtract", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, )
def __str__(self): return f"SubtractionIntervention()"
[docs] class RotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) rotate_layer = RotateLayer(self.embed_dim) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
[docs] def forward(self, base, source, subspaces=None, **kwargs): rotated_base = self.rotate_layer(base) rotated_source = self.rotate_layer(source) # interchange rotated_base = _do_intervention_by_swap( rotated_base, rotated_source, "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, ) # inverse base output = torch.matmul(rotated_base, self.rotate_layer.weight.T) return output.to(base.dtype)
def __str__(self): return f"RotatedSpaceIntervention()"
[docs] class BoundlessRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space with boundary mask."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) rotate_layer = RotateLayer(self.embed_dim) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) self.intervention_boundaries = torch.nn.Parameter( torch.tensor([0.5]), requires_grad=True ) self.temperature = torch.nn.Parameter(torch.tensor(50.0)) self.intervention_population = torch.nn.Parameter( torch.arange(0, self.embed_dim), requires_grad=False )
def get_boundary_parameters(self): return self.intervention_boundaries def get_temperature(self): return self.temperature def set_temperature(self, temp: torch.Tensor): self.temperature.data = temp def set_intervention_boundaries(self, intervention_boundaries): self.intervention_boundaries = torch.nn.Parameter( torch.tensor([intervention_boundaries]), requires_grad=True )
[docs] def forward(self, base, source, subspaces=None, **kwargs): batch_size = base.shape[0] rotated_base = self.rotate_layer(base) rotated_source = self.rotate_layer(source) # get boundary intervention_boundaries = torch.clamp(self.intervention_boundaries, 1e-3, 1) boundary_mask = sigmoid_boundary( self.intervention_population.repeat(batch_size, 1), 0.0, intervention_boundaries[0] * int(self.embed_dim), self.temperature, ) boundary_mask = ( torch.ones(batch_size, device=base.device).unsqueeze(dim=-1) * boundary_mask ) boundary_mask = boundary_mask.to(rotated_base.dtype) # interchange rotated_output = ( 1.0 - boundary_mask ) * rotated_base + boundary_mask * rotated_source # inverse output output = torch.matmul(rotated_output, self.rotate_layer.weight.T) return output.to(base.dtype)
def __str__(self): return f"BoundlessRotatedSpaceIntervention()"
[docs] class SigmoidMaskRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space with boundary mask."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) rotate_layer = RotateLayer(self.embed_dim) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer) # boundary masks are initialized to close to 1 self.masks = torch.nn.Parameter( torch.tensor([100.0] * self.embed_dim), requires_grad=True ) self.temperature = torch.nn.Parameter(torch.tensor(50.0))
def get_boundary_parameters(self): return self.intervention_boundaries def get_temperature(self): return self.temperature def set_temperature(self, temp: torch.Tensor): self.temperature.data = temp
[docs] def forward(self, base, source, subspaces=None, **kwargs): batch_size = base.shape[0] rotated_base = self.rotate_layer(base) rotated_source = self.rotate_layer(source) # get boundary mask between 0 and 1 from sigmoid boundary_mask = torch.sigmoid(self.masks / self.temperature) boundary_mask = ( torch.ones(batch_size, device=base.device).unsqueeze(dim=-1) * boundary_mask ) boundary_mask = boundary_mask.to(rotated_base.dtype) # interchange rotated_output = ( 1.0 - boundary_mask ) * rotated_base + boundary_mask * rotated_source # inverse output output = torch.matmul(rotated_output, self.rotate_layer.weight.T) return output.to(base.dtype)
def __str__(self): return f"SigmoidMaskRotatedSpaceIntervention()"
[docs] class SigmoidMaskIntervention(TrainableIntervention, LocalistRepresentationIntervention): """Intervention in the original basis with binary mask."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) self.mask = torch.nn.Parameter( torch.zeros(self.embed_dim), requires_grad=True) self.temperature = torch.nn.Parameter(torch.tensor(0.01))
def get_temperature(self): return self.temperature def set_temperature(self, temp: torch.Tensor): self.temperature.data = temp
[docs] def forward(self, base, source, subspaces=None, **kwargs): batch_size = base.shape[0] # get boundary mask between 0 and 1 from sigmoid mask_sigmoid = torch.sigmoid(self.mask / torch.tensor(self.temperature)) # interchange intervened_output = ( 1.0 - mask_sigmoid ) * base + mask_sigmoid * source return intervened_output
def __str__(self): return f"SigmoidMaskIntervention()"
[docs] class LowRankRotatedSpaceIntervention(TrainableIntervention, DistributedRepresentationIntervention): """Intervention in the rotated space."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"]) self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
[docs] def forward(self, base, source, subspaces=None, **kwargs): rotated_base = self.rotate_layer(base) rotated_source = self.rotate_layer(source) if subspaces is not None: if self.use_fast or _can_use_fast(subspaces): if self.subspace_partition is None: sel_subspace_indices = subspaces[0] else: sel_subspace_indices = [] for subspace in subspaces[0]: sel_subspace_indices.extend( self.subspace_partition[subspace] ) diff = rotated_source - rotated_base assert rotated_base.shape[0] == len(subspaces) batched_subspace = diff[..., sel_subspace_indices].unsqueeze(dim=1) batched_weights = self.rotate_layer.weight[..., sel_subspace_indices].T output = base + torch.matmul(batched_subspace, batched_weights).squeeze( dim=1 ) else: assert self.subspace_partition is not None output = [] diff = rotated_source - rotated_base assert rotated_base.shape[0] == len(subspaces) batched_subspace = [] batched_weights = [] for example_i in range(len(subspaces)): # render subspace as column indices sel_subspace_indices = [] for subspace in subspaces[example_i]: sel_subspace_indices.extend( self.subspace_partition[subspace] ) LHS = diff[example_i, sel_subspace_indices].unsqueeze(dim=0) RHS = self.rotate_layer.weight[..., sel_subspace_indices].T batched_subspace += [LHS] batched_weights += [RHS] batched_subspace = torch.stack(batched_subspace, dim=0) batched_weights = torch.stack(batched_weights, dim=0) output = base + torch.matmul(batched_subspace, batched_weights).squeeze( dim=1 ) else: output = base + torch.matmul( (rotated_source - rotated_base), self.rotate_layer.weight.T ) return output.to(base.dtype)
def __str__(self): return f"LowRankRotatedSpaceIntervention()"
[docs] class PCARotatedSpaceIntervention(BasisAgnosticIntervention, DistributedRepresentationIntervention): """Intervention in the pca space."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) pca = kwargs["pca"] pca_mean = kwargs["pca_mean"] pca_std = kwargs["pca_std"] self.pca_components = torch.nn.Parameter( torch.tensor(pca.components_, dtype=torch.float32), requires_grad=False ) self.pca_mean = torch.nn.Parameter( torch.tensor(pca_mean, dtype=torch.float32), requires_grad=False ) self.pca_std = torch.nn.Parameter( torch.tensor(pca_std, dtype=torch.float32), requires_grad=False ) self.trainable = False
[docs] def forward(self, base, source, subspaces=None, **kwargs): base_norm = (base - self.pca_mean) / self.pca_std source_norm = (source - self.pca_mean) / self.pca_std rotated_base = torch.matmul(base_norm, self.pca_components.T) # B * D_R rotated_source = torch.matmul(source_norm, self.pca_components.T) # interchange rotated_base = _do_intervention_by_swap( rotated_base, rotated_source, "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, ) # inverse base output = torch.matmul(rotated_base, self.pca_components) # B * D output = (output * self.pca_std) + self.pca_mean return output
def __str__(self): return f"PCARotatedSpaceIntervention()"
[docs] class NoiseIntervention(ConstantSourceIntervention, LocalistRepresentationIntervention): """Noise intervention"""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) rs = np.random.RandomState(1) prng = lambda *shape: rs.randn(*shape) noise_level = kwargs["noise_leve"] \ if "noise_leve" in kwargs else 0.13462981581687927 self.register_buffer('noise', torch.from_numpy( prng(1, 4, self.embed_dim))) self.register_buffer('noise_level', torch.tensor(noise_level))
[docs] def forward(self, base, source=None, subspaces=None, **kwargs): base[..., : self.interchange_dim] += self.noise * self.noise_level return base
def __str__(self): return f"NoiseIntervention()"
[docs] class AutoencoderIntervention(TrainableIntervention): """Intervene in the latent space of an autoencoder."""
[docs] def __init__(self, **kwargs): super().__init__(**kwargs) if "latent_dim" not in kwargs: raise ValueError('Missing latent_dim in kwargs.') if "embed_dim" in kwargs: self.embed_dim = torch.tensor(kwargs["embed_dim"]) self.autoencoder = AutoencoderLayer( self.embed_dim, kwargs["latent_dim"])
[docs] def forward(self, base, source, subspaces=None, **kwargs): base_dtype = base.dtype base = base.to(self.autoencoder.encoder[0].weight.dtype) base_latent = self.autoencoder.encode(base) source_latent = self.autoencoder.encode(source) base_latent[..., self.interchange_dim] = source_latent[..., self.interchange_dim] inv_output = self.autoencoder.decode(base_latent) return inv_output.to(base_dtype)
def __str__(self): return f"AutoencoderIntervention()"
[docs] class JumpReLUAutoencoderIntervention(TrainableIntervention): """Interchange intervention on JumpReLU SAE's latent subspaces"""
[docs] def __init__(self, **kwargs): # Note that we initialise these to zeros because we're loading in pre-trained weights. # If you want to train your own SAEs then we recommend using blah super().__init__(**kwargs, keep_last_dim=True) self.W_enc = torch.nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"])) self.W_dec = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim)) self.threshold = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"])) self.b_enc = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"])) self.b_dec = torch.nn.Parameter(torch.zeros(self.embed_dim))
def encode(self, input_acts): pre_acts = input_acts @ self.W_enc + self.b_enc mask = (pre_acts > self.threshold) acts = mask * torch.nn.functional.relu(pre_acts) return acts def decode(self, acts): return acts @ self.W_dec + self.b_dec
[docs] def forward(self, base, source=None, subspaces=None, **kwargs): # generate latents for base and source runs. base_latent = self.encode(base) source_latent = self.encode(source) # intervention. intervened_latent = _do_intervention_by_swap( base_latent, source_latent, "interchange", self.interchange_dim, subspaces, subspace_partition=self.subspace_partition, use_fast=self.use_fast, ) # decode intervened latent. recon = self.decode(intervened_latent) return recon
def __str__(self): return f"JumpReLUAutoencoderIntervention()"