bash
复制代码
import torch
import json
import os
import shutil
from safetensors.torch import load_file, save_file
def merge_full_optimized(
base_model_path: str,
adapter_path: str,
output_merged_path: str,
device: str = "cuda:0"
):
print(f"🚀 [START] 开始全量合并流程 (LoRA + Embedding)...")
# 1. 加载 Adapter 权重并解析 LoRA 配置
print(f"📥 加载 Adapter 权重...")
adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"), device=device)
with open(os.path.join(adapter_path, "adapter_config.json"), 'r') as f:
lora_config = json.load(f)
r = lora_config["r"]
alpha = lora_config["lora_alpha"]
scaling = alpha / r
print(f"✅ LoRA 配置读取成功: r={r}, alpha={alpha}, scaling={scaling}")
# 2. 【修复】准备输出目录:先从基础模型复制配置文件
os.makedirs(output_merged_path, exist_ok=True)
print(f"📋 从基础模型复制配置文件...")
for file in os.listdir(base_model_path):
src = os.path.join(base_model_path, file)
if os.path.isfile(src) and not any(file.endswith(ex) for ex in ['.safetensors', '.bin', '.pt']):
shutil.copy2(src, os.path.join(output_merged_path, file))
# adapter 目录的非权重文件一般不需要(adapter_config.json 不是合并模型的一部分)
# 如果 adapter 目录有其他特殊文件,可以按需添加
# 3. 核心合并逻辑:分片处理
safetensors_files = sorted([f for f in os.listdir(base_model_path) if f.endswith('.safetensors')])
weight_map = {}
new_num_tokens = None
for st_file in safetensors_files:
print(f"📦 处理分片: {st_file}...")
src_path = os.path.join(base_model_path, st_file)
dst_path = os.path.join(output_merged_path, st_file)
state_dict = load_file(src_path, device=device)
updated_keys = []
for key in list(state_dict.keys()):
# --- A. 处理 LoRA 合并 ---
prefix = "base_model.model."
lora_a_key = f"{prefix}{key.replace('.weight', '.lora_A.weight')}"
lora_b_key = f"{prefix}{key.replace('.weight', '.lora_B.weight')}"
if lora_a_key in adapter_weights and lora_b_key in adapter_weights:
W = state_dict[key].to(torch.bfloat16)
A = adapter_weights[lora_a_key].to(torch.bfloat16)
B = adapter_weights[lora_b_key].to(torch.bfloat16)
state_dict[key] = W + (B @ A) * scaling
updated_keys.append(key)
# --- B. 处理 Embedding 和 LM_Head 替换 ---
embed_save_key = f"{prefix}{key}.modules_to_save.default.weight"
if embed_save_key in adapter_weights:
state_dict[key] = adapter_weights[embed_save_key].to(state_dict[key].dtype)
if "embed_tokens" in key:
new_num_tokens = state_dict[key].shape[0]
print(f"💎 已替换模块: {key}")
# 保存分片
cpu_state_dict = {k: v.contiguous().cpu() for k, v in state_dict.items()}
save_file(cpu_state_dict, dst_path)
for k in state_dict.keys():
weight_map[k] = st_file
del state_dict, cpu_state_dict
torch.cuda.empty_cache()
# 4. 更新 config.json
if new_num_tokens:
cfg_path = os.path.join(output_merged_path, "config.json")
if os.path.exists(cfg_path):
with open(cfg_path, 'r+') as f:
config = json.load(f)
config["vocab_size"] = new_num_tokens
f.seek(0)
json.dump(config, f, indent=2)
f.truncate()
print(f"📝 已更新 config.json: vocab_size = {new_num_tokens}")
else:
print(f"⚠️ 警告: 未找到 config.json,无法更新 vocab_size")
# 5. 生成索引
total_size = sum(os.path.getsize(os.path.join(output_merged_path, f)) for f in safetensors_files)
index_path = os.path.join(output_merged_path, "model.safetensors.index.json")
with open(index_path, 'w') as f:
json.dump({"metadata": {"total_size": total_size}, "weight_map": weight_map}, f, indent=2)
print(f"✨ [DONE] 全量合并成功!输出目录: {output_merged_path}")
print(f"📁 请确认以下文件存在:")
for f in ["config.json", "tokenizer.json", "tokenizer_config.json", "model.safetensors.index.json"]:
p = os.path.join(output_merged_path, f)
status = "✅" if os.path.exists(p) else "❌"
print(f" {status} {f}")
if __name__ == "__main__":
CONFIG = {
"base_model_path": "/root/private_data/models/deepseek-ai/6epoch-merged",
"adapter_path": "/root/private_data/output/sharedata/checkpoint-10500/",
"output_merged_path": "/root/private_data/models/deepseek-ai/6epoch-merged-sft",
"device": "cuda:0"
}
merge_full_optimized(**CONFIG)
# 这是一个 Python 的语法技巧,叫做 "字典解包"(Dictionary Unpacking)。
# 简单来说,CONFIG 的作用是将一个字典里的键值对,
# "拆解"成一个个独立的关键字参数(Keyword Arguments)传给函数。
# merge_full_optimized(**CONFIG)