diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
index 489b43310..7bf631b5c 100644
--- a/.github/workflows/checks.yml
+++ b/.github/workflows/checks.yml
@@ -230,7 +230,7 @@ jobs:
- "BERT"
- "Exploratory_Analysis_Demo"
# - "Grokking_Demo"
- # - "Head_Detector_Demo"
+ - "Head_Detector_Demo"
# - "Interactive_Neuroscope"
# - "LLaMA"
# - "LLaMA2_GPU_Quantized"
diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb
index b6d7e09da..6f27d1af0 100644
--- a/demos/Head_Detector_Demo.ipynb
+++ b/demos/Head_Detector_Demo.ipynb
@@ -1,2659 +1,2662 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "
\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "YquKKgs17NOv"
- },
- "source": [
- "# TransformerLens Head Detector Demo"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wKW2CqN-yZuY"
- },
- "source": [
- "A common technique in mechanistic interpretability of transformer-based neural networks is identification of specialized attention heads, based on the attention patterns elicited by one or more prompts. The most basic examples of such heads are: previous token head, duplicate token head, or induction head ([more info](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_Jzi6YHRHKP1JziwdE02qdYZ)). Usually, such heads are identified manually, by through visualizations of attention patterns layer by layer, head by head, and trying to recognize the patterns by eye.\n",
- "\n",
- "The purpose of the `TransformerLens.head_detector` feature is to automate a part of that workflow. The pattern characterizing a head of particular type/function is specified as a `Tensor` being a `seq_len x seq_len` [lower triangular matrix](https://en.wikipedia.org/wiki/Triangular_matrix). It can be either passed to the `detect_head` function directly or by giving a string identifying of several pre-defined detection patterns."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3a53LkPTAjzB"
- },
- "source": [
- "## How to use this notebook\n",
- "\n",
- "Go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n",
- "\n",
- "Tips for reading this Colab:\n",
- "\n",
- "* You can run all this code for yourself!\n",
- "* The graphs are interactive!\n",
- "* Use the table of contents pane in the sidebar to navigate\n",
- "* Collapse irrelevant sections with the dropdown arrows\n",
- "* Search the page using the search in the sidebar, not CTRL+F"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "nCWImh1S7fNx"
- },
- "source": [
- "## Setup (Ignore)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "4LZeYL3XAc7T",
- "outputId": "680da02d-5ca8-4ab3-bc24-f2827f0fcd95"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Running as a Colab notebook\n",
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git\n",
- " Cloning https://github.com/TransformerLensOrg/TransformerLens.git to /tmp/pip-req-build-v3x96q_b\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-v3x96q_b\n",
- " Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773\n",
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.12.0)\n",
- "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.6.1)\n",
- "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.0.3)\n",
- "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.2.15)\n",
- "Requirement already satisfied: numpy>=1.23 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.24.3)\n",
- "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.5.3)\n",
- "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (13.3.4)\n",
- "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.0.0+cu118)\n",
- "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.65.0)\n",
- "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.28.1)\n",
- "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.15.0)\n",
- "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (9.0.0)\n",
- "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.3.6)\n",
- "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2.27.1)\n",
- "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.2.0)\n",
- "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.70.14)\n",
- "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2023.4.0)\n",
- "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.8.4)\n",
- "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.14.1)\n",
- "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (23.1)\n",
- "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.18.0)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (6.0)\n",
- "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (2.13.3)\n",
- "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (4.5.0)\n",
- "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2022.7.1)\n",
- "Requirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.2.0)\n",
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.14.0)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.12.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (1.11.1)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.1.2)\n",
- "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (2.0.0)\n",
- "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->transformer-lens==0.0.0) (3.25.2)\n",
- "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->transformer-lens==0.0.0) (16.0.2)\n",
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (2022.10.31)\n",
- "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (0.13.3)\n",
- "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (8.1.3)\n",
- "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.1.31)\n",
- "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (5.9.5)\n",
- "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.21.1)\n",
- "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.4.0)\n",
- "Requirement already satisfied: pathtools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.1.2)\n",
- "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.3.2)\n",
- "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (67.7.2)\n",
- "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.4.4)\n",
- "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.20.3)\n",
- "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer-lens==0.0.0) (1.16.0)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (23.1.0)\n",
- "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (2.0.12)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (6.0.4)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (4.0.2)\n",
- "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.9.2)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.3)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.1)\n",
- "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (4.0.10)\n",
- "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py<3.0.0,>=2.2.0->rich>=12.6.0->transformer-lens==0.0.0) (0.1.2)\n",
- "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (1.26.15)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2022.12.7)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.4)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->transformer-lens==0.0.0) (2.1.2)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->transformer-lens==0.0.0) (1.3.0)\n",
- "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.0)\n",
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting git+https://github.com/neelnanda-io/neel-plotly.git\n",
- " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-u8mujxc3\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-u8mujxc3\n",
- " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n",
- " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (1.24.3)\n",
- "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (2.0.0+cu118)\n",
- "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (5.13.1)\n",
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (4.65.0)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (1.5.3)\n",
- "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->neel-plotly==0.0.0) (2022.7.1)\n",
- "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->neel-plotly==0.0.0) (8.2.2)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.12.0)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (4.5.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (1.11.1)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.1.2)\n",
- "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (2.0.0)\n",
- "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (3.25.2)\n",
- "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (16.0.2)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->neel-plotly==0.0.0) (1.16.0)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->neel-plotly==0.0.0) (2.1.2)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->neel-plotly==0.0.0) (1.3.0)\n",
- "\n",
- "## Installing the NodeSource Node.js 16.x repo...\n",
- "\n",
- "\n",
- "## Populating apt-get cache...\n",
- "\n",
- "+ apt-get update\n",
- "Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\n",
- "Hit:2 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease\n",
- "Hit:3 https://deb.nodesource.com/node_16.x focal InRelease\n",
- "Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]\n",
- "Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease\n",
- "Hit:6 http://archive.ubuntu.com/ubuntu focal InRelease\n",
- "Get:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]\n",
- "Hit:8 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease\n",
- "Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\n",
- "Get:10 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]\n",
- "Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease\n",
- "Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease\n",
- "Fetched 336 kB in 2s (202 kB/s)\n",
- "Reading package lists... Done\n",
- "\n",
- "## Confirming \"focal\" is supported...\n",
- "\n",
- "+ curl -sLf -o /dev/null 'https://deb.nodesource.com/node_16.x/dists/focal/Release'\n",
- "\n",
- "## Adding the NodeSource signing key to your keyring...\n",
- "\n",
- "+ curl -s https://deb.nodesource.com/gpgkey/nodesource.gpg.key | gpg --dearmor | tee /usr/share/keyrings/nodesource.gpg >/dev/null\n",
- "\n",
- "## Creating apt sources list file for the NodeSource Node.js 16.x repo...\n",
- "\n",
- "+ echo 'deb [signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_16.x focal main' > /etc/apt/sources.list.d/nodesource.list\n",
- "+ echo 'deb-src [signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_16.x focal main' >> /etc/apt/sources.list.d/nodesource.list\n",
- "\n",
- "## Running `apt-get update` for you...\n",
- "\n",
- "+ apt-get update\n",
- "Hit:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease\n",
- "Hit:2 http://security.ubuntu.com/ubuntu focal-security InRelease\n",
- "Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\n",
- "Hit:4 https://deb.nodesource.com/node_16.x focal InRelease\n",
- "Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease\n",
- "Hit:6 http://archive.ubuntu.com/ubuntu focal InRelease\n",
- "Hit:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease\n",
- "Get:8 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]\n",
- "Hit:9 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease\n",
- "Hit:10 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\n",
- "Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease\n",
- "Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease\n",
- "Fetched 108 kB in 1s (73.2 kB/s)\n",
- "Reading package lists... Done\n",
- "\n",
- "## Run `\u001b[1msudo apt-get install -y nodejs\u001b[m` to install Node.js 16.x and npm\n",
- "## You may also need development tools to build native addons:\n",
- " sudo apt-get install gcc g++ make\n",
- "## To install the Yarn package manager, run:\n",
- " curl -sL https://dl.yarnpkg.com/debian/pubkey.gpg | gpg --dearmor | sudo tee /usr/share/keyrings/yarnkey.gpg >/dev/null\n",
- " echo \"deb [signed-by=/usr/share/keyrings/yarnkey.gpg] https://dl.yarnpkg.com/debian stable main\" | sudo tee /etc/apt/sources.list.d/yarn.list\n",
- " sudo apt-get update && sudo apt-get install yarn\n",
- "\n",
- "\n",
- "Reading package lists... Done\n",
- "Building dependency tree \n",
- "Reading state information... Done\n",
- "nodejs is already the newest version (16.20.0-deb-1nodesource1).\n",
- "0 upgraded, 0 newly installed, 0 to remove and 27 not upgraded.\n",
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting git+https://github.com/TransformerLensOrg/PySvelte.git\n",
- " Cloning https://github.com/TransformerLensOrg/PySvelte.git to /tmp/pip-req-build-09ycdh0j\n",
- " Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/PySvelte.git /tmp/pip-req-build-09ycdh0j\n",
- " Resolved https://github.com/TransformerLensOrg/PySvelte.git to commit 8410eae58503df0a293857a61a1a11ca35f86525\n",
- " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (0.6.1)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (1.24.3)\n",
- "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.0.0+cu118)\n",
- "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.12.0)\n",
- "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (4.28.1)\n",
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (4.65.0)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (1.5.3)\n",
- "Requirement already satisfied: typeguard~=2.0 in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.13.3)\n",
- "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (9.0.0)\n",
- "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.3.6)\n",
- "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (2.27.1)\n",
- "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (3.2.0)\n",
- "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.70.14)\n",
- "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (2023.4.0)\n",
- "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (3.8.4)\n",
- "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.14.1)\n",
- "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (23.1)\n",
- "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.18.0)\n",
- "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (6.0)\n",
- "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->PySvelte==1.0.0) (2.8.2)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->PySvelte==1.0.0) (2022.7.1)\n",
- "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.12.0)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (4.5.0)\n",
- "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (1.11.1)\n",
- "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.1)\n",
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.1.2)\n",
- "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (2.0.0)\n",
- "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->PySvelte==1.0.0) (3.25.2)\n",
- "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->PySvelte==1.0.0) (16.0.2)\n",
- "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->PySvelte==1.0.0) (2022.10.31)\n",
- "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->PySvelte==1.0.0) (0.13.3)\n",
- "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (23.1.0)\n",
- "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (2.0.12)\n",
- "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (6.0.4)\n",
- "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (4.0.2)\n",
- "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.9.2)\n",
- "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.3.3)\n",
- "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.3.1)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->PySvelte==1.0.0) (1.16.0)\n",
- "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (1.26.15)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (2022.12.7)\n",
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (3.4)\n",
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->PySvelte==1.0.0) (2.1.2)\n",
- "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->PySvelte==1.0.0) (1.3.0)\n",
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Requirement already satisfied: typeguard==2.13.3 in /usr/local/lib/python3.10/dist-packages (2.13.3)\n",
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (4.5.0)\n"
- ]
- }
- ],
- "source": [
- "# NBVAL_IGNORE_OUTPUT\n",
- "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
- "import os\n",
- "\n",
- "DEVELOPMENT_MODE = True\n",
- "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
- "try:\n",
- " import google.colab\n",
- " IN_COLAB = True\n",
- " print(\"Running as a Colab notebook\")\n",
- "except:\n",
- " IN_COLAB = False\n",
- " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
- " from IPython import get_ipython\n",
- "\n",
- " ipython = get_ipython()\n",
- " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
- " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
- " ipython.run_line_magic(\"autoreload\", \"2\")\n",
- "\n",
- "if IN_COLAB or IN_GITHUB:\n",
- " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n",
- " # Install Neel's personal plotting utils\n",
- " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n",
- " # Install another version of node that makes PySvelte work way faster\n",
- " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
- " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
- " # Needed for PySvelte to work, v3 came out and broke things...\n",
- " %pip install typeguard==2.13.3\n",
- " %pip install typing-extensions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "LBjE0qm6Ahyf"
- },
- "outputs": [],
- "source": [
- "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
- "import plotly.io as pio\n",
- "\n",
- "if IN_COLAB or not DEBUG_MODE:\n",
- " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n",
- " pio.renderers.default = \"colab\"\n",
- "else:\n",
- " pio.renderers.default = \"png\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "id": "ScWILAgIGt5O"
- },
- "outputs": [],
- "source": [
- "import torch\n",
- "import einops\n",
- "import pysvelte\n",
- "from tqdm import tqdm\n",
- "\n",
- "import transformer_lens\n",
- "from transformer_lens import HookedTransformer, ActivationCache\n",
- "from neel_plotly import line, imshow, scatter"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "13A_MpOwJBaJ",
- "outputId": "8b84df9b-886f-4205-cd51-0dfaf48d72d6"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "device = 'cuda'\n"
- ]
- }
- ],
- "source": [
- "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
- "print(f\"{device = }\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wLp6sCvBnXRn"
- },
- "source": [
- "### Some plotting utils"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "Gw7D7_IKkR3y"
- },
- "outputs": [],
- "source": [
- "# Util for plotting head detection scores\n",
- "\n",
- "def plot_head_detection_scores(\n",
- " scores: torch.Tensor,\n",
- " zmin: float = -1,\n",
- " zmax: float = 1,\n",
- " xaxis: str = \"Head\",\n",
- " yaxis: str = \"Layer\",\n",
- " title: str = \"Head Matches\"\n",
- ") -> None:\n",
- " imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)\n",
- "\n",
- "def plot_attn_pattern_from_cache(cache: ActivationCache, layer_i: int):\n",
- " attention_pattern = cache[\"pattern\", layer_i, \"attn\"].squeeze(0)\n",
- " attention_pattern = einops.rearrange(attention_pattern, \"heads seq1 seq2 -> seq1 seq2 heads\")\n",
- " print(f\"Layer {layer_i} Attention Heads:\")\n",
- " return pysvelte.AttentionMulti(tokens=model.to_str_tokens(prompt), attention=attention_pattern)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "eclSY10h7r4R"
- },
- "source": [
- "## Head detector"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QSVGddQDk1M6"
- },
- "source": [
- "Utils: these will be in `transformer_lens.utils` after merging the fork to the main repo"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "4zQYJUU4kgPu"
- },
- "outputs": [],
- "source": [
- "def is_square(x: torch.Tensor) -> bool:\n",
- " \"\"\"Checks if `x` is a square matrix.\"\"\"\n",
- " return x.ndim == 2 and x.shape[0] == x.shape[1]\n",
- "\n",
- "def is_lower_triangular(x: torch.Tensor) -> bool:\n",
- " \"\"\"Checks if `x` is a lower triangular matrix.\"\"\"\n",
- " if not is_square(x):\n",
- " return False\n",
- " return x.equal(x.tril())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "BCqH-TfXk49T"
- },
- "source": [
- "The code below is copy-pasted from the expanded (not yet merged) version of `transformer_lens.head_detector`.\n",
- "\n",
- "After merging the code below can be replaced with simply\n",
- "\n",
- "```py\n",
- "from transformer_lens.head_detector import *\n",
- "```\n",
- "\n",
- "(but please don't use star-imports in production ;))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "id": "5ikyL8-S7u2Z"
- },
- "outputs": [],
- "source": [
- "from collections import defaultdict\n",
- "import logging\n",
- "from typing import cast, Dict, List, Optional, Tuple, Union\n",
- "from typing_extensions import get_args, Literal\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "from transformer_lens import HookedTransformer, ActivationCache\n",
- "# from transformer_lens.utils import is_lower_triangular, is_square\n",
- "\n",
- "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n",
- "HEAD_NAMES = cast(List[HeadName], get_args(HeadName))\n",
- "ErrorMeasure = Literal[\"abs\", \"mul\"]\n",
- "\n",
- "LayerHeadTuple = Tuple[int, int]\n",
- "LayerToHead = Dict[int, List[int]]\n",
- "\n",
- "INVALID_HEAD_NAME_ERR = (\n",
- " f\"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s\"\n",
- ")\n",
- "\n",
- "SEQ_LEN_ERR = (\n",
- " \"The sequence must be non-empty and must fit within the model's context window.\"\n",
- ")\n",
- "\n",
- "DET_PAT_NOT_SQUARE_ERR = \"The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s\"\n",
- "\n",
- "\n",
- "def detect_head(\n",
- " model: HookedTransformer,\n",
- " seq: Union[str, List[str]],\n",
- " detection_pattern: Union[torch.Tensor, HeadName],\n",
- " heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n",
- " cache: Optional[ActivationCache] = None,\n",
- " *,\n",
- " exclude_bos: bool = False,\n",
- " exclude_current_token: bool = False,\n",
- " error_measure: ErrorMeasure = \"mul\",\n",
- ") -> torch.Tensor:\n",
- " \"\"\"Searches the model (or a set of specific heads, for circuit analysis) for a particular type of attention head.\n",
- " This head is specified by a detection pattern, a (sequence_length, sequence_length) tensor representing the attention pattern we expect that type of attention head to show.\n",
- " The detection pattern can be also passed not as a tensor, but as a name of one of pre-specified types of attention head (see `HeadName` for available patterns), in which case the tensor is computed within the function itself.\n",
- "\n",
- " There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern.\n",
- "\n",
- " 1. `\"mul\"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.\n",
- " Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score:\n",
- " how big fraction of this head's attention is allocated to these specific query-key pairs?\n",
- " Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).\n",
- " 2. `\"abs\"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.\n",
- " The \"raw result\" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,\n",
- " with 1 being perfect match and -1 perfect mismatch.\n",
- "\n",
- " **Which one should you use?** `\"abs\"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to\n",
- " reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `\"abs\"`.\n",
- "\n",
- " The advantage of `\"abs\"` is that you can make more precise predictions, and have that measured in the score.\n",
- " You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer.\n",
- " The \"mul\" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.\n",
- "\n",
- " Args:\n",
- " ----------\n",
- " model: Model being used.\n",
- " seq: String or list of strings being fed to the model.\n",
- " head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a head_name or a detection_pattern, but not both!\n",
- " detection_pattern: (sequence_length, sequence_length) Tensor representing what attention pattern corresponds to the head we're looking for **or** the name of a pre-specified head. Currently available heads are: `[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]`.\n",
- " heads: If specific attention heads is given here, all other heads' score is set to -1. Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer, head) or a dictionary mapping a layer to heads within that layer that we want to analyze.\n",
- " cache: Include the cache to save time if you want.\n",
- " exclude_bos: Exclude attention paid to the beginning of sequence token.\n",
- " exclude_current_token: Exclude attention paid to the current token.\n",
- " error_measure: `\"mul\"` for using element-wise multiplication (default). `\"abs\"` for using absolute values of element-wise differences as the error measure.\n",
- "\n",
- " Returns:\n",
- " ----------\n",
- " A (n_layers, n_heads) Tensor representing the score for each attention head.\n",
- "\n",
- " Example:\n",
- " --------\n",
- " .. code-block:: python\n",
- "\n",
- " >>> from transformer_lens import HookedTransformer, utils\n",
- " >>> from transformer_lens.head_detector import detect_head\n",
- " >>> import plotly.express as px\n",
- "\n",
- " >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
- " >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
- "\n",
- " >>> model = HookedTransformer.from_pretrained(\"gpt2-small\")\n",
- " >>> sequence = \"This is a test sequence. This is a test sequence.\"\n",
- "\n",
- " >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n",
- " >>> imshow(attention_score, zmin=-1, zmax=1, xaxis=\"Head\", yaxis=\"Layer\", title=\"Previous Head Matches\")\n",
- " \"\"\"\n",
- "\n",
- " cfg = model.cfg\n",
- " tokens = model.to_tokens(seq).to(cfg.device)\n",
- " seq_len = tokens.shape[-1]\n",
- " \n",
- " # Validate error_measure\n",
- " \n",
- " assert error_measure in get_args(ErrorMeasure), f\"Invalid {error_measure=}; valid values are {get_args(ErrorMeasure)}\"\n",
- "\n",
- " # Validate detection pattern if it's a string\n",
- " if isinstance(detection_pattern, str):\n",
- " assert detection_pattern in HEAD_NAMES, (\n",
- " INVALID_HEAD_NAME_ERR % detection_pattern\n",
- " )\n",
- " if isinstance(seq, list):\n",
- " batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq]\n",
- " return torch.stack(batch_scores).mean(0)\n",
- " detection_pattern = cast(\n",
- " torch.Tensor,\n",
- " eval(f\"get_{detection_pattern}_detection_pattern(tokens.cpu())\"),\n",
- " ).to(cfg.device)\n",
- "\n",
- " # if we're using \"mul\", detection_pattern should consist of zeros and ones\n",
- " if error_measure == \"mul\" and not set(detection_pattern.unique().tolist()).issubset(\n",
- " {0, 1}\n",
- " ):\n",
- " logging.warning(\n",
- " \"Using detection pattern with values other than 0 or 1 with error_measure 'mul'\"\n",
- " )\n",
- "\n",
- " # Validate inputs and detection pattern shape\n",
- " assert 1 < tokens.shape[-1] < cfg.n_ctx, SEQ_LEN_ERR\n",
- " assert (\n",
- " is_lower_triangular(detection_pattern) and seq_len == detection_pattern.shape[0]\n",
- " ), DET_PAT_NOT_SQUARE_ERR % (seq_len, detection_pattern.shape)\n",
- "\n",
- " if cache is None:\n",
- " _, cache = model.run_with_cache(tokens, remove_batch_dim=True)\n",
- "\n",
- " if heads is None:\n",
- " layer2heads = {\n",
- " layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)\n",
- " }\n",
- " elif isinstance(heads, list):\n",
- " layer2heads = defaultdict(list)\n",
- " for layer, head in heads:\n",
- " layer2heads[layer].append(head)\n",
- " else:\n",
- " layer2heads = heads\n",
- "\n",
- " matches = -torch.ones(cfg.n_layers, cfg.n_heads)\n",
- "\n",
- " for layer, layer_heads in layer2heads.items():\n",
- " # [n_heads q_pos k_pos]\n",
- " layer_attention_patterns = cache[\"pattern\", layer, \"attn\"]\n",
- " for head in layer_heads:\n",
- " head_attention_pattern = layer_attention_patterns[head, :, :]\n",
- " head_score = compute_head_attention_similarity_score(\n",
- " head_attention_pattern,\n",
- " detection_pattern=detection_pattern,\n",
- " exclude_bos=exclude_bos,\n",
- " exclude_current_token=exclude_current_token,\n",
- " error_measure=error_measure,\n",
- " )\n",
- " matches[layer, head] = head_score\n",
- " return matches\n",
- "\n",
- "\n",
- "# Previous token head\n",
- "def get_previous_token_head_detection_pattern(\n",
- " tokens: torch.Tensor, # [batch (1) x pos]\n",
- ") -> torch.Tensor:\n",
- " \"\"\"Outputs a detection score for [previous token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=0O5VOHe9xeZn8Ertywkh7ioc).\n",
- "\n",
- " Args:\n",
- " tokens: Tokens being fed to the model.\n",
- " \"\"\"\n",
- " detection_pattern = torch.zeros(tokens.shape[-1], tokens.shape[-1])\n",
- " # Adds a diagonal of 1's below the main diagonal.\n",
- " detection_pattern[1:, :-1] = torch.eye(tokens.shape[-1] - 1)\n",
- " return torch.tril(detection_pattern)\n",
- "\n",
- "\n",
- "# Duplicate token head\n",
- "def get_duplicate_token_head_detection_pattern(\n",
- " tokens: torch.Tensor, # [batch (1) x pos]\n",
- ") -> torch.Tensor:\n",
- " \"\"\"Outputs a detection score for [duplicate token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=2UkvedzOnghL5UHUgVhROxeo).\n",
- "\n",
- " Args:\n",
- " sequence: String being fed to the model.\n",
- " \"\"\"\n",
- " # [pos x pos]\n",
- " token_pattern = tokens.repeat(tokens.shape[-1], 1).numpy()\n",
- "\n",
- " # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.\n",
- " eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)\n",
- "\n",
- " np.fill_diagonal(\n",
- " eq_mask, 0\n",
- " ) # Current token is always a duplicate of itself. Ignore that.\n",
- " detection_pattern = eq_mask.astype(int)\n",
- " return torch.tril(torch.as_tensor(detection_pattern).float())\n",
- "\n",
- "\n",
- "# Induction head\n",
- "def get_induction_head_detection_pattern(\n",
- " tokens: torch.Tensor, # [batch (1) x pos]\n",
- ") -> torch.Tensor:\n",
- " \"\"\"Outputs a detection score for [induction heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_tFVuP5csv5ORIthmqwj0gSY).\n",
- "\n",
- " Args:\n",
- " sequence: String being fed to the model.\n",
- " \"\"\"\n",
- " duplicate_pattern = get_duplicate_token_head_detection_pattern(tokens)\n",
- "\n",
- " # Shift all items one to the right\n",
- " shifted_tensor = torch.roll(duplicate_pattern, shifts=1, dims=1)\n",
- "\n",
- " # Replace first column with 0's\n",
- " # we don't care about bos but shifting to the right moves the last column to the first,\n",
- " # and the last column might contain non-zero values.\n",
- " zeros_column = torch.zeros(duplicate_pattern.shape[0], 1)\n",
- " result_tensor = torch.cat((zeros_column, shifted_tensor[:, 1:]), dim=1)\n",
- " return torch.tril(result_tensor)\n",
- "\n",
- "\n",
- "def get_supported_heads() -> None:\n",
- " \"\"\"Returns a list of supported heads.\"\"\"\n",
- " print(f\"Supported heads: {HEAD_NAMES}\")\n",
- "\n",
- "\n",
- "def compute_head_attention_similarity_score(\n",
- " attention_pattern: torch.Tensor, # [q_pos k_pos]\n",
- " detection_pattern: torch.Tensor, # [seq_len seq_len] (seq_len == q_pos == k_pos)\n",
- " *,\n",
- " exclude_bos: bool,\n",
- " exclude_current_token: bool,\n",
- " error_measure: ErrorMeasure,\n",
- ") -> float:\n",
- " \"\"\"Compute the similarity between `attention_pattern` and `detection_pattern`.\n",
- "\n",
- " Args:\n",
- " attention_pattern: Lower triangular matrix (Tensor) representing the attention pattern of a particular attention head.\n",
- " detection_pattern: Lower triangular matrix (Tensor) representing the attention pattern we are looking for.\n",
- " exclude_bos: `True` if the beginning-of-sentence (BOS) token should be omitted from comparison. `False` otherwise.\n",
- " exclude_bcurrent_token: `True` if the current token at each position should be omitted from comparison. `False` otherwise.\n",
- " error_measure: \"abs\" for using absolute values of element-wise differences as the error measure. \"mul\" for using element-wise multiplication (legacy code).\n",
- " \"\"\"\n",
- " assert is_square(\n",
- " attention_pattern\n",
- " ), f\"Attention pattern is not square; got shape {attention_pattern.shape}\"\n",
- "\n",
- " # mul\n",
- "\n",
- " if error_measure == \"mul\":\n",
- " if exclude_bos:\n",
- " attention_pattern[:, 0] = 0\n",
- " if exclude_current_token:\n",
- " attention_pattern.fill_diagonal_(0)\n",
- " score = attention_pattern * detection_pattern\n",
- " return (score.sum() / attention_pattern.sum()).item()\n",
- "\n",
- " # abs\n",
- "\n",
- " abs_diff = (attention_pattern - detection_pattern).abs()\n",
- " assert (abs_diff - torch.tril(abs_diff).to(abs_diff.device)).sum() == 0\n",
- "\n",
- " size = len(abs_diff)\n",
- " if exclude_bos:\n",
- " abs_diff[:, 0] = 0\n",
- " if exclude_current_token:\n",
- " abs_diff.fill_diagonal_(0)\n",
- "\n",
- " return 1 - round((abs_diff.mean() * size).item(), 3)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Bw4CZS-tCH7u"
- },
- "source": [
- "## Using Head Detector For Premade Heads\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A0iDohDUmS_r"
- },
- "source": [
- "Load the model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "
\n",
+ ""
+ ]
},
- "id": "nIiUfx76I6a1",
- "outputId": "85bf4ea6-0c27-4f3f-dfe5-b173dd3b70e0"
- },
- "outputs": [
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Using pad_token, but it is not set yet.\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YquKKgs17NOv"
+ },
+ "source": [
+ "# TransformerLens Head Detector Demo"
+ ]
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Loaded pretrained model gpt2-small into HookedTransformer\n"
- ]
- }
- ],
- "source": [
- "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=device)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "cgqKW_kWmPWX"
- },
- "source": [
- "See what heads are supported out of the box"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "3i0kKYngmLru",
- "outputId": "72a44f58-8a6b-4551-bb38-ddc177f5fd25"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Supported heads: ('previous_token_head', 'duplicate_token_head', 'induction_head')\n"
- ]
- }
- ],
- "source": [
- "get_supported_heads()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "b88sGXh1mUvD"
- },
- "source": [
- "Let's test detecting previous token head in the following prompt."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "hncQfgF8CE_i",
- "outputId": "a925be04-74ed-4e9e-b02d-faf6d24026f0"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "prompt = \"The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery.\"\n",
- "head_scores = detect_head(model, prompt, \"previous_token_head\")\n",
- "plot_head_detection_scores(head_scores, title=\"Previous Head Matches\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "f_iFBhRRKQIF"
- },
- "source": [
- "We can see both L2H2 and L4H11 are doing a fair bit of previous token detection. Let's take a look and see if that pans out."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "id": "95mH2b43n0EZ"
- },
- "outputs": [],
- "source": [
- "_, cache = model.run_with_cache(prompt)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "S7bz-uZQKWpj",
- "outputId": "8bb120a6-3223-4b65-9c92-fd089d3f1d4e"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 2 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wKW2CqN-yZuY"
+ },
+ "source": [
+ "A common technique in mechanistic interpretability of transformer-based neural networks is identification of specialized attention heads, based on the attention patterns elicited by one or more prompts. The most basic examples of such heads are: previous token head, duplicate token head, or induction head ([more info](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_Jzi6YHRHKP1JziwdE02qdYZ)). Usually, such heads are identified manually, by through visualizations of attention patterns layer by layer, head by head, and trying to recognize the patterns by eye.\n",
+ "\n",
+ "The purpose of the `TransformerLens.head_detector` feature is to automate a part of that workflow. The pattern characterizing a head of particular type/function is specified as a `Tensor` being a `seq_len x seq_len` [lower triangular matrix](https://en.wikipedia.org/wiki/Triangular_matrix). It can be either passed to the `detect_head` function directly or by giving a string identifying of several pre-defined detection patterns."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3a53LkPTAjzB"
+ },
+ "source": [
+ "## How to use this notebook\n",
+ "\n",
+ "Go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n",
+ "\n",
+ "Tips for reading this Colab:\n",
+ "\n",
+ "* You can run all this code for yourself!\n",
+ "* The graphs are interactive!\n",
+ "* Use the table of contents pane in the sidebar to navigate\n",
+ "* Collapse irrelevant sections with the dropdown arrows\n",
+ "* Search the page using the search in the sidebar, not CTRL+F"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "nCWImh1S7fNx"
+ },
+ "source": [
+ "## Setup (Ignore)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4LZeYL3XAc7T",
+ "outputId": "680da02d-5ca8-4ab3-bc24-f2827f0fcd95"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running as a Colab notebook\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git\n",
+ " Cloning https://github.com/TransformerLensOrg/TransformerLens.git to /tmp/pip-req-build-v3x96q_b\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-v3x96q_b\n",
+ " Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit 0ffcc8ad647d9e991f4c2596557a9d7475617773\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.12.0)\n",
+ "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.6.1)\n",
+ "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.0.3)\n",
+ "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.2.15)\n",
+ "Requirement already satisfied: numpy>=1.23 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.24.3)\n",
+ "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (1.5.3)\n",
+ "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (13.3.4)\n",
+ "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.0.0+cu118)\n",
+ "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.65.0)\n",
+ "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (4.28.1)\n",
+ "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.15.0)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (9.0.0)\n",
+ "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.3.6)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2.27.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.2.0)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.70.14)\n",
+ "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (2023.4.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (3.8.4)\n",
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.14.1)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (23.1)\n",
+ "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (0.18.0)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.7.1->transformer-lens==0.0.0) (6.0)\n",
+ "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (2.13.3)\n",
+ "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.10/dist-packages (from jaxtyping>=0.2.11->transformer-lens==0.0.0) (4.5.0)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.1.5->transformer-lens==0.0.0) (2022.7.1)\n",
+ "Requirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.2.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=12.6.0->transformer-lens==0.0.0) (2.14.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.12.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (1.11.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->transformer-lens==0.0.0) (2.0.0)\n",
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->transformer-lens==0.0.0) (3.25.2)\n",
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->transformer-lens==0.0.0) (16.0.2)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (2022.10.31)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.25.1->transformer-lens==0.0.0) (0.13.3)\n",
+ "Requirement already satisfied: Click!=8.0.0,>=7.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (8.1.3)\n",
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.1.31)\n",
+ "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (5.9.5)\n",
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.21.1)\n",
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.4.0)\n",
+ "Requirement already satisfied: pathtools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (0.1.2)\n",
+ "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.3.2)\n",
+ "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (67.7.2)\n",
+ "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (1.4.4)\n",
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb>=0.13.5->transformer-lens==0.0.0) (3.20.3)\n",
+ "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer-lens==0.0.0) (1.16.0)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (2.0.12)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (4.0.2)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.3)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.7.1->transformer-lens==0.0.0) (1.3.1)\n",
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (4.0.10)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py<3.0.0,>=2.2.0->rich>=12.6.0->transformer-lens==0.0.0) (0.1.2)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (1.26.15)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (2022.12.7)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer-lens==0.0.0) (3.4)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->transformer-lens==0.0.0) (2.1.2)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->transformer-lens==0.0.0) (1.3.0)\n",
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.0)\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting git+https://github.com/neelnanda-io/neel-plotly.git\n",
+ " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-u8mujxc3\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-u8mujxc3\n",
+ " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (1.24.3)\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (2.0.0+cu118)\n",
+ "Requirement already satisfied: plotly in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (5.13.1)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (4.65.0)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (1.5.3)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->neel-plotly==0.0.0) (2022.7.1)\n",
+ "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from plotly->neel-plotly==0.0.0) (8.2.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.12.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (4.5.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (1.11.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->neel-plotly==0.0.0) (2.0.0)\n",
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (3.25.2)\n",
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (16.0.2)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->neel-plotly==0.0.0) (1.16.0)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->neel-plotly==0.0.0) (2.1.2)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->neel-plotly==0.0.0) (1.3.0)\n",
+ "\n",
+ "## Installing the NodeSource Node.js 16.x repo...\n",
+ "\n",
+ "\n",
+ "## Populating apt-get cache...\n",
+ "\n",
+ "+ apt-get update\n",
+ "Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\n",
+ "Hit:2 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease\n",
+ "Hit:3 https://deb.nodesource.com/node_16.x focal InRelease\n",
+ "Get:4 http://security.ubuntu.com/ubuntu focal-security InRelease [114 kB]\n",
+ "Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease\n",
+ "Hit:6 http://archive.ubuntu.com/ubuntu focal InRelease\n",
+ "Get:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease [114 kB]\n",
+ "Hit:8 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease\n",
+ "Hit:9 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\n",
+ "Get:10 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]\n",
+ "Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease\n",
+ "Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease\n",
+ "Fetched 336 kB in 2s (202 kB/s)\n",
+ "Reading package lists... Done\n",
+ "\n",
+ "## Confirming \"focal\" is supported...\n",
+ "\n",
+ "+ curl -sLf -o /dev/null 'https://deb.nodesource.com/node_16.x/dists/focal/Release'\n",
+ "\n",
+ "## Adding the NodeSource signing key to your keyring...\n",
+ "\n",
+ "+ curl -s https://deb.nodesource.com/gpgkey/nodesource.gpg.key | gpg --dearmor | tee /usr/share/keyrings/nodesource.gpg >/dev/null\n",
+ "\n",
+ "## Creating apt sources list file for the NodeSource Node.js 16.x repo...\n",
+ "\n",
+ "+ echo 'deb [signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_16.x focal main' > /etc/apt/sources.list.d/nodesource.list\n",
+ "+ echo 'deb-src [signed-by=/usr/share/keyrings/nodesource.gpg] https://deb.nodesource.com/node_16.x focal main' >> /etc/apt/sources.list.d/nodesource.list\n",
+ "\n",
+ "## Running `apt-get update` for you...\n",
+ "\n",
+ "+ apt-get update\n",
+ "Hit:1 https://cloud.r-project.org/bin/linux/ubuntu focal-cran40/ InRelease\n",
+ "Hit:2 http://security.ubuntu.com/ubuntu focal-security InRelease\n",
+ "Hit:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64 InRelease\n",
+ "Hit:4 https://deb.nodesource.com/node_16.x focal InRelease\n",
+ "Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu focal InRelease\n",
+ "Hit:6 http://archive.ubuntu.com/ubuntu focal InRelease\n",
+ "Hit:7 http://archive.ubuntu.com/ubuntu focal-updates InRelease\n",
+ "Get:8 http://archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]\n",
+ "Hit:9 http://ppa.launchpad.net/cran/libgit2/ubuntu focal InRelease\n",
+ "Hit:10 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease\n",
+ "Hit:11 http://ppa.launchpad.net/graphics-drivers/ppa/ubuntu focal InRelease\n",
+ "Hit:12 http://ppa.launchpad.net/ubuntugis/ppa/ubuntu focal InRelease\n",
+ "Fetched 108 kB in 1s (73.2 kB/s)\n",
+ "Reading package lists... Done\n",
+ "\n",
+ "## Run `\u001b[1msudo apt-get install -y nodejs\u001b[m` to install Node.js 16.x and npm\n",
+ "## You may also need development tools to build native addons:\n",
+ " sudo apt-get install gcc g++ make\n",
+ "## To install the Yarn package manager, run:\n",
+ " curl -sL https://dl.yarnpkg.com/debian/pubkey.gpg | gpg --dearmor | sudo tee /usr/share/keyrings/yarnkey.gpg >/dev/null\n",
+ " echo \"deb [signed-by=/usr/share/keyrings/yarnkey.gpg] https://dl.yarnpkg.com/debian stable main\" | sudo tee /etc/apt/sources.list.d/yarn.list\n",
+ " sudo apt-get update && sudo apt-get install yarn\n",
+ "\n",
+ "\n",
+ "Reading package lists... Done\n",
+ "Building dependency tree \n",
+ "Reading state information... Done\n",
+ "nodejs is already the newest version (16.20.0-deb-1nodesource1).\n",
+ "0 upgraded, 0 newly installed, 0 to remove and 27 not upgraded.\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting git+https://github.com/TransformerLensOrg/PySvelte.git\n",
+ " Cloning https://github.com/TransformerLensOrg/PySvelte.git to /tmp/pip-req-build-09ycdh0j\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/PySvelte.git /tmp/pip-req-build-09ycdh0j\n",
+ " Resolved https://github.com/TransformerLensOrg/PySvelte.git to commit 8410eae58503df0a293857a61a1a11ca35f86525\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (0.6.1)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (1.24.3)\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.0.0+cu118)\n",
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.12.0)\n",
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (4.28.1)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (4.65.0)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (1.5.3)\n",
+ "Requirement already satisfied: typeguard~=2.0 in /usr/local/lib/python3.10/dist-packages (from PySvelte==1.0.0) (2.13.3)\n",
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (9.0.0)\n",
+ "Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.3.6)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (2.27.1)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (3.2.0)\n",
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.70.14)\n",
+ "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (2023.4.0)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (3.8.4)\n",
+ "Requirement already satisfied: huggingface-hub<1.0.0,>=0.11.0 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.14.1)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (23.1)\n",
+ "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (0.18.0)\n",
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets->PySvelte==1.0.0) (6.0)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->PySvelte==1.0.0) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->PySvelte==1.0.0) (2022.7.1)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.12.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (4.5.0)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (1.11.1)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.1)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (3.1.2)\n",
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->PySvelte==1.0.0) (2.0.0)\n",
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->PySvelte==1.0.0) (3.25.2)\n",
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->PySvelte==1.0.0) (16.0.2)\n",
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->PySvelte==1.0.0) (2022.10.31)\n",
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->PySvelte==1.0.0) (0.13.3)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (23.1.0)\n",
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (2.0.12)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (6.0.4)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (4.0.2)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.9.2)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.3.3)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->PySvelte==1.0.0) (1.3.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->PySvelte==1.0.0) (1.16.0)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (1.26.15)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (2022.12.7)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets->PySvelte==1.0.0) (3.4)\n",
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->PySvelte==1.0.0) (2.1.2)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->PySvelte==1.0.0) (1.3.0)\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Requirement already satisfied: typeguard==2.13.3 in /usr/local/lib/python3.10/dist-packages (2.13.3)\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (4.5.0)\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "# NBVAL_IGNORE_OUTPUT\n",
+ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "import os\n",
+ "\n",
+ "DEVELOPMENT_MODE = True\n",
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ "except:\n",
+ " IN_COLAB = False\n",
+ " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
+ " from IPython import get_ipython\n",
+ "\n",
+ " ipython = get_ipython()\n",
+ " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ " ipython.magic(\"autoreload 2\")\n",
+ "\n",
+ "if IN_COLAB or IN_GITHUB:\n",
+ " %pip install transformer_lens\n",
+ " # Install Neel's personal plotting utils\n",
+ " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n",
+ " # Install another version of node that makes PySvelte work way faster\n",
+ " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
+ " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
+ " # Needed for PySvelte to work, v3 came out and broke things...\n",
+ " %pip install typeguard==2.13.3\n",
+ " %pip install typing-extensions"
]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "7OEhpa-HLZJq",
- "outputId": "a27f6688-a121-4ca3-efe8-9f3d71bc37cd"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 4 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LBjE0qm6Ahyf"
+ },
+ "outputs": [],
+ "source": [
+ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
+ "import plotly.io as pio\n",
+ "\n",
+ "if IN_COLAB or not DEVELOPMENT_MODE:\n",
+ " # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n",
+ " pio.renderers.default = \"colab\"\n",
+ "else:\n",
+ " pio.renderers.default = \"png\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ScWILAgIGt5O"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import einops\n",
+ "import pysvelte\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "import transformer_lens\n",
+ "from transformer_lens import ActivationCache\n",
+ "from transformer_lens.model_bridge import TransformerBridge\n",
+ "from neel_plotly import line, imshow, scatter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "13A_MpOwJBaJ",
+ "outputId": "8b84df9b-886f-4205-cd51-0dfaf48d72d6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "device = 'cuda'\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ "print(f\"{device = }\")"
]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 4)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "f4Eiua9pMLok"
- },
- "source": [
- "As we expected, L2H2 is doing a lot of previous token detection, but doesn't appear to be a sharp previous token detection head. L4H11, on the other hand, is pretty much perfect. In fact, the only place it seems to be putting any other attention is the very first token, where it pays attention to the BOS (*beginning-of-sentence*) token.\n",
- "\n",
- "Mechanistic interpretability is still a very new field, and we don't know the best ways to measure things yet. Ignoring attention paid to BOS allows us to solve problems like the above, but may also give us artifically high results for a head like L4H10, which doesn't appear to be doing much of anything, but does have a bit of previous token attention going on if you squint carefully.\n",
- "\n",
- "As such, the head detector supports both an `exclude_bos` and `exclude_current_token` argument, which ignores all BOS attention and all current token attention respectively. By default these are `False`, but this is a pretty arbitrary decision, so feel free to try things out! You don't need a good reason to change these arguments - pick whatever best helps you find out useful things!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "C42HUkb_NRad",
- "outputId": "3669bb7c-0d18-45ba-8d62-18162fb70b89"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(model, prompt, \"previous_token_head\", exclude_bos=True, exclude_current_token=True)\n",
- "plot_head_detection_scores(head_scores, title=\"Previous Head Matches\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "oiWQjv9UNerH"
- },
- "source": [
- "Now we have a lot more detection, including L0H3 and L5H6 which were unremarkable before. Let's check them out!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "cCsB8fD8oH5i",
- "outputId": "e1ebdc6d-07a6-446d-91de-a5b197e999b5"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 5 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "wLp6sCvBnXRn"
+ },
+ "source": [
+ "### Some plotting utils"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "Gw7D7_IKkR3y"
+ },
+ "outputs": [],
+ "source": [
+ "# Util for plotting head detection scores\n",
+ "\n",
+ "def plot_head_detection_scores(\n",
+ " scores: torch.Tensor,\n",
+ " zmin: float = -1,\n",
+ " zmax: float = 1,\n",
+ " xaxis: str = \"Head\",\n",
+ " yaxis: str = \"Layer\",\n",
+ " title: str = \"Head Matches\"\n",
+ ") -> None:\n",
+ " imshow(scores, zmin=zmin, zmax=zmax, xaxis=xaxis, yaxis=yaxis, title=title)\n",
+ "\n",
+ "def plot_attn_pattern_from_cache(cache: ActivationCache, layer_i: int):\n",
+ " attention_pattern = cache[\"pattern\", layer_i, \"attn\"].squeeze(0)\n",
+ " attention_pattern = einops.rearrange(attention_pattern, \"heads seq1 seq2 -> seq1 seq2 heads\")\n",
+ " print(f\"Layer {layer_i} Attention Heads:\")\n",
+ " return pysvelte.AttentionMulti(tokens=model.to_str_tokens(prompt), attention=attention_pattern)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eclSY10h7r4R"
+ },
+ "source": [
+ "## Head detector"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QSVGddQDk1M6"
+ },
+ "source": [
+ "Utils: these will be in `transformer_lens.utils` after merging the fork to the main repo"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "4zQYJUU4kgPu"
+ },
+ "outputs": [],
+ "source": [
+ "def is_square(x: torch.Tensor) -> bool:\n",
+ " \"\"\"Checks if `x` is a square matrix.\"\"\"\n",
+ " return x.ndim == 2 and x.shape[0] == x.shape[1]\n",
+ "\n",
+ "def is_lower_triangular(x: torch.Tensor) -> bool:\n",
+ " \"\"\"Checks if `x` is a lower triangular matrix.\"\"\"\n",
+ " if not is_square(x):\n",
+ " return False\n",
+ " return x.equal(x.tril())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BCqH-TfXk49T"
+ },
+ "source": [
+ "The code below is copy-pasted from the expanded (not yet merged) version of `transformer_lens.head_detector`.\n",
+ "\n",
+ "After merging the code below can be replaced with simply\n",
+ "\n",
+ "```py\n",
+ "from transformer_lens.head_detector import *\n",
+ "```\n",
+ "\n",
+ "(but please don't use star-imports in production ;))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5ikyL8-S7u2Z"
+ },
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "import logging\n",
+ "from typing import cast, Dict, List, Optional, Tuple, Union\n",
+ "from typing_extensions import get_args, Literal\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "\n",
+ "# from transformer_lens.utils import is_lower_triangular, is_square\n",
+ "\n",
+ "HeadName = Literal[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]\n",
+ "HEAD_NAMES = cast(List[HeadName], get_args(HeadName))\n",
+ "ErrorMeasure = Literal[\"abs\", \"mul\"]\n",
+ "\n",
+ "LayerHeadTuple = Tuple[int, int]\n",
+ "LayerToHead = Dict[int, List[int]]\n",
+ "\n",
+ "INVALID_HEAD_NAME_ERR = (\n",
+ " f\"detection_pattern must be a Tensor or one of head names: {HEAD_NAMES}; got %s\"\n",
+ ")\n",
+ "\n",
+ "SEQ_LEN_ERR = (\n",
+ " \"The sequence must be non-empty and must fit within the model's context window.\"\n",
+ ")\n",
+ "\n",
+ "DET_PAT_NOT_SQUARE_ERR = \"The detection pattern must be a lower triangular matrix of shape (sequence_length, sequence_length); sequence_length=%d; got detection patern of shape %s\"\n",
+ "\n",
+ "\n",
+ "def detect_head(\n",
+ " model: TransformerBridge,\n",
+ " seq: Union[str, List[str]],\n",
+ " detection_pattern: Union[torch.Tensor, HeadName],\n",
+ " heads: Optional[Union[List[LayerHeadTuple], LayerToHead]] = None,\n",
+ " cache: Optional[ActivationCache] = None,\n",
+ " *,\n",
+ " exclude_bos: bool = False,\n",
+ " exclude_current_token: bool = False,\n",
+ " error_measure: ErrorMeasure = \"mul\",\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"Searches the model (or a set of specific heads, for circuit analysis) for a particular type of attention head.\n",
+ " This head is specified by a detection pattern, a (sequence_length, sequence_length) tensor representing the attention pattern we expect that type of attention head to show.\n",
+ " The detection pattern can be also passed not as a tensor, but as a name of one of pre-specified types of attention head (see `HeadName` for available patterns), in which case the tensor is computed within the function itself.\n",
+ "\n",
+ " There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern.\n",
+ "\n",
+ " 1. `\"mul\"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.\n",
+ " Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score:\n",
+ " how big fraction of this head's attention is allocated to these specific query-key pairs?\n",
+ " Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).\n",
+ " 2. `\"abs\"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.\n",
+ " The \"raw result\" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval,\n",
+ " with 1 being perfect match and -1 perfect mismatch.\n",
+ "\n",
+ " **Which one should you use?** `\"abs\"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to\n",
+ " reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `\"abs\"`.\n",
+ "\n",
+ " The advantage of `\"abs\"` is that you can make more precise predictions, and have that measured in the score.\n",
+ " You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer.\n",
+ " The \"mul\" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2.\n",
+ "\n",
+ " Args:\n",
+ " ----------\n",
+ " model: Model being used.\n",
+ " seq: String or list of strings being fed to the model.\n",
+ " head_name: Name of an existing head in HEAD_NAMES we want to check. Must pass either a head_name or a detection_pattern, but not both!\n",
+ " detection_pattern: (sequence_length, sequence_length) Tensor representing what attention pattern corresponds to the head we're looking for **or** the name of a pre-specified head. Currently available heads are: `[\"previous_token_head\", \"duplicate_token_head\", \"induction_head\"]`.\n",
+ " heads: If specific attention heads is given here, all other heads' score is set to -1. Useful for IOI-style circuit analysis. Heads can be spacified as a list tuples (layer, head) or a dictionary mapping a layer to heads within that layer that we want to analyze.\n",
+ " cache: Include the cache to save time if you want.\n",
+ " exclude_bos: Exclude attention paid to the beginning of sequence token.\n",
+ " exclude_current_token: Exclude attention paid to the current token.\n",
+ " error_measure: `\"mul\"` for using element-wise multiplication (default). `\"abs\"` for using absolute values of element-wise differences as the error measure.\n",
+ "\n",
+ " Returns:\n",
+ " ----------\n",
+ " A (n_layers, n_heads) Tensor representing the score for each attention head.\n",
+ "\n",
+ " Example:\n",
+ " --------\n",
+ " .. code-block:: python\n",
+ "\n",
+ " >>> from transformer_lens import utils\n",
+ " >>> from transformer_lens.model_bridge import TransformerBridge\n",
+ " >>> from transformer_lens.head_detector import detect_head\n",
+ " >>> import plotly.express as px\n",
+ "\n",
+ " >>> def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
+ " >>> px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
+ "\n",
+ " >>> model = TransformerBridge.boot_transformers(\"gpt2\")\n",
+ " >>> model.enable_compatibility_mode()\n",
+ " >>> sequence = \"This is a test sequence. This is a test sequence.\"\n",
+ "\n",
+ " >>> attention_score = detect_head(model, sequence, \"previous_token_head\")\n",
+ " >>> imshow(attention_score, zmin=-1, zmax=1, xaxis=\"Head\", yaxis=\"Layer\", title=\"Previous Head Matches\")\n",
+ " \"\"\"\n",
+ "\n",
+ " cfg = model.cfg\n",
+ " tokens = model.to_tokens(seq).to(cfg.device)\n",
+ " seq_len = tokens.shape[-1]\n",
+ " \n",
+ " # Validate error_measure\n",
+ " \n",
+ " assert error_measure in get_args(ErrorMeasure), f\"Invalid {error_measure=}; valid values are {get_args(ErrorMeasure)}\"\n",
+ "\n",
+ " # Validate detection pattern if it's a string\n",
+ " if isinstance(detection_pattern, str):\n",
+ " assert detection_pattern in HEAD_NAMES, (\n",
+ " INVALID_HEAD_NAME_ERR % detection_pattern\n",
+ " )\n",
+ " if isinstance(seq, list):\n",
+ " batch_scores = [detect_head(model, seq, detection_pattern) for seq in seq]\n",
+ " return torch.stack(batch_scores).mean(0)\n",
+ " detection_pattern = cast(\n",
+ " torch.Tensor,\n",
+ " eval(f\"get_{detection_pattern}_detection_pattern(tokens.cpu())\"),\n",
+ " ).to(cfg.device)\n",
+ "\n",
+ " # if we're using \"mul\", detection_pattern should consist of zeros and ones\n",
+ " if error_measure == \"mul\" and not set(detection_pattern.unique().tolist()).issubset(\n",
+ " {0, 1}\n",
+ " ):\n",
+ " logging.warning(\n",
+ " \"Using detection pattern with values other than 0 or 1 with error_measure 'mul'\"\n",
+ " )\n",
+ "\n",
+ " # Validate inputs and detection pattern shape\n",
+ " assert 1 < tokens.shape[-1] < cfg.n_ctx, SEQ_LEN_ERR\n",
+ " assert (\n",
+ " is_lower_triangular(detection_pattern) and seq_len == detection_pattern.shape[0]\n",
+ " ), DET_PAT_NOT_SQUARE_ERR % (seq_len, detection_pattern.shape)\n",
+ "\n",
+ " if cache is None:\n",
+ " _, cache = model.run_with_cache(tokens, remove_batch_dim=True)\n",
+ "\n",
+ " if heads is None:\n",
+ " layer2heads = {\n",
+ " layer_i: list(range(cfg.n_heads)) for layer_i in range(cfg.n_layers)\n",
+ " }\n",
+ " elif isinstance(heads, list):\n",
+ " layer2heads = defaultdict(list)\n",
+ " for layer, head in heads:\n",
+ " layer2heads[layer].append(head)\n",
+ " else:\n",
+ " layer2heads = heads\n",
+ "\n",
+ " matches = -torch.ones(cfg.n_layers, cfg.n_heads)\n",
+ "\n",
+ " for layer, layer_heads in layer2heads.items():\n",
+ " # [n_heads q_pos k_pos]\n",
+ " layer_attention_patterns = cache[\"pattern\", layer, \"attn\"]\n",
+ " for head in layer_heads:\n",
+ " head_attention_pattern = layer_attention_patterns[head, :, :]\n",
+ " head_score = compute_head_attention_similarity_score(\n",
+ " head_attention_pattern,\n",
+ " detection_pattern=detection_pattern,\n",
+ " exclude_bos=exclude_bos,\n",
+ " exclude_current_token=exclude_current_token,\n",
+ " error_measure=error_measure,\n",
+ " )\n",
+ " matches[layer, head] = head_score\n",
+ " return matches\n",
+ "\n",
+ "\n",
+ "# Previous token head\n",
+ "def get_previous_token_head_detection_pattern(\n",
+ " tokens: torch.Tensor, # [batch (1) x pos]\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"Outputs a detection score for [previous token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=0O5VOHe9xeZn8Ertywkh7ioc).\n",
+ "\n",
+ " Args:\n",
+ " tokens: Tokens being fed to the model.\n",
+ " \"\"\"\n",
+ " detection_pattern = torch.zeros(tokens.shape[-1], tokens.shape[-1])\n",
+ " # Adds a diagonal of 1's below the main diagonal.\n",
+ " detection_pattern[1:, :-1] = torch.eye(tokens.shape[-1] - 1)\n",
+ " return torch.tril(detection_pattern)\n",
+ "\n",
+ "\n",
+ "# Duplicate token head\n",
+ "def get_duplicate_token_head_detection_pattern(\n",
+ " tokens: torch.Tensor, # [batch (1) x pos]\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"Outputs a detection score for [duplicate token heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=2UkvedzOnghL5UHUgVhROxeo).\n",
+ "\n",
+ " Args:\n",
+ " sequence: String being fed to the model.\n",
+ " \"\"\"\n",
+ " # [pos x pos]\n",
+ " token_pattern = tokens.repeat(tokens.shape[-1], 1).numpy()\n",
+ "\n",
+ " # If token_pattern[i][j] matches its transpose, then token j and token i are duplicates.\n",
+ " eq_mask = np.equal(token_pattern, token_pattern.T).astype(int)\n",
+ "\n",
+ " np.fill_diagonal(\n",
+ " eq_mask, 0\n",
+ " ) # Current token is always a duplicate of itself. Ignore that.\n",
+ " detection_pattern = eq_mask.astype(int)\n",
+ " return torch.tril(torch.as_tensor(detection_pattern).float())\n",
+ "\n",
+ "\n",
+ "# Induction head\n",
+ "def get_induction_head_detection_pattern(\n",
+ " tokens: torch.Tensor, # [batch (1) x pos]\n",
+ ") -> torch.Tensor:\n",
+ " \"\"\"Outputs a detection score for [induction heads](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=_tFVuP5csv5ORIthmqwj0gSY).\n",
+ "\n",
+ " Args:\n",
+ " sequence: String being fed to the model.\n",
+ " \"\"\"\n",
+ " duplicate_pattern = get_duplicate_token_head_detection_pattern(tokens)\n",
+ "\n",
+ " # Shift all items one to the right\n",
+ " shifted_tensor = torch.roll(duplicate_pattern, shifts=1, dims=1)\n",
+ "\n",
+ " # Replace first column with 0's\n",
+ " # we don't care about bos but shifting to the right moves the last column to the first,\n",
+ " # and the last column might contain non-zero values.\n",
+ " zeros_column = torch.zeros(duplicate_pattern.shape[0], 1)\n",
+ " result_tensor = torch.cat((zeros_column, shifted_tensor[:, 1:]), dim=1)\n",
+ " return torch.tril(result_tensor)\n",
+ "\n",
+ "\n",
+ "def get_supported_heads() -> None:\n",
+ " \"\"\"Returns a list of supported heads.\"\"\"\n",
+ " print(f\"Supported heads: {HEAD_NAMES}\")\n",
+ "\n",
+ "\n",
+ "def compute_head_attention_similarity_score(\n",
+ " attention_pattern: torch.Tensor, # [q_pos k_pos]\n",
+ " detection_pattern: torch.Tensor, # [seq_len seq_len] (seq_len == q_pos == k_pos)\n",
+ " *,\n",
+ " exclude_bos: bool,\n",
+ " exclude_current_token: bool,\n",
+ " error_measure: ErrorMeasure,\n",
+ ") -> float:\n",
+ " \"\"\"Compute the similarity between `attention_pattern` and `detection_pattern`.\n",
+ "\n",
+ " Args:\n",
+ " attention_pattern: Lower triangular matrix (Tensor) representing the attention pattern of a particular attention head.\n",
+ " detection_pattern: Lower triangular matrix (Tensor) representing the attention pattern we are looking for.\n",
+ " exclude_bos: `True` if the beginning-of-sentence (BOS) token should be omitted from comparison. `False` otherwise.\n",
+ " exclude_bcurrent_token: `True` if the current token at each position should be omitted from comparison. `False` otherwise.\n",
+ " error_measure: \"abs\" for using absolute values of element-wise differences as the error measure. \"mul\" for using element-wise multiplication (legacy code).\n",
+ " \"\"\"\n",
+ " assert is_square(\n",
+ " attention_pattern\n",
+ " ), f\"Attention pattern is not square; got shape {attention_pattern.shape}\"\n",
+ "\n",
+ " # mul\n",
+ "\n",
+ " if error_measure == \"mul\":\n",
+ " if exclude_bos:\n",
+ " attention_pattern[:, 0] = 0\n",
+ " if exclude_current_token:\n",
+ " attention_pattern.fill_diagonal_(0)\n",
+ " score = attention_pattern * detection_pattern\n",
+ " return (score.sum() / attention_pattern.sum()).item()\n",
+ "\n",
+ " # abs\n",
+ "\n",
+ " abs_diff = (attention_pattern - detection_pattern).abs()\n",
+ " assert (abs_diff - torch.tril(abs_diff).to(abs_diff.device)).sum() == 0\n",
+ "\n",
+ " size = len(abs_diff)\n",
+ " if exclude_bos:\n",
+ " abs_diff[:, 0] = 0\n",
+ " if exclude_current_token:\n",
+ " abs_diff.fill_diagonal_(0)\n",
+ "\n",
+ " return 1 - round((abs_diff.mean() * size).item(), 3)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Bw4CZS-tCH7u"
+ },
+ "source": [
+ "## Using Head Detector For Premade Heads\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "A0iDohDUmS_r"
+ },
+ "source": [
+ "Load the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "nIiUfx76I6a1",
+ "outputId": "85bf4ea6-0c27-4f3f-dfe5-b173dd3b70e0"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using pad_token, but it is not set yet.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loaded pretrained model gpt2-small into HookedTransformer\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "model = TransformerBridge.boot_transformers(\"gpt2\", device=device)\n",
+ "model.enable_compatibility_mode()"
]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 5)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "wMbU3X5OoNO3",
- "outputId": "659cc27d-5fcc-4523-c767-500e1cf9e2ae"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 0 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cgqKW_kWmPWX"
+ },
+ "source": [
+ "See what heads are supported out of the box"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "3i0kKYngmLru",
+ "outputId": "72a44f58-8a6b-4551-bb38-ddc177f5fd25"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Supported heads: ('previous_token_head', 'duplicate_token_head', 'induction_head')\n"
+ ]
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "get_supported_heads()"
]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 0)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Gtxc-sUlN0lo"
- },
- "source": [
- "Here, we see some interesting results. L5H6 does very little, but happens to react quite strongly to the first token of \"Trans|former\". (Capital letters? Current word detection? We don't know)\n",
- "\n",
- "L0H3 reacts almost entirely to the current token, but what little it does outside of this pays attention to the previous token. Again, it seems to be caring about the first token of \"Trans|former\".\n",
- "\n",
- "In order to more fully automate these heads, we'll need to discover more principled ways of expressing these scores. For now, you can see how while scores may be misleading, different scores lead us to interesting results."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Dzvf8UsiwR18"
- },
- "source": [
- "## Using Head Detector for Custom Heads\n",
- "\n",
- "These heads are great, but sometimes there are more than three things going on in Transformers. [citation needed] As a result, we may want to use our head detector for things that aren't pre-included in TransformerLens. Fortunately, the head detector provides support for this, via **detection patterns**.\n",
- "\n",
- "\n",
- "A detection pattern is simply a matrix of the same size as our attention pattern, which specifies the attention pattern exhibited by the kind of head we're looking for.\n",
- "\n",
- "There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern. You can choose it by passing the right value to the `error_measure` argument.\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "aM8NGXj7wRs_"
- },
- "source": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "OFLBWOw3wUWb"
- },
- "source": [
- "\n",
- "### 1. `\"mul\"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.\n",
- "\n",
- "Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score: how big fraction of this head's attention is allocated to these specific query-key pairs? Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).\n",
- "\n",
- "
\n",
- "\n",
- "$$\n",
- "\\begin{pmatrix}\n",
- "1 & 0 & 0 & 0 \\\\\n",
- "0.5 & 0.5 & 0 & 0 \\\\\n",
- "0.2 & 0.3 & 0.5 & 0 \\\\\n",
- "0.1 & 0.15 & 0.5 & 0.25\n",
- "\\end{pmatrix}\n",
- "\\odot\n",
- "\\begin{pmatrix}\n",
- "0 & 0 & 0 & 0 \\\\\n",
- "1 & 0 & 0 & 0 \\\\\n",
- "0 & 1 & 0 & 0 \\\\\n",
- "0 & 0 & 1 & 0\n",
- "\\end{pmatrix}\n",
- "=\n",
- "\\begin{pmatrix}\n",
- "0 & 0 & 0 & 0 \\\\\n",
- "0.5 & 0 & 0 & 0 \\\\\n",
- "0 & 0.3 & 0 & 0 \\\\\n",
- "0 & 0 & 0.5 & 0\n",
- "\\end{pmatrix}\n",
- "$$\n",
- "\n",
- "
\n",
- "\n",
- "0.5, 0.3, and 0.5 all get multiplied by 1, so they get kept. All the others go to 0 and are removed. (Note: You can use values other than 0 or 1 when creating your own heads)\n",
- "\n",
- "Our total score would then be 1.3 / 4, or 0.325. If we ignore bos and current token, it would be 0.8 / 0.95 instead, or ~0.842. (This is a large difference, but the difference generally gets smaller as the matrices get bigger)\n",
- "\n",
- "This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Mewy1K9fOmOk"
- },
- "source": [
- "\n",
- "### 2. `\"abs\"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.\n",
- "\n",
- "The \"raw result\" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval, with 1 being perfect match and -1 perfect mismatch.\n",
- "\n",
- "We take the attention pattern and compute its absolute element-wise difference with our detection pattern. Since every number in any of the two patterns has a value between -1 and 1, the maximum absolute difference of any pair is 2 and the minimum is 0:\n",
- "\n",
- "$$|-1-1|=|1-(-1)|=2$$\n",
- "\n",
- "$$|x-x|=0$$\n",
- "\n",
- "That number tells us how much our expectation and the real attention pattern diverge, i.e., the error.\n",
- "\n",
- "$$\n",
- "M_{diff}=\n",
- "\\left|\n",
- "\\begin{pmatrix}\n",
- "1 & 0 & 0 & 0\n",
- "\\\\\n",
- "0.5 & 0.5 & 0 & 0 \n",
- "\\\\\n",
- "0.2 & 0.3 & 0.5 & 0 \n",
- "\\\\\n",
- "0.1 & 0.15 & 0.5 & 0.25 \n",
- "\\end{pmatrix}\n",
- "-\n",
- "\\begin{pmatrix}\n",
- "0 & 0 & 0 & 0\n",
- "\\\\\n",
- "1 & 0 & 0 & 0 \n",
- "\\\\\n",
- "0 & 1 & 0 & 0 \n",
- "\\\\\n",
- "0 & 0 & 1 & 0 \n",
- "\\end{pmatrix}\n",
- "\\right|\n",
- "=\n",
- "\\begin{pmatrix}\n",
- "1 & 0 & 0 & 0\n",
- "\\\\\n",
- "0.5 & 0.5 & 0 & 0 \n",
- "\\\\\n",
- "0.2 & 0.7 & 0.5 & 0\n",
- "\\\\\n",
- "0.1 & 0.15 & 0.5 & 0.25 \n",
- "\\end{pmatrix}\n",
- "$$\n",
- "\n",
- "\n",
- "We take the mean and multiply it by the number of rows.\n",
- "\n",
- "We subtract the result from 1 in order to map the (0, 2) interval where lower is better to the (-1, 1) interval where higher is better.\n",
- "\n",
- "$$1 - \\text{n_rows} \\times \\text{mean}(M_{diff}) = 1 - 4 \\times 0.275 = 1 - 1.1 = -.1$$\n",
- "\n",
- "Our final score would then be -1. If we ignore `BOS` and current token, it would be 0.6625. (This is a large difference, but the difference generally gets smaller as the matrices get bigger.)\n",
- "\n",
- "This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument.\n",
- "\n",
- "I'm curious what's going on with this L0H3 result, where we mostly focus on the current token but occasionally focus on the \"Trans\" token in \"Trans|former\". Let's make a **current word head** detection pattern, which returns 1 for previous tokens that are part of the current word being looked at, and 0 for everything else."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AS5yRsZgwtAl"
- },
- "source": [
- "### **Which one should you use?** \n",
- "\n",
- "`\"abs\"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `\"abs\"`. \n",
- "\n",
- "The advantage of `\"abs\"` is that you can make more precise predictions, and have that measured in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer. The \"mul\" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KjeBZ9XBxsjb"
- },
- "source": [
- "Below we show how different scores these two measures can give on the same prompt. After that, we will proceed with using `\"abs\"` and will get back to `\"mul\"` at the end of the notebook."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
},
- "id": "eknvQfWmRr74",
- "outputId": "cf428f24-dd7d-4760-ed34-986f588c1411"
- },
- "outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "23 ['<|endoftext|>', 'The', ' following', ' lex', 'ical', ' sequence', ' has', ' been', ' optim', 'ised', ' for', ' the', ' maxim', 'isation', ' of', ' lo', 'qu', 'aciously', ' multit', 'oken', ' letter', ' combinations', '.']\n"
- ]
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "b88sGXh1mUvD"
+ },
+ "source": [
+ "Let's test detecting previous token head in the following prompt."
+ ]
},
{
- "data": {
- "text/plain": [
- "torch.Size([23, 23])"
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "hncQfgF8CE_i",
+ "outputId": "a925be04-74ed-4e9e-b02d-faf6d24026f0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "prompt = \"The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery.\"\n",
+ "head_scores = detect_head(model, prompt, \"previous_token_head\")\n",
+ "plot_head_detection_scores(head_scores, title=\"Previous Head Matches\")"
]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "prompt = \"The following lexical sequence has been optimised for the maximisation of loquaciously multitoken letter combinations.\"\n",
- "tokens = model.to_str_tokens(prompt)\n",
- "print(len(tokens), tokens)\n",
- "detection_pattern = []\n",
- "for i in range(2):\n",
- " detection_pattern.append([0 for t in tokens]) # Ignore BOS token and first token.\n",
- "for i in range(2, len(tokens)):\n",
- " current_token = i\n",
- " previous_tokens_in_word = 0\n",
- " while not tokens[current_token].startswith(' '): # If the current token does not start with a space (and is not the first token) it's part of a word.\n",
- " previous_tokens_in_word += 1\n",
- " current_token -= 1\n",
- " # Hacky code that adds in some 1's where needed, and fills the rest of the row with 0's.\n",
- " detection_pattern.append([0 for j in range(i - previous_tokens_in_word)] + [1 for j in range(previous_tokens_in_word)] + [0 for j in range(i+1, len(tokens)+1)])\n",
- "detection_pattern = torch.as_tensor(detection_pattern).to(device)\n",
- "detection_pattern.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "id": "rOh5aUu80Ols"
- },
- "outputs": [],
- "source": [
- "_, cache = model.run_with_cache(prompt)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "cBbW960Tw7hI"
- },
- "source": [
- "`\"mul\"`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "bm9z5sR4Yu3A",
- "outputId": "b26da44a-4dcf-4489-a558-4801a7fcbcc4"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(\n",
- " model, \n",
- " prompt, \n",
- " detection_pattern=detection_pattern, \n",
- " exclude_bos=False, \n",
- " exclude_current_token=True, \n",
- " error_measure=\"mul\"\n",
- ")\n",
- "plot_head_detection_scores(head_scores, title=\"Current Word Head Matches (mul)\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "MgglfsyzxGFe"
- },
- "source": [
- "`\"abs\"`"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "oVzeaGEhxKrq",
- "outputId": "94dbc3b4-0b84-4c1f-be51-0427c96d0076"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(\n",
- " model, \n",
- " prompt, \n",
- " detection_pattern=detection_pattern, \n",
- " exclude_bos=False, \n",
- " exclude_current_token=True, \n",
- " error_measure=\"abs\"\n",
- ")\n",
- "plot_head_detection_scores(head_scores, title=\"Current Word Head Matches (abs)\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "t8AnWKW2Y9Jj"
- },
- "source": [
- "75% match for L0H3 - only 16% for L5H6. Let's check them out with our new sequence!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "11l4x8H0ZEEp",
- "outputId": "71e275d8-b882-4d95-a077-1b6390552e31"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 5 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f_iFBhRRKQIF"
+ },
+ "source": [
+ "We can see both L2H2 and L4H11 are doing a fair bit of previous token detection. Let's take a look and see if that pans out."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "95mH2b43n0EZ"
+ },
+ "outputs": [],
+ "source": [
+ "_, cache = model.run_with_cache(prompt)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "S7bz-uZQKWpj",
+ "outputId": "8bb120a6-3223-4b65-9c92-fd089d3f1d4e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 2 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 2)"
]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 5)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 339
- },
- "id": "d4bTwsAgYljL",
- "outputId": "d424ea36-af30-41cd-b7d0-3f7e8c3c53b9"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 0 Attention Heads:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " \n",
- "\n",
- " \n",
- " \n",
- " \n",
- " \n",
- " "
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "7OEhpa-HLZJq",
+ "outputId": "a27f6688-a121-4ca3-efe8-9f3d71bc37cd"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 4 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
],
- "text/plain": [
- ""
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 4)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "f4Eiua9pMLok"
+ },
+ "source": [
+ "As we expected, L2H2 is doing a lot of previous token detection, but doesn't appear to be a sharp previous token detection head. L4H11, on the other hand, is pretty much perfect. In fact, the only place it seems to be putting any other attention is the very first token, where it pays attention to the BOS (*beginning-of-sentence*) token.\n",
+ "\n",
+ "Mechanistic interpretability is still a very new field, and we don't know the best ways to measure things yet. Ignoring attention paid to BOS allows us to solve problems like the above, but may also give us artifically high results for a head like L4H10, which doesn't appear to be doing much of anything, but does have a bit of previous token attention going on if you squint carefully.\n",
+ "\n",
+ "As such, the head detector supports both an `exclude_bos` and `exclude_current_token` argument, which ignores all BOS attention and all current token attention respectively. By default these are `False`, but this is a pretty arbitrary decision, so feel free to try things out! You don't need a good reason to change these arguments - pick whatever best helps you find out useful things!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "C42HUkb_NRad",
+ "outputId": "3669bb7c-0d18-45ba-8d62-18162fb70b89"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(model, prompt, \"previous_token_head\", exclude_bos=True, exclude_current_token=True)\n",
+ "plot_head_detection_scores(head_scores, title=\"Previous Head Matches\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oiWQjv9UNerH"
+ },
+ "source": [
+ "Now we have a lot more detection, including L0H3 and L5H6 which were unremarkable before. Let's check them out!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "cCsB8fD8oH5i",
+ "outputId": "e1ebdc6d-07a6-446d-91de-a5b197e999b5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 5 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "wMbU3X5OoNO3",
+ "outputId": "659cc27d-5fcc-4523-c767-500e1cf9e2ae"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 0 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Gtxc-sUlN0lo"
+ },
+ "source": [
+ "Here, we see some interesting results. L5H6 does very little, but happens to react quite strongly to the first token of \"Trans|former\". (Capital letters? Current word detection? We don't know)\n",
+ "\n",
+ "L0H3 reacts almost entirely to the current token, but what little it does outside of this pays attention to the previous token. Again, it seems to be caring about the first token of \"Trans|former\".\n",
+ "\n",
+ "In order to more fully automate these heads, we'll need to discover more principled ways of expressing these scores. For now, you can see how while scores may be misleading, different scores lead us to interesting results."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Dzvf8UsiwR18"
+ },
+ "source": [
+ "## Using Head Detector for Custom Heads\n",
+ "\n",
+ "These heads are great, but sometimes there are more than three things going on in Transformers. [citation needed] As a result, we may want to use our head detector for things that aren't pre-included in TransformerLens. Fortunately, the head detector provides support for this, via **detection patterns**.\n",
+ "\n",
+ "\n",
+ "A detection pattern is simply a matrix of the same size as our attention pattern, which specifies the attention pattern exhibited by the kind of head we're looking for.\n",
+ "\n",
+ "There are two error measures available for quantifying the match between the detection pattern and the actual attention pattern. You can choose it by passing the right value to the `error_measure` argument.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "aM8NGXj7wRs_"
+ },
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "OFLBWOw3wUWb"
+ },
+ "source": [
+ "\n",
+ "### 1. `\"mul\"` (default) multiplies both tensors element-wise and divides the sum of the result by the sum of the attention pattern.\n",
+ "\n",
+ "Typically, the detection pattern should in this case contain only ones and zeros, which allows a straightforward interpretation of the score: how big fraction of this head's attention is allocated to these specific query-key pairs? Using values other than 0 or 1 is not prohibited but will raise a warning (which can be disabled, of course).\n",
+ "\n",
+ "
\n",
+ "\n",
+ "$$\n",
+ "\\begin{pmatrix}\n",
+ "1 & 0 & 0 & 0 \\\\\n",
+ "0.5 & 0.5 & 0 & 0 \\\\\n",
+ "0.2 & 0.3 & 0.5 & 0 \\\\\n",
+ "0.1 & 0.15 & 0.5 & 0.25\n",
+ "\\end{pmatrix}\n",
+ "\\odot\n",
+ "\\begin{pmatrix}\n",
+ "0 & 0 & 0 & 0 \\\\\n",
+ "1 & 0 & 0 & 0 \\\\\n",
+ "0 & 1 & 0 & 0 \\\\\n",
+ "0 & 0 & 1 & 0\n",
+ "\\end{pmatrix}\n",
+ "=\n",
+ "\\begin{pmatrix}\n",
+ "0 & 0 & 0 & 0 \\\\\n",
+ "0.5 & 0 & 0 & 0 \\\\\n",
+ "0 & 0.3 & 0 & 0 \\\\\n",
+ "0 & 0 & 0.5 & 0\n",
+ "\\end{pmatrix}\n",
+ "$$\n",
+ "\n",
+ "
\n",
+ "\n",
+ "0.5, 0.3, and 0.5 all get multiplied by 1, so they get kept. All the others go to 0 and are removed. (Note: You can use values other than 0 or 1 when creating your own heads)\n",
+ "\n",
+ "Our total score would then be 1.3 / 4, or 0.325. If we ignore bos and current token, it would be 0.8 / 0.95 instead, or ~0.842. (This is a large difference, but the difference generally gets smaller as the matrices get bigger)\n",
+ "\n",
+ "This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Mewy1K9fOmOk"
+ },
+ "source": [
+ "\n",
+ "### 2. `\"abs\"` calculates the mean element-wise absolute difference between the detection pattern and the actual attention pattern.\n",
+ "\n",
+ "The \"raw result\" ranges from 0 to 2 where lower score corresponds to greater accuracy. Subtracting it from 1 maps that range to (-1, 1) interval, with 1 being perfect match and -1 perfect mismatch.\n",
+ "\n",
+ "We take the attention pattern and compute its absolute element-wise difference with our detection pattern. Since every number in any of the two patterns has a value between -1 and 1, the maximum absolute difference of any pair is 2 and the minimum is 0:\n",
+ "\n",
+ "$$|-1-1|=|1-(-1)|=2$$\n",
+ "\n",
+ "$$|x-x|=0$$\n",
+ "\n",
+ "That number tells us how much our expectation and the real attention pattern diverge, i.e., the error.\n",
+ "\n",
+ "$$\n",
+ "M_{diff}=\n",
+ "\\left|\n",
+ "\\begin{pmatrix}\n",
+ "1 & 0 & 0 & 0\n",
+ "\\\\\n",
+ "0.5 & 0.5 & 0 & 0 \n",
+ "\\\\\n",
+ "0.2 & 0.3 & 0.5 & 0 \n",
+ "\\\\\n",
+ "0.1 & 0.15 & 0.5 & 0.25 \n",
+ "\\end{pmatrix}\n",
+ "-\n",
+ "\\begin{pmatrix}\n",
+ "0 & 0 & 0 & 0\n",
+ "\\\\\n",
+ "1 & 0 & 0 & 0 \n",
+ "\\\\\n",
+ "0 & 1 & 0 & 0 \n",
+ "\\\\\n",
+ "0 & 0 & 1 & 0 \n",
+ "\\end{pmatrix}\n",
+ "\\right|\n",
+ "=\n",
+ "\\begin{pmatrix}\n",
+ "1 & 0 & 0 & 0\n",
+ "\\\\\n",
+ "0.5 & 0.5 & 0 & 0 \n",
+ "\\\\\n",
+ "0.2 & 0.7 & 0.5 & 0\n",
+ "\\\\\n",
+ "0.1 & 0.15 & 0.5 & 0.25 \n",
+ "\\end{pmatrix}\n",
+ "$$\n",
+ "\n",
+ "\n",
+ "We take the mean and multiply it by the number of rows.\n",
+ "\n",
+ "We subtract the result from 1 in order to map the (0, 2) interval where lower is better to the (-1, 1) interval where higher is better.\n",
+ "\n",
+ "$$1 - \\text{n_rows} \\times \\text{mean}(M_{diff}) = 1 - 4 \\times 0.275 = 1 - 1.1 = -.1$$\n",
+ "\n",
+ "Our final score would then be -1. If we ignore `BOS` and current token, it would be 0.6625. (This is a large difference, but the difference generally gets smaller as the matrices get bigger.)\n",
+ "\n",
+ "This is how the head detector works under the hood - each existing head just has its own detection pattern. Thus, we can pass in our own detection pattern using the `detection_pattern` argument.\n",
+ "\n",
+ "I'm curious what's going on with this L0H3 result, where we mostly focus on the current token but occasionally focus on the \"Trans\" token in \"Trans|former\". Let's make a **current word head** detection pattern, which returns 1 for previous tokens that are part of the current word being looked at, and 0 for everything else."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AS5yRsZgwtAl"
+ },
+ "source": [
+ "### **Which one should you use?** \n",
+ "\n",
+ "`\"abs\"` is likely better for quick or exploratory investigations. For precise examinations where you're trying to reproduce as much functionality as possible or really test your understanding of the attention head, you probably want to switch to `\"abs\"`. \n",
+ "\n",
+ "The advantage of `\"abs\"` is that you can make more precise predictions, and have that measured in the score. You can predict, for instance, 0.2 attention to X, and 0.8 attention to Y, and your score will be better if your prediction is closer. The \"mul\" metric does not allow this, you'll get the same score if attention is 0.2, 0.8 or 0.5, 0.5 or 0.8, 0.2."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KjeBZ9XBxsjb"
+ },
+ "source": [
+ "Below we show how different scores these two measures can give on the same prompt. After that, we will proceed with using `\"abs\"` and will get back to `\"mul\"` at the end of the notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eknvQfWmRr74",
+ "outputId": "cf428f24-dd7d-4760-ed34-986f588c1411"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "23 ['<|endoftext|>', 'The', ' following', ' lex', 'ical', ' sequence', ' has', ' been', ' optim', 'ised', ' for', ' the', ' maxim', 'isation', ' of', ' lo', 'qu', 'aciously', ' multit', 'oken', ' letter', ' combinations', '.']\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([23, 23])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "prompt = \"The following lexical sequence has been optimised for the maximisation of loquaciously multitoken letter combinations.\"\n",
+ "tokens = model.to_str_tokens(prompt)\n",
+ "print(len(tokens), tokens)\n",
+ "detection_pattern = []\n",
+ "for i in range(2):\n",
+ " detection_pattern.append([0 for t in tokens]) # Ignore BOS token and first token.\n",
+ "for i in range(2, len(tokens)):\n",
+ " current_token = i\n",
+ " previous_tokens_in_word = 0\n",
+ " while not tokens[current_token].startswith(' '): # If the current token does not start with a space (and is not the first token) it's part of a word.\n",
+ " previous_tokens_in_word += 1\n",
+ " current_token -= 1\n",
+ " # Hacky code that adds in some 1's where needed, and fills the rest of the row with 0's.\n",
+ " detection_pattern.append([0 for j in range(i - previous_tokens_in_word)] + [1 for j in range(previous_tokens_in_word)] + [0 for j in range(i+1, len(tokens)+1)])\n",
+ "detection_pattern = torch.as_tensor(detection_pattern).to(device)\n",
+ "detection_pattern.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "id": "rOh5aUu80Ols"
+ },
+ "outputs": [],
+ "source": [
+ "_, cache = model.run_with_cache(prompt)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cBbW960Tw7hI"
+ },
+ "source": [
+ "`\"mul\"`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "bm9z5sR4Yu3A",
+ "outputId": "b26da44a-4dcf-4489-a558-4801a7fcbcc4"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(\n",
+ " model, \n",
+ " prompt, \n",
+ " detection_pattern=detection_pattern, \n",
+ " exclude_bos=False, \n",
+ " exclude_current_token=True, \n",
+ " error_measure=\"mul\"\n",
+ ")\n",
+ "plot_head_detection_scores(head_scores, title=\"Current Word Head Matches (mul)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MgglfsyzxGFe"
+ },
+ "source": [
+ "`\"abs\"`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "oVzeaGEhxKrq",
+ "outputId": "94dbc3b4-0b84-4c1f-be51-0427c96d0076"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(\n",
+ " model, \n",
+ " prompt, \n",
+ " detection_pattern=detection_pattern, \n",
+ " exclude_bos=False, \n",
+ " exclude_current_token=True, \n",
+ " error_measure=\"abs\"\n",
+ ")\n",
+ "plot_head_detection_scores(head_scores, title=\"Current Word Head Matches (abs)\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "t8AnWKW2Y9Jj"
+ },
+ "source": [
+ "75% match for L0H3 - only 16% for L5H6. Let's check them out with our new sequence!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "11l4x8H0ZEEp",
+ "outputId": "71e275d8-b882-4d95-a077-1b6390552e31"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 5 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 339
+ },
+ "id": "d4bTwsAgYljL",
+ "outputId": "d424ea36-af30-41cd-b7d0-3f7e8c3c53b9"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Layer 0 Attention Heads:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "plot_attn_pattern_from_cache(cache, 0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TMwvCbW-Rae_"
+ },
+ "source": [
+ "As we can see, L5H6 appears to be doing something totally different than we expected, whereas L0H3 is mostly doing what we expected - by our original hypothesis, we would expect \"lo|qu|aciously\" to have a lot of attention paid to, and \"combinations|.\" the same, which didn't happen. However, our two-token words were exactly as we expected. Could this be a two-token detector (that doesn't work on punctuation)? A \"current word\" detector that just doesn't understand an obscure word like \"loquaciously\"? The field is full of such problems, just waiting to be answered!\n",
+ "\n",
+ "So, why do this at all? For just a couple of sentences, it's easier to just look at the attention patterns directly and see what we get. But as we can see, heads react differently to different sentences. What we might want to do is give an entire dataset or distribution of sentences to our attention head and see that it consistently does what we want - that's something that would be much harder without this feature!\n",
+ "\n",
+ "So what if we gave it a whole distribution? Rather than actually create one, which is not the point of this demo, we're just going to repeat our last sentence a thousand times."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 560
+ },
+ "id": "LrgxDy7C7p-n",
+ "outputId": "fa0983a4-1d67-4903-a73a-2c09bbc891a7"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 100/100 [00:13<00:00, 7.64it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "scores = []\n",
+ "for i in tqdm(range(100)):\n",
+ " scores.append(detect_head(model, prompt, detection_pattern=detection_pattern, exclude_bos=False, exclude_current_token=True, error_measure=\"abs\"))\n",
+ "scores = torch.stack(scores).mean(dim=0)\n",
+ "plot_head_detection_scores(scores, title=\"Current Word Head Matches\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AGvX96qf18H3"
+ },
+ "source": [
+ "## Processing Many Prompts\n",
+ "\n",
+ "`detect_head` can also take more than one prompt. The resulting attention score is the mean of scores for each prompt."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {
+ "id": "xKsRWJyi4nKb"
+ },
+ "outputs": [],
+ "source": [
+ "prompts = [\n",
+ " \"This is the first the test prompt.\",\n",
+ " \"This is another test prompt, being just a sequence of tokens.\",\n",
+ " \"If you're interested in mechanistic interpretability, this is how the sausage REALLY is made.\"\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "bDCNbAKn8O6c",
+ "outputId": "3f3a69e9-6909-4d1b-ad14-52081f7b3fb3"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(model, prompts, \"previous_token_head\", error_measure=\"abs\")\n",
+ "plot_head_detection_scores(head_scores, title=\"Previous token head; average across 3 prompts\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vzsyyU892l1m"
+ },
+ "source": [
+ "L4H11 emerges again as the dominant head, exactly as expected."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "50VyBU3u408u"
+ },
+ "source": [
+ "What about duplicate token heads?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "a2Tvp9-a4oZC",
+ "outputId": "d1cd9693-cebc-4b03-edbf-320eeb8b4084"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(model, prompts, \"duplicate_token_head\", error_measure=\"abs\")\n",
+ "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JeKiWg41423q"
+ },
+ "source": [
+ "Nothing but this should be expected, in hindsight, since our prompts don't contain too many duplicate tokens. Let's try three other prompts that do."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {
+ "id": "afe4NxXM5ESp"
+ },
+ "outputs": [],
+ "source": [
+ "prompts = [\n",
+ " \"one two three one two three one two three\",\n",
+ " \"1 2 3 4 5 1 2 3 4 1 2 3 1 2 3 4 5 6 7\",\n",
+ " \"green ideas sleep furiously; green ideas don't sleep furiously\"\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "0LpotLqt5TRj",
+ "outputId": "d1ea2496-93e7-4e9c-c915-708d279cd699"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(model, prompts, \"duplicate_token_head\", exclude_bos=False, exclude_current_token=False, error_measure=\"abs\")\n",
+ "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9cxe4I5V4wya"
+ },
+ "source": [
+ "3 or 4 heads seem to do something that we would expected from a duplicate token head but the signal is not very strong. You can tweak the `exclude_bos` and `exclude_current_token` flags if you want, but it doesn't change much."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GrNd4zSw6FxL"
+ },
+ "source": [
+ "Let's hunt for induction heads now!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "ixfjDS4n6Jd7",
+ "outputId": "ce750192-3c4f-40bf-f6b7-e7423ec38ada"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "head_scores = detect_head(model, prompts, \"induction_head\", exclude_bos=False, exclude_current_token=False, error_measure=\"abs\")\n",
+ "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JDL4I2hj6P5z"
+ },
+ "source": [
+ "Similarly, at least on average.\n",
+ "\n",
+ "Try running the script on different prompts and see if you can get high values for duplicate token or induction heads."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Ys0pdGBI6min"
+ },
+ "source": [
+ "## Why not element-wise multiplication - robustness against [Goodharting](https://en.wikipedia.org/wiki/Goodhart%27s_law)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TtFepOB474yd"
+ },
+ "source": [
+ "Initially, the error measure was not the mean element-wise absolute value error (normalized to the number of rows) but the mean [element-wise product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)). However, it had its problems, such as susceptibility to Goodharting. You can specify a pattern consisting of all ones and in this way achieve a perfect match for all layers and heads in the model.\n",
+ "\n",
+ "More generally, using element-wise product causes the score to go down when we narrow our hypothesis. We can get a maximum score by just predicting 1 for everything. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "id": "QAsossB28q7v"
+ },
+ "outputs": [],
+ "source": [
+ "prompt = \"The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery.\"\n",
+ "seq_len = len(model.to_str_tokens(prompt))\n",
+ "# torch.tril to make the pattern lower triangular\n",
+ "ones_detection_pattern = torch.tril(torch.ones(seq_len, seq_len).to(device))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "5BCWs0QX61gH",
+ "outputId": "3e950264-d6f7-4d5d-a0f3-570aa4a1e3e8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "ones_head_scores = detect_head(\n",
+ " model, \n",
+ " prompt, \n",
+ " ones_detection_pattern, \n",
+ " exclude_bos=True, \n",
+ " exclude_current_token=True, \n",
+ ")\n",
+ "plot_head_detection_scores(ones_head_scores, title=\"Transformers Have Now Been Solved, We Can All Go Home\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PaYJhx6W8l90"
+ },
+ "source": [
+ "The new error measure also achieves uniform score but this time its uniformly extremely negative because **not a single head in the model matches this pattern**.\n",
+ "\n",
+ "*(It's true that the scores descend below -9 whereas in theory they should remain within the (-1, 1) range. It's not yet clear if that matters for real-world uses.)*\n",
+ "\n",
+ "An alternative would be to demand that *predictions add up to 1 for each row* but that seems unnecessarily nitpicky considering that your score will get reduced in general for not doing that anyway.\n",
+ "\n",
+ "Mean squared errors have also bean tried before converging on the absolute ones. The problem with MSE is that the scores get lower as attention gets more diffuse. Error value of 1 would become 1, 0.5 would become 0.25 etc."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 542
+ },
+ "id": "U722j1mJ9TbC",
+ "outputId": "8525577e-d060-4cf9-c355-e74cda383ae8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "ones_head_scores = detect_head(\n",
+ " model, \n",
+ " prompt, \n",
+ " ones_detection_pattern, \n",
+ " exclude_bos=True, \n",
+ " exclude_current_token=True, \n",
+ " error_measure=\"abs\" # we specify the error measure here\n",
+ ")\n",
+ "plot_head_detection_scores(ones_head_scores, title=\"Transformers Have Not Been Solved Yet, Get Back To Work!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hcXmbyWJ6knn"
+ },
+ "source": [
+ "## Further improvements"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fkxMLVehaZOj"
+ },
+ "source": [
+ "**Performance for large distributions** isn't as good as it could be. The head detector could be rewritten to support taking in a list of sequences and performing these computations in parallel, but 1000 sequences per minute is certainly adequate for most use cases. If having this be faster would help your research, please write up an issue on TransformerLens, mention it on the Open Source Mechanistic Interpretability Slack, or e-mail jaybaileycs@gmail.com."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zXrc4KA5A5c9"
+ },
+ "source": [
+ "### Other\n",
+ "\n",
+ "- Extending to few-shot learning/translation heads\n",
+ "- More pre-specified heads?\n",
+ "- For inspiration, see [this post from Neel](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/btasQF7wiCYPsr5qw)"
]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "plot_attn_pattern_from_cache(cache, 0)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TMwvCbW-Rae_"
- },
- "source": [
- "As we can see, L5H6 appears to be doing something totally different than we expected, whereas L0H3 is mostly doing what we expected - by our original hypothesis, we would expect \"lo|qu|aciously\" to have a lot of attention paid to, and \"combinations|.\" the same, which didn't happen. However, our two-token words were exactly as we expected. Could this be a two-token detector (that doesn't work on punctuation)? A \"current word\" detector that just doesn't understand an obscure word like \"loquaciously\"? The field is full of such problems, just waiting to be answered!\n",
- "\n",
- "So, why do this at all? For just a couple of sentences, it's easier to just look at the attention patterns directly and see what we get. But as we can see, heads react differently to different sentences. What we might want to do is give an entire dataset or distribution of sentences to our attention head and see that it consistently does what we want - that's something that would be much harder without this feature!\n",
- "\n",
- "So what if we gave it a whole distribution? Rather than actually create one, which is not the point of this demo, we're just going to repeat our last sentence a thousand times."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 560
- },
- "id": "LrgxDy7C7p-n",
- "outputId": "fa0983a4-1d67-4903-a73a-2c09bbc891a7"
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 100/100 [00:13<00:00, 7.64it/s]\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "scores = []\n",
- "for i in tqdm(range(100)):\n",
- " scores.append(detect_head(model, prompt, detection_pattern=detection_pattern, exclude_bos=False, exclude_current_token=True, error_measure=\"abs\"))\n",
- "scores = torch.stack(scores).mean(dim=0)\n",
- "plot_head_detection_scores(scores, title=\"Current Word Head Matches\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AGvX96qf18H3"
- },
- "source": [
- "## Processing Many Prompts\n",
- "\n",
- "`detect_head` can also take more than one prompt. The resulting attention score is the mean of scores for each prompt."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "id": "xKsRWJyi4nKb"
- },
- "outputs": [],
- "source": [
- "prompts = [\n",
- " \"This is the first the test prompt.\",\n",
- " \"This is another test prompt, being just a sequence of tokens.\",\n",
- " \"If you're interested in mechanistic interpretability, this is how the sausage REALLY is made.\"\n",
- "]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "bDCNbAKn8O6c",
- "outputId": "3f3a69e9-6909-4d1b-ad14-52081f7b3fb3"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(model, prompts, \"previous_token_head\", error_measure=\"abs\")\n",
- "plot_head_detection_scores(head_scores, title=\"Previous token head; average across 3 prompts\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vzsyyU892l1m"
- },
- "source": [
- "L4H11 emerges again as the dominant head, exactly as expected."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "50VyBU3u408u"
- },
- "source": [
- "What about duplicate token heads?"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "a2Tvp9-a4oZC",
- "outputId": "d1cd9693-cebc-4b03-edbf-320eeb8b4084"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(model, prompts, \"duplicate_token_head\", error_measure=\"abs\")\n",
- "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JeKiWg41423q"
- },
- "source": [
- "Nothing but this should be expected, in hindsight, since our prompts don't contain too many duplicate tokens. Let's try three other prompts that do."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "metadata": {
- "id": "afe4NxXM5ESp"
- },
- "outputs": [],
- "source": [
- "prompts = [\n",
- " \"one two three one two three one two three\",\n",
- " \"1 2 3 4 5 1 2 3 4 1 2 3 1 2 3 4 5 6 7\",\n",
- " \"green ideas sleep furiously; green ideas don't sleep furiously\"\n",
- "]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "0LpotLqt5TRj",
- "outputId": "d1ea2496-93e7-4e9c-c915-708d279cd699"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(model, prompts, \"duplicate_token_head\", exclude_bos=False, exclude_current_token=False, error_measure=\"abs\")\n",
- "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9cxe4I5V4wya"
- },
- "source": [
- "3 or 4 heads seem to do something that we would expected from a duplicate token head but the signal is not very strong. You can tweak the `exclude_bos` and `exclude_current_token` flags if you want, but it doesn't change much."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "GrNd4zSw6FxL"
- },
- "source": [
- "Let's hunt for induction heads now!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "ixfjDS4n6Jd7",
- "outputId": "ce750192-3c4f-40bf-f6b7-e7423ec38ada"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "head_scores = detect_head(model, prompts, \"induction_head\", exclude_bos=False, exclude_current_token=False, error_measure=\"abs\")\n",
- "plot_head_detection_scores(head_scores, title=\"Duplicate token head; average across 3 prompts\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JDL4I2hj6P5z"
- },
- "source": [
- "Similarly, at least on average.\n",
- "\n",
- "Try running the script on different prompts and see if you can get high values for duplicate token or induction heads."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Ys0pdGBI6min"
- },
- "source": [
- "## Why not element-wise multiplication - robustness against [Goodharting](https://en.wikipedia.org/wiki/Goodhart%27s_law)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TtFepOB474yd"
- },
- "source": [
- "Initially, the error measure was not the mean element-wise absolute value error (normalized to the number of rows) but the mean [element-wise product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)). However, it had its problems, such as susceptibility to Goodharting. You can specify a pattern consisting of all ones and in this way achieve a perfect match for all layers and heads in the model.\n",
- "\n",
- "More generally, using element-wise product causes the score to go down when we narrow our hypothesis. We can get a maximum score by just predicting 1 for everything. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {
- "id": "QAsossB28q7v"
- },
- "outputs": [],
- "source": [
- "prompt = \"The head detector feature for TransformerLens allows users to check for various common heads automatically, reducing the cost of discovery.\"\n",
- "seq_len = len(model.to_str_tokens(prompt))\n",
- "# torch.tril to make the pattern lower triangular\n",
- "ones_detection_pattern = torch.tril(torch.ones(seq_len, seq_len).to(device))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "5BCWs0QX61gH",
- "outputId": "3e950264-d6f7-4d5d-a0f3-570aa4a1e3e8"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
}
- ],
- "source": [
- "ones_head_scores = detect_head(\n",
- " model, \n",
- " prompt, \n",
- " ones_detection_pattern, \n",
- " exclude_bos=True, \n",
- " exclude_current_token=True, \n",
- ")\n",
- "plot_head_detection_scores(ones_head_scores, title=\"Transformers Have Now Been Solved, We Can All Go Home\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PaYJhx6W8l90"
- },
- "source": [
- "The new error measure also achieves uniform score but this time its uniformly extremely negative because **not a single head in the model matches this pattern**.\n",
- "\n",
- "*(It's true that the scores descend below -9 whereas in theory they should remain within the (-1, 1) range. It's not yet clear if that matters for real-world uses.)*\n",
- "\n",
- "An alternative would be to demand that *predictions add up to 1 for each row* but that seems unnecessarily nitpicky considering that your score will get reduced in general for not doing that anyway.\n",
- "\n",
- "Mean squared errors have also bean tried before converging on the absolute ones. The problem with MSE is that the scores get lower as attention gets more diffuse. Error value of 1 would become 1, 0.5 would become 0.25 etc."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {
+ ],
+ "metadata": {
+ "accelerator": "GPU",
"colab": {
- "base_uri": "https://localhost:8080/",
- "height": 542
- },
- "id": "U722j1mJ9TbC",
- "outputId": "8525577e-d060-4cf9-c355-e74cda383ae8"
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.6"
}
- ],
- "source": [
- "ones_head_scores = detect_head(\n",
- " model, \n",
- " prompt, \n",
- " ones_detection_pattern, \n",
- " exclude_bos=True, \n",
- " exclude_current_token=True, \n",
- " error_measure=\"abs\" # we specify the error measure here\n",
- ")\n",
- "plot_head_detection_scores(ones_head_scores, title=\"Transformers Have Not Been Solved Yet, Get Back To Work!\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hcXmbyWJ6knn"
- },
- "source": [
- "## Further improvements"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "fkxMLVehaZOj"
- },
- "source": [
- "**Performance for large distributions** isn't as good as it could be. The head detector could be rewritten to support taking in a list of sequences and performing these computations in parallel, but 1000 sequences per minute is certainly adequate for most use cases. If having this be faster would help your research, please write up an issue on TransformerLens, mention it on the Open Source Mechanistic Interpretability Slack, or e-mail jaybaileycs@gmail.com."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "zXrc4KA5A5c9"
- },
- "source": [
- "### Other\n",
- "\n",
- "- Extending to few-shot learning/translation heads\n",
- "- More pre-specified heads?\n",
- "- For inspiration, see [this post from Neel](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/btasQF7wiCYPsr5qw)"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": []
- },
- "gpuClass": "standard",
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
},
- "language_info": {
- "name": "python",
- "version": "3.10.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
+ "nbformat": 4,
+ "nbformat_minor": 0
}