Epoch、Batch、Step 之间的关系

文章目录

Epoch

  • 定义 :一个 epoch 表示模型对整个训练集进行一次完整的遍历,即所有样本都经历一次前向传播和反向传播的训练过程。

    当我们说 "训练了 10 个 epoch",意味着模型已经从头到尾扫过了训练集 10 次。

Batch

  • 定义 :训练时通常不会将整个数据集一次性输入到模型,而是将数据分成若干小批量(mini-batch)逐步进行训练,其中每个 batch 包含一定数量的样本(batch size)

  • 公式 :假设数据集大小为 N N N, batch size 为 B B B, 则一个 epoch 内的 batch 数量为:

    Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉

    这里使用向上取整 ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 是因为数据集大小 N N N 可能无法被 B B B 整除。大多数深度学习框架在加载数据时可以自动处理最后一个不完整的 batch。例如,在 PyTorch 的 DataLoader 中,通过设置参数 drop_last 决定是否丢弃最后那个不完整的 batch(如果 drop_last=True,则会丢弃最后不足一个 batch 的样本,以确保所有 batch 大小一致)。

Step

  • 定义 :在训练中,一次对参数的更新过程 被称为一个 step。也就是说,执行一次前向传播(forward)、反向传播(backward)以及参数更新(optimizer.step()),就算完成了 1 个 step。

  • 公式 :对应于上面的定义,一个 epoch 内 step 的数量与该 epoch 内的 batch 数量相同。当训练了 E E E 个 epoch 时,总的 step 数为:

    Total Steps = Number of Batches per Epoch × E = ⌈ N B ⌉ × E \text{Total Steps} = \text{Number of Batches per Epoch} \times E = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=Number of Batches per Epoch×E=⌈BN⌉×E

关系总结

概念 定义 公式
Epoch 整个训练集完整遍历一次 -
Batch 一小组样本,用于一次参数更新前的前/后向传播 Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉
Step 一次完整的参数更新过程(前向+反向传播+更新参数) Total Steps = ⌈ N B ⌉ × E \text{Total Steps} = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=⌈BN⌉×E

举例说明

假设:

  • 数据集大小 N = 10 , 000 N = 10,000 N=10,000
  • batch size B = 32 B = 32 B=32
  • epoch 数 E = 5 E = 5 E=5
  1. 计算 1 个 epoch 中的 batch 数量:

    Number of Batches per Epoch = ⌈ 10 , 000 32 ⌉ = ⌈ 312.5 ⌉ = 313 \text{Number of Batches per Epoch} = \left\lceil \frac{10,000}{32} \right\rceil = \left\lceil 312.5 \right\rceil = 313 Number of Batches per Epoch=⌈3210,000⌉=⌈312.5⌉=313

  2. 计算总步数:

    Total Steps = ⌈ 10 , 000 32 ⌉ × 5 = 313 × 5 = 1565 \text{Total Steps} = \left\lceil \frac{10,000}{32} \right\rceil \times 5 = 313 \times 5 = 1565 Total Steps=⌈3210,000⌉×5=313×5=1565

图示(以前 2 个 epoch 为例):

bash 复制代码
Epoch 1
  └── Batch 1 → Step 1  (前向+反向传播+更新参数)
  └── Batch 2 → Step 2
  └── Batch 3 → Step 3
  └── ... 
  └── Batch 313 → Step 313
  
Epoch 2
  └── Batch 1 → Step 314
  └── Batch 2 → Step 315
  └── Batch 3 → Step 316
  └── ... 
  └── Batch 313 → Step 626
  • 每个 epoch 包含多个 batch,每个 batch 对应 1 次参数更新(即 1 个 step)。
  • 我们可以用 epochs 控制训练回合数或用 max_steps 控制训练总 step 数(AI 画图的 UI 界面中常出现这个超参数选项)。

代码示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 数据集参数
N = 10000  # 数据集总样本数
B = 32     # batch_size
E = 5      # epochs

# 创建一个示例数据集
X = torch.randn(N, 10)        # 假设输入维度为 10
y = torch.randint(0, 2, (N,)) # 二分类标签

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=False)

model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

steps_per_epoch = len(dataloader)   # 一个 epoch 内 batch 的数量
total_steps = steps_per_epoch * E

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

current_step = 0

# 这里设置成从 1 开始只是为了不在 print 中额外设置,实际写代码的时候不需要纠结这一点
for epoch in range(1, E + 1):
    for batch_idx, (inputs, targets) in enumerate(dataloader, start=1): 
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 清零梯度(这一步放在反向传播前和参数更新之后都可以)
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播(计算梯度)
        loss.backward()

        # 参数更新
        optimizer.step()
        
        current_step += 1
        # 每 50 步打印一次
        if batch_idx % 50 == 0:
            # 如果不需要打印累积 step,可以去除 current_step 项直接使用 batch_idx
            print(f"Epoch [{epoch}/{E}], Batch [{batch_idx}/{steps_per_epoch}], "
                  f"Step [{current_step}/{total_steps}], Loss: {loss.item():.4f}")

# 可以看到:
# - epoch 从 1 到 E
# - batch 从 1 到 steps_per_epoch
# - step 累计从 1 到 total_steps

输出

Epoch [1/5], Batch [50/313], Step [50/1565], Loss: 0.7307
Epoch [1/5], Batch [100/313], Step [100/1565], Loss: 0.6950
Epoch [1/5], Batch [150/313], Step [150/1565], Loss: 0.7380
Epoch [1/5], Batch [200/313], Step [200/1565], Loss: 0.7046
Epoch [1/5], Batch [250/313], Step [250/1565], Loss: 0.6798
Epoch [1/5], Batch [300/313], Step [300/1565], Loss: 0.7319
Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058
Epoch [2/5], Batch [100/313], Step [413/1565], Loss: 0.7026
Epoch [2/5], Batch [150/313], Step [463/1565], Loss: 0.6650
Epoch [2/5], Batch [200/313], Step [513/1565], Loss: 0.6923
Epoch [2/5], Batch [250/313], Step [563/1565], Loss: 0.6889
Epoch [2/5], Batch [300/313], Step [613/1565], Loss: 0.6896
...

思考一下 :为什么输出 Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058 中的 step 是 363?

实践中相关的概念

  1. 学习率调度器(Scheduler)
    常见的学习率更新方式有两种:

    • 以 step 为基础 :在每个 step 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
              # 在每个 step 后更新学习率
              scheduler.step()
    • 以 epoch 为基础 :在每个 epoch 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
          # 在每个 epoch 后更新学习率
          scheduler.step()
  2. 早停(Early Stopping)
    可以基于 epoch 或 step 来监控验证集性能,若在一定 patience(耐心值)内验证性能没有提高,则提前停止训练来避免过拟合。

    python 复制代码
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 5
    
    for epoch in range(E):
        train_one_epoch(model, dataloader, optimizer, criterion)
        
        val_loss = validate(model, val_dataloader, criterion)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # 保存最佳模型参数
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"No improvement in validation for {patience} epochs, stopping early.")
                break
  3. Batch Size 与显存
    更大的 batch size 意味着每个 step 会处理更多数据,占用更多显存。当遇到 GPU 内存不足(Out of Memory,OOM)错误时,可以尝试减小 batch size,如果仍想达成大 batch size 的效果,使用梯度累积。

进一步阅读

SGD、BGD、MBGD 之间的区别

相关推荐
CM莫问1 小时前
python实战(十一)——情感分析
人工智能
chenchihwen1 小时前
TimesFM(Time Series Foundation Model)时间序列预测的数据研究(3)
人工智能·python
听风吹等浪起1 小时前
改进系列(6):基于DenseNet网络添加TripletAttention注意力层实现的番茄病害图像分类
网络·人工智能·神经网络·算法·分类
夜半被帅醒1 小时前
【人工智能解读】神经网络(CNN)的特点及其应用场景器学习(Machine Learning, ML)的基本概念
人工智能·神经网络·机器学习
一只小灿灿3 小时前
计算机视觉中的图像滤波与增强算法
人工智能·算法·计算机视觉
程序猿000001号3 小时前
使用PyTorch Lightning简化深度学习模型开发
pytorch·深度学习·目标检测
Illusionna.3 小时前
Word2Vec 模型 PyTorch 实现并复现论文中的数据集
人工智能·pytorch·算法·自然语言处理·nlp·matplotlib·word2vec
平凡灵感码头3 小时前
机器学习算法概览
人工智能·算法·机器学习
UQI-LIUWJ3 小时前
论文结论:GPTs and Hallucination Why do large language models hallucinate
人工智能·语言模型·自然语言处理