Skip to content

Add Support for LTX-2.3 Models#13217

Merged
dg845 merged 43 commits intomainfrom
ltx2-3-pipeline
Mar 19, 2026
Merged

Add Support for LTX-2.3 Models#13217
dg845 merged 43 commits intomainfrom
ltx2-3-pipeline

Conversation

@dg845
Copy link
Copy Markdown
Collaborator

@dg845 dg845 commented Mar 6, 2026

What does this PR do?

This PR adds support for LTX-2.3 (official code, model weights), a new model in the LTX-2.X family of audio-video models. LTX-2.3 has improved audio and visual quality and prompt adherence as compared to LTX-2.0.

T2V Example
import torch
from diffusers import LTX2Pipeline
from diffusers.pipelines.ltx2.export_utils import encode_video

model_id = "dg845/LTX-2.3-Diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

frame_rate = 24.0
width = 768
height = 512
num_inference_steps = 30

prompt = (
    "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
    "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
    "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
    "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
    "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
    "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
    "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
    "breath-taking, movie-like shot."
)
negative_prompt = (
    "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
    "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
    "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
    "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
    "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
    "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
    "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
    "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
    "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
    "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
    "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)

pipe = LTX2Pipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

generator = torch.Generator(device).manual_seed(seed)
video, audio = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=num_inference_steps,
    guidance_scale=3.0,
    stg_scale=1.0,
    modality_scale=3.0,
    guidance_rescale=0.7,
    audio_guidance_scale=7.0,
    audio_stg_scale=1.0,
    audio_modality_scale=3.0,
    audio_guidance_rescale=0.7,
    spatio_temporal_guidance_blocks=[28],
    generator=generator,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_3_t2v.mp4",
)
I2V Example
import torch
from diffusers import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image

model_id = "dg845/LTX-2.3-Diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

frame_rate = 24.0
width = 768
height = 512
num_inference_steps = 30

prompt = (
    "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
    "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
    "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
    "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
    "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
    "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
    "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
    "breath-taking, movie-like shot."
)
negative_prompt = (
    "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
    "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
    "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
    "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
    "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
    "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
    "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
    "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
    "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
    "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
    "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)

pipe = LTX2ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

generator = torch.Generator(device).manual_seed(seed)
video, audio = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=num_inference_steps,
    guidance_scale=3.0,
    stg_scale=1.0,
    modality_scale=3.0,
    guidance_rescale=0.7,
    audio_guidance_scale=7.0,
    audio_stg_scale=1.0,
    audio_modality_scale=3.0,
    audio_guidance_rescale=0.7,
    spatio_temporal_guidance_blocks=[28],
    generator=generator,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_3_i2v.mp4",
)
FLF2V Example
import torch
from diffusers import LTX2ConditionPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.pipeline_ltx2_condition import LTX2VideoCondition
from diffusers.utils import load_image

model_id = "dg845/LTX-2.3-Diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

frame_rate = 24.0
width = 768
height = 512
num_inference_steps = 30

prompt = (
    "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are "
    "delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright "
    "sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, "
    "low-angle perspective."
)
negative_prompt = (
    "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
    "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
    "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
    "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
    "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
    "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
    "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
    "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
    "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
    "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
    "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)

first_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)
last_image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)
first_cond = LTX2VideoCondition(frames=first_image, index=0, strength=1.0)
last_cond = LTX2VideoCondition(frames=last_image, index=-1, strength=1.0)
conditions = [first_cond, last_cond]

pipe = LTX2ConditionPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

generator = torch.Generator(device).manual_seed(seed)
video, audio = pipe(
    conditions=conditions,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=num_inference_steps,
    guidance_scale=3.0,
    stg_scale=1.0,
    modality_scale=3.0,
    guidance_rescale=0.7,
    audio_guidance_scale=7.0,
    audio_stg_scale=1.0,
    audio_modality_scale=3.0,
    audio_guidance_rescale=0.7,
    spatio_temporal_guidance_blocks=[28],
    generator=generator,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_3_flf2v.mp4",
)
I2V Two Stage Example
import torch
from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.utils import load_image

model_id = "dg845/LTX-2.3-Diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

frame_rate = 24.0
width = 768
height = 512
num_inference_steps = 30

