第七十二篇-V100-32G+WebUI+Flux.1-Schnell+Lora+文生图

写了一个简单的Flux工具,可以加载Lora,在这里做个笔记

安装

bash 复制代码
pip install torch diffusers gradio

代码

r_webui_3.py

bash 复制代码
import torch
import time
import argparse
import os
from datetime import datetime
from diffusers import FluxPipeline
import gradio as gr
from threading import Lock

# 全局模型管理器(避免重复加载)
class ModelManager:
    def __init__(self):
        self.pipeline = None
        self.current_lora_path = None
        self.lock = Lock()
        
    def load_model(self, progress=gr.Progress()):
        """加载基础模型"""
        if self.pipeline is not None:
            return "模型已加载,无需重复加载"
            
        with self.lock:
            setup_environment()
            progress(0.3, desc="正在加载FLUX模型...")
            self.pipeline, load_time = load_flux_model_mini("fp16", "cuda")
            return f"✅ 模型加载完成,耗时 {load_time:.2f} 秒"
    
    def load_lora(self, lora_path, lora_weight):
        """加载LoRA权重"""
        if not lora_path:
            return "未提供LoRA路径,跳过加载"
            
        with self.lock:
            if self.pipeline is None:
                return "❌ 请先加载基础模型"
                
            self.pipeline, load_time = load_lora_weights(self.pipeline, lora_path, lora_weight)
            self.current_lora_path = lora_path
            return f"✅ LoRA加载完成,耗时 {load_time:.2f} 秒"
    
    def generate(self, prompt, negative_prompt, steps, guidance, height, width, seed, 
                 lora_path, lora_weight):
        """执行图像生成"""
        with self.lock:
            if self.pipeline is None:
                return None, "❌ 请先加载基础模型"
            
            # 如果需要加载新的LoRA
            if lora_path and lora_path != self.current_lora_path:
                status = self.load_lora(lora_path, lora_weight)
                print(status)
            # 如果不需要LoRA且当前有加载的LoRA,移除它
            elif not lora_path and self.current_lora_path:
                self.pipeline.unload_lora_weights()
                self.current_lora_path = None
                print("已移除当前LoRA权重")
            try:
                image, gen_time = generate_image_optimized(
                    self.pipeline,
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    height=height,
                    width=width,
                    num_inference_steps=steps,
                    guidance_scale=guidance,
                    generator=torch.Generator().manual_seed(seed) if seed and seed > 0 else None,
                )
                
                # 保存图像
                os.makedirs("outputs", exist_ok=True)
                timestamp = datetime.now().strftime("%H%M%S")
                lora_tag = "lora" if lora_path else "none"
                output_path = f"outputs/flux_{lora_tag}_{timestamp}.png"
                image.save(output_path)
                
                # 获取显存信息
                vram_info = ""
                if torch.cuda.is_available():
                    vram_info = f"💾 当前VRAM: {torch.cuda.memory_allocated() / 1024**3:.2f}GB"
                
                stats = f"""✅ 生成耗时: {gen_time:.2f}秒\n{vram_info}\n💾 保存路径: {output_path}"""
                
                # 合并状态和统计信息
                full_info = f"{stats}"
                return image, full_info
                
            except Exception as e:
                return None, f"❌ 生成失败: {str(e)}"
    
    def cleanup(self):
        """清理模型"""
        with self.lock:
            if self.pipeline is not None:
                del self.pipeline
                self.pipeline = None
                self.current_lora_path = None
                torch.cuda.empty_cache()
                return "✅ 模型已卸载,显存已释放"
            return "⚠️ 没有加载的模型需要清理"

# 初始化全局模型管理器
model_manager = ModelManager()

def setup_environment():
    """设置环境变量优化V100性能"""
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch.backends.cudnn.benchmark = True
    
    if torch.cuda.is_available():
        return f"📊 GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB PyTorch版本: {torch.__version__} CUDA版本: {torch.version.cuda}"
    return "⚠️ 未检测到CUDA设备"

def setup_xformers(pipeline):
    """安全启用xformers(带fallback)"""
    try:
        pipeline.enable_xformers_memory_efficient_attention()
        return "✅ 启用 xformers 内存高效注意力"
    except (ImportError, AttributeError) as e:
        return f"⚠️ xformers不可用 ({e}), 回退到默认注意力实现"

def load_flux_model_mini(precision="fp16", device="cuda"):
    """超轻量级加载 - V100 OOM专用"""
    print(f"🚀 正在加载FLUX.1-schnell模型(超轻量模式)...")
    start_time = time.time()
    
    torch_dtype = torch.float16
    
    pipeline = FluxPipeline.from_pretrained(
        "/models/flux-schnell",
        torch_dtype=torch_dtype,
        use_safetensors=True,
    )
    
    pipeline.enable_sequential_cpu_offload()
    pipeline.vae.enable_slicing()
    pipeline.vae.enable_tiling()
    pipeline.enable_attention_slicing(slice_size="max")
    
    setup_xformers(pipeline)
    
    load_time = time.time() - start_time
    print(f"✅ 模型加载完成! 耗时: {load_time:.2f}秒")
    
    if device == "cuda":
        print(f"💾 当前VRAM: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
    
    return pipeline, load_time

def load_lora_weights(pipeline, lora_path, lora_weight=1.0):
    """加载LoRA权重 - 显存优化版 + 兼容新API"""
    if not lora_path or not os.path.exists(lora_path):
        raise ValueError(f"LoRA路径不存在: {lora_path}")
    
    print(f"\n🔌 正在加载LoRA权重: {os.path.basename(lora_path)}")
    torch.cuda.empty_cache()
    
    pipeline.load_lora_weights(lora_path, adapter_name="default", prefix=None)
    pipeline.set_adapters(["default"], adapter_weights=[lora_weight])
    
    return pipeline, 0

def generate_image_optimized(pipeline, prompt, **kwargs):
    """显存优化的生成函数"""
    torch.cuda.empty_cache()
    kwargs.pop("max_sequence_length", None)
    
    start_time = time.time()
    result = pipeline(prompt=prompt, **kwargs)
    inference_time = time.time() - start_time
    
    return result.images[0], inference_time

# Gradio界面构建
def create_interface():
    # 创建温柔橙色系主题
    custom_theme = gr.themes.Soft(
        # 使用柔和的橙色作为主色调
        primary_hue="orange",
        secondary_hue="orange",
        neutral_hue="gray",
        
        # 字体设置
        text_size=gr.themes.sizes.text_md,
        font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]
    )
    with gr.Blocks(title="Flux.1-Schnell 图像生成器", theme=custom_theme) as demo:
        gr.Markdown("# Flux.1-Schnell 图像生成器")
        
        with gr.Row():
            with gr.Column():
                # 模型控制区域
                gr.Markdown("")
                with gr.Row():
                    load_model_btn = gr.Button("加载模型", variant="primary")
                    cleanup_btn = gr.Button("卸载模型", variant="stop")
                    generate_btn = gr.Button("🎨生成图像", variant="primary", scale=2)
                with gr.Accordion("系统状态数据", open=True):
                    model_status = gr.Textbox(label="系统信息",value="未加载", interactive=False, lines=1)
                
                # 输入参数区域
                #gr.Markdown("### 📝 生成参数")
                prompt = gr.Textbox(
                    label="提示词 (Prompt)",
                    placeholder="输入你想要生成的图像描述...",
                    lines=3,
                    value="a beautiful woman"
                )
                
                negative_prompt = gr.Textbox(
                    label="负面提示词 (Negative Prompt)",
                    placeholder="输入你不希望出现的内容...",
                    lines=2,
                    value=""
                )
                
                with gr.Row():
                    with gr.Column():
                        height = gr.Slider(
                            label="图像高度", minimum=256, maximum=2048, 
                            value=512, step=64
                        )
                        width = gr.Slider(
                            label="图像宽度", minimum=256, maximum=2048, 
                            value=512, step=64
                        )
                    
                    with gr.Column():
                        steps = gr.Slider(
                            label="推理步数", minimum=1, maximum=100, 
                            value=4, step=4,
                            info="FLUX-schnell推荐值为4"
                        )
                        guidance = gr.Number(
                            label="引导比例", value=0.0, 
                            info="FLUX-schnell必须为0.0", interactive=False, visible=False
                        )
                        seed = gr.Number(
                            label="随机种子 (-1为随机)", value=-1, 
                            precision=0
                        )
                
                # LoRA设置
                with gr.Accordion("LoRA设置", open=False):
                    lora_path = gr.Textbox(
                        label="LoRA路径 (.safetensors)",
                        placeholder="/opt/ai-runner/ai-toolkit/loras/flux_lustly-ai_v1.safetensors",
                        value=""
                    )
                    lora_weight = gr.Slider(
                        label="LoRA权重", minimum=0.0, maximum=2.0, 
                        value=1.0, step=0.1
                    )
                    with gr.Row():
                        fill_lora_btn = gr.Button("填充默认路径", size="md")
                        load_lora_btn = gr.Button("加载LoRA权重", variant="primary", size="md")
                    lora_status = gr.Textbox(label="LoRA状态", interactive=False)
            
            with gr.Column():
                # 输出区域
                #gr.Markdown("### 🖼️ 生成结果")
                output_image = gr.Image(
                    label="生成的图像",
                    type="pil",
                    format="png"
                )
                generation_info = gr.Textbox(label="生成信息", interactive=False, lines=3)
        
        # 事件处理
        load_model_btn.click(
            fn=model_manager.load_model,
            outputs=[model_status]
        )
        
        cleanup_btn.click(
            fn=model_manager.cleanup,
            outputs=[model_status]
        )
        
        load_lora_btn.click(
            fn=lambda path, weight: model_manager.load_lora(path, weight),
            inputs=[lora_path, lora_weight],
            outputs=[lora_status]
        )
        
        fill_lora_btn.click(
            fn=lambda: "/opt/ai-runner/ai-toolkit/loras/flux_lustly-ai_v1.safetensors",
            outputs=[lora_path]
        )
        
        generate_btn.click(
            fn=model_manager.generate,
            inputs=[
                prompt, negative_prompt, steps, guidance, 
                height, width, seed, lora_path, lora_weight
            ],
            outputs=[output_image, generation_info]
        )
        
        # 页面加载时检查环境
        demo.load(
            fn=setup_environment,
            outputs=[model_status]
        )
    
    return demo

if __name__ == "__main__":
    # 启动Gradio应用
    demo = create_interface()
    demo.launch(
        server_name="0.0.0.0",
        server_port=28001,
        show_error=True,
        share=False,
        debug=True
    )

运行

bash 复制代码
python r_webui_3.py

访问

bash 复制代码
http://192.168.31.222:28001/

效果

相关推荐
weixin_438077492 分钟前
CS336 Assignment 4 (data): Filtering Language Modeling Data 翻译和实现
人工智能·python·语言模型·自然语言处理
合方圆~小文2 分钟前
工业摄像头工作原理与核心特性
数据库·人工智能·模块测试
小郭团队3 分钟前
未来PLC会消失吗?会被嵌入式系统取代吗?
c语言·人工智能·python·嵌入式硬件·架构
yesyesido3 分钟前
智能文件格式转换器:文本/Excel与CSV无缝互转的在线工具
开发语言·python·excel
Aaron15883 分钟前
全频段SDR干扰源模块设计
人工智能·嵌入式硬件·算法·fpga开发·硬件架构·信息与通信·基带工程
摆烂咸鱼~4 分钟前
机器学习(9-2)
人工智能·机器学习
_200_6 分钟前
Lua 流程控制
开发语言·junit·lua
环黄金线HHJX.6 分钟前
拼音字母量子编程PQLAiQt架构”这一概念。结合上下文《QuantumTuan ⇆ QT:Qt》
开发语言·人工智能·qt·编辑器·量子计算
王夏奇6 分钟前
python在汽车电子行业中的应用1-基础知识概念
开发语言·python·汽车
子夜江寒7 分钟前
基于PyTorch的CBOW模型实现与词向量生成
pytorch·python