LtxVAE 学习笔记

目录

[LtxVAE 推理代码](#LtxVAE 推理代码)

[repeat 尝试:](#repeat 尝试:)


LtxVAE 推理代码

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        start=time.time()
        for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
            cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
            cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

            self.cond_image_tensor_dict[person_name] = cond_image_tensor

            video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
            self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
            # if i == 0:
                # self.reset_person_name(person_name)
        print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)

repeat 尝试:

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        
        for i in range(2):
            start=time.time()
            for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
                cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
                cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

                self.cond_image_tensor_dict[person_name] = cond_image_tensor

                video_frames = cond_image_tensor.repeat(1,1,1,1,1)
                latent_A = self.vae.encode(video_frames)
                latent_A = latent_A.repeat(1,5,1,1)

                # 方法B:单帧 + repeat
                video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                latent_B = self.vae.encode(video_frames)

                # 对比
                print((latent_A - latent_B).abs().mean())

                # video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                # self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
                # if i == 0:
                    # self.reset_person_name(person_name)
            print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)
相关推荐
AI机器学习算法1 天前
深度学习模型演进:6个里程碑式CNN架构
人工智能·深度学习·cnn·大模型·ai学习路线
Ztopcloud极拓云视角1 天前
从 OpenRouter 数据看中美 AI 调用量反转:统计口径、模型路由与多云应对方案
人工智能·阿里云·大模型·token·中美ai
AI医影跨模态组学1 天前
如何将深度学习MTSR与膀胱癌ITGB8/TGF-β/WNT机制建立关联,并进一步解释其与患者预后及肿瘤侵袭、免疫抑制的生物学联系
人工智能·深度学习·论文·医学影像
搬砖的前端1 天前
AI编辑器开源主模型搭配本地模型辅助对标GPT5.2/GPT5.4/Claude4.6(前端开发专属)
人工智能·开源·claude·mcp·trae·qwen3.6·ops4.6
Python私教1 天前
Hermes Agent 安全加固与生态扩展:2026-04-23 更新解析
人工智能
饼干哥哥1 天前
Kimi K2.6 干成了Claude Design国产版,一句话生成电影级的动态品牌网站
人工智能
肖有米XTKF86461 天前
带货者精品优选模式系统的平台解析
人工智能·信息可视化·团队开发·csdn开发云
天天进步20151 天前
打破沙盒限制:OpenWork 如何通过权限模型实现安全的系统级调用?
人工智能·安全
xcbrand1 天前
政府事业机构品牌策划公司找哪家
大数据·人工智能·python
骥龙1 天前
第十篇:合规与未来展望——构建AI智能体安全标准
人工智能·安全