import torchfrom transformer_lens import HookedTransformerimport transformer_lens.utils as utilstorch.set_grad_enabled(False)if __name__ == '__main__': model_name = "gpt2-xl" cache_dir='/u/csc485h/fall/pub/tl_models_cache/' model = HookedTransformer.from_pretrained(model_name, cache_dir=cache_dir) def get_top_answer(logits): max_idx = logits[0,-1].argmax() return model.to_string(max_idx) def run_model(prompt, model): prompt_tokens = model.to_tokens(prompt) logits, activations = model.run_with_cache(prompt_tokens) print('model answer:', prompt, get_top_answer(logits), sep='\n') # Shortcuts: r = lambda prompt: run_model(prompt, model) # r('a b c -> c; d e f -> f; g h i ->') # r('a b c -> c; d e f -> f; x y z ->') # r('a -> A; b -> B; d ->') # Without the ICL demos, of course, the model can't perform the task # r('g h i ->') # r('x y z ->') # r('m q r ->') def run_model_with_task_vector(icl_prompt, no_context_prompt, model, layer): icl_prompt_tokens = model.to_tokens(icl_prompt) logits, activations = model.run_with_cache(icl_prompt_tokens) # Now, let's insert the "context vector" aka the "task vector" to the model's inference process def patch_fn(x, hook): x[:, 0, :] = activations[hook.name][:, -1, :] return x forward_hook = (utils.get_act_name('resid_pre', layer), patch_fn) logits = model.run_with_hooks(model.to_tokens(no_context_prompt), fwd_hooks=[forward_hook]) print(f'with no context, but with the task vector patched at layer {layer}, the model answers:', no_context_prompt, get_top_answer(logits), sep='\n') t = lambda prompt, l: run_model_with_task_vector( "a b c -> c; d e f -> f; g h i ->", prompt, model, l) # t('x y z ->', 4) # t('n o p ->', 4) # t('q r s ->', 4) # t('m q r ->', 4)