prompt = (
    "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
    "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
    "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
    "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
    "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
    "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
    "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
    "breath-taking, movie-like shot."
)
negative_prompt = (
    "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
    "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
    "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
    "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
    "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
    "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
    "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
    "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
    "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
    "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
    "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)

pipe = LTX2ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

generator = torch.Generator(device).manual_seed(seed)
video_latent, audio_latent = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=num_inference_steps,
    guidance_scale=3.0,
    stg_scale=1.0,
    modality_scale=3.0,
    guidance_rescale=0.7,
    audio_guidance_scale=7.0,
    audio_stg_scale=1.0,
    audio_modality_scale=3.0,
    audio_guidance_rescale=0.7,
    spatio_temporal_guidance_blocks=[28],
    use_cross_timestep=True,
    generator=generator,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "dg845/LTX-2.3-Spatial-Upsampler-Diffusers",
    subfolder="latent_upsampler",
    torch_dtype=dtype,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

pipe.load_lora_weights(
    "Lightricks/LTX-2.3",
    adapter_name="stage_2_distilled",
    weight_name="ltx-2.3-22b-distilled-lora-384.safetensors",
)
pipe.set_adapters("stage_2_distilled", 1.0)
# Change scheduler to use Stage 2 distilled sigmas as is
new_scheduler = FlowMatchEulerDiscreteScheduler.from_config(
    pipe.scheduler.config, use_dynamic_shifting=False, shift_terminal=None
)
pipe.scheduler = new_scheduler

video, audio = pipe(
    image=image,
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=3,
    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,  # For Stage 2 distilled, disable all guidance
    stg_scale=0.0,
    modality_scale=1.0,
    guidance_rescale=0.0,
    audio_guidance_scale=1.0,
    audio_stg_scale=0.0,
    audio_modality_scale=1.0,
    audio_guidance_rescale=0.0,
    spatio_temporal_guidance_blocks=None,
    use_cross_timestep=True,
    generator=generator,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_3_i2v_two_stage.mp4",
)
I2V Distilled Example
import torch
from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.pipelines.ltx2.utils import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from diffusers.utils import load_image

model_id = "dg845/LTX-2.3-Distilled-Diffusers"
device = "cuda"
dtype = torch.bfloat16
seed = 42

frame_rate = 24.0
width = 768
height = 512
num_inference_steps = 30

prompt = (
    "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in "
    "gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs "
    "before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small "
    "fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly "
    "shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a "
    "smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the "
    "distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a "
    "breath-taking, movie-like shot."
)
negative_prompt = None

image = load_image(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)

pipe = LTX2ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.vae.enable_tiling()

generator = torch.Generator(device).manual_seed(seed)
video_latent, audio_latent = pipe(
    image=image,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width,
    height=height,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=8,
    sigmas=DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,  # Disable all guidance for distilled inference
    stg_scale=0.0,
    modality_scale=1.0,
    guidance_rescale=0.0,
    audio_guidance_scale=1.0,
    audio_stg_scale=0.0,
    audio_modality_scale=1.0,
    audio_guidance_rescale=0.0,
    spatio_temporal_guidance_blocks=None,
    use_cross_timestep=True,
    generator=generator,
    output_type="latent",
    return_dict=False,
)

latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
    "dg845/LTX-2.3-Spatial-Upsampler-Diffusers",
    subfolder="latent_upsampler",
    torch_dtype=dtype,
)
upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
upsample_pipe.enable_model_cpu_offload(device=device)
upscaled_video_latent = upsample_pipe(
    latents=video_latent,
    output_type="latent",
    return_dict=False,
)[0]

video, audio = pipe(
    image=image,
    latents=upscaled_video_latent,
    audio_latents=audio_latent,
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=width * 2,
    height=height * 2,
    num_frames=121,
    frame_rate=frame_rate,
    num_inference_steps=3,
    noise_scale=STAGE_2_DISTILLED_SIGMA_VALUES[0],
    sigmas=STAGE_2_DISTILLED_SIGMA_VALUES,
    guidance_scale=1.0,  # Disable all guidance for distilled inference
    stg_scale=0.0,
    modality_scale=1.0,
    guidance_rescale=0.0,
    audio_guidance_scale=1.0,
    audio_stg_scale=0.0,
    audio_modality_scale=1.0,
    audio_guidance_rescale=0.0,
    spatio_temporal_guidance_blocks=None,
    use_cross_timestep=True,
    generator=generator,
    output_type="np",
    return_dict=False,
)

encode_video(
    video[0],
    fps=frame_rate,
    audio=audio[0].float().cpu(),
    audio_sample_rate=pipe.vocoder.config.output_sampling_rate,
    output_path="ltx2_3_i2v_distilled.mp4",
)

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu
@sayakpaul

return hidden_states


class LTX2PerturbedAttnProcessor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looking at the code, it's unclear to me whether SkipLayerGuidance currently works for LTX-2.3 for the following reasons:

  1. Not attention backend agnostic: if I understand correctly, STG is implemented through AttentionProcessorSkipHook, which uses AttentionScoreSkipFunctionMode to intercept calls to torch.nn.functional.scaled_dot_product_attention to simply return the value:
    if func is torch.nn.functional.scaled_dot_product_attention:
    But I think other attention backends like flash-attn won't call that function and thus will not work with SkipLayerGuidance.
  2. LTX-2.3 does additional computation on the values: LTX-2.3 additionally processes the values using learned per-head gates before sending it to the attention output projection to_out. This is not supported by the current SkipLayerGuidance implementation.

I'm not sure whether these issues can be resolved with changes to the SkipLayerGuidance implementation or whether something like a new attention processor would make more sense here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have opened a PR with a possible modification to SkipLayerGuidance to allow it to better support LTX-2.3 at #13220.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good callout! From my understanding, guider as a component doesn't change much. LTX-2 is probably an exception. If more models start to do their own form of SLG, we could think of giving them their own guider classes / attention processors. But for now, I think modifications to the existing SLG class make more sense.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge LTx2.3 with a special custom attention processor in this PR first ASAP

the design from the other PR to refator guider is fundamentally wrong - the purpose of hooks (and guider as well) that it modifies behavior from the outside, without the model needing to be aware & implement logic specific to it
i will look to refactor with guiders in the follow-up modular PR

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point on guiders not being backend agnostic is a good thing to keep in mind.

@asomoza asomoza mentioned this pull request Mar 9, 2026
2 tasks
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dg845
Copy link
Copy Markdown
Collaborator Author

dg845 commented Mar 10, 2026

LTX-2.3 diffusers converted checkpoint: dg845/LTX-2.3-Diffusers (may still have bugs).

@dg845
Copy link
Copy Markdown
Collaborator Author

dg845 commented Mar 11, 2026

I2V sample using the example above:

ltx2_3_i2v_stage_1.mp4

This uses dg845/LTX-2.3-Diffusers with CFG + STG + modality guidance with the LTX-2.3 default guidance scales.

@dg845 dg845 marked this pull request as ready for review March 11, 2026 06:02
@dg845 dg845 changed the title [WIP] Add Support for LTX-2.3 Models Add Support for LTX-2.3 Models Mar 16, 2026
@dg845 dg845 requested review from sayakpaul and yiyixuxu March 16, 2026 06:58
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the changes and keeping patience.

The amount of changes (also the ability to navigate them) is a bit overwhelming TBH.

I have left a few comments. Let me know if they make sense. We could consider adding a test-suite mirroring the existing LTX-2 pipeline tests but changing the components with changes specific to LTX-2.3?

LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any issue in unifying LTX2Vocoder and LTX2VocoderWithBWE?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no issue in principle. But because LTX2VocoderWithBWE contains two LTX2Vocoders as submodules it was more natural to me to wrap them in a new module (and it's also more parallel to the original code).

"q_norm": "norm_q",
"k_norm": "norm_k",
# LTX-2.3
"audio_prompt_adaln_single": "audio_prompt_adaln",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this pop up? Distillation checkpoint?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prompt_adaln and audio_prompt_adaln modules are used by both the full model and distilled model to calculate scale/shift modulation parameters for the text encoder_hidden_states for the video and audio modalities respectively. (I believe this is in place of the caption_projections, which were removed in LTX-2.3.)

resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this go at the last of init params to prevent backwards breaking in case someone is using positional arguments?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put upsample_type there because it is follows the argument ordering of LTX2VideoDownBlock3D, which already used an analogous downsample_type argument. I think the positional argument point is valid but IMO there is less risk of it breaking things as I think it's less likely that users are explicitly calling LTX2VideoUpBlock3d on its own.

LTXVideoUpsampler3d(
out_channels * upscale_factor,
self.upsamplers = nn.ModuleList()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like stride is the only factor that varies depending on upsampler_type. So, maybe we could do something like:

if upsample_type == "spatial":
    stride = (1, 2, 2)
elif upsample_type == "temporal":
    stride = (2, 1, 1)
elif upsample_type == "spatio_temporal":
    stride = (2, 2, 2)

self.upsamplers.append(..., strides=strides)

WDYT?

return hidden_states


class LTX2PerturbedAttnProcessor:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point on guiders not being backend agnostic is a good thing to keep in mind.

Comment on lines +1211 to +1212
self_attention_mask=None,
audio_self_attention_mask=None,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these used by the other pipelines, such as I2V?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are not used by any currently implemented pipeline. They might be used in pipelines that are in the LTX-2 code but not yet implemented in diffusers.

)
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
# Use delta formulation as it works more nicely with multiple guidance terms
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(note to other reviewers): guidance is computed a bit latter to account for everything that comes before the computation.


if self.do_modality_isolation_guidance:
with self.transformer.cache_context("uncond_modality"):
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these calls vary from the previous ones in terms of the inputs? If so, it could be nice to add a small comment about it because the call arg list is pretty long.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there is already an existing comment:

# Turn off A2V and V2A cross attn to isolate video and audio modalities
isolate_modalities=True,

noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta

# Apply LTX-2.X guidance rescaling
if self.guidance_rescale > 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we unable to use the rescaling utility?

return x


class SnakeBeta(nn.Module):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL.

Should this go to activations.py? Okay if not.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally it should, although I'm not familiar enough with Snake/SnakeBeta to say whether this is a stable, widely reusable implementation. My impression is that it's more or less standard though (this implementation follows the original LTX-2 code, which itself follows the BigVGAN-V2 implementation).

dg845 and others added 2 commits March 16, 2026 18:37
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)

