Task Vectors
import torch
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
torch.set_grad_enabled(False)
if __name__ == '__main__':
model_name = "gpt2-xl"
cache_dir='/u/csc485h/fall/pub/tl_models_cache/'
model = HookedTransformer.from_pretrained(model_name, cache_dir=cache_dir)
def get_top_answer(logits):
max_idx = logits[0,-1].argmax()
return model.to_string(max_idx)
def run_model(prompt, model):
prompt_tokens = model.to_tokens(prompt)
logits, activations = model.run_with_cache(prompt_tokens)
print('model answer:', prompt, get_top_answer(logits), sep='\n')
# Shortcuts:
r = lambda prompt: run_model(prompt, model)
# r('a b c -> c; d e f -> f; g h i ->')
# r('a b c -> c; d e f -> f; x y z ->')
# r('a -> A; b -> B; d ->')
# Without the ICL demos, of course, the model can't perform the task
# r('g h i ->')
# r('x y z ->')
# r('m q r ->')
def run_model_with_task_vector(icl_prompt, no_context_prompt, model, layer):
icl_prompt_tokens = model.to_tokens(icl_prompt)
logits, activations = model.run_with_cache(icl_prompt_tokens)
# Now, let's insert the "context vector" aka the "task vector" to the model's inference process
def patch_fn(x, hook):
x[:, 0, :] = activations[hook.name][:, -1, :]
return x
forward_hook = (utils.get_act_name('resid_pre', layer), patch_fn)
logits = model.run_with_hooks(model.to_tokens(no_context_prompt), fwd_hooks=[forward_hook])
print(f'with no context, but with the task vector patched at layer {layer}, the model answers:',
no_context_prompt, get_top_answer(logits), sep='\n')
t = lambda prompt, l: run_model_with_task_vector(
"a b c -> c; d e f -> f; g h i ->", prompt, model, l)
# t('x y z ->', 4)
# t('n o p ->', 4)
# t('q r s ->', 4)
# t('m q r ->', 4)