[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970
Conversation
- Verifies if a window is correct for a given mask type. If it isn't either force sentinel values or assert. If forcing sentinel values then warn the user - All possible ways of using attn, i.e. DPA, MHA, TL, fused attn APIs are all now guaranteeing that window size will not be None and appropriately set before passing downstream to internal APIs, primitives or classes. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…tract responsibility can be handled by MHA and lower APIs Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…fused attn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci jax L0 L1 |
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
2c5a448 to
c770934
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a
Confidence Score: 5/5Safe to merge — canonicalization is applied consistently at every public entry point and the cpp-extension boundary, double-calls are idempotent, and the PyTorch negative-right-value coerce bug is correctly fixed. The new check_set_window_size branches cover all sentinel combinations correctly, including the previously problematic negative-right-value inputs that are now rejected rather than silently coerced. All module constructors apply the function before the module is frozen, and subsequent calls at lower layers see only already-canonical values, so no spurious warnings will fire in normal usage. No files require special attention. Important Files Changed
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…cing Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…cing for PyTorch framework code Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
| window_size = (-1, -1) | ||
| # Coerce the right side window to -1. | ||
| elif orig_window_size == (-1, 0): | ||
| window_size = (-1, -1) |
There was a problem hiding this comment.
I wonder if we should do this, or go the other direction and change the mask? Technically, this could be a valid combination right? no_mask/padding + swa(left, 0) -> essentially causal + swa(left,0)?
There was a problem hiding this comment.
@cyanguwa I did think a bit about this especially since I had seen it in the PyT check_set_window_size() too.
Because there is a lot of downstream branching on the mask type in the primitives, and none really on the SWA window size, I'd prefer to not coerce the mask and instead coerce the SWA window size instead with a warning. This can also make debugging difficult (because we do change the masks for some of the CP patterns internally without the user being aware and so it just increases the chances of something going wrong when the mask has the ability to be changed in multiple places and often)
Also, smaller concern but if the mask is indeed coerced, it can give the user an incorrect understanding of the support when is_fused_attn_kernel_available() is called by them. They maybe asking about padding masks and may get an answer for padding_causal instead now (which can also be argued for the SWA window but I believe the mask just has more ramifications in general)
There was a problem hiding this comment.
Could we make "no_mask/padding + swa(left,0)" officially supported then, as a first-class citizen, just like any other combination, including "causal/padding_causal/BRCM/PBRCM + swa(left,0)", which is what it's equivalent to anyway.
Description
TE PyT attn uses
check_set_window_size()to regulate thewindow_sizebased on theattn maskand the user passedwindow_size. This is done higher in the stack, so that a limited subset of "valid" values of thewindow_sizepropagate to the backends.Via this PR TE JAX attn tries to mimic this behavior to clean up the SWA mechanism in the different TE JAX attn APIs via a common updating logic. This new function
check_set_window_size()does not constrain the user of the API (rather strips this complexity off for the user) when using (or not using SWA). TE handles the checks and modifications to thewindow_sizeinternally and warns the user about any canonicalization performed, when needed.Type of change
Changes
Update APIs to canonicalize the SWA before passing to the backends
Update internal CP P2P Ring helpers to re-canonicalize when it changes the mask for internal computations
Update tests to validate the chosen SWA window before running the tests
nit: [PyT] Fixed a small QOL check in the right side window for causal cases so as to not actively coerce that value and instead assert it.
Testing
Ran local fused attn + dist fused attn tests on H1008 and GB2004 and they pass successfully.
Next steps
A subsequent PR will fix any issues / missing op in Unfused DPA SWA to make the SWA infrastructure and fallback paths more robust
Checklist: