学习基于pytorch的VGG图像分类 day2

注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.

目录

VGG网络搭建(模型文件)

1.字典文件配置

2.提取特征网络结构

[3. VGG类的定义](#3. VGG类的定义)

4.VGG网络实例化


VGG网络搭建(模型文件)

1.字典文件配置

python 复制代码
#字典文件,对应各个配置,数字对应卷积核的个数,'M'对应最大液化(即maxpool)
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

2.提取特征网络结构

python 复制代码
#提取特征网络结构
def make_features(cfg: list): #传入对应的列表
    layers = [] #定义一个空列表,存放每层的结果
    in_channels = 3 #输入为RGB彩色图片,输入通道为3
    for v in cfg: #通过for循环遍历列表
        if v == "M":                                                    #maxpool size = 2,stride = 2
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)] #创建最大池化下载量程,池化核为2,布局也为2
        else:                                                           #conv padding = 1,stride = 1
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) #创建卷积操作(输入特征矩阵深度,输出特征矩阵深度(卷积核个数),卷积核为3,填充为1,stride默认为1(不用写))
            layers += [conv2d, nn.ReLU(True)] #使用ReLU激活函数
            in_channels = v #输出深度改变成v
    return nn.Sequential(*layers) #通过Sequential函数将列表以非关键字参数的形式传入(*代表非关键字传入)

3. VGG类的定义

python 复制代码
class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=False): #(通过make_features生成的提取特征网络结构,分类的类别个数,是否对网络权重初始化)
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential( #生成分类网络
            nn.Linear(512*7*7, 4096), #全连接层上下的节点个数
            nn.ReLU(True),  #ReLU函数激活
            nn.Dropout(p=0.5), #Dropout函数减少过拟合,以50%的比例随机失活神经元
            nn.Linear(4096, 4096), #第一层和第二层
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes) #第二层和第三层,总计3层全连接层,最后连接到输出层,输出num_classes的所需个数
        )
        if init_weights: #初始化权重函数
            self._initialize_weights()

    def forward(self, x): #正向传播 x就是输入的图像数据 
        # N x 3 x 224 x 224
        x = self.features(x) #用features提取特征网络结构
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1) #对输出进行一个展平处理,(start_dim定义从哪个维度开始展平处理)
        # N x 512*7*7
        x = self.classifier(x) #输入到分类网络结构
        return x

    def _initialize_weights(self):
        for m in self.modules(): #遍历网络的每一个子模块
            if isinstance(m, nn.Conv2d): #遍历到卷积层
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight) #使用xavier函数初始化,初始化卷积核的权重
                if m.bias is not None: #卷积核采用偏置
                    nn.init.constant_(m.bias, 0) #将偏执初始化为0
            elif isinstance(m, nn.Linear): #遍历到全连接层,下面同理
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

4.VGG网络实例化

python 复制代码
#实例化VGG网络结构
def vgg(model_name="vgg16", **kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg), **kwargs) #通过VGG这个类实现实例化网络,(**可变长度的字典变量)
    return model

内容参考来源:

​​​​​​使用pytorch搭建VGG网络_哔哩哔哩_bilibili

相关推荐
Ronin-Lotus8 分钟前
嵌入式硬件篇---ADC模拟-数字转换
笔记·stm32·单片机·嵌入式硬件·学习·低代码·模块测试
编程小猹15 分钟前
学习golang语言时遇到的难点语法
学习·golang·xcode
promising-w32 分钟前
单片机基础模块学习——数码管
单片机·嵌入式硬件·学习
不爱学英文的码字机器1 小时前
我的2024:创作历程与成长总结
学习·程序人生·交友
AI街潜水的八角1 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
Sean_summer1 小时前
1.21学习
学习
东京老树根2 小时前
Excel 技巧17 - 如何计算倒计时,并添加该倒计时的数据条(★)
笔记·学习·excel
不想写代码的我2 小时前
梁山派入门指南3——串口使用详解,包括串口发送数据、重定向、中断接收不定长数据、DMA+串口接收不定长数据,以及对应的bsp文件和使用示例
单片机·学习·gd32·梁山派
虾球xz2 小时前
游戏引擎学习第84天
学习·游戏引擎
m0_748240543 小时前
AutoSar架构学习笔记
笔记·学习·架构