Skip to content

Optimize CPU masked_scatter for contiguous arrays#3670

Draft
AK-Khan02 wants to merge 1 commit into
ml-explore:mainfrom
AK-Khan02:ak/masked-scatter-cpu-perf
Draft

Optimize CPU masked_scatter for contiguous arrays#3670
AK-Khan02 wants to merge 1 commit into
ml-explore:mainfrom
AK-Khan02:ak/masked-scatter-cpu-perf

Conversation

@AK-Khan02

@AK-Khan02 AK-Khan02 commented Jun 13, 2026

Copy link
Copy Markdown

Summary

Closes #3669.

Adds a CPU fast path for masked_scatter when mask, src, and out are row-contiguous. The existing ContiguousIterator implementation remains the fallback for strided and broadcasted cases.

Why

The CPU implementation previously used the general iterator path for all cases. For contiguous arrays, direct pointer indexing avoids per-element iterator/stride bookkeeping while preserving the same source-consumption semantics.

Local benchmark

CPU-only local benchmark, float32:

  • 4M elements, 1% mask density: 5.746 ms -> 2.548 ms
  • 4M elements, 10% mask density: 8.320 ms -> 4.467 ms
  • 4M elements, 50% mask density: 23.816 ms -> 14.140 ms

Validation

PYTHONPATH=python/tests /tmp/mlx-mask-venv/bin/python -m unittest \
  python.tests.test_ops.TestOps.test_masked_scatter \
  python.tests.test_vmap.TestVmap.test_vmap_masked_scatter

@AK-Khan02

Copy link
Copy Markdown
Author

WIP: this currently adds a scoped contiguous CPU fast path for masked_scatter with measurable local speedups. I’m still exploring whether a chunked prefix-count approach or other approaches could safely optimize the larger/general path further, so feedback on the current fast-path shape is welcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimize CPU masked_scatter for contiguous inputs

1 participant