Skip to content

Register SyncBatchNorm as quantization module#1491

Open
5had3z wants to merge 1 commit into
NVIDIA:mainfrom
5had3z:fix/register-syncbn
Open

Register SyncBatchNorm as quantization module#1491
5had3z wants to merge 1 commit into
NVIDIA:mainfrom
5had3z:fix/register-syncbn

Conversation

@5had3z
Copy link
Copy Markdown

@5had3z 5had3z commented May 14, 2026

What does this PR do?

Type of change: Bug fix

Registers nn.SyncBatchNorm layer for quantization. If a model is setup for distributed training before PTQ, none of the SyncBatchNorm layers are recognised and quantized. On loading of a checkpoint there is now a mismatch between the modelopt state of a model that hasn't had DDP/SyncBN applied to it and the checkpoint trained with DDP/SyncBN.

Performing PTQ and then applying DDP/SyncBN for QAT works fine, but considering that unwrapping DDP is handled properly for either ordering of the steps, SyncBN conversion should be able to be performed in either order as well.

Usage

## train.py
model = get_model()
# DDP Setup
nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = nn.parallel.DistributedDataParallel(
    model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
)
# PTQ
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib)
mtq.print_quant_summary(model) # Missing SyncBN layers
# QAT
train(model)
# Save Checkpoint
torch.save(mto.modelopt_state(model), "modelopt.pt")
torch.save(model.module.state_dict(), "params.pt")

## inference.py
model = get_model()
# Below fails as nn.BatchNorm2d in current model state does not have state in checkpoint since 
# nn.SyncBatchNorm modules were skipped over.
model = restore_model_from_modelopt_state(model, torch.load("modelopt.pt", weights_only=False))
mode.load_state_dict(torch.load("params.pt")

Testing

Added nn.SyncBatchNorm to the quantization tests where other BatchNorm layers appear..

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ - Models trained with missing norm from their modelopt state dict will now have this depending on initialization order.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ❌ - lmk if you want this
  • Did you get Claude approval on this PR?: ❌

Additional Information

Code for testing issue, run with python3 script.py or torchrun --nproc-per-node=2 script.py.

from pathlib import Path
import torch
import os
from torch import nn
import torch.distributed as dist
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.datasets.cifar import CIFAR10
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.transforms.v2 import ToTensor
import modelopt.torch.quantization as mtq
import modelopt.torch.opt as mto
from rich.progress import track

assert torch.cuda.is_available(), "NVIDIA GPUs required for distributed training"

torch.cuda.set_device(f"cuda:{os.environ.get('LOCAL_RANK', 0)}")
if dist.is_available() and int(os.environ.get("WORLD_SIZE", 1)) > 1:
    dist.init_process_group(backend="nccl")

model = resnet18(weights=ResNet18_Weights.DEFAULT).cuda()

if dist.is_initialized():  # SyncBN and DDP for training
    nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(
        model, device_ids=[dist.get_rank()], output_device=dist.get_rank()
    )


def calib(m: nn.Module):
    datapath = Path.cwd() / "data"
    datapath.mkdir(exist_ok=True)
    dataset = CIFAR10(datapath, train=False, download=True, transform=ToTensor())
    if dist.is_initialized():
        sampler = DistributedSampler(dataset)
    else:
        sampler = None
    dataloader = DataLoader(dataset, sampler=sampler, num_workers=4, batch_size=8)
    for img, tgt in track(dataloader):
        m(img.cuda())


# PTQ
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib)
if not dist.is_initialized() or dist.get_rank() == 0:
    mtq.print_quant_summary(model)

# Do training

mto_ckpt = Path.cwd() / "opt.pt"
torch.save(mto.modelopt_state(model), mto_ckpt)
param_ckpt = Path.cwd() / "params.pt"
if isinstance(model, nn.parallel.DistributedDataParallel):
    params = model.module.state_dict()
else:
    params = model.state_dict()
torch.save(params, param_ckpt)

# Load 'single' GPU for inference
model = resnet18(weights=ResNet18_Weights.DEFAULT).cuda()
model = mto.restore_from_modelopt_state(
    model, torch.load(mto_ckpt, map_location="cuda", weights_only=False)
)
model.load_state_dict(torch.load(param_ckpt))

Summary by CodeRabbit

  • New Features

    • Added support for synchronized batch normalization layers in quantization workflows.
  • Tests

    • Extended quantization test coverage for synchronized batch normalization module types.
  • Chores

    • Updated quantization configurations to handle synchronized batch normalization layers consistently with other batch norm types.

Review Change Stack

Signed-off-by: Bryce Ferenczi <bryce.ferenczi@Arkeus.com>
@5had3z 5had3z requested review from a team as code owners May 14, 2026 04:48
@5had3z 5had3z requested a review from meenchen May 14, 2026 04:48
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 14, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 208dcebb-65ec-4555-bebc-b8880ec9d5bb

📥 Commits

Reviewing files that changed from the base of the PR and between 229ba61 and f702d84.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/nn/modules/quant_batchnorm.py
  • modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml
  • modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml
  • tests/unit/torch/quantization/test_quant_batchnorm.py

📝 Walkthrough

Walkthrough

This PR adds nn.SyncBatchNorm support to the quantization framework by registering the module type, disabling its quantization by default, applying that configuration to a specific model, and extending test coverage to verify the behavior.

Changes

nn.SyncBatchNorm Quantization Support

Layer / File(s) Summary
Core module registration and default disabling
modelopt/torch/quantization/nn/modules/quant_batchnorm.py, modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml, tests/unit/torch/quantization/test_quant_batchnorm.py
nn.SyncBatchNorm is registered in QuantModuleRegistry, added to the default disabled quantizers configuration, and included in three parametrized test cases (test_no_quant, test_fake_quant_per_tensor, test_fake_quant_per_channel).
Model-specific quantization configuration
modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml
The disable rule for nn.SyncBatchNorm is applied to the Step3.5-Flash model configuration alongside the existing disabled rules for other BatchNorm variants.

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and accurately describes the main change: registering SyncBatchNorm for quantization support, which is the core objective across all modified files.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. Changes are minimal registration and configuration updates with no unsafe deserialization, hardcoded credentials, dangerous APIs, or dependency issues.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@meenchen
Copy link
Copy Markdown
Contributor

@5had3z thanks for the PR. Could you share some background on the use cases for BatchNorm quantization?

@meenchen meenchen requested review from cjluo-nv and realAsma May 18, 2026 23:50
@5had3z
Copy link
Copy Markdown
Author

5had3z commented May 19, 2026

@meenchen SyncBatchNorm is disabled by default, just like the other batch norms. The main purpose of this change is to resolve the problem is illustrated in the example code where if nn.SyncBatchNorm.convert_sync_batchnorm is applied to the model for DDP training before the quantization conversion, none of the normalization layers are registered. When resuming for evaluation on a single GPU, restore_from_modelopt_state will register the conversion of the normalization layers (still disabled by default). Then when it tries to copy the saved modelopt state dict to the model, there is a mismatch in the current state dict (containing norm layers) and the saved state dict (missing norm layers) and an error is raised in torch.quantization.restore_quantizer_state where extra_keys is not empty.

EDIT: Sorry, I thought I had comments of where the problems occur in the example code. Maybe I tweaked something in the test script, re-copied it over and was missing the comments.

EDIT2: Ahh no yes I have comments in the first block, just not the second that is the full end-to-end test code.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Small, well-scoped bug fix that registers nn.SyncBatchNorm analogously to the existing BatchNorm1d/2d/3d registrations. The two YAML configs that enumerate BatchNorm variants are both updated consistently, and the existing parameterized batchnorm tests are extended to cover SyncBatchNorm across test_no_quant, test_fake_quant_per_tensor, and test_fake_quant_per_channel. Explicit registration is necessary since QuantModuleRegistry uses exact class matching. The PR body documents a clear repro and notes the back-compat implication for checkpoints saved before this fix.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness — Incomplete Coverage of Exclusion Lists

[BLOCKER] SyncBatchNorm not added to _default_disabled_quantizer_cfg in config.py
The hardcoded list at modelopt/torch/quantization/config.py:211-213 excludes nn.BatchNorm{1,2,3}d but not nn.SyncBatchNorm. This list is used by FP8_DEFAULT_CFG, INT8_DEFAULT_CFG, and other built-in *_CFG dicts. After this PR, any user using these configs with a SyncBN-containing model will silently have SyncBN layers quantized — changing behavior for non-DDP users too.

[BLOCKER] 5 general PTQ recipes still have inline BatchNorm exclusions without SyncBatchNorm
The PR updates the $import unit (default_disabled_quantizers.yaml) and one model-specific recipe (Step3.5-Flash). But these recipes have inline exclusion lists that don't yet use the unit:

  • modelopt_recipes/general/ptq/fp8_default-fp8_kv.yml
  • modelopt_recipes/general/ptq/nvfp4_default-fp8_kv.yml
  • modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yml
  • modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yml
  • modelopt_recipes/general/ptq/nvfp4_omlp_only-fp8_kv.yml

Each needs a - parent_class: 'nn.SyncBatchNorm' entry. Without these updates, SyncBN layers will be quantized when these recipes are used — inconsistent behavior depending on recipe choice.

@5had3z
Copy link
Copy Markdown
Author

5had3z commented May 19, 2026

[BLOCKER] 5 general PTQ recipes still have inline BatchNorm exclusions without SyncBatchNorm

All of these configs should have imported the changes I have already made in modelopt_recipes/configs/ptq/units/default_disabled_quantizers.yaml. I'm not sure what else needs to change.

imports:
    default_disabled_quantizers: configs/ptq/units/default_disabled_quantizers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants