【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9------PyTorch新闻数据集文本分类任务

      • [1. 认为主流程code](#1. 认为主流程code)
      • [2. NLP 对话和预测基本均属于分类任务详细见](#2. NLP 对话和预测基本均属于分类任务详细见)
      • [3. Tensorborad](#3. Tensorborad)

1. 认为主流程code

python 复制代码
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN

python 复制代码
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

相关推荐
走在路上的菜鸟3 分钟前
Android学Dart学习笔记第二十三节 类-扩展类型
android·笔记·学习·flutter
墨_浅-4 分钟前
教育/培训行业智能体应用分类及知识库检索模型微调
人工智能·分类·数据挖掘
愤怒学习的白菜13 分钟前
0 trivial:UVM的空壳平台
学习·uvm·ic验证
快乐非自愿14 分钟前
Java函数式接口——渐进式学习
java·开发语言·学习
心动啊12136 分钟前
负载均衡 + Nginx的基本使用
学习·nginx·负载均衡
菜鸟‍36 分钟前
【课程学习】
学习·信息与通信
暗然而日章37 分钟前
C++基础:Stanford CS106L学习笔记 11 Lambdas表达式
c++·笔记·学习
lxh011341 分钟前
2025/12/19学习记录
学习
辞旧 lekkk43 分钟前
【c++】c++11(上)
开发语言·c++·学习·萌新
走在路上的菜鸟1 小时前
Android学Dart学习笔记第二十一节 类-点的简写
android·笔记·学习·flutter