写了一个简单的Flux工具,可以加载Lora,在这里做个笔记
安装
bash
pip install torch diffusers gradio
代码
r_webui_3.py
bash
import torch
import time
import argparse
import os
from datetime import datetime
from diffusers import FluxPipeline
import gradio as gr
from threading import Lock
# 全局模型管理器(避免重复加载)
class ModelManager:
def __init__(self):
self.pipeline = None
self.current_lora_path = None
self.lock = Lock()
def load_model(self, progress=gr.Progress()):
"""加载基础模型"""
if self.pipeline is not None:
return "模型已加载,无需重复加载"
with self.lock:
setup_environment()
progress(0.3, desc="正在加载FLUX模型...")
self.pipeline, load_time = load_flux_model_mini("fp16", "cuda")
return f"✅ 模型加载完成,耗时 {load_time:.2f} 秒"
def load_lora(self, lora_path, lora_weight):
"""加载LoRA权重"""
if not lora_path:
return "未提供LoRA路径,跳过加载"
with self.lock:
if self.pipeline is None:
return "❌ 请先加载基础模型"
self.pipeline, load_time = load_lora_weights(self.pipeline, lora_path, lora_weight)
self.current_lora_path = lora_path
return f"✅ LoRA加载完成,耗时 {load_time:.2f} 秒"
def generate(self, prompt, negative_prompt, steps, guidance, height, width, seed,
lora_path, lora_weight):
"""执行图像生成"""
with self.lock:
if self.pipeline is None:
return None, "❌ 请先加载基础模型"
# 如果需要加载新的LoRA
if lora_path and lora_path != self.current_lora_path:
status = self.load_lora(lora_path, lora_weight)
print(status)
# 如果不需要LoRA且当前有加载的LoRA,移除它
elif not lora_path and self.current_lora_path:
self.pipeline.unload_lora_weights()
self.current_lora_path = None
print("已移除当前LoRA权重")
try:
image, gen_time = generate_image_optimized(
self.pipeline,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=torch.Generator().manual_seed(seed) if seed and seed > 0 else None,
)
# 保存图像
os.makedirs("outputs", exist_ok=True)
timestamp = datetime.now().strftime("%H%M%S")
lora_tag = "lora" if lora_path else "none"
output_path = f"outputs/flux_{lora_tag}_{timestamp}.png"
image.save(output_path)
# 获取显存信息
vram_info = ""
if torch.cuda.is_available():
vram_info = f"💾 当前VRAM: {torch.cuda.memory_allocated() / 1024**3:.2f}GB"
stats = f"""✅ 生成耗时: {gen_time:.2f}秒\n{vram_info}\n💾 保存路径: {output_path}"""
# 合并状态和统计信息
full_info = f"{stats}"
return image, full_info
except Exception as e:
return None, f"❌ 生成失败: {str(e)}"
def cleanup(self):
"""清理模型"""
with self.lock:
if self.pipeline is not None:
del self.pipeline
self.pipeline = None
self.current_lora_path = None
torch.cuda.empty_cache()
return "✅ 模型已卸载,显存已释放"
return "⚠️ 没有加载的模型需要清理"
# 初始化全局模型管理器
model_manager = ModelManager()
def setup_environment():
"""设置环境变量优化V100性能"""
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
return f"📊 GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB PyTorch版本: {torch.__version__} CUDA版本: {torch.version.cuda}"
return "⚠️ 未检测到CUDA设备"
def setup_xformers(pipeline):
"""安全启用xformers(带fallback)"""
try:
pipeline.enable_xformers_memory_efficient_attention()
return "✅ 启用 xformers 内存高效注意力"
except (ImportError, AttributeError) as e:
return f"⚠️ xformers不可用 ({e}), 回退到默认注意力实现"
def load_flux_model_mini(precision="fp16", device="cuda"):
"""超轻量级加载 - V100 OOM专用"""
print(f"🚀 正在加载FLUX.1-schnell模型(超轻量模式)...")
start_time = time.time()
torch_dtype = torch.float16
pipeline = FluxPipeline.from_pretrained(
"/models/flux-schnell",
torch_dtype=torch_dtype,
use_safetensors=True,
)
pipeline.enable_sequential_cpu_offload()
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
pipeline.enable_attention_slicing(slice_size="max")
setup_xformers(pipeline)
load_time = time.time() - start_time
print(f"✅ 模型加载完成! 耗时: {load_time:.2f}秒")
if device == "cuda":
print(f"💾 当前VRAM: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
return pipeline, load_time
def load_lora_weights(pipeline, lora_path, lora_weight=1.0):
"""加载LoRA权重 - 显存优化版 + 兼容新API"""
if not lora_path or not os.path.exists(lora_path):
raise ValueError(f"LoRA路径不存在: {lora_path}")
print(f"\n🔌 正在加载LoRA权重: {os.path.basename(lora_path)}")
torch.cuda.empty_cache()
pipeline.load_lora_weights(lora_path, adapter_name="default", prefix=None)
pipeline.set_adapters(["default"], adapter_weights=[lora_weight])
return pipeline, 0
def generate_image_optimized(pipeline, prompt, **kwargs):
"""显存优化的生成函数"""
torch.cuda.empty_cache()
kwargs.pop("max_sequence_length", None)
start_time = time.time()
result = pipeline(prompt=prompt, **kwargs)
inference_time = time.time() - start_time
return result.images[0], inference_time
# Gradio界面构建
def create_interface():
# 创建温柔橙色系主题
custom_theme = gr.themes.Soft(
# 使用柔和的橙色作为主色调
primary_hue="orange",
secondary_hue="orange",
neutral_hue="gray",
# 字体设置
text_size=gr.themes.sizes.text_md,
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"]
)
with gr.Blocks(title="Flux.1-Schnell 图像生成器", theme=custom_theme) as demo:
gr.Markdown("# Flux.1-Schnell 图像生成器")
with gr.Row():
with gr.Column():
# 模型控制区域
gr.Markdown("")
with gr.Row():
load_model_btn = gr.Button("加载模型", variant="primary")
cleanup_btn = gr.Button("卸载模型", variant="stop")
generate_btn = gr.Button("🎨生成图像", variant="primary", scale=2)
with gr.Accordion("系统状态数据", open=True):
model_status = gr.Textbox(label="系统信息",value="未加载", interactive=False, lines=1)
# 输入参数区域
#gr.Markdown("### 📝 生成参数")
prompt = gr.Textbox(
label="提示词 (Prompt)",
placeholder="输入你想要生成的图像描述...",
lines=3,
value="a beautiful woman"
)
negative_prompt = gr.Textbox(
label="负面提示词 (Negative Prompt)",
placeholder="输入你不希望出现的内容...",
lines=2,
value=""
)
with gr.Row():
with gr.Column():
height = gr.Slider(
label="图像高度", minimum=256, maximum=2048,
value=512, step=64
)
width = gr.Slider(
label="图像宽度", minimum=256, maximum=2048,
value=512, step=64
)
with gr.Column():
steps = gr.Slider(
label="推理步数", minimum=1, maximum=100,
value=4, step=4,
info="FLUX-schnell推荐值为4"
)
guidance = gr.Number(
label="引导比例", value=0.0,
info="FLUX-schnell必须为0.0", interactive=False, visible=False
)
seed = gr.Number(
label="随机种子 (-1为随机)", value=-1,
precision=0
)
# LoRA设置
with gr.Accordion("LoRA设置", open=False):
lora_path = gr.Textbox(
label="LoRA路径 (.safetensors)",
placeholder="/opt/ai-runner/ai-toolkit/loras/flux_lustly-ai_v1.safetensors",
value=""
)
lora_weight = gr.Slider(
label="LoRA权重", minimum=0.0, maximum=2.0,
value=1.0, step=0.1
)
with gr.Row():
fill_lora_btn = gr.Button("填充默认路径", size="md")
load_lora_btn = gr.Button("加载LoRA权重", variant="primary", size="md")
lora_status = gr.Textbox(label="LoRA状态", interactive=False)
with gr.Column():
# 输出区域
#gr.Markdown("### 🖼️ 生成结果")
output_image = gr.Image(
label="生成的图像",
type="pil",
format="png"
)
generation_info = gr.Textbox(label="生成信息", interactive=False, lines=3)
# 事件处理
load_model_btn.click(
fn=model_manager.load_model,
outputs=[model_status]
)
cleanup_btn.click(
fn=model_manager.cleanup,
outputs=[model_status]
)
load_lora_btn.click(
fn=lambda path, weight: model_manager.load_lora(path, weight),
inputs=[lora_path, lora_weight],
outputs=[lora_status]
)
fill_lora_btn.click(
fn=lambda: "/opt/ai-runner/ai-toolkit/loras/flux_lustly-ai_v1.safetensors",
outputs=[lora_path]
)
generate_btn.click(
fn=model_manager.generate,
inputs=[
prompt, negative_prompt, steps, guidance,
height, width, seed, lora_path, lora_weight
],
outputs=[output_image, generation_info]
)
# 页面加载时检查环境
demo.load(
fn=setup_environment,
outputs=[model_status]
)
return demo
if __name__ == "__main__":
# 启动Gradio应用
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=28001,
show_error=True,
share=False,
debug=True
)
运行
bash
python r_webui_3.py
访问
bash
http://192.168.31.222:28001/
效果
