Skip to content

Fix use-after-free when a custom Metal kernel is called with different dtypes in one graph#3662

Open
discobot wants to merge 1 commit into
ml-explore:mainfrom
discobot:fix/3347-custom-kernel-dtype-cache
Open

Fix use-after-free when a custom Metal kernel is called with different dtypes in one graph#3662
discobot wants to merge 1 commit into
ml-explore:mainfrom
discobot:fix/3347-custom-kernel-dtype-cache

Conversation

@discobot

Copy link
Copy Markdown

Proposed changes

Fixes #3347.

Calling the same mx.fast.metal_kernel with different input dtypes inside one lazy graph regenerates different source under the same kernel name, and the cache eviction (clear_library) releases a pipeline state that an uncommitted command buffer still references (mechanism detailed in #3347). This implements the approach @zcbenz suggested on #3434: put the input dtypes in the generated kernel name, so a given name always maps to exactly one source and dtype variants coexist in the cache instead of evicting each other.

Dtypes turn out not to be the only per-call fact write_signature bakes into the source: the address space choice (constant for inputs with size() < 8) and pass-by-reference for 0-dim inputs vary per call too. The generated name now encodes all of it — each input's dtype with a c/s suffix for those two cases, plus the output dtypes — and max_constant_array_size is hoisted to file scope so the naming and write_signature can't drift. The clear_library path is kept for genuinely different user sources reusing one name (covered by the existing test_custom_kernel_caching).

Added a regression test (test_custom_kernel_mixed_dtypes) that composes float16 and float32 invocations of one kernel in a single graph; with METAL_DEVICE_WRAPPER_TYPE=1 as CI sets, the unfixed code aborts on it with the command buffer references deallocated object assertion. Also updated the generated-name example in the custom Metal kernels docs (names now look like custom_kernel_myexp_float_float16_float16).

With the fix, test_fast.py (24 tests) and test_export_import.py (16 tests) pass under Metal API validation.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

…t dtypes in one graph

The source generated for mx.fast.metal_kernel embeds the input/output
dtypes and how each input is passed, but the kernel name only encoded the
template arguments. Calling the same kernel with a different input dtype in
one lazy graph regenerated different source under the same name, and the
custom kernel cache then cleared the old library while its pipeline state
was still referenced by the uncommitted command buffer, causing a
use-after-free (an abort under Metal API validation, sporadic NaNs
otherwise). Append the input/output dtypes and the input address space and
reference markers to the generated kernel name so a given name always maps
to exactly one source. Adds a regression test composing float16 and float32
calls of one kernel in a single graph. Fixes ml-explore#3347.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] mx.fast.metal_kernel: use-after-free when multiple custom kernels compose in lazy graph

1 participant