Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions src/ltx_latent_upscaler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace LTXVUpsampler {
float spatial_scale = 2.f;
int spatial_up_num = 2;
int spatial_down_den = 1;
int temporal_up_factor = 1;
};

static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
Expand Down Expand Up @@ -83,9 +84,13 @@ namespace LTXVUpsampler {
if (detected_blocks > 0) {
config.num_blocks_per_stage = detected_blocks;
}
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
config.spatial_upsample = config.rational_resampler || has_tensor(tensor_storage_map, "upsampler.0.weight");
config.temporal_upsample = has_tensor(tensor_storage_map, "temporal_upsampler.0.weight");
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0);
config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels;
config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels;
if (config.temporal_upsample) {
config.temporal_up_factor = 2;
}
if (config.rational_resampler) {
int64_t out_channels = get_tensor_ne(tensor_storage_map,
"upsampler.conv.weight",
Expand Down Expand Up @@ -207,6 +212,30 @@ namespace LTXVUpsampler {
}
};

class TemporalPixelShuffleND : public UnaryBlock {
protected:
int upscale_factor;

public:
explicit TemporalPixelShuffleND(int upscale_factor)
: upscale_factor(upscale_factor) {}

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) override {
GGML_ASSERT(upscale_factor > 0);
GGML_ASSERT(x->ne[3] % upscale_factor == 0);
const int64_t W = x->ne[0];
const int64_t H = x->ne[1];
const int64_t F = x->ne[2];
const int64_t C = x->ne[3] / upscale_factor;

// x: [b, c*p, f, h, w] -> [b, c, f*p, h, w]
x = ggml_ext_cont(ctx->ggml_ctx, x);
x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, F, upscale_factor, C);
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3));
return ggml_reshape_4d(ctx->ggml_ctx, x, W, H, F * upscale_factor, C);
}
};

class BlurDownsample : public GGMLBlock {
protected:
int64_t channels;
Expand Down Expand Up @@ -308,8 +337,7 @@ namespace LTXVUpsampler {
explicit LatentUpsampler(LatentUpsamplerConfig config)
: config(std::move(config)) {
GGML_ASSERT(this->config.dims == 3);
GGML_ASSERT(this->config.spatial_upsample);
GGML_ASSERT(!this->config.temporal_upsample);
GGML_ASSERT(this->config.spatial_upsample || this->config.temporal_upsample);

blocks["initial_conv"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.in_channels,
this->config.mid_channels,
Expand All @@ -324,6 +352,13 @@ namespace LTXVUpsampler {
blocks["upsampler"] = std::shared_ptr<GGMLBlock>(new SpatialRationalResampler(this->config.mid_channels,
this->config.spatial_up_num,
this->config.spatial_down_den));
} else if (this->config.temporal_upsample) {
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv3d(this->config.mid_channels,
this->config.temporal_up_factor * this->config.mid_channels,
{3, 3, 3},
{1, 1, 1},
{1, 1, 1}));
blocks["upsampler.1"] = std::shared_ptr<GGMLBlock>(new TemporalPixelShuffleND(this->config.temporal_up_factor));
} else {
blocks["upsampler.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.mid_channels,
4 * this->config.mid_channels,
Expand All @@ -344,7 +379,7 @@ namespace LTXVUpsampler {

ggml_tensor* forward(GGMLRunnerContext* ctx, ggml_tensor* x) {
// x: [b, c, f, h, w]
// return: [b, c, f, scaled_h, scaled_w]
// return: [b, c, scaled_f, scaled_h, scaled_w]
auto initial_conv = std::dynamic_pointer_cast<Conv3d>(blocks["initial_conv"]);
auto initial_norm = std::dynamic_pointer_cast<VideoGroupNorm>(blocks["initial_norm"]);
auto final_conv = std::dynamic_pointer_cast<Conv3d>(blocks["final_conv"]);
Expand All @@ -363,6 +398,12 @@ namespace LTXVUpsampler {
if (config.rational_resampler) {
auto upsampler = std::dynamic_pointer_cast<SpatialRationalResampler>(blocks["upsampler"]);
x = upsampler->forward(ctx, x);
} else if (config.temporal_upsample) {
auto upsample_conv = std::dynamic_pointer_cast<Conv3d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<TemporalPixelShuffleND>(blocks["upsampler.1"]);
x = upsample_conv->forward(ctx, x); // [b, c*2, f, h, w]
x = pixel_shuffle->forward(ctx, x); // [b, c, f*2, h, w]
x = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, x->ne[2]); // x[:, :, 1:, :, :]
} else {
auto upsample_conv = std::dynamic_pointer_cast<Conv2d>(blocks["upsampler.0"]);
auto pixel_shuffle = std::dynamic_pointer_cast<PixelShuffleND>(blocks["upsampler.1"]);
Expand Down Expand Up @@ -415,23 +456,24 @@ namespace LTXVUpsampler {
}

const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
bool has_regular_spatial = has_tensor(tensor_storage_map, "upsampler.0.weight");
bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight");
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
(!has_regular_spatial && !has_rational_spatial)) {
LOG_ERROR("unsupported LTX latent upsampler weights: expected spatial upsampler tensors");
(!has_regular_upsampler && !has_rational_spatial)) {
LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors");
return false;
}

LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
if (config.dims != 3 || !config.spatial_upsample || config.temporal_upsample ||
config.spatial_up_num < 1 || config.spatial_down_den < 1) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f",
if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) ||
config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) {
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d",
config.dims,
config.spatial_upsample,
config.temporal_upsample,
config.rational_resampler,
config.spatial_scale);
config.spatial_scale,
config.temporal_up_factor);
return false;
}

Expand All @@ -454,11 +496,12 @@ namespace LTXVUpsampler {
}
model->load_fixed_tensors();

LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, rational=%d",
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d",
config.in_channels,
config.mid_channels,
config.num_blocks_per_stage,
config.spatial_scale,
config.temporal_up_factor,
config.rational_resampler);
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/ltx_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ namespace LTXVAE {

GGML_ASSERT(x->ne[2] >= temporal_pad);

int end_idx = x->ne[2] - temporal_pad;
int end_idx = (int)x->ne[2] - temporal_pad;
int start_idx = std::max(end_idx - pad, 0);

// Save a contiguous copy of the last `pad` frames so the large `x`
Expand Down
131 changes: 104 additions & 27 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2153,19 +2153,41 @@ class StableDiffusionGGML {
int vae_scale_factor = get_vae_scale_factor();
int W = width / vae_scale_factor;
int H = height / vae_scale_factor;
int T = frames;
if (sd_version_is_ltxav(version)) {
T = ((T - 1) / 8) + 1;
} else if (sd_version_is_wan(version)) {
T = ((T - 1) / 4) + 1;
}
int C = get_latent_channel();
int T = video_frames_to_latent_frames(frames);
int C = get_latent_channel();
if (video) {
return sd::zeros<float>({W, H, T, C, 1});
}
return sd::zeros<float>({W, H, C, 1});
}

int video_frames_to_latent_frames(int frames) {
int latent_frames = frames;
if (sd_version_is_ltxav(version)) {
latent_frames = ((frames - 1) / 8) + 1;
} else if (sd_version_is_wan(version)) {
latent_frames = ((frames - 1) / 4) + 1;
}
return latent_frames;
}

int latent_frames_to_video_frames(int latent_frames) {
if (latent_frames <= 0) {
return latent_frames;
}
if (sd_version_is_ltxav(version)) {
return (latent_frames - 1) * 8 + 1;
}
if (sd_version_is_wan(version)) {
return (latent_frames - 1) * 4 + 1;
}
return latent_frames;
}

int align_video_frames(int frames) {
return latent_frames_to_video_frames(video_frames_to_latent_frames(frames));
}

sd::Tensor<float> encode_to_vae_latents(const sd::Tensor<float>& x) {
auto latents = first_stage_model->encode(n_threads, x, vae_tiling_params, circular_x, circular_y);
if (latents.empty()) {
Expand Down Expand Up @@ -3000,16 +3022,12 @@ struct GenerationRequest {
}

GenerationRequest(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
prompt = SAFE_STR(sd_vid_gen_params->prompt);
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
width = sd_vid_gen_params->width;
height = sd_vid_gen_params->height;
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
if (sd_version_is_ltxav(sd_ctx->sd->version)) {
frames = ((requested_frames - 1 + 7) / 8) * 8 + 1;
} else {
frames = (requested_frames - 1) / 4 * 4 + 1;
}
prompt = SAFE_STR(sd_vid_gen_params->prompt);
negative_prompt = SAFE_STR(sd_vid_gen_params->negative_prompt);
width = sd_vid_gen_params->width;
height = sd_vid_gen_params->height;
requested_frames = std::max(1, sd_vid_gen_params->video_frames);
frames = sd_ctx->sd->align_video_frames(requested_frames);
clip_skip = sd_vid_gen_params->clip_skip;
fps = std::max(1, sd_vid_gen_params->fps);
vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
Expand Down Expand Up @@ -3567,6 +3585,30 @@ static sd::Tensor<float> unpack_ltxav_audio_latent(const sd::Tensor<float>& pack
return audio_latent;
}

static sd::Tensor<float> make_ltxav_empty_audio_latent(int audio_length) {
if (audio_length <= 0) {
return {};
}
constexpr int kLtxavAudioFrequencyBins = 16;
constexpr int kLtxavAudioChannels = 8;
return sd::zeros<float>({kLtxavAudioFrequencyBins, audio_length, kLtxavAudioChannels, 1});
}

static sd::Tensor<float> resize_ltxav_audio_latent(const sd::Tensor<float>& audio_latent,
int target_audio_length) {
auto resized = make_ltxav_empty_audio_latent(target_audio_length);
if (resized.empty() || audio_latent.empty()) {
return resized;
}
GGML_ASSERT(audio_latent.dim() == 3 || audio_latent.dim() == 4);
int copy_length = std::min(static_cast<int>(audio_latent.shape()[1]), target_audio_length);
if (copy_length > 0) {
auto copied = sd::ops::slice(audio_latent, 1, 0, copy_length);
sd::ops::slice_assign(&resized, 1, 0, copy_length, copied);
}
return resized;
}

static int get_ltxav_num_audio_latents(int frames, int fps) {
GGML_ASSERT(frames > 0);
GGML_ASSERT(fps > 0);
Expand Down Expand Up @@ -4396,10 +4438,8 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
}

if (sd_version_is_ltxav(sd_ctx->sd->version)) {
constexpr int kLtxavAudioFrequencyBins = 16;
constexpr int kLtxavAudioChannels = 8;
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = sd::zeros<float>({kLtxavAudioFrequencyBins, latents.audio_length, kLtxavAudioChannels, 1});
latents.audio_length = get_ltxav_num_audio_latents(request->frames, request->fps);
latents.audio_latent = make_ltxav_empty_audio_latent(latents.audio_length);
}

if (sd_version_is_ltxav(sd_ctx->sd->version)) {
Expand Down Expand Up @@ -4749,9 +4789,9 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
(int)vid.shape()[1],
(int)vid.shape()[2],
(int)vid.shape()[3]);
if (request.requested_frames > 0 &&
vid.shape()[2] > request.requested_frames) {
vid = sd::ops::slice(vid, 2, 0, request.requested_frames);
if (request.frames > 0 &&
vid.shape()[2] > request.frames) {
vid = sd::ops::slice(vid, 2, 0, request.frames);
}

sd_image_t* result_images = (sd_image_t*)calloc(vid.shape()[2], sizeof(sd_image_t));
Expand Down Expand Up @@ -5118,9 +5158,46 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("LTX latent spatial upscale completed, taking %.2fs",
(upscale_end - upscale_start) * 1.0f / 1000);

x_t = std::move(upscaled_latent);
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
x_t = std::move(upscaled_latent);
hires_request.width = static_cast<int>(x_t.shape()[0]) * hires_request.vae_scale_factor;
hires_request.height = static_cast<int>(x_t.shape()[1]) * hires_request.vae_scale_factor;
int upscaled_latent_frames = static_cast<int>(x_t.shape()[2]);
int upscaled_frames = sd_ctx->sd->latent_frames_to_video_frames(upscaled_latent_frames);
if (upscaled_frames != hires_request.frames) {
LOG_INFO("LTX latent upsampler output latent frames %d, frames %d -> %d",
upscaled_latent_frames,
hires_request.frames,
upscaled_frames);
hires_request.frames = upscaled_frames;
}
if (sd_version_is_ltxav(sd_ctx->sd->version) && latents.audio_length > 0) {
int target_audio_length = get_ltxav_num_audio_latents(hires_request.frames, hires_request.fps);
if (target_audio_length != latents.audio_length) {
int latent_channels = sd_ctx->sd->get_latent_channel();
sd::Tensor<float> video_latent = x_t;
sd::Tensor<float> audio_latent = latents.audio_latent;
if (x_t.shape()[3] > latent_channels) {
video_latent = sd::ops::slice(x_t, 3, 0, latent_channels);
audio_latent = unpack_ltxav_audio_latent(x_t, latents.audio_length, latent_channels);
}
audio_latent = resize_ltxav_audio_latent(audio_latent, target_audio_length);
if (audio_latent.empty()) {
LOG_ERROR("failed to resize LTX audio latent for latent upscale: %d -> %d",
latents.audio_length,
target_audio_length);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
return false;
}
x_t = pack_ltxav_audio_and_video_latents(video_latent, audio_latent);
latents.audio_latent = std::move(audio_latent);
LOG_INFO("LTX audio latent length adjusted for latent upscale: %d -> %d",
latents.audio_length,
target_audio_length);
latents.audio_length = target_audio_length;
}
}
if ((request.hires.target_width > 0 || request.hires.target_height > 0) &&
(request.hires.target_width != hires_request.width || request.hires.target_height != hires_request.height)) {
LOG_WARN("LTX latent spatial upsampler output is %dx%d; ignoring hires target %dx%d",
Expand Down
Loading