LtxVAE 学习笔记

目录

[LtxVAE 推理代码](#LtxVAE 推理代码)

[repeat 尝试:](#repeat 尝试:)


LtxVAE 推理代码

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        start=time.time()
        for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
            cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
            cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

            self.cond_image_tensor_dict[person_name] = cond_image_tensor

            video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
            self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
            # if i == 0:
                # self.reset_person_name(person_name)
        print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)

repeat 尝试:

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        
        for i in range(2):
            start=time.time()
            for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
                cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
                cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

                self.cond_image_tensor_dict[person_name] = cond_image_tensor

                video_frames = cond_image_tensor.repeat(1,1,1,1,1)
                latent_A = self.vae.encode(video_frames)
                latent_A = latent_A.repeat(1,5,1,1)

                # 方法B:单帧 + repeat
                video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                latent_B = self.vae.encode(video_frames)

                # 对比
                print((latent_A - latent_B).abs().mean())

                # video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                # self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
                # if i == 0:
                    # self.reset_person_name(person_name)
            print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)
相关推荐
IT_陈寒29 分钟前
Vite的热更新突然不香了,排查三小时差点砸键盘
前端·人工智能·后端
阿里云大数据AI技术2 小时前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu12273 小时前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
字节跳动视频云技术团队3 小时前
让 Agent 成为音视频工作台:AI MediaKit CLI + Skill 发布
人工智能·音视频开发
魏祖潇3 小时前
framework 整合实战——DDD/TDD/SDD 三件套在 framework 仓的真实落地
人工智能·后端
Token炼金师3 小时前
去噪扩散:从随机噪声到高保真图像的数学之路
人工智能·aigc
这个DBA有点耶3 小时前
AI写的SQL跑崩了生产库,这锅谁背?
数据库·人工智能·程序员
阿里云大数据AI技术4 小时前
阿里云 EMR AI 助手正式发布:从问答工具到全栈智能运维助手
运维·人工智能
Larcher5 小时前
从零搭建 MCP 服务——让 AI 拥有无限扩展能力
人工智能·程序员
zzzzzz3105 小时前
你的 AI 写的 React 烂透了?这个 8000+ Star 的开源工具能揪出 90% 的「Agent 屎山」
人工智能