Hugging Face微调语言模型:问答任务

本教程展示如何微调训练一个支持问答任务的模型。微调后的模型仍然是通过提取上下文的子串来回答问题的,而不是生成新的文本。

一、数据集下载

数据使用斯坦福问答数据集(SQuAD)

1、SQuAD 数据集

斯坦福问答数据集(Stanford Question Answering Dataset, SQuAD)是自然语言处理领域最具影响力的机器阅读理解数据集之一,被誉为"机器阅读理解界的ImageNet"。它由斯坦福大学研究人员于2016年发布,旨在推动问答系统从简单的关键词匹配向深层次文本理解演进。

  • SQuAD 1.1 ‌(2016年)
    包含536篇维基百科文章、约10万个问题-答案对,所有问题均可在上下文中找到明确答案。
  • SQuAD 2.0 ‌(2018年升级版)
    SQuAD 1.1基础上新增超过5万个"无法回答的问题"(unanswerable questions),这些问题是人为构造、看似合理但上下文中无对应答案的问题。这一设计显著提升了任务难度,迫使模型学会判断问题是否可答,更贴近真实应用场景。
python 复制代码
from datasets import load_dataset
squad_v2 = False #炼丹机器性能好,选择squad_v2数据集(True), 性能一般选择squad数据集(False)
datasets = load_dataset("squad_v2" if squad_v2 else "squad") 

打印数据集信息:

python 复制代码
print(datasets)
print(datasets['train'][10])

输出:

python 复制代码
DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})
python 复制代码
{'id': '5733bed24776f41900661188', 
 'title': 'University_of_Notre_Dame', 'context': 'The university is the major seat of the Congregation of Holy Cross (albeit not its official headquarters, which are in Rome). Its main seminary, Moreau Seminary, is located on the campus across St. Joseph lake from the Main Building. Old College, the oldest building on campus and located near the shore of St. Mary lake, houses undergraduate seminarians. Retired priests and brothers reside in Fatima House (a former retreat center), Holy Cross House, as well as Columba Hall near the Grotto. The university through the Moreau Seminary has ties to theologian Frederick Buechner. While not Catholic, Buechner has praised writers from Notre Dame and Moreau Seminary created a Buechner Prize for Preaching.', 
 'question': 'Where is the headquarters of the Congregation of the Holy Cross?',
 'answers': {'text': ['Rome'], 'answer_start': [119]}}
 

可以看到答案是text文本和在文本中的起始位置(这里是第119个字符)表示的。使用HTML格式输出数据集示例(jupyter用户可展示):

python 复制代码
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

# 展示    
show_random_elements(datasets["train"], num_examples = 1)

输出:

| | id | title | context | question | answers |

0 572673df5951b619008f730b Professional_wrestling Though they have not had the level of exposure as other wrestlers, bears have long been a part of professional wrestling. Usually declawed and muzzled, they often wrestled shoot matches against audience members, offered a cash reward if they could pin the bear. They also wrestled professionals in worked, often battle royal or handicap, matches (usually booked so the bear won). Though they have wrestled around the world and continue to do so, wrestling bears enjoyed their greatest popularity in the Southern United States, during the 1960s and 1970s. The practice of bear wrestling has met strong opposition from animal rights activists in recent decades, contributing to its lack of mainstream acceptance. As of 2006, it is banned in 20 U.S. states. Perhaps the most famous wrestling bears are Ginger, Victor, Hercules and Terrible Ted. How many states have banned bear wrestling as of 2006? {'text': ['20'], 'answer_start': [739]}

二、预处理数据

加载distilbert-base-uncased问答模型,,并使用Tokenizers进行分词处理:

python 复制代码
import transformers

model_checkpoint = "distilbert-base-uncased"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint) # "distilbert-base-uncased"

# Tokenizers 使用的是 FastTokenizer(Rust 实现,速度和功能性上有一定优势)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

使用tokenizer对文本进行编码,可以对多个句子进行编码(一个用于上下文,一个用于答案):

python 复制代码
tokenizer("What is your name?", "My name is Sylvain.")

输出:

复制代码
{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

在其他任务中,当文档的长度超过模型最大句子长度时,通常会截断它们。但在问答任务中,直接删除上下文,可能会导致答案丢失。

为了解决这个问题,允许数据集中的一个(长)句子拆分成多个句子,每个特征的长度都小于模型的最大长度。

python 复制代码
# The maximum length of a feature (question and context)
max_length = 384 
# The authorized overlap between two part of the context when splitting it is needed.
doc_stride = 128 # 截断时每次移动多少个 token,保留更多上下文信息
  • 处理超出最大长度的文本数据示例:

从训练集中找出一个超过最大长度(384)的文本:

python 复制代码
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        # 挑选出来超过384(最大长度)的数据样例
        example = datasets["train"][i]  
        break
print(len(tokenizer(example["question"], example["context"])["input_ids"]))

问题和上下文长度超过384:

复制代码
396

截断的策略:

  • 直接截断超出部分: truncation=only_second
  • 仅截断上下文(context),保留问题(question): return_overflowing_tokens=True
python 复制代码
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride # 128
)
  • example["question"]example["context"]:这是两个输入文本。通常在问答任务中,question 是用户提出的问题,而 context 是模型需要从中查找答案的上下文文本。这两个文本会被拼接在一起进行编码。
  • max_length=max_length:该参数用于指定最大序列长度。如果输入文本经过分词后超过这个长度,将会根据 truncation 参数的设置进行截断。
  • truncation="only_second":此参数控制如何截断文本。当输入包含两个文本(如问题和上下文)时,truncation="only_second" 表示只对第二个文本(即 context)进行截断,保留第一个文本(即 question)的完整内容。
  • return_overflowing_tokens=True:当输入文本被截断后,该参数决定是否返回被截断的部分。这在处理长文本时非常有用,因为它允许你获取所有可能的分块,以便后续处理。
  • stride=doc_stride:该参数用于设置滑动窗口的步长,即在截断时每次移动多少个 token。它与 return_overflowing_tokens=True 配合使用,可以实现对长文本的分块处理,从而保留更多上下文信息。

使用此策略截断后,Tokenizer 将返回多个 input_ids 列表。

python 复制代码
[len(x) for x in tokenized_example["input_ids"]]

输出:

复制代码
[384, 157]

解码两个特征,可以看到上下文重叠的部分:

python 复制代码
for x in tokenized_example["input_ids"]:
    print(tokenizer.decode(x))

输出:

复制代码
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 -- 2010 season. the team is coached by mike brey, who, as of the 2014 -- 15 season, his fifteenth at notre dame, has achieved a 332 - 165 record. in 2009 they were invited to the nit, where they advanced to the semifinals but were beaten by penn state who went on and beat baylor in the championship. the 2010 -- 11 team concluded its regular season ranked number seven in the country, with a record of 25 -- 5, brey's fifth straight 20 - win season, and a second - place finish in the big east. during the 2014 - 15 season, the team went 32 - 6 and won the acc conference tournament, later advancing to the elite 8, where the fighting irish lost on a missed buzzer - beater against then undefeated kentucky. led by nba draft picks jerian grant and pat connaughton, the fighting irish beat the eventual national champion duke blue devils twice during the season. the 32 wins were [SEP]
[CLS] how many wins does the notre dame men's basketball team have? [SEP] championship. the 2010 -- 11 team concluded its regular season ranked number seven in the country, with a record of 25 -- 5, brey's fifth straight 20 - win season, and a second - place finish in the big east. during the 2014 - 15 season, the team went 32 - 6 and won the acc conference tournament, later advancing to the elite 8, where the fighting irish lost on a missed buzzer - beater against then undefeated kentucky. led by nba draft picks jerian grant and pat connaughton, the fighting irish beat the eventual national champion duke blue devils twice during the season. the 32 wins were the most by the fighting irish team since 1908 - 09. [SEP]
  • 使用 offsets_mapping 获取原始的 input_ids

设置 return_offsets_mapping=True,将使得截断分割生成的多个 input_ids 列表中的 token,通过映射保留原始文本的 input_ids。

如下所示:第一个标记([CLS])的起始和结束字符都是(0, 0),因为它不对应问题/答案的任何部分,然后第二个标记与问题(question)的字符0到3相同.

python 复制代码
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)
print(tokenized_example["offset_mapping"][0][:100])

分词与截断 :将example["question"](问题)和example["context"](上下文)合并分词。通过truncation="only_second"指定仅截断过长的上下文(第二个参数),而问题部分完整保留。

处理长文本 :当上下文超过max_length时:

  • return_overflowing_tokens=True生成多个片段(避免信息丢失)
  • stride=doc_stride控制片段间的重叠步长(如步长=128,则相邻片段重叠128个token),确保截断边界的信息连续性。

偏移量映射return_offsets_mapping=True返回每个token在‌原始文本 ‌中的字符级位置(起始索引,结束索引)。例如:
[(0,2), (3,7), ...] 表示第一个token占据原始文本0-2字符位置。

输出解析tokenized_example["offset_mapping"][0][:100] 打印‌第一个片段‌中前100个token的原始文本位置映射,用于后续答案定位(如匹配答案在原文中的精确位置)。

输出:

复制代码
[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373), (374, 379), (379, 380), (381, 384), (385, 389), (390, 393), (394, 406), (407, 408), (409, 415), (416, 418)]

因此,通过映射关系可以找到答案在上下文中的起始和结束位置。

我们只需区分偏移的哪些部分对应于问题,哪些部分对应于上下文。

python 复制代码
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

输出:

复制代码
how How

借助tokenized_examplesequence_ids方法,可以方便的区分token的来源编号:

  • 对于特殊标记:返回None,
  • 对于正文Token:返回句子编号(从0开始编号)。

综上,可以很方便的在一个输入特征中找到答案的起始和结束 Token。

python 复制代码
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

输出:

复制代码
[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]

在自然语言处理任务中,将文本中的答案 span 映射到 tokenized 输入的起始和结束位置。它通过比较字符级别的起始和结束位置与tokenized 文本的 offset 映射,来确定答案在 token 序列中的位置。

python 复制代码
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# 当前span在文本中的起始标记索引。
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1

# 当前span在文本中的结束标记索引。
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1

# 检测答案是否超出span范围(如果超出范围,该特征将以CLS标记索引标记)。
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # 将token_start_index和token_end_index移动到答案的两端。
    # 注意:如果答案是最后一个单词,我们可以移到最后一个标记之后(边界情况)。
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("答案不在此特征中。")

输出:

复制代码
23 26

打印检查是否准确找到了起始位置:

python 复制代码
# 通过查找 offset mapping 位置,解码 context 中的答案 
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
# 直接打印 数据集中的标准答案(answer["text"])
print(answers["text"][0])

输出:

python 复制代码
over 1, 600
over 1,600
  • 关于填充的策略

    • 对于没有超过最大长度的文本,填充补齐长度。

    • 对于需要左侧填充的模型,交换 question 和 context 顺序

python 复制代码
pad_on_right = tokenizer.padding_side == "right"

将所有内容整合到一个函数中,并将其应用到训练集。针对不可回答的情况(上下文过长,答案在另一个特征中),为开始和结束位置都设置了cls索引。如果allow_impossible_answers标志为False,可以简单地从训练集中丢弃这些示例。

