Nested Hierarchical Structure with MQNLI#

In this notebook, we use pyvene to analyze language models on a complex heirarchical task: multiply-quantified natural language inference (MQNLI). We begin by describing the causal structure that generates our data through semantic composition. Then, we analyze whether models that learn this task employ the same compositional structure to solve it.

__author__ = "Amir Zur"

Set-Up#

%load_ext autoreload
%autoreload 2
import pyvene
import os
import json 
import random
import copy
import itertools
import numpy as np
from tqdm import tqdm, trange
from collections import Counter

import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer
from datasets import Dataset
from sklearn.metrics import accuracy_score

from pyvene import CausalModel
# from pyvene.models.configuration_intervenable_model import RepresentationConfig
from pyvene import (
    IntervenableModel,
    RotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig
)

from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm
# from pyvene.models.gpt_neox.modelings_intervenable_gpt_neox import create_gpt_neox
TRAIN_DIR = 'mqnli_factual'
DAS_DIR = 'mqnli_das'
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
<torch._C.Generator at 0x169a1e510b0>

The MQNLI Task#

Multiply-quantified natural language inference (MQNLI) is a variant of natural language inference (NLI) with a highly structured composition. Each datapoint is a pair of sentences, and each label is the logical relation between them (e.g., entailment, reverse entailment, no relation). Crucially, the logical relation can be computed compositionally: first compute the relation between the two sentences’ nouns, then their determiners, and then combine these intermediate computations to find the relation between the sentences’ noun phrases (NPs).

In this section, we walk through the high-level causal structure of MQNLI with examples from the MQNLI dataset. We encourage first-time readers to read the original paper, Geiger, Cases, Karttunen, and Potts (2018), before getting started.

# JSON files generated from adapting MQNLI codebase
# https://github.com/atticusg/MultiplyQuantifiedData

class Hashabledict(dict):
    def __hash__(self):
        return hash(frozenset(self))

with open('tutorial_data/mqnli_q_projectivity.json') as f:
    determiner_signatures = json.load(f)
    determiner_signatures = Hashabledict({
        q1: Hashabledict({
            q2: Hashabledict({
                r1: Hashabledict({
                    r2: determiner_signatures[q1][q2][r1][r2]  
                    for r2 in determiner_signatures[q1][q2][r1]
                })  
                for r1 in determiner_signatures[q1][q2]
            })
            for q2 in determiner_signatures[q1]
        })
        for q1 in determiner_signatures
    })

with open('tutorial_data/mqnli_neg_signature.json') as f:
    negation_signature = Hashabledict(json.load(f))

with open('tutorial_data/mqnli_empty_signature.json') as f:
    emptystring_signature = Hashabledict(json.load(f))

with open('tutorial_data/mqnli_cont_signature.json') as f:
    compose_contradiction_signature = Hashabledict(json.load(f))

with open('tutorial_data/mqnli_neg_cont_signature.json') as f:
    compose_neg_contradiction_signature = Hashabledict(json.load(f))
parents = {
    "N_P_O": [],
    "N_H_O": [],
    "N_O": ["N_P_O", "N_H_O"],
    "Adj_P_O": [],
    "Adj_H_O": [],
    "Adj_O": ["Adj_P_O", "Adj_H_O"],
    "NP_O": ["Adj_O", "N_O"],
    "Q_P_O": [],
    "Q_H_O": [],
    "Q_O": ["Q_P_O", "Q_H_O"],
    "V_P": [],
    "V_H": [],
    "V": ["V_P", "V_H"],
    "Adv_P": [],
    "Adv_H": [],
    "Adv": ["Adv_P", "Adv_H"],
    "VP": ["Adv", "V"],
    "QP_O": ["Q_O", "NP_O", "VP"],
    "Neg_P": [],
    "Neg_H": [],
    "Neg": ["Neg_P", "Neg_H"],
    "NegP": ["Neg", "QP_O"],
    "N_P_S": [],
    "N_H_S": [],
    "N_S": ["N_P_S", "N_H_S"],
    "Adj_P_S": [],
    "Adj_H_S": [],
    "Adj_S": ["Adj_P_S", "Adj_H_S"],
    "NP_S": ["Adj_S", "N_S"],
    "Q_P_S": [],
    "Q_H_S": [],
    "Q_S": ["Q_P_S", "Q_H_S"],
    "QP_S": ["Q_S", "NP_S", "NegP"]
}
EMPTY = ""
IND = "independence"
EQV = "equivalence"
ENT = "entails"
REV = "reverse entails"
CON = "contradiction"
ALT = "alternation"
COV = "cover"
# all possible relation values from the original paper 
# (https://arxiv.org/pdf/1810.13033.pdf)
RELATIONS = [IND, EQV, ENT, REV, CON, ALT, COV]

Q_VALUES = [
    determiner_signatures[q1][q2]
    for q1 in determiner_signatures for q2 in determiner_signatures[q1]
]

values = {
    "N_P_O": ["tree", "rock"],
    "N_H_O": ["tree", "rock"],
    "N_O": [EQV, IND],
    "Adj_P_O": ["happy", "sad", EMPTY],
    "Adj_H_O": ["happy", "sad", EMPTY],
    "Adj_O": [EQV, IND, ENT, REV],
    "NP_O": [EQV, IND, ENT, REV],
    "Q_P_O": ["some", "every"],
    "Q_H_O": ["some", "every"],
    "Q_O": Q_VALUES,
    "V_P": ["climbed", "threw"],
    "V_H": ["climbed", "threw"],
    "V": [EQV, IND],
    "Adv_P": ["energetically", "joyfully", EMPTY],
    "Adv_H": ["energetically", "joyfully", EMPTY],
    "Adv": [EQV, IND, ENT, REV],
    "VP": [EQV, IND, ENT, REV],
    "QP_O": [EQV, IND, ENT, REV],
    "Neg_P": ["not", EMPTY],
    "Neg_H": ["not", EMPTY],
    "Neg": [
        negation_signature, 
        emptystring_signature, 
        compose_contradiction_signature, 
        compose_neg_contradiction_signature
    ],
    "NegP": RELATIONS,
    "N_P_S": ["child", "dog"],
    "N_H_S": ["child", "dog"],
    "N_S": [EQV, IND],
    "Adj_P_S": ["little", "cute", EMPTY],
    "Adj_H_S": ["little", "cute", EMPTY],
    "Adj_S": [EQV, IND, ENT, REV],
    "NP_S": [EQV, IND, ENT, REV],
    "Q_P_S": ["some", "every"],
    "Q_H_S": ["some", "every"],
    "Q_S": Q_VALUES,
    "QP_S": RELATIONS
}
# adapted from original code for MQNLI:
# https://github.com/atticusg/MultiplyQuantifiedData/blob/master/natural_logic_model.py

def adj_merge(adj_p, adj_h):
    if adj_p == adj_h:
        return EQV
    if adj_p == EMPTY:
        return REV
    if adj_h == EMPTY:
        return ENT
    return IND

adv_merge = adj_merge

def noun_phrase(adj, noun):
    #merges a noun relation with an adjective relation
    #or a verb relation with an adverb relation
    # makes sense: if the objects are the same, then adjective's relation holds
    # otherwise, they're independent
    if noun == EQV:
        return adj
    return IND

verb_phrase = noun_phrase

def negation_merge(neg_p, neg_h):
    #merges negation
    if neg_p == neg_h and neg_p == EMPTY:
        return Hashabledict(emptystring_signature)
    if neg_p == neg_h and neg_p != EMPTY:
        return Hashabledict(negation_signature)
    if neg_p == EMPTY:
        return Hashabledict(compose_contradiction_signature)
    if neg_p != EMPTY:
        return Hashabledict(compose_neg_contradiction_signature)

negation_phrase = lambda neg, qp: neg[qp]

noun_merge = lambda n_p, n_h: EQV if n_p == n_h else IND
verb_merge = noun_merge

quantifier_merge = lambda q_p, q_h: determiner_signatures[q_p][q_h]

quantifier_phrase = lambda q, np, vp: q[np][vp]

functions = {
    "N_P_O": lambda: "tree",
    "N_H_O": lambda: "tree",
    "N_O": noun_merge,
    "Adj_P_O": lambda: "happy",
    "Adj_H_O": lambda: "happy",
    "Adj_O": adj_merge,
    "NP_O": noun_phrase,
    "Q_P_O": lambda: "some",
    "Q_H_O": lambda: "some",
    "Q_O": quantifier_merge,
    "V_P": lambda: "climbed",
    "V_H": lambda: "climbed",
    "V": verb_merge,
    "Adv_P": lambda: "energetically",
    "Adv_H": lambda: "energetically",
    "Adv": adv_merge,
    "VP": verb_phrase,
    "QP_O": quantifier_phrase,
    "Neg_P": lambda: "not",
    "Neg_H": lambda: "not",
    "Neg": negation_merge,
    "NegP": negation_phrase,
    "N_P_S": lambda: "dog",
    "N_H_S": lambda: "dog",
    "N_S": noun_merge,
    "Adj_P_S": lambda: "cute",
    "Adj_H_S": lambda: "cute",
    "Adj_S": adj_merge,
    "NP_S": noun_phrase,
    "Q_P_S": lambda: "some",
    "Q_H_S": lambda: "some",
    "Q_S": quantifier_merge,
    "QP_S": quantifier_phrase
}
# coordinates to display the MQNLI tree
pos = {
    "N_P_O": (32, 0.3),
    "N_H_O": (30, 0.7),
    "N_O": (31, 1.3),
    "Adj_P_O": (28, 0),
    "Adj_H_O": (26, 0.5),
    "Adj_O": (27, 1),
    "NP_O": (29, 2),
    "Q_P_O": (24, 1.3),
    "Q_H_O": (22, 1.7),
    "Q_O": (23, 2.5),
    "V_P": (21, 0),
    "V_H": (19, 0.5),
    "V": (20, 1),
    "Adv_P": (17, -0.3),
    "Adv_H": (15, 0.2),
    "Adv": (16, 0.7),
    "VP": (18, 2),
    "QP_O": (25, 3),
    "Neg_P": (13, 2.5),
    "Neg_H": (11, 3),
    "Neg": (12, 3.5),
    "NegP": (14, 4),
    "N_P_S": (9, 2.2),
    "N_H_S": (7, 2.8),
    "N_S": (8, 3.3),
    "Adj_P_S": (5, 1.5),
    "Adj_H_S": (3, 2),
    "Adj_S": (4, 2.5),
    "NP_S": (6, 4.3),
    "Q_P_S": (2, 3.2),
    "Q_H_S": (0, 3.5),
    "Q_S": (1, 4),
    "QP_S": (10, 5)
}
variables = list(parents.keys())  # pretty sure this preserves order?
mqlni_model = CausalModel(variables, values, parents, functions, pos=pos)

We begin by visualizing the high-level structure of the MQNLI task. All leaf nodes consist of text entries (e.g., Adj_P_O might be “fun”, N_H_O might be “child”), with one copy corresponding to the premise sentence (P) and the other corresponding to the hypothesis (H). All other nodes take on relation values (e.g., entailment, reverse entailment, no relation), or mappings between relations (e.g., (entailment, entailment) –> no entailment).

mqlni_model.print_structure()
../../_images/d1c6b90794f1c5c0fcfa0931492742db78295ea701f49c67bfe4f5609cebd96a.png

Example MQNLI Input#

Here we walk through an example MQNLI input, presented below.

Premise: Every dog climbed some tree.
Hypothesis: Some dog did not climb every tree.

The output for this sentence pair is contradiction, because the premise entails that there cannot be a dog who climbed no tree.

inputs = {
    # premise
    "Q_P_S": "every",
    "Adj_P_S": EMPTY,
    "N_P_S": "dog",
    "Neg_P": EMPTY,
    "Adv_P": EMPTY,
    "V_P": "climbed",
    "Q_P_O": "some",
    "Adj_P_O": EMPTY,
    "N_P_O": "tree",
    # hypothesis
    "Q_H_S": "some",
    "Adj_H_S": EMPTY,
    "N_H_S": "dog",
    "Neg_H": "not",
    "Adv_H": EMPTY,
    "V_H": "climbed",
    "Q_H_O": "some",
    "Adj_H_O": EMPTY,
    "N_H_O": "tree"
}

setting = mqlni_model.run_forward(inputs)
def print_premise(setting):
    print(
        setting["Q_P_S"],
        setting["Adj_P_S"],
        setting["N_P_S"],
        setting["Neg_P"],
        setting["Adv_P"],
        setting["V_P"],
        setting["Q_P_O"],
        setting["Adj_P_O"],
        setting["N_P_O"]
    )

def print_hypothesis(setting):
    print(
        setting["Q_H_S"],
        setting["Adj_H_S"],
        setting["N_H_S"],
        setting["Neg_H"],
        setting["Adv_H"],
        setting["V_H"],
        setting["Q_H_O"],
        setting["Adj_H_O"],
        setting["N_H_O"]
    )

print_premise(setting)
print_hypothesis(setting)

print(setting["QP_S"]) # the output is in the root node, QP_S
every  dog   climbed some  tree
some  dog not  climbed some  tree
contradiction

We display the structure of the MQNLI example below. Note that the intermediate relation values compose with each other to produce the final output. For instance, since N_O and Adj_O both take on the value equivalence, then NP_O is also an equivalence relation.

display = {v: not isinstance(values[v][0], dict) for v in variables}
mqlni_model.print_setting(setting, display=display)
../../_images/47c576c6625384e231a4e0bd6c52e5721c4b4d45373b16378e3fa9f1508d6107.png

MQNLI with an Intervention#

Below is a run of the MQNLI logic model where we intervene on the object quantifier (QP_O) and swap its relation value from equivalence to contradiction. Note that this changes the final output from contradiction to entailment.

intervention = {**inputs, 'QP_O': 'contradiction'}
setting = mqlni_model.run_forward(intervention)
print_premise(setting)
print_hypothesis(setting)
print(setting["QP_S"])

mqlni_model.print_setting(setting, display=display)
every  dog   climbed some  tree
some  dog not  climbed some  tree
entails
../../_images/510186486f21246350da9fca871ec45a7cc4552e2185e3dccc90f2fde4ff0a8b.png

Training a Model on MQNLI#

In this section, we train a language model (GPT-2) on the MQNLI task. Importantly, we do not need to access the original MQNLI dataset to do this – having set up our causal model earlier in this notebook, we can generate datapoints by sampling from it directly!

dataset = mqlni_model.generate_factual_dataset(100, sampler=mqlni_model.sample_input_tree_balanced, return_tensors=False)

X = [example['input_ids'] for example in dataset]
y = [example['labels'] for example in dataset]
i = 0

print_premise(X[i])
print_hypothesis(X[i])
print(y[i]['QP_S'])
some little child  energetically climbed some happy tree
every  child not energetically climbed some happy tree
alternation

We want to make sure that we indeed sampled evenly from the causal structure. For now, we can just verify that every possible output value (i.e., the values of the root node QP_S) has enough datapoints.

Counter([n['QP_S'] for n in y])
Counter({'alternation': 16,
         'contradiction': 18,
         'reverse entails': 15,
         'independence': 14,
         'entails': 16,
         'equivalence': 10,
         'cover': 11})

Time to train a language model!

config, tokenizer, model = create_gpt2_lm()
loaded model
def premise_to_string(setting):
    return \
        setting["Q_P_S"] + ' ' + \
        setting["Adj_P_S"] + ' ' + \
        setting["N_P_S"] + ' ' + \
        setting["Neg_P"] + ' ' + \
        setting["Adv_P"] + ' ' + \
        setting["V_P"] + ' ' + \
        setting["Q_P_O"] + ' ' + \
        setting["Adj_P_O"] + ' ' + \
        setting["N_P_O"]

def hypothesis_to_string(setting):
    return \
        setting["Q_H_S"] + ' ' + \
        setting["Adj_H_S"] + ' ' + \
        setting["N_H_S"] + ' ' + \
        setting["Neg_H"] + ' ' + \
        setting["Adv_H"] + ' ' + \
        setting["V_H"] + ' ' + \
        setting["Q_H_O"] + ' ' + \
        setting["Adj_H_O"] + ' ' + \
        setting["N_H_O"]

def preprocess_input(setting):
    return f'Premise: {premise_to_string(setting)}\nHypothesis: {hypothesis_to_string(setting)}\nRelation: '

def preprocess_output(setting):
    return setting['QP_S']
IGNORE_INDEX = -100
MAX_LENGTH = 64

tokenizer.pad_token = tokenizer.eos_token

def preprocess(X, y):
    examples = [preprocess_input(x) for x in X]
    labels = [preprocess_output(y) for y in y]

    examples = tokenizer(
        examples, 
        padding='max_length', 
        max_length=MAX_LENGTH, 
        truncation=True, 
        return_tensors='pt'
    )
    labels = tokenizer(
        labels, 
        padding='max_length', 
        max_length=MAX_LENGTH, 
        truncation=True, 
        return_tensors='pt'
    )['input_ids'][:, 0] # get first token of label
    
    # put label at the last index
    examples['labels'] = torch.full_like(examples['input_ids'], IGNORE_INDEX)
    examples['labels'][:, -1] = labels

    return examples

train_dataset = preprocess(X, y)
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]=TRAIN_DIR

# save your trained model checkpoint to wandb
os.environ["WANDB_LOG_MODEL"]="false"

def accuracy_metric(x):
    labels = x.label_ids[:, -1]
    # predictions = x.predictions[0].argmax(axis=-1)[:, -2]  # uncomment for gpt-neox
    predictions = x.predictions.argmax(axis=-1)[:, -2]
    return {
        'accuracy': accuracy_score(y_true=labels, y_pred=predictions),
    }

train_ds = Dataset.from_dict(train_dataset)

batch_size = 8

training_args = TrainingArguments(
    output_dir=TRAIN_DIR,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-05,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    report_to="wandb", # optional, remove if you don't want to log to wandb
    use_cpu=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=train_ds,
    compute_metrics=accuracy_metric
)

Train the model and log it in a Weights & Biases run.

import wandb 

_ = trainer.train()
wandb version 0.16.4 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.16.3
Run data is saved locally in c:\Users\amirz\Source\NLP\clones\pyvene\tutorials\advanced_tutorials\wandb\run-20240314_140315-cwmxs9sx
{'eval_loss': 4.242105484008789, 'eval_accuracy': 0.17, 'eval_runtime': 12.9558, 'eval_samples_per_second': 7.719, 'eval_steps_per_second': 1.003, 'epoch': 1.0}
{'train_runtime': 52.4216, 'train_samples_per_second': 1.908, 'train_steps_per_second': 0.248, 'train_loss': 5.367249708909255, 'epoch': 1.0}
trainer.save_model(TRAIN_DIR)

We evaluate the model on the training dataset and a newly-sampled test dataset. The model generalizes, meaning it learned the MQNLI task!

tokenizer.pad_token = tokenizer.eos_token
test_examples = mqlni_model.generate_factual_dataset(100, sampler=mqlni_model.sample_input_tree_balanced, return_tensors=False)
X = [example['input_ids'] for example in test_examples]
y = [example['labels'] for example in test_examples]
test_dataset = preprocess(X, y)
test_ds = Dataset.from_dict(test_dataset)
results = trainer.evaluate(train_ds)
results
{'eval_loss': 4.242105484008789,
 'eval_accuracy': 0.17,
 'eval_runtime': 12.6987,
 'eval_samples_per_second': 7.875,
 'eval_steps_per_second': 1.024,
 'epoch': 1.0}
results = trainer.evaluate(test_ds)
results
{'eval_loss': 4.242105484008789,
 'eval_accuracy': 0.17,
 'eval_runtime': 12.532,
 'eval_samples_per_second': 7.98,
 'eval_steps_per_second': 1.037,
 'epoch': 1.0}
wandb.finish()

Run history:


eval/accuracy▁▁▁
eval/loss▁▁▁
eval/runtime█▄▁
eval/samples_per_second▁▅█
eval/steps_per_second▁▅█
train/epoch▁▁▁▁
train/global_step▁▁▁▁
train/total_flos
train/train_loss
train/train_runtime
train/train_samples_per_second
train/train_steps_per_second

Run summary:


eval/accuracy0.17
eval/loss4.24211
eval/runtime12.532
eval/samples_per_second7.98
eval/steps_per_second1.037
train/epoch1.0
train/global_step13
train/total_flos3266150400000.0
train/train_loss5.36725
train/train_runtime52.4216
train/train_samples_per_second1.908
train/train_steps_per_second0.248

View run derby-crumble-2 at: https://wandb.ai/amirzur1212/mqnli_factual/runs/cwmxs9sx
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: .\wandb\run-20240314_140315-cwmxs9sx\logs

Interpreting Our Model#

In this final section, we use distributed alignment search DAS from Geiger*, Wu*, Potts, Icard, and Goodman (2024) to find an alignment between the high-level causal structure of MQNLI and our low-level language model trained on the MQNLI dataset.

In this section, we won’t search for an alignment over the entire causal structure. Instead, we will search for an alignment between the low-level model representations and the NegP token. Efficiently searching for a full alignment over a complex, nested causal structure is an open problem that is worth pursuing!

config, tokenizer, model = create_gpt2_lm(name=TRAIN_DIR)
loaded model

Here we set up our alignment. We will search for a subspace with a dimension of 128 (a fourth of our model’s hidden dimension size) in the residual stream of the 10th layer of the model (out of 12 overall transformer layers).

config = IntervenableConfig(
    model_type=type(model),
    representations=[
        RepresentationConfig(
            10,  # layer
            "block_output",  # intervention type
            "pos",  # intervention unit is now aligned with tokens
            1,  # max number of unit
            subspace_partition=[[0, 128]], 
            # intervention_link_key=0,
        )
    ],
    intervention_types=RotatedSpaceIntervention,
)
intervenable = IntervenableModel(config, model, use_fast=True)
# intervenable.set_device('cuda')
intervenable.disable_model_gradients()
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered

Crucially, our interpretability experiments rely on a counterfactual dataset: for a given input, we want to consider what might happen had the value of NegP changed while all else remained the same. Fortunately, this is precisely what our causal model can provide us with! We can generate a counterfactual dataset by sampling inputs that only vary from each other on the NegP node.

def sample_intervention(*args, **kwargs):
    return {
        'NegP' : random.choice(mqlni_model.values['NegP'])
    }

def intervention_id(*args, **kwargs):
    return 0

batch_size = 2 # specifies how many inputs we want per intervention that is sampled

dataset = mqlni_model.generate_counterfactual_dataset(
    100, intervention_id, batch_size, 
    sampler=mqlni_model.sample_input_tree_balanced, intervention_sampler=sample_intervention, return_tensors=False
)
print(dataset[0]['base_labels']['QP_S'])
print(dataset[0]['labels']['QP_S'])
alternation
independence
# check that base labels are diverse (should be guaranteed by sampling from balanced input tree)
Counter([d['base_labels']['QP_S'] for d in dataset])
Counter({'alternation': 23,
         'reverse entails': 10,
         'cover': 15,
         'equivalence': 15,
         'contradiction': 9,
         'independence': 14,
         'entails': 14})
# check that counterfactuals labels are diverse (note that this could be skewed by the node's effect on the final output)
Counter([d['labels']['QP_S'] for d in dataset])
Counter({'independence': 47,
         'reverse entails': 10,
         'entails': 9,
         'alternation': 20,
         'cover': 10,
         'contradiction': 3,
         'equivalence': 1})
IGNORE_INDEX = -100
MAX_LENGTH = 64

tokenizer.pad_token = tokenizer.eos_token

def preprocess_input(setting):
    return f'Premise: {premise_to_string(setting)}\nHypothesis: {hypothesis_to_string(setting)}\nRelation: '

def preprocess_output(setting):
    return setting['QP_S']

def tokenize(x):
    return tokenizer(
        x, 
        padding='max_length', 
        max_length=MAX_LENGTH, 
        truncation=True, 
        return_tensors='pt'
    )

def preprocess_counterfactual(data):
    preprocessed_data = []
    for d in data:
        base = preprocess_input(d['input_ids'])
        sources = [preprocess_input(d['source_input_ids'][0])]
        label = preprocess_output(d['labels'])
        base_label = preprocess_output(d['base_labels'])

        preprocessed = {}
        preprocessed['input'] = tokenize(base)
        preprocessed['source'] = [tokenize(sources[0])]
        # place label at last index
        label = tokenize(label)['input_ids'][:, 0]
        preprocessed['label'] = torch.full_like(preprocessed['input']['input_ids'], IGNORE_INDEX)
        preprocessed['label'][:, -1] = label
        # repeat for base label
        base_label = tokenize(base_label)['input_ids'][:, 0]
        preprocessed['base_label'] = torch.full_like(preprocessed['input']['input_ids'], IGNORE_INDEX)
        preprocessed['base_label'][:, -1] = base_label
        preprocessed['intervention_id'] = torch.tensor(d['intervention_id'])
        preprocessed_data.append(preprocessed)
    return preprocessed_data
train_dataset = preprocess_counterfactual(dataset)
dataloader = DataLoader(train_dataset, batch_size=2)

Set up our optimizer, loss function, and accuracy metric.

epochs = 3
batch_size = 8
gradient_accumulation_steps = 1

optimizer_params = []
for k, v in intervenable.interventions.items():
    optimizer_params += [{"params": v[0].rotate_layer.parameters()}]
    break
optimizer = torch.optim.Adam(optimizer_params, lr=0.001)


def compute_metrics(eval_preds, eval_labels):
    accuracy = accuracy_score(
        y_pred=eval_preds[..., -2].squeeze().clone().detach().cpu().numpy(), 
        y_true=eval_labels[..., -1].squeeze().clone().detach().cpu().numpy()
    )
    return {
        "accuracy": accuracy
    }


def compute_loss(outputs, labels):
    # Shift so that tokens < n predict n
    shift_logits = outputs[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss

Run DAS to find an alignment between our high-level causal model and the language model trained on MQNLI.

intervenable.model.train()  # train enables drop-off but no grads
print("intervention trainable parameters: ", intervenable.count_parameters())
train_iterator = trange(0, int(epochs), desc="Epoch")

total_step = 0
for epoch in train_iterator:
    epoch_iterator = tqdm(
        DataLoader(
            train_dataset,
            batch_size=batch_size,
            drop_last=True
        ),
        desc=f"Epoch: {epoch}",
        position=0,
        leave=True
    )
    for batch in epoch_iterator:
        inputs = {k: v.to(intervenable.get_device()) for k, v in batch['input'].items()}
        sources = [{k: v.to(intervenable.get_device()) for k, v in s.items()} for s in batch['source']]
        _, counterfactual_outputs = intervenable(
            inputs,
            sources,
            {"sources->base": ([[[MAX_LENGTH - 2]] * batch_size], [[[MAX_LENGTH - 2]] * batch_size])},
            subspaces=[[[0]] * batch_size],
        )

        eval_metrics = compute_metrics(
            counterfactual_outputs.logits.argmax(-1), batch["label"].to(intervenable.get_device())
        )

        # loss and backprop
        loss = compute_loss(
            counterfactual_outputs.logits, batch["label"].to(intervenable.get_device())
        )

        epoch_iterator.set_postfix({"loss": loss.item(), "acc": eval_metrics["accuracy"]})

        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            optimizer.step()
            intervenable.set_zero_grad()
        total_step += 1
intervention trainable parameters:  589824
Epoch: 0: 100%|██████████| 12/12 [00:59<00:00,  4.93s/it, loss=6.88, acc=0]    

Epoch: 1: 100%|██████████| 12/12 [00:58<00:00,  4.91s/it, loss=5.86, acc=0]   

Epoch: 2: 100%|██████████| 12/12 [00:59<00:00,  4.98s/it, loss=6.37, acc=0]    

Epoch: 100%|██████████| 3/3 [02:57<00:00, 59.28s/it]
intervenable.save(DAS_DIR)
Directory 'mqnli_das' already exists.

Evaluate DAS Alignment#

Lastly, we evaluate the accuracy of the alignment found by DAS.

config, tokenizer, model = create_gpt2_lm(name=TRAIN_DIR)
intervenable = IntervenableModel.load(DAS_DIR, model)
WARNING:root:The key is provided in the config. Assuming this is loaded from a pretrained module.
loaded model

First, we evaluate the model’s performance on the MQNLI factual task. As before, we expect our model to complete this task with nearly perfect accuracy.

tokenizer.pad_token = tokenizer.eos_token
examples = mqlni_model.generate_factual_dataset(100, sampler=mqlni_model.sample_input_tree_balanced, return_tensors=False)
X = [example['input_ids'] for example in examples]
y = [example['labels'] for example in examples]
test_factual_dataset = preprocess(X, y)
def accuracy_metric(x):
    labels = x.label_ids[:, -1]
    # predictions = x.predictions[0].argmax(axis=-1)[:, -2]  # take one index before label
    predictions = x.predictions.argmax(axis=-1)[:, -2]
    return {
        'accuracy': accuracy_score(y_true=labels, y_pred=predictions),
    }

test_factual_ds = Dataset.from_dict(test_factual_dataset)

batch_size = 8

training_args = TrainingArguments(
    output_dir="test_mqnli_trainer",
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    learning_rate=1e-05,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    report_to="none",
    use_cpu=True,
)

trainer = Trainer(
    model=intervenable.model,
    args=training_args,
    compute_metrics=accuracy_metric
)
trainer.evaluate(test_factual_ds)
{'eval_loss': 4.302060127258301,
 'eval_accuracy': 0.17,
 'eval_runtime': 18.8671,
 'eval_samples_per_second': 5.3,
 'eval_steps_per_second': 0.689}

Next, we evaluate the DAS alignment on a newly-sampled counterfactual dataset. If our alignment generalizes, then we should expect a high interchange intervention accuracy (IIA) on the counterfactual dataset. We compute IIA by comparing the model’s output with an interchange-intervention to the causal model’s output with that same intervention applied to its high-level variables (i.e., we verify that an intervention changes the model’s output in the direction that our causal model predicts).

batch_size = 8

dataset = mqlni_model.generate_counterfactual_dataset(
    1000, intervention_id, 1, 
    sampler=mqlni_model.sample_input_tree_balanced, intervention_sampler=sample_intervention, return_tensors=False
)
test_cf_dataset = preprocess_counterfactual(dataset)
def compute_metrics(eval_preds, eval_labels):
    accuracy = accuracy_score(
        y_pred=eval_preds[..., -2].squeeze().clone().detach().cpu().numpy(), 
        y_true=eval_labels[..., -1].squeeze().clone().detach().cpu().numpy()
    )
    return {
        "accuracy": accuracy
    }
intervenable.set_device('cuda')
# intervenable.disable_intervention_gradients()
# intervenable.disable_model_gradients()
# intervenable.model.eval()  # train enables drop-off but no grads
print("intervention trainable parameters: ", intervenable.count_parameters())

base_labels = None
counterfactual_labels = None
base_preds = None
counterfactual_preds = None

batch_size = 8

with torch.no_grad():
    for batch in tqdm(DataLoader(
        test_cf_dataset,
        batch_size=batch_size,
        drop_last=True
    )):
        inputs = {k: v.to(intervenable.get_device()) for k, v in batch['input'].items()}
        sources = [{k: v.to(intervenable.get_device()) for k, v in s.items()} for s in batch['source']]
        _, counterfactual_outputs = intervenable(
            inputs,
            sources,
            {"sources->base": ([[[MAX_LENGTH - 2]] * batch_size], [[[MAX_LENGTH - 2]] * batch_size])},
            subspaces=[[[0]] * batch_size],
        )

        if base_labels is None:
            # base_preds = base_outputs.logits.argmax(-1).clone().detach()
            # base_labels = batch['base_label']
            counterfactual_preds = counterfactual_outputs.logits.argmax(-1).clone().detach()
            counterfactual_labels = batch['label']
        else:
            # base_preds = torch.cat((base_preds, base_outputs.logits.argmax(-1).clone().detach()))
            # base_labels = torch.cat((base_labels, batch['base_label']))
            counterfactual_preds = torch.cat((counterfactual_preds, counterfactual_outputs.logits.argmax(-1).clone().detach()))
            counterfactual_labels = torch.cat((counterfactual_labels, batch['label']))
{
    # 'base_accuracy': compute_metrics(base_preds, base_labels)['accuracy'],
    'counterfactual_accuracy': compute_metrics(counterfactual_preds, counterfactual_labels)['accuracy']
}
{'counterfactual_accuracy': 0.125}