Analyzing Sparse Autoencoders (SAEs) from Gemma Scope#

Open In Colab

__author__ = "Zhengxuan Wu"
__version__ = "09/23/2024"

Overview#

This tutorial aims to (1) reproduce and (2) extend some of the results in the Gemma Scope (SAE) tutorial in notebook for interpreting latents of SAEs. This tutorial also shows basic model steering with SAEs. This notebook is built as a show-case for the Gemma 2 2B model as well as its SAEs. However, this tutorial can be extended to any other model types and their SAEs.

Note: This tutorial assumes SAEs are pretrained separately.

Set-up#

try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene

except ModuleNotFoundError:
    !pip install git+https://github.com/stanfordnlp/pyvene.git
/u/nlp/anaconda/main/anaconda3/envs/wuzhengx-310/lib/python3.10/site-packages/transformers/utils/hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
from pyvene import (
    ConstantSourceIntervention,
    SourcelessIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
    CollectIntervention,
    JumpReLUAutoencoderIntervention
)

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
import torch.nn as nn

# If you haven't login, you need to do so.
# notebook_login()

Loading the model and its tokenizer#

torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b", # google/gemma-2b-it
    device_map='auto',
)
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

We give it the prompt “Would you be able to travel through time using a wormhole?” and print the generated output

# The input text
prompt = "Would you be able to travel through time using a wormhole?"

# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("cuda")
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))
tensor([[     2,  18925,    692,    614,   3326,    577,   5056,   1593,   1069,
           2177,    476,  47420,  18216, 235336]], device='cuda:0')
<bos>Would you be able to travel through time using a wormhole?

[Answer 1]

Yes, you can travel through time using a wormhole.

A wormhole is a theoretical object that connects two points in space-time. It is a tunnel through space-time that allows objects to travel from

Loading a SAE, and create SAE interventions#

pyvene can load SAEs as interventions for analyzing latents as well as model steering.

LAYER = 20
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename=f"layer_{LAYER}/width_16k/average_l0_71/params.npz",
    force_download=False,
)
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}

Implementing SAEs as pyvene-native Interventions#

Create a pyvene-native intervention for SAEs to collect latent collection

class JumpReLUSAECollectIntervention(
    CollectIntervention
):
  """Collect activations"""
  def __init__(self, **kwargs):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__(**kwargs, keep_last_dim=True)
    self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
    self.W_dec = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
    self.threshold = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
    self.b_enc = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
    self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def forward(self, base, source=None, subspaces=None):
    acts = self.encode(base)

    return acts

Running the model with SAE to collect activations with pyvene APIs#

sae = JumpReLUSAECollectIntervention(
    embed_dim=params['W_enc'].shape[0],
    low_rank_dimension=params['W_enc'].shape[1]
)
sae.load_state_dict(pt_params, strict=False)
sae.cuda()

# add the intervention to the model computation graph via the config
pv_model = pyvene.IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae}, model=model)
sae_acts = pv_model.forward(
    {"input_ids": inputs}, return_dict=True).collected_activations[0]
"""
Results (from Gemma Scope) should be:
tensor([[7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,
           86,   89]], device='cuda:0')
"""
(sae_acts > 1).sum(-1)
tensor([7017,   47,   65,   70,   55,   72,   65,   75,   80,   72,   68,   93,
          86,   89], device='cuda:0')
"""
Results (from Gemma Scope) should be:
tensor([[ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,
         10004, 10004, 12935,  3442]], device='cuda:0')
"""
values, inds = sae_acts.max(-1)

inds
tensor([ 6631,  5482, 10376,  1670, 11023,  7562,  9407,  8399, 12935, 10004,
        10004, 10004, 12935,  3442], device='cuda:0')

Gemma-2-2B-it steering with Gemma-2-2B SAEs#

We could also try to steer Gemma-2-2B-it by overloading Gemma-2-2B SAE, and see if it works.

torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", # google/gemma-2b-it
    device_map='auto',
)
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

Implementing SAEs as pyvene-native Interventions for model steering#

The subspace notation built in to pyvene let us to steer models by intervening on different features.

class JumpReLUSAESteeringIntervention(
    SourcelessIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
  def __init__(self, **kwargs):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__(**kwargs, keep_last_dim=True)
    self.W_enc = nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
    self.W_dec = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
    self.threshold = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
    self.b_enc = nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
    self.b_dec = nn.Parameter(torch.zeros(self.embed_dim))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, base, source=None, subspaces=None):
    steering_vec = torch.tensor(subspaces["mag"]) * self.W_dec[subspaces["idx"]]
    return base + steering_vec

Loading the Gemma base model SAE weights.

sae = JumpReLUSAESteeringIntervention(
    embed_dim=params['W_enc'].shape[0],
    low_rank_dimension=params['W_enc'].shape[1]
)
sae.load_state_dict(pt_params, strict=False)
sae.cuda()

# add the intervention to the model computation graph via the config
pv_model = pyvene.IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae}, model=model)
prompt = "Which dog breed do people think is cuter, poodle or doodle?"

prompt = tokenizer(prompt, return_tensors="pt").to("cuda")
_, reft_response = pv_model.generate(
    prompt, unit_locations=None, intervene_on_prompt=True, 
    subspaces=[{"idx": 10004, "mag": 100.0}],
    max_new_tokens=128, do_sample=True, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))
Which dog breed do people think is cuter, poodle or doodle? 

It really depends on personal preference, but it's often a subjective matter. 

Here's a bit about each, to help you decide:

**Poodles:**

* Origin: France
* Types: Standard, Miniature, Toy
* Known for: Curly, hypoallergenic fur; intelligence and trainability.
* Appearance: Classic, distinguished look with a flowing coat and well-defined facial features.

**Doodles:**

* Origin (general) Space-travel, time-travel or a blend - depending on the specific dog's ancestry.  The term is used across a variety of breeds.

Here you go: a “Space-travel, time-travel” Doodle!

Interchange intervention with JumpReLU SAEs.#

You can also swap values between examples for a specific latent dimension. However, since SAE usually maps a concpet to 1D subspace, swapping between examples and resetting the scalar to another value are similar.

sae = JumpReLUAutoencoderIntervention(
    embed_dim=params['W_enc'].shape[0],
    low_rank_dimension=params['W_enc'].shape[1]
)
sae.load_state_dict(pt_params, strict=False)
sae.cuda()

# add the intervention to the model computation graph via the config
pv_model = pyvene.IntervenableModel({
   "component": f"model.layers[{LAYER}].output",
   "intervention": sae}, model=model)
base = tokenizer(
    "Which dog breed do people think is cuter, poodle or doodle?", 
    return_tensors="pt").to("cuda")
source = tokenizer(
    "Origin (general) Space-travel, time-travel", 
    return_tensors="pt").to("cuda")

# run an interchange intervention 
original_outputs, intervened_outputs = pv_model(
  # the base input
  base=base, 
  # the source input
  sources=source, 
  # the location to intervene (swap last tokens)
  unit_locations={"sources->base": (11, 14)},
  # the SAE latent dimension mapping to the time travel concept ("10004")
  subspaces=[10004],
  output_original_output=True
)
logits_diff = intervened_outputs.logits[:,-1] - original_outputs.logits[:,-1]
values, indices = logits_diff.topk(k=10, sorted=True)
print("** topk logits diff **")
tokenizer.batch_decode(indices[0].unsqueeze(dim=-1))
** topk logits diff **
['PhysRevD',
 ' transporting',
 ' teleport',
 ' space',
 ' transit',
 ' transported',
 ' transporter',
 ' transpor',
 ' multiverse',
 ' universes']