python 复制代码
def prepare_train_features(examples):
    # 一些问题的左侧可能有很多空白字符,这对我们没有用,而且会导致上下文的截断失败
    # (标记化的问题将占用大量空间)。因此,我们删除左侧的空白字符。
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # 使用截断和填充对我们的示例进行标记化,但保留溢出部分,使用步幅(stride)。
    # 当上下文很长时,这会导致一个示例可能提供多个特征,其中每个特征的上下文都与前一个特征的上下文有一些重叠。
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 由于一个示例可能给我们提供多个特征(如果它具有很长的上下文),我们需要一个从特征到其对应示例的映射。这个键就提供了这个映射关系。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # 偏移映射将为我们提供从令牌到原始上下文中的字符位置的映射。这将帮助我们计算开始位置和结束位置。
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # 让我们为这些示例进行标记!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # 我们将使用 CLS 特殊 token 的索引来标记不可能的答案。
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # 获取与该示例对应的序列(以了解上下文和问题是什么)。
        sequence_ids = tokenized_examples.sequence_ids(i)

        # 一个示例可以提供多个跨度,这是包含此文本跨度的示例的索引。
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # 如果没有给出答案,则将cls_index设置为答案。
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案在文本中的开始和结束字符索引。
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # 当前跨度在文本中的开始令牌索引。
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # 当前跨度在文本中的结束令牌索引。
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # 检测答案是否超出跨度(在这种情况下,该特征的标签将使用CLS索引)。
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # 否则,将token_start_index和token_end_index移到答案的两端。
                # 注意:如果答案是最后一个单词(边缘情况),我们可以在最后一个偏移之后继续。
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

使用 datasets.map 方法将 prepare_train_features 应用于所有训练、验证和测试数据:

  • batched: 批量处理数据
  • remove_columns: 因为预处理更改了样本的数量(一条示例被拆分成了多条示例),所以在应用它时需要删除旧列
  • load_from_cache_file:是否使用datasets库的自动缓存

datasets 库针对大规模数据,实现了高效缓存机制,能够自动检测传递给 map 的函数是否已更改(因此需要不使用缓存数据)。如果在调用 map 时设置 load_from_cache_file=False,可以强制重新应用预处理。

python 复制代码
tokenized_datasets = datasets.map(prepare_train_features,
                                  batched=True,
                                  remove_columns=datasets["train"].column_names)

三、微调模型

使用 AutoModelForQuestionAnswering问答模型。

python 复制代码
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

警告信息:正在丢弃一些权重(vocab_transformvocab_layer_norm 层),并随机初始化其他一些权重(pre_classifierclassifier 层)。在微调模型情况下是正常的,因为正在删除用于预训练模型的掩码语言建模任务的头部,并用一个新的头部替换它,对于这个新头部,没有预训练的权重,所以库会警告在用它进行推理之前应该对这个模型进行微调。

python 复制代码
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

1、训练超参数(TrainingArguments)

python 复制代码
batch_size = 32 #根据个人炼丹设备进行调整
model_dir = f"models/{model_checkpoint}-finetuned-squad"

args = TrainingArguments(
    output_dir=model_dir,
    eval_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

2、Data Collator(数据整理器)

数据整理器将训练数据整理为批次数据,用于模型训练时的批次处理。

python 复制代码
from transformers import default_data_collator

data_collator = default_data_collator

3、实例化训练器(Trainer)

为了减少训练时间(需要大量算力支持),训练模型过程中不计算模型评估指标。训练完成后,再单独进行模型评估。

python 复制代码
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)
trainer.train()

输出:

python 复制代码
{'train_runtime': 41287.5044, 'train_samples_per_second': 6.432, 'train_steps_per_second': 0.402, 'train_loss': 1.086345330872689, 'epoch': 3.0}

4、训练完成后,第一时间保存模型权重文件

python 复制代码
model_to_save = trainer.save_model(model_dir)

四、模型评估

1、查看模型输出

模型输出需要将模型的预测映射回上下文的部分。模型直接输出的是预测答案的起始位置结束位置loss

查看模型输出:

python 复制代码
import torch

for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

输出:

python 复制代码
odict_keys(['loss', 'start_logits', 'end_logits'])

查看start_logitsend_logits属性:

python 复制代码
output.start_logits.shape, output.end_logits.shape

输出:

复制代码
torch.Size([16, 384]) torch.Size([16, 384])

模型输出的 start_logitsend_logits 的维度 torch.Size([64, 384]) 具体含义如下:

  • 第一个维度 64 ‌:表示批次大小(batch size),即一次输入处理的样本数量。在这里,每次处理 64 个问题-上下文对。
  • 第二个维度 384 ‌:表示序列长度(sequence length),即模型处理的最大 token 数量。对于 BERT 等模型,通常最大序列长度为 512,但实际对模型进行了截断,最大长度为 384。

具体来说,每个样本的输出维度解释如下:

  1. start_logits‌:每个位置作为答案起始位置的得分,维度为 [64, 384],表示每个样本中每个 token 作为答案开始位置的可能性分数。
  2. end_logits‌:每个位置作为答案结束位置的得分,维度为 [64, 384],表示每个样本中每个 token 作为答案结束位置的可能性分数。

查看start_logitsend_logits维度:

python 复制代码
output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1)

输出:

python 复制代码
tensor([89, 89, 86, 86, 92, 87, 86, 85, 87, 84, 87, 84, 88, 88, 95, 89],
       device='cuda:0') 
tensor([98, 98, 95, 95, 17, 12, 11, 10, 12,  9, 12, 93, 97, 97, 20, 14],
       device='cuda:0')

2、将模型的概率输出转换为文本答案

直接使用 start_logits的最大概率位置作为起始位置,end_logits的最大概率位置作为终止位置来形成答案是‌不推荐‌的做法。

这种方法存在几个关键问题:

(1)缺乏连贯性 ‌:单独的最大值可能对应不连续的 token 位置,导致答案片段不完整或无意义。

(2)忽略位置关系 ‌:startend位置之间需要满足逻辑关系,即起始位置必须小于等于结束位置。

(3)最优解不一定是单点 ‌:在问答任务中,最优的答案应该是 start_logit + end_logit 组合得分最大的连续子序列,而不是两个独立的最大值。

正确的做法是计算所有可能的起始-结束位置对的得分总和,选择得分最高的连续片段作为最终答案。具体来说,应该计算所有满足 start ≤ end 的位置对 (start, end) 的得分:start_logit[start] + end_logit[end],然后选择得分最高的那个片段。

这种方法能够确保答案的连贯性和语义完整性,是抽取式问答任务的标准做法。

为了对答案进行分类,对前n_best_size = 20个分值高的start_logitend_logit进行排列组合。

python 复制代码
import numpy as np

n_best_size = 20 # 选取前20个高分位置进行排列组合

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()

# 获取最佳的起始和结束位置的索引:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()

valid_answers = []

# 遍历起始位置和结束位置的索引组合
for start_index in start_indexes:
    for end_index in end_indexes:
        if start_index <= end_index:  # 需要进一步测试以检查答案是否在上下文中
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": ""  # 我们需要找到一种方法来获取与上下文中答案对应的原始子字符串
                }
            )

根据它们的得分对valid_answers进行排序,并仅保留最佳答案。唯一剩下的问题是如何检查给定的跨度是否在上下文中(而不是问题中),以及如何获取其中的文本。为此,需要向验证特征添加两个内容:

  • 生成该特征的示例的ID(因为每个示例可以生成多个特征,如前所示)
  • 偏移映射,它将为我们提供从标记索引到上下文中字符位置的映射
python 复制代码
def prepare_validation_features(examples):
    # 一些问题的左侧有很多空白,这些空白并不有用且会导致上下文截断失败(分词后的问题会占用很多空间)。
    # 因此我们移除这些左侧空白
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # 使用截断和可能的填充对我们的示例进行分词,但使用步长保留溢出的令牌。这导致一个长上下文的示例可能产生
    # 几个特征,每个特征的上下文都会稍微与前一个特征的上下文重叠。
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 由于一个示例在上下文很长时可能会产生几个特征,我们需要一个从特征映射到其对应示例的映射。这个键就是为了这个目的。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # 我们保留产生这个特征的示例ID,并且会存储偏移映射。
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # 获取与该示例对应的序列(以了解哪些是上下文,哪些是问题)。
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # 一个示例可以产生几个文本段,这里是包含该文本段的示例的索引。
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # 将不属于上下文的偏移映射设置为None,以便容易确定一个令牌位置是否属于上下文。
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

prepare_validation_features应用到整个验证集:

python 复制代码
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

预测:

python 复制代码
raw_predictions = trainer.predict(validation_features)

Trainer会隐藏模型不使用的列(在这里是example_idoffset_mapping,需要它们进行后处理),所以需要将它们重新设置回来:

python 复制代码
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

由于在偏移映射中,当它对应于问题的一部分时,将其设置为None,因此可以轻松检查答案是否完全在上下文中。还可以从考虑中排除非常长的答案(可以调整的超参数)。

展开说下具体实现:

  • 首先从模型输出中获取起始和结束的逻辑值(logits),这些值表明答案在文本中可能开始和结束的位置。
  • 然后,它使用偏移映射(offset_mapping)来找到这些逻辑值在原始文本中的具体位置。
  • 接下来,代码遍历可能的开始和结束索引组合,排除那些不在上下文范围内或长度不合适的答案。
  • 对于有效的答案,它计算出一个分数(基于开始和结束逻辑值的和),并将答案及其分数存储起来。
  • 最后,它根据分数对答案进行排序,并返回得分最高的几个答案。
python 复制代码
max_answer_length = 30

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]

# 第一个特征来自第一个示例。对于更一般的情况,我们需要将example_id匹配到一个示例索引
context = datasets["validation"][0]["context"]

# 收集最佳开始/结束逻辑的索引:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        # 不考虑超出范围的答案,原因是索引超出范围或对应于输入ID的部分不在上下文中。
        if (
            start_index >= len(offset_mapping)
            or end_index >= len(offset_mapping)
            or offset_mapping[start_index] is None
            or offset_mapping[end_index] is None
        ):
            continue
        # 不考虑长度小于0或大于max_answer_length的答案。
        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
            continue
        if start_index <= end_index: # 我们需要细化这个测试,以检查答案是否在上下文中
            start_char = offset_mapping[start_index][0]
            end_char = offset_mapping[end_index][1]
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": context[start_char: end_char]
                }
            )

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

输出:

复制代码
[{'score': 15.986347, 'text': 'Denver Broncos'},
 {'score': 14.585561,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 13.152991, 'text': 'Carolina Panthers'},
 {'score': 12.38233, 'text': 'Broncos'},
 {'score': 10.981544,
  'text': 'Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 10.852013,
  'text': 'American Football Conference (AFC) champion Denver Broncos'},
 {'score': 10.635618,
  'text': 'The American Football Conference (AFC) champion Denver Broncos'},
 {'score': 10.283654, 'text': 'Denver'},
 {'score': 9.451225,
  'text': 'American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 9.234833,
  'text': 'The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 8.7582445,
  'text': 'Denver Broncos defeated the National Football Conference'},
 {'score': 8.187819,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina'},
 {'score': 8.134832, 'text': 'Panthers'},
 {'score': 8.092252,
  'text': 'Denver Broncos defeated the National Football Conference (NFC)'},
 {'score': 7.7162285,
  'text': 'the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 7.595868,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24--10'},
 {'score': 7.382572,
  'text': 'National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 7.320059,
  'text': 'Denver Broncos defeated the National Football Conference (NFC'},
 {'score': 6.755249, 'text': 'Carolina'},
 {'score': 6.728976, 'text': 'champion Denver Broncos'}]

打印比较模型输出和标准答案(Ground-truth)是否一致:

python 复制代码
datasets["validation"][0]["answers"]

输出:

复制代码
{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

**模型最高概率的输出与标准答案一致。**对于其他特征,需要建立一个示例与其对应特征的映射关系。由于一个示例可以生成多个特征,需要将由给定示例生成的所有特征中的所有答案汇集在一起,然后选择最佳答案。下面的代码构建了一个示例索引到其对应特征索引的映射关系:

python 复制代码
import collections

examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)

squad_v2 = True时,有一定概率出现不可能的答案(impossible answer)。上面的代码仅保留在上下文中的答案,还需要获取不可能答案的分数(其起始和结束索引对应于CLS标记的索引)。

当一个示例生成多个特征时,必须在所有特征中的不可能答案都预测出现不可能答案时(因为一个特征可能之所以能够预测出不可能答案,是因为答案不在它可以访问的上下文部分),这就是为什么一个示例中不可能答案的分数是该示例生成的每个特征中的不可能答案的分数的最小值。

python 复制代码
from tqdm.auto import tqdm

def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # 构建一个从示例到其对应特征的映射。
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # 我们需要填充的字典。
    predictions = collections.OrderedDict()

    # 日志记录。
    print(f"正在后处理 {len(examples)} 个示例的预测,这些预测分散在 {len(features)} 个特征中。")

    # 遍历所有示例!
    for example_index, example in enumerate(tqdm(examples)):
        # 这些是与当前示例关联的特征的索引。
        feature_indices = features_per_example[example_index]

        min_null_score = None # 仅在squad_v2为True时使用。
        valid_answers = []
        
        context = example["context"]
        # 遍历与当前示例关联的所有特征。
        for feature_index in feature_indices:
            # 我们获取模型对这个特征的预测。
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # 这将允许我们将logits中的某些位置映射到原始上下文中的文本跨度。
            offset_mapping = features[feature_index]["offset_mapping"]

            # 更新最小空预测。
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # 浏览所有的最佳开始和结束logits,为 `n_best_size` 个最佳选择。
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # 不考虑超出范围的答案,原因是索引超出范围或对应于输入ID的部分不在上下文中。
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # 不考虑长度小于0或大于max_answer_length的答案。
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # 在极少数情况下我们没有一个非空预测,我们创建一个假预测以避免失败。
            best_answer = {"text": "", "score": 0.0}
        
        # 选择我们的最终答案:最佳答案或空答案(仅适用于squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

在原始结果上应用后处理问答结果:

python 复制代码
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

使用 datasets.load_metric 中加载 SQuAD v2 的评估指标

python 复制代码
from datasets import load_metric

metric = load_metric("squad_v2" if squad_v2 else "squad")

接下来,可以调用上面定义的函数进行评估。只需稍微调整一下预测和标签的格式,因为它期望的是一系列字典而不是一个大字典。在使用squad_v2数据集时,还需要设置no_answer_probability参数。

python 复制代码
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

输出:

python 复制代码
{'exact_match': 74.88174077578051, 'f1': 83.6359321422016}
相关推荐
无名修道院3 小时前
AI大模型-面向开发者的开源框架:构建语言模型应用的实用指南
语言模型·agent·ai大模型
陈健平3 小时前
用 Kimi 2.5 Agent 从 0 搭建「宇宙吞噬,ps:和球球大作战这种差不多」对抗小游戏(Canvas 粒子特效 + AI Bot + 排行榜)
android·人工智能·agent·kimi2.5
切糕师学AI3 小时前
什么是边缘计算(Edge Computing)架构?
人工智能·架构·边缘计算
shenxianasi4 小时前
【论文精读】Florence: A New Foundation Model for Computer Vision
人工智能·机器学习·计算机视觉·自然语言处理·transformer
天天爱吃肉82184 小时前
新能源汽车多测试设备联调与多物理信息融合测试方法及数据价值挖掘
人工智能·嵌入式硬件·机器学习·汽车
EnCi Zheng4 小时前
04a. LayoutParser 安装指南
人工智能
心本无晴.4 小时前
ClawdBot:从桌面自动化到个人AI助手的进化之路
运维·人工智能·自动化
Li emily12 小时前
成功接入A股实时行情API获取实时市场数据
人工智能·python·金融·fastapi