六、新闻主题分类任务

以一段新闻报道中的文本描述内容为输入,使用模型帮助我们判断它最有可能属于哪一种类型的新闻,这是典型的文本分类问题 。我们这里假定每种类型是互斥 的,即文本描述有且只有一种类型,例如一篇新闻不能即是娱乐类又是财经类,只能是一种类别。

一、数据下载与介绍

我们使用的是AG_NEWS数据集,已经被集成在了torchtext中,下面是下载数据集的代码:

注意:

如果没有torchtext时,使用pip安装时会有一个大坑。

torchtext安装时会检查pytorch的版本,如果版本不兼容,它会卸载你的torch,然后安装一个GPU版本的兼容的torch,这个过程是自动的,没有什么提示,或者大部分人不会具体去看提示,这里会非常坑。

我在刚开始安装torchtext后,怎么也无法使用GPU,我还是以为是显卡有问题了,搞了好久最后才发现是torch被变成了CPU版本,刚开始不知道,就卸载torch,然后重装CUDA版本的torch,但是没用,最后装上的还是CPU版本的torch(torchtext真是霸道!),往复了几次都不行,怎么装都是CUP版本的torch,巨坑!!!

怎么寻找正确的torchtext版本?

一个简单的规律是,torchtext的版本号比torch高一个子版本,然后主版本为0, 阶段版本号最好也是对应的。例如:

torch1.13.1 对应的 torchtext 应该torchtext 0.14.1
那么应该使用下面命令安装
pip install torchtext==0.14.1

上面的规律是对应torch主版本为1的,torch主版本为2的可以参考类似的规律。

感谢博客《更新 torchtext 造成的torch版本不匹配的问题》带来的解答。

python 复制代码
# 导入有关torch的工具包
import torch as tc
import torchtext
# 导入torchtext.datasets中的文本分类任务
from torchtext.datasets import AG_NEWS
import os

# 定义数据下载路径,当前路径的data文件夹
load_data_path = './Datasets/'
# 如果不存在该路径,则创建这个路径
if not os.path.exists(load_data_path):
    os.makedirs(load_data_path)

# 选取torchtext中的文本分类数据集'AG_NEWS'即新闻主题分类数据,保存在指定目录下
# 将数值映射后的训练和验证数据加载到内存中
train_data, test_data = AG_NEWS(
    root=load_data_path, split=('train', 'test'))

# AG_NEWS返回的数据是一个迭代器,每个元素都是一个元组,包含文本和标签
for (label, text) in train_data:
    print(f"Label: {label}, Text: {text}")
for (label, text) in test_data:
    print(f"Label: {label}, Text: {text}")

下载完成后,会有两个以.csv结尾的文件,

数据集中的内容如下:

"3","Fears for T N pension after talks","Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."

"4","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com)","SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket."

  1. 训练集有12000个样本,测试集有7600个样本。
  2. 一共有四种标签{1,2,3,4}对应{World,Sports,Business,SCI/Tech}分别指世界性新闻、体育新闻、商业新闻和技术类新闻。
  3. 每条样本有三列,第一列是标签,说明该新闻属于哪一类;第二列是新闻标题;第三列是新闻简述。
  4. test.csv和train.csv中的格式相同

二、构建Dataset类,读取数据

我们使用上面的代码将数据集进行保存后,新建一个Python文件,开始构建读取数据的Dataset类,代码如下:

