From fc9ae13e056a6775f2f73f9ce7a018bfc93dd26b Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 23 May 2026 01:24:32 +0800 Subject: [PATCH] feat: add LTX temporal latent upscaler support --- src/ltx_latent_upscaler.hpp | 71 +++++++++++++++---- src/ltx_vae.hpp | 2 +- src/stable-diffusion.cpp | 131 ++++++++++++++++++++++++++++-------- 3 files changed, 162 insertions(+), 42 deletions(-) diff --git a/src/ltx_latent_upscaler.hpp b/src/ltx_latent_upscaler.hpp index 1cdc02282..ea4a830c6 100644 --- a/src/ltx_latent_upscaler.hpp +++ b/src/ltx_latent_upscaler.hpp @@ -31,6 +31,7 @@ namespace LTXVUpsampler { float spatial_scale = 2.f; int spatial_up_num = 2; int spatial_down_den = 1; + int temporal_up_factor = 1; }; static inline bool has_tensor(const String2TensorStorage& tensor_storage_map, @@ -83,9 +84,13 @@ namespace LTXVUpsampler { if (detected_blocks > 0) { config.num_blocks_per_stage = detected_blocks; } - config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight"); - config.spatial_upsample = config.rational_resampler || has_tensor(tensor_storage_map, "upsampler.0.weight"); - config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight"); + config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight"); + int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0); + config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels; + config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels; + if (config.temporal_upsample) { + config.temporal_up_factor = 2; + } if (config.rational_resampler) { int64_t out_channels = get_tensor_ne(tensor_storage_map, "upsampler.conv.weight", @@ -207,6 +212,30 @@ namespace LTXVUpsampler { } }; + class TemporalPixelShuffleND : public UnaryBlock { + protected: + int upscale_factor; + + public: + explicit TemporalPixelShuffleND(int upscale_factor) + : upscale_factor(upscale_factor) {} + + ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override { + GGML_ASSERT(upscale_factor > 0); + GGML_ASSERT(x->ne[3] % upscale_factor == 0); + const int64_t W = x->ne[0]; + const int64_t H = x->ne[1]; + const int64_t F = x->ne[2]; + const int64_t C = x->ne[3] / upscale_factor; + + // x: [b, c*p, f, h, w] -> [b, c, f*p, h, w] + x = ggml_ext_cont(ctx->ggml_ctx, x); + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, F, upscale_factor, C); + x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); + return ggml_reshape_4d(ctx->ggml_ctx, x, W, H, F * upscale_factor, C); + } + }; + class BlurDownsample : public GGMLBlock { protected: int64_t channels; @@ -308,8 +337,7 @@ namespace LTXVUpsampler { explicit LatentUpsampler(LatentUpsamplerConfig config) : config(std::move(config)) { GGML_ASSERT(this->config.dims == 3); - GGML_ASSERT(this->config.spatial_upsample); - GGML_ASSERT(!this->config.temporal_upsample); + GGML_ASSERT(this->config.spatial_upsample || this->config.temporal_upsample); blocks["initial_conv"] = std::shared_ptr(new Conv3d(this->config.in_channels, this->config.mid_channels, @@ -324,6 +352,13 @@ namespace LTXVUpsampler { blocks["upsampler"] = std::shared_ptr(new SpatialRationalResampler(this->config.mid_channels, this->config.spatial_up_num, this->config.spatial_down_den)); + } else if (this->config.temporal_upsample) { + blocks["upsampler.0"] = std::shared_ptr(new Conv3d(this->config.mid_channels, + this->config.temporal_up_factor * this->config.mid_channels, + {3, 3, 3}, + {1, 1, 1}, + {1, 1, 1})); + blocks["upsampler.1"] = std::shared_ptr(new TemporalPixelShuffleND(this->config.temporal_up_factor)); } else { blocks["upsampler.0"] = std::shared_ptr(new Conv2d(this->config.mid_channels, 4 * this->config.mid_channels, @@ -344,7 +379,7 @@ namespace LTXVUpsampler { ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) { // x: [b, c, f, h, w] - // return: [b, c, f, scaled_h, scaled_w] + // return: [b, c, scaled_f, scaled_h, scaled_w] auto initial_conv = std::dynamic_pointer_cast(blocks["initial_conv"]); auto initial_norm = std::dynamic_pointer_cast(blocks["initial_norm"]); auto final_conv = std::dynamic_pointer_cast(blocks["final_conv"]); @@ -363,6 +398,12 @@ namespace LTXVUpsampler { if (config.rational_resampler) { auto upsampler = std::dynamic_pointer_cast(blocks["upsampler"]); x = upsampler->forward(ctx, x); + } else if (config.temporal_upsample) { + auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); + auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); + x = upsample_conv->forward(ctx, x); // [b, c*2, f, h, w] + x = pixel_shuffle->forward(ctx, x); // [b, c, f*2, h, w] + x = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, x->ne[2]); // x[:, :, 1:, :, :] } else { auto upsample_conv = std::dynamic_pointer_cast(blocks["upsampler.0"]); auto pixel_shuffle = std::dynamic_pointer_cast(blocks["upsampler.1"]); @@ -415,23 +456,24 @@ namespace LTXVUpsampler { } const auto& tensor_storage_map = model_loader.get_tensor_storage_map(); - bool has_regular_spatial = has_tensor(tensor_storage_map, "upsampler.0.weight"); + bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight"); bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight"); if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") || - (!has_regular_spatial && !has_rational_spatial)) { - LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors"); + (!has_regular_upsampler && !has_rational_spatial)) { + LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors"); return false; } LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map); - if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample || - config.spatial_up_num < 1 || config.spatial_down_den < 1) { - LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f", + if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) || + config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) { + LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d", config.dims, config.spatial_upsample, config.temporal_upsample, config.rational_resampler, - config.spatial_scale); + config.spatial_scale, + config.temporal_up_factor); return false; } @@ -454,11 +496,12 @@ namespace LTXVUpsampler { } model->load_fixed_tensors(); - LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, rational=%d", + LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d", config.in_channels, config.mid_channels, config.num_blocks_per_stage, config.spatial_scale, + config.temporal_up_factor, config.rational_resampler); return true; } diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index 756741a43..5fdce6c28 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -153,7 +153,7 @@ namespace LTXVAE { GGML_ASSERT(x->ne[2] >= temporal_pad); - int end_idx = x->ne[2] - temporal_pad; + int end_idx = (int)x->ne[2] - temporal_pad; int start_idx = std::max(end_idx - pad, 0); // Save a contiguous copy of the last `pad` frames so the large `x` diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 566c3b4aa..e9be059c1 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -2153,19 +2153,41 @@ class StableDiffusionGGML { int vae_scale_factor = get_vae_scale_factor(); int W = width / vae_scale_factor; int H = height / vae_scale_factor; - int T = frames; - if (sd_version_is_ltxav(version)) { - T = ((T - 1) / 8) + 1; - } else if (sd_version_is_wan(version)) { - T = ((T - 1) / 4) + 1; - } - int C = get_latent_channel(); + int T = video_frames_to_latent_frames(frames); + int C = get_latent_channel(); if (video) { return sd::zeros({W, H, T, C, 1}); } return sd::zeros({W, H, C, 1}); } + int video_frames_to_latent_frames(int frames) { + int latent_frames = frames; + if (sd_version_is_ltxav(version)) { + latent_frames = ((frames - 1) / 8) + 1; + } else if (sd_version_is_wan(version)) { + latent_frames = ((frames - 1) / 4) + 1; + } + return latent_frames; + } + + int latent_frames_to_video_frames(int latent_frames) { + if (latent_frames <= 0) { + return latent_frames; + } + if (sd_version_is_ltxav(version)) { + return (latent_frames - 1) * 8 + 1; + } + if (sd_version_is_wan(version)) { + return (latent_frames - 1) * 4 + 1; + } + return latent_frames; + } + + int align_video_frames(int frames) { + return latent_frames_to_video_frames(video_frames_to_latent_frames(frames)); + } + sd::Tensor encode_to_vae_latents(const sd::Tensor& x) { auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y); if (latents.empty()) { @@ -3000,16 +3022,12 @@ struct GenerationRequest { } GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { - prompt = SAFE_STR(sd_vid_gen_params->prompt); - negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); - width = sd_vid_gen_params->width; - height = sd_vid_gen_params->height; - requested_frames = std::max(1, sd_vid_gen_params->video_frames); - if (sd_version_is_ltxav(sd_ctx->sd->version)) { - frames = ((requested_frames - 1 + 7) / 8) * 8 + 1; - } else { - frames = (requested_frames - 1) / 4 * 4 + 1; - } + prompt = SAFE_STR(sd_vid_gen_params->prompt); + negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt); + width = sd_vid_gen_params->width; + height = sd_vid_gen_params->height; + requested_frames = std::max(1, sd_vid_gen_params->video_frames); + frames = sd_ctx->sd->align_video_frames(requested_frames); clip_skip = sd_vid_gen_params->clip_skip; fps = std::max(1, sd_vid_gen_params->fps); vae_scale_factor = sd_ctx->sd->get_vae_scale_factor(); @@ -3567,6 +3585,30 @@ static sd::Tensor unpack_ltxav_audio_latent(const sd::Tensor& pack return audio_latent; } +static sd::Tensor make_ltxav_empty_audio_latent(int audio_length) { + if (audio_length <= 0) { + return {}; + } + constexpr int kLtxavAudioFrequencyBins = 16; + constexpr int kLtxavAudioChannels = 8; + return sd::zeros({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1}); +} + +static sd::Tensor resize_ltxav_audio_latent(const sd::Tensor& audio_latent, + int target_audio_length) { + auto resized = make_ltxav_empty_audio_latent(target_audio_length); + if (resized.empty() || audio_latent.empty()) { + return resized; + } + GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4); + int copy_length = std::min(static_cast(audio_latent.shape()[1]), target_audio_length); + if (copy_length > 0) { + auto copied = sd::ops::slice(audio_latent, 1, 0, copy_length); + sd::ops::slice_assign(&resized, 1, 0, copy_length, copied); + } + return resized; +} + static int get_ltxav_num_audio_latents(int frames, int fps) { GGML_ASSERT(frames > 0); GGML_ASSERT(fps > 0); @@ -4396,10 +4438,8 @@ static std::optional prepare_video_generation_latents(sd } if (sd_version_is_ltxav(sd_ctx->sd->version)) { - constexpr int kLtxavAudioFrequencyBins = 16; - constexpr int kLtxavAudioChannels = 8; - latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); - latents.audio_latent = sd::zeros({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1}); + latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps); + latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length); } if (sd_version_is_ltxav(sd_ctx->sd->version)) { @@ -4749,9 +4789,9 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx, (int)vid.shape()[1], (int)vid.shape()[2], (int)vid.shape()[3]); - if (request.requested_frames > 0 && - vid.shape()[2] > request.requested_frames) { - vid = sd::ops::slice(vid, 2, 0, request.requested_frames); + if (request.frames > 0 && + vid.shape()[2] > request.frames) { + vid = sd::ops::slice(vid, 2, 0, request.frames); } sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t)); @@ -5118,9 +5158,46 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx, LOG_INFO("LTX latent spatial upscale completed, taking %.2fs", (upscale_end - upscale_start) * 1.0f / 1000); - x_t = std::move(upscaled_latent); - hires_request.width = static_cast(x_t.shape()[0]) * hires_request.vae_scale_factor; - hires_request.height = static_cast(x_t.shape()[1]) * hires_request.vae_scale_factor; + x_t = std::move(upscaled_latent); + hires_request.width = static_cast(x_t.shape()[0]) * hires_request.vae_scale_factor; + hires_request.height = static_cast(x_t.shape()[1]) * hires_request.vae_scale_factor; + int upscaled_latent_frames = static_cast(x_t.shape()[2]); + int upscaled_frames = sd_ctx->sd->latent_frames_to_video_frames(upscaled_latent_frames); + if (upscaled_frames != hires_request.frames) { + LOG_INFO("LTX latent upsampler output latent frames %d, frames %d -> %d", + upscaled_latent_frames, + hires_request.frames, + upscaled_frames); + hires_request.frames = upscaled_frames; + } + if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0) { + int target_audio_length = get_ltxav_num_audio_latents(hires_request.frames, hires_request.fps); + if (target_audio_length != latents.audio_length) { + int latent_channels = sd_ctx->sd->get_latent_channel(); + sd::Tensor video_latent = x_t; + sd::Tensor audio_latent = latents.audio_latent; + if (x_t.shape()[3] > latent_channels) { + video_latent = sd::ops::slice(x_t, 3, 0, latent_channels); + audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels); + } + audio_latent = resize_ltxav_audio_latent(audio_latent, target_audio_length); + if (audio_latent.empty()) { + LOG_ERROR("failed to resize LTX audio latent for latent upscale: %d -> %d", + latents.audio_length, + target_audio_length); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->diffusion_model->free_params_buffer(); + } + return false; + } + x_t = pack_ltxav_audio_and_video_latents(video_latent, audio_latent); + latents.audio_latent = std::move(audio_latent); + LOG_INFO("LTX audio latent length adjusted for latent upscale: %d -> %d", + latents.audio_length, + target_audio_length); + latents.audio_length = target_audio_length; + } + } if ((request.hires.target_width > 0 || request.hires.target_height > 0) && (request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) { LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",