使用 T5 模型完成新闻摘要任务

前文

内容摘要是自然语言处理(NLP)的核心问题之一,实现这一能目标必须让模型具备语言理解和内容生成两大能力。本文使用新闻数据,通过微调 T5 模型来完成提取新闻摘要这一任务。

要实现这一任务需要使用 Seq2Seq 模型,这也是为什么选中 T5 的原因。 Text-to-Text Transfer Transformer (T5) 是一种基于 Transformer 的模型,构建在编码器-解码器架构上,在众多无监督和监督任务中进行了多任务混合预训练,其中每个任务都是文本到文本的形式。T5 在摘要、翻译等领域表现很出色。

数据

本文要使用的数据集是 XSum 数据集,该数据集每个样本由 BBC 文章 document 和对应的单句摘要 summary 组成。下面展示一个样本如下:

vbnet 复制代码
{
'document': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.', 
'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.', 
'id': '35232142'
} 

为了和我们使用的大模型 T5 配套使用,我们处理数据的 tokenizer 也是加载自 T5 ,这样处理的结果符合模型的输入要求。 其中具体的工作主要是在 document 前面加上 "summarize: ", 然后将各个 documentsummary 处理成对应的 input_ids ,token_type_ids ,attention_mask

ini 复制代码
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

另外为了符合训练 Seq2Seq 模型的要求,我们需要一种特殊的数据处理器,它不仅可以将输入扩充到 batch 中的最长样本的长度,还可以填充标签。我们这里选用的是 DataCollatorForSeq2Seq

ini 复制代码
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
train_dataset = tokenized_datasets["train"].to_tf_dataset(batch_size=BATCH_SIZE, columns=["input_ids", "attention_mask", "labels"], shuffle=True, collate_fn=data_collator)
test_dataset = tokenized_datasets["test"].to_tf_dataset(batch_size=BATCH_SIZE,  columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)
generation_dataset = tokenized_datasets["test"].shuffle().select(list(range(200))).to_tf_dataset(batch_size=BATCH_SIZE,   columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)

模型

为了更好评估我们的模型效果,我们会计算真实序列和预测序列之间的 ROUGE scoreROUGE(Recall-Oriented Understudy for Gisting Evaluation) 是一种用于自动评估文本摘要质量的指标。它主要用于衡量生成的摘要与真实摘要之间的相似性,是评估文本摘要生成系统性能的常用指标之一。

ini 复制代码
def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_l(decoded_labels, decoded_predictions)
    result = {"RougeL": result["f1_score"]}
    return result

因为显存有限,所以选择了 t5-small 这个小模型,只有 6 层结构, 大约 93M 的参数,属于 t5 家族中较小的模型,但是足够胜任本次的任务。

另外为了微调模型能达到最佳的效果,我们使用全部数据来进行预训练,并进行 5 个 epoch ,batch_size 是 8 个,需要 17G 左右的显存,耗时将近 4.5 小时

arduino 复制代码
Epoch 1/5
20405/20405 [==============================] - 3240s 158ms/step - loss: 2.6656 - val_loss: 2.3561 - RougeL: 0.2318
Epoch 2/5
20405/20405 [==============================] - 3200s 157ms/step - loss: 2.5085 - val_loss: 2.2873 - RougeL: 0.2393
Epoch 3/5
20405/20405 [==============================] - 3199s 157ms/step - loss: 2.4330 - val_loss: 2.2488 - RougeL: 0.2438
Epoch 4/5
20405/20405 [==============================] - 3212s 157ms/step - loss: 2.3785 - val_loss: 2.2172 - RougeL: 0.2468
Epoch 5/5
20405/20405 [==============================] - 3239s 159ms/step - loss: 2.3335 - val_loss: 2.1938 - RougeL: 0.2493

推理

推理的过程我们不用编写复杂的推理代码,Hugging Face Transformers 库中的 pipeline 工具已经针对各种任务提供了各种流水线工具,比如这里我们使用 summarization 工具,再结合我们经过微调的大模型 t5-small ,即可对给定的输入文档生成摘要。

ini 复制代码
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, framework="tf")
for i in range(5):
    summarizer(raw_datasets["test"][i]["document"], min_length=MIN_TARGET_LENGTH, max_length=MAX_TARGET_LENGTH)
    print(raw_datasets["test"][i]["summary"])

推理结果和原摘要进行对比,可以看出:

vbnet 复制代码
[{'summary_text': 'The Australian government is trying to verify the deaths of two Lebanese fighters who left the country to fight with Islamic State militants.'}]
The Australian government says new citizenship laws will help combat terrorism but experts worry too many people will be affected by the laws, including children.
[{'summary_text': "It's been a week since the murder of Lucky the donkey, who was found dead at her home in Fermanagh, County Down."}]
On the day of her burial, County Fermanagh murder victim Connie Leonard's legacy is already being felt.
[{'summary_text': 'Insurance claims for catastrophic storms in Australia have risen by more than A$8.8m (£8.6m) in the past year, a council has said.'}]
Severe storms that hit Australia during April and May have led to more than A$1.55bn ($1.18bn; £778m) in insurance losses so far.
[{'summary_text': "It's been a long time since I was selected by BBC Sport to pick my own team of the week for BBC Sport."}]
Manchester United did Chelsea's title rivals Tottenham a favour and kept up their own pursuit of the top four with a dominant win over the Premier League leaders.
[{'summary_text': 'A Belarusian politician has been found guilty of tax evasion after he held bank accounts in Poland and Lithuania.'}]
One of Belarus' most prominent human rights activists has been sentenced to four-and-a-half years in prison for tax evasion.

小结

其实对于简单的文本摘要任务,常规的 BERT 或者 T5-small 这种传统的模型也是可以胜任的,如果没有更高精度的业务需求没有必要一味追求 LLM ,毕竟对于个人或者中小企业来说目前使用 LLM 的成本较高。

相关推荐
小不点区块2 小时前
大舍传媒:如何在海外新闻媒体发稿报道摩洛哥?
大数据·人工智能·驱动开发·阿里云
NewsMash2 小时前
UGC与AI引领的下一个10年,丝芭传媒已经准备好
人工智能·传媒
SEU-WYL3 小时前
基于深度学习的动画渲染
人工智能·深度学习·dnn
罗必答5 小时前
意得辑ABSJU202优惠15%啦,新用户注册直减哦
人工智能
羞儿6 小时前
【读点论文】基于二维伽马函数的光照不均匀图像自适应校正算法
人工智能·算法·计算机视觉
算法金「全网同名」7 小时前
算法金 | 时间序列预测真的需要深度学习模型吗?是的,我需要。不,你不需要?
深度学习·机器学习·数据分析
SEU-WYL7 小时前
基于深度学习的文本框检测
人工智能·深度学习·dnn
B站计算机毕业设计超人7 小时前
计算机毕业设计Python深度学习美食推荐系统 美食可视化 美食数据分析大屏 美食爬虫 美团爬虫 机器学习 大数据毕业设计 Django Vue.js
大数据·python·深度学习·机器学习·数据分析·课程设计·推荐算法
电商运营花7 小时前
告别盲目跟风!1688竞品数据分析实战指南(图文解析)
大数据·人工智能·经验分享·笔记·数据挖掘·数据分析
Rjdeng7 小时前
【AI大模型】驱动的未来:穿戴设备如何革新血液、皮肤检测与营养健康管理
人工智能·ai·穿戴设备·血液分析·营养健康