Replicating the IOI paper#

__author__ = "Aryaman Arora"
__version__ = "1/24/2023"

Overview#

We’re going to try to replicate some results of the original IOI paper (Wang et al., 2022) using pyvene, as a demonstration of path patching and verification of their results.

Setup#

try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene

except ModuleNotFoundError:
    !pip install git+https://github.com/frankaging/pyvene.git
import random
import pandas as pd
from tutorial_ioi_utils import *
import pyvene as pv

import matplotlib.pyplot as plt

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    scale_y_reverse,
    scale_fill_cmap,
    geom_text,
    theme_bw,
    xlim,
    ylim,
    scale_x_continuous
)

# please try not to do this, the plot somehow throw warnings though :(
import warnings

warnings.filterwarnings("ignore")

config, tokenizer, gpt2 = pv.create_gpt2_lm(cache_dir="/Users/aryamanarora/.cache/huggingface/hub/")
_ = gpt2.eval()

titles={
    "block_output": "single restored layer in GPT2-XL",
    "mlp_activation": "center of interval of 10 patched mlp layer",
    "attention_output": "center of interval of 10 patched attn layer"
}

colors={
    "block_output": "Purples",
    "mlp_activation": "Greens",
    "attention_output": "Reds"
} 
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In [58], line 27
     23 import warnings
     25 warnings.filterwarnings("ignore")
---> 27 config, tokenizer, gpt2 = pv.create_gpt2_lm(cache_dir="/Users/aryamanarora/.cache/huggingface/hub/")
     28 _ = gpt2.eval()
     30 titles={
     31     "block_output": "single restored layer in GPT2-XL",
     32     "mlp_activation": "center of interval of 10 patched mlp layer",
     33     "attention_output": "center of interval of 10 patched attn layer"
     34 }

File ~/Documents/Code/pyvene/pyvene/models/gpt2/modelings_intervenable_gpt2.py:84, in create_gpt2_lm(name, config, cache_dir)
     81 """Creates a GPT2 LM, config, and tokenizer from the given name and revision"""
     82 from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
---> 84 tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
     85 if config is None:
     86     config = GPT2Config.from_pretrained(name)

File ~/opt/miniconda3/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1968, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)
   1965 if "tokenizer_file" in vocab_files:
   1966     # Try to get the tokenizer config to see if there are versioned tokenizer files.
   1967     fast_tokenizer_file = FULL_TOKENIZER_FILE
-> 1968     resolved_config_file = cached_file(
   1969         pretrained_model_name_or_path,
   1970         TOKENIZER_CONFIG_FILE,
   1971         cache_dir=cache_dir,
   1972         force_download=force_download,
   1973         resume_download=resume_download,
   1974         proxies=proxies,
   1975         token=token,
   1976         revision=revision,
   1977         local_files_only=local_files_only,
   1978         subfolder=subfolder,
   1979         user_agent=user_agent,
   1980         _raise_exceptions_for_missing_entries=False,
   1981         _raise_exceptions_for_connection_errors=False,
   1982         _commit_hash=commit_hash,
   1983     )
   1984     commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
   1985     if resolved_config_file is not None:

File ~/opt/miniconda3/lib/python3.9/site-packages/transformers/utils/hub.py:429, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    426 user_agent = http_user_agent(user_agent)
    427 try:
    428     # Load from URL or cache if already cached
--> 429     resolved_file = hf_hub_download(
    430         path_or_repo_id,
    431         filename,
    432         subfolder=None if len(subfolder) == 0 else subfolder,
    433         repo_type=repo_type,
    434         revision=revision,
    435         cache_dir=cache_dir,
    436         user_agent=user_agent,
    437         force_download=force_download,
    438         proxies=proxies,
    439         resume_download=resume_download,
    440         token=token,
    441         local_files_only=local_files_only,
    442     )
    443 except GatedRepoError as e:
    444     raise EnvironmentError(
    445         "You are trying to access a gated repo.\nMake sure to request access at "
    446         f"https://huggingface.co/{path_or_repo_id} and pass a token having permission to this repo either "
    447         "by logging in with `huggingface-cli login` or by passing `token=<your_token>`."
    448     ) from e

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    115 if check_use_auth_token:
    116     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 118 return fn(*args, **kwargs)

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:1232, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, endpoint, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)
   1230 try:
   1231     try:
