深度学习2-pyTorch学习-第一个神经网络

第一个PyTorch神经网络代码

面向对象编程OOP

首先,我们来理解一下类(Class)和对象(Object):

  • 类(Class)是创建对象的蓝图或模板。它定义了一组属性(变量)和方法(函数),这些属性和方法属于该类创建的任何对象。

  • 对象(Object)是类的实例。当根据类创建对象时,每个对象都拥有类中定义的属性和方法。

self是类的实例(对象)的引用。

fc1是实例的一个属性,

什么是方法的应用

x.y 表示访问对象的属性y。

x.y(z1,z2),这表示调用对象x的方法y,并传入参数z1,z2。注意,这里的y是一个方法(函数)。

在python中,函数也是一等公民,可以像变量一样传递,当传递函数的时候,是不需要括号的:

复制代码
# 当你想:
# 1. 执行函数/方法 → 用括号
# 2. 传递函数/方法本身 → 不用括号

# 例子:
func = torch.relu          # 把relu函数赋值给func(不用括号)
result = func(tensor)      # 使用func计算(用括号)
result2 = torch.relu(tensor)  # 直接使用(用括号)

def _init(self)

是一个特殊的方法,被称为"构造方法"或"初始化方法"。当你创建一个类的实例(对象)时,Python会自动调用这个方法。它的作用是初始化对象的属性。

self

在Python类的方法中是一个指向对象实例本身的引用。它是一个约定俗成的名称(你可以用其他名字,但强烈建议使用self)。当你调用一个对象的方法时,Python会自动将对象实例作为第一个参数传递给该方法,这个参数就是self。

神经网络的训练和部署一般流程

训练模型是机器学习和深度学习中的核心过程,旨在通过大量数据学习模型参数,以便模型能够对新的、未见过的数据做出准确的预测。

训练模型通常包括以下几个步骤:

  1. 数据准备:
    收集和处理数据,包括清洗、标准化和归一化。
    将数据分为训练集、验证集和测试集。
  2. 定义模型:
    选择模型架构,例如决策树、神经网络等。
    初始化模型参数(权重和偏置)。
  3. 选择损失函数:
    根据任务类型(如分类、回归)选择合适的损失函数。
  4. 选择优化器:
    选择一个优化算法,如SGD、Adam等,来更新模型参数。
  5. 前向传播:
    在每次迭代中,将输入数据通过模型传递,计算预测输出。
  6. 计算损失:
    使用损失函数评估预测输出与真实标签之间的差异。
  7. 反向传播:
    利用自动求导计算损失相对于模型参数的梯度。
  8. 参数更新:
    根据计算出的梯度和优化器的策略更新模型参数。
  9. 迭代优化:
    重复步骤5-8,直到模型在验证集上的性能不再提升或达到预定的迭代次数。
    评估和测试:
  10. 使用测试集评估模型的最终性能,确保模型没有过拟合。
    模型调优:
  11. 根据模型在测试集上的表现进行调参,如改变学习率、增加正则化等。
  12. 部署模型:
    将训练好的模型部署到生产环境中,用于实际的预测任务。

第一个神经

网络的代码如下:

复制代码
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的全连接神经网络
class SimplenNN(nn.Module):
    def __init__(self):
        super(SimplenNN, self).__init__()
        self.fc1 = nn.Linear(2,2) # 输入层到隐藏层
        self.fc2 = nn.Linear(2,1) # 隐藏层到输出层

    def forward(self, x):
        x = torch.relu(self.fc1(x)) # ReLu激活函数
        x = self.fc2(x)
        return x

# 创建网络实例
model = SimplenNN()
# 打印模型结构
print(model)

模块和类

nn是一个模块(Module),它包含了许多类、函数和其他模块。nn.Module是一个类,它是所有的神经网络的模块的基类。

nn.Module也是一个类,它继承自nn.Module。其层次结构如下:

复制代码
torch.nn (模块,包含很多类)
├── nn.Module (类:所有神经网络组件的基类)
│   ├── nn.Linear (类:线性层,继承自Module)
│   ├── nn.Conv2d (类:卷积层,继承自Module)
│   ├── nn.ReLU (类:激活函数,继承自Module)
│   └── ... 其他200多个类

继承

类似于class 子类名(父类名)的格式就代表了一种继承。

class SimplenNN()定义了一个类,该类继承了nn.Module

super(SimplenNN, self).init () 作用是继承父类的功能。

nn.Linear(2,2)仅仅是创建对象,

神经网络基类 nn.Module

nn.Module是PyTorch中所有神经网络模块的基类,它提供了一些必要的机制,例如

  • 将模型迁移到GPU
  • 保存加载模型
  • 提供一些必要的魔法方法
  • 管理网络的参数(通过parameters())

网络的前向传播与计算损失函数

复制代码
import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的全连接神经网络
class SimplenNN(nn.Module):
    def __init__(self):
        super(SimplenNN, self).__init__()
        self.fc1 = nn.Linear(2,2) # 输入层到隐藏层
        self.fc2 = nn.Linear(2,1) # 隐藏层到输出层

    def forward(self, x):
        x = torch.relu(self.fc1(x)) # ReLu激活函数
        x = self.fc2(x)
        return x

# 创建网络实例
model = SimplenNN()
# 打印模型结构
print(model)

# 随机输入
x = torch.randn(1,2)

# 前向传播
output = model(x)
print(output)

# 定义损失函数
criterion = nn.MSELoss() # 均方根误差

#假设目标值
target = torch.randn(1,1)
print(target)

#计算损失函数
loss = criterion(output, target)
print(loss)

输出一下网络的输出、真实值、损失函数值

复制代码
tensor([[0.1357]], grad_fn=<AddmmBackward0>)
tensor([[-1.4453]])
tensor(2.4993, grad_fn=<MseLossBackward0>)

可以3个数对应上了。

网络的训练

下面展示1次训练的代码,首先是定义优化器,然后依次:清空梯度、反向传播、更新参数:

复制代码
#定义优化器(使用Adam优化器)
optimizer = optim.Adam(model.parameters())

# 训练步骤
optimizer.zero_grad() #清空梯度
loss.backward()       #反向传播
optimizer.step()      #更新参数
相关推荐
我的xiaodoujiao2 小时前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 30--开源电商商城系统项目实战--配置测试环境地址
python·学习·测试工具·pytest
YJlio2 小时前
Active Directory 工具学习笔记(10.2):AdExplorer 实战(二)— 对象 / 属性 / 搜索 / 快照
java·笔记·学习
Allen_LVyingbo2 小时前
多模态知识图谱赋能大学医疗AI精准教学研究(上)
学习·知识图谱·健康医疗
青衫码上行2 小时前
【JavaWeb学习 | 第19篇】Filter过滤器
java·学习·servlet·tomcat
IT·小灰灰2 小时前
DeepSeek-V3.2:开源大模型的里程碑式突破与硅基流动平台实战指南
大数据·人工智能·python·深度学习·算法·数据挖掘·开源
stereohomology2 小时前
用大模型学习everything 1.5a的特殊用法
学习·everything
【建模先锋】2 小时前
精品数据分享 | 锂电池数据集(六)基于深度迁移学习的锂离子电池实时个性化健康状态预测
人工智能·深度学习·机器学习·迁移学习·锂电池寿命预测·锂电池数据集·寿命预测
集30411 小时前
C++多线程学习笔记
c++·笔记·学习
知南x11 小时前
【正点原子STM32MP157 可信任固件TF-A学习篇】(2) STM32MP1 中的 TF-A
stm32·嵌入式硬件·学习·stm32mp157