【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9------PyTorch新闻数据集文本分类任务

      • [1. 认为主流程code](#1. 认为主流程code)
      • [2. NLP 对话和预测基本均属于分类任务详细见](#2. NLP 对话和预测基本均属于分类任务详细见)
      • [3. Tensorborad](#3. Tensorborad)

1. 认为主流程code

python 复制代码
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN

python 复制代码
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

相关推荐
我叫唧唧波6 小时前
Python+AI 全栈学习笔记
人工智能·python·学习
城北徐宫7 小时前
Linux信号深度解剖:5种产生、3张表、4次切换
linux·c++·学习
三品吉他手会点灯8 小时前
C语言学习笔记 - 43.运算符与表达式 - 运算符1 - 运算符的分类和简单介绍
c语言·笔记·学习·算法
吃好睡好便好9 小时前
芒种时节如何保健
学习·生活
lizhihai_9910 小时前
股市学习心得-A股服务器/算力服务器龙头
大数据·运维·服务器·人工智能·科技·学习
烛之武10 小时前
Pytorch学习笔记(1)
pytorch·笔记·学习
飞翔中文网12 小时前
Java学习笔记之反射
java·笔记·学习
知南x13 小时前
【DPDK核心知识了解】(2) 内核旁路与硬件交互
学习
零陵上将军_xdr13 小时前
后端转全栈学习-Day4-JavaScript 基础-2
开发语言·javascript·学习
一楼的猫13 小时前
叙事指纹93.2%的技术确认与AI写作同质化——网文创作的差异化路径分析
人工智能·学习·机器学习·写作·ai写作