【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9------PyTorch新闻数据集文本分类任务

      • [1. 认为主流程code](#1. 认为主流程code)
      • [2. NLP 对话和预测基本均属于分类任务详细见](#2. NLP 对话和预测基本均属于分类任务详细见)
      • [3. Tensorborad](#3. Tensorborad)

1. 认为主流程code

python 复制代码
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN

python 复制代码
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

相关推荐
rannn_1114 分钟前
【苍穹外卖|Day4】套餐页面开发(新增套餐、分页查询、删除套餐、修改套餐、起售停售)
java·spring boot·后端·学习
张人玉18 分钟前
VisionPro 定位与卡尺测量学习笔记
笔记·学习·计算机视觉·vsionprp
觉醒大王1 小时前
强女思维:着急,是贪欲外显的相。
java·论文阅读·笔记·深度学习·学习·自然语言处理·学习方法
哈__1 小时前
CANN内存管理与资源优化
人工智能·pytorch
YCY^v^2 小时前
JeecgBoot 项目运行指南
java·学习
云小逸2 小时前
【nmap源码解析】Nmap OS识别核心模块深度解析:osscan2.cc源码剖析(1)
开发语言·网络·学习·nmap
是小蟹呀^2 小时前
从稀疏到自适应:人脸识别中稀疏表示的核心演进
人工智能·分类
JustDI-CM2 小时前
AI学习笔记-提示词工程
人工智能·笔记·学习
悟纤2 小时前
学习与专注音乐流派 (Study & Focus Music):AI 音乐创作终极指南 | Suno高级篇 | 第33篇
大数据·人工智能·深度学习·学习·suno·suno api
爱写bug的野原新之助2 小时前
加密摘要算法MD5、SHA、HMAC:学习笔记
笔记·学习