Skip to content

Add reflect and symmetric padding modes to mx.pad#3608

Open
katlun-lgtm wants to merge 3 commits into
ml-explore:mainfrom
katlun-lgtm:add-reflect-symmetric-pad
Open

Add reflect and symmetric padding modes to mx.pad#3608
katlun-lgtm wants to merge 3 commits into
ml-explore:mainfrom
katlun-lgtm:add-reflect-symmetric-pad

Conversation

@katlun-lgtm

Copy link
Copy Markdown

Summary

Adds numpy.pad-compatible "reflect" and "symmetric" modes to mx.pad. These join the existing "constant" and "edge" modes.

Both modes match numpy.pad semantics for arbitrary pad sizes — when the pad width exceeds the axis length the reflection repeats, exactly as NumPy does. (Earlier attempts at these modes were limited to pad < dim; this implementation removes that restriction.)

  • reflect — mirror padding that does not repeat the edge value (period 2(n-1)).
  • symmetric — mirror padding that does repeat the edge value (period 2n).

Implementation

reflect_pad (in mlx/ops.cpp) builds a per-axis index map with a triangle-wave reflection function and gathers with take — one take per padded axis. A degenerate axis of length 1 maps every coordinate to 0. No new primitive or kernel is introduced; it composes existing ops, so it works on every backend and is differentiable for free.

Files changed

  • mlx/ops.cppreflect_pad helper + reflect / symmetric dispatch branches in pad.
  • python/src/ops.cpp — extend the mode Literal and docstring.
  • python/tests/test_ops.pytest_pad_reflect_symmetric.
  • tests/ops_tests.cpp — reflect/symmetric CHECK cases incl. multi-reflect.

Testing

All run locally on an M3 Max:

  • Python test_pad_reflect_symmetric — 13 shape/pad-width cases × 2 modes compared element-for-element against numpy.pad (in-bounds, multi-reflect where pad ≫ axis, asymmetric per-axis, zero-width sides, and degenerate axes n==1, n==2). Exact match.
  • C++ tests/ops_tests.cpp "test pad" — 9 assertions pass.
  • Full C++ suite251 cases / 251 passed, 3442 assertions / 0 failed.
>>> import mlx.core as mx
>>> a = mx.array([1, 2, 3])
>>> mx.pad(a, 2, mode="reflect")
array([3, 2, 1, 2, 3, 2, 1], dtype=int32)
>>> mx.pad(a, 2, mode="symmetric")
array([2, 1, 1, 2, 3, 3, 2], dtype=int32)

@katlun-lgtm katlun-lgtm marked this pull request as ready for review May 30, 2026 20:22
@zcbenz

zcbenz commented Jun 1, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the PR, there was actually #2862 which is still waiting for a thorough review to merge. I think your PR is more elegant and I'm more in favor of this one, just in case do you think we are missing anything from #2862?

katlun-lgtm and others added 2 commits June 13, 2026 09:10
Implements numpy.pad-compatible "reflect" and "symmetric" modes for
mx.pad, matching numpy semantics for arbitrary pad sizes (the reflection
repeats when the pad width exceeds the axis length).

- mlx/ops.cpp: reflect_pad helper builds a per-axis triangle-wave index
  map and gathers with take; one take per padded axis. reflect uses
  period 2(n-1) and skips the edge; symmetric uses period 2n and repeats
  the edge. n==1 maps to 0.
- python/src/ops.cpp: extend the pad mode Literal and docstring.
- python/tests/test_ops.py: test_pad_reflect_symmetric covers in-bounds,
  multi-reflect, asymmetric per-axis, zero-width sides, and degenerate
  axes (n==1, n==2), checked against numpy.pad.
- tests/ops_tests.cpp: reflect/symmetric CHECK cases incl. multi-reflect.
@zcbenz zcbenz force-pushed the add-reflect-symmetric-pad branch from 63a3e46 to 579bfd3 Compare June 13, 2026 00:11
@katlun-lgtm

Copy link
Copy Markdown
Author

Thanks @zcbenz! I went through #2862 carefully — functionally the two cover the same ground (reflect + symmetric, matching numpy.pad including the edge-excluded vs edge-included distinction), so nothing user-facing is missing.

Differences I found:

  • Implementation: [Feature] Additional padding modes #2862 builds the padding with a slice/concatenate/tile/slice_update chain (a separate function per mode). This PR instead computes a single triangle-wave index map per padded axis and does one take (gather), with both modes sharing one helper via an include_edge flag — fewer lines and a lighter op graph.
  • Multi-reflect (pad wider than the axis): handled in both; here it falls out of the modular index map rather than needing explicit tiling/reps logic.
  • Tests: the suite here is a superset — it adds degenerate axes (n == 1, n == 2), multi-reflect on both sides simultaneously, asymmetric per-axis widths, and 3D cases, all checked elementwise against numpy.pad.
  • The one thing [Feature] Additional padding modes #2862 had that this PR didn't — an ACKNOWLEDGMENTS.md entry — I've just added.

Happy to adjust naming or docs to whatever you prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants