几行代码实现可视化的文生图采样过程

前言

Huggingface的StableDiffusionPipeline实现了端到端的文生图功能,用户输入文字,即可得到最后的图片。但是在某些场景下,如果需要获取中间过程的图片,用于优化长时间等待过程的用户体验,就只能通过修改源码进行实现了。

源码分析

对于SD的__call__()方法来说,前面的所有步骤都是准备工作,关键在于第七步denoising 这里。progress_bar 里面是一个循环,timesteps 用于控制采样步数,因此总共循环timesteps 步,而progress_bar 之外是图片尺寸提升以及安全特征检测等图像恢复工作。不难发现,latents 是采样过程中的潜在空空间向量,需要通过VAE进行解码才能变为现实空间的图片向量。所以为了将采样过程中间生成的图片也显示出来,只需要把最后出现的VAE也用于中间生成的潜在空间向量的解码即可。

python 复制代码
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

	# 7. Denoising loop
	num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
	self._num_timesteps = len(timesteps)
	with self.progress_bar(total=num_inference_steps) as progress_bar:
	    for i, t in enumerate(timesteps):
	        # expand the latents if we are doing classifier free guidance
	        latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
	        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
	
	        # predict the noise residual
	        noise_pred = self.unet(
	            latent_model_input,
	            t,
	            encoder_hidden_states=prompt_embeds,
	            timestep_cond=timestep_cond,
	            cross_attention_kwargs=self.cross_attention_kwargs,
	            return_dict=False,
	        )[0]
	
	        # perform guidance
	        if self.do_classifier_free_guidance:
	            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
	            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
	
	        if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
	            # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
	            noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
	
	        # compute the previous noisy sample x_t -> x_t-1
	        latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
	        self.mid_res.append(latents)    
	        if callback_on_step_end is not None:
	            callback_kwargs = {}
	            for k in callback_on_step_end_tensor_inputs:
	                callback_kwargs[k] = locals()[k]
	            callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
	
	            latents = callback_outputs.pop("latents", latents)
	            prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
	            negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
	
	        # call the callback, if provided
	        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
	            progress_bar.update()
	            if callback is not None and i % callback_steps == 0:
	                step_idx = i // getattr(self.scheduler, "order", 1)
	                callback(step_idx, t, latents)
	
	if not output_type == "latent":
	    # 图像从潜在空间转化为图片空间,4通道转3通道,并且图片尺寸变大:从64提升到512
	    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
	    # 图片特征安全性检测
	    image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
	else:
	    image = latents
	    has_nsfw_concept = None
	
	if has_nsfw_concept is None:
	    do_denormalize = [True] * image.shape[0]
	else:
	    do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
	# tensor转化为PIL.image
	image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
	
	# Offload all models
	self.maybe_free_model_hooks()
	
	if not return_dict:
	    return (image, has_nsfw_concept)
	
	return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

修改

如果每一采样的中间过程都进行解码,那势必增加整体推理时间,因此可以跳过采样初期,直到采样后期图片基本确定后,再按照一定的间隔进行解码。 i<20表示采样步骤小于20不考虑, i%2==0表示后期每两步进行一次解码,加入yield变成生成器。

python 复制代码
if i < 20 or i%2==0:
    continue
mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])

在原代码相应位置修改后结果如下:

python 复制代码
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
	with self.progress_bar(total=num_inference_steps) as progress_bar:
		for i, t in enumerate(timesteps):
			 # ...省略中间代码...
             # call the callback, if provided
             if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                 progress_bar.update()
                 if callback is not None and i % callback_steps == 0:
                     step_idx = i // getattr(self.scheduler, "order", 1)
                     callback(step_idx, t, latents)
             # 加入这一部分代码
             if i < 20 or i%2==0:
                 continue
             mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
             yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])
	
	if not output_type == "latent":
	 		# ...省略后面代码...

效果

通过gradio构建一个可视化应用:

python 复制代码
# 导入相关的模块
from diffusers import StableDiffusionPipeline
import torch
import time
import gradio as gr
# 本地模型路径
model_id = "your_model_path"
pipe = StableDiffusionPipeline.from_pretrained(model_id)

# 定义一个事件监听函数,返回生成器对象
def stream_img(prompt):
    for mid in pipe(prompt=prompt, num_inference_steps=30):
        yield mid[0]

with gr.Blocks(analytics_enabled=False) as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
            button = gr.Button('生成', size='sm')
        output_img = gr.Image(interactive=False, type="pil",  image_mode="RGB")
        # 输出到图片组件
        button.click(fn=stream_img, inputs=prompt, outputs=output_img, queue=True)
demo.queue(max_size=20)
demo.launch()

需要注意:button.click(queue=True)才能接受生成器对象,并且需要设置demo.queue(max_size)


可以看到图片生成的中间过程正确显示出来了。

相关推荐
小于小于大橙子3 小时前
视觉SLAM数学基础
人工智能·数码相机·自动化·自动驾驶·几何学
封步宇AIGC4 小时前
量化交易系统开发-实时行情自动化交易-3.4.2.Okex行情交易数据
人工智能·python·机器学习·数据挖掘
封步宇AIGC5 小时前
量化交易系统开发-实时行情自动化交易-2.技术栈
人工智能·python·机器学习·数据挖掘
陌上阳光5 小时前
动手学深度学习68 Transformer
人工智能·深度学习·transformer
OpenI启智社区5 小时前
共筑开源技术新篇章 | 2024 CCF中国开源大会盛大开幕
人工智能·开源·ccf中国开源大会·大湾区
AI服务老曹5 小时前
建立更及时、更有效的安全生产优化提升策略的智慧油站开源了
大数据·人工智能·物联网·开源·音视频
YRr YRr5 小时前
PyTorch:torchvision中的dataset的使用
人工智能
love_and_hope5 小时前
Pytorch学习--神经网络--完整的模型训练套路
人工智能·pytorch·python·深度学习·神经网络·学习
思通数据6 小时前
AI与OCR:数字档案馆图像扫描与文字识别技术实现与项目案例
大数据·人工智能·目标检测·计算机视觉·自然语言处理·数据挖掘·ocr
兔老大的胡萝卜6 小时前
关于 3D Engine Design for Virtual Globes(三维数字地球引擎设计)
人工智能·3d