Skip to content

Implement batched matmul for large 1D dot products#3580

Open
Ved235 wants to merge 5 commits into
ml-explore:mainfrom
Ved235:main
Open

Implement batched matmul for large 1D dot products#3580
Ved235 wants to merge 5 commits into
ml-explore:mainfrom
Ved235:main

Conversation

@Ved235

@Ved235 Ved235 commented May 22, 2026

Copy link
Copy Markdown

Proposed changes

Addresses issue #3533. Adds routing logic in mlx/ops.cpp so that it divides the large 1D dot product into chunks so gemv parallelizes.

Benchmark

import mlx.core as mx
import numpy as np
import time

def bench(fn, rounds=100, label=""):
    for _ in range(3):
        r = fn()
        mx.eval(r)

    times = []
    for _ in range(rounds):
        mx.eval()  
        t0 = time.perf_counter()
        r = fn()
        mx.eval(r) 
        times.append(time.perf_counter() - t0)

    times.sort()
    median = times[len(times) // 2]
    best = times[0]
    worst = times[-1]
    print(f"{label}")
    print(f"median={median*1000:.3f}ms | min={best*1000:.3f}ms | max={worst*1000:.3f}ms")
    return r

a = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)
b = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)

a_np = np.array(a, copy=False)
b_np = np.array(b, copy=False)

ccc = bench(lambda: mx.inner(a, b), label="MLX native")

print(f"mx.inner : {float(ccc)}")

Using this benchmarking script the performance changes are:

median=15.393ms | min=15.323ms | max=15.769ms

to

median=1.741ms | min=1.682ms | max=1.835ms

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Ved235

Ved235 commented May 23, 2026

Copy link
Copy Markdown
Author

@zcbenz would it possible for you to review this?

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically looks good to me, thanks!

Comment thread mlx/ops.cpp Outdated
@Ved235

Ved235 commented May 26, 2026

Copy link
Copy Markdown
Author

@zcbenz I have made the changes could you review and merge.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like another review from maintainers before merging.

@Ved235

Ved235 commented May 28, 2026

Copy link
Copy Markdown
Author

I would like another review from maintainers before merging.

Ok sure, is there someone specific I should tag to request a review?

@zcbenz zcbenz requested review from angeloskath and jagrit06 May 28, 2026 22:49
@Ved235

Ved235 commented Jun 6, 2026

Copy link
Copy Markdown
Author

@angeloskath @jagrit06 it would be great if you could review these changes

@Ved235

Ved235 commented Jun 12, 2026

Copy link
Copy Markdown
Author

@zcbenz Its been several weeks since the changes and I believe that the changes are not very large in terms of number of lines, so would it be possible for this to be merged?

@zcbenz

zcbenz commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Sorry WWDC had largely disrupted our schedule, there is no wrong with this PR it is just I need another view on this since I lack the background knowledge. WWDC is over now and there is a large backlog so please give us more time.

@Ved235

Ved235 commented Jun 12, 2026

Copy link
Copy Markdown
Author

No worries and thanks for the update.

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ved235 sorry for the super late reply, especially since it will be a negative response.

This is something we need to fix indeed but unfortunately this is not the way to fix it. Basically the ops should not really change based on the shape but the implementation should. The same way that we route to split-k kernel when the matrix K dimension is large.

I am gonna mark this as requested changes so we don't merge it by accident and make a new PR with a specialized kernel for this particular case. I am not sure how important it is but I think the kernel will end up being simple enough.

@Ved235

Ved235 commented Jun 13, 2026

Copy link
Copy Markdown
Author

Should I make a PR for this specialised kernel?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants