Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
512 changes: 512 additions & 0 deletions demos/GPT_OSS_Demo.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions demos/LLaMA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
"output_type": "stream",
"text": [
"/tmp/ipykernel_16979/572068249.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"load_ext autoreload\")\n",
" ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
"/tmp/ipykernel_16979/572068249.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"autoreload 2\")\n"
" ipython.run_line_magic(\"autoreload\", \"2\")\n"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions demos/Santa_Coder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"output_type": "stream",
"text": [
"/tmp/ipykernel_35643/572068249.py:21: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"load_ext autoreload\")\n",
" ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
"/tmp/ipykernel_35643/572068249.py:22: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n",
" ipython.magic(\"autoreload 2\")\n"
" ipython.run_line_magic(\"autoreload\", \"2\")\n"
]
}
],
Expand Down
20 changes: 18 additions & 2 deletions tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ def test_accumulated_resid_with_apply_ln():
# Run the model and cache all activations
_, cache = model.run_with_cache(tokens)

# Get accumulated resid and apply ln seperately (cribbed notebook code)
# Get accumulated resid and apply final ln seperately
accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = cache.apply_ln_to_stack(
accumulated_residual, layer=-1, pos_slice=-1
accumulated_residual, layer=-1, pos_slice=-1, recompute_ln=True
)

# Get scaled_residual_stack using apply_ln parameter
Expand All @@ -271,6 +271,22 @@ def test_accumulated_resid_with_apply_ln():
assert labels == expected_labels


@torch.no_grad
def test_apply_ln_recompute_ln_differs_from_cached():
model = load_model("solu-2l")
tokens, _ = get_ioi_tokens_and_answer_tokens(model)
_, cache = model.run_with_cache(tokens)

accumulated = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
with_recompute = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=True)
with_cached = cache.apply_ln_to_stack(accumulated, layer=-1, pos_slice=-1, recompute_ln=False)

assert with_recompute.shape == with_cached.shape
assert not torch.isclose(
with_recompute, with_cached, atol=1e-7
).all(), "recompute_ln=True and recompute_ln=False should differ for accumulated residual stack"


@torch.no_grad
def test_decompose_resid_with_apply_ln():
# Load solu-2l
Expand Down
69 changes: 68 additions & 1 deletion tests/acceptance/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import pytest

from transformer_lens.evals import IOIDataset, ioi_eval
from transformer_lens.evals import (
IOIDataset,
ioi_eval,
make_mmlu_data_loader,
mmlu_eval,
)
from transformer_lens.HookedTransformer import HookedTransformer


Expand Down Expand Up @@ -70,3 +75,65 @@ def test_inverted_template(model):
results = ioi_eval(model, dataset=ds)
assert results["Logit Difference"] < -2.0
assert results["Accuracy"] <= 0.01


def test_mmlu_data_loader_single_subject():
"""
Test loading MMLU data for a single subject.
"""
data = make_mmlu_data_loader(subjects="abstract_algebra", num_samples=5)
assert len(data) == 5
assert all(isinstance(d, dict) for d in data)
assert all("question" in d for d in data)
assert all("choices" in d for d in data)
assert all("answer" in d for d in data)
assert all("subject" in d for d in data)
assert all(len(d["choices"]) == 4 for d in data)
assert all(d["subject"] == "abstract_algebra" for d in data)


def test_mmlu_data_loader_multiple_subjects():
"""
Test loading MMLU data for multiple subjects.
"""
subjects = ["abstract_algebra", "anatomy"]
data = make_mmlu_data_loader(subjects=subjects, num_samples=3)
assert len(data) == 6 # 3 samples per subject
subjects_in_data = {d["subject"] for d in data}
assert subjects_in_data == set(subjects)


def test_mmlu_data_loader_invalid_subject():
"""
Test that invalid subject names raise an error.
"""
with pytest.raises(ValueError, match="Invalid subject"):
make_mmlu_data_loader(subjects="invalid_subject_name")


def test_mmlu_eval_single_subject(model):
"""
Test MMLU evaluation on a single subject with a small number of samples.
Uses a small model and few samples for fast CI execution.
"""
results = mmlu_eval(model, subjects="abstract_algebra", num_samples=5)
assert "accuracy" in results
assert "num_correct" in results
assert "num_total" in results
assert "subject_scores" in results
assert 0 <= results["accuracy"] <= 1
assert results["num_total"] == 5
assert results["num_correct"] <= results["num_total"]
assert "abstract_algebra" in results["subject_scores"]


def test_mmlu_eval_multiple_subjects(model):
"""
Test MMLU evaluation on multiple subjects.
"""
subjects = ["abstract_algebra", "anatomy"]
results = mmlu_eval(model, subjects=subjects, num_samples=3)
assert results["num_total"] == 6 # 3 samples per subject
assert len(results["subject_scores"]) == 2
assert all(subject in results["subject_scores"] for subject in subjects)
assert all(0 <= acc <= 1 for acc in results["subject_scores"].values())
37 changes: 37 additions & 0 deletions tests/integration/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,43 @@ def skip_grad(output_grad: torch.Tensor, hook: Any):
model.remove_all_hook_fns(including_permanent=True)


def test_backward_hook_returning_bare_tensor():
"""Regression test for issue #1160.

When a backward hook returns a bare tensor (not wrapped in a tuple),
PyTorch's register_full_backward_hook raises:
RuntimeError: hook 'hook' has changed the size of value

The fix wraps bare tensor returns as (result,) before returning to PyTorch.
"""
c = Counter()

def modify_grad(grad: torch.Tensor, hook: Any):
c.inc()
return grad # bare tensor, NOT (grad,)

with model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", modify_grad)]):
out = model(prompt)
out.sum().backward()
assert c.count == 1
model.remove_all_hook_fns(including_permanent=True)


def test_backward_hook_returning_none():
"""Backward hooks returning None should not raise."""
c = Counter()

def observe_grad(grad: torch.Tensor, hook: Any):
c.inc()
return None

with model.hooks(bwd_hooks=[("blocks.0.hook_resid_post", observe_grad)]):
out = model(prompt)
out.sum().backward()
assert c.count == 1
model.remove_all_hook_fns(including_permanent=True)


def test_hook_context_manager_with_permanent_hook():
c = Counter()
model.add_perma_hook(embed, c.inc)
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_match_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,7 @@ def test_compare_huggingface_attention_match_local_implementation(
attn_output, hf_attn.c_proj.weight.T, hf_attn.c_proj.bias
)

assert torch.allclose(tl_out, hf_out, atol=1e-4)
# Tolerance accounts for float32 accumulation differences between
# TransformerLens and HuggingFace attention implementations across
# 12 layers. Empirically, worst-case diff is ~1.3e-3 on layer 11.
assert torch.allclose(tl_out, hf_out, atol=1e-3)
Loading
Loading