diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 5425b7b575eb..39012327b61c 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -450,9 +450,9 @@ def _encode_video(self, x: torch.Tensor) -> torch.Tensor: matches Wan2pt2VAEInterface; no autocast (WanVAE was trained with is_amp=False).""" in_dtype = x.dtype dtype = self.vae.dtype - mean = self._vae_latents_mean.to(device=x.device, dtype=dtype) - inv_std = self._vae_latents_inv_std.to(device=x.device, dtype=dtype) raw_mu = retrieve_latents(self.vae.encode(x.to(dtype)), sample_mode="argmax") + mean = self._vae_latents_mean.to(device=raw_mu.device, dtype=dtype) + inv_std = self._vae_latents_inv_std.to(device=raw_mu.device, dtype=dtype) return ((raw_mu - mean.view(1, -1, 1, 1, 1)) * inv_std.view(1, -1, 1, 1, 1)).to(in_dtype) def decode_sound(self, latent: torch.Tensor) -> torch.Tensor: