LtxVAE 学习笔记

目录

[LtxVAE 推理代码](#LtxVAE 推理代码)

[repeat 尝试:](#repeat 尝试:)


LtxVAE 推理代码

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        start=time.time()
        for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
            cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
            cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

            self.cond_image_tensor_dict[person_name] = cond_image_tensor

            video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
            self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
            # if i == 0:
                # self.reset_person_name(person_name)
        print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)

repeat 尝试:

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        
        for i in range(2):
            start=time.time()
            for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
                cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
                cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

                self.cond_image_tensor_dict[person_name] = cond_image_tensor

                video_frames = cond_image_tensor.repeat(1,1,1,1,1)
                latent_A = self.vae.encode(video_frames)
                latent_A = latent_A.repeat(1,5,1,1)

                # 方法B:单帧 + repeat
                video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                latent_B = self.vae.encode(video_frames)

                # 对比
                print((latent_A - latent_B).abs().mean())

                # video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                # self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
                # if i == 0:
                    # self.reset_person_name(person_name)
            print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)
相关推荐
juyou51182 小时前
清明踏青亲子研学升温,AI+数字乡村技术破解体验与安全管控痛点
大数据·人工智能·科技·ar·语音识别
Juicedata2 小时前
一文解锁 JuiceFS 在 AI 场景中的性能优化
人工智能·性能优化
木头程序员2 小时前
关于load_data_fashion_mnist函数运行原理以及运行速度慢解决方案
人工智能·python·深度学习·d2l
东离与糖宝2 小时前
2026 Java AI框架选型:Spring AI/LangChain4j企业级对比
java·人工智能
yunpeng.zhou2 小时前
深度理解agent与llm之间的关系、及mcp与skill的区别
人工智能·python·ai
CoderJia程序员甲2 小时前
GitHub 热榜项目 - 日榜(2026-04-03)
人工智能·ai·大模型·github·ai教程
TDengine (老段)2 小时前
TDengine IDMP 可视化 —— 趋势图
大数据·数据库·人工智能·物联网·时序数据库·tdengine·涛思数据
东离与糖宝2 小时前
Java AI工程化:PyTorch On Java+SpringBoot微服务部署(2025-2026最新实战)
java·人工智能
2601_955363152 小时前
技术赋能B端拓客:号码核验行业的迭代与价值升级
大数据·人工智能