基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

  1. 数据预处理

    • 使用DataGenerator类加载并预处理数据,处理变长序列的padding。
    • 输入为内容(content),目标为标题(title)。
  2. 模型构建

    • 基于BERT构建Seq2Seq模型,使用交叉熵损失。
    • 采用Beam Search进行生成,支持Top-K采样。
  3. 训练与评估

    • 使用Adam优化器进行训练。
    • 每个epoch结束时通过Evaluate回调生成示例标题以观察效果。

    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    from bert4keras.bert import build_bert_model
    from bert4keras.tokenizer import Tokenizer, load_vocab
    from keras.layers import *
    from keras.models import Model
    from keras import backend as K
    from bert4keras.snippets import parallel_apply
    from keras.optimizers import Adam
    import keras
    import math
    from sklearn.model_selection import train_test_split
    from rouge import Rouge # 需要安装rouge包

    配置参数

    config_path = 'bert/chinese_L-12_H-768_A-12/bert_config.json'
    checkpoint_path = 'bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
    dict_path = 'bert/chinese_L-12_H-768_A-12/vocab.txt'

    max_input_len = 256
    max_output_len = 32
    batch_size = 16
    epochs = 10
    beam_size = 3
    learning_rate = 2e-5
    val_split = 0.1

    数据预处理增强

    class DataGenerator(keras.utils.Sequence):
    def init(self, data, batch_size=8, mode='train'):
    self.batch_size = batch_size
    self.mode = mode
    self.data = data
    self.indices = np.arange(len(data))

    复制代码
     def __len__(self):
         return math.ceil(len(self.data) / self.batch_size)
    
     def __getitem__(self, index):
         batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
         batch = self.data.iloc[batch_indices]
         return self._process_batch(batch)
    
     def on_epoch_end(self):
         if self.mode == 'train':
             np.random.shuffle(self.indices)
    
     def _process_batch(self, batch):
         batch_x, batch_y = [], []
         for _, row in batch.iterrows():
             content = row['content'][:max_input_len]
             title = row['title'][:max_output_len-2]  # 保留空间给[CLS]和[SEP]
             
             # 编码器输入
             x, _ = tokenizer.encode(content, max_length=max_input_len)
             
             # 解码器输入输出
             y, _ = tokenizer.encode(title, max_length=max_output_len)
             decoder_input = [tokenizer._token_start_id] + y[:-1]
             decoder_output = y
             
             batch_x.append(x)
             batch_y.append({'decoder_input': decoder_input, 'decoder_output': decoder_output})
         
         # 动态padding
         padded_x = self._pad_sequences([x for x in batch_x], maxlen=max_input_len)
         padded_decoder_input = self._pad_sequences(
             [y['decoder_input'] for y in batch_y], 
             maxlen=max_output_len,
             padding='post'
         )
         padded_decoder_output = self._pad_sequences(
             [y['decoder_output'] for y in batch_y],
             maxlen=max_output_len,
             padding='post'
         )
         
         return [padded_x, padded_decoder_input], padded_decoder_output
    
     def _pad_sequences(self, sequences, maxlen, padding='pre'):
         padded = np.zeros((len(sequences), maxlen))
         for i, seq in enumerate(sequences):
             if len(seq) > maxlen:
                 seq = seq[:maxlen]
             if padding == 'pre':
                 padded[i, -len(seq):] = seq
             else:
                 padded[i, :len(seq)] = seq
         return padded

    改进的模型架构

    def build_seq2seq_model():
    # 编码器
    encoder_inputs = Input(shape=(None,), name='Encoder-Input')
    encoder = build_bert_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='encoder',
    return_keras_model=False,
    )
    encoder_outputs = encoder(encoder_inputs)

    复制代码
     # 解码器
     decoder_inputs = Input(shape=(None,), name='Decoder-Input')
     decoder = build_bert_model(
         config_path=config_path,
         checkpoint_path=checkpoint_path,
         model='decoder',
         application='lm',
         return_keras_model=False,
     )
     decoder_outputs = decoder([decoder_inputs, encoder_outputs])
    
     # 连接模型
     model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
     
     # 自定义损失函数(忽略padding)
     def seq2seq_loss(y_true, y_pred):
         y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
         loss = K.sparse_categorical_crossentropy(
             y_true, y_pred, from_logits=True
         )
         return K.sum(loss * y_mask) / K.sum(y_mask)
    
     model.compile(Adam(learning_rate), loss=seq2seq_loss)
     return model

    改进的Beam Search

    def beam_search(model, input_seq, beam_size=3):
    encoder_input = tokenizer.encode(input_seq)[0]
    encoder_output = model.get_layer('bert').predict(np.array([encoder_input]))

    复制代码
     sequences = [[[tokenizer._token_start_id], 0.0]]
     for _ in range(max_output_len):
         all_candidates = []
         for seq, score in sequences:
             if seq[-1] == tokenizer._token_end_id:
                 all_candidates.append((seq, score))
                 continue
             
             decoder_input = np.array([seq])
             decoder_output = model.get_layer('bert_1').predict(
                 [decoder_input, encoder_output]
             )[:, -1, :]
             
             top_k = np.argsort(decoder_output[0])[-beam_size:]
             for token in top_k:
                 new_seq = seq + [token]
                 new_score = score + np.log(decoder_output[0][token])
                 all_candidates.append((new_seq, new_score))
         
         # 长度归一化
         ordered = sorted(all_candidates, key=lambda x: x[1]/(len(x[0])+1e-8), reverse=True)
         sequences = ordered[:beam_size]
     
     best_seq = sequences[0][0]
     return tokenizer.decode(best_seq[1:-1])  # 去除[CLS]和[SEP]

    增强的评估回调

    class AdvancedEvaluate(keras.callbacks.Callback):
    def init(self, val_data, sample_size=5):
    self.val_data = val_data
    self.rouge = Rouge()
    self.samples = val_data.sample(sample_size)

    复制代码
     def on_epoch_end(self, epoch, logs=None):
         # 生成示例
         print("\n生成示例:")
         for _, row in self.samples.iterrows():
             generated = beam_search(self.model, row['content'], beam_size)
             print(f"真实标题: {row['title']}")
             print(f"生成标题: {generated}\n")
         
         # 计算ROUGE分数
         references = []
         hypotheses = []
         for _, row in self.val_data.iterrows():
             generated = beam_search(self.model, row['content'], beam_size=1)
             references.append(row['title'])
             hypotheses.append(generated)
         
         scores = self.rouge.get_scores(hypotheses, references, avg=True)
         print(f"验证集ROUGE-L: {scores['rouge-l']['f']:.4f}")

    主流程

    if name == "main":
    # 加载数据
    full_data = pd.read_csv('train.tsv', sep='\t', names=['title', 'content'])
    train_data, val_data = train_test_split(full_data, test_size=val_split)

    复制代码
     # 初始化tokenizer
     tokenizer = Tokenizer(dict_path, do_lower_case=True)
    
     # 构建模型
     model = build_seq2seq_model()
     model.summary()
    
     # 数据生成器
     train_gen = DataGenerator(train_data, batch_size, mode='train')
     val_gen = DataGenerator(val_data, batch_size, mode='val')
    
     # 训练配置
     callbacks = [
         AdvancedEvaluate(val_data),
         keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1),
         keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
     ]
    
     # 开始训练
     model.fit(
         train_gen,
         validation_data=val_gen,
         epochs=epochs,
         callbacks=callbacks,
         workers=4,
         use_multiprocessing=True
     )
相关推荐
byxdaz5 分钟前
PyTorch处理数据--Dataset和DataLoader
人工智能·深度学习·机器学习
gs801401 小时前
RAG生成中的多文档动态融合及去重加权策略探讨
人工智能·机器学习
字节跳动开源1 小时前
MySQL遇到AI:字节跳动开源 MySQL 虚拟索引 VIDEX
人工智能·mysql·开源·虚拟索引技术·解耦架构
图书馆钉子户1 小时前
django orm的优缺点
后端·python·django
linuxxx1101 小时前
django报错:RuntimeError: populate() isn‘t reentrant
后端·python·django
YJlio1 小时前
Manus AI 与多语言手写识别技术解析
人工智能
@小匠1 小时前
使用 Python包管理工具 uv 完成 Open WebUI 的安装
开发语言·python·uv
网络风云2 小时前
Flask(七)用户认证与权限管理
后端·python·flask
船长@Quant2 小时前
PyTorch量化技术教程:第四章 PyTorch在量化交易中的应用
pytorch·python·深度学习·机器学习·量化交易·ta-lib
MobiCetus2 小时前
如何一键安装所有Python项目的依赖!
开发语言·jvm·c++·人工智能·python·算法·机器学习