Epoch、Batch、Step 之间的关系

文章目录

Epoch

  • 定义 :一个 epoch 表示模型对整个训练集进行一次完整的遍历,即所有样本都经历一次前向传播和反向传播的训练过程。

    当我们说 "训练了 10 个 epoch",意味着模型已经从头到尾扫过了训练集 10 次。

Batch

  • 定义 :训练时通常不会将整个数据集一次性输入到模型,而是将数据分成若干小批量(mini-batch)逐步进行训练,其中每个 batch 包含一定数量的样本(batch size)

  • 公式 :假设数据集大小为 N N N, batch size 为 B B B, 则一个 epoch 内的 batch 数量为:

    Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉

    这里使用向上取整 ⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉ 是因为数据集大小 N N N 可能无法被 B B B 整除。大多数深度学习框架在加载数据时可以自动处理最后一个不完整的 batch。例如,在 PyTorch 的 DataLoader 中,通过设置参数 drop_last 决定是否丢弃最后那个不完整的 batch(如果 drop_last=True,则会丢弃最后不足一个 batch 的样本,以确保所有 batch 大小一致)。

Step

  • 定义 :在训练中,一次对参数的更新过程 被称为一个 step。也就是说,执行一次前向传播(forward)、反向传播(backward)以及参数更新(optimizer.step()),就算完成了 1 个 step。

  • 公式 :对应于上面的定义,一个 epoch 内 step 的数量与该 epoch 内的 batch 数量相同。当训练了 E E E 个 epoch 时,总的 step 数为:

    Total Steps = Number of Batches per Epoch × E = ⌈ N B ⌉ × E \text{Total Steps} = \text{Number of Batches per Epoch} \times E = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=Number of Batches per Epoch×E=⌈BN⌉×E

关系总结

概念 定义 公式
Epoch 整个训练集完整遍历一次 -
Batch 一小组样本,用于一次参数更新前的前/后向传播 Number of Batches per Epoch = ⌈ N B ⌉ \text{Number of Batches per Epoch} = \left\lceil \frac{N}{B} \right\rceil Number of Batches per Epoch=⌈BN⌉
Step 一次完整的参数更新过程(前向+反向传播+更新参数) Total Steps = ⌈ N B ⌉ × E \text{Total Steps} = \left\lceil \frac{N}{B} \right\rceil \times E Total Steps=⌈BN⌉×E

举例说明

假设:

  • 数据集大小 N = 10 , 000 N = 10,000 N=10,000
  • batch size B = 32 B = 32 B=32
  • epoch 数 E = 5 E = 5 E=5
  1. 计算 1 个 epoch 中的 batch 数量:

    Number of Batches per Epoch = ⌈ 10 , 000 32 ⌉ = ⌈ 312.5 ⌉ = 313 \text{Number of Batches per Epoch} = \left\lceil \frac{10,000}{32} \right\rceil = \left\lceil 312.5 \right\rceil = 313 Number of Batches per Epoch=⌈3210,000⌉=⌈312.5⌉=313

  2. 计算总步数:

    Total Steps = ⌈ 10 , 000 32 ⌉ × 5 = 313 × 5 = 1565 \text{Total Steps} = \left\lceil \frac{10,000}{32} \right\rceil \times 5 = 313 \times 5 = 1565 Total Steps=⌈3210,000⌉×5=313×5=1565

图示(以前 2 个 epoch 为例):

bash 复制代码
Epoch 1
  └── Batch 1 → Step 1  (前向+反向传播+更新参数)
  └── Batch 2 → Step 2
  └── Batch 3 → Step 3
  └── ... 
  └── Batch 313 → Step 313
  
Epoch 2
  └── Batch 1 → Step 314
  └── Batch 2 → Step 315
  └── Batch 3 → Step 316
  └── ... 
  └── Batch 313 → Step 626
  • 每个 epoch 包含多个 batch,每个 batch 对应 1 次参数更新(即 1 个 step)。
  • 我们可以用 epochs 控制训练回合数或用 max_steps 控制训练总 step 数(AI 画图的 UI 界面中常出现这个超参数选项)。

代码示例

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 数据集参数
N = 10000  # 数据集总样本数
B = 32     # batch_size
E = 5      # epochs

# 创建一个示例数据集
X = torch.randn(N, 10)        # 假设输入维度为 10
y = torch.randint(0, 2, (N,)) # 二分类标签

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=B, shuffle=True, drop_last=False)

