使用 T5 模型完成新闻摘要任务

前文

内容摘要是自然语言处理(NLP)的核心问题之一,实现这一能目标必须让模型具备语言理解和内容生成两大能力。本文使用新闻数据,通过微调 T5 模型来完成提取新闻摘要这一任务。

要实现这一任务需要使用 Seq2Seq 模型,这也是为什么选中 T5 的原因。 Text-to-Text Transfer Transformer (T5) 是一种基于 Transformer 的模型,构建在编码器-解码器架构上,在众多无监督和监督任务中进行了多任务混合预训练,其中每个任务都是文本到文本的形式。T5 在摘要、翻译等领域表现很出色。

数据

本文要使用的数据集是 XSum 数据集,该数据集每个样本由 BBC 文章 document 和对应的单句摘要 summary 组成。下面展示一个样本如下:

vbnet 复制代码
{
'document': 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.', 
'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.', 
'id': '35232142'
} 

为了和我们使用的大模型 T5 配套使用,我们处理数据的 tokenizer 也是加载自 T5 ,这样处理的结果符合模型的输入要求。 其中具体的工作主要是在 document 前面加上 "summarize: ", 然后将各个 documentsummary 处理成对应的 input_ids ,token_type_ids ,attention_mask

ini 复制代码
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=MAX_TARGET_LENGTH, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

另外为了符合训练 Seq2Seq 模型的要求,我们需要一种特殊的数据处理器,它不仅可以将输入扩充到 batch 中的最长样本的长度,还可以填充标签。我们这里选用的是 DataCollatorForSeq2Seq

ini 复制代码
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="tf")
train_dataset = tokenized_datasets["train"].to_tf_dataset(batch_size=BATCH_SIZE, columns=["input_ids", "attention_mask", "labels"], shuffle=True, collate_fn=data_collator)
test_dataset = tokenized_datasets["test"].to_tf_dataset(batch_size=BATCH_SIZE,  columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)
generation_dataset = tokenized_datasets["test"].shuffle().select(list(range(200))).to_tf_dataset(batch_size=BATCH_SIZE,   columns=["input_ids", "attention_mask", "labels"], shuffle=False, collate_fn=data_collator)

模型

为了更好评估我们的模型效果,我们会计算真实序列和预测序列之间的 ROUGE scoreROUGE(Recall-Oriented Understudy for Gisting Evaluation) 是一种用于自动评估文本摘要质量的指标。它主要用于衡量生成的摘要与真实摘要之间的相似性,是评估文本摘要生成系统性能的常用指标之一。

ini 复制代码
def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_l(decoded_labels, decoded_predictions)
    result = {"RougeL": result["f1_score"]}
    return result

因为显存有限,所以选择了 t5-small 这个小模型,只有 6 层结构, 大约 93M 的参数,属于 t5 家族中较小的模型,但是足够胜任本次的任务。

另外为了微调模型能达到最佳的效果,我们使用全部数据来进行预训练,并进行 5 个 epoch ,batch_size 是 8 个,需要 17G 左右的显存,耗时将近 4.5 小时

arduino 复制代码
Epoch 1/5
20405/20405 [==============================] - 3240s 158ms/step - loss: 2.6656 - val_loss: 2.3561 - RougeL: 0.2318
Epoch 2/5
20405/20405 [==============================] - 3200s 157ms/step - loss: 2.5085 - val_loss: 2.2873 - RougeL: 0.2393
Epoch 3/5
20405/20405 [==============================] - 3199s 157ms/step - loss: 2.4330 - val_loss: 2.2488 - RougeL: 0.2438
Epoch 4/5
20405/20405 [==============================] - 3212s 157ms/step - loss: 2.3785 - val_loss: 2.2172 - RougeL: 0.2468
Epoch 5/5
20405/20405 [==============================] - 3239s 159ms/step - loss: 2.3335 - val_loss: 2.1938 - RougeL: 0.2493

推理

推理的过程我们不用编写复杂的推理代码,Hugging Face Transformers 库中的 pipeline 工具已经针对各种任务提供了各种流水线工具,比如这里我们使用 summarization 工具,再结合我们经过微调的大模型 t5-small ,即可对给定的输入文档生成摘要。

ini 复制代码
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, framework="tf")
for i in range(5):
    summarizer(raw_datasets["test"][i]["document"], min_length=MIN_TARGET_LENGTH, max_length=MAX_TARGET_LENGTH)
    print(raw_datasets["test"][i]["summary"])

推理结果和原摘要进行对比,可以看出:

vbnet 复制代码
[{'summary_text': 'The Australian government is trying to verify the deaths of two Lebanese fighters who left the country to fight with Islamic State militants.'}]
The Australian government says new citizenship laws will help combat terrorism but experts worry too many people will be affected by the laws, including children.
[{'summary_text': "It's been a week since the murder of Lucky the donkey, who was found dead at her home in Fermanagh, County Down."}]
On the day of her burial, County Fermanagh murder victim Connie Leonard's legacy is already being felt.
[{'summary_text': 'Insurance claims for catastrophic storms in Australia have risen by more than A$8.8m (£8.6m) in the past year, a council has said.'}]
Severe storms that hit Australia during April and May have led to more than A$1.55bn ($1.18bn; £778m) in insurance losses so far.
[{'summary_text': "It's been a long time since I was selected by BBC Sport to pick my own team of the week for BBC Sport."}]
Manchester United did Chelsea's title rivals Tottenham a favour and kept up their own pursuit of the top four with a dominant win over the Premier League leaders.
[{'summary_text': 'A Belarusian politician has been found guilty of tax evasion after he held bank accounts in Poland and Lithuania.'}]
One of Belarus' most prominent human rights activists has been sentenced to four-and-a-half years in prison for tax evasion.

小结

其实对于简单的文本摘要任务,常规的 BERT 或者 T5-small 这种传统的模型也是可以胜任的,如果没有更高精度的业务需求没有必要一味追求 LLM ,毕竟对于个人或者中小企业来说目前使用 LLM 的成本较高。

相关推荐
SpikeKing几秒前
LLM - 大模型 ScallingLaws 的指导模型设计与实验环境(PLM) 教程(4)
人工智能·llm·transformer·plm·scalinglaws
编码浪子9 分钟前
Transformer的编码机制
人工智能·深度学习·transformer
IE0623 分钟前
深度学习系列76:流式tts的一个简单实现
人工智能·深度学习
GIS数据转换器28 分钟前
城市生命线安全保障:技术应用与策略创新
大数据·人工智能·安全·3d·智慧城市
一水鉴天2 小时前
为AI聊天工具添加一个知识系统 之65 详细设计 之6 变形机器人及伺服跟随
人工智能
m0_743106465 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106465 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
井底哇哇8 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证8 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩8 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer