Epoch、Batch、Step 之间的关系

文章目录

Epoch

  • 定义 :一个 epoch 表示模型对整个训练集进行一次完整的遍历,即所有样本都经历一次前向传播和反向传播的训练过程。

    当我们说 "训练了 10 个 epoch",意味着模型已经从头到尾扫过了训练集 10 次。

Batch

  • 定义 :训练时通常不会将整个数据集一次性输入到模型,而是将数据分成若干小批量(mini-batch)逐步进行训练,其中每个 batch 包含一定数量的样本(batch size)

  • 公式 :假设数据集大小为 N N N, batch size 为 B B B, 则一个 epoch 内的 batch 数量为:

    Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉

    这里使用向上取整 ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 是因为数据集大小 N N N 可能无法被 B B B 整除。大多数深度学习框架在加载数据时可以自动处理最后一个不完整的 batch。例如,在 PyTorch 的 DataLoader 中,通过设置参数 drop_last 决定是否丢弃最后那个不完整的 batch(如果 drop_last=True,则会丢弃最后不足一个 batch 的样本,以确保所有 batch 大小一致)。

Step

  • 定义 :在训练中,一次对参数的更新过程 被称为一个 step。也就是说,执行一次前向传播(forward)、反向传播(backward)以及参数更新(optimizer.step()),就算完成了 1 个 step。

  • 公式 :对应于上面的定义,一个 epoch 内 step 的数量与该 epoch 内的 batch 数量相同。当训练了 E E E 个 epoch 时,总的 step 数为:

    Total Steps = Number of Batches per Epoch × E = ⌈ N B ⌉ × E \text{Total Steps} = \text{Number of Batches per Epoch} \times E = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=Number of Batches per Epoch×E=⌈BN⌉×E

关系总结

概念 定义 公式
Epoch 整个训练集完整遍历一次 -
Batch 一小组样本,用于一次参数更新前的前/后向传播 Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉
Step 一次完整的参数更新过程(前向+反向传播+更新参数) Total Steps = ⌈ N B ⌉ × E \text{Total Steps} = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=⌈BN⌉×E

举例说明

假设:

  • 数据集大小 N = 10 , 000 N = 10,000 N=10,000
  • batch size B = 32 B = 32 B=32
  • epoch 数 E = 5 E = 5 E=5
  1. 计算 1 个 epoch 中的 batch 数量:

    Number of Batches per Epoch = ⌈ 10 , 000 32 ⌉ = ⌈ 312.5 ⌉ = 313 \text{Number of Batches per Epoch} = \left\lceil \frac{10,000}{32} \right\rceil = \left\lceil 312.5 \right\rceil = 313 Number of Batches per Epoch=⌈3210,000⌉=⌈312.5⌉=313

  2. 计算总步数:

    Total Steps = ⌈ 10 , 000 32 ⌉ × 5 = 313 × 5 = 1565 \text{Total Steps} = \left\lceil \frac{10,000}{32} \right\rceil \times 5 = 313 \times 5 = 1565 Total Steps=⌈3210,000⌉×5=313×5=1565

图示(以前 2 个 epoch 为例):

bash 复制代码
Epoch 1
  └── Batch 1 → Step 1  (前向+反向传播+更新参数)
  └── Batch 2 → Step 2
  └── Batch 3 → Step 3
  └── ... 
  └── Batch 313 → Step 313
  
Epoch 2
  └── Batch 1 → Step 314
  └── Batch 2 → Step 315
  └── Batch 3 → Step 316
  └── ... 
  └── Batch 313 → Step 626
  • 每个 epoch 包含多个 batch,每个 batch 对应 1 次参数更新(即 1 个 step)。
  • 我们可以用 epochs 控制训练回合数或用 max_steps 控制训练总 step 数(AI 画图的 UI 界面中常出现这个超参数选项)。

代码示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 数据集参数
N = 10000  # 数据集总样本数
B = 32     # batch_size
E = 5      # epochs

# 创建一个示例数据集
X = torch.randn(N, 10)        # 假设输入维度为 10
y = torch.randint(0, 2, (N,)) # 二分类标签

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=False)

model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

steps_per_epoch = len(dataloader)   # 一个 epoch 内 batch 的数量
total_steps = steps_per_epoch * E

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

current_step = 0

