Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "ggml"]
path = ggml
url = https://github.com/ggml-org/ggml.git
url = https://github.com/rmatif/ggml.git
branch = ace
142 changes: 141 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <cmath>
#include <cctype>
#include <filesystem>
#include <functional>
Expand Down Expand Up @@ -149,7 +150,7 @@ struct SDCliParams {
options.manual_options = {
{"-M",
"--mode",
"run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen",
"run mode, one of [img_gen, vid_gen, audio_gen, upscale, convert], default: img_gen",
on_mode_arg},
{"",
"--preview",
Expand All @@ -170,6 +171,10 @@ struct SDCliParams {
return false;
}

if (mode == AUDIO_GEN && output_path == "output.png") {
output_path = "output.wav";
}

if (mode == CONVERT) {
if (output_path == "output.png") {
output_path = "output.gguf";
Expand Down Expand Up @@ -370,6 +375,87 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
return result;
}

static bool write_wav(const std::string& path, const sd_audio_t& audio) {
if (audio.data == nullptr || audio.sample_count == 0 || audio.channels == 0) {
return false;
}

FILE* f = fopen(path.c_str(), "wb");
if (!f) {
return false;
}

uint32_t sample_rate = audio.sample_rate;
uint16_t channels = static_cast<uint16_t>(audio.channels);
uint16_t bits_per_sample = 16;
uint32_t byte_rate = sample_rate * channels * (bits_per_sample / 8);
uint16_t block_align = channels * (bits_per_sample / 8);
uint32_t data_size = audio.sample_count * channels * (bits_per_sample / 8);
uint32_t chunk_size = 36 + data_size;

fwrite("RIFF", 1, 4, f);
fwrite(&chunk_size, 4, 1, f);
fwrite("WAVE", 1, 4, f);
fwrite("fmt ", 1, 4, f);
uint32_t subchunk1_size = 16;
uint16_t audio_format = 1;
fwrite(&subchunk1_size, 4, 1, f);
fwrite(&audio_format, 2, 1, f);
fwrite(&channels, 2, 1, f);
fwrite(&sample_rate, 4, 1, f);
fwrite(&byte_rate, 4, 1, f);
fwrite(&block_align, 2, 1, f);
fwrite(&bits_per_sample, 2, 1, f);
fwrite("data", 1, 4, f);
fwrite(&data_size, 4, 1, f);

for (uint32_t i = 0; i < audio.sample_count * audio.channels; ++i) {
float v = audio.data[i];
if (v > 1.0f) v = 1.0f;
if (v < -1.0f) v = -1.0f;
int16_t s = (int16_t)std::lrintf(v * 32767.0f);
fwrite(&s, sizeof(int16_t), 1, f);
}

fclose(f);
return true;
}

bool save_audio_result(const SDCliParams& cli_params,
const SDGenerationParams& gen_params,
const sd_audio_t& audio) {
(void)gen_params;
namespace fs = std::filesystem;
fs::path out_path = cli_params.output_path;

if (!out_path.parent_path().empty()) {
std::error_code ec;
fs::create_directories(out_path.parent_path(), ec);
if (ec) {
LOG_ERROR("failed to create directory '%s': %s",
out_path.parent_path().string().c_str(), ec.message().c_str());
return false;
}
}

fs::path base_path = out_path;
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
if (!ext.empty())
base_path.replace_extension();

std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
if (ext_lower != ".wav") {
ext = ".wav";
}

fs::path audio_path = base_path;
audio_path += ext;
bool ok = write_wav(audio_path.string(), audio);
LOG_INFO("save result audio to '%s' (%s)", audio_path.string().c_str(), ok ? "success" : "failure");
return ok;
}

bool save_results(const SDCliParams& cli_params,
const SDContextParams& ctx_params,
const SDGenerationParams& gen_params,
Expand Down Expand Up @@ -501,6 +587,10 @@ int main(int argc, const char* argv[]) {
cli_params.preview_fps = gen_params.fps;
if (cli_params.preview_method == PREVIEW_PROJ)
cli_params.preview_fps /= 4;
if (cli_params.mode == AUDIO_GEN) {
cli_params.preview_method = PREVIEW_NONE;
cli_params.preview_noisy = false;
}

sd_set_log_callback(sd_log_cb, (void*)&cli_params);
log_verbose = cli_params.verbose;
Expand Down Expand Up @@ -540,6 +630,56 @@ int main(int argc, const char* argv[]) {
}
}

if (cli_params.mode == AUDIO_GEN) {
bool vae_decode_only = true;
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, false);

sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) {
LOG_INFO("new_sd_ctx_t failed");
return 1;
}

if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
}
if (gen_params.sample_params.guidance.txt_cfg == 7.0f) {
gen_params.sample_params.guidance.txt_cfg = 1.0f;
}

sd_audio_gen_params_t audio_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.lyrics.c_str(),
gen_params.keyscale.c_str(),
gen_params.language.c_str(),
gen_params.bpm,
gen_params.duration,
gen_params.timesignature,
gen_params.lm_seed,
gen_params.sample_params,
gen_params.seed,
};

sd_audio_t* audio = generate_audio(sd_ctx, &audio_params);
if (audio == nullptr) {
LOG_ERROR("audio generation failed");
free_sd_ctx(sd_ctx);
return 1;
}

bool ok = save_audio_result(cli_params, gen_params, *audio);
free(audio->data);
free(audio);
free_sd_ctx(sd_ctx);
return ok ? 0 : 1;
}

bool vae_decode_only = true;
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {0, 0, 3, nullptr};
Expand Down
60 changes: 59 additions & 1 deletion examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ namespace fs = std::filesystem;
const char* modes_str[] = {
"img_gen",
"vid_gen",
"audio_gen",
"convert",
"upscale",
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
#define SD_ALL_MODES_STR "img_gen, vid_gen, audio_gen, convert, upscale"

enum SDMode {
IMG_GEN,
VID_GEN,
AUDIO_GEN,
CONVERT,
UPSCALE,
MODE_COUNT
Expand Down Expand Up @@ -1024,6 +1026,13 @@ struct SDGenerationParams {
std::string prompt;
std::string prompt_with_lora; // for metadata record only
std::string negative_prompt;
std::string lyrics;
std::string language = "en";
std::string keyscale = "C major";
float bpm = 120.f;
float duration = 120.f;
int timesignature = 2;
int lm_seed = 0;
int clip_skip = -1; // <= 0 represents unspecified
int width = -1;
int height = -1;
Expand Down Expand Up @@ -1090,6 +1099,18 @@ struct SDGenerationParams {
"--negative-prompt",
"the negative prompt (default: \"\")",
&negative_prompt},
{"",
"--lyrics",
"lyrics for ACE audio models",
&lyrics},
{"",
"--language",
"language for ACE audio lyrics (default: en)",
&language},
{"",
"--keyscale",
"keyscale for ACE audio (e.g. \"C major\")",
&keyscale},
{"-i",
"--init-img",
"path to the init image",
Expand Down Expand Up @@ -1131,6 +1152,14 @@ struct SDGenerationParams {
"--width",
"image width, in pixel space (default: 512)",
&width},
{"",
"--timesignature",
"time signature for ACE audio (default: 2)",
&timesignature},
{"",
"--lm-seed",
"seed for ACE audio semantic token generation (default: 0)",
&lm_seed},
{"",
"--steps",
"number of sample steps (default: 20)",
Expand Down Expand Up @@ -1176,6 +1205,14 @@ struct SDGenerationParams {
"--cfg-scale",
"unconditional guidance scale: (default: 7.0)",
&sample_params.guidance.txt_cfg},
{"",
"--bpm",
"tempo in BPM for ACE audio (default: 120)",
&bpm},
{"",
"--duration",
"duration in seconds for ACE audio (default: 120.0)",
&duration},
{"",
"--img-cfg-scale",
"image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)",
Expand Down Expand Up @@ -1573,6 +1610,13 @@ struct SDGenerationParams {

load_if_exists("prompt", prompt);
load_if_exists("negative_prompt", negative_prompt);
load_if_exists("lyrics", lyrics);
load_if_exists("language", language);
load_if_exists("keyscale", keyscale);
load_if_exists("bpm", bpm);
load_if_exists("duration", duration);
load_if_exists("timesignature", timesignature);
load_if_exists("lm_seed", lm_seed);
load_if_exists("cache_mode", cache_mode);
load_if_exists("cache_option", cache_option);
load_if_exists("cache_preset", cache_preset);
Expand Down Expand Up @@ -1744,6 +1788,13 @@ struct SDGenerationParams {
return false;
}

if (mode == AUDIO_GEN) {
if (duration <= 0.f) {
LOG_ERROR("error: audio duration must be greater than 0\n");
return false;
}
}

sd_cache_params_init(&cache_params);

auto parse_named_params = [&](const std::string& opt_str) -> bool {
Expand Down Expand Up @@ -1937,6 +1988,13 @@ struct SDGenerationParams {
<< " high_noise_loras: \"" << high_noise_loras_str << "\",\n"
<< " prompt: \"" << prompt << "\",\n"
<< " negative_prompt: \"" << negative_prompt << "\",\n"
<< " lyrics: \"" << lyrics << "\",\n"
<< " language: \"" << language << "\",\n"
<< " keyscale: \"" << keyscale << "\",\n"
<< " bpm: " << bpm << ",\n"
<< " duration: " << duration << ",\n"
<< " timesignature: " << timesignature << ",\n"
<< " lm_seed: " << lm_seed << ",\n"
<< " clip_skip: " << clip_skip << ",\n"
<< " width: " << width << ",\n"
<< " height: " << height << ",\n"
Expand Down
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 203 files
27 changes: 27 additions & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ typedef struct {
uint8_t* data;
} sd_image_t;

typedef struct {
uint32_t sample_rate;
uint32_t channels;
uint32_t sample_count;
float* data;
} sd_audio_t;

typedef struct {
int* layers;
size_t layer_count;
Expand Down Expand Up @@ -304,6 +311,22 @@ typedef struct {
sd_cache_params_t cache;
} sd_img_gen_params_t;

typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt;
const char* negative_prompt;
const char* lyrics;
const char* keyscale;
const char* language;
float bpm;
float duration;
int timesignature;
int lm_seed;
sd_sample_params_t sample_params;
int64_t seed;
} sd_audio_gen_params_t;

typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
Expand Down Expand Up @@ -372,6 +395,10 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

SD_API void sd_audio_gen_params_init(sd_audio_gen_params_t* sd_audio_gen_params);
SD_API char* sd_audio_gen_params_to_str(const sd_audio_gen_params_t* sd_audio_gen_params);
SD_API sd_audio_t* generate_audio(sd_ctx_t* sd_ctx, const sd_audio_gen_params_t* sd_audio_gen_params);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);

Expand Down
Loading
Loading