Epoch、Batch、Step 之间的关系

文章目录

Epoch

  • 定义 :一个 epoch 表示模型对整个训练集进行一次完整的遍历,即所有样本都经历一次前向传播和反向传播的训练过程。

    当我们说 "训练了 10 个 epoch",意味着模型已经从头到尾扫过了训练集 10 次。

Batch

  • 定义 :训练时通常不会将整个数据集一次性输入到模型,而是将数据分成若干小批量(mini-batch)逐步进行训练,其中每个 batch 包含一定数量的样本(batch size)

  • 公式 :假设数据集大小为 N N N, batch size 为 B B B, 则一个 epoch 内的 batch 数量为:

    Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉

    这里使用向上取整 ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 是因为数据集大小 N N N 可能无法被 B B B 整除。大多数深度学习框架在加载数据时可以自动处理最后一个不完整的 batch。例如,在 PyTorch 的 DataLoader 中,通过设置参数 drop_last 决定是否丢弃最后那个不完整的 batch(如果 drop_last=True,则会丢弃最后不足一个 batch 的样本,以确保所有 batch 大小一致)。

Step

  • 定义 :在训练中,一次对参数的更新过程 被称为一个 step。也就是说,执行一次前向传播(forward)、反向传播(backward)以及参数更新(optimizer.step()),就算完成了 1 个 step。

  • 公式 :对应于上面的定义,一个 epoch 内 step 的数量与该 epoch 内的 batch 数量相同。当训练了 E E E 个 epoch 时,总的 step 数为:

    Total Steps = Number of Batches per Epoch × E = ⌈ N B ⌉ × E \text{Total Steps} = \text{Number of Batches per Epoch} \times E = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=Number of Batches per Epoch×E=⌈BN⌉×E

关系总结

概念 定义 公式
Epoch 整个训练集完整遍历一次 -
Batch 一小组样本,用于一次参数更新前的前/后向传播 Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉
Step 一次完整的参数更新过程(前向+反向传播+更新参数) Total Steps = ⌈ N B ⌉ × E \text{Total Steps} = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=⌈BN⌉×E

举例说明

假设:

  • 数据集大小 N = 10 , 000 N = 10,000 N=10,000
  • batch size B = 32 B = 32 B=32
  • epoch 数 E = 5 E = 5 E=5
  1. 计算 1 个 epoch 中的 batch 数量:

    Number of Batches per Epoch = ⌈ 10 , 000 32 ⌉ = ⌈ 312.5 ⌉ = 313 \text{Number of Batches per Epoch} = \left\lceil \frac{10,000}{32} \right\rceil = \left\lceil 312.5 \right\rceil = 313 Number of Batches per Epoch=⌈3210,000⌉=⌈312.5⌉=313

  2. 计算总步数:

    Total Steps = ⌈ 10 , 000 32 ⌉ × 5 = 313 × 5 = 1565 \text{Total Steps} = \left\lceil \frac{10,000}{32} \right\rceil \times 5 = 313 \times 5 = 1565 Total Steps=⌈3210,000⌉×5=313×5=1565

图示(以前 2 个 epoch 为例):

bash 复制代码
Epoch 1
  └── Batch 1 → Step 1  (前向+反向传播+更新参数)
  └── Batch 2 → Step 2
  └── Batch 3 → Step 3
  └── ... 
  └── Batch 313 → Step 313
  
Epoch 2
  └── Batch 1 → Step 314
  └── Batch 2 → Step 315
  └── Batch 3 → Step 316
  └── ... 
  └── Batch 313 → Step 626
  • 每个 epoch 包含多个 batch,每个 batch 对应 1 次参数更新(即 1 个 step)。
  • 我们可以用 epochs 控制训练回合数或用 max_steps 控制训练总 step 数(AI 画图的 UI 界面中常出现这个超参数选项)。

代码示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 数据集参数
N = 10000  # 数据集总样本数
B = 32     # batch_size
E = 5      # epochs

# 创建一个示例数据集
X = torch.randn(N, 10)        # 假设输入维度为 10
y = torch.randint(0, 2, (N,)) # 二分类标签

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=False)

model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

steps_per_epoch = len(dataloader)   # 一个 epoch 内 batch 的数量
total_steps = steps_per_epoch * E

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

current_step = 0

# 这里设置成从 1 开始只是为了不在 print 中额外设置,实际写代码的时候不需要纠结这一点
for epoch in range(1, E + 1):
    for batch_idx, (inputs, targets) in enumerate(dataloader, start=1): 
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 清零梯度(这一步放在反向传播前和参数更新之后都可以)
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播(计算梯度)
        loss.backward()

        # 参数更新
        optimizer.step()
        
        current_step += 1
        # 每 50 步打印一次
        if batch_idx % 50 == 0:
            # 如果不需要打印累积 step,可以去除 current_step 项直接使用 batch_idx
            print(f"Epoch [{epoch}/{E}], Batch [{batch_idx}/{steps_per_epoch}], "
                  f"Step [{current_step}/{total_steps}], Loss: {loss.item():.4f}")

# 可以看到:
# - epoch 从 1 到 E
# - batch 从 1 到 steps_per_epoch
# - step 累计从 1 到 total_steps

输出

复制代码
Epoch [1/5], Batch [50/313], Step [50/1565], Loss: 0.7307
Epoch [1/5], Batch [100/313], Step [100/1565], Loss: 0.6950
Epoch [1/5], Batch [150/313], Step [150/1565], Loss: 0.7380
Epoch [1/5], Batch [200/313], Step [200/1565], Loss: 0.7046
Epoch [1/5], Batch [250/313], Step [250/1565], Loss: 0.6798
Epoch [1/5], Batch [300/313], Step [300/1565], Loss: 0.7319
Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058
Epoch [2/5], Batch [100/313], Step [413/1565], Loss: 0.7026
Epoch [2/5], Batch [150/313], Step [463/1565], Loss: 0.6650
Epoch [2/5], Batch [200/313], Step [513/1565], Loss: 0.6923
Epoch [2/5], Batch [250/313], Step [563/1565], Loss: 0.6889
Epoch [2/5], Batch [300/313], Step [613/1565], Loss: 0.6896
...

思考一下 :为什么输出 Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058 中的 step 是 363?

实践中相关的概念

  1. 学习率调度器(Scheduler)
    常见的学习率更新方式有两种:

    • 以 step 为基础 :在每个 step 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
              # 在每个 step 后更新学习率
              scheduler.step()
    • 以 epoch 为基础 :在每个 epoch 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
          # 在每个 epoch 后更新学习率
          scheduler.step()
  2. 早停(Early Stopping)
    可以基于 epoch 或 step 来监控验证集性能,若在一定 patience(耐心值)内验证性能没有提高,则提前停止训练来避免过拟合。

    python 复制代码
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 5
    
    for epoch in range(E):
        train_one_epoch(model, dataloader, optimizer, criterion)
        
        val_loss = validate(model, val_dataloader, criterion)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # 保存最佳模型参数
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"No improvement in validation for {patience} epochs, stopping early.")
                break
  3. Batch Size 与显存
    更大的 batch size 意味着每个 step 会处理更多数据,占用更多显存。当遇到 GPU 内存不足(Out of Memory,OOM)错误时,可以尝试减小 batch size,如果仍想达成大 batch size 的效果,使用梯度累积。

进一步阅读

SGD、BGD、MBGD 之间的区别

相关推荐
Mr数据杨4 分钟前
灾害推文识别与应急信息筛选优化
机器学习·数据分析·kaggle
weisian1518 分钟前
基础篇--概念原理-1-Token是什么?——从原理到实战,一篇讲透
人工智能·职场和发展·token
大模型最新论文速读13 分钟前
Select to Think:蒸馏 token 排序能力,效果平均提升24%
论文阅读·人工智能·深度学习·机器学习·自然语言处理
Studying 开龙wu13 分钟前
深度学习PyTorch 实战九:YOLOv1目标检测从标注-训练-预测
pytorch·深度学习·yolo
老了,不知天命16 分钟前
鳶尾花項目JAVA
java·开发语言·机器学习
维元码簿19 分钟前
Claude Code 深度拆解:多 Agent 协作 3 — Task 状态机、SendMessage 与消息邮箱
ai·agent·claude code·ai coding
二哈赛车手23 分钟前
新人笔记---实现简易版的rag的bm25检索(利用ES),以及RAG上传时的ES与向量数据库双写
java·数据库·笔记·spring·elasticsearch·ai
无忧智库24 分钟前
跨行业数据要素可信流通体系建设:打破信任壁垒的完整工程方法论(WORD)
大数据·人工智能
mit6.82424 分钟前
NitroGen: AI 自动玩游戏
人工智能
小王毕业啦26 分钟前
2007-2024年 省级-农林牧渔总产值、农业总产值数据(xlsx)
大数据·人工智能·数据挖掘·数据分析·社科数据·实证分析·经管数据