AIGC笔记--SVD中UNet加载预训练权重

1--加载方式

  1. 加载全参数(.ckpt)

  2. 加载LoRA(.safetensors)

2--简单实例

python 复制代码
import sys
sys.path.append("/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion")

import torch
from peft import LoraConfig
from safetensors import safe_open

from svd.models.i2v_svd_unet import UNetSpatioTemporalConditionModel
from svd.utils.util import zero_rank_print

if __name__ == "__main__":

    pretrained_model_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/svd_models/models/stable-video-diffusion-img2vid-xt"
    unet = UNetSpatioTemporalConditionModel.from_pretrained(pretrained_model_path, subfolder = "unet")

    # resume_checkpoint_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion/results/outputs_motionlora_realRota_0603_1024_stride5/test-0-2024-06-03T14-31-30/checkpoints/checkpoint-500.safetensors"
    resume_checkpoint_path = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-waimai-aigc/liujinfu/Codes/v3d-vgen-motion/results/outputs_motionFull_realRota_0529_stride5/test-0-2024-05-29T10-04-34/checkpoints/checkpoint-step-5000.ckpt"

    # Load pretrained unet weights
    if resume_checkpoint_path.endswith(".ckpt"):
        zero_rank_print(f"resume from checkpoint: {resume_checkpoint_path}")
        resume_checkpoint = torch.load(resume_checkpoint_path, map_location="cpu")
        # resume dit parameters
        print(f'resume_checkpoint keys: {resume_checkpoint.keys()}')
        state_dict = resume_checkpoint["state_dict"]
        m, u = unet.load_state_dict(state_dict, strict=False)
        zero_rank_print(f"dit missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0
        # resume global step
        resume_global_step = False
        if "global_step" in resume_checkpoint and resume_global_step:
            zero_rank_print(f"resume global_step: {resume_checkpoint['global_step']}")
            global_step = resume_checkpoint['global_step']    
            
    elif resume_checkpoint_path.endswith(".safetensors"):

        unet_lora_config = LoraConfig(
            r = 64, 
            lora_alpha = 64, # scaling = lora_alpha / r
            init_lora_weights = "gaussian", 
            target_modules = ["to_q","to_k","to_v","to_out.0"],
            lora_dropout = 0.1
        )
        unet.add_adapter(unet_lora_config)

        zero_rank_print(f"resume from safetensors: {resume_checkpoint_path}")
        
        state_dict = {}
        with safe_open(resume_checkpoint_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                key_ = key.replace('unet.', '').replace('.weight', '')
                state_dict[key_] = f.get_tensor(key)

                u = 0
                try:
                    unet.get_submodule(key_+'.default').state_dict()['weight'].data.copy_(state_dict[key_])
                except:
                    u += 1
        assert u == 0, "resume unet params failed"

    print("All Done!")
相关推荐
墨风如雪8 小时前
DeepSeek OCR:用'眼睛'阅读长文本,AI记忆新纪元?
aigc
算家计算13 小时前
SAIL-VL2本地部署教程:2B/8B参数媲美大规模模型,为轻量级设备量身打造的多模态大脑
人工智能·开源·aigc
ECT-OS-JiuHuaShan18 小时前
《元推理框架技术白皮书》,人工智能领域的“杂交水稻“
人工智能·aigc·学习方法·量子计算·空间计算
Jagger_21 小时前
组织能力才是AI公司真正的壁垒:构建AI Native组织的完整指南
aigc
Mintopia21 小时前
🧩 隐私计算技术在 Web AIGC 数据处理中的应用实践
前端·javascript·aigc
程序员X小鹿1 天前
谷歌又出黑科技:支持图文混排的AI创意画布来了!1个想法,3秒出图,免费可用!(附教程)
aigc
万里鹏程转瞬至1 天前
开源项目分析:wan2.1 VACE 关键设计与实现代码解读
论文阅读·aigc
墨风如雪1 天前
告别代码苦海:Manus 1.5 让你的创意以光速落地
aigc
麦麦麦造2 天前
有了 MCP,为什么Claude 还要推出 Skills?
人工智能·aigc·ai编程
张晓~183399481212 天前
碰一碰发视频 系统源码 /PHP 语言开发方案
开发语言·线性代数·矩阵·aigc·php·音视频·文心一言