Skip to content

Knowledge Neuron Suppression

In this tutorial, we will explore knowledge neuron suppression using a simple code example with a BERT-based transformer model. We’ll suppress a specific neuron in the network and observe how it affects the model’s output when predicting masked tokens.

In particular, we will reproduce the example of suppressing the plural neuron w1094(9)w^{(9)}_{1094} as we have shown in class.

figure 4

Table of Contents

Open Table of Contents

Step 1: Preparation

We first need to import the necessary libraries to work with the BERT model and PyTorch:

from collections import namedtuple
import torch
from transformer_lens import HookedEncoder
import transformer_lens.utils as utils
torch.set_grad_enabled(False)

neuron = namedtuple('neuron', ['layer', 'idx'])

model_name = "bert-base-cased"
cache_dir='/u/csc485h/fall/pub/tl_models_cache/'

model = HookedEncoder.from_pretrained(model_name, cache_dir=cache_dir).to('cuda')

Step 2: Tokenizing the Input

We create a simple masked sentence and tokenize it:

tokens = model.tokenizer('Raymond is selling [MASK] sketches.', return_tensors='pt')
mask_index = 4

Step 3: Defining the Neuron of Interest

We choose a specific knowledge neuron in layer 9 and neuron index 1094:

kn = neuron(9, 1094)

This neuron will be the one we “suppress” in order to observe how it affects the model’s behavior.

Step 4: Defining Candidate Tokens

We define a list of possible candidate tokens (determiners in this case) that could replace the [MASK] token:

dets = ['this', 'that', 'these', 'those', 'the', 'some']
det_ids = model.tokenizer.convert_tokens_to_ids(dets)

We convert these tokens into their corresponding token IDs using the BERT tokenizer. These will be used to measure the prediction probabilities before and after neuron suppression.

Step 5: Computing the Original Probabilities

We first compute the model’s prediction probabilities for the determiner tokens without any modification:

orig_probs = model(tokens['input_ids'])[0, mask_index].softmax(0)[det_ids]

This line retrieves the logits from the model output, applies a softmax operation to compute the probabilities, and then selects the probabilities for the candidate tokens (dets).

Step 6: Defining the Patch Function

Next, we define a patch function that will zero out the activation of the specific knowledge neuron (layer 9, index 1094):

def patch(x, hook):
    x[0, mask_index, kn.idx] = 0.
    return x

Step 7: Running the Model with Hooks

We use the model’s hooks to apply the patch during the forward pass:

logits = model.run_with_hooks(
    tokens['input_ids'],
    fwd_hooks = [(utils.get_act_name('mlp_post', kn.layer), patch)]
)

Here’s what happens:

Step 8: Computing the Suppressed Probabilities

We compute the prediction probabilities again after suppressing the knowledge neuron:

probs = logits[0, mask_index].softmax(0)[det_ids]

This line is similar to the original probability calculation, but it’s done after modifying the neuron activations.

Step 9: Comparing the Results

Finally, we compare the original and suppressed probabilities for each determiner token:

print('det', 'original', 'after erase kn', 'delta', sep='\t')
for det, orig_prob, prob in zip(dets, orig_probs, probs):
    print(det, f'{orig_prob.item():.3e}', f'{prob.item():.3e}', f'{(prob-orig_prob)/(orig_prob):+.2%}', sep='\t')

This block prints out:

Quiz 8

In the example code, you explored how suppressing a specific knowledge neuron (layer 9, neuron 1094) affects the prediction probabilities of determiners like “this”, “that”, “these”, “those”, “the”, and “some”.

Now, it’s your turn! Follow the steps below to compute and report the probability delta (the percentage change) for the determiner “a” using the sentence Raymond is selling [MASK] sketches..

Instructions

  1. Add the determiner “a” to the list of determiners in the code.
  2. Compute the original probability of predicting “a” for the masked token.
  3. Suppress the knowledge neuron (layer 9, neuron 1094) as done in the example.
  4. Compute the probability of “a” after suppressing the neuron.
  5. Calculate the probability delta (the percentage change) using the formula:
    delta=probability after suppressionoriginal probabilityoriginal probability×100\text{delta} = \frac{\text{probability after suppression} - \text{original probability}}{\text{original probability}} \times 100
  6. Report the delta for the determiner “a”.