diff --git a/examples/server/main.cpp b/examples/server/main.cpp index e13346b9e..b4145e1d3 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -368,6 +369,18 @@ int main(int argc, const char** argv) { return httplib::Server::HandlerResponse::Unhandled; }); + auto wait_for_generation = [](std::future& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) { + std::future_status ft_status; + do { + if (!ft.valid()) + break; + ft_status = ft.wait_for(std::chrono::milliseconds(1000)); + if (req.is_connection_closed()) { + sd_cancel_generation(sd_ctx, SD_CANCEL_ALL); + } + } while (ft_status != std::future_status::ready); + }; + // root svr.Get("/", [&](const httplib::Request&, httplib::Response& res) { if (!svr_params.serve_html_path.empty()) { @@ -510,11 +523,13 @@ int main(int argc, const char** argv) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(sd_ctx_mutex); results = generate_image(sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + + wait_for_generation(ft, sd_ctx, req); for (int i = 0; i < num_results; i++) { if (results[i].data == nullptr) { @@ -756,11 +771,13 @@ int main(int argc, const char** argv) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(sd_ctx_mutex); results = generate_image(sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + + wait_for_generation(ft, sd_ctx, req); json out; out["created"] = static_cast(std::time(nullptr)); @@ -1071,11 +1088,13 @@ int main(int argc, const char** argv) { sd_image_t* results = nullptr; int num_results = 0; - { + std::future ft = std::async(std::launch::async, [&]() { std::lock_guard lock(sd_ctx_mutex); results = generate_image(sd_ctx, &img_gen_params); num_results = gen_params.batch_count; - } + }); + + wait_for_generation(ft, sd_ctx, req); json out; out["images"] = json::array(); diff --git a/include/stable-diffusion.h b/include/stable-diffusion.h index 51b2b3291..cfcc40926 100644 --- a/include/stable-diffusion.h +++ b/include/stable-diffusion.h @@ -372,6 +372,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); +enum sd_cancel_mode_t +{ + SD_CANCEL_ALL, + SD_CANCEL_NEW_LATENTS, + SD_CANCEL_RESET +}; + +SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode); + SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out); diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index 717fec18e..d44400d56 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -23,6 +23,8 @@ #include "latent-preview.h" #include "name_conversion.h" +#include + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -99,6 +101,9 @@ void suppress_pp(int step, int steps, float time, void* data) { /*=============================================== StableDiffusionGGML ================================================*/ +static_assert(std::atomic::is_always_lock_free, + "sd_cancel_mode_t must be lock-free"); + class StableDiffusionGGML { public: ggml_backend_t backend = nullptr; // general backend @@ -149,6 +154,8 @@ class StableDiffusionGGML { std::shared_ptr denoiser = std::make_shared(); + std::atomic cancellation_flag; + StableDiffusionGGML() = default; ~StableDiffusionGGML() { @@ -164,6 +171,18 @@ class StableDiffusionGGML { ggml_backend_free(backend); } + void set_cancel_flag(enum sd_cancel_mode_t flag) { + cancellation_flag.store(flag, std::memory_order_release); + } + + void reset_cancel_flag() { + set_cancel_flag(SD_CANCEL_RESET); + } + + enum sd_cancel_mode_t get_cancel_flag() { + return cancellation_flag.load(std::memory_order_acquire); + } + void init_backend() { #ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); @@ -1869,6 +1888,12 @@ class StableDiffusionGGML { } auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { + enum sd_cancel_mode_t cancel_flag = get_cancel_flag(); + if (cancel_flag != SD_CANCEL_RESET) { + LOG_DEBUG("cancelling latent decodings"); + return nullptr; + } + auto sd_preview_cb = sd_get_preview_callback(); auto sd_preview_cb_data = sd_get_preview_callback_data(); auto sd_preview_mode = sd_get_preview_mode(); @@ -3423,6 +3448,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat); } for (int b = 0; b < batch_count; b++) { + + if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) { + LOG_ERROR("cancelling generation"); + break; + } + int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = seed + b; LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed); @@ -3484,6 +3515,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, LOG_INFO("decoding %zu latents", final_latents.size()); std::vector decoded_images; // collect decoded images for (size_t i = 0; i < final_latents.size(); i++) { + + if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) { + LOG_ERROR("cancelling latent decodings"); + break; + } + t1 = ggml_time_ms(); struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); // print_ggml_tensor(img); @@ -3520,6 +3557,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, return result_images; } +void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) +{ + if (sd_ctx && sd_ctx->sd) { + if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) { + mode = SD_CANCEL_ALL; + } + sd_ctx->sd->set_cancel_flag(mode); + } +} + sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; int width = sd_img_gen_params->width; @@ -3542,6 +3589,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g return nullptr; } + sd_ctx->sd->reset_cancel_flag(); + struct ggml_init_params params; params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = nullptr;