前言
Huggingface的StableDiffusionPipeline
实现了端到端的文生图功能,用户输入文字,即可得到最后的图片。但是在某些场景下,如果需要获取中间过程的图片,用于优化长时间等待过程的用户体验,就只能通过修改源码进行实现了。
源码分析
对于SD的__call__()
方法来说,前面的所有步骤都是准备工作,关键在于第七步denoising 这里。progress_bar 里面是一个循环,timesteps 用于控制采样步数,因此总共循环timesteps 步,而progress_bar 之外是图片尺寸提升以及安全特征检测等图像恢复工作。不难发现,latents 是采样过程中的潜在空空间向量,需要通过VAE
进行解码才能变为现实空间的图片向量。所以为了将采样过程中间生成的图片也显示出来,只需要把最后出现的VAE
也用于中间生成的潜在空间向量的解码即可。
python
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
self.mid_res.append(latents)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
# 图像从潜在空间转化为图片空间,4通道转3通道,并且图片尺寸变大:从64提升到512
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# 图片特征安全性检测
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
# tensor转化为PIL.image
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
修改
如果每一采样的中间过程都进行解码,那势必增加整体推理时间,因此可以跳过采样初期,直到采样后期图片基本确定后,再按照一定的间隔进行解码。 i<20
表示采样步骤小于20不考虑, i%2==0
表示后期每两步进行一次解码,加入yield
变成生成器。
python
if i < 20 or i%2==0:
continue
mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])
在原代码相应位置修改后结果如下:
python
# diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# ...省略中间代码...
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 加入这一部分代码
if i < 20 or i%2==0:
continue
mid_img = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
yield self.image_processor.postprocess(mid_img, output_type=output_type, do_denormalize=[True] * mid_img.shape[0])
if not output_type == "latent":
# ...省略后面代码...
效果
通过gradio
构建一个可视化应用:
python
# 导入相关的模块
from diffusers import StableDiffusionPipeline
import torch
import time
import gradio as gr
# 本地模型路径
model_id = "your_model_path"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
# 定义一个事件监听函数,返回生成器对象
def stream_img(prompt):
for mid in pipe(prompt=prompt, num_inference_steps=30):
yield mid[0]
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)")
button = gr.Button('生成', size='sm')
output_img = gr.Image(interactive=False, type="pil", image_mode="RGB")
# 输出到图片组件
button.click(fn=stream_img, inputs=prompt, outputs=output_img, queue=True)
demo.queue(max_size=20)
demo.launch()
需要注意:button.click(queue=True)
才能接受生成器对象,并且需要设置demo.queue(max_size)
可以看到图片生成的中间过程正确显示出来了。