Skip to content

Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660

Open
hojin12312 wants to merge 4 commits into
ml-explore:mainfrom
hojin12312:sdpa-256-revival
Open

Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660
hojin12312 wants to merge 4 commits into
ml-explore:mainfrom
hojin12312:sdpa-256-revival

Conversation

@hojin12312

Copy link
Copy Markdown

Summary

Revival of #3293: adds head_dim=192/256 instantiations to the fused
steel_attention full-attention kernel and routes SDPA to it above the
kL>16384 dispatch crossover (below it the unfused path is faster and its
transient is bounded). Decode (sdpa_vector) already supports these head
dims; this closes the gap for prefill.

The three original commits by @Thump604 are preserved with authorship.
On top of them, one completeness commit for current main:

  • add the missing bd=192 steel instantiation that the routing already
    targeted
  • exclude head_dim>=192 from the NAX dispatch branch (the NAX kernel
    family has no 192/256 instantiations)

Motivation and real-world evidence: #3658. The unfused fallback
materializes a score transient that grows linearly with context length;
at 100K+ context on a 36 GB machine a serving runtime must shrink
prefill chunks to stay alive and GPU utilization collapses (26+ minute
single-turn prefills measured at ~133K with Gemma 4 26B).

Microbenchmarks (M4 Max, 36 GB)

n_q=16, n_kv=8 (GQA), head_dim=256, fp16 — peak memory, inputs included:

qL kL inputs fused peak unfused peak max abs err
1024 32768 264M 272M 1,824M 0.0
2048 65536 528M 544M 5,696M 0.0
2048 131072 1,040M 1,056M 11,328M 0.0

Fused transient is ~16 MB regardless of kL (O(1)); unfused grows to
+10.3 GB at 131K. Outputs agree exactly.

Speed: above the crossover the fused kernel is 0.74–0.79x of unfused
per-kernel on M4 Max — consistent with @angeloskath's earlier
measurements, and the reason the routing keeps unfused below kL=16384
where its transient is affordable.

End-to-end (production serving, Mac Studio 36 GB)

oMLX server, chunked prefill (1024-token chunks), prefix KV cache
(4-bit quantized at rest), 30 GB memory ceiling with a 27 GB prefill
safety cap. The serving stack's prefill admission estimator was taught
the fused O(1) transient; everything else unchanged. Model: Gemma 4
26B-A4B MoE 3-bit, head_dim=256, 5/30 full-attention layers.

Scenario unfused (0.31.2) fused (this branch)
Max cold single-shot prefill ~156K (admission estimator rejects — the projected SDPA transient, not real memory, is the binding term) 184K completes (629 s); 204K is rejected gracefully at 182K, where resident fp16 KV genuinely exhausts the cap
Multi-turn conversation ceiling ~133K; beyond it chunk-shrink degrades a single turn's prefill to 26+ minutes 167K at 99% prefix-cache hit and 9–12 s/turn all the way up — the old 133K crawl zone passes at 8–9 s/turn; ends in a clean memory-cap rejection, no crash
One-shot ~25K-token paste on top of a 130K conversation impractical (crawl zone) 83 s, full prefix-cache hit, no memory event
Qwen3.6 35B-A3B MoE 3-bit (10 full-attn layers, hd=256), 145K cold 280 s 332 s (0.84x — where memory was never the binding constraint, the fused kernel's per-kernel cost shows; the trade is bounded memory)

The honest summary: per-kernel the fused path is slower, but it converts
"context sizes that OOM or crawl" into "context sizes that just run".
For head_dim=256 models the practical context ceiling on this machine
moved +28K (cold) / +34K (multi-turn), and the failure mode past the
ceiling is now a clean rejection instead of a 26-minute crawl.

Validation

  • python/tests/test_fast_sdpa.py: 16 tests pass (1 skipped) on this
    branch; test_fast.py: 23 pass
  • Output parity fused vs unfused verified at fp16/bf16, qL∈{1024,2048},
    kL∈{32K..131K} (max abs err 0.0)
  • Third-party validation history on the original PR (fix: add head_dim=256 to fused SDPA full attention kernel #3293): M2 Ultra,
    M3 Ultra
  • ~6 hours of production serving traffic on this build (multi-turn
    sweeps to 167K, cold prefills to 204K) with zero crashes

Credit: the kernel work is @Thump604's (#3293); this PR rebases it onto
current main with two completeness fixes and adds real-world serving
evidence (#3658).

🤖 Generated with Claude Code

Thump604 and others added 4 commits June 11, 2026 20:10
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.

Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.

Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
The fused steel_attention kernel with bd=256 is ~30% slower than the
unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused
by default and only use the fused kernel when key_sequence_length > 16384,
where unfused would exceed Metal buffer limits.

Benchmark (M2 Ultra, H=64, qL=2048, float16):
  kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing)
  kL=32768: fused only (unfused crashes)

Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
Same pattern as head_dim=256: unfused by default for short sequences,
fused when kL > 16384 (where unfused would exceed Metal buffer limits).

Adds vector kernel instantiations for decode path.

Fixes ml-explore#3312.
- Add missing bd=192 steel_attention instantiation (use_fallback routes
  head_dim=192 to the fused full kernel, but only bd=256 was instantiated)
- Exclude head_dim >= 192 from the NAX dispatch branch: the NAX kernel
  family only instantiates bd=64/128, so those shapes go to the legacy
  steel kernel which has the instantiations

Co-authored-by: Thump604 <thump@cosmiccooler.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants