Skip to content

Fix intermittent wrong bias gradient in fast::layer_norm VJP (Metal WAR hazard)#3630

Merged
angeloskath merged 4 commits into
ml-explore:mainfrom
tillahoffmann:fix-layernorm-vjp-bias-war-hazard
Jun 12, 2026
Merged

Fix intermittent wrong bias gradient in fast::layer_norm VJP (Metal WAR hazard)#3630
angeloskath merged 4 commits into
ml-explore:mainfrom
tillahoffmann:fix-layernorm-vjp-bias-war-hazard

Conversation

@tillahoffmann

Copy link
Copy Markdown
Contributor

Summary

fast::layer_norm's backward returns an intermittently wrong gradient w.r.t. the bias on the Metal GPU. The gradients w.r.t. the input and weight are always correct, and the CPU backend is always correct and deterministic. The bias gradient varies across identical repeated calls, so it is a race rather than a numerical issue. RMSNorm has no bias and is unaffected.

Root cause — a write-after-read (WAR) hazard

In LayerNormVJP::eval_gpu (mlx/backend/metal/normalization.cpp):

  1. The bias gradient gb is computed by a strided_reduce_general_dispatch that reads the cotangent g, dispatched before the main vjp kernel.
  2. The main vjp kernel then writes g's buffer in place: when the cotangent is donatable, gx (gx.copy_shared_buffer(g)) or gw_temp (gw_temp.copy_shared_buffer(g)) alias g, and gw_temp is bound as a kernel output.
  3. The command encoder is created with MTL::DispatchTypeConcurrent, so dispatches overlap unless separated by a memoryBarrier. The encoder's automatic barriers only cover read-after-write (set_input_array checks prev_outputs_, which records prior outputs only — see CommandEncoder::maybeInsertBarrier in mlx/backend/metal/device.cpp). A prior dispatch's reads are never recorded, so this write-after-read hazard is not covered. The two dispatches overlap and the reduction reads a partially-overwritten g.

A batched input makes the cotangent donatable, which is what triggers the aliasing. RMSNorm's only reduction reads gw_temp (an output of its kernel), a read-after-write that the auto-barrier already handles — which is why only LayerNorm's d/db is affected.

Fix

Insert an explicit compute_encoder.barrier() after the bias-gradient reduction so its read of g completes before the kernel overwrites the buffer. This is a targeted fix for a general gap (the encoder does not model WAR hazards); the barrier is the minimal, local correction and matches how barrier() is used elsewhere as an escape hatch.

Testing (red → green TDD)

Added TEST_CASE("test layer norm vjp bias grad race") in tests/gpu_tests.cpp (Metal-only). It runs a batched LayerNorm bias-gradient VJP on the GPU many times and compares against a deterministic CPU reference (with a CPU-vs-CPU self-check first).

  • Red (test commit alone): reliably fails — observed worst diffs of 0.050, 0.126, 0.209 (tolerance 1e-5) across runs, with a different magnitude each time, confirming a race. Trips well within the 3000-iteration loop.
  • Green (with the fix): passes across repeated runs, all 3000 iterations matching CPU exactly.
  • Full C++ suite: 253/253 cases, 3443 assertions, 0 failures.

pre-commit (clang-format) clean.

The two commits are split test-then-fix to make the red→green explicit.

tillahoffmann and others added 2 commits June 4, 2026 22:23
LayerNormVJP::eval_gpu computes the bias gradient with a reduction that
reads the cotangent `g`, dispatched before the main vjp kernel. When the
cotangent is donatable the kernel overwrites `g`'s buffer in place (gx or
gw_temp alias g), which is a write-after-read hazard. The Metal command
encoder uses concurrent dispatch and only auto-inserts barriers for
read-after-write, so the reduction races the overwrite and the bias
gradient is intermittently wrong (error on the order of the value
magnitude); the gradients for x and w are unaffected, and RMSNorm has no
bias so it is not hit.

This test loops a batched LayerNorm bias-gradient VJP on the GPU and
compares against a deterministic CPU reference. It is reliably red on
the current code and turns green once a barrier is inserted after the
reduction.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Insert an explicit barrier after the bias-gradient reduction in
LayerNormVJP::eval_gpu. The reduction reads the cotangent `g`, and the
vjp kernel dispatched after it overwrites `g`'s buffer in place when the
cotangent is donated (gx or gw_temp alias g). The command encoder uses
concurrent dispatch and its automatic barriers only cover
read-after-write, so this write-after-read hazard was uncovered and the
two dispatches could overlap, yielding an intermittently wrong bias
gradient.

