多头注意力&位置编码:完型填空任务

主题

完型填空任务是自然语言处理中的经典问题,要求模型根据上下文预测被遮盖的单词(如 [MASK])。本文对比了两种实现方式:基于 PyTorch 的自定义 Transformer 模型基于 Hugging Face Transformers 的预训练 BERT 模型,并探讨了它们的优缺点及改进方向

1. PyTorch 自定义 Transformer 模型

实现概述
  • 使用 PyTorch 构建了一个简单的 Transformer 模型,包括嵌入层、位置编码、多头注意力机制和全连接层。
  • 数据通过手动构建的词汇表进行索引化,训练时对 [MASK] 位置的单词进行预测。
代码核心
python 复制代码
class TransformerModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_heads, max_len):

        super(TransformerModel, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, embed_dim))

        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x, mask):

        x = self.embedding(x) + self.positional_encoding[:, :x.size(1), :]

        x = x.transpose(0, 1)

        x, _ = self.attention(x, x, x, key_padding_mask=~mask.squeeze(1).squeeze(1))

        x = x.transpose(0, 1)

        logits = self.fc(x)

        return logits
graph TD A(开始) --> B(加载训练和测试数据) B --> C(构建词汇表) C --> D(将训练样本转换为索引) D --> E(定义Transformer模型) E --> F(初始化模型/损失函数和优化器) F --> G(开始训练循环) G --> H(每个epoch中) H --> I(将输入数据传入模型) I --> J(计算MASK的损失) J --> K(反向传播并优化) K --> L{是否达到指定 epoch?} L -- 否 --> H L -- 是 --> M(测试模型) M --> N(对测试句子进行预测) N --> O(输出预测结果) O --> P(结束)
优点
  1. 灵活性:可以根据任务需求自由调整模型结构(如层数、嵌入维度、注意力头数)。
  2. 学习过程透明:从零开始构建模型,便于理解 Transformer 的内部机制。
缺点
  1. 训练成本高:需要从头训练,数据量不足时效果有限。
  2. 性能有限:模型未经过大规模预训练,难以捕获复杂的语言模式。
  3. 数据依赖强:需要大量标注数据才能达到较好的效果。
改进方向
  • 增加训练数据量,或使用数据增强技术。
  • 引入预训练权重(如加载预训练的 Transformer 模型)。
  • 使用动态学习率调度器和梯度裁剪优化训练过程。

然而没感受到明显的进步和提升,反倒是感受十几年前被BP梯度无法收敛支配的恐惧,一方面是造的训练数据有漏洞,另外就是训练数据的体量有点儿小。

2. Transformers 预训练 BERT 模型

实现概述
  • 使用 Hugging Face 的 transformers 库加载预训练的 bert-base-chinese 模型。
  • 直接利用 BERT 的 Masked Language Model (MLM) 任务能力,对 [MASK] 位置进行预测。
代码核心
模型下载
bash 复制代码
 git clone https://www.modelscope.cn/tiansz/bert-base-chinese.git

实现

python 复制代码
from transformers import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

model = BertForMaskedLM.from_pretrained("bert-base-chinese")

sentence = "小明喜欢吃[MASK],因为很甜。"

inputs = tokenizer(sentence, return_tensors="pt")

mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

with torch.no_grad():

    outputs = model(**inputs)

    logits = outputs.logits

mask_token_logits = logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
优点
  1. 高性能:BERT 已在大规模语料上预训练,能够捕获复杂的上下文关系。
  2. 易用性:无需从头训练,只需微调即可适应特定任务。
  3. 丰富的工具支持:Hugging Face 提供了大量预训练模型和工具,方便快速开发。
缺点
  1. 依赖预训练模型:需要下载预训练权重,可能受限于网络环境。
  2. 灵活性较低:模型结构固定,难以根据任务需求进行大幅修改。
  3. 计算资源需求高:BERT 模型较大,推理和微调需要更多的计算资源。
相关推荐
新智元17 小时前
刚刚,英伟达祭出下一代 GPU!狂飙百万 token 巨兽,投 1 亿爆赚 50 亿
人工智能·openai
霍格沃兹_测试17 小时前
从零开始搭建Qwen智能体:新手也能轻松上手指南
人工智能
SmartJavaAI17 小时前
Java调用Whisper和Vosk语音识别(ASR)模型,实现高效实时语音识别(附源码)
java·人工智能·whisper·语音识别
山东小木17 小时前
JBoltAI需求分析大师:基于SpringBoot的大模型智能需求文档生成解决方案
人工智能·spring boot·后端·需求分析·jboltai·javaai·aigs
君名余曰正则17 小时前
【竞赛系列】机器学习实操项目08——全球城市计算AI挑战赛(数据可视化分析)
人工智能·机器学习·信息可视化
算家计算17 小时前
一张图+一段音频=电影级视频!阿里Wan2.2-S2V-14B本地部署教程:实现丝滑口型同步
人工智能·开源·aigc
Moonbit17 小时前
MoonBit 再次走进清华:张宏波受邀参加「思源计划」与「程序设计训练课」
前端·后端·编程语言
XINVRY-FPGA17 小时前
XCVP1902-2MSEVSVA6865 AMD 赛灵思 XilinxVersal Premium FPGA
人工智能·嵌入式硬件·神经网络·fpga开发·云计算·腾讯云·fpga
RestCloud17 小时前
一站式数据集成:iPaaS 如何让开发者和业务人员都满意?
前端·后端·架构
算家计算17 小时前
多年AI顽疾被攻克!OpenAI前CTO团队破解AI随机性难题,大模型可靠性迎来飞跃
人工智能·llm·资讯