一、BERT
1. 核心概述
-
全称:Bidirectional Encoder Representations from Transformers
-
提出者:Google 团队(2018年发表)
-
核心贡献 :引入了深度双向架构 ,确立了 "预训练 (Pre-training) + 微调 (Fine-tuning)" 的新范式,在发布时刷新了11项 NLP 任务的 SOTA(State of the Art)记录。
-
核心思想:通过大规模无标注语料进行预训练,学习到包含丰富上下文信息的词向量表示(Contextualized Embeddings),再用少量的标注数据针对特定下游任务进行微调。
2. 模型架构
-
骨干网络 :Transformer 的 Encoder(编码器) 部分。
-
双向性 :与早期单向语言模型(如 GPT-1 从左到右生成)不同,BERT 基于 Self-Attention 机制,在处理每一个词时,都能同时看到其左侧和右侧的所有上下文信息。
-
经典版本:
-
BERT-Base:12层 (L=12),隐藏层维度768 (H=768),注意力头数12 (A=12),总参数量约 1.1亿。
-
BERT-Large:24层 (L=24),隐藏层维度1024 (H=1024),注意力头数16 (A=16),总参数量约 3.4亿。
-
3. 输入表示 (Input Representation)
BERT的输入是由三种 Embedding 直接相加构成的:
-
Token Embeddings(词向量):使用 WordPiece 算法将单词切分为子词(Subword,如 playing -> play + ##ing),以缓解 OOV(未登录词)问题。
-
Segment Embeddings(句子向量):用于区分输入的两个句子(例如句子A标记为0,句子B标记为1)。
-
Position Embeddings(位置向量) :与 Transformer 原生的正弦函数绝对位置编码不同,BERT 使用的是可学习的绝对位置编码。最大支持长度通常为 512。
关键的特殊标记 (Special Tokens):
-
CLS\]:永远放在序列的**第一个位置** 。其最后一层的输出向量被用来代表整个句子的语义,常用于**句子级别的分类任务**。
-
MASK\]:在预训练的 MLM 任务中,用于遮蔽真实的单词。
BERT 能够在无监督数据上大放异彩,归功于它设计的两个预训练任务:
任务一:Masked Language Model (MLM, 掩码语言模型)
-
目的:迫使模型通过上下文来预测被遮蔽的词,从而学习双向语境。
-
做法 :随机选择输入序列中 15% 的 Token 作为目标进行预测。
-
细节(为了缓解预训练和微调时的输入不一致问题,因为微调时不会出现 [MASK]) :
在这被选中的15%的 Token 中:
-
80% 的概率:替换为 [MASK](例如:my dog is hairy -> my dog is[MASK])。
-
10% 的概率:替换为随机的一个词(例如:my dog is hairy -> my dog is apple)。
-
10% 的概率:保持原词不变(例如:my dog is hairy -> my dog is hairy),但模型依然需要去预测它,以验证其正确性。
-
任务二:Next Sentence Prediction (NSP, 下一句预测)
-
目的:让模型学习句子之间的逻辑关系,对 QA(问答)和 NLI(自然语言推理)等任务有很大帮助。
-
做法:二分类任务。输入句子A和句子B,判断B是否真的是A的下一句。
-
50% 正样本:B 确实在真实文本中紧跟在 A 后面(IsNext)。
-
50% 负样本:B 是从语料库中随机抽取的一个句子(NotNext)。
-
5. 微调机制 (Fine-tuning)
BERT 在处理下游任务时非常灵活,只需要对输入和输出做少量修改,并微调所有参数即可:
-
单句分类(如情感分析):直接取句首 [CLS] 标记对应的隐藏层向量,接入一个全连接层进行分类。
-
句子对分类(如语义相似度/自然语言推理):句子 A 和 B 通过 [SEP] 拼接,同样取 [CLS] 的输出进行分类。
-
序列标注(如命名实体识别 NER):取每一个 Token 对应的最终隐藏层向量,接入分类器(或 CRF 层)预测每个词的标签。
-
阅读理解问答(如 SQuAD):输入问题和段落,预测答案在段落中的"起始位置"和"结束位置"。
二、代码
import torch
from torch import nn
import test_68transformer
def get_tokens_and_segments(tokens_a,tokens_b=None):
tokens=['<cls>']+tokens_a+['<sep>']
segments=[0]*(len(tokens_a)+2)
if tokens_b is not None:
tokens+=tokens_b+['<sep>']
segments+=[1]*(len(tokens_b)+1)
return tokens,segments
class BERTEncoder(nn.Module):
def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,
max_len=1000,key_size=768,query_size=768,value_size=768,**kwargs):
super(BERTEncoder, self).__init__(**kwargs)
self.token_embedding=nn.Embedding(vocab_size,num_hiddens)#有多少需要编号的词*每个词多少维度
self.segment_embedding=nn.Embedding(2,num_hiddens)#两类0,1*每个label几个维度
self.blks=nn.Sequential()
for i in range(num_layers):
self.blks.add_module(f"{i}",test_68transformer.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,
ffn_num_input,ffn_num_hiddens,num_heads,dropout,True))
self.position_embedding=nn.Parameter(torch.randn(1,max_len,num_hiddens))#1方便广播,与batch_size对应,max_len最长长度;num_hiddens
def forward(self,tokens,segments,valid_len):
X=self.token_embedding(tokens)+self.segment_embedding(segments)
X=X+self.position_embedding.data[:,:X.shape[1],:]
for blk in self.blks:
X=blk(X,valid_len)
return X
# print(encoder_X.shape)
class MaskLM(nn.Module):
#从X中挑出mask,然后对其做预测,即BERT的MLM(Masked Language Model,掩蔽语言模型
def __init__(self,vocab_size,num_hiddens,num_inputs=768,**kwargs):
super(MaskLM, self).__init__(**kwargs)
self.mlp=nn.Sequential(
nn.Linear(num_inputs,num_hiddens),
nn.ReLU(),
nn.LayerNorm(num_hiddens),
nn.Linear(num_hiddens,vocab_size))
def forward(self,X,pred_position):
#X的形状(batch_size,seq_len,num_hiddens)
#pred_position的形状(batch_size,num_pred_position)
#输出的形状(batch_size,num_pred_position,vocab_size)
num_pred_position=pred_position.shape[1]
pred_position=pred_position.reshape(-1)
batch_size=X.shape[0]
batch_idx=torch.arange(0,batch_size)
batch_idx=torch.repeat_interleave(batch_idx,num_pred_position)
masked_X=X[batch_idx,pred_position]
masked_X=masked_X.reshape((batch_size,num_pred_position,-1))
mlm_Y_hat=self.mlp(masked_X)
return mlm_Y_hat
vocab_size,num_hiddens,ffn_num_hiddens,num_heads=10000,768,1024,4
norm_shape,ffn_num_input,num_layers,dropout=[768],768,2,0.2
encoder=BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
tokens=torch.randint(0,vocab_size,(2,8))
segments=torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,0,1,1,1,1]])
encoder_X=encoder(tokens,segments,None)
mlm=MaskLM(vocab_size,num_hiddens)
mlm_positions=torch.tensor([[1,5,2],[6,1,5]])
mlm_Y_hat=mlm(encoder_X,mlm_positions)
# print(mlm_Y_hat.shape)
mlm_Y=torch.tensor([[7,8,9],[10,20,30]])
loss=nn.CrossEntropyLoss(reduction='none')
mlm_l=loss(mlm_Y_hat.reshape((-1,vocab_size)),mlm_Y.reshape(-1))
# print(mlm_l.shape)
#预测下一个句子
class NextSentencePred(nn.Module):
def __init__(self,num_inputs,**kwargs):
super(NextSentencePred, self).__init__(**kwargs)
self.output=nn.Linear(num_inputs,2)
def forward(self, X):
return self.output(X)
encoder_X=torch.flatten(encoder_X,start_dim=1)#X变成(batch_size,seq_len*num_hiddens)
nsp=NextSentencePred(encoder_X.shape[-1])#只对最后一层数值全连接操作
nsp_Y_hat=nsp(encoder_X)#输出维度变为(batch_size,2)
# print(nsp_Y_hat.shape)
nsp_y=torch.tensor([0,1])
nsp_l=loss(nsp_Y_hat,nsp_y)
# print(nsp_l.shape)
class BERTModel(nn.Module):
def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,
num_heads,num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,hid_in_feature=768,
mlm_in_feature=768,nsp_in_feature=768,**kwargs):
super(BERTModel, self).__init__()
self.encoder=BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout,
max_len=max_len,key_size=key_size,query_size=query_size,value_size=value_size)
self.hidden=nn.Sequential(nn.Linear(hid_in_feature,num_hiddens),nn.Tanh())
self.mlm=MaskLM(vocab_size,num_hiddens,mlm_in_feature)
self.nsp=NextSentencePred(nsp_in_feature)
def forward(self,tokens,segments,valid_len=None,pred_position=None):
encoded_X=self.encoder(tokens,segments,valid_len)
if pred_position is not None:
mlm_Y_hat=self.mlm(encoded_X,pred_position)
else:
mlm_Y_hat=None
nsp_Y_hat=self.nsp(self.hidden(encoded_X[:,0,:]))
return encoded_X,mlm_Y_hat,nsp_Y_hat