Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660
Open
hojin12312 wants to merge 4 commits into
Open
Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660hojin12312 wants to merge 4 commits into
hojin12312 wants to merge 4 commits into
Conversation
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.
Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.
Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
The fused steel_attention kernel with bd=256 is ~30% slower than the unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused by default and only use the fused kernel when key_sequence_length > 16384, where unfused would exceed Metal buffer limits. Benchmark (M2 Ultra, H=64, qL=2048, float16): kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing) kL=32768: fused only (unfused crashes) Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
Same pattern as head_dim=256: unfused by default for short sequences, fused when kL > 16384 (where unfused would exceed Metal buffer limits). Adds vector kernel instantiations for decode path. Fixes ml-explore#3312.
- Add missing bd=192 steel_attention instantiation (use_fallback routes head_dim=192 to the fused full kernel, but only bd=256 was instantiated) - Exclude head_dim >= 192 from the NAX dispatch branch: the NAX kernel family only instantiates bd=64/128, so those shapes go to the legacy steel kernel which has the instantiations Co-authored-by: Thump604 <thump@cosmiccooler.org>
This was referenced Jun 11, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Revival of #3293: adds head_dim=192/256 instantiations to the fused
steel_attentionfull-attention kernel and routes SDPA to it above thekL>16384 dispatch crossover (below it the unfused path is faster and its
transient is bounded). Decode (
sdpa_vector) already supports these headdims; this closes the gap for prefill.
The three original commits by @Thump604 are preserved with authorship.
On top of them, one completeness commit for current main:
bd=192steel instantiation that the routing alreadytargeted
family has no 192/256 instantiations)
Motivation and real-world evidence: #3658. The unfused fallback
materializes a score transient that grows linearly with context length;
at 100K+ context on a 36 GB machine a serving runtime must shrink
prefill chunks to stay alive and GPU utilization collapses (26+ minute
single-turn prefills measured at ~133K with Gemma 4 26B).
Microbenchmarks (M4 Max, 36 GB)
n_q=16, n_kv=8 (GQA), head_dim=256, fp16— peak memory, inputs included:Fused transient is ~16 MB regardless of kL (O(1)); unfused grows to
+10.3 GB at 131K. Outputs agree exactly.
Speed: above the crossover the fused kernel is 0.74–0.79x of unfused
per-kernel on M4 Max — consistent with @angeloskath's earlier
measurements, and the reason the routing keeps unfused below kL=16384
where its transient is affordable.
End-to-end (production serving, Mac Studio 36 GB)
oMLX server, chunked prefill (1024-token chunks), prefix KV cache
(4-bit quantized at rest), 30 GB memory ceiling with a 27 GB prefill
safety cap. The serving stack's prefill admission estimator was taught
the fused O(1) transient; everything else unchanged. Model: Gemma 4
26B-A4B MoE 3-bit, head_dim=256, 5/30 full-attention layers.
The honest summary: per-kernel the fused path is slower, but it converts
"context sizes that OOM or crawl" into "context sizes that just run".
For head_dim=256 models the practical context ceiling on this machine
moved +28K (cold) / +34K (multi-turn), and the failure mode past the
ceiling is now a clean rejection instead of a 26-minute crawl.
Validation
python/tests/test_fast_sdpa.py: 16 tests pass (1 skipped) on thisbranch;
test_fast.py: 23 passkL∈{32K..131K} (max abs err 0.0)
M3 Ultra
sweeps to 167K, cold prefills to 204K) with zero crashes
Credit: the kernel work is @Thump604's (#3293); this PR rebases it onto
current main with two completeness fixes and adds real-world serving
evidence (#3658).
🤖 Generated with Claude Code