使用 T5 模型完成新闻摘要任务

前文

内容摘要是自然语言处理(NLP)的核心问题之一,实现这一能目标必须让模型具备语言理解和内容生成两大能力。本文使用新闻数据,通过微调 T5 模型来完成提取新闻摘要这一任务。

要实现这一任务需要使用 Seq2Seq 模型,这也是为什么选中 T5 的原因。 Text-to-Text Transfer Transformer (T5) 是一种基于 Transformer 的模型,构建在编码器-解码器架构上,在众多无监督和监督任务中进行了多任务混合预训练,其中每个任务都是文本到文本的形式。T5 在摘要、翻译等领域表现很出色。

数据

本文要使用的数据集是 XSum 数据集,该数据集每个样本由 BBC 文章 document 和对应的单句摘要 summary 组成。下面展示一个样本如下:

vbnet 复制代码
{
'document': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.', 
'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.', 
'id': '35232142'
} 

为了和我们使用的大模型 T5 配套使用,我们处理数据的 tokenizer 也是加载自 T5 ,这样处理的结果符合模型的输入要求。 其中具体的工作主要是在 document 前面加上 "summarize: ", 然后将各个 documentsummary 处理成对应的 input_ids ,token_type_ids ,attention_mask

ini 复制代码
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

另外为了符合训练 Seq2Seq 模型的要求,我们需要一种特殊的数据处理器,它不仅可以将输入扩充到 batch 中的最长样本的长度,还可以填充标签。我们这里选用的是 DataCollatorForSeq2Seq

ini 复制代码
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
train_dataset = tokenized_datasets["train"].to_tf_dataset(batch_size=BATCH_SIZE, columns=["input_ids", "attention_mask", "labels"], shuffle=True, collate_fn=data_collator)
test_dataset = tokenized_datasets["test"].to_tf_dataset(batch_size=BATCH_SIZE,  columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)
generation_dataset = tokenized_datasets["test"].shuffle().select(list(range(200))).to_tf_dataset(batch_size=BATCH_SIZE,   columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)

模型

为了更好评估我们的模型效果,我们会计算真实序列和预测序列之间的 ROUGE scoreROUGE(Recall-Oriented Understudy for Gisting Evaluation) 是一种用于自动评估文本摘要质量的指标。它主要用于衡量生成的摘要与真实摘要之间的相似性,是评估文本摘要生成系统性能的常用指标之一。

ini 复制代码
def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_l(decoded_labels, decoded_predictions)
    result = {"RougeL": result["f1_score"]}
    return result

因为显存有限,所以选择了 t5-small 这个小模型,只有 6 层结构, 大约 93M 的参数,属于 t5 家族中较小的模型,但是足够胜任本次的任务。

另外为了微调模型能达到最佳的效果,我们使用全部数据来进行预训练,并进行 5 个 epoch ,batch_size 是 8 个,需要 17G 左右的显存,耗时将近 4.5 小时

arduino 复制代码
Epoch 1/5
20405/20405 [==============================] - 3240s 158ms/step - loss: 2.6656 - val_loss: 2.3561 - RougeL: 0.2318
Epoch 2/5
20405/20405 [==============================] - 3200s 157ms/step - loss: 2.5085 - val_loss: 2.2873 - RougeL: 0.2393
Epoch 3/5
20405/20405 [==============================] - 3199s 157ms/step - loss: 2.4330 - val_loss: 2.2488 - RougeL: 0.2438
Epoch 4/5
20405/20405 [==============================] - 3212s 157ms/step - loss: 2.3785 - val_loss: 2.2172 - RougeL: 0.2468
Epoch 5/5
20405/20405 [==============================] - 3239s 159ms/step - loss: 2.3335 - val_loss: 2.1938 - RougeL: 0.2493

推理

推理的过程我们不用编写复杂的推理代码,Hugging Face Transformers 库中的 pipeline 工具已经针对各种任务提供了各种流水线工具,比如这里我们使用 summarization 工具,再结合我们经过微调的大模型 t5-small ,即可对给定的输入文档生成摘要。

ini 复制代码
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, framework="tf")
for i in range(5):
    summarizer(raw_datasets["test"][i]["document"], min_length=MIN_TARGET_LENGTH, max_length=MAX_TARGET_LENGTH)
    print(raw_datasets["test"][i]["summary"])

推理结果和原摘要进行对比,可以看出:

vbnet 复制代码
[{'summary_text': 'The Australian government is trying to verify the deaths of two Lebanese fighters who left the country to fight with Islamic State militants.'}]
The Australian government says new citizenship laws will help combat terrorism but experts worry too many people will be affected by the laws, including children.
[{'summary_text': "It's been a week since the murder of Lucky the donkey, who was found dead at her home in Fermanagh, County Down."}]
On the day of her burial, County Fermanagh murder victim Connie Leonard's legacy is already being felt.
[{'summary_text': 'Insurance claims for catastrophic storms in Australia have risen by more than A$8.8m (£8.6m) in the past year, a council has said.'}]
Severe storms that hit Australia during April and May have led to more than A$1.55bn ($1.18bn; £778m) in insurance losses so far.
[{'summary_text': "It's been a long time since I was selected by BBC Sport to pick my own team of the week for BBC Sport."}]
Manchester United did Chelsea's title rivals Tottenham a favour and kept up their own pursuit of the top four with a dominant win over the Premier League leaders.
[{'summary_text': 'A Belarusian politician has been found guilty of tax evasion after he held bank accounts in Poland and Lithuania.'}]
One of Belarus' most prominent human rights activists has been sentenced to four-and-a-half years in prison for tax evasion.

小结

其实对于简单的文本摘要任务,常规的 BERT 或者 T5-small 这种传统的模型也是可以胜任的,如果没有更高精度的业务需求没有必要一味追求 LLM ,毕竟对于个人或者中小企业来说目前使用 LLM 的成本较高。

相关推荐
云卓SKYDROID3 分钟前
无人机载重模块技术要点分析
人工智能·无人机·科普·高科技·云卓科技
云卓SKYDROID5 分钟前
无人机RTK技术要点与难点分析
人工智能·无人机·科普·高科技·云卓科技
麻雀无能为力1 小时前
CAU数据挖掘 支持向量机
人工智能·支持向量机·数据挖掘·中国农业大学计算机
智能汽车人1 小时前
Robot---能打羽毛球的机器人
人工智能·机器人·强化学习
埃菲尔铁塔_CV算法1 小时前
基于 TOF 图像高频信息恢复 RGB 图像的原理、应用与实现
人工智能·深度学习·数码相机·算法·目标检测·计算机视觉
ζั͡山 ั͡有扶苏 ั͡✾1 小时前
AI辅助编程工具对比分析:Cursor、Copilot及其他主流选择
人工智能·copilot·cursor
东临碣石821 小时前
【AI论文】数学推理能否提升大型语言模型(LLM)的通用能力?——探究大型语言模型推理能力的可迁移性
人工智能·语言模型·自然语言处理
未来智慧谷2 小时前
微软医疗AI诊断系统发布 多智能体协作实现疑难病例分析
人工智能·microsoft·医疗ai
野生技术架构师2 小时前
简述MCP的原理-AI时代的USB接口
人工智能·microsoft
Allen_LVyingbo2 小时前
Python常用医疗AI库以及案例解析(2025年版、上)
开发语言·人工智能·python·学习·健康医疗