Fix intermittent wrong bias gradient in fast::layer_norm VJP (Metal WAR hazard)#3630
Conversation
LayerNormVJP::eval_gpu computes the bias gradient with a reduction that reads the cotangent `g`, dispatched before the main vjp kernel. When the cotangent is donatable the kernel overwrites `g`'s buffer in place (gx or gw_temp alias g), which is a write-after-read hazard. The Metal command encoder uses concurrent dispatch and only auto-inserts barriers for read-after-write, so the reduction races the overwrite and the bias gradient is intermittently wrong (error on the order of the value magnitude); the gradients for x and w are unaffected, and RMSNorm has no bias so it is not hit. This test loops a batched LayerNorm bias-gradient VJP on the GPU and compares against a deterministic CPU reference. It is reliably red on the current code and turns green once a barrier is inserted after the reduction. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Insert an explicit barrier after the bias-gradient reduction in LayerNormVJP::eval_gpu. The reduction reads the cotangent `g`, and the vjp kernel dispatched after it overwrites `g`'s buffer in place when the cotangent is donated (gx or gw_temp alias g). The command encoder uses concurrent dispatch and its automatic barriers only cover read-after-write, so this write-after-read hazard was uncovered and the two dispatches could overlap, yielding an intermittently wrong bias gradient. The barrier keeps the reduction's read of `g` from racing the kernel's overwrite. This turns the regression test added in the previous commit green. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
zcbenz
left a comment
There was a problem hiding this comment.
The test just always passed on my machine. Does it fail on your machine without the barrier?
It also does not quite make sense to me, the kernels submitted to the same queue are executed sequentially so no memory barrier is needed (otherwise lots of ops would break).
|
Yes — it fails reliably for me without the barrier, on an Apple M4 (10-core, Metal 4). Reverting just the On the sequential point: the compute encoder is created with This case is WAR: the bias reduction only reads I also experimented with a more general fix here (tillahoffmann/jax-mps#169), but wanted to propose this minimal fix first. Another option might be reordering so the reduction isn't followed by an in-place overwrite of |
zcbenz
left a comment
There was a problem hiding this comment.
Thanks for the elaboration, I think it makes sense. However I can not reproduce the problem with your test case on either M3 Max or M5 Max, which is reasonable though since it highly relies on concurrent dispatches.
I'm good with the change, but we would need another review to merge.
|
Thanks that's a very nice catch! I think we should definitely do the general fix though. Basically if the current output is a previous input then we should put a barrier. Besides catching other similar issues the current solution is even pessimistic in the case we couldn't donate the Let me try to cook something up in this branch and I 'll ask you to verify @tillahoffmann . |
|
I think the general fix turned out pretty concise as well. I will run some benchmarks to make sure it doesn't have any massive effect by adding too many barriers (it shouldn't) and then after @tillahoffmann verifies that it also fixes the test in his case we can merge. |
|
Verified the general fix on an Apple M4 (10-core, Metal 4) — the machine that reproduces the race. Method: built the branch at
I also stress-tested the cross-encoder path, since the fix is intra-encoder only: ran the full suite with LGTM — the general intra-encoder fix fully resolves the bias-grad race here. Thanks for the quick turnaround! |
Summary
fast::layer_norm's backward returns an intermittently wrong gradient w.r.t. the bias on the Metal GPU. The gradients w.r.t. the input and weight are always correct, and the CPU backend is always correct and deterministic. The bias gradient varies across identical repeated calls, so it is a race rather than a numerical issue. RMSNorm has no bias and is unaffected.Root cause — a write-after-read (WAR) hazard
In
LayerNormVJP::eval_gpu(mlx/backend/metal/normalization.cpp):gbis computed by astrided_reduce_general_dispatchthat reads the cotangentg, dispatched before the main vjp kernel.g's buffer in place: when the cotangent is donatable,gx(gx.copy_shared_buffer(g)) orgw_temp(gw_temp.copy_shared_buffer(g)) aliasg, andgw_tempis bound as a kernel output.MTL::DispatchTypeConcurrent, so dispatches overlap unless separated by amemoryBarrier. The encoder's automatic barriers only cover read-after-write (set_input_arraychecksprev_outputs_, which records prior outputs only — seeCommandEncoder::maybeInsertBarrierinmlx/backend/metal/device.cpp). A prior dispatch's reads are never recorded, so this write-after-read hazard is not covered. The two dispatches overlap and the reduction reads a partially-overwritteng.A batched input makes the cotangent donatable, which is what triggers the aliasing. RMSNorm's only reduction reads
gw_temp(an output of its kernel), a read-after-write that the auto-barrier already handles — which is why only LayerNorm'sd/dbis affected.Fix
Insert an explicit
compute_encoder.barrier()after the bias-gradient reduction so its read ofgcompletes before the kernel overwrites the buffer. This is a targeted fix for a general gap (the encoder does not model WAR hazards); the barrier is the minimal, local correction and matches howbarrier()is used elsewhere as an escape hatch.Testing (red → green TDD)
Added
TEST_CASE("test layer norm vjp bias grad race")intests/gpu_tests.cpp(Metal-only). It runs a batched LayerNorm bias-gradient VJP on the GPU many times and compares against a deterministic CPU reference (with a CPU-vs-CPU self-check first).0.050,0.126,0.209(tolerance1e-5) across runs, with a different magnitude each time, confirming a race. Trips well within the 3000-iteration loop.pre-commit(clang-format) clean.The two commits are split test-then-fix to make the red→green explicit.