Causal Evaluation of Probes#
This is a simple tutorial showing you to collect activations from intervention-points in a model. We’ll compare 1D DAS IIA on each layer and position for block_output
in pythia-70M with logistic regression probing accuracy. The task we’ll look at is gender prediction, where gendered names are used in templates like “[name] walked because”, which elicits the associated gendered pronoun “he” or “she” as the next-token prediction for this model.
__author__ = "Aryaman Arora"
__version__ = "01/10/2024"
Setup#
try:
# This library is our indicator that the required installs
# need to be done.
import pyvene as pv
except ModuleNotFoundError:
!pip install git+https://github.com/stanfordnlp/pyvene.git
import pandas as pd
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
import torch
import random
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
%config InlineBackend.figure_formats = ['svg']
from plotnine import (
ggplot,
geom_tile,
aes,
facet_wrap,
theme,
element_text,
geom_bar,
geom_hline,
scale_y_log10,
geom_line,
geom_point,
geom_text,
ggtitle, xlab, ylab,
ggsave
)
from plotnine.scales import scale_y_reverse, scale_fill_cmap
from tqdm import tqdm
from collections import namedtuple
Load model and data#
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = "EleutherAI/pythia-70m" # "EleutherAI/pythia-6.9B"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token
gpt = AutoModelForCausalLM.from_pretrained(
model,
revision="main",
torch_dtype=torch.bfloat16 if model == "EleutherAI/pythia-6.9b" else torch.float32,
).to(device)
We have a list of 100 names for each gender, and we’ll filter for names that are one token in length. We’ll further filter for examples the model agrees with our labels for, since some of these names might be ambiguous or the model might not have the expected behaviour. This ensures that baseline IIA is 0.
Example = namedtuple("Example", ["base", "src", "base_label", "src_label"])
names = {
"he": [
"James",
"Robert",
"John",
"Michael",
"David",
"William",
"Richard",
"Joseph",
"Thomas",
"Christopher",
"Charles",
"Daniel",
"Matthew",
"Anthony",
"Mark",
"Donald",
"Steven",
"Andrew",
"Paul",
"Joshua",
"Kenneth",
"Kevin",
"Brian",
"George",
"Timothy",
"Ronald",
"Jason",
"Edward",
"Jeffrey",
"Ryan",
"Jacob",
"Gary",
"Nicholas",
"Eric",
"Jonathan",
"Stephen",
"Larry",
"Justin",
"Scott",
"Brandon",
"Benjamin",
"Samuel",
"Gregory",
"Alexander",
"Patrick",
"Frank",
"Raymond",
"Jack",
"Dennis",
"Jerry",
"Tyler",
"Aaron",
"Jose",
"Adam",
"Nathan",
"Henry",
"Zachary",
"Douglas",
"Peter",
"Kyle",
"Noah",
"Ethan",
"Jeremy",
"Walter",
"Christian",
"Keith",
"Roger",
"Terry",
"Austin",
"Sean",
"Gerald",
"Carl",
"Harold",
"Dylan",
"Arthur",
"Lawrence",
"Jordan",
"Jesse",
"Bryan",
"Billy",
"Bruce",
"Gabriel",
"Joe",
"Logan",
"Alan",
"Juan",
"Albert",
"Willie",
"Elijah",
"Wayne",
"Randy",
"Vincent",
"Mason",
"Roy",
"Ralph",
"Bobby",
"Russell",
"Bradley",
"Philip",
"Eugene",
],
"she": [
"Mary",
"Patricia",
"Jennifer",
"Linda",
"Elizabeth",
"Barbara",
"Susan",
"Jessica",
"Sarah",
"Karen",
"Lisa",
"Nancy",
"Betty",
"Sandra",
"Margaret",
"Ashley",
"Kimberly",
"Emily",
"Donna",
"Michelle",
"Carol",
"Amanda",
"Melissa",
"Deborah",
"Stephanie",
"Dorothy",
"Rebecca",
"Sharon",
"Laura",
"Cynthia",
"Amy",
"Kathleen",
"Angela",
"Shirley",
"Brenda",
"Emma",
"Anna",
"Pamela",
"Nicole",
"Samantha",
"Katherine",
"Christine",
"Helen",
"Debra",
"Rachel",
"Carolyn",
"Janet",
"Maria",
"Catherine",
"Heather",
"Diane",
"Olivia",
"Julie",
"Joyce",
"Victoria",
"Ruth",
"Virginia",
"Lauren",
"Kelly",
"Christina",
"Joan",
"Evelyn",
"Judith",
"Andrea",
"Hannah",
"Megan",
"Cheryl",
"Jacqueline",
"Martha",
"Madison",
"Teresa",
"Gloria",
"Sara",
"Janice",
"Ann",
"Kathryn",
"Abigail",
"Sophia",
"Frances",
"Jean",
"Alice",
"Judy",
"Isabella",
"Julia",
"Grace",
"Amber",
"Denise",
"Danielle",
"Marilyn",
"Beverly",
"Charlotte",
"Natalie",
"Theresa",
"Diana",
"Brittany",
"Doris",
"Kayla",
"Alexis",
"Lori",
"Marie",
],
}
# filter names that are > 1 token
names = {
key: [name for name in names[key] if len(tokenizer.tokenize(name)) == 1]
for key in names
}
print(len(names["he"]), len(names["she"]))
def sample_example(tokenizer):
# sample labels (not matching)
base_label = random.choice(list(names.keys()))
src_label = [key for key in names if key != base_label][0]
# sample names
base_name = random.choice(names[base_label])
src_name = random.choice(names[src_label])
# make pair
base = tokenizer(f"<|endoftext|>{base_name} walked because", return_tensors="pt")
src = tokenizer(f"<|endoftext|>{src_name} walked because", return_tensors="pt")
base_label = tokenizer.encode(" " + base_label)[0]
src_label = tokenizer.encode(" " + src_label)[0]
return Example(base, src, base_label, src_label)
47 10
sample_example(tokenizer)
Example(base={'input_ids': tensor([[ 0, 37376, 7428, 984]]), 'attention_mask': tensor([[1, 1, 1, 1]])}, src={'input_ids': tensor([[ 0, 44305, 7428, 984]]), 'attention_mask': tensor([[1, 1, 1, 1]])}, base_label=344, src_label=703)
def generate_n_doable_examples(n, model, tokenizer):
examples = []
iterator = tqdm(range(n))
while len(examples) < n:
ex = sample_example(tokenizer)
for k, v in ex.base.items():
if v is not None and isinstance(v, torch.Tensor):
ex.base[k] = v.to(model.device)
for k, v in ex.src.items():
if v is not None and isinstance(v, torch.Tensor):
ex.src[k] = v.to(model.device)
logits_base = model(**ex.base).logits[0, -1]
logits_src = model(**ex.src).logits[0, -1]
if (
logits_base[ex.base_label] > logits_base[ex.src_label]
and logits_src[ex.src_label] > logits_src[ex.base_label]
):
examples.append(ex)
iterator.update(1)
return examples
# make dataset
total_steps = 100
trainset = generate_n_doable_examples(total_steps, gpt, tokenizer)
evalset = generate_n_doable_examples(50, gpt, tokenizer)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 100/100 [00:05<00:00, 17.51it/s]
100%|██████████| 50/50 [00:02<00:00, 19.85it/s]
DAS#
This is the usual 1D DAS setup, training on batch size of 1.
def intervention_config(intervention_site, layer, num_dims=1):
config = pv.IntervenableConfig([
{
"layer": layer,
"component": intervention_site,
"intervention_type": pv.LowRankRotatedSpaceIntervention,
"low_rank_dimension": num_dims,
}
])
return config
# loss function
loss_fct = torch.nn.CrossEntropyLoss()
def calculate_loss(logits, label):
"""Calculate cross entropy between logits and a single target label (can be batched)"""
shift_labels = label.to(logits.device)
loss = loss_fct(logits, shift_labels)
return loss
# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers
# loop over layers and positions
for layer in range(num_layers):
for position in range(4):
print(f"layer: {layer}, position: {position}")
# set up intervenable model
config = intervention_config("block_output", layer, 1)
intervenable = pv.IntervenableModel(config, gpt)
intervenable.set_device(device)
intervenable.disable_model_gradients()
# set up optimizer
optimizer_params = []
for k, v in intervenable.interventions.items():
try:
optimizer_params.append({"params": v[0].rotate_layer.parameters()})
except:
pass
optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * total_steps),
num_training_steps=total_steps,
)
# training loop
iterator = tqdm(trainset)
for example in iterator:
# forward pass
_, counterfactual_outputs = intervenable(
example.base,
[example.src],
{"sources->base": position},
)
# loss
logits = counterfactual_outputs.logits[:, -1]
loss = calculate_loss(logits, torch.tensor([example.src_label]).to(device))
iterator.set_postfix({"loss": f"{loss.item():.3f}"})
# backward
loss.backward()
optimizer.step()
scheduler.step()
# eval
with torch.no_grad():
iia = 0
iterator = tqdm(evalset)
for example in iterator:
# forward
_, counterfactual_outputs = intervenable(
example.base,
[example.src],
{"sources->base": position},
)
# calculate iia
logits = counterfactual_outputs.logits[0, -1]
if logits[example.src_label] > logits[example.base_label]:
iia += 1
# stats
iia = iia / len(evalset)
stats.append({"layer": layer, "position": position, "iia": iia})
print(f"iia: {iia:.3%}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_das.csv")
layer: 0, position: 0
100%|██████████| 100/100 [00:09<00:00, 10.66it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.26it/s]
iia: 0.000%
layer: 0, position: 1
100%|██████████| 100/100 [00:09<00:00, 10.97it/s, loss=1.268]
100%|██████████| 50/50 [00:03<00:00, 13.35it/s]
iia: 98.000%
layer: 0, position: 2
100%|██████████| 100/100 [00:09<00:00, 10.93it/s, loss=4.130]
100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
iia: 0.000%
layer: 0, position: 3
100%|██████████| 100/100 [00:09<00:00, 10.83it/s, loss=4.276]
100%|██████████| 50/50 [00:03<00:00, 13.28it/s]
iia: 0.000%
layer: 1, position: 0
100%|██████████| 100/100 [00:09<00:00, 10.94it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
iia: 0.000%
layer: 1, position: 1
100%|██████████| 100/100 [00:09<00:00, 10.93it/s, loss=1.231]
100%|██████████| 50/50 [00:03<00:00, 13.20it/s]
iia: 100.000%
layer: 1, position: 2
100%|██████████| 100/100 [00:09<00:00, 10.94it/s, loss=4.422]
100%|██████████| 50/50 [00:03<00:00, 13.19it/s]
iia: 0.000%
layer: 1, position: 3
100%|██████████| 100/100 [00:09<00:00, 10.92it/s, loss=4.308]
100%|██████████| 50/50 [00:03<00:00, 13.13it/s]
iia: 0.000%
layer: 2, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.26it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.30it/s]
iia: 0.000%
layer: 2, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.27it/s, loss=1.305]
100%|██████████| 50/50 [00:03<00:00, 12.93it/s]
iia: 98.000%
layer: 2, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.28it/s, loss=1.938]
100%|██████████| 50/50 [00:03<00:00, 13.27it/s]
iia: 16.000%
layer: 2, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.24it/s, loss=2.408]
100%|██████████| 50/50 [00:03<00:00, 13.02it/s]
iia: 10.000%
layer: 3, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.40it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.34it/s]
iia: 0.000%
layer: 3, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.37it/s, loss=3.477]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 52.000%
layer: 3, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.34it/s, loss=2.225]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 6.000%
layer: 3, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.42it/s, loss=1.945]
100%|██████████| 50/50 [00:03<00:00, 13.32it/s]
iia: 10.000%
layer: 4, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.59it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
iia: 0.000%
layer: 4, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.62it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.38it/s]
iia: 0.000%
layer: 4, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.67it/s, loss=4.034]
100%|██████████| 50/50 [00:03<00:00, 13.40it/s]
iia: 0.000%
layer: 4, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.76it/s, loss=1.062]
100%|██████████| 50/50 [00:03<00:00, 13.25it/s]
iia: 98.000%
layer: 5, position: 0
100%|██████████| 100/100 [00:08<00:00, 11.70it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.22it/s]
iia: 0.000%
layer: 5, position: 1
100%|██████████| 100/100 [00:08<00:00, 11.39it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 12.83it/s]
iia: 0.000%
layer: 5, position: 2
100%|██████████| 100/100 [00:08<00:00, 11.81it/s, loss=4.355]
100%|██████████| 50/50 [00:03<00:00, 13.43it/s]
iia: 0.000%
layer: 5, position: 3
100%|██████████| 100/100 [00:08<00:00, 11.44it/s, loss=1.113]
100%|██████████| 50/50 [00:03<00:00, 13.33it/s]
iia: 98.000%
And this is the plot of IIA. In layers 2 and 3 it seems the gender is represented across positions 1-3, and entirely in position 3 in later layers.
df = pd.read_csv(f"./tutorial_data/pyvene_gender_das.csv")
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["IIA"] = df["iia"].astype(float)
custom_labels = ["EOS", "<name>", "walked", "because"]
breaks = [0, 1, 2, 3]
plot = (
ggplot(df, aes(x="layer", y="pos"))
+ geom_tile(aes(fill="IIA"))
+ scale_fill_cmap("Purples") + xlab("layers")
+ scale_y_reverse(
limits = (-0.5, 3.5),
breaks=breaks, labels=custom_labels)
+ theme(figure_size=(5, 3)) + ylab("")
+ theme(axis_text_y = element_text(angle = 90, hjust = 1))
+ ggtitle("Trained Intervention (DAS)")
)
ggsave(
plot, filename=f"./tutorial_data/pyvene_gender_das.pdf", dpi=200
)
print(plot)
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:718: PlotnineWarning: Saving 5 x 3 in image.
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:719: PlotnineWarning: Filename: ./tutorial_data/pyvene_gender_das.pdf
Probing#
We’ll define a dummy intervention CollectActivation
to collect activations and train a simple probe.
def probing_config(intervention_site, layer):
"""Generate intervention config."""
# init
config = pv.IntervenableConfig([{
"layer": layer,
"component": intervention_site,
"intervention_type": pv.CollectIntervention,
}])
return config
This is the training loop.
# intervention settings
stats = []
num_layers = gpt.config.num_hidden_layers
# 344 = " he", 703 = " she"
label_mapping = {344: 0, 703: 1}
# loop over layers and positions
with torch.no_grad():
for layer in range(num_layers):
for position in range(4):
print(f"layer: {layer}, position: {position}")
# set up intervenable model
config = probing_config("block_output", layer)
intervenable = pv.IntervenableModel(config, gpt)
intervenable.set_device(device)
intervenable.disable_model_gradients()
# training loop
activations, labels = [], []
iterator = tqdm(trainset)
for example in iterator:
# forward pass
base_outputs, _ = intervenable(
example.base,
unit_locations={"base": position},
)
base_activations = base_outputs[1][0]
src_outputs, _ = intervenable(
example.src,
unit_locations={"base": position},
)
src_activations = src_outputs[1][0]
# collect activation
activations.extend(
[base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
)
labels.extend([example.base_label, example.src_label])
labels = [label_mapping[label] for label in labels]
# train logistic regression
lr = LogisticRegression(random_state=42, max_iter=1000).fit(
activations, labels
)
# eval
activations, labels = [], []
iterator = tqdm(evalset)
for example in iterator:
# forward pass
base_outputs, _ = intervenable(
example.base,
unit_locations={"base": position},
)
base_activations = base_outputs[1][0]
src_outputs, _ = intervenable(
example.src,
unit_locations={"base": position},
)
src_activations = src_outputs[1][0]
# collect activation
activations.extend(
[base_activations.detach()[0].cpu().numpy(), src_activations.detach()[0].cpu().numpy()]
)
labels.extend([example.base_label, example.src_label])
labels = [label_mapping[label] for label in labels]
# stats
acc = lr.score(activations, labels)
f1 = f1_score(labels, lr.predict(activations))
stats.append({"layer": layer, "position": position, "acc": acc, "f1": f1})
print(f"acc: {acc:.3%}, f1: {f1:.3f}")
df = pd.DataFrame(stats)
df.to_csv(f"./tutorial_data/pyvene_gender_probe.csv")
layer: 0, position: 0
100%|██████████| 100/100 [00:11<00:00, 8.98it/s]
100%|██████████| 50/50 [00:05<00:00, 9.12it/s]
acc: 50.000%, f1: 0.000
layer: 0, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.12it/s]
100%|██████████| 50/50 [00:05<00:00, 9.12it/s]
acc: 100.000%, f1: 1.000
layer: 0, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.12it/s]
100%|██████████| 50/50 [00:05<00:00, 9.24it/s]
acc: 99.000%, f1: 0.990
layer: 0, position: 3
100%|██████████| 100/100 [00:10<00:00, 9.23it/s]
100%|██████████| 50/50 [00:05<00:00, 9.34it/s]
acc: 88.000%, f1: 0.875
layer: 1, position: 0
100%|██████████| 100/100 [00:11<00:00, 9.09it/s]
100%|██████████| 50/50 [00:05<00:00, 8.95it/s]
acc: 50.000%, f1: 0.000
layer: 1, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.13it/s]
100%|██████████| 50/50 [00:05<00:00, 9.28it/s]
acc: 100.000%, f1: 1.000
layer: 1, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.45it/s]
100%|██████████| 50/50 [00:05<00:00, 9.35it/s]
acc: 97.000%, f1: 0.971
layer: 1, position: 3
100%|██████████| 100/100 [00:10<00:00, 9.50it/s]
100%|██████████| 50/50 [00:05<00:00, 9.53it/s]
acc: 96.000%, f1: 0.962
layer: 2, position: 0
100%|██████████| 100/100 [00:10<00:00, 9.33it/s]
100%|██████████| 50/50 [00:05<00:00, 9.19it/s]
acc: 50.000%, f1: 0.000
layer: 2, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.59it/s]
100%|██████████| 50/50 [00:05<00:00, 9.85it/s]
acc: 100.000%, f1: 1.000
layer: 2, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.86it/s]
100%|██████████| 50/50 [00:05<00:00, 9.76it/s]
acc: 100.000%, f1: 1.000
layer: 2, position: 3
100%|██████████| 100/100 [00:10<00:00, 9.92it/s]
100%|██████████| 50/50 [00:05<00:00, 9.84it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 0
100%|██████████| 100/100 [00:10<00:00, 9.82it/s]
100%|██████████| 50/50 [00:05<00:00, 9.85it/s]
acc: 50.000%, f1: 0.000
layer: 3, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.84it/s]
100%|██████████| 50/50 [00:05<00:00, 9.80it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.89it/s]
100%|██████████| 50/50 [00:05<00:00, 9.91it/s]
acc: 100.000%, f1: 1.000
layer: 3, position: 3
100%|██████████| 100/100 [00:10<00:00, 9.92it/s]
100%|██████████| 50/50 [00:05<00:00, 9.72it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 0
100%|██████████| 100/100 [00:10<00:00, 9.81it/s]
100%|██████████| 50/50 [00:05<00:00, 9.82it/s]
acc: 50.000%, f1: 0.000
layer: 4, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.32it/s]
100%|██████████| 50/50 [00:05<00:00, 8.99it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.29it/s]
100%|██████████| 50/50 [00:05<00:00, 9.74it/s]
acc: 100.000%, f1: 1.000
layer: 4, position: 3
100%|██████████| 100/100 [00:11<00:00, 8.88it/s]
100%|██████████| 50/50 [00:05<00:00, 8.88it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 0
100%|██████████| 100/100 [00:10<00:00, 9.12it/s]
100%|██████████| 50/50 [00:05<00:00, 9.51it/s]
acc: 50.000%, f1: 0.000
layer: 5, position: 1
100%|██████████| 100/100 [00:10<00:00, 9.46it/s]
100%|██████████| 50/50 [00:05<00:00, 9.20it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 2
100%|██████████| 100/100 [00:10<00:00, 9.31it/s]
100%|██████████| 50/50 [00:05<00:00, 8.73it/s]
acc: 100.000%, f1: 1.000
layer: 5, position: 3
100%|██████████| 100/100 [00:10<00:00, 9.19it/s]
100%|██████████| 50/50 [00:05<00:00, 9.43it/s]
acc: 100.000%, f1: 1.000
And the probe accuracy plot is below. Note the extremely high accuracy at all positions at and after the name! Early layers at later positions are better but it saturates much before the IIA for DAS. This shows how unreliable probes are for tracing causal effect.
df = pd.read_csv(f"./tutorial_data/pyvene_gender_probe.csv")
df["layer"] = df["layer"].astype(int)
df["pos"] = df["position"].astype(int)
df["ACC"] = df["acc"].astype(float)
custom_labels = ["EOS", "<name>", "walked", "because"]
breaks = [0, 1, 2, 3]
plot = (
ggplot(df, aes(x="layer", y="pos", fill="ACC"))
+ geom_tile()
+ scale_fill_cmap("Reds") + xlab("layers")
+ scale_y_reverse(
limits = (-0.5, 3.5),
breaks=breaks, labels=custom_labels)
+ theme(figure_size=(5, 3)) + ylab("")
+ theme(axis_text_y = element_text(angle = 90, hjust = 1))
+ ggtitle("Trained Linear Probe")
)
ggsave(
plot, filename=f"./tutorial_data/pyvene_gender_probe.pdf", dpi=200
)
print(plot)
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:718: PlotnineWarning: Saving 5 x 3 in image.
/Users/aryamanarora/opt/miniconda3/lib/python3.9/site-packages/plotnine/ggplot.py:719: PlotnineWarning: Filename: ./tutorial_data/pyvene_gender_probe.pdf