Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 4 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 4 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync and Async checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • DCP Sync Checkpoint loading

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Async Checkpointing

    • to_new_empty function with device="cpu" is being used in Async Checkpointing. This function returned Quantizer.make_empty without setting the device. For device = "cpu" we now dequantize. So that the Async checkpointing directly saves the bf16 data on disk and reload works fine.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes DCP sync and async checkpoint loading for MXFP8/NVFP4 and other quantization recipes under FSDP2 with QuantizedModelInit. The core changes make QuantizedTensor subclasses safely round-trip through torch.save/torch.load(weights_only=True) by registering them with add_safe_globals, moving __reduce_ex__ helpers to module-level functions (avoiding pybind11 classmethod reductions), adding a Quantizer.__getstate__/__setstate__ pair that serializes the TE_DType enum as a plain int, and intercepting aten._to_copy.default to move internal FP8 buffers to CPU without dequantizing.

  • DCP sync fix: QuantizedTensor.untyped_storage() now returns a zero-byte storage, preventing FSDP2's same-tensor identity check from incorrectly matching the BF16 sharded checkpoint tensor against the FP8 parameter's empty storage.
  • DCP async fix: A new _to_copy.default dispatch handler preserves the QuantizedTensor subclass when FSDP2 stages buffers to CPU, and all __reduce_ex__ helpers now serialize fp8_dtype/fp4_dtype as int to stay free of pybind11 enum reductions that weights_only=True rejects.
  • Test cleanup: Numerous xfail markers for MXFP8/NVFP4/Float8BlockScaling are removed now that the underlying bugs are resolved.

Confidence Score: 5/5

Safe to merge; the DCP checkpointing paths are well-scoped and the two edge cases noted are unlikely to be exercised in the FSDP2 checkpoint flow.

The changes are tightly targeted at the DCP checkpointing flow. The add_safe_globals list is specific (no builtins.getattr), the dtype-as-int serialization is internally consistent across __getstate__/__setstate__/__reduce_ex__, and the _to_copy handler correctly moves FP8 buffers to CPU for async staging. The two concerns — dtype-only casts now returning a QuantizedTensor instead of dequantizing, and transpose-only Float8Tensor pickling data=None — are both outside the FSDP2 checkpoint hot path and are low-risk for this PR's stated scope.

transformer_engine/pytorch/quantized_tensor.py (new _to_copy handler intercepts all _to_copy.default calls including dtype-only) and transformer_engine/pytorch/tensor/float8_tensor.py (removed CPU dequantization fallback in __reduce_ex__).

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() returning empty storage, new _to_copy.default handler to preserve QuantizedTensor on device moves, Quantizer.__getstate__/__setstate__ for int-serialized dtype, and updates cpu() and make_like to pass device explicitly; the _to_copy handler now intercepts dtype-only casts that previously dequantized.
transformer_engine/pytorch/init.py Registers all QuantizedTensor subclasses, storage mixins, quantizer types, and module-level reconstructor functions with torch.serialization.add_safe_globals to allow weights_only=True DCP async-staging round-trips; imports for storage classes were already present so no NameError risk.
transformer_engine/pytorch/tensor/float8_tensor.py Moves _make_in_reduce_ex classmethod to module-level _make_float8_tensor_in_reduce_ex with fp8_dtype as int, removes the CPU-dequantization fallback in __reduce_ex__, and updates Float8CurrentScalingQuantizer.__getstate__ to call super; removing the fallback means a transpose-only CPU tensor would pickle data=None.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Moves _make_in_reduce_ex classmethod to module-level function and serializes fp8_dtype as int in __reduce_ex__; straightforward refactor with no logic changes to data paths.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Moves _make_in_reduce_ex classmethod to module-level and serializes fp4_dtype as int; NVFP4Quantizer.__getstate__ now calls super().__getstate__() to leverage base-class dtype serialization before clearing amax_reduction_group.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Removes the per-class untyped_storage() override (now inherited from QuantizedTensor) and migrates _make_in_reduce_ex to module-level with int-encoded fp8_dtype; clean change with no logic regressions.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Removes device key from get_metadata() since callers (make_like, _to_copy handler, _IdentityFunc) now pass device explicitly, eliminating potential duplicate-kwarg conflicts.
transformer_engine/pytorch/module/base.py Adds NVFP4Quantizer to the isinstance check that sets amax_reduction_group from the DTensor device mesh, enabling correct FSDP2 all-gather for NVFP4 sharded parameters.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Removes multiple xfail markers for MXFP8/NVFP4/Float8BlockScaling now that underlying bugs are fixed; Float8BlockScaling SM120-specific xfail is correctly collapsed to cover both sync and async.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Removes xfail guards for Float8BlockScaling+fp8_init and NVFP4+fp8_init+TransformerLayer; correctness of numerical tolerances for NVFP4 was flagged in a prior review thread.

Sequence Diagram

sequenceDiagram
    participant FSDP2
    participant QT as QuantizedTensor
    participant Dispatch as __torch_dispatch__
    participant Storage as Internal Buffers

    Note over FSDP2,Storage: DCP Async Save
    FSDP2->>QT: "tensor.to(device=cpu)"
    QT->>Dispatch: aten._to_copy.default
    Dispatch->>Storage: move all internal buffers to CPU
    Dispatch-->>FSDP2: CPU QuantizedTensor (preserved type)
    FSDP2->>QT: torch.save(cpu_tensor, ...)
    QT->>QT: "__reduce_ex__ - _make_*_in_reduce_ex (fp8_dtype as int)"
    FSDP2->>QT: "torch.load(..., weights_only=True)"
    Note right of QT: add_safe_globals enables reconstruction
    QT-->>FSDP2: QuantizedTensor restored

    Note over FSDP2,Storage: DCP Sync Load - same-tensor check
    FSDP2->>QT: param.untyped_storage()
    QT-->>FSDP2: UntypedStorage(0 bytes)
    FSDP2->>FSDP2: "storage(bf16_ckpt) == storage(fp8_param)?"
    Note right of FSDP2: Always False - no false same-tensor match
Loading

Reviews (4): Last reviewed commit: "address review comment" | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Comment on lines +820 to +828
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Typo: "neec" should be "need" — appears in both NVFP4 tolerance blocks.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)

Comment on lines +867 to +875
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same typo ("neec") in the second NVFP4 tolerance block for the post-training-step check.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)

Comment on lines 243 to 244
# NVFP4 scale unpad/repad through FSDP2 introduces small numerical
# differences vs the manual dequantize-then-allgather path.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tolerance relaxed 250× for NVFP4 allgather verification

The tolerance for _check_fp8_fsdp2_allgather on NVFP4Tensor jumped from atol=5e-4, rtol=5e-3 to atol=0.125, rtol=0.25. This test compares param.dequantize() against fp32_allgathered_params[name], validating round-trip numerical fidelity of the all-gather path. A 25% relative tolerance makes the check nearly a no-op for FP4 values. A comment citing the 4-bit mantissa precision ceiling would justify the new values.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:

  • It's uncomfortable how new_empty and empty_like would have different behavior. I suppose we could interpret empty_like as "make a tensor that matches the input" and new_empty as "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows.
  • Would this affect FSDP or CPU offloading?
  • Given the weirdness, would it be worthwhile raising a warning if new_empty is called outside of DCP?

# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

# differences vs the manual dequantize-then-allgather path.
if isinstance(param, NVFP4Tensor):
tols = dict(atol=5e-4, rtol=5e-3)
tols = dict(atol=0.125, rtol=0.25)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the tolerances so much bigger? Is it also due to the dequant+quant path? If so, the comment above is no longer relevant and should be replaced with a better one (but I would still like an explanation why we cannot just load the nvfp4 values from the checkpoint).

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see now why you want to dequantize. I don't think this is needed though - we should be able to create the QuantlizedTensor on the CPU and save it, no? I remember that the CPU offloading of the activations faced similar problem and already had to support some CPU ops on the QuantizedTensor anyway.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 force-pushed the fsdp2_dcp_laod_fix branch from 3589ffa to 4197bee Compare May 13, 2026 04:00
@vthumbe1503 vthumbe1503 requested a review from ksivaman as a code owner May 13, 2026 04:00
pre-commit-ci Bot and others added 2 commits May 13, 2026 04:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

Comment thread transformer_engine/pytorch/__init__.py Outdated
# allow-listed for ``torch.load(weights_only=True)`` (used
# internally by DCP async-staging) to accept the stream.
_TE_DType,
getattr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 security Session-wide getattr whitelisted for weights_only=True loading

getattr is registered as a safe global at module import time. add_safe_globals is process-wide in PyTorch, so any torch.load(…, weights_only=True) call made anywhere in a session that has imported transformer_engine.pytorch — including checkpoint loads for entirely different models — now has getattr available to the pickle stream. A malicious checkpoint loaded elsewhere could use getattr to access sensitive attributes of any already-constructed object reachable from the whitelisted globals (e.g. getattr(Float8Quantizer_instance, 'amax_reduction_group') to obtain a process group, or to build callable gadget chains). The weights_only=True flag is specifically a defence against untrusted pickle payloads; adding a general-purpose reflective accessor defeats that defence.

A targeted alternative: serialize _fp8_dtype as its integer value (int(self._fp8_dtype)) and reconstruct it in _make_in_reduce_ex via TE_DType(int_value), then add TE_DType to safe globals instead of getattr. This preserves the weights_only invariant without whitelisting a reflective accessor.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants