大模型参数高效微调技术实战(二)-Prompt Tuning

随着,ChatGPT 迅速爆火,引发了大模型的时代变革。然而对于普通大众来说,进行大模型的预训练或者全量微调遥不可及。由此,催生了各种参数高效微调技术,让科研人员或者普通开发者有机会尝试微调大模型。

因此,该技术值得我们进行深入分析其背后的机理,之前分享了大模型参数高效微调技术原理综述 的文章。下面给大家分享大模型参数高效微调技术实战 系列文章,该系列共六篇文章,相关代码均放置在GitHub:llm-action

本文为大模型参数高效微调技术实战的第二篇。

Prompt Tuning 简述

Prompt Tuning(论文:The Power of Scale for Parameter-Efficient Prompt Tuning ),该方法可以看作是 Prefix Tuning 的简化版本,它给每个任务定义了自己的Prompt,然后拼接到数据上作为输入,但只在输入层加入prompt tokens ,并且不需要加入 MLP 进行调整来解决难训练的问题。

更加详细的介绍可参考之前的文章:大模型参数高效微调技术原理综述(二)-BitFit、Prefix Tuning、Prompt Tuning

Prompt Tuning 微调实战

为了不影响阅读体验,详细的代码放置在GitHub:llm-action 项目中 peft_prompt_tuning_clm.ipynb文件,这里仅列出关键步骤。

第一步,引进必要的库,如:Prompt Tuning 配置类 PromptTuningConfig

python 复制代码
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType

第二步,创建 Prompt Tuning 微调方法对应的配置。

python 复制代码
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    tokenizer_name_or_path=model_name_or_path,
)

参数说明:

  • prompt_tuning_init:提示嵌入的初始化方法。PEFT支持文本(TEXT)和随机(RANDOM)初始化。在原理篇中提到过 Prompt token 的初始化方法和长度对于模型性能有影响。与随机初始化和使用样本词汇表初始化相比,Prompt Tuning 采用类标签初始化模型的效果更好。不过随着模型参数规模的提升,这种gap最终会消失。因此,如果需要使用类标签和样本词汇表初始化需指定为TEXT。
  • prompt_tuning_init_text:用于初始化提示嵌入的文本,在使用文本(TEXT)初始化方法时使用。
  • task_type:指定任务类型。如:条件生成任务(SEQ_2_SEQ_LM),因果语言建模(CAUSAL_LM)等。
  • num_virtual_tokens:指定虚拟Token数。在原理篇中,提到过提示虚拟 Token 的长度在20左右时的表现已经不错(超过20之后,提升Prompt token长度,对模型的性能提升不明显了);同样的,这个gap也会随着模型参数规模的提升而减小(即对于超大规模模型而言,即使提示虚拟 Token 长度很短,对性能也不会有太大的影响)。

第三步,通过调用 get_peft_model 方法包装基础的 Transformer 模型。

python 复制代码
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

通过 print_trainable_parameters 方法可以查看可训练参数的数量(仅为8,192)以及占比(仅为0.00146%)。

csharp 复制代码
trainable params: 8,192 || all params: 559,222,784 || trainable%: 0.0014648902430985358

Prompt Tuning 模型类结构如下所示:

scss 复制代码
PeftModelForCausalLM(
  (base_model): BloomForCausalLM(
    (transformer): BloomModel(
      (word_embeddings): Embedding(250880, 1024)
      (word_embeddings_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (h): ModuleList(
        ...
      )
      (ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=1024, out_features=250880, bias=False)
  )
  (prompt_encoder): ModuleDict(
    (default): PromptEmbedding(
      (embedding): Embedding(8, 1024)
    )
  )
  (word_embeddings): Embedding(250880, 1024)
)

从模型类结构可以看到,Prompt Tuning 只在输入层加入 prompt virtual tokens,其他地方均没有变化,具体可查看 PromptEmbedding 的源码。

python 复制代码
class PromptEmbedding(torch.nn.Module):
    def __init__(self, config, word_embeddings):
        super().__init__()

        total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
        # 初始化 embedding 层
        self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
        
        # 如果使用文本进行初始化,执行如下逻辑,PromptTuningConfig 配置类需要传入初始化文本。
        if config.prompt_tuning_init == PromptTuningInit.TEXT:
            from transformers import AutoTokenizer

            tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path)
            init_text = config.prompt_tuning_init_text
            init_token_ids = tokenizer(init_text)["input_ids"]
            # Trim or iterate until num_text_tokens matches total_virtual_tokens
            num_text_tokens = len(init_token_ids)
            if num_text_tokens > total_virtual_tokens:
                init_token_ids = init_token_ids[:total_virtual_tokens]
            elif num_text_tokens < total_virtual_tokens:
                num_reps = math.ceil(total_virtual_tokens / num_text_tokens)
                init_token_ids = init_token_ids * num_reps
            init_token_ids = init_token_ids[:total_virtual_tokens]

            word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
            # 初始化embedding层的权重
            self.embedding.weight = torch.nn.Parameter(word_embedding_weights)

    def forward(self, indices):
        # Just get embeddings
        prompt_embeddings = self.embedding(indices)
        return prompt_embeddings

第四步,模型训练的其余部分均无需更改,当模型训练完成之后,保存高效微调部分的模型权重以供模型推理即可。

python 复制代码
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

输出的模型权重文件如下所示:

css 复制代码
/data/nfs/llm/model/bloomz-560m_PROMPT_TUNING_CAUSAL_LM
├── [ 500]  adapter_config.json
├── [ 33K]  adapter_model.bin
└── [ 111]  README.md

0 directories, 3 files

注意:这里只会保存经过训练的增量 PEFT 权重。其中,adapter_config.json 为 Prompt Tuning 配置文件;adapter_model.bin 为 Prompt Tuning 权重文件。

第五步,加载微调后的权重文件进行推理。

python 复制代码
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

# 加载PEFT配置
config = PeftConfig.from_pretrained(peft_model_id)

# 加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
# 加载PEFT模型
model = PeftModel.from_pretrained(model, peft_model_id)

# Tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

# 模型推理
outputs = model.generate(
        input_ids=inputs["input_ids"], 
        attention_mask=inputs["attention_mask"], 
        max_new_tokens=10, 
        eos_token_id=3
    )

# Tokenizer 解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

至此,我们完成了Prompt Tuning的训练及推理。

结语

本文对 Prompt Tuning 的基本原理进行了简述;同时,讲解了使用 Prompt Tuning 技术进行模型训练及推理。下文将对 P-Tuning 微调技术进行实战讲解。

如果觉得我的文章能够能够给您带来帮助,期待您的点赞收藏加关注~~

相关推荐
老顾聊技术1 小时前
老顾深度解析【字节跳动的AI项目DeerFlow】源码之工程结构(六)
llm·agent
亚马逊云开发者1 小时前
在 Amazon Bedrock 中结合 RAG 与 MCP 高效缓解提示词膨胀问题
llm
数据智能老司机1 小时前
MCP 实战——MCP 服务器的身份验证与部署
llm·agent·mcp
数据智能老司机1 小时前
MCP 实战——高级服务器架构
llm·agent·mcp
DevYK13 小时前
企业级 Agent 开发实战(一) LangGraph 快速入门
后端·llm·agent
Ethan.Yuan13 小时前
【深度长文】Anthropic发布Prompt Engineering全新指南
大模型·llm·prompt·提示工程
AI大模型18 小时前
基于 Docker 的 LLaMA-Factory 全流程部署指南
docker·llm·llama
AI大模型19 小时前
强推!大模型学习书籍合集推荐 | (含PDF地址)
程序员·llm·agent
字节跳动安全中心20 小时前
智能体防御 | 一文了解3种系统提示词加固方法
安全·llm
聚客AI21 小时前
🧩万亿级Token训练!解密大模型预训练算力黑洞与RLHF对齐革命
人工智能·llm·强化学习