From 1028827017e5bbce67eb4f03ecb16e4feccd2208 Mon Sep 17 00:00:00 2001 From: boffee Date: Wed, 10 Jun 2026 20:41:04 -0500 Subject: [PATCH 1/2] Cast LTX2 scale_shift tables to activation dtype at use site Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32 alongside bf16 weights. The original implementation casts the tables to the activation dtype at every use site, but the diffusers port casts device only, so loading a checkpoint with its native dtypes promotes the modulated hidden states to fp32 and the following linear layers raise "mat1 and mat2 must have the same dtype". Restore the use-site dtype cast; outputs are bit-identical to a model whose tables were flattened to the weight dtype at load time. Co-Authored-By: Claude Fable 5 --- .../models/transformers/transformer_ltx2.py | 15 ++++++++--- .../test_models_transformer_ltx2.py | 25 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 465408d94693..0ba202f92741 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -586,7 +586,11 @@ def get_mod_params( scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int ) -> tuple[torch.Tensor, ...]: num_ada_params = scale_shift_table.shape[0] - ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape( + # Cast to temb's dtype at the use site (matching the original implementation): + # checkpoints store the scale_shift tables in fp32 alongside bf16 weights, so + # without the cast the fp32 tables promote the modulated hidden states and the + # following linear layers fail on mixed dtypes. + ada_values = scale_shift_table[None, None].to(device=temb.device, dtype=temb.dtype) + temb.reshape( batch_size, temb.shape[1], num_ada_params, -1 ) ada_params = ada_values.unbind(dim=2) @@ -1620,14 +1624,19 @@ def forward( ) # 6. Output layers (including unpatchification) - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + scale_shift_values = ( + self.scale_shift_table[None, None].to(embedded_timestep.dtype) + embedded_timestep[:, :, None] + ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift output = self.proj_out(hidden_states) - audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(audio_embedded_timestep.dtype) + + audio_embedded_timestep[:, :, None] + ) audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] audio_hidden_states = self.audio_norm_out(audio_hidden_states) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index e0e858bb6916..42202303e4f3 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -115,6 +115,31 @@ def get_dummy_inputs(self) -> dict[str, torch.Tensor]: class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin): """Core model tests for LTX2 Video Transformer.""" + def test_fp32_scale_shift_tables_match_uniform_dtype(self): + # Published LTX-2 checkpoints store the AdaLN scale_shift tables in fp32 + # alongside bf16 weights. The tables are cast to the activation dtype at + # the use site (as in the original implementation), so a natively loaded + # mixed-dtype model must run and produce the same outputs as one whose + # tables were flattened to the weight dtype at load time. + torch.manual_seed(0) + model = self.model_class(**self.get_init_dict()).to(torch.bfloat16).to(torch_device).eval() + inputs = { + key: value.to(torch.bfloat16) if isinstance(value, torch.Tensor) and value.is_floating_point() else value + for key, value in self.get_dummy_inputs().items() + } + + with torch.no_grad(): + reference = model(**inputs) + + for name, param in model.named_parameters(): + if "scale_shift_table" in name: + param.data = param.data.float() + with torch.no_grad(): + mixed = model(**inputs) + + assert torch.equal(reference[0], mixed[0]) + assert torch.equal(reference[1], mixed[1]) + class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for LTX2 Video Transformer.""" From fd2d329e208d37b5ac7adc2710557715d8f94ecf Mon Sep 17 00:00:00 2001 From: boffee Date: Fri, 12 Jun 2026 05:17:19 -0500 Subject: [PATCH 2/2] Fix LTX2 connector register layout to match the original LTX implementation The connector replaced left-padding positions with the tiled registers and then flipped the whole sequence, which put the prompt tokens at the front in reversed order and the register tile reversed within each block. The original LTX implementation (ltx-core _replace_padded_with_learnable_registers, also matched by ComfyUI) front-aligns the valid tokens in their original order and fills the tail with registers indexed by absolute position. Since the connector blocks apply RoPE, the reversed layout produces off-distribution embeddings; short prompts (e.g. negative prompts, whose context is mostly registers) are hit hardest, which manifests as overblown CFG: at cfg > 1 (or CFG++ samplers at cfg 1) the unconditional branch is computed from a mostly-register context with scrambled positions. Replace the fill+flip with a stable-argsort gather (valid tokens to the front, order preserved, per batch row) and fill the tail with the absolute-position register tile. Co-Authored-By: Claude Fable 5 --- src/diffusers/pipelines/ltx2/connectors.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 8a00a0c6b452..cedfb219c937 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -302,13 +302,19 @@ def forward( if binary_attn_mask.ndim == 4: binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - # Replace padding positions with learned registers using vectorized masking - mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + # Move the valid tokens to the front in their original order and fill the tail + # with registers indexed by absolute position, matching the original LTX + # implementation (`_replace_padded_with_learnable_registers`). A stable argsort + # of the inverted mask gathers valid tokens first while preserving their order. + order = torch.argsort(1 - binary_attn_mask, dim=1, stable=True) # [B, L] + front_aligned = torch.gather( + hidden_states, 1, order.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1]) + ) + num_valid = binary_attn_mask.sum(dim=1, keepdim=True) # [B, 1] + positions = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) # [1, L] + front_mask = (positions < num_valid).unsqueeze(-1) # [B, L, 1] registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] - hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - - # Flip sequence: embeddings move to front, registers to back (from left padding layout) - hidden_states = torch.flip(hidden_states, dims=[1]) + hidden_states = torch.where(front_mask, front_aligned, registers_expanded.to(hidden_states.dtype)) # Overwrite attention_mask with an all-zeros mask if using registers. attention_mask = torch.zeros_like(attention_mask)