1. 为什么GPU反而比CPU慢?
在深度学习的固有印象中,GPU通常被认为是训练加速的神器。然而,在使用鸢尾花(Iris)这样的小型数据集和简单的多层感知机(MLP)模型进行实验时,我们可能会观察到一个反直觉的现象:CPU的训练速度反而比GPU快得多。
1.1 实验数据对比
- CPU (i9-12900KF): 训练耗时约 3秒。
- GPU (RTX 3080 Ti): 训练耗时约 11秒。
1.2 核心原因分析
对于"玩具级别"的小任务,GPU的并行计算优势无法发挥,反而被以下三个开销拖累:
- 数据传输开销 (Data Transfer Overhead)
- CPU内存 <-> GPU显存: 在GPU计算前,数据和模型需要从主机内存复制到显存;计算后,结果(如loss值)又需要传回内存。
- 同步瓶颈 : 尤其是在训练循环中频繁调用
loss.item(),这会强制GPU等待数据传回CPU,打断了流水线。
- 核心启动开销 (Kernel Launch Overhead)
- GPU执行每个操作(如一次加法、一个激活函数)都需要启动一个"核心"(Kernel)。
- 对于极小的计算量,启动核心的时间可能比实际计算的时间还要长。
- 性能浪费 (Performance Waste)
- GPU拥有成千上万个计算单元,专为大规模并行设计。
- 当数据量太小时,绝大多数计算单元处于闲置状态,类似于"用卡车运送一瓶水"。
2. 深入探究:数据传输开销与优化
为了验证数据传输是主要的性能瓶颈,我们可以尝试对训练循环中的 loss.item() 操作进行优化。
2.1 优化方案一:完全移除过程记录
如果我们注释掉 losses.append(loss.item()) 和打印逻辑:
- 结果 : 训练时间骤降至 2.86秒。
- 结论: 证明了从GPU回传数据到CPU确实是最大的耗时来源。
- 缺点: 无法实时监控训练过程,也无法绘制损失曲线。
2.2 优化方案二:降低记录频率
如果我们尝试每隔200个epoch才记录一次loss:
- 结果 : 训练时间约为 10.38秒,改善并不明显。
- 原因 :
loss.item()是一个同步操作。一旦调用,Python代码必须等待GPU完成当前的计算并将标量值传回CPU才能继续执行下一行。即使减少了记录次数,这种强制的同步依然会频繁打断GPU的连续执行流,导致性能无法显著提升。
2.3 什么时候该用GPU?
GPU真正发挥威力的场景包括:
- 大型数据集: 如ImageNet,成千上万的高维图片。
- 大型模型: 如ResNet、Transformer,拥有百万级参数。
- 高计算密度: 大量的矩阵乘法和卷积运算。
3. Python黑魔法:__call__方法
在阅读PyTorch代码时,你经常会看到 model(x) 这种写法,而不是 model.forward(x)。这背后的原理是Python的 __call__ 魔术方法。
3.1 什么是__call__?
__call__ 是一个特殊方法,它允许类的实例像函数一样被调用。
3.2 代码示例
不带参数的调用:
class Counter:
def __init__(self):
self.count = 0
def __call__(self):
self.count += 1
return self.count
c = Counter()
print(c()) # 输出: 1
print(c()) # 输出: 2
带参数的调用:
class Adder:
def __call__(self, a, b):
return a + b
add = Adder()
print(add(3, 5)) # 输出: 8
4. PyTorch中的__call__机制
4.1 为什么不直接用forward?
在PyTorch中,所有的层(如 nn.Linear)和模型都继承自 nn.Module。
当我们执行 output = model(input) 时:
- 实际上调用的是
nn.Module定义的__call__方法。 __call__方法内部会处理很多额外逻辑(如运行 Hooks 钩子函数)。- 最后才调用我们定义的
forward方法。
4.2 最佳实践
始终使用 model(x)而不是 **model.forward(x)**。
这是PyTorch的设计规范,确保了框架功能的完整性(如自动梯度计算、钩子机制等)都能正常工作。
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 10)
def forward(self, x):
# 这里的 self.fc1(x) 也是触发了 nn.Linear 的 __call__
out = self.fc1(x)
return out