model = nn.Sequential(
    nn.Linear(10, 50),
    nn.ReLU(),
    nn.Linear(50, 2)
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

steps_per_epoch = len(dataloader)   # 一个 epoch 内 batch 的数量
total_steps = steps_per_epoch * E

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

current_step = 0

# 这里设置成从 1 开始只是为了不在 print 中额外设置,实际写代码的时候不需要纠结这一点
for epoch in range(1, E + 1):
    for batch_idx, (inputs, targets) in enumerate(dataloader, start=1): 
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # 清零梯度(这一步放在反向传播前和参数更新之后都可以)
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, targets)

        # 反向传播(计算梯度)
        loss.backward()

        # 参数更新
        optimizer.step()
        
        current_step += 1
        # 每 50 步打印一次
        if batch_idx % 50 == 0:
            # 如果不需要打印累积 step,可以去除 current_step 项直接使用 batch_idx
            print(f"Epoch [{epoch}/{E}], Batch [{batch_idx}/{steps_per_epoch}], "
                  f"Step [{current_step}/{total_steps}], Loss: {loss.item():.4f}")

# 可以看到:
# - epoch 从 1 到 E
# - batch 从 1 到 steps_per_epoch
# - step 累计从 1 到 total_steps

输出

Epoch [1/5], Batch [50/313], Step [50/1565], Loss: 0.7307
Epoch [1/5], Batch [100/313], Step [100/1565], Loss: 0.6950
Epoch [1/5], Batch [150/313], Step [150/1565], Loss: 0.7380
Epoch [1/5], Batch [200/313], Step [200/1565], Loss: 0.7046
Epoch [1/5], Batch [250/313], Step [250/1565], Loss: 0.6798
Epoch [1/5], Batch [300/313], Step [300/1565], Loss: 0.7319
Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058
Epoch [2/5], Batch [100/313], Step [413/1565], Loss: 0.7026
Epoch [2/5], Batch [150/313], Step [463/1565], Loss: 0.6650
Epoch [2/5], Batch [200/313], Step [513/1565], Loss: 0.6923
Epoch [2/5], Batch [250/313], Step [563/1565], Loss: 0.6889
Epoch [2/5], Batch [300/313], Step [613/1565], Loss: 0.6896
...

思考一下 :为什么输出 Epoch [2/5], Batch [50/313], Step [363/1565], Loss: 0.7058 中的 step 是 363?

实践中相关的概念

  1. 学习率调度器(Scheduler)
    常见的学习率更新方式有两种:

    • 以 step 为基础 :在每个 step 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
              # 在每个 step 后更新学习率
              scheduler.step()
    • 以 epoch 为基础 :在每个 epoch 结束后更新学习率。

      python 复制代码
      scheduler = ...
      
      for epoch in range(E):
          for batch_idx, (inputs, targets) in enumerate(dataloader):
              # 前向、后向、更新参数
              ...
          # 在每个 epoch 后更新学习率
          scheduler.step()
  2. 早停(Early Stopping)
    可以基于 epoch 或 step 来监控验证集性能,若在一定 patience(耐心值)内验证性能没有提高,则提前停止训练来避免过拟合。

    python 复制代码
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 5
    
    for epoch in range(E):
        train_one_epoch(model, dataloader, optimizer, criterion)
        
        val_loss = validate(model, val_dataloader, criterion)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # 保存最佳模型参数
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"No improvement in validation for {patience} epochs, stopping early.")
                break
  3. Batch Size 与显存
    更大的 batch size 意味着每个 step 会处理更多数据,占用更多显存。当遇到 GPU 内存不足(Out of Memory,OOM)错误时,可以尝试减小 batch size,如果仍想达成大 batch size 的效果,使用梯度累积。

进一步阅读

SGD、BGD、MBGD 之间的区别

相关推荐
wit_@38 分钟前
【深入解析】 RNN 算法:原理、应用与实现
python·rnn·深度学习·神经网络
元宇宙时间39 分钟前
DPIN与CESS Network达成全球战略合作,推动DePIN与AI领域创新突破
人工智能
雨后的路39 分钟前
小雨:2024年,有哪些有趣的智能体?附文章总结/收藏/提醒助手教程
人工智能·程序员
格林威43 分钟前
工业网口相机:如何通过调整网口参数设置,优化图像传输和网络性能,达到最大帧率
网络·人工智能·数码相机·opencv·计算机视觉·c#
goomind44 分钟前
Transformer之Decoder
人工智能·深度学习·llm·nlp·transformer
BTColdman11 小时前
Plume :RWAfi 叙事引领者,全新加密时代的新蓝筹生态
人工智能·区块链
Dream25121 小时前
【神经网络基础】
人工智能·深度学习·神经网络
白白糖1 小时前
深度学习 Pytorch 张量的线性代数运算
人工智能·pytorch·深度学习
hao_wujing1 小时前
通过视觉语言模型蒸馏进行 3D 形状零件分割
人工智能·语言模型·自然语言处理
AI-智能2 小时前
NLP入门书籍《掌握NLP:从基础到大语言模型》免费下载pdf
人工智能·自然语言处理·程序员·llm·prompt·ai编程·ai大模型