Skip to content

Fix CPU gather transposing column-contiguous slices#3647

Open
tillahoffmann wants to merge 2 commits into
ml-explore:mainfrom
tillahoffmann:fix-cpu-chained-gather
Open

Fix CPU gather transposing column-contiguous slices#3647
tillahoffmann wants to merge 2 commits into
ml-explore:mainfrom
tillahoffmann:fix-cpu-chained-gather

Conversation

@tillahoffmann

Copy link
Copy Markdown
Contributor

Summary

mx.take (gather) on the CPU backend returns a transposed / wrong-stride result when the source is column-contiguous and the gathered slice spans more than one non-singleton dimension. The GPU backend is correct.

The gather "fast copy" path in mlx/backend/cpu/indexing.cpp decided whether each gathered slice could be copied as a single contiguous block based on the source's row_contiguous / col_contiguous flags. The gather output is always written in row-major order, but a column-contiguous source is contiguous in memory in column-major order — so copying a multi-dimensional slice as a raw block transposes it.

This surfaces via:

  • chained take through size-1 axes, which produces a column-contiguous intermediate, and
  • a direct take from a transposed (column-contiguous) source.

Repro

import mlx.core as mx
import numpy as np

# 1) chained take through size-1 axes
with mx.stream(mx.cpu):
    u = mx.array([1.0, 2.0], dtype=mx.float32).reshape(2, 1, 1)
    g = mx.take(u, mx.array([0, 1], dtype=mx.int32), axis=0)
    g = mx.take(g, mx.array([0, 0, 0], dtype=mx.int32), axis=1)
    g = mx.take(g, mx.array([0, 0, 0], dtype=mx.int32), axis=2)
    mx.eval(g)

# 2) direct take from a transposed (col-contiguous) source
with mx.stream(mx.cpu):
    a = mx.transpose(mx.arange(24).reshape(4, 3, 2), (2, 1, 0))  # [2,3,4] col-contiguous
    t = mx.take(a, mx.array([0, 1], dtype=mx.int32), axis=2)
    mx.eval(t)

Before this fix (disagrees with NumPy / the GPU backend):

chained take  mlx  : [1 1 1 2 2 2 1 1 1 2 2 2 1 1 1 2 2 2]
              numpy: [1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2]
col take      mlx  : [0 6 1 7 2 8 3 9 4 10 5 11]
              numpy: [0 6 2 8 4 10 1 7 3 9 5 11]

After this fix both match NumPy.

Originally found while differential-testing a CPU reference against the GPU/MPS path; a downstream consumer reaches this through a dynamic_update_slice lowering that expands to a chained-take + mask + where.

Fix

Replace the flag-based heuristic with a direct check that the slice is genuinely row-major contiguous within the source: each non-singleton slice dimension's source stride must equal the product of the slice sizes of the dimensions inside it (size-1 dimensions are skipped, since their stride is irrelevant). This is the exact precondition for the contiguous std::copy; everything else falls back to the strided iterator. The check is also strictly more correct than the previous row-contiguous branch.

Testing

  • Added a test gather contiguity C++ regression case (tests/ops_tests.cpp) covering both scenarios, forced onto the CPU device.
  • Full C++ test suite passes (253 cases / 3445 assertions).
  • Verified end-to-end through the Python wrapper: the repro above matches NumPy with the fix and mismatches without it.

🤖 Generated with Claude Code

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, I think it works correctly, but without background knowledge I would need more reviews on this.

tillahoffmann and others added 2 commits June 11, 2026 08:29
The gather "fast copy" path used the row/col contiguous flags to decide
whether a per-index slice could be copied as a single contiguous block.
For a column-contiguous source the slice is contiguous in memory but in
column-major order, while the output is written in row-major order, so a
multi-dimensional slice came out transposed.

This surfaced on the CPU backend via chained `take` through size-1 axes
(which yield a col-contiguous intermediate) and via a direct `take` from a
transposed source; the GPU backend was correct.

Replace the flag-based heuristic with a direct check that the slice is
row-major contiguous within the source (each non-singleton slice dim's
source stride equals the product of the inner slice sizes), which is the
exact precondition for the contiguous copy. Falls back to the strided
iterator otherwise.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@zcbenz zcbenz force-pushed the fix-cpu-chained-gather branch from e5b97b3 to 8e20a28 Compare June 10, 2026 23:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants