环境信息
| 项目 | 配置 |
|---|---|
| GPU | NVIDIA GeForce RTX 4090 D (24 GB VRAM) |
| 系统内存 | 64 GB |
| 操作系统 | Windows 10/11 |
| Python | 3.12 |
| PyTorch | 2.5.1+cu124 |
| 虚拟环境 | uv (v0.7.2) |
模型结构
DreamX-World-5B 推理涉及三个独立模型:
| 模型 | 大小 (bf16) | 来源 |
|---|---|---|
| UMT5-XXL 文本编码器 | ~10.6 GB | Wan2.2-TI2V-5B |
| CausalWanModel 扩散模型 | ~10 GB | DreamX-World-5B/model.safetensors |
| Wan2.2 VAE | ~2.8 GB | Wan2.2-TI2V-5B |
一、Windows 兼容性修复
1.1 flash_attention 替换
问题 :flash_attn 和 triton 不支持 Windows,安装失败。
修复 :wan/modules/model_2_2.py 中 WanSelfAttention 和 WanCrossAttention 直接调用 flash_attention() 无回退机制,在 Windows 下必定崩溃。
方案 :将导入和调用改为有回退机制的 attention() 函数(内部自动降级到 torch.nn.functional.scaled_dot_product_attention)。
python
# 修改前
from .attention import flash_attention
x = flash_attention(q, k, v, k_lens=context_lens)
# 修改后
from .attention import attention
x = attention(q, k, v, k_lens=context_lens)
文件 :wan/modules/model_2_2.py(2处调用)
1.2 视频写入兼容性
问题 :torchvision.io.write_video 与新版 PyAV 不兼容,frame.pict_type = "NONE" 报 TypeError。
方案 :替换为 imageio 写入。
python
# 修改前
from torchvision.io import write_video
write_video(output_path, video[0], fps=args.fps)
# 修改后
import imageio
writer = imageio.get_writer(output_path, fps=args.fps, codec='libx264')
for frame in video[0].numpy().astype('uint8'):
writer.append_data(frame)
writer.close()
文件 :inference_ar_forcing.py
1.3 模型权重形状修复
问题 :model.safetensors 中 cam_self_attn 层的权重维度为 768(attn_compress=4,即 3072/4),但推理代码因 eprope: false 将 attn_compress 设为 1,导致模型期望维度 3072,加载时 size mismatch。
方案 :YAML 配置中 eprope: false 改为 eprope: true,使 attn_compress=4 匹配权重。
yaml
# 修改前
eprope: false
# 修改后
eprope: true
文件 :configs/dreamx-ar/causal_camera_forcing_5b.yaml
二、显存 (VRAM) 优化 --- 逐段加载释放
2.1 问题分析
原始代码将三个模型全部常驻 GPU 显存:
| 组件 | 显存占用 |
|---|---|
| 文本编码器 | ~10.6 GB |
| 扩散模型 | ~10 GB |
| KV 缓存 (30层) | ~12 GB |
| VAE | ~2.8 GB |
| 总计 | ~35.4 GB |
RTX 4090 仅 24 GB 显存,VAE 解码被迫在不足 2 GB 的剩余空间中运行,耗时 24 分钟。
2.2 优化方案:Sequential Offload
每个组件独立执行「加载到 GPU → 使用 → 卸载回 CPU → 清空显存」:
[0/3 VAE encode] 加载 VAE → 编码图像 → 释放 Free VRAM: 22.5 GB
[1/3 text encoder] 加载编码器 → 编码文本 → 释放 Free VRAM: 22.4 GB
[2/3 diffusion] 加载扩散模型 → 推理 → 释放模型+缓存 Free VRAM: 22.4 GB
[3/3 VAE decode] 加载 VAE → 解码视频 → 释放 Free VRAM: 21.6 GB
2.3 实现细节
pipeline/pipeline_causal_camera.py --- 核心修改:
python
class CausalCameraInferencePipeline(torch.nn.Module):
def __init__(self, ...):
...
self.sequential_offload = False # 新增标志
def inference(self, ...):
# Step 1: 文本编码 --- 加载/使用/释放
if self.sequential_offload:
self.text_encoder.to(gpu)
conditional_dict = self.text_encoder(text_prompts=text_prompts)
if self.sequential_offload:
self.text_encoder.to(cpu)
torch.cuda.empty_cache()
gc.collect()
# Step 2: 扩散推理 --- 加载/使用/释放
if self.sequential_offload:
self.generator.to(gpu)
# ... 扩散循环 ...
if self.sequential_offload:
self._clear_caches() # 释放 KV 缓存 (~12GB)
self.generator.to(cpu) # 卸载扩散模型 (~10GB)
torch.cuda.empty_cache()
gc.collect()
# Step 3: VAE 解码 --- 加载/使用/释放
if self.sequential_offload:
self.vae.to(gpu)
video = self.vae.decode_to_pixel(output)
if self.sequential_offload:
self.vae.to(cpu)
torch.cuda.empty_cache()
gc.collect()
inference_ar_forcing.py --- 入口修改:
python
# --device_map 参数启用逐段加载
if getattr(args, 'device_map', False):
pipeline.sequential_offload = True
# VAE 编码也单独加载释放
if getattr(pipeline, 'sequential_offload', False):
pipeline.vae.to(device)
initial_latent = pipeline.vae.encode_to_latent(image_tensor)
if getattr(pipeline, 'sequential_offload', False):
pipeline.vae.to(cpu)
torch.cuda.empty_cache()
gc.collect()
2.4 性能对比
| 指标 | 全部常驻 GPU | 仅 VAE 前释放 | 逐段加载释放 |
|---|---|---|---|
| 扩散推理 | 15.4 s | 15.6 s | 15.7 s |
| VAE 解码 | 24 分钟 | 1.7 分钟 | 39 秒 |
| 总耗时 | 24 分钟 | 2 分钟 | 58 秒 |
| VAE 可用显存 | ~1-2 GB | ~21 GB | ~21 GB |
VAE 解码加速 37 倍,总耗时从 24 分钟降至不到 1 分钟。
三、内存 (RAM) 优化
3.1 问题分析
优化前加载过程的内存峰值达 55 GB:
start: 0 GB
text encoder 加载后: 23.3 GB (模型 + state_dict 同时存在)
VAE 加载后: 23.2 GB
generator 随机初始化后: 35.8 GB (float32 随机权重 ~12.6 GB)
load_state_dict 后: 55.5 GB (数据从 mmap 拷贝, 旧权重未释放)
del sd + gc 后: 35.8 GB
峰值: 55.1 GB
3.2 优化措施
3.2.1 文本编码器直接用 bfloat16 创建
原始代码以 float32 创建模型(~21 GB),后续再转 bfloat16,白白多占 ~10 GB。
python
# 修改前 (utils/wan_wrapper.py)
self.text_encoder = umt5_xxl(dtype=torch.float32, ...)
# 修改后
self.text_encoder = umt5_xxl(dtype=torch.bfloat16, ...)
3.2.2 使用 assign=True 避免权重拷贝
load_state_dict() 默认行为是将 state_dict 的数据拷贝 到模型参数中,导致内存翻倍。assign=True 直接替换参数引用,旧权重立即释放。
python
# 修改前
pipeline.generator.load_state_dict(sd, strict=False)
# 修改后
pipeline.generator.load_state_dict(sd, strict=False, assign=True)
对文本编码器同样适用:
python
# 修改前 (utils/wan_wrapper.py)
self.text_encoder.load_state_dict(torch.load(path, ...))
# 修改后
_sd = torch.load(path, map_location='cpu', weights_only=False)
self.text_encoder.load_state_dict(_sd, assign=True)
del _sd
gc.collect()
3.2.3 及时释放 state_dict + 垃圾回收
在每次 load_state_dict 后立即 del state_dict 并调用 gc.collect()。
python
# inference_ar_forcing.py --- load_pipeline()
sd = load_file(args.checkpoint_path)
sd = {"model." + k: v for k, v in sd.items()}
pipeline.generator.load_state_dict(sd, strict=False, assign=True)
del sd # 立即释放 safetensors 引用
gc.collect() # 触发垃圾回收
# pipeline.to(bfloat16) 后也要 gc
pipeline = pipeline.to(dtype=torch.bfloat16)
gc.collect() # 释放 float32 → bfloat16 转换产生的旧副本
3.3 内存对比
| 阶段 | 优化前 | 优化后 |
|---|---|---|
| text encoder 加载后 | 23.3 GB | 23.3 GB |
| generator 随机权重后 | 35.8 GB | 35.7 GB |
| load_state_dict 后 | 55.5 GB | 35.7 GB |
| 峰值 | 55.1 GB | 35.7 GB |
| 节省 | --- | ~20 GB |
四、修改文件汇总
| 文件 | 修改内容 |
|---|---|
configs/dreamx-ar/causal_camera_forcing_5b.yaml |
eprope: false → true |
wan/modules/model_2_2.py |
flash_attention → attention (2处) |
utils/wan_wrapper.py |
text encoder bf16 +assign=True + gc.collect() |
inference_ar_forcing.py |
--device_map 参数、sequential offload、assign=True、del sd + gc.collect()、imageio 视频写入 |
pipeline/pipeline_causal_camera.py |
sequential_offload 标志、逐段加载/释放逻辑、_clear_caches() 方法、gc.collect() |
五、使用方法
powershell
& E:\DreamX-World\.venv\Scripts\python.exe E:\DreamX-World\inference_ar_forcing.py `
--config_path E:\DreamX-World\configs\dreamx-ar\causal_camera_forcing_5b.yaml `
--model_name E:\Wan2.2-TI2V-5B `
--transformer_path E:\DreamX-World-5B `
--checkpoint_path E:\DreamX-World-5B\model.safetensors `
--data_path E:\DreamX-World\configs\dreamx\eval.json `
--output_folder E:\DreamX-World\outputs_ar `
--num_output_frames 21 `
--seed 42 `
--chunk_relative `
--device_map
| 参数 | 说明 |
|---|---|
--device_map |
启用逐段加载释放模式 |
--num_output_frames |
潜空间帧数,需被 3 整除 |
--chunk_relative |
分块计算相对相机位姿 |
生成时长参考(16 FPS):
--num_output_frames |
像素帧数 | 时长 |
|---|---|---|
| 21 | 81 | ~5 秒 |
| 63 | 249 | ~16 秒 |
| 123 | 489 | ~31 秒 |
| 243 | 969 | ~61 秒 |
六、最终性能指标
配置 :RTX 4090 D (24GB) + 64GB RAM + Windows + --device_map + 21帧
| 指标 | 数值 |
|---|---|
| 扩散推理 | ~16 秒 |
| VAE 解码 | ~40 秒 |
| 总耗时 | ~1 分钟 |
| 显存峰值 | ~12.6 GB (扩散阶段) |
| 内存峰值 | ~36 GB |
| 输出分辨率 | 704 x 1280 |
| 输出帧率 | 16 FPS |
附录:优化后完整源码
A1. configs/dreamx-ar/causal_camera_forcing_5b.yaml
yaml
model_kwargs:
model_name: Wan2.2-TI2V-5B-Camera
timestep_shift: 5.0
model_root_path: ./ # Path containing wan_models/Wan2.2-TI2V-5B-Camera/ directory
eprope: true
denoising_step_list:
- 1000
- 750
- 500
- 250
warp_denoising_step: true
num_train_timestep: 1000
timestep_shift: 5.0
guidance_scale: 3.0
seed: 0
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
num_frame_per_block: 3
is_wan_2_2_vae: true
context_noise: 0.1
adapter:
type: "lora"
rank: 256
alpha: 256
dropout: 0.0
dtype: "bfloat16"
verbose: false
A2. inference_ar_forcing.py
python
"""
AR-Forcing inference script for chunk-wise causal video generation with camera control.
Input JSON format:
[
{
"image_path": "/path/to/image.png",
"caption": "description text",
"action_seq": ["wi", "s"],
"action_speed_list": [6, 4]
},
...
]
Usage:
CUDA_VISIBLE_DEVICES=0 python inference_ar_forcing.py \
--config_path configs/ar_forcing/causal_camera_forcing_5b.yaml \
--base_checkpoint_path /path/to/baseline.pt \
--data_path configs/dreamx/eval.json \
--output_folder outputs_ar/ \
--num_output_frames 21
"""
import argparse
import gc
import json
import os
import cv2
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
import imageio
import torch.nn.functional as F
from pipeline.pipeline_causal_camera import CausalCameraInferencePipeline
from utils.misc import set_seed
from utils.trajectory_processor import generate_trajectory_from_json, Camera
from utils.memory import cpu, gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
from utils.postprocess import postprocess_video_frames
# ─────────────────────────────────────────────────────────────────────────────
# Camera: cam_params (numpy) → PRoPE dict {viewmats, K}
# ─────────────────────────────────────────────────────────────────────────────
def _invert_SE3(mats):
"""Invert batch of 4x4 SE(3) matrices."""
R_inv = mats[..., :3, :3].transpose(-1, -2)
result = torch.zeros_like(mats)
result[..., :3, :3] = R_inv
result[..., :3, 3] = -torch.einsum("...ij,...j->...i", R_inv, mats[..., :3, 3])
result[..., 3, 3] = 1.0
return result
def get_relative_pose(cam_params):
"""Compute relative c2w poses (first frame as origin)."""
abs_w2cs = [cp.w2c_mat for cp in cam_params]
abs_c2ws = [cp.c2w_mat for cp in cam_params]
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]
], dtype=np.float32)
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w] + [abs2rel @ c2w for c2w in abs_c2ws[1:]]
return np.array(ret_poses, dtype=np.float32)
def cam_params_to_prope_dict(cam_params, device, dtype=torch.bfloat16, chunk_relative=False):
"""
Convert Camera objects → PRoPE conditioning dict.
Steps:
1. Subsample to latent-aligned frames (1+4k pattern)
2. Compute relative c2w poses (global or chunk-relative)
3. Invert to w2c (viewmats)
4. Expand each frame to 880 spatial tokens (22x40 patches)
5. Build normalized intrinsic K matrices
Args:
chunk_relative: If True, compute relative poses per chunk (chunk_size=3).
Returns:
{'viewmats': [1, T_latent*880, 4, 4], 'K': [1, T_latent*880, 3, 3]}
"""
num_frames = len(cam_params)
latent_frame_count = 1 + (num_frames - 1) // 4
aligned_indices = [0] + [1 + 4 * i for i in range(latent_frame_count - 1)]
cam_params_sub = [cam_params[i] for i in aligned_indices]
if chunk_relative:
chunk_size = 3
all_relative_poses = []
for chunk_start in range(0, len(cam_params_sub), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(cam_params_sub))
if chunk_start == 0:
chunk_cams = cam_params_sub[chunk_start:chunk_end]
else:
reference_cam = cam_params_sub[chunk_start - 1]
chunk_cams = [reference_cam] + cam_params_sub[chunk_start:chunk_end]
chunk_poses = get_relative_pose(chunk_cams)
if chunk_start == 0:
all_relative_poses.append(chunk_poses)
else:
all_relative_poses.append(chunk_poses[1:])
c2w_poses = np.concatenate(all_relative_poses, axis=0)
else:
c2w_poses = get_relative_pose(cam_params_sub)
c2ws = torch.as_tensor(c2w_poses, dtype=dtype, device=device)
viewmats = _invert_SE3(c2ws) # [T_latent, 4, 4]
# Expand to per-token: 880 = 22*40 spatial patches per frame
viewmats = viewmats.unsqueeze(1).expand(-1, 880, -1, -1).reshape(1, -1, 4, 4)
# Normalized intrinsics (fixed, matching training config)
fx_norm = 969.6969696969696 / (960.0 * 2) # ≈ 0.505
fy_norm = 969.6969696969696 / (540.0 * 2) # ≈ 0.898
K = torch.zeros((1, 3, 3), dtype=dtype, device=device)
K[:, 0, 0] = fx_norm
K[:, 1, 1] = fy_norm
K[:, 0, 2] = 0.5
K[:, 1, 2] = 0.5
K[:, 2, 2] = 1.0
K = K.unsqueeze(1).expand(-1, viewmats.shape[1], -1, -1).reshape(1, -1, 3, 3)
return {'viewmats': viewmats, 'K': K}
def parse_args():
parser = argparse.ArgumentParser(description="AR-Forcing causal video generation")
# Model and config paths
parser.add_argument("--config_path", type=str, required=True,
help="Path to AR-forcing YAML config file")
parser.add_argument("--model_name", type=str, default=None,
help="Path to the folder containing Wan2.2 base model weights (text encoder, tokenizer, VAE)")
parser.add_argument("--transformer_path", type=str, default=None,
help="Path to the folder containing AR-forcing transformer config.json")
parser.add_argument("--vae_path", type=str, default=None,
help="Path to VAE checkpoint file (overrides model_name/Wan2.2_VAE.pth)")
parser.add_argument("--base_checkpoint_path", type=str, default=None,
help="Path to base .pt checkpoint (generator_ema key)")
parser.add_argument("--checkpoint_path", type=str, default=None,
help="Path to additional checkpoint (.pt or .safetensors)")
# Input/Output
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=True)
# Generation parameters
parser.add_argument("--num_output_frames", type=int, default=21)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--fps", type=int, default=16)
# Post-processing
parser.add_argument("--color_correction_strength", type=float, default=0.3)
# Camera
parser.add_argument("--chunk_relative", action="store_true",
help="Compute relative camera poses per chunk (chunk_size=3) instead of globally")
# LoRA
parser.add_argument("--lora_ckpt", type=str, default=None,
help="Path to LoRA checkpoint (requires adapter section in config)")
parser.add_argument("--device_map", action="store_true",
help="Use accelerate device_map to auto-split model across GPU/CPU")
return parser.parse_args()
def load_pipeline(args, config, device):
"""Load the CausalCameraInferencePipeline with checkpoints."""
# Build explicit paths from --model_name and --transformer_path if provided
text_encoder_path = None
tokenizer_path = None
vae_path = args.vae_path
model_config_path = None
if args.model_name:
text_encoder_path = os.path.join(args.model_name, "models_t5_umt5-xxl-enc-bf16.pth")
tokenizer_path = os.path.join(args.model_name, "google/umt5-xxl/")
if vae_path is None:
vae_path = os.path.join(args.model_name, "Wan2.2_VAE.pth")
if args.transformer_path:
model_config_path = os.path.join(args.transformer_path, "config.json")
pipeline = CausalCameraInferencePipeline(
config, device=device, num_output_frames=args.num_output_frames,
model_config_path=model_config_path,
text_encoder_path=text_encoder_path,
tokenizer_path=tokenizer_path,
vae_path=vae_path,
)
if args.base_checkpoint_path:
state_dict = torch.load(args.base_checkpoint_path, map_location="cpu")
checkpoint_key = "generator_ema" if "generator_ema" in state_dict else "generator"
gen_sd = state_dict.get(checkpoint_key, state_dict)
try:
missing, unexpected = pipeline.generator.load_state_dict(gen_sd, strict=False, assign=True)
except RuntimeError:
fixed = {k.replace("model._fsdp_wrapped_module.", "model.", 1): v for k, v in gen_sd.items()}
missing, unexpected = pipeline.generator.load_state_dict(fixed, strict=False, assign=True)
print(f"Base checkpoint loaded --- missing: {len(missing)}, unexpected: {len(unexpected)}")
del state_dict, gen_sd
gc.collect()
if args.checkpoint_path:
if args.checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file
sd = load_file(args.checkpoint_path)
sd = {"model." + k: v for k, v in sd.items()}
elif args.checkpoint_path.endswith(".pt"):
raw = torch.load(args.checkpoint_path, map_location="cpu")
sd = raw.get("generator_ema", raw.get("generator", raw))
del raw
else:
import glob
from safetensors.torch import load_file
sd = {}
for f in glob.glob(args.checkpoint_path + "/*.safetensors"):
for k, v in load_file(f).items():
if args.base_checkpoint_path is None or 'cam_self_attn' in k:
sd['model.' + k] = v
missing, unexpected = pipeline.generator.load_state_dict(sd, strict=False, assign=True)
print(f"Checkpoint loaded --- missing: {len(missing)}, unexpected: {len(unexpected)}")
del sd
gc.collect()
# LoRA support
lora_ckpt_path = args.lora_ckpt
if getattr(config, "adapter", None) and lora_ckpt_path:
import peft
from utils.lora_peft import configure_lora_for_model
print(f"Applying LoRA with config: {config.adapter}")
pipeline.generator.model = configure_lora_for_model(
pipeline.generator.model,
model_name="generator",
lora_config=config.adapter,
)
print(f"Loading LoRA checkpoint from {lora_ckpt_path}")
lora_checkpoint = torch.load(lora_ckpt_path, map_location="cpu")
if isinstance(lora_checkpoint, dict) and "generator_lora" in lora_checkpoint:
peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint["generator_lora"])
else:
peft.set_peft_model_state_dict(pipeline.generator.model, lora_checkpoint)
print("LoRA weights loaded for generator")
return pipeline
def main():
args = parse_args()
device = torch.device("cuda")
set_seed(args.seed)
torch.set_grad_enabled(False)
print(f"Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
low_memory = get_cuda_free_memory_gb(gpu) < 40
# Load config
config = OmegaConf.load(args.config_path)
default_config = OmegaConf.load(
os.path.join(os.path.dirname(args.config_path), "default_config.yaml"))
config = OmegaConf.merge(default_config, config)
# Load pipeline
pipeline = load_pipeline(args, config, device)
pipeline = pipeline.to(dtype=torch.bfloat16)
gc.collect()
if getattr(args, 'device_map', False):
pipeline.sequential_offload = True
print("Sequential offload: each component loaded → used → freed independently")
elif low_memory:
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=gpu)
pipeline.generator.to(device=gpu)
pipeline.vae.to(device=gpu)
else:
pipeline.text_encoder.to(device=gpu)
pipeline.generator.to(device=gpu)
pipeline.vae.to(device=gpu)
# Load JSON data
with open(args.data_path, 'r') as f:
items = json.load(f)
print(f"Loaded {len(items)} items from {args.data_path}")
# Image transform (fixed 704x1280)
transform = transforms.Compose([
transforms.Resize((704, 1280)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
num_pixel_frames = (args.num_output_frames - 1) * 4 + 1
os.makedirs(args.output_folder, exist_ok=True)
# ─── Inference loop ───
for idx, item in enumerate(items):
image_path = item['image_path']
caption = item.get('caption', item.get('prompt', ''))
action_seq = item['action_seq']
action_speed_list = item['action_speed_list']
task_id = item.get('task_id', str(idx))
img_parent = os.path.basename(os.path.dirname(image_path))
img_basename = os.path.splitext(os.path.basename(image_path))[0]
output_name = f"{img_parent}_{img_basename}" if img_parent else img_basename
output_path = os.path.join(args.output_folder, f'{task_id}_{output_name}.mp4')
if os.path.exists(output_path):
print(f"[{idx}] Skip (exists): {output_path}")
continue
print(f"[{idx}] Generating: {output_path}")
# 1) Encode input image
pil_image = Image.open(image_path).convert('RGB')
image_tensor = transform(pil_image).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
if getattr(pipeline, 'sequential_offload', False):
pipeline.vae.to(device)
initial_latent = pipeline.vae.encode_to_latent(image_tensor).to(device=device, dtype=torch.bfloat16)
if getattr(pipeline, 'sequential_offload', False):
pipeline.vae.to(cpu)
torch.cuda.empty_cache()
gc.collect()
print(f"[0/3 VAE encode] done. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
# 2) Build noise (first frame = encoded image)
sampled_noise = torch.randn([1, args.num_output_frames, 48, 44, 80], device=device, dtype=torch.bfloat16)
sampled_noise[:, 0] = initial_latent
# 3) Build camera trajectory → PRoPE dict
action_seq_lower = [a.lower() for a in action_seq]
trajectory_spec = list(zip(action_seq_lower, action_speed_list))
_, cam_params_np, _ = generate_trajectory_from_json(
trajectory_spec=trajectory_spec,
num_frames=num_pixel_frames,
return_cam_params=True,
)
cam_objects = [Camera(cam_params_np[i].tolist()) for i in range(cam_params_np.shape[0])]
control_camera = cam_params_to_prope_dict(cam_objects, device=device, chunk_relative=args.chunk_relative)
# 4) Run inference
video, latents = pipeline.inference(
noise=sampled_noise,
text_prompts=[caption],
y=None,
y_camera=control_camera,
return_latents=True,
)
# 5) Post-process and save video
video = rearrange(video, 'b t c h w -> b t h w c').cpu()
video = 255.0 * video
pipeline.vae.model.clear_cache()
reference_frame = video[0, 0] if video.shape[1] > 0 else None
video = postprocess_video_frames(
video,
reference_frame=reference_frame,
color_correction_strength=args.color_correction_strength,
)
writer = imageio.get_writer(output_path, fps=args.fps, codec='libx264')
for frame in video[0].numpy().astype('uint8'):
writer.append_data(frame)
writer.close()
print(f" Saved: {output_path}")
print("Done.")
if __name__ == "__main__":
main()
A3. pipeline/pipeline_causal_camera.py
python
from typing import List, Optional
import gc
import torch
import tqdm
from utils.wan_wrapper import WanDiffusionCameraWrapper, WanTextEncoder, WanVAEWrapper
from utils.memory import cpu, gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
class CausalCameraInferencePipeline(torch.nn.Module):
def __init__(self, args, device, generator=None, text_encoder=None, vae=None,
model_config_path=None, text_encoder_path=None, tokenizer_path=None, vae_path=None,
**kwargs):
super().__init__()
model_kwargs = getattr(args, "model_kwargs", {})
model_name = model_kwargs["model_name"]
model_root_path = model_kwargs["model_root_path"]
self.generator = WanDiffusionCameraWrapper(
**model_kwargs, is_causal=True, model_config_path=model_config_path,
**kwargs) if generator is None else generator
self.text_encoder = WanTextEncoder(
model_name=model_name, model_root_path=model_root_path,
text_encoder_path=text_encoder_path, tokenizer_path=tokenizer_path,
) if text_encoder is None else text_encoder
self.vae = WanVAEWrapper(
model_root_path=model_root_path, vae_path=vae_path,
) if vae is None else vae
self.scheduler = self.generator.get_scheduler()
self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
if args.warp_denoising_step:
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
self.num_transformer_blocks = 30
self.frame_seq_length = 880
self.kv_cache1 = None
self.args = args
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
self.independent_first_frame = getattr(args, "independent_first_frame", False)
self.local_attn_size = self.generator.model.local_attn_size
self.sequential_offload = False
print(f"KV inference with {self.num_frame_per_block} frames per block")
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
def inference(
self,
noise: torch.Tensor,
text_prompts: List[str],
y: torch.Tensor,
y_camera: torch.Tensor,
initial_latent: Optional[torch.Tensor] = None,
return_latents: bool = False,
profile: bool = True,
low_memory: bool = False,
) -> torch.Tensor:
batch_size, num_frames, num_channels, height, width = noise.shape
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
assert num_frames % self.num_frame_per_block == 0
num_blocks = num_frames // self.num_frame_per_block
else:
assert (num_frames - 1) % self.num_frame_per_block == 0
num_blocks = (num_frames - 1) // self.num_frame_per_block
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
num_output_frames = num_frames + num_input_frames
if self.sequential_offload:
self.text_encoder.to(gpu)
conditional_dict = self.text_encoder(text_prompts=text_prompts)
if self.sequential_offload:
self.text_encoder.to(cpu)
torch.cuda.empty_cache()
gc.collect()
print(f"[1/3 text encoder] done. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
if low_memory:
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
move_model_to_device_with_memory_preservation(
self.text_encoder, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
output = torch.zeros(
[batch_size, num_output_frames, num_channels, height, width],
device=noise.device, dtype=noise.dtype)
if profile:
init_start = torch.cuda.Event(enable_timing=True)
init_end = torch.cuda.Event(enable_timing=True)
diffusion_start = torch.cuda.Event(enable_timing=True)
diffusion_end = torch.cuda.Event(enable_timing=True)
vae_start = torch.cuda.Event(enable_timing=True)
vae_end = torch.cuda.Event(enable_timing=True)
block_times = []
block_start = torch.cuda.Event(enable_timing=True)
block_end = torch.cuda.Event(enable_timing=True)
init_start.record()
if self.sequential_offload:
self.generator.to(gpu)
print(f"[2/3 diffusion] generator loaded. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
# Initialize KV cache
if self.kv_cache1 is None:
self._initialize_kv_cache(batch_size, noise.dtype, noise.device, num_frames)
self._initialize_crossattn_cache(batch_size, noise.dtype, noise.device)
else:
for block_index in range(self.num_transformer_blocks):
self.crossattn_cache[block_index]["is_init"] = False
for block_index in range(len(self.kv_cache1)):
self.kv_cache1[block_index]["global_end_index"] = torch.tensor([0], dtype=torch.long, device=noise.device)
self.kv_cache1[block_index]["local_end_index"] = torch.tensor([0], dtype=torch.long, device=noise.device)
# Cache initial latent frames
current_start_frame = 0
if initial_latent is not None:
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
if self.independent_first_frame:
assert (num_input_frames - 1) % self.num_frame_per_block == 0
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
output[:, :1] = initial_latent[:, :1]
self.generator(
noisy_image_or_video=initial_latent[:, :1],
conditional_dict=conditional_dict,
y=y[:, :1] if y is not None else None,
y_camera=y_camera if isinstance(y_camera, dict) else y_camera[:, :1],
timestep=timestep * 0,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += 1
else:
assert num_input_frames % self.num_frame_per_block == 0
num_input_blocks = num_input_frames // self.num_frame_per_block
for _ in range(num_input_blocks):
current_ref_latents = initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
y_latents = y[:, current_start_frame:current_start_frame + self.num_frame_per_block]
y_camera_latents = y_camera if isinstance(y_camera, dict) else y_camera[:, current_start_frame:current_start_frame + self.num_frame_per_block]
self.generator(
noisy_image_or_video=current_ref_latents,
conditional_dict=conditional_dict,
y=y_latents, y_camera=y_camera_latents,
timestep=timestep * 0,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
current_start_frame += self.num_frame_per_block
if profile:
init_end.record()
torch.cuda.synchronize()
diffusion_start.record()
# Temporal denoising loop
all_num_frames = [self.num_frame_per_block] * num_blocks
if self.independent_first_frame and initial_latent is None:
all_num_frames = [1] + all_num_frames
first_frame_mask = torch.zeros_like(noise)
first_frame_mask[:, 1:] = 1
for i, current_num_frames in tqdm.tqdm(enumerate(all_num_frames)):
if profile:
block_start.record()
noisy_input = noise[:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
latents = noisy_input
y_latents = y[:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames] if y is not None else None
if isinstance(y_camera, dict):
start_index = (current_start_frame - num_input_frames) * self.frame_seq_length
end_index = start_index + self.frame_seq_length * current_num_frames
y_camera_latents = {
'viewmats': y_camera['viewmats'][:, start_index:end_index],
'K': y_camera['K'][:, start_index:end_index],
}
else:
y_camera_latents = y_camera[:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
first_frame_mask_block = first_frame_mask[:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
# Spatial denoising loop
for index, current_timestep in enumerate(self.denoising_step_list):
temp_ts = ((first_frame_mask_block[0, :, 0, ::2, ::2]) * current_timestep).flatten()
temp_ts = torch.cat([
temp_ts,
temp_ts.new_ones(self.frame_seq_length * current_num_frames - temp_ts.size(0)) * current_timestep
])
timestep = temp_ts.unsqueeze(0).expand(batch_size, temp_ts.size(0))
if index < len(self.denoising_step_list) - 1:
_, denoised_pred = self.generator(
noisy_image_or_video=latents,
conditional_dict=conditional_dict,
y=y_latents, y_camera=y_camera_latents,
timestep=timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length)
next_timestep = self.denoising_step_list[index + 1]
next_timestep = next_timestep * torch.ones(
[batch_size, current_num_frames], device=noise.device, dtype=torch.long)
if i == 0:
next_timestep[:, 0] = 0
latents = self.scheduler.add_noise(
denoised_pred.flatten(0, 1),
torch.randn_like(denoised_pred.flatten(0, 1)),
next_timestep.flatten()
).unflatten(0, denoised_pred.shape[:2])
latents = latents * first_frame_mask_block + noisy_input * (1 - first_frame_mask_block)
else:
_, denoised_pred = self.generator(
noisy_image_or_video=latents,
conditional_dict=conditional_dict,
y=y_latents, y_camera=y_camera_latents,
timestep=timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length)
denoised_pred = denoised_pred * first_frame_mask_block + noisy_input * (1 - first_frame_mask_block)
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
# Rerun with context noise to update KV cache
context_timestep = torch.ones_like(timestep) * self.args.context_noise
self.generator(
noisy_image_or_video=denoised_pred,
conditional_dict=conditional_dict,
y=y_latents, y_camera=y_camera_latents,
timestep=context_timestep,
kv_cache=self.kv_cache1,
crossattn_cache=self.crossattn_cache,
current_start=current_start_frame * self.frame_seq_length,
)
if profile:
block_end.record()
torch.cuda.synchronize()
block_times.append(block_start.elapsed_time(block_end))
current_start_frame += current_num_frames
if profile:
diffusion_end.record()
torch.cuda.synchronize()
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
init_time = init_start.elapsed_time(init_end)
vae_start.record()
if self.sequential_offload:
self._clear_caches()
self.generator.to(cpu)
torch.cuda.empty_cache()
gc.collect()
print(f"[2/3 diffusion] done. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
self.vae.to(gpu)
print(f"[3/3 VAE decode] vae loaded. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
else:
self._free_diffusion_memory()
video = self.vae.decode_to_pixel(output)
video = (video * 0.5 + 0.5).clamp(0, 1)
if self.sequential_offload:
self.vae.to(cpu)
torch.cuda.empty_cache()
gc.collect()
print(f"[3/3 VAE decode] done. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
else:
self._reload_diffusion_to_gpu(output.device)
if profile:
vae_end.record()
torch.cuda.synchronize()
vae_time = vae_start.elapsed_time(vae_end)
total_time = init_time + diffusion_time + vae_time
print(f"Profiling: init={init_time:.0f}ms, diffusion={diffusion_time:.0f}ms, vae={vae_time:.0f}ms, total={total_time:.0f}ms")
for i, bt in enumerate(block_times):
print(f" Block {i}: {bt:.0f}ms")
return (video, output) if return_latents else video
def _initialize_kv_cache(self, batch_size, dtype, device, num_frames=21):
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
kv_cache_size = 18480
num_head = 24
self.kv_cache1 = [
{
"k": torch.zeros([batch_size, kv_cache_size, num_head, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, kv_cache_size, num_head, 128], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
}
for _ in range(self.num_transformer_blocks)
]
def _initialize_crossattn_cache(self, batch_size, dtype, device):
num_head = 24
self.crossattn_cache = [
{
"k": torch.zeros([batch_size, 512, num_head, 128], dtype=dtype, device=device),
"v": torch.zeros([batch_size, 512, num_head, 128], dtype=dtype, device=device),
"is_init": False,
}
for _ in range(self.num_transformer_blocks)
]
def _clear_caches(self):
if self.kv_cache1 is not None:
for cache in self.kv_cache1:
for key in list(cache.keys()):
if isinstance(cache[key], torch.Tensor):
cache[key] = None
self.kv_cache1 = None
if self.crossattn_cache is not None:
for cache in self.crossattn_cache:
for key in list(cache.keys()):
if isinstance(cache[key], torch.Tensor):
cache[key] = None
self.crossattn_cache = None
def _free_diffusion_memory(self):
if self.kv_cache1 is not None:
for cache in self.kv_cache1:
for key in list(cache.keys()):
if isinstance(cache[key], torch.Tensor):
cache[key] = None
self.kv_cache1 = None
if self.crossattn_cache is not None:
for cache in self.crossattn_cache:
for key in list(cache.keys()):
if isinstance(cache[key], torch.Tensor):
cache[key] = None
self.crossattn_cache = None
self.generator.to(cpu)
self.text_encoder.to(cpu)
torch.cuda.empty_cache()
print(f"Freed diffusion memory. Free VRAM: {get_cuda_free_memory_gb(gpu):.1f} GB")
def _reload_diffusion_to_gpu(self, device):
self.generator.to(device)
torch.cuda.empty_cache()
A4. utils/wan_wrapper.py
python
import gc
import types
from typing import List, Optional
import torch
import os
from utils.scheduler import SchedulerInterface, FlowMatchScheduler
from wan.modules.tokenizers import HuggingfaceTokenizer
from wan.modules.t5 import umt5_xxl
from wan.modules.vae_2_2 import _video_vae as _video_vae_2_2
class WanTextEncoder(torch.nn.Module):
def __init__(self, model_name: str = "Wan2.2-TI2V-5B-Camera", model_root_path: str = "",
text_encoder_path: str = None, tokenizer_path: str = None):
super().__init__()
if text_encoder_path is None:
text_encoder_path = os.path.join(model_root_path, f"wan_models/{model_name}/models_t5_umt5-xxl-enc-bf16.pth")
if tokenizer_path is None:
tokenizer_path = os.path.join(model_root_path, f"wan_models/{model_name}/google/umt5-xxl/")
self.text_encoder = umt5_xxl(
encoder_only=True, return_tokenizer=False,
dtype=torch.bfloat16, device=torch.device('cpu')
).eval().requires_grad_(False)
_sd = torch.load(text_encoder_path, map_location='cpu', weights_only=False)
self.text_encoder.load_state_dict(_sd, assign=True)
del _sd
gc.collect()
self.tokenizer = HuggingfaceTokenizer(
name=tokenizer_path, seq_len=512, clean='whitespace')
@property
def device(self):
return torch.cuda.current_device()
def forward(self, text_prompts: List[str]) -> dict:
ids, mask = self.tokenizer(text_prompts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.device)
mask = mask.to(self.device)
seq_lens = mask.gt(0).sum(dim=1).long()
context = self.text_encoder(ids, mask)
for u, v in zip(context, seq_lens):
u[v:] = 0.0
return {"prompt_embeds": context}
class WanVAEWrapper(torch.nn.Module):
def __init__(self, model_root_path="", vae_path: str = None):
super().__init__()
mean = [
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]
std = [
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744,
]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
if vae_path is None:
vae_path = os.path.join(model_root_path, "wan_models/Wan2.2-TI2V-5B/Wan2.2_VAE.pth")
self.model = _video_vae_2_2(
pretrained_path=vae_path,
z_dim=48, temperal_downsample=[False, True, True]
).eval().requires_grad_(False)
def encode_to_latent(self, pixel: torch.Tensor) -> torch.Tensor:
device, dtype = pixel.device, pixel.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
output = [self.model.encode(u.unsqueeze(0), scale).float().squeeze(0) for u in pixel]
output = torch.stack(output, dim=0)
return output.permute(0, 2, 1, 3, 4)
def decode_to_pixel(self, latent: torch.Tensor, use_cache: bool = False) -> torch.Tensor:
zs = latent.permute(0, 2, 1, 3, 4)
device, dtype = latent.device, latent.dtype
scale = [self.mean.to(device=device, dtype=dtype),
1.0 / self.std.to(device=device, dtype=dtype)]
decode_fn = self.model.cached_decode if use_cache else self.model.decode
output = []
with torch.autocast(device_type=device.type, dtype=dtype):
for u in zs:
output.append(decode_fn(u.unsqueeze(0), scale).float().clamp_(-1, 1).squeeze(0))
output = torch.stack(output, dim=0)
return output.permute(0, 2, 1, 3, 4)
class WanDiffusionCameraWrapper(torch.nn.Module):
def __init__(
self,
model_name="Wan2.2-TI2V-5B-Camera",
timestep_shift=5.0,
is_causal=True,
local_attn_size=12,
sink_size=3,
model_root_path="",
model_config_path: str = None,
**kwargs,
):
super().__init__()
from wan.modules.causal_camera_model_2_2_prope_infinity import CausalWanModel
num_output_frames = kwargs.get('num_output_frames', 21)
eprope = kwargs.get('eprope', False)
attn_compress = 4 if eprope else 1
if model_config_path is None:
model_config_path = os.path.join(model_root_path, f"wan_models/{model_name}/config.json")
self.model = CausalWanModel.from_config(
model_config_path,
local_attn_size=local_attn_size, sink_size=sink_size, attn_compress=attn_compress)
self.model.eval()
self.scheduler = FlowMatchScheduler(shift=timestep_shift, sigma_min=0.0, extra_one_step=True)
self.scheduler.set_timesteps(1000, training=True)
self.seq_len = 880 * num_output_frames
self._bind_scheduler_methods()
def _bind_scheduler_methods(self):
self.scheduler.convert_x0_to_noise = types.MethodType(
SchedulerInterface.convert_x0_to_noise, self.scheduler)
self.scheduler.convert_noise_to_x0 = types.MethodType(
SchedulerInterface.convert_noise_to_x0, self.scheduler)
self.scheduler.convert_velocity_to_x0 = types.MethodType(
SchedulerInterface.convert_velocity_to_x0, self.scheduler)
def get_scheduler(self) -> SchedulerInterface:
return self.scheduler
def _convert_flow_pred_to_x0(self, flow_pred, xt, timestep):
original_dtype = flow_pred.dtype
flow_pred, xt, sigmas, timesteps = map(
lambda x: x.double().to(flow_pred.device),
[flow_pred, xt, self.scheduler.sigmas, self.scheduler.timesteps])
timestep_id = torch.argmin(
(timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1)
return (xt - sigma_t * flow_pred).to(original_dtype)
def forward(
self,
noisy_image_or_video: torch.Tensor,
conditional_dict: dict,
y_camera,
timestep: torch.Tensor,
y: Optional[torch.Tensor] = None,
kv_cache: Optional[List[dict]] = None,
crossattn_cache: Optional[List[dict]] = None,
current_start: Optional[int] = None,
cache_start: Optional[int] = None,
cache_update_policy: str = "commit_detached",
) -> torch.Tensor:
prompt_embeds = conditional_dict["prompt_embeds"]
skip_length = noisy_image_or_video.shape[-1] * noisy_image_or_video.shape[-2] // 4
original_timestep = timestep[:, ::skip_length]
y_camera_input = y_camera if (y_camera is None or isinstance(y_camera, dict)) else y_camera.permute(0, 2, 1, 3, 4)
flow_pred = self.model(
noisy_image_or_video.permute(0, 2, 1, 3, 4),
t=timestep,
context=prompt_embeds,
y=y.permute(0, 2, 1, 3, 4) if y is not None else None,
y_camera=y_camera_input,
seq_len=self.seq_len,
kv_cache=kv_cache,
crossattn_cache=crossattn_cache,
current_start=current_start,
cache_start=cache_start,
cache_update_policy=cache_update_policy,
).permute(0, 2, 1, 3, 4)
pred_x0 = self._convert_flow_pred_to_x0(
flow_pred.flatten(0, 1),
noisy_image_or_video.flatten(0, 1),
original_timestep.flatten(0, 1),
).unflatten(0, flow_pred.shape[:2])
return flow_pred, pred_x0
A5. wan/modules/model_2_2.py(仅展示修改部分)
该文件共 577 行,仅修改了导入和两处函数调用,其余代码未改动。
python
# 第 9 行 --- 导入修改
from .attention import attention # 原: from .attention import flash_attention
# 第 146 行 --- WanSelfAttention.forward() 中的调用修改
x = attention( # 原: x = flash_attention(
q=rope_apply(q, grid_sizes, freqs),
k=rope_apply(k, grid_sizes, freqs),
v=v,
k_lens=seq_lens,
window_size=self.window_size)
# 第 188 行 --- WanCrossAttention.forward() 中的调用修改
x = attention(q, k, v, k_lens=context_lens) # 原: x = flash_attention(q, k, v, k_lens=context_lens)