PyTorch深度学习模型训练流程:(一、分类)

自己写了个封装PyTorch深度学习训练流程的函数,实现了根据输入参数训练模型并可视化训练过程的功能,可以方便快捷地检验一个模型的效果,有助于提高选择模型架构、优化超参数等工作的效率。发出来供大家参考,如有不足之处,欢迎批评讨论。

分类是人工智能的一个非常重要的应用,这篇文章分享的函数适用于实现分类的深度学习模型,包括以下功能:

  1. 根据输入的数据集、模型、优化器、损失函数等参数训练一个分类模型;
  2. 使用visdom可视化训练过程,实时输出精确度曲线、损失曲线、混淆矩阵和ROC曲线;
  3. 支持二分类和多分类;
  4. 输入数据集支持形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型,可以方便灵活地调用torch内置数据集或自定义数据集;
  5. 支持使用GPU加速深度学习模型的训练。

废话不多说,先来看下输出效果:
二分类
多分类

深度学习的完整流程通常包括以下几个步骤:

  1. 收集数据
  2. 数据预处理
  3. 选择模型
  4. 训练模型
  5. 评估模型
  6. 超参数调优
  7. 测试模型

本函数封装了训练模型和评估模型的步骤,包括:

  1. 若数据集为(X,y)形式则分离训练集和测试集(测试集占20%),数据标准化,封装训练集和测试集;
  2. 将训练集和测试集设置为加载器;
  3. 遍历训练集加载器,计算每一批次的输出和损失,并反向传播更新神经网络参数;
  4. 每迭代100次评估一下模型,用测试集数据计算并画出精确度曲线、损失曲线、混淆矩阵和ROC曲线。

代码如下:

python 复制代码
from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdom

from typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizer


def classify(
        data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],
        model: nn.Module,
        optimizer: Optimizer,
        criterion: nn.Module,
        scaler: Optional[TransformerMixin] = None,
        batch_size: int = 64,
        epochs: int = 10,
        device: Optional[torch.device] = None
) -> nn.Module:
    """
    分类任务的训练函数。
    :param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型
    :param model: 分类模型
    :param optimizer: 优化器
    :param criterion: 损失函数
    :param scaler: 数据标准化器
    :param batch_size: 批大小
    :param epochs: 训练轮数
    :param device: 训练设备
    :return: 训练好的分类模型
    """
    if isinstance(data[0], np.ndarray):
        X, y = data
        # 处理类别
        classes = np.unique(y)
        classes_str = [str(i) for i in classes]
        num_classes = len(classes)
        # 分离训练集和测试集,指定随机种子以便复现
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        # 数据标准化
        if scaler is not None:
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        # 转换为tensor
        X_train = torch.from_numpy(X_train.astype(np.float32))
        X_test = torch.from_numpy(X_test.astype(np.float32))
        y_train = torch.from_numpy(y_train.astype(np.int64))
        y_test = torch.from_numpy(y_test.astype(np.int64))
        # 将X和y封装成TensorDataset
        train_dataset = TensorDataset(X_train, y_train)
        test_dataset = TensorDataset(X_test, y_test)

    elif isinstance(data[0], Dataset):
        train_dataset, test_dataset = data
        classes = list(train_dataset.class_to_idx.values())
        classes_str = train_dataset.classes
        num_classes = len(classes)
    else:
        raise ValueError('Unsupported data type')

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )

    model.to(device)
    vis = Visdom()
    # 训练模型
    for epoch in range(epochs):
        for step, (batch_x_train, batch_y_train) in enumerate(train_loader):
            batch_x_train = batch_x_train.to(device)
            batch_y_train = batch_y_train.to(device)
            # 前向传播
            output = model(batch_x_train)
            loss = criterion(output, batch_y_train)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            niter = epoch * len(train_loader) + step + 1  # 计算迭代次数
            if niter % 100 == 0:
                # 评估模型
                model.eval()
                with torch.no_grad():
                    eval_dict = {
                        'test_loss': [],
                        'test_acc': [],
                        'test_cm': [],
                        'test_roc': [],
                    }
                    for batch_x_test, batch_y_test in test_loader:
                        batch_x_test = batch_x_test.to(device)
                        batch_y_test = batch_y_test.to(device)
                        test_output = model(batch_x_test)
                        predicted = torch.argmax(test_output, 1)
                        test_predicted_tuple = (batch_y_test.numpy(), predicted.numpy())
                        # 计算并记录损失、精确度、混淆矩阵、ROC曲线
                        eval_dict['test_loss'].append(criterion(test_output, batch_y_test))
                        eval_dict['test_acc'].append(accuracy_score(*test_predicted_tuple))
                        eval_dict['test_cm'].append(confusion_matrix(*test_predicted_tuple, labels=classes))
                        if num_classes == 2:
                            # eval_dict['test_roc']形状为(len,(fpr,tpr),3)
                            eval_dict['test_roc'].append(roc_curve(*test_predicted_tuple)[:2])
                        else:
                            # 多分类ROC曲线需要one-hot编码
                            y_test_one_hot, predicted_one_hot = map(partial(label_binarize, classes=classes), test_predicted_tuple)
                            fpr_list = []
                            tpr_list = []
                            for i in range(num_classes):
                                fpr, tpr, _ = roc_curve(y_test_one_hot[:, i], predicted_one_hot[:, i])
                                # 无(fpr,tpr)数据点时,插值填充(0,0)数据点
                                if len(fpr) != 3:
                                    fpr = np.insert(fpr, 0, 0)
                                    tpr = np.insert(tpr, 0, 0)
                                fpr_list.append(fpr)
                                tpr_list.append(tpr)
                            # eval_dict['test_roc']形状为(len,(fpr,tpr),num_classes,3)
                            eval_dict['test_roc'].append((fpr_list, tpr_list))

                    # 画出损失曲线
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),
                        win='loss',
                        update='append',
                        opts=dict(title='Loss', legend=['train_loss', 'test_loss']),
                    )
                    # 画出精确度曲线
                    train_acc = accuracy_score(batch_y_train.numpy(), torch.argmax(output, 1).numpy())
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.tensor((train_acc, np.mean(eval_dict['test_acc']))).unsqueeze(0),
                        win='accuracy',
                        update='append',
                        opts=dict(title='Accuracy', legend=['train_acc', 'test_acc'], ytickmin=0, ytickmax=1),
                    )
                    # 画出混淆矩阵
                    vis.heatmap(
                        X=np.add.reduce(eval_dict['test_cm']),
                        win='confusion_matrix',
                        opts=dict(title='Confusion Matrix', columnnames=classes_str, rownames=classes_str),
                    )
                    # 画出ROC曲线
                    test_roc_arr = np.array(eval_dict['test_roc'])
                    zeros_df = pd.DataFrame({'fpr': [0], 'tpr': [0]})  # 用于填充的(0,0)数据点
                    ones_df = pd.DataFrame({'fpr': [1], 'tpr': [1]})  # 用于填充的(1,1)数据点
                    if num_classes == 2:
                        plot_arr = test_roc_arr[:, :, 1]  # 提取(fpr,tpr)数据点,形状为(len,(fpr,tpr))
                        cats = pd.qcut(plot_arr[:, 0], q=10, labels=False, duplicates='drop')  # 按fpr大小分成10个数据一样多的区间
                        groups = pd.DataFrame(plot_arr, columns=['fpr', 'tpr']).groupby(cats).mean()  # 计算每个区间的平均值,形状为(10,(fpr,tpr))
                        plot_df = pd.concat([zeros_df, groups, ones_df])  # 头添加(0,0),尾添加(1,1)数据点,形状为(12,(fpr,tpr))
                        x = plot_df['fpr']
                        Y = plot_df['tpr']
                    else:
                        plot_df_list = []
                        plot_arr = test_roc_arr[:, :, :, 1].swapaxes(1, 2)  # 提取(fpr,tpr)数据点并换轴,形状为(len,num_classes,(fpr,tpr))
                        for i in range(num_classes):
                            cats = pd.qcut(plot_arr[:, i, 0], q=10, labels=False, duplicates='drop')
                            groups = pd.DataFrame(plot_arr[:, i, :], columns=['fpr', 'tpr']).groupby(cats).mean()  # 形状为(10,(fpr,tpr))
                            plot_df = pd.concat([zeros_df, groups, ones_df])  # 形状为(12,(fpr,tpr))
                            add_num = 12 - len(plot_df)
                            # 长度不足12时,插值填充(0,0)数据点
                            if add_num > 0:
                                plot_df = pd.concat([zeros_df] * add_num + [plot_df])
                            plot_df_list.append(plot_df)  # 形状为(num_classes,12,(fpr,tpr))
                        plot_arr_sum = np.stack(plot_df_list, axis=1)  # 形状为(12,num_classes,(fpr,tpr))
                        x = plot_arr_sum[:, :, 0]
                        Y = plot_arr_sum[:, :, 1]
                    vis.line(
                        X=x,
                        Y=Y,
                        win='ROC',
                        opts=dict(title='ROC', legend=classes_str),
                    )
    return model

注意:代码运行前要先在命令行输入python -m visdom.server,在浏览器中打开提供的链接:

成功运行的效果如下:

相关推荐
说私域1 分钟前
淘宝直播与开源链动2+1模式AI智能名片S2B2C商城小程序的融合发展研究
人工智能·小程序·开源
陈敬雷-充电了么-CEO兼CTO2 分钟前
复杂任务攻坚:多模态大模型推理技术从 CoT 数据到 RL 优化的突破之路
人工智能·python·神经网络·自然语言处理·chatgpt·aigc·智能体
阿维同学16 分钟前
自动驾驶关键算法深度研究
人工智能·算法·自动驾驶
盼小辉丶20 分钟前
TensorFlow深度学习实战——基于自编码器构建句子向量
人工智能·深度学习·tensorflow
YOLO大师28 分钟前
华为OD机试 2025B卷 - 小明减肥(C++&Python&JAVA&JS&C语言)
c++·python·华为od·华为od机试·华为od2025b卷·华为机试2025b卷·华为od机试2025b卷
xiao5kou4chang6kai440 分钟前
【Python-GEE】如何利用Landsat时间序列影像通过调和回归方法提取农作物特征并进行分类
python·gee·森林监测·洪涝灾害·干旱评估·植被变化
kaikaile199544 分钟前
使用Python进行数据可视化的初学者指南
开发语言·python·信息可视化
Par@ish1 小时前
【网络安全】恶意 Python 包“psslib”仿冒 passlib,可导致 Windows 系统关闭
windows·python·web安全
倔强的石头1061 小时前
Bright Data MCP+Trae :快速构建电商导购助手垂直智能体
大数据·人工智能
意疏1 小时前
【Python篇】PyCharm 安装与基础配置指南
开发语言·python·pycharm