python 复制代码
#!------------------------第一步:数据读取,构建Dataset类--------------------------------
class AG_NEWS_Data(Dataset):
    def __init__(self, train=True) -> None:
        super().__init__()
        data_path = os.path.join(BASE_PATH, 'train.csv') if train else os.path.join(
            BASE_PATH, 'test.csv')  # 设置数据路径,本实验中只使用了训练集
        self.data = pd.read_csv(data_path, sep=',', header=None)  # 读取数据
        # print(self.data.head())

        sen_len = []  # 每条样本中文本句子长度
        self.contents = ''  # 所有样本分词后的内容
        token_number = 0  # 所有文本中有多少个不同的分词
        label_count = []  # 所有样本的label标签

        # * 计算每条样本的长度,取出每条样本的标签label,拼接所有样本内容到contents中
        for i in range(self.__len__()):
            content, label = self.__getitem__(i)
        # for content, label in data:
            sen_len.append(len(content.split(' ')))  # 每条样本的长度
            label_count.append(label)  # 取出每条样本的标签label
            self.contents += ' '+content  # 拼接样本内容到contents中

        vocab_dict = {v: idx for idx, v in enumerate(
            set(self.contents.split(' ')))}  # 获取所有分词集合
        token_number = len(vocab_dict)
        sen_len_distribution = {str(i): sen_len.count(i) for i in sorted(
            set(sen_len))}  # 句子长度分布的字典,如{'80':192,'81':689,...},即长度为80的句子有192个...
        label_n_distribution = {str(i): label_count.count(i) for i in set(
            label_count)}  # 标签数量分布的字典,如{'1':20000,'2':20000,...},每个标签对应的样本个数

        self.vocab_dict, self.token_number, self.sen_len_distribution, self.label_n_distribution = vocab_dict, token_number, sen_len_distribution, label_n_distribution

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        label = int(self.data.iloc[index, 0])  # 提取标签,并转换为int类型
        content = self.data.iloc[index, 1]+' ' + \
            self.data.iloc[index, 2]  # 拼接样本中的题目和内容文本
        content = content.lower()  # 将所有单词转换为小写类型
        # 使用正则表达式,只保留文本中的数字和单词,将其余信息替换为空格
        content = re.sub(r'[^\w\s]', ' ', content)
        content = re.sub(r'\s+', ' ', content)  # 将多个空格的位置替换为1个空格

        return content, label

三、构建网络模型

对网络中的每一层都要设置初始化权重值,权重值的初始换范围一般是一个小于1的数,可以接近零,但不能是0,是0的话,模型会变得特别难训练(大量的经验总结到的)。

只设置三层简单的线性层。

python 复制代码
#! -------------------第二步:构建网络模型,构建带有Embedding层的文本分类模型-----
class TextSentiment(nn.Module):
    """文本分类模型"""

    def __init__(self, vocab_size, embed_dim, num_class):
        """description:类的初始化函数

        Args:
            vocab_size (int): 整个语料包含的不同词汇总数
            embed_dim (int): 指定词嵌入的维度
            num_class (int): 文本分类的类别总数
        """
        super().__init__()
        # 实例化Embedding层,sparse=True代表每次对该层求解梯度,只更新部分权重
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim=embed_dim, sparse=True)
        # 实例化线性层,参数分别是embed_dim和num_class
        self.fc1 = nn.Linear(in_features=LEN_STA*EMBED_DIM, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=num_class)
        # 为各层初始化权重
        self.init_weight()

    def init_weight(self):
        """初始化权重函数
        """
        # 指定初始权重的取值范围数
        init_range = 0.5
        # 各层的权重参数都是初始化为均匀分布
        self.embedding.weight.data.uniform_(-init_range, init_range)
        for fc in [self.fc1, self.fc2, self.fc3]:
            fc.weight.data.uniform_(-init_range, init_range)
            # 偏置初始化为0
            fc.bias.data.zero_()

    def forward(self, text):
        """正向计算过程

        Args:
            text (list): 文本数值映射后的结果

        Returns:
            tensor: 与类别数尺寸相同的张量,用以判断文本类别
        """
        # 获得embedding的结果embedded
        # 此时embedded的尺寸为(m,32)其中m是BACTH_SIZE大小的数据中的词汇总数,32为指定词嵌入的维度EMBED_DIM
        # print(text.shape)
        embedded = self.embedding(text)
        # embedded = F.avg_pool1d(embedded, kernel_size=3)
        x = embedded.view(embedded.size(0), -1)
        # print(embedded.shape)
        # print(len_sta*EMBED_DIM)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x

四、序列化和长度标准化

规范输入句子的长度,并进行序列化,即将文本转换为tensor类型的整数,才可以进行Embedding操作。 可以使用one-hot编码进行序列化,这里为了方便直接使用了[0,1,2,3,4,...,]这种单纯的数字。

python 复制代码
def get_length_standard(rate=0.9):
    """计算文本内容标准化长度的函数,根据样本文本长度的分布情况(从小到大),取前rate的分割点处的长度作为标准长度.

    Args:
        rate (float, optional): Defaults to 0.9.

    Returns:
        int: 样本文本长度分布中前rate的分割点处的长度
    """
    value_sum = 0  # 统计当前符合条件的样本总数
    sample_len = len(AG_NEWS)  # 数据集总长度

    # 取出每个长度对应的样本数量key=句子长度,value=该长度下的样本数量
    for key, value in AG_NEWS.sen_len_distribution.items():
        value_sum += int(value)
        if (value_sum/sample_len >= rate):
            return int(key)


def get_sen_ser(sentence, len_sta):
    """对样本内容进行标准化和序列化的函数,多删少补(补0)

    Args:
        sentence (str): [description]
        len_sta ([type]): [description]

    Returns:
        [type]: [description]
    """
    # 对句子进行序列化
    vocab_list = [AG_NEWS.vocab_dict[v] for v in sentence.split(' ')]

    if (len(vocab_list) >= len_sta):
        return vocab_list[:len_sta]
    else:
        vocab_list.extend([0]*(len_sta-len(vocab_list)))
        return vocab_list

五、自定义生成Batch的函数

python 复制代码
#! --------------------------第四步:自定义生成batch的函数----------------------

def generate_batch(batch):
    """生成batch数据的函数

    Args:
        batch (list): 由样本张量和对应标签的元组组成的batch_size大小的列表,形如:[(sample1,label1),(sample2,label2),...]
    Returns:
        tensor: 样本张量和标签各自的列表形式(张量),形如:text=tensor([sample1,sample2,....]),label=tensor([label1,label2,....])
    """
    label = []  # 存储样本标签
    text = []  # 存储样本的文本
    for t, l in batch:
        # 从batch中获得标签张量
        text.append(get_sen_ser(t, len_sta=LEN_STA))  # 对文本进行标准化和序列化处理
        # 从batch中获得样本张量
        label.append(int(l)-1)  # 序列化标签
    # text = tc.cat(text)
    # text = torch.tensor(np.array(text), device=device)
    text = torch.tensor(text, device=device)
    return text, torch.tensor(label, device=device)

六、构建训练函数

python 复制代码
#!---------------------------第五步:构建训练函数----------------------------
def train(train_data):
    """模型训练函数"""
    # 初始化训练损失和准确率为0
    train_loss = 0
    train_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(train_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数

    # 对data进行循环遍历,使用每个batch的数据进行参数更新
    for text, label in data:
        # 1、设置优化器初始梯度为0
        optimizer.zero_grad()
        # 2、模型输入一个批次数据,获得输出
        label_pre = model(text)
        # 3、根据真实标签与模型输出计算损失
        loss = loss_F(label_pre, label)
        # 4、误差反向传播
        loss.backward()
        # 5、更新参数
        optimizer.step()

        # 将该批次的损失加到总损失中
        train_loss += loss.item()
        # 将该批次的准确率加到总准确率中
        train_acc += (label_pre.argmax(1) == label).sum().item()

    # 使用学习率调节器自动调整学习率
    scheduler.step()

    # 返回本轮训练的平均损失和平均准确率
    return train_loss/len(train_data), train_acc/len(train_data)

七、构建验证函数

python 复制代码
#!-----------------------------第六步:构建验证函数------------------------
def val(val_data):
    model.eval()

    # 初始化训练损失和准确率为0
    val_loss = 0
    val_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(val_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数
    with torch.no_grad():
        for text, label in data:
            label_pre = model(text)
            # 根据真实标签与模型输出计算损失
            loss = loss_F(label_pre, label)

            # 将该损失加入到总损失中
            val_loss += loss

            # 将该次的准确个数加入到总个数中
            val_acc += (label_pre.argmax(1) == label).sum().item()
    # 返回本轮训练的平均损失和平均准确率
    return val_loss/len(val_data), val_acc/len(val_data)

八、模型训练和验证

python 复制代码
if __name__ == '__main__':

    # 设置数据的存储路径
    BASE_PATH = r'H:\Pytorch学习\Datasets\datasets\AG_NEWS'
    # 检查显卡是否可用
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # 加载训练数据
    AG_NEWS = AG_NEWS_Data(train=True)  # 加载数据集
    generator = torch.Generator().manual_seed(2024)  # 设置随机数生成器和随机种子
    AG_NEWS_train, AG_NEWS_val = random_split(  # 划分训练集和验证集
        AG_NEWS, [0.7, 0.3], generator=generator)

    VOCAB_SIZE = len(AG_NEWS.vocab_dict)  # 获取train_data语料中包含的不同词汇总数
    BATCH_SIZE = 1000  # 指定BATCH_SIZE的大小
    EMBED_DIM = 32  # 指定词嵌入的维度
    NUN_CLASS = 4  # 类别总数
    LEARN_RATE = 0.005  # 学习率
    LEN_STA = get_length_standard(0.9)  # 每句话的规范长度,统一长度,多删少补
    EPOCH = 100  # 设置数据集迭代次数

    model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)  # 实例化模型
    loss_F = nn.CrossEntropyLoss().to(device)  # 设置损失函数
    optimizer = optim.SGD(model.parameters(), lr=LEARN_RATE)  # 设置优化函数
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=1, gamma=0.9)  # 设置学习率调整器

    # 进行模型训练和验证
    for epoch in range(EPOCH):
        train_loss, train_acc = train(AG_NEWS_train)
        print(
            f'epoch {epoch}:\ttrain_loss:{train_loss:.6f}\ttrain_acc:{train_acc:.6f}', end='\t')
        val_loss, val_acc = val(AG_NEWS_val)
        print(
            f'val_loss:{val_loss:.6f}\tval_acc:{val_acc:.6f}')

