【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9------PyTorch新闻数据集文本分类任务

      • [1. 认为主流程code](#1. 认为主流程code)
      • [2. NLP 对话和预测基本均属于分类任务详细见](#2. NLP 对话和预测基本均属于分类任务详细见)
      • [3. Tensorborad](#3. Tensorborad)

1. 认为主流程code

python 复制代码
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN

python 复制代码
class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

相关推荐
下午见。6 分钟前
C语言结构体入门:定义、访问与传参全解析
c语言·笔记·学习
im_AMBER9 分钟前
React 16
前端·笔记·学习·react.js·前端框架
民乐团扒谱机3 小时前
实验室安全教育与管理平台学习记录(七)网络安全
学习·安全·web安全
蒙奇D索大3 小时前
【11408学习记录】考研英语长难句精析:三步拆解真题复杂结构,轻松攻克阅读难关!
笔记·学习·考研·改行学it
zd2005724 小时前
AI辅助数据分析和学习了没?
人工智能·学习
洛白白4 小时前
“职场心态与心穷
经验分享·学习·生活·学习方法
_dindong5 小时前
笔试强训:Week-4
数据结构·c++·笔记·学习·算法·哈希算法·散列表
~~李木子~~5 小时前
Windows软件自动扫描与分类工具 - 技术文档
windows·分类·数据挖掘
DKPT6 小时前
如何设置JVM参数避开直接内存溢出的坑?
java·开发语言·jvm·笔记·学习
一 乐6 小时前
智慧党建|党务学习|基于SprinBoot+vue的智慧党建学习平台(源码+数据库+文档)
java·前端·数据库·vue.js·spring boot·学习