LtxVAE 学习笔记

目录

[LtxVAE 推理代码](#LtxVAE 推理代码)

[repeat 尝试:](#repeat 尝试:)


LtxVAE 推理代码

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        start=time.time()
        for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
            cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
            cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

            self.cond_image_tensor_dict[person_name] = cond_image_tensor

            video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
            self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
            # if i == 0:
                # self.reset_person_name(person_name)
        print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)

repeat 尝试:

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        
        for i in range(2):
            start=time.time()
            for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
                cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
                cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

                self.cond_image_tensor_dict[person_name] = cond_image_tensor

                video_frames = cond_image_tensor.repeat(1,1,1,1,1)
                latent_A = self.vae.encode(video_frames)
                latent_A = latent_A.repeat(1,5,1,1)

                # 方法B:单帧 + repeat
                video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                latent_B = self.vae.encode(video_frames)

                # 对比
                print((latent_A - latent_B).abs().mean())

                # video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                # self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
                # if i == 0:
                    # self.reset_person_name(person_name)
            print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)
相关推荐
共创splendid--与您携手24 分钟前
AI读取前端项目生成skill.md
前端·人工智能·ai
gis分享者2 小时前
AI数字营销实测体验,GEO效果查询功能体验
人工智能·csdn·geo·数字营销·实测体验·效果查询
莱歌数字2 小时前
轻出20%性能:三维拓扑优化如何重塑无人机电子设备散热格局
人工智能·科技·制造·cae·散热
猿小猴子2 小时前
主流 AI IDE 之一的「DeepSeek-Reasonix 」介绍
人工智能·ai·deepseek·reasonix
装不满的克莱因瓶2 小时前
链式法则如何传递参数误差 —— 深入理解神经网络中的梯度传播
人工智能·python·深度学习·神经网络·数学·机器学习·ai
Anastasiozzzz2 小时前
从有限状态机到智能体图:传统 FSM 与 Agent Graph的演进
java·人工智能·python·ai
程序员cxuan8 小时前
为每个任务配一套 harness:Claude Code 里的动态工作流
人工智能
程序员cxuan8 小时前
Claude Fable 5 来了
人工智能·后端·程序员
云边云科技_云网融合8 小时前
云边云科技亮相 2026 WOD 制造业数智化博览会 云网融合赋能制造焕新
人工智能·科技·安全·制造
Σίσυφος19008 小时前
激光三角 光平面标定-多高度误差分析
人工智能·计算机视觉·平面