第N5周:Pytorch文本分类入门

程序完整解释

1. 程序目标

这是一个文本分类程序,使用PyTorch框架对AG_NEWS新闻数据集进行分类。AG_NEWS数据集包含4个类别的新闻文章:

1: World(世界新闻)

2: Sports(体育新闻)

3: Business(商业新闻)

4: Sci/Tech(科技新闻)

2. 核心架构

程序采用词嵌入平均 + 线性分类器的简单但有效的架构

程序如下:

python 复制代码
# ================================
# 一、数据加载与预处理
# ================================

import torch  # 深度学习框架
from torch import nn  # 神经网络模块
import torchvision  # 计算机视觉工具包(虽然这里没有使用)
from torchvision import transforms, datasets  # 数据转换和数据集工具
import os, PIL, pathlib, warnings  # 系统、图像处理、路径处理和警告控制

warnings.filterwarnings("ignore")  # 忽略所有警告信息,使输出更干净

# 设置设备:优先使用GPU(CUDA),如果没有则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 导入torchtext的AG_NEWS数据集(新闻分类数据集)
from torchtext.datasets import AG_NEWS

# 加载AG_NEWS数据集,返回训练集和测试集迭代器
# 注意:新版本torchtext中,AG_NEWS()直接返回(train_iter, test_iter)
train_iter, test_iter = AG_NEWS()

# 重新导入datasets模块(这里有些冗余,可以直接使用上面的导入)
import torchtext.datasets as datasets

# 另一种加载方式:指定数据存储路径和分割方式
# train_dataset, test_dataset = datasets.AG_NEWS(root='./data', split=("train", "test"))
# 注:上面这行被注释掉了,实际使用的是前面的train_iter, test_iter

# 将训练迭代器转换为可迭代对象
train_iter = iter(train_iter)

# 获取第一条训练数据:next()返回(label, text)对
label, text = next(train_iter)

print("Text: ", text)  # 打印文本内容
print("Label: ", label)  # 打印标签(1-4表示4个类别)

# 导入分词器和词汇表构建工具
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

# 获取基础的英文分词器函数
# 'basic_english'会将文本转换为小写,并按空格和标点符号分割
tokenizer = get_tokenizer('basic_english')

# 定义生成器函数:遍历数据集,对每条文本进行分词
def yield_tokens(data_iter):
    """
    生成器函数:遍历数据集中的每条文本,使用分词器进行分词
    
    参数:
        data_iter: 数据迭代器,每个元素为(_, text)对
    
    生成:
        分词后的词汇列表
    """
    for _, text in data_iter:
        yield tokenizer(text)  # 对文本进行分词

# 重新创建训练迭代器(因为前面已经取出了第一条数据)
train_iter, test_iter = AG_NEWS()
train_iter = iter(train_iter)

# 从分词结果构建词汇表
# specials=["<unk>"]:添加特殊token,<unk>表示未知词
vocab = build_vocab_from_iterator(
    yield_tokens(train_iter),  # 生成器,提供分词后的词汇
    specials=["<unk>"]  # 特殊token
)

# 设置默认索引:当遇到词汇表中不存在的词时,使用<unk>的索引
vocab.set_default_index(vocab["<unk>"])

# 打印词汇表信息
print("词典大小:", len(vocab))  # 词汇表中唯一词汇的数量
print("词典内部映射:", vocab.get_stoi())  # 显示词汇到索引的映射(只显示部分)

# ================================
# 二、模型准备
# ================================

class TextClassificationModel(nn.Module):
    """
    文本分类模型:使用EmbeddingBag + 线性层的简单架构
    
    特点:
        - EmbeddingBag:高效处理变长文本,直接对嵌入向量求平均
        - 适用于文本分类任务
        - 参数少,训练快
    """
    
    def __init__(self, vocab_size, embed_dim, num_class):
        """
        初始化模型
        
        参数:
            vocab_size: 词汇表大小
            embed_dim: 词嵌入维度
            num_class: 分类类别数
        """
        super(TextClassificationModel, self).__init__()
        
        # EmbeddingBag层:将词索引映射为嵌入向量,并对batch内文本的嵌入向量求平均
        # mode='mean'(默认):对文本中所有词的嵌入向量求平均
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim)
        
        # 全连接层:将平均后的嵌入向量映射到类别空间
        self.fc = nn.Linear(embed_dim, num_class)
        
        # 初始化权重
        self.init_weights()
    
    def init_weights(self):
        """初始化模型权重"""
        initrange = 0.5  # 初始化范围
        
        # 均匀分布初始化嵌入层权重
        self.embedding.weight.data.uniform_(-initrange, initrange)
        
        # 均匀分布初始化全连接层权重
        self.fc.weight.data.uniform_(-initrange, initrange)
        
        # 全连接层偏置初始化为0
        self.fc.bias.data.zero_()
    
    def forward(self, text, offsets):
        """
        前向传播
        
        参数:
            text: 所有文本拼接后的词索引张量
            offsets: 每个文本在text张量中的起始位置
            
        返回:
            每个类别的得分(logits)
        """
        # EmbeddingBag处理:根据offsets将text分割成不同文本,对每个文本的词嵌入求平均
        embedded = self.embedding(text, offsets)
        
        # 全连接层:得到类别得分
        return self.fc(embedded)

# 模型参数设置
num_class = 4  # AG_NEWS有4个类别:World, Sports, Business, Sci/Tech
vocab_size = len(vocab)  # 词汇表大小
em_size = 64  # 词嵌入维度

# 创建模型实例并移动到指定设备(GPU/CPU)
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

import time  # 用于计时

def train(dataloader):
    """
    训练一个epoch
    
    参数:
        dataloader: 训练数据加载器
    """
    model.train()  # 设置为训练模式
    total_acc, train_loss, total_count = 0, 0, 0  # 累计准确率、损失、样本数
    log_interval = 500  # 每500个batch打印一次日志
    start_time = time.time()  # 记录开始时间
    
    # 遍历数据:每个batch包含(label, text, offsets)
    for idx, (label, text, offsets) in enumerate(dataloader):
        # 前向传播
        predicted_label = model(text, offsets)
        
        # 梯度清零
        optimizer.zero_grad()
        
        # 计算损失(交叉熵损失)
        loss = criterion(predicted_label, label)
        
        # 反向传播
        loss.backward()
        
        # 参数更新
        optimizer.step()
        
        # 计算准确率:预测正确的样本数
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        
        # 累计损失(乘以batch size)
        train_loss += loss.item() * label.size(0)
        
        # 累计样本数
        total_count += label.size(0)
        
        # 定期打印训练进度
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches '
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(
                      epoch, idx, len(dataloader),
                      total_acc / total_count, train_loss / total_count))
            
            # 重置累计值
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()  # 重置计时器

def evaluate(dataloader):
    """
    评估模型性能(在验证集或测试集上)
    
    参数:
        dataloader: 评估数据加载器
        
    返回:
        (准确率, 平均损失)
    """
    model.eval()  # 设置为评估模式(不计算梯度)
    total_acc, test_loss, total_count = 0, 0, 0
    
    # 不计算梯度,节省内存和计算
    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            # 前向传播
            predicted_label = model(text, offsets)
            
            # 计算损失
            loss = criterion(predicted_label, label)
            
            # 计算准确率
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            
            # 累计损失
            test_loss += loss.item() * label.size(0)
            
            # 累计样本数
            total_count += label.size(0)
    
    # 返回平均准确率和平均损失
    return total_acc / total_count, test_loss / total_count

# ================================
# 三、模型训练
# ================================

from torch.utils.data.dataset import random_split  # 用于数据集分割
from torchtext.data.functional import to_map_style_dataset  # 转换数据集格式
from torch.utils.data import DataLoader  # 数据加载器

# 定义collate_batch函数(原代码中缺失,需要补充)
def collate_batch(batch):
    """
    自定义batch处理函数:将一个batch的数据整理成模型需要的格式
    
    参数:
        batch: 一个batch的数据,每个元素为(label, text)
        
    返回:
        (labels, texts, offsets)
    """
    label_list, text_list, offsets = [], [], [0]
    
    # 遍历batch中的每个样本
    for _label, _text in batch:
        label_list.append(_label)  # 收集标签
        
        # 将文本分词并转换为词汇表中的索引
        processed_text = torch.tensor(vocab(tokenizer(_text)), dtype=torch.int64)
        text_list.append(processed_text)
        
        # 计算offsets:每个文本在拼接后的起始位置
        offsets.append(processed_text.size(0))
    
    # 将所有文本拼接成一个长张量
    text_tensor = torch.cat(text_list)
    
    # 计算offsets(累积和)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    
    # 转换标签为张量
    label_tensor = torch.tensor(label_list, dtype=torch.int64)
    
    return label_tensor.to(device), text_tensor.to(device), offsets.to(device)

# 超参数设置
EPOCHS = 10  # 训练轮数
LR = 5.0  # 学习率(较高,配合SGD和学习率衰减)
BATCH_SIZE = 64  # 批大小

# 损失函数:交叉熵损失(适用于多分类)
criterion = nn.CrossEntropyLoss()

# 优化器:随机梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

# 学习率调度器:每1个epoch,学习率乘以0.1
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

total_accu = None  # 用于记录最佳验证准确率

# 重新加载数据集(因为前面已经使用过train_iter)
train_iter, test_iter = AG_NEWS()

# 将迭代式数据集转换为映射式数据集(支持索引访问)
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)

# 将训练集分割为训练集和验证集(95%训练,5%验证)
num_train = int(len(train_dataset) * 0.95)
split_train, split_valid = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

# 创建数据加载器
train_dataloader = DataLoader(
    split_train, 
    batch_size=BATCH_SIZE, 
    shuffle=True,  # 训练时打乱数据
    collate_fn=collate_batch  # 使用自定义的batch处理函数
)

valid_dataloader = DataLoader(
    split_valid, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=collate_batch
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    collate_fn=collate_batch
)

# 训练循环
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    
    # 训练一个epoch
    train(train_dataloader)
    
    # 在验证集上评估
    val_acc, val_loss = evaluate(valid_dataloader)
    
    # 学习率调度:如果验证准确率没有提升,则降低学习率
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    
    # 打印epoch总结
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f}'.format(
              epoch, time.time() - epoch_start_time, val_acc, val_loss))
    print('-' * 69)

# ================================
# 四、测试评估
# ================================

print('Checking the results of test dataset.')
# 在测试集上评估最终模型性能
test_acc, test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))
相关推荐
喜欢吃豆1 小时前
Parquet 范式:大语言模型训练数据格式优化的基础解析
人工智能·语言模型·自然语言处理·大模型·parquet
AI松子6661 小时前
PyTorch-混合精度训练(amp)
人工智能·pytorch·python
MDLZH1 小时前
Pytorch性能调优简单总结
人工智能·pytorch·python
GIS数据转换器2 小时前
基于GIS的智慧旅游调度指挥平台
运维·人工智能·物联网·无人机·旅游·1024程序员节
沧澜sincerely3 小时前
数据挖掘概述
人工智能·数据挖掘
数数科技的数据干货4 小时前
从爆款到厂牌:解读游戏工业化的业务持续增长道路
运维·数据库·人工智能
amhjdx7 小时前
星巽短剧以科技赋能影视创新,构建全球短剧新生态!
人工智能·科技
听风南巷7 小时前
机器人全身控制WBC理论及零空间原理解析(数学原理解析版)
人工智能·数学建模·机器人
美林数据Tempodata8 小时前
“双新”指引,AI驱动:工业数智应用生产性实践创新
大数据·人工智能·物联网·实践中心建设·金基地建设