-> 1232         metadata = get_hf_file_metadata(
   1233             url=url,
   1234             token=token,
   1235             proxies=proxies,
   1236             timeout=etag_timeout,
   1237         )
   1238     except EntryNotFoundError as http_error:
   1239         # Cache the non-existence of the file and raise
   1240         commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:118, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    115 if check_use_auth_token:
    116     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 118 return fn(*args, **kwargs)

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:1599, in get_hf_file_metadata(url, token, proxies, timeout)
   1596 headers["Accept-Encoding"] = "identity"  # prevent any compression => we want to know the real size of the file
   1598 # Retrieve metadata
-> 1599 r = _request_wrapper(
   1600     method="HEAD",
   1601     url=url,
   1602     headers=headers,
   1603     allow_redirects=False,
   1604     follow_relative_redirects=True,
   1605     proxies=proxies,
   1606     timeout=timeout,
   1607 )
   1608 hf_raise_for_status(r)
   1610 # Return

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:417, in _request_wrapper(method, url, max_retries, base_wait_time, max_wait_time, timeout, follow_relative_redirects, **params)
    415 # 2. Force relative redirection
    416 if follow_relative_redirects:
--> 417     response = _request_wrapper(
    418         method=method,
    419         url=url,
    420         max_retries=max_retries,
    421         base_wait_time=base_wait_time,
    422         max_wait_time=max_wait_time,
    423         timeout=timeout,
    424         follow_relative_redirects=False,
    425         **params,
    426     )
    428     # If redirection, we redirect only relative paths.
    429     # This is useful in case of a renamed repository.
    430     if 300 <= response.status_code <= 399:

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/file_download.py:452, in _request_wrapper(method, url, max_retries, base_wait_time, max_wait_time, timeout, follow_relative_redirects, **params)
    449     return response
    451 # 3. Exponential backoff
--> 452 return http_backoff(
    453     method=method,
    454     url=url,
    455     max_retries=max_retries,
    456     base_wait_time=base_wait_time,
    457     max_wait_time=max_wait_time,
    458     retry_on_exceptions=(Timeout, ProxyError),
    459     retry_on_status_codes=(),
    460     timeout=timeout,
    461     **params,
    462 )

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_http.py:258, in http_backoff(method, url, max_retries, base_wait_time, max_wait_time, retry_on_exceptions, retry_on_status_codes, **kwargs)
    255     kwargs["data"].seek(io_obj_initial_pos)
    257 # Perform request and return if status_code is not in the retry list.
--> 258 response = session.request(method=method, url=url, **kwargs)
    259 if response.status_code not in retry_on_status_codes:
    260     return response

File ~/opt/miniconda3/lib/python3.9/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)
    584 send_kwargs = {
    585     "timeout": timeout,
    586     "allow_redirects": allow_redirects,
    587 }
    588 send_kwargs.update(settings)
--> 589 resp = self.send(prep, **send_kwargs)
    591 return resp

File ~/opt/miniconda3/lib/python3.9/site-packages/requests/sessions.py:703, in Session.send(self, request, **kwargs)
    700 start = preferred_clock()
    702 # Send the request
--> 703 r = adapter.send(request, **kwargs)
    705 # Total elapsed time of the request (approximately)
    706 elapsed = preferred_clock() - start

File ~/opt/miniconda3/lib/python3.9/site-packages/huggingface_hub/utils/_http.py:63, in UniqueRequestIdAdapter.send(self, request, *args, **kwargs)
     61 """Catch any RequestException to append request id to the error message for debugging."""
     62 try:
---> 63     return super().send(request, *args, **kwargs)
     64 except requests.RequestException as e:
     65     request_id = request.headers.get(X_AMZN_TRACE_ID)

File ~/opt/miniconda3/lib/python3.9/site-packages/requests/adapters.py:486, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies)
    483     timeout = TimeoutSauce(connect=timeout, read=timeout)
    485 try:
--> 486     resp = conn.urlopen(
    487         method=request.method,
    488         url=url,
    489         body=request.body,
    490         headers=request.headers,
    491         redirect=False,
    492         assert_same_host=False,
    493         preload_content=False,
    494         decode_content=False,
    495         retries=self.max_retries,
    496         timeout=timeout,
    497         chunked=chunked,
    498     )
    500 except (ProtocolError, OSError) as err:
    501     raise ConnectionError(err, request=request)

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:703, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)
    700     self._prepare_proxy(conn)
    702 # Make the request on the httplib connection object.
--> 703 httplib_response = self._make_request(
    704     conn,
    705     method,
    706     url,
    707     timeout=timeout_obj,
    708     body=body,
    709     headers=headers,
    710     chunked=chunked,
    711 )
    713 # If we're going to release the connection in ``finally:``, then
    714 # the response doesn't need to know about the connection. Otherwise
    715 # it will also try to release it and we'll have a double-release
    716 # mess.
    717 response_conn = conn if not release_conn else None

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:386, in HTTPConnectionPool._make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw)
    384 # Trigger any extra validation we need to do.
    385 try:
--> 386     self._validate_conn(conn)
    387 except (SocketTimeout, BaseSSLError) as e:
    388     # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout.
    389     self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connectionpool.py:1042, in HTTPSConnectionPool._validate_conn(self, conn)
   1040 # Force connect early to allow us to validate the connection.
   1041 if not getattr(conn, "sock", None):  # AppEngine might not have  `.sock`
-> 1042     conn.connect()
   1044 if not conn.is_verified:
   1045     warnings.warn(
   1046         (
   1047             "Unverified HTTPS request is being made to host '%s'. "
   (...)
   1052         InsecureRequestWarning,
   1053     )

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connection.py:358, in HTTPSConnection.connect(self)
    356 def connect(self):
    357     # Add certificate verification
--> 358     self.sock = conn = self._new_conn()
    359     hostname = self.host
    360     tls_in_tls = False

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/connection.py:174, in HTTPConnection._new_conn(self)
    171     extra_kw["socket_options"] = self.socket_options
    173 try:
--> 174     conn = connection.create_connection(
    175         (self._dns_host, self.port), self.timeout, **extra_kw
    176     )
    178 except SocketTimeout:
    179     raise ConnectTimeoutError(
    180         self,
    181         "Connection to %s timed out. (connect timeout=%s)"
    182         % (self.host, self.timeout),
    183     )

File ~/opt/miniconda3/lib/python3.9/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options)
     83     if source_address:
     84         sock.bind(source_address)
---> 85     sock.connect(sa)
     86     return sock
     88 except socket.error as e:

KeyboardInterrupt: 

Path patching config#

This is taken from the pyvene 101 tutorial. Basically, we’ll intervene at all positions for a single attention head, and restore the base input for all upstream model components. This will get the direct effect of the intervention on the logits.

def path_patching_config(
    layer, last_layer, 
    component="head_attention_value_output", unit="h.pos"
):
    intervening_component = [
        {"layer": layer, "component": component, "unit": unit, "group_key": 0}]
    restoring_components = []
    if not component.startswith("mlp_"):
        restoring_components += [
            {"layer": layer, "component": "mlp_output", "group_key": 1}]
    for i in range(layer+1, last_layer):
        restoring_components += [
            {"layer": i, "component": "attention_output", "group_key": 1},
            {"layer": i, "component": "mlp_output", "group_key": 1}
        ]
    intervenable_config = pv.IntervenableConfig(
        intervening_component + restoring_components)
    return intervenable_config, len(restoring_components)

Dataset + Utils#

Just sampling prompts for the IOI task.

test_distribution = PromptDistribution(
    names=NAMES,
    objects=OBJECTS,
    places=PLACES,
    templates=TEMPLATES,
)

D_test = test_distribution.sample_das(
    tokenizer=tokenizer,
    base_patterns=[
        "ABB",
    ],
    source_patterns=["DCE"],
    labels="name",
    samples_per_combination=25,
) + test_distribution.sample_das(
    tokenizer=tokenizer,
    base_patterns=[
        "BAB",
    ],
    source_patterns=["DCE"],
    labels="name",
    samples_per_combination=25,
)
tokenizer.pad_token = tokenizer.eos_token

def get_last_token(logits, attention_mask):
    last_token_indices = attention_mask.sum(1) - 1
    batch_indices = torch.arange(logits.size(0)).unsqueeze(1)
    return logits[batch_indices, last_token_indices.unsqueeze(1)].squeeze(1)
for batch in D_test.batches(batch_size=1):
    print(batch.base, batch.base.tokens['input_ids'].shape, tokenizer.decode(batch.patched_answer_tokens[0][1]))
    print(batch.source, batch.source.tokens['input_ids'].shape, tokenizer.decode(batch.patched_answer_tokens[0][0]))
    break
[<===PROMPT=== Then, Justin and Bryan went to the hospital. Justin gave a necklace to>] torch.Size([1, 15])  Bryan
[<===PROMPT=== Then, Courtney and Thomas went to the hospital. Ashley gave a necklace to>] torch.Size([1, 15])  None

We’re also implementing the logit diff metric, which checks the difference in logits between the two names in the sentence. Positive logit diff means a correct prediction is more likely (the IO, i.e. non-subject name).

def compute_logit_diff(logits: torch.tensor, batch):
    base_logit = get_last_token(logits, batch.base.tokens['attention_mask'])
    base_label = batch.patched_answer_tokens[:, 1].to(gpt2.device)
    logit_diffs = []
    for batch_i in range(base_logit.size(0)):
        correct_name = base_label[batch_i]
        other_name = tokenizer.encode(' ' + batch.base.prompts[batch_i].s_name)[0]
        logit_diffs.append(base_logit[batch_i, correct_name] - base_logit[batch_i, other_name])
    return logit_diffs

They reported a baseline logit difference of 3.56 and a task accuracy of 99.3% for GPT-2 in the paper. Let’s check if this holds on our dataset (it pretty much does):

with torch.no_grad():
    logit_diffs = []
    argmax_acc = 0
    for batch in tqdm(D_test.batches(batch_size=5), total=10):
        base_label = batch.patched_answer_tokens[:, 1].to(gpt2.device)
        base_logit = get_last_token(gpt2(**batch.base.tokens).logits, batch.base.tokens['attention_mask'])
        src_logit = get_last_token(gpt2(**batch.source.tokens).logits, batch.source.tokens['attention_mask'])
        for batch_i in range(5):
            other_name = tokenizer.encode(' ' + batch.base.prompts[batch_i].s_name)[0]
            correct_name = base_label[batch_i]

            # logit diff
            logit_diffs.append(
                base_logit[batch_i, base_label[batch_i]] - base_logit[batch_i, other_name].item())

            # baseline accuracy
            argmax = base_logit[batch_i].argmax()
            if argmax == base_label[batch_i]:
                argmax_acc += 1

    logit_diff = (sum(logit_diffs) / len(logit_diffs)).item()
    print("avg logit diff:", logit_diff)
    print("argmax acc:", argmax_acc / 50)
100%|██████████| 10/10 [00:02<00:00,  4.54it/s]
avg logit diff: 3.7562549114227295
argmax acc: 0.94

Name mover heads#

We will replicate figure 3, which identifies heads which directly affect the logits.

data = []

