开源C++版AI画图大模型框架stable-diffusion.cpp开发使用初体验

stable-diffusion.cpp是一个C++编写的轻量级开源类AIGC大模型框架,可以支持在消费级普通设备上本地部署运行大模型进行AI画图,以及作为依赖库集成的到应用程序中提供类似于网页版stable-diffusion的功能。

以下基于stable-diffusion.cpp的源码利用C++ api来开发实例demo演示加载本地模型文件输入提示词生成画图,这里采用显卡CUDA加速计算,如果没有显卡也可以直接使用CPU。

项目结构

stable_diffusion_cpp_starter
	- stable-diffusion.cpp
	- src
	  |- main.cpp
	- CMakeLists.txt

有两个前置操作:

  • 在系统安装好CUDA Toolkit
  • 将stable-diffusion.cpp源码根目录的CMakeLists.txt里面SD_CUBLAS选项打开设为ON

不过,如果没有支持CUDA的显卡,默认采用CPU计算,则可以忽略以上两项

CMakeLists.txt

cmake_minimum_required(VERSION 3.15)

project(stable_diffusion_cpp_starter)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

add_subdirectory(stable-diffusion.cpp)

include_directories(
    ${CMAKE_CURRENT_SOURCE_DIR}/stable-diffusion.cpp
    ${CMAKE_CURRENT_SOURCE_DIR}/stable-diffusion.cpp/thirdparty
)

file(GLOB SRC
    src/*.h
    src/*.cpp
)

add_executable(${PROJECT_NAME} ${SRC})

target_link_libraries(${PROJECT_NAME} 
    stable-diffusion 
    ${CMAKE_THREAD_LIBS_INIT} # means pthread on unix
)

main.cpp

cpp 复制代码
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#include "stable-diffusion.h"

#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"

#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"

#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"

const char* rng_type_to_str[] = {
    "std_default",
    "cuda",
};

// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
const char* sample_method_str[] = {
    "euler_a",
    "euler",
    "heun",
    "dpm2",
    "dpm++2s_a",
    "dpm++2m",
    "dpm++2mv2",
    "lcm",
};

// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
const char* schedule_str[] = {
    "default",
    "discrete",
    "karras",
    "ays",
};

const char* modes_str[] = {
    "txt2img",
    "img2img",
    "img2vid",
    "convert",
};

enum SDMode 
{
    TXT2IMG,
    IMG2IMG,
    IMG2VID,
    CONVERT,
    MODE_COUNT
};

struct SDParams 
{
    int n_threads = -1;
    SDMode mode   = TXT2IMG;

    std::string model_path;
    std::string vae_path;
    std::string taesd_path;
    std::string esrgan_path;
    std::string controlnet_path;
    std::string embeddings_path;
    std::string stacked_id_embeddings_path;
    std::string input_id_images_path;
    sd_type_t wtype = SD_TYPE_COUNT;
    std::string lora_model_dir;
    std::string output_path = "output.png";
    std::string input_path;
    std::string control_image_path;

    std::string prompt;
    std::string negative_prompt;
    float min_cfg     = 1.0f;
    float cfg_scale   = 7.0f;
    float style_ratio = 20.f;
    int clip_skip     = -1;  // <= 0 represents unspecified
    int width         = 512;
    int height        = 512;
    int batch_count   = 1;

    int video_frames         = 6;
    int motion_bucket_id     = 127;
    int fps                  = 6;
    float augmentation_level = 0.f;

    sample_method_t sample_method = EULER_A;
    schedule_t schedule           = DEFAULT;
    int sample_steps              = 20;
    float strength                = 0.75f;
    float control_strength        = 0.9f;
    rng_type_t rng_type           = CUDA_RNG;
    int64_t seed                  = 42;
    bool verbose                  = false;
    bool vae_tiling               = false;
    bool control_net_cpu          = false;
    bool normalize_input          = false;
    bool clip_on_cpu              = false;
    bool vae_on_cpu               = false;
    bool canny_preprocess         = false;
    bool color                    = false;
    int upscale_repeats           = 1;
};

static std::string sd_basename(const std::string& path) 
{
    size_t pos = path.find_last_of('/');
    if (pos != std::string::npos) {
        return path.substr(pos + 1);
    }
    pos = path.find_last_of('\\');
    if (pos != std::string::npos) {
        return path.substr(pos + 1);
    }
    return path;
}

std::string get_image_params(SDParams params, int64_t seed) 
{
    std::string parameter_string = params.prompt + "\n";
    if (params.negative_prompt.size() != 0) {
        parameter_string += "Negative prompt: " + params.negative_prompt + "\n";
    }
    parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
    parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
    parameter_string += "Seed: " + std::to_string(seed) + ", ";
    parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
    parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
    parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
    parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
    if (params.schedule == KARRAS) {
        parameter_string += " karras";
    }
    parameter_string += ", ";
    parameter_string += "Version: stable-diffusion.cpp";
    return parameter_string;
}

/* Enables Printing the log level tag in color using ANSI escape codes */
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) 
{
    SDParams* params = (SDParams*)data;
    int tag_color;
    const char* level_str;
    FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;

    if (!log || (!params->verbose && level <= SD_LOG_DEBUG)) 
        return;

    switch (level) 
    {
        case SD_LOG_DEBUG:
            tag_color = 37;
            level_str = "DEBUG";
            break;
        case SD_LOG_INFO:
            tag_color = 34;
            level_str = "INFO";
            break;
        case SD_LOG_WARN:
            tag_color = 35;
            level_str = "WARN";
            break;
        case SD_LOG_ERROR:
            tag_color = 31;
            level_str = "ERROR";
            break;
        default: /* Potential future-proofing */
            tag_color = 33;
            level_str = "?????";
            break;
    }

    if (params->color == true) 
        fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
    else 
        fprintf(out_stream, "[%-5s] ", level_str);
    fputs(log, out_stream);
    fflush(out_stream);
}

int main(int argc, const char* argv[]) 
{
    // set sd params
    const std::string model_path = "./v1-5-pruned-emaonly.ckpt";
    const std::string img_output_path = "./gen_img.png";
    const std::string prompt = "a cute little dog with flowers";

    SDParams params;
    params.model_path = model_path;
    params.output_path = img_output_path;
    params.prompt = prompt;

    sd_set_log_callback(sd_log_cb, (void*)&params);

    if (params.mode == CONVERT) 
    {
        bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
        if (!success) 
        {
            fprintf(stderr,
                    "convert '%s'/'%s' to '%s' failed\n",
                    params.model_path.c_str(),
                    params.vae_path.c_str(),
                    params.output_path.c_str());
            return 1;
        } 
        else 
        {
            printf("convert '%s'/'%s' to '%s' success\n",
                   params.model_path.c_str(),
                   params.vae_path.c_str(),
                   params.output_path.c_str());
            return 0;
        }
    }

    if (params.mode == IMG2VID) 
    {
        fprintf(stderr, "SVD support is broken, do not use it!!!\n");
        return 1;
    }

    // prepare image buffer
    bool vae_decode_only          = true;
    uint8_t* input_image_buffer   = NULL;
    uint8_t* control_image_buffer = NULL;
    if (params.mode == IMG2IMG || params.mode == IMG2VID) 
    {
        vae_decode_only = false;

        int c              = 0;
        int width          = 0;
        int height         = 0;
        input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
        if (input_image_buffer == NULL) {
            fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
            return 1;
        }
        if (c < 3) 
        {
            fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
            free(input_image_buffer);
            return 1;
        }
        if (width <= 0) 
        {
            fprintf(stderr, "error: the width of image must be greater than 0\n");
            free(input_image_buffer);
            return 1;
        }
        if (height <= 0) 
        {
            fprintf(stderr, "error: the height of image must be greater than 0\n");
            free(input_image_buffer);
            return 1;
        }

        // Resize input image ...
        if (params.height != height || params.width != width) 
        {
            printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
            int resized_height = params.height;
            int resized_width  = params.width;

            uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
            if (resized_image_buffer == NULL) 
            {
                fprintf(stderr, "error: allocate memory for resize input image\n");
                free(input_image_buffer);
                return 1;
            }
            stbir_resize(input_image_buffer, width, height, 0,
                         resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
                         3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
                         STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
                         STBIR_FILTER_BOX, STBIR_FILTER_BOX,
                         STBIR_COLORSPACE_SRGB, nullptr);

            // Save resized result
            free(input_image_buffer);
            input_image_buffer = resized_image_buffer;
        }
    }

    // init sd context
    sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
                                  params.vae_path.c_str(),
                                  params.taesd_path.c_str(),
                                  params.controlnet_path.c_str(),
                                  params.lora_model_dir.c_str(),
                                  params.embeddings_path.c_str(),
                                  params.stacked_id_embeddings_path.c_str(),
                                  vae_decode_only,
                                  params.vae_tiling,
                                  true,
                                  params.n_threads,
                                  params.wtype,
                                  params.rng_type,
                                  params.schedule,
                                  params.clip_on_cpu,
                                  params.control_net_cpu,
                                  params.vae_on_cpu);

    if (sd_ctx == NULL) 
    {
        printf("new_sd_ctx_t failed\n");
        return 1;
    }

    sd_image_t* control_image = NULL;
    if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) 
    {
        int c                = 0;
        control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
        if (control_image_buffer == NULL) 
        {
            fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
            return 1;
        }
        control_image = new sd_image_t{(uint32_t)params.width,
                                       (uint32_t)params.height,
                                       3,
                                       control_image_buffer};
        if (params.canny_preprocess) 
        {  // apply preprocessor
            control_image->data = preprocess_canny(control_image->data,
                                                   control_image->width,
                                                   control_image->height,
                                                   0.08f,
                                                   0.08f,
                                                   0.8f,
                                                   1.0f,
                                                   false);
        }
    }

    // generate image
    sd_image_t* results;
    if (params.mode == TXT2IMG) 
    {
        results = txt2img(sd_ctx,
                          params.prompt.c_str(),
                          params.negative_prompt.c_str(),
                          params.clip_skip,
                          params.cfg_scale,
                          params.width,
                          params.height,
                          params.sample_method,
                          params.sample_steps,
                          params.seed,
                          params.batch_count,
                          control_image,
                          params.control_strength,
                          params.style_ratio,
                          params.normalize_input,
                          params.input_id_images_path.c_str());
    } 
    else 
    {
        sd_image_t input_image = {(uint32_t)params.width,
                                  (uint32_t)params.height,
                                  3,
                                  input_image_buffer};

        if (params.mode == IMG2VID) {
            results = img2vid(sd_ctx,
                              input_image,
                              params.width,
                              params.height,
                              params.video_frames,
                              params.motion_bucket_id,
                              params.fps,
                              params.augmentation_level,
                              params.min_cfg,
                              params.cfg_scale,
                              params.sample_method,
                              params.sample_steps,
                              params.strength,
                              params.seed);
            if (results == NULL) 
            {
                printf("generate failed\n");
                free_sd_ctx(sd_ctx);
                return 1;
            }
            size_t last            = params.output_path.find_last_of(".");
            std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
            for (int i = 0; i < params.video_frames; i++) 
            {
                if (results[i].data == NULL) 
                    continue;

                std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
                stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
                               results[i].data, 0, get_image_params(params, params.seed + i).c_str());
                printf("save result image to '%s'\n", final_image_path.c_str());
                free(results[i].data);
                results[i].data = NULL;
            }
            free(results);
            free_sd_ctx(sd_ctx);
            return 0;
        } 
        else 
        {
            results = img2img(sd_ctx,
                              input_image,
                              params.prompt.c_str(),
                              params.negative_prompt.c_str(),
                              params.clip_skip,
                              params.cfg_scale,
                              params.width,
                              params.height,
                              params.sample_method,
                              params.sample_steps,
                              params.strength,
                              params.seed,
                              params.batch_count,
                              control_image,
                              params.control_strength,
                              params.style_ratio,
                              params.normalize_input,
                              params.input_id_images_path.c_str());
        }
    }

    if (results == NULL) 
    {
        printf("generate failed\n");
        free_sd_ctx(sd_ctx);
        return 1;
    }

    int upscale_factor = 4;  // unused for RealESRGAN_x4plus_anime_6B.pth
    if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) 
    {
        upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
                                                        params.n_threads,
                                                        params.wtype);

        if (upscaler_ctx == NULL) 
            printf("new_upscaler_ctx failed\n");
        else 
        {
            for (int i = 0; i < params.batch_count; i++) 
            {
                if (results[i].data == NULL) 
                {
                    continue;
                }
                sd_image_t current_image = results[i];
                for (int u = 0; u < params.upscale_repeats; ++u) 
                {
                    sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
                    if (upscaled_image.data == NULL) 
                    {
                        printf("upscale failed\n");
                        break;
                    }
                    free(current_image.data);
                    current_image = upscaled_image;
                }
                results[i] = current_image;  // Set the final upscaled image as the result
            }
        }
    }

    size_t last            = params.output_path.find_last_of(".");
    std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
    for (int i = 0; i < params.batch_count; i++) 
    {
        if (results[i].data == NULL) 
            continue;

        std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
        stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
                       results[i].data, 0, get_image_params(params, params.seed + i).c_str());
        printf("save result image to '%s'\n", final_image_path.c_str());
        free(results[i].data);
        results[i].data = NULL;
    }
    free(results);
    free_sd_ctx(sd_ctx);
    free(control_image_buffer);
    free(input_image_buffer);

    return 0;
}

