【NLP练习】Pytorch文本分类入门

Pytorch文本分类入门

python 复制代码
🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍖 原作者:K同学啊 | 接辅导、项目定制

一、前期准备

1. 环境安装

确保已经安装torchtext与portalocker库

2. 加载数据

python 复制代码
#加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

python 复制代码
device(type='cpu')
python 复制代码
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split='train')   #加载AG NEWS数据集

3. 构建词典

python 复制代码
#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')   #返回分词器函数

def yield_tokens(data_iter):
    for _,text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                 specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引
python 复制代码
vocab(['here','is','an','example'])

输出:

python 复制代码
[475, 21, 30, 5297]
python 复制代码
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')

输出:

python 复制代码
[475, 21, 2, 30, 5297]

4. 生成数据批次和迭代器

python 复制代码
# 生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]
    for(_label, _text) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        #偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和
    return label_list.to(device), text_list.to(device), offsets.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

二、准备模型

1. 定义模型

定义TextClassificationModel模型,首先对文本进行嵌入,然后对句子嵌入后的结果进行均值聚合

python 复制代码
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小
                                        embed_dim,        # 嵌入的维度
                                        sparse=False)     #
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

2. 定义实例

python 复制代码
#定义实例
num_class = len(set([label for (label,text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练函数与评估函数

python 复制代码
#定义训练函数与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 500
    start_time = time.time()
    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        optimizer.step()                                  #每一步自动更新
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(label,text,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

三、训练模型

1. 拆分数据集并运行模型

python 复制代码
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10
LR          = 5
BATCH_SIZE  = 64

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 加载数据
train_iter, test_iter   = AG_NEWS() 
train_dataset   = to_map_style_dataset(train_iter)
test_dataset    = to_map_style_dataset(test_iter)
num_train       = int(len(train_dataset) * 0.95)

split_train_ , split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
test_dataloader     = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)

输出:

python 复制代码
|epoch1| 500/1782 batches|train_acc0.909 train_loss0.00420
|epoch1|1000/1782 batches|train_acc0.909 train_loss0.00431
|epoch1|1500/1782 batches|train_acc0.910 train_loss0.00415
---------------------------------------------------------------------
| epoch 1 | time:17.55s | valid_acc 0.913 valid_loss 0.004
---------------------------------------------------------------------
|epoch2| 500/1782 batches|train_acc0.924 train_loss0.00355
|epoch2|1000/1782 batches|train_acc0.922 train_loss0.00366
|epoch2|1500/1782 batches|train_acc0.917 train_loss0.00376
---------------------------------------------------------------------
| epoch 2 | time:17.58s | valid_acc 0.914 valid_loss 0.004
---------------------------------------------------------------------
|epoch3| 500/1782 batches|train_acc0.929 train_loss0.00329
|epoch3|1000/1782 batches|train_acc0.929 train_loss0.00332
|epoch3|1500/1782 batches|train_acc0.929 train_loss0.00337
---------------------------------------------------------------------
| epoch 3 | time:19.67s | valid_acc 0.892 valid_loss 0.005
---------------------------------------------------------------------
|epoch4| 500/1782 batches|train_acc0.947 train_loss0.00258
|epoch4|1000/1782 batches|train_acc0.946 train_loss0.00257
|epoch4|1500/1782 batches|train_acc0.946 train_loss0.00266
---------------------------------------------------------------------
| epoch 4 | time:18.36s | valid_acc 0.915 valid_loss 0.004
---------------------------------------------------------------------
|epoch5| 500/1782 batches|train_acc0.951 train_loss0.00243
|epoch5|1000/1782 batches|train_acc0.949 train_loss0.00252
|epoch5|1500/1782 batches|train_acc0.947 train_loss0.00256
---------------------------------------------------------------------
| epoch 5 | time:17.92s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch6| 500/1782 batches|train_acc0.950 train_loss0.00245
|epoch6|1000/1782 batches|train_acc0.950 train_loss0.00246
|epoch6|1500/1782 batches|train_acc0.950 train_loss0.00245
---------------------------------------------------------------------
| epoch 6 | time:18.10s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch7| 500/1782 batches|train_acc0.950 train_loss0.00245
|epoch7|1000/1782 batches|train_acc0.951 train_loss0.00242
|epoch7|1500/1782 batches|train_acc0.951 train_loss0.00239
---------------------------------------------------------------------
| epoch 7 | time:18.08s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
|epoch8| 500/1782 batches|train_acc0.951 train_loss0.00238
|epoch8|1000/1782 batches|train_acc0.951 train_loss0.00241
|epoch8|1500/1782 batches|train_acc0.955 train_loss0.00228
---------------------------------------------------------------------
| epoch 8 | time:18.75s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch9| 500/1782 batches|train_acc0.952 train_loss0.00234
|epoch9|1000/1782 batches|train_acc0.953 train_loss0.00235
|epoch9|1500/1782 batches|train_acc0.951 train_loss0.00237
---------------------------------------------------------------------
| epoch 9 | time:18.50s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
|epoch10| 500/1782 batches|train_acc0.951 train_loss0.00234
|epoch10|1000/1782 batches|train_acc0.954 train_loss0.00231
|epoch10|1500/1782 batches|train_acc0.954 train_loss0.00234
---------------------------------------------------------------------
| epoch 10 | time:17.82s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------

2. 使用测试数据集评估模型

python 复制代码
print('Checking the results of test dataset.')
test_acc,test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

输出:

python 复制代码
Checking the results of test dataset.
test accuracy    0.908
相关推荐
XuecWu33 分钟前
Karpathy的AutoResearch与Gemini三层 Agent 架构后的相通设计逻辑
人工智能·深度学习·语言模型·自然语言处理
V搜xhliang02461 小时前
世界模型、强化学习PPOSAC
人工智能·深度学习·机器学习·语言模型·自然语言处理
闻道且行之3 小时前
PyTorch 深度学习开发 常见疑难报错与解决方案汇总
人工智能·pytorch·深度学习
Dxy12393102166 小时前
深度学习的优雅收尾:PyTorch中PolynomialLR的终极指南
人工智能·pytorch·深度学习
chushiyunen8 小时前
pycharm打包whl
人工智能·pytorch·python
墨染天姬8 小时前
【AI】PyTorch 框架
人工智能·pytorch·python
wuxuand8 小时前
2026时序分类综述A Comprehensive Review of Time Series Classification
人工智能·深度学习·分类·数据挖掘
小班得瑞9 小时前
pytorch使用小结
pytorch
Fleshy数模9 小时前
基于PyTorch实现食物图像分类:从数据加载到CNN训练全流程
pytorch·分类·cnn
开放知识图谱10 小时前
论文浅尝 | 简单即有效:图和大型语言模型在基于知识图谱的检索增强生成中的作用(ICLR2025)
人工智能·语言模型·自然语言处理·知识图谱