LtxVAE 学习笔记

目录

[LtxVAE 推理代码](#LtxVAE 推理代码)

[repeat 尝试:](#repeat 尝试:)


LtxVAE 推理代码

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        start=time.time()
        for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
            cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
            cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

            self.cond_image_tensor_dict[person_name] = cond_image_tensor

            video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
            self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
            # if i == 0:
                # self.reset_person_name(person_name)
        print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)

repeat 尝试:

python 复制代码
import math
import os
import time
import numpy as np
import torch
import torch.distributed as dist
from einops import rearrange
import torch.nn as nn
from transformers import Wav2Vec2FeatureExtractor
import torchvision.transforms as transforms
from PIL import Image

from loguru import logger
import torch

from flash_head.ltx_video.ltx_vae import LtxVAE
from flash_head.utils.facecrop import process_image

def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
    def get_image(cond_image_path, use_face_crop):
        if use_face_crop:
            try:
                image = process_image(cond_image_path)
                return image
            except Exception as e:
                logger.error(f"Error processing {cond_image_path}: {e}")
        return Image.open(cond_image_path).convert("RGB")

    if os.path.isdir(cond_image_path_or_dir):
        import glob
        cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.jpg"))
        cond_image_list.sort()
        cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
    else:
        cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
    return cond_image_dict

def resize_and_centercrop(cond_image, target_size):
    """
    Resize image or tensor to the target size without padding.
    """

    # Get the original size
    if isinstance(cond_image, torch.Tensor):
        _, orig_h, orig_w = cond_image.shape
    else:
        orig_h, orig_w = cond_image.height, cond_image.width

    target_h, target_w = target_size
    
    # Calculate the scaling factor for resizing
    scale_h = target_h / orig_h
    scale_w = target_w / orig_w
    
    # Compute the final size
    scale = max(scale_h, scale_w)
    final_h = math.ceil(scale * orig_h)
    final_w = math.ceil(scale * orig_w)
    
    # Resize
    if isinstance(cond_image, torch.Tensor):
        if len(cond_image.shape) == 3:
            cond_image = cond_image[None]
        resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() 
        # crop
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) 
        cropped_tensor = cropped_tensor.squeeze(0)
    else:
        resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
        resized_image = np.array(resized_image)
        # tensor and crop
        resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
        cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
        cropped_tensor = cropped_tensor[:, :, None, :, :] 

    return cropped_tensor

class FlashHeadPipeline:
    def __init__(self,vae_dir):
            
        self.vae = LtxVAE(
            pretrained_model_type_or_path=vae_dir,
            dtype=torch.bfloat16,
            device = "cuda:0",
        )

        self.target_h = 448
        self.target_w = 448
        self.device = "cuda:0"
        self.frame_num = 33

        self.param_dtype=torch.bfloat16

        self.vae.model.encode = torch.compile(self.vae.model.encode)
        self.vae.model.decode = torch.compile(self.vae.model.decode)

        cond_image_path_or_dir='/data/lbg/project/SoulX-FlashHead-api2/imgs/d11_960.jpg'
        self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, True)

        self.cond_image_tensor_dict = {}
        self.ref_img_latent_dict = {}
        
        for i in range(2):
            start=time.time()
            for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
                cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
                cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2

                self.cond_image_tensor_dict[person_name] = cond_image_tensor

                video_frames = cond_image_tensor.repeat(1,1,1,1,1)
                latent_A = self.vae.encode(video_frames)
                latent_A = latent_A.repeat(1,5,1,1)

                # 方法B:单帧 + repeat
                video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                latent_B = self.vae.encode(video_frames)

                # 对比
                print((latent_A - latent_B).abs().mean())

                # video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
                # self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
                # if i == 0:
                    # self.reset_person_name(person_name)
            print('vae.encode time',time.time()-start,len(self.cond_image_dict.items()),len(video_frames))


if __name__ == "__main__":

    vae_dir= r"/data/lbg/models/flash_head_models/SoulX-FlashHead-1.3B/VAE_LTX/"


    flash_pipe = FlashHeadPipeline(vae_dir)
相关推荐
枫叶林FYL1 小时前
【机器学习与智慧医疗】T2DM-EWS: 2型糖尿病早期预警系统(多参数集成分类模型)完整实现
人工智能·机器学习·分类
南屹川1 小时前
【缓存技术】Redis实战:从缓存策略到分布式锁
人工智能
Li emily8 小时前
解决了加密货币api多币种订阅时的数据乱序问题
人工智能·python·api·fastapi
山川绿水8 小时前
bugku——PWN——overflow2
人工智能·web安全·网络安全
程序员cxuan8 小时前
微信读书官方发了 skills,把我给秀麻了。
人工智能·后端·程序员
fake_ss1988 小时前
AI时代学习全栈项目开发的新范式
java·人工智能·学习·架构·个人开发·学习方法
nassi_8 小时前
对AI工程问题的一些思考
大数据·人工智能·hadoop
AI技术控8 小时前
《Transformers are Inherently Succinct》论文解读:从“能表达什么”到“多紧凑地表达”
人工智能·python·深度学习·机器学习·自然语言处理
蔡俊锋8 小时前
AI记忆压缩术:从305GB到7.4GB的魔法
人工智能·ai·ai 记忆
Upsy-Daisy9 小时前
AI Agent 项目学习笔记(二):Spring AI 与 ChatClient 主链路解析
人工智能·笔记·学习