如何理解神经网络训练的循环过程

正向传播 → 反向传播 → 参数更新

这个过程是一个完整的训练迭代,然后不断重复这一轮又一轮,直到模型收敛(损失函数趋于最小),这就是神经网络训练的核心流程。


🧠 详细解释整个循环过程:

我们把整个训练流程拆解成一个"学习周期",它包括以下三个核心步骤:


1️⃣ 正向传播(Forward Propagation)

  • 输入数据:x
  • 计算每一层的输出,最终得到模型预测值 y(/hat)
  • 根据真实值 y 计算当前损失(Loss):比如均方误差、交叉熵等

目的

  • 得到当前参数下的预测结果
  • 为反向传播提供计算梯度所需的所有中间变量

2️⃣ 反向传播(Backward Propagation)

  • 利用链式法则(Chain Rule)从输出层开始,逐层向前计算每个参数(权重 W 、偏置 b)对损失的影响(即梯度)
  • 所有梯度保存下来供后续使用

目的

  • 知道每个参数怎么影响损失函数
  • 为参数更新提供依据

3️⃣ 参数更新(Parameter Update)

  • 使用梯度下降或其他优化算法(如 Adam、SGD with momentum)来更新参数:
    W : = W − α ⋅ ∂ L ∂ W W := W - \alpha \cdot \frac{\partial L}{\partial W} W:=W−α⋅∂W∂L
    b : = b − α ⋅ ∂ L ∂ b b := b - \alpha \cdot \frac{\partial L}{\partial b} b:=b−α⋅∂b∂L
    其中 α 是学习率

目的

  • 调整参数,使下一次预测更准确
  • 让损失函数逐步减小

🔁 循环进行:Epoch × Batch

整个流程通常是在两个嵌套循环中进行的:

text 复制代码
for epoch in range(总训练轮数):
    for batch in 数据集:
        正向传播 → 计算预测和损失
        反向传播 → 计算梯度
        参数更新 → 梯度下降优化参数
  • epoch:完整遍历一遍所有训练数据
  • batch:每次使用的数据子集(mini-batch)

📈 整个训练过程中,我们期望看到的是:

阶段 表现
刚开始训练 损失大,预测不准
中期训练 损失逐渐下降,预测变好
接近收敛 损失稳定在一个较小值,模型表现良好

🎯 总结一句话:

神经网络的训练就是一个不断"预测 → 算误差 → 算梯度 → 调参数"的循环过程。通过一次次的正向传播、反向传播和参数更新,模型逐步学会如何做出更准确的预测,直到损失函数达到一个我们认为满意的最小值。


实践

使用 小批量梯度下降(Mini-batch Gradient Descent) 作为优化算法,以 MNIST 手写数字分类任务为例,构建一个简单的神经网络进行训练。

✅ 使用工具:PyTorch

  • 自动求导机制支持反向传播;
  • DataLoader 支持 mini-batch;
  • SGD 优化器实现小批量梯度下降;

🧠 代码实现

python 复制代码
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ================== 1. 数据准备 ==================
# 数据预处理:标准化 + 张量转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载训练数据和测试数据
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# ================== 2. 模型定义 ==================
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 展平图像
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleNet()

# ================== 3. 损失函数 ==================
criterion = nn.CrossEntropyLoss()

# ================== 4. 小批量梯度下降优化器 ==================
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用 SGD,mini-batch 已在 DataLoader 中设置

# ================== 5. 训练循环 ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    # DataLoader可以实现数据分批次,train_loader是已经分批后的数据,一个images就是一个batch
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 正向传播:输入 -> 输出
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播:计算梯度
        optimizer.zero_grad()
        loss.backward()

        # 参数更新:使用小批量梯度下降(SGD)
        # 每次 optimizer.step() 就是一次参数更新 
        optimizer.step()
		
		# 记录了一个 epoch 中所有 batch 的 loss 总和,最后除以 batch 数量,得到平均 loss 并打印出来。
        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

部分参数的辅助理解可以看我另一篇文章:关于epoch、batch_size等参数含义,及optimizer.step()的含义及数学过程

📌 说明

部分 内容
数据加载 使用 DataLoader 构造 mini-batch(batch_size=64)
正向传播 outputs = model(images)
损失计算 loss = criterion(outputs, labels)
反向传播 loss.backward()
参数更新 optimizer.step(),使用的是 SGD 算法

✅ 小批量梯度下降的关键点

  • 每次更新参数只使用一个 batch 的样本(如 64 张图片)
  • 比随机梯度下降更稳定,比批量梯度下降更快
  • PyTorch 的 DataLoader + SGD 优化器天然支持 mini-batch

相关推荐
山顶夕景1 小时前
【RL】Does RLVR enable LLMs to self-improve?
深度学习·llm·强化学习·rlvr
cg50172 小时前
基于 Bert 基本模型进行 Fine-tuned
人工智能·深度学习·bert
6***x5453 小时前
C在机器学习中的ML.NET应用
人工智能·机器学习
甄心爱学习4 小时前
数据挖掘-聚类方法
人工智能·算法·机器学习
长桥夜波5 小时前
机器学习日报21
人工智能·机器学习
AndrewHZ5 小时前
【图像处理基石】如何使用大模型进行图像处理工作?
图像处理·人工智能·深度学习·算法·llm·stablediffusion·可控性
人邮异步社区6 小时前
如何有效地利用AI辅助编程,提高编程效率?
人工智能·深度学习·ai编程
星星上的吴彦祖6 小时前
多模态感知驱动的人机交互决策研究综述
python·深度学习·计算机视觉·人机交互
全息数据6 小时前
WSL2 中将 Ubuntu 20.04 升级到 22.04 的详细步骤
深度学习·ubuntu·wsl2
Jay20021117 小时前
【机器学习】10 正则化 - 减小过拟合
人工智能·机器学习