14.神经网络的基本骨架 - nn.Module 的使用

神经网络的基本骨架 - nn.Module 的使用

Pytorch官网左侧:Python API(相当于package,提供了一些不同的工具)

关于神经网络的工具主要在torch.nn里

网站地址:torch.nn --- PyTorch 1.8.1 documentation

Containers

Containers 包含6个模块:

  • Module
  • Sequential
  • ModuleList
  • ModuleDict
  • ParameterList
  • ParameterDict

其中最常用的是 Module 模块 (为所有神经网络提供基本骨架)

复制代码
CLASS torch.nn.Module  #搭建的Model都必须继承该类

模板:

复制代码
import torch.nn as nn
import torch.nn.functional as F
 
class Model(nn.Module):   #搭建的神经网络 Model继承了 Module类(父类)
    def __init__(self):   #初始化函数
        super(Model, self).__init__()   #必须要这一步,调用父类的初始化函数
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
 
    def forward(self, x):   #前向传播(为输入和输出中间的处理过程),x为输入
        x = F.relu(self.conv1(x))   #conv为卷积,relu为非线性处理
        return F.relu(self.conv2(x))

代码中比较重要:

前向传播 forward(在所有子类中进行重写)

反向传播 backward

实战

先介绍pycharm的实用工具,使用 Code ---> Generate ---> Override Methods 可以自动补全代码

例子:

复制代码
import torch
from torch import nn
 
 
class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
 
    # def __init__(self):
    #     super(Tudui, self).__init__()
 
    def forward(self,input):
        output = input + 1
        return output
 
tudui = Tudui()   #拿Tudui模板创建出的神经网络
x = torch.tensor(1.0)  #将1.0这个数转换成tensor类型
output = tudui(x)
print(output)

上面的代码根据网站所提供的案例模版得到

运行结果:

debug看流程

在下列语句前打断点:

复制代码
tudui = Tudui()   #整个程序的开始

然后点击蜘蛛,点击 Step into My Code,可以看到代码每一步的执行过程

i() #整个程序的开始

复制代码
然后点击蜘蛛,点击 Step into My Code,可以看到代码每一步的执行过程

[外链图片转存中...(img-8Rp3mCOt-1724861486484)]
相关推荐
音视频牛哥16 分钟前
超清≠清晰:视频系统里的分辨率陷阱与秩序真相
人工智能·机器学习·计算机视觉·音视频·大牛直播sdk·rtsp播放器rtmp播放器·smartmediakit
johnny23318 分钟前
AI视频创作工具汇总:MoneyPrinterTurbo、KrillinAI、NarratoAI、ViMax
人工智能·音视频
Coovally AI模型快速验证1 小时前
当视觉语言模型接收到相互矛盾的信息时,它会相信哪个信号?
人工智能·深度学习·算法·机器学习·目标跟踪·语言模型
居7然1 小时前
Attention注意力机制:原理、实现与优化全解析
人工智能·深度学习·大模型·transformer·embedding
Scabbards_1 小时前
KGGEN: 用语言模型从纯文本中提取知识图
人工智能·语言模型·自然语言处理
LeonDL1681 小时前
【通用视觉框架】基于C#+Winform+OpencvSharp开发的视觉框架软件,全套源码,开箱即用
人工智能·c#·winform·opencvsharp·机器视觉软件框架·通用视觉框架·机器视觉框架
AI纪元故事会1 小时前
《目标检测全解析:从R-CNN到DETR,六大经典模型深度对比与实战指南》
人工智能·yolo·目标检测·r语言·cnn
Shang180989357262 小时前
T41LQ 一款高性能、低功耗的系统级芯片(SoC) 适用于各种AIoT应用智能安防、智能家居方案优选T41L
人工智能·驱动开发·嵌入式硬件·fpga开发·信息与通信·信号处理·t41lq
Bony-2 小时前
用于糖尿病视网膜病变图像生成的GAN
人工智能·神经网络·生成对抗网络
罗西的思考2 小时前
【Agent】 ACE(Agentic Context Engineering)源码阅读笔记---(3)关键创新
人工智能·算法