最新阿里开源视频生成框架Tora部署

Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。

Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。

通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。

Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。

Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。

github项目地址:https://github.com/alibaba/Tora。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

cd modules/SwissArmyTransformer

pip install -e .

cd ../../sat

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、CogVideoX-5b模型下载

git lfs install

git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git

4、Tora t2v模型下载

https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt

二**、功能测试**

1、运行测试

(1)python代码调用测试

复制代码
import argparse
import gc
import json
import math
import os
import pickle
from pathlib import Path
from typing import List, Union

import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as TT
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torchvision.utils import flow_to_image
from tqdm import tqdm
from utils.flow_utils import process_traj
from utils.misc import vis_tensor

from sat import mpu
from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint

def read_from_cli():
    cnt = 0
    try:
        while True:
            x = input("Please input English text (Ctrl-D quit): ")
            yield x.strip(), cnt
            cnt += 1
    except EOFError as e:
        pass

def read_from_file(p, rank=0, world_size=1):
    with open(p, "r") as fin:
        cnt = -1
        for l in fin:
            cnt += 1
            if cnt % world_size != rank:
                continue
            yield l.strip(), cnt

def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))

def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
            batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
        else:
            batch[key] = value_dict[key]

    if T is not None:
        batch["num_video_frames"] = T

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc

def draw_points(video, points):
    """
    Draw points onto video frames.

    Parameters:
        video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
                            H is the height, W is the width, and C is the number of channels.
        points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],
                            each point contains x and y coordinates.

    Returns:
        torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
    """

    T = video.shape[0]
    N = len(points)
    device = video.device
    dtype = video.dtype
    video = video.cpu().numpy().copy()
    traj = np.zeros(video.shape[-3:], dtype=np.uint8)  # [H, W, C]
    for n in range(N):
        for t in range(1, T):
            cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)
    for t in range(T):
        mask = traj[..., -1] > 0
        mask = repeat(mask, "h w -> h w c", c=3)
        alpha = 0.7
        video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
        for n in range(N):
            cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)
    video = torch.from_numpy(video).to(device, dtype)
    return video

def save_video_as_grid_and_mp4(
    video_batch: torch.Tensor,
    save_path: str,
    name: str,
    fps: int = 5,
    args=None,
    key=None,
    traj_points=None,
    prompt="",
):
    os.makedirs(save_path, exist_ok=True)
    p = Path(save_path)

    for i, vid in enumerate(video_batch):
        x = rearrange(vid, "t c h w -> t h w c")
        x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8)  # [T H W C]
        os.makedirs(p / "video", exist_ok=True)
        os.makedirs(p / "prompt", exist_ok=True)
        if traj_points is not None:
            os.makedirs(p / "traj", exist_ok=True)
            os.makedirs(p / "traj_video", exist_ok=True)
            write_video(
                p / "video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
            with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:
                pickle.dump(traj_points, f)
            x = draw_points(x, traj_points)
            write_video(
                p / "traj_video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
        else:
            write_video(
                p / "video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
        with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:
            f.write(prompt)

def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
    if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
        arr = resize(
            arr,
            size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
            interpolation=InterpolationMode.BICUBIC,
        )
    else:
        arr = resize(
            arr,
            size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
            interpolation=InterpolationMode.BICUBIC,
        )

    h, w = arr.shape[2], arr.shape[3]
    arr = arr.squeeze(0)

    delta_h = h - image_size[0]
    delta_w = w - image_size[1]

    if reshape_mode == "random" or reshape_mode == "none":
        top = np.random.randint(0, delta_h + 1)
        left = np.random.randint(0, delta_w + 1)
    elif reshape_mode == "center":
        top, left = delta_h // 2, delta_w // 2
    else:
        raise NotImplementedError
    arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
    return arr

def sampling_main(args, model_cls):
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls

    load_checkpoint(model, args)
    model.eval()

    if args.input_type == "cli":
        data_iter = read_from_cli()
    elif args.input_type == "txt":
        rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
        print("rank and world_size", rank, world_size)
        data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
    else:
        raise NotImplementedError

    image_size = [480, 720]

    sample_func = model.sample
    T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
    num_samples = [1]
    force_uc_zero_embeddings = ["txt"]
    device = model.device
    with torch.no_grad():
        for text, cnt in tqdm(data_iter):
            set_random_seed(args.seed)
            if args.flow_from_prompt:
                text, flow_files = text.split("\t")
            total_num_frames = (T - 1) * 4 + 1  # T is the video latent size, 13 * 4 = 52
            if args.no_flow_injection:
                video_flow = None
            elif args.flow_from_prompt:
                assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
                p = os.path.join(args.flow_path, flow_files)
                print(f"Flow path: {p}")
                video_flow = (
                    torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
                )
            elif args.flow_path:
                print(f"Flow path: {args.flow_path}")
                video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
                    :total_num_frames
                ].unsqueeze_(0)
            elif args.point_path:
                if type(args.point_path) == str:
                    args.point_path = json.loads(args.point_path)
                print(f"Point path: {args.point_path}")
                video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
                video_flow = video_flow.unsqueeze_(0)
            else:
                print("No flow injection")
                video_flow = None

            if video_flow is not None:
                model.to("cpu")  # move model to cpu, run vae on gpu only.
                tmp = rearrange(video_flow[0], "T H W C -> T C H W")
                video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda")  # [1 T C H W]
                if args.vis_traj_features:
                    os.makedirs("samples/flow", exist_ok=True)
                    vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
                    imageio.mimwrite(
                        "samples/flow/flow2_vis.gif",
                        rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
                        fps=8,
                        loop=0,
                    )
                del tmp
                video_flow = (
                    rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
                )
                torch.cuda.empty_cache()
                video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous()  # for uncondition
                model.first_stage_model.to(device)
                video_flow = model.encode_first_stage(video_flow, None)
                video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
                model.to(device)

            print("rank:", rank, "start to process", text, cnt)
            # TODO: broadcast image2video
            value_dict = {
                "prompt": text,
                "negative_prompt": "",
                "num_frames": torch.tensor(T).unsqueeze(0),
            }

            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
            )
            for key in batch:
                if isinstance(batch[key], torch.Tensor):
                    print(key, batch[key].shape)
                elif isinstance(batch[key], list):
                    print(key, [len(l) for l in batch[key]])
                else:
                    print(key, batch[key])
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=force_uc_zero_embeddings,
            )

            for k in c:
                if not k == "crossattn":
                    c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
            for index in range(args.batch_size):
                # reload model on GPU
                model.to(device)
                samples_z = sample_func(
                    c,
                    uc=uc,
                    batch_size=1,
                    shape=(T, C, H // F, W // F),
                    video_flow=video_flow,
                )
                samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()

                # Unload the model from GPU to save GPU memory
                model.to("cpu")
                torch.cuda.empty_cache()
                first_stage_model = model.first_stage_model
                first_stage_model = first_stage_model.to(device)

                latent = 1.0 / model.scale_factor * samples_z

                # Decode latent serial to save GPU memory
                recons = []
                loop_num = (T - 1) // 2
                for i in range(loop_num):
                    if i == 0:
                        start_frame, end_frame = 0, 3
                    else:
                        start_frame, end_frame = i * 2 + 1, i * 2 + 3
                    if i == loop_num - 1:
                        clear_fake_cp_cache = True
                    else:
                        clear_fake_cp_cache = False
                    with torch.no_grad():
                        recon = first_stage_model.decode(
                            latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
                        )

                    recons.append(recon)

                recon = torch.cat(recons, dim=2).to(torch.float32)
                samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()

                save_path = args.output_dir
                name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
                if args.flow_from_prompt:
                    name = Path(flow_files).stem
                if mpu.get_model_parallel_rank() == 0:
                    save_video_as_grid_and_mp4(
                        samples,
                        save_path,
                        name,
                        fps=args.sampling_fps,
                        traj_points=locals().get("points", None),
                        prompt=text,
                    )
            del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
            gc.collect()

if __name__ == "__main__":
    if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
        os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
        os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
        os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    py_parser = argparse.ArgumentParser(add_help=False)
    known, args_list = py_parser.parse_known_args()

    args = get_args(args_list)
    args = argparse.Namespace(**vars(args), **vars(known))
    del args.deepspeed_config
    args.model_config.first_stage_config.params.cp_size = 1
    args.model_config.network_config.params.transformer_args.model_parallel_size = 1
    args.model_config.network_config.params.transformer_args.checkpoint_activations = False
    args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
    args.model_config.en_and_decode_n_samples_a_time = 1

    sampling_main(args, model_cls=SATVideoDiffusionEngine)

未完......

更多详细的欢迎关注:杰哥新技术

相关推荐
华新嘉华DTC创新营销1 小时前
华新嘉华:AI搜索优化重塑本地生活行业:智能推荐正取代“关键词匹配”
人工智能·百度·生活
SmartBrain2 小时前
DeerFlow 实践:华为IPD流程的评审智能体设计
人工智能·语言模型·架构
l1t3 小时前
利用DeepSeek实现服务器客户端模式的DuckDB原型
服务器·c语言·数据库·人工智能·postgresql·协议·duckdb
寒月霜华4 小时前
机器学习-数据标注
人工智能·机器学习
九章云极AladdinEdu5 小时前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
人工智能训练师6 小时前
Ubuntu22.04如何安装新版本的Node.js和npm
linux·运维·前端·人工智能·ubuntu·npm·node.js
cxr8288 小时前
SPARC方法论在Claude Code基于规则驱动开发中的应用
人工智能·驱动开发·claude·智能体
研梦非凡8 小时前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
幂简集成8 小时前
Realtime API 语音代理端到端接入全流程教程(含 Demo,延迟 280ms)
人工智能·个人开发
龙腾-虎跃9 小时前
FreeSWITCH FunASR语音识别模块
人工智能·语音识别·xcode