从零学习大模型(八)-----P-Tuning(上)

什么是P-Tuning?

P-Tuning 是一种参数高效的微调方法,旨在通过引入可学习的提示(Prompt)来改进预训练语言模型在特定任务中的表现。与传统的全参数微调相比,P-Tuning 只需要微调少量的额外参数,使得模型在执行任务时更加高效,并且能够适应多种下游任务。

在 P-Tuning 中,提示(Prompt)是一些附加到输入上的可学习向量,这些向量作为模型的输入,帮助调整模型的行为。通过对这些可学习提示进行训练,模型可以在特定任务上获得较好的表现,而不需要对原有的大量模型参数进行更新。

P-Tuning 的发展背景和动机

  • 语言模型在 NLP 任务中表现出了极大的潜力,但全参数微调方法需要对所有模型参数进行更新,对于大型模型(如 GPT-3)来说,参数数量非常庞大,导致计算开销巨大,训练成本高昂。因此,出现了各种参数高效微调方法,以减少对模型全参数的依赖,实现高效微调。
  • P-Tuning 的动机在于通过引入可学习的提示(Prompt)来优化模型性能,而不需要对整个模型进行大规模更新。P-Tuning 借鉴了 Prompt 设计的思想,将手工设计的静态 Prompt 转换为可学习的动态提示,使得模型可以更好地适应具体任务。与传统的手工设计 Prompt 相比,可学习的 Prompt 更加灵活,可以通过训练自动优化,从而提升模型在特定任务中的表现。
  • P-Tuning 的发展是对参数高效微调需求的回应,旨在降低微调大型模型的成本,同时保留模型的强大能力。在此背景下,P-Tuning 被提出,作为一种在维持高性能的同时,减少计算和存储成本的有效方法。

与其他参数高效微调方法的比较(如 LoRA、Prefix Tuning)

  • LoRA:LoRA 通过引入低秩适配矩阵的方式来微调模型的部分参数,特别适用于大型 Transformer 模型的微调。与 P-Tuning 相比,LoRA 主要关注减少更新参数的数量,保留大部分预训练参数不变,而 P-Tuning 则是通过添加可学习的提示来调整输入。
  • Prefix Tuning:Prefix Tuning 通过在模型输入前添加一组可学习的前缀来引导模型生成更符合任务需求的输出。与 P-Tuning 类似,Prefix Tuning 也是通过添加额外的可学习参数来进行微调,但它主要在输入的前缀部分进行优化,而 P-Tuning 则在输入的任意位置插入可学习的提示

P-Tuning 的基本概念与原理

Prompt 的概念及其在模型中的作用

Prompt 是在输入文本前或后添加的文本片段,用于引导预训练语言模型生成符合任务需求的输出。在自然语言处理任务中,Prompt 可以是一个问题、指令或描述性内容,使模型能够理解用户的意图,并产生相应的输出。传统上,Prompt 是手工设计的,但这种方式对不同任务的适应性较差,难以捕捉任务的细微差别。P-Tuning 引入了可学习的 Prompt,使得模型可以根据数据自动优化这些提示,从而更好地适应特定任务需求,显著提高了模型在各种任务上的表现。

P-Tuning 如何在语言模型中引入可学习的提示(Prompt)

  • 在 P-Tuning 中,可学习的提示(Prompt)是通过在输入嵌入中插入一系列可学习向量来实现的。这些向量作为额外的输入,结合原始的输入嵌入传递给语言模型,以便模型能够根据这些提示更好地理解和生成符合特定任务需求的输出。
  • 具体来说,P-Tuning 将一组可学习的嵌入插入到输入序列中,这些嵌入向量在训练过程中会随着任务的目标函数一起被优化,从而学会如何有效地引导模型生成更准确的结果。与传统的手工设计 Prompt 不同,这些嵌入是可训练的,可以自动适应数据的特性,使得模型能够更灵活地处理不同任务。
  • 例如,在一个情感分析任务中,P-Tuning 通过在输入文本前后加入可学习的提示,使得模型能够更好地捕捉文本中的情感特征。这些提示在训练过程中被优化,以最大化情感分类的准确率,从而提升模型的整体性能。

与传统 Prompt 的区别:手工设计 vs. 可学习 Prompt

  • 手工设计的 Prompt:手工设计的 Prompt 是由人类专家根据任务需求编写的静态文本片段,目的是为模型提供上下文信息,以便引导模型生成合适的输出。手工设计的 Prompt 通常需要深入理解任务,并且对于不同的任务需要手动调整,这种方式非常依赖人工经验,适应性较差,难以捕捉数据的复杂特性。
  • 可学习的 Prompt(P-Tuning):可学习的 Prompt 是由一组可训练的嵌入向量组成,这些向量会在模型的训练过程中自动优化,从而使模型更好地适应特定任务。这种动态的方式允许 Prompt 直接从数据中学习任务相关的信息,避免了手工设计的繁琐过程,具有更高的灵活性和适应性。通过将可学习的 Prompt 嵌入到输入序列中,模型可以更灵活地理解和生成符合任务需求的输出。
  • 主要区别:手工设计的 Prompt 依赖于人类的直觉和知识,而可学习的 Prompt 则依赖于数据驱动的优化过程。在 P-Tuning 中,可学习的 Prompt 能够根据不同任务的需求自动调整,使得模型可以更好地适应不同的任务和数据分布。与手工设计相比,可学习的 Prompt 具有更强的泛化能力,尤其在复杂或新颖的任务中表现出色。

P-Tuning 的结构与方法

插入可学习向量的方法

  • 在 P-Tuning 中,插入可学习向量的方式是将一组随机初始化的嵌入向量作为模型输入的一部分。具体来说,这些可学习向量被插入到输入序列的前面、后面或中间,形成一个扩展后的输入序列。
  • 这些向量和原始的输入嵌入一起传递给模型的嵌入层,并在训练过程中通过反向传播进行更新,从而使得模型能够根据任务需求调整这些向量的值。
  • 例如,对于一个文本分类任务,可以在输入文本的前面插入若干个可学习的嵌入向量,这些向量在训练过程中不断被优化,使得模型在分类时能够更好地捕捉输入文本的特征。
下面用代码展示了一个P-Tuning的过程
python 复制代码
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

# 加载预训练模型和分词器
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# 定义可学习的提示向量
class PTuningPrompt(nn.Module):
    def __init__(self, prompt_length, hidden_size):
        super(PTuningPrompt, self).__init__()
        self.prompt_embeddings = nn.Parameter(torch.randn(prompt_length, hidden_size))

    def forward(self, batch_size):
        # 返回 (batch_size, prompt_length, hidden_size) 的提示嵌入
        return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)

prompt_length = 10
hidden_size = model.config.hidden_size
p_tuning_prompt = PTuningPrompt(prompt_length, hidden_size)

# 模拟输入文本
input_text = ["This is a positive review.", "This is a negative review."]
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

# 获取提示嵌入
batch_size = inputs.input_ids.shape[0]
prompt_embeddings = p_tuning_prompt(batch_size)

# 将提示嵌入与原始输入嵌入结合
input_embeddings = model.embeddings(input_ids=inputs.input_ids)
extended_embeddings = torch.cat((prompt_embeddings, input_embeddings), dim=1)

# 前向传播
outputs = model(inputs_embeds=extended_embeddings)
print(outputs.last_hidden_state.shape)

训练和更新 Prompt 的策略

  • 优化目标:P-Tuning 中的可学习提示向量在训练过程中是通过反向传播来优化的,类似于传统模型参数的训练过程。优化目标通常是任务的损失函数(例如交叉熵损失),以确保提示向量能够帮助模型最大化任务的表现。
  • 提示向量的更新:在每次前向传播过程中,可学习的提示向量与输入文本一起传递到模型中,计算输出和任务目标之间的损失。然后通过反向传播来更新这些提示向量,使它们逐渐学习到对特定任务有用的信息。
  • 梯度计算 :与模型的其他参数一样,提示向量的梯度也通过反向传播计算得到。通过定义提示向量为 nn.Parameter,这些提示向量会在训练过程中自动被加入到梯度计算中,并使用优化器(例如 AdamW)进行更新。
  • 与其他参数的差异:在 P-Tuning 中,模型的核心参数保持冻结状态,只更新提示向量。这种做法使得在大型模型上进行微调时,可以大大减少需要更新的参数数量,降低计算成本和存储开销。
  • 示例代码:以下代码展示了如何训练和更新可学习提示向量。
python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer

# 加载预训练模型和分词器
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# 定义可学习的提示向量
class PTuningPrompt(nn.Module):
    def __init__(self, prompt_length, hidden_size):
        super(PTuningPrompt, self).__init__()
        self.prompt_embeddings = nn.Parameter(torch.randn(prompt_length, hidden_size))

    def forward(self, batch_size):
        # 返回 (batch_size, prompt_length, hidden_size) 的提示嵌入
        return self.prompt_embeddings.unsqueeze(0).expand(batch_size, -1, -1)

prompt_length = 10
hidden_size = model.config.hidden_size
p_tuning_prompt = PTuningPrompt(prompt_length, hidden_size)

# 模拟输入文本
input_text = ["This is a positive review.", "This is a negative review."]
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

# 获取提示嵌入
batch_size = inputs.input_ids.shape[0]
prompt_embeddings = p_tuning_prompt(batch_size)

# 将提示嵌入与原始输入嵌入结合
input_embeddings = model.embeddings(input_ids=inputs.input_ids)
extended_embeddings = torch.cat((prompt_embeddings, input_embeddings), dim=1)

# 前向传播
outputs = model(inputs_embeds=extended_embeddings)

# 定义损失函数和优化器
loss_fn = nn.MSELoss()  # 示例损失函数
optimizer = optim.AdamW([p_tuning_prompt.prompt_embeddings], lr=1e-4)

# 模拟目标值
target = torch.randn_like(outputs.last_hidden_state)

# 计算损失
loss = loss_fn(outputs.last_hidden_state, target)

# 反向传播和更新提示向量
optimizer.zero_grad()
loss.backward()
optimizer.step()
  • 代码解读
    • 在这段代码中,提示向量 prompt_embeddings 被定义为可训练参数,并在训练过程中通过反向传播进行更新。
    • 损失函数用于计算模型输出和目标之间的差异,通过最小化这个损失,提示向量会逐渐学习如何更好地引导模型生成符合任务需求的输出。
    • 通过 optimizer 来更新提示向量,从而完成 P-Tuning 的训练过程。
相关推荐
PieroPc21 分钟前
Python 写的 智慧记 进销存 辅助 程序 导入导出 excel 可打印
开发语言·python·excel
迅易科技1 小时前
借助腾讯云质检平台的新范式,做工业制造企业质检的“AI慧眼”
人工智能·视觉检测·制造
古希腊掌管学习的神2 小时前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI3 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长3 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
梧桐树04294 小时前
python常用内建模块:collections
python
AI_NEW_COME4 小时前
知识库管理系统可扩展性深度测评
人工智能
Dream_Snowar4 小时前
速通Python 第三节
开发语言·python
海棠AI实验室5 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself5 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot