BERT-pytorch源码实现,解决内存溢出问题

BERT-pytorch源码实现,解决内存溢出问题

相信大家很多人都在做BERT这个模型,但是,有些人可能就是直接从transfermer这个模型里直接导入数据,但是这种方法不方便我们修改模型,于是有些人就通过pytorch详细实现了BERT,但是博主发现,这些详细实现BERT的代码出现了内存溢出问题,博主就做了改进,下面代码,我们可以解决掉内存溢出问题,主要还是因为中间结果并没有完全释放,代码如下:

注:大家如果要解决内存溢出问题,关注del语句就可以了。

python 复制代码
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
         https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert
'''
import re
import math
import torch
import numpy as np
from random import *
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
 
import matplotlib.pyplot as plt
from data_process import get_data

setences,label,setences_test,label_test=get_data()
device = torch.device('cpu')

sentences=setences
#text = (
#    'Hello, how are you? I am Romeo.\n' # R
#    'Hello, Romeo My name is Juliet. Nice to meet you.\n' # J
#    'Nice meet you too. How are you today?\n' # R
#    'Great. My baseball team won the competition.\n' # J
#    'Oh Congratulations, Juliet\n' # R
#    'Thank you Romeo\n' # J
#    'Where are you going today?\n' # R
#    'I am going shopping. What about you?\n' # J
#    'I am going to visit my grandmother. she is not very well' # R
#)
#sentences = re.sub("[.,!?\\-]", '', text.lower()).split('\n') # filt
#print(sentences)

word_list = list(set(" ".join(setences).split())|set(" ".join(setences_test).split())) # ['hello', 'how', 'are', 'you',...]
word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}
for i, w in enumerate(word_list):
    word2idx[w] = i + 4
idx2word = {i: w for i, w in enumerate(word2idx)}
vocab_size = len(word2idx)
 
token_list = list()
for sentence in setences:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)







#print(token_list)
'''
[[12, 7, 22, 5, 39, 21, 15],
 [12, 15, 13, 35, 10, 27, 34, 14, 19, 5],
 [34, 19, 5, 17, 7, 22, 5, 8],
 [33, 13, 37, 32, 28, 11, 16],
 [30, 23, 27],
 [6, 5, 15],
 [36, 22, 5, 31, 8],
 [39, 21, 31, 18, 9, 20, 5],
 [39, 21, 31, 14, 29, 13, 4, 25, 10, 26, 38, 24]]
'''
# BERT Parameters
maxlen = 30
batch_size = 6
max_pred = 5 # max tokens of prediction
n_layers = 6
n_heads = 12
d_model = 768
d_ff = 768*4 # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 3
# sample IsNext and NotNext to be same in small batch size
def make_data():
    batch = []
    for i in range(len(setences)):
        tokens_a_index =  i
        tokens_a = token_list[tokens_a_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) 
 
        # MASK LM
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))  # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]  # candidate masked position
        shuffle(cand_maked_pos)

        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word2idx['[MASK]']  # make mask
            elif random() > 0.9:  # 10%
                index = randint(0, vocab_size - 1)  # random index in vocabulary
                while index < 4:  # can't involve 'CLS', 'SEP', 'PAD'
                    index = randint(0, vocab_size - 1)
                input_ids[pos] = index  # replace
 
        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
 
        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
 
       
        batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label[tokens_a_index]])  # IsNext
    return batch
batch = make_data()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b)     # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped)          # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)
 
 
class MyDataSet(Data.Dataset):
    def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):
        self.input_ids = input_ids
        self.segment_ids = segment_ids
        self.masked_tokens = masked_tokens
        self.masked_pos = masked_pos
        self.isNext = isNext
 
    def __len__(self):
        return len(self.input_ids)
 
    def __getitem__(self, idx):
        return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[
            idx]
 
loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, seq_len = seq_q.size()   #[batch_size,maxlen]
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # [batch_size, 1, seq_len]
    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]
 
def gelu(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
 
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
 
    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long)
        # print("pos:",pos)
        '''pos: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])'''
        pos = pos.unsqueeze(0).expand_as(x).to(device)  # [seq_len] -> [batch_size, seq_len]
        # print("pos_batch:", pos)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        del pos,x, seg
        return self.norm(embedding)
 
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
 
    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        del attn,scores,Q, K, V,attn_mask
        return context
 
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]
        residual, batch_size = Q, Q.size(0)
        residual=residual.to(device)
       
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]
 
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]
 
        # context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]
        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v)# context: [batch_size, seq_len, n_heads, d_v]
        output = nn.Linear(n_heads * d_v, d_model).to(device)(context)
        del context,attn_mask,q_s,k_s,v_s
        
        return nn.LayerNorm(d_model).to(device)(output + residual) # output: [batch_size, seq_len, d_model]
 
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
 
    def forward(self, x):
        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)
       
        return self.fc2(gelu(self.fc1(x)))
 
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()
 
    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]
        del enc_self_attn_mask,enc_inputs
        return enc_outputs

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        self.embedding = Embedding()
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Dropout(0.5),
            nn.Tanh(),
        )
        self.classifier = nn.Linear(d_model, 3)
        self.linear = nn.Linear(d_model, d_model)
        self.activ2 = gelu
        # fc2 is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)
        self.fc2.weight = embed_weight
 
    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]
        for layer in self.layers:
            # output: [batch_size, max_len, d_model]
            output = layer(output, enc_self_attn_mask)
        # it will be decided by first token(CLS)
        '''
         (fc): Sequential(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Dropout(p=0.5, inplace=False)
            (2): Tanh()
            )
          (classifier): Linear(in_features=768, out_features=2, bias=True)
          (linear): Linear(in_features=768, out_features=768, bias=True)
          (fc2): Linear(in_features=768, out_features=40, bias=False)
        '''
        # logits_clsf :根据[CLS]预测是否是连续的句子,[CLS]在第一维
        h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]
        logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext
 
        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]
        #logits_lm:预测mask的token
        logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]
        del h_masked,h_pooled,output,enc_self_attn_mask,masked_pos,input_ids,segment_ids
        return logits_lm, logits_clsf
model = BERT().to(device)
# print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000001)
#out = torch.gather(input, dim, index)
index = torch.from_numpy(np.array([[1, 2, 0], [2, 0, 1]])).type(torch.LongTensor)
index = index[:, :, None].expand(-1, -1, 10)
loss_list=[]
for epoch in range(10):
    loss_sum=0
    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:
   
      
      logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)
      #logits_lm:[batch_size,max_pred,vocab_size] -> [batch_size*max_pred,vocab_size],batch_size*max_pred个词。每个词都有vocab_size种可能。
      loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM
      loss_lm = (loss_lm.float()).mean()
    #  isNext=isNext.to(device)
      loss_clsf = criterion(logits_clsf, isNext) # for sentence classification
      loss = loss_lm + loss_clsf
      loss_sum=loss_sum+loss
      loss_list.append(float(loss))

      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      del loss, logits_clsf, input_ids,segment_ids,masked_tokens,masked_pos,logits_lm,isNext,loss_clsf,loss_lm
     
    

# Predict mask tokens ans isNext



print('test')

token_list=[]

for sentence in setences_test:
    arr = [word2idx[s] for s in sentence.split()]
    token_list.append(arr)


def make_data_test():
    batch = []
  
    for i in range(len(setences_test)):
        tokens_a_index =  i
        tokens_a = token_list[tokens_a_index]
        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']]
        segment_ids = [0] * (1 + len(tokens_a) + 1) 
 
        # MASK LM
        n_pred = min(max_pred, max(1, int(len(input_ids) * 0.15)))  # 15 % of tokens in one sentence
        cand_maked_pos = [i for i, token in enumerate(input_ids)
                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']]  # candidate masked position
        shuffle(cand_maked_pos)

        masked_tokens, masked_pos = [], []
        for pos in cand_maked_pos[:n_pred]:
            masked_pos.append(pos)
            masked_tokens.append(input_ids[pos])
            if random() < 0.8:  # 80%
                input_ids[pos] = word2idx['[MASK]']  # make mask
            elif random() > 0.9:  # 10%
                index = randint(0, vocab_size - 1)  # random index in vocabulary
                while index < 4:  # can't involve 'CLS', 'SEP', 'PAD'
                    index = randint(0, vocab_size - 1)
                input_ids[pos] = index  # replace
 
        # Zero Paddings
        n_pad = maxlen - len(input_ids)
        input_ids.extend([0] * n_pad)
        segment_ids.extend([0] * n_pad)
 
        # Zero Padding (100% - 15%) tokens
        if max_pred > n_pred:
            n_pad = max_pred - n_pred
            masked_tokens.extend([0] * n_pad)
            masked_pos.extend([0] * n_pad)
 
       
        batch.append([input_ids, segment_ids, masked_tokens, masked_pos, label_test[tokens_a_index]])  # IsNext
          
      
          
    return batch
# Proprecessing Finished
 
batch = make_data_test()
input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)
'''>>> a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b)     # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
zip(*zipped)          # 与 zip 相反,可理解为解压,为zip的逆过程,可用于矩阵的转置
[(1, 2, 3), (4, 5, 6)]
'''
input_ids, segment_ids, masked_tokens, masked_pos, isNext = \
    torch.LongTensor(input_ids), torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens), \
    torch.LongTensor(masked_pos), torch.LongTensor(isNext)
 
predict_list=[]

for i in range(len(batch)):

    input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[0]

    print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])
 
    logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \
                     torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))


    logits_lm = logits_lm.data.max(2)[1][0].data.numpy()
    print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])
    print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])


 
    logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]
    print('isNext : ', isNext )
    print('predict isNext : ', logits_clsf)
    predict_list.append(logits_clsf)


test_loss = 0
correct = 0
total = 0
target_num =[0,0,0]
predict_num = [0,0,0]
p=0
acc_num =[0,0,0]

for i in label_test:
    target_num[i]+=1

for i in predict_list:
       
      #  print(i.argmax())
        index=int(i)
        if index in [0,1,2]:
            predict_num[index]+=1
    #    print(id2word[index],id2word[p])
       
        if index==label_test[p]:
           
           
            acc_num[index]+=1
        p=p+1


#print(target_num)
#print(predict_num)
#print(acc_num)
recallz=0
precisionz=0
accuracyz=0
F1z=0

ps=0
rs=0
for i in range(3):
    if target_num[i]!=0:
        recallz=acc_num[i]/target_num[i]
    else:
        recallz=0
    if predict_num[i]!=0:
         precisionz=acc_num[i]/predict_num[i]
    else:
        precisionz=0
    ps=ps+precisionz
    rs=rs+recallz
    if recallz+precisionz!=0:
         F1z=2*recallz*precisionz/(recallz+precisionz)+F1z

#recall = [acc_num[i]/target_num[i] for i in range(3)]

#precision = [acc_num[i]/predict_num[i] for i in range(3)]

#F1 = [2*recall[i]*precision[i]/(recall[i]+precision[i]) for i in range(3)]

print()
accuracy = sum(acc_num)/sum(target_num) 


# 打印格式方便复制
print('recall:',rs/3)
print('precision:',ps/3)
print('F1:',F1z/3)
print('accuracy',accuracy)

plt.plot(loss_list,label='BERT')
plt.legend()
plt.title('loss-epoch')
plt.show()
相关推荐
Memene摸鱼日报14 小时前
「Memene 摸鱼日报 2025.9.16」OpenAI 推出 GPT-5-Codex 编程模型,xAI 发布 Grok 4 Fast
人工智能·aigc
xiaohouzi11223314 小时前
OpenCV的cv2.VideoCapture如何加GStreamer后端
人工智能·opencv·计算机视觉
用户1252055970814 小时前
解决Stable Diffusion WebUI训练嵌入式模型报错问题
人工智能
Juchecar14 小时前
一文讲清 nn.LayerNorm 层归一化
人工智能
小关会打代码14 小时前
计算机视觉案例分享之答题卡识别
人工智能·计算机视觉
Se7en25814 小时前
使用 NVIDIA Dynamo 部署 PD 分离推理服务
人工智能
海拥14 小时前
用 LazyLLM 搭建一个代码注释 / 文档 Agent 的实测体验
人工智能
天天进步201515 小时前
用Python打造专业级老照片修复工具:让时光倒流的数字魔法
人工智能·计算机视觉
文火冰糖的硅基工坊15 小时前
《投资-54》数字资产的形式有哪些?
人工智能·区块链
机器之心15 小时前
刚刚,OpenAI发布GPT-5-Codex:可独立工作超7小时,还能审查、重构大型项目
人工智能·openai