Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin

## Ideogram4LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.Ideogram4LoraLoaderMixin

## LoraBaseMixin

[[autodoc]] loaders.lora_base.LoraBaseMixin
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def text_encoder_attn_modules(text_encoder):
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
"Ideogram4LoraLoaderMixin",
"ErnieImageLoraLoaderMixin",
"CosmosLoraLoaderMixin",
]
Expand Down Expand Up @@ -128,6 +129,7 @@ def text_encoder_attn_modules(text_encoder):
HeliosLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
Ideogram4LoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
Expand Down
85 changes: 85 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,3 +2883,88 @@ def get_alpha_scales(down_weight, alpha_key):

converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict


def _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict):
"""
Convert non-diffusers Ideogram4 LoRA state dict to diffusers format.

Handles:
- `diffusion_model.` / `conditional_transformer.` prefix removal
- `lora_down`/`lora_up` (kohya) -> `lora_A`/`lora_B`, with `.alpha` folded into the weights
- fused `attention.qkv` -> split `to_q`/`to_k`/`to_v`; `attention.o` -> `to_out.0`
- `feed_forward.w1`/`w2`/`w3` and `adaln_modulation` map one-to-one
"""
for prefix in ("diffusion_model.", "conditional_transformer."):
if any(k.startswith(prefix) for k in state_dict):
state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items()}
break

is_kohya = any(".lora_down.weight" in k for k in state_dict)
down_suffix = ".lora_down.weight" if is_kohya else ".lora_A.weight"
up_suffix = ".lora_up.weight" if is_kohya else ".lora_B.weight"

def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha_tensor = state_dict.pop(alpha_key, None)
if alpha_tensor is None:
return 1.0, 1.0
# LoRA is scaled by `alpha / rank` in the forward pass; split the factor between down and up.
scale = alpha_tensor.item() / rank
scale_down, scale_up = scale, 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up

def pull(base):
"""Pop the scaled (lora_A, lora_B) pair for a module path, or return None if absent."""
down_key = base + down_suffix
if down_key not in state_dict:
return None
down = state_dict.pop(down_key)
up = state_dict.pop(base + up_suffix)
scale_down, scale_up = get_alpha_scales(down, base + ".alpha")
return down * scale_down, up * scale_up

num_layers = 0
for k in state_dict:
match = re.match(r"layers\.(\d+)\.", k)
if match:
num_layers = max(num_layers, int(match.group(1)) + 1)

converted_state_dict = {}
for i in range(num_layers):
layer_prefix = f"layers.{i}"

# Fused qkv -> split to_q / to_k / to_v (shared down/lora_A, chunk up/lora_B in thirds).
qkv = pull(f"{layer_prefix}.attention.qkv")
if qkv is not None:
down, up = qkv
up_q, up_k, up_v = torch.chunk(up, 3, dim=0)
for proj, up_proj in (("to_q", up_q), ("to_k", up_k), ("to_v", up_v)):
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_A.weight"] = down.clone()
converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_B.weight"] = up_proj.contiguous()

# attention.o -> attention.to_out.0
out = pull(f"{layer_prefix}.attention.o")
if out is not None:
down, up = out
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_A.weight"] = down
converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_B.weight"] = up

# feed_forward.{w1,w2,w3} and adaln_modulation map one-to-one.
for module in ("feed_forward.w1", "feed_forward.w2", "feed_forward.w3", "adaln_modulation"):
pair = pull(f"{layer_prefix}.{module}")
if pair is not None:
down, up = pair
converted_state_dict[f"{layer_prefix}.{module}.lora_A.weight"] = down
converted_state_dict[f"{layer_prefix}.{module}.lora_B.weight"] = up

if len(state_dict) > 0:
raise ValueError(
f"`state_dict` should be empty at this point but has {sorted(state_dict.keys())}. "
"This may be an unsupported Ideogram4 LoRA layout."
)

return {f"transformer.{k}": v for k, v in converted_state_dict.items()}
208 changes: 208 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
_convert_non_diffusers_anima_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_ideogram4_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
Expand Down Expand Up @@ -6018,6 +6019,213 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class Ideogram4LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Ideogram4Transformer2DModel`]. Specific to [`Ideogram4Pipeline`].
"""

_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

# ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused
# `attention.qkv` projection; convert those to the diffusers layout before loading.
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any(
".attention.qkv." in k for k in state_dict
)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
adapter_name: str | None = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: dict | None = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: list[str] | None = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)


class ErnieImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`].
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers.masking_utils import create_causal_mask

from ...image_processor import VaeImageProcessor
from ...loaders import Ideogram4LoraLoaderMixin
from ...models.autoencoders import AutoencoderKLFlux2
from ...models.transformers.transformer_ideogram4 import (
IMAGE_POSITION_OFFSET,
Expand Down Expand Up @@ -137,7 +138,7 @@ def _expand_tensor_to_effective_batch(
return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by)


class Ideogram4Pipeline(DiffusionPipeline):
class Ideogram4Pipeline(DiffusionPipeline, Ideogram4LoraLoaderMixin):
r"""
Text-to-image pipeline for Ideogram4.

Expand Down
Loading