目录
[1 加载预训练模型的分词器](#1 加载预训练模型的分词器)
[2 加载本地数据集](#2 加载本地数据集)
[3 数据预处理](#3 数据预处理)
[4 创建数据加载器](#4 创建数据加载器)
[5 定义下游任务的模型](#5 定义下游任务的模型)
[6 测试代码](#6 测试代码)
[7 训练代码](#7 训练代码)
#加载预训练的翻译分词器之前需要先安装一个第三方库
-后面接的是清华源
! pip install sentencepiece -i Simple Index
#sentencepiece开源工具, 可以更好的生成词向量
1 加载预训练模型的分词器
python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../data/model/opus-mt-en-ro/', use_fast=True)
print(tokenizer)
MarianTokenizer(name_or_path='../data/model/opus-mt-en-ro/', vocab_size=59543, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={ 0: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 1: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 59542: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), }
python
#假设文本,查看分词器的输出结果
text = [['hello, everyone today is a good day', 'It is late, please go home']]
tokenizer.batch_encode_plus(text)
{'input_ids': [[92, 778, 3, 1773, 879, 32, 8, 265, 431, 84, 32, 1450, 3, 709, 100, 540, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
2 加载本地数据集
python
from datasets import load_dataset
dataset = load_dataset('../data/datasets/wmt16-ro-en/')
dataset
DatasetDict({ train: Dataset({ features: ['translation'], num_rows: 610320 }) validation: Dataset({ features: ['translation'], num_rows: 1999 }) test: Dataset({ features: ['translation'], num_rows: 1999 }) })
python
#数据采样,数据量太多的, 需要随机抽取一些
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))
3 数据预处理
python
#查看训练数据的第一条数据
dataset['train'][0]
{'translation': {'en': 'For these reasons I voted in favour of the proposal for a new regulation that aims for greater clarity and transparency in the GSP system.', 'ro': 'Din aceste motive am votat în favoarea propunerii de nou regulament care își propune o mai mare claritate și transparență în sistemul SPG.'}}
python
def preprocess_function(data, tokenizer):
"""定义数据预处理的函数"""
#分别获取'en'与'ro'对应的文本句子
en = [ex['en'] for ex in data['translation']]
ro = [ex['ro'] for ex in data['translation']]
#对'en'文本进行编码分词
data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)
#对'ro'文本进行编码分词,并将结果的'input_ids'作为labels
with tokenizer.as_target_tokenizer():
data['labels'] = tokenizer.batch_encode_plus(
ro, max_length=128, truncation=True)['input_ids']
return data
python
#用map函数将定义的预处理函数加载进来
dataset = dataset.map(preprocess_function,
batched=True,
batch_size=1000,
num_proc=1,
remove_columns=['translation'],
fn_kwargs={'tokenizer' : tokenizer})
python
#查看训练数据的第一条数据
print(dataset['train'][0])
{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
python
#数据批次处理函数:将数据一批批进行输出
def collate_fn(data):
# 求最长的label
max_length=max([len(i['labels']) for i in data])
for i in data:
#获取每一句需要补充的pad数量,赋值为100,
pads = [-100] * (max_length - len(i['labels']))
#每一句都加上需要补的pad
i['labels'] = i['labels'] + pads
#会自动将数据集中的所有类型的数据都按照最大序列长度进行补全pad
data = tokenizer.pad(
encoded_inputs=data,
padding=True,
max_length=None,
pad_to_multiple_of=None, #数据位数补齐到指定的倍数上(否)
return_tensors='pt'
)
#序列数据也有编码器的输入数据 decoder_input_ids
#字典添加数据的方式
data['decoder_input_ids'] = torch.full_like(data['labels'],
tokenizer.get_vocab()['pad'],
dtype=torch.long)
#第一个token是cls,不需要传入校验预测值,就从索引为1的开始
data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
data['decoder_input_ids'][data['decoder_input_ids'] == -100] = tokenizer.get_vocab()['<pad>']
return data
python
tokenizer.get_vocab()['pad'], tokenizer.get_vocab()['<pad>']
4 创建数据加载器
python
import torch
loader = torch.utils.data.DataLoader(dataset=dataset['train'],
batch_size=8,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
for data in loader:
break
data
{'input_ids': tensor([[ 12, 182, 381, 129, 13, 3177, 4, 397, 3490, 51, 4, 31307, 8305, 30, 196, 451, 1304, 30, 314, 57, 462, 5194, 14, 4, 6170, 1323, 13, 198, 13, 64, 239, 3473, 1151, 20, 1273, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [ 40, 16127, 56, 3024, 12, 76, 248, 13, 2043, 13500, 3, 85, 932, 10119, 3, 4077, 14, 4, 2040, 5589, 3551, 12, 123, 444, 4, 1586, 2716, 15373, 3, 193, 174, 154, 166, 11192, 279, 4391, 4166, 20, 85, 3524, 18, 33, 32, 381, 510, 20, 238, 14180, 2, 0], [ 67, 3363, 14, 8822, 3, 16751, 18, 244, 4704, 2028, 108, 4, 20738, 1058, 1136, 2936, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [ 20, 4243, 2, 10, 2587, 14, 102, 3, 12, 182, 129, 13, 22040, 238, 11617, 3, 372, 11, 3292, 46367, 21, 464, 732, 3, 37, 4, 2082, 14, 4, 1099, 211, 10197, 879, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [ 172, 2764, 3, 4413, 65, 141, 14, 3459, 5625, 234, 4877, 6319, 13, 4194, 7413, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [ 84, 32, 4925, 4, 4523, 13, 4, 16687, 14, 234, 24570, 13, 116, 20, 323, 37989, 745, 14, 7837, 25625, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [ 172, 63, 2763, 3, 12, 3177, 4, 10745, 11879, 14, 1137, 13, 3365, 4, 1991, 2981, 82, 4047, 82, 1116, 433, 4, 826, 14, 4, 692, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [44638, 2649, 88, 66, 4035, 18556, 3360, 37, 239, 3196, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'labels': tensor([[ 109, 519, 3, 3174, 19, 8966, 4429, 1129, 3620, 52, 6577, 42, 10241, 28833, 22883, 19, 5965, 24, 22724, 21, 11887, 28, 737, 24, 35130, 5, 43, 17205, 39, 8, 9, 1240, 170, 850, 71, 5449, 21, 17, 1769, 2, 0, -100, -100, -100, -100, -100, -100, -100], [ 40, 16127, 56, 10499, 270, 2187, 3595, 33759, 3, 3595, 259, 34901, 3, 17, 953, 19796, 2040, 5, 17085, 270, 3247, 5354, 666, 3188, 3, 43, 17, 3595, 259, 29, 57, 42, 2410, 38, 22, 279, 10978, 31, 55, 5, 519, 3503, 17, 16223, 266, 924, 2, 0], [ 135, 36886, 71, 8, 16368, 44, 4707, 3, 22432, 31, 552, 1211, 221, 11563, 44, 17422, 3, 44, 1136, 3038, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ 17, 1919, 2, 10, 3043, 5, 283, 3, 1064, 1897, 19, 13549, 39913, 3, 4892, 21, 3292, 463, 7397, 464, 732, 3, 39, 6021, 10429, 38, 43, 575, 2746, 2105, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ 172, 2957, 3, 4413, 8, 79, 916, 371, 473, 94, 9730, 1000, 7239, 9, 27722, 452, 28724, 93, 59, 122, 21, 4762, 7413, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ 5504, 14147, 21, 509, 881, 5575, 21, 411, 5434, 22, 2322, 39, 42, 2401, 136, 40673, 5, 23269, 49, 9272, 22, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [ 127, 252, 2763, 3, 8966, 14761, 5558, 8, 1750, 5, 8, 3387, 3, 511, 50, 1434, 1369, 3, 394, 24, 18279, 3027, 3, 42, 1101, 291, 7627, 38, 6356, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [36966, 3, 39046, 29, 80, 5395, 253, 24, 300, 39, 2661, 405, 2, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]]), 'decoder_input_ids': tensor([[34426, 109, 519, 3, 3174, 19, 8966, 4429, 1129, 3620, 52, 6577, 42, 10241, 28833, 22883, 19, 5965, 24, 22724, 21, 11887, 28, 737, 24, 35130, 5, 43, 17205, 39, 8, 9, 1240, 170, 850, 71, 5449, 21, 17, 1769, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 40, 16127, 56, 10499, 270, 2187, 3595, 33759, 3, 3595, 259, 34901, 3, 17, 953, 19796, 2040, 5, 17085, 270, 3247, 5354, 666, 3188, 3, 43, 17, 3595, 259, 29, 57, 42, 2410, 38, 22, 279, 10978, 31, 55, 5, 519, 3503, 17, 16223, 266, 924, 2], [34426, 135, 36886, 71, 8, 16368, 44, 4707, 3, 22432, 31, 552, 1211, 221, 11563, 44, 17422, 3, 44, 1136, 3038, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 17, 1919, 2, 10, 3043, 5, 283, 3, 1064, 1897, 19, 13549, 39913, 3, 4892, 21, 3292, 463, 7397, 464, 732, 3, 39, 6021, 10429, 38, 43, 575, 2746, 2105, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 172, 2957, 3, 4413, 8, 79, 916, 371, 473, 94, 9730, 1000, 7239, 9, 27722, 452, 28724, 93, 59, 122, 21, 4762, 7413, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 5504, 14147, 21, 509, 881, 5575, 21, 411, 5434, 22, 2322, 39, 42, 2401, 136, 40673, 5, 23269, 49, 9272, 22, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 127, 252, 2763, 3, 8966, 14761, 5558, 8, 1750, 5, 8, 3387, 3, 511, 50, 1434, 1369, 3, 394, 24, 18279, 3027, 3, 42, 1101, 291, 7627, 38, 6356, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542], [34426, 36966, 3, 39046, 29, 80, 5395, 253, 24, 300, 39, 2661, 405, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542]])}
python
for k, v in data.items():
print(k, v.shape)
input_ids torch.Size([8, 50]) attention_mask torch.Size([8, 50]) labels torch.Size([8, 48]) decoder_input_ids torch.Size([8, 48])
5 定义下游任务的模型
python
#翻译任务是标准的seq2seq的任务, LM是使用的模型的Linear Model那一层, MarianModel是专门用于翻译的模型
from transformers import AutoModelForSeq2SeqLM, MarianModel
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
#加载的预训练模型与加载的预训练分词器保持一致
self.pretrained = MarianModel.from_pretrained('../data/model/opus-mt-en-ro/')
#创建一个变量 最后一层有vocab_size个类别概率输出,对应的也会有vocab_size个bias参数
self.register_buffer('final_logits_bias', torch.zeros(1, tokenizer.vocab_size))
#输出层:全连接层
self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)
#加载预训练权重参数
parameters = AutoModelForSeq2SeqLM.from_pretrained('../data/model/opus-mt-en-ro/')
self.fc.load_state_dict(parameters.lm_head.state_dict())
#创建损失函数 ,创建在训练代码里面也可以!!!
self.criterion = torch.nn.CrossEntropyLoss()
def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
logits = self.pretrained(input_ids=input_ids,
attention_mask= attention_mask,
decoder_input_ids=decoder_input_ids)
#获取最后一层的hidden_state
logits = logits.last_hidden_state
#输入全连接层中(输出层), 加上一个自己创建的偏差self.final_logits_bias
logits = self.fc(logits) + self.final_logits_bias
#计算损失
#flatten()在这里相当于reshape, 将logits三维变二维, labels二维变一维
loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())
return {'loss': loss, 'logits': logits}
python
model = Model() #创建模型时,会重新加载预训练模型,会占用C盘内存
#查看一下模型的参数量
print(sum(p.numel() for p in model.parameters()))
105634816
python
#试跑一下
outs = model(**data)
outs['loss'], outs['logits'].shape
(1.524869441986084, torch.Size([8, 48, 59543]))
6 测试代码
python
data['input_ids'][0]
tensor([ 12, 182, 381, 129, 13, 3177, 4, 397, 3490, 51, 4, 31307, 8305, 30, 196, 451, 1304, 30, 314, 57, 462, 5194, 14, 4, 6170, 1323, 13, 198, 13, 64, 239, 3473, 1151, 20, 1273, 2, 0, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542, 59542])
python
def test(model):
#创建测试时输入的数据加载器
loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],
batch_size=8,
collate_fn=collate_fn,
shuffle=True,
drop_last=True)
predictions = [] #用于存放预测的结果
references = [] #用于存放真实的结果
for i, data in enumerate(loader_test):
#只需要下游任务的前向传播,不需要反向传播
with torch.no_grad():
#outs是下游任务模型的输出结果,是一个元祖,包含loss和logits
outs = model(**data)
#预测结果 outs['logits']是一个三维数组 (batch_size , 序列长度lens, 类别数量vocab_size)
pred = tokenizer.batch_decode(outs['logits'].argmax(dim=2))
#真实值
label = tokenizer.batch_decode(data['decoder_input_ids'])
#将预测值与真实值分别放于列表中
predictions.append(pred)
references.append(label)
if i % 2 == 0:
print(i)
input_ids = tokenizer.decode(data['input_ids'][0])
print('input_ids=', input_ids)
print('pred=', pred[0])
print('label=', label[0])
if i == 10:
break
#for遍历组成一个列表
references = [[j] for j in references]
python
test(model)
0 input_ids= Do you have a super-talented and spontaneous child, who likes movies?</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Ave un copil super talenttalentat şi spontan, căruia îi plac filmele </s> ,,, - - - - - - - - - - - Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ai Ave Ave Ave label= pad Ai un copil super-talentat și spontan, căruia îi plac filmele?</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 2 input_ids= For the new study, researchers from Taiwan's Min-Sheng General Hospital used data collected since 2007 in a trial comparing two kinds of bariatric surgery - gastric bypass and sleeve gastrectomy - to medical treatments for type 2 diabetes in people who were mildly obese.</s> pred= În noul studiu, cercetătorii de la Spitalul General Min-Sheng din Taiwan au folosit datele colectate începând 2007 într-un studiual care care s compară două tipuri de chirurgie bariatrică - bypass gastric şi gatrectomie - mâneceoane - cu tratamentele medical pentru diabetului zaharat de tip 2 la persoanele care o ușoară.</s> label= pad Pentru noul studiu, cercetătorii de la Spitalul General Min-Sheng din Taiwan au folosit datele strânse după 2007 într-un trial în care se compară două tipuri de chirurgie bariatrică - bypass gastric și gastrectomie în manșon - în tratamentul medical al diabetului zaharat de tip 2 la persoane cu obezitate ușoară. 4 input_ids= After the public emotion passes, we will return to the reasons we have created this space, the freedom of movement that is not only for people but also for goods.</s> pred= Dupăa ce emo emotia publica, ne vom intoarce la motivele pentru care am creat acesttiul,, libertatea libertate de miscare, nu este numai pentru oameni, ci si pentru bunuri.</s> label= pad Dupa ce trece emotia publica, ne vom intoarce la motivele pentru care am creat spatiul acesta, aceasta libertate de miscare care nu e numai pentru oameni, ci și pentru bunuri. 6 input_ids= I thought that what I was doing at the time was good for me and I think it was indeed good for me.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Crede ca crezut, ca ceea bine pentru mine si cred că mi fost într pentru mine,</s> ,, - - - - - - - - - Mi M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M M label= pad Asa am considerat atunci ca e bine pentru mine și cred ca a fost bine pentru mine.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 8 input_ids= In the eyes of the radical left wing, Brussels is an agent of international capitalism and a promoter of globalisation which imposed austerity to the poor.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= În ochii aripiii radicale,-ul este un capitalismului interna și promo promotor al globalizării, a impus austeritate săraci săracilor.</s> ,,, = În În În În În În În În În În În În În În În În În În În În În În label= pad În ochii stângii radicale Bruxelles-ul este agentul capitalismului internațional și un promotor al globalizării care a impus austeritatea săracilor.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 10 input_ids= We must be more demanding.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Trebuie să fim mai exigentisi.</s> ,,,,,,,------ ar ar - - - - - - Ei Ei Ei ar ar ar ar ar ar ar ar,,,,,,,,, ar ar ar ar ar ar ar ar ar ar ar ar ar label= pad Trebuie sa fim mai pretentiosi.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
7 训练代码
python
from transformers import AdamW
from transformers.optimization import get_scheduler
#设置设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda', index=0)
python
#训练代码
def train(model):
#设置优化器与学习率衰减计划
optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = get_scheduler(name='linear',
num_warmup_steps=0, #没有预热缓冲区,从一开始学习率就开始梯度下降
num_training_steps=len(loader),
optimizer=optimizer)
#模型发到设备上, 调到训练模式
model.to(device)
model.train()
for i,data in enumerate(loader):
#不一个个变量接收数据,并将变量发送到设备上的 简便写法
for k in data.keys():
#将data传到设备上
data[k] = data[k].to(device)
#将data里的数据传到下游任务模型中,获取输出结果outs(一个字典)
outs = model(**data)
#从输出结果中获取损失
loss = outs['loss']
#进行反向传播数据的必须是标量(就是一个数值)
#反向传播
loss.backward()
#为了梯度下降的稳定性,进行梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#梯度更新
optimizer.step()
scheduler.step()
#梯度清零
optimizer.zero_grad()
model.zero_grad()
if i % 50 == 0:
#聚合类别概率
outs = outs['logits'].argmax(dim=2)
#准确的数量
correct = (data['decoder_input_ids'] == outs).sum().item()
#数据总量
total = data['decoder_input_ids'].shape[1] * 8 #batch_size=8
#计算准确率
accuracy = correct / total
predictions = []
references = []
for j in range(8):
pred = tokenizer.decode(outs[j])
label = tokenizer.decode(data['decoder_input_ids'][j])
predictions.append(pred)
references.append(label)
#取出lr,方式基本是固定的
lr = optimizer.state_dict()['param_groups'][0]['lr']
print(i, loss.item(), accuracy, lr)
python
train(model)
#翻译问题看准确率没有意义,需要看损失是否在减小,还有预测的最终结果是否正确
0 0.8231143355369568 0.005319148936170213 1.9992e-05 50 0.8883197903633118 0.0033333333333333335 1.9592e-05 100 1.1980996131896973 0.00684931506849315 1.9192000000000002e-05 150 0.8107476234436035 0.00646551724137931 1.8792000000000002e-05 200 1.0034911632537842 0.0022522522522522522 1.8392e-05 250 0.7016224265098572 0.004032258064516129 1.7992e-05 300 1.063989281654358 0.0047169811320754715 1.7592000000000004e-05 350 0.5443494915962219 0.00390625 1.7192e-05 400 0.8666834235191345 0.006097560975609756 1.6792e-05 450 0.6649302840232849 0.003048780487804878 1.6392e-05 500 0.46878930926322937 0.0 1.5992000000000002e-05 550 0.7509095072746277 0.00211864406779661 1.5592e-05 600 0.9289752244949341 0.0026595744680851063 1.5192000000000003e-05 650 1.041787028312683 0.005859375 1.4792000000000002e-05 700 0.9425965547561646 0.006024096385542169 1.4392000000000002e-05 750 0.9585253000259399 0.0 1.3992000000000001e-05 800 0.5976186394691467 0.011029411764705883 1.3592000000000001e-05 850 0.816501796245575 0.0015060240963855422 1.3192e-05 900 0.844330906867981 0.001488095238095238 1.2792e-05 950 0.706694483757019 0.0 1.2392000000000003e-05 1000 0.8571699857711792 0.004464285714285714 1.1992000000000001e-05 1050 0.6503761410713196 0.005319148936170213 1.1592000000000002e-05 1100 0.994631826877594 0.0026041666666666665 1.1192e-05 1150 0.8582945466041565 0.003787878787878788 1.0792000000000001e-05 1200 0.48289453983306885 0.004310344827586207 1.0392e-05 1250 0.6844062209129333 0.0030864197530864196 9.992e-06 1300 0.4874255061149597 0.005208333333333333 9.592e-06 1350 0.9279842972755432 0.001644736842105263 9.192000000000001e-06 1400 1.0108321905136108 0.001838235294117647 8.792e-06 1450 0.8091197609901428 0.0020161290322580645 8.392e-06 1500 0.7929010987281799 0.0 7.992e-06 1550 0.6564688086509705 0.0026041666666666665 7.592e-06 1600 1.0069290399551392 0.001736111111111111 7.192e-06 1650 0.6669528484344482 0.0 6.792000000000001e-06 1700 0.7156780362129211 0.004166666666666667 6.392000000000001e-06 1750 0.7411083579063416 0.0020491803278688526 5.992e-06 1800 0.7482017874717712 0.002631578947368421 5.592000000000001e-06 1850 0.8117333650588989 0.006172839506172839 5.1920000000000004e-06 1900 0.758994996547699 0.007936507936507936 4.792000000000001e-06 1950 0.5511546730995178 0.01 4.3920000000000005e-06 2000 0.7756522297859192 0.0033783783783783786 3.992e-06 2050 0.7603906393051147 0.004032258064516129 3.5920000000000005e-06 2100 1.021585464477539 0.014204545454545454 3.192e-06 2150 0.7125442624092102 0.005681818181818182 2.792e-06 2200 0.6212157011032104 0.002232142857142857 2.392e-06 2250 0.948815107345581 0.00423728813559322 1.992e-06 2300 0.8442288637161255 0.003472222222222222 1.5920000000000002e-06 2350 0.8308764696121216 0.004032258064516129 1.1920000000000002e-06 2400 0.9954149127006531 0.005319148936170213 7.920000000000001e-07 2450 0.7432130575180054 0.0037313432835820895 3.92e-07
8.保存与加载训练好的模型
python
#保存训练好的模型
torch.save(model, '../data/model/翻译.model')
#加载保存的模型
model_2 = torch.load('../data/model/翻译.model', map_location='cpu')
test(model_2)
0 input_ids= The Japanese culture courses that we have proposed have been designed as an invitation to the general public, to make the first steps on the road to self-discovery through art.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Cursurile de cultură japoneză pe care le le-am propus au fost concepute ca o invita adresată pentru publicului larg, pentru a face primii pași pe drumul spre autodeceriaștere prin artă.</s> </s> ... - În -- - - A - În Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs Curs label= pad Cursurile de cultură japoneză pe care vi le-am propus au fost concepute ca o invitație, adresată publicului larg, de a realiza primii pași pe drumul spre autocunoaștere prin artă.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 2 input_ids= According to him, at the request of Corneliu Vadim Tudor's family, the church will remain open from Wednesday to Thursday. Last night it was closed at 01.00 a.m., and reopened at 07.00. a.m.</s> pred= Potrivit acestuia, la cererea familiei Corneliu Corneliu Vadim Tudor, biserica va rămânemane deschisa de de perioada trecut miercuri până joi. la trecutăa a închisăchisa la ora 01:00 a redeschisa lata la la ora.00 a00 label= pad Potrivit acestuia, la cererea familiei lui Corneliu Vadim Tudor, biserica va ramane deschisa și în noaptea de miercuri spre joi, noaptea trecuta fiind inchisa la ora 01.00 și redeschisa dimineata, la 07.00. 4 input_ids= The market has rebounded somewhat this month, with the Dow now down nearly 10% from the May high.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Pia a revenitâat oarecum luna aceasta, iar Dow fiind acum cu aproape 10% mai nivelulul din luna.</s> luiuluiului. la În De De În Piața Piața Piața Piața Piața Piața Piața Piața Piața Piața Pia Piața Piața Piața Piața Pia Piața Pia Piața Pia label= pad Piața a reculat oarecum luna aceasta, indicele Dow fiind acum cu aproximativ 10% sub maximul din mai.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 6 input_ids= Addressing the issue of violence, the Dalai Lama also commented on George Bush's actions following 9/11 terrorist attacks, claiming that the US" violent response engendered a chain of uncontrollable events.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> pred= Ad la problema violen, Dalai Lama a comentat de asemenea acțiunile acțiunile lui George Bush în atacurile teroriste din la 9 septembrie, afirmând că răspunsul violent al SUA SUA a generat un lanț de evenimente decontrolabile.</s> Uul Da G aul La A George - La În La A În - - La label= pad Referitor la problema violenței, Dalai Lama a comentat de asemenea despre acțiunile lui George Bush după atacurile teroriste de pe 11 septembrie, afirmând că răspunsul violent din partea SUA a generat un lanț de evenimente incontrolabile.</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 8 input_ids= Before arriving at the Court of Appeal of Iasi, the van transporting the rapists from Vaslui was involved in an accident in Păcurari district as the car was entering Iasi city through the beltway.</s> pred= Înainte de a ajunge la Curtea de Apel dinşi, duba care viol au transportați violatorii din Vaslui a fost implicată într-un accident în în districtul Purari, în intrând în orașulși prin autostrad.</s> label= pad Înainte de a ajunge la Curtea de Apel Iași, duba cu care erau transportați violatorii din Vaslui a fost implicată într-un accident produs în cartierul Păcurari, mașina intrând în Iași pe centură. 10 input_ids= "If the people voted for us it does not mean they are stupid enough to vote for us again if they don't see anything," said Blăjuț, who stated that the main reason the electorate is no longer interested in politics is the hypocrisy of the elected persons.</s> pred= "Dacă oamenii au-au votat, nu înseamnă că sunt destulşti să voteze voteze voteze dacă nu văd nimic", a declarat Blăjuț, care a declarat că principalul motiv pentru care electoratul nu mai este interesat de politică este ipocrizia persoanelorșilor.</s> label= pad "Dacă oamenii ne-au votat, nu înseamnă că sunt proști să ne mai voteze dacă nu văd nimic", a spus Blăjuț, care a apreciat că principalul motiv pentru care electoratul nu mai este interesat de politică este ipocrizia aleșilor.