Skip to content

Add stft and istft to mlx.core.fft#3639

Open
eyupcanakman wants to merge 1 commit into
ml-explore:mainfrom
eyupcanakman:feat/fft-stft-istft
Open

Add stft and istft to mlx.core.fft#3639
eyupcanakman wants to merge 1 commit into
ml-explore:mainfrom
eyupcanakman:feat/fft-stft-istft

Conversation

@eyupcanakman

Copy link
Copy Markdown
Contributor

Proposed changes

Adds stft and istft to mlx.core.fft, closing #1004. Both are written in C++ so they are available from every language binding.

stft frames the last axis with as_strided, applies the window, and runs rfft (or fft for two-sided output). istft inverts that with an overlap-add normalized by the squared-window envelope. The signature follows torch.stft, except norm takes the same "backward"/"ortho"/"forward" string the rest of mlx.fft uses. Batched input works and the output has shape (..., n_freq, n_frames).

Tested against torch.stft and torch.istft over reflect, constant, and edge padding, one-sided and two-sided transforms, the ortho norm, batched input, and n_fft that is not a multiple of hop_length. The C++ tests cover the shapes, an overlap-add round-trip, and the error paths.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Comment thread mlx/fft.cpp

namespace {

// Pad the last axis; mx::pad has no reflect mode, so build it from take.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would #3608 help?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, its reflect mode matches what this helper does and also handles pad >= n, so I can drop the helper and call pad directly once it lands. Want me to rebase on it after it merges, or keep the local version for now?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for verifying, let's just keep the local version for now.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must admit I'm not familiar with short-time fft at all, but the implementation looks good, and the tests are well-written, so this PR looks good to me.

Comment thread mlx/fft.h
StreamOrDevice s = {});

/** Compute the inverse Short-Time Fourier Transform. */
MLX_API array istft(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the function declarations and implementations to the end of the files, so the functions are listed in the same order?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants