知识蒸馏(KD)详解一:认识一下BERT 模型

知识蒸馏(KD)详解一:认识一下BERT 模型

1 什么是BERT

简单来说,BERT其实就是一个预训练模型,是一个文本特征提取器。它的核心思想:

第一,就是在一个很大的预料库上进行预训练,得到语言的表征的一个预训练模型;

第二,下游在接一个具体任务的一些层,如文本分类、情感识别、问答等,和这个预训练模型进行一起微调训练,进而实现具体的任务。

论文参考:https://arxiv.org/pdf/1810.04805

2 模型

2.1 结构

模型分为以下几个部分:

  1. 输入 [CLS] Masked SentenceA [SEP] Masked Sentence B [SEP] 对应的token ids
    如:"[CLS] my dog is cute [SEP] he likes play [MASK] [SEP]",其中[MASK] 是被隐藏掉的词,用于训练预测。
    转换为对应的token ids : [ 101, 2026, 3899, 2003, 10140, 102, 2002, 7777, 2377, 103,
    102]
    ,其中[CLS]是101,[MASK]是103,[SEP]是102

  1. Embeddings 部分

    Token Embeddings :对应的是将30522个词所对应的id转为768维度的向量,如"my" 的id 是 2026,就变成一个768的向量

    Segment Embeddings:来分割句子A和B,前一个句子A所有token 是0,后一个句子B所有token 是1,前一个句子包含CLS,结果就是[0,0,0,0,0,0,1,1,1,1,1],然后embeddings 成768维向量

    Position Embeddings:位置编码,最大512个token输入,输出维度也是768维度,具体位置编码可查看transformer 结构中的实现。

    最终的embeddings 就是将三者相加起来。


  1. 骨干网络:Bidirectional Transformers 也就是 Transformer Encoder
    堆叠的N层transformer encoder,论文作者给出了两种不同参数的模型
    在这里插入图片描述

    可以看到我们上面例子768维度的embeddings就是论文中提到的一种输入输出维度,输出维度(batch, SeqLen, 768)

  1. 下游任务:NSP + Mask LM
    要想得到一个预训练模型,必须是骨干网络+下游任务。论文中的预训练模型就是由NSP + Mask LM下游任务得到的。
    NSP: Linear(in_features=768, out_features=2, bias=True) ,将骨骼网络输出对应CLS位置的向量转为一个2维的向量,来分类判断句子B是否句子A下句的概率,标签1是,0不是。
    Mask LM: (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (transform_act_fn): GELUActivation()
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=30522, bias=True)
    )
    除了CLS对应的输出向量外,其他token 对应的骨干网络输出向量经过以上层,得到30522个词对应的概率。
    然后对NSP和Mask LM 输出向量进行softmax 并argmax 找到概率最大的位置,即是输出token ids
    最终输出可写为如模型结构图中所示:

2.2 训练

我们要知道我们拿bert 是干什么的,目的是用来做特征提取的预训练,在结合自己的任务做微调训练,因此自然分为两部分:Pre-Train,Fine-Tuning


  1. Pre-Train

    前面已经说过,论文是拿NSP+MLM 预测头来预训练模型,这种形式标签可自动生成,不需要额外标注。
    NPS: NSP 的具体做法是,其中,50% 的概率将语义连贯的两个连续句子作为训练文本(同一篇文章连续两句话),另外 50% 的概率将完全随机抽取两个句子作为训练文本。如下实例:

    连续句对:[CLS]你人呢?[SEP]一整天都没看到你了。[SEP] label:1

    随机句对:[CLS]今天晚上吃烤鱼。[SEP]地球是圆的。[SEP] label:0

    预测的C 概率最大的位置对应即对应的label,1表示后面句子B是前面句子A的下一句,否则不是。
    MLM: 针对样本,每一个token都有15%的概率被选为mask,而对于每一个mask:80%的概率是[MASK],10%是原词,10%是其他词,而标签始终是原词。以此来构造mask。


  1. Fine-Tuning
    在Pre-Train 基础得到的不带预测头的骨干网络,在加上新的任务的预测头来进行微调训练。也就是基于BERT的预训练骨干网络和权重来构建新的训练任务。

2.3 损失

  1. MLM(Masked Language Modeling)损失
  • 设一句输入长度为 SS,词表大小 ∣V∣|V|∣V∣。从非特殊符号位置里随机抽出 15% 作为掩码集合 M\mathcal{M}M(80/10/10 只影响输入端替换标签始终是原词)。

  • 对每个被选中位置i∈Mi \in \mathcal{M}i∈M:

    • 模型给出词表 logits:zi∈R∣V∣{z}_i \in \mathbb{R}^{|V|}zi∈R∣V∣

    • 该位置的真实 token id:yi∈{0,...,∣V∣−1}y_i \in \{0,\dots,|V|-1\}yi∈{0,...,∣V∣−1}

    • 单点交叉熵:CEi=−log⁡softmax(zi)yi{CE}_i = -\log \mathrm{softmax}(\mathbf{z}i){y_i}CEi=−logsoftmax(zi)yi

  • MLM 总损失(常用平均):

LMLM=1∣M∣∑i∈MCEi{L}{\text{MLM}} = \frac{1}{|\mathcal{M}|}\sum{i\in \mathcal{M}} \mathrm{CE}_iLMLM=∣M∣1i∈M∑CEi

