Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。
Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。
通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。
Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。
Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。
github项目地址:https://github.com/alibaba/Tora。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118
cd modules/SwissArmyTransformer
pip install -e .
cd ../../sat
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
3、CogVideoX-5b模型下载:
git lfs install
git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git
4、Tora t2v模型下载:
https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt
二**、功能测试**
1、运行测试:
(1)python代码调用测试
import argparse
import gc
import json
import math
import os
import pickle
from pathlib import Path
from typing import List, Union
import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as TT
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torchvision.utils import flow_to_image
from tqdm import tqdm
from utils.flow_utils import process_traj
from utils.misc import vis_tensor
from sat import mpu
from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def draw_points(video, points):
"""
Draw points onto video frames.
Parameters:
video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,
H is the height, W is the width, and C is the number of channels.
points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],
each point contains x and y coordinates.
Returns:
torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C].
"""
T = video.shape[0]
N = len(points)
device = video.device
dtype = video.dtype
video = video.cpu().numpy().copy()
traj = np.zeros(video.shape[-3:], dtype=np.uint8) # [H, W, C]
for n in range(N):
for t in range(1, T):
cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)
for t in range(T):
mask = traj[..., -1] > 0
mask = repeat(mask, "h w -> h w c", c=3)
alpha = 0.7
video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alpha
for n in range(N):
cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)
video = torch.from_numpy(video).to(device, dtype)
return video
def save_video_as_grid_and_mp4(
video_batch: torch.Tensor,
save_path: str,
name: str,
fps: int = 5,
args=None,
key=None,
traj_points=None,
prompt="",
):
os.makedirs(save_path, exist_ok=True)
p = Path(save_path)
for i, vid in enumerate(video_batch):
x = rearrange(vid, "t c h w -> t h w c")
x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8) # [T H W C]
os.makedirs(p / "video", exist_ok=True)
os.makedirs(p / "prompt", exist_ok=True)
if traj_points is not None:
os.makedirs(p / "traj", exist_ok=True)
os.makedirs(p / "traj_video", exist_ok=True)
write_video(
p / "video" / f"{name}_{i:06d}.mp4",
x,
fps=fps,
video_codec="libx264",
options={"crf": "18"},
)
with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:
pickle.dump(traj_points, f)
x = draw_points(x, traj_points)
write_video(
p / "traj_video" / f"{name}_{i:06d}.mp4",
x,
fps=fps,
video_codec="libx264",
options={"crf": "18"},
)
else:
write_video(
p / "video" / f"{name}_{i:06d}.mp4",
x,
fps=fps,
video_codec="libx264",
options={"crf": "18"},
)
with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:
f.write(prompt)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls):
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = [480, 720]
sample_func = model.sample
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
with torch.no_grad():
for text, cnt in tqdm(data_iter):
set_random_seed(args.seed)
if args.flow_from_prompt:
text, flow_files = text.split("\t")
total_num_frames = (T - 1) * 4 + 1 # T is the video latent size, 13 * 4 = 52
if args.no_flow_injection:
video_flow = None
elif args.flow_from_prompt:
assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"
p = os.path.join(args.flow_path, flow_files)
print(f"Flow path: {p}")
video_flow = (
torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda()
)
elif args.flow_path:
print(f"Flow path: {args.flow_path}")
video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[
:total_num_frames
].unsqueeze_(0)
elif args.point_path:
if type(args.point_path) == str:
args.point_path = json.loads(args.point_path)
print(f"Point path: {args.point_path}")
video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)
video_flow = video_flow.unsqueeze_(0)
else:
print("No flow injection")
video_flow = None
if video_flow is not None:
model.to("cpu") # move model to cpu, run vae on gpu only.
tmp = rearrange(video_flow[0], "T H W C -> T C H W")
video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda") # [1 T C H W]
if args.vis_traj_features:
os.makedirs("samples/flow", exist_ok=True)
vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")
imageio.mimwrite(
"samples/flow/flow2_vis.gif",
rearrange(video_flow[0], "T C H W -> T H W C").cpu(),
fps=8,
loop=0,
)
del tmp
video_flow = (
rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16)
)
torch.cuda.empty_cache()
video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous() # for uncondition
model.first_stage_model.to(device)
video_flow = model.encode_first_stage(video_flow, None)
video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()
model.to(device)
print("rank:", rank, "start to process", text, cnt)
# TODO: broadcast image2video
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
for index in range(args.batch_size):
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
video_flow=video_flow,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
save_path = args.output_dir
name = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"
if args.flow_from_prompt:
name = Path(flow_files).stem
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(
samples,
save_path,
name,
fps=args.sampling_fps,
traj_points=locals().get("points", None),
prompt=text,
)
del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_uc
gc.collect()
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
args.model_config.en_and_decode_n_samples_a_time = 1
sampling_main(args, model_cls=SATVideoDiffusionEngine)
未完......
更多详细的欢迎关注:杰哥新技术