N3 - Pytorch文本分类入门


目录


文本分类的基本流程

常用的数据清洗方法

如何使用jieba实现英文分词

如何构建文本向量

代码实践

数据准备

使用AG News数据集进行文本分类。

AG News(AG's News Topic Classification Dataset)是一个广泛用于文本分类任务的数据集,尤其是在新闻领域。该数据集是由AG's Corpus of News Articles收集整理而来,包含了四个主要类别: 世界、体育、商业和科技。

python 复制代码
# 使用torchtext导入数据集
import torch
torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = torch.utils._import_utils.dill_available()

from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')

我们通过打印数据内容查看一下数据集的格式

python 复制代码
for i, data in enumerate(train_iter):
    print(data)
    if i == 3:
        break

由此可见数据集的每一个条目是一个元组,包含新闻文章所属的类别和新闻文章的文本内容,其中类别是一个整数,从1到4,分别对应 世界、科技、体育和商业。

构建词典

要构建词典,需要一个分词器,将句子分成分散的词后,再创建词典。也就是上图文本分类任务中的:文本清洗、分词、文本向量化这三步做的事情。

python 复制代码
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
	for _, text in data_iter:
		yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>'])
# 给未知单词设置一个默认索引,当一个单词不在词库中,就取默认索引,将它表示为<unk>
vocab.set_default_index(vocab['<unk>'])

get_tokenizer用于获取分词器函数,分词器可以将一个字符串转换成一个单词的列表

python 复制代码
print(tokenizer('Here is the example'))

vocab是使用torchtext的函数构建出的字典对象,可以使用它直接将单词转换为对应的词典序号,然后可以将序号转换为词向量(例如使用one-hot编码)。

python 复制代码
print(vocab(['here', 'is', 'the', 'example']))

生成数据批次和迭代器

python 复制代码
import torch
from torch.utils.data import DataLoader

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def collate_batch(batch):
	label_list, text_list, offsets = [], [], [0]
	for _label, _text in batch:
		# 将一个批次的标签汇集起来
		label_list.append(label_pipeline(_label))
		# 将一个批次的文本转换成序号汇集起来
		processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
		text_list.append(processed_text)

		# 当前批次中每个句子的长度
		offsets.append(processed_text.size(0))
	label_list = torch.tensor(label_list, dtype=torch.int64)
	text_list = torch.cat(text_list, dim=0)
	offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # 把每个句子的长度累计求合,成为真正的偏移量
	return label_list.to(device), text_list.to(device), offsets.to(device)

# 生成DataLoader
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

模型设计

模型的结构如上图所示,对文本进行嵌入后,将句子的嵌入结果进行均值聚合,也就是使用EmbeddingBag mode为mean

python 复制代码
from torch import nn
class TextClassificationModel(nn.Module):
	def __init__(self, vocab_size, embed_dim, num_classes):
		super().__init__()
		self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
		self.fc = nn.Linear(embed_dim, num_class)
		self.init_weights()
	
	def init_weights(self):
		initrange = 0.5
		self.embedding.weight.data.uniform_(-initrange, initrange)
		self.fc.weight.data.uniform_(-initrange, initrange)
		self.fc.bias.data.zero_()

	def forward(self, text, offsets):
		embedded = self.embedding(text, offsets)
		return self.fc(embedded)

以上模型中

  • self.embedding 是词嵌入层。作用是将离散的单词表示 (这里直接是单词的词典序号)映射为固定大小的连续向量(也就是单词的向量化)。这些向量捕捉了单词之间的词义关系,并作为网络的输入。
  • self.embedding.weight 是词嵌入层的权重矩阵,它的形状为(vocab_size, embed_dim),其中vocab_size是词汇表的大小,embed_dim是嵌入向量的维度
  • self.embedding.weight.data 是权重矩阵的数据部分,对它进行操作也就直接操作了底层的权重张量
  • .uniform_(-initrange, initrange) 这代表执行了一个原地操作(in-place operation),用于将权重矩阵的值用一个均匀分布进行初始化。均匀分布的范围是[-initrange,initrange],其中initrange是一个正数。

模型创建

python 复制代码
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

创建模型对象

模型训练

定义训练函数与评估函数

python 复制代码
import time
def train(dataloader):
	model.train()
	total_acc, train_loss, total_count = 0, 0, 0
	log_interval = 500
	start_time = time.time()
	
	for idx, (label, text, offsets) in enumerate(dataloader):
		predicted_label = model(text, offsets)
		
		optimizer.zero_grad()
		loss = criterion(predicted_label, label)
		loss.backward()
		optimizer.step()

		total_acc += (predicted_label.argmax(1) == label).sum().item()
		train_loss += loss.item()
		total_count += label.size(0)

		if idx % log_interval == 0 and idx > 0:
			elapsed = time.time() - start_time
			print('| epoch {:1d} | {:4d}/{:4d} batches'
				  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader), total_acc/total_count, train_loss/total_count))
			total_acc, train_loss, total_count = 0, 0, 0
			start_time = time.time()

def evaluate(dataloader):
	model.eval()
	total_acc, train_loss, total_count = 0, 0, 0
	
	with torch.no_grad():
		for idx, (label, text, offsets) in enumerate(dataloader):
			predicted_label = model(text, offsets)
			loss = criterion(predicted_label, label)
			total_acc += (predicted_label.argmax(1) == label).sum().item()
			train_loss += loss.item()
			total_count += label.size(0)

	return total_acc/total_count, train_loss/total_count

开始训练

python 复制代码
from torch.utils.data import random_split
from torchtext.data.functional import to_map_style_dataset

# 超参数
EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter, test_iter = AG_NEWS()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)

split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
	epoch_start_time = time.time()
	train(train_dataloder)
	val_acc, val_loss = evaluate(valid_dataloader)
	
	if total_accu is not None and total_accu > val_acc:
		scheduler.step()
	else:
		total_accu = val_acc
	print('-'*69)
	print('| epoch {:1d} | time: {:4.2f}s | '
		  'valid_acc {:4.3f} valid_loss {:4.3f}'.format(epoch, time.time() - epoch_start_time, val_acc, val_loss))
	print('-'*69)

在上面的代码中,to_map_style_dataset的作用是将一个迭代的数据集(iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引更方便地访问数据集中的元素。

在pytorch中,数据集可以分成两种类型,Iterable-style和Map-style。Iterable-style数据集实现了__iter__()方法,可以迭代访问数据集中的元素,但不支持通过索引访问。而Map-style数据集实现了__getitem__()__len__()方法,可以直接通过索引访问特定元素,并能获取数据集的大小。

torchtext是pytorch的一个扩展库,专注于处理文本数据。torchtext.data.functional中的to_map_style_dataset函数可以帮助我们将一个Iterable-style的数据集转换为一个易于操作的Map-style的数据集。然后就可以通过索引直接访问数据集中的特定样本,从而简化训练、验证和测试过程中的数据处理。

训练过程如下:

bash 复制代码
| epoch 1 |  500/1782 batches| train_acc 0.907 train_loss 0.00429
| epoch 1 | 1000/1782 batches| train_acc 0.906 train_loss 0.00431
| epoch 1 | 1500/1782 batches| train_acc 0.909 train_loss 0.00421
---------------------------------------------------------------------
| epoch 1 time: 6.77s | valid_acc 0.913 valid_loss 0.004
---------------------------------------------------------------------
| epoch 2 |  500/1782 batches| train_acc 0.921 train_loss 0.00369
| epoch 2 | 1000/1782 batches| train_acc 0.920 train_loss 0.00375
| epoch 2 | 1500/1782 batches| train_acc 0.918 train_loss 0.00376
---------------------------------------------------------------------
| epoch 2 time: 6.80s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 3 |  500/1782 batches| train_acc 0.930 train_loss 0.00323
| epoch 3 | 1000/1782 batches| train_acc 0.926 train_loss 0.00334
| epoch 3 | 1500/1782 batches| train_acc 0.925 train_loss 0.00343
---------------------------------------------------------------------
| epoch 3 time: 6.93s | valid_acc 0.860 valid_loss 0.006
---------------------------------------------------------------------
| epoch 4 |  500/1782 batches| train_acc 0.943 train_loss 0.00267
| epoch 4 | 1000/1782 batches| train_acc 0.945 train_loss 0.00263
| epoch 4 | 1500/1782 batches| train_acc 0.946 train_loss 0.00265
---------------------------------------------------------------------
| epoch 4 time: 6.83s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 5 |  500/1782 batches| train_acc 0.947 train_loss 0.00256
| epoch 5 | 1000/1782 batches| train_acc 0.947 train_loss 0.00258
| epoch 5 | 1500/1782 batches| train_acc 0.947 train_loss 0.00261
---------------------------------------------------------------------
| epoch 5 time: 6.76s | valid_acc 0.921 valid_loss 0.004
---------------------------------------------------------------------
| epoch 6 |  500/1782 batches| train_acc 0.948 train_loss 0.00253
| epoch 6 | 1000/1782 batches| train_acc 0.949 train_loss 0.00253
| epoch 6 | 1500/1782 batches| train_acc 0.951 train_loss 0.00241
---------------------------------------------------------------------
| epoch 6 time: 6.93s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 7 |  500/1782 batches| train_acc 0.949 train_loss 0.00248
| epoch 7 | 1000/1782 batches| train_acc 0.949 train_loss 0.00250
| epoch 7 | 1500/1782 batches| train_acc 0.949 train_loss 0.00248
---------------------------------------------------------------------
| epoch 7 time: 6.85s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 8 |  500/1782 batches| train_acc 0.948 train_loss 0.00247
| epoch 8 | 1000/1782 batches| train_acc 0.950 train_loss 0.00250
| epoch 8 | 1500/1782 batches| train_acc 0.951 train_loss 0.00243
---------------------------------------------------------------------
| epoch 8 time: 6.76s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------
| epoch 9 |  500/1782 batches| train_acc 0.951 train_loss 0.00239
| epoch 9 | 1000/1782 batches| train_acc 0.948 train_loss 0.00259
| epoch 9 | 1500/1782 batches| train_acc 0.951 train_loss 0.00244
---------------------------------------------------------------------
| epoch 9 time: 6.87s | valid_acc 0.926 valid_loss 0.004
---------------------------------------------------------------------

评估模型

python 复制代码
print('Checking the results of test dataset.')
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

总结与心得体会

文本分类任务,关键的是前面对文本的处理,合并,嵌入,最后的分类反而非常简直,直接使用了一层全连接层就可以达到不错的效果了。

在复现的过程中,由于使用的库版本不一致导致torchtext库部分代码无法正常运行,卡了好久,后面搜索了一些之前打卡的同学的博客,才找到解决方案。复现模型时尽量不要使用最新的版本,而是使用原来的版本,先运行起来,再改动。

相关推荐
jndingxin1 分钟前
OpenCV特征检测(1)检测图像中的线段的类LineSegmentDe()的使用
人工智能·opencv·计算机视觉
@月落11 分钟前
alibaba获得店铺的所有商品 API接口
java·大数据·数据库·人工智能·学习
z千鑫20 分钟前
【人工智能】如何利用AI轻松将java,c++等代码转换为Python语言?程序员必读
java·c++·人工智能·gpt·agent·ai编程·ai工具
MinIO官方账号38 分钟前
从 HDFS 迁移到 MinIO 企业对象存储
人工智能·分布式·postgresql·架构·开源
aWty_1 小时前
机器学习--K-Means
人工智能·机器学习·kmeans
草莓屁屁我不吃1 小时前
AI大语言模型的全面解读
人工智能·语言模型·自然语言处理·chatgpt
WPG大大通1 小时前
有奖直播 | onsemi IPM 助力汽车电气革命及电子化时代冷热管理
大数据·人工智能·汽车·方案·电气·大大通·研讨会
百锦再1 小时前
AI对汽车行业的冲击和比亚迪新能源汽车市场占比
人工智能·汽车
ws2019071 小时前
抓机遇,促发展——2025第十二届广州国际汽车零部件加工技术及汽车模具展览会
大数据·人工智能·汽车
Zhangci]1 小时前
Opencv图像预处理(三)
人工智能·opencv·计算机视觉