基于BERT的序列到序列(Seq2Seq)模型,生成文本摘要或标题

  1. 数据预处理

    • 使用DataGenerator类加载并预处理数据,处理变长序列的padding。
    • 输入为内容(content),目标为标题(title)。
  2. 模型构建

    • 基于BERT构建Seq2Seq模型,使用交叉熵损失。
    • 采用Beam Search进行生成,支持Top-K采样。
  3. 训练与评估

    • 使用Adam优化器进行训练。
    • 每个epoch结束时通过Evaluate回调生成示例标题以观察效果。

    import numpy as np
    import pandas as pd
    from tqdm import tqdm
    from bert4keras.bert import build_bert_model
    from bert4keras.tokenizer import Tokenizer, load_vocab
    from keras.layers import *
    from keras.models import Model
    from keras import backend as K
    from bert4keras.snippets import parallel_apply
    from keras.optimizers import Adam
    import keras
    import math
    from sklearn.model_selection import train_test_split
    from rouge import Rouge # 需要安装rouge包

    配置参数

    config_path = 'bert/chinese_L-12_H-768_A-12/bert_config.json'
    checkpoint_path = 'bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
    dict_path = 'bert/chinese_L-12_H-768_A-12/vocab.txt'

    max_input_len = 256
    max_output_len = 32
    batch_size = 16
    epochs = 10
    beam_size = 3
    learning_rate = 2e-5
    val_split = 0.1

    数据预处理增强

    class DataGenerator(keras.utils.Sequence):
    def init(self, data, batch_size=8, mode='train'):
    self.batch_size = batch_size
    self.mode = mode
    self.data = data
    self.indices = np.arange(len(data))

    复制代码
     def __len__(self):
         return math.ceil(len(self.data) / self.batch_size)
    
     def __getitem__(self, index):
         batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
         batch = self.data.iloc[batch_indices]
         return self._process_batch(batch)
    
     def on_epoch_end(self):
         if self.mode == 'train':
             np.random.shuffle(self.indices)
    
     def _process_batch(self, batch):
         batch_x, batch_y = [], []
         for _, row in batch.iterrows():
             content = row['content'][:max_input_len]
             title = row['title'][:max_output_len-2]  # 保留空间给[CLS]和[SEP]
             
             # 编码器输入
             x, _ = tokenizer.encode(content, max_length=max_input_len)
             
             # 解码器输入输出
             y, _ = tokenizer.encode(title, max_length=max_output_len)
             decoder_input = [tokenizer._token_start_id] + y[:-1]
             decoder_output = y
             
             batch_x.append(x)
             batch_y.append({'decoder_input': decoder_input, 'decoder_output': decoder_output})
         
         # 动态padding
         padded_x = self._pad_sequences([x for x in batch_x], maxlen=max_input_len)
         padded_decoder_input = self._pad_sequences(
             [y['decoder_input'] for y in batch_y], 
             maxlen=max_output_len,
             padding='post'
         )
         padded_decoder_output = self._pad_sequences(
             [y['decoder_output'] for y in batch_y],
             maxlen=max_output_len,
             padding='post'
         )
         
         return [padded_x, padded_decoder_input], padded_decoder_output
    
     def _pad_sequences(self, sequences, maxlen, padding='pre'):
         padded = np.zeros((len(sequences), maxlen))
         for i, seq in enumerate(sequences):
             if len(seq) > maxlen:
                 seq = seq[:maxlen]
             if padding == 'pre':
                 padded[i, -len(seq):] = seq
             else:
                 padded[i, :len(seq)] = seq
         return padded

    改进的模型架构

    def build_seq2seq_model():
    # 编码器
    encoder_inputs = Input(shape=(None,), name='Encoder-Input')
    encoder = build_bert_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='encoder',
    return_keras_model=False,
    )
    encoder_outputs = encoder(encoder_inputs)

    复制代码
     # 解码器
     decoder_inputs = Input(shape=(None,), name='Decoder-Input')
     decoder = build_bert_model(
         config_path=config_path,
         checkpoint_path=checkpoint_path,
         model='decoder',
         application='lm',
         return_keras_model=False,
     )
     decoder_outputs = decoder([decoder_inputs, encoder_outputs])
    
     # 连接模型
     model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
     
     # 自定义损失函数(忽略padding)
     def seq2seq_loss(y_true, y_pred):
         y_mask = K.cast(K.not_equal(y_true, 0), K.floatx())
         loss = K.sparse_categorical_crossentropy(
             y_true, y_pred, from_logits=True
         )
         return K.sum(loss * y_mask) / K.sum(y_mask)
    
     model.compile(Adam(learning_rate), loss=seq2seq_loss)
     return model

    改进的Beam Search

    def beam_search(model, input_seq, beam_size=3):
    encoder_input = tokenizer.encode(input_seq)[0]
    encoder_output = model.get_layer('bert').predict(np.array([encoder_input]))

    复制代码
     sequences = [[[tokenizer._token_start_id], 0.0]]
     for _ in range(max_output_len):
         all_candidates = []
         for seq, score in sequences:
             if seq[-1] == tokenizer._token_end_id:
                 all_candidates.append((seq, score))
                 continue
             
             decoder_input = np.array([seq])
             decoder_output = model.get_layer('bert_1').predict(
                 [decoder_input, encoder_output]
             )[:, -1, :]
             
             top_k = np.argsort(decoder_output[0])[-beam_size:]
             for token in top_k:
                 new_seq = seq + [token]
                 new_score = score + np.log(decoder_output[0][token])
                 all_candidates.append((new_seq, new_score))
         
         # 长度归一化
         ordered = sorted(all_candidates, key=lambda x: x[1]/(len(x[0])+1e-8), reverse=True)
         sequences = ordered[:beam_size]
     
     best_seq = sequences[0][0]
     return tokenizer.decode(best_seq[1:-1])  # 去除[CLS]和[SEP]

    增强的评估回调

    class AdvancedEvaluate(keras.callbacks.Callback):
    def init(self, val_data, sample_size=5):
    self.val_data = val_data
    self.rouge = Rouge()
    self.samples = val_data.sample(sample_size)

    复制代码
     def on_epoch_end(self, epoch, logs=None):
         # 生成示例
         print("\n生成示例:")
         for _, row in self.samples.iterrows():
             generated = beam_search(self.model, row['content'], beam_size)
             print(f"真实标题: {row['title']}")
             print(f"生成标题: {generated}\n")
         
         # 计算ROUGE分数
         references = []
         hypotheses = []
         for _, row in self.val_data.iterrows():
             generated = beam_search(self.model, row['content'], beam_size=1)
             references.append(row['title'])
             hypotheses.append(generated)
         
         scores = self.rouge.get_scores(hypotheses, references, avg=True)
         print(f"验证集ROUGE-L: {scores['rouge-l']['f']:.4f}")

    主流程

    if name == "main":
    # 加载数据
    full_data = pd.read_csv('train.tsv', sep='\t', names=['title', 'content'])
    train_data, val_data = train_test_split(full_data, test_size=val_split)

    复制代码
     # 初始化tokenizer
     tokenizer = Tokenizer(dict_path, do_lower_case=True)
    
     # 构建模型
     model = build_seq2seq_model()
     model.summary()
    
     # 数据生成器
     train_gen = DataGenerator(train_data, batch_size, mode='train')
     val_gen = DataGenerator(val_data, batch_size, mode='val')
    
     # 训练配置
     callbacks = [
         AdvancedEvaluate(val_data),
         keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2, verbose=1),
         keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True)
     ]
    
     # 开始训练
     model.fit(
         train_gen,
         validation_data=val_gen,
         epochs=epochs,
         callbacks=callbacks,
         workers=4,
         use_multiprocessing=True
     )
