day29异常处理@浙大疏锦行
为什么深度学习需要异常处理?
- 训练成本高: 模型训练通常需要数小时甚至数天。你绝不希望跑到 99% 时因为读取了一张损坏的图片而导致程序崩溃,前功尽弃。
- 数据不完美: 真实世界的数据集(Dataset)往往包含坏数据(格式错误、路径不存在、维度对不上)。
- 硬件资源: 需要优雅地处理 GPU 显存不足或设备不可用的情况。
以下是 Python 异常处理的核心语法和在深度学习中的常用场景。
1. 核心语法结构
异常处理的基本逻辑是:"尝试做某事,如果出错了就捕获它,不管怎样最后都要清理现场。"
Python
try:
# 可能发生错误的代码块
# 例如:加载模型、读取文件、矩阵运算
result = 10 / 0
except ZeroDivisionError as e:
# 当捕捉到特定错误(这里是除以零)时执行的代码
print(f"出错了:{e}")
except Exception as e:
# 当捕捉到其他所有类型的错误时执行
print(f"发生了未知错误:{e}")
else:
# 如果 try 块没有报错,则执行这里
print("一切顺利!结果是:", result)
finally:
# 无论是否报错,都会执行(通常用于关闭文件、释放显存等)
print("清理完成。")
2. 深度学习中的常见实战场景
场景一:数据加载(跳过坏数据)
这是 DL 中最常用的场景。在 DataLoader 中读取成千上万张图片时,遇到损坏的文件应该跳过并记录,而不是崩溃。
Python
import os
from PIL import Image
def load_image(image_path):
try:
img = Image.open(image_path)
img.verify() # 验证文件是否完整
return img
except (FileNotFoundError, OSError) as e:
# 捕获文件找不到或文件损坏的错误
print(f"Warning: 跳过损坏或缺失的图片 {image_path}. 错误信息: {e}")
return None
except Exception as e:
# 兜底捕获其他未知错误
print(f"Error: 处理 {image_path} 时发生未知错误: {e}")
return None
# 使用
data = load_image("dataset/bad_image.jpg")
if data is not None:
pass # 继续处理
场景二:张量维度检查(Assert 与 Raise)
在编写模型(Model)的前向传播(Forward)时,维度不匹配(Shape Mismatch)是噩梦。你可以主动抛出异常或使用断言。
raise: 主动报错,阻止代码继续运行。assert: 调试用的断言,如果条件为假则报错(DL 常用)。
Python
import torch
def calculate_loss(prediction, target):
# 检查维度是否一致
if prediction.shape != target.shape:
# 主动抛出 ValueError,提示信息要写清楚
raise ValueError(f"维度不匹配!预测值形状 {prediction.shape} vs 目标值形状 {target.shape}")
return torch.mean((prediction - target)**2)
# 或者使用更简洁的 assert(通常用于检查函数输入)
def forward(x):
# 假设输入必须是 (Batch, 3, 224, 224)
assert x.shape[1] == 3, f"输入通道数必须是 3,但得到了 {x.shape[1]}"
# ... 模型逻辑
场景三:处理 GPU/CPU 设备问题
在代码跨设备迁移时,经常需要处理设备不存在的情况。
Python
import torch
try:
# 尝试使用第一个 GPU
device = torch.device("cuda:0")
dummy_tensor = torch.zeros(1).to(device)
except AssertionError:
print("未检测到 GPU,或者 CUDA 不可用。")
device = torch.device("cpu")
except RuntimeError as e:
print(f"GPU 运行时错误(可能是显存不足):{e}")
device = torch.device("cpu")
print(f"当前使用设备: {device}")
3. 常用异常类型速查表
在深度学习调试中,你最常遇到以下几种异常:
| 异常名称 | 常见原因 (DL 上下文) |
|---|---|
ImportError / ModuleNotFoundError |
忘记安装库(如 No module named 'torch')。 |
IndexError |
访问列表或 Tensor 时越界(如 Batch Size 设错了)。 |
KeyError |
读取字典中不存在的键(常见于加载模型 state_dict 键名不匹配)。 |
ValueError |
传参的值不合法(如把负数传给了要求正数的函数)。 |
TypeError |
类型错误(如把 list 传给了需要 torch.Tensor 的函数,或数据类型是 Float 但需要 Long)。 |
RuntimeError |
最常见。通常涉及 PyTorch/TensorFlow 内部错误,如矩阵乘法维度对不上、显存溢出(OOM)。 |
4. 最佳实践建议
- 不要滥用裸露的
except:- 坏做法:
except: pass(这会吞掉所有错误,包括你按 Ctrl+C 想停止程序的操作,导致程序无法终止且不知道错哪了)。 - 好做法: 总是捕获具体的错误
except ValueError:,或者至少把错误打印出来except Exception as e: print(e)。
- 坏做法:
- 善用
finally做清理- 如果你使用了进度条(如
tqdm),如果循环报错崩溃,进度条往往会"卡"在屏幕上乱掉。用finally可以手动关闭它。
- 如果你使用了进度条(如
- 结合 Logging(日志)
- 在跑长时间实验时,不要只用
print,建议在except块中配合logging.error()将错误写入文件,这样即使程序半夜崩了,第二天早上你也能在日志里看到是第几轮(Epoch)崩的。
- 在跑长时间实验时,不要只用