fix: add head_dim=256 to fused SDPA full attention kernel#3293
fix: add head_dim=256 to fused SDPA full attention kernel#3293Thump604 wants to merge 3 commits intoml-explore:mainfrom
Conversation
|
@zcbenz This fixes the crash I reported on #3216. The root cause was the unfused SDPA fallback, not thread safety — I posted the details there. I've removed the completion handler error storage from this PR per your note about MLX not being exception-safe. The remaining change is:
The Metal kernel template already handles arbitrary No CI checks on fork PRs — happy to provide any test results you need. |
a2d6335 to
f35ce26
Compare
|
I'm good with the "fix: add head_dim=256 to fused SDPA full attention kernel" change and can you reset this PR with only that commit, or create a new PR for it? The other changes definitely need to be discussed first, and opening a new issue for what you are proposing would be more helpful. And we must carefully ensure they are not introducing performance regressions before we can look further. |
8d4b379 to
f35ce26
Compare
|
Done — reset to just the head_dim=256 commit (f35ce26). The completion handler and chunked SDPA changes are removed from this PR. I'll open a separate issue to discuss the long-context GPU watchdog problem and the chunked attention approach. |
|
Can you rebase on the main branch without the "Make each thread have its own default stream" commit? |
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.
Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.
Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
|
Rebased on main — thread-local-streams commit removed. |
f35ce26 to
726c9a0
Compare
|
Hm before merging this we probably need to only route to the fused kernel for large sequences because it is likely to be slower than the unfused version for shorter sequences. We 've gone back and forth several times regarding enabling this. @jagrit06 feel free to run the benchmarks and/or tune routing and then merge. |
Benchmark: fused vs unfused SDPA for head_dim=256Per @angeloskath's request — benchmarked fused (steel_attention bd=256) vs unfused (matmul + softmax + matmul) across sequence lengths. Hardware: M2 Ultra 128GB, MLX 0.31.2-dev H=8 (KV heads, GQA)
H=64 (query heads, full)
Comparison: head_dim=128 (already supported)
AnalysisThe bd=256 fused kernel is ~30% slower than unfused at all lengths. The bd=128 kernel is 10-50% faster. The bd=256 tile configuration (32×16×256, 4 splits, 1 alignment) likely needs tuning for the larger block dimension. However: The unfused path crashes at 32K+ with H=64 because the score matrix (B×H×L×L×2 bytes = 128 GB at 32K) exceeds Metal's buffer allocator. This is the original bug — models with head_dim=256 (all Qwen3.5) cannot run beyond ~16K context without the fused kernel. Routing suggestionA sequence-length threshold could route short sequences to unfused (faster) and long sequences to fused (only working path). The crossover for correctness is roughly when Alternatively, the fused kernel could be tuned for bd=256 — the current 32×16×256 tile config may not be optimal. Happy to test alternative tile sizes if there's a preferred configuration to try. |
|
So this is what I wrote above basically. It is slower than the unfused which is problematic. The Qwen 3.5 not running on more than 16K context is not quite correct as it implies that you would be running the full 16k tokens in one go. Running it by chunks of 2k will work fine and be 30% faster. Having said that, it probably still makes sense to enable this for large sequences only. Which is what I wrote above. I do not think we should merge this as is! There is absolutely no reason to take a 30% hit in 99% of cases to enable the 1%. |
|
@angeloskath Agreed. I'll update this PR with sequence-length routing: unfused by default for head_dim=256, fused only when the sequence is long enough that unfused would fail. Will post the updated code and benchmarks showing no regression on short sequences. |
The fused steel_attention kernel with bd=256 is ~30% slower than the unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused by default and only use the fused kernel when key_sequence_length > 16384, where unfused would exceed Metal buffer limits. Benchmark (M2 Ultra, H=64, qL=2048, float16): kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing) kL=32768: fused only (unfused crashes) Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
Routing update: unfused by default for head_dim=256Per @angeloskath's feedback — pushed
Code change (+6 lines)// For head_dim=256, the fused full-attention kernel is ~30% slower than
// unfused. Only route to fused when kL is large enough that unfused would
// exceed Metal buffer limits (the fused kernel tiles K/V so it scales).
const bool sdpa_full_256_ok =
query_head_dim == 256 && key_sequence_length > 16384;
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 ||
sdpa_full_256_ok);Benchmark (M2 Ultra 128GB, float16, B=1)Routing boundary — H=64, qL=2048 (realistic prefill chunk):
→ 2.0x faster at the boundary by routing to unfused. Full sweep — H=8, qL=min(2048, kL):
Correctness verification
Threshold rationale16384 chosen because:
Happy to adjust the threshold if @jagrit06 finds a better crossover point during benchmarking. |
Validation on M3 Ultra 256GBI've validated the head_dim=256 fix on M3 Ultra: Test Hardware:
Test Results (head_dim=256, Qwen3.5 pattern):
Test Script: import mlx.core as mx
import time
def test_head_dim_256(seq_len):
B, H, D = 1, 8, 256 # head_dim=256 like Qwen3.5
q = mx.random.normal((B, H, seq_len, D))
k = mx.random.normal((B, H, seq_len, D))
v = mx.random.normal((B, H, seq_len, D))
start = time.time()
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / (D ** 0.5))
mx.eval(out)
elapsed = time.time() - start
assert out.shape == (B, H, seq_len, D)
assert mx.all(mx.isfinite(out)).item()
print(f"✅ {seq_len//1000}K: {elapsed:.3f}s")
test_head_dim_256(16 * 1024)
test_head_dim_256(32 * 1024) # Would crash before PR #3293
test_head_dim_256(64 * 1024) # Would crash before PR #3293Key Findings:
Validation Result: For contexts beyond 64K with head_dim=256, users will need both PR #3293 (this one) + PR #3307 (chunked SDPA). Ready for merge! 🎯 |
Additional head dimension discovered: head_dim=192While testing this PR's fix, I discovered that head_dim=192 also crashes at 128K context with the same unfused fallback allocation issue. Test results at 128K context:
The fix for head_dim=256 should also cover head_dim=192 using the same approach - adding it to Created Issue #3312 to track this separately, but wanted to mention it here since the fix pattern is identical. cc @Thump604 - this might be worth including in this PR or a follow-up. |
Fused steel_attention bd=256 is ~30% slower than unfused. Route to unfused by default, fused only when kL > 16384 (where unfused crashes). Matches PR ml-explore#3293 fix pushed to fork. Verified: 39/39 SDPA tests pass.
Same pattern as head_dim=256: unfused by default for short sequences, fused when kL > 16384 (where unfused would exceed Metal buffer limits). Adds vector kernel instantiations for decode path. Fixes ml-explore#3312.
|
@hnshah Good catch. Pushed a commit that adds head_dim=192 using the same routing pattern -- unfused for kL <= 16384, fused above that. The steel attention kernel template handles BD=192 natively (TD=192/8=24, bk=16). Vector kernel instantiations added too. The 16384 threshold is conservative for 192 (the fused/unfused perf gap may be smaller than 256's ~30%). Happy to benchmark and lower it if someone has a model with head_dim=192 to test -- I don't have one on hand. The models I'm aware of with 192-dim heads (GLM4-MoE-Lite) use Q=192/K=192/V=256, so the Q!=V check gates out the fused path anyway. @jagrit06 this should be straightforward to verify alongside the existing routing. |
Summary
Add
head_dim == 256tosdpa_full_supported_head_dimand instantiate thesteel_attentionkernel withbd=256.Models with head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention path which materializes the full score matrix as a single matmul. At 32K+ context this creates 8+ GB single allocations that crash Metal's buffer allocator.
The Metal kernel template already handles arbitrary
BDvia template parameter — only the dispatch gate and kernel instantiation list were missing.Changes (1 commit, 2 files, +3 lines)
scaled_dot_product_attention.cpp: Addquery_head_dim == 256tosdpa_full_supported_head_dimsteel_attention.metal: Addinstantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype)Verification
M2 Ultra 128GB, Qwen3.5-122B-A10B (5-bit, head_dim=256):
Affected models
All models with head_dim=256, including Qwen3.5-122B-A10B, Qwen3.5-35B-A3B, Qwen3.5-27B, Qwen3.5-4B.