Skip to content

fix: add head_dim=256 to fused SDPA full attention kernel#3293

Open
Thump604 wants to merge 3 commits intoml-explore:mainfrom
Thump604:fix/sdpa-full-head-dim-256
Open

fix: add head_dim=256 to fused SDPA full attention kernel#3293
Thump604 wants to merge 3 commits intoml-explore:mainfrom
Thump604:fix/sdpa-full-head-dim-256

Conversation

@Thump604
Copy link
Copy Markdown

@Thump604 Thump604 commented Mar 22, 2026

Summary

Add head_dim == 256 to sdpa_full_supported_head_dim and instantiate the steel_attention kernel with bd=256.

Models with head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention path which materializes the full score matrix as a single matmul. At 32K+ context this creates 8+ GB single allocations that crash Metal's buffer allocator.

The Metal kernel template already handles arbitrary BD via template parameter — only the dispatch gate and kernel instantiation list were missing.

Changes (1 commit, 2 files, +3 lines)

  • scaled_dot_product_attention.cpp: Add query_head_dim == 256 to sdpa_full_supported_head_dim
  • steel_attention.metal: Add instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype)

Verification

M2 Ultra 128GB, Qwen3.5-122B-A10B (5-bit, head_dim=256):

Context Before (unfused fallback) After (fused tiled)
16K Works (allocation fits) Works
32K CRASH (8.59 GB/layer) Works
64K CRASH Works
128K CRASH Works

Affected models

All models with head_dim=256, including Qwen3.5-122B-A10B, Qwen3.5-35B-A3B, Qwen3.5-27B, Qwen3.5-4B.

@Thump604
Copy link
Copy Markdown
Author

@zcbenz This fixes the crash I reported on #3216. The root cause was the unfused SDPA fallback, not thread safety — I posted the details there.

I've removed the completion handler error storage from this PR per your note about MLX not being exception-safe. The remaining change is:

  1. Add head_dim == 256 to sdpa_full_supported_head_dim (dispatch gate)
  2. Instantiate steel_attention kernel with bd=256 (pre-compiled kernel list)

The Metal kernel template already handles arbitrary BD. Verified 32K/64K/128K context on M2 Ultra with Qwen3.5-122B (head_dim=256).

No CI checks on fork PRs — happy to provide any test results you need.

@Thump604 Thump604 force-pushed the fix/sdpa-full-head-dim-256 branch 2 times, most recently from a2d6335 to f35ce26 Compare March 22, 2026 12:41
@Thump604 Thump604 changed the title fix: add head_dim=256 to fused SDPA full attention + safe completion handler errors fix: SDPA head_dim=256 + completion handler error safety + chunked full-attention for long context Mar 22, 2026
@Thump604 Thump604 marked this pull request as draft March 22, 2026 22:23
@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 22, 2026

I'm good with the "fix: add head_dim=256 to fused SDPA full attention kernel" change and can you reset this PR with only that commit, or create a new PR for it?

The other changes definitely need to be discussed first, and opening a new issue for what you are proposing would be more helpful. And we must carefully ensure they are not introducing performance regressions before we can look further.

@Thump604 Thump604 force-pushed the fix/sdpa-full-head-dim-256 branch from 8d4b379 to f35ce26 Compare March 22, 2026 23:46
@Thump604 Thump604 changed the title fix: SDPA head_dim=256 + completion handler error safety + chunked full-attention for long context fix: add head_dim=256 to fused SDPA full attention kernel Mar 22, 2026
@Thump604 Thump604 marked this pull request as ready for review March 22, 2026 23:47
@Thump604
Copy link
Copy Markdown
Author

Done — reset to just the head_dim=256 commit (f35ce26). The completion handler and chunked SDPA changes are removed from this PR.

I'll open a separate issue to discuss the long-context GPU watchdog problem and the chunked attention approach.

@zcbenz
Copy link
Copy Markdown
Collaborator

zcbenz commented Mar 22, 2026

Can you rebase on the main branch without the "Make each thread have its own default stream" commit?

sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.

Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.

Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
@Thump604
Copy link
Copy Markdown
Author

Rebased on main — thread-local-streams commit removed.

@Thump604 Thump604 force-pushed the fix/sdpa-full-head-dim-256 branch from f35ce26 to 726c9a0 Compare March 22, 2026 23:57
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@zcbenz zcbenz requested a review from jagrit06 March 23, 2026 00:24
@angeloskath
Copy link
Copy Markdown
Member

Hm before merging this we probably need to only route to the fused kernel for large sequences because it is likely to be slower than the unfused version for shorter sequences. We 've gone back and forth several times regarding enabling this.

@jagrit06 feel free to run the benchmarks and/or tune routing and then merge.

@Thump604
Copy link
Copy Markdown
Author

Benchmark: fused vs unfused SDPA for head_dim=256

Per @angeloskath's request — benchmarked fused (steel_attention bd=256) vs unfused (matmul + softmax + matmul) across sequence lengths.

Hardware: M2 Ultra 128GB, MLX 0.31.2-dev
Precision: float16, B=1

H=8 (KV heads, GQA)

SeqLen Fused (ms) Unfused (ms) Ratio Winner
128 0.31 0.27 0.87x unfused
512 0.57 0.46 0.81x unfused
1024 1.30 0.84 0.65x unfused
4096 12.60 8.95 0.71x unfused
8192 48.05 34.08 0.71x unfused
16384 188.04 135.28 0.72x unfused
32768 754.84 746.48 0.99x unfused

H=64 (query heads, full)

SeqLen Fused (ms) Unfused (ms) Ratio Winner
128 0.34 0.32 0.96x unfused
1024 6.34 4.52 0.71x unfused
4096 93.96 67.29 0.72x unfused
16384 1533 1083 0.71x unfused
32768 works CRASH fused (only option)

Comparison: head_dim=128 (already supported)

SeqLen Fused (ms) Unfused (ms) Ratio Winner
128 0.22 0.25 1.11x fused
1024 0.55 0.60 1.09x fused
8192 18.40 20.84 1.13x fused
32768 285.28 428.35 1.50x fused

Analysis

The bd=256 fused kernel is ~30% slower than unfused at all lengths. The bd=128 kernel is 10-50% faster. The bd=256 tile configuration (32×16×256, 4 splits, 1 alignment) likely needs tuning for the larger block dimension.

However: The unfused path crashes at 32K+ with H=64 because the score matrix (B×H×L×L×2 bytes = 128 GB at 32K) exceeds Metal's buffer allocator. This is the original bug — models with head_dim=256 (all Qwen3.5) cannot run beyond ~16K context without the fused kernel.

Routing suggestion

A sequence-length threshold could route short sequences to unfused (faster) and long sequences to fused (only working path). The crossover for correctness is roughly when H * L * L * 2 > Metal buffer limit. For Qwen3.5 (H=64, D=256), that's around L=16K.

Alternatively, the fused kernel could be tuned for bd=256 — the current 32×16×256 tile config may not be optimal. Happy to test alternative tile sizes if there's a preferred configuration to try.

@angeloskath
Copy link
Copy Markdown
Member

So this is what I wrote above basically. It is slower than the unfused which is problematic.

The Qwen 3.5 not running on more than 16K context is not quite correct as it implies that you would be running the full 16k tokens in one go. Running it by chunks of 2k will work fine and be 30% faster.

Having said that, it probably still makes sense to enable this for large sequences only. Which is what I wrote above.

I do not think we should merge this as is! There is absolutely no reason to take a 30% hit in 99% of cases to enable the 1%.

@Thump604
Copy link
Copy Markdown
Author

@angeloskath Agreed. I'll update this PR with sequence-length routing: unfused by default for head_dim=256, fused only when the sequence is long enough that unfused would fail. Will post the updated code and benchmarks showing no regression on short sequences.

The fused steel_attention kernel with bd=256 is ~30% slower than the
unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused
by default and only use the fused kernel when key_sequence_length > 16384,
where unfused would exceed Metal buffer limits.

Benchmark (M2 Ultra, H=64, qL=2048, float16):
  kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing)
  kL=32768: fused only (unfused crashes)

Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
@Thump604
Copy link
Copy Markdown
Author

Routing update: unfused by default for head_dim=256

Per @angeloskath's feedback — pushed 73974355 which adds sequence-length routing in use_fallback():

  • kL ≤ 16384: unfused (matmul path) — ~30% faster, safe for typical inference
  • kL > 16384: fused (steel_attention bd=256) — handles long sequences where unfused would crash

Code change (+6 lines)

// For head_dim=256, the fused full-attention kernel is ~30% slower than
// unfused. Only route to fused when kL is large enough that unfused would
// exceed Metal buffer limits (the fused kernel tiles K/V so it scales).
const bool sdpa_full_256_ok =
    query_head_dim == 256 && key_sequence_length > 16384;
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
    (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 ||
     sdpa_full_256_ok);

Benchmark (M2 Ultra 128GB, float16, B=1)

Routing boundary — H=64, qL=2048 (realistic prefill chunk):

kL Path Time (ms)
16384 unfused 124.49
16385 fused 248.70

2.0x faster at the boundary by routing to unfused.

Full sweep — H=8, qL=min(2048, kL):

kL Path Time (ms)
1024 unfused 1.12
4096 unfused 4.28
8192 unfused 8.36
16384 unfused 16.37
16385 fused 36.34
32768 fused 61.61

Correctness verification

  • SDPA test suite: 39 passed, 0 failed
  • head_dim=256 unfused (kL=4K): max_diff=0.000122 vs reference
  • head_dim=256 fused (kL=20K): max_diff=0.000015 vs reference
  • Decode path (qL=1, kL=32K): works correctly (vector kernel, unaffected)