运行结果

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce GTX 1060 with Max-Q Design, compute capability 6.1, VMM: yes
[INFO ] stable-diffusion.cpp:169  - loading model from './v1-5-pruned-emaonly.ckpt'
[INFO ] model.cpp:736  - load ./v1-5-pruned-emaonly.ckpt using checkpoint format
[INFO ] stable-diffusion.cpp:192  - Stable Diffusion 1.x
[INFO ] stable-diffusion.cpp:198  - Stable Diffusion weight type: f32
[INFO ] stable-diffusion.cpp:419  - total params memory size = 2719.24MB (VRAM 2719.24MB, RAM 0.00MB): clip 469.44MB(VRAM), unet 2155.33MB(VRAM), vae 94.47MB(VRAM), controlnet 0.00MB(VRAM), pmid 0.00MB(VRAM)
[INFO ] stable-diffusion.cpp:423  - loading model from './v1-5-pruned-emaonly.ckpt' completed, taking 18.72s
[INFO ] stable-diffusion.cpp:440  - running in eps-prediction mode
[INFO ] stable-diffusion.cpp:556  - Attempting to apply 0 LoRAs
[INFO ] stable-diffusion.cpp:1203 - apply_loras completed, taking 0.00s
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 1.40 MiB
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 1.40 MiB
[INFO ] stable-diffusion.cpp:1316 - get_learned_condition completed, taking 514 ms
[INFO ] stable-diffusion.cpp:1334 - sampling using Euler A method
[INFO ] stable-diffusion.cpp:1338 - generating image: 1/1 - seed 42
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 559.90 MiB
  |==================================================| 20/20 - 1.40s/it
[INFO ] stable-diffusion.cpp:1381 - sampling completed, taking 35.05s
[INFO ] stable-diffusion.cpp:1389 - generating 1 latent images completed, taking 35.07s
[INFO ] stable-diffusion.cpp:1392 - decoding 1 latents
ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 1664.00 MiB
[INFO ] stable-diffusion.cpp:1402 - latent 1 decoded, taking 3.03s
[INFO ] stable-diffusion.cpp:1406 - decode_first_stage completed, taking 3.03s
[INFO ] stable-diffusion.cpp:1490 - txt2img completed in 38.64s
save result image to './gen_img.png'

注:

  • stable_diffusion支持的模型文件需要自己去下载,推荐到huggingface官网下载ckpt格式文件
  • 提示词要使用英文
  • 支持文字生成图和以图辅助生成图,参数很多,可以多尝试

源码

stable_diffusion_cpp_starter

相关推荐
volcanical几秒前
Dataset Distillation with Attention Labels for Fine-tuning BERT
人工智能·深度学习·bert
L_cl几秒前
【NLP 17、NLP的基础——分词】
人工智能·自然语言处理
西西弗Sisyphus3 分钟前
大型语言模型(LLMs)演化树 Large Language Models
人工智能·语言模型·自然语言处理·大模型
煤泥做不到的!28 分钟前
挑战一个月基本掌握C++(第十一天)进阶文件,异常处理,动态内存
开发语言·c++
F-2H30 分钟前
C语言:指针4(常量指针和指针常量及动态内存分配)
java·linux·c语言·开发语言·前端·c++
axxy20001 小时前
leetcode之hot100---24两两交换链表中的节点(C++)
c++·leetcode·链表
车载诊断技术2 小时前
电子电气架构 --- 什么是EPS?
网络·人工智能·安全·架构·汽车·需求分析
若亦_Royi2 小时前
C++ 的大括号的用法合集
开发语言·c++
KevinRay_2 小时前
Python超能力:高级技巧让你的代码飞起来
网络·人工智能·python·lambda表达式·列表推导式·python高级技巧
跃跃欲试-迪之2 小时前
animatediff 模型网盘分享
人工智能·stable diffusion