Skip to content

fix(shapeless): infer dynamic dim in Unflatten::output_shapes (#2607)#3649

Open
nac7 wants to merge 2 commits into
ml-explore:mainfrom
nac7:fix/shapeless-unflatten-mid-dim
Open

fix(shapeless): infer dynamic dim in Unflatten::output_shapes (#2607)#3649
nac7 wants to merge 2 commits into
ml-explore:mainfrom
nac7:fix/shapeless-unflatten-mid-dim

Conversation

@nac7

@nac7 nac7 commented Jun 9, 2026

Copy link
Copy Markdown

What

mx.matmul(x, W) where x has ndim > 2 and W has ndim ≤ 2
is implemented in ops.cpp as:

Flatten(x, 0, -2)  →  Matmul  →  Unflatten(out, 0, orig_shape)

orig_shape is x.shape()[:-1] at trace time (e.g. {1, 40}), which
is stored verbatim as state inside the Unflatten primitive.

In shapeless mode compile_replace calls output_shapes on every
primitive to recompute output shapes for the actual inputs.
Unflatten::output_shapes returned the stored shape unchanged, so a
call with seq_len=248 would produce (1, 40, H) instead of
(1, 248, H) — silently wrong.

Fix

mlx/primitives.cppUnflatten::output_shapes:

When inputs[0].shape()[axis_] ≠ product(shape_), scan shape_
right-to-left for the unique dimension whose removal leaves a divisor
of the actual flat size, then scale that dimension to absorb the new
size. If no such dimension exists, raise std::invalid_argument.

This matches the common transformer pattern where leading batch
dimensions are static and seq_len is dynamic.

Tests

Two new tests in python/tests/test_export_import.py:

  • test_export_matmul_shapeless_mid_dim — the exact reproduction
    case from [BUG] shapeless matmul isn't #2607 (B=1, varying seq_len ∈ {40, 248, 623}); checks
    both shape and numerical agreement with the reference function.
  • test_export_matmul_shapeless_batch_and_mid_dim — higher-rank
    variant (B=2, varying seq_len); ensures the fix generalises
    beyond the single-batch case.

Fixes #2607

When mx.matmul is called on a 3-D+ input with a 2-D weight matrix,
ops.cpp flattens the batch dims (Flatten), runs the 2-D Matmul, then
unflattens (Unflatten).  The Unflatten primitive stores the trace-time
batch shape (e.g. {1, 40}) as state.

In shapeless mode, compile_replace calls output_shapes on each
primitive to recompute output shapes for the new input.  The previous
implementation of Unflatten::output_shapes used the stored shape
verbatim, so a call with seq_len=248 would return (1, 40, H) instead
of (1, 248, H) -- silently wrong.

Fix: when the actual flat size differs from product(shape_), scan
shape_ right-to-left for the unique dimension whose removal yields a
divisor of the actual flat size and scale it accordingly.  This
handles the typical transformer pattern where batch dims are static
and seq_len is dynamic.  Throw if no valid dimension is found.

Adds two regression tests from issue ml-explore#2607:
  - test_export_matmul_shapeless_mid_dim   (B=1, varying seq_len)
  - test_export_matmul_shapeless_batch_and_mid_dim (B=2, varying seq_len)

Fixes ml-explore#2607
@zcbenz

zcbenz commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Can you fix the lint error?

@nac7

nac7 commented Jun 12, 2026

Copy link
Copy Markdown
Author

Hi @zcbenz , fixed the lint error. Thanks for your review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] shapeless matmul isn't

2 participants