【深度学习实战—6】:基于Pytorch的血细胞图像分类(通用型图像分类程序)

✨博客主页:米开朗琪罗~🎈

✨博主爱好:羽毛球🏸

✨年轻人要:Living for the moment(活在当下)!💪

🏆推荐专栏:【图像处理】【千锤百炼Python】【深度学习】【排序算法

目录

图像分类是搞深度学习一定要掌握的一个视觉任务,本文章将基于血细胞数据集实现图像分类!

本文程序已解耦,可当做通用型图像分类框架使用。

数据集下载地址:Blood Cell Images

😺一、数据集介绍

从 kaggle 上下载到数据集后解压可以得到两个文件夹,分别是dataset-masterdataset2-master

其中dataset-master的 JPEGImages 中包含了血细胞的原始图像,而且没有对血细胞进行分类,在 Annotations 文件夹内包含了对应 JPEGImages 中的每张图像血细胞的.xml格式的定位标签,也就是说,该文件夹是用来做目标检测的。

而在dataset2-master中的 images 文件夹中,包含了TRAINTESTTEST_SIMPLE三种文件夹,且这三种文件夹下包含了血细胞的四种类别,分别是:EOSINOPHIL、LYMPHOCYTE、MONOCYTE、NEUTROPHIL。

但需要注意的是,在TRAINTEST文件夹下的图像,是已经经过数据增强之后的了,而TEST_SIMPLE文件夹下的图像并没有经过数据增强,因此我们将TRAINTESTTEST_SIMPLE三种文件夹分别用作训练集、验证集和测试集。即:

  • TRAIN------train(训练集)
  • TEST------val(验证集)
  • TEST_SIMPLE------test(测试集)

😺二、工程文件夹目录

我的工程文件夹目录如下,可以看到有很多的py文件,每个py文件具有不同的功能,这么写的好处是未来修改程序更加方便,而且每个py程序都没有很长。如果全部写到一个py程序里,则会显得很臃肿,修改起来也不轻松。

对每个文件的解释如下:

  • checkpoints:存放训练的模型权重;
  • datasets:存放数据集。并对数据集划分;
  • log_dir:存放训练日志。包括训练、验证时候的损失与精度情况;
  • option.py:存放整个工程下需要用到的所有参数;
  • utils.py:存放各种函数。包括文件夹创建、绘制精度与损失变化情况、结果预测等;
  • getdata.py:构建数据管道。其中定义了计算数据集中所有图形的均值和方差函数;
  • model.py:构建神经网络模型;
  • train.py:训练模型;
  • evaluate.py:评估训练模型。有三种预测方式可以选择,分别是:对单张图像进行预测,对多张图像进行预测,对整个目录下的图片进行预测;
  • pth2onnx:将pth模型转换到onnx模型;
  • onnx_inference.py:使用.onnx模型对数据进行推理。

😺三、option.py

为了方便了解这些参数代表什么意思,在help中,全部使用了中文解释。

python 复制代码
import argparse


def get_args():
    parser = argparse.ArgumentParser(description='all argument')
    parser.add_argument('--device', type=str, default='cuda', help='可以选择cuda或者cpu训练,苹果电脑m1芯片也可以选择mps加速训练')
    parser.add_argument('--loadsize', type=int, default=224, help='统一图像尺寸')
    parser.add_argument('--epochs', type=int, default=3, help='总的训练次数')
    parser.add_argument('--batch_size', type=int, default=16, help='每次喂多少数据给到网络')
    parser.add_argument('--lr', type=float, default=1e-2, help='初始学习率')
    parser.add_argument('--dataset_train', type=str, default='./datasets/train', help='训练集路径')
    parser.add_argument('--dataset_val', type=str, default="./datasets/val", help='验证集路径')
    parser.add_argument('--dataset_test', type=str, default="./datasets/test", help='测试集路径')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='模型存放路径')
    parser.add_argument('--log_dir', type=str, default='./log_dir', help='训练日志保存的路径')
    parser.add_argument('--logging_txt', type=str, default='./log_dir/logging.txt', help='训练日志位置')
    parser.add_argument('--pretrained', type=bool, default=False, help='是否要继续上次的训练')
    parser.add_argument('--which_epoch', type=str, default='best.pth', help='如果继续训练,需要加载哪一个模型')
    parser.add_argument('--test_model_path', type=str, default='./checkpoints/best.pth', help='选择一个模型用于测试')
    parser.add_argument('--onnx_path', type=str, default='./checkpoints/best.onnx', help='.onnx模型的存放路径')
    parser.add_argument('--test_img_path', type=str, default='./datasets/test/EOSINOPHIL/_0_5239.jpeg', help='选择一张测试图像')
    parser.add_argument('--test_dir_path', type=str, default='./datasets/test', help='选择一个测试路径')
    return parser.parse_args()

😺四、getdata.py

getdata.py中各函数的解释:

  • data_augmentation:该函数用作数据增强,最常使用的是transforms.Resize()transforms.ToTensor()transforms.Normalize()。由于数据集中已经对原始图像进行了数据增强,因此部分参数在下面注释掉了。
    • transforms.Resize():将图像统一尺寸。
    • transforms.ToTensor():维度变换。从 HWC 到 CWH 。
    • transforms.Normalize():图像归一化。归一化的参数需要从get_mean_and_std函数计算得到。
  • MyData:构建数据管道。返回一个字典。
  • imshow:图像可视化。可在构建数据管道后,可视化部分数据。
  • get_mean_and_std:计算图像均值和方差。计算结果放到transforms.Normalize()中。
python 复制代码
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from option import get_args
opt = get_args()


def data_augmentation():

    data_transform = {
        'train': transforms.Compose([
            # transforms.RandomRotation(45),  # 随机旋转,角度在-45到45度之间
            # transforms.RandomHorizontalFlip(p=0.5),  # 以0.5的概率水平翻转
            # transforms.RandomVerticalFlip(p=0.5),  # 以0.5的概率垂直翻转
            # transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数依次为亮度、对比度、饱和度、色相
            # transforms.RandomGrayscale(p=0.025),  # 以0.025的概率变为灰度图像,3通道即R=G=B
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),  # HWC -> CHW
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])  # 使用均值和标准差标准化三个通道的数据
        ]),
        'val': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])
        ]),
        'test': transforms.Compose([
            transforms.Resize((opt.loadsize, opt.loadsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.6786, 0.6413, 0.6605], [0.2599, 0.2595, 0.2569])
        ])
    }
    return data_transform


def MyData():

    data_transform = data_augmentation()

    # 读取数据集
    image_datasets = {
        'train': ImageFolder(opt.dataset_train, data_transform['train']),
        'val': ImageFolder(opt.dataset_test, data_transform['val']),
        'test': ImageFolder(opt.dataset_test, data_transform['test'])
    }
    # 构建管道
    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=opt.batch_size, shuffle=True),
        'val': DataLoader(image_datasets['val'], batch_size=opt.batch_size, shuffle=True),
        'test': DataLoader(image_datasets['test'], batch_size=opt.batch_size, shuffle=True)
    }
    return dataloaders


"""
图像可视化
"""
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.6786, 0.6413, 0.6605])
    std = np.array([0.2599, 0.2595, 0.2569])
    inp = inp * std + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.show()


# 计算数据集所有图像的均值和方差
def get_mean_and_std(dataset):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


if __name__ == '__main__':
    mena_std_transform = transforms.Compose([transforms.ToTensor()])
    dataset = ImageFolder(opt.dataset_train, transform=mena_std_transform)
    print(dataset.class_to_idx)		# 每个类别的索引
    mean, std = get_mean_and_std(dataset)
    print(mean)
    print(std)
    dataloader = MyData()
    inputs, classes = next(iter(dataloader['train']))
    out = make_grid(inputs, nrow=4)     # nrow参数可以选择显示的列数
    class_names = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']
    imshow(out, title=[class_names[x] for x in classes])

运行main函数可以得到:

python 复制代码
类别索引:  {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}
==> Computing mean and std..
tensor([0.6786, 0.6413, 0.6605])
tensor([0.2599, 0.2595, 0.2569])

将 opt.batchsize 设为8后,可以得到下图:

😺五、utils.py

utils.py中各函数的解释:

  • make_dir:创建文件夹。
  • draw_number:绘制损失与精度的变化情况。
  • visual_image_single:单张图像可视化预测。
  • visual_image_multi:多张图像可视化预测。
  • get_confusion_matrix:输出混淆矩阵。用于对整个文件夹进行预测的情况。
  • plot_confusion_matrix:混淆矩阵可视化。
  • get_roc_auc:绘制ROC曲线。
  • visual_img_dir:对整个文件夹进行预测。并得到分类报告、准确率、精确率、召回率、F1得分
python 复制代码
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from PIL import Image
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
from scipy import interp
from itertools import cycle
from option import get_args
opt = get_args()


"""
创建文件夹
"""
def make_dir():
    if os.path.exists(opt.log_dir) == True:
        pass
    else:
        os.mkdir(opt.log_dir)
    if os.path.exists(opt.checkpoints) == True:
        pass
    else:
        os.mkdir(opt.checkpoints)



"""
绘制损失与精度的变化情况
"""
def draw_number(epochs, train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt):

    color = ['red', 'blue', 'green', 'orange']
    marker = ['o', '*', 'p', '+']
    linestyle = ['-', '--', '-.', ':']

    plt.plot(epochs, train_loss_plt, color=color[0], marker=marker[0], linestyle=linestyle[0], label="trainingsets-loss")
    plt.plot(epochs, train_acc_plt, color=color[1], marker=marker[1], linestyle=linestyle[1], label="trainingsets-acc")
    plt.plot(epochs, val_loss_plt, color=color[2], marker=marker[2], linestyle=linestyle[2], label="validationsets-loss")
    plt.plot(epochs, val_acc_plt, color=color[3], marker=marker[3], linestyle=linestyle[3], label="validationsets-acc")

    plt.legend()
    plt.xlabel("epochs")
    plt.ylabel("value")
    plt.title("Loss and accuracy changes in training and validation sets")
    plt.savefig("Loss_Accuracy.jpg")
    plt.show()


"""
单张图像可视化预测
"""
def visual_image_single(img_path, transform_test, model, class_names):
    image = Image.open(img_path).convert('RGB')
    img = transform_test(image)
    img = img.unsqueeze_(0)
    out = model(img)
    pred_softmax = F.softmax(out, dim=1)        # 对 logit 分数做 softmax 运算
    top_n = torch.topk(pred_softmax, len(class_names))
    confs = top_n[0].cpu().detach().numpy().squeeze().tolist()      # 所有类别的预测概率
    confs_max = max(confs)      # 最大概率值
    confs_max_position = confs.index(confs_max)     # 最大概率值所在的位置
    print('Pre:{}   Conf:{:.3f}'.format(class_names[confs_max_position], confs_max))
    plt.axis('off')
    plt.title('Pre:{}   Conf:{:.3f}'.format(class_names[confs_max_position], confs_max))
    plt.imshow(image)
    plt.show()


"""
多张图像可视化预测
"""
def visual_image_multi(dataloader, model, class_names):
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            for i in range(len(images)):
                plt.subplot(4, 4, i + 1)
                plt.title("Prediction:{}\nTarget:{}".format(class_names[predicted[i]], class_names[labels[i]]), fontsize=8)
                img = images[i].swapaxes(0, 1)
                img = img.swapaxes(1, 2)
                plt.imshow(img)
                plt.axis('off')
            plt.show()


"""
对整个文件夹进行预测, 并输出混淆矩阵
"""
def get_confusion_matrix(trues, preds, labels):
    conf_matrix = confusion_matrix(trues, preds, labels=[i for i in range(len(labels))])
    return conf_matrix

def plot_confusion_matrix(conf_matrix, labels):
    plt.imshow(conf_matrix, cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    indices = range(conf_matrix.shape[0])
    plt.xticks(indices, labels)
    plt.yticks(indices, labels)
    plt.colorbar()
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    # 显示数据
    for first_index in range(conf_matrix.shape[0]):
        for second_index in range(conf_matrix.shape[1]):
          plt.text(first_index, second_index, conf_matrix[first_index, second_index])
    plt.savefig('heatmap_confusion_matrix.jpg')
    plt.show()


def get_roc_auc(trues, preds, labels):
    nb_classes = len(labels)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(nb_classes):
        fpr[i], tpr[i], _ = roc_curve(trues[:, i], preds[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    fpr["micro"], tpr["micro"], _ = roc_curve(trues.ravel(), preds.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(nb_classes)])) 

    mean_tpr = np.zeros_like(all_fpr)
    for i in range(nb_classes):
        mean_tpr += interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= nb_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    lw = 2
    plt.figure()
    plt.plot(fpr["micro"], tpr["micro"],label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]),color='deeppink', linestyle=':', linewidth=4)
    plt.plot(fpr["macro"], tpr["macro"],label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]),color='navy', linestyle=':', linewidth=4)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green'])
    for i, color in zip(range(nb_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=lw, label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Some extension of Receiver operating characteristic to multi-class')
    plt.legend(loc="lower right")
    plt.savefig("ROC_多分类.jpg")
    plt.show()

def visual_img_dir(dataloader, model, class_names):
    """
    normalize: True:显示百分比, False: 显示个数
    """
    y_pred = []
    y_true = []
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            y_pred.extend(predicted.cpu().numpy())
            y_true.extend(labels.cpu().numpy())

        accuracy = accuracy_score(y_true, y_pred)  # 准确率 值所有判断正确的数据(TP+TN)占总量的比例。
        precision = precision_score(y_true, y_pred, average='macro')  # 精确率 所有被判定为正类(TP+FP)中,真实的正类(TP)占的比例。
        recall = recall_score(y_true, y_pred, average='macro')  # 召回率 所有真实为正类(TP+FN)中,被判定为正类(TP)占的比例。
        f1 = f1_score(y_true, y_pred, average='macro')  # f1-score 它赋予Precision score和Recall Score相同的权重,以衡量其准确性方面的性能,使其成为准确性指标的替代方案(它不需要我们知道样本总数)。
        conf_matrix = get_confusion_matrix(y_true, y_pred, labels=class_names)
        print('分类报告:\n', classification_report(y_true, y_pred))  # 分类报告
        print("[accuracy:{:.4f}]  [precision:{:.4f}]  [recall:{:.4f}]  [f1:{:.4f}]".format(accuracy, precision, recall, f1))
        plot_confusion_matrix(conf_matrix, labels=class_names)

        test_trues = label_binarize(y_true, classes=[i for i in range(len(class_names))])
        test_preds = label_binarize(y_pred, classes=[i for i in range(len(class_names))])
        get_roc_auc(test_trues, test_preds, class_names)

😺六、model.py

我们可以自定义一个分类网络,也可以使用现有的经典分类网络,如resnet50,在使用resnet50时,可以选择冻结部分网络层,即冻结的网络层不可再被训练,仅使用其网络结构,网络参数是早已学习好的;也可以选择冻结所有层;也可以选择不冻结任何层。在迁移学习的时候,需要注意最后的分类层。血细胞分类共有4类,而resnet50最后的全连接层有1000个神经元输出,所以需要修改最后一层全连接层,将其输出改为4。

python 复制代码
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torchsummary import summary
from option import get_args
opt = get_args()

class My_CNN(nn.Module):
    def __init__(self):
        super(My_CNN, self).__init__()
        self.conv1_1 = nn.Sequential(nn.Conv2d(3, 16, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(16))
        self.conv1_2 = nn.Sequential(nn.Conv2d(16, 32, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(32))
        self.conv2_1 = nn.Sequential(nn.Conv2d(32, 32, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(32))
        self.conv2_2 = nn.Sequential(nn.Conv2d(32, 64, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(64))
        self.conv3_1 = nn.Sequential(nn.Conv2d(64, 64, (3, 3), 1, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(64))
        self.conv3_2 = nn.Sequential(nn.Conv2d(64, 128, (3, 3), 2, 1),
                                     nn.ReLU(),
                                     nn.BatchNorm2d(128))

        self.linear_1 = nn.Linear(28 * 28 * 128, 80)
        self.linear_2 = nn.Linear(80, 4)

    def forward(self, x):
        in_size = x.size(0)
        x = self.conv1_1(x)
        x = self.conv1_2(x)
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv3_1(x)
        x = self.conv3_2(x)
        x = x.view(in_size, -1)
        x = self.linear_1(x)
        out = self.linear_2(x)
        return out


"""
使用预训练模型 1 ------------微调模型
使用预训练的模型来初始化网络,而非随机初始化网络,并且权重可以随着训练的进行而发生改变,步骤如下:
--(1)替换输出层。将模型的最后一个全连接层替换为新的全连接层;
--(2)训练输出层。新的输出层会将前面的层所提取出的低级特征映射到我们所期望的类别的概率;
--(3)训练输出层之前的层。也就是将这些层的权重标记为需要求导。

固定模型的参数 2 ------------微调模型
固定预训练模型的参数,将模型除了输出层之外的所有层看作一个特征提取器。在训练模型的时候,这些层的权重不参与训练,不可优化。
"""
def ResNet():
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    """
    可选择仅冻结某一层或者全部冻结
    """
    # for name, layer in model.named_children():  # 仅冻结layer1层
    #     if name == "layer1":
    #         for param in layer.parameters():
    #             param.requires_grad = False
    #
    # for param in model.parameters():    # 冻结所有层,锁定模型所有参数,所有层设置为不可训练的模式。
    #     param.requires_grad = False

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 4)
    return model

if __name__ == '__main__':
    model = ResNet()
    print(summary(model.to(opt.device), (3, opt.loadsize, opt.loadsize), opt.batch_size))

😺七、train.py

train.py解释如下:

  • make_dir():从 utils 中调用函数,目的是如果当前工程目录下不存在相应的文件夹(log_dircheckpoints),则主动创建,如果已经存在,则不做处理。
  • file = open(opt.logging_txt, 'w'):创建.txt文件,后续将写入训练过程的相关信息,包括损失与精度的变化情况。
  • writer = SummaryWriter():SummaryWriter 类将条目直接写入指定文件夹中的事件文件,以供 TensorBoard 使用。在程序运行时,会在工程目录下自动新建一个 run 文件夹,用于存储训练过程。在 run 文件夹下使用终端,输入tensorboard --logdir=run可以在网页中查看网络训练过程。
  • train_best:定义训练过程的函数。
python 复制代码
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.nn as nn
from model import My_CNN, ResNet
from getdata import MyData
from utils import draw_number, EarlyStopping, make_dir
from option import get_args
opt = get_args()

make_dir()
file = open(opt.logging_txt, 'w')
writer = SummaryWriter()

def train_best(model, num_epoch, dataloaders, optimizer, loss_function):

    model.to(opt.device)
    train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt = [], [], [], []  # 将训练和验证过程的损失和精度保留下来,用于绘制折线图

    for epoch in range(start_epoch, opt.epochs):
        print("---------开始第{}/{}轮训练---------".format(epoch, opt.epochs))
        for phase in ['train', 'val']:

            loss_sum, acc_sum = 0, 0
            step = 0            # 将数据全部取完, 记录每一个batch
            all_step = 0        # 记录取了多少个数据

            for (inputs, labels) in tqdm(dataloaders[phase], position=0):
                if phase == 'train':
                    model.train()
                if phase == 'val':
                    model.eval()
                inputs = inputs.to(opt.device)
                labels = labels.to(opt.device)
                optimizer.zero_grad()  # 梯度清零,防止累加

                a = inputs.size(0)  # 每一批次拿了多少张图像
                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                _, pred = torch.max(outputs, 1)  # 返回每一行的最大值和其索引
                loss.backward()
                optimizer.step()

                loss_sum += loss.item() * inputs.size(0)  # 损失
                acc_sum += torch.sum(pred == labels.data)

                step += 1
                all_step += a

                print("[Epoch: {}/{}]  [step = {}]  [{}_loss = {:.3f}, {}_acc = {:.3f}]".
                      format(epoch, opt.epochs, all_step, phase, loss_sum / all_step, phase, acc_sum.double() / all_step))

            # 保留每一个epoch后的训练损失与精度
            if phase == 'train':
                train_loss = loss_sum / len(dataloaders[phase].dataset)
                train_acc = acc_sum.double() / len(dataloaders[phase].dataset)
                train_acc = np.float32(train_acc.cpu().numpy())
                train_loss_plt.append(train_loss)
                train_acc_plt.append(train_acc)

            else:
                val_loss = loss_sum / len(dataloaders[phase].dataset)
                val_acc = acc_sum.double() / len(dataloaders[phase].dataset)
                val_acc = np.float32(val_acc.cpu().numpy())
                val_loss_plt.append(val_loss)
                val_acc_plt.append(val_acc)

                writer.add_scalars('loss', {'train': train_loss, 'val': val_loss}, global_step=epoch + 1 - start_epoch)
                writer.add_scalars('acc', {'train': train_acc, 'val': val_acc}, global_step=epoch + 1 - start_epoch)
                writer.close()

        print("EPOCH = {}/{}  train_loss = {:.3f}, train_acc = {:.3f}, val_loss = {:.3f}, val_acc = {:.3f} \n".
              format(epoch, num_epoch, train_loss, train_acc, val_loss, val_acc))
        file.write("EPOCH = {}/{}  train_loss = {:.3f}, train_acc = {:.3f}, val_loss = {:.3f}, val_acc = {:.3f} \n".
              format(epoch, num_epoch, train_loss, train_acc, val_loss, val_acc))

        state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
        if epoch % 2 == 0:
            torch.save(state, opt.checkpoints + 'model_{}.pth'.format(epoch))

    draw_number(np.arange(0, opt.epoch-start_epoch, 1), train_loss_plt, train_acc_plt, val_loss_plt, val_acc_plt)


if __name__ == '__main__':
    model = ResNet()
    # model = nn.DataParallel(model)      # 多卡并行训练解开这句注释
    model.to(opt.device)

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    if opt.pretrained:
        checkpoint = torch.load(opt.checkpoints + opt.which_epoch)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存模型,将从头开始训练!')

    dataloaders = MyData()

    train_best(model, opt.epochs, dataloaders, optimizer, loss_function)

😺八、evaluate.py

evaluate.py需要注意:

  • class_names:必须要和数据管道的标签对应。也就是getdata.py运行得到的类别索引。

    类别索引: {'EOSINOPHIL': 0, 'LYMPHOCYTE': 1, 'MONOCYTE': 2, 'NEUTROPHIL': 3}

  • main 函数内的visual_image_single:将每次弹出一张预测结果

  • main 函数内的visual_image_multi:将每次弹出opt.batch_size张预测结果,可以通过修改opt.batch_size改变预测数量,同时可以跳转到utils.py里的visual_image_multi函数中,通过修改plt.subplot()中的参数,可以控制预测结果的排列分布,例如 4 行 4 列 或者 2 行 8 列 等。

  • main 函数内的visual_img_dir:将得到ROC曲线图,混淆矩阵图、各种评估指标等。

python 复制代码
from model import My_CNN, ResNet
from getdata import MyData, data_augmentation
import torch.utils.data
from option import get_args
from utils import visual_image_single, visual_image_multi, visual_img_dir


opt = get_args()

model = ResNet()
ckpt = torch.load(opt.test_model_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.eval()

data_transform = data_augmentation()        # 测试单张图像使用
transform_test = data_transform['test']

dataloaders = MyData()                      # 测试多张图像和文件夹使用
dataloader = dataloaders['test']

class_names = ['EOSINOPHIL', 'LYMPHOCYTE', 'MONOCYTE', 'NEUTROPHIL']

if __name__ == '__main__':

    # visual_image_single(opt.test_img_path, transform_test, model, class_names)
    # visual_image_multi(dataloader, model, class_names)
    visual_img_dir(dataloader, model, class_names=class_names)

程序运行结果如下所示:
visual_image_single

visual_image_multi

visual_img_dir

python 复制代码
分类报告:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.00      0.00      0.00         6
           2       0.00      0.00      0.00         4
           3       0.68      1.00      0.81        48

    accuracy                           0.68        71
   macro avg       0.17      0.25      0.20        71
weighted avg       0.46      0.68      0.55        71

[accuracy:0.6761]  [precision:0.1690]  [recall:0.2500]  [f1:0.2017]

😺九、pth2onnx.py

evaluate.py需要注意:

模型转换时,需要指定模型的输入大小,即input变量。

python 复制代码
import torch
from torch.autograd import Variable
import onnx
from model import My_CNN, ResNet
from option import get_args
opt = get_args()

model = ResNet()
ckpt = torch.load(opt.test_model_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
model.eval()
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, opt.loadsize, opt.loadsize))

torch.onnx.export(model, input, opt.onnx_path, input_names=input_name, output_names=output_name, verbose=True)

# check .onnx model
onnx_model = onnx.load(opt.onnx_path)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

程序运行后就可以在checkpoints文件夹下发现.onnx文件。

😺十、onnx_inference.py

使用onnx模型进行推理。

注意在推理前,要把opt.batch_size改为 1。

python 复制代码
import numpy as np
import onnxruntime
import time
from getdata import MyData
from option import get_args
opt = get_args()

def infer_test(model_path, data_loader, device):
    if device == 'cpu':
        print("using CPUExecutionProvider")
        session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    else:
        print("using CUDAExecutionProvider")
        session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])

    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name

    total = 0.0
    correct = 0
    start_time = time.time()
    for batch, data in enumerate(data_loader):
        X, y = data
        X = X.numpy()
        y = y.numpy()

        output = session.run([output_name], {input_name: X})[0]
        y_pred = np.argmax(output, axis=1)

        if y[0] == y_pred[0]:
            correct += 1
        total += 1
    end_time = time.time()
    print(end_time - start_time)
    print("accuracy is {}%".format(correct / total * 100.0))


def main():
    input_model_path = opt.onnx_path
    device = input("cpu or gpu?")
    dataloaders = MyData()
    infer_test(input_model_path, dataloaders['test'], device)


if __name__ == "__main__":
    main()

推理结果如下所示:

python 复制代码
cpu or gpu?cpu
using CPUExecutionProvider
1.8580236434936523
accuracy is 67.6056338028169%
相关推荐
岑梓铭22 分钟前
(CentOs系统虚拟机)Standalone模式下安装部署“基于Python编写”的Spark框架
linux·python·spark·centos
边缘计算社区25 分钟前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客52036 分钟前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主36 分钟前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
Eric.Lee202139 分钟前
moviepy将图片序列制作成视频并加载字幕 - python 实现
开发语言·python·音视频·moviepy·字幕视频合成·图像制作为视频
KeyPan40 分钟前
【IMU:视觉惯性SLAM系统】
计算机视觉
Dontla44 分钟前
vscode怎么设置anaconda python解释器(anaconda解释器、vscode解释器)
ide·vscode·python
深圳南柯电子1 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