Threshold rationale

16384 chosen because:

  • Unfused empirically works at kL=16K (benchmarked, verified)
  • Unfused crashes at kL=32K with H=64 (score matrix exceeds Metal buffer limit)
  • With chunked prefill (qL ≤ 2048), unfused could handle higher kL — 16384 is conservative

Happy to adjust the threshold if @jagrit06 finds a better crossover point during benchmarking.

@hnshah
Copy link
Copy Markdown

hnshah commented Mar 24, 2026

Validation on M3 Ultra 256GB

I've validated the head_dim=256 fix on M3 Ultra:

Test Hardware:

  • Mac Studio M3 Ultra (256GB)
  • macOS 25.3.0 (Darwin 25.3.0)
  • MLX: from your fix/sdpa-full-head-dim-256 branch

Test Results (head_dim=256, Qwen3.5 pattern):

Context Length Time Memory Delta Result Notes
16K tokens 0.427s 0.00 GB ✅ Pass Works (baseline)
32K tokens 1.713s 0.00 GB ✅ Pass CRITICAL: Would crash before PR
64K tokens 8.101s 0.00 GB ✅ Pass CRITICAL: Would crash before PR
128K tokens - - ❌ OOM Needs PR #3307 (chunked SDPA)

Test Script:

import mlx.core as mx
import time

def test_head_dim_256(seq_len):
    B, H, D = 1, 8, 256  # head_dim=256 like Qwen3.5
    q = mx.random.normal((B, H, seq_len, D))
    k = mx.random.normal((B, H, seq_len, D))
    v = mx.random.normal((B, H, seq_len, D))
    
    start = time.time()
    out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / (D ** 0.5))
    mx.eval(out)
    elapsed = time.time() - start
    
    assert out.shape == (B, H, seq_len, D)
    assert mx.all(mx.isfinite(out)).item()
    
    print(f"✅ {seq_len//1000}K: {elapsed:.3f}s")

test_head_dim_256(16 * 1024)
test_head_dim_256(32 * 1024)  # Would crash before PR #3293
test_head_dim_256(64 * 1024)  # Would crash before PR #3293

Key Findings:

Validation Result:
PR #3293 successfully fixes the head_dim=256 crash at 32K-64K contexts. The 8+ GB single-allocation issue is resolved - the fused kernel now routes correctly for head_dim=256.

For contexts beyond 64K with head_dim=256, users will need both PR #3293 (this one) + PR #3307 (chunked SDPA).

Ready for merge! 🎯

@hnshah
Copy link
Copy Markdown

hnshah commented Mar 25, 2026

Additional head dimension discovered: head_dim=192

While testing this PR's fix, I discovered that head_dim=192 also crashes at 128K context with the same unfused fallback allocation issue.

Test results at 128K context:

  • head_dim=64: ✅ Works
  • head_dim=128: ✅ Works
  • head_dim=192: ❌ Crashes (same allocation error)
  • head_dim=256: ❌ Crashes (fixed by this PR)

The fix for head_dim=256 should also cover head_dim=192 using the same approach - adding it to sdpa_full_supported_head_dim and instantiating the kernel.

Created Issue #3312 to track this separately, but wanted to mention it here since the fix pattern is identical.

cc @Thump604 - this might be worth including in this PR or a follow-up.

Thump604 added a commit to Thump604/mlx that referenced this pull request Mar 25, 2026
Fused steel_attention bd=256 is ~30% slower than unfused. Route to
unfused by default, fused only when kL > 16384 (where unfused crashes).
Matches PR ml-explore#3293 fix pushed to fork. Verified: 39/39 SDPA tests pass.
Same pattern as head_dim=256: unfused by default for short sequences,
fused when kL > 16384 (where unfused would exceed Metal buffer limits).

Adds vector kernel instantiations for decode path.

Fixes ml-explore#3312.
@Thump604
Copy link
Copy Markdown
Author

@hnshah Good catch. Pushed a commit that adds head_dim=192 using the same routing pattern -- unfused for kL <= 16384, fused above that.

The steel attention kernel template handles BD=192 natively (TD=192/8=24, bk=16). Vector kernel instantiations added too.

The 16384 threshold is conservative for 192 (the fused/unfused perf gap may be smaller than 256's ~30%). Happy to benchmark and lower it if someone has a model with head_dim=192 to test -- I don't have one on hand. The models I'm aware of with 192-dim heads (GLM4-MoE-Lite) use Q=192/K=192/V=256, so the Q!=V check gates out the fused path anyway.

@jagrit06 this should be straightforward to verify alongside the existing routing.

@hnshah
Copy link
Copy Markdown

hnshah commented Mar 25, 2026

Thanks @Thump604 for incorporating head_dim=192 support! This addresses the issue I reported in #3312. Really appreciate you adding this - tested and working great on M3 Ultra. 🎯

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants