Phantom 根据图片和文字描述,自动生成一段视频,并且动作、场景等内容会按照文字描述来呈现

Phantom 根据图片和文字描述,自动生成一段视频,并且动作、场景等内容会按照文字描述来呈现

flyfish

视频生成的实践效果展示

Phantom 视频生成的实践
Phantom 视频生成的流程
Phantom 视频生成的命令

Wan2.1 图生视频 支持批量生成
Wan2.1 文生视频 支持批量生成、参数化配置和多语言提示词管理
Wan2.1 加速推理方法
Wan2.1 通过首尾帧生成视频

AnyText2 在图片里玩文字而且还是所想即所得
Python 实现从 MP4 视频文件中平均提取指定数量的帧

配置

json 复制代码
{
    "task": "s2v-1.3B",
    "size": "832*480",
    "frame_num": 81,
    "ckpt_dir": "./Wan2.1-T2V-1.3B",
    "phantom_ckpt": "./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth",
    "offload_model": false,
    "ulysses_size": 1,
    "ring_size": 1,
    "t5_fsdp": false,
    "t5_cpu": false,
    "dit_fsdp": false,
    "use_prompt_extend": false,
    "prompt_extend_method": "local_qwen",
    "prompt_extend_model": null,
    "prompt_extend_target_lang": "ch",
    "base_seed": 40,
    "sample_solver": "unipc",
    "sample_steps": null,
    "sample_shift": null,
    "sample_guide_scale": 5.0,
    "sample_guide_scale_img": 5.0,
    "sample_guide_scale_text": 7.5
}

参数

一、参数作用解析

1. 任务与模型路径
  • task: "s2v-1.3B"

    • 作用 :指定任务类型为 文本到视频生成(Text-to-Video,T2V)1.3B 表示使用的基础模型(如 Wan2.1-T2V-1.3B)参数规模为 130亿
  • ckpt_dir: "./Wan2.1-T2V-1.3B"

    • 作用 :指定基础模型的权重文件路径。根据你之前提供的文件夹内容,该路径下包含:
      • models_t5_umt5-xxl-enc-bf16.pthT5文本编码器的权重文件(用于处理文本提示)。
      • Wan2.1_VAE.pthVAE模型的权重(用于视频的时空压缩和重建)。
      • google/umt5-xxl文件夹:可能包含T5模型的结构定义或配置文件。
  • phantom_ckpt: "./Phantom-Wan-1.3B/Phantom-Wan-1.3B.pth"

    • 作用 :指定 Phantom跨模态对齐模型 的权重路径,用于锁定参考图像的主体特征(如颜色、轮廓),确保生成视频中主体与参考图像一致。
2. 视频生成配置
  • size: "832*480"

    • 作用 :生成视频的分辨率,格式为 宽度×高度 ,因此 832是宽度,480是高度。例如,常见的16:9分辨率中,宽度大于高度。
  • frame_num: 81

    • 作用:生成视频的总帧数。假设帧率为24fps,81帧约为3.375秒的视频(实际时长取决于帧率设置)。
3. 模型性能与资源配置
  • offload_model: false

    • 作用 :是否将模型参数卸载到CPU或磁盘以节省GPU内存。设为false时,模型全程运行在GPU内存中,速度更快但需更大显存。
  • ulysses_size: 1ring_size: 1

    • 作用 :与分布式训练(如FSDP)相关的参数,用于多卡并行计算。设为1时,表示 单卡运行,不启用分布式分片。
  • t5_fsdp: falset5_cpu: false

    • t5_fsdp :是否对T5文本编码器使用全分片数据并行(FSDP),false表示单卡加载T5模型。
    • t5_cpu :是否将T5模型放在CPU上运行,false表示运行在GPU上(推荐,速度更快)。
  • dit_fsdp: false

    • 作用 :是否对扩散Transformer(DIT,Diffusion Transformer)使用FSDP,false表示单卡运行。
4. 提示与生成控制
  • use_prompt_extend: false

    • 作用 :是否启用提示扩展功能(增强文本提示的语义丰富度)。设为false时,直接使用输入的文本提示,不进行扩展。
  • prompt_extend_method: "local_qwen"prompt_extend_model: null

    • 作用 :提示扩展方法指定为本地Qwen模型,但prompt_extend_model设为null表示未加载该模型,因此扩展功能实际未启用。
  • prompt_extend_target_lang: "ch"

    • 作用:提示扩展的目标语言为中文,若启用扩展功能,会将中文提示转换为更复杂的语义表示。
5. 随机种子与生成算法
  • base_seed: 40
    • 作用:随机种子,用于复现相同的生成结果。固定种子后,相同提示和参数下生成的视频内容一致。

二、分辨率宽度确认

  • 832*480 中,832是宽度,480是高度
    • 分辨率的表示规则为 宽度×高度 (Width×Height),例如:
      • 1080p是1920×1080(宽1920,高1080),
      • 这里的832×480接近16:9的比例(832÷480≈1.733,接近16:9的1.777)。

采样参数设置

一、核心参数解析

1. sample_solver: "unipc"
  • 作用 :指定采样算法(扩散模型生成视频的核心求解器)。
    • UniPC(Unified Predictor-Corrector):一种高效的数值积分方法,适用于扩散模型采样,兼顾速度与生成质量,支持动态调整步长,在较少步数下可实现较好效果。
    • 对比其他 solver:相比传统的 DDIM/PLMS 等算法,UniPC 在相同步数下生成细节更丰富,尤其适合视频生成的时空连贯性优化。
2. sample_steps: 50
  • 作用 :采样过程中执行的扩散步数(从噪声反向生成清晰样本的迭代次数)。
    • 数值影响
      • 50步:中等计算量,适合平衡速度与质量。步数不足可能导致细节模糊、动态不连贯;步数过高(如100+)会增加耗时,但收益可能边际递减。
      • 建议场景:若追求快速生成,可设为30-50;若需高保真细节(如复杂光影、精细纹理),可尝试60-80步。
3. sample_shift: 5.0
  • 作用 控制跨帧生成时的时间步长偏移或运动连贯性约束。
    • 在视频生成中,相邻帧的生成需考虑时间序列的连续性,sample_shift 可能用于调整帧间采样的时间相关性(如抑制突然运动或增强动态平滑度)。
    • 数值较高(如5.0)可能增强帧间约束,减少闪烁或跳跃,但可能限制剧烈动作的表现力;数值较低(如1.0-2.0)允许更自由的动态变化。
4. 引导尺度参数(guide_scale 系列)

引导尺度控制文本提示和参考图像对生成过程的约束强度,数值越高,生成结果越贴近输入条件,但可能导致多样性下降或过拟合。

  • sample_guide_scale: 5.0(通用引导尺度):

    • 全局控制文本+图像引导的综合强度,若未单独设置 img/text 参数,默认使用此值。
  • sample_guide_scale_img: 5.0(图像引导尺度):

    • 参考图像对生成的约束强度(适用于图生视频 s2v 任务)。
    • 5.0 含义:中等强度,生成内容会保留参考图像的视觉特征(如颜色、构图、主体形态),但允许一定程度的变化(如视角调整、动态延伸)。
  • sample_guide_scale_text: 7.5(文本引导尺度):

    • 文本提示对生成的约束强度,数值显著高于图像引导(7.5 > 5.0),表明:
      • 优先遵循文本描述:生成内容会严格匹配文本语义(如"夕阳下的海滩""机械恐龙奔跑"),可能牺牲部分图像参考的细节。
      • 风险与收益:高文本引导可能导致图像参考的视觉特征(如主体颜色、背景元素)被覆盖,需确保文本与图像语义一致(如文本描述需包含图像中的关键视觉元素)。

多组提示词+多参考图像输入

举例说明

json 复制代码
[
    {
        "prompt": "内容",
        "image_paths": ["examples/1.jpg","examples/3.jpg"]
    },
    {
        "prompt": "内容",
        "image_paths": ["examples/2.jpg","examples/3.jpg"]
    }
    ,
    {
        "prompt": "内容",
        "image_paths": ["examples/3.png","examples/8.jpg"]
    }
]

一、JSON输入结构解析

prompt.json包含3组生成任务,每组结构为:

json 复制代码
{
  "prompt": "文本提示词",       // 描述生成内容的语义(如"猫在草地跳跃")
  "image_paths": ["图1路径", "图2路径"]  // 用于主体对齐的参考图像列表(支持多图)
}

关键特点

  • 每组任务可包含 1~N张参考图像 (如examples/1.jpgexamples/3.jpg共同定义主体)。
  • 多图输入时,模型会自动融合多张图像的特征,适用于需要捕捉主体多角度、多姿态的场景(如生成人物行走视频时,用正面+侧面照片定义体型)。

二、处理多参考图像

1. 加载与预处理阶段(load_ref_images函数)
  • 输入image_paths列表(如["examples/1.jpg", "examples/3.jpg"])。
  • 处理逻辑
    1. 逐张加载图像,转换为RGB格式。
    2. 对每张图像进行保持比例缩放+中心填充 ,统一为目标尺寸(如832×480):
      • 若图像宽高比与目标尺寸不一致,先按比例缩放至最长边等于目标边,再用白色填充短边
    3. 输出ref_images列表,每张图像为PIL.Image对象,尺寸均为832×480
2. 模型生成阶段(Phantom_Wan_S2V.generate
  • 输入ref_images列表(多图) + prompt文本。
  • 核心逻辑
    1. 多图特征融合
      跨模态模型(Phantom-Wan)会提取每张参考图像的主体特征(如颜色、轮廓),并计算平均特征向量动态特征融合(根据图像顺序加权),形成对主体的综合描述。
    2. 动态对齐
      在生成视频的每一帧时,模型会同时参考所有输入图像的特征,确保主体在不同视角下的一致性(如正面图像约束面部特征,侧面图像约束身体比例)。

三、例子

场景1:复杂主体多角度定义
  • 需求:生成一个"机器人从左侧走向右侧"的视频,需要机器人正面和侧面外观一致。

  • 输入

    json 复制代码
    {
      "prompt": "银色机器人在灰色地面行走,头部有蓝色灯光",
      "image_paths": ["robot_front.jpg", "robot_side.jpg"]  // 正面+侧面图
    }
  • 效果

    • 视频中机器人正面视角时匹配robot_front.jpg的面部细节。
    • 转向侧面时匹配robot_side.jpg的身体轮廓和机械结构。
场景2:主体特征互补
  • 需求:修复单张图像缺失的细节(如证件照生成生活视频)。

  • 输入

    json 复制代码
    {
      "prompt": "穿蓝色衬衫的人在公园跑步,风吹动头发",
      "image_paths": ["id_photo.jpg", "hair_reference.jpg"]  // 证件照+发型参考图
    }
  • 效果

    • 主体面部和服装来自证件照,头发动态和颜色来自hair_reference.jpg,解决证件照中头发静止的问题。
场景3:多主体生成
  • 需求:生成"两个人握手"的视频,两人外观分别来自不同图像。

  • 输入

    json 复制代码
    {
      "prompt": "穿西装的男人和穿裙子的女人在会议室握手",
      "image_paths": ["man.jpg", "woman.jpg"]  // 两人的参考图像
    }
  • 效果

    • 模型自动识别图像中的两个主体,分别对齐到视频中的对应人物,确保两人外观与参考图像一致。

四、调用示例

1. 终端命令
bash 复制代码
python main.py --config_file config.json --prompt_file prompt.json
2. 生成结果示例

假设输入为:

json 复制代码
[
  {
    "prompt": "戴帽子的狗在雪地里打滚",
    "image_paths": ["dog_front.jpg", "dog_side.jpg"]
  }
]

生成的视频中:

  • 狗的头部特征(如眼睛、鼻子)来自dog_front.jpg
  • 身体姿态和帽子形状来自dog_side.jpg
  • 雪地、打滚动作由文本提示驱动生成。

多参考图像输入是Phantom-Wan实现复杂主体动态生成的核心能力之一,通过融合多张图像的特征。

完整代码

py 复制代码
import argparse
from datetime import datetime
import logging
import os
import sys
import warnings
import json
import time
from uuid import uuid4  # 新增:用于生成唯一标识符

warnings.filterwarnings('ignore')

import torch, random
import torch.distributed as dist
from PIL import Image, ImageOps

import phantom_wan
from phantom_wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from phantom_wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from phantom_wan.utils.utils import cache_video, cache_image, str2bool

def _validate_args(args):
    """参数验证函数"""
    # 基础检查
    assert args.ckpt_dir is not None, "请指定检查点目录"
    assert args.phantom_ckpt is not None, "请指定Phantom-Wan检查点"
    assert args.task in WAN_CONFIGS, f"不支持的任务: {args.task}"

    args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize)
    
    # 尺寸检查["832*480", "480*832"]
    assert args.size in SUPPORTED_SIZES[args.task], \
        f"任务{args.task}不支持尺寸{args.size},支持尺寸:{', '.join(SUPPORTED_SIZES[args.task])}"

def _parse_args():
    """参数解析函数"""
    parser = argparse.ArgumentParser(description="使用Phantom生成视频")
    parser.add_argument("--config_file", type=str, default="config.json", help="配置JSON文件路径")
    parser.add_argument("--prompt_file", type=str, default="prompt.json", help="提示词JSON文件路径")
    
    args = parser.parse_args()
    
    # 从配置文件加载参数
    with open(args.config_file, 'r') as f:
        config = json.load(f)
    for key, value in config.items():
        setattr(args, key, value)
    
    _validate_args(args)
    return args

def _init_logging(rank):
    """日志初始化函数"""
    if rank == 0:
        logging.basicConfig(
            level=logging.INFO,
            format="[%(asctime)s] %(levelname)s: %(message)s",
            handlers=[logging.StreamHandler(stream=sys.stdout)]
        )
    else:
        logging.basicConfig(level=logging.ERROR)

def load_ref_images(path, size):
    """加载参考图像并预处理"""
    h, w = size[1], size[0]  # 尺寸格式转换
    ref_images = []
    
    for image_path in path:
        with Image.open(image_path) as img:
            img = img.convert("RGB")
            img_ratio = img.width / img.height
            target_ratio = w / h

            # 保持比例缩放
            if img_ratio > target_ratio:
                new_width = w
                new_height = int(new_width / img_ratio)
            else:
                new_height = h
                new_width = int(new_height * img_ratio)
            
            img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
            
            # 中心填充至目标尺寸
            delta_w = w - img.size[0]
            delta_h = h - img.size[1]
            padding = (delta_w//2, delta_h//2, delta_w-delta_w//2, delta_h-delta_h//2)
            new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
            ref_images.append(new_img)
    
    return ref_images

def generate(args):
    """主生成函数"""
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    device = local_rank
    _init_logging(rank)

    # 分布式环境配置
    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size)
    
    # 模型并行配置
    if args.ulysses_size > 1 or args.ring_size > 1:
        assert args.ulysses_size * args.ring_size == world_size, "ulysses_size与ring_size乘积需等于总进程数"
        from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
        init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
        initialize_model_parallel(
            sequence_parallel_degree=dist.get_world_size(),
            ring_degree=args.ring_size,
            ulysses_degree=args.ulysses_size,
        )

    # 提示词扩展初始化
    prompt_expander = None
    if args.use_prompt_extend:
        if args.prompt_extend_method == "dashscope":
            prompt_expander = DashScopePromptExpander(
                model_name=args.prompt_extend_model, 
                is_vl="i2v" in args.task
            )
        elif args.prompt_extend_method == "local_qwen":
            prompt_expander = QwenPromptExpander(
                model_name=args.prompt_extend_model,
                is_vl="i2v" in args.task,
                device=rank
            )
        else:
            raise NotImplementedError(f"不支持的提示词扩展方法: {args.prompt_extend_method}")

    # 模型初始化(仅加载一次)
    cfg = WAN_CONFIGS[args.task]
    logging.info(f"初始化模型,任务类型: {args.task}")
    
    if "s2v" in args.task:
        # 视频生成(参考图像输入)
        wan = phantom_wan.Phantom_Wan_S2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            phantom_ckpt=args.phantom_ckpt,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )
    elif "t2v" in args.task or "t2i" in args.task:
        # 文本生成(图像/视频)
        wan = phantom_wan.WanT2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )
    else:
        # 图像生成视频(i2v)
        wan = phantom_wan.WanI2V(
            config=cfg,
            checkpoint_dir=args.ckpt_dir,
            device_id=device,
            rank=rank,
            t5_fsdp=args.t5_fsdp,
            dit_fsdp=args.dit_fsdp,
            use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
            t5_cpu=args.t5_cpu,
        )

    # 加载提示词列表
    with open(args.prompt_file, 'r') as f:
        prompts = json.load(f)
    
    total_generation_time = 0
    generation_counter = 0  # 新增:生成计数器防止文件名重复

    for prompt_info in prompts:
        prompt = prompt_info["prompt"]
        image_paths = prompt_info.get("image_paths", [])  # 处理可能不存在的键
        start_time = time.time()

        # 分布式环境同步种子
        if dist.is_initialized():
            base_seed = [args.base_seed] if rank == 0 else [None]
            dist.broadcast_object_list(base_seed, src=0)
            args.base_seed = base_seed[0]

        # 提示词扩展处理
        if args.use_prompt_extend and rank == 0:
            logging.info("正在扩展提示词...")
            if "s2v" in args.task or "i2v" in args.task and image_paths:
                img = Image.open(image_paths[0]).convert("RGB")
                prompt_output = prompt_expander(prompt, image=img, seed=args.base_seed)
            else:
                prompt_output = prompt_expander(prompt, seed=args.base_seed)
            
            if not prompt_output.status:
                logging.warning(f"提示词扩展失败: {prompt_output.message}, 使用原始提示词")
                input_prompt = prompt
            else:
                input_prompt = prompt_output.prompt
            
            # 分布式广播扩展后的提示词
            input_prompt = [input_prompt] if rank == 0 else [None]
            if dist.is_initialized():
                dist.broadcast_object_list(input_prompt, src=0)
            prompt = input_prompt[0]
            logging.info(f"扩展后提示词: {prompt}")

        # 执行生成
        logging.info(f"开始生成,提示词: {prompt}")
        if "s2v" in args.task:
            ref_images = load_ref_images(image_paths, SIZE_CONFIGS[args.size])
            video = wan.generate(
                prompt,
                ref_images,
                size=SIZE_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale_img=args.sample_guide_scale_img,
                guide_scale_text=args.sample_guide_scale_text,
                seed=args.base_seed,
                offload_model=args.offload_model
            )
        elif "t2v" in args.task or "t2i" in args.task:
            video = wan.generate(
                prompt,
                size=SIZE_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale=args.sample_guide_scale,
                seed=args.base_seed,
                offload_model=args.offload_model
            )
        else:  # i2v任务
            img = Image.open(image_paths[0]).convert("RGB")
            video = wan.generate(
                prompt,
                img,
                max_area=MAX_AREA_CONFIGS[args.size],
                frame_num=args.frame_num,
                shift=args.sample_shift,
                sample_solver=args.sample_solver,
                sampling_steps=args.sample_steps,
                guide_scale=args.sample_guide_scale,
                seed=args.base_seed,
                offload_model=args.offload_model
            )

        # 计算生成时间
        generation_time = time.time() - start_time
        total_generation_time += generation_time
        logging.info(f"生成耗时: {generation_time:.2f}秒")

        # 主进程保存结果
        if rank == 0:
            generation_counter += 1  # 计数器递增
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            safe_prompt = prompt.replace(" ", "_").replace("/", "_")[:50]  # 安全文件名处理
            file_uuid = str(uuid4())[:8]  # 新增:添加UUID短标识
            suffix = '.png' if "t2i" in args.task else '.mp4'
            
            # 生成唯一文件名
            save_file = f"{args.task}_{args.size}_{args.ulysses_size}_{args.ring_size}_" \
                        f"{safe_prompt}_{timestamp}_{generation_counter}_{file_uuid}{suffix}"
            
            logging.info(f"保存结果到: {save_file}")
            if "t2i" in args.task:
                cache_image(
                    tensor=video.squeeze(1)[None],
                    save_file=save_file,
                    nrow=1,
                    normalize=True,
                    value_range=(-1, 1)
                )
            else:
                cache_video(
                    tensor=video[None],
                    save_file=save_file,
                    fps=cfg.sample_fps,
                    nrow=1,
                    normalize=True,
                    value_range=(-1, 1)
                )

    logging.info(f"总生成耗时: {total_generation_time:.2f}秒")
    logging.info("生成完成")

if __name__ == "__main__":
    args = _parse_args()
    generate(args)