最新阿里开源视频生成框架Tora部署

Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。

Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。

通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。

Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。

Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。

github项目地址:https://github.com/alibaba/Tora。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

cd modules/SwissArmyTransformer

pip install -e .

cd ../../sat

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、CogVideoX-5b模型下载

git lfs install

git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git

4、Tora t2v模型下载

https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt

二**、功能测试**

1、运行测试

(1)python代码调用测试

复制代码
import argparse
import gc
import json
import math
import os
import pickle
from pathlib import Path
from typing import List, Union

import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as TT
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torchvision.utils import flow_to_image
from tqdm import tqdm
from utils.flow_utils import process_traj
from utils.misc import vis_tensor

from sat import mpu
from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint

def read_from_cli():
    cnt = 0
    try:
        while True:
            x = input("Please input English text (Ctrl-D quit): ")
            yield x.strip(), cnt
            cnt += 1
    except EOFError as e:
        pass

def read_from_file(p, rank=0, world_size=1):
    with open(p, "r") as fin:
        cnt = -1
        for l in fin:
            cnt += 1
            if cnt % world_size != rank:
                continue
            yield l.strip(), cnt

def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))

def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
            batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
        else:
            batch[key] = value_dict[key]

    if T is not None:
        batch["num_video_frames"] = T

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc

def draw_points(video, points):
    """
    Draw points onto video frames.

    Parameters:
        video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
                            H is the height, W is the width, and C is the number of channels.
        points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],
                            each point contains x and y coordinates.

    Returns:
        torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
    """

    T = video.shape[0]
    N = len(points)
    device = video.device
    dtype = video.dtype
    video = video.cpu().numpy().copy()
    traj = np.zeros(video.shape[-3:], dtype=np.uint8)  # [H, W, C]
    for n in range(N):
        for t in range(1, T):
            cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)
    for t in range(T):
        mask = traj[..., -1] > 0
        mask = repeat(mask, "h w -> h w c", c=3)
        alpha = 0.7
        video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
        for n in range(N):
            cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)
    video = torch.from_numpy(video).to(device, dtype)
    return video