其他位置(未被选择)不计入损失。PyTorch 里通常把这些位置的 label 设为 -100 ,配合 ignore_index=-100


  1. NSP(Next Sentence Prediction)损失
  • 输入是句对 [A;B],标签 yNSP∈{0,1}y_{\text{NSP}}\in\{0,1\}yNSP∈{0,1}(1=IsNext , 0=NotNext)。

  • [CLS] 的(pooler 后)向量接一个线性层得到二类 logits:r∈R2{r}\in\mathbb{R}^{2}r∈R2。

  • NSP 损失(二分类交叉熵):

LNSP=−log⁡softmax(r) yNSP{L}{\text{NSP}} = -\log \mathrm{softmax}(\mathbf{r}){\,y_{\text{NSP}}}LNSP=−logsoftmax(r)yNSP


  1. 联合目标(原始 BERT)

L=LMLM+LNSP{L} = \mathcal{L}{\text{MLM}} + \mathcal{L}{\text{NSP}}L=LMLM+LNSP

论文中不额外加权;实现里通常直接相加(同量级)。


3 如何使用BERT

https://huggingface.co/docs/transformers/model_doc/bert

https://huggingface.co/google-bert/bert-base-uncased

huggingface中提供了很多bert 模型和使用方法,而我们将使用论文中提到的骨骼网络+NSP+MLM的模型结构来测试。

code:

python 复制代码
from transformers import AutoTokenizer, BertForPreTraining  
import torch  
  
device = "cuda:0" # cup  
# 手动步骤实现(更深入了解过程)  
model_name = "bert-base-uncased"  
tokenizer = AutoTokenizer.from_pretrained(model_name)  
model = BertForPreTraining.from_pretrained(model_name).eval().to(device)                     # .cuda()  
print(model)  
# 待分析的句子  
textA = "my dog is cute"  
textB = "he likes play [MASK]"  
  
# 返回包含 'input_ids', 'token_type_ids', 'attention_mask' 的字典  
inputs = tokenizer(textA, textB, return_tensors="pt", truncation=True).to(device)            # .cuda()  
#  转换成tokens 查看  
in_toks = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])  
print("[输入]:")  
print(inputs)  
print(in_toks)  
  
# 将输入传递给模型,得到输出  
with torch.no_grad(): # 推理时不需要计算梯度  
    outputs = model(**inputs, output_hidden_states=True)  
  
# MLM 输出  
mlm_logits = outputs.prediction_logits                                                    #(1, 11, 30522)  
mask_pos = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]  # 找到mask的位置 9mask_pos = mask_pos[0].item()  
mask_probs = torch.softmax(mlm_logits[0, mask_pos], dim=-1)                               # 计算 mask 概率  
  
# top-5 预测  
mask_top5_tok_ids = torch.topk(mask_probs, k=5, dim=-1).indices                           # 取概率前5位置ids  
mask_top5_toks = tokenizer.convert_ids_to_tokens(mask_top5_tok_ids.cpu().tolist())  
print("[输出top5]:")# 转换成tokens  
print(mask_top5_toks)  
  
# 用top1 回填原句子  
filled_ids = inputs["input_ids"][0].clone()  
filled_ids[mask_pos] = mask_top5_tok_ids[0].item()  
filled = tokenizer.decode(filled_ids)  
print("\n[MLM] 回填后的句子:")  
print(filled)  
  
# nsp 输出  
nsp_logits = outputs.seq_relationship_logits  
nsp_probs = torch.softmax(nsp_logits, dim=-1)  
print(f"\n[NSP] 概率:IsNext={nsp_probs[0,0].item():.4f}, NotNext={ nsp_probs[0, 1].item():.4f}")  
  
# 骨干网络输出  
last_hidden = outputs.hidden_states[-1]  
print("\n[Encoder] last_hidden_state 形状:", tuple(last_hidden.shape))  
print("[Encoder] [CLS] 向量前5维:", last_hidden[0, 0, :5])  
  
print()

inputs_ids就是2.1中提到的输入tokens ids

token_type_ids 就是2.1中提到的用于Segment Embeddings 的ids

attention_mask 是针对于改句子是否有padding,padding的位置为0,因为训练时整个batch,要求等长,末尾需要补充padding

结果:

相关推荐
Cathyqiii4 小时前
生成对抗网络(GAN)
人工智能·深度学习·计算机视觉
ai产品老杨5 小时前
打通各大芯片厂商相互间的壁垒,省去繁琐重复的适配流程的智慧工业开源了
人工智能·开源·音视频·能源
小陈phd6 小时前
高级RAG策略学习(五)——llama_index实现上下文窗口增强检索RAG
人工智能
凯禾瑞华养老实训室8 小时前
人才教育导向下:老年生活照护实训室助力提升学生老年照护服务能力
人工智能
湫兮之风9 小时前
Opencv: cv::LUT()深入解析图像块快速查表变换
人工智能·opencv·计算机视觉
Christo39 小时前
TFS-2018《On the convergence of the sparse possibilistic c-means algorithm》
人工智能·算法·机器学习·数据挖掘
qq_508823409 小时前
金融量化指标--2Alpha 阿尔法
大数据·人工智能
黑金IT10 小时前
`.cursorrules` 与 `.cursorcontext`:Cursor AI 编程助手时代下的“双轨配置”指南
人工智能
dlraba80210 小时前
基于 OpenCV 的信用卡数字识别:从原理到实现
人工智能·opencv·计算机视觉