The barrier keeps the reduction's read of `g` from racing the kernel's
overwrite. This turns the regression test added in the previous commit
green.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test just always passed on my machine. Does it fail on your machine without the barrier?

It also does not quite make sense to me, the kernels submitted to the same queue are executed sequentially so no memory barrier is needed (otherwise lots of ops would break).

@tillahoffmann

Copy link
Copy Markdown
Contributor Author

Yes — it fails reliably for me without the barrier, on an Apple M4 (10-core, Metal 4). Reverting just the barrier() line (keeping the test), it fails every run with worst ≈ 0.087 vs the 1e-5 tolerance; with the barrier it passes. So I think "always passed" is scheduler-dependent — the two dispatches overlap on the M4 (under relatively heavy load) and don't seem to on yours.

On the sequential point: the compute encoder is created with MTL::DispatchTypeConcurrent (device.cpp:476), so consecutive dispatches can overlap unless MLX inserts a memoryBarrier itself. That only happens (maybeInsertBarrier, device.cpp:345) when an input is found in prev_outputs_ — buffers previously written (device.cpp:308). So the auto-barriers cover RAW/WAW but not WAR.

This case is WAR: the bias reduction only reads g, so g never enters prev_outputs_. The next kernel writes gx/gw_temp, which alias g when the cotangent is donated (normalization.cpp:337/355), so no barrier is emitted and the read and overwrite can overlap. It doesn't surface elsewhere because the uncovered shape is narrow — a read-only dispatch followed by a concurrent overwrite of the same buffer via donation. RMSNormVJP sidesteps it (reduction runs after the kernel, reads gw_temp, no bias).

I also experimented with a more general fix here (tillahoffmann/jax-mps#169), but wanted to propose this minimal fix first. Another option might be reordering so the reduction isn't followed by an in-place overwrite of g, or not donating g while a bias reduction is pending.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the elaboration, I think it makes sense. However I can not reproduce the problem with your test case on either M3 Max or M5 Max, which is reasonable though since it highly relies on concurrent dispatches.

I'm good with the change, but we would need another review to merge.

@angeloskath

Copy link
Copy Markdown
Member

Thanks that's a very nice catch!

I think we should definitely do the general fix though. Basically if the current output is a previous input then we should put a barrier. Besides catching other similar issues the current solution is even pessimistic in the case we couldn't donate the g.

Let me try to cook something up in this branch and I 'll ask you to verify @tillahoffmann .

@angeloskath

Copy link
Copy Markdown
Member

I think the general fix turned out pretty concise as well.

I will run some benchmarks to make sure it doesn't have any massive effect by adding too many barriers (it shouldn't) and then after @tillahoffmann verifies that it also fixes the test in his case we can merge.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a quite nice fix!

@tillahoffmann

Copy link
Copy Markdown
Contributor Author

Verified the general fix on an Apple M4 (10-core, Metal 4) — the machine that reproduces the race.

Method: built the branch at 3ebbadc and ran the regression test as a red→green check.

  • Red (no fix): rebuilt with both the LayerNorm barrier and the general fix removed (test-only, d2bf6e33). Reliably fails — 15/15 runs red, worst bias-grad diff 0.27 vs the 1e-5 tolerance.
  • Green (general fix): 0 failures across 58 invocations (30 isolated + 8 concurrent + 20 final), each looping 3000× internally ≈ 170k VJPs, all matching the CPU reference.
  • Full suite: 253/253 cases pass.

I also stress-tested the cross-encoder path, since the fix is intra-encoder only: ran the full suite with MLX_MAX_OPS_PER_BUFFER=1 MLX_MAX_MB_PER_BUFFER=1 (forces a commit after nearly every op, turning op-to-op deps into cross-encoder ones). 253/253 across 3 runs. And I tried to construct a cross-encoder WAR directly (independent reader + in-place donator sharing a buffer) — couldn't trigger it, because in-place overwrite requires is_donatable() (use_count == 1), which is mutually exclusive with having a second concurrent reader. The only read+overwrite-same-buffer case is intra-primitive (always intra-encoder), which this fix covers.

LGTM — the general intra-encoder fix fully resolves the bias-grad race here. Thanks for the quick turnaround!

@angeloskath angeloskath merged commit 269e099 into ml-explore:main Jun 12, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants