参考学习来自:
- https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
- RNN完成姓名分类
- https://download.pytorch.org/tutorial/data.zip
导入库
python
import glob # 用于查找符合规则的文件名
import os
import unicodedata
import string
import torch
import torch.nn as nn
import torch.optim as optim
import random
GPU 配置
python
"Device configuration"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
导入数据
python
"导入数据"
def findFiles(path):
return glob.glob(path)
for i in findFiles('./data/names/*.txt'):
print(i)
output
python
./data/names/Korean.txt
./data/names/Irish.txt
./data/names/Portuguese.txt
./data/names/Vietnamese.txt
./data/names/Czech.txt
./data/names/Russian.txt
./data/names/Scottish.txt
./data/names/German.txt
./data/names/Polish.txt
./data/names/Spanish.txt
./data/names/English.txt
./data/names/French.txt
./data/names/Japanese.txt
./data/names/Dutch.txt
./data/names/Greek.txt
./data/names/Chinese.txt
./data/names/Italian.txt
./data/names/Arabic.txt
查看名字所有字符
python
all_letters = string.ascii_letters + " .,;'" #所有的字母和标点
print(all_letters) # abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
n_letters = len(all_letters)
print(n_letters) # 57
字符都转化为 ASCII 形式
python
# Turn a Unicode string to plain ASCII
# Thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in all_letters
)
print(unicodeToAscii('Ślusàrski')) # Slusarski
category_lines = {} # 每个语言下的名字
all_categories = [] # 保存所有的语言的名字
# 读文件, 返回文件中的每一个
def readLines(filename):
# 读文件, 去空格, 按换行进行划分
lines = open(filename, encoding='utf-8').read().strip().split('\n')
return [unicodeToAscii(line) for line in lines]
# 顺序读取每一个文件, 保存其中的姓名
for filename in findFiles('./data/names/*.txt'):
# basename返回基础的文件名
# splitext将文件名和后缀分开
category = os.path.splitext(os.path.basename(filename))[0]
all_categories.append(category)
lines = readLines(filename)
category_lines[category] = lines
n_categories = len(all_categories)
# 一共有18种语言的名字
print(len(category_lines)) # 18
# 查看某种语言的名字
print(category_lines.keys())
"""
dict_keys(['Korean', 'Irish', 'Portuguese', 'Vietnamese', 'Czech', 'Russian', 'Scottish', 'German', 'Polish',
'Spanish', 'English', 'French', 'Japanese', 'Dutch', 'Greek', 'Chinese', 'Italian', 'Arabic'])
"""
print(category_lines['English'][:5]) # ['Abbas', 'Abbey', 'Abbott', 'Abdi', 'Abel']
名字转化为 Tensor
python
"将Name转为Tensor"
# Find letter index from all_letters, e.g. "a"=0
def letterToIndex(letter):
# 返回字母在字母表中的位置
# abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
return all_letters.find(letter)
# 将整个名字转为<line_length x 1 x n_letters>
def lineToTensor(line):
tensor = torch.zeros(len(line), 1, n_letters)
for li, letter in enumerate(line):
tensor[li][0][letterToIndex(letter)] = 1
return tensor
print(lineToTensor('Bryant').shape) # torch.Size([6, 1, 57])
print(lineToTensor('Mike').shape) # torch.Size([4, 1, 57])
print(lineToTensor('Mike'))
output
python
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]]])
定义 RNN 分类网络
python
"定义网络"
class RNN_NAME(nn.Module):
def __init__(self, input_size, hidden_size, output_size, layers, batch_size):
super(RNN_NAME, self).__init__()
# 这里的input_size相当于是特征的个数, 这里即是所有字母的个数57
self.input_size = input_size
# 这里相当于是rnn输出的维数
self.hidden_size = hidden_size
# 这里是分类的个数, 这里相当于是18, 一共有18类
self.output_size = output_size
# 这里是rnn的层数
self.layers = layers
# batchsize的个数
self.batch_size = batch_size
self.gru = nn.GRU(input_size = self.input_size, hidden_size = self.hidden_size, num_layers = self.layers)
self.FC = nn.Linear(self.hidden_size, self.output_size)
def init_hidden(self):
return torch.zeros(self.layers, self.batch_size, self.hidden_size).to(device)
def forward(self, x):
self.batch_size = x.size(1)
self.hidden = self.init_hidden() # 初始化memory的内容
rnn_out, self.hidden = self.gru(x,self.hidden)
out = self.FC(rnn_out[-1])
return out
RNN 网络初始化
python
"初始化"
layers = 3
input_size = n_letters
hidden_size = 128
output_size = len(all_categories) # 18类
batch_size = 1 # 因为这里name的长度不相同, 所以将batch_size设为1
# 网络初始化
rnn_name = RNN_NAME(input_size, hidden_size, output_size, layers, batch_size).to(device)
# 测试网络输入输出
input_data = lineToTensor('Mike')
out_data = rnn_name(input_data.to(device))
print(out_data.size())
# torch.Size([1, 18])
print(out_data)
"""
tensor([[ 0.0051, -0.0081, -0.0906, 0.1054, 0.0679, -0.0187, -0.0497, -0.0349,
-0.0466, 0.0004, -0.0099, -0.0940, -0.0201, -0.0694, -0.0542, -0.0807,
0.0005, 0.1373]], grad_fn=<AddmmBackward0>)
"""
准备训练数据集
python
"准备训练数据"
# 打印最后显示的结果
def categoryFromOutput(output):
top_n, top_i = output.topk(1)
category_i = top_i[0].item()
return all_categories[category_i], category_i
# Get a random training example
def randomChoice(l):
# 随机返回一个l中的分类
return l[random.randint(0, len(l)-1)]
def randomTrainingExample():
"""随机挑选一种语言的一个名字, 用来产生训练需要的数据
"""
category = randomChoice(all_categories) # 随机选一个语言
line = randomChoice(category_lines[category]) # 随机从这个语言里选一个名字
category_tensor = torch.tensor([all_categories.index(category)]).long() # 将分类转为Tensor
line_tensor = lineToTensor(line) # 将名字转为Tensor
return category, line, category_tensor, line_tensor
for i in range(10):
category, line, category_tensor, line_tensor = randomTrainingExample()
print('category = ', category, '/ line = ', line)
"""
category = Scottish / line = Miller
category = Irish / line = O'Mooney
category = Dutch / line = Kools
category = French / line = Bouchard
category = French / line = Masson
category = English / line = May
category = Spanish / line = Holguin
category = Portuguese / line = Moreno
category = Greek / line = Nomikos
category = Czech / line = Cernohous
"""
计算预测精度
python
"计算精度"
def get_accuracy(logit, target):
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects/batch_size
return accuracy.item()
网络训练
python
"训练"
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_name.parameters(),lr=0.005)
N_EPHOCS = 100
# 共有20074个数据
n_iters = 2000 # 训练n_iters个名字
print_every = 200 # 每print_every轮打印一次结果, 并传递一次误差, 更新系数
for epoch in range(N_EPHOCS):
train_running_loss = 0.0
loss = torch.tensor([0.0]).float().to(device)
train_count = 0
rnn_name.train()
# trainging round
for iter in range(1, n_iters+1):
# 获取样本
category, line, category_tensor, line_tensor = randomTrainingExample()
category_tensor = category_tensor.to(device)
line_tensor = line_tensor.to(device)
# 进行训练
optimizer.zero_grad()
# reset hidden states
rnn_name.hidden = rnn_name.init_hidden()
# forward+backward+optimize
outputs = rnn_name(line_tensor)
guess, guess_i = categoryFromOutput(outputs)
if guess == category:
train_count = train_count + 1
# print(iter)
loss = loss + criterion(outputs, category_tensor)
if iter % print_every == 0:
correct = '✓' if guess == category else '✗ ({})'.format(category)
print('{} /{} {}'.format(line, guess, correct))
# 进行反向传播(print_every次传一次误差)
loss.backward()
optimizer.step()
train_running_loss = train_running_loss + loss.detach().item()
loss = torch.tensor([0.0]).float().to(device)
# 查看准确率
train_acc = train_count / n_iters * 100
print('=======')
print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/n_iters, train_acc))
print('=======')
"验证"
test_count = 0
test_num = 500 # 测试500个名字
rnn_name.eval()
for test_iter in range(1, test_num+1):
category, line, category_tensor, line_tensor = randomTrainingExample()
outputs = rnn_name(line_tensor.to(device))
guess, guess_i = categoryFromOutput(outputs)
if guess == category:
test_count = test_count + 1
print('Test Accuracy : {:<6.4f}%'.format(test_count/test_num*100))
"""
=======
Epoch : 71 | Loss : 0.2698 | Train Accuracy : 92.35 %
=======
Konda /Japanese ✓
Seighin /Irish ✓
Sedmik /Czech ✓
Jestkov /Russian ✓
So /Korean ✓
Laar /Dutch ✓
Mazza /Italian ✓
Abl /Czech ✓
Doyle /Irish ✓
Skwor /Czech ✓
=======
Epoch : 72 | Loss : 0.2583 | Train Accuracy : 91.40 %
=======
Martoyas /Greek ✗ (Russian)
Hallett /English ✓
Bermudez /Spanish ✓
Bonhomme /French ✓
Lam /Chinese ✗ (Vietnamese)
Kourempes /Greek ✓
Matsuki /Japanese ✓
Sokolofsky /Polish ✓
Kuipers /Dutch ✓
Aller /Dutch ✓
=======
Epoch : 73 | Loss : 0.2564 | Train Accuracy : 91.65 %
=======
Xiao /Chinese ✓
Lian /Chinese ✓
Prill /Czech ✓
Glatter /Czech ✓
Bureau /French ✓
Vallance /French ✗ (English)
Johnstone /Scottish ✓
Kassab /Arabic ✓
Atiyeh /Arabic ✓
Luong /Vietnamese ✓
=======
Epoch : 74 | Loss : 0.2610 | Train Accuracy : 91.60 %
=======
Poirier /French ✓
Hahn /German ✓
Botros /Arabic ✓
Adrichem /Dutch ✓
Rios /Portuguese ✗ (Spanish)
Zangari /Italian ✓
Mccallum /Scottish ✓
Ferro /Portuguese ✓
Kim /Vietnamese ✗ (Korean)
Chang /Korean ✓
=======
Epoch : 75 | Loss : 0.2792 | Train Accuracy : 90.85 %
=======
Monte /Italian ✓
Han /Vietnamese ✗ (Korean)
Ying /Chinese ✓
O'Driscoll /Irish ✓
Poirier /French ✓
Tsuruya /Japanese ✓
Ralph /Czech ✗ (English)
Bonheur /French ✓
Romagna /Italian ✓
Krakowski /Polish ✓
=======
Epoch : 76 | Loss : 0.3000 | Train Accuracy : 90.85 %
=======
Fay /French ✓
Plisek /Czech ✓
Travere /French ✓
Moon /Korean ✓
Aalst /Dutch ✓
Sassa /Japanese ✓
Holzmann /German ✓
Kruessel /Czech ✓
Truong /Vietnamese ✓
Kassab /Arabic ✓
=======
Epoch : 77 | Loss : 0.2602 | Train Accuracy : 92.00 %
=======
Janda /Polish ✓
Doan /Vietnamese ✓
O'Reilly /Irish ✓
Torres /Portuguese ✓
Esipovich /Russian ✓
Johnston /Scottish ✓
Viteri /Spanish ✗ (Italian)
Vandroogenbroeck /Dutch ✓
Montero /Spanish ✓
Wan /Chinese ✓
=======
Epoch : 78 | Loss : 0.2460 | Train Accuracy : 92.35 %
=======
Ha /Korean ✗ (Vietnamese)
Christie /Scottish ✓
Pereira /Portuguese ✓
Kuiper /Dutch ✓
Bautista /Spanish ✓
Takewaki /Japanese ✓
Ying /Chinese ✓
Sokoloff /Polish ✓
Raghailligh /Irish ✓
Edgell /English ✓
=======
Epoch : 79 | Loss : 0.2640 | Train Accuracy : 91.60 %
=======
Neal /English ✓
Kafka /Czech ✓
Hay /Scottish ✓
Pelletier /French ✓
Wilchek /Czech ✓
Marchesi /Italian ✓
Michel /Spanish ✓
Ling /Chinese ✓
Zou /Chinese ✓
Freitas /Portuguese ✓
=======
Epoch : 80 | Loss : 0.2579 | Train Accuracy : 92.60 %
=======
Kelly /Scottish ✓
Chou /Korean ✓
Unsworth /English ✓
Larue /French ✓
Tron /Vietnamese ✓
Herodes /Czech ✓
Fremut /Czech ✓
Thai /Vietnamese ✓
Clark /Irish ✗ (Scottish)
Halenkov /Russian ✓
=======
Epoch : 81 | Loss : 0.2453 | Train Accuracy : 92.40 %
=======
Than /Vietnamese ✓
Pfeifer /Czech ✓
Kennedy /Scottish ✓
Klerkx /Dutch ✓
Freitas /Portuguese ✓
Miller /Scottish ✓
Deeb /Arabic ✓
Almetov /Russian ✓
Li /Korean ✓
Perez /Spanish ✓
=======
Epoch : 82 | Loss : 0.2579 | Train Accuracy : 91.85 %
=======
Horiatis /Greek ✓
Clowes /English ✓
Santos /Portuguese ✓
Mochan /Irish ✓
Sokolof /Polish ✓
Ogterop /Dutch ✓
Schwartz /Czech ✓
Abadi /Arabic ✓
Winogrodzki /Polish ✓
Bazzhin /Russian ✓
=======
Epoch : 83 | Loss : 0.2397 | Train Accuracy : 92.25 %
=======
Ybarra /Spanish ✓
Jacobson /English ✓
Shi /Chinese ✓
Berg /Dutch ✓
Araya /Spanish ✓
Banos /Greek ✓
Winograd /Polish ✓
Schneider /German ✗ (Czech)
Calligaris /Italian ✓
Hasbulatov /Russian ✓
=======
Epoch : 84 | Loss : 0.2466 | Train Accuracy : 92.25 %
=======
Felkerzam /Polish ✗ (Russian)
Gomolka /Polish ✓
Belo /Portuguese ✓
Santos /Portuguese ✓
Rim /Korean ✓
Solomon /French ✓
Agnellutti /Italian ✓
Santana /Portuguese ✓
Ramaker /Dutch ✓
Palmisano /Italian ✓
=======
Epoch : 85 | Loss : 0.2390 | Train Accuracy : 92.80 %
=======
Moraitopoulos /Greek ✓
Rosenberger /German ✓
Torres /Portuguese ✓
Sabbagh /Arabic ✓
Aggelen /Dutch ✓
Chieu /Chinese ✓
Ukhtomsky /Russian ✓
Kouros /Greek ✓
Portyansky /Russian ✓
Dinh /Vietnamese ✓
=======
Epoch : 86 | Loss : 0.2173 | Train Accuracy : 93.15 %
=======
Jaskulski /Polish ✓
Mendez /Spanish ✓
Smyth /English ✓
Mo /Korean ✓
Seegers /Dutch ✓
Elford /English ✓
Lau /Chinese ✓
Sargent /French ✓
Zeman /Czech ✓
Ohme /German ✓
=======
Epoch : 87 | Loss : 0.2189 | Train Accuracy : 93.00 %
=======
Gwang /Korean ✓
Zhernakov /Russian ✓
Quraishi /Arabic ✓
Asker /Arabic ✓
Franco /Portuguese ✓
Bacon /Czech ✓
Weineltk /Czech ✓
Thuy /Vietnamese ✓
Close /Greek ✓
Xiang /Chinese ✓
=======
Epoch : 88 | Loss : 0.2497 | Train Accuracy : 91.95 %
=======
Basara /Arabic ✓
Esteves /Portuguese ✓
Metz /German ✓
Lopez /Spanish ✓
Bukoski /Polish ✓
Rao /Italian ✓
Klerk /Dutch ✓
Espinoza /Spanish ✓
Wehunt /German ✓
Giehl /German ✓
=======
Epoch : 89 | Loss : 0.2454 | Train Accuracy : 92.70 %
=======
Deniau /French ✓
Bover /Spanish ✗ (Italian)
Soho /Japanese ✓
Ron /Korean ✓
Dolezal /Czech ✓
Yamamura /Japanese ✓
Ramsay /Scottish ✓
Fromm /German ✓
Huerta /Spanish ✓
Colman /Irish ✓
=======
Epoch : 90 | Loss : 0.2432 | Train Accuracy : 92.55 %
=======
Fei /Chinese ✓
Dinh /Vietnamese ✓
Krawiec /Czech ✓
Urbanek /Czech ✓
Yan /Chinese ✓
Giehl /German ✓
Riain /Irish ✓
Zhestkov /Russian ✓
Almetiev /Russian ✓
Deniel /French ✓
=======
Epoch : 91 | Loss : 0.2318 | Train Accuracy : 92.45 %
=======
Macleod /Scottish ✓
Savatier /French ✓
Chong /Chinese ✗ (Korean)
Prchal /Czech ✓
Tsoumada /Greek ✓
Agosti /Italian ✓
Mojjis /Czech ✓
Mitsui /Japanese ✓
San nicolas /Spanish ✓
Kokoris /Greek ✓
=======
Epoch : 92 | Loss : 0.2163 | Train Accuracy : 92.25 %
=======
Marubeni /Japanese ✓
Suh /Korean ✓
Klerks /Dutch ✓
Roma /Spanish ✓
Lac /Vietnamese ✓
Khoury /Arabic ✓
Hoefler /German ✓
Gibson /Scottish ✓
Jasso /Spanish ✓
Oirschotten /Dutch ✓
=======
Epoch : 93 | Loss : 0.2329 | Train Accuracy : 92.90 %
=======
Kikuchi /Japanese ✓
Blecher /German ✓
Ying /Chinese ✓
Nahas /Arabic ✓
Araujo /Portuguese ✓
Koemans /Dutch ✓
Durant /French ✓
Martzenyuk /Russian ✓
Petru /Czech ✓
Wood /Scottish ✓
=======
Epoch : 94 | Loss : 0.2080 | Train Accuracy : 93.35 %
=======
Vespa /Italian ✓
Kowalczyk /Polish ✓
Nahas /Arabic ✓
Tresler /German ✓
Caivano /Italian ✓
Pho /Vietnamese ✓
Dael /Dutch ✓
Heidl /Czech ✓
Totolos /Greek ✓
Adlersflugel /German ✓
=======
Epoch : 95 | Loss : 0.2400 | Train Accuracy : 92.45 %
=======
Zwolenksy /Czech ✓
Chermak /Czech ✓
Sgro /Italian ✓
Lestrange /French ✓
Hwang /Korean ✓
Barros /Spanish ✗ (Portuguese)
Boyle /Scottish ✓
Schmeling /German ✓
Poplawski /Polish ✓
Pantelas /Greek ✓
=======
Epoch : 96 | Loss : 0.2150 | Train Accuracy : 93.10 %
=======
Asher /English ✓
Sai /Vietnamese ✓
Hakimi /Arabic ✓
Sheng /Chinese ✓
Lebedevich /Russian ✓
Castellano /Spanish ✓
Battaglia /Italian ✓
Cardozo /Portuguese ✓
Sinagra /Italian ✓
Bumgarner /German ✓
=======
Epoch : 97 | Loss : 0.2359 | Train Accuracy : 92.80 %
=======
Lyes /English ✓
Do /Vietnamese ✓
Busch /German ✓
Shimura /Japanese ✓
O'Connor /Irish ✓
Geiger /German ✓
Reier /German ✓
Paris /French ✓
Sokolof /Polish ✓
Limarev /Russian ✓
=======
Epoch : 98 | Loss : 0.2369 | Train Accuracy : 92.55 %
=======
Osladil /Czech ✓
Roy /French ✓
Ortega /Spanish ✓
Ruzzier /Italian ✓
Chou /Korean ✓
Mateus /Portuguese ✓
Ibanez /Spanish ✓
Cabello /Spanish ✓
Duguay /French ✓
Yi /Korean ✓
=======
Epoch : 99 | Loss : 0.2192 | Train Accuracy : 92.65 %
=======
Test Accuracy : 94.8000%
"""
精度不足 95%
下面写写测试流程
python
"单独测试"
input_data = lineToTensor('WMN').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 5.7034, 1.0659, -8.1210, 0.5218, -2.5818, 0.1362, 7.4420, 3.9547,
1.0824, -7.2981, 4.0036, -2.2944, 1.6835, -3.5377, -7.6178, 9.8920,
-0.5789, -1.0545]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""
input_data = lineToTensor('Huang').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 4.1755, -1.2160, -3.8454, 5.4587, -4.6412, 0.8137, 1.2896, 1.7429,
-1.3145, -2.1634, 1.0224, -2.7052, 2.1592, 1.1909, -8.0874, 10.2988,
-1.9252, -0.0729]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""
完整代码
python
import glob # 用于查找符合规则的文件名
import os
import unicodedata
import string
import torch
import torch.nn as nn
import torch.optim as optim
import random
"Device configuration"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
"导入数据"
def findFiles(path):
return glob.glob(path)
for i in findFiles('./data/names/*.txt'):
print(i)
"""
./data/names/Korean.txt
./data/names/Irish.txt
./data/names/Portuguese.txt
./data/names/Vietnamese.txt
./data/names/Czech.txt
./data/names/Russian.txt
./data/names/Scottish.txt
./data/names/German.txt
./data/names/Polish.txt
./data/names/Spanish.txt
./data/names/English.txt
./data/names/French.txt
./data/names/Japanese.txt
./data/names/Dutch.txt
./data/names/Greek.txt
./data/names/Chinese.txt
./data/names/Italian.txt
./data/names/Arabic.txt
"""
all_letters = string.ascii_letters + " .,;'" #所有的字母和标点
print(all_letters) # abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
n_letters = len(all_letters)
print(n_letters) # 57
# Turn a Unicode string to plain ASCII
# Thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in all_letters
)
print(unicodeToAscii('Ślusàrski')) # Slusarski
category_lines = {} # 每个语言下的名字
all_categories = [] # 保存所有的语言的名字
# 读文件, 返回文件中的每一个
def readLines(filename):
# 读文件, 去空格, 按换行进行划分
lines = open(filename, encoding='utf-8').read().strip().split('\n')
return [unicodeToAscii(line) for line in lines]
# 顺序读取每一个文件, 保存其中的姓名
for filename in findFiles('./data/names/*.txt'):
# basename返回基础的文件名
# splitext将文件名和后缀分开
category = os.path.splitext(os.path.basename(filename))[0]
all_categories.append(category)
lines = readLines(filename)
category_lines[category] = lines
n_categories = len(all_categories)
# 一共有18种语言的名字
print(len(category_lines)) # 18
# 查看某种语言的名字
print(category_lines.keys())
"""
dict_keys(['Korean', 'Irish', 'Portuguese', 'Vietnamese', 'Czech', 'Russian', 'Scottish', 'German', 'Polish',
'Spanish', 'English', 'French', 'Japanese', 'Dutch', 'Greek', 'Chinese', 'Italian', 'Arabic'])
"""
print(category_lines['English'][:5]) # ['Abbas', 'Abbey', 'Abbott', 'Abdi', 'Abel']
"将Name转为Tensor"
# Find letter index from all_letters, e.g. "a"=0
def letterToIndex(letter):
# 返回字母在字母表中的位置
# abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,;'
return all_letters.find(letter)
# 将整个名字转为<line_length x 1 x n_letters>
def lineToTensor(line):
tensor = torch.zeros(len(line), 1, n_letters)
for li, letter in enumerate(line):
tensor[li][0][letterToIndex(letter)] = 1
return tensor
print(lineToTensor('Bryant').shape) # torch.Size([6, 1, 57])
print(lineToTensor('Mike').shape) # torch.Size([4, 1, 57])
print(lineToTensor('Mike'))
"""
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0.]]])
"""
"定义网络"
class RNN_NAME(nn.Module):
def __init__(self, input_size, hidden_size, output_size, layers, batch_size):
super(RNN_NAME, self).__init__()
# 这里的input_size相当于是特征的个数, 这里即是所有字母的个数57
self.input_size = input_size
# 这里相当于是rnn输出的维数
self.hidden_size = hidden_size
# 这里是分类的个数, 这里相当于是18, 一共有18类
self.output_size = output_size
# 这里是rnn的层数
self.layers = layers
# batchsize的个数
self.batch_size = batch_size
self.gru = nn.GRU(input_size = self.input_size, hidden_size = self.hidden_size, num_layers = self.layers)
self.FC = nn.Linear(self.hidden_size, self.output_size)
def init_hidden(self):
return torch.zeros(self.layers, self.batch_size, self.hidden_size).to(device)
def forward(self, x):
self.batch_size = x.size(1)
self.hidden = self.init_hidden() # 初始化memory的内容
rnn_out, self.hidden = self.gru(x,self.hidden)
out = self.FC(rnn_out[-1])
return out
"初始化"
layers = 3
input_size = n_letters
hidden_size = 128
output_size = len(all_categories) # 18类
batch_size = 1 # 因为这里name的长度不相同, 所以将batch_size设为1
# 网络初始化
rnn_name = RNN_NAME(input_size, hidden_size, output_size, layers, batch_size).to(device)
# 测试网络输入输出
input_data = lineToTensor('Mike')
out_data = rnn_name(input_data.to(device))
print(out_data.size())
# torch.Size([1, 18])
print(out_data)
"""
tensor([[ 0.0051, -0.0081, -0.0906, 0.1054, 0.0679, -0.0187, -0.0497, -0.0349,
-0.0466, 0.0004, -0.0099, -0.0940, -0.0201, -0.0694, -0.0542, -0.0807,
0.0005, 0.1373]], grad_fn=<AddmmBackward0>)
"""
"准备训练数据"
# 打印最后显示的结果
def categoryFromOutput(output):
top_n, top_i = output.topk(1)
category_i = top_i[0].item()
return all_categories[category_i], category_i
# Get a random training example
def randomChoice(l):
# 随机返回一个l中的分类
return l[random.randint(0, len(l)-1)]
def randomTrainingExample():
"""随机挑选一种语言的一个名字, 用来产生训练需要的数据
"""
category = randomChoice(all_categories) # 随机选一个语言
line = randomChoice(category_lines[category]) # 随机从这个语言里选一个名字
category_tensor = torch.tensor([all_categories.index(category)]).long() # 将分类转为Tensor
line_tensor = lineToTensor(line) # 将名字转为Tensor
return category, line, category_tensor, line_tensor
for i in range(10):
category, line, category_tensor, line_tensor = randomTrainingExample()
print('category = ', category, '/ line = ', line)
"""
category = Scottish / line = Miller
category = Irish / line = O'Mooney
category = Dutch / line = Kools
category = French / line = Bouchard
category = French / line = Masson
category = English / line = May
category = Spanish / line = Holguin
category = Portuguese / line = Moreno
category = Greek / line = Nomikos
category = Czech / line = Cernohous
"""
"计算精度"
def get_accuracy(logit, target):
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects/batch_size
return accuracy.item()
"训练"
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn_name.parameters(),lr=0.005)
N_EPHOCS = 100
# 共有20074个数据
n_iters = 2000 # 训练n_iters个名字
print_every = 200 # 每print_every轮打印一次结果, 并传递一次误差, 更新系数
for epoch in range(N_EPHOCS):
train_running_loss = 0.0
loss = torch.tensor([0.0]).float().to(device)
train_count = 0
rnn_name.train()
# trainging round
for iter in range(1, n_iters+1):
# 获取样本
category, line, category_tensor, line_tensor = randomTrainingExample()
category_tensor = category_tensor.to(device)
line_tensor = line_tensor.to(device)
# 进行训练
optimizer.zero_grad()
# reset hidden states
rnn_name.hidden = rnn_name.init_hidden()
# forward+backward+optimize
outputs = rnn_name(line_tensor)
guess, guess_i = categoryFromOutput(outputs)
if guess == category:
train_count = train_count + 1
# print(iter)
loss = loss + criterion(outputs, category_tensor)
if iter % print_every == 0:
correct = '✓' if guess == category else '✗ ({})'.format(category)
print('{} /{} {}'.format(line, guess, correct))
# 进行反向传播(print_every次传一次误差)
loss.backward()
optimizer.step()
train_running_loss = train_running_loss + loss.detach().item()
loss = torch.tensor([0.0]).float().to(device)
# 查看准确率
train_acc = train_count / n_iters * 100
print('=======')
print('Epoch : {:0>2d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, train_running_loss/n_iters, train_acc))
print('=======')
"验证"
test_count = 0
test_num = 500 # 测试500个名字
rnn_name.eval()
for test_iter in range(1, test_num+1):
category, line, category_tensor, line_tensor = randomTrainingExample()
outputs = rnn_name(line_tensor.to(device))
guess, guess_i = categoryFromOutput(outputs)
if guess == category:
test_count = test_count + 1
print('Test Accuracy : {:<6.4f}%'.format(test_count/test_num*100))
"""
=======
Epoch : 71 | Loss : 0.2698 | Train Accuracy : 92.35 %
=======
Konda /Japanese ✓
Seighin /Irish ✓
Sedmik /Czech ✓
Jestkov /Russian ✓
So /Korean ✓
Laar /Dutch ✓
Mazza /Italian ✓
Abl /Czech ✓
Doyle /Irish ✓
Skwor /Czech ✓
=======
Epoch : 72 | Loss : 0.2583 | Train Accuracy : 91.40 %
=======
Martoyas /Greek ✗ (Russian)
Hallett /English ✓
Bermudez /Spanish ✓
Bonhomme /French ✓
Lam /Chinese ✗ (Vietnamese)
Kourempes /Greek ✓
Matsuki /Japanese ✓
Sokolofsky /Polish ✓
Kuipers /Dutch ✓
Aller /Dutch ✓
=======
Epoch : 73 | Loss : 0.2564 | Train Accuracy : 91.65 %
=======
Xiao /Chinese ✓
Lian /Chinese ✓
Prill /Czech ✓
Glatter /Czech ✓
Bureau /French ✓
Vallance /French ✗ (English)
Johnstone /Scottish ✓
Kassab /Arabic ✓
Atiyeh /Arabic ✓
Luong /Vietnamese ✓
=======
Epoch : 74 | Loss : 0.2610 | Train Accuracy : 91.60 %
=======
Poirier /French ✓
Hahn /German ✓
Botros /Arabic ✓
Adrichem /Dutch ✓
Rios /Portuguese ✗ (Spanish)
Zangari /Italian ✓
Mccallum /Scottish ✓
Ferro /Portuguese ✓
Kim /Vietnamese ✗ (Korean)
Chang /Korean ✓
=======
Epoch : 75 | Loss : 0.2792 | Train Accuracy : 90.85 %
=======
Monte /Italian ✓
Han /Vietnamese ✗ (Korean)
Ying /Chinese ✓
O'Driscoll /Irish ✓
Poirier /French ✓
Tsuruya /Japanese ✓
Ralph /Czech ✗ (English)
Bonheur /French ✓
Romagna /Italian ✓
Krakowski /Polish ✓
=======
Epoch : 76 | Loss : 0.3000 | Train Accuracy : 90.85 %
=======
Fay /French ✓
Plisek /Czech ✓
Travere /French ✓
Moon /Korean ✓
Aalst /Dutch ✓
Sassa /Japanese ✓
Holzmann /German ✓
Kruessel /Czech ✓
Truong /Vietnamese ✓
Kassab /Arabic ✓
=======
Epoch : 77 | Loss : 0.2602 | Train Accuracy : 92.00 %
=======
Janda /Polish ✓
Doan /Vietnamese ✓
O'Reilly /Irish ✓
Torres /Portuguese ✓
Esipovich /Russian ✓
Johnston /Scottish ✓
Viteri /Spanish ✗ (Italian)
Vandroogenbroeck /Dutch ✓
Montero /Spanish ✓
Wan /Chinese ✓
=======
Epoch : 78 | Loss : 0.2460 | Train Accuracy : 92.35 %
=======
Ha /Korean ✗ (Vietnamese)
Christie /Scottish ✓
Pereira /Portuguese ✓
Kuiper /Dutch ✓
Bautista /Spanish ✓
Takewaki /Japanese ✓
Ying /Chinese ✓
Sokoloff /Polish ✓
Raghailligh /Irish ✓
Edgell /English ✓
=======
Epoch : 79 | Loss : 0.2640 | Train Accuracy : 91.60 %
=======
Neal /English ✓
Kafka /Czech ✓
Hay /Scottish ✓
Pelletier /French ✓
Wilchek /Czech ✓
Marchesi /Italian ✓
Michel /Spanish ✓
Ling /Chinese ✓
Zou /Chinese ✓
Freitas /Portuguese ✓
=======
Epoch : 80 | Loss : 0.2579 | Train Accuracy : 92.60 %
=======
Kelly /Scottish ✓
Chou /Korean ✓
Unsworth /English ✓
Larue /French ✓
Tron /Vietnamese ✓
Herodes /Czech ✓
Fremut /Czech ✓
Thai /Vietnamese ✓
Clark /Irish ✗ (Scottish)
Halenkov /Russian ✓
=======
Epoch : 81 | Loss : 0.2453 | Train Accuracy : 92.40 %
=======
Than /Vietnamese ✓
Pfeifer /Czech ✓
Kennedy /Scottish ✓
Klerkx /Dutch ✓
Freitas /Portuguese ✓
Miller /Scottish ✓
Deeb /Arabic ✓
Almetov /Russian ✓
Li /Korean ✓
Perez /Spanish ✓
=======
Epoch : 82 | Loss : 0.2579 | Train Accuracy : 91.85 %
=======
Horiatis /Greek ✓
Clowes /English ✓
Santos /Portuguese ✓
Mochan /Irish ✓
Sokolof /Polish ✓
Ogterop /Dutch ✓
Schwartz /Czech ✓
Abadi /Arabic ✓
Winogrodzki /Polish ✓
Bazzhin /Russian ✓
=======
Epoch : 83 | Loss : 0.2397 | Train Accuracy : 92.25 %
=======
Ybarra /Spanish ✓
Jacobson /English ✓
Shi /Chinese ✓
Berg /Dutch ✓
Araya /Spanish ✓
Banos /Greek ✓
Winograd /Polish ✓
Schneider /German ✗ (Czech)
Calligaris /Italian ✓
Hasbulatov /Russian ✓
=======
Epoch : 84 | Loss : 0.2466 | Train Accuracy : 92.25 %
=======
Felkerzam /Polish ✗ (Russian)
Gomolka /Polish ✓
Belo /Portuguese ✓
Santos /Portuguese ✓
Rim /Korean ✓
Solomon /French ✓
Agnellutti /Italian ✓
Santana /Portuguese ✓
Ramaker /Dutch ✓
Palmisano /Italian ✓
=======
Epoch : 85 | Loss : 0.2390 | Train Accuracy : 92.80 %
=======
Moraitopoulos /Greek ✓
Rosenberger /German ✓
Torres /Portuguese ✓
Sabbagh /Arabic ✓
Aggelen /Dutch ✓
Chieu /Chinese ✓
Ukhtomsky /Russian ✓
Kouros /Greek ✓
Portyansky /Russian ✓
Dinh /Vietnamese ✓
=======
Epoch : 86 | Loss : 0.2173 | Train Accuracy : 93.15 %
=======
Jaskulski /Polish ✓
Mendez /Spanish ✓
Smyth /English ✓
Mo /Korean ✓
Seegers /Dutch ✓
Elford /English ✓
Lau /Chinese ✓
Sargent /French ✓
Zeman /Czech ✓
Ohme /German ✓
=======
Epoch : 87 | Loss : 0.2189 | Train Accuracy : 93.00 %
=======
Gwang /Korean ✓
Zhernakov /Russian ✓
Quraishi /Arabic ✓
Asker /Arabic ✓
Franco /Portuguese ✓
Bacon /Czech ✓
Weineltk /Czech ✓
Thuy /Vietnamese ✓
Close /Greek ✓
Xiang /Chinese ✓
=======
Epoch : 88 | Loss : 0.2497 | Train Accuracy : 91.95 %
=======
Basara /Arabic ✓
Esteves /Portuguese ✓
Metz /German ✓
Lopez /Spanish ✓
Bukoski /Polish ✓
Rao /Italian ✓
Klerk /Dutch ✓
Espinoza /Spanish ✓
Wehunt /German ✓
Giehl /German ✓
=======
Epoch : 89 | Loss : 0.2454 | Train Accuracy : 92.70 %
=======
Deniau /French ✓
Bover /Spanish ✗ (Italian)
Soho /Japanese ✓
Ron /Korean ✓
Dolezal /Czech ✓
Yamamura /Japanese ✓
Ramsay /Scottish ✓
Fromm /German ✓
Huerta /Spanish ✓
Colman /Irish ✓
=======
Epoch : 90 | Loss : 0.2432 | Train Accuracy : 92.55 %
=======
Fei /Chinese ✓
Dinh /Vietnamese ✓
Krawiec /Czech ✓
Urbanek /Czech ✓
Yan /Chinese ✓
Giehl /German ✓
Riain /Irish ✓
Zhestkov /Russian ✓
Almetiev /Russian ✓
Deniel /French ✓
=======
Epoch : 91 | Loss : 0.2318 | Train Accuracy : 92.45 %
=======
Macleod /Scottish ✓
Savatier /French ✓
Chong /Chinese ✗ (Korean)
Prchal /Czech ✓
Tsoumada /Greek ✓
Agosti /Italian ✓
Mojjis /Czech ✓
Mitsui /Japanese ✓
San nicolas /Spanish ✓
Kokoris /Greek ✓
=======
Epoch : 92 | Loss : 0.2163 | Train Accuracy : 92.25 %
=======
Marubeni /Japanese ✓
Suh /Korean ✓
Klerks /Dutch ✓
Roma /Spanish ✓
Lac /Vietnamese ✓
Khoury /Arabic ✓
Hoefler /German ✓
Gibson /Scottish ✓
Jasso /Spanish ✓
Oirschotten /Dutch ✓
=======
Epoch : 93 | Loss : 0.2329 | Train Accuracy : 92.90 %
=======
Kikuchi /Japanese ✓
Blecher /German ✓
Ying /Chinese ✓
Nahas /Arabic ✓
Araujo /Portuguese ✓
Koemans /Dutch ✓
Durant /French ✓
Martzenyuk /Russian ✓
Petru /Czech ✓
Wood /Scottish ✓
=======
Epoch : 94 | Loss : 0.2080 | Train Accuracy : 93.35 %
=======
Vespa /Italian ✓
Kowalczyk /Polish ✓
Nahas /Arabic ✓
Tresler /German ✓
Caivano /Italian ✓
Pho /Vietnamese ✓
Dael /Dutch ✓
Heidl /Czech ✓
Totolos /Greek ✓
Adlersflugel /German ✓
=======
Epoch : 95 | Loss : 0.2400 | Train Accuracy : 92.45 %
=======
Zwolenksy /Czech ✓
Chermak /Czech ✓
Sgro /Italian ✓
Lestrange /French ✓
Hwang /Korean ✓
Barros /Spanish ✗ (Portuguese)
Boyle /Scottish ✓
Schmeling /German ✓
Poplawski /Polish ✓
Pantelas /Greek ✓
=======
Epoch : 96 | Loss : 0.2150 | Train Accuracy : 93.10 %
=======
Asher /English ✓
Sai /Vietnamese ✓
Hakimi /Arabic ✓
Sheng /Chinese ✓
Lebedevich /Russian ✓
Castellano /Spanish ✓
Battaglia /Italian ✓
Cardozo /Portuguese ✓
Sinagra /Italian ✓
Bumgarner /German ✓
=======
Epoch : 97 | Loss : 0.2359 | Train Accuracy : 92.80 %
=======
Lyes /English ✓
Do /Vietnamese ✓
Busch /German ✓
Shimura /Japanese ✓
O'Connor /Irish ✓
Geiger /German ✓
Reier /German ✓
Paris /French ✓
Sokolof /Polish ✓
Limarev /Russian ✓
=======
Epoch : 98 | Loss : 0.2369 | Train Accuracy : 92.55 %
=======
Osladil /Czech ✓
Roy /French ✓
Ortega /Spanish ✓
Ruzzier /Italian ✓
Chou /Korean ✓
Mateus /Portuguese ✓
Ibanez /Spanish ✓
Cabello /Spanish ✓
Duguay /French ✓
Yi /Korean ✓
=======
Epoch : 99 | Loss : 0.2192 | Train Accuracy : 92.65 %
=======
Test Accuracy : 94.8000%
"""
"单独测试"
input_data = lineToTensor('WMN').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 5.7034, 1.0659, -8.1210, 0.5218, -2.5818, 0.1362, 7.4420, 3.9547,
1.0824, -7.2981, 4.0036, -2.2944, 1.6835, -3.5377, -7.6178, 9.8920,
-0.5789, -1.0545]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""
input_data = lineToTensor('Huang').to(device)
out_data = rnn_name(input_data)
print(out_data)
print(categoryFromOutput(out_data))
"""
tensor([[ 4.1755, -1.2160, -3.8454, 5.4587, -4.6412, 0.8137, 1.2896, 1.7429,
-1.3145, -2.1634, 1.0224, -2.7052, 2.1592, 1.1909, -8.0874, 10.2988,
-1.9252, -0.0729]], device='cuda:0', grad_fn=<AddmmBackward0>)
('Chinese', 15)
"""