Pytorch代码入门学习之分类任务(一):搭建网络框架

目录

一、网络框架介绍

二、导包

三、定义卷积神经网络

[3.1 代码展示](#3.1 代码展示)

[3.2 定义网络的目的](#3.2 定义网络的目的)

[3.3 Pytorch搭建网络](#3.3 Pytorch搭建网络)

四、测试网络效果


一、网络框架介绍

网络理解:

将32*32大小的灰度图片(下述的代码中输入为32*32大小的RGB彩色图片),输入到网络中;经过第一次卷积C1,变成了6通道、28*28大小的一个特征向量;通过一次下采样S2,变成了6通道、14*14大小的一个特征向量,其宽高相当于折损了一般;经过第二次卷积C3,变成了16通道、10*10大小的一个特征向量;通过第二次下采样S4,变成了16通道、5*5大小的一个特征向量;最后三层全连接输出。

①Convolutious(卷积):涉及到输入、输出与很多参数的设置,需要初始化。

②Subsampling(下采样):该网络中使用的是最大池化下采样的方法,最大池化下采样的和维2*2大小。

**最大池化:**Max Pooling,取窗口内的最大值作为输出。

③Full Connection(全连接):需要初始化。

二、导包

python 复制代码
import torch  # torch基础库
import torch.nn as nn  # torch神经网络库
import torch.nn.functional as F

三、定义卷积神经网络

3.1 代码展示

python 复制代码
class Net(nn.Module):
    # 初始化
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    # 前向传播
    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x,(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(-1,x.size()[1:].numel())
        x = F.relu(self.fc1(x))  # 进入全连接层需要进行激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # 最后一层为输出层,要输出结果,不需要进行激活
        return x

3.2 定义网络的目的

希望网络有科学系参数,通过输入数据的训练让相关参数不断更新、梯度下降到一个合适的值,之后输入新的图片,可以进行分类或者预测。

3.3 Pytorch搭建网络

Pytorch搭建网络通常会采用类进行管理,可取名为Net(该名字可以更换),通常需要继承nn.Model类(相当于在Net中将Model定义好的方法直接进行使用)。搭建网络通常包括两个函数:

**①初始化函数(含有默认参数):**实例化这个类的时候会自动执行的一部分,这里面放网络需要初始化的内容。

def init(self)

**A. super(Net,self).init():**在该函数中通常需要进行多继承操作,相当于把Model类里面继承的类以及全部的类的方法都继承下来,供Net去使用;

**B. nn.Conv2d(3,6,5):**2d卷积核的函数,只涉及三个参数,其余参数使用默认值;第一个参数为输入的通道数,第二个参数为输出特征向量的通道数,第三个参数为卷积核大小(使用output公式进行计算 W-F+1=28,W=32,F=5 );

:其中W是指宽高,F是指所求的ColorSize的大小,P是指Padding---像图片外面补边,让它去遍历,默认为0;S是指步长,卷积核遍历图片的步长,默认为1;

**C. nn.Linear(16*5*5,120):**全连接层的初始化,涉及两个参数(输入特征的维数大小和输出特征的维数大小),全连接层需要对特征做一个拉平,将每一个特征拉平,将上一个特征向量拉为一条直线,送给全连接层;

**②前向传播函数:**需实现前向回归逻辑,相当于完成整个网络运行的逻辑,x是指输入,相当于上图中的input。

def forward(self,x)

**A. F.relu(x):**relu激活函数,激活之后网络具有非线性的分离能力;

B. tensor[batch,channel,H,W]: channel是指通道数,例如RBG三通道这些概念、H是指高,W是指宽,batch是指有几批这样的数据;

**C. F.max_pool2d(x,(2,2)):**最大池化下采样对x进行处理;

**D. x.view(-1,x.size()[1:].numel()):**进行拉平、展平之后给全连接层,对当前的输入数据x进行一个形式转换,输入行和列,这里所对应的列等于self.fc1 = nn.Linear(16*5*5,120)这里所对应的行,为x.size切片之后数据的乘积;行信息根据批次信息自动生成,-1让程序自动生成这个行;为什么要切1,对于tensor信息来说,将batch切掉,channel、H、W相乘等于16*5*5;

注意: Pytorch处理的都是张量(张量是神经网络所使用的主要数据结构)数据。

四、测试网络效果

相当于打印网络初始化部分,也可以与网络结构相对应检查一下。

python 复制代码
net = Net()
print(net)

参考:Pytorch逐行代码入门学习_哔哩哔哩_bilibili

相关推荐
IT古董34 分钟前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
摸鱼仙人~2 小时前
Attention Free Transformer (AFT)-2020论文笔记
论文阅读·深度学习·transformer
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
小王子10243 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui3 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20254 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
Mason Lin4 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客4 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
RZer6 小时前
Hypium+python鸿蒙原生自动化安装配置
python·自动化·harmonyos
davenian7 小时前
DeepSeek-R1 论文. Reinforcement Learning 通过强化学习激励大型语言模型的推理能力
人工智能·深度学习·语言模型·deepseek