Causal Evaluation of Probes

Causal Evaluation of Probes#

This is a simple tutorial showing you to collect activations from intervention-points in a model. We’ll compare 1D DAS IIA on each layer and position for block_output in pythia-70M with logistic regression probing accuracy. The task we’ll look at is gender prediction, where gendered names are used in templates like “[name] walked because”, which elicits the associated gendered pronoun “he” or “she” as the next-token prediction for this model.

Open In Colab

__author__ = "Aryaman Arora"
__version__ = "01/10/2024"

Setup#

try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene as pv

except ModuleNotFoundError:
    !pip install git+https://github.com/stanfordnlp/pyvene.git
import pandas as pd
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)
import torch
import random
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
    geom_line,
    geom_point,
    geom_text,
    ggtitle, xlab, ylab, 
    ggsave
)
from plotnine.scales import scale_y_reverse, scale_fill_cmap
from tqdm import tqdm
from collections import namedtuple

Load model and data#

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "EleutherAI/pythia-70m" # "EleutherAI/pythia-6.9B"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token
gpt = AutoModelForCausalLM.from_pretrained(
    model,
    revision="main",
    torch_dtype=torch.bfloat16 if model == "EleutherAI/pythia-6.9b" else torch.float32,
).to(device)

We have a list of 100 names for each gender, and we’ll filter for names that are one token in length. We’ll further filter for examples the model agrees with our labels for, since some of these names might be ambiguous or the model might not have the expected behaviour. This ensures that baseline IIA is 0.

Example = namedtuple("Example", ["base", "src", "base_label", "src_label"])

names = {
    "he": [
        "James",
        "Robert",
        "John",
        "Michael",
        "David",
        "William",
        "Richard",
        "Joseph",
        "Thomas",
        "Christopher",
        "Charles",
        "Daniel",
        "Matthew",
        "Anthony",
        "Mark",
        "Donald",
        "Steven",
        "Andrew",
        "Paul",
        "Joshua",
        "Kenneth",
        "Kevin",
        "Brian",
        "George",
        "Timothy",
        "Ronald",
        "Jason",
        "Edward",
        "Jeffrey",
        "Ryan",
        "Jacob",
        "Gary",
        "Nicholas",
        "Eric",
        "Jonathan",
        "Stephen",
        "Larry",
        "Justin",
        "Scott",
        "Brandon",
        "Benjamin",
        "Samuel",
        "Gregory",
        "Alexander",
        "Patrick",
        "Frank",
        "Raymond",
        "Jack",
        "Dennis",
        "Jerry",
        "Tyler",
        "Aaron",
        "Jose",
        "Adam",
        "Nathan",
        "Henry",
        "Zachary",
        "Douglas",
        "Peter",
        "Kyle",
        "Noah",
        "Ethan",
        "Jeremy",
        "Walter",
        "Christian",
        "Keith",
        "Roger",
        "Terry",
        "Austin",
        "Sean",
        "Gerald",
        "Carl",
        "Harold",
        "Dylan",
        "Arthur",
        "Lawrence",
        "Jordan",
        "Jesse",
        "Bryan",
        "Billy",
        "Bruce",
        "Gabriel",
        "Joe",
        "Logan",
        "Alan",
        "Juan",
        "Albert",
        "Willie",
        "Elijah",
        "Wayne",
        "Randy",
        "Vincent",
        "Mason",
        "Roy",
        "Ralph",
        "Bobby",
        "Russell",
        "Bradley",
        "Philip",
        "Eugene",
    ],
    "she": [
        "Mary",
        "Patricia",
        "Jennifer",
        "Linda",
        "Elizabeth",
        "Barbara",
        "Susan",
        "Jessica",
        "Sarah",
        "Karen",
        "Lisa",
        "Nancy",
        "Betty",
        "Sandra",
        "Margaret",
        "Ashley",
        "Kimberly",
        "Emily",
        "Donna",
        "Michelle",
        "Carol",
        "Amanda",
        "Melissa",
        "Deborah",
        "Stephanie",
        "Dorothy",
        "Rebecca",
        "Sharon",
        "Laura",
        "Cynthia",
        "Amy",
        "Kathleen",
        "Angela",
        "Shirley",
        "Brenda",
        "Emma",
        "Anna",
        "Pamela",
        "Nicole",
        "Samantha",
        "Katherine",
        "Christine",
        "Helen",
        "Debra",
        "Rachel",
        "Carolyn",
        "Janet",
        "Maria",
        "Catherine",
        "Heather",
        "Diane",
        "Olivia",
        "Julie",
        "Joyce",
        "Victoria",
        "Ruth",
        "Virginia",
        "Lauren",
        "Kelly",
        "Christina",
        "Joan",
        "Evelyn",
        "Judith",
        "Andrea",
        "Hannah",
        "Megan",
        "Cheryl",
        "Jacqueline",
        "Martha",
        "Madison",
        "Teresa",
        "Gloria",
        "Sara",
        "Janice",
        "Ann",
        "Kathryn",
        "Abigail",
        "Sophia",
        "Frances",
        "Jean",
        "Alice",
        "Judy",
        "Isabella",
        "Julia",
        "Grace",
        "Amber",
        "Denise",
        "Danielle",
        "Marilyn",
        "Beverly",
        "Charlotte",
        "Natalie",
        "Theresa",
        "Diana",
        "Brittany",
        "Doris",
        "Kayla",
        "Alexis",
        "Lori",
        "Marie",
    ],
}

# filter names that are > 1 token
names = {
    key: [name for name in names[key] if len(tokenizer.tokenize(name)) == 1]
    for key in names
}
print(len(names["he"]), len(names["she"]))


def sample_example(tokenizer):
    # sample labels (not matching)
    base_label = random.choice(list(names.keys()))
    src_label = [key for key in names if key != base_label][0]

    # sample names
    base_name = random.choice(names[base_label])
    src_name = random.choice(names[src_label])

    # make pair
    base = tokenizer(f"<|endoftext|>{base_name} walked because", return_tensors="pt")
    src = tokenizer(f"<|endoftext|>{src_name} walked because", return_tensors="pt")
    base_label = tokenizer.encode(" " + base_label)[0]
    src_label = tokenizer.encode(" " + src_label)[0]
    return Example(base, src, base_label, src_label)
47 10
sample_example(tokenizer)
Example(base={'input_ids': tensor([[    0, 37376,  7428,   984]]), 'attention_mask': tensor([[1, 1, 1, 1]])}, src={'input_ids': tensor([[    0, 44305,  7428,   984]]), 'attention_mask': tensor([[1, 1, 1, 1]])}, base_label=344, src_label=703)
def generate_n_doable_examples(n, model, tokenizer):
    examples = []
    iterator = tqdm(range(n))
    while len(examples) < n:
        ex = sample_example(tokenizer)
        for k, v in ex.base.items():
            if v is not None and isinstance(v, torch.Tensor):
                ex.base[k] = v.to(model.device)
        for k, v in ex.src.items():
            if v is not None and isinstance(v, torch.Tensor):
                ex.src[k] = v.to(model.device)
        logits_base = model(**ex.base).logits[0, -1]
        logits_src = model(**ex.src).logits[0, -1]
        if (
            logits_base[ex.base_label] > logits_base[ex.src_label]
            and logits_src[ex.src_label] > logits_src[ex.base_label]
        ):
            examples.append(ex)
            iterator.update(1)
    return examples
# make dataset
total_steps = 100
trainset = generate_n_doable_examples(total_steps, gpt, tokenizer)
evalset = generate_n_doable_examples(50, gpt, tokenizer)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 100/100 [00:05<00:00, 17.51it/s]
100%|██████████| 50/50 [00:02<00:00, 19.85it/s]

DAS#

This is the usual 1D DAS setup, training on batch size of 1.

def intervention_config(intervention_site, layer, num_dims=1):
    config = pv.IntervenableConfig([
        {
            "layer": layer,
            "component": intervention_site,
            "intervention_type": pv.LowRankRotatedSpaceIntervention,
            "low_rank_dimension": num_dims,
        }
    ])
    return config
# loss function
loss_fct = torch.nn.CrossEntropyLoss()

def calculate_loss(logits, label):
    """Calculate cross entropy between logits and a single target label (can be batched)"""
    shift_labels = label.to(logits.device)
    loss = loss_fct(logits, shift_labels)
    return loss
# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers

# loop over layers and positions
for layer in range(num_layers):
    for position in range(4):
        print(f"layer: {layer}, position: {position}")

        # set up intervenable model
        config = intervention_config("block_output", layer, 1)
        intervenable = pv.IntervenableModel(config, gpt)
        intervenable.set_device(device)
        intervenable.disable_model_gradients()

        # set up optimizer
        optimizer_params = []
        for k, v in intervenable.interventions.items():
            try:
                optimizer_params.append({"params": v[0].rotate_layer.parameters()})
            except:
                pass
        optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps,
        )

        # training loop
        iterator = tqdm(trainset)
        for example in iterator:
            # forward pass
            _, counterfactual_outputs = intervenable(
                example.base,
                [example.src],
                {"sources->base": position},
            )

            # loss
            logits = counterfactual_outputs.logits[:, -1]
            loss = calculate_loss(logits, torch.tensor([example.src_label]).to(device))
            iterator.set_postfix({"loss": f"{loss.item():.3f}"})

            # backward
            loss.backward()
            optimizer.step()
            scheduler.step()

        # eval
        with torch.no_grad():
            iia = 0
            iterator = tqdm(evalset)
            for example in iterator:
                # forward
                _, counterfactual_outputs = intervenable(
                    example.base,
                    [example.src],
                    {"sources->base": position},
                )

                # calculate iia
                logits = counterfactual_outputs.logits[0, -1]
                if logits[example.src_label] > logits[example.base_label]:
                    iia += 1

            # stats
            iia = iia / len(evalset)
            stats.append({"layer": layer, "position": position, "iia": iia})
            print(f"iia: {iia:.3%}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_das.csv")
layer: 0, position: 0
100%|██████████| 100/100 [00:09<00:00, 10.66it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
iia: 0.000%
layer: 0, position: 1
100%|██████████| 100/100 [00:09<00:00, 10.97it/s, loss=1.268]
100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
iia: 98.000%
layer: 0, position: 2
100%|██████████| 100/100 [00:09<00:00, 10.93it/s, loss=4.130]
100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
iia: 0.000%
layer: 0, position: 3
100%|██████████| 100/100 [00:09<00:00, 10.83it/s, loss=4.276]
100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
iia: 0.000%
layer: 1, position: 0
100%|██████████| 100/100 [00:09<00:00, 10.94it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
iia: 0.000%
layer: 1, position: 1
100%|██████████| 100/100 [00:09<00:00, 10.93it/s, loss=1.231]
100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
iia: 100.000%
layer: 1, position: 2
100%|██████████| 100/100 [00:09<00:00, 10.94it/s, loss=4.422]
100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
iia: 0.000%
layer: 1, position: 3
100%|██████████| 100/100 [00:09<00:00, 10.92it/s, loss=4.308]
100%|██████████| 50/50 [00:03<00:00, 13.13it/s]
iia: 0.000%
layer: 2, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.26it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
iia: 0.000%
layer: 2, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.27it/s, loss=1.305]
100%|██████████| 50/50 [00:03<00:00, 12.93it/s]
iia: 98.000%
layer: 2, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.28it/s, loss=1.938]
100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
iia: 16.000%
layer: 2, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.24it/s, loss=2.408]
100%|██████████| 50/50 [00:03<00:00, 13.02it/s]
iia: 10.000%
layer: 3, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.40it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
iia: 0.000%
layer: 3, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.37it/s, loss=3.477]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 52.000%
layer: 3, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.34it/s, loss=2.225]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 6.000%
layer: 3, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.42it/s, loss=1.945]
100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
iia: 10.000%
layer: 4, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.59it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
iia: 0.000%
layer: 4, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.62it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
iia: 0.000%
layer: 4, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.67it/s, loss=4.034]
100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
iia: 0.000%
layer: 4, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.76it/s, loss=1.062]
100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
iia: 98.000%
layer: 5, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.70it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 0.000%
layer: 5, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.39it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 12.83it/s]
iia: 0.000%
layer: 5, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.81it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
iia: 0.000%
layer: 5, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.44it/s, loss=1.113]
100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
iia: 98.000%

And this is the plot of IIA. In layers 2 and 3 it seems the gender is represented across positions 1-3, and entirely in position 3 in later layers.

df = pd.read_csv(f"./tutorial_data/pyvene_gender_das.csv")
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["IIA"] = df["iia"].astype(float)

custom_labels = ["EOS", "<name>", "walked", "because"]
breaks = [0, 1, 2, 3]

plot = (
    ggplot(df, aes(x="layer", y="pos"))    
    + geom_tile(aes(fill="IIA"))
    + scale_fill_cmap("Purples") + xlab("layers")
    + scale_y_reverse(
        limits = (-0.5, 3.5), 
        breaks=breaks, labels=custom_labels) 
    + theme(figure_size=(5, 3)) + ylab("") 
    + theme(axis_text_y  = element_text(angle = 90, hjust = 1))
    + ggtitle("Trained Intervention (DAS)")
)
ggsave(
    plot, filename=f"./tutorial_data/pyvene_gender_das.pdf", dpi=200
)
print(plot)
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:718: PlotnineWarning: Saving 5 x 3 in image.
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:719: PlotnineWarning: Filename: ./tutorial_data/pyvene_gender_das.pdf
../../_images/92ccd360492d0b55dce2cf6f63b9b627f2268b51f418c7e1593d44e74eeb3cd1.svg

Probing#

We’ll define a dummy intervention CollectActivation to collect activations and train a simple probe.

def probing_config(intervention_site, layer):
    """Generate intervention config."""

    # init
    config = pv.IntervenableConfig([{
        "layer": layer,
        "component": intervention_site,
        "intervention_type": pv.CollectIntervention,
    }])
    return config

This is the training loop.

# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers

# 344 = " he", 703 = " she"
label_mapping = {344: 0, 703: 1}