if self_attention_mask is not None and self_attention_mask.ndim == 3:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we use self_attention_mask? just looked over the pipeline_ltx2 seems it always passed self_attention_mask = None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use self_attention_mask in any of the current diffusers LTX-2.X pipelines.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove these logic then


import numpy as np
import torch
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we probably just create new pipelines for 2.3 - it's getting a bit overwhemling, no?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the pipelines are getting pretty complex. If I understand correctly, most of the code added to the pipelines in the current PR (such as the multimodal guidance strategy with STG) is also supported by the original code for LTX-2.0 checkpoints as well. So the current official inference code for LTX-2.0 and LTX-2.3 checkpoints is mostly the same, and if we follow that the 2.0 and 2.3 pipelines would also mostly be the same.

Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

temporal_upsample: bool = False,
rational_spatial_scale: float | None = 2.0,
rational_spatial_scale: float = 2.0,
use_rational_resampler: bool = True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just a refactor?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is mostly a refactor, the LTX-2.0 checkpoint uses the rational resampler but the LTX-2.3 checkpoint does not. So I think I wanted to decouple these two arguments to make it easier to support checkpoints which do not use the rational resampler.

@tin2tin
Copy link
Copy Markdown

tin2tin commented Mar 19, 2026

In the 2 step and distill examples, this line: use_cross_timestep=True gave me an error.

@dg845
Copy link
Copy Markdown
Collaborator Author

dg845 commented Mar 19, 2026

Hi @tin2tin, can you provide an example script where you get the error? When I tested the examples with use_cross_timestep=True the generation finished without any errors.

@tin2tin
Copy link
Copy Markdown

tin2tin commented Mar 19, 2026

I'm sorry. I realize my version of your branch might have been a day old. If you get no error, it's all good. Thank you for persistently working on improving this patch. The Diffusers community is eagerly awaiting it.

@dg845 dg845 merged commit 072d15e into main Mar 19, 2026
11 of 12 checks passed
@dg845 dg845 deleted the ltx2-3-pipeline branch March 19, 2026 21:58
@tin2tin
Copy link
Copy Markdown

tin2tin commented Mar 19, 2026

@dg845 Congratulations, and thank you and everybody helping out!

CalamitousFelicitousness added a commit to vladmandic/sdnext that referenced this pull request Apr 25, 2026
Distilled refine ran without the scheduler swap and identity guidance
kwargs prescribed by huggingface/diffusers#13217; only Dev got that
setup via the supports_canonical_stage2 branch. Distilled is already
trained at identity but still needs the recipe applied to avoid the
four-way composition double-dipping on top of the distilled sigma
schedule (oversaturation/striping).

Unify the branches under family == '2.x'; gate only the LoRA load and
unload on supports_canonical_stage2.
liutyi pushed a commit to liutyi/sdnext that referenced this pull request Apr 25, 2026
Distilled refine ran without the scheduler swap and identity guidance
kwargs prescribed by huggingface/diffusers#13217; only Dev got that
setup via the supports_canonical_stage2 branch. Distilled is already
trained at identity but still needs the recipe applied to avoid the
four-way composition double-dipping on top of the distilled sigma
schedule (oversaturation/striping).

Unify the branches under family == '2.x'; gate only the LoRA load and
unload on supports_canonical_stage2.
liutyi pushed a commit to liutyi/sdnext that referenced this pull request Apr 25, 2026
Distilled refine ran without the scheduler swap and identity guidance
kwargs prescribed by huggingface/diffusers#13217; only Dev got that
setup via the supports_canonical_stage2 branch. Distilled is already
trained at identity but still needs the recipe applied to avoid the
four-way composition double-dipping on top of the distilled sigma
schedule (oversaturation/striping).

Unify the branches under family == '2.x'; gate only the LoRA load and
unload on supports_canonical_stage2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants