From 9f099d42292bd1ecafcc0dfd626954be21952f30 Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Tue, 9 Jun 2026 15:34:57 -0700 Subject: [PATCH] fix(cosmos3): pin VAE latent norm buffers to encode output device Under sharded placement (device_map="balanced"), vae.encode() runs on the VAE's own device while the mean/inv_std buffers were pinned to x.device, causing a cross-device RuntimeError. Compute raw_mu first, then pin the normalization buffers to its device so all tensors share one device. --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 5425b7b575eb..39012327b61c 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -450,9 +450,9 @@ def _encode_video(self, x: torch.Tensor) -> torch.Tensor: matches Wan2pt2VAEInterface; no autocast (WanVAE was trained with is_amp=False).""" in_dtype = x.dtype dtype = self.vae.dtype - mean = self._vae_latents_mean.to(device=x.device, dtype=dtype) - inv_std = self._vae_latents_inv_std.to(device=x.device, dtype=dtype) raw_mu = retrieve_latents(self.vae.encode(x.to(dtype)), sample_mode="argmax") + mean = self._vae_latents_mean.to(device=raw_mu.device, dtype=dtype) + inv_std = self._vae_latents_inv_std.to(device=raw_mu.device, dtype=dtype) return ((raw_mu - mean.view(1, -1, 1, 1, 1)) * inv_std.view(1, -1, 1, 1, 1)).to(in_dtype) def decode_sound(self, latent: torch.Tensor) -> torch.Tensor: