带有embeding 同时训练的Lora 权重合并,合并后的权重的模型,再训练数的Loss 突然增加

带有embeding 同时训练的Lora 权重合并

bash 复制代码
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model_path = "/root/private_data/models/deepseek-ai/6epoch-merged"
adapter_path = "/root/private_data/output/sharedata/checkpoint-10500/"
output_path = "/root/private_data/models/deepseek-ai/merged_correct"

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    torch_dtype="bfloat16",
    device_map="cuda:0"
)

# 加载 adapter
model = PeftModel.from_pretrained(model, adapter_path)

# 合并
model = model.merge_and_unload()

# 保存
model.save_pretrained(output_path)

tokenizer = AutoTokenizer.from_pretrained(base_model_path)
tokenizer.save_pretrained(output_path)

错误的写法 合并后的权重的模型,再训练数的Loss 突然增加

bash 复制代码
### 合并llamafactory lora 到新的模型,
### 注意 --export_dir这个路径,要新的路径,
### 新训练的 checkpoint路径不要用原来checkpoint路径,否则会lora对不齐错误
### 对应带有embedding 不合适 会丢失信息

llamafactory-cli export \
 --model_name_or_path /root/private_data/models/deepseek-ai/6epoch-merged \
 --adapter_name_or_path /root/private_data/output/binding_sft_short_part_004-005_lora64_epoch1/checkpoint-9000 \
 --template alpaca \
 --finetuning_type lora \
 --export_dir /root/private_data/models/deepseek-ai/6epoch-merged-lora64_data005 \
 --export_size 4 \
 --export_device auto \
 --export_legacy_format false

错误写法2

bash 复制代码
import torch
import json
import os
import shutil
from safetensors.torch import load_file, save_file


def merge_full_optimized(
    base_model_path: str,
    adapter_path: str,
    output_merged_path: str,
    device: str = "cuda:0"
):
    print(f"🚀 [START] 开始全量合并流程 (LoRA + Embedding)...")
    
    # 1. 加载 Adapter 权重并解析 LoRA 配置
    print(f"📥 加载 Adapter 权重...")
    adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"), device=device)
    
    with open(os.path.join(adapter_path, "adapter_config.json"), 'r') as f:
        lora_config = json.load(f)
    
    r = lora_config["r"]
    alpha = lora_config["lora_alpha"]
    scaling = alpha / r
    print(f"✅ LoRA 配置读取成功: r={r}, alpha={alpha}, scaling={scaling}")

    # 2. 【修复】准备输出目录:先从基础模型复制配置文件
    os.makedirs(output_merged_path, exist_ok=True)
    
    print(f"📋 从基础模型复制配置文件...")
    for file in os.listdir(base_model_path):
        src = os.path.join(base_model_path, file)
        if os.path.isfile(src) and not any(file.endswith(ex) for ex in ['.safetensors', '.bin', '.pt']):
            shutil.copy2(src, os.path.join(output_merged_path, file))
    
    # adapter 目录的非权重文件一般不需要(adapter_config.json 不是合并模型的一部分)
    # 如果 adapter 目录有其他特殊文件,可以按需添加

    # 3. 核心合并逻辑:分片处理
    safetensors_files = sorted([f for f in os.listdir(base_model_path) if f.endswith('.safetensors')])
    weight_map = {}
    new_num_tokens = None

    for st_file in safetensors_files:
        print(f"📦 处理分片: {st_file}...")
        src_path = os.path.join(base_model_path, st_file)
        dst_path = os.path.join(output_merged_path, st_file)
        state_dict = load_file(src_path, device=device)

        updated_keys = []
        for key in list(state_dict.keys()):
            # --- A. 处理 LoRA 合并 ---
            prefix = "base_model.model."
            lora_a_key = f"{prefix}{key.replace('.weight', '.lora_A.weight')}"
            lora_b_key = f"{prefix}{key.replace('.weight', '.lora_B.weight')}"

            if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
                W = state_dict[key].to(torch.bfloat16)
                A = adapter_weights[lora_a_key].to(torch.bfloat16)
                B = adapter_weights[lora_b_key].to(torch.bfloat16)
                
                state_dict[key] = W + (B @ A) * scaling
                updated_keys.append(key)

            # --- B. 处理 Embedding 和 LM_Head 替换 ---
            embed_save_key = f"{prefix}{key}.modules_to_save.default.weight"
            if embed_save_key in adapter_weights:
                state_dict[key] = adapter_weights[embed_save_key].to(state_dict[key].dtype)
                if "embed_tokens" in key:
                    new_num_tokens = state_dict[key].shape[0]
                print(f"💎 已替换模块: {key}")

        # 保存分片
        cpu_state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()}
        save_file(cpu_state_dict, dst_path)
        for k in state_dict.keys():
            weight_map[k] = st_file
        
        del state_dict, cpu_state_dict
        torch.cuda.empty_cache()

    # 4. 更新 config.json
    if new_num_tokens:
        cfg_path = os.path.join(output_merged_path, "config.json")
        if os.path.exists(cfg_path):
            with open(cfg_path, 'r+') as f:
                config = json.load(f)
                config["vocab_size"] = new_num_tokens
                f.seek(0)
                json.dump(config, f, indent=2)
                f.truncate()
            print(f"📝 已更新 config.json: vocab_size = {new_num_tokens}")
        else:
            print(f"⚠️ 警告: 未找到 config.json,无法更新 vocab_size")

    # 5. 生成索引
    total_size = sum(os.path.getsize(os.path.join(output_merged_path, f)) for f in safetensors_files)
    index_path = os.path.join(output_merged_path, "model.safetensors.index.json")
    with open(index_path, 'w') as f:
        json.dump({"metadata": {"total_size": total_size}, "weight_map": weight_map}, f, indent=2)

    print(f"✨ [DONE] 全量合并成功!输出目录: {output_merged_path}")
    print(f"📁 请确认以下文件存在:")
    for f in ["config.json", "tokenizer.json", "tokenizer_config.json", "model.safetensors.index.json"]:
        p = os.path.join(output_merged_path, f)
        status = "✅" if os.path.exists(p) else "❌"
        print(f"   {status} {f}")


if __name__ == "__main__":
    CONFIG = {
        "base_model_path": "/root/private_data/models/deepseek-ai/6epoch-merged",
        "adapter_path": "/root/private_data/output/sharedata/checkpoint-10500/",
        "output_merged_path": "/root/private_data/models/deepseek-ai/6epoch-merged-sft",
        "device": "cuda:0"
    }
    merge_full_optimized(**CONFIG)

    # 这是一个 Python 的语法技巧,叫做 "字典解包"(Dictionary Unpacking)。
    # 简单来说,CONFIG 的作用是将一个字典里的键值对,
    # "拆解"成一个个独立的关键字参数(Keyword Arguments)传给函数。
    # merge_full_optimized(**CONFIG)
相关推荐
xiezhr25 分钟前
折腾了半小时,终于让AI能帮我写飞书文档了
人工智能·agent·ai编程
ZhengEnCi10 小时前
09bad-斯坦福CS336作业一-构建优化器
人工智能
ZhengEnCi11 小时前
09bac-斯坦福CS336作业一-实现训练损失计算
人工智能
冬奇Lab11 小时前
Skill 系列(01):Skill 评测体系——如何量化一个 AI Skill 的质量
人工智能
IT_陈寒14 小时前
Redis内存爆了,原来我漏掉了这个致命配置
前端·人工智能·后端
用户35218024547516 小时前
🎆从 Prompt 到 Skill:让 Spring AI Agent 学会"装新技能"
人工智能·spring boot·ai编程
米小虾16 小时前
手把手教你搭建第一个生产级AI Agent:从选型到实战的完整指南
人工智能·agent
任沫16 小时前
Agent之Function Call
javascript·人工智能·go
米小虾16 小时前
2026年AI Agent全面爆发:从开源生态到企业级应用的进化之路
人工智能·agent
用户69190268133917 小时前
Vibe Coding 开发项目的基本范式
人工智能·设计模式·代码规范