From 33c76774cc1ee693738d7efb861b7420f83fe972 Mon Sep 17 00:00:00 2001 From: HaozheZhang6 Date: Thu, 11 Jun 2026 17:26:34 +0000 Subject: [PATCH 1/2] Fix `Ideogram4MRoPE` collapsing under `torch.autocast` (compute rotary in float32) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ideogram4 builds image-token positions as IMAGE_POSITION_OFFSET (65536) + (t, h, w). `Ideogram4MRoPE.forward` casts its operands to float32, but the rotary matmul (and cos/sin) is on autocast's downcast list, so under torch.autocast("cuda", bfloat16) — common in training and pipeline code — it runs in bfloat16 anyway. bfloat16's step at 65536 is 512, so every image position in a <=512 grid rounds to the same value: all image tokens get identical rotary embeddings, spatial information is lost, and the decoded image degenerates to a flat color. Wrap the frequency computation in torch.autocast(enabled=False) so the rotary embeddings are always computed in float32, matching how transformers guards its RoPE modules. Added a regression test that fails on main and passes with the fix. Fixes #13920 --- .../transformers/transformer_ideogram4.py | 33 +++++++++++-------- .../test_models_transformer_ideogram4.py | 18 ++++++++++ 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ideogram4.py b/src/diffusers/models/transformers/transformer_ideogram4.py index 121118e3bd80..4c3761f8589d 100644 --- a/src/diffusers/models/transformers/transformer_ideogram4.py +++ b/src/diffusers/models/transformers/transformer_ideogram4.py @@ -70,20 +70,25 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso raise ValueError(f"`position_ids` must have shape (B, L, 3), got {tuple(position_ids.shape)}.") batch_size, seq_len, _ = position_ids.shape - pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32) - inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1) - freqs = inv_freq @ pos.unsqueeze(2) - freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) - - # Interleaved mrope: pull H freqs into idx 1 mod 3, W freqs into idx 2 mod 3. - freqs_t = freqs[0].clone() - for axis, offset in ((1, 1), (2, 2)): - length = self.mrope_section[axis] * 3 - idx = torch.arange(offset, length, 3, device=freqs_t.device) - freqs_t[..., idx] = freqs[axis][..., idx] - - emb = torch.cat((freqs_t, freqs_t), dim=-1) - return emb.cos(), emb.sin() + # Rotary frequencies must be computed in float32: Ideogram4's image positions start at + # IMAGE_POSITION_OFFSET (65536), so an ambient autocast would otherwise run the matmul and + # cos/sin in bfloat16, rounding every image position to the same value and collapsing the + # rotary embeddings (all spatial information is lost). + with torch.autocast(device_type=position_ids.device.type, enabled=False): + pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32) + inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1) + freqs = inv_freq @ pos.unsqueeze(2) + freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) + + # Interleaved mrope: pull H freqs into idx 1 mod 3, W freqs into idx 2 mod 3. + freqs_t = freqs[0].clone() + for axis, offset in ((1, 1), (2, 2)): + length = self.mrope_section[axis] * 3 + idx = torch.arange(offset, length, 3, device=freqs_t.device) + freqs_t[..., idx] = freqs[axis][..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + return emb.cos(), emb.sin() class Ideogram4AttnProcessor: diff --git a/tests/models/transformers/test_models_transformer_ideogram4.py b/tests/models/transformers/test_models_transformer_ideogram4.py index 31592ada64bc..d8e7318d501d 100644 --- a/tests/models/transformers/test_models_transformer_ideogram4.py +++ b/tests/models/transformers/test_models_transformer_ideogram4.py @@ -21,6 +21,7 @@ IMAGE_POSITION_OFFSET, LLM_TOKEN_INDICATOR, OUTPUT_IMAGE_INDICATOR, + Ideogram4MRoPE, ) from diffusers.utils.torch_utils import randn_tensor @@ -164,3 +165,20 @@ def test_gradient_checkpointing_is_applied(self): class TestIdeogram4TransformerAttention(Ideogram4TransformerTesterConfig, AttentionTesterMixin): """Attention processor tests for Ideogram 4 Transformer.""" + + +def test_ideogram4_mrope_is_autocast_invariant(): + # Ideogram4's image positions start at IMAGE_POSITION_OFFSET (65536), so the rotary matmul must + # run in float32: under an ambient autocast it would otherwise execute in bfloat16 and round every + # image position to the same value, collapsing all spatial information (the decoded image goes flat). + rope = Ideogram4MRoPE(head_dim=256, base=5_000_000, mrope_section=(24, 20, 20)).to(torch_device) + position_ids = torch.tensor([[[0, 0, 0], [0, 0, 1], [0, 63, 63]]], device=torch_device) + IMAGE_POSITION_OFFSET + + cos_ref, sin_ref = rope(position_ids) + with torch.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16): + cos_ac, sin_ac = rope(position_ids) + + # Distinct image positions must keep distinct embeddings, identical to the float32 computation. + assert not torch.equal(cos_ac[0, 0], cos_ac[0, 1]) + assert torch.equal(cos_ac, cos_ref) + assert torch.equal(sin_ac, sin_ref) From 717d13c2759716b2fc74a77970df51e25aae4b00 Mon Sep 17 00:00:00 2001 From: HaozheZhang6 Date: Fri, 12 Jun 2026 02:24:58 +0000 Subject: [PATCH 2/2] Compute the rotary frequencies in float64 instead of disabling autocast MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review: replace the torch.autocast(enabled=False) guard with a float64 computation, which autocast does not downcast — matching the float64 rope path used elsewhere (Flux). The autocast and float32 paths stay bit-identical (max|delta|=0). --- .../transformers/transformer_ideogram4.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ideogram4.py b/src/diffusers/models/transformers/transformer_ideogram4.py index 4c3761f8589d..03cc6c84a051 100644 --- a/src/diffusers/models/transformers/transformer_ideogram4.py +++ b/src/diffusers/models/transformers/transformer_ideogram4.py @@ -70,25 +70,24 @@ def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso raise ValueError(f"`position_ids` must have shape (B, L, 3), got {tuple(position_ids.shape)}.") batch_size, seq_len, _ = position_ids.shape - # Rotary frequencies must be computed in float32: Ideogram4's image positions start at - # IMAGE_POSITION_OFFSET (65536), so an ambient autocast would otherwise run the matmul and - # cos/sin in bfloat16, rounding every image position to the same value and collapsing the - # rotary embeddings (all spatial information is lost). - with torch.autocast(device_type=position_ids.device.type, enabled=False): - pos = position_ids.permute(2, 0, 1).to(dtype=torch.float32) - inv_freq = self.inv_freq.to(dtype=torch.float32)[None, None, :, None].expand(3, batch_size, -1, 1) - freqs = inv_freq @ pos.unsqueeze(2) - freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) - - # Interleaved mrope: pull H freqs into idx 1 mod 3, W freqs into idx 2 mod 3. - freqs_t = freqs[0].clone() - for axis, offset in ((1, 1), (2, 2)): - length = self.mrope_section[axis] * 3 - idx = torch.arange(offset, length, 3, device=freqs_t.device) - freqs_t[..., idx] = freqs[axis][..., idx] - - emb = torch.cat((freqs_t, freqs_t), dim=-1) - return emb.cos(), emb.sin() + # Rotary frequencies are computed in float64: Ideogram4's image positions start at + # IMAGE_POSITION_OFFSET (65536), which float32 cannot represent distinctly once an ambient + # autocast runs the matmul/cos/sin in bfloat16, collapsing every image position to the same + # embedding. float64 is not downcast by autocast, matching the float64 rope path Flux uses. + pos = position_ids.permute(2, 0, 1).to(dtype=torch.float64) + inv_freq = self.inv_freq.to(dtype=torch.float64)[None, None, :, None].expand(3, batch_size, -1, 1) + freqs = inv_freq @ pos.unsqueeze(2) + freqs = freqs.transpose(2, 3) # (3, B, L, inv_freq_size) + + # Interleaved mrope: pull H freqs into idx 1 mod 3, W freqs into idx 2 mod 3. + freqs_t = freqs[0].clone() + for axis, offset in ((1, 1), (2, 2)): + length = self.mrope_section[axis] * 3 + idx = torch.arange(offset, length, 3, device=freqs_t.device) + freqs_t[..., idx] = freqs[axis][..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + return emb.cos().float(), emb.sin().float() class Ideogram4AttnProcessor: