《动手学深度学习》-69预训练bert数据集实现

一、预训练bert数据集

在使用 BERT 进行预训练时,通常需要准备两个关键的任务数据:

  1. 下一句预测任务(NSP,Next Sentence Prediction)
  2. 遮蔽语言模型任务(MLM,Masked Language Model)

这两个任务有助于模型理解上下文关系以及填充缺失的词汇,从而增强其对自然语言的理解能力。

1. 数据加载和预处理

在准备数据集时,首先需要加载和预处理数据,将文本从原始数据中提取出来,并进行相应的处理,使其能够输入到 BERT 中。通常,数据集需要先进行词元化(tokenization),将每个句子拆解成小的词元(单词或子词)。

2. 下一句预测任务(NSP)

BERT 通过 下一句预测(NSP) 任务来训练模型判断两个句子是否连续。为此,我们需要从段落中获取相邻的句子,并创建标签,表示它们是否是连续的。

  • 相邻句子:从段落中提取相邻的句子,标签为 "是"。
  • 不相关句子:随机选择段落中的其他句子,将其与当前句子配对,标签为 "否"。

这个任务的目标是让模型学习判断两句话是否来自同一篇文章,并理解句子间的关系。

3. 遮蔽语言模型任务(MLM)

BERT 通过 遮蔽语言模型(MLM) 任务来训练模型根据上下文预测被遮蔽的词元。为了生成 MLM 数据,我们需要随机选择部分词元进行遮蔽,并记录被遮蔽的词元标签。

  • 随机遮蔽 :在每个句子中,随机选择约 15% 的词元进行遮蔽。大部分情况下(80%),这些词元会被 <mask> 标记替代。剩下的 10% 保留原词,另外 10% 用随机词替代。这种方式帮助模型学习如何基于上下文预测被遮蔽的词元。

通过这一任务,BERT 能够更好地理解语言的语法和上下文信息。

4. 填充数据

由于 BERT 的输入要求所有句子的长度一致,因此在生成数据时,需要对每个句子进行填充(padding)。所有序列会被扩展到指定的最大长度 max_len,并生成有效长度(valid_lens)以确保模型能够区分填充部分和实际内容。

填充过程包括:

  • 在句子的末尾添加特殊的 <pad> 标记,直到句子达到最大长度。
  • 对于标记 <pad> 的部分,模型的训练会使用权重(weight)来避免对填充部分进行学习。

二、代码

复制代码
import torch
import os
import random
import d2l
import test_69bert
# d2l.DATA_HUB['wikitext-2'] = (
#     'https://s3.amazonaws.com/research.metamind.io/wikitext/'
#     'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
def read_wiki(data_dir):
    file_name=os.path.join(data_dir,'wiki.train.tokens')#data_dir:这个路径现在是从你手动下载并存放数据集的目录中读取。
    with open(file_name,'r') as f:
        lines = f.readlines()
    paragraphs=[line.strip().lower().strip(' . ') for line in lines if len(line.strip(' . '))>=2]#筛选出连续句子超过2的,对句子去空格,改小写
    random.shuffle(paragraphs)
    return paragraphs
def _get_next_sentence(sentence,next_sentence,paragraphs):#匹配上下句,依靠生成的随机数决定是否上下文相关联,而不是靠人工或者语境
    if random.random()<0.5:
        is_next=True
    else:
        next_sentence=random.choice(random.choice(paragraphs))
        is_next=False
    return sentence,next_sentence,is_next
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):#生成用于判断上下文的数据集
    nsp_data_from_paragraph=[]
    for i in range(len(paragraph)-1):
        token_a,token_b,is_next = _get_next_sentence(paragraph[i],paragraph[i+1],paragraphs)
        if len(token_a)+len(token_b)+3>max_len:
            continue
        tokens,segments=test_69bert.get_tokens_and_segments(token_a,token_b)
        nsp_data_from_paragraph.append((tokens,segments,is_next))
    return nsp_data_from_paragraph
def _replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab):
    mlm_input_tokens=[token for token in tokens]
    pred_positions_and_labels=[]#用于存储被遮蔽的位置和对应的原始标签(即,原始词元)。这个列表将用于计算损失,帮助模型学习预测被遮蔽的词元。
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        masked_token=None
        if random.random()<0.8:
            masked_token='<mask>'
        else:
            if random.random()<0.5:
                masked_token=tokens[mlm_pred_position]
            else:
                masked_token=random.choice(vocab.idx_to_token)
        mlm_input_tokens[mlm_pred_position]=masked_token#将 mlm_input_tokens 中相应位置的词元替换为 masked_token。
        pred_positions_and_labels.append(mlm_pred_position,tokens[mlm_pred_position])#存储了所有被遮蔽的词元位置和它们的标签(即原始的词元)
    return mlm_input_tokens,pred_positions_and_labels
def _get_mlm_data_from_tokens(tokens,vocab):
    candidate_pred_positions=[]
    for i,token in enumerate(tokens):
        if token in ['<cls>','<sep>']:
            continue
        candidate_pred_positions.append(i)
    num_mlm_preds=max(1,round(tokens)*0.15)
    mlm_input_tokens,pred_positions_and_labels=_replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab)
    pred_positions_and_labels=sorted(pred_positions_and_labels,key=lambda x:x[0])
    pred_positions=[v[0] for v in pred_positions_and_labels]
    mlm_pred_labels=[v[1] for v in pred_positions_and_labels]
    return vocab[mlm_input_tokens],pred_positions,mlm_pred_labels
def _pad_bert_input(example,max_len,vocab):
    max_num_mlm_pred=round(max_len*0.15)
    all_tokens_ids,all_segments,valid_lens=[],[],[]
    all_pred_positions,all_mlm_weights,all_mlm_labels=[],[],[]
    nsp_labels=[]
    for (token_ids,pred_positions,mlm_pred_label_ids,segmentms,is_next) in example:
        all_tokens_ids.append(torch.tensor(token_ids+[vocab['<pad>']]*(max_len-len(token_ids)),dtype=torch.long))
        all_segments.append(torch.tensor(segmentms+[0]*(max_len-len(segmentms)),dtype=torch.float32))
        valid_lens.append(torch.tensor(len(token_ids),dtype=torch.float32))
        all_pred_positions.append(torch.tensor(pred_positions+[0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.float32))
        all_mlm_weights.append(torch.tensor([1.0]*len(mlm_pred_label_ids)+[0.0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.float32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids+[0]*(max_num_mlm_pred-len(mlm_pred_label_ids)),dtype=torch.int32))
        nsp_labels.append(torch.tensor(mlm_pred_label_ids+[0.0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.int32))
        all_mlm_labels.append(torch.tensor(mlm_pred_label_ids+[0]*(max_num_mlm_pred-len(mlm_pred_label_ids)),dtype=torch.long))
        return (all_tokens_ids,all_segments,valid_lens,all_pred_positions,all_mlm_weights,all_mlm_labels)
class _WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self,paragraphs,max_len):
        paragraphs=[d2l.tokenize(paragraph,token='word')for paragraph in paragraphs]#对每个段落进行词元化,按单词分割
        sentences=[sentence for paragraph in paragraphs for sentence in paragraph]
        self.vocab=d2l.Vocab(sentences,min_freq=5,reserved_tokens=['<pad>','<mask>','<cls>','<sep>'])#创建一个词汇表,过滤频率低于5
        examples=[]
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(paragraph,paragraphs,self.vocab,max_len))
            examples=[(_get_mlm_data_from_tokens(tokens,self.vocab)+(segments,is_next))for tokens,segments,is_next in examples]#生成遮蔽语言模型(MLM) 任务的数据
            (self.all_token_ids, self.all_segments, self.valid_lens,
             self.all_pred_positions, self.all_mlm_weights,
             self.all_mlm_labels, self.nsp_labels) = _pad_bert_input(examples, max_len, self.vocab)
    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])
    def __len__(self):
        return len(self.all_token_ids)
#@save
def load_data_wiki(batch_size, max_len):
    """加载WikiText-2数据集"""
    # num_workers = d2l.get_dataloader_workers()
    # data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
    data_dir=''
    paragraphs = read_wiki(data_dir)
    train_set = _WikiTextDataset(paragraphs, max_len)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                        shuffle=True)
    return train_iter, train_set.vocab
相关推荐
IT_陈寒2 小时前
Python开发者的效率革命:这5个技巧让你的代码提速50%!
前端·人工智能·后端
用户69371750013842 小时前
不卷AI速度,我卷自己的从容——北京程序员手记
android·前端·人工智能
love530love2 小时前
不用聊天软件 OpenClaw 手机浏览器远程访问控制:Tailscale 配置、设备配对与常见问题全解
人工智能·windows·python·智能手机·tailscale·openclaw·远程访问控制
lifallen2 小时前
从零推导多 Agent 协作网络 (Flow Agent)
人工智能·语言模型
CoovallyAIHub2 小时前
2.5GB 塞进浏览器:Mistral 开源实时语音识别,延迟不到半秒
深度学习·算法·计算机视觉
guoji77882 小时前
2026年Gemini 3 Pro vs 豆包2.0深度评测:海外顶流与国产黑马谁更强?
大数据·人工智能·架构
NAGNIP2 小时前
一文搞懂深度学习中的损失函数设计!
人工智能·算法
千桐科技2 小时前
大模型幻觉难解?2026深度解析:知识图谱如何成为LLM落地的“刚需”与高薪新赛道
人工智能·大模型·llm·知识图谱·大模型幻觉·qknow·行业深度ai应用
mygugu2 小时前
详细分析swanlab集成mmengine底层实现机制--源码分析
python·深度学习·可视化