# loop over layers and positions
with torch.no_grad():
    for layer in range(num_layers):
        for position in range(4):
            print(f"layer: {layer}, position: {position}")

            # set up intervenable model
            config = probing_config("block_output", layer)
            intervenable = pv.IntervenableModel(config, gpt)
            intervenable.set_device(device)
            intervenable.disable_model_gradients()

            # training loop
            activations, labels = [], []
            iterator = tqdm(trainset)
            for example in iterator:
                # forward pass
                base_outputs, _ = intervenable(
                    example.base,
                    unit_locations={"base": position},
                )
                base_activations = base_outputs[1][0]

                src_outputs, _ = intervenable(
                    example.src,
                    unit_locations={"base": position},
                )
                src_activations = src_outputs[1][0]
                
                # collect activation
                activations.extend(
                    [base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
                )
                labels.extend([example.base_label, example.src_label])
            labels = [label_mapping[label] for label in labels]
            
            # train logistic regression
            lr = LogisticRegression(random_state=42, max_iter=1000).fit(
                activations, labels
            )

            # eval
            activations, labels = [], []
            iterator = tqdm(evalset)
            for example in iterator:
                # forward pass
                base_outputs, _ = intervenable(
                    example.base,
                    unit_locations={"base": position},
                )
                base_activations = base_outputs[1][0]

                src_outputs, _ = intervenable(
                    example.src,
                    unit_locations={"base": position},
                )
                src_activations = src_outputs[1][0]
                
                # collect activation
                activations.extend(
                    [base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
                )
                labels.extend([example.base_label, example.src_label])
            labels = [label_mapping[label] for label in labels]

            # stats
            acc = lr.score(activations, labels)
            f1 = f1_score(labels, lr.predict(activations))
            stats.append({"layer": layer, "position": position, "acc": acc, "f1": f1})
            print(f"acc: {acc:.3%}, f1: {f1:.3f}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_probe.csv")
layer: 0, position: 0
100%|██████████| 100/100 [00:11<00:00,  8.98it/s]
100%|██████████| 50/50 [00:05<00:00,  9.12it/s]
acc: 50.000%, f1: 0.000
layer: 0, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.12it/s]
100%|██████████| 50/50 [00:05<00:00,  9.12it/s]
acc: 100.000%, f1: 1.000
layer: 0, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.12it/s]
100%|██████████| 50/50 [00:05<00:00,  9.24it/s]
acc: 99.000%, f1: 0.990
layer: 0, position: 3
100%|██████████| 100/100 [00:10<00:00,  9.23it/s]
100%|██████████| 50/50 [00:05<00:00,  9.34it/s]
acc: 88.000%, f1: 0.875
layer: 1, position: 0
100%|██████████| 100/100 [00:11<00:00,  9.09it/s]
100%|██████████| 50/50 [00:05<00:00,  8.95it/s]
acc: 50.000%, f1: 0.000
layer: 1, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.13it/s]
100%|██████████| 50/50 [00:05<00:00,  9.28it/s]
acc: 100.000%, f1: 1.000
layer: 1, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.45it/s]
100%|██████████| 50/50 [00:05<00:00,  9.35it/s]
acc: 97.000%, f1: 0.971
layer: 1, position: 3
100%|██████████| 100/100 [00:10<00:00,  9.50it/s]
100%|██████████| 50/50 [00:05<00:00,  9.53it/s]
acc: 96.000%, f1: 0.962
layer: 2, position: 0
100%|██████████| 100/100 [00:10<00:00,  9.33it/s]
100%|██████████| 50/50 [00:05<00:00,  9.19it/s]
acc: 50.000%, f1: 0.000
layer: 2, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.59it/s]
100%|██████████| 50/50 [00:05<00:00,  9.85it/s]
acc: 100.000%, f1: 1.000
layer: 2, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.86it/s]
100%|██████████| 50/50 [00:05<00:00,  9.76it/s]
acc: 100.000%, f1: 1.000
layer: 2, position: 3
100%|██████████| 100/100 [00:10<00:00,  9.92it/s]
100%|██████████| 50/50 [00:05<00:00,  9.84it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 0
100%|██████████| 100/100 [00:10<00:00,  9.82it/s]
100%|██████████| 50/50 [00:05<00:00,  9.85it/s]
acc: 50.000%, f1: 0.000
layer: 3, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.84it/s]
100%|██████████| 50/50 [00:05<00:00,  9.80it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.89it/s]
100%|██████████| 50/50 [00:05<00:00,  9.91it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 3
100%|██████████| 100/100 [00:10<00:00,  9.92it/s]
100%|██████████| 50/50 [00:05<00:00,  9.72it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 0
100%|██████████| 100/100 [00:10<00:00,  9.81it/s]
100%|██████████| 50/50 [00:05<00:00,  9.82it/s]
acc: 50.000%, f1: 0.000
layer: 4, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.32it/s]
100%|██████████| 50/50 [00:05<00:00,  8.99it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.29it/s]
100%|██████████| 50/50 [00:05<00:00,  9.74it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 3
100%|██████████| 100/100 [00:11<00:00,  8.88it/s]
100%|██████████| 50/50 [00:05<00:00,  8.88it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 0
100%|██████████| 100/100 [00:10<00:00,  9.12it/s]
100%|██████████| 50/50 [00:05<00:00,  9.51it/s]
acc: 50.000%, f1: 0.000
layer: 5, position: 1
100%|██████████| 100/100 [00:10<00:00,  9.46it/s]
100%|██████████| 50/50 [00:05<00:00,  9.20it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 2
100%|██████████| 100/100 [00:10<00:00,  9.31it/s]
100%|██████████| 50/50 [00:05<00:00,  8.73it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 3
100%|██████████| 100/100 [00:10<00:00,  9.19it/s]
100%|██████████| 50/50 [00:05<00:00,  9.43it/s]
acc: 100.000%, f1: 1.000

And the probe accuracy plot is below. Note the extremely high accuracy at all positions at and after the name! Early layers at later positions are better but it saturates much before the IIA for DAS. This shows how unreliable probes are for tracing causal effect.

df = pd.read_csv(f"./tutorial_data/pyvene_gender_probe.csv")
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["ACC"] = df["acc"].astype(float)

custom_labels = ["EOS", "<name>", "walked", "because"]
breaks = [0, 1, 2, 3]

plot = (
    ggplot(df, aes(x="layer", y="pos", fill="ACC"))    
    + geom_tile()
    + scale_fill_cmap("Reds") + xlab("layers")
    + scale_y_reverse(
        limits = (-0.5, 3.5), 
        breaks=breaks, labels=custom_labels) 
    + theme(figure_size=(5, 3)) + ylab("") 
    + theme(axis_text_y  = element_text(angle = 90, hjust = 1))
    + ggtitle("Trained Linear Probe")
)
ggsave(
    plot, filename=f"./tutorial_data/pyvene_gender_probe.pdf", dpi=200
)
print(plot)
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:718: PlotnineWarning: Saving 5 x 3 in image.
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:719: PlotnineWarning: Filename: ./tutorial_data/pyvene_gender_probe.pdf
../../_images/11b69c1c4043f08e46917c93191ab92b6e7e8b1dbce16bb0fb5cb6d890fbc70d.svg