【HuggingFace LLM】经典NLP微调任务之掩码自回归

正文

模型挑选与数据集一览

python 复制代码
from transformers import AutoModelForMaskedLM

model_checkpoint = "distilbert-base-uncased"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

distilbert_num_parameters = model.num_parameters() / 1_000_000
print(f"'>>> DistilBERT number of parameters: {round(distilbert_num_parameters)}M'")
print(f"'>>> BERT number of parameters: 110M'")

'>>> DistilBERT number of parameters: 67M'
'>>> BERT number of parameters: 110M'

选择通过蒸馏得到的学生模型DistilBERT,参数量相比BERT约为一半。

python 复制代码
from transformers import AutoTokenizer
import torch

text = "This is a great [MASK]."
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

由于input_ids中经过tokenizer之后,会有一个掩码符号[MASK],因此首先对其进行定位。

token_logits的形状是[1, 8, 30522](意味着vocabulary有这么多的维度)。

当输入经过模型后,选择[MASK]位置的topk概率token,通过解码后还原到了原文本中。

python 复制代码
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")
>>>imdb_dataset
>>>DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

# 查看数据
sample = imdb_dataset["train"].shuffle(seed=42).select(range(3))

for row in sample:
    print(f"\n'>>> Review: {row['text']}'")
    print(f"'>>> Label: {row['label']}'")
    

#'>>> Review: This is your typical Priyadarshan movie--a bunch of loony characters out on some silly mission. His signature climax has the entire cast of the film coming together and fighting each other in some crazy moshpit over hidden money. Whether it is a winning lottery ticket in Malamaal Weekly, black money in Hera Pheri, "kodokoo" in Phir Hera Pheri, etc., etc., the director is becoming ridiculously predictable. Don\'t get me wrong; as clichéd and preposterous his movies may be, I usually end up enjoying the comedy. However, in most his previous movies there has actually been some good humor, (Hungama and Hera Pheri being noteworthy ones). Now, the hilarity of his films is fading as he is using the same formula over and over again.<br /><br />Songs are good. Tanushree Datta looks awesome. Rajpal Yadav is irritating, and Tusshar is not a whole lot better. Kunal Khemu is OK, and Sharman Joshi is the best.'
#'>>> Label: 0'

#'>>> Review: Okay, the story makes no sense, the characters lack any dimensionally, the best dialogue is ad-libs about the low quality of movie, the cinematography is dismal, and only editing saves a bit of the muddle, but Sam" Peckinpah directed the film. Somehow, his direction is not enough. For those who appreciate Peckinpah and his great work, this movie is a disappointment. Even a great cast cannot redeem the time the viewer wastes with this minimal effort.<br /><br />The proper response to the movie is the contempt that the director San Peckinpah, James Caan, Robert Duvall, Burt Young, Bo Hopkins, Arthur Hill, and even Gig Young bring to their work. Watch the great Peckinpah films. Skip this mess.'
#'>>> Label: 0'

#'>>> Review: I saw this movie at the theaters when I was about 6 or 7 years old. I loved it then, and have recently come to own a VHS version. <br /><br />My 4 and 6 year old children love this movie and have been asking again and again to watch it. <br /><br />I have enjoyed watching it again too. Though I have to admit it is not as good on a little TV.<br /><br />I do not have older children so I do not know what they would think of it. <br /><br />The songs are very cute. My daughter keeps singing them over and over.<br /><br />Hope this helps.'
#'>>> Label: 1'

大型电影评论数据集 (简称 IMDb)是一个电影评论语料库,常用于评估情感分析模型的性能

python 复制代码
# 检查训练集和测试集中的标签是否仅包含0/1?
def check_labels(atr='train'):
  train_len = imdb_dataset[atr].num_rows
  for batch in range(0, train_len, 1000):
    yield imdb_dataset[atr]['label'][batch: batch + 1000]

labels = check_labels('test')
# labels = check_labels('train')

flag = True
for label_batch in labels:
  zero_batch = [item for item in label_batch if item==0]
  one_batch = [item for item in label_batch if item==1]
  assert len(zero_batch) + len(one_batch) == len(label_batch), '存在其他标签'

上述代码是用于检查标签情况,是否在训练集和测试集中存在非1/0的标签值。

数据预处理

