从零学习大模型(九)-----P-Tuning(下)

代码展示P-Tuning的全过程

python 复制代码
import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# 1. 数据准备
dataset = load_dataset("imdb")

# 2. 构建提示
def add_prompt(examples):
    examples['text'] = ["这段文本的情感是:'{}'".format(text) for text in examples['text']]
    return examples

dataset = dataset.map(add_prompt)

# 3. 模型选择
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 4. 添加可训练的嵌入向量
class PromptEmbedding(nn.Module):
    def __init__(self, prompt_length, embedding_dim):
        super(PromptEmbedding, self).__init__()
        self.prompt_embedding = nn.Parameter(torch.randn(prompt_length, embedding_dim))

    def forward(self, x):
        prompt = self.prompt_embedding.unsqueeze(0).repeat(x.size(0), 1, 1)  # 扩展到batch大小
        return torch.cat((prompt, x), dim=1)

# 定义新模型
class P_Tuning_BERT(nn.Module):
    def __init__(self, base_model, prompt_length):
        super(P_Tuning_BERT, self).__init__()
        self.base_model = base_model
        self.prompt_embedding = PromptEmbedding(prompt_length, base_model.bert.config.hidden_size)

    def forward(self, input_ids, attention_mask=None, labels=None):
        # 获取原始的输入嵌入
        embeddings = self.base_model.bert.embeddings(input_ids)
        # 添加prompt嵌入
        embeddings = self.prompt_embedding(embeddings)
        outputs = self.base_model.bert(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = self.base_model.classifier(outputs[1])  # 只取池化输出
        return (logits,)

# 设置P-Tuning模型
prompt_length = 5  # Prompt的长度
p_tuning_model = P_Tuning_BERT(model, prompt_length)

# 冻结原模型参数
for param in p_tuning_model.base_model.parameters():
    param.requires_grad = False

# 5. 数据预处理
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 6. 微调过程
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=p_tuning_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

# 7. 训练模型
trainer.train()

# 8. 测试模型
trainer.evaluate()

# 9. 应用模型
def predict(text):
    p_tuning_model.eval()
    inputs = tokenizer("这段文本的情感是:'{}'".format(text), return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = p_tuning_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
    logits = outputs[0]
    predicted_class = torch.argmax(logits, dim=-1)
    return "积极" if predicted_class.item() == 1 else "消极"

# 测试应用
print(predict("这家餐厅的服务很好。"))

P-Tuning的实验结果

文本分类任务

  • P-Tuning:在小数据集上,F1-score为0.85,训练时间为1小时。
  • 全参数微调:在相同数据集上,F1-score为0.88,训练时间为3小时,但在验证集上过拟合(F1-score为0.80)。

对话生成任务

  • P-Tuning:生成的回复自然性评分为4.2/5,训练时间为2小时。
  • 全参数微调:生成的回复自然性评分为4.5/5,但训练时间为5小时。

P-Tuning的优点

1. 计算效率高

  • 参数更新少:P-Tuning仅更新与提示相关的嵌入向量,减少了训练过程中需要优化的参数数量。这意味着在同样的计算资源下,可以更快速地进行实验和模型调整。

2. 减少过拟合风险

  • 冻结预训练模型的参数:通过冻结大部分模型参数,P-Tuning降低了在小数据集上过拟合的风险。对于数据量有限的任务,P-Tuning能够更好地泛化。

3. 灵活性和适应性强

  • 任务适应性:可以通过简单地调整提示内容来适应不同的任务,无需修改整个模型架构。这使得在多任务场景中,P-Tuning能够快速切换和调整。
  • Prompt设计自由:研究者可以根据具体任务设计不同的提示,以探索对模型性能的影响。这种灵活性允许在多个任务之间共享同一模型,而只需修改提示。

4. 易于实现和部署

  • 实现简单:相较于全参数微调,P-Tuning的实现更加简便,尤其是在不需要重新训练整个模型的情况下。只需在输入中添加提示即可。
  • 资源需求低:由于只更新部分参数,P-Tuning对计算资源的需求较低,可以在较小的硬件上进行训练和部署。

5. 在小数据集上的表现良好

  • 数据效率高:P-Tuning特别适用于小数据集场景,在这些场景中,训练整个模型可能导致性能下降,而P-Tuning可以利用预训练的知识,有效提升模型的性能。

6. 提升模型的可解释性

  • 可解释性增强:由于P-Tuning强调了提示的作用,研究者可以更清晰地理解模型如何通过特定提示来做出不同的决策。这对于分析模型的行为和结果非常有帮助。

7. 迁移学习效果好

  • 知识迁移:P-Tuning能够有效地利用预训练模型中存储的知识,通过适当的提示,将这种知识迁移到新任务中。这使得在许多下游任务中,P-Tuning能够实现与全参数微调相当甚至更好的性能。

P-Tuning的局限性

1. 提示设计的依赖性

  • 提示的有效性:P-Tuning的性能高度依赖于提示的设计和选择。不同的提示可能会导致模型产生不同的预测结果。如果提示设计不当,可能会影响模型的理解和预测能力。
  • 提示选择的挑战:设计有效的提示需要领域知识和经验,这对于非专业人士来说可能是一个挑战。

2. 学习到的提示嵌入的复杂性

  • 提示嵌入的可解释性:虽然P-Tuning提供了一定的可解释性,但学习到的提示嵌入的具体意义和如何影响模型决策可能仍然不够清晰。研究者可能难以解读这些嵌入的具体作用。
  • 相似性问题:不同任务或数据集可能会导致提示嵌入相似性较高,导致模型在迁移到新任务时表现不佳。

3. 数据集和任务的限制

  • 适用性问题:P-Tuning在小数据集上表现良好,但在大规模和复杂任务中,可能无法完全发挥预训练模型的潜力。在某些情况下,全参数微调可能仍然是更优的选择。
  • 数据分布差异:如果训练和测试数据的分布差异较大,P-Tuning的效果可能受到影响,特别是如果提示未能充分捕捉任务的关键特征。

4. 对训练资源的需求

  • 额外的训练时间:尽管P-Tuning的训练参数较少,但学习提示嵌入仍然需要一定的训练时间和计算资源。在资源有限的情况下,可能仍需权衡使用全参数微调与P-Tuning的选择。

5. 任务特定性

  • 领域适应性:某些领域的特定任务可能不适合使用P-Tuning,尤其是在需要高度专业化的知识和上下文理解的情况下。全参数微调可能更好地适应这些特定的领域。

6. 模型性能的极限

  • 性能瓶颈:由于只更新部分参数,P-Tuning在某些情况下可能无法突破预训练模型的性能极限。在需要极高性能的任务中,全参数微调可能更能挖掘模型的潜力。

P-Tuning的未来发展方向

1. 大规模模型的适应性

  • 模型架构的调整:为适应更大规模的模型,P-Tuning可以通过调整提示嵌入的维度和数量来保持与模型的对齐。这意味着需要为每个新任务设计适当的嵌入结构。
  • 分层提示:对于大型模型,可以设计分层的提示结构,允许在不同层次上进行信息传递,从而使模型更有效地利用提示信息。

2. 多任务学习

  • 共享提示嵌入:在多任务设置中,可以设计共享的提示嵌入,以便在不同任务之间传递信息。这有助于提高模型的训练效率,并减少为每个任务单独训练提示的需求。
  • 动态提示调整:利用动态生成的提示来适应不同任务的需求。通过实时分析任务特征,生成适合特定任务的提示,从而增强模型的适应性。

3. 增强训练方法

  • 自适应学习率:为不同任务的提示嵌入设置不同的学习率,以便更好地适应每个任务的特性。这可以通过监控每个任务的性能来动态调整学习率。
  • 数据增强:结合数据增强技术,在训练过程中引入多样化的训练样本,从而提高模型在新任务和大规模数据集上的泛化能力。

4. 集成方法

  • 与其他技术结合:将P-Tuning与其他微调技术(如LoRA、Adapter等)结合使用,可以进一步提升模型的性能。这些技术可以帮助在不大幅增加模型参数的情况下,增强模型对新任务的适应性。
  • 知识蒸馏:通过知识蒸馏技术,将大型模型的知识迁移到较小的模型中,同时利用P-Tuning进行微调,可以在资源有限的情况下实现较好的性能。

5. 任务定制化

  • 针对性任务提示设计:针对特定任务或领域设计专门的提示嵌入,以确保它们能有效捕捉任务特征。这可能包括对领域特定的语言和上下文的理解。
  • 领域适应性:在特定领域(如医疗、法律等)中,通过细化提示以增强对领域术语和上下文的理解,提升模型在特定领域任务上的表现。
相关推荐
baiduopenmap12 分钟前
百度世界2024精选公开课:基于地图智能体的导航出行AI应用创新实践
前端·人工智能·百度地图
hopetomorrow13 分钟前
学习路之压力测试--jmeter安装教程
学习·jmeter·压力测试
hopetomorrow13 分钟前
学习路之PHP--使用GROUP BY 发生错误 SELECT list is not in GROUP BY clause .......... 解决
开发语言·学习·php
小任同学Alex15 分钟前
浦语提示词工程实践(LangGPT版,服务器上部署internlm2-chat-1_8b,踩坑很多才完成的详细教程,)
人工智能·自然语言处理·大模型
新加坡内哥谈技术21 分钟前
微软 Ignite 2024 大会
人工智能
/**书香门第*/42 分钟前
Cocos creator 3.8 支持的动画 7
学习·游戏·游戏引擎·游戏程序·cocos2d
江瀚视野1 小时前
Q3净利增长超预期,文心大模型调用量大增,百度未来如何分析?
人工智能
陪学1 小时前
百度遭初创企业指控抄袭,维权还是碰瓷?
人工智能·百度·面试·职场和发展·产品运营
QCN_1 小时前
湘潭大学人工智能考试复习1(软件工程)
人工智能
Landy_Jay1 小时前
深度学习:GPT-1的MindSpore实践
人工智能·gpt·深度学习