fix(shapeless): infer dynamic dim in Unflatten::output_shapes (#2607)#3649
Open
nac7 wants to merge 2 commits into
Open
fix(shapeless): infer dynamic dim in Unflatten::output_shapes (#2607)#3649nac7 wants to merge 2 commits into
nac7 wants to merge 2 commits into
Conversation
When mx.matmul is called on a 3-D+ input with a 2-D weight matrix,
ops.cpp flattens the batch dims (Flatten), runs the 2-D Matmul, then
unflattens (Unflatten). The Unflatten primitive stores the trace-time
batch shape (e.g. {1, 40}) as state.
In shapeless mode, compile_replace calls output_shapes on each
primitive to recompute output shapes for the new input. The previous
implementation of Unflatten::output_shapes used the stored shape
verbatim, so a call with seq_len=248 would return (1, 40, H) instead
of (1, 248, H) -- silently wrong.
Fix: when the actual flat size differs from product(shape_), scan
shape_ right-to-left for the unique dimension whose removal yields a
divisor of the actual flat size and scale it accordingly. This
handles the typical transformer pattern where batch dims are static
and seq_len is dynamic. Throw if no valid dimension is found.
Adds two regression tests from issue ml-explore#2607:
- test_export_matmul_shapeless_mid_dim (B=1, varying seq_len)
- test_export_matmul_shapeless_batch_and_mid_dim (B=2, varying seq_len)
Fixes ml-explore#2607
Collaborator
|
Can you fix the lint error? |
Author
|
Hi @zcbenz , fixed the lint error. Thanks for your review! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
mx.matmul(x, W)wherexhasndim > 2andWhasndim ≤ 2is implemented in
ops.cppas:orig_shapeisx.shape()[:-1]at trace time (e.g.{1, 40}), whichis stored verbatim as state inside the
Unflattenprimitive.In shapeless mode
compile_replacecallsoutput_shapeson everyprimitive to recompute output shapes for the actual inputs.
Unflatten::output_shapesreturned the stored shape unchanged, so acall with
seq_len=248would produce(1, 40, H)instead of(1, 248, H)— silently wrong.Fix
mlx/primitives.cpp—Unflatten::output_shapes:When
inputs[0].shape()[axis_] ≠ product(shape_), scanshape_right-to-left for the unique dimension whose removal leaves a divisor
of the actual flat size, then scale that dimension to absorb the new
size. If no such dimension exists, raise
std::invalid_argument.This matches the common transformer pattern where leading batch
dimensions are static and
seq_lenis dynamic.Tests
Two new tests in
python/tests/test_export_import.py:test_export_matmul_shapeless_mid_dim— the exact reproductioncase from [BUG] shapeless matmul isn't #2607 (
B=1, varyingseq_len∈ {40, 248, 623}); checksboth shape and numerical agreement with the reference function.
test_export_matmul_shapeless_batch_and_mid_dim— higher-rankvariant (
B=2, varyingseq_len); ensures the fix generalisesbeyond the single-batch case.
Fixes #2607