几行代码实现可视化的文生图采样过程

前言

Huggingface的StableDiffusionPipeline实现了端到端的文生图功能,用户输入文字,即可得到最后的图片。但是在某些场景下,如果需要获取中间过程的图片,用于优化长时间等待过程的用户体验,就只能通过修改源码进行实现了。

源码分析

对于SD的__call__()方法来说,前面的所有步骤都是准备工作,关键在于第七步denoising 这里。progress_bar 里面是一个循环,timesteps 用于控制采样步数,因此总共循环timesteps 步,而progress_bar 之外是图片尺寸提升以及安全特征检测等图像恢复工作。不难发现,latents 是采样过程中的潜在空空间向量,需要通过VAE进行解码才能变为现实空间的图片向量。所以为了将采样过程中间生成的图片也显示出来,只需要把最后出现的VAE也用于中间生成的潜在空间向量的解码即可。

python 复制代码
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

	# 7. Denoising loop
	num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
	self._num_timesteps = len(timesteps)
	with self.progress_bar(total=num_inference_steps) as progress_bar:
	    for i, t in enumerate(timesteps):
	        # expand the latents if we are doing classifier free guidance
	        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
	        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
	
	        # predict the noise residual
	        noise_pred = self.unet(
	            latent_model_input,
	            t,
	            encoder_hidden_states=prompt_embeds,
	            timestep_cond=timestep_cond,
	            cross_attention_kwargs=self.cross_attention_kwargs,
	            return_dict=False,
	        )[0]
	
	        # perform guidance
	        if self.do_classifier_free_guidance:
	            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
	            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
	
	        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
	            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
	            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
	
	        # compute the previous noisy sample x_t -> x_t-1
	        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
	        self.mid_res.append(latents)    
	        if callback_on_step_end is not None:
	            callback_kwargs = {}
	            for k in callback_on_step_end_tensor_inputs:
	                callback_kwargs[k] = locals()[k]
	            callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
	
	            latents = callback_outputs.pop("latents", latents)
	            prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
	            negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
	
	        # call the callback, if provided
	        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
	            progress_bar.update()
	            if callback is not None and i % callback_steps == 0:
	                step_idx = i // getattr(self.scheduler, "order", 1)
	                callback(step_idx, t, latents)
	
	if not output_type == "latent":
	    # 图像从潜在空间转化为图片空间,4通道转3通道,并且图片尺寸变大:从64提升到512
	    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
	    # 图片特征安全性检测
	    image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
	else:
	    image = latents
	    has_nsfw_concept = None
	
	if has_nsfw_concept is None:
	    do_denormalize = [True] * image.shape[0]
	else:
	    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
	# tensor转化为PIL.image
	image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
	
	# Offload all models
	self.maybe_free_model_hooks()
	
	if not return_dict:
	    return (image, has_nsfw_concept)
	
	return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

修改

如果每一采样的中间过程都进行解码,那势必增加整体推理时间,因此可以跳过采样初期,直到采样后期图片基本确定后,再按照一定的间隔进行解码。 i<20表示采样步骤小于20不考虑, i%2==0表示后期每两步进行一次解码,加入yield变成生成器。

python 复制代码
if i < 20 or i%2==0:
    continue
mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])

在原代码相应位置修改后结果如下:

python 复制代码
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
	with self.progress_bar(total=num_inference_steps) as progress_bar:
		for i, t in enumerate(timesteps):
			 # ...省略中间代码...
             # call the callback, if provided
             if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                 progress_bar.update()
                 if callback is not None and i % callback_steps == 0:
                     step_idx = i // getattr(self.scheduler, "order", 1)
                     callback(step_idx, t, latents)
             # 加入这一部分代码
             if i < 20 or i%2==0:
                 continue
             mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
             yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])
	
	if not output_type == "latent":
	 		# ...省略后面代码...

效果

通过gradio构建一个可视化应用:

python 复制代码
# 导入相关的模块
from diffusers import StableDiffusionPipeline
import torch
import time
import gradio as gr
# 本地模型路径
model_id = "your_model_path"
pipe = StableDiffusionPipeline.from_pretrained(model_id)

# 定义一个事件监听函数,返回生成器对象
def stream_img(prompt):
    for mid in pipe(prompt=prompt, num_inference_steps=30):
        yield mid[0]

with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
            button = gr.Button('生成', size='sm')
        output_img = gr.Image(interactive=False, type="pil",  image_mode="RGB")
        # 输出到图片组件
        button.click(fn=stream_img, inputs=prompt, outputs=output_img, queue=True)
demo.queue(max_size=20)
demo.launch()

需要注意:button.click(queue=True)才能接受生成器对象,并且需要设置demo.queue(max_size)


可以看到图片生成的中间过程正确显示出来了。

相关推荐
云天徽上3 分钟前
【数据可视化】全国星巴克门店可视化
人工智能·机器学习·信息可视化·数据挖掘·数据分析
大嘴吧Lucy5 分钟前
大模型 | AI驱动的数据分析:利用自然语言实现数据查询到可视化呈现
人工智能·信息可视化·数据分析
艾思科蓝 AiScholar38 分钟前
【连续多届EI稳定收录&出版级别高&高录用快检索】第五届机械设计与仿真国际学术会议(MDS 2025)
人工智能·数学建模·自然语言处理·系统架构·机器人·软件工程·拓扑学
watersink1 小时前
面试题库笔记
大数据·人工智能·机器学习
Yuleave1 小时前
PaSa:基于大语言模型的综合学术论文搜索智能体
人工智能·语言模型·自然语言处理
数字化综合解决方案提供商1 小时前
【Rate Limiting Advanced插件】赋能AI资源高效分配
大数据·人工智能
一只码代码的章鱼2 小时前
机器学习2 (笔记)(朴素贝叶斯,集成学习,KNN和matlab运用)
人工智能·机器学习
周杰伦_Jay2 小时前
简洁明了:介绍大模型的基本概念(大模型和小模型、模型分类、发展历程、泛化和微调)
人工智能·算法·机器学习·生成对抗网络·分类·数据挖掘·transformer
SpikeKing2 小时前
LLM - 大模型 ScallingLaws 的指导模型设计与实验环境(PLM) 教程(4)
人工智能·llm·transformer·plm·scalinglaws
编码浪子2 小时前
Transformer的编码机制
人工智能·深度学习·transformer