diff --git a/src/denoiser.hpp b/src/denoiser.hpp index ee2ef380c..3d884bc79 100644 --- a/src/denoiser.hpp +++ b/src/denoiser.hpp @@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler { parse_extra_sample_args(extra_sample_args); } - static std::string trim(std::string value) { - const char* whitespace = " \t\r\n"; - size_t begin = value.find_first_not_of(whitespace); - if (begin == std::string::npos) { - return ""; - } - size_t end = value.find_last_not_of(whitespace); - return value.substr(begin, end - begin + 1); - } - void parse_extra_sample_args(const char* extra_sample_args) { - if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') { - return; - } - - std::string raw(extra_sample_args); - size_t start = 0; - auto parse_arg = [&](const std::string& item) { - std::string token = trim(item); - if (token.empty()) { - return; - } - size_t eq = token.find('='); - if (eq == std::string::npos) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - return; - } - - std::string key = trim(token.substr(0, eq)); - std::string value = trim(token.substr(eq + 1)); - auto parse_float = [&](float* out) -> bool { - try { - size_t consumed = 0; - float parsed = std::stof(value, &consumed); - if (!trim(value.substr(consumed)).empty()) { - return false; - } - *out = parsed; - return true; - } catch (const std::exception&) { - return false; + for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) { + if (key == "max_shift") { + if (!parse_strict_float(value, max_shift)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); } - }; - try { - if (key == "max_shift") { - if (!parse_float(&max_shift)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "base_shift") { - if (!parse_float(&base_shift)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "terminal") { - if (!parse_float(&terminal)) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else if (key == "stretch") { - std::string v = value; - std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - if (v == "1" || v == "true" || v == "yes" || v == "on") { - stretch = true; - } else if (v == "0" || v == "false" || v == "no" || v == "off") { - stretch = false; - } else { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - } else { - LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str()); + } else if (key == "base_shift") { + if (!parse_strict_float(value, base_shift)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); - } - }; - - for (size_t pos = 0; pos <= raw.size(); ++pos) { - if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') { - parse_arg(raw.substr(start, pos - start)); - start = pos + 1; + } else if (key == "terminal") { + if (!parse_strict_float(value, terminal)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else if (key == "stretch") { + if (!parse_strict_bool(value, stretch)) { + LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str()); + } + } else { + LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str()); } } } @@ -1276,7 +1218,7 @@ static sd::Tensor sample_dpmpp_2m_v2(denoise_cb_t model, return x; } -using SamplerExtraArgs = std::vector>; +using SamplerExtraArgs = KeyValueArgs; static sd::Tensor sample_lcm(denoise_cb_t model, sd::Tensor x, @@ -1296,15 +1238,8 @@ static sd::Tensor sample_lcm(denoise_cb_t model, for (const auto& [key, value] : extra_sample_args) { float parsed = 0.0f; - try { - size_t consumed = 0; - parsed = std::stof(value, &consumed); - if (trim(value.substr(consumed)).size() != 0) { - LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str()); - continue; - } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str()); + if (!parse_strict_float(value, parsed)) { + LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str()); continue; } if (key == "noise_clip_std") { @@ -1861,15 +1796,8 @@ static sd::Tensor sample_gradient_estimation(denoise_cb_t model, for (const auto& [key, value] : extra_sample_args) { float parsed = 0.0f; - try { - size_t consumed = 0; - parsed = std::stof(value, &consumed); - if (trim(value.substr(consumed)).size() != 0) { - LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str()); - continue; - } - } catch (const std::exception&) { - LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str()); + if (!parse_strict_float(value, parsed)) { + LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str()); continue; } if (key == "gamma") { @@ -1916,46 +1844,6 @@ static sd::Tensor sample_gradient_estimation(denoise_cb_t model, return x; } -static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) { - SamplerExtraArgs pairs; - - if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') { - return pairs; - } - - auto trim = [](std::string value) -> std::string { - const char* whitespace = " \t\r\n"; - size_t begin = value.find_first_not_of(whitespace); - if (begin == std::string::npos) { - return ""; - } - size_t end = value.find_last_not_of(whitespace); - return value.substr(begin, end - begin + 1); - }; - - std::string raw(extra_sample_args); - size_t start = 0; - - for (size_t pos = 0; pos <= raw.size(); ++pos) { - if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') { - std::string item = raw.substr(start, pos - start); - std::string token = trim(item); - - if (!token.empty()) { - size_t eq = token.find('='); - if (eq != std::string::npos) { - std::string key = trim(token.substr(0, eq)); - std::string value = trim(token.substr(eq + 1)); - pairs.emplace_back(std::move(key), std::move(value)); - } - } - start = pos + 1; - } - } - - return pairs; -} - // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t static sd::Tensor sample_k_diffusion(sample_method_t method, denoise_cb_t model, @@ -1965,7 +1853,7 @@ static sd::Tensor sample_k_diffusion(sample_method_t method, float eta, bool is_flow_denoiser, const char* extra_sample_args) { - SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args); + SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg"); switch (method) { case EULER_A_SAMPLE_METHOD: return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); diff --git a/src/ltx_vae.hpp b/src/ltx_vae.hpp index b7a462fc5..756741a43 100644 --- a/src/ltx_vae.hpp +++ b/src/ltx_vae.hpp @@ -1251,65 +1251,22 @@ struct LTXVideoVAE : public VAE { temporal_tiling_enabled = enabled; } - static std::string trim_tiling_arg(std::string value) { - const char* whitespace = " \t\r\n"; - size_t begin = value.find_first_not_of(whitespace); - if (begin == std::string::npos) { - return ""; - } - size_t end = value.find_last_not_of(whitespace); - return value.substr(begin, end - begin + 1); - } - - static bool parse_tiling_int(const std::string& value, int& parsed) { - try { - size_t consumed = 0; - parsed = std::stoi(value, &consumed); - return trim_tiling_arg(value.substr(consumed)).empty(); - } catch (...) { - return false; - } - } - void set_tiling_params(const sd_tiling_params_t& params) override { temporal_tiling_enabled = params.temporal_tiling; temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; - const char* extra_tiling_args = params.extra_tiling_args; - if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') { - return; - } - - std::string raw(extra_tiling_args); - size_t start = 0; - for (size_t pos = 0; pos <= raw.size(); ++pos) { - if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') { - continue; - } - - std::string token = trim_tiling_arg(raw.substr(start, pos - start)); - if (!token.empty()) { - size_t eq = token.find('='); - if (eq == std::string::npos) { - LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str()); - } else { - std::string key = trim_tiling_arg(token.substr(0, eq)); - std::string value = trim_tiling_arg(token.substr(eq + 1)); - int parsed = 0; - if (!parse_tiling_int(value, parsed)) { - LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str()); - } else if (key == "temporal_tile_frames") { - temporal_tile_frames = std::max(1, parsed); - } else if (key == "temporal_tile_overlap") { - temporal_tile_overlap = std::max(0, parsed); - } else { - LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str()); - } - } + for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) { + int parsed = 0; + if (!parse_strict_int(value, parsed)) { + LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str()); + } else if (key == "temporal_tile_frames") { + temporal_tile_frames = std::max(1, parsed); + } else if (key == "temporal_tile_overlap") { + temporal_tile_overlap = std::max(0, parsed); + } else { + LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str()); } - - start = pos + 1; } } diff --git a/src/util.cpp b/src/util.cpp index 1c2e5e899..1921b3b23 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -1,8 +1,10 @@ #include "util.h" #include +#include #include #include #include +#include #include #include #include @@ -406,6 +408,88 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } +KeyValueArgs parse_key_value_args(const char* args, const char* context) { + KeyValueArgs pairs; + + if (args == nullptr || args[0] == '\0') { + return pairs; + } + + std::string raw(args); + size_t start = 0; + for (size_t pos = 0; pos <= raw.size(); ++pos) { + if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') { + continue; + } + + std::string token = trim(raw.substr(start, pos - start)); + if (!token.empty()) { + size_t eq = token.find('='); + if (eq == std::string::npos) { + const char* log_context = context ? context : "key=value arg"; + LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str()); + } else { + std::string key = trim(token.substr(0, eq)); + std::string value = trim(token.substr(eq + 1)); + pairs.emplace_back(std::move(key), std::move(value)); + } + } + + start = pos + 1; + } + + return pairs; +} + +KeyValueArgs parse_key_value_args(const std::string& args, const char* context) { + return parse_key_value_args(args.c_str(), context); +} + +bool parse_strict_float(const std::string& text, float& value) { + try { + size_t consumed = 0; + float parsed = std::stof(text, &consumed); + if (!trim(text.substr(consumed)).empty()) { + return false; + } + value = parsed; + return true; + } catch (const std::exception&) { + return false; + } +} + +bool parse_strict_int(const std::string& text, int& value) { + try { + size_t consumed = 0; + int parsed = std::stoi(text, &consumed); + if (!trim(text.substr(consumed)).empty()) { + return false; + } + value = parsed; + return true; + } catch (const std::exception&) { + return false; + } +} + +bool parse_strict_bool(const std::string& text, bool& value) { + std::string lowered = trim(text); + std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + + if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") { + value = true; + return true; + } + if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") { + value = false; + return true; + } + return false; +} + static std::string build_progress_bar(int step, int steps) { std::string progress = " |"; int max_progress = 50; diff --git a/src/util.h b/src/util.h index 9843ae18f..9f6099597 100644 --- a/src/util.h +++ b/src/util.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "ggml-backend.h" @@ -65,6 +66,15 @@ class MmapWrapper { std::string path_join(const std::string& p1, const std::string& p2); std::vector split_string(const std::string& str, char delimiter); + +using KeyValueArgs = std::vector>; + +KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg"); +KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg"); +bool parse_strict_float(const std::string& text, float& value); +bool parse_strict_int(const std::string& text, int& value); +bool parse_strict_bool(const std::string& text, bool& value); + void pretty_progress(int step, int steps, float time); void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);