Skip to content

Fix Select::jvp tangent indexing (silent zero JVP through mx.where)#3643

Closed
nac7 wants to merge 1 commit into
ml-explore:mainfrom
nac7:fix/select-jvp-tangent-indexing
Closed

Fix Select::jvp tangent indexing (silent zero JVP through mx.where)#3643
nac7 wants to merge 1 commit into
ml-explore:mainfrom
nac7:fix/select-jvp-tangent-indexing

Conversation

@nac7

@nac7 nac7 commented Jun 8, 2026

Copy link
Copy Markdown

Summary

Fixes #3627.

mx.jvp silently returned zero when differentiating through mx.where(cond, constant, traced_value) — when the constant is in the true branch and the traced value is in the false branch.

Root Cause

Select::jvp in mlx/primitives.cpp had three related indexing errors.

The JVP driver in transforms.cpp builds a compact tangent vector: tangents[i] is the tangent for argnums[i], not for argument number i. Every other primitive JVP follows this positional convention. Select::jvp did not:

  1. Wrong assertionassert(tangents.size() == 3) assumed a full-length vector; corrected to assert(tangents.size() == argnums.size()).

  2. Wrong call sitesjvp_fun(argnums[0]) and jvp_fun(argnums[i]) passed argument numbers (1 or 2) as the positional index i. Inside the lambda, int arg = argnums[i] then accessed argnums out-of-bounds (UB in release mode). In practice the OOB read returned 0, hitting the arg == 0 branch and returning zeros_like. Corrected to jvp_fun(0) / jvp_fun(i).

  3. Wrong tangent accesstangents[1] and tangents[2] were hard-coded absolute argument indices instead of tangents[i]. Corrected to tangents[i].

Tests

Added test_jvp_where_tangent_indexing in test_autograd.py covering all four variants:

Variant Description Before After
A where(cond, traced, const) correct correct
B where(cond, const, traced) returned 0 (bug) returns 2.0
C where(cond, traced, traced) correct correct
D batched false-branch-only input broken correct

The tangents vector passed to primitive jvp methods is compact:
tangents[i] corresponds to argnums[i], not to argnums[j] for any j.
Select::jvp had three related errors:

- assert(tangents.size() == 3) assumed a full-length vector; corrected
  to assert(tangents.size() == argnums.size()).
- jvp_fun was called as jvp_fun(argnums[j]), passing argument numbers
  (1 or 2) as the positional index i; the lambda then accessed
  argnums[i] out-of-bounds. Corrected to jvp_fun(j).
- Inside jvp_fun, tangents[1] and tangents[2] were hard-coded absolute
  indices; corrected to tangents[i] (the positional index).

Together these caused mx.jvp to silently return zero when
mx.where(cond, constant, traced_value) was differentiated and only
the false branch carried a tangent.
@zcbenz

zcbenz commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the PR but this had been covered by #3633, please let me know if the PR was missing anything.

@zcbenz zcbenz closed this Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] JVP silently returns a zero tangent when differentiating through mx.where

2 participants