九、完整代码与输出结果

(一)完整代码

python 复制代码
import re
import os
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader
import torch.optim as optim
import torch


#!------------------------第一步:数据读取,构建Dataset类-----------------------------
class AG_NEWS_Data(Dataset):
    def __init__(self, train=True) -> None:
        super().__init__()
        data_path = os.path.join(BASE_PATH, 'train.csv') if train else os.path.join(
            BASE_PATH, 'test.csv')  # 设置数据路径,本实验中只使用了训练集
        self.data = pd.read_csv(data_path, sep=',', header=None)  # 读取数据
        # print(self.data.head())

        sen_len = []  # 每条样本中文本句子长度
        self.contents = ''  # 所有样本分词后的内容
        token_number = 0  # 所有文本中有多少个不同的分词
        label_count = []  # 所有样本的label标签

        # * 计算每条样本的长度,取出每条样本的标签label,拼接所有样本内容到contents中
        for i in range(self.__len__()):
            content, label = self.__getitem__(i)
        # for content, label in data:
            sen_len.append(len(content.split(' ')))  # 每条样本的长度
            label_count.append(label)  # 取出每条样本的标签label
            self.contents += ' '+content  # 拼接样本内容到contents中

        vocab_dict = {v: idx for idx, v in enumerate(
            set(self.contents.split(' ')))}  # 获取所有分词集合
        token_number = len(vocab_dict)
        sen_len_distribution = {str(i): sen_len.count(i) for i in sorted(
            set(sen_len))}  # 句子长度分布的字典,如{'80':192,'81':689,...},即长度为80的句子有192个...
        label_n_distribution = {str(i): label_count.count(i) for i in set(
            label_count)}  # 标签数量分布的字典,如{'1':20000,'2':20000,...},每个标签对应的样本个数

        self.vocab_dict, self.token_number, self.sen_len_distribution, self.label_n_distribution = vocab_dict, token_number, sen_len_distribution, label_n_distribution

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        label = int(self.data.iloc[index, 0])  # 提取标签,并转换为int类型
        content = self.data.iloc[index, 1]+' ' + \
            self.data.iloc[index, 2]  # 拼接样本中的题目和内容文本
        content = content.lower()  # 将所有单词转换为小写类型
        # 使用正则表达式,只保留文本中的数字和单词,将其余信息替换为空格
        content = re.sub(r'[^\w\s]', ' ', content)
        content = re.sub(r'\s+', ' ', content)  # 将多个空格的位置替换为1个空格

        return content, label

#! --------------------第二步:构建网络模型,构建带有Embedding层的文本分类模型------
class TextSentiment(nn.Module):
    """文本分类模型"""

    def __init__(self, vocab_size, embed_dim, num_class):
        """description:类的初始化函数

        Args:
            vocab_size (int): 整个语料包含的不同词汇总数
            embed_dim (int): 指定词嵌入的维度
            num_class (int): 文本分类的类别总数
        """
        super().__init__()
        # 实例化Embedding层,sparse=True代表每次对该层求解梯度,只更新部分权重
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim=embed_dim, sparse=True)
        # 实例化线性层,参数分别是embed_dim和num_class
        self.fc1 = nn.Linear(in_features=LEN_STA*EMBED_DIM, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=num_class)
        # 为各层初始化权重
        self.init_weight()

    def init_weight(self):
        """初始化权重函数
        """
        # 指定初始权重的取值范围数
        init_range = 0.5
        # 各层的权重参数都是初始化为均匀分布
        self.embedding.weight.data.uniform_(-init_range, init_range)
        for fc in [self.fc1, self.fc2, self.fc3]:
            fc.weight.data.uniform_(-init_range, init_range)
            # 偏置初始化为0
            fc.bias.data.zero_()

    def forward(self, text):
        """正向计算过程

        Args:
            text (list): 文本数值映射后的结果

        Returns:
            tensor: 与类别数尺寸相同的张量,用以判断文本类别
        """
        # 获得embedding的结果embedded
        # 此时embedded的尺寸为(m,32)其中m是BACTH_SIZE大小的数据中的词汇总数,32为指定词嵌入的维度EMBED_DIM
        # print(text.shape)
        embedded = self.embedding(text)
        # embedded = F.avg_pool1d(embedded, kernel_size=3)
        x = embedded.view(embedded.size(0), -1)
        # print(embedded.shape)
        # print(len_sta*EMBED_DIM)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x


#! ----------------------第三步:将每个样本中的句子进行长度标准化和序列化-----------------
def get_length_standard(rate=0.9):
    """计算文本内容标准化长度的函数,根据样本文本长度的分布情况(从小到大),取前rate的分割点处的长度作为标准长度.

    Args:
        rate (float, optional): Defaults to 0.9.

    Returns:
        int: 样本文本长度分布中前rate的分割点处的长度
    """
    value_sum = 0  # 统计当前符合条件的样本总数
    sample_len = len(AG_NEWS)  # 数据集总长度

    # 取出每个长度对应的样本数量key=句子长度,value=该长度下的样本数量
    for key, value in AG_NEWS.sen_len_distribution.items():
        value_sum += int(value)
        if (value_sum/sample_len >= rate):
            return int(key)


def get_sen_ser(sentence, len_sta):
    """对样本内容进行标准化和序列化的函数,多删少补(补0)

    Args:
        sentence (str): [description]
        len_sta ([type]): [description]

    Returns:
        [type]: [description]
    """
    # 对句子进行序列化
    vocab_list = [AG_NEWS.vocab_dict[v] for v in sentence.split(' ')]

    if (len(vocab_list) >= len_sta):
        return vocab_list[:len_sta]
    else:
        vocab_list.extend([0]*(len_sta-len(vocab_list)))
        return vocab_list


#! --------------------------第四步:自定义生成batch的函数-------------------------
def generate_batch(batch):
    """生成batch数据的函数

    Args:
        batch (list): 由样本张量和对应标签的元组组成的batch_size大小的列表,形如:[(sample1,label1),(sample2,label2),...]
    Returns:
        tensor: 样本张量和标签各自的列表形式(张量),形如:text=tensor([sample1,sample2,....]),label=tensor([label1,label2,....])
    """
    label = []  # 存储样本标签
    text = []  # 存储样本的文本
    for t, l in batch:
        # 从batch中获得标签张量
        text.append(get_sen_ser(t, len_sta=LEN_STA))  # 对文本进行标准化和序列化处理
        # 从batch中获得样本张量
        label.append(int(l)-1)  # 序列化标签
    # text = tc.cat(text)
    # text = torch.tensor(np.array(text), device=device)
    text = torch.tensor(text, device=device)
    return text, torch.tensor(label, device=device)


#!-----------------------------------第五步:构建训练函数-------------------------
def train(train_data):
    """模型训练函数"""
    # 初始化训练损失和准确率为0
    train_loss = 0
    train_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(train_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数

    # 对data进行循环遍历,使用每个batch的数据进行参数更新
    for text, label in data:
        # 1、设置优化器初始梯度为0
        optimizer.zero_grad()
        # 2、模型输入一个批次数据,获得输出
        label_pre = model(text)
        # 3、根据真实标签与模型输出计算损失
        loss = loss_F(label_pre, label)
        # 4、误差反向传播
        loss.backward()
        # 5、更新参数
        optimizer.step()

        # 将该批次的损失加到总损失中
        train_loss += loss.item()
        # 将该批次的准确率加到总准确率中
        train_acc += (label_pre.argmax(1) == label).sum().item()

    # 使用学习率调节器自动调整学习率
    scheduler.step()

    # 返回本轮训练的平均损失和平均准确率
    return train_loss/len(train_data), train_acc/len(train_data)

#!-----------------------------第六步:构建验证函数--------------------------
def val(val_data):
    model.eval()

    # 初始化训练损失和准确率为0
    val_loss = 0
    val_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(val_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数
    with torch.no_grad():
        for text, label in data:
            label_pre = model(text)
            # 根据真实标签与模型输出计算损失
            loss = loss_F(label_pre, label)

            # 将该损失加入到总损失中
            val_loss += loss

            # 将该次的准确个数加入到总个数中
            val_acc += (label_pre.argmax(1) == label).sum().item()
    # 返回本轮训练的平均损失和平均准确率
    return val_loss/len(val_data), val_acc/len(val_data)


#! --------------------------第七步:进行模型训练和验证---------------------------------
if __name__ == '__main__':

    # 设置数据的存储路径
    BASE_PATH = r'H:\Pytorch学习\Datasets\datasets\AG_NEWS'
    # 检查显卡是否可用
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # 加载训练数据
    AG_NEWS = AG_NEWS_Data(train=True)  # 加载数据集
    generator = torch.Generator().manual_seed(2024)  # 设置随机数生成器和随机种子
    AG_NEWS_train, AG_NEWS_val = random_split(  # 划分训练集和验证集
        AG_NEWS, [0.7, 0.3], generator=generator)

    VOCAB_SIZE = len(AG_NEWS.vocab_dict)  # 获取train_data语料中包含的不同词汇总数
    BATCH_SIZE = 1000  # 指定BATCH_SIZE的大小
    EMBED_DIM = 32  # 指定词嵌入的维度
    NUN_CLASS = 4  # 类别总数
    LEARN_RATE = 0.005  # 学习率
    LEN_STA = get_length_standard(0.9)  # 每句话的规范长度,统一长度,多删少补
    EPOCH = 100  # 设置数据集迭代次数

    model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)  # 实例化模型
    loss_F = nn.CrossEntropyLoss().to(device)  # 设置损失函数
    optimizer = optim.SGD(model.parameters(), lr=LEARN_RATE)  # 设置优化函数
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=1, gamma=0.9)  # 设置学习率调整器

    # 进行模型训练和验证
    for epoch in range(EPOCH):
        train_loss, train_acc = train(AG_NEWS_train)
        print(
            f'epoch {epoch}:\ttrain_loss:{train_loss:.6f}\ttrain_acc:{train_acc:.6f}', end='\t')
        val_loss, val_acc = val(AG_NEWS_val)
        print(
            f'val_loss:{val_loss:.6f}\tval_acc:{val_acc:.6f}')

(二)输出结果

EPOCH = 100 ,设置数据集迭代了100次,结果如下,可以看出,模型能力有限,有预测能力,但只有一点点。

epoch 0: train_loss:0.017730 train_acc:0.267131 val_loss:0.003573 val_acc:0.266333

epoch 1: train_loss:0.002014 train_acc:0.274238 val_loss:0.001452 val_acc:0.284139

epoch 2: train_loss:0.001417 train_acc:0.289357 val_loss:0.001417 val_acc:0.289583

epoch 3: train_loss:0.001396 train_acc:0.292762 val_loss:0.001392 val_acc:0.293667

epoch 4: train_loss:0.001389 train_acc:0.294369 val_loss:0.001386 val_acc:0.298972

epoch 5: train_loss:0.001383 train_acc:0.298071 val_loss:0.001384 val_acc:0.297611

epoch 6: train_loss:0.001381 train_acc:0.300738 val_loss:0.001382 val_acc:0.303028

epoch 7: train_loss:0.001379 train_acc:0.303667 val_loss:0.001379 val_acc:0.302861

epoch 8: train_loss:0.001376 train_acc:0.304119 val_loss:0.001375 val_acc:0.303528

epoch 9: train_loss:0.001375 train_acc:0.304893 val_loss:0.001376 val_acc:0.300528

epoch 10: train_loss:0.001374 train_acc:0.307119 val_loss:0.001372 val_acc:0.308639

epoch 11: train_loss:0.001372 train_acc:0.308905 val_loss:0.001374 val_acc:0.303667

epoch 12: train_loss:0.001371 train_acc:0.310357 val_loss:0.001372 val_acc:0.309667

epoch 13: train_loss:0.001370 train_acc:0.311393 val_loss:0.001372 val_acc:0.309917

epoch 14: train_loss:0.001369 train_acc:0.311607 val_loss:0.001370 val_acc:0.308667

epoch 15: train_loss:0.001369 train_acc:0.311929 val_loss:0.001370 val_acc:0.312222

epoch 16: train_loss:0.001368 train_acc:0.312952 val_loss:0.001369 val_acc:0.309778

epoch 17: train_loss:0.001368 train_acc:0.313524 val_loss:0.001367 val_acc:0.314528

epoch 18: train_loss:0.001367 train_acc:0.313905 val_loss:0.001368 val_acc:0.315444

epoch 19: train_loss:0.001367 train_acc:0.314810 val_loss:0.001367 val_acc:0.315694

epoch 20: train_loss:0.001366 train_acc:0.315952 val_loss:0.001368 val_acc:0.313333

epoch 21: train_loss:0.001366 train_acc:0.317262 val_loss:0.001367 val_acc:0.314750

epoch 22: train_loss:0.001366 train_acc:0.315976 val_loss:0.001366 val_acc:0.316222

epoch 23: train_loss:0.001365 train_acc:0.317345 val_loss:0.001366 val_acc:0.316139

epoch 24: train_loss:0.001365 train_acc:0.315976 val_loss:0.001366 val_acc:0.316444

epoch 25: train_loss:0.001365 train_acc:0.316786 val_loss:0.001366 val_acc:0.314111

epoch 26: train_loss:0.001365 train_acc:0.316905 val_loss:0.001365 val_acc:0.318611

epoch 27: train_loss:0.001364 train_acc:0.318774 val_loss:0.001365 val_acc:0.316944

epoch 28: train_loss:0.001364 train_acc:0.319036 val_loss:0.001366 val_acc:0.314944

epoch 29: train_loss:0.001364 train_acc:0.318393 val_loss:0.001365 val_acc:0.316111

epoch 30: train_loss:0.001364 train_acc:0.319250 val_loss:0.001365 val_acc:0.316833

epoch 31: train_loss:0.001364 train_acc:0.318440 val_loss:0.001365 val_acc:0.317444

epoch 32: train_loss:0.001364 train_acc:0.319500 val_loss:0.001365 val_acc:0.316444

epoch 33: train_loss:0.001364 train_acc:0.319333 val_loss:0.001365 val_acc:0.315972

epoch 34: train_loss:0.001363 train_acc:0.319786 val_loss:0.001365 val_acc:0.315389

epoch 35: train_loss:0.001363 train_acc:0.319560 val_loss:0.001365 val_acc:0.316583

epoch 36: train_loss:0.001363 train_acc:0.320024 val_loss:0.001365 val_acc:0.316556

epoch 37: train_loss:0.001363 train_acc:0.320774 val_loss:0.001365 val_acc:0.316639

epoch 38: train_loss:0.001363 train_acc:0.320179 val_loss:0.001365 val_acc:0.315889

epoch 39: train_loss:0.001363 train_acc:0.320393 val_loss:0.001365 val_acc:0.315139

epoch 40: train_loss:0.001363 train_acc:0.320774 val_loss:0.001365 val_acc:0.316278

epoch 41: train_loss:0.001363 train_acc:0.320821 val_loss:0.001365 val_acc:0.315167

epoch 42: train_loss:0.001363 train_acc:0.321167 val_loss:0.001365 val_acc:0.315667

epoch 43: train_loss:0.001363 train_acc:0.320619 val_loss:0.001365 val_acc:0.316167

epoch 44: train_loss:0.001363 train_acc:0.320571 val_loss:0.001365 val_acc:0.316778

epoch 45: train_loss:0.001363 train_acc:0.321714 val_loss:0.001365 val_acc:0.316611

epoch 46: train_loss:0.001363 train_acc:0.321143 val_loss:0.001365 val_acc:0.316000

epoch 47: train_loss:0.001363 train_acc:0.321262 val_loss:0.001365 val_acc:0.316056

epoch 48: train_loss:0.001363 train_acc:0.321429 val_loss:0.001365 val_acc:0.315722

epoch 49: train_loss:0.001363 train_acc:0.321036 val_loss:0.001365 val_acc:0.315917

epoch 50: train_loss:0.001363 train_acc:0.321417 val_loss:0.001365 val_acc:0.315639

epoch 51: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.315889

epoch 52: train_loss:0.001362 train_acc:0.321524 val_loss:0.001365 val_acc:0.316056

epoch 53: train_loss:0.001362 train_acc:0.321690 val_loss:0.001365 val_acc:0.315889

epoch 54: train_loss:0.001362 train_acc:0.321429 val_loss:0.001365 val_acc:0.316028

epoch 55: train_loss:0.001362 train_acc:0.321536 val_loss:0.001365 val_acc:0.316083

epoch 56: train_loss:0.001362 train_acc:0.321417 val_loss:0.001365 val_acc:0.315639

epoch 57: train_loss:0.001362 train_acc:0.321476 val_loss:0.001365 val_acc:0.315750

epoch 58: train_loss:0.001362 train_acc:0.321512 val_loss:0.001365 val_acc:0.315806

epoch 59: train_loss:0.001362 train_acc:0.321452 val_loss:0.001365 val_acc:0.315861

epoch 60: train_loss:0.001362 train_acc:0.321750 val_loss:0.001365 val_acc:0.316000

epoch 61: train_loss:0.001362 train_acc:0.321298 val_loss:0.001365 val_acc:0.315889

epoch 62: train_loss:0.001362 train_acc:0.321405 val_loss:0.001365 val_acc:0.316000

epoch 63: train_loss:0.001362 train_acc:0.321607 val_loss:0.001365 val_acc:0.315972

epoch 64: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316111

epoch 65: train_loss:0.001362 train_acc:0.321452 val_loss:0.001365 val_acc:0.316056

epoch 66: train_loss:0.001362 train_acc:0.321452 val_loss:0.001365 val_acc:0.316111

epoch 67: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316083

epoch 68: train_loss:0.001362 train_acc:0.321464 val_loss:0.001365 val_acc:0.316111

epoch 69: train_loss:0.001362 train_acc:0.321679 val_loss:0.001365 val_acc:0.316139

epoch 70: train_loss:0.001362 train_acc:0.321476 val_loss:0.001365 val_acc:0.316139

epoch 71: train_loss:0.001362 train_acc:0.321714 val_loss:0.001365 val_acc:0.316111

epoch 72: train_loss:0.001362 train_acc:0.321679 val_loss:0.001365 val_acc:0.316056

epoch 73: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.316056

epoch 74: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316028

epoch 75: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316028

epoch 76: train_loss:0.001362 train_acc:0.321500 val_loss:0.001365 val_acc:0.316083

epoch 77: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316056

epoch 78: train_loss:0.001362 train_acc:0.321571 val_loss:0.001365 val_acc:0.316056

epoch 79: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316028

epoch 80: train_loss:0.001362 train_acc:0.321631 val_loss:0.001365 val_acc:0.316028

epoch 81: train_loss:0.001362 train_acc:0.321512 val_loss:0.001365 val_acc:0.316028

epoch 82: train_loss:0.001362 train_acc:0.321536 val_loss:0.001365 val_acc:0.316056

epoch 83: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316056

epoch 84: train_loss:0.001362 train_acc:0.321512 val_loss:0.001365 val_acc:0.316056

epoch 85: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.316056

epoch 86: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316056

epoch 87: train_loss:0.001362 train_acc:0.321536 val_loss:0.001365 val_acc:0.316056

epoch 88: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316056

epoch 89: train_loss:0.001362 train_acc:0.321571 val_loss:0.001365 val_acc:0.316056

epoch 90: train_loss:0.001362 train_acc:0.321536 val_loss:0.001365 val_acc:0.316056

epoch 91: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.316056

epoch 92: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.316083

epoch 93: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316083

epoch 94: train_loss:0.001362 train_acc:0.321571 val_loss:0.001365 val_acc:0.316056

epoch 95: train_loss:0.001362 train_acc:0.321583 val_loss:0.001365 val_acc:0.316056

epoch 96: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316056

epoch 97: train_loss:0.001362 train_acc:0.321536 val_loss:0.001365 val_acc:0.316056

epoch 98: train_loss:0.001362 train_acc:0.321548 val_loss:0.001365 val_acc:0.316056

epoch 99: train_loss:0.001362 train_acc:0.321560 val_loss:0.001365 val_acc:0.316056

相关推荐
m0_7482329231 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
靴子学长8 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AIGCmagic社区16 小时前
AI多模态技术介绍:理解多模态大语言模型的原理
人工智能·语言模型·自然语言处理
开放知识图谱19 小时前
论文浅尝 | HippoRAG:神经生物学启发的大语言模型的长期记忆(Neurips2024)
人工智能·语言模型·自然语言处理
i查拉图斯特拉如是1 天前
基于MindSpore NLP的PEFT微调
人工智能·自然语言处理
野蛮的大西瓜2 天前
BigBlueButton视频会议 vs 钉钉视频会议系统的详细对比
人工智能·自然语言处理·自动化·音视频·实时音视频·信息与通信·视频编解码
Hugging Face2 天前
欢迎 PaliGemma 2 – 来自 Google 的新视觉语言模型
人工智能·语言模型·自然语言处理
宝贝儿好2 天前
【NLP】第七章:Transformer原理及实操
人工智能·深度学习·自然语言处理·transformer
新加坡内哥谈技术2 天前
OpenAI发布全新AI模型 o3 与 o3-mini:推理与编码能力迎来重大突破. AGI 来临
大数据·人工智能·语言模型·自然语言处理
三月七(爱看动漫的程序员)2 天前
Knowledge Graph Prompting for Multi-Document Question Answering
人工智能·gpt·学习·语言模型·自然语言处理·机器人·知识图谱