Trainable Interventions#

__author__ = "Zhengxuan Wu"
__version__ = "11/28/2023"

Overview#

When interventions are static, you are mannually look for interesting counterfactual behaviors. What if interventions are trainable? and what if you train interventions with certain counterfactual behaviors?

We think, if you can train such interventions, you find a systematic way of affecting the causal circuits realized in the NNs. With certain types of interventions, e.g., basis respect ones (vanilla causal abstraction or DAS), you are doing causal abstraction.

In this tutorial, we show how you can train interventions with customized dataset.

Set-up#

try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene

except ModuleNotFoundError:
    !pip install git+https://github.com/frankaging/pyvene.git
[2024-01-11 01:23:38,042] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
import os, json, torch
import numpy as np
np.object = object
import pandas as pd
from torch.utils.data import DataLoader
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

from pyvene import embed_to_distrib, top_vals, format_token
from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
)
from pyvene import create_gpt2_lm

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

from tutorial_intervention_training_utils import (
    visualize_program,
    make_supervised_counterfactual_data_module,
)
from torch.nn import CrossEntropyLoss

programs = json.load(open(os.path.join("./tutorial_data/", "seed_programs.json")))
programs_10 = json.load(
    open(os.path.join("./tutorial_data/", "selected_programs_10.json"))
)

# load fine-tuned model for this tutorial from HF
config, tokenizer, gpt2 = create_gpt2_lm("zhengxuanzenwu/gpt2-5token-solver")
tokenizer.pad_token = tokenizer.eos_token
_ = gpt2.to("cuda")
_ = gpt2.eval()

# what is this model finetuned for? this visualizes a program that this model can solve.
select_program = "07065a"
print("\nA program that the model can solve (Cs are input tokens):")
visualize_program(programs[select_program])
loaded model

A program that the model can solve (Cs are input tokens):
../../_images/c0fb1f5e9c3706572fead3ff8ce22d1f25cb88e40ccb50dee6d6222d91f4357b.svg

Aligning the output variable with a single dimension in the rotated basis#

We know the output is a boolean value, which can be represented using a single dimension (i.e., pos/neg lies on a line with a linearly separable boundary). Can we learn a 1-d DAS direction to align activations with the output variable?

We first need to sample (base, source(s)) pairs. In this case, it is easy, since the counterfactual behavior is that source output label overwrites the base one.

# load in some datasets we need
num_of_shots = 10
mode = "E"
counterfactual_data_module, _ = make_supervised_counterfactual_data_module(
    programs[select_program],
    800,
    num_of_shots,
    "op5",  # this is the last variable = output
    tokenizer,
    data_path="./tutorial_data/",
    program_uuid=select_program,
    mode=mode,
    n_test_sample=1000,
)

print("Model input looks like this (10-shot ICL):")
print(counterfactual_data_module["train_dataset"][0]["question"])

# dataset cleaning
columns_to_remove = [
    "question",
    "source_question",
    "answers",
    "base_answers",
    "source_answers",
]
train_dataset = (
    counterfactual_data_module["train_dataset"]
    .remove_columns(columns_to_remove)
    .with_format("torch")
)
eval_dataset = (
    counterfactual_data_module["eval_dataset"]
    .remove_columns(columns_to_remove)
    .with_format("torch")
)
test_dataset = (
    counterfactual_data_module["test_dataset"]
    .remove_columns(columns_to_remove)
    .with_format("torch")
)


def add_locations(example):
    example["source_0->base.0.pos"] = [129]  # the fixed last token location
    example["source_0->base.1.pos"] = [129]  # the fixed last token location
    example["subspaces"] = [0]  # the only subspace is a single column
    return example


train_dataset = train_dataset.map(add_locations).shuffle(seed=42)
eval_dataset = eval_dataset.map(add_locations)
test_dataset = test_dataset.map(add_locations)
train_dataloader = DataLoader(train_dataset, batch_size=64)
eval_dataloader = DataLoader(eval_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

print("\nTraining data for the intervention should contain these fields:")
print(train_dataset[0].keys())
Model input looks like this (10-shot ICL):
reportprint,itches,reenshots,Jean,reenshots=True
dated,net,dated,Kate,dated=False
Inst,Amid,Billy,While,Billy=True
mas,quickShip,mas,imei,mas=True
Excellent,embedreportprint,Ret,Theme,Excellent=True
engers,Rick,engers,Jerry,engers=True
debian,OSED,international,quickShip,international=True
Led,orous,Lead,Cash,Lead=True
mission,embedreportprint,mission,mbudsman,mission=True
ABC,store,BBC,ixty,CBC=False
ilitation,Brow,ilitation,ICLE,ilitation

Training data for the intervention should contain these fields:
dict_keys(['input_ids', 'attention_mask', 'source_input_ids', 'source_attention_mask', 'labels', 'source_0->base.0.pos', 'source_0->base.1.pos', 'subspaces'])

A 1-d DAS direction learning config#

No full rotation learning is needed here.

def single_d_low_rank_das_position_config(
    model_type, intervention_type, layer, intervention_types
):
    config = IntervenableConfig(
        model_type=model_type,
        representations=[
            RepresentationConfig(
                layer,              # layer
                intervention_type,  # intervention type
                "pos",              # intervention unit
                1,                  # max number of unit
                low_rank_dimension=1,         # a single das direction
                subspace_partition=[[0, 1]],  # dummy partition
            ),
        ],
        intervention_types=intervention_types,
    )
    return config


config = single_d_low_rank_das_position_config(
    type(gpt2), "block_output", 11, LowRankRotatedSpaceIntervention
)
intervenable = IntervenableModel(config, gpt2)
intervenable.set_device("cuda")
intervenable.disable_model_gradients()

Your own loss and metrics function#

def inputs_collator(inputs):
    for k, v in inputs.items():
        if "->" in k:
            inputs[k] = v.tolist()
        elif "subspace" in k:
            inputs[k] = [v.tolist()]
        elif v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to("cuda")
    return inputs


def compute_loss(logits, labels):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, 50257)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    return loss


def compute_metrics(eval_preds, eval_labels):
    total_count = 0
    correct_count = 0
    for eval_pred, eval_label in zip(eval_preds, eval_labels):
        actual_test_labels = eval_label[:, -1]
        pred_test_labels = torch.argmax(eval_pred[:, -2], dim=-1)
        correct_labels = actual_test_labels == pred_test_labels
        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()
    accuracy = round(correct_count / total_count, 2)
    return {"accuracy": accuracy}["accuracy"]

Learning#

intervenable.train_alignment(
    train_dataloader=train_dataloader,
    compute_loss=compute_loss,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator,
)
Epoch: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [01:34<00:00,  9.43s/it, loss=0.24, acc=0.91]

Evaluation#

intervenable.eval_alignment(
    eval_dataloader=test_dataloader,
    compute_metrics=compute_metrics,
    inputs_collator=inputs_collator,
)
                                                                                                                                                                                                                                          
0.9763200000000002

The above means >97% of time, you can intervene on a single activation in the rotated basis to flip the model prediction