- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
- 🚀 文章来源:K同学的学习圈子
文章目录
- 前言
- [1 我的环境](#1 我的环境)
- [2 pytorch实现DenseNet算法](#2 pytorch实现DenseNet算法)
-
- [2.1 前期准备](#2.1 前期准备)
-
- [2.1.1 引入库](#2.1.1 引入库)
- [2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)](#2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU))
- [2.1.3 导入数据](#2.1.3 导入数据)
- [2.1.4 可视化数据](#2.1.4 可视化数据)
- [2.1.4 图像数据变换](#2.1.4 图像数据变换)
- [2.1.4 划分数据集](#2.1.4 划分数据集)
- [2.1.4 加载数据](#2.1.4 加载数据)
- [2.1.4 查看数据](#2.1.4 查看数据)
- [2.2 搭建DenseNet_SE模型](#2.2 搭建DenseNet_SE模型)
- [2.3 训练模型](#2.3 训练模型)
-
- [2.3.1 设置超参数](#2.3.1 设置超参数)
- [2.3.2 编写训练函数](#2.3.2 编写训练函数)
- [2.3.3 编写测试函数](#2.3.3 编写测试函数)
- [2.3.4 正式训练](#2.3.4 正式训练)
- [2.4 结果可视化](#2.4 结果可视化)
- [2.4 指定图片进行预测](#2.4 指定图片进行预测)
- [2.6 模型评估](#2.6 模型评估)
- [3 tensorflow实现DenseNet算法](#3 tensorflow实现DenseNet算法)
- [4 知识点详解](#4 知识点详解)
-
- [4.1 SE-Net算法详解](#4.1 SE-Net算法详解)
- [4 总结](#4 总结)
前言
关键字: pytorch实现DenseNet_SE算法,tensorflow实现DenseNet_SE算法,SE_Net算法详解
1 我的环境
- 电脑系统:Windows 11
- 语言环境:python 3.8.6
- 编译器:pycharm2020.2.3
- 深度学习环境:
torch == 1.9.1+cu111
torchvision == 0.10.1+cu111
TensorFlow 2.10.1 - 显卡:NVIDIA GeForce RTX 4070
2 pytorch实现DenseNet算法
2.1 前期准备
2.1.1 引入库
python
import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率
import warnings
warnings.filterwarnings('ignore') # 忽略一些warning内容,无需打印
2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)
python
"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))
输出
Using cuda device
2.1.3 导入数据
python
'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\monkeypox_recognition"
data_dir = Path(data_dir)
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)
输出
['Monkeypox', 'Others']
2.1.4 可视化数据
python
'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "Monkeypox"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
image_file = image_files[i]
ax = plt.subplot(3, 4, i + 1)
img = Image.open(str(image_file))
plt.imshow(img)
plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()
2.1.4 图像数据变换
python
'''前期工作-图像数据变换'''
total_datadir = data_dir
# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸
transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)
输出
Dataset ImageFolder
Number of datapoints: 2142
Root location: D:\DeepLearning\data\monkeypox_recognition
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
{'Monkeypox': 0, 'Others': 1}
2.1.4 划分数据集
python
'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data)) # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))
输出
train_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E08E0D0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E04E640>
train_size=1713
test_size=429
2.1.4 加载数据
python
'''前期工作-加载数据'''
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=1)
2.1.4 查看数据
python
'''前期工作-查看数据'''
for X, y in test_dl:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
输出
Shape of X [N, C, H, W]: torch.Size([32, 3, 224, 224])
Shape of y: torch.Size([32]) torch.int64
2.2 搭建DenseNet_SE模型
python
"""构建DenseNet_SE网络"""
# 这里我们采用了Pytorch的框架来实现DenseNet,
# 首先实现DenseBlock中的内部结构,这里是BN+ReLU+1×1Conv+BN+ReLU+3×3Conv结构,最后也加入dropout层用于训练过程。
class _DenseLayer(nn.Sequential):
"""Basic unit of DenseBlock (using bottleneck layer) """
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
self.add_module('relu1', nn.ReLU(inplace=True)),
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
kernel_size=1, stride=1, bias=False)),
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
self.add_module('relu2', nn.ReLU(inplace=True)),
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = drop_rate
def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return torch.cat([x, new_features], 1)
# 实现DenseBlock模块,内部是密集连接方式(输入特征数线性增长):
class _DenseBlock(nn.Sequential):
"""DenseBlock """
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(
num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
self.add_module('denselayer%d' % (i + 1), layer)
# 实现Transition层,它主要是一个卷积层和一个池化层:
class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
# SE模块实现
class Squeeze_excitation_layer(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_excitation_layer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=True),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=True),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# 最后我们实现DenseNet_SE网络:
class DenseNet(nn.Module):
r"""Densenet-BC model class, based on
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
Args:
growth_rate (int) - how many filters to add each layer (`k` in paper)
block_config (list of 3 or 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
"""
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
num_init_features=24, bn_size=4, compression=0.5, drop_rate=0,
num_classes=1000):
super(DenseNet, self).__init__()
# First Conv2d
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
]))
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
self.features.add_module('denseblock%d' % (i + 1), block)
num_features += num_layers * growth_rate
if i != len(block_config) - 1:
transition = _Transition(num_input_features=num_features,
num_output_features=int(num_features * compression))
self.features.add_module('transition%d' % (i + 1), transition)
num_features = int(num_features * compression)
# SE_layer
self.features.add_module('SE-module', Squeeze_excitation_layer(num_features))
# Final bn+relu
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.features.add_module('relu5', nn.ReLU(inplace=True))
# classification layer
self.classifier = nn.Linear(num_features, num_classes)
# params initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
features = self.features(x)
out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
out = self.classifier(out)
return out
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth'}
def densenet121(pretrained=False, **kwargs):
"""DenseNet121"""
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
if pretrained:
# '.'s are no longer allowed in module names, but pervious _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = model_zoo.load_url(model_urls['densenet121'])
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
model.load_state_dict(state_dict)
return model
"""搭建densenet121模型"""
# model = densenet121().to(device)
model = densenet121(True).to(device) # 使用预训练模型
print(model)
print(summary.summary(model, (3, 224, 224))) # 查看模型的参数量以及相关指标
该模型相比DenseNet的区别是,在最后一个denseblock后增加SE_layer。
python
# SE_layer
self.features.add_module('SE-module', Squeeze_excitation_layer(num_features))
输出
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
BatchNorm2d-5 [-1, 64, 56, 56] 128
ReLU-6 [-1, 64, 56, 56] 0
Conv2d-7 [-1, 128, 56, 56] 8,192
BatchNorm2d-8 [-1, 128, 56, 56] 256
ReLU-9 [-1, 128, 56, 56] 0
Conv2d-10 [-1, 32, 56, 56] 36,864
BatchNorm2d-11 [-1, 96, 56, 56] 192
ReLU-12 [-1, 96, 56, 56] 0
Conv2d-13 [-1, 128, 56, 56] 12,288
BatchNorm2d-14 [-1, 128, 56, 56] 256
ReLU-15 [-1, 128, 56, 56] 0
Conv2d-16 [-1, 32, 56, 56] 36,864
BatchNorm2d-17 [-1, 128, 56, 56] 256
ReLU-18 [-1, 128, 56, 56] 0
Conv2d-19 [-1, 128, 56, 56] 16,384
BatchNorm2d-20 [-1, 128, 56, 56] 256
ReLU-21 [-1, 128, 56, 56] 0
Conv2d-22 [-1, 32, 56, 56] 36,864
BatchNorm2d-23 [-1, 160, 56, 56] 320
ReLU-24 [-1, 160, 56, 56] 0
Conv2d-25 [-1, 128, 56, 56] 20,480
BatchNorm2d-26 [-1, 128, 56, 56] 256
ReLU-27 [-1, 128, 56, 56] 0
Conv2d-28 [-1, 32, 56, 56] 36,864
BatchNorm2d-29 [-1, 192, 56, 56] 384
ReLU-30 [-1, 192, 56, 56] 0
Conv2d-31 [-1, 128, 56, 56] 24,576
BatchNorm2d-32 [-1, 128, 56, 56] 256
ReLU-33 [-1, 128, 56, 56] 0
Conv2d-34 [-1, 32, 56, 56] 36,864
BatchNorm2d-35 [-1, 224, 56, 56] 448
ReLU-36 [-1, 224, 56, 56] 0
Conv2d-37 [-1, 128, 56, 56] 28,672
BatchNorm2d-38 [-1, 128, 56, 56] 256
ReLU-39 [-1, 128, 56, 56] 0
Conv2d-40 [-1, 32, 56, 56] 36,864
BatchNorm2d-41 [-1, 256, 56, 56] 512
ReLU-42 [-1, 256, 56, 56] 0
Conv2d-43 [-1, 128, 56, 56] 32,768
AvgPool2d-44 [-1, 128, 28, 28] 0
BatchNorm2d-45 [-1, 128, 28, 28] 256
ReLU-46 [-1, 128, 28, 28] 0
Conv2d-47 [-1, 128, 28, 28] 16,384
BatchNorm2d-48 [-1, 128, 28, 28] 256
ReLU-49 [-1, 128, 28, 28] 0
Conv2d-50 [-1, 32, 28, 28] 36,864
BatchNorm2d-51 [-1, 160, 28, 28] 320
ReLU-52 [-1, 160, 28, 28] 0
Conv2d-53 [-1, 128, 28, 28] 20,480
BatchNorm2d-54 [-1, 128, 28, 28] 256
ReLU-55 [-1, 128, 28, 28] 0
Conv2d-56 [-1, 32, 28, 28] 36,864
BatchNorm2d-57 [-1, 192, 28, 28] 384
ReLU-58 [-1, 192, 28, 28] 0
Conv2d-59 [-1, 128, 28, 28] 24,576
BatchNorm2d-60 [-1, 128, 28, 28] 256
ReLU-61 [-1, 128, 28, 28] 0
Conv2d-62 [-1, 32, 28, 28] 36,864
BatchNorm2d-63 [-1, 224, 28, 28] 448
ReLU-64 [-1, 224, 28, 28] 0
Conv2d-65 [-1, 128, 28, 28] 28,672
BatchNorm2d-66 [-1, 128, 28, 28] 256
ReLU-67 [-1, 128, 28, 28] 0
Conv2d-68 [-1, 32, 28, 28] 36,864
BatchNorm2d-69 [-1, 256, 28, 28] 512
ReLU-70 [-1, 256, 28, 28] 0
Conv2d-71 [-1, 128, 28, 28] 32,768
BatchNorm2d-72 [-1, 128, 28, 28] 256
ReLU-73 [-1, 128, 28, 28] 0
Conv2d-74 [-1, 32, 28, 28] 36,864
BatchNorm2d-75 [-1, 288, 28, 28] 576
ReLU-76 [-1, 288, 28, 28] 0
Conv2d-77 [-1, 128, 28, 28] 36,864
BatchNorm2d-78 [-1, 128, 28, 28] 256
ReLU-79 [-1, 128, 28, 28] 0
Conv2d-80 [-1, 32, 28, 28] 36,864
BatchNorm2d-81 [-1, 320, 28, 28] 640
ReLU-82 [-1, 320, 28, 28] 0
Conv2d-83 [-1, 128, 28, 28] 40,960
BatchNorm2d-84 [-1, 128, 28, 28] 256
ReLU-85 [-1, 128, 28, 28] 0
Conv2d-86 [-1, 32, 28, 28] 36,864
BatchNorm2d-87 [-1, 352, 28, 28] 704
ReLU-88 [-1, 352, 28, 28] 0
Conv2d-89 [-1, 128, 28, 28] 45,056
BatchNorm2d-90 [-1, 128, 28, 28] 256
ReLU-91 [-1, 128, 28, 28] 0
Conv2d-92 [-1, 32, 28, 28] 36,864
BatchNorm2d-93 [-1, 384, 28, 28] 768
ReLU-94 [-1, 384, 28, 28] 0
Conv2d-95 [-1, 128, 28, 28] 49,152
BatchNorm2d-96 [-1, 128, 28, 28] 256
ReLU-97 [-1, 128, 28, 28] 0
Conv2d-98 [-1, 32, 28, 28] 36,864
BatchNorm2d-99 [-1, 416, 28, 28] 832
ReLU-100 [-1, 416, 28, 28] 0
Conv2d-101 [-1, 128, 28, 28] 53,248
BatchNorm2d-102 [-1, 128, 28, 28] 256
ReLU-103 [-1, 128, 28, 28] 0
Conv2d-104 [-1, 32, 28, 28] 36,864
BatchNorm2d-105 [-1, 448, 28, 28] 896
ReLU-106 [-1, 448, 28, 28] 0
Conv2d-107 [-1, 128, 28, 28] 57,344
BatchNorm2d-108 [-1, 128, 28, 28] 256
ReLU-109 [-1, 128, 28, 28] 0
Conv2d-110 [-1, 32, 28, 28] 36,864
BatchNorm2d-111 [-1, 480, 28, 28] 960
ReLU-112 [-1, 480, 28, 28] 0
Conv2d-113 [-1, 128, 28, 28] 61,440
BatchNorm2d-114 [-1, 128, 28, 28] 256
ReLU-115 [-1, 128, 28, 28] 0
Conv2d-116 [-1, 32, 28, 28] 36,864
BatchNorm2d-117 [-1, 512, 28, 28] 1,024
ReLU-118 [-1, 512, 28, 28] 0
Conv2d-119 [-1, 256, 28, 28] 131,072
AvgPool2d-120 [-1, 256, 14, 14] 0
BatchNorm2d-121 [-1, 256, 14, 14] 512
ReLU-122 [-1, 256, 14, 14] 0
Conv2d-123 [-1, 128, 14, 14] 32,768
BatchNorm2d-124 [-1, 128, 14, 14] 256
ReLU-125 [-1, 128, 14, 14] 0
Conv2d-126 [-1, 32, 14, 14] 36,864
BatchNorm2d-127 [-1, 288, 14, 14] 576
ReLU-128 [-1, 288, 14, 14] 0
Conv2d-129 [-1, 128, 14, 14] 36,864
BatchNorm2d-130 [-1, 128, 14, 14] 256
ReLU-131 [-1, 128, 14, 14] 0
Conv2d-132 [-1, 32, 14, 14] 36,864
BatchNorm2d-133 [-1, 320, 14, 14] 640
ReLU-134 [-1, 320, 14, 14] 0
Conv2d-135 [-1, 128, 14, 14] 40,960
BatchNorm2d-136 [-1, 128, 14, 14] 256
ReLU-137 [-1, 128, 14, 14] 0
Conv2d-138 [-1, 32, 14, 14] 36,864
BatchNorm2d-139 [-1, 352, 14, 14] 704
ReLU-140 [-1, 352, 14, 14] 0
Conv2d-141 [-1, 128, 14, 14] 45,056
BatchNorm2d-142 [-1, 128, 14, 14] 256
ReLU-143 [-1, 128, 14, 14] 0
Conv2d-144 [-1, 32, 14, 14] 36,864
BatchNorm2d-145 [-1, 384, 14, 14] 768
ReLU-146 [-1, 384, 14, 14] 0
Conv2d-147 [-1, 128, 14, 14] 49,152
BatchNorm2d-148 [-1, 128, 14, 14] 256
ReLU-149 [-1, 128, 14, 14] 0
Conv2d-150 [-1, 32, 14, 14] 36,864
BatchNorm2d-151 [-1, 416, 14, 14] 832
ReLU-152 [-1, 416, 14, 14] 0
Conv2d-153 [-1, 128, 14, 14] 53,248
BatchNorm2d-154 [-1, 128, 14, 14] 256
ReLU-155 [-1, 128, 14, 14] 0
Conv2d-156 [-1, 32, 14, 14] 36,864
BatchNorm2d-157 [-1, 448, 14, 14] 896
ReLU-158 [-1, 448, 14, 14] 0
Conv2d-159 [-1, 128, 14, 14] 57,344
BatchNorm2d-160 [-1, 128, 14, 14] 256
ReLU-161 [-1, 128, 14, 14] 0
Conv2d-162 [-1, 32, 14, 14] 36,864
BatchNorm2d-163 [-1, 480, 14, 14] 960
ReLU-164 [-1, 480, 14, 14] 0
Conv2d-165 [-1, 128, 14, 14] 61,440
BatchNorm2d-166 [-1, 128, 14, 14] 256
ReLU-167 [-1, 128, 14, 14] 0
Conv2d-168 [-1, 32, 14, 14] 36,864
BatchNorm2d-169 [-1, 512, 14, 14] 1,024
ReLU-170 [-1, 512, 14, 14] 0
Conv2d-171 [-1, 128, 14, 14] 65,536
BatchNorm2d-172 [-1, 128, 14, 14] 256
ReLU-173 [-1, 128, 14, 14] 0
Conv2d-174 [-1, 32, 14, 14] 36,864
BatchNorm2d-175 [-1, 544, 14, 14] 1,088
ReLU-176 [-1, 544, 14, 14] 0
Conv2d-177 [-1, 128, 14, 14] 69,632
BatchNorm2d-178 [-1, 128, 14, 14] 256
ReLU-179 [-1, 128, 14, 14] 0
Conv2d-180 [-1, 32, 14, 14] 36,864
BatchNorm2d-181 [-1, 576, 14, 14] 1,152
ReLU-182 [-1, 576, 14, 14] 0
Conv2d-183 [-1, 128, 14, 14] 73,728
BatchNorm2d-184 [-1, 128, 14, 14] 256
ReLU-185 [-1, 128, 14, 14] 0
Conv2d-186 [-1, 32, 14, 14] 36,864
BatchNorm2d-187 [-1, 608, 14, 14] 1,216
ReLU-188 [-1, 608, 14, 14] 0
Conv2d-189 [-1, 128, 14, 14] 77,824
BatchNorm2d-190 [-1, 128, 14, 14] 256
ReLU-191 [-1, 128, 14, 14] 0
Conv2d-192 [-1, 32, 14, 14] 36,864
BatchNorm2d-193 [-1, 640, 14, 14] 1,280
ReLU-194 [-1, 640, 14, 14] 0
Conv2d-195 [-1, 128, 14, 14] 81,920
BatchNorm2d-196 [-1, 128, 14, 14] 256
ReLU-197 [-1, 128, 14, 14] 0
Conv2d-198 [-1, 32, 14, 14] 36,864
BatchNorm2d-199 [-1, 672, 14, 14] 1,344
ReLU-200 [-1, 672, 14, 14] 0
Conv2d-201 [-1, 128, 14, 14] 86,016
BatchNorm2d-202 [-1, 128, 14, 14] 256
ReLU-203 [-1, 128, 14, 14] 0
Conv2d-204 [-1, 32, 14, 14] 36,864
BatchNorm2d-205 [-1, 704, 14, 14] 1,408
ReLU-206 [-1, 704, 14, 14] 0
Conv2d-207 [-1, 128, 14, 14] 90,112
BatchNorm2d-208 [-1, 128, 14, 14] 256
ReLU-209 [-1, 128, 14, 14] 0
Conv2d-210 [-1, 32, 14, 14] 36,864
BatchNorm2d-211 [-1, 736, 14, 14] 1,472
ReLU-212 [-1, 736, 14, 14] 0
Conv2d-213 [-1, 128, 14, 14] 94,208
BatchNorm2d-214 [-1, 128, 14, 14] 256
ReLU-215 [-1, 128, 14, 14] 0
Conv2d-216 [-1, 32, 14, 14] 36,864
BatchNorm2d-217 [-1, 768, 14, 14] 1,536
ReLU-218 [-1, 768, 14, 14] 0
Conv2d-219 [-1, 128, 14, 14] 98,304
BatchNorm2d-220 [-1, 128, 14, 14] 256
ReLU-221 [-1, 128, 14, 14] 0
Conv2d-222 [-1, 32, 14, 14] 36,864
BatchNorm2d-223 [-1, 800, 14, 14] 1,600
ReLU-224 [-1, 800, 14, 14] 0
Conv2d-225 [-1, 128, 14, 14] 102,400
BatchNorm2d-226 [-1, 128, 14, 14] 256
ReLU-227 [-1, 128, 14, 14] 0
Conv2d-228 [-1, 32, 14, 14] 36,864
BatchNorm2d-229 [-1, 832, 14, 14] 1,664
ReLU-230 [-1, 832, 14, 14] 0
Conv2d-231 [-1, 128, 14, 14] 106,496
BatchNorm2d-232 [-1, 128, 14, 14] 256
ReLU-233 [-1, 128, 14, 14] 0
Conv2d-234 [-1, 32, 14, 14] 36,864
BatchNorm2d-235 [-1, 864, 14, 14] 1,728
ReLU-236 [-1, 864, 14, 14] 0
Conv2d-237 [-1, 128, 14, 14] 110,592
BatchNorm2d-238 [-1, 128, 14, 14] 256
ReLU-239 [-1, 128, 14, 14] 0
Conv2d-240 [-1, 32, 14, 14] 36,864
BatchNorm2d-241 [-1, 896, 14, 14] 1,792
ReLU-242 [-1, 896, 14, 14] 0
Conv2d-243 [-1, 128, 14, 14] 114,688
BatchNorm2d-244 [-1, 128, 14, 14] 256
ReLU-245 [-1, 128, 14, 14] 0
Conv2d-246 [-1, 32, 14, 14] 36,864
BatchNorm2d-247 [-1, 928, 14, 14] 1,856
ReLU-248 [-1, 928, 14, 14] 0
Conv2d-249 [-1, 128, 14, 14] 118,784
BatchNorm2d-250 [-1, 128, 14, 14] 256
ReLU-251 [-1, 128, 14, 14] 0
Conv2d-252 [-1, 32, 14, 14] 36,864
BatchNorm2d-253 [-1, 960, 14, 14] 1,920
ReLU-254 [-1, 960, 14, 14] 0
Conv2d-255 [-1, 128, 14, 14] 122,880
BatchNorm2d-256 [-1, 128, 14, 14] 256
ReLU-257 [-1, 128, 14, 14] 0
Conv2d-258 [-1, 32, 14, 14] 36,864
BatchNorm2d-259 [-1, 992, 14, 14] 1,984
ReLU-260 [-1, 992, 14, 14] 0
Conv2d-261 [-1, 128, 14, 14] 126,976
BatchNorm2d-262 [-1, 128, 14, 14] 256
ReLU-263 [-1, 128, 14, 14] 0
Conv2d-264 [-1, 32, 14, 14] 36,864
BatchNorm2d-265 [-1, 1024, 14, 14] 2,048
ReLU-266 [-1, 1024, 14, 14] 0
Conv2d-267 [-1, 512, 14, 14] 524,288
AvgPool2d-268 [-1, 512, 7, 7] 0
BatchNorm2d-269 [-1, 512, 7, 7] 1,024
ReLU-270 [-1, 512, 7, 7] 0
Conv2d-271 [-1, 128, 7, 7] 65,536
BatchNorm2d-272 [-1, 128, 7, 7] 256
ReLU-273 [-1, 128, 7, 7] 0
Conv2d-274 [-1, 32, 7, 7] 36,864
BatchNorm2d-275 [-1, 544, 7, 7] 1,088
ReLU-276 [-1, 544, 7, 7] 0
Conv2d-277 [-1, 128, 7, 7] 69,632
BatchNorm2d-278 [-1, 128, 7, 7] 256
ReLU-279 [-1, 128, 7, 7] 0
Conv2d-280 [-1, 32, 7, 7] 36,864
BatchNorm2d-281 [-1, 576, 7, 7] 1,152
ReLU-282 [-1, 576, 7, 7] 0
Conv2d-283 [-1, 128, 7, 7] 73,728
BatchNorm2d-284 [-1, 128, 7, 7] 256
ReLU-285 [-1, 128, 7, 7] 0
Conv2d-286 [-1, 32, 7, 7] 36,864
BatchNorm2d-287 [-1, 608, 7, 7] 1,216
ReLU-288 [-1, 608, 7, 7] 0
Conv2d-289 [-1, 128, 7, 7] 77,824
BatchNorm2d-290 [-1, 128, 7, 7] 256
ReLU-291 [-1, 128, 7, 7] 0
Conv2d-292 [-1, 32, 7, 7] 36,864
BatchNorm2d-293 [-1, 640, 7, 7] 1,280
ReLU-294 [-1, 640, 7, 7] 0
Conv2d-295 [-1, 128, 7, 7] 81,920
BatchNorm2d-296 [-1, 128, 7, 7] 256
ReLU-297 [-1, 128, 7, 7] 0
Conv2d-298 [-1, 32, 7, 7] 36,864
BatchNorm2d-299 [-1, 672, 7, 7] 1,344
ReLU-300 [-1, 672, 7, 7] 0
Conv2d-301 [-1, 128, 7, 7] 86,016
BatchNorm2d-302 [-1, 128, 7, 7] 256
ReLU-303 [-1, 128, 7, 7] 0
Conv2d-304 [-1, 32, 7, 7] 36,864
BatchNorm2d-305 [-1, 704, 7, 7] 1,408
ReLU-306 [-1, 704, 7, 7] 0
Conv2d-307 [-1, 128, 7, 7] 90,112
BatchNorm2d-308 [-1, 128, 7, 7] 256
ReLU-309 [-1, 128, 7, 7] 0
Conv2d-310 [-1, 32, 7, 7] 36,864
BatchNorm2d-311 [-1, 736, 7, 7] 1,472
ReLU-312 [-1, 736, 7, 7] 0
Conv2d-313 [-1, 128, 7, 7] 94,208
BatchNorm2d-314 [-1, 128, 7, 7] 256
ReLU-315 [-1, 128, 7, 7] 0
Conv2d-316 [-1, 32, 7, 7] 36,864
BatchNorm2d-317 [-1, 768, 7, 7] 1,536
ReLU-318 [-1, 768, 7, 7] 0
Conv2d-319 [-1, 128, 7, 7] 98,304
BatchNorm2d-320 [-1, 128, 7, 7] 256
ReLU-321 [-1, 128, 7, 7] 0
Conv2d-322 [-1, 32, 7, 7] 36,864
BatchNorm2d-323 [-1, 800, 7, 7] 1,600
ReLU-324 [-1, 800, 7, 7] 0
Conv2d-325 [-1, 128, 7, 7] 102,400
BatchNorm2d-326 [-1, 128, 7, 7] 256
ReLU-327 [-1, 128, 7, 7] 0
Conv2d-328 [-1, 32, 7, 7] 36,864
BatchNorm2d-329 [-1, 832, 7, 7] 1,664
ReLU-330 [-1, 832, 7, 7] 0
Conv2d-331 [-1, 128, 7, 7] 106,496
BatchNorm2d-332 [-1, 128, 7, 7] 256
ReLU-333 [-1, 128, 7, 7] 0
Conv2d-334 [-1, 32, 7, 7] 36,864
BatchNorm2d-335 [-1, 864, 7, 7] 1,728
ReLU-336 [-1, 864, 7, 7] 0
Conv2d-337 [-1, 128, 7, 7] 110,592
BatchNorm2d-338 [-1, 128, 7, 7] 256
ReLU-339 [-1, 128, 7, 7] 0
Conv2d-340 [-1, 32, 7, 7] 36,864
BatchNorm2d-341 [-1, 896, 7, 7] 1,792
ReLU-342 [-1, 896, 7, 7] 0
Conv2d-343 [-1, 128, 7, 7] 114,688
BatchNorm2d-344 [-1, 128, 7, 7] 256
ReLU-345 [-1, 128, 7, 7] 0
Conv2d-346 [-1, 32, 7, 7] 36,864
BatchNorm2d-347 [-1, 928, 7, 7] 1,856
ReLU-348 [-1, 928, 7, 7] 0
Conv2d-349 [-1, 128, 7, 7] 118,784
BatchNorm2d-350 [-1, 128, 7, 7] 256
ReLU-351 [-1, 128, 7, 7] 0
Conv2d-352 [-1, 32, 7, 7] 36,864
BatchNorm2d-353 [-1, 960, 7, 7] 1,920
ReLU-354 [-1, 960, 7, 7] 0
Conv2d-355 [-1, 128, 7, 7] 122,880
BatchNorm2d-356 [-1, 128, 7, 7] 256
ReLU-357 [-1, 128, 7, 7] 0
Conv2d-358 [-1, 32, 7, 7] 36,864
BatchNorm2d-359 [-1, 992, 7, 7] 1,984
ReLU-360 [-1, 992, 7, 7] 0
Conv2d-361 [-1, 128, 7, 7] 126,976
BatchNorm2d-362 [-1, 128, 7, 7] 256
ReLU-363 [-1, 128, 7, 7] 0
Conv2d-364 [-1, 32, 7, 7] 36,864
AdaptiveAvgPool2d-365 [-1, 1024, 1, 1] 0
Linear-366 [-1, 64] 65,600
ReLU-367 [-1, 64] 0
Linear-368 [-1, 1024] 66,560
Sigmoid-369 [-1, 1024] 0
Squeeze_excitation_layer-370 [-1, 1024, 7, 7] 0
BatchNorm2d-371 [-1, 1024, 7, 7] 2,048
ReLU-372 [-1, 1024, 7, 7] 0
Linear-373 [-1, 1000] 1,025,000
================================================================
Total params: 8,111,016
Trainable params: 8,111,016
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 294.99
Params size (MB): 30.94
Estimated Total Size (MB): 326.50
----------------------------------------------------------------
2.3 训练模型
2.3.1 设置超参数
python
"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss() # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4 # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法
2.3.2 编写训练函数
python
"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset) # 训练集的大小,一共60000张图片
num_batches = len(dataloader) # 批次数目,1875(60000/32)
train_loss, train_acc = 0, 0 # 初始化训练损失和正确率
for X, y in dataloader: # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)
X, y = X.to(device), y.to(device) # 用于将数据存到显卡
# 计算预测误差
pred = model(X) # 网络输出
loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
# 反向传播
optimizer.zero_grad() # 清空过往梯度
loss.backward() # 反向传播,计算当前梯度
optimizer.step() # 根据梯度更新网络参数
# 记录acc与loss
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
2.3.3 编写测试函数
python
"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset) # 测试集的大小,一共10000张图片
num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整)
test_loss, test_acc = 0, 0
# 当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
# 计算loss
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()#统计预测正确的个数
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
2.3.4 正式训练
python
"""训练模型--正式训练"""
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_test_acc=0
for epoch in range(epochs):
milliseconds_t1 = int(time.time() * 1000)
# 更新学习率(使用自定义学习率时使用)
# adjust_learning_rate(lr_opt, epoch, learn_rate)
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, model_opt)
scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = lr_opt.state_dict()['param_groups'][0]['lr']
milliseconds_t2 = int(time.time() * 1000)
template = ('Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')
if best_test_acc < epoch_test_acc:
best_test_acc = epoch_test_acc
#备份最好的模型
best_model = copy.deepcopy(model)
template = (
'Epoch:{:2d}, duration:{}ms, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E},Update the best model')
print(
template.format(epoch + 1, milliseconds_t2-milliseconds_t1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model.pth' # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')
Epoch: 1, duration:12559ms, Train_acc:56.9%, Train_loss:4.304, Test_acc:72.0%,Test_loss:2.320, Lr:1.00E-04,Update the best model
Epoch: 2, duration:11948ms, Train_acc:68.6%, Train_loss:1.381, Test_acc:73.0%,Test_loss:0.831, Lr:1.00E-04,Update the best model
Epoch: 3, duration:11949ms, Train_acc:74.2%, Train_loss:0.674, Test_acc:76.9%,Test_loss:0.561, Lr:1.00E-04,Update the best model
Epoch: 4, duration:12024ms, Train_acc:77.8%, Train_loss:0.532, Test_acc:74.4%,Test_loss:0.516, Lr:1.00E-04
Epoch: 5, duration:11876ms, Train_acc:80.5%, Train_loss:0.465, Test_acc:80.4%,Test_loss:0.472, Lr:1.00E-04,Update the best model
Epoch: 6, duration:11869ms, Train_acc:84.1%, Train_loss:0.409, Test_acc:82.3%,Test_loss:0.404, Lr:1.00E-04,Update the best model
Epoch: 7, duration:12088ms, Train_acc:84.1%, Train_loss:0.378, Test_acc:83.2%,Test_loss:0.355, Lr:1.00E-04,Update the best model
Epoch: 8, duration:12025ms, Train_acc:86.0%, Train_loss:0.348, Test_acc:85.3%,Test_loss:0.349, Lr:1.00E-04,Update the best model
Epoch: 9, duration:12019ms, Train_acc:86.2%, Train_loss:0.334, Test_acc:85.5%,Test_loss:0.360, Lr:1.00E-04,Update the best model
Epoch:10, duration:12027ms, Train_acc:88.3%, Train_loss:0.290, Test_acc:88.8%,Test_loss:0.260, Lr:1.00E-04,Update the best model
Epoch:11, duration:11865ms, Train_acc:88.9%, Train_loss:0.273, Test_acc:86.7%,Test_loss:0.311, Lr:1.00E-04
Epoch:12, duration:12054ms, Train_acc:90.0%, Train_loss:0.259, Test_acc:89.3%,Test_loss:0.271, Lr:1.00E-04,Update the best model
Epoch:13, duration:11983ms, Train_acc:90.3%, Train_loss:0.236, Test_acc:88.8%,Test_loss:0.272, Lr:1.00E-04
Epoch:14, duration:11980ms, Train_acc:90.1%, Train_loss:0.246, Test_acc:90.0%,Test_loss:0.229, Lr:1.00E-04,Update the best model
Epoch:15, duration:11936ms, Train_acc:91.4%, Train_loss:0.217, Test_acc:90.2%,Test_loss:0.256, Lr:1.00E-04,Update the best model
Epoch:16, duration:11935ms, Train_acc:93.8%, Train_loss:0.170, Test_acc:91.4%,Test_loss:0.237, Lr:1.00E-04,Update the best model
Epoch:17, duration:11980ms, Train_acc:93.7%, Train_loss:0.178, Test_acc:87.6%,Test_loss:0.353, Lr:1.00E-04
Epoch:18, duration:12344ms, Train_acc:92.8%, Train_loss:0.179, Test_acc:92.3%,Test_loss:0.190, Lr:1.00E-04,Update the best model
Epoch:19, duration:12301ms, Train_acc:95.3%, Train_loss:0.128, Test_acc:89.3%,Test_loss:0.275, Lr:1.00E-04
Epoch:20, duration:11914ms, Train_acc:95.3%, Train_loss:0.129, Test_acc:92.8%,Test_loss:0.218, Lr:1.00E-04,Update the best model
Done
2.4 结果可视化
python
"""训练模型--结果可视化"""
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
2.4 指定图片进行预测
python
def predict_one_image(image_path, model, transform, classes):
test_img = Image.open(image_path).convert('RGB')
plt.imshow(test_img) # 展示预测的图片
plt.show()
test_img = transform(test_img)
img = test_img.to(device).unsqueeze(0)
model.eval()
output = model(img)
_, pred = torch.max(output, 1)
pred_class = classes[pred]
print(f'预测结果是:{pred_class}')
# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))
"""指定图片进行预测"""
classes = list(total_data.class_to_idx)
# 预测训练集中的某张照片
predict_one_image(image_path=str(Path(data_dir) / "Monkeypox/M01_01_01.jpg"),
model=model,
transform=train_transforms,
classes=classes)
输出
预测结果是:Monkeypox
2.6 模型评估
python
"""模型评估"""
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
# 查看是否与我们记录的最高准确率一致
print(epoch_test_acc, epoch_test_loss)
输出
0.9277389277389277 0.21906232248459542
3 tensorflow实现DenseNet算法
3.1.引入库
python
from PIL import Image
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
import tensorflow as tf
from keras import layers, models, Input
from keras.layers import Input, Activation, BatchNormalization, Flatten
from keras.layers import Dense, Conv2D, MaxPooling2D, ZeroPadding2D, GlobalMaxPooling2D, AveragePooling2D, Flatten, \
Dropout, BatchNormalization, GlobalAveragePooling2D
from keras.models import Model
from keras import regularizers
from tensorflow import keras
from keras.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore') # 忽略一些warning内容,无需打印
3.2.设置GPU(如果使用的是CPU可以忽略这步)
python
'''前期工作-设置GPU(如果使用的是CPU可以忽略这步)'''
# 检查GPU是否可用
print(tf.test.is_built_with_cuda())
gpus = tf.config.list_physical_devices("GPU")
print(gpus)
if gpus:
gpu0 = gpus[0] # 如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) # 设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0], "GPU")
执行结果
True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
3.3.导入数据
python
'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\monkeypox_recognition"
data_dir = Path(data_dir)
3.4.查看数据
python
'''前期工作-查看数据'''
image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:", image_count)
image_list = list(data_dir.glob('Monkeypox/*.jpg'))
image = Image.open(str(image_list[1]))
# 查看图像实例的属性
print(image.format, image.size, image.mode)
plt.imshow(image)
plt.axis("off")
plt.show()
执行结果:
图片总数为: 2142
JPEG (224, 224) RGB
3.5.加载数据
python
'''数据预处理-加载数据'''
batch_size = 32
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
运行结果:
html
Found 2142 files belonging to 2 classes.
Using 1714 files for training.
Found 2142 files belonging to 2 classes.
Using 428 files for validation.
['Monkeypox', 'Others']
3.6.再次检查数据
python{.line-numbers}
'''数据预处理-再次检查数据'''
# Image_batch是形状的张量(16, 336, 336, 3)。这是一批形状336x336x3的16张图片(最后一维指的是彩色通道RGB)。
# Label_batch是形状(16,)的张量,这些标签对应16张图片
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
运行结果
(32, 224, 224, 3)
(32,)
3.7.配置数据集
python
'''数据预处理-配置数据集'''
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
3.8.可视化数据
python
'''数据预处理-可视化数据'''
plt.figure(figsize=(10, 5))
for images, labels in train_ds.take(1):
for i in range(8):
ax = plt.subplot(2, 4, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]], fontsize=10)
plt.axis("off")
# 显示图片
plt.show()
3.9.构建DenseNet网络
python
"""构建DenseNet网络"""
def conv_fn(x, growth_rate):
x1 = keras.layers.BatchNormalization()(x)
x1 = keras.layers.Activation('relu')(x1)
x1 = keras.layers.Conv2D(4 * growth_rate, 1, 1, padding="same", use_bias=False)(x1)
x1 = keras.layers.BatchNormalization()(x1)
x1 = keras.layers.Activation("relu")(x1)
x1 = keras.layers.Conv2D(growth_rate, 3, 1, padding="same", use_bias=False)(x1)
return keras.layers.Concatenate(axis=3)([x, x1])
def dense_block(x, block, growth_rate=32):
for i in range(block):
x = conv_fn(x, growth_rate)
return x
k = keras.backend
def trans_block(x, theta):
x1 = keras.layers.BatchNormalization()(x)
x1 = keras.layers.Activation("relu")(x1)
x1 = keras.layers.Conv2D(int(k.int_shape(x)[3] * theta), 1, 1, use_bias=False)(x1)
x1 = keras.layers.AveragePooling2D(pool_size=(2, 2), strides=2, padding="valid")(x1)
return x1
def densenet(input_shape, block, n_classes=1000):
# 56*56*64
x_input = keras.layers.Input(shape=input_shape)
x = keras.layers.Conv2D(64, kernel_size=(7, 7), strides=2, padding="same", use_bias=False)(x_input)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPooling2D(pool_size=3, strides=2, padding="same")(x)
x = dense_block(x, block[0])
x = trans_block(x, 0.5) # 28*28
x = dense_block(x, block[1])
x = trans_block(x, 0.5) # 14*14
x = dense_block(x, block[2])
x = trans_block(x, 0.5) # 7*7
x = dense_block(x, block[3])
x = keras.layers.BatchNormalization()(x)
x = keras.layers.Activation("relu")(x)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(n_classes, activation="softmax")(x)
model = keras.models.Model(inputs=[x_input], outputs=[outputs])
return model
model_121 = densenet([224, 224, 3], [6, 12, 24, 16]) # DenseNet-121
model_169 = densenet([224, 224, 3], [6, 12, 32, 32]) # DenseNet-169
model_201 = densenet([224, 224, 3], [6, 12, 48, 32]) # DenseNet-201
model_269 = densenet([224, 224, 3], [6, 12, 64, 48]) # DenseNet-269
model = model_121
model.summary()
网络结构结果如下:
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3 0 []
)]
conv2d (Conv2D) (None, 112, 112, 64 9408 ['input_1[0][0]']
)
batch_normalization (BatchNorm (None, 112, 112, 64 256 ['conv2d[0][0]']
alization) )
max_pooling2d (MaxPooling2D) (None, 56, 56, 64) 0 ['batch_normalization[0][0]']
batch_normalization_1 (BatchNo (None, 56, 56, 64) 256 ['max_pooling2d[0][0]']
rmalization)
activation (Activation) (None, 56, 56, 64) 0 ['batch_normalization_1[0][0]']
conv2d_1 (Conv2D) (None, 56, 56, 128) 8192 ['activation[0][0]']
batch_normalization_2 (BatchNo (None, 56, 56, 128) 512 ['conv2d_1[0][0]']
rmalization)
activation_1 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_2[0][0]']
conv2d_2 (Conv2D) (None, 56, 56, 32) 36864 ['activation_1[0][0]']
concatenate (Concatenate) (None, 56, 56, 96) 0 ['max_pooling2d[0][0]',
'conv2d_2[0][0]']
batch_normalization_3 (BatchNo (None, 56, 56, 96) 384 ['concatenate[0][0]']
rmalization)
activation_2 (Activation) (None, 56, 56, 96) 0 ['batch_normalization_3[0][0]']
conv2d_3 (Conv2D) (None, 56, 56, 128) 12288 ['activation_2[0][0]']
batch_normalization_4 (BatchNo (None, 56, 56, 128) 512 ['conv2d_3[0][0]']
rmalization)
activation_3 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_4[0][0]']
conv2d_4 (Conv2D) (None, 56, 56, 32) 36864 ['activation_3[0][0]']
concatenate_1 (Concatenate) (None, 56, 56, 128) 0 ['concatenate[0][0]',
'conv2d_4[0][0]']
batch_normalization_5 (BatchNo (None, 56, 56, 128) 512 ['concatenate_1[0][0]']
rmalization)
activation_4 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_5[0][0]']
conv2d_5 (Conv2D) (None, 56, 56, 128) 16384 ['activation_4[0][0]']
batch_normalization_6 (BatchNo (None, 56, 56, 128) 512 ['conv2d_5[0][0]']
rmalization)
activation_5 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_6[0][0]']
conv2d_6 (Conv2D) (None, 56, 56, 32) 36864 ['activation_5[0][0]']
concatenate_2 (Concatenate) (None, 56, 56, 160) 0 ['concatenate_1[0][0]',
'conv2d_6[0][0]']
batch_normalization_7 (BatchNo (None, 56, 56, 160) 640 ['concatenate_2[0][0]']
rmalization)
activation_6 (Activation) (None, 56, 56, 160) 0 ['batch_normalization_7[0][0]']
conv2d_7 (Conv2D) (None, 56, 56, 128) 20480 ['activation_6[0][0]']
batch_normalization_8 (BatchNo (None, 56, 56, 128) 512 ['conv2d_7[0][0]']
rmalization)
activation_7 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_8[0][0]']
conv2d_8 (Conv2D) (None, 56, 56, 32) 36864 ['activation_7[0][0]']
concatenate_3 (Concatenate) (None, 56, 56, 192) 0 ['concatenate_2[0][0]',
'conv2d_8[0][0]']
batch_normalization_9 (BatchNo (None, 56, 56, 192) 768 ['concatenate_3[0][0]']
rmalization)
activation_8 (Activation) (None, 56, 56, 192) 0 ['batch_normalization_9[0][0]']
conv2d_9 (Conv2D) (None, 56, 56, 128) 24576 ['activation_8[0][0]']
batch_normalization_10 (BatchN (None, 56, 56, 128) 512 ['conv2d_9[0][0]']
ormalization)
activation_9 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_10[0][0]']
conv2d_10 (Conv2D) (None, 56, 56, 32) 36864 ['activation_9[0][0]']
concatenate_4 (Concatenate) (None, 56, 56, 224) 0 ['concatenate_3[0][0]',
'conv2d_10[0][0]']
batch_normalization_11 (BatchN (None, 56, 56, 224) 896 ['concatenate_4[0][0]']
ormalization)
activation_10 (Activation) (None, 56, 56, 224) 0 ['batch_normalization_11[0][0]']
conv2d_11 (Conv2D) (None, 56, 56, 128) 28672 ['activation_10[0][0]']
batch_normalization_12 (BatchN (None, 56, 56, 128) 512 ['conv2d_11[0][0]']
ormalization)
activation_11 (Activation) (None, 56, 56, 128) 0 ['batch_normalization_12[0][0]']
conv2d_12 (Conv2D) (None, 56, 56, 32) 36864 ['activation_11[0][0]']
concatenate_5 (Concatenate) (None, 56, 56, 256) 0 ['concatenate_4[0][0]',
'conv2d_12[0][0]']
batch_normalization_13 (BatchN (None, 56, 56, 256) 1024 ['concatenate_5[0][0]']
ormalization)
activation_12 (Activation) (None, 56, 56, 256) 0 ['batch_normalization_13[0][0]']
conv2d_13 (Conv2D) (None, 56, 56, 128) 32768 ['activation_12[0][0]']
average_pooling2d (AveragePool (None, 28, 28, 128) 0 ['conv2d_13[0][0]']
ing2D)
batch_normalization_14 (BatchN (None, 28, 28, 128) 512 ['average_pooling2d[0][0]']
ormalization)
activation_13 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_14[0][0]']
conv2d_14 (Conv2D) (None, 28, 28, 128) 16384 ['activation_13[0][0]']
batch_normalization_15 (BatchN (None, 28, 28, 128) 512 ['conv2d_14[0][0]']
ormalization)
activation_14 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_15[0][0]']
conv2d_15 (Conv2D) (None, 28, 28, 32) 36864 ['activation_14[0][0]']
concatenate_6 (Concatenate) (None, 28, 28, 160) 0 ['average_pooling2d[0][0]',
'conv2d_15[0][0]']
batch_normalization_16 (BatchN (None, 28, 28, 160) 640 ['concatenate_6[0][0]']
ormalization)
activation_15 (Activation) (None, 28, 28, 160) 0 ['batch_normalization_16[0][0]']
conv2d_16 (Conv2D) (None, 28, 28, 128) 20480 ['activation_15[0][0]']
batch_normalization_17 (BatchN (None, 28, 28, 128) 512 ['conv2d_16[0][0]']
ormalization)
activation_16 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_17[0][0]']
conv2d_17 (Conv2D) (None, 28, 28, 32) 36864 ['activation_16[0][0]']
concatenate_7 (Concatenate) (None, 28, 28, 192) 0 ['concatenate_6[0][0]',
'conv2d_17[0][0]']
batch_normalization_18 (BatchN (None, 28, 28, 192) 768 ['concatenate_7[0][0]']
ormalization)
activation_17 (Activation) (None, 28, 28, 192) 0 ['batch_normalization_18[0][0]']
conv2d_18 (Conv2D) (None, 28, 28, 128) 24576 ['activation_17[0][0]']
batch_normalization_19 (BatchN (None, 28, 28, 128) 512 ['conv2d_18[0][0]']
ormalization)
activation_18 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_19[0][0]']
conv2d_19 (Conv2D) (None, 28, 28, 32) 36864 ['activation_18[0][0]']
concatenate_8 (Concatenate) (None, 28, 28, 224) 0 ['concatenate_7[0][0]',
'conv2d_19[0][0]']
batch_normalization_20 (BatchN (None, 28, 28, 224) 896 ['concatenate_8[0][0]']
ormalization)
activation_19 (Activation) (None, 28, 28, 224) 0 ['batch_normalization_20[0][0]']
conv2d_20 (Conv2D) (None, 28, 28, 128) 28672 ['activation_19[0][0]']
batch_normalization_21 (BatchN (None, 28, 28, 128) 512 ['conv2d_20[0][0]']
ormalization)
activation_20 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_21[0][0]']
conv2d_21 (Conv2D) (None, 28, 28, 32) 36864 ['activation_20[0][0]']
concatenate_9 (Concatenate) (None, 28, 28, 256) 0 ['concatenate_8[0][0]',
'conv2d_21[0][0]']
batch_normalization_22 (BatchN (None, 28, 28, 256) 1024 ['concatenate_9[0][0]']
ormalization)
activation_21 (Activation) (None, 28, 28, 256) 0 ['batch_normalization_22[0][0]']
conv2d_22 (Conv2D) (None, 28, 28, 128) 32768 ['activation_21[0][0]']
batch_normalization_23 (BatchN (None, 28, 28, 128) 512 ['conv2d_22[0][0]']
ormalization)
activation_22 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_23[0][0]']
conv2d_23 (Conv2D) (None, 28, 28, 32) 36864 ['activation_22[0][0]']
concatenate_10 (Concatenate) (None, 28, 28, 288) 0 ['concatenate_9[0][0]',
'conv2d_23[0][0]']
batch_normalization_24 (BatchN (None, 28, 28, 288) 1152 ['concatenate_10[0][0]']
ormalization)
activation_23 (Activation) (None, 28, 28, 288) 0 ['batch_normalization_24[0][0]']
conv2d_24 (Conv2D) (None, 28, 28, 128) 36864 ['activation_23[0][0]']
batch_normalization_25 (BatchN (None, 28, 28, 128) 512 ['conv2d_24[0][0]']
ormalization)
activation_24 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_25[0][0]']
conv2d_25 (Conv2D) (None, 28, 28, 32) 36864 ['activation_24[0][0]']
concatenate_11 (Concatenate) (None, 28, 28, 320) 0 ['concatenate_10[0][0]',
'conv2d_25[0][0]']
batch_normalization_26 (BatchN (None, 28, 28, 320) 1280 ['concatenate_11[0][0]']
ormalization)
activation_25 (Activation) (None, 28, 28, 320) 0 ['batch_normalization_26[0][0]']
conv2d_26 (Conv2D) (None, 28, 28, 128) 40960 ['activation_25[0][0]']
batch_normalization_27 (BatchN (None, 28, 28, 128) 512 ['conv2d_26[0][0]']
ormalization)
activation_26 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_27[0][0]']
conv2d_27 (Conv2D) (None, 28, 28, 32) 36864 ['activation_26[0][0]']
concatenate_12 (Concatenate) (None, 28, 28, 352) 0 ['concatenate_11[0][0]',
'conv2d_27[0][0]']
batch_normalization_28 (BatchN (None, 28, 28, 352) 1408 ['concatenate_12[0][0]']
ormalization)
activation_27 (Activation) (None, 28, 28, 352) 0 ['batch_normalization_28[0][0]']
conv2d_28 (Conv2D) (None, 28, 28, 128) 45056 ['activation_27[0][0]']
batch_normalization_29 (BatchN (None, 28, 28, 128) 512 ['conv2d_28[0][0]']
ormalization)
activation_28 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_29[0][0]']
conv2d_29 (Conv2D) (None, 28, 28, 32) 36864 ['activation_28[0][0]']
concatenate_13 (Concatenate) (None, 28, 28, 384) 0 ['concatenate_12[0][0]',
'conv2d_29[0][0]']
batch_normalization_30 (BatchN (None, 28, 28, 384) 1536 ['concatenate_13[0][0]']
ormalization)
activation_29 (Activation) (None, 28, 28, 384) 0 ['batch_normalization_30[0][0]']
conv2d_30 (Conv2D) (None, 28, 28, 128) 49152 ['activation_29[0][0]']
batch_normalization_31 (BatchN (None, 28, 28, 128) 512 ['conv2d_30[0][0]']
ormalization)
activation_30 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_31[0][0]']
conv2d_31 (Conv2D) (None, 28, 28, 32) 36864 ['activation_30[0][0]']
concatenate_14 (Concatenate) (None, 28, 28, 416) 0 ['concatenate_13[0][0]',
'conv2d_31[0][0]']
batch_normalization_32 (BatchN (None, 28, 28, 416) 1664 ['concatenate_14[0][0]']
ormalization)
activation_31 (Activation) (None, 28, 28, 416) 0 ['batch_normalization_32[0][0]']
conv2d_32 (Conv2D) (None, 28, 28, 128) 53248 ['activation_31[0][0]']
batch_normalization_33 (BatchN (None, 28, 28, 128) 512 ['conv2d_32[0][0]']
ormalization)
activation_32 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_33[0][0]']
conv2d_33 (Conv2D) (None, 28, 28, 32) 36864 ['activation_32[0][0]']
concatenate_15 (Concatenate) (None, 28, 28, 448) 0 ['concatenate_14[0][0]',
'conv2d_33[0][0]']
batch_normalization_34 (BatchN (None, 28, 28, 448) 1792 ['concatenate_15[0][0]']
ormalization)
activation_33 (Activation) (None, 28, 28, 448) 0 ['batch_normalization_34[0][0]']
conv2d_34 (Conv2D) (None, 28, 28, 128) 57344 ['activation_33[0][0]']
batch_normalization_35 (BatchN (None, 28, 28, 128) 512 ['conv2d_34[0][0]']
ormalization)
activation_34 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_35[0][0]']
conv2d_35 (Conv2D) (None, 28, 28, 32) 36864 ['activation_34[0][0]']
concatenate_16 (Concatenate) (None, 28, 28, 480) 0 ['concatenate_15[0][0]',
'conv2d_35[0][0]']
batch_normalization_36 (BatchN (None, 28, 28, 480) 1920 ['concatenate_16[0][0]']
ormalization)
activation_35 (Activation) (None, 28, 28, 480) 0 ['batch_normalization_36[0][0]']
conv2d_36 (Conv2D) (None, 28, 28, 128) 61440 ['activation_35[0][0]']
batch_normalization_37 (BatchN (None, 28, 28, 128) 512 ['conv2d_36[0][0]']
ormalization)
activation_36 (Activation) (None, 28, 28, 128) 0 ['batch_normalization_37[0][0]']
conv2d_37 (Conv2D) (None, 28, 28, 32) 36864 ['activation_36[0][0]']
concatenate_17 (Concatenate) (None, 28, 28, 512) 0 ['concatenate_16[0][0]',
'conv2d_37[0][0]']
batch_normalization_38 (BatchN (None, 28, 28, 512) 2048 ['concatenate_17[0][0]']
ormalization)
activation_37 (Activation) (None, 28, 28, 512) 0 ['batch_normalization_38[0][0]']
conv2d_38 (Conv2D) (None, 28, 28, 256) 131072 ['activation_37[0][0]']
average_pooling2d_1 (AveragePo (None, 14, 14, 256) 0 ['conv2d_38[0][0]']
oling2D)
batch_normalization_39 (BatchN (None, 14, 14, 256) 1024 ['average_pooling2d_1[0][0]']
ormalization)
activation_38 (Activation) (None, 14, 14, 256) 0 ['batch_normalization_39[0][0]']
conv2d_39 (Conv2D) (None, 14, 14, 128) 32768 ['activation_38[0][0]']
batch_normalization_40 (BatchN (None, 14, 14, 128) 512 ['conv2d_39[0][0]']
ormalization)
activation_39 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_40[0][0]']
conv2d_40 (Conv2D) (None, 14, 14, 32) 36864 ['activation_39[0][0]']
concatenate_18 (Concatenate) (None, 14, 14, 288) 0 ['average_pooling2d_1[0][0]',
'conv2d_40[0][0]']
batch_normalization_41 (BatchN (None, 14, 14, 288) 1152 ['concatenate_18[0][0]']
ormalization)
activation_40 (Activation) (None, 14, 14, 288) 0 ['batch_normalization_41[0][0]']
conv2d_41 (Conv2D) (None, 14, 14, 128) 36864 ['activation_40[0][0]']
batch_normalization_42 (BatchN (None, 14, 14, 128) 512 ['conv2d_41[0][0]']
ormalization)
activation_41 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_42[0][0]']
conv2d_42 (Conv2D) (None, 14, 14, 32) 36864 ['activation_41[0][0]']
concatenate_19 (Concatenate) (None, 14, 14, 320) 0 ['concatenate_18[0][0]',
'conv2d_42[0][0]']
batch_normalization_43 (BatchN (None, 14, 14, 320) 1280 ['concatenate_19[0][0]']
ormalization)
activation_42 (Activation) (None, 14, 14, 320) 0 ['batch_normalization_43[0][0]']
conv2d_43 (Conv2D) (None, 14, 14, 128) 40960 ['activation_42[0][0]']
batch_normalization_44 (BatchN (None, 14, 14, 128) 512 ['conv2d_43[0][0]']
ormalization)
activation_43 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_44[0][0]']
conv2d_44 (Conv2D) (None, 14, 14, 32) 36864 ['activation_43[0][0]']
concatenate_20 (Concatenate) (None, 14, 14, 352) 0 ['concatenate_19[0][0]',
'conv2d_44[0][0]']
batch_normalization_45 (BatchN (None, 14, 14, 352) 1408 ['concatenate_20[0][0]']
ormalization)
activation_44 (Activation) (None, 14, 14, 352) 0 ['batch_normalization_45[0][0]']
conv2d_45 (Conv2D) (None, 14, 14, 128) 45056 ['activation_44[0][0]']
batch_normalization_46 (BatchN (None, 14, 14, 128) 512 ['conv2d_45[0][0]']
ormalization)
activation_45 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_46[0][0]']
conv2d_46 (Conv2D) (None, 14, 14, 32) 36864 ['activation_45[0][0]']
concatenate_21 (Concatenate) (None, 14, 14, 384) 0 ['concatenate_20[0][0]',
'conv2d_46[0][0]']
batch_normalization_47 (BatchN (None, 14, 14, 384) 1536 ['concatenate_21[0][0]']
ormalization)
activation_46 (Activation) (None, 14, 14, 384) 0 ['batch_normalization_47[0][0]']
conv2d_47 (Conv2D) (None, 14, 14, 128) 49152 ['activation_46[0][0]']
batch_normalization_48 (BatchN (None, 14, 14, 128) 512 ['conv2d_47[0][0]']
ormalization)
activation_47 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_48[0][0]']
conv2d_48 (Conv2D) (None, 14, 14, 32) 36864 ['activation_47[0][0]']
concatenate_22 (Concatenate) (None, 14, 14, 416) 0 ['concatenate_21[0][0]',
'conv2d_48[0][0]']
batch_normalization_49 (BatchN (None, 14, 14, 416) 1664 ['concatenate_22[0][0]']
ormalization)
activation_48 (Activation) (None, 14, 14, 416) 0 ['batch_normalization_49[0][0]']
conv2d_49 (Conv2D) (None, 14, 14, 128) 53248 ['activation_48[0][0]']
batch_normalization_50 (BatchN (None, 14, 14, 128) 512 ['conv2d_49[0][0]']
ormalization)
activation_49 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_50[0][0]']
conv2d_50 (Conv2D) (None, 14, 14, 32) 36864 ['activation_49[0][0]']
concatenate_23 (Concatenate) (None, 14, 14, 448) 0 ['concatenate_22[0][0]',
'conv2d_50[0][0]']
batch_normalization_51 (BatchN (None, 14, 14, 448) 1792 ['concatenate_23[0][0]']
ormalization)
activation_50 (Activation) (None, 14, 14, 448) 0 ['batch_normalization_51[0][0]']
conv2d_51 (Conv2D) (None, 14, 14, 128) 57344 ['activation_50[0][0]']
batch_normalization_52 (BatchN (None, 14, 14, 128) 512 ['conv2d_51[0][0]']
ormalization)
activation_51 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_52[0][0]']
conv2d_52 (Conv2D) (None, 14, 14, 32) 36864 ['activation_51[0][0]']
concatenate_24 (Concatenate) (None, 14, 14, 480) 0 ['concatenate_23[0][0]',
'conv2d_52[0][0]']
batch_normalization_53 (BatchN (None, 14, 14, 480) 1920 ['concatenate_24[0][0]']
ormalization)
activation_52 (Activation) (None, 14, 14, 480) 0 ['batch_normalization_53[0][0]']
conv2d_53 (Conv2D) (None, 14, 14, 128) 61440 ['activation_52[0][0]']
batch_normalization_54 (BatchN (None, 14, 14, 128) 512 ['conv2d_53[0][0]']
ormalization)
activation_53 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_54[0][0]']
conv2d_54 (Conv2D) (None, 14, 14, 32) 36864 ['activation_53[0][0]']
concatenate_25 (Concatenate) (None, 14, 14, 512) 0 ['concatenate_24[0][0]',
'conv2d_54[0][0]']
batch_normalization_55 (BatchN (None, 14, 14, 512) 2048 ['concatenate_25[0][0]']
ormalization)
activation_54 (Activation) (None, 14, 14, 512) 0 ['batch_normalization_55[0][0]']
conv2d_55 (Conv2D) (None, 14, 14, 128) 65536 ['activation_54[0][0]']
batch_normalization_56 (BatchN (None, 14, 14, 128) 512 ['conv2d_55[0][0]']
ormalization)
activation_55 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_56[0][0]']
conv2d_56 (Conv2D) (None, 14, 14, 32) 36864 ['activation_55[0][0]']
concatenate_26 (Concatenate) (None, 14, 14, 544) 0 ['concatenate_25[0][0]',
'conv2d_56[0][0]']
batch_normalization_57 (BatchN (None, 14, 14, 544) 2176 ['concatenate_26[0][0]']
ormalization)
activation_56 (Activation) (None, 14, 14, 544) 0 ['batch_normalization_57[0][0]']
conv2d_57 (Conv2D) (None, 14, 14, 128) 69632 ['activation_56[0][0]']
batch_normalization_58 (BatchN (None, 14, 14, 128) 512 ['conv2d_57[0][0]']
ormalization)
activation_57 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_58[0][0]']
conv2d_58 (Conv2D) (None, 14, 14, 32) 36864 ['activation_57[0][0]']
concatenate_27 (Concatenate) (None, 14, 14, 576) 0 ['concatenate_26[0][0]',
'conv2d_58[0][0]']
batch_normalization_59 (BatchN (None, 14, 14, 576) 2304 ['concatenate_27[0][0]']
ormalization)
activation_58 (Activation) (None, 14, 14, 576) 0 ['batch_normalization_59[0][0]']
conv2d_59 (Conv2D) (None, 14, 14, 128) 73728 ['activation_58[0][0]']
batch_normalization_60 (BatchN (None, 14, 14, 128) 512 ['conv2d_59[0][0]']
ormalization)
activation_59 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_60[0][0]']
conv2d_60 (Conv2D) (None, 14, 14, 32) 36864 ['activation_59[0][0]']
concatenate_28 (Concatenate) (None, 14, 14, 608) 0 ['concatenate_27[0][0]',
'conv2d_60[0][0]']
batch_normalization_61 (BatchN (None, 14, 14, 608) 2432 ['concatenate_28[0][0]']
ormalization)
activation_60 (Activation) (None, 14, 14, 608) 0 ['batch_normalization_61[0][0]']
conv2d_61 (Conv2D) (None, 14, 14, 128) 77824 ['activation_60[0][0]']
batch_normalization_62 (BatchN (None, 14, 14, 128) 512 ['conv2d_61[0][0]']
ormalization)
activation_61 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_62[0][0]']
conv2d_62 (Conv2D) (None, 14, 14, 32) 36864 ['activation_61[0][0]']
concatenate_29 (Concatenate) (None, 14, 14, 640) 0 ['concatenate_28[0][0]',
'conv2d_62[0][0]']
batch_normalization_63 (BatchN (None, 14, 14, 640) 2560 ['concatenate_29[0][0]']
ormalization)
activation_62 (Activation) (None, 14, 14, 640) 0 ['batch_normalization_63[0][0]']
conv2d_63 (Conv2D) (None, 14, 14, 128) 81920 ['activation_62[0][0]']
batch_normalization_64 (BatchN (None, 14, 14, 128) 512 ['conv2d_63[0][0]']
ormalization)
activation_63 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_64[0][0]']
conv2d_64 (Conv2D) (None, 14, 14, 32) 36864 ['activation_63[0][0]']
concatenate_30 (Concatenate) (None, 14, 14, 672) 0 ['concatenate_29[0][0]',
'conv2d_64[0][0]']
batch_normalization_65 (BatchN (None, 14, 14, 672) 2688 ['concatenate_30[0][0]']
ormalization)
activation_64 (Activation) (None, 14, 14, 672) 0 ['batch_normalization_65[0][0]']
conv2d_65 (Conv2D) (None, 14, 14, 128) 86016 ['activation_64[0][0]']
batch_normalization_66 (BatchN (None, 14, 14, 128) 512 ['conv2d_65[0][0]']
ormalization)
activation_65 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_66[0][0]']
conv2d_66 (Conv2D) (None, 14, 14, 32) 36864 ['activation_65[0][0]']
concatenate_31 (Concatenate) (None, 14, 14, 704) 0 ['concatenate_30[0][0]',
'conv2d_66[0][0]']
batch_normalization_67 (BatchN (None, 14, 14, 704) 2816 ['concatenate_31[0][0]']
ormalization)
activation_66 (Activation) (None, 14, 14, 704) 0 ['batch_normalization_67[0][0]']
conv2d_67 (Conv2D) (None, 14, 14, 128) 90112 ['activation_66[0][0]']
batch_normalization_68 (BatchN (None, 14, 14, 128) 512 ['conv2d_67[0][0]']
ormalization)
activation_67 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_68[0][0]']
conv2d_68 (Conv2D) (None, 14, 14, 32) 36864 ['activation_67[0][0]']
concatenate_32 (Concatenate) (None, 14, 14, 736) 0 ['concatenate_31[0][0]',
'conv2d_68[0][0]']
batch_normalization_69 (BatchN (None, 14, 14, 736) 2944 ['concatenate_32[0][0]']
ormalization)
activation_68 (Activation) (None, 14, 14, 736) 0 ['batch_normalization_69[0][0]']
conv2d_69 (Conv2D) (None, 14, 14, 128) 94208 ['activation_68[0][0]']
batch_normalization_70 (BatchN (None, 14, 14, 128) 512 ['conv2d_69[0][0]']
ormalization)
activation_69 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_70[0][0]']
conv2d_70 (Conv2D) (None, 14, 14, 32) 36864 ['activation_69[0][0]']
concatenate_33 (Concatenate) (None, 14, 14, 768) 0 ['concatenate_32[0][0]',
'conv2d_70[0][0]']
batch_normalization_71 (BatchN (None, 14, 14, 768) 3072 ['concatenate_33[0][0]']
ormalization)
activation_70 (Activation) (None, 14, 14, 768) 0 ['batch_normalization_71[0][0]']
conv2d_71 (Conv2D) (None, 14, 14, 128) 98304 ['activation_70[0][0]']
batch_normalization_72 (BatchN (None, 14, 14, 128) 512 ['conv2d_71[0][0]']
ormalization)
activation_71 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_72[0][0]']
conv2d_72 (Conv2D) (None, 14, 14, 32) 36864 ['activation_71[0][0]']
concatenate_34 (Concatenate) (None, 14, 14, 800) 0 ['concatenate_33[0][0]',
'conv2d_72[0][0]']
batch_normalization_73 (BatchN (None, 14, 14, 800) 3200 ['concatenate_34[0][0]']
ormalization)
activation_72 (Activation) (None, 14, 14, 800) 0 ['batch_normalization_73[0][0]']
conv2d_73 (Conv2D) (None, 14, 14, 128) 102400 ['activation_72[0][0]']
batch_normalization_74 (BatchN (None, 14, 14, 128) 512 ['conv2d_73[0][0]']
ormalization)
activation_73 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_74[0][0]']
conv2d_74 (Conv2D) (None, 14, 14, 32) 36864 ['activation_73[0][0]']
concatenate_35 (Concatenate) (None, 14, 14, 832) 0 ['concatenate_34[0][0]',
'conv2d_74[0][0]']
batch_normalization_75 (BatchN (None, 14, 14, 832) 3328 ['concatenate_35[0][0]']
ormalization)
activation_74 (Activation) (None, 14, 14, 832) 0 ['batch_normalization_75[0][0]']
conv2d_75 (Conv2D) (None, 14, 14, 128) 106496 ['activation_74[0][0]']
batch_normalization_76 (BatchN (None, 14, 14, 128) 512 ['conv2d_75[0][0]']
ormalization)
activation_75 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_76[0][0]']
conv2d_76 (Conv2D) (None, 14, 14, 32) 36864 ['activation_75[0][0]']
concatenate_36 (Concatenate) (None, 14, 14, 864) 0 ['concatenate_35[0][0]',
'conv2d_76[0][0]']
batch_normalization_77 (BatchN (None, 14, 14, 864) 3456 ['concatenate_36[0][0]']
ormalization)
activation_76 (Activation) (None, 14, 14, 864) 0 ['batch_normalization_77[0][0]']
conv2d_77 (Conv2D) (None, 14, 14, 128) 110592 ['activation_76[0][0]']
batch_normalization_78 (BatchN (None, 14, 14, 128) 512 ['conv2d_77[0][0]']
ormalization)
activation_77 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_78[0][0]']
conv2d_78 (Conv2D) (None, 14, 14, 32) 36864 ['activation_77[0][0]']
concatenate_37 (Concatenate) (None, 14, 14, 896) 0 ['concatenate_36[0][0]',
'conv2d_78[0][0]']
batch_normalization_79 (BatchN (None, 14, 14, 896) 3584 ['concatenate_37[0][0]']
ormalization)
activation_78 (Activation) (None, 14, 14, 896) 0 ['batch_normalization_79[0][0]']
conv2d_79 (Conv2D) (None, 14, 14, 128) 114688 ['activation_78[0][0]']
batch_normalization_80 (BatchN (None, 14, 14, 128) 512 ['conv2d_79[0][0]']
ormalization)
activation_79 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_80[0][0]']
conv2d_80 (Conv2D) (None, 14, 14, 32) 36864 ['activation_79[0][0]']
concatenate_38 (Concatenate) (None, 14, 14, 928) 0 ['concatenate_37[0][0]',
'conv2d_80[0][0]']
batch_normalization_81 (BatchN (None, 14, 14, 928) 3712 ['concatenate_38[0][0]']
ormalization)
activation_80 (Activation) (None, 14, 14, 928) 0 ['batch_normalization_81[0][0]']
conv2d_81 (Conv2D) (None, 14, 14, 128) 118784 ['activation_80[0][0]']
batch_normalization_82 (BatchN (None, 14, 14, 128) 512 ['conv2d_81[0][0]']
ormalization)
activation_81 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_82[0][0]']
conv2d_82 (Conv2D) (None, 14, 14, 32) 36864 ['activation_81[0][0]']
concatenate_39 (Concatenate) (None, 14, 14, 960) 0 ['concatenate_38[0][0]',
'conv2d_82[0][0]']
batch_normalization_83 (BatchN (None, 14, 14, 960) 3840 ['concatenate_39[0][0]']
ormalization)
activation_82 (Activation) (None, 14, 14, 960) 0 ['batch_normalization_83[0][0]']
conv2d_83 (Conv2D) (None, 14, 14, 128) 122880 ['activation_82[0][0]']
batch_normalization_84 (BatchN (None, 14, 14, 128) 512 ['conv2d_83[0][0]']
ormalization)
activation_83 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_84[0][0]']
conv2d_84 (Conv2D) (None, 14, 14, 32) 36864 ['activation_83[0][0]']
concatenate_40 (Concatenate) (None, 14, 14, 992) 0 ['concatenate_39[0][0]',
'conv2d_84[0][0]']
batch_normalization_85 (BatchN (None, 14, 14, 992) 3968 ['concatenate_40[0][0]']
ormalization)
activation_84 (Activation) (None, 14, 14, 992) 0 ['batch_normalization_85[0][0]']
conv2d_85 (Conv2D) (None, 14, 14, 128) 126976 ['activation_84[0][0]']
batch_normalization_86 (BatchN (None, 14, 14, 128) 512 ['conv2d_85[0][0]']
ormalization)
activation_85 (Activation) (None, 14, 14, 128) 0 ['batch_normalization_86[0][0]']
conv2d_86 (Conv2D) (None, 14, 14, 32) 36864 ['activation_85[0][0]']
concatenate_41 (Concatenate) (None, 14, 14, 1024 0 ['concatenate_40[0][0]',
) 'conv2d_86[0][0]']
batch_normalization_87 (BatchN (None, 14, 14, 1024 4096 ['concatenate_41[0][0]']
ormalization) )
activation_86 (Activation) (None, 14, 14, 1024 0 ['batch_normalization_87[0][0]']
)
conv2d_87 (Conv2D) (None, 14, 14, 512) 524288 ['activation_86[0][0]']
average_pooling2d_2 (AveragePo (None, 7, 7, 512) 0 ['conv2d_87[0][0]']
oling2D)
batch_normalization_88 (BatchN (None, 7, 7, 512) 2048 ['average_pooling2d_2[0][0]']
ormalization)
activation_87 (Activation) (None, 7, 7, 512) 0 ['batch_normalization_88[0][0]']
conv2d_88 (Conv2D) (None, 7, 7, 128) 65536 ['activation_87[0][0]']
batch_normalization_89 (BatchN (None, 7, 7, 128) 512 ['conv2d_88[0][0]']
ormalization)
activation_88 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_89[0][0]']
conv2d_89 (Conv2D) (None, 7, 7, 32) 36864 ['activation_88[0][0]']
concatenate_42 (Concatenate) (None, 7, 7, 544) 0 ['average_pooling2d_2[0][0]',
'conv2d_89[0][0]']
batch_normalization_90 (BatchN (None, 7, 7, 544) 2176 ['concatenate_42[0][0]']
ormalization)
activation_89 (Activation) (None, 7, 7, 544) 0 ['batch_normalization_90[0][0]']
conv2d_90 (Conv2D) (None, 7, 7, 128) 69632 ['activation_89[0][0]']
batch_normalization_91 (BatchN (None, 7, 7, 128) 512 ['conv2d_90[0][0]']
ormalization)
activation_90 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_91[0][0]']
conv2d_91 (Conv2D) (None, 7, 7, 32) 36864 ['activation_90[0][0]']
concatenate_43 (Concatenate) (None, 7, 7, 576) 0 ['concatenate_42[0][0]',
'conv2d_91[0][0]']
batch_normalization_92 (BatchN (None, 7, 7, 576) 2304 ['concatenate_43[0][0]']
ormalization)
activation_91 (Activation) (None, 7, 7, 576) 0 ['batch_normalization_92[0][0]']
conv2d_92 (Conv2D) (None, 7, 7, 128) 73728 ['activation_91[0][0]']
batch_normalization_93 (BatchN (None, 7, 7, 128) 512 ['conv2d_92[0][0]']
ormalization)
activation_92 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_93[0][0]']
conv2d_93 (Conv2D) (None, 7, 7, 32) 36864 ['activation_92[0][0]']
concatenate_44 (Concatenate) (None, 7, 7, 608) 0 ['concatenate_43[0][0]',
'conv2d_93[0][0]']
batch_normalization_94 (BatchN (None, 7, 7, 608) 2432 ['concatenate_44[0][0]']
ormalization)
activation_93 (Activation) (None, 7, 7, 608) 0 ['batch_normalization_94[0][0]']
conv2d_94 (Conv2D) (None, 7, 7, 128) 77824 ['activation_93[0][0]']
batch_normalization_95 (BatchN (None, 7, 7, 128) 512 ['conv2d_94[0][0]']
ormalization)
activation_94 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_95[0][0]']
conv2d_95 (Conv2D) (None, 7, 7, 32) 36864 ['activation_94[0][0]']
concatenate_45 (Concatenate) (None, 7, 7, 640) 0 ['concatenate_44[0][0]',
'conv2d_95[0][0]']
batch_normalization_96 (BatchN (None, 7, 7, 640) 2560 ['concatenate_45[0][0]']
ormalization)
activation_95 (Activation) (None, 7, 7, 640) 0 ['batch_normalization_96[0][0]']
conv2d_96 (Conv2D) (None, 7, 7, 128) 81920 ['activation_95[0][0]']
batch_normalization_97 (BatchN (None, 7, 7, 128) 512 ['conv2d_96[0][0]']
ormalization)
activation_96 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_97[0][0]']
conv2d_97 (Conv2D) (None, 7, 7, 32) 36864 ['activation_96[0][0]']
concatenate_46 (Concatenate) (None, 7, 7, 672) 0 ['concatenate_45[0][0]',
'conv2d_97[0][0]']
batch_normalization_98 (BatchN (None, 7, 7, 672) 2688 ['concatenate_46[0][0]']
ormalization)
activation_97 (Activation) (None, 7, 7, 672) 0 ['batch_normalization_98[0][0]']
conv2d_98 (Conv2D) (None, 7, 7, 128) 86016 ['activation_97[0][0]']
batch_normalization_99 (BatchN (None, 7, 7, 128) 512 ['conv2d_98[0][0]']
ormalization)
activation_98 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_99[0][0]']
conv2d_99 (Conv2D) (None, 7, 7, 32) 36864 ['activation_98[0][0]']
concatenate_47 (Concatenate) (None, 7, 7, 704) 0 ['concatenate_46[0][0]',
'conv2d_99[0][0]']
batch_normalization_100 (Batch (None, 7, 7, 704) 2816 ['concatenate_47[0][0]']
Normalization)
activation_99 (Activation) (None, 7, 7, 704) 0 ['batch_normalization_100[0][0]']
conv2d_100 (Conv2D) (None, 7, 7, 128) 90112 ['activation_99[0][0]']
batch_normalization_101 (Batch (None, 7, 7, 128) 512 ['conv2d_100[0][0]']
Normalization)
activation_100 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_101[0][0]']
conv2d_101 (Conv2D) (None, 7, 7, 32) 36864 ['activation_100[0][0]']
concatenate_48 (Concatenate) (None, 7, 7, 736) 0 ['concatenate_47[0][0]',
'conv2d_101[0][0]']
batch_normalization_102 (Batch (None, 7, 7, 736) 2944 ['concatenate_48[0][0]']
Normalization)
activation_101 (Activation) (None, 7, 7, 736) 0 ['batch_normalization_102[0][0]']
conv2d_102 (Conv2D) (None, 7, 7, 128) 94208 ['activation_101[0][0]']
batch_normalization_103 (Batch (None, 7, 7, 128) 512 ['conv2d_102[0][0]']
Normalization)
activation_102 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_103[0][0]']
conv2d_103 (Conv2D) (None, 7, 7, 32) 36864 ['activation_102[0][0]']
concatenate_49 (Concatenate) (None, 7, 7, 768) 0 ['concatenate_48[0][0]',
'conv2d_103[0][0]']
batch_normalization_104 (Batch (None, 7, 7, 768) 3072 ['concatenate_49[0][0]']
Normalization)
activation_103 (Activation) (None, 7, 7, 768) 0 ['batch_normalization_104[0][0]']
conv2d_104 (Conv2D) (None, 7, 7, 128) 98304 ['activation_103[0][0]']
batch_normalization_105 (Batch (None, 7, 7, 128) 512 ['conv2d_104[0][0]']
Normalization)
activation_104 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_105[0][0]']
conv2d_105 (Conv2D) (None, 7, 7, 32) 36864 ['activation_104[0][0]']
concatenate_50 (Concatenate) (None, 7, 7, 800) 0 ['concatenate_49[0][0]',
'conv2d_105[0][0]']
batch_normalization_106 (Batch (None, 7, 7, 800) 3200 ['concatenate_50[0][0]']
Normalization)
activation_105 (Activation) (None, 7, 7, 800) 0 ['batch_normalization_106[0][0]']
conv2d_106 (Conv2D) (None, 7, 7, 128) 102400 ['activation_105[0][0]']
batch_normalization_107 (Batch (None, 7, 7, 128) 512 ['conv2d_106[0][0]']
Normalization)
activation_106 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_107[0][0]']
conv2d_107 (Conv2D) (None, 7, 7, 32) 36864 ['activation_106[0][0]']
concatenate_51 (Concatenate) (None, 7, 7, 832) 0 ['concatenate_50[0][0]',
'conv2d_107[0][0]']
batch_normalization_108 (Batch (None, 7, 7, 832) 3328 ['concatenate_51[0][0]']
Normalization)
activation_107 (Activation) (None, 7, 7, 832) 0 ['batch_normalization_108[0][0]']
conv2d_108 (Conv2D) (None, 7, 7, 128) 106496 ['activation_107[0][0]']
batch_normalization_109 (Batch (None, 7, 7, 128) 512 ['conv2d_108[0][0]']
Normalization)
activation_108 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_109[0][0]']
conv2d_109 (Conv2D) (None, 7, 7, 32) 36864 ['activation_108[0][0]']
concatenate_52 (Concatenate) (None, 7, 7, 864) 0 ['concatenate_51[0][0]',
'conv2d_109[0][0]']
batch_normalization_110 (Batch (None, 7, 7, 864) 3456 ['concatenate_52[0][0]']
Normalization)
activation_109 (Activation) (None, 7, 7, 864) 0 ['batch_normalization_110[0][0]']
conv2d_110 (Conv2D) (None, 7, 7, 128) 110592 ['activation_109[0][0]']
batch_normalization_111 (Batch (None, 7, 7, 128) 512 ['conv2d_110[0][0]']
Normalization)
activation_110 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_111[0][0]']
conv2d_111 (Conv2D) (None, 7, 7, 32) 36864 ['activation_110[0][0]']
concatenate_53 (Concatenate) (None, 7, 7, 896) 0 ['concatenate_52[0][0]',
'conv2d_111[0][0]']
batch_normalization_112 (Batch (None, 7, 7, 896) 3584 ['concatenate_53[0][0]']
Normalization)
activation_111 (Activation) (None, 7, 7, 896) 0 ['batch_normalization_112[0][0]']
conv2d_112 (Conv2D) (None, 7, 7, 128) 114688 ['activation_111[0][0]']
batch_normalization_113 (Batch (None, 7, 7, 128) 512 ['conv2d_112[0][0]']
Normalization)
activation_112 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_113[0][0]']
conv2d_113 (Conv2D) (None, 7, 7, 32) 36864 ['activation_112[0][0]']
concatenate_54 (Concatenate) (None, 7, 7, 928) 0 ['concatenate_53[0][0]',
'conv2d_113[0][0]']
batch_normalization_114 (Batch (None, 7, 7, 928) 3712 ['concatenate_54[0][0]']
Normalization)
activation_113 (Activation) (None, 7, 7, 928) 0 ['batch_normalization_114[0][0]']
conv2d_114 (Conv2D) (None, 7, 7, 128) 118784 ['activation_113[0][0]']
batch_normalization_115 (Batch (None, 7, 7, 128) 512 ['conv2d_114[0][0]']
Normalization)
activation_114 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_115[0][0]']
conv2d_115 (Conv2D) (None, 7, 7, 32) 36864 ['activation_114[0][0]']
concatenate_55 (Concatenate) (None, 7, 7, 960) 0 ['concatenate_54[0][0]',
'conv2d_115[0][0]']
batch_normalization_116 (Batch (None, 7, 7, 960) 3840 ['concatenate_55[0][0]']
Normalization)
activation_115 (Activation) (None, 7, 7, 960) 0 ['batch_normalization_116[0][0]']
conv2d_116 (Conv2D) (None, 7, 7, 128) 122880 ['activation_115[0][0]']
batch_normalization_117 (Batch (None, 7, 7, 128) 512 ['conv2d_116[0][0]']
Normalization)
activation_116 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_117[0][0]']
conv2d_117 (Conv2D) (None, 7, 7, 32) 36864 ['activation_116[0][0]']
concatenate_56 (Concatenate) (None, 7, 7, 992) 0 ['concatenate_55[0][0]',
'conv2d_117[0][0]']
batch_normalization_118 (Batch (None, 7, 7, 992) 3968 ['concatenate_56[0][0]']
Normalization)
activation_117 (Activation) (None, 7, 7, 992) 0 ['batch_normalization_118[0][0]']
conv2d_118 (Conv2D) (None, 7, 7, 128) 126976 ['activation_117[0][0]']
batch_normalization_119 (Batch (None, 7, 7, 128) 512 ['conv2d_118[0][0]']
Normalization)
activation_118 (Activation) (None, 7, 7, 128) 0 ['batch_normalization_119[0][0]']
conv2d_119 (Conv2D) (None, 7, 7, 32) 36864 ['activation_118[0][0]']
concatenate_57 (Concatenate) (None, 7, 7, 1024) 0 ['concatenate_56[0][0]',
'conv2d_119[0][0]']
global_average_pooling2d (Glob (None, 1024) 0 ['concatenate_57[0][0]']
alAveragePooling2D)
dense (Dense) (None, 16) 16400 ['global_average_pooling2d[0][0]'
]
activation_119 (Activation) (None, 16) 0 ['dense[0][0]']
dense_1 (Dense) (None, 1024) 17408 ['activation_119[0][0]']
activation_120 (Activation) (None, 1024) 0 ['dense_1[0][0]']
reshape (Reshape) (None, 1, 1, 1024) 0 ['activation_120[0][0]']
tf.math.multiply (TFOpLambda) (None, 7, 7, 1024) 0 ['concatenate_57[0][0]',
'reshape[0][0]']
batch_normalization_120 (Batch (None, 7, 7, 1024) 4096 ['tf.math.multiply[0][0]']
Normalization)
activation_121 (Activation) (None, 7, 7, 1024) 0 ['batch_normalization_120[0][0]']
global_average_pooling2d_1 (Gl (None, 1024) 0 ['activation_121[0][0]']
obalAveragePooling2D)
dense_2 (Dense) (None, 1000) 1025000 ['global_average_pooling2d_1[0][0
]']
==================================================================================================
Total params: 8,096,312
Trainable params: 8,012,664
Non-trainable params: 83,648
__________________________________________________________________________________________________
3.10.编译模型
python
#设置初始学习率
initial_learning_rate = 1e-4
opt = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate)
model.compile(optimizer=opt,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
3.11.训练模型
python
'''训练模型'''
epochs = 20
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
训练记录如下:
Epoch 1/20
54/54 [==============================] - ETA: 0s - loss: 4.1581 - accuracy: 0.5618
Epoch 1: val_accuracy improved from -inf to 0.33956, saving model to best_model.h5
54/54 [==============================] - 24s 236ms/step - loss: 4.1581 - accuracy: 0.5618 - val_loss: 6.6434 - val_accuracy: 0.3396
Epoch 2/20
54/54 [==============================] - ETA: 0s - loss: 1.3906 - accuracy: 0.7042
Epoch 2: val_accuracy improved from 0.33956 to 0.63944, saving model to best_model.h5
54/54 [==============================] - 12s 217ms/step - loss: 1.3906 - accuracy: 0.7042 - val_loss: 5.0245 - val_accuracy: 0.6394
Epoch 3/20
54/54 [==============================] - ETA: 0s - loss: 0.7259 - accuracy: 0.7392
Epoch 3: val_accuracy did not improve from 0.63944
54/54 [==============================] - 11s 211ms/step - loss: 0.7259 - accuracy: 0.7392 - val_loss: 2.8552 - val_accuracy: 0.5729
Epoch 4/20
54/54 [==============================] - ETA: 0s - loss: 0.5109 - accuracy: 0.7940
Epoch 4: val_accuracy did not improve from 0.63944
54/54 [==============================] - 11s 211ms/step - loss: 0.5109 - accuracy: 0.7940 - val_loss: 1.6427 - val_accuracy: 0.6336
Epoch 5/20
54/54 [==============================] - ETA: 0s - loss: 0.3891 - accuracy: 0.8431
Epoch 5: val_accuracy improved from 0.63944 to 0.69137, saving model to best_model.h5
54/54 [==============================] - 12s 218ms/step - loss: 0.3891 - accuracy: 0.8431 - val_loss: 0.9914 - val_accuracy: 0.6914
Epoch 6/20
54/54 [==============================] - ETA: 0s - loss: 0.3434 - accuracy: 0.8635
Epoch 6: val_accuracy did not improve from 0.69137
54/54 [==============================] - 11s 213ms/step - loss: 0.3434 - accuracy: 0.8635 - val_loss: 0.7353 - val_accuracy: 0.6826
Epoch 7/20
54/54 [==============================] - ETA: 0s - loss: 0.2720 - accuracy: 0.8950
Epoch 7: val_accuracy did not improve from 0.69137
54/54 [==============================] - 11s 213ms/step - loss: 0.2720 - accuracy: 0.8950 - val_loss: 0.9839 - val_accuracy: 0.6120
Epoch 8/20
54/54 [==============================] - ETA: 0s - loss: 0.2083 - accuracy: 0.9277
Epoch 8: val_accuracy improved from 0.69137 to 0.74504, saving model to best_model.h5
54/54 [==============================] - 12s 218ms/step - loss: 0.2083 - accuracy: 0.9277 - val_loss: 0.8169 - val_accuracy: 0.7450
Epoch 9/20
54/54 [==============================] - ETA: 0s - loss: 0.2032 - accuracy: 0.9247
Epoch 9: val_accuracy improved from 0.74504 to 0.80980, saving model to best_model.h5
54/54 [==============================] - 12s 217ms/step - loss: 0.2032 - accuracy: 0.9247 - val_loss: 0.4398 - val_accuracy: 0.8098
Epoch 10/20
54/54 [==============================] - ETA: 0s - loss: 0.1558 - accuracy: 0.9411
Epoch 10: val_accuracy did not improve from 0.80980
54/54 [==============================] - 11s 212ms/step - loss: 0.1558 - accuracy: 0.9411 - val_loss: 0.6900 - val_accuracy: 0.7853
Epoch 11/20
54/54 [==============================] - ETA: 0s - loss: 0.1223 - accuracy: 0.9568
Epoch 11: val_accuracy did not improve from 0.80980
54/54 [==============================] - 11s 213ms/step - loss: 0.1223 - accuracy: 0.9568 - val_loss: 0.7019 - val_accuracy: 0.7433
Epoch 12/20
54/54 [==============================] - ETA: 0s - loss: 0.0909 - accuracy: 0.9673
Epoch 12: val_accuracy improved from 0.80980 to 0.82205, saving model to best_model.h5
54/54 [==============================] - 12s 218ms/step - loss: 0.0909 - accuracy: 0.9673 - val_loss: 0.5862 - val_accuracy: 0.8221
Epoch 13/20
54/54 [==============================] - ETA: 0s - loss: 0.1773 - accuracy: 0.9288
Epoch 13: val_accuracy did not improve from 0.82205
54/54 [==============================] - 11s 212ms/step - loss: 0.1773 - accuracy: 0.9288 - val_loss: 0.7781 - val_accuracy: 0.7905
Epoch 14/20
54/54 [==============================] - ETA: 0s - loss: 0.1375 - accuracy: 0.9481
Epoch 14: val_accuracy improved from 0.82205 to 0.85998, saving model to best_model.h5
54/54 [==============================] - 12s 218ms/step - loss: 0.1375 - accuracy: 0.9481 - val_loss: 0.3867 - val_accuracy: 0.8600
Epoch 15/20
54/54 [==============================] - ETA: 0s - loss: 0.0727 - accuracy: 0.9755
Epoch 15: val_accuracy improved from 0.85998 to 0.91482, saving model to best_model.h5
54/54 [==============================] - 12s 224ms/step - loss: 0.0727 - accuracy: 0.9755 - val_loss: 0.2605 - val_accuracy: 0.9148
Epoch 16/20
54/54 [==============================] - ETA: 0s - loss: 0.0412 - accuracy: 0.9912
Epoch 16: val_accuracy improved from 0.91482 to 0.91890, saving model to best_model.h5
54/54 [==============================] - 12s 220ms/step - loss: 0.0412 - accuracy: 0.9912 - val_loss: 0.1958 - val_accuracy: 0.9189
Epoch 17/20
54/54 [==============================] - ETA: 0s - loss: 0.0466 - accuracy: 0.9848
Epoch 17: val_accuracy did not improve from 0.91890
54/54 [==============================] - 11s 213ms/step - loss: 0.0466 - accuracy: 0.9848 - val_loss: 0.2973 - val_accuracy: 0.8991
Epoch 18/20
54/54 [==============================] - ETA: 0s - loss: 0.0786 - accuracy: 0.9697
Epoch 18: val_accuracy did not improve from 0.91890
54/54 [==============================] - 11s 213ms/step - loss: 0.0786 - accuracy: 0.9697 - val_loss: 1.5921 - val_accuracy: 0.7170
Epoch 19/20
54/54 [==============================] - ETA: 0s - loss: 0.0757 - accuracy: 0.9778
Epoch 19: val_accuracy improved from 0.91890 to 0.92065, saving model to best_model.h5
54/54 [==============================] - 12s 218ms/step - loss: 0.0757 - accuracy: 0.9778 - val_loss: 0.2539 - val_accuracy: 0.9207
Epoch 20/20
54/54 [==============================] - ETA: 0s - loss: 0.1000 - accuracy: 0.9656
Epoch 20: val_accuracy did not improve from 0.92065
54/54 [==============================] - 11s 213ms/step - loss: 0.1000 - accuracy: 0.9656 - val_loss: 1.0522 - val_accuracy: 0.6914
3.12.模型评估
python
'''模型评估'''
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(loss))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
3.13.图像预测
python
'''指定图片进行预测'''
# 采用加载的模型(new_model)来看预测结果
plt.figure(figsize=(10, 5)) # 图形的宽为10高为5
plt.suptitle("预测结果展示", fontsize=10)
for images, labels in val_ds.take(1):
for i in range(8):
ax = plt.subplot(2, 4, i + 1)
# 显示图片
plt.imshow(images[i].numpy().astype("uint8"))
# 需要给图片增加一个维度
img_array = tf.expand_dims(images[i], 0)
# 使用模型预测图片中的人物
predictions = model.predict(img_array)
plt.title(class_names[np.argmax(predictions)], fontsize=10)
plt.axis("off")
plt.show()
4 知识点详解
4.1 SE-Net算法详解
SE-Net是ImageNet 2017 (lmageNet 收官赛)的冠军模型,是由WMW团队发布。具有复杂度低,参数少和计算量小的优点。且SENet 思路很简单,很容易扩展到已有网络结构如 Inception 和 ResNet 中。已经有很多工作在空间维度上来提升网络的性能,如 nception 等,而 SENet 将关注点放在了特征通道之间的关系上。其具体策略为: 通过学习的方式来自动获取到每个特征通道的重要程度,然后依照这个重要程度去提升有用的特征并抑制对当前任务用处不大的特征,这又叫做"特征重标定"策略。
SE模块的灵活性在于它可以直接应用现有的网络结构中。以Inception 和 ResNet 为例,我们只需要在Inception 模块或 Residual 模块后添加一个SE 模块即可。具体如下图所示:
具体的SE模块如上图所示。给定一个输入 x x x,其特征通道数为 c 1 c_1 c1,通过一系列卷积等变换 F t r F_{tr} Ftr后得到一个特征通道数为 c 2 c_2 c2的特征。与传统的卷积神经网络不同,我们需要通过下面三个操作来重标定前面得到的特征。
(1)Squeeze:顺着空间维度来进行特征压缩,将一个通道中整个空间特征编码为一个全局特征,这个实数某种程度上具有全局的感受野,并且输出的通道数和输入的特征通道数相等,例如将形状为(1, 32, 32, 10)的feature map压缩成(1, 1, 1, 10)。此操作通常采用global average pooling来实现。
(2)Excitation:得到全局描述特征后,通过Excitation来获取特征通道之间的关系,它是一个类似于循环神经网络中门的机制。
s = F e x ( z , W ) = σ ( g ( z , W ) ) = σ ( W 2 R e L U ( W 1 ) ) s=F_{ex}(z,W)=\sigma(g(z,W))=\sigma(W_2ReLU(W_1)) s=Fex(z,W)=σ(g(z,W))=σ(W2ReLU(W1))
这里采用包含两个全连接层的bottleneck结构,即中间小两头大的结构:其中第一个全连接层起到降维的作用,并通过ReLU激活,第二个全连接层用来将其恢复至原始的维度。进行Excitation操作的最终目的是为每个特征通道生成权重,即学习到各个通道的激活值(sigmoid激活,值在0~1之间)。
(3)Scale:我们将Excitation的输出权重看做是经过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定,从而使得模型对各个通道的特征更具有辨别能力,这类似于attention机制。
该过程可以简单的概括为:
从框架图中能看出就是在Residual后添加一个SE过程,
1、首先建立一个Global pooling 获取全局视野
2、两次全连接FC:第一次完成的是降维作用,一次完成的是升维作用恢复到原始维度
(一降一升,维度不变,因此可以随意加到任何过程之后)
3、通过sigmoid激活,权重参数在0~1之间
4、最后Scale操作把权重参数加回原始维度的Residual (Residual * Weight)
SE模块很容易嵌入到其他网络中,为了验证SE模块的作用,在其它流行网络如ResNet和Inception中引入SE模块,测试其在ImageNet上的效果,如下表所示
首先看一下网络的深度对 SE 的影响。上表分别展示了 ResNet-50、ResNet-101、ResNet-152 和嵌入 SE 模型的结果。第一栏 Original 是原作者实现的结果,为了进行公平的比较,重新进行了实验得到 Our re-implementation 的结果。最后一栏 SE-module 是指嵌入了 SE 模块的结果,它的训练参数和第二栏 Our re-implementation 一致。括号中的红色数值是指相对于 Our re-implementation 的精度提升的幅值。
从上表可以看出,SE-ResNets 在各种深度上都远远超过了其对应的没有SE的结构版本的精度,这说明无论网络的深度如何,SE模块都能够给网络带来性能上的增益。值得一提的是,SE-ResNet-50 可以达到和ResNet-101 一样的精度;更甚,SE-ResNet-101 远远地超过了更深的ResNet-152。
上图展示了ResNet-50 和 ResNet-152 以及它们对应的嵌入SE模块的网络在ImageNet上的训练过程,可以明显地看出加入了SE模块的网络收敛到更低的错误率上。
4 总结
普通的卷积实际上是对局部区域进行的特征融合,因此其感受野不大,若设计出更多的通道特征来增加这个,不可避免的将导致计算量大大的增加。而SENet网络的创新点在于关注channel之间的关系,希望模型可以自动学习到不同channel特征的重要程度。
简而言之,在每个channel上将整个特征图浓缩成一个值,即在Squeeze步骤中通过averagepooling的操作计算每个通道的特征,此时每个通道只有一个特征,即size为c;然后在Excitation步骤中,通过降维+ReLU+升维+sigmoid操作,建模出特征通道之间的相互依赖关系,计算出每个特征通道的重要程度,此时size仍为c,c中的每个元素代表着相应通道的重要程度,越重要则越接近1;最后在Scale步骤中,将之前的操作得出的特征图进行scale操作,而scale的权重就是刚刚计算出的Excitation特征(size为c)通过reshape后(size为11c)的矩阵,即对各个通道的特征进行相应的放大或缩小。