Skip to content

Task Vectors

import torch
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

torch.set_grad_enabled(False)

if __name__ == '__main__':
    
    model_name = "gpt2-xl"
    cache_dir='/u/csc485h/fall/pub/tl_models_cache/'

    model = HookedTransformer.from_pretrained(model_name, cache_dir=cache_dir)

    def get_top_answer(logits):
        max_idx = logits[0,-1].argmax()
        return model.to_string(max_idx)

    def run_model(prompt, model):
        prompt_tokens = model.to_tokens(prompt)
        logits, activations = model.run_with_cache(prompt_tokens)
        print('model answer:', prompt, get_top_answer(logits), sep='\n')

    # Shortcuts:
    r = lambda prompt: run_model(prompt, model)

    # r('a b c -> c; d e f -> f; g h i ->')
    # r('a b c -> c; d e f -> f; x y z ->')
    # r('a -> A; b -> B; d ->')

    # Without the ICL demos, of course, the model can't perform the task
    # r('g h i ->')
    # r('x y z ->')
    # r('m q r ->')

    def run_model_with_task_vector(icl_prompt, no_context_prompt, model, layer):
        icl_prompt_tokens = model.to_tokens(icl_prompt)
        logits, activations = model.run_with_cache(icl_prompt_tokens)

        # Now, let's insert the "context vector" aka the "task vector" to the model's inference process
        def patch_fn(x, hook):
            x[:, 0, :] = activations[hook.name][:, -1, :]
            return x

        forward_hook = (utils.get_act_name('resid_pre', layer), patch_fn)
        logits = model.run_with_hooks(model.to_tokens(no_context_prompt), fwd_hooks=[forward_hook])

        print(f'with no context, but with the task vector patched at layer {layer}, the model answers:',
            no_context_prompt, get_top_answer(logits), sep='\n')


    t = lambda prompt, l: run_model_with_task_vector(
        "a b c -> c; d e f -> f; g h i ->", prompt, model, l)

    # t('x y z ->', 4)
    # t('n o p ->', 4)
    # t('q r s ->', 4)
    # t('m q r ->', 4)