Replicating the IOI paper#
__author__ = "Aryaman Arora"
__version__ = "1/24/2023"
Overview#
We’re going to try to replicate some results of the original IOI paper (Wang et al., 2022) using pyvene
, as a demonstration of path patching and verification of their results.
Setup#
try:
# This library is our indicator that the required installs
# need to be done.
import pyvene
except ModuleNotFoundError:
!pip install git+https://github.com/frankaging/pyvene.git
import random
import pandas as pd
from tutorial_ioi_utils import *
import pyvene as pv
import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']
from plotnine import (
ggplot,
geom_tile,
aes,
scale_y_reverse,
scale_fill_cmap,
geom_text,
theme_bw,
xlim,
ylim,
scale_x_continuous
)
# please try not to do this, the plot somehow throw warnings though :(
import warnings
warnings.filterwarnings("ignore")
config, tokenizer, gpt2 = pv.create_gpt2_lm(cache_dir="/Users/aryamanarora/.cache/huggingface/hub/")
_ = gpt2.eval()
titles={
"block_output": "single restored layer in GPT2-XL",
"mlp_activation": "center of interval of 10 patched mlp layer",
"attention_output": "center of interval of 10 patched attn layer"
}
colors={
"block_output": "Purples",
"mlp_activation": "Greens",
"attention_output": "Reds"
}
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In [58], line 27
23 import warnings
25 warnings.filterwarnings("ignore")
---> 27 config, tokenizer, gpt2 = pv.create_gpt2_lm(cache_dir="/Users/aryamanarora/.cache/huggingface/hub/")
28 _ = gpt2.eval()
30 titles={
31 "block_output": "single restored layer in GPT2-XL",
32 "mlp_activation": "center of interval of 10 patched mlp layer",
33 "attention_output": "center of interval of 10 patched attn layer"
34 }
File ~/Documents/Code/pyvene/pyvene/models/gpt2/modelings_intervenable_gpt2.py:84, in create_gpt2_lm(name, config, cache_dir)
81 """Creates a GPT2 LM, config, and tokenizer from the given name and revision"""
82 from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
---> 84 tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
85 if config is None:
86 config = GPT2Config.from_pretrained(name)
File ~/opt/miniconda3/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1968, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)
1965 if "tokenizer_file" in vocab_files:
1966 # Try to get the tokenizer config to see if there are versioned tokenizer files.
1967 fast_tokenizer_file = FULL_TOKENIZER_FILE
-> 1968 resolved_config_file = cached_file(
1969 pretrained_model_name_or_path,
1970 TOKENIZER_CONFIG_FILE,
1971 cache_dir=cache_dir,
1972 force_download=force_download,
1973 resume_download=resume_download,
1974 proxies=proxies,
1975 token=token,
1976 revision=revision,
1977 local_files_only=local_files_only,
1978 subfolder=subfolder,
1979 user_agent=user_agent,
1980 _raise_exceptions_for_missing_entries=False,
1981 _raise_exceptions_for_connection_errors=False,
1982 _commit_hash=commit_hash,
1983 )
1984 commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
1985 if resolved_config_file is not None:
File ~/opt/miniconda3/lib/python3.9/site-packages/transformers/utils/hub.py:429, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
426 user_agent = http_user_agent(user_agent)
427 try:
428 # Load from URL or cache if already cached
--> 429 resolved_file = hf_hub_download(
430 path_or_repo_id,
431 filename,
432 subfolder=None if len(subfolder) == 0 else subfolder,
433 repo_type=repo_type,
434 revision=revision,
435 cache_dir=cache_dir,
436 user_agent=user_agent,
437 force_download=force_download,
438 proxies=proxies,
439 resume_download=resume_download,
440 token=token,
441 local_files_only=local_files_only,
442 )
443 except GatedRepoError as e:
444 raise EnvironmentError(
445 "You are trying to access a gated repo.\nMake sure to request access at "
446 f"https://huggingface.co/{path_or_repo_id} and pass a token having permission to this repo either "
447 "by logging in with `huggingface-cli login` or by passing `token=<your_token>`."
448 ) from e
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
115 if check_use_auth_token:
116 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 118 return fn(*args, **kwargs)
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:1232, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, endpoint, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)
1230 try:
1231 try:
-> 1232 metadata = get_hf_file_metadata(
1233 url=url,
1234 token=token,
1235 proxies=proxies,
1236 timeout=etag_timeout,
1237 )
1238 except EntryNotFoundError as http_error:
1239 # Cache the non-existence of the file and raise
1240 commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
115 if check_use_auth_token:
116 kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 118 return fn(*args, **kwargs)
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:1599, in get_hf_file_metadata(url, token, proxies, timeout)
1596 headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file
1598 # Retrieve metadata
-> 1599 r = _request_wrapper(
1600 method="HEAD",
1601 url=url,
1602 headers=headers,
1603 allow_redirects=False,
1604 follow_relative_redirects=True,
1605 proxies=proxies,
1606 timeout=timeout,
1607 )
1608 hf_raise_for_status(r)
1610 # Return
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:417, in _request_wrapper(method, url, max_retries, base_wait_time, max_wait_time, timeout, follow_relative_redirects, **params)
415 # 2. Force relative redirection
416 if follow_relative_redirects:
--> 417 response = _request_wrapper(
418 method=method,
419 url=url,
420 max_retries=max_retries,
421 base_wait_time=base_wait_time,
422 max_wait_time=max_wait_time,
423 timeout=timeout,
424 follow_relative_redirects=False,
425 **params,
426 )
428 # If redirection, we redirect only relative paths.
429 # This is useful in case of a renamed repository.
430 if 300 <= response.status_code <= 399:
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:452, in _request_wrapper(method, url, max_retries, base_wait_time, max_wait_time, timeout, follow_relative_redirects, **params)
449 return response
451 # 3. Exponential backoff
--> 452 return http_backoff(
453 method=method,
454 url=url,
455 max_retries=max_retries,
456 base_wait_time=base_wait_time,
457 max_wait_time=max_wait_time,
458 retry_on_exceptions=(Timeout, ProxyError),
459 retry_on_status_codes=(),
460 timeout=timeout,
461 **params,
462 )
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_http.py:258, in http_backoff(method, url, max_retries, base_wait_time, max_wait_time, retry_on_exceptions, retry_on_status_codes, **kwargs)
255 kwargs["data"].seek(io_obj_initial_pos)
257 # Perform request and return if status_code is not in the retry list.
--> 258 response = session.request(method=method, url=url, **kwargs)
259 if response.status_code not in retry_on_status_codes:
260 return response
File ~/opt/miniconda3/lib/python3.9/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
584 send_kwargs = {
585 "timeout": timeout,
586 "allow_redirects": allow_redirects,
587 }
588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
591 return resp
File ~/opt/miniconda3/lib/python3.9/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
700 start = preferred_clock()
702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
705 # Total elapsed time of the request (approximately)
706 elapsed = preferred_clock() - start
File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_http.py:63, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
61 """Catch any RequestException to append request id to the error message for debugging."""
62 try:
---> 63 return super().send(request, *args, **kwargs)
64 except requests.RequestException as e:
65 request_id = request.headers.get(X_AMZN_TRACE_ID)
File ~/opt/miniconda3/lib/python3.9/site-packages/requests/adapters.py:486, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
483 timeout = TimeoutSauce(connect=timeout, read=timeout)
485 try:
--> 486 resp = conn.urlopen(
487 method=request.method,
488 url=url,
489 body=request.body,
490 headers=request.headers,
491 redirect=False,
492 assert_same_host=False,
493 preload_content=False,
494 decode_content=False,
495 retries=self.max_retries,
496 timeout=timeout,
497 chunked=chunked,
498 )
500 except (ProtocolError, OSError) as err:
501 raise ConnectionError(err, request=request)
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:703, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)
700 self._prepare_proxy(conn)
702 # Make the request on the httplib connection object.
--> 703 httplib_response = self._make_request(
704 conn,
705 method,
706 url,
707 timeout=timeout_obj,
708 body=body,
709 headers=headers,
710 chunked=chunked,
711 )
713 # If we're going to release the connection in ``finally:``, then
714 # the response doesn't need to know about the connection. Otherwise
715 # it will also try to release it and we'll have a double-release
716 # mess.
717 response_conn = conn if not release_conn else None
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:386, in HTTPConnectionPool._make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw)
384 # Trigger any extra validation we need to do.
385 try:
--> 386 self._validate_conn(conn)
387 except (SocketTimeout, BaseSSLError) as e:
388 # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout.
389 self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:1042, in HTTPSConnectionPool._validate_conn(self, conn)
1040 # Force connect early to allow us to validate the connection.
1041 if not getattr(conn, "sock", None): # AppEngine might not have `.sock`
-> 1042 conn.connect()
1044 if not conn.is_verified:
1045 warnings.warn(
1046 (
1047 "Unverified HTTPS request is being made to host '%s'. "
(...)
1052 InsecureRequestWarning,
1053 )
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connection.py:358, in HTTPSConnection.connect(self)
356 def connect(self):
357 # Add certificate verification
--> 358 self.sock = conn = self._new_conn()
359 hostname = self.host
360 tls_in_tls = False
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connection.py:174, in HTTPConnection._new_conn(self)
171 extra_kw["socket_options"] = self.socket_options
173 try:
--> 174 conn = connection.create_connection(
175 (self._dns_host, self.port), self.timeout, **extra_kw
176 )
178 except SocketTimeout:
179 raise ConnectTimeoutError(
180 self,
181 "Connection to %s timed out. (connect timeout=%s)"
182 % (self.host, self.timeout),
183 )
File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options)
83 if source_address:
84 sock.bind(source_address)
---> 85 sock.connect(sa)
86 return sock
88 except socket.error as e:
KeyboardInterrupt:
Path patching config#
This is taken from the pyvene
101 tutorial. Basically, we’ll intervene at all positions for a single attention head, and restore the base input for all upstream model components. This will get the direct effect of the intervention on the logits.
def path_patching_config(
layer, last_layer,
component="head_attention_value_output", unit="h.pos"
):
intervening_component = [
{"layer": layer, "component": component, "unit": unit, "group_key": 0}]
restoring_components = []
if not component.startswith("mlp_"):
restoring_components += [
{"layer": layer, "component": "mlp_output", "group_key": 1}]
for i in range(layer+1, last_layer):
restoring_components += [
{"layer": i, "component": "attention_output", "group_key": 1},
{"layer": i, "component": "mlp_output", "group_key": 1}
]
intervenable_config = pv.IntervenableConfig(
intervening_component + restoring_components)
return intervenable_config, len(restoring_components)
Dataset + Utils#
Just sampling prompts for the IOI task.
test_distribution = PromptDistribution(
names=NAMES,
objects=OBJECTS,
places=PLACES,
templates=TEMPLATES,
)
D_test = test_distribution.sample_das(
tokenizer=tokenizer,
base_patterns=[
"ABB",
],
source_patterns=["DCE"],
labels="name",
samples_per_combination=25,
) + test_distribution.sample_das(
tokenizer=tokenizer,
base_patterns=[
"BAB",
],
source_patterns=["DCE"],
labels="name",
samples_per_combination=25,
)
tokenizer.pad_token = tokenizer.eos_token
def get_last_token(logits, attention_mask):
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(logits.size(0)).unsqueeze(1)
return logits[batch_indices, last_token_indices.unsqueeze(1)].squeeze(1)
for batch in D_test.batches(batch_size=1):
print(batch.base, batch.base.tokens['input_ids'].shape, tokenizer.decode(batch.patched_answer_tokens[0][1]))
print(batch.source, batch.source.tokens['input_ids'].shape, tokenizer.decode(batch.patched_answer_tokens[0][0]))
break
[<===PROMPT=== Then, Justin and Bryan went to the hospital. Justin gave a necklace to>] torch.Size([1, 15]) Bryan
[<===PROMPT=== Then, Courtney and Thomas went to the hospital. Ashley gave a necklace to>] torch.Size([1, 15]) None
We’re also implementing the logit diff metric, which checks the difference in logits between the two names in the sentence. Positive logit diff means a correct prediction is more likely (the IO, i.e. non-subject name).
def compute_logit_diff(logits: torch.tensor, batch):
base_logit = get_last_token(logits, batch.base.tokens['attention_mask'])
base_label = batch.patched_answer_tokens[:, 1].to(gpt2.device)
logit_diffs = []
for batch_i in range(base_logit.size(0)):
correct_name = base_label[batch_i]
other_name = tokenizer.encode(' ' + batch.base.prompts[batch_i].s_name)[0]
logit_diffs.append(base_logit[batch_i, correct_name] - base_logit[batch_i, other_name])
return logit_diffs
They reported a baseline logit difference of 3.56 and a task accuracy of 99.3% for GPT-2 in the paper. Let’s check if this holds on our dataset (it pretty much does):
with torch.no_grad():
logit_diffs = []
argmax_acc = 0
for batch in tqdm(D_test.batches(batch_size=5), total=10):
base_label = batch.patched_answer_tokens[:, 1].to(gpt2.device)
base_logit = get_last_token(gpt2(**batch.base.tokens).logits, batch.base.tokens['attention_mask'])
src_logit = get_last_token(gpt2(**batch.source.tokens).logits, batch.source.tokens['attention_mask'])
for batch_i in range(5):
other_name = tokenizer.encode(' ' + batch.base.prompts[batch_i].s_name)[0]
correct_name = base_label[batch_i]
# logit diff
logit_diffs.append(
base_logit[batch_i, base_label[batch_i]] - base_logit[batch_i, other_name].item())
# baseline accuracy
argmax = base_logit[batch_i].argmax()
if argmax == base_label[batch_i]:
argmax_acc += 1
logit_diff = (sum(logit_diffs) / len(logit_diffs)).item()
print("avg logit diff:", logit_diff)
print("argmax acc:", argmax_acc / 50)
100%|██████████| 10/10 [00:02<00:00, 4.54it/s]
avg logit diff: 3.7562549114227295
argmax acc: 0.94
Name mover heads#
We will replicate figure 3, which identifies heads which directly affect the logits.
data = []
with torch.no_grad():
for layer in range(8, 12):
intervenable_config, num_restores = path_patching_config(layer, 12)
intervenable = IntervenableModel(intervenable_config, gpt2, use_fast=True)
for head in range(gpt2.config.n_head):
eval_labels, eval_preds, logit_diffs = [], [], []
for batch_dataset in tqdm(D_test.batches(batch_size=1), total=50):
# prepare
base_inputs = batch_dataset.base.tokens
source_inputs = batch_dataset.source.tokens
labels = batch_dataset.patched_answer_tokens[:, 1].to(gpt2.device)
pos = list(range(base_inputs["input_ids"].shape[-1]))
# inference
_, counterfactual_outputs = intervenable(
{"input_ids": base_inputs["input_ids"]},
[{"input_ids": source_inputs["input_ids"]}, {"input_ids": base_inputs["input_ids"]}],
{"sources->base": ((
[[[[head]], [pos]]]+[[pos]]*num_restores,
[[[[head]], [pos]]]+[[pos]]*num_restores
))}
)
logit_diffs.extend(compute_logit_diff(counterfactual_outputs.logits, batch_dataset))
eval_labels += [labels]
last_token_logits = get_last_token(counterfactual_outputs.logits, batch_dataset.base.tokens['attention_mask']).unsqueeze(1)
eval_preds += [last_token_logits]
# metrics
eval_metrics = compute_metrics(
eval_preds, eval_labels,
)
mean_logit_diff = sum(logit_diffs) / len(logit_diffs)
data.append({"layer": layer, "head": head, "logit_diff": mean_logit_diff, **eval_metrics})
print(data[-1])
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.
In case multiple location tags are passed only the first one will be considered
100%|██████████| 50/50 [00:07<00:00, 6.68it/s]
{'layer': 8, 'head': 0, 'logit_diff': tensor(3.7504), 'accuracy': 0.94, 'kl_div': tensor(-74.3469), 'label_logit': -74.34694038391113, 'label_prob': 0.5380359962582588}
100%|██████████| 50/50 [00:06<00:00, 7.44it/s]
{'layer': 8, 'head': 1, 'logit_diff': tensor(3.7568), 'accuracy': 0.94, 'kl_div': tensor(-74.4057), 'label_logit': -74.40572700500488, 'label_prob': 0.5395114549994469}
100%|██████████| 50/50 [00:06<00:00, 7.22it/s]
{'layer': 8, 'head': 2, 'logit_diff': tensor(3.8010), 'accuracy': 0.94, 'kl_div': tensor(-74.4281), 'label_logit': -74.42814888000488, 'label_prob': 0.5384641465544701}
100%|██████████| 50/50 [00:06<00:00, 7.38it/s]
{'layer': 8, 'head': 3, 'logit_diff': tensor(3.7179), 'accuracy': 0.92, 'kl_div': tensor(-74.5966), 'label_logit': -74.59663963317871, 'label_prob': 0.5215314196050167}
100%|██████████| 50/50 [00:08<00:00, 5.84it/s]
{'layer': 8, 'head': 4, 'logit_diff': tensor(3.7548), 'accuracy': 0.94, 'kl_div': tensor(-74.4089), 'label_logit': -74.40891792297363, 'label_prob': 0.5398365586996079}
100%|██████████| 50/50 [00:07<00:00, 6.36it/s]
{'layer': 8, 'head': 5, 'logit_diff': tensor(3.7737), 'accuracy': 0.94, 'kl_div': tensor(-74.4948), 'label_logit': -74.49480590820312, 'label_prob': 0.5443936404585838}
100%|██████████| 50/50 [00:08<00:00, 6.24it/s]
{'layer': 8, 'head': 6, 'logit_diff': tensor(3.8182), 'accuracy': 0.94, 'kl_div': tensor(-74.4737), 'label_logit': -74.47373725891113, 'label_prob': 0.5551303905248642}
100%|██████████| 50/50 [00:06<00:00, 7.79it/s]
{'layer': 8, 'head': 7, 'logit_diff': tensor(3.7553), 'accuracy': 0.94, 'kl_div': tensor(-74.4145), 'label_logit': -74.41451263427734, 'label_prob': 0.5396078166365623}
100%|██████████| 50/50 [00:07<00:00, 7.13it/s]
{'layer': 8, 'head': 8, 'logit_diff': tensor(3.7844), 'accuracy': 0.94, 'kl_div': tensor(-74.4013), 'label_logit': -74.40127014160156, 'label_prob': 0.536662351489067}
100%|██████████| 50/50 [00:07<00:00, 7.08it/s]
{'layer': 8, 'head': 9, 'logit_diff': tensor(3.7546), 'accuracy': 0.94, 'kl_div': tensor(-74.4106), 'label_logit': -74.41055702209472, 'label_prob': 0.5401303231716156}
100%|██████████| 50/50 [00:07<00:00, 6.89it/s]
{'layer': 8, 'head': 10, 'logit_diff': tensor(3.5075), 'accuracy': 0.94, 'kl_div': tensor(-74.2388), 'label_logit': -74.23884353637695, 'label_prob': 0.5428883665800095}
100%|██████████| 50/50 [00:06<00:00, 7.37it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.
In case multiple location tags are passed only the first one will be considered
{'layer': 8, 'head': 11, 'logit_diff': tensor(3.7405), 'accuracy': 0.92, 'kl_div': tensor(-74.6022), 'label_logit': -74.60216705322266, 'label_prob': 0.5177586142718792}
100%|██████████| 50/50 [00:06<00:00, 7.30it/s]
{'layer': 9, 'head': 0, 'logit_diff': tensor(3.7050), 'accuracy': 0.94, 'kl_div': tensor(-74.6208), 'label_logit': -74.620824508667, 'label_prob': 0.5102730639278888}
100%|██████████| 50/50 [00:06<00:00, 7.15it/s]
{'layer': 9, 'head': 1, 'logit_diff': tensor(3.7559), 'accuracy': 0.94, 'kl_div': tensor(-74.4151), 'label_logit': -74.41508010864258, 'label_prob': 0.5392078700661659}
100%|██████████| 50/50 [00:06<00:00, 7.28it/s]
{'layer': 9, 'head': 2, 'logit_diff': tensor(3.7019), 'accuracy': 0.94, 'kl_div': tensor(-74.6157), 'label_logit': -74.61571449279785, 'label_prob': 0.5136157578229904}
100%|██████████| 50/50 [00:08<00:00, 6.16it/s]
{'layer': 9, 'head': 3, 'logit_diff': tensor(3.7431), 'accuracy': 0.94, 'kl_div': tensor(-74.3672), 'label_logit': -74.3671482849121, 'label_prob': 0.5368358224630356}
100%|██████████| 50/50 [00:07<00:00, 6.46it/s]
{'layer': 9, 'head': 4, 'logit_diff': tensor(3.7632), 'accuracy': 0.94, 'kl_div': tensor(-74.4501), 'label_logit': -74.45014938354493, 'label_prob': 0.5293045191466809}
100%|██████████| 50/50 [00:07<00:00, 6.86it/s]
{'layer': 9, 'head': 5, 'logit_diff': tensor(3.7335), 'accuracy': 0.94, 'kl_div': tensor(-74.4766), 'label_logit': -74.47663719177245, 'label_prob': 0.5465028408169746}
100%|██████████| 50/50 [00:08<00:00, 5.85it/s]
{'layer': 9, 'head': 6, 'logit_diff': tensor(2.7914), 'accuracy': 0.7, 'kl_div': tensor(-76.0374), 'label_logit': -76.03742889404298, 'label_prob': 0.29187622375786304}
100%|██████████| 50/50 [00:07<00:00, 7.05it/s]
{'layer': 9, 'head': 7, 'logit_diff': tensor(3.7026), 'accuracy': 0.94, 'kl_div': tensor(-74.3311), 'label_logit': -74.33110733032227, 'label_prob': 0.5488018499314785}
100%|██████████| 50/50 [00:07<00:00, 7.05it/s]
{'layer': 9, 'head': 8, 'logit_diff': tensor(3.7859), 'accuracy': 0.94, 'kl_div': tensor(-74.8128), 'label_logit': -74.81278228759766, 'label_prob': 0.4744531024992466}
100%|██████████| 50/50 [00:07<00:00, 6.61it/s]
{'layer': 9, 'head': 9, 'logit_diff': tensor(1.4911), 'accuracy': 0.22, 'kl_div': tensor(-77.3908), 'label_logit': -77.39078643798828, 'label_prob': 0.12255394758656621}
100%|██████████| 50/50 [00:07<00:00, 6.70it/s]
{'layer': 9, 'head': 10, 'logit_diff': tensor(3.7551), 'accuracy': 0.94, 'kl_div': tensor(-74.3962), 'label_logit': -74.39617370605468, 'label_prob': 0.5393955698609352}
100%|██████████| 50/50 [00:07<00:00, 6.31it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.
In case multiple location tags are passed only the first one will be considered
{'layer': 9, 'head': 11, 'logit_diff': tensor(3.7567), 'accuracy': 0.94, 'kl_div': tensor(-74.4002), 'label_logit': -74.40017051696778, 'label_prob': 0.5373660787940026}
100%|██████████| 50/50 [00:07<00:00, 6.85it/s]
{'layer': 10, 'head': 0, 'logit_diff': tensor(3.1255), 'accuracy': 0.66, 'kl_div': tensor(-75.9131), 'label_logit': -75.91307647705078, 'label_prob': 0.3100175105780363}
100%|██████████| 50/50 [00:06<00:00, 7.76it/s]
{'layer': 10, 'head': 1, 'logit_diff': tensor(3.6291), 'accuracy': 0.92, 'kl_div': tensor(-74.7467), 'label_logit': -74.74669570922852, 'label_prob': 0.4808885481953621}
100%|██████████| 50/50 [00:06<00:00, 7.71it/s]
{'layer': 10, 'head': 2, 'logit_diff': tensor(3.6425), 'accuracy': 0.88, 'kl_div': tensor(-74.8445), 'label_logit': -74.84453872680665, 'label_prob': 0.4534783412516117}
100%|██████████| 50/50 [00:06<00:00, 7.57it/s]
{'layer': 10, 'head': 3, 'logit_diff': tensor(3.7048), 'accuracy': 0.92, 'kl_div': tensor(-74.5217), 'label_logit': -74.52171600341796, 'label_prob': 0.5153644406795501}
100%|██████████| 50/50 [00:06<00:00, 7.59it/s]
{'layer': 10, 'head': 4, 'logit_diff': tensor(3.7698), 'accuracy': 0.94, 'kl_div': tensor(-74.6806), 'label_logit': -74.68061622619629, 'label_prob': 0.5424765661358834}
100%|██████████| 50/50 [00:06<00:00, 7.19it/s]
{'layer': 10, 'head': 5, 'logit_diff': tensor(3.7580), 'accuracy': 0.94, 'kl_div': tensor(-74.3888), 'label_logit': -74.3887621307373, 'label_prob': 0.5407887950539589}
100%|██████████| 50/50 [00:07<00:00, 7.00it/s]
{'layer': 10, 'head': 6, 'logit_diff': tensor(3.4450), 'accuracy': 0.9, 'kl_div': tensor(-75.0385), 'label_logit': -75.0384928894043, 'label_prob': 0.4376738278567791}
100%|██████████| 50/50 [00:07<00:00, 6.58it/s]
{'layer': 10, 'head': 7, 'logit_diff': tensor(5.1090), 'accuracy': 0.98, 'kl_div': tensor(-72.2116), 'label_logit': -72.21165046691894, 'label_prob': 0.8099988362193108}
100%|██████████| 50/50 [00:07<00:00, 6.89it/s]
{'layer': 10, 'head': 8, 'logit_diff': tensor(3.7607), 'accuracy': 0.94, 'kl_div': tensor(-74.4309), 'label_logit': -74.43088127136231, 'label_prob': 0.5412813138961792}
100%|██████████| 50/50 [00:06<00:00, 7.34it/s]
{'layer': 10, 'head': 9, 'logit_diff': tensor(3.7759), 'accuracy': 0.92, 'kl_div': tensor(-74.4457), 'label_logit': -74.44571556091309, 'label_prob': 0.5376637886464596}
100%|██████████| 50/50 [00:06<00:00, 7.46it/s]
{'layer': 10, 'head': 10, 'logit_diff': tensor(3.1971), 'accuracy': 0.8, 'kl_div': tensor(-75.1706), 'label_logit': -75.17062294006348, 'label_prob': 0.38778901934623716}
100%|██████████| 50/50 [00:06<00:00, 7.67it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.
In case multiple location tags are passed only the first one will be considered
{'layer': 10, 'head': 11, 'logit_diff': tensor(3.7515), 'accuracy': 0.94, 'kl_div': tensor(-74.4513), 'label_logit': -74.4512720489502, 'label_prob': 0.5345661920309067}
100%|██████████| 50/50 [00:06<00:00, 7.49it/s]
{'layer': 11, 'head': 0, 'logit_diff': tensor(3.7527), 'accuracy': 0.94, 'kl_div': tensor(-74.4928), 'label_logit': -74.49275856018066, 'label_prob': 0.5453835587203503}
100%|██████████| 50/50 [00:06<00:00, 7.53it/s]
{'layer': 11, 'head': 1, 'logit_diff': tensor(3.6884), 'accuracy': 0.9, 'kl_div': tensor(-74.3526), 'label_logit': -74.35264060974121, 'label_prob': 0.5213113659620285}
100%|██████████| 50/50 [00:06<00:00, 7.46it/s]
{'layer': 11, 'head': 2, 'logit_diff': tensor(4.2981), 'accuracy': 0.88, 'kl_div': tensor(-74.8699), 'label_logit': -74.86986404418946, 'label_prob': 0.4666699156165123}
100%|██████████| 50/50 [00:06<00:00, 7.71it/s]
{'layer': 11, 'head': 3, 'logit_diff': tensor(3.7677), 'accuracy': 0.9, 'kl_div': tensor(-74.9174), 'label_logit': -74.91741744995117, 'label_prob': 0.47216988608241084}
100%|██████████| 50/50 [00:06<00:00, 7.60it/s]
{'layer': 11, 'head': 4, 'logit_diff': tensor(3.7532), 'accuracy': 0.94, 'kl_div': tensor(-74.3741), 'label_logit': -74.37413780212403, 'label_prob': 0.5394928365945816}
100%|██████████| 50/50 [00:06<00:00, 7.54it/s]
{'layer': 11, 'head': 5, 'logit_diff': tensor(3.7504), 'accuracy': 0.94, 'kl_div': tensor(-74.3879), 'label_logit': -74.38786392211914, 'label_prob': 0.5406103874742985}
100%|██████████| 50/50 [00:07<00:00, 6.85it/s]
{'layer': 11, 'head': 6, 'logit_diff': tensor(3.7768), 'accuracy': 0.94, 'kl_div': tensor(-74.6573), 'label_logit': -74.65733688354493, 'label_prob': 0.4942662340402603}
100%|██████████| 50/50 [00:07<00:00, 6.83it/s]
{'layer': 11, 'head': 7, 'logit_diff': tensor(3.7518), 'accuracy': 0.94, 'kl_div': tensor(-74.4400), 'label_logit': -74.44003486633301, 'label_prob': 0.5410507157444954}
100%|██████████| 50/50 [00:07<00:00, 6.28it/s]
{'layer': 11, 'head': 8, 'logit_diff': tensor(3.7459), 'accuracy': 0.94, 'kl_div': tensor(-74.9492), 'label_logit': -74.94922843933105, 'label_prob': 0.535863026380539}
100%|██████████| 50/50 [00:08<00:00, 6.13it/s]
{'layer': 11, 'head': 9, 'logit_diff': tensor(3.6901), 'accuracy': 0.94, 'kl_div': tensor(-74.6614), 'label_logit': -74.66142028808594, 'label_prob': 0.4914996309578419}
100%|██████████| 50/50 [00:07<00:00, 6.31it/s]
{'layer': 11, 'head': 10, 'logit_diff': tensor(4.5615), 'accuracy': 0.96, 'kl_div': tensor(-72.9968), 'label_logit': -72.99680526733398, 'label_prob': 0.7122889611124993}
100%|██████████| 50/50 [00:07<00:00, 6.59it/s]
{'layer': 11, 'head': 11, 'logit_diff': tensor(3.7668), 'accuracy': 0.94, 'kl_div': tensor(-76.5733), 'label_logit': -76.57328651428223, 'label_prob': 0.507421883046627}
df = pd.DataFrame(data)
df["logit_diff"] = df["logit_diff"].apply(lambda x: x.item())
df["logit_diff_relative"] = (df["logit_diff"] - logit_diff) / logit_diff
lim = df["logit_diff_relative"].abs().max()
df["formatted"] = df["logit_diff_relative"].apply(lambda x: f"{x:.2f}")
plot = (
ggplot(df, aes(x="head", y="layer", fill="logit_diff_relative")) + geom_tile()
+ scale_fill_cmap("RdBu", limits=(-lim, lim)) + scale_y_reverse(expand=[0, 0])
+ geom_text(aes(label="formatted"), size=8, color="black")
+ theme_bw() + scale_x_continuous(expand=[0, 0])
)
print(plot)