llama factory 扩充词表训练

文章目录

方式一

python 复制代码
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


model_name = "/root/autodl-tmp/LLaMA-Factory/ckpts/Qwen3-0.6B"
new_tokens = ["<|C-L|>", "<|S-L|>", "<|C-S|>", "<|S-S|>"]

output_dir = model_name + "_custom_tokens"

print("[DEBUG] output_dir: ", output_dir)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="auto",
    trust_remote_code=True,
)

print("[DEBUG] 已加载原始模型和分词器")
print(f"[DEBUG] 原始词表大小: {len(tokenizer)}")


# 检查当前是否已存在
exist = [t for t in new_tokens if tokenizer.convert_tokens_to_ids(t) != tokenizer.unk_token_id]
print("[DEBUG] new_tokens already exist:", exist)

exist = []
for t in new_tokens:
    for st in list(t):
        if tokenizer.convert_tokens_to_ids(st) != tokenizer.unk_token_id:
            exist.append(st)

print("[DEBUG] tokens already exist:", set(exist))


# 把 new tokens 加入 tokenizer
added = tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})    # 把它们当作 additional_special_tokens(语义上更"特殊")
# added = tokenizer.add_tokens(new_tokens)       # 或者直接 add_tokens(普通 token,但也会被视为单 token)
print(f"[DEBUG] 成功添加 {added} 个新token")
print(f"[DEBUG] 扩展后词表大小: {len(tokenizer)}")


# 初始化 embedding
model.resize_token_embeddings(len(tokenizer))
emb = model.get_input_embeddings().weight.data
vocab_size, dim = emb.shape
print("[DEBUG] ", vocab_size, dim)

for new_id, new_token in zip(list(range(vocab_size - added, vocab_size)), new_tokens):

    mean_tensor = []
    for mean_id in list(new_token):
        mean_emb = emb[tokenizer.convert_tokens_to_ids(mean_id), :]
        mean_tensor.append(mean_emb)

    emb[new_id] = torch.stack(mean_tensor, 0).mean(0)

print("[DEBUG] ", emb.shape)

exist = [t for t in new_tokens if tokenizer.convert_tokens_to_ids(t) != tokenizer.unk_token_id]
print("[DEBUG] new_tokens already exist:", exist)

for t in new_tokens:
    toks = tokenizer.tokenize(t)
    ids = tokenizer.encode(t, add_special_tokens=False)
    print("[DEBUG] ", t, "-> tokens:", toks, "ids:", ids)
    assert len(ids) == 1, f"{t} 被拆成 {len(ids)} 个 token,需检查 tokenizer 类型"


tokenizer.save_pretrained(output_dir, push_to_hub=False)
model.save_pretrained(output_dir, push_to_hub=False)

方式二

在 train.yaml 中添加

yaml 复制代码
# # additional_target: embed_tokens,norm
# # additional_target: embed_tokens,lm_head,norm
new_special_tokens_config: /root/autodl-tmp/LLaMA-Factory/yamls/control_tokens.yaml
init_special_tokens: noise_init   # noise_init, desc_init, desc_init_w_noise
add_special_tokens: <|C-L|>, <|S-L|>, <|C-S|>, <|S-S|>

ref:https://github.com/hiyouga/LLaMA-Factory/pull/9267

注意

合并权重需要有

yaml 复制代码
skip_special_tokens: false

并且加载模型的时候也需要

python 复制代码
 self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, use_fast=True, trust_remote_code=True)

 if self.infer_backend == "huggingface":
     self.model = AutoModelForCausalLM.from_pretrained(
         self.model_path,
         dtype="auto",
         device_map="auto",
         trust_remote_code=True,
     )

 elif self.infer_backend == "vllm":
     from vllm import LLM, SamplingParams

     self.sampling_params = SamplingParams(
         temperature=0.01,           
         # top_p=0.9,                
         # top_k=1,                  
         max_tokens=8,               
         stop=[],                    
         skip_special_tokens=False,   # 保留特殊token
     )

     self.model = LLM(
         model=self.model_path,
         tensor_parallel_size=len(self.devices),       
         gpu_memory_utilization=0.9,                   
         max_model_len=4096,
         trust_remote_code=True,
         enable_prefix_caching=True,                   
         max_num_seqs=64,                              
     )
python 复制代码
content = self.tokenizer.decode(output_ids[index:], skip_special_tokens=False).strip("\n")
相关推荐
薛定谔的猫19826 小时前
十七、用 GPT2 中文对联模型实现经典上联自动对下联:
人工智能·深度学习·gpt2·大模型 训练 调优
机 _ 长8 小时前
YOLO26 改进 | 基于特征蒸馏 | 知识蒸馏 (Response & Feature-based Distillation)
python·深度学习·机器学习
龙山云仓9 小时前
No140:AI世间故事-对话康德——先验哲学与AI理性:范畴、道德律与自主性
大数据·人工智能·深度学习·机器学习·全文检索·lucene
jay神10 小时前
基于YOLOv8的木材表面缺陷检测系统
人工智能·深度学习·yolo·计算机视觉·毕业设计
songyuc11 小时前
【Llava】load_pretrained_model() 说明
人工智能·深度学习
名为沙丁鱼的猫72911 小时前
【MCP 协议层(Protocol layer)详解】:深入分析MCP Python SDK中协议层的实现机制
人工智能·深度学习·神经网络·机器学习·自然语言处理·nlp
小Tomkk12 小时前
PyTorch +YOLO + Label Studio + 图像识别 深度学习项目实战 (二)
pytorch·深度学习·yolo
龙腾亚太13 小时前
航空零部件加工变形难题破解:数字孪生 + 深度学习的精度控制实战
人工智能·深度学习·数字孪生·ai工程师·ai证书·转型ai
Coding茶水间13 小时前
基于深度学习的输电电力设备检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
开发语言·人工智能·深度学习·yolo·目标检测·机器学习