AIGC之图片生成——基于检索的图生成

项目代码:github.com/liangwq/Cha... AIGC之图片生成------基于clip内容检索

背景:

前面已经介绍了基于内容的图检索,今天我们来介绍基于检索的图生成。基于检索的图生成重点在于多模态的检索,生成图至少有两种应用: 1.大模型生成文案,抽取关键词,clip检索出合适的配图 2.基于文案检索出图,以检索图为基础继续加工 3.上面两部分的组合,做文案配图 在这部分不会详细介绍改图,只是给出了一个SD turbo的image to image的简单例子。理论上讲这部分的图加工可以很复杂,比如:分图层生成图、按物体修改、按色系修改、基于草图增删、组合合图...... 但这部分的目的在于介绍基于检索的方式来做图的生成,介绍基于图检索方式对图质量提升的好处。所以后面的一些修改组合技术可以另开章节介绍。

正文:

基于内容的检索系统在文本配图上应用

搭建qwen文本生成模型,代码如下:

python 复制代码
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat-int4", trust_remote_code=True,cache_dir="./")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat-int4", device_map="auto", trust_remote_code=True,cache_dir="./").eval()
model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat-int4", trust_remote_code=True,cache_dir="./")
def predict(history, max_length, top_p, temperature):
    stop = StopOnTokens()
    messages = []
    for idx, (user_msg, model_msg) in enumerate(history):
        if idx == len(history) - 1 and not model_msg:
            messages.append({"role": "user", "content": user_msg})
            break
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if model_msg:
            messages.append({"role": "assistant", "content": model_msg})

    print("\n\n====conversation====\n", messages)
    model_inputs = tokenizer.apply_chat_template(messages,
                                                 add_generation_prompt=True,
                                                 tokenize=True,
                                                 return_tensors="pt").to(next(model.parameters()).device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": model_inputs,
        "streamer": streamer,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "stopping_criteria": StoppingCriteriaList([stop]),
        "repetition_penalty": 1.2,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    for new_token in streamer:
        if new_token != '':
            history[-1][1] += new_token
            yield history
with gr.Blocks() as demo:
    with gr.Tab("文本创作页面"):
    #gr.HTML("""<h1 align="center">ChatGLM3-6B Gradio Simple Demo</h1>""")
        chatbot = gr.Chatbot()

        with gr.Row():
            with gr.Column(scale=4):
                with gr.Column(scale=12):
                    user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10, container=False)
                with gr.Column(min_width=32, scale=1):
                    submitBtn = gr.Button("Submit")
            with gr.Column(scale=1):
                emptyBtn = gr.Button("Clear History")
                max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
                top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
                temperature = gr.Slider(0.01, 1, value=0.6, step=0.01, label="Temperature", interactive=True)


        def user(query, history):
            return "", history + [[parse_text(query), ""]]


        submitBtn.click(user, [user_input, chatbot], [user_input, chatbot], queue=False).then(
            predict, [chatbot, max_length, top_p, temperature], chatbot
        )
        emptyBtn.click(lambda: None, None, chatbot, queue=False)

类似openai API接口请求的前端代码如下,openai API服务端代码可以直接看我项目库代码。

python 复制代码
if prompt := st.chat_input():
    st.session_state.messages.append({"role": "user", "content": prompt})
    st.chat_message("user").write(prompt)
    print(st.session_state.messages)
    
        
    messages =[]
    messages.append({"role": "user", "content": prompt})
    history_mssg.append({"role": "user", "content":str(st.session_state.messages)+ prompt})
    #print(history_mssg)
    response = openai.ChatCompletion.create(model="Qwen", messages=history_mssg,#st.session_state.messages,
    stream=False,
    stop=[])
    msg = response.choices[0].message.content
    assistant_mssg = {"role": "assistant", "content": msg}
    st.session_state.messages.append({"role": "assistant", "content": msg})
    history_mssg.append(assistant_mssg)
    st.chat_message("assistant").write(msg)

在LLM的交互界面让模型生成你要的文本,生成完在你需要配图的地方, 1.输入:"抽取上面文字的关键词"得到LLM的回答就是你的检索图的query词(这部分后面可以做到按钮方式点击直接传到query地方检索); 2.然后输入"把上面的关键词翻译成英文,英文输出到一行",这部分就可以作为你后面image2image的prompt

基于内容的检索系统在图生成上的应用

这部分会用到stable turbo用来对检索出来的图,通过prompt方式来修改。 1.通过对生成文本抽取关键词,检索到图 2.从检索中图中选择一张复制到参考图位置 3.把抽取的关键词翻译成英文promt,用stable turbo改图 stable turbo具体代码如下:

python 复制代码
if SAFETY_CHECKER == "True":
    i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        cache_dir = "./",
        torch_dtype=torch_dtype,
        variant="fp16" if torch_dtype == torch.float16 else "fp32",
    )
    t2i_pipe = AutoPipelineForText2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        cache_dir = "./",
        torch_dtype=torch_dtype,
        variant="fp16" if torch_dtype == torch.float16 else "fp32",
    )
else:
    i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        safety_checker=None,
        cache_dir = "./",
        torch_dtype=torch_dtype,
        variant="fp16" if torch_dtype == torch.float16 else "fp32",
    )
    t2i_pipe = AutoPipelineForText2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        safety_checker=None,
        cache_dir = "./",
        torch_dtype=torch_dtype,
        variant="fp16" if torch_dtype == torch.float16 else "fp32",
    )


t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
t2i_pipe.set_progress_bar_config(disable=True)
i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
i2i_pipe.set_progress_bar_config(disable=True)


def resize_crop(image, size=512):
    image = image.convert("RGB")
    w, h = image.size
    image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
    return image


async def predict_image(init_image, prompt, strength, steps, seed=1231231):
    if init_image is not None:
        init_image = resize_crop(init_image)
        generator = torch.manual_seed(seed)
        last_time = time.time()
    
        if int(steps * strength) < 1:
            steps = math.ceil(1 / max(0.10, strength))
            
        results = i2i_pipe(
            prompt=prompt,
            image=init_image,
            generator=generator,
            num_inference_steps=steps,
            guidance_scale=0.0,
            strength=strength,
            width=512,
            height=512,
            output_type="pil",
        )
    else:
        generator = torch.manual_seed(seed)
        last_time = time.time()
        results = t2i_pipe(
            prompt=prompt,
            generator=generator,
            num_inference_steps=steps,
            guidance_scale=0.0,
            width=512,
            height=512,
            output_type="pil",
        )
    print(f"Pipe took {time.time() - last_time} seconds")
    nsfw_content_detected = (
        results.nsfw_content_detected[0]
        if "nsfw_content_detected" in results
        else False
    )
    if nsfw_content_detected:
        gr.Warning("NSFW content detected.")
        return Image.new("RGB", (512, 512))
    return results.images[0]

界面可视化部分代码如下:

python 复制代码
    with gr.Tab("配图页面"):
        init_image_state = gr.State()
        with gr.Row():
            with gr.Column(scale=1):
                gallery = gr.Gallery(
                        label="Generated images", show_label=False,elem_id="gallery",show_share_button=True,columns=[1], rows=[5], object_fit="contain", height="auto")
                slider = gr.Slider(0, 10, step=1)
                input_image = gr.Image( type="pil")
                text_prompt = gr.Textbox(label="Search Word")

                with gr.Row():
                    text_button = gr.Button(value="Text Search")
                    image_button = gr.Button(value="Image Search")
        
            with gr.Column(elem_id="container",scale=4):
                with gr.Row():
                    prompt = gr.Textbox(
                        placeholder="Insert your prompt here:",
                        #scale=5,
                        lines=10,
                        container=False,
                    )
                    generate_bt = gr.Button("Generate")#, scale=1)
                with gr.Row():
                    with gr.Column():
                        image_input = gr.Image(
                            sources=["upload", "webcam", "clipboard"],
                            label="Webcam",
                            type="pil",
                        )
                    with gr.Column():
                        image = gr.Image(type="filepath")
                        with gr.Accordion("Advanced options", open=False):
                            strength = gr.Slider(
                                label="Strength",
                                value=0.7,
                                minimum=0.0,
                                maximum=1.0,
                                step=0.001,
                            )
                            steps = gr.Slider(
                                label="Steps", value=2, minimum=1, maximum=10, step=1
                            )
                            seed = gr.Slider(
                                randomize=True,
                                minimum=0,
                                maximum=12013012031030,
                                label="Seed",
                                step=1,
                            )
            image_button.click(image_search_image, inputs = [input_image,slider], outputs =[gallery])      
            text_button.click(text_search_image, inputs = [text_prompt,slider], outputs =[gallery])
            
            inputs = [image_input, prompt, strength, steps, seed]
            generate_bt.click(fn=predict_image, inputs=inputs, outputs=image, show_progress=False)
            prompt.change(fn=predict_image, inputs=inputs, outputs=image, show_progress=False)
            steps.change(fn=predict_image, inputs=inputs, outputs=image, show_progress=False)
            seed.change(fn=predict_image, inputs=inputs, outputs=image, show_progress=False)
            strength.change(fn=predict_image, inputs=inputs, outputs=image, show_progress=False)

项目代码:github.com/liangwq/Cha...

小结:

这篇文章介绍了如何基于clip检索到的图给文章配图,进一步介绍了如何基于检索到的图做图的生成修改。文章虽然只是简单的介绍了基于image2image用关键词prompt方式来改图的方法,但这个只是给大家一个思路。实际上还有很多的基于检索到的图改图方法大家可以基于自己需要去尝试。这篇文章的目的在于强调合启发大家基于检索生成图的思考。到此你就拥有一个基于检索的配图、改图的简单工具。其实大家也看到了图库的重要性、以及检索准确性的重要性。如果图库质量好、检索质量好后面的创作任务事半功倍。所以真正的功夫还在数据,这里面可以搞的东西很多,后面会再用几篇文章简单介绍。

相关推荐
远望清一色几秒前
基于MATLAB边缘检测博文
开发语言·算法·matlab
千天夜2 分钟前
激活函数解析:神经网络背后的“驱动力”
人工智能·深度学习·神经网络
tyler_download2 分钟前
手撸 chatgpt 大模型:简述 LLM 的架构,算法和训练流程
算法·chatgpt
大数据面试宝典3 分钟前
用AI来写SQL:让ChatGPT成为你的数据库助手
数据库·人工智能·chatgpt
封步宇AIGC8 分钟前
量化交易系统开发-实时行情自动化交易-3.4.1.2.A股交易数据
人工智能·python·机器学习·数据挖掘
m0_5236742110 分钟前
技术前沿:从强化学习到Prompt Engineering,业务流程管理的创新之路
人工智能·深度学习·目标检测·机器学习·语言模型·自然语言处理·数据挖掘
HappyAcmen20 分钟前
IDEA部署AI代写插件
java·人工智能·intellij-idea
SoraLuna22 分钟前
「Mac玩转仓颉内测版7」入门篇7 - Cangjie控制结构(下)
算法·macos·动态规划·cangjie
我狠狠地刷刷刷刷刷26 分钟前
中文分词模拟器
开发语言·python·算法
鸽鸽程序猿26 分钟前
【算法】【优选算法】前缀和(上)
java·算法·前缀和