TransformerLens: A Crash Course
In A2, Question 3, you’ll be using transformer_lens
to explore the mechanics of language models at a deeper level. While transformer_lens
shares similarities with Huggingface, allowing you to load pre-trained models and generate predictions, it offers a unique advantage: hook points. These hook points enable you to intercept and modify the model’s computations during inference, allowing for greater control and insight into the model’s internal processes.
Table of Contents
Open Table of Contents
Basic LM Functions
Step 1: Load a Pre-Trained Model
Load a pre-trained model using transformer_lens
. Here, we use the gpt2-xl
model. You should notice that the process is very similar to huggingface and Quiz 6. Notice that you don’t need to load the tokenizer separately --- HookedTransformer
has handled that for you.
import torch
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
# Load the GPT-2 model
model_name = "gpt2-xl"
cache_dir='/u/csc485h/fall/pub/tl_models_cache/'
model = HookedTransformer.from_pretrained(model_name, cache_dir=cache_dir).to('cuda')
Step 2: Prepare the Prompt and Generate Top 5 Next Tokens
Now, you can tokenize the prompt and pass it through the model in a single step. Use model.to_tokens()
to convert the prompt to tokens and get the logits.
# Define prompt and convert it to tokens
prompt = "The CN Tower is located in the city of"
tokens = model.to_tokens(prompt)
# Generate logits
with torch.no_grad():
outputs = model(tokens)
logits = outputs[0, -1, :] # Get logits for the last token
# Get the top 5 tokens and their scores
top_k = 5
top_logits, top_indices = torch.topk(logits, top_k)
top_tokens = [model.tokenizer.decode([token_id.item()]) for token_id in top_indices]
# Print top 5 tokens and their scores
print("Top 5 Tokens:")
for i, token in enumerate(top_tokens):
print(f"{i + 1}: {token} (logit: {top_logits[i].item()})")
Now, you should get the same results as quiz 6.
Cache Model Activations and Intercept Inference
In this section, we’ll cover two key new operations only made possible by transformer_lens
:
- Caching model activations during inference using
run_with_cache()
. - Modifying specific activations to intercept the inference process of the language model, such as altering the embedding of a particular phrase in a prompt or changing the patterns of some particular attention heads.
Example 1: Cache All Activation Values with run_with_cache()
The run_with_cache
method in transformer_lens
allows you to store the intermediate activation values for each layer and then access them after running a prompt through the model.
# Run the model with caching
# We are using the same tokens and prompts
with torch.no_grad():
outputs, cache = model.run_with_cache(tokens)
# Example: Accessing activations from layer 3
layer_3_activations = cache[utils.get_act_name('attn', 3)]
print("Attention Scores at Layer 3:", layer_3_activations)
This code snippet stores all activation values as the model processes the input. You can then access any layer’s activations by referencing the cache. For example, cache[utils.get_act_name('attn', 3)]
gives the attention scores at layer 3.
Example 2: Modify the Embedding of “The CN Tower” During Inference
Now, let’s intercept the model’s inference process and corrupt a specific embedding. Suppose we want to modify the embedding for “The CN Tower” in the prompt, “The CN Tower is located in the city of”.
- Step 1: Define a corruption function to replace the embedding with a corrupted version.
- Step 2: Use
run_with_hooks
to insert the corrupted embedding at the embedding layer.
Here’s a sample implementation:
# Define the corruption function
def corrupt_embedding(x, hook):
# Only corrupt the token [1, 2, 3], corresponding to "The CN Tower"
x[:, [1,2,3], :] = 0
return x
# Run the model with the corrupted embedding
with torch.no_grad():
corrupt_outputs = model.run_with_hooks(tokens,
fwd_hooks=[(utils.get_act_name("embed"), corrupt_embedding)])
corrupt_logits = corrupt_outputs[0, -1, :]
# Get the top 5 tokens and their scores
top_logits, top_indices = torch.topk(corrupt_logits, 5)
top_tokens = [model.tokenizer.decode([token_id.item()]) for token_id in top_indices]
# Print top 5 tokens and their scores
print("Top 5 Tokens After Corruption:")
for i, token in enumerate(top_tokens):
print(f"{i + 1}: {token} (logit: {top_logits[i].item()})")
In this example, we:
- Replace the embedding for “The CN Tower” with zeros and define the
corrupt_embedding
function. - Use
run_with_hooks
to run the model with this hook, corrupting the specific embedding during inference. - Typically, embeddings are corrupted by adding random noise to simulate realistic perturbations. This is also the case in A2 Q3.
- For Quiz 7, however, we set the embedding values directly to zero to eliminate randomness. While this isn’t standard practice, it provides a simple, consistent approach suitable for a practice quiz question.
Summary
With transformer_lens
, you can:
- Cache activations using
run_with_cache
, allowing you to inspect and analyze activations at different layers. - Corrupt specific embeddings during inference using
run_with_hooks
, enabling targeted modifications to see how they affect model predictions. This is particularly useful for exploring the effects of specific tokens or phrases on the model’s output.
Quiz 7
What are these top 5 most probable tokens after corruption? Select the top 5.
More Explanations
1. run_with_cache
Caches All Residual Stream Values
When you use run_with_cache
, it not only captures activations from attention and MLP layers but also stores all intermediate values in the residual stream across the model’s layers. This is extremely useful for later work involving activation patching, where you might want to replace or modify certain intermediate values in the residual stream to observe how it influences the model’s output. With these cached values, you’ll be able to revisit any part of the residual stream for experiments, such as selectively patching layers or specific components in a layer.
2. More Exploration: Beyond Attention Patterns
In Example 1, we focused on examining attention patterns by looking at the output of a particular attention layer. But you can explore much more than just the attention mechanism. With utils.get_act_name
, you can access a variety of other activations, including the residual stream, MLP output, layer norms, and more. Here are a few examples to give you an idea:
from transformer_lens import utils
# Sample act_names for different layer components at layer 3
resid_post_name = utils.get_act_name("resid_post", layer=3)
ln1_output_name = utils.get_act_name("ln1", layer=3)
mlp_name = utils.get_act_name("mlp", layer=3)
print("Residual Post Name:", resid_post_name)
print("Layer Norm 1 Name:", ln1_output_name)
print("MLP Activation Name:", mlp_name)
Feel free to modify the code in Example 1 to explore these different activations. For example, you can access cache[resid_post_name]
to see the output of the residual stream after the attention layer. Experimenting with these different activation points will help you build a more comprehensive understanding of the model’s inner workings and how each component interacts in the forward pass.
3. Understanding fwd_hooks
The forward hooks (fwd_hooks
) are at the core of how transformer_lens
allows you to intercept and manipulate the model’s inference process. fwd_hooks
is a list of tuples. Each tuple consists of an activation name (act_name
) and a corresponding hook function (hook_function
) to be applied at that point in the model.
Each tuple in fwd_hooks
has the form:
("act_name", hook_function)
When the model reaches a hook’s designated location as specified by act_name
, the function fwd_hooks
is called with the current activation value. This lets you modify the activation in place before the model proceeds with the next layer.
For example:
def hook_function(value, hook):
# Perform your desired manipulation here
modified_value = value[...]
return modified_value
# List of forward hooks
# There can be more than one hook
fwd_hooks = [
(utils.get_act_name("resid_post", layer=3), hook_function_1),
(utils.get_act_name("ln1", layer=3), hook_function_2),
(utils.get_act_name("mlp", layer=3), hook_function_3),
]
intercepted_outputs = model.run_with_hooks(tokens,
fwd_hooks=fwd_hooks)
When you run the model with run_with_hooks(tokens, fwd_hooks=fwd_hooks)
, it will process the tokens as usual but will call hook_function
when it reaches the activation points. The hook function receives the current activation value and can modify it in place before passing it along. This approach allows for precise control over the forward pass, enabling you to selectively manipulate specific layers and components.
By adding multiple tuples to fwd_hooks
, you can specify various points in the model to observe or modify activations, making it an invaluable tool for targeted analysis and experimentation.