Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 122 additions & 18 deletions transformer_lens/model_bridge/generalized_components/bloom_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through BLOOM attention with hooks.

Uses the parent's hooked Q/K/V split path so that hook_q, hook_k, hook_v,
hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and
attention masking are handled in _reconstruct_attention.

BLOOM attention requires these arguments:
- hidden_states (first positional arg)
- residual (second positional arg)
Expand All @@ -84,32 +88,132 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
f"Original component not set for {self.name}. Call set_original_component() first."
)

# Apply hook_in to hidden_states (first positional argument)
# Extract hidden_states (first positional arg) and residual (second positional arg)
if len(args) > 0 and isinstance(args[0], torch.Tensor):
hooked_input = self.hook_in(args[0])
args = (hooked_input,) + args[1:]
hidden_states = args[0]
elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
hidden_states = kwargs["hidden_states"]
else:
raise ValueError("Could not find hidden_states in args or kwargs")

residual = args[1] if len(args) > 1 and isinstance(args[1], torch.Tensor) else None

# Apply input hook
hooked_input = self.hook_in(hidden_states)

# BLOOM attention requires residual as second positional arg
# The original BLOOM block passes it, so we just pass everything through
# No need to validate since the original component will handle it
# Run through split Q/K/V projections (these fire hook_q, hook_k, hook_v)
q_output = self.q(hooked_input)
k_output = self.k(hooked_input)
v_output = self.v(hooked_input)

# Call the original BLOOM attention component with all arguments
# BLOOM attention returns (hidden_states,) or (hidden_states, attention_weights)
output = self.original_component(*args, **kwargs)
# Reconstruct attention with ALiBi (fires hook_attn_scores, hook_pattern)
attn_output, attn_weights = self._reconstruct_attention(
q_output, k_output, v_output, **kwargs
)

# BLOOM's original attention applies dropout_add(dense_output, residual, ...)
# inside the attention module, not in the block. We must replicate this.
if residual is not None:
assert self.original_component is not None
hidden_dropout = getattr(self.original_component, "hidden_dropout", 0.0)
if self.training:
attn_output = torch.nn.functional.dropout(
attn_output, p=hidden_dropout, training=True
)
attn_output = attn_output + residual

# Apply hook_out to the hidden_states (first element of tuple)
if isinstance(output, tuple) and len(output) > 0:
processed_output = list(output)
if isinstance(output[0], torch.Tensor):
processed_output[0] = self.hook_out(output[0])
output = tuple(processed_output)
elif isinstance(output, torch.Tensor):
output = self.hook_out(output)
# Apply output hook
output = (attn_output, attn_weights)
output = self._process_output(output)

return output

def _reconstruct_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
) -> tuple:
"""Reconstruct attention using BLOOM's ALiBi-based score computation.

BLOOM computes attention scores via alibi.baddbmm(Q, K^T) which fuses the
ALiBi positional bias directly into the score computation. This override
mirrors that behavior while keeping all hook points active.
"""
assert self.original_component is not None
assert self.config is not None
num_heads = self.config.n_heads
head_dim: int = self.config.d_head

# Reshape Q/K/V from [batch, seq, hidden] to [batch, heads, seq, head_dim]
if q.ndim == 3:
batch_size, seq_len, _ = q.shape
q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
elif q.ndim == 4:
batch_size, seq_len = q.shape[0], q.shape[1]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
else:
raise ValueError(f"Unexpected Q tensor shape: {q.shape}")

# BLOOM uses baddbmm: alibi + Q @ K^T * inv_norm_factor
# Reshape to [batch*heads, seq, head_dim] for baddbmm
q_bh = q.reshape(batch_size * num_heads, seq_len, head_dim)
k_bh = k.reshape(batch_size * num_heads, seq_len, head_dim)
v_bh = v.reshape(batch_size * num_heads, seq_len, head_dim)

inv_norm_factor = head_dim ** (-0.5)
beta = 1.0

alibi = kwargs.get("alibi", None)
if alibi is not None:
# alibi shape: [batch*heads, 1, seq] or [batch*heads, seq, seq]
# baddbmm: beta * alibi + alpha * (Q @ K^T)
attn_scores = alibi.baddbmm(
batch1=q_bh,
batch2=k_bh.transpose(-1, -2),
beta=beta,
alpha=inv_norm_factor,
)
else:
attn_scores = torch.bmm(q_bh, k_bh.transpose(-1, -2)) * inv_norm_factor

# Reshape to [batch, heads, seq, seq]
attn_scores = attn_scores.view(batch_size, num_heads, seq_len, -1)

# Apply attention mask
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : attn_scores.shape[-1]]
attn_scores = attn_scores + causal_mask

attn_scores = self.hook_attn_scores(attn_scores)

# Softmax in float32 for numerical stability (matches HF BLOOM)
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
attn_weights = attn_weights.to(q.dtype)

if hasattr(self.original_component, "attention_dropout"):
attn_weights = self.original_component.attention_dropout(attn_weights) # type: ignore[operator]

attn_weights = self.hook_pattern(attn_weights)

# Compute attention output
# Reshape weights to [batch*heads, seq, seq] for bmm
attn_weights_bh = attn_weights.reshape(batch_size * num_heads, seq_len, -1)
attn_output = torch.bmm(attn_weights_bh, v_bh)

# Reshape back to [batch, seq, hidden]
attn_output = attn_output.view(batch_size, num_heads, seq_len, head_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)

# Apply output projection
if hasattr(self, "o") and self.o is not None:
attn_output = self.o(attn_output)

return (attn_output, attn_weights)

def set_processed_weights(
self, weights: Mapping[str, torch.Tensor | None], verbose: bool = False
) -> None:
Expand Down
Loading