import math
import numpy as np
from collections import Counter
import torch
from torch import nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import tqdm
import random
import sys
print("Pytorch 版本:", torch.__version__)
print("Python 版本:", sys.version)
复制代码
Pytorch 版本: 2.1.2
Python 版本: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
# 数据路径
DATA_PATH = '/kaggle/input/poetry/poetry.txt'
# 先看下原始数据,每一行格式为"诗的标题:诗的内容"
with open(DATA_PATH, 'r', encoding='utf-8') as f:
lines = f.readlines()
for i in range(0, 5):
print(lines[i])
print(f"origin_line_count = {len(lines)}")
# 最小词频
MIN_WORD_FREQUENCY = 8
# 统计词频,利用Counter可以直接按单个字符进行统计词频
counter = Counter()
for line in poetry:
counter.update(line)
# 过滤掉低词频的词
tokens = [token for token, count in counter.items() if count >= MIN_WORD_FREQUENCY]
python复制代码
# 打印一下出现次数前5的字
for i, (token, count) in enumerate(counter.items()):
print(token, "->",count)
if i >= 4:
break;
# 读取一批数据,并解码
temp_dataloader = DataLoader(dataset=my_dataset, batch_size=8, shuffle=True)
one_batch_data = next(iter(temp_dataloader))
for poetry_line_id in one_batch_data.tolist():
poetry_line = tokenizer.decode(poetry_line_id)
print("".join([w for w in poetry_line if w != Tokenizer.PAD]))
for epoch in range(1, EPOCH_NUM + 1):
model.train()
total_loss = 0
data_progress = tqdm.tqdm(train_dataloader, desc="Train...")
for step, data in enumerate(data_progress, 1):
data = data.to(DEVICE)
# 随机选一个位置,拆分src和tgt
e = random.randint(1, 20)
src = data[:, :e]
# tgt不要最后一个token,tgt_y不要第一个的token
tgt, tgt_y = data[:, e:-1], data[:, e + 1:]
# 进行Transformer的计算,再将结果送给最后的线性层进行预测
out = model(src, tgt)
out = model.predictor(out)
# 使用view时,前面的数据必须是在内存连续的(即is_contiguous()为true)
# 使用permute后,会导致数据不是内存连续的(即is_contiguous()为false),需要先调用contiguous(),才能继续使用view
loss = criteria(out.view(-1, out.size(-1)), tgt_y.permute(1, 0).contiguous().view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# 更新训练进度
data_progress.set_description(f"Train... [epoch {epoch}/{EPOCH_NUM}, loss {(total_loss / step):.5f}]")
复制代码
Train... [epoch 1/50, loss 3.66444]: 100%|██████████| 381/381 [00:10<00:00, 35.40it/s]
Train... [epoch 2/50, loss 3.35216]: 100%|██████████| 381/381 [00:09<00:00, 39.61it/s]
Train... [epoch 3/50, loss 3.27860]: 100%|██████████| 381/381 [00:09<00:00, 39.44it/s]
Train... [epoch 4/50, loss 3.15286]: 100%|██████████| 381/381 [00:09<00:00, 39.10it/s]
Train... [epoch 5/50, loss 3.05621]: 100%|██████████| 381/381 [00:09<00:00, 39.32it/s]
Train... [epoch 6/50, loss 2.97613]: 100%|██████████| 381/381 [00:09<00:00, 39.42it/s]
Train... [epoch 7/50, loss 2.91857]: 100%|██████████| 381/381 [00:09<00:00, 38.83it/s]
Train... [epoch 8/50, loss 2.88052]: 100%|██████████| 381/381 [00:09<00:00, 39.59it/s]
Train... [epoch 9/50, loss 2.78789]: 100%|██████████| 381/381 [00:09<00:00, 39.19it/s]
Train... [epoch 10/50, loss 2.77379]: 100%|██████████| 381/381 [00:09<00:00, 38.24it/s]
......
Train... [epoch 41/50, loss 2.25991]: 100%|██████████| 381/381 [00:09<00:00, 39.89it/s]
Train... [epoch 42/50, loss 2.24437]: 100%|██████████| 381/381 [00:09<00:00, 39.72it/s]
Train... [epoch 43/50, loss 2.23779]: 100%|██████████| 381/381 [00:09<00:00, 39.09it/s]
Train... [epoch 44/50, loss 2.25092]: 100%|██████████| 381/381 [00:09<00:00, 39.16it/s]
Train... [epoch 45/50, loss 2.23653]: 100%|██████████| 381/381 [00:09<00:00, 39.90it/s]
Train... [epoch 46/50, loss 2.20175]: 100%|██████████| 381/381 [00:09<00:00, 39.51it/s]
Train... [epoch 47/50, loss 2.22046]: 100%|██████████| 381/381 [00:09<00:00, 39.83it/s]
Train... [epoch 48/50, loss 2.20892]: 100%|██████████| 381/381 [00:09<00:00, 39.84it/s]
Train... [epoch 49/50, loss 2.22276]: 100%|██████████| 381/381 [00:09<00:00, 39.35it/s]
Train... [epoch 50/50, loss 2.20212]: 100%|██████████| 381/381 [00:09<00:00, 39.75it/s]
7. 推理
直接推理
python复制代码
model.eval()
with torch.no_grad():
word_ids = tokenizer.encode("清明时节")
src = torch.LongTensor([word_ids[:-2]]).to(DEVICE)
tgt = torch.LongTensor([word_ids[-2:-1]]).to(DEVICE)
# 一个一个词预测,直到预测为<eos>,或者达到句子最大长度
for i in range(64):
out = model(src, tgt)
# 预测结果,只需最后一个词
predict = model.predictor(out[-1:])
# 找出最大值的index
y = torch.argmax(predict, dim=2)
# 和之前的预测结果拼接到一起
tgt = torch.cat([tgt, y], dim=1)
# 如果为<eos>
if y == tokenizer.eos_id:
break
src_decode = "".join([w for w in tokenizer.decode(src[0].tolist()) if w != Tokenizer.PAD])
print(f"src = {src}, src_decode = {src_decode}")
tgt_decode = "".join([w for w in tokenizer.decode(tgt[0].tolist()) if w != Tokenizer.PAD])
print(f"tgt = {tgt}, tgt_decode = {tgt_decode}")
def generate_random_poem(tokenizer, model, text):
"""
随机生成一首诗、自动续写
"""
if text == None or text == "":
text = tokenizer.id_to_token(random.randint(4, len(tokenizer)))
model.eval()
with torch.no_grad():
word_ids = tokenizer.encode(text)
src = torch.LongTensor([word_ids[:-2]]).to(DEVICE)
tgt = torch.LongTensor([word_ids[-2:-1]]).to(DEVICE)
# 一个一个词预测,直到预测为<eos>,或者达到句子最大长度
for i in range(64):
y = predict(model, src, tgt)
# 和之前的预测结果拼接到一起
tgt = torch.cat([tgt, y.view(1, 1)], dim=1)
# 如果为<eos>
if y == tokenizer.eos_id:
break
# src_decode = "".join([w for w in tokenizer.decode(src[0].tolist()) if w != Tokenizer.PAD])
# print(f"src = {src}, src_decode = {src_decode}")
# tgt_decode = "".join([w for w in tokenizer.decode(tgt[0].tolist()) if w != Tokenizer.PAD])
# print(f"tgt = {tgt}, tgt_decode = {tgt_decode}")
result = torch.cat([src, tgt], dim=1)
result_decode = "".join([w for w in tokenizer.decode(result[0].tolist()) if w != Tokenizer.PAD])
return result_decode
for i in range(0, 5):
poetry_line = generate_random_poem(tokenizer, model, "清明")
print(poetry_line)