目录
[LtxVAE 推理代码](#LtxVAE 推理代码)
[repeat 尝试:](#repeat 尝试:)
LtxVAE 推理代码
python
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image
from loguru import logger
import torch
from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image
def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
def get_image(cond_image_path, use_face_crop):
if use_face_crop:
try:
image = process_image(cond_image_path)
return image
except Exception as e:
logger.error(f"Error processing {cond_image_path}: {e}")
return Image.open(cond_image_path).convert("RGB")
if os.path.isdir(cond_image_path_or_dir):
import glob
cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
cond_image_list.sort()
cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
else:
cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
return cond_image_dict
def resize_and_centercrop(cond_image, target_size):
"""
Resize image or tensor to the target size without padding.
"""
# Get the original size
if isinstance(cond_image, torch.Tensor):
_, orig_h, orig_w = cond_image.shape
else:
orig_h, orig_w = cond_image.height, cond_image.width
target_h, target_w = target_size
# Calculate the scaling factor for resizing
scale_h = target_h / orig_h
scale_w = target_w / orig_w
# Compute the final size
scale = max(scale_h, scale_w)
final_h = math.ceil(scale * orig_h)
final_w = math.ceil(scale * orig_w)
# Resize
if isinstance(cond_image, torch.Tensor):
if len(cond_image.shape) == 3:
cond_image = cond_image[None]
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
# crop
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor.squeeze(0)
else:
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
resized_image = np.array(resized_image)
# tensor and crop
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor[:, :, None, :, :]
return cropped_tensor
class FlashHeadPipeline:
def __init__(self,vae_dir):
self.vae = LtxVAE(
pretrained_model_type_or_path=vae_dir,
dtype=torch.bfloat16,
device = "cuda:0",
)
self.target_h = 448
self.target_w = 448
self.device = "cuda:0"
self.frame_num = 33
self.param_dtype=torch.bfloat16
self.vae.model.encode = torch.compile(self.vae.model.encode)
self.vae.model.decode = torch.compile(self.vae.model.decode)
cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)
self.cond_image_tensor_dict = {}
self.ref_img_latent_dict = {}
start=time.time()
for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2
self.cond_image_tensor_dict[person_name] = cond_image_tensor
video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
# if i == 0:
# self.reset_person_name(person_name)
print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))
if __name__ == "__main__":
vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"
flash_pipe = FlashHeadPipeline(vae_dir)
repeat 尝试:
python
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image
from loguru import logger
import torch
from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image
def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
def get_image(cond_image_path, use_face_crop):
if use_face_crop:
try:
image = process_image(cond_image_path)
return image
except Exception as e:
logger.error(f"Error processing {cond_image_path}: {e}")
return Image.open(cond_image_path).convert("RGB")
if os.path.isdir(cond_image_path_or_dir):
import glob
cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
cond_image_list.sort()
cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
else:
cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
return cond_image_dict
def resize_and_centercrop(cond_image, target_size):
"""
Resize image or tensor to the target size without padding.
"""
# Get the original size
if isinstance(cond_image, torch.Tensor):
_, orig_h, orig_w = cond_image.shape
else:
orig_h, orig_w = cond_image.height, cond_image.width
target_h, target_w = target_size
# Calculate the scaling factor for resizing
scale_h = target_h / orig_h
scale_w = target_w / orig_w
# Compute the final size
scale = max(scale_h, scale_w)
final_h = math.ceil(scale * orig_h)
final_w = math.ceil(scale * orig_w)
# Resize
if isinstance(cond_image, torch.Tensor):
if len(cond_image.shape) == 3:
cond_image = cond_image[None]
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
# crop
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor.squeeze(0)
else:
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
resized_image = np.array(resized_image)
# tensor and crop
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
cropped_tensor = cropped_tensor[:, :, None, :, :]
return cropped_tensor
class FlashHeadPipeline:
def __init__(self,vae_dir):
self.vae = LtxVAE(
pretrained_model_type_or_path=vae_dir,
dtype=torch.bfloat16,
device = "cuda:0",
)
self.target_h = 448
self.target_w = 448
self.device = "cuda:0"
self.frame_num = 33
self.param_dtype=torch.bfloat16
self.vae.model.encode = torch.compile(self.vae.model.encode)
self.vae.model.decode = torch.compile(self.vae.model.decode)
cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)
self.cond_image_tensor_dict = {}
self.ref_img_latent_dict = {}
for i in range(2):
start=time.time()
for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2
self.cond_image_tensor_dict[person_name] = cond_image_tensor
video_frames = cond_image_tensor.repeat(1,1,1,1,1)
latent_A = self.vae.encode(video_frames)
latent_A = latent_A.repeat(1,5,1,1)
# 方法B:单帧 + repeat
video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
latent_B = self.vae.encode(video_frames)
# 对比
print((latent_A - latent_B).abs().mean())
# video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
# self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
# if i == 0:
# self.reset_person_name(person_name)
print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))
if __name__ == "__main__":
vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"
flash_pipe = FlashHeadPipeline(vae_dir)