Epoch、Batch、Step 之间的关系

文章目录

Epoch

  • 定义 :一个 epoch 表示模型对整个训练集进行一次完整的遍历,即所有样本都经历一次前向传播和反向传播的训练过程。

    当我们说 "训练了 10 个 epoch",意味着模型已经从头到尾扫过了训练集 10 次。

Batch

  • 定义 :训练时通常不会将整个数据集一次性输入到模型,而是将数据分成若干小批量(mini-batch)逐步进行训练,其中每个 batch 包含一定数量的样本(batch size)

  • 公式 :假设数据集大小为 N N N, batch size 为 B B B, 则一个 epoch 内的 batch 数量为:

    Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉

    这里使用向上取整 ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 是因为数据集大小 N N N 可能无法被 B B B 整除。大多数深度学习框架在加载数据时可以自动处理最后一个不完整的 batch。例如,在 PyTorch 的 DataLoader 中,通过设置参数 drop_last 决定是否丢弃最后那个不完整的 batch(如果 drop_last=True,则会丢弃最后不足一个 batch 的样本,以确保所有 batch 大小一致)。

Step

  • 定义 :在训练中,一次对参数的更新过程 被称为一个 step。也就是说,执行一次前向传播(forward)、反向传播(backward)以及参数更新(optimizer.step()),就算完成了 1 个 step。

  • 公式 :对应于上面的定义,一个 epoch 内 step 的数量与该 epoch 内的 batch 数量相同。当训练了 E E E 个 epoch 时,总的 step 数为:

    Total Steps = Number of Batches per Epoch × E = ⌈ N B ⌉ × E \text{Total Steps} = \text{Number of Batches per Epoch} \times E = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=Number of Batches per Epoch×E=⌈BN⌉×E

关系总结

概念 定义 公式
Epoch 整个训练集完整遍历一次 -
Batch 一小组样本,用于一次参数更新前的前/后向传播 Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉
Step 一次完整的参数更新过程(前向+反向传播+更新参数) Total Steps = ⌈ N B ⌉ × E \text{Total Steps} = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=⌈BN⌉×E

举例说明

假设:

  • 数据集大小 N = 10 , 000 N = 10,000 N=10,000
  • batch size B = 32 B = 32 B=32
  • epoch 数 E = 5 E = 5 E=5
  1. 计算 1 个 epoch 中的 batch 数量:

    Number of Batches per Epoch = ⌈ 10 , 000 32 ⌉ = ⌈ 312.5 ⌉ = 313 \text{Number of Batches per Epoch} = \left\lceil \frac{10,000}{32} \right\rceil = \left\lceil 312.5 \right\rceil = 313 Number of Batches per Epoch=⌈3210,000⌉=⌈312.5⌉=313

  2. 计算总步数:

    Total Steps = ⌈ 10 , 000 32 ⌉ × 5 = 313 × 5 = 1565 \text{Total Steps} = \left\lceil \frac{10,000}{32} \right\rceil \times 5 = 313 \times 5 = 1565 Total Steps=⌈3210,000⌉×5=313×5=1565

图示(以前 2 个 epoch 为例):

bash 复制代码
Epoch 1
  └── Batch 1 → Step 1  (前向+反向传播+更新参数)
  └── Batch 2 → Step 2
  └── Batch 3 → Step 3
  └── ... 
  └── Batch 313 → Step 313
  
Epoch 2
  └── Batch 1 → Step 314
  └── Batch 2 → Step 315
  └── Batch 3 → Step 316
  └── ... 
  └── Batch 313 → Step 626
  • 每个 epoch 包含多个 batch,每个 batch 对应 1 次参数更新(即 1 个 step)。
  • 我们可以用 epochs 控制训练回合数或用 max_steps 控制训练总 step 数(AI 画图的 UI 界面中常出现这个超参数选项)。

代码示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 数据集参数
N = 10000  # 数据集总样本数
B = 32     # batch_size
E = 5      # epochs

# 创建一个示例数据集
X = torch.randn(N, 10)        # 假设输入维度为 10
y = torch.randint(0, 2, (N,)) # 二分类标签

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=False)

model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

steps_per_epoch = len(dataloader)   # 一个 epoch 内 batch 的数量
total_steps = steps_per_epoch * E

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

current_step = 0

# 这里设置成从 1 开始只是为了不在 print 中额外设置,实际写代码的时候不需要纠结这一点
for epoch in range(1, E + 1):
    for batch_idx, (inputs, targets) in enumerate(dataloader, start=1): 
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 清零梯度(这一步放在反向传播前和参数更新之后都可以)
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播(计算梯度)
        loss.backward()

        # 参数更新
        optimizer.step()
        
        current_step += 1
        # 每 50 步打印一次
        if batch_idx % 50 == 0:
            # 如果不需要打印累积 step,可以去除 current_step 项直接使用 batch_idx
            print(f"Epoch [{epoch}/{E}], Batch [{batch_idx}/{steps_per_epoch}], "
                  f"Step [{current_step}/{total_steps}], Loss: {loss.item():.4f}")

# 可以看到:
# - epoch 从 1 到 E
# - batch 从 1 到 steps_per_epoch
# - step 累计从 1 到 total_steps

输出

复制代码
Epoch [1/5], Batch [50/313], Step [50/1565], Loss: 0.7307
Epoch [1/5], Batch [100/313], Step [100/1565], Loss: 0.6950
Epoch [1/5], Batch [150/313], Step [150/1565], Loss: 0.7380
Epoch [1/5], Batch [200/313], Step [200/1565], Loss: 0.7046
Epoch [1/5], Batch [250/313], Step [250/1565], Loss: 0.6798
Epoch [1/5], Batch [300/313], Step [300/1565], Loss: 0.7319
Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058
Epoch [2/5], Batch [100/313], Step [413/1565], Loss: 0.7026
Epoch [2/5], Batch [150/313], Step [463/1565], Loss: 0.6650
Epoch [2/5], Batch [200/313], Step [513/1565], Loss: 0.6923
Epoch [2/5], Batch [250/313], Step [563/1565], Loss: 0.6889
Epoch [2/5], Batch [300/313], Step [613/1565], Loss: 0.6896
...

思考一下 :为什么输出 Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058 中的 step 是 363?

实践中相关的概念

  1. 学习率调度器(Scheduler)
    常见的学习率更新方式有两种:

    • 以 step 为基础 :在每个 step 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
              # 在每个 step 后更新学习率
              scheduler.step()
    • 以 epoch 为基础 :在每个 epoch 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
          # 在每个 epoch 后更新学习率
          scheduler.step()
  2. 早停(Early Stopping)
    可以基于 epoch 或 step 来监控验证集性能,若在一定 patience(耐心值)内验证性能没有提高,则提前停止训练来避免过拟合。

    python 复制代码
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 5
    
    for epoch in range(E):
        train_one_epoch(model, dataloader, optimizer, criterion)
        
        val_loss = validate(model, val_dataloader, criterion)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # 保存最佳模型参数
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"No improvement in validation for {patience} epochs, stopping early.")
                break
  3. Batch Size 与显存
    更大的 batch size 意味着每个 step 会处理更多数据,占用更多显存。当遇到 GPU 内存不足(Out of Memory,OOM)错误时,可以尝试减小 batch size,如果仍想达成大 batch size 的效果,使用梯度累积。

进一步阅读

SGD、BGD、MBGD 之间的区别

相关推荐
Moshow郑锴5 小时前
人工智能中的(特征选择)数据过滤方法和包裹方法
人工智能
TY-20255 小时前
【CV 目标检测】Fast RCNN模型①——与R-CNN区别
人工智能·目标检测·目标跟踪·cnn
liuhenghui52016 小时前
神经网络 常见分类
ai
CareyWYR6 小时前
苹果芯片Mac使用Docker部署MinerU api服务
人工智能
失散137 小时前
自然语言处理——02 文本预处理(下)
人工智能·自然语言处理
mit6.8247 小时前
[1Prompt1Story] 滑动窗口机制 | 图像生成管线 | VAE变分自编码器 | UNet去噪神经网络
人工智能·python
sinat_286945197 小时前
AI应用安全 - Prompt注入攻击
人工智能·安全·prompt
迈火8 小时前
ComfyUI-3D-Pack:3D创作的AI神器
人工智能·gpt·3d·ai·stable diffusion·aigc·midjourney
Moshow郑锴9 小时前
机器学习的特征工程(特征构造、特征选择、特征转换和特征提取)详解
人工智能·机器学习
CareyWYR10 小时前
每周AI论文速递(250811-250815)
人工智能