Pytorch代码入门学习之分类任务(一):搭建网络框架

目录

一、网络框架介绍

二、导包

三、定义卷积神经网络

[3.1 代码展示](#3.1 代码展示)

[3.2 定义网络的目的](#3.2 定义网络的目的)

[3.3 Pytorch搭建网络](#3.3 Pytorch搭建网络)

四、测试网络效果


一、网络框架介绍

网络理解:

将32*32大小的灰度图片(下述的代码中输入为32*32大小的RGB彩色图片),输入到网络中;经过第一次卷积C1,变成了6通道、28*28大小的一个特征向量;通过一次下采样S2,变成了6通道、14*14大小的一个特征向量,其宽高相当于折损了一般;经过第二次卷积C3,变成了16通道、10*10大小的一个特征向量;通过第二次下采样S4,变成了16通道、5*5大小的一个特征向量;最后三层全连接输出。

①Convolutious(卷积):涉及到输入、输出与很多参数的设置,需要初始化。

②Subsampling(下采样):该网络中使用的是最大池化下采样的方法,最大池化下采样的和维2*2大小。

**最大池化:**Max Pooling,取窗口内的最大值作为输出。

③Full Connection(全连接):需要初始化。

二、导包

python 复制代码
import torch  # torch基础库
import torch.nn as nn  # torch神经网络库
import torch.nn.functional as F

三、定义卷积神经网络

3.1 代码展示

python 复制代码
class Net(nn.Module):
    # 初始化
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    # 前向传播
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,x.size()[1:].numel())
        x = F.relu(self.fc1(x))  # 进入全连接层需要进行激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # 最后一层为输出层,要输出结果,不需要进行激活
        return x

3.2 定义网络的目的

希望网络有科学系参数,通过输入数据的训练让相关参数不断更新、梯度下降到一个合适的值,之后输入新的图片,可以进行分类或者预测。

3.3 Pytorch搭建网络

Pytorch搭建网络通常会采用类进行管理,可取名为Net(该名字可以更换),通常需要继承nn.Model类(相当于在Net中将Model定义好的方法直接进行使用)。搭建网络通常包括两个函数:

**①初始化函数(含有默认参数):**实例化这个类的时候会自动执行的一部分,这里面放网络需要初始化的内容。

def init(self)

**A. super(Net,self).init():**在该函数中通常需要进行多继承操作,相当于把Model类里面继承的类以及全部的类的方法都继承下来,供Net去使用;

**B. nn.Conv2d(3,6,5):**2d卷积核的函数,只涉及三个参数,其余参数使用默认值;第一个参数为输入的通道数,第二个参数为输出特征向量的通道数,第三个参数为卷积核大小(使用output公式进行计算 W-F+1=28,W=32,F=5 );

:其中W是指宽高,F是指所求的ColorSize的大小,P是指Padding---像图片外面补边,让它去遍历,默认为0;S是指步长,卷积核遍历图片的步长,默认为1;

**C. nn.Linear(16*5*5,120):**全连接层的初始化,涉及两个参数(输入特征的维数大小和输出特征的维数大小),全连接层需要对特征做一个拉平,将每一个特征拉平,将上一个特征向量拉为一条直线,送给全连接层;

**②前向传播函数:**需实现前向回归逻辑,相当于完成整个网络运行的逻辑,x是指输入,相当于上图中的input。

def forward(self,x)

**A. F.relu(x):**relu激活函数,激活之后网络具有非线性的分离能力;

B. tensor[batch,channel,H,W]: channel是指通道数,例如RBG三通道这些概念、H是指高,W是指宽,batch是指有几批这样的数据;

**C. F.max_pool2d(x,(2,2)):**最大池化下采样对x进行处理;

**D. x.view(-1,x.size()[1:].numel()):**进行拉平、展平之后给全连接层,对当前的输入数据x进行一个形式转换,输入行和列,这里所对应的列等于self.fc1 = nn.Linear(16*5*5,120)这里所对应的行,为x.size切片之后数据的乘积;行信息根据批次信息自动生成,-1让程序自动生成这个行;为什么要切1,对于tensor信息来说,将batch切掉,channel、H、W相乘等于16*5*5;

注意: Pytorch处理的都是张量(张量是神经网络所使用的主要数据结构)数据。

四、测试网络效果

相当于打印网络初始化部分,也可以与网络结构相对应检查一下。

python 复制代码
net = Net()
print(net)

参考:Pytorch逐行代码入门学习_哔哩哔哩_bilibili

相关推荐
2501_941507942 小时前
【YOLOv26】教育环境中危险物品实时检测系统_基于深度学习的校园安全解决方案
深度学习·安全·yolo
Hgfdsaqwr8 小时前
Django全栈开发入门:构建一个博客系统
jvm·数据库·python
开发者小天9 小时前
python中For Loop的用法
java·服务器·python
老百姓懂点AI9 小时前
[RAG实战] 向量数据库选型与优化:智能体来了(西南总部)AI agent指挥官的长短期记忆架构设计
python
沃达德软件10 小时前
人工智能治安管控系统
图像处理·人工智能·深度学习·目标检测·计算机视觉·目标跟踪·视觉检测
劈星斩月10 小时前
神经网络之感知机(Perceptron)
神经网络·感知机·perceptron
喵手11 小时前
Python爬虫零基础入门【第九章:实战项目教学·第15节】搜索页采集:关键词队列 + 结果去重 + 反爬友好策略!
爬虫·python·爬虫实战·python爬虫工程化实战·零基础python爬虫教学·搜索页采集·关键词队列
Suchadar11 小时前
if判断语句——Python
开发语言·python
ʚB҉L҉A҉C҉K҉.҉基҉德҉^҉大11 小时前
自动化机器学习(AutoML)库TPOT使用指南
jvm·数据库·python
喵手12 小时前
Python爬虫零基础入门【第九章:实战项目教学·第14节】表格型页面采集:多列、多行、跨页(通用表格解析)!
爬虫·python·python爬虫实战·python爬虫工程化实战·python爬虫零基础入门·表格型页面采集·通用表格解析