一、预训练bert数据集
在使用 BERT 进行预训练时,通常需要准备两个关键的任务数据:
- 下一句预测任务(NSP,Next Sentence Prediction)
- 遮蔽语言模型任务(MLM,Masked Language Model)
这两个任务有助于模型理解上下文关系以及填充缺失的词汇,从而增强其对自然语言的理解能力。
1. 数据加载和预处理
在准备数据集时,首先需要加载和预处理数据,将文本从原始数据中提取出来,并进行相应的处理,使其能够输入到 BERT 中。通常,数据集需要先进行词元化(tokenization),将每个句子拆解成小的词元(单词或子词)。
2. 下一句预测任务(NSP)
BERT 通过 下一句预测(NSP) 任务来训练模型判断两个句子是否连续。为此,我们需要从段落中获取相邻的句子,并创建标签,表示它们是否是连续的。
- 相邻句子:从段落中提取相邻的句子,标签为 "是"。
- 不相关句子:随机选择段落中的其他句子,将其与当前句子配对,标签为 "否"。
这个任务的目标是让模型学习判断两句话是否来自同一篇文章,并理解句子间的关系。
3. 遮蔽语言模型任务(MLM)
BERT 通过 遮蔽语言模型(MLM) 任务来训练模型根据上下文预测被遮蔽的词元。为了生成 MLM 数据,我们需要随机选择部分词元进行遮蔽,并记录被遮蔽的词元标签。
- 随机遮蔽 :在每个句子中,随机选择约 15% 的词元进行遮蔽。大部分情况下(80%),这些词元会被
<mask>标记替代。剩下的 10% 保留原词,另外 10% 用随机词替代。这种方式帮助模型学习如何基于上下文预测被遮蔽的词元。
通过这一任务,BERT 能够更好地理解语言的语法和上下文信息。
4. 填充数据
由于 BERT 的输入要求所有句子的长度一致,因此在生成数据时,需要对每个句子进行填充(padding)。所有序列会被扩展到指定的最大长度 max_len,并生成有效长度(valid_lens)以确保模型能够区分填充部分和实际内容。
填充过程包括:
- 在句子的末尾添加特殊的
<pad>标记,直到句子达到最大长度。 - 对于标记
<pad>的部分,模型的训练会使用权重(weight)来避免对填充部分进行学习。
二、代码
import torch
import os
import random
import d2l
import test_69bert
# d2l.DATA_HUB['wikitext-2'] = (
# 'https://s3.amazonaws.com/research.metamind.io/wikitext/'
# 'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
def read_wiki(data_dir):
file_name=os.path.join(data_dir,'wiki.train.tokens')#data_dir:这个路径现在是从你手动下载并存放数据集的目录中读取。
with open(file_name,'r') as f:
lines = f.readlines()
paragraphs=[line.strip().lower().strip(' . ') for line in lines if len(line.strip(' . '))>=2]#筛选出连续句子超过2的,对句子去空格,改小写
random.shuffle(paragraphs)
return paragraphs
def _get_next_sentence(sentence,next_sentence,paragraphs):#匹配上下句,依靠生成的随机数决定是否上下文相关联,而不是靠人工或者语境
if random.random()<0.5:
is_next=True
else:
next_sentence=random.choice(random.choice(paragraphs))
is_next=False
return sentence,next_sentence,is_next
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):#生成用于判断上下文的数据集
nsp_data_from_paragraph=[]
for i in range(len(paragraph)-1):
token_a,token_b,is_next = _get_next_sentence(paragraph[i],paragraph[i+1],paragraphs)
if len(token_a)+len(token_b)+3>max_len:
continue
tokens,segments=test_69bert.get_tokens_and_segments(token_a,token_b)
nsp_data_from_paragraph.append((tokens,segments,is_next))
return nsp_data_from_paragraph
def _replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab):
mlm_input_tokens=[token for token in tokens]
pred_positions_and_labels=[]#用于存储被遮蔽的位置和对应的原始标签(即,原始词元)。这个列表将用于计算损失,帮助模型学习预测被遮蔽的词元。
random.shuffle(candidate_pred_positions)
for mlm_pred_position in candidate_pred_positions:
if len(pred_positions_and_labels) >= num_mlm_preds:
break
masked_token=None
if random.random()<0.8:
masked_token='<mask>'
else:
if random.random()<0.5:
masked_token=tokens[mlm_pred_position]
else:
masked_token=random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position]=masked_token#将 mlm_input_tokens 中相应位置的词元替换为 masked_token。
pred_positions_and_labels.append(mlm_pred_position,tokens[mlm_pred_position])#存储了所有被遮蔽的词元位置和它们的标签(即原始的词元)
return mlm_input_tokens,pred_positions_and_labels
def _get_mlm_data_from_tokens(tokens,vocab):
candidate_pred_positions=[]
for i,token in enumerate(tokens):
if token in ['<cls>','<sep>']:
continue
candidate_pred_positions.append(i)
num_mlm_preds=max(1,round(tokens)*0.15)
mlm_input_tokens,pred_positions_and_labels=_replace_mlm_tokens(tokens,candidate_pred_positions,num_mlm_preds,vocab)
pred_positions_and_labels=sorted(pred_positions_and_labels,key=lambda x:x[0])
pred_positions=[v[0] for v in pred_positions_and_labels]
mlm_pred_labels=[v[1] for v in pred_positions_and_labels]
return vocab[mlm_input_tokens],pred_positions,mlm_pred_labels
def _pad_bert_input(example,max_len,vocab):
max_num_mlm_pred=round(max_len*0.15)
all_tokens_ids,all_segments,valid_lens=[],[],[]
all_pred_positions,all_mlm_weights,all_mlm_labels=[],[],[]
nsp_labels=[]
for (token_ids,pred_positions,mlm_pred_label_ids,segmentms,is_next) in example:
all_tokens_ids.append(torch.tensor(token_ids+[vocab['<pad>']]*(max_len-len(token_ids)),dtype=torch.long))
all_segments.append(torch.tensor(segmentms+[0]*(max_len-len(segmentms)),dtype=torch.float32))
valid_lens.append(torch.tensor(len(token_ids),dtype=torch.float32))
all_pred_positions.append(torch.tensor(pred_positions+[0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.float32))
all_mlm_weights.append(torch.tensor([1.0]*len(mlm_pred_label_ids)+[0.0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.float32))
all_mlm_labels.append(torch.tensor(mlm_pred_label_ids+[0]*(max_num_mlm_pred-len(mlm_pred_label_ids)),dtype=torch.int32))
nsp_labels.append(torch.tensor(mlm_pred_label_ids+[0.0]*(max_num_mlm_pred-len(pred_positions)),dtype=torch.int32))
all_mlm_labels.append(torch.tensor(mlm_pred_label_ids+[0]*(max_num_mlm_pred-len(mlm_pred_label_ids)),dtype=torch.long))
return (all_tokens_ids,all_segments,valid_lens,all_pred_positions,all_mlm_weights,all_mlm_labels)
class _WikiTextDataset(torch.utils.data.Dataset):
def __init__(self,paragraphs,max_len):
paragraphs=[d2l.tokenize(paragraph,token='word')for paragraph in paragraphs]#对每个段落进行词元化,按单词分割
sentences=[sentence for paragraph in paragraphs for sentence in paragraph]
self.vocab=d2l.Vocab(sentences,min_freq=5,reserved_tokens=['<pad>','<mask>','<cls>','<sep>'])#创建一个词汇表,过滤频率低于5
examples=[]
for paragraph in paragraphs:
examples.extend(_get_nsp_data_from_paragraph(paragraph,paragraphs,self.vocab,max_len))
examples=[(_get_mlm_data_from_tokens(tokens,self.vocab)+(segments,is_next))for tokens,segments,is_next in examples]#生成遮蔽语言模型(MLM) 任务的数据
(self.all_token_ids, self.all_segments, self.valid_lens,
self.all_pred_positions, self.all_mlm_weights,
self.all_mlm_labels, self.nsp_labels) = _pad_bert_input(examples, max_len, self.vocab)
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx], self.all_pred_positions[idx],
self.all_mlm_weights[idx], self.all_mlm_labels[idx],
self.nsp_labels[idx])
def __len__(self):
return len(self.all_token_ids)
#@save
def load_data_wiki(batch_size, max_len):
"""加载WikiText-2数据集"""
# num_workers = d2l.get_dataloader_workers()
# data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
data_dir=''
paragraphs = read_wiki(data_dir)
train_set = _WikiTextDataset(paragraphs, max_len)
train_iter = torch.utils.data.DataLoader(train_set, batch_size,
shuffle=True)
return train_iter, train_set.vocab