# 这里设置成从 1 开始只是为了不在 print 中额外设置,实际写代码的时候不需要纠结这一点
for epoch in range(1, E + 1):
    for batch_idx, (inputs, targets) in enumerate(dataloader, start=1): 
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 清零梯度(这一步放在反向传播前和参数更新之后都可以)
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播(计算梯度)
        loss.backward()

        # 参数更新
        optimizer.step()
        
        current_step += 1
        # 每 50 步打印一次
        if batch_idx % 50 == 0:
            # 如果不需要打印累积 step,可以去除 current_step 项直接使用 batch_idx
            print(f"Epoch [{epoch}/{E}], Batch [{batch_idx}/{steps_per_epoch}], "
                  f"Step [{current_step}/{total_steps}], Loss: {loss.item():.4f}")

# 可以看到:
# - epoch 从 1 到 E
# - batch 从 1 到 steps_per_epoch
# - step 累计从 1 到 total_steps

输出

复制代码
Epoch [1/5], Batch [50/313], Step [50/1565], Loss: 0.7307
Epoch [1/5], Batch [100/313], Step [100/1565], Loss: 0.6950
Epoch [1/5], Batch [150/313], Step [150/1565], Loss: 0.7380
Epoch [1/5], Batch [200/313], Step [200/1565], Loss: 0.7046
Epoch [1/5], Batch [250/313], Step [250/1565], Loss: 0.6798
Epoch [1/5], Batch [300/313], Step [300/1565], Loss: 0.7319
Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058
Epoch [2/5], Batch [100/313], Step [413/1565], Loss: 0.7026
Epoch [2/5], Batch [150/313], Step [463/1565], Loss: 0.6650
Epoch [2/5], Batch [200/313], Step [513/1565], Loss: 0.6923
Epoch [2/5], Batch [250/313], Step [563/1565], Loss: 0.6889
Epoch [2/5], Batch [300/313], Step [613/1565], Loss: 0.6896
...

思考一下 :为什么输出 Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058 中的 step 是 363?

实践中相关的概念

  1. 学习率调度器(Scheduler)
    常见的学习率更新方式有两种:

    • 以 step 为基础 :在每个 step 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
              # 在每个 step 后更新学习率
              scheduler.step()
    • 以 epoch 为基础 :在每个 epoch 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
          # 在每个 epoch 后更新学习率
          scheduler.step()
  2. 早停(Early Stopping)
    可以基于 epoch 或 step 来监控验证集性能,若在一定 patience(耐心值)内验证性能没有提高,则提前停止训练来避免过拟合。

    python 复制代码
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 5
    
    for epoch in range(E):
        train_one_epoch(model, dataloader, optimizer, criterion)
        
        val_loss = validate(model, val_dataloader, criterion)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # 保存最佳模型参数
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"No improvement in validation for {patience} epochs, stopping early.")
                break
  3. Batch Size 与显存
    更大的 batch size 意味着每个 step 会处理更多数据,占用更多显存。当遇到 GPU 内存不足(Out of Memory,OOM)错误时,可以尝试减小 batch size,如果仍想达成大 batch size 的效果,使用梯度累积。

进一步阅读

SGD、BGD、MBGD 之间的区别

相关推荐
十里清风2 分钟前
不同类型的语义相似度损失函数(SentenceTransformerLoss)
深度学习
叶子2024229 分钟前
守护进程实验——autoDL
人工智能·算法·机器学习
陈奕昆12 分钟前
4.3 HarmonyOS NEXT AI驱动的交互创新:智能助手、实时语音与AR/MR开发实战
人工智能·交互·harmonyos
张较瘦_31 分钟前
[论文阅读] 人工智能 | 用大语言模型抓虫:如何让网络协议实现与RFC规范对齐
论文阅读·人工智能·语言模型
qb_jiajia36 分钟前
微软认证考试科目众多?该如何选择?
人工智能·microsoft·微软·云计算
pen-ai1 小时前
【统计方法】蒙特卡洛
人工智能·机器学习·概率论
说私域1 小时前
基于开源AI智能名片链动2+1模式S2B2C商城小程序的生态农庄留存运营策略研究
人工智能·小程序·开源·零售
摘取一颗天上星️1 小时前
大模型微调技术全景图:从全量更新到参数高效适配
人工智能·深度学习·机器学习
要努力啊啊啊1 小时前
策略梯度核心:Advantage 与 GAE 原理详解
论文阅读·人工智能·深度学习·自然语言处理
AI航海家(Ethan)1 小时前
RAG技术解析:实现高精度大语言模型知识增强
人工智能·语言模型·自然语言处理