diff --git a/README.md b/README.md index a3e207aa4..a32f8c5d7 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ API and command-line option may change frequently.*** - [Chroma](./docs/chroma.md) - [Chroma1-Radiance](./docs/chroma_radiance.md) - [Qwen Image](./docs/qwen_image.md) + - [LongCat Image](./docs/longcat_image.md) - [Z-Image](./docs/z_image.md) - [Ovis-Image](./docs/ovis_image.md) - [Anima](./docs/anima.md) @@ -48,6 +49,7 @@ API and command-line option may change frequently.*** - Image Edit Models - [FLUX.1-Kontext-dev](./docs/kontext.md) - [Qwen Image Edit series](./docs/qwen_image_edit.md) + - [LongCat Image Edit](./docs/longcat_image.md) - Video Models - [Wan2.1/Wan2.2](./docs/wan.md) - [LTX-2.3](./docs/ltx2.md) @@ -133,6 +135,7 @@ For runtime and parameter backend placement, see the [backend selection guide](. - [Chroma](./docs/chroma.md) - [🔥Qwen Image](./docs/qwen_image.md) - [🔥Qwen Image Edit series](./docs/qwen_image_edit.md) +- [🔥LongCat Image / LongCat Image Edit](./docs/longcat_image.md) - [🔥Wan2.1/Wan2.2](./docs/wan.md) - [🔥LTX-2.3](./docs/ltx2.md) - [🔥Z-Image](./docs/z_image.md) diff --git a/assets/longcat/example.png b/assets/longcat/example.png new file mode 100644 index 000000000..3020fb8f9 Binary files /dev/null and b/assets/longcat/example.png differ diff --git a/docs/longcat_image.md b/docs/longcat_image.md new file mode 100644 index 000000000..49cae193d --- /dev/null +++ b/docs/longcat_image.md @@ -0,0 +1,30 @@ +# How to Use + +LongCat-Image uses a LongCat diffusion transformer, the FLUX VAE, and Qwen2.5-VL as the LLM text encoder. + +## Download weights + +- Download LongCat Image + - safetensors: https://huggingface.co/Comfy-Org/LongCat-Image/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/vantagewithai/LongCat-Image-GGUF/tree/main/comfy +- Download LongCat Image Edit + - LongCat Image Edit Turbo: https://huggingface.co/meituan-longcat/LongCat-Image-Edit-Turbo + - gguf: https://huggingface.co/vantagewithai/LongCat-Image-Edit-GGUF/tree/main +- Download vae + - safetensors: https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download qwen_2.5_vl 7b + - safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders + - gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main + - For image editing with GGUF text encoders, also download the matching mmproj file and pass it with `--llm_vision`. + +## Run + +LongCat uses quoted text for character-level text rendering. Put target text inside single quotes, double quotes, or Chinese quotes. + +### LongCat Image + +``` +.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\LongCat-Image-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p "a lovely cat holding a sign says 'longcat.cpp'" --cfg-scale 5.0 --sampling-method euler --flow-shift 3 -v --offload-to-cpu --diffusion-fa +``` + +longcat example diff --git a/src/anima.hpp b/src/anima.hpp index 486aec3ad..dc7e9f883 100644 --- a/src/anima.hpp +++ b/src/anima.hpp @@ -598,7 +598,8 @@ namespace Anima { {}, empty_ref_latents, false, - 1.0f); + 1.0f, + false); std::vector axis_thetas = { static_cast(theta) * calc_ntk_factor(t_extrapolation_ratio, axes_dim[0]), diff --git a/src/auto_encoder_kl.hpp b/src/auto_encoder_kl.hpp index 489f8fd30..13396e737 100644 --- a/src/auto_encoder_kl.hpp +++ b/src/auto_encoder_kl.hpp @@ -680,7 +680,7 @@ struct AutoEncoderKL : public VAE { } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; shift_factor = 0.0609f; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { scale_factor = 0.3611f; shift_factor = 0.1159f; } else if (sd_version_uses_flux2_vae(version)) { diff --git a/src/conditioner.hpp b/src/conditioner.hpp index e5a702b9e..3963f3abf 100644 --- a/src/conditioner.hpp +++ b/src/conditioner.hpp @@ -1747,7 +1747,8 @@ struct LLMEmbedder : public Conditioner { std::tuple, std::vector, std::vector> tokenize(std::string text, const std::pair& attn_range, size_t min_length = 0, - size_t max_length = 100000000) { + size_t max_length = 100000000, + bool spell_quotes = false) { std::vector> parsed_attention; if (attn_range.first >= 0 && attn_range.second > 0) { if (attn_range.first > 0) { @@ -1755,6 +1756,9 @@ struct LLMEmbedder : public Conditioner { } if (attn_range.second - attn_range.first > 0) { auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); + if (spell_quotes) { + new_parsed_attention = split_quotation_attention(new_parsed_attention); + } parsed_attention.insert(parsed_attention.end(), new_parsed_attention.begin(), new_parsed_attention.end()); @@ -1804,8 +1808,10 @@ struct LLMEmbedder : public Conditioner { int hidden_states_min_length, const std::vector>>& image_embeds, const std::set& out_layers, - int prompt_template_encode_start_idx) { - auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length); + int prompt_template_encode_start_idx, + bool spell_quotes = false, + int max_length = 100000000) { + auto tokens_weights_mask = tokenize(prompt, prompt_attn_range, min_length, max_length, spell_quotes); auto& tokens = std::get<0>(tokens_weights_mask); auto& weights = std::get<1>(tokens_weights_mask); auto& mask = std::get<2>(tokens_weights_mask); @@ -1866,6 +1872,7 @@ struct LLMEmbedder : public Conditioner { int prompt_template_encode_start_idx = 34; int min_length = 0; // pad tokens int hidden_states_min_length = 0; // zero pad hidden_states + bool spell_quotes = false; std::set out_layers; int64_t t0 = ggml_time_ms(); @@ -1938,6 +1945,71 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n"; } + } else if (sd_version_is_longcat(version)) { + spell_quotes = true; + + if (llm->enable_vision && conditioner_params.ref_images != nullptr && !conditioner_params.ref_images->empty()) { + LOG_INFO("LongCatEditPipeline"); + prompt_template_encode_start_idx = 67; + min_length = 512 + prompt_template_encode_start_idx; + int image_embed_idx = 36 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images->size(); i++) { + const auto& image = (*conditioner_params.ref_images)[i]; + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = static_cast(image.shape()[1]); + int width = static_cast(image.shape()[0]); + int h_bar = static_cast(std::round(height / factor) * factor); + int w_bar = static_cast(std::round(width / factor) * factor); + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, height, width, h_bar, w_bar); + + auto resized_image = clip_preprocess(image, w_bar, h_bar); + auto image_embed = llm->encode_image(n_threads, resized_image); + GGML_ASSERT(!image_embed.empty()); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + static_cast(image_embed.shape()[1]) + 6; + + img_prompt += "<|vision_start|>"; + int64_t num_image_tokens = image_embed.shape()[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + } + + prompt = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + } else { + prompt_template_encode_start_idx = 36; + min_length = 512 + prompt_template_encode_start_idx; + + prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"; + } + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; } else if (version == VERSION_FLUX2) { prompt_template_encode_start_idx = 0; hidden_states_min_length = 512; @@ -2012,7 +2084,8 @@ struct LLMEmbedder : public Conditioner { hidden_states_min_length, image_embeds, out_layers, - prompt_template_encode_start_idx); + prompt_template_encode_start_idx, + spell_quotes); std::vector> extra_hidden_states_vec; for (int i = 0; i < extra_prompts.size(); i++) { auto extra_hidden_states = encode_prompt(n_threads, @@ -2022,7 +2095,8 @@ struct LLMEmbedder : public Conditioner { hidden_states_min_length, image_embeds, out_layers, - prompt_template_encode_start_idx); + prompt_template_encode_start_idx, + spell_quotes); extra_hidden_states_vec.push_back(std::move(extra_hidden_states)); } diff --git a/src/flux.hpp b/src/flux.hpp index 2aac3be0c..85da3043a 100644 --- a/src/flux.hpp +++ b/src/flux.hpp @@ -446,7 +446,6 @@ namespace Flux { if (use_yak_mlp || use_mlp_silu_act) { mlp_mult_factor = 2; } - blocks["linear1"] = std::shared_ptr(new Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim * mlp_mult_factor, mlp_proj_bias)); blocks["linear2"] = std::shared_ptr(new Linear(hidden_size + mlp_hidden_dim, hidden_size, mlp_proj_bias)); blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); @@ -1225,6 +1224,9 @@ namespace Flux { flux_params.share_modulation = true; flux_params.ref_index_scale = 10.f; flux_params.use_mlp_silu_act = true; + } else if (sd_version_is_longcat(version)) { + flux_params.context_in_dim = 3584; + flux_params.vec_in_dim = 0; } int64_t head_dim = 0; int64_t actual_radiance_patch_size = -1; @@ -1412,7 +1414,6 @@ namespace Flux { } else if (version == VERSION_OVIS_IMAGE) { txt_arange_dims = {1, 2}; } - pe_vec = Rope::gen_flux_pe(static_cast(x->ne[1]), static_cast(x->ne[0]), flux_params.patch_size, @@ -1425,7 +1426,8 @@ namespace Flux { flux_params.theta, circular_y_enabled, circular_x_enabled, - flux_params.axes_dim); + flux_params.axes_dim, + sd_version_is_longcat(version)); int pos_len = static_cast(pe_vec.size() / flux_params.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); diff --git a/src/ggml_extend.hpp b/src/ggml_extend.hpp index 28df9f1bf..9178a31e8 100644 --- a/src/ggml_extend.hpp +++ b/src/ggml_extend.hpp @@ -953,11 +953,17 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_group_norm_32(ggml_context* ctx, return ggml_group_norm(ctx, a, 32, eps); } +__STATIC_INLINE__ bool ggml_ext_is_padded_1d(const ggml_tensor* x) { + return x->nb[0] == ggml_type_size(x->type) && + x->nb[2] == x->nb[1] * x->ne[1] && + x->nb[3] == x->nb[2] * x->ne[2]; +} + __STATIC_INLINE__ ggml_tensor* ggml_ext_scale(ggml_context* ctx, ggml_tensor* x, float factor, bool inplace = false) { - if (!ggml_is_contiguous(x)) { + if (!ggml_ext_is_padded_1d(x)) { x = ggml_cont(ctx, x); } if (inplace) { @@ -3664,7 +3670,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( ggml_tensor* hc = ggml_transpose(ctx, hc_t); ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc), up * vp, batch); - return ggml_scale(ctx, out, scale); + return ggml_ext_scale(ctx, out, scale); } else { int batch = (int)h->ne[3]; // 1. Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch] @@ -3747,7 +3753,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_lokr_forward( ggml_tensor* hc = ggml_transpose(ctx, hc_t); // ungroup ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc), w_out, h_out, up * vp, batch); - return ggml_scale(ctx, out, scale); + return ggml_ext_scale(ctx, out, scale); } } diff --git a/src/model.cpp b/src/model.cpp index 9929605ec..8351a2be6 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -410,7 +410,7 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s } SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight, input_block_weight; + TensorStorage token_embedding_weight, input_block_weight, context_ebedding_weight; bool has_multiple_encoders = false; bool is_unet = false; @@ -428,7 +428,8 @@ SDVersion ModelLoader::get_sd_version() { bool has_attn_1024 = false; for (auto& [name, tensor_storage] : tensor_storage_map) { - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos || + tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) { is_flux = true; } if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) { @@ -522,6 +523,9 @@ SDVersion ModelLoader::get_sd_version() { tensor_storage.name == "unet.conv_in.weight") { input_block_weight = tensor_storage; } + if (tensor_storage.name == "model.diffusion_model.txt_in.weight" || tensor_storage.name == "model.diffusion_model.context_embedder.weight") { + context_ebedding_weight = tensor_storage; + } } if (is_wan) { LOG_DEBUG("patch_embedding_channels %d", patch_embedding_channels); @@ -552,16 +556,20 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux && !is_flux2) { - if (input_block_weight.ne[0] == 384) { - return VERSION_FLUX_FILL; - } - if (input_block_weight.ne[0] == 128) { - return VERSION_FLUX_CONTROLS; - } - if (input_block_weight.ne[0] == 196) { - return VERSION_FLEX_2; + if (context_ebedding_weight.ne[0] == 3584) { + return VERSION_LONGCAT; + } else { + if (input_block_weight.ne[0] == 384) { + return VERSION_FLUX_FILL; + } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } + return VERSION_FLUX; } - return VERSION_FLUX; } if (is_flux2) { diff --git a/src/model.h b/src/model.h index c0a8d810f..fadeeefb0 100644 --- a/src/model.h +++ b/src/model.h @@ -47,6 +47,7 @@ enum SDVersion { VERSION_Z_IMAGE, VERSION_OVIS_IMAGE, VERSION_ERNIE_IMAGE, + VERSION_LONGCAT, VERSION_COUNT, }; @@ -141,6 +142,13 @@ static inline bool sd_version_is_z_image(SDVersion version) { return false; } +static inline bool sd_version_is_longcat(SDVersion version) { + if (version == VERSION_LONGCAT) { + return true; + } + return false; +} + static inline bool sd_version_is_ernie_image(SDVersion version) { if (version == VERSION_ERNIE_IMAGE) { return true; @@ -176,7 +184,8 @@ static inline bool sd_version_is_dit(SDVersion version) { version == VERSION_HIDREAM_O1 || sd_version_is_anima(version) || sd_version_is_z_image(version) || - sd_version_is_ernie_image(version)) { + sd_version_is_ernie_image(version) || + sd_version_is_longcat(version)) { return true; } return false; diff --git a/src/name_conversion.cpp b/src/name_conversion.cpp index 618c7f6e9..819066d00 100644 --- a/src/name_conversion.cpp +++ b/src/name_conversion.cpp @@ -508,6 +508,12 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { static std::unordered_map flux_name_map; if (flux_name_map.empty()) { + // --- time_embed (longcat) --- + flux_name_map["time_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + // --- time_text_embed --- flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; @@ -561,6 +567,11 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { flux_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "img_attn.norm.key_norm.scale"; flux_name_map[block_prefix + "attn.norm_added_q.weight"] = dst_prefix + "txt_attn.norm.query_norm.scale"; flux_name_map[block_prefix + "attn.norm_added_k.weight"] = dst_prefix + "txt_attn.norm.key_norm.scale"; + // Comfy-Org/LongCat-Image stores already-converted RMSNorm tensors as *.weight. + flux_name_map[dst_prefix + "img_attn.norm.query_norm.weight"] = dst_prefix + "img_attn.norm.query_norm.scale"; + flux_name_map[dst_prefix + "img_attn.norm.key_norm.weight"] = dst_prefix + "img_attn.norm.key_norm.scale"; + flux_name_map[dst_prefix + "txt_attn.norm.query_norm.weight"] = dst_prefix + "txt_attn.norm.query_norm.scale"; + flux_name_map[dst_prefix + "txt_attn.norm.key_norm.weight"] = dst_prefix + "txt_attn.norm.key_norm.scale"; // ff flux_name_map[block_prefix + "ff.net.0.proj.weight"] = dst_prefix + "img_mlp.0.weight"; @@ -599,8 +610,11 @@ std::string convert_diffusers_dit_to_original_flux(std::string name) { flux_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "norm.query_norm.scale"; flux_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "norm.key_norm.scale"; - flux_name_map[block_prefix + "proj_out.weight"] = dst_prefix + "linear2.weight"; - flux_name_map[block_prefix + "proj_out.bias"] = dst_prefix + "linear2.bias"; + // Comfy-Org/LongCat-Image stores already-converted RMSNorm tensors as *.weight. + flux_name_map[dst_prefix + "norm.query_norm.weight"] = dst_prefix + "norm.query_norm.scale"; + flux_name_map[dst_prefix + "norm.key_norm.weight"] = dst_prefix + "norm.key_norm.scale"; + flux_name_map[block_prefix + "proj_out.weight"] = dst_prefix + "linear2.weight"; + flux_name_map[block_prefix + "proj_out.bias"] = dst_prefix + "linear2.bias"; } // --- final layers --- @@ -668,7 +682,7 @@ std::string convert_diffusion_model_name(std::string name, std::string prefix, S name = convert_diffusers_unet_to_original_sdxl(name); } else if (sd_version_is_sd3(version)) { name = convert_diffusers_dit_to_original_sd3(name); - } else if (sd_version_is_flux(version) || sd_version_is_flux2(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_flux2(version) || sd_version_is_longcat(version)) { name = convert_diffusers_dit_to_original_flux(name); } else if (sd_version_is_z_image(version)) { name = convert_diffusers_dit_to_original_lumina2(name); diff --git a/src/rope.hpp b/src/rope.hpp index f84fac885..4c9597607 100644 --- a/src/rope.hpp +++ b/src/rope.hpp @@ -111,6 +111,16 @@ namespace Rope { return txt_ids; } + __STATIC_INLINE__ std::vector> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) { + auto txt_ids = std::vector>(bs * context_len, std::vector(axes_dim_num, 0.0f)); + for (int i = 0; i < bs * context_len; i++) { + float token_index = static_cast(i % context_len); + txt_ids[i][1] = token_index; + txt_ids[i][2] = token_index; + } + return txt_ids; + } + __STATIC_INLINE__ std::vector> gen_flux_img_ids(int h, int w, int patch_size, @@ -122,7 +132,6 @@ namespace Rope { bool scale_rope = false) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; - std::vector> img_ids(h_len * w_len, std::vector(axes_dim_num, 0.0)); int h_start = h_offset; @@ -135,7 +144,6 @@ namespace Rope { std::vector row_ids = linspace(1.f * h_start, 1.f * h_start + h_len - 1, h_len); std::vector col_ids = linspace(1.f * w_start, 1.f * w_start + w_len - 1, w_len); - for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { img_ids[i * w_len + j][0] = 1.f * index; @@ -244,14 +252,16 @@ namespace Rope { __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, int axes_dim_num, + int start_index, const std::vector& ref_latents, bool increase_ref_index, float ref_index_scale, - bool scale_rope) { + bool scale_rope, + int base_offset = 0) { std::vector> ids; int curr_h_offset = 0; int curr_w_offset = 0; - int index = 1; + int index = start_index; for (ggml_tensor* ref : ref_latents) { int h_offset = 0; int w_offset = 0; @@ -270,8 +280,8 @@ namespace Rope { bs, axes_dim_num, static_cast(index * ref_index_scale), - h_offset, - w_offset, + h_offset + base_offset, + w_offset + base_offset, scale_rope); ids = concat_ids(ids, ref_ids, bs); @@ -294,13 +304,17 @@ namespace Rope { std::set txt_arange_dims, const std::vector& ref_latents, bool increase_ref_index, - float ref_index_scale) { - auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num); + float ref_index_scale, + bool is_longcat) { + int x_index = is_longcat ? 1 : 0; + + auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num, txt_arange_dims); + int offset = is_longcat ? context_len : 0; + auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, x_index, offset, offset); auto ids = concat_ids(txt_ids, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, x_index + 1, ref_latents, increase_ref_index, ref_index_scale, false, offset); ids = concat_ids(ids, refs_ids, bs); } return ids; @@ -319,7 +333,8 @@ namespace Rope { int theta, bool circular_h, bool circular_w, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool is_longcat) { std::vector> ids = gen_flux_ids(h, w, patch_size, @@ -329,7 +344,8 @@ namespace Rope { txt_arange_dims, ref_latents, increase_ref_index, - ref_index_scale); + ref_index_scale, + is_longcat); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { int h_len = (h + (patch_size / 2)) / patch_size; @@ -394,7 +410,7 @@ namespace Rope { auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); if (ref_latents.size() > 0) { - auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true); + auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f, true); ids = concat_ids(ids, refs_ids, bs); } return ids; diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index e9be059c1..9e8e4744e 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -62,6 +62,7 @@ const char* model_version_to_str[] = { "Z-Image", "Ovis Image", "Ernie Image", + "Longcat-Image", }; const char* sampling_methods_str[] = { @@ -594,6 +595,22 @@ class StableDiffusionGGML { "model.diffusion_model", version, sd_ctx_params->qwen_image_zero_cond_t); + } else if (sd_version_is_longcat(version)) { + bool enable_vision = false; + if (!vae_decode_only) { + enable_vision = true; + } + cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), + params_backend_for(SDBackendModule::TE), + tensor_storage_map, + version, + "", + enable_vision); + diffusion_model = std::make_shared(backend_for(SDBackendModule::DIFFUSION), + params_backend_for(SDBackendModule::DIFFUSION), + tensor_storage_map, + version, + sd_ctx_params->chroma_use_dit_mask); } else if (version == VERSION_HIDREAM_O1) { cond_stage_model = std::make_shared(backend_for(SDBackendModule::TE), params_backend_for(SDBackendModule::TE), @@ -1098,7 +1115,7 @@ class StableDiffusionGGML { } else { default_flow_shift = 3.f; } - } else if (sd_version_is_flux(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_longcat(version)) { pred_type = FLUX_FLOW_PRED; default_flow_shift = 1.0f; // TODO: validate @@ -1108,6 +1125,9 @@ class StableDiffusionGGML { break; } } + if (sd_version_is_longcat(version)) { + default_flow_shift = 3.0f; + } } else if (sd_version_is_flux2(version)) { pred_type = FLUX2_FLOW_PRED; } else { @@ -1645,7 +1665,7 @@ class StableDiffusionGGML { if (sd_version_is_sd3(version)) { latent_rgb_proj = sd3_latent_rgb_proj; latent_rgb_bias = sd3_latent_rgb_bias; - } else if (sd_version_is_flux(version) || sd_version_is_z_image(version)) { + } else if (sd_version_is_flux(version) || sd_version_is_z_image(version) || sd_version_is_longcat(version)) { latent_rgb_proj = flux_latent_rgb_proj; latent_rgb_bias = flux_latent_rgb_bias; } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { diff --git a/src/util.cpp b/src/util.cpp index 1921b3b23..77fc5429c 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -842,6 +842,139 @@ std::vector> parse_prompt_attention(const std::str return res; } +static size_t get_utf8_char_len(char c) { + unsigned char uc = static_cast(c); + if ((uc & 0x80) == 0) { + return 1; + } + if ((uc & 0xE0) == 0xC0) { + return 2; + } + if ((uc & 0xF0) == 0xE0) { + return 3; + } + if ((uc & 0xF8) == 0xF0) { + return 4; + } + return 1; +} + +static bool is_ascii_alpha(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); +} + +static bool starts_with_at(const std::string& text, size_t pos, const std::string& needle) { + return pos + needle.size() <= text.size() && text.compare(pos, needle.size(), needle) == 0; +} + +static bool is_word_internal_apostrophe(const std::string& text, size_t pos) { + return pos > 0 && pos + 1 < text.size() && + is_ascii_alpha(text[pos - 1]) && is_ascii_alpha(text[pos + 1]); +} + +static std::vector> split_quotation(const std::string& text) { + static const std::vector> quote_pairs = { + {"'", "'"}, + {"\"", "\""}, + {"\xE2\x80\x98", "\xE2\x80\x99"}, + {"\xE2\x80\x9C", "\xE2\x80\x9D"}, + }; + + std::vector> result; + size_t segment_start = 0; + size_t i = 0; + + auto push_segment = [&](size_t begin, size_t end, bool matched) { + if (end > begin) { + result.emplace_back(text.substr(begin, end - begin), matched); + } + }; + + while (i < text.size()) { + bool matched_quote = false; + for (const auto& quote_pair : quote_pairs) { + const std::string& open_quote = quote_pair.first; + const std::string& close_quote = quote_pair.second; + if (!starts_with_at(text, i, open_quote)) { + continue; + } + if (open_quote == "'" && is_word_internal_apostrophe(text, i)) { + continue; + } + + size_t search_pos = i + open_quote.size(); + size_t close_pos = std::string::npos; + bool invalid = false; + while (search_pos < text.size()) { + if (open_quote != close_quote && starts_with_at(text, search_pos, open_quote)) { + invalid = true; + break; + } + if (starts_with_at(text, search_pos, close_quote)) { + if (close_quote == "'" && is_word_internal_apostrophe(text, search_pos)) { + search_pos += close_quote.size(); + continue; + } + close_pos = search_pos; + break; + } + + size_t char_len = get_utf8_char_len(text[search_pos]); + if (search_pos + char_len > text.size()) { + char_len = 1; + } + search_pos += char_len; + } + if (invalid || close_pos == std::string::npos) { + continue; + } + + size_t quote_start = i; + push_segment(segment_start, quote_start, false); + i = close_pos + close_quote.size(); + push_segment(quote_start, i, true); + segment_start = i; + matched_quote = true; + break; + } + if (!matched_quote) { + size_t char_len = get_utf8_char_len(text[i]); + if (i + char_len > text.size()) { + char_len = 1; + } + i += char_len; + } + } + + push_segment(segment_start, text.size(), false); + return result; +} + +std::vector> split_quotation_attention( + const std::vector>& parsed_attention) { + std::vector> result; + for (const auto& item : parsed_attention) { + const std::string& text = item.first; + float weight = item.second; + for (const auto& part : split_quotation(text)) { + if (part.second) { + size_t i = 0; + while (i < part.first.size()) { + size_t char_len = get_utf8_char_len(part.first[i]); + if (i + char_len > part.first.size()) { + char_len = 1; + } + result.emplace_back(part.first.substr(i, char_len), weight); + i += char_len; + } + } else { + result.emplace_back(part.first, weight); + } + } + } + return result; +} + // namespace is needed to avoid conflicts with ggml_backend_extend.hpp namespace ggml_cpu { #include "ggml-cpu.h" diff --git a/src/util.h b/src/util.h index 9f6099597..c3b06b1d6 100644 --- a/src/util.h +++ b/src/util.h @@ -83,6 +83,8 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text); +std::vector> split_quotation_attention( + const std::vector>& parsed_attention); sd_progress_cb_t sd_get_progress_callback(); void* sd_get_progress_callback_data();