diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 465408d94693..0ba202f92741 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -586,7 +586,11 @@ def get_mod_params( scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int ) -> tuple[torch.Tensor, ...]: num_ada_params = scale_shift_table.shape[0] - ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape( + # Cast to temb's dtype at the use site (matching the original implementation): + # checkpoints store the scale_shift tables in fp32 alongside bf16 weights, so + # without the cast the fp32 tables promote the modulated hidden states and the + # following linear layers fail on mixed dtypes. + ada_values = scale_shift_table[None, None].to(device=temb.device, dtype=temb.dtype) + temb.reshape( batch_size, temb.shape[1], num_ada_params, -1 ) ada_params = ada_values.unbind(dim=2) @@ -1620,14 +1624,19 @@ def forward( ) # 6. Output layers (including unpatchification) - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + scale_shift_values = ( + self.scale_shift_table[None, None].to(embedded_timestep.dtype) + embedded_timestep[:, :, None] + ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift output = self.proj_out(hidden_states) - audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(audio_embedded_timestep.dtype) + + audio_embedded_timestep[:, :, None] + ) audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] audio_hidden_states = self.audio_norm_out(audio_hidden_states) diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 8a00a0c6b452..cedfb219c937 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -302,13 +302,19 @@ def forward( if binary_attn_mask.ndim == 4: binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - # Replace padding positions with learned registers using vectorized masking - mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + # Move the valid tokens to the front in their original order and fill the tail + # with registers indexed by absolute position, matching the original LTX + # implementation (`_replace_padded_with_learnable_registers`). A stable argsort + # of the inverted mask gathers valid tokens first while preserving their order. + order = torch.argsort(1 - binary_attn_mask, dim=1, stable=True) # [B, L] + front_aligned = torch.gather( + hidden_states, 1, order.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + ) + num_valid = binary_attn_mask.sum(dim=1, keepdim=True) # [B, 1] + positions = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) # [1, L] + front_mask = (positions < num_valid).unsqueeze(-1) # [B, L, 1] registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] - hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - - # Flip sequence: embeddings move to front, registers to back (from left padding layout) - hidden_states = torch.flip(hidden_states, dims=[1]) + hidden_states = torch.where(front_mask, front_aligned, registers_expanded.to(hidden_states.dtype)) # Overwrite attention_mask with an all-zeros mask if using registers. attention_mask = torch.zeros_like(attention_mask) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index e0e858bb6916..42202303e4f3 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -115,6 +115,31 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin): """Core model tests for LTX2 Video Transformer.""" + def test_fp32_scale_shift_tables_match_uniform_dtype(self): + # Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32 + # alongside bf16 weights. The tables are cast to the activation dtype at + # the use site (as in the original implementation), so a natively loaded + # mixed-dtype model must run and produce the same outputs as one whose + # tables were flattened to the weight dtype at load time. + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()).to(torch.bfloat16).to(torch_device).eval() + inputs = { + key: value.to(torch.bfloat16) if isinstance(value, torch.Tensor) and value.is_floating_point() else value + for key, value in self.get_dummy_inputs().items() + } + + with torch.no_grad(): + reference = model(**inputs) + + for name, param in model.named_parameters(): + if "scale_shift_table" in name: + param.data = param.data.float() + with torch.no_grad(): + mixed = model(**inputs) + + assert torch.equal(reference[0], mixed[0]) + assert torch.equal(reference[1], mixed[1]) + class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for LTX2 Video Transformer."""