Knowledge Neuron Suppression
In this tutorial, we will explore knowledge neuron suppression using a simple code example with a BERT-based transformer model. We’ll suppress a specific neuron in the network and observe how it affects the model’s output when predicting masked tokens.
In particular, we will reproduce the example of suppressing the plural neuron as we have shown in class.
Table of Contents
Open Table of Contents
- Step 1: Preparation
- Step 2: Tokenizing the Input
- Step 3: Defining the Neuron of Interest
- Step 4: Defining Candidate Tokens
- Step 5: Computing the Original Probabilities
- Step 6: Defining the Patch Function
- Step 7: Running the Model with Hooks
- Step 8: Computing the Suppressed Probabilities
- Step 9: Comparing the Results
- Quiz 8
Step 1: Preparation
We first need to import the necessary libraries to work with the BERT model and PyTorch:
from collections import namedtuple
import torch
from transformer_lens import HookedEncoder
import transformer_lens.utils as utils
torch.set_grad_enabled(False)
neuron = namedtuple('neuron', ['layer', 'idx'])
model_name = "bert-base-cased"
cache_dir='/u/csc485h/fall/pub/tl_models_cache/'
model = HookedEncoder.from_pretrained(model_name, cache_dir=cache_dir).to('cuda')
Step 2: Tokenizing the Input
We create a simple masked sentence and tokenize it:
tokens = model.tokenizer('Raymond is selling [MASK] sketches.', return_tensors='pt')
mask_index = 4
- The sentence
Raymond is selling [MASK] sketches.
is tokenized using the BERT tokenizer, and the[MASK]
token represents the word that BERT has to predict. mask_index = 4
refers to the index of the[MASK]
token in the input sequence.
Step 3: Defining the Neuron of Interest
We choose a specific knowledge neuron in layer 9 and neuron index 1094:
kn = neuron(9, 1094)
This neuron will be the one we “suppress” in order to observe how it affects the model’s behavior.
Step 4: Defining Candidate Tokens
We define a list of possible candidate tokens (determiners in this case) that could replace the [MASK]
token:
dets = ['this', 'that', 'these', 'those', 'the', 'some']
det_ids = model.tokenizer.convert_tokens_to_ids(dets)
We convert these tokens into their corresponding token IDs using the BERT tokenizer. These will be used to measure the prediction probabilities before and after neuron suppression.
Step 5: Computing the Original Probabilities
We first compute the model’s prediction probabilities for the determiner tokens without any modification:
orig_probs = model(tokens['input_ids'])[0, mask_index].softmax(0)[det_ids]
This line retrieves the logits from the model output, applies a softmax operation to compute the probabilities, and then selects the probabilities for the candidate tokens (dets
).
Step 6: Defining the Patch Function
Next, we define a patch
function that will zero out the activation of the specific knowledge neuron (layer 9, index 1094):
def patch(x, hook):
x[0, mask_index, kn.idx] = 0.
return x
x
is the tensor of activations for the neurons in the model.- We set the activation of the knowledge neuron (at the
[MASK]
token position) to 0. - This simulates the suppression of the knowledge neuron.
Step 7: Running the Model with Hooks
We use the model’s hooks to apply the patch during the forward pass:
logits = model.run_with_hooks(
tokens['input_ids'],
fwd_hooks = [(utils.get_act_name('mlp_post', kn.layer), patch)]
)
Here’s what happens:
- We run the model with hooks that modify the neuron activations after layer 9’s MLP post-processing (
mlp_post
). - The
patch
function is applied during the forward pass at the specified layer and neuron.
Step 8: Computing the Suppressed Probabilities
We compute the prediction probabilities again after suppressing the knowledge neuron:
probs = logits[0, mask_index].softmax(0)[det_ids]
This line is similar to the original probability calculation, but it’s done after modifying the neuron activations.
Step 9: Comparing the Results
Finally, we compare the original and suppressed probabilities for each determiner token:
print('det', 'original', 'after erase kn', 'delta', sep='\t')
for det, orig_prob, prob in zip(dets, orig_probs, probs):
print(det, f'{orig_prob.item():.3e}', f'{prob.item():.3e}', f'{(prob-orig_prob)/(orig_prob):+.2%}', sep='\t')
This block prints out:
- The determiner token (
det
). - The original probability before neuron suppression.
- The probability after suppressing the neuron.
- The percentage change (
delta
) in probability, showing how much the suppression of the knowledge neuron affected the prediction.
Quiz 8
In the example code, you explored how suppressing a specific knowledge neuron (layer 9, neuron 1094) affects the prediction probabilities of determiners like “this”, “that”, “these”, “those”, “the”, and “some”.
Now, it’s your turn! Follow the steps below to compute and report the probability delta (the percentage change) for the determiner “a” using the sentence Raymond is selling [MASK] sketches.
.
Instructions
- Add the determiner “a” to the list of determiners in the code.
- Compute the original probability of predicting “a” for the masked token.
- Suppress the knowledge neuron (layer 9, neuron 1094) as done in the example.
- Compute the probability of “a” after suppressing the neuron.
- Calculate the probability delta (the percentage change) using the formula:
- Report the delta for the determiner “a”.