
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
博主简介:努力学习的22级本科生一枚 🌟;探索AI算法,C++,go语言的世界;在迷茫中寻找光芒🌸
博客主页:羊小猪~~-CSDN博客
内容简介:这一篇是NLP的入门项目,以AG_NEW新闻数据为例。
🌸箴言🌸:去寻找理想的"天空""之城
上一篇内容:【NLP入门系列三】NLP文本嵌入(以Embedding和EmbeddingBag为例)-CSDN博客💁💁💁💁: 如果在conda安装环境,由于nlp的核心包是torchtext,所以如果把握不好就重新创建一虚拟环境(小编的"难忘"经历)
文章目录
🤔 思路

1、准备
AG News 数据集 (也叫 AG's Corpus or AG News Dataset),这是一个广泛用于自然语言处理(NLP)任务中的文本分类数据集。
基本信息:
- 全称:AG News
- 来源:来源于 AG's corpus,由 A. Godin 在 2005 年构建。
- 用途 :主要用于短文本多类别分类任务
- 语言:英文
- 类别数:4 类新闻主题
- 训练样本数:120,000 条
- 测试样本数:7,600 条
类别标签(共 4 类)
标签 | 含义 |
---|---|
1 |
World (世界) |
2 |
Sports (体育) |
3 |
Business (商业) |
4 |
Science and Technology (科技) |
数据加载
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchtext
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 检查设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
device(type='cuda')
python
# 加载本地数据
train_df = pd.read_csv("./data/train.csv")
test_df = pd.read_csv("./data/test.csv")
# 合并标题和描述数据
train_df["text"] = train_df["Title"] + " " + train_df["Description"]
test_df["text"] = test_df["Title"] + " " + test_df["Description"]
# 查看数据格式
train_df
| | Class Index | Title | Description | text |
| 0 | 3 | Wall St. Bears Claw Back Into the Black (Reuters) | Reuters - Short-sellers, Wall Street's dwindli... | Wall St. Bears Claw Back Into the Black (Reute... |
| 1 | 3 | Carlyle Looks Toward Commercial Aerospace (Reu... | Reuters - Private investment firm Carlyle Grou... | Carlyle Looks Toward Commercial Aerospace (Reu... |
| 2 | 3 | Oil and Economy Cloud Stocks' Outlook (Reuters) | Reuters - Soaring crude prices plus worries\ab... | Oil and Economy Cloud Stocks' Outlook (Reuters... |
| 3 | 3 | Iraq Halts Oil Exports from Main Southern Pipe... | Reuters - Authorities have halted oil export\f... | Iraq Halts Oil Exports from Main Southern Pipe... |
| 4 | 3 | Oil prices soar to all-time record, posing new... | AFP - Tearaway world oil prices, toppling reco... | Oil prices soar to all-time record, posing new... |
| ... | ... | ... | ... | ... |
| 119995 | 1 | Pakistan's Musharraf Says Won't Quit as Army C... | KARACHI (Reuters) - Pakistani President Perve... | Pakistan's Musharraf Says Won't Quit as Army C... |
| 119996 | 2 | Renteria signing a top-shelf deal | Red Sox general manager Theo Epstein acknowled... | Renteria signing a top-shelf deal Red Sox gene... |
| 119997 | 2 | Saban not going to Dolphins yet | The Miami Dolphins will put their courtship of... | Saban not going to Dolphins yet The Miami Dolp... |
| 119998 | 2 | Today's NFL games | PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ... | Today's NFL games PITTSBURGH at NY GIANTS Time... |
119999 | 2 | Nets get Carter from Raptors | INDIANAPOLIS -- All-Star Vince Carter was trad... | Nets get Carter from Raptors INDIANAPOLIS -- A... |
---|
120000 rows × 4 columns
构建词表
python
# 定义 Dataset
class AGNewsDataset(Dataset):
def __init__(self, dataframe):
self.labels = dataframe['Class Index'].tolist() # 列表数据
self.texts = dataframe['text'].tolist()
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.labels[idx], self.texts[idx]
# 加载数据
train_dataset = AGNewsDataset(train_df)
test_dataset = AGNewsDataset(test_df)
# 构建词表
tokenizer = get_tokenizer("basic_english") # 英文数据,设置英文分词
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text) # 构建词表
# 构建词表,设置索引
vocab = build_vocab_from_iterator(yield_tokens(train_dataset), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
print("Vocab size:", len(vocab))
Vocab size: 95804
python
# 查看这些单词所在词典的索引
vocab(['here', 'is', 'an', 'example'])
[475, 21, 30, 5297]
python
'''
标签,原始是字符串类型,现在要转换成 数字 类型
文本数字化,需要一个函数进行转换(vocab)
'''
text_pipline = lambda x : vocab(tokenizer(x)) # 先分词。在数字化
label_pipline = lambda x : int(x) - 1 # 标签转化为数字
# 举例
text_pipline('here is the an example')
[475, 21, 2, 30, 5297]
2、生成数据批次和迭代器
python
# 采用embeddingbag嵌入方式,故需要构建数据,包括长度、标签、偏移量
'''
数据格式:长度(~, 1)
标签:一维
偏移量:一维
'''
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
# 标签列表,注意字符串转换成数字
label_list.append(label_pipline(_label))
# 文本列表,注意要转入tensro数据
temp_text = torch.tensor(text_pipline(_text), dtype=torch.int64)
text_list.append(temp_text)
# 偏移量
offsets.append(temp_text.size(0))
# 全部转变成tensor变量
label_list = torch.tensor(label_list, dtype=torch.int64)
text_list = torch.cat(text_list)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
return label_list.to(device), text_list.to(device), offsets.to(device)
# 数据加载
batch_size = 16
train_dl = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_batch
)
test_dl = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_batch
)
3、定义与模型
模型定义
python
class TextModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super().__init__()
self.embeddingBag = nn.EmbeddingBag(vocab_size, # 词典大小
embed_dim, # 嵌入维度
sparse=False)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
# 初始化权重
def init_weights(self):
initrange = 0.5
self.embeddingBag.weight.data.uniform_(-initrange, initrange) # 初始化权重范围
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_() # 偏置置为0
def forward(self, text, offsets):
embedding = self.embeddingBag(text, offsets)
return self.fc(embedding)
创建模型
python
# 查看数据类别
train_df.groupby('Class Index').count()
| | Title | Description | text |
| Class Index | | | |
| 1 | 30000 | 30000 | 30000 |
| 2 | 30000 | 30000 | 30000 |
| 3 | 30000 | 30000 | 30000 |
4 | 30000 | 30000 | 30000 |
---|
python
class_num = 4
vocab_len = len(vocab)
embed_dim = 64 # 嵌入到64维度中
model = TextModel(vocab_size=vocab_len, embed_dim=embed_dim, num_class=class_num).to(device=device)
4、创建训练和评估函数
训练函数
python
def train(model, dataset, optimizer, loss_fn):
size = len(dataset.dataset)
num_batch = len(dataset)
train_acc = 0
train_loss = 0
for _, (label, text, offset) in enumerate(dataset):
label, text, offset = label.to(device), text.to(device), offset.to(device)
predict_label = model(text, offset)
loss = loss_fn(predict_label, label)
# 求导与反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc += (predict_label.argmax(1) == label).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batch
return train_acc, train_loss
评估函数
python
def test(model, dataset, loss_fn):
size = len(dataset.dataset)
batch_size = len(dataset)
test_acc, test_loss = 0, 0
with torch.no_grad():
for _, (label, text, offset) in enumerate(dataset):
label, text, offset = label.to(device), text.to(device), offset.to(device)
predict = model(text, offset)
loss = loss_fn(predict, label)
test_acc += (predict.argmax(1) == label).sum().item()
test_loss += loss.item()
test_acc /= size
test_loss /= batch_size
return test_acc, test_loss
创建超参数
python
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.01) # 动态调整学习率
5、模型训练
python
import copy
epochs = 10
train_acc, train_loss, test_acc, test_loss = [], [], [], []
best_acc = 0
for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(model, train_dl, optimizer, loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
model.eval()
epoch_test_acc, epoch_test_loss = test(model, test_dl, loss_fn)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
if best_acc is not None and epoch_test_acc > best_acc:
# 动态调整学习率
scheduler.step()
best_acc = epoch_test_acc
best_model = copy.deepcopy(model) # 保存模型
# 当前学习率
lr = optimizer.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型到文件
path = './best_model.pth'
torch.save(best_model.state_dict(), path) # 保存模型参数
Epoch: 1, Train_acc:79.9%, Train_loss:0.562, Test_acc:86.9%, Test_loss:0.392, Lr:5.00E-01
Epoch: 2, Train_acc:89.7%, Train_loss:0.313, Test_acc:88.9%, Test_loss:0.346, Lr:5.00E-01
Epoch: 3, Train_acc:91.2%, Train_loss:0.269, Test_acc:89.6%, Test_loss:0.329, Lr:5.00E-01
Epoch: 4, Train_acc:92.0%, Train_loss:0.243, Test_acc:89.8%, Test_loss:0.319, Lr:5.00E-01
Epoch: 5, Train_acc:92.6%, Train_loss:0.224, Test_acc:90.2%, Test_loss:0.315, Lr:5.00E-03
Epoch: 6, Train_acc:93.3%, Train_loss:0.207, Test_acc:90.6%, Test_loss:0.297, Lr:5.00E-03
Epoch: 7, Train_acc:93.4%, Train_loss:0.204, Test_acc:90.7%, Test_loss:0.295, Lr:5.00E-03
Epoch: 8, Train_acc:93.4%, Train_loss:0.203, Test_acc:90.7%, Test_loss:0.294, Lr:5.00E-03
Epoch: 9, Train_acc:93.4%, Train_loss:0.202, Test_acc:90.8%, Test_loss:0.293, Lr:5.00E-03
Epoch:10, Train_acc:93.4%, Train_loss:0.201, Test_acc:90.7%, Test_loss:0.293, Lr:5.00E-03
6、结果展示
python
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率
epoch_length = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epoch_length, train_acc, label='Train Accuaray')
plt.plot(epoch_length, test_acc, label='Test Accuaray')
plt.legend(loc='lower right')
plt.title('Accurary')
plt.subplot(1, 2, 2)
plt.plot(epoch_length, train_loss, label='Train Loss')
plt.plot(epoch_length, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.show()
7、预测
python
model.load_state_dict(torch.load("./best_model.pth"))
model.eval() # 模型评估
# 测试句子
test_sentence = "This is a news about Technology"
# 转换为 token
token_ids = vocab(tokenizer(test_sentence)) # 切割分词--> 词典序列
text = torch.tensor(token_ids, dtype=torch.long).to(device) # 转化为tensor
offsets = torch.tensor([0], dtype=torch.long).to(device)
# 测试,注意:不需要反向求导
with torch.no_grad():
output = model(text, offsets)
predicted_label = output.argmax(1).item()
# 输出结果
class_names = ["World", "Sports", "Business", "Science and Technology"]
print(f"预测类别: {class_names[predicted_label]}")
预测类别: Science and Technology