def save_video_as_grid_and_mp4(
    video_batch: torch.Tensor,
    save_path: str,
    name: str,
    fps: int = 5,
    args=None,
    key=None,
    traj_points=None,
    prompt="",
):
    os.makedirs(save_path, exist_ok=True)
    p = Path(save_path)

    for i, vid in enumerate(video_batch):
        x = rearrange(vid, "t c h w -> t h w c")
        x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8)  # [T H W C]
        os.makedirs(p / "video", exist_ok=True)
        os.makedirs(p / "prompt", exist_ok=True)
        if traj_points is not None:
            os.makedirs(p / "traj", exist_ok=True)
            os.makedirs(p / "traj_video", exist_ok=True)
            write_video(
                p / "video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
            with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:
                pickle.dump(traj_points, f)
            x = draw_points(x, traj_points)
            write_video(
                p / "traj_video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
        else:
            write_video(
                p / "video" / f"{name}_{i:06d}.mp4",
                x,
                fps=fps,
                video_codec="libx264",
                options={"crf": "18"},
            )
        with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:
            f.write(prompt)

def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
    if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
        arr = resize(
            arr,
            size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
            interpolation=InterpolationMode.BICUBIC,
        )
    else:
        arr = resize(
            arr,
            size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
            interpolation=InterpolationMode.BICUBIC,
        )

    h, w = arr.shape[2], arr.shape[3]
    arr = arr.squeeze(0)

    delta_h = h - image_size[0]
    delta_w = w - image_size[1]

    if reshape_mode == "random" or reshape_mode == "none":
        top = np.random.randint(0, delta_h + 1)
        left = np.random.randint(0, delta_w + 1)
    elif reshape_mode == "center":
        top, left = delta_h // 2, delta_w // 2
    else:
        raise NotImplementedError
    arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
    return arr

def sampling_main(args, model_cls):
    if isinstance(model_cls, type):
        model = get_model(args, model_cls)
    else:
        model = model_cls

    load_checkpoint(model, args)
    model.eval()

    if args.input_type == "cli":
        data_iter = read_from_cli()
    elif args.input_type == "txt":
        rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
        print("rank and world_size", rank, world_size)
        data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
    else:
        raise NotImplementedError

    image_size = [480, 720]

    sample_func = model.sample
    T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
    num_samples = [1]
    force_uc_zero_embeddings = ["txt"]
    device = model.device
    with torch.no_grad():
        for text, cnt in tqdm(data_iter):
            set_random_seed(args.seed)
            if args.flow_from_prompt:
                text, flow_files = text.split("\t")
            total_num_frames = (T - 1) * 4 + 1  # T is the video latent size, 13 * 4 = 52
            if args.no_flow_injection:
                video_flow = None
            elif args.flow_from_prompt:
                assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
                p = os.path.join(args.flow_path, flow_files)
                print(f"Flow path: {p}")
                video_flow = (
                    torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
                )
            elif args.flow_path:
                print(f"Flow path: {args.flow_path}")
                video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
                    :total_num_frames
                ].unsqueeze_(0)
            elif args.point_path:
                if type(args.point_path) == str:
                    args.point_path = json.loads(args.point_path)
                print(f"Point path: {args.point_path}")
                video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
                video_flow = video_flow.unsqueeze_(0)
            else:
                print("No flow injection")
                video_flow = None

            if video_flow is not None:
                model.to("cpu")  # move model to cpu, run vae on gpu only.
                tmp = rearrange(video_flow[0], "T H W C -> T C H W")
                video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda")  # [1 T C H W]
                if args.vis_traj_features:
                    os.makedirs("samples/flow", exist_ok=True)
                    vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
                    imageio.mimwrite(
                        "samples/flow/flow2_vis.gif",
                        rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
                        fps=8,
                        loop=0,
                    )
                del tmp
                video_flow = (
                    rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
                )
                torch.cuda.empty_cache()
                video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous()  # for uncondition
                model.first_stage_model.to(device)
                video_flow = model.encode_first_stage(video_flow, None)
                video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
                model.to(device)

            print("rank:", rank, "start to process", text, cnt)
            # TODO: broadcast image2video
            value_dict = {
                "prompt": text,
                "negative_prompt": "",
                "num_frames": torch.tensor(T).unsqueeze(0),
            }

            batch, batch_uc = get_batch(
                get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
            )
            for key in batch:
                if isinstance(batch[key], torch.Tensor):
                    print(key, batch[key].shape)
                elif isinstance(batch[key], list):
                    print(key, [len(l) for l in batch[key]])
                else:
                    print(key, batch[key])
            c, uc = model.conditioner.get_unconditional_conditioning(
                batch,
                batch_uc=batch_uc,
                force_uc_zero_embeddings=force_uc_zero_embeddings,
            )

            for k in c:
                if not k == "crossattn":
                    c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
            for index in range(args.batch_size):
                # reload model on GPU
                model.to(device)
                samples_z = sample_func(
                    c,
                    uc=uc,
                    batch_size=1,
                    shape=(T, C, H // F, W // F),
                    video_flow=video_flow,
                )
                samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()

                # Unload the model from GPU to save GPU memory
                model.to("cpu")
                torch.cuda.empty_cache()
                first_stage_model = model.first_stage_model
                first_stage_model = first_stage_model.to(device)

                latent = 1.0 / model.scale_factor * samples_z

                # Decode latent serial to save GPU memory
                recons = []
                loop_num = (T - 1) // 2
                for i in range(loop_num):
                    if i == 0:
                        start_frame, end_frame = 0, 3
                    else:
                        start_frame, end_frame = i * 2 + 1, i * 2 + 3
                    if i == loop_num - 1:
                        clear_fake_cp_cache = True
                    else:
                        clear_fake_cp_cache = False
                    with torch.no_grad():
                        recon = first_stage_model.decode(
                            latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
                        )

                    recons.append(recon)

                recon = torch.cat(recons, dim=2).to(torch.float32)
                samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
                samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()

                save_path = args.output_dir
                name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
                if args.flow_from_prompt:
                    name = Path(flow_files).stem
                if mpu.get_model_parallel_rank() == 0:
                    save_video_as_grid_and_mp4(
                        samples,
                        save_path,
                        name,
                        fps=args.sampling_fps,
                        traj_points=locals().get("points", None),
                        prompt=text,
                    )
            del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
            gc.collect()

if __name__ == "__main__":
    if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
        os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
        os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
        os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    py_parser = argparse.ArgumentParser(add_help=False)
    known, args_list = py_parser.parse_known_args()

    args = get_args(args_list)
    args = argparse.Namespace(**vars(args), **vars(known))
    del args.deepspeed_config
    args.model_config.first_stage_config.params.cp_size = 1
    args.model_config.network_config.params.transformer_args.model_parallel_size = 1
    args.model_config.network_config.params.transformer_args.checkpoint_activations = False
    args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
    args.model_config.en_and_decode_n_samples_a_time = 1

    sampling_main(args, model_cls=SATVideoDiffusionEngine)

未完......

更多详细的欢迎关注:杰哥新技术

相关推荐
Billy_Zuo6 小时前
人工智能机器学习——模型评价及优化
android·人工智能·机器学习
Web3_Daisy7 小时前
克隆代币 + 捆绑开盘:多链环境下的低成本发币玩法
人工智能·web3·区块链·比特币
sali-tec7 小时前
C# 基于halcon的视觉工作流-章32-线线测量
开发语言·人工智能·算法·计算机视觉·c#
七牛云行业应用7 小时前
图灵奖得主萨顿演讲解读:深度学习的局限与AI新范式
人工智能·深度学习
IMER SIMPLE7 小时前
人工智能-python-深度学习-神经网络VGG(详解)
人工智能·python·深度学习
martinzh7 小时前
检索器江湖:那些让RAG神功大成的武林绝学
人工智能
Juchecar7 小时前
通过“单词补全”演示 Transformer 原理(Python代码可运行)
人工智能·python
禁默8 小时前
第六届机器学习与计算机应用国际学术会议
运维·人工智能·机器学习·自动化
念念01078 小时前
基于机器学习的P2P网贷平台信用违约预测模型
人工智能·机器学习
悟乙己8 小时前
机器学习超参数调优全方法介绍指南
人工智能·机器学习·超参数