在掩码自回归预测任务中,我们把所有的样本都进行拼接 ,然后再根据max_length进行截断。这么做的目的是为了尽可能减少信息的丢失 ,如果按每个样本max_length截断的话,会丢失很多信息。

python 复制代码
def tokenize_function(examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    return result


# Use batched=True to activate fast multithreading!
tokenized_datasets = imdb_dataset.map(
    tokenize_function, batched=True, remove_columns=["text", "label"]
)
  1. 不启用truncation=True防止截断;
  2. 并且在返回的数据集中添加一个word_ids,提供tokeninput_ids之间的一个\[001 分类任务#tokenizer\|映射关系]。

!NOTE Title

word_ids主要是为了:

  1. token[MASK]时,提供一个原输入中的input_ids作为标签;
  2. 识别哪些是源自同一个 预分词后的单词 ,一般源于同一个单词的token设为同一个标签;

这里选择max_length = 128作为最大截断数。

python 复制代码
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

pdb.set_trace()后,检测到每个batch输入的长度为1000。因此相当于把1000组input_ids进行了一个拼接,然后再将其按照一份max_length=128进行分割。

!NOTE

python 复制代码
>>>a = {'a': [[1,2], [3,4,5]]}
>>>b = {k:sum(a[k], []) for k in a.keys()}
>>>a
>>>Out[4]: {'a': [[1, 2], [3, 4, 5]]}
>>>b
>>>Out[5]: {'a': [1, 2, 3, 4, 5]}

使用sum(a, [])的方式去把一个多层嵌套的列表展平为一维列表

使用API微调

数据整理器
python 复制代码
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

其中,mlm_probability=0.15是一个指定要掩码的词元比例。

python 复制代码
samples = [lm_datasets["train"][i] for i in range(2)]
for sample in samples:
    _ = sample.pop("word_ids")

for chunk in data_collator(samples)["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
    print(f"\n'>>> {tokenizer.convert_ids_to_tokens(chunk)}'")

!NOTE decodeconvert_ids_to_tokens的区别

二者的区别在于,decodeconvert_ids_to_tokens的高级封装,decode会把源自同一个单词的tokens合并 (例如hugging##face,会显示为huggingface);但是convert_ids_to_tokens会直接显示分开的tokens,而不会合并展示

全词掩码

打标签时,源自一个单词的tokens希望可以被打上同一个标签(同时被预测)。

python 复制代码
import collections
import numpy as np
import pdb
from transformers import default_data_collator

pdb.set_trace()
wwm_probability = 0.2

def whole_word_masking_data_collator(features):
    for feature in features:
        word_ids = feature.pop("word_ids")

        # Create a map between words and corresponding token indices
        mapping = collections.defaultdict(list)
        current_word_index = -1
        current_word = None
        for idx, word_id in enumerate(word_ids):
            if word_id is not None:  # None表示为特殊字符
                if word_id != current_word:
                    current_word = word_id
                    current_word_index += 1
                mapping[current_word_index].append(idx)

        # Randomly mask words
        mask = np.random.binomial(1, wwm_probability, (len(mapping),))
        input_ids = feature["input_ids"]
        labels = feature["labels"]
        new_labels = [-100] * len(labels)
        for word_id in np.where(mask)[0]:
            word_id = word_id.item()
            for idx in mapping[word_id]:
                new_labels[idx] = labels[idx]
                input_ids[idx] = tokenizer.mask_token_id
        feature["labels"] = new_labels

    return default_data_collator(features)

samples = [lm_datasets["train"][i] for i in range(2)]
batch = whole_word_masking_data_collator(samples)

for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")

调试过程如下,可展开显示。

python 复制代码
--Return--
None
> /tmp/ipython-input-1526888649.py(6)<cell line: 0>()
      4 from transformers import default_data_collator
      5 
----> 6 pdb.set_trace()
      7 wwm_probability = 0.2
      8 

ipdb> b 13
Breakpoint 1 at /tmp/ipython-input-1526888649.py:13
ipdb> c
> /tmp/ipython-input-1526888649.py(13)whole_word_masking_data_collator()
     11 def whole_word_masking_data_collator(features):
     12     for feature in features:
1--> 13         word_ids = feature.pop("word_ids")
     14 
     15         # Create a map between words and corresponding token indices

ipdb> p feature.shape
*** AttributeError: 'dict' object has no attribute 'shape'
ipdb> p list(feature.keys())
['input_ids', 'attention_mask', 'word_ids', 'labels']
ipdb> p len(feature['input_ids'])
128
ipdb> b 20
Breakpoint 2 at /tmp/ipython-input-1526888649.py:20
ipdb> c
> /tmp/ipython-input-1526888649.py(20)whole_word_masking_data_collator()
     18         current_word = None
     19         for idx, word_id in enumerate(word_ids):
2--> 20             if word_id is not None:
     21                 if word_id != current_word:
     22                     current_word = word_id

ipdb> p word_id
None
ipdb> c
> /tmp/ipython-input-1526888649.py(20)whole_word_masking_data_collator()
     18         current_word = None
     19         for idx, word_id in enumerate(word_ids):
2--> 20             if word_id is not None:
     21                 if word_id != current_word:
     22                     current_word = word_id

ipdb> p current_word_idx
*** NameError: name 'current_word_idx' is not defined
ipdb> p current_word_index
-1
ipdb> p mapping
defaultdict(<class 'list'>, {})
ipdb> b 23
Breakpoint 3 at /tmp/ipython-input-1526888649.py:23
ipdb> c
> /tmp/ipython-input-1526888649.py(23)whole_word_masking_data_collator()
     21                 if word_id != current_word:
     22                     current_word = word_id
3--> 23                     current_word_index += 1
     24                 mapping[current_word_index].append(idx)
     25 

ipdb> p current_word
0
ipdb> p idx
1
ipdb> c
> /tmp/ipython-input-1526888649.py(20)whole_word_masking_data_collator()
     18         current_word = None
     19         for idx, word_id in enumerate(word_ids):
2--> 20             if word_id is not None:
     21                 if word_id != current_word:
     22                     current_word = word_id

ipdb> p mapping
defaultdict(<class 'list'>, {0: [1]})
ipdb> cl
Clear all breaks? yes
Deleted breakpoint 1 at /tmp/ipython-input-1526888649.py:13
Deleted breakpoint 2 at /tmp/ipython-input-1526888649.py:20
Deleted breakpoint 3 at /tmp/ipython-input-1526888649.py:23
ipdb> b 27
Breakpoint 4 at /tmp/ipython-input-1526888649.py:27
ipdb> p mapping
defaultdict(<class 'list'>, {0: [1]})
ipdb> c
> /tmp/ipython-input-1526888649.py(27)whole_word_masking_data_collator()
     25 
     26         # Randomly mask words
4--> 27         mask = np.random.binomial(1, wwm_probability, (len(mapping),))
     28         input_ids = feature["input_ids"]
     29         labels = feature["labels"]

ipdb> p mapping
defaultdict(<class 'list'>, {0: [1], 1: [2], 2: [3], 3: [4], 4: [5], 5: [6], 6: [7], 7: [8], 8: [9], 9: [10], 10: [11], 11: [12], 12: [13], 13: [14], 14: [15], 15: [16], 16: [17], 17: [18], 18: [19], 19: [20], 20: [21], 21: [22], 22: [23], 23: [24], 24: [25], 25: [26], 26: [27], 27: [28], 28: [29], 29: [30], 30: [31], 31: [32], 32: [33], 33: [34], 34: [35], 35: [36], 36: [37], 37: [38], 38: [39], 39: [40], 40: [41], 41: [42], 42: [43], 43: [44], 44: [45], 45: [46], 46: [47], 47: [48], 48: [49], 49: [50], 50: [51], 51: [52], 52: [53], 53: [54], 54: [55], 55: [56], 56: [57], 57: [58], 58: [59], 59: [60], 60: [61], 61: [62], 62: [63], 63: [64], 64: [65], 65: [66], 66: [67], 67: [68], 68: [69], 69: [70], 70: [71], 71: [72], 72: [73], 73: [74], 74: [75], 75: [76], 76: [77], 77: [78], 78: [79], 79: [80], 80: [81], 81: [82], 82: [83], 83: [84], 84: [85], 85: [86], 86: [87], 87: [88], 88: [89], 89: [90], 90: [91], 91: [92], 92: [93], 93: [94], 94: [95], 95: [96], 96: [97], 97: [98], 98: [99], 99: [100], 100: [101], 101: [102], 102: [103], 103: [104], 104: [105], 105: [106], 106: [107], 107: [108, 109], 108: [110], 109: [111], 110: [112], 111: [113], 112: [114], 113: [115], 114: [116], 115: [117], 116: [118], 117: [119], 118: [120, 121], 119: [122], 120: [123], 121: [124], 122: [125], 123: [126], 124: [127]})
ipdb> b 28
Breakpoint 5 at /tmp/ipython-input-1526888649.py:28
ipdb> c
> /tmp/ipython-input-1526888649.py(28)whole_word_masking_data_collator()
     26         # Randomly mask words
4    27         mask = np.random.binomial(1, wwm_probability, (len(mapping),))
5--> 28         input_ids = feature["input_ids"]
     29         labels = feature["labels"]
     30         new_labels = [-100] * len(labels)

ipdb> p mask
array([0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0])
ipdb> b 32
Breakpoint 6 at /tmp/ipython-input-1526888649.py:32
ipdb> c
> /tmp/ipython-input-1526888649.py(32)whole_word_masking_data_collator()
     30         new_labels = [-100] * len(labels)
     31         for word_id in np.where(mask)[0]:
6--> 32             word_id = word_id.item()
     33             for idx in mapping[word_id]:
     34                 new_labels[idx] = labels[idx]

ipdb> p word_id
np.int64(1)
ipdb> cl
Clear all breaks? yes
Deleted breakpoint 4 at /tmp/ipython-input-1526888649.py:27
Deleted breakpoint 5 at /tmp/ipython-input-1526888649.py:28
Deleted breakpoint 6 at /tmp/ipython-input-1526888649.py:32
ipdb> c

'>>> [CLS] i [MASK] i am curious - [MASK] from [MASK] video store [MASK] of [MASK] the controversy that surrounded it [MASK] [MASK] was first released in 1967. i also heard [MASK] [MASK] first it was seized by u. s. customs if [MASK] ever tried to [MASK] this country, [MASK] being [MASK] fan of films considered " controversial " [MASK] really [MASK] [MASK] see this for myself. < br / > < br / > the plot is centered [MASK] a young [MASK] drama [MASK] named lena who [MASK] to learn everything she can about life. in particular she wants to [MASK] her attentions to making some sort of documentary [MASK] [MASK] [MASK] average swede [MASK] about certain political issues such'

'>>> as the vietnam war and race issues in the [MASK] states. in between asking politicians and ordinary denizens of stockholm about their [MASK] on politics, [MASK] [MASK] sex [MASK] her drama teacher [MASK] classmates, and married men. < br / > < br / > what kills me about i am curious - yellow [MASK] that 40 years ago, this was considered [MASK]. really [MASK] the sex and nudity scenes are few and far between, even then it ' [MASK] not [MASK] like some [MASK] [MASK] made porno. while [MASK] countrymen [MASK] find it shocking, in [MASK] sex and nudity are a major [MASK] [MASK] swedish cinema [MASK] even [MASK] [MASK] [MASK],'

由调试过程可知,每次输入的都是128,即最大截断输入的样本。

mapping作用是定位单词->tokens之间的映射关系

由于源于同一个单词的word_id会相同。

复制代码
raw -> 'EU    rejects German call to boycott British lamb .'
tokens -> ['[CLS]', 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'la', '##mb', '.', '[SEP]']
word_ids -> [None, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, None]

如上,源自同一个单词的la##mb会有相同的word_id

因此14-22代码逻辑在于:

  1. current_word_index遇到第一个token后,变成0(表示第0个单词);
  2. idx是在tokens中的对应token的索引,二者构成一个单词->token的映射
  3. 检索前后两个相邻word_id,若相同则持续在列表中添加表示一个单词->多个token的映射

!NOTE np.random.binomial

NumPy 中用于生成二项分布随机数 的核心函数,二项分布描述了n 次独立伯努利试验中成功次数的概率分布,举例说明:

python 复制代码
import numpy as np 
res = np.random.binomial(10, 0.5, 3) 
print(res) # 输出:[8, 6, 5]

np.random.binomial(n=10, p=0.5, size=3)的本质是:

  • n=10一次 "二项试验" 包含10 次独立的伯努利试验(比如抛 10 次硬币);
  • size=3 :总共做3 次独立的 "二项试验"(比如分 3 轮,每轮抛 10 次硬币);
  • p=0.5:每次伯努利试验的成功概率(比如抛硬币正面朝上的概率)。

结果[8, 6, 5]的准确含义:

这个结果不是 "3×10 次实验" 的整体统计,而是:

  1. 第一次二项试验:做了 10 次伯努利试验,成功 8 次;
  2. 第二次二项试验:做了 10 次伯努利试验,成功 6 次;
  3. 第三次二项试验:做了 10 次伯努利试验,成功 5 次。
api参数定义
python 复制代码
from transformers import TrainingArguments

batch_size = 64
# Show the training loss with every epoch
logging_steps = len(downsampled_dataset["train"]) // batch_size
model_name = model_checkpoint.split("/")[-1]

training_args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-imdb",
    overwrite_output_dir=True,
    eval_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
    fp16=True,
    logging_steps=logging_steps,
)

logging_steps开启以跟踪每个epoch的训练损失

!NOTE

step定义为1个epoch中经历batch的次数,例如1000样本,batch设为64,则step1000//64=15。因此这里设置logging_steps恰好为1个epoch(隔15个step)打印训练损失一次。

#重要 如果后续要使用全量掩码数据整理器替换原始,需要添加一行remove_unused_columns=False,不然word_ids列会被删除报错。

python 复制代码
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=whole_word_masking_data_collator,
    tokenizer=tokenizer,
)
困惑度指标

假设我们的测试集主要由语法正确的句子组成,那么衡量语言模型质量的一种方法是计算它在测试集所有句子中为下一个词分配的概率。高概率表明模型对未见过的例子并不感到"惊讶"或"困惑" ,并表明它已经学习了该语言的基本语法模式。困惑度有多种数学定义,但我们将使用将其定义为交叉熵损失的指数。因此,我们可以使用 Trainer.evaluate() 函数计算测试集上的交叉熵损失,然后对结果取指数,从而计算预训练模型的困惑度。

python 复制代码
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

越低的perplexity表明语言模型越好。

使用Accelerate微调

由于评估时动态随机掩码会导致每次的困惑度不同,因此改为手动静态掩码

python 复制代码
def insert_random_mask(batch):
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]
    masked_inputs = data_collator(features)
    # Create a new "masked" column for each column in the dataset
    return {"masked_" + k: v.numpy() for k, v in masked_inputs.items()}

downsampled_dataset = downsampled_dataset.remove_columns(["word_ids"])
eval_dataset = downsampled_dataset["test"].map(
    insert_random_mask,
    batched=True,
    remove_columns=downsampled_dataset["test"].column_names,
)
eval_dataset = eval_dataset.rename_columns(
    {
        "masked_input_ids": "input_ids",
        "masked_attention_mask": "attention_mask",
        "masked_labels": "labels",
    }
)

"列优先" 批量数据转为 "行优先" 样本列表

输入 batch 格式:{"input_ids": [样本1_ids, 样本2_ids], "attention_mask": [样本1_mask, 样本2_mask], ...}

输出 features 格式:[{'input_ids': 样本1_ids, 'attention_mask': 样本1_mask, ...}, {'input_ids': 样本2_ids, ...}]

!NOTE

#重要 适配 data_collator 要求的输入格式:每个元素是一个样本字典

python 复制代码
from torch.utils.data import DataLoader
from transformers import default_data_collator

batch_size = 64
train_dataloader = DataLoader(
    downsampled_dataset["train"],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator,
)
eval_dataloader = DataLoader(
    eval_dataset, batch_size=batch_size, collate_fn=default_data_collator
)

model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)

from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

from huggingface_hub import get_full_repo_name
model_name = "distilbert-base-uncased-finetuned-imdb-accelerate"
repo_name = get_full_repo_name(model_name)
repo_name
from huggingface_hub import Repository
output_dir = model_name
repo = Repository(output_dir, clone_from=repo_name)

from tqdm.auto import tqdm
import torch
import math

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)
            
        loss = outputs.loss
        losses.append(accelerator.gather(loss.repeat(batch_size)))

    losses = torch.cat(losses)
    losses = losses[: len(eval_dataset)]
    try:
        perplexity = math.exp(torch.mean(losses))
    except OverflowError:
        perplexity = float("inf")

    print(f">>> Epoch {epoch}: Perplexity: {perplexity}")

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        repo.push_to_hub(
            commit_message=f"Training in progress epoch {epoch}", blocking=False
        )

问题

总结