with torch.no_grad():
    for layer in range(8, 12):
        intervenable_config, num_restores = path_patching_config(layer, 12)
        intervenable = IntervenableModel(intervenable_config, gpt2, use_fast=True)

        for head in range(gpt2.config.n_head):
            eval_labels, eval_preds, logit_diffs = [], [], []
            for batch_dataset in tqdm(D_test.batches(batch_size=1), total=50):
                # prepare
                base_inputs = batch_dataset.base.tokens
                source_inputs = batch_dataset.source.tokens
                labels = batch_dataset.patched_answer_tokens[:, 1].to(gpt2.device)
                pos = list(range(base_inputs["input_ids"].shape[-1]))

                # inference
                _, counterfactual_outputs = intervenable(
                    {"input_ids": base_inputs["input_ids"]}, 
                    [{"input_ids": source_inputs["input_ids"]}, {"input_ids": base_inputs["input_ids"]}],
                    {"sources->base": ((
                        [[[[head]], [pos]]]+[[pos]]*num_restores, 
                        [[[[head]], [pos]]]+[[pos]]*num_restores
                    ))}
                )
                logit_diffs.extend(compute_logit_diff(counterfactual_outputs.logits, batch_dataset))
                eval_labels += [labels]
                last_token_logits = get_last_token(counterfactual_outputs.logits, batch_dataset.base.tokens['attention_mask']).unsqueeze(1)
                eval_preds += [last_token_logits]
            
            # metrics
            eval_metrics = compute_metrics(
                eval_preds, eval_labels,
            )
            mean_logit_diff = sum(logit_diffs) / len(logit_diffs)
            data.append({"layer": layer, "head": head, "logit_diff": mean_logit_diff, **eval_metrics})
            print(data[-1])
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
100%|██████████| 50/50 [00:07<00:00,  6.68it/s]
{'layer': 8, 'head': 0, 'logit_diff': tensor(3.7504), 'accuracy': 0.94, 'kl_div': tensor(-74.3469), 'label_logit': -74.34694038391113, 'label_prob': 0.5380359962582588}
100%|██████████| 50/50 [00:06<00:00,  7.44it/s]
{'layer': 8, 'head': 1, 'logit_diff': tensor(3.7568), 'accuracy': 0.94, 'kl_div': tensor(-74.4057), 'label_logit': -74.40572700500488, 'label_prob': 0.5395114549994469}
100%|██████████| 50/50 [00:06<00:00,  7.22it/s]
{'layer': 8, 'head': 2, 'logit_diff': tensor(3.8010), 'accuracy': 0.94, 'kl_div': tensor(-74.4281), 'label_logit': -74.42814888000488, 'label_prob': 0.5384641465544701}
100%|██████████| 50/50 [00:06<00:00,  7.38it/s]
{'layer': 8, 'head': 3, 'logit_diff': tensor(3.7179), 'accuracy': 0.92, 'kl_div': tensor(-74.5966), 'label_logit': -74.59663963317871, 'label_prob': 0.5215314196050167}
100%|██████████| 50/50 [00:08<00:00,  5.84it/s]
{'layer': 8, 'head': 4, 'logit_diff': tensor(3.7548), 'accuracy': 0.94, 'kl_div': tensor(-74.4089), 'label_logit': -74.40891792297363, 'label_prob': 0.5398365586996079}
100%|██████████| 50/50 [00:07<00:00,  6.36it/s]
{'layer': 8, 'head': 5, 'logit_diff': tensor(3.7737), 'accuracy': 0.94, 'kl_div': tensor(-74.4948), 'label_logit': -74.49480590820312, 'label_prob': 0.5443936404585838}
100%|██████████| 50/50 [00:08<00:00,  6.24it/s]
{'layer': 8, 'head': 6, 'logit_diff': tensor(3.8182), 'accuracy': 0.94, 'kl_div': tensor(-74.4737), 'label_logit': -74.47373725891113, 'label_prob': 0.5551303905248642}
100%|██████████| 50/50 [00:06<00:00,  7.79it/s]
{'layer': 8, 'head': 7, 'logit_diff': tensor(3.7553), 'accuracy': 0.94, 'kl_div': tensor(-74.4145), 'label_logit': -74.41451263427734, 'label_prob': 0.5396078166365623}
100%|██████████| 50/50 [00:07<00:00,  7.13it/s]
{'layer': 8, 'head': 8, 'logit_diff': tensor(3.7844), 'accuracy': 0.94, 'kl_div': tensor(-74.4013), 'label_logit': -74.40127014160156, 'label_prob': 0.536662351489067}
100%|██████████| 50/50 [00:07<00:00,  7.08it/s]
{'layer': 8, 'head': 9, 'logit_diff': tensor(3.7546), 'accuracy': 0.94, 'kl_div': tensor(-74.4106), 'label_logit': -74.41055702209472, 'label_prob': 0.5401303231716156}
100%|██████████| 50/50 [00:07<00:00,  6.89it/s]
{'layer': 8, 'head': 10, 'logit_diff': tensor(3.5075), 'accuracy': 0.94, 'kl_div': tensor(-74.2388), 'label_logit': -74.23884353637695, 'label_prob': 0.5428883665800095}
100%|██████████| 50/50 [00:06<00:00,  7.37it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
{'layer': 8, 'head': 11, 'logit_diff': tensor(3.7405), 'accuracy': 0.92, 'kl_div': tensor(-74.6022), 'label_logit': -74.60216705322266, 'label_prob': 0.5177586142718792}
100%|██████████| 50/50 [00:06<00:00,  7.30it/s]
{'layer': 9, 'head': 0, 'logit_diff': tensor(3.7050), 'accuracy': 0.94, 'kl_div': tensor(-74.6208), 'label_logit': -74.620824508667, 'label_prob': 0.5102730639278888}
100%|██████████| 50/50 [00:06<00:00,  7.15it/s]
{'layer': 9, 'head': 1, 'logit_diff': tensor(3.7559), 'accuracy': 0.94, 'kl_div': tensor(-74.4151), 'label_logit': -74.41508010864258, 'label_prob': 0.5392078700661659}
100%|██████████| 50/50 [00:06<00:00,  7.28it/s]
{'layer': 9, 'head': 2, 'logit_diff': tensor(3.7019), 'accuracy': 0.94, 'kl_div': tensor(-74.6157), 'label_logit': -74.61571449279785, 'label_prob': 0.5136157578229904}
100%|██████████| 50/50 [00:08<00:00,  6.16it/s]
{'layer': 9, 'head': 3, 'logit_diff': tensor(3.7431), 'accuracy': 0.94, 'kl_div': tensor(-74.3672), 'label_logit': -74.3671482849121, 'label_prob': 0.5368358224630356}
100%|██████████| 50/50 [00:07<00:00,  6.46it/s]
{'layer': 9, 'head': 4, 'logit_diff': tensor(3.7632), 'accuracy': 0.94, 'kl_div': tensor(-74.4501), 'label_logit': -74.45014938354493, 'label_prob': 0.5293045191466809}
100%|██████████| 50/50 [00:07<00:00,  6.86it/s]
{'layer': 9, 'head': 5, 'logit_diff': tensor(3.7335), 'accuracy': 0.94, 'kl_div': tensor(-74.4766), 'label_logit': -74.47663719177245, 'label_prob': 0.5465028408169746}
100%|██████████| 50/50 [00:08<00:00,  5.85it/s]
{'layer': 9, 'head': 6, 'logit_diff': tensor(2.7914), 'accuracy': 0.7, 'kl_div': tensor(-76.0374), 'label_logit': -76.03742889404298, 'label_prob': 0.29187622375786304}
100%|██████████| 50/50 [00:07<00:00,  7.05it/s]
{'layer': 9, 'head': 7, 'logit_diff': tensor(3.7026), 'accuracy': 0.94, 'kl_div': tensor(-74.3311), 'label_logit': -74.33110733032227, 'label_prob': 0.5488018499314785}
100%|██████████| 50/50 [00:07<00:00,  7.05it/s]
{'layer': 9, 'head': 8, 'logit_diff': tensor(3.7859), 'accuracy': 0.94, 'kl_div': tensor(-74.8128), 'label_logit': -74.81278228759766, 'label_prob': 0.4744531024992466}
100%|██████████| 50/50 [00:07<00:00,  6.61it/s]
{'layer': 9, 'head': 9, 'logit_diff': tensor(1.4911), 'accuracy': 0.22, 'kl_div': tensor(-77.3908), 'label_logit': -77.39078643798828, 'label_prob': 0.12255394758656621}
100%|██████████| 50/50 [00:07<00:00,  6.70it/s]
{'layer': 9, 'head': 10, 'logit_diff': tensor(3.7551), 'accuracy': 0.94, 'kl_div': tensor(-74.3962), 'label_logit': -74.39617370605468, 'label_prob': 0.5393955698609352}
100%|██████████| 50/50 [00:07<00:00,  6.31it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
{'layer': 9, 'head': 11, 'logit_diff': tensor(3.7567), 'accuracy': 0.94, 'kl_div': tensor(-74.4002), 'label_logit': -74.40017051696778, 'label_prob': 0.5373660787940026}
100%|██████████| 50/50 [00:07<00:00,  6.85it/s]
{'layer': 10, 'head': 0, 'logit_diff': tensor(3.1255), 'accuracy': 0.66, 'kl_div': tensor(-75.9131), 'label_logit': -75.91307647705078, 'label_prob': 0.3100175105780363}
100%|██████████| 50/50 [00:06<00:00,  7.76it/s]
{'layer': 10, 'head': 1, 'logit_diff': tensor(3.6291), 'accuracy': 0.92, 'kl_div': tensor(-74.7467), 'label_logit': -74.74669570922852, 'label_prob': 0.4808885481953621}
100%|██████████| 50/50 [00:06<00:00,  7.71it/s]
{'layer': 10, 'head': 2, 'logit_diff': tensor(3.6425), 'accuracy': 0.88, 'kl_div': tensor(-74.8445), 'label_logit': -74.84453872680665, 'label_prob': 0.4534783412516117}
100%|██████████| 50/50 [00:06<00:00,  7.57it/s]
{'layer': 10, 'head': 3, 'logit_diff': tensor(3.7048), 'accuracy': 0.92, 'kl_div': tensor(-74.5217), 'label_logit': -74.52171600341796, 'label_prob': 0.5153644406795501}
100%|██████████| 50/50 [00:06<00:00,  7.59it/s]
{'layer': 10, 'head': 4, 'logit_diff': tensor(3.7698), 'accuracy': 0.94, 'kl_div': tensor(-74.6806), 'label_logit': -74.68061622619629, 'label_prob': 0.5424765661358834}
100%|██████████| 50/50 [00:06<00:00,  7.19it/s]
{'layer': 10, 'head': 5, 'logit_diff': tensor(3.7580), 'accuracy': 0.94, 'kl_div': tensor(-74.3888), 'label_logit': -74.3887621307373, 'label_prob': 0.5407887950539589}
100%|██████████| 50/50 [00:07<00:00,  7.00it/s]
{'layer': 10, 'head': 6, 'logit_diff': tensor(3.4450), 'accuracy': 0.9, 'kl_div': tensor(-75.0385), 'label_logit': -75.0384928894043, 'label_prob': 0.4376738278567791}
100%|██████████| 50/50 [00:07<00:00,  6.58it/s]
{'layer': 10, 'head': 7, 'logit_diff': tensor(5.1090), 'accuracy': 0.98, 'kl_div': tensor(-72.2116), 'label_logit': -72.21165046691894, 'label_prob': 0.8099988362193108}
100%|██████████| 50/50 [00:07<00:00,  6.89it/s]
{'layer': 10, 'head': 8, 'logit_diff': tensor(3.7607), 'accuracy': 0.94, 'kl_div': tensor(-74.4309), 'label_logit': -74.43088127136231, 'label_prob': 0.5412813138961792}
100%|██████████| 50/50 [00:06<00:00,  7.34it/s]
{'layer': 10, 'head': 9, 'logit_diff': tensor(3.7759), 'accuracy': 0.92, 'kl_div': tensor(-74.4457), 'label_logit': -74.44571556091309, 'label_prob': 0.5376637886464596}
100%|██████████| 50/50 [00:06<00:00,  7.46it/s]
{'layer': 10, 'head': 10, 'logit_diff': tensor(3.1971), 'accuracy': 0.8, 'kl_div': tensor(-75.1706), 'label_logit': -75.17062294006348, 'label_prob': 0.38778901934623716}
100%|██████████| 50/50 [00:06<00:00,  7.67it/s]
WARNING:root:Detected use_fast=True means the intervention location will be static within a batch.

In case multiple location tags are passed only the first one will be considered
{'layer': 10, 'head': 11, 'logit_diff': tensor(3.7515), 'accuracy': 0.94, 'kl_div': tensor(-74.4513), 'label_logit': -74.4512720489502, 'label_prob': 0.5345661920309067}
100%|██████████| 50/50 [00:06<00:00,  7.49it/s]
{'layer': 11, 'head': 0, 'logit_diff': tensor(3.7527), 'accuracy': 0.94, 'kl_div': tensor(-74.4928), 'label_logit': -74.49275856018066, 'label_prob': 0.5453835587203503}
100%|██████████| 50/50 [00:06<00:00,  7.53it/s]
{'layer': 11, 'head': 1, 'logit_diff': tensor(3.6884), 'accuracy': 0.9, 'kl_div': tensor(-74.3526), 'label_logit': -74.35264060974121, 'label_prob': 0.5213113659620285}
100%|██████████| 50/50 [00:06<00:00,  7.46it/s]
{'layer': 11, 'head': 2, 'logit_diff': tensor(4.2981), 'accuracy': 0.88, 'kl_div': tensor(-74.8699), 'label_logit': -74.86986404418946, 'label_prob': 0.4666699156165123}
100%|██████████| 50/50 [00:06<00:00,  7.71it/s]
{'layer': 11, 'head': 3, 'logit_diff': tensor(3.7677), 'accuracy': 0.9, 'kl_div': tensor(-74.9174), 'label_logit': -74.91741744995117, 'label_prob': 0.47216988608241084}
100%|██████████| 50/50 [00:06<00:00,  7.60it/s]
{'layer': 11, 'head': 4, 'logit_diff': tensor(3.7532), 'accuracy': 0.94, 'kl_div': tensor(-74.3741), 'label_logit': -74.37413780212403, 'label_prob': 0.5394928365945816}
100%|██████████| 50/50 [00:06<00:00,  7.54it/s]
{'layer': 11, 'head': 5, 'logit_diff': tensor(3.7504), 'accuracy': 0.94, 'kl_div': tensor(-74.3879), 'label_logit': -74.38786392211914, 'label_prob': 0.5406103874742985}
100%|██████████| 50/50 [00:07<00:00,  6.85it/s]
{'layer': 11, 'head': 6, 'logit_diff': tensor(3.7768), 'accuracy': 0.94, 'kl_div': tensor(-74.6573), 'label_logit': -74.65733688354493, 'label_prob': 0.4942662340402603}
100%|██████████| 50/50 [00:07<00:00,  6.83it/s]
{'layer': 11, 'head': 7, 'logit_diff': tensor(3.7518), 'accuracy': 0.94, 'kl_div': tensor(-74.4400), 'label_logit': -74.44003486633301, 'label_prob': 0.5410507157444954}
100%|██████████| 50/50 [00:07<00:00,  6.28it/s]
{'layer': 11, 'head': 8, 'logit_diff': tensor(3.7459), 'accuracy': 0.94, 'kl_div': tensor(-74.9492), 'label_logit': -74.94922843933105, 'label_prob': 0.535863026380539}
100%|██████████| 50/50 [00:08<00:00,  6.13it/s]
{'layer': 11, 'head': 9, 'logit_diff': tensor(3.6901), 'accuracy': 0.94, 'kl_div': tensor(-74.6614), 'label_logit': -74.66142028808594, 'label_prob': 0.4914996309578419}
100%|██████████| 50/50 [00:07<00:00,  6.31it/s]
{'layer': 11, 'head': 10, 'logit_diff': tensor(4.5615), 'accuracy': 0.96, 'kl_div': tensor(-72.9968), 'label_logit': -72.99680526733398, 'label_prob': 0.7122889611124993}
100%|██████████| 50/50 [00:07<00:00,  6.59it/s]
{'layer': 11, 'head': 11, 'logit_diff': tensor(3.7668), 'accuracy': 0.94, 'kl_div': tensor(-76.5733), 'label_logit': -76.57328651428223, 'label_prob': 0.507421883046627}

df = pd.DataFrame(data)
df["logit_diff"] = df["logit_diff"].apply(lambda x: x.item())
df["logit_diff_relative"] = (df["logit_diff"] - logit_diff) / logit_diff
lim = df["logit_diff_relative"].abs().max()
df["formatted"] = df["logit_diff_relative"].apply(lambda x: f"{x:.2f}")
plot = (
    ggplot(df, aes(x="head", y="layer", fill="logit_diff_relative")) + geom_tile()
    + scale_fill_cmap("RdBu", limits=(-lim, lim)) + scale_y_reverse(expand=[0, 0])
    + geom_text(aes(label="formatted"), size=8, color="black")
    + theme_bw() + scale_x_continuous(expand=[0, 0])
)
print(plot)
../../_images/dfc3bb8512eb7e3c6be222658aee1fe5a23474554a0ab68a0beb92752d1b4669.svg