相关推荐
AI军哥11 分钟前
MySQL8的安装方法
人工智能·mysql·yolo·机器学习·deepseek
余弦的倒数25 分钟前
知识蒸馏和迁移学习的区别
人工智能·机器学习·迁移学习
Allen Bright25 分钟前
【机器学习-线性回归-2】理解线性回归中的连续值与离散值
人工智能·机器学习·线性回归
weixin_贾33 分钟前
最新AI-Python机器学习与深度学习技术在植被参数反演中的核心技术应用
python·机器学习·植被参数·遥感反演
张槊哲42 分钟前
函数的定义与使用(python)
开发语言·python
船长@Quant1 小时前
文档构建:Sphinx全面使用指南 — 实战篇
python·markdown·sphinx·文档构建
青松@FasterAI1 小时前
【程序员 NLP 入门】词嵌入 - 上下文中的窗口大小是什么意思? (★小白必会版★)
人工智能·自然语言处理
AIGC大时代1 小时前
高效使用DeepSeek对“情境+ 对象 +问题“型课题进行开题!
数据库·人工智能·算法·aigc·智能写作·deepseek
硅谷秋水1 小时前
GAIA-2:用于自动驾驶的可控多视图生成世界模型
人工智能·机器学习·自动驾驶
多巴胺与内啡肽.1 小时前
深度学习--自然语言处理统计语言与神经语言模型
深度学习·语言模型·自然语言处理