PyTorch Lightning 中 TorchMetrics
目录
- [1. 背景与动机](#1. 背景与动机)
- [2. 核心概念](#2. 核心概念)
- [3. 基础用法](#3. 基础用法)
- [4. 回归任务指标详解](#4. 回归任务指标详解)
- [5. MetricCollection 管理多指标](#5. MetricCollection 管理多指标)
- [6. 在 Lightning 中的最佳实践](#6. 在 Lightning 中的最佳实践)
- [7. 常见错误与调试](#7. 常见错误与调试)
- [8. 进阶主题](#8. 进阶主题)
- [9. 完整示例](#9. 完整示例)
- [10. 总结与速查表](#10. 总结与速查表)
1. 背景与动机
1.1 为什么需要 TorchMetrics?
传统做法的问题:
python
# ❌ 手动计算指标的问题
def compute_accuracy(preds, targets):
correct = (preds == targets).sum()
total = len(targets)
return correct / total
# 问题1: 每个batch都要重新计算,无法跨batch累积
# 问题2: 分布式训练时需要手动同步
# 问题3: 代码重复,容易出错
# 问题4: 无法处理边界情况(如除零错误)
TorchMetrics 的优势:
| 特性 | 手动实现 | TorchMetrics |
|---|---|---|
| 状态管理 | 需自己维护累积变量 | ✅ 自动管理状态 |
| 分布式训练 | 需手动同步 | ✅ 自动跨GPU同步 |
| 数值稳定性 | 容易溢出/精度问题 | ✅ 优化过的计算 |
| 代码复用 | 每个项目重写 | ✅ 标准化API |
| Lightning集成 | 需额外适配 | ✅ 原生支持 |
2. 核心概念
2.1 Metric 的三阶段生命周期
TorchMetrics 的核心设计理念是状态累积 → 最终计算:
┌─────────────────────────────────────────────────────┐
│ Metric 生命周期 │
├─────────────────────────────────────────────────────┤
│ │
│ 1. update(preds, targets) ← 每个 batch 调用 │
│ │ │
│ ├─ 累积中间状态(sum, count, etc.) │
│ └─ 不进行最终计算(避免重复计算) │
│ │
│ 2. compute() ← epoch 结束时调用 │
│ │ │
│ ├─ 基于累积状态计算最终结果 │
│ └─ 返回标量指标值 │
│ │
│ 3. reset() ← 开始新 epoch 前调用 │
│ │ │
│ └─ 清空所有累积状态 │
│ │
└─────────────────────────────────────────────────────┘
示例说明:
python
from torchmetrics import Accuracy
import torch
# 初始化指标
acc = Accuracy(task='multiclass', num_classes=3)
# === Epoch 1, Batch 1 ===
preds1 = torch.tensor([0, 1, 2, 1])
targets1 = torch.tensor([0, 1, 2, 2])
acc.update(preds1, targets1) # 内部累积:3 correct, 4 total
# === Epoch 1, Batch 2 ===
preds2 = torch.tensor([0, 0, 1])
targets2 = torch.tensor([0, 0, 1])
acc.update(preds2, targets2) # 继续累积:6 correct, 7 total
# === Epoch 1 结束 ===
final_acc = acc.compute() # 计算:6/7 = 0.8571
print(f"Accuracy: {final_acc}") # 0.8571
# === 准备 Epoch 2 ===
acc.reset() # 清空状态,重新开始
2.2 Metric 的内部状态
每个 Metric 内部维护的状态示例:
python
# 以 MeanSquaredError 为例
class MeanSquaredError:
def __init__(self):
self.sum_squared_error = 0.0 # 累积的平方误差和
self.total_samples = 0 # 累积的样本数
def update(self, preds, target):
self.sum_squared_error += ((preds - target) ** 2).sum()
self.total_samples += target.numel()
def compute(self):
return self.sum_squared_error / self.total_samples
def reset(self):
self.sum_squared_error = 0.0
self.total_samples = 0
关键理解:
- ✅
update()只累积,不计算最终值(高效) - ✅
compute()才进行除法等操作(避免重复计算) - ✅
reset()清空状态,准备下一轮
3. 基础用法
3.1 单个指标的完整流程
python
import torch
from torchmetrics import MeanSquaredError
# ===== 步骤1: 初始化指标 =====
mse = MeanSquaredError()
# ===== 步骤2: 训练循环中更新 =====
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# 前向传播
preds = model(data)
# 更新指标(不计算最终值)
mse.update(preds, target)
# ===== 步骤3: Epoch结束时计算 =====
epoch_mse = mse.compute()
print(f"Epoch {epoch}, MSE: {epoch_mse:.4f}")
# ===== 步骤4: 重置状态 =====
mse.reset()
3.2 常用指标快速索引
分类任务
python
from torchmetrics import (
Accuracy, # 准确率
Precision, # 精确率
Recall, # 召回率
F1Score, # F1分数
ConfusionMatrix, # 混淆矩阵
AUROC, # ROC曲线下面积
)
# 二分类
acc = Accuracy(task='binary')
f1 = F1Score(task='binary')
# 多分类
acc = Accuracy(task='multiclass', num_classes=10)
f1 = F1Score(task='multiclass', num_classes=10, average='macro')
# 多标签
acc = Accuracy(task='multilabel', num_labels=5)
回归任务(重点⭐)
python
from torchmetrics import (
MeanSquaredError, # MSE - 均方误差
MeanAbsoluteError, # MAE - 平均绝对误差
R2Score, # R² - 决定系数(拟合优度)
MeanAbsolutePercentageError, # MAPE - 平均绝对百分比误差
ExplainedVariance, # 解释方差
PearsonCorrCoef, # 皮尔逊相关系数
)
# 基本用法
mse = MeanSquaredError()
mae = MeanAbsoluteError()
r2 = R2Score()
# RMSE(均方根误差)
rmse = MeanSquaredError(squared=False) # 注意参数!
4. 回归任务指标详解
4.1 核心拟合优度指标
📊 R² Score (决定系数)
含义: 模型解释目标变量方差的比例
公式:
R² = 1 - (SS_res / SS_tot)
其中:
SS_res = Σ(y_true - y_pred)² # 残差平方和
SS_tot = Σ(y_true - y_mean)² # 总平方和
取值范围:
R² = 1:完美预测R² = 0:模型性能 = 预测均值R² < 0:模型比预测均值还差(过拟合或严重错误)
代码示例:
python
from torchmetrics import R2Score
import torch
r2 = R2Score()
# 示例数据
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])
r2.update(y_pred, y_true)
score = r2.compute()
print(f"R² Score: {score:.4f}") # 0.9486
# 解读:模型解释了 94.86% 的方差
使用建议:
- ✅ 最常用的回归评估指标
- ✅ 适合比较不同模型
- ⚠️ 对异常值敏感
- ⚠️ 样本量小时可能不稳定
📏 MSE & RMSE
MSE (Mean Squared Error):
MSE = (1/n) Σ(y_true - y_pred)²
RMSE (Root Mean Squared Error):
RMSE = √MSE
对比:
| 指标 | 单位 | 解释性 | 对异常值 |
|---|---|---|---|
| MSE | 平方单位 | 较差 | 非常敏感(平方放大误差) |
| RMSE | 原始单位 | 好 | 敏感 |
代码示例:
python
from torchmetrics import MeanSquaredError
# MSE
mse = MeanSquaredError()
mse.update(y_pred, y_true)
print(f"MSE: {mse.compute():.4f}")
# RMSE(注意 squared=False)
rmse = MeanSquaredError(squared=False)
rmse.update(y_pred, y_true)
print(f"RMSE: {rmse.compute():.4f}")
📐 MAE & MAPE
MAE (Mean Absolute Error):
MAE = (1/n) Σ|y_true - y_pred|
特点:
- ✅ 对异常值不敏感(相比MSE)
- ✅ 单位与原始数据一致
- ✅ 易于解释
MAPE (Mean Absolute Percentage Error):
MAPE = (100/n) Σ|(y_true - y_pred) / y_true|
特点:
- ✅ 输出百分比,便于理解
- ⚠️ 当
y_true = 0时会除零 - ⚠️ 对小值误差敏感
代码示例:
python
from torchmetrics import MeanAbsoluteError, MeanAbsolutePercentageError
mae = MeanAbsoluteError()
mape = MeanAbsolutePercentageError()
# 假设房价预测(单位:万元)
y_true = torch.tensor([100.0, 150.0, 200.0, 250.0])
y_pred = torch.tensor([95.0, 160.0, 195.0, 260.0])
mae.update(y_pred, y_true)
mape.update(y_pred, y_true)
print(f"MAE: {mae.compute():.2f} 万元") # 7.50 万元
print(f"MAPE: {mape.compute():.2f}%") # 4.25%
4.2 指标选择指南
不同场景的推荐:
python
# 场景1: 通用回归任务(推荐组合)
metrics = {
'R2Score': R2Score(), # 拟合优度
'RMSE': MeanSquaredError(squared=False), # 误差大小
'MAE': MeanAbsoluteError() # 鲁棒性评估
}
# 场景2: 对异常值敏感的任务(如金融预测)
metrics = {
'R2Score': R2Score(),
'MAE': MeanAbsoluteError(), # 主要指标
'MAPE': MeanAbsolutePercentageError()
}
# 场景3: 需要严格惩罚大误差的任务
metrics = {
'R2Score': R2Score(),
'MSE': MeanSquaredError(), # 平方惩罚
'RMSE': MeanSquaredError(squared=False)
}
5. MetricCollection 管理多指标
5.1 为什么使用 MetricCollection?
不使用 MetricCollection(繁琐):
python
# ❌ 需要手动管理每个指标
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
# 为每个阶段创建指标
self.train_r2 = R2Score()
self.train_mse = MeanSquaredError()
self.train_mae = MeanAbsoluteError()
self.val_r2 = R2Score()
self.val_mse = MeanSquaredError()
self.val_mae = MeanAbsoluteError()
self.test_r2 = R2Score()
self.test_mse = MeanSquaredError()
self.test_mae = MeanAbsoluteError()
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
# 需要手动更新每个指标
self.val_r2.update(y_hat, y)
self.val_mse.update(y_hat, y)
self.val_mae.update(y_hat, y)
def on_validation_epoch_end(self):
# 手动计算、记录、重置
self.log('val_r2', self.val_r2.compute())
self.log('val_mse', self.val_mse.compute())
self.log('val_mae', self.val_mae.compute())
self.val_r2.reset()
self.val_mse.reset()
self.val_mae.reset()
使用 MetricCollection(优雅✨):
python
from torchmetrics import MetricCollection
class MyModel(pl.LightningModule):
def __init__(self):
super().__init__()
# ✅ 一次定义所有指标
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
'RMSE': MeanSquaredError(squared=False),
'MAE': MeanAbsoluteError(),
'MAPE': MeanAbsolutePercentageError(),
})
# ✅ 自动克隆到不同阶段,添加前缀
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
# ✅ 一次更新所有指标
self.val_metrics.update(y_hat, y)
def on_validation_epoch_end(self):
# ✅ 一次计算、记录、重置所有指标
metrics = self.val_metrics.compute()
self.log_dict(metrics)
self.val_metrics.reset()
5.2 MetricCollection 核心特性
特性1: 自动前缀管理
python
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
})
# 为不同阶段添加前缀
train_metrics = metrics.clone(prefix='train_')
val_metrics = metrics.clone(prefix='val_')
# 结果:
# train_metrics.compute() → {'train_R2Score': 0.85, 'train_MSE': 0.12}
# val_metrics.compute() → {'val_R2Score': 0.82, 'val_MSE': 0.15}
特性2: 批量操作
python
# 批量更新
metrics.update(preds, targets) # 等价于对每个指标调用 update()
# 批量计算
results = metrics.compute() # 返回字典
# 批量重置
metrics.reset()
特性3: 动态添加/删除指标
python
# 初始化
metrics = MetricCollection({
'R2Score': R2Score(),
})
# 运行时添加
metrics.add_metrics({'MSE': MeanSquaredError()})
# 删除指标
del metrics['R2Score']
5.3 完整示例:回归任务最佳实践
python
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import MetricCollection, R2Score, MeanSquaredError, MeanAbsoluteError
class RegressionModel(pl.LightningModule):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.save_hyperparameters()
# 模型结构
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
# ===== 配置指标集合 =====
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
'RMSE': MeanSquaredError(squared=False),
'MAE': MeanAbsoluteError(),
})
# 为不同阶段创建独立实例
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
# 记录 loss
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# 更新指标(不计算)
self.train_metrics.update(y_hat, y)
return loss
def on_train_epoch_end(self):
"""训练 epoch 结束时统一处理指标"""
# 计算所有指标
metrics = self.train_metrics.compute()
# 批量记录到 logger
self.log_dict(metrics, on_epoch=True)
# 重置状态
self.train_metrics.reset()
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
self.val_metrics.update(y_hat, y)
return loss
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics, on_epoch=True, prog_bar=True)
self.val_metrics.reset()
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
self.log('test_loss', loss, on_epoch=True)
self.test_metrics.update(y_hat, y)
return loss
def on_test_epoch_end(self):
"""测试结束,打印详细结果"""
metrics = self.test_metrics.compute()
self.log_dict(metrics, on_epoch=True)
self.test_metrics.reset()
# 打印格式化的测试结果
print("\n" + "="*60)
print("测试集评估结果")
print("="*60)
for key, value in metrics.items():
clean_key = key.replace('test_', '')
print(f" {clean_key:15s}: {value:.6f}")
print("="*60)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
输出示例:
============================================================
测试集评估结果
============================================================
R2Score : 0.856234
MSE : 0.012345
RMSE : 0.111111
MAE : 0.089012
============================================================
其中根据pytorch-lightning中torchmetrics中的记载https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metriccollection Similarly it can also reduce the amount of code required to log multiple metrics inside your LightningModule. In most cases we just have to replace
self.logwithself.log_dict. 后续也在第二个章节self.log和self.log_dict的区别记载了这两个的区别
6. 在 Lightning 中的最佳实践
6.1 完整的生命周期管理
python
class MyLightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.val_metrics = MetricCollection({...})
# ✅ 步骤1: 在 *_step 中只更新(update)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.val_metrics.update(y_hat, y) # 仅更新,不计算
# ✅ 步骤2: 在 on_*_epoch_end 中计算(compute)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute() # 统一计算
self.log_dict(metrics, on_epoch=True)
self.val_metrics.reset() # 重置状态
为什么这样设计?
每个 batch: update() ← 轻量级操作,只累积状态
每个 batch: update()
每个 batch: update()
...
Epoch 结束: compute() ← 一次性计算,避免重复开销
reset() ← 清空状态
6.2 与 EarlyStopping 和 ModelCheckpoint 配合
python
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# ===== 在 MInterface 中 =====
class MInterface(pl.LightningModule):
def __init__(self):
# 使用下划线前缀(与 loss 命名一致)
self.val_metrics = metrics.clone(prefix='val_')
def validation_step(self, batch, batch_idx):
# 确保 val_loss 使用原命名
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
# 指标会自动命名为 val_R2Score, val_MSE 等
self.val_metrics.update(y_hat, y)
# ===== 在 main.py 中配置回调 =====
callbacks = [
# 基于 loss 的早停
EarlyStopping(
monitor='val_loss', # 监控 loss(命名一致)
patience=10,
mode='min'
),
# 也可以基于 R² 保存最佳模型
ModelCheckpoint(
monitor='val_R2Score', # 监控 R²
mode='max', # R² 越大越好
filename='best-{epoch}-{val_R2Score:.4f}'
)
]
命名规则总结:
| 指标类型 | 命名格式 | 示例 |
|---|---|---|
| 手动记录的 loss | 阶段_loss |
train_loss, val_loss |
| MetricCollection 指标 | prefix + 指标名 |
val_R2Score, val_MSE |
6.3 进度条显示配置
python
def validation_step(self, batch, batch_idx):
# 只在进度条显示最重要的指标
self.log('val_loss', loss, prog_bar=True) # ✅ 显示
self.val_metrics.update(y_hat, y)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
# 选择性显示在进度条
self.log('val_R2Score', metrics['val_R2Score'], prog_bar=True) # ✅ 显示
# 其他指标只记录到 logger,不显示在进度条
for key, value in metrics.items():
if key != 'val_R2Score':
self.log(key, value, prog_bar=False) # ❌ 不显示
self.val_metrics.reset()
效果:
Epoch 10: 100%|██████| 50/50 [00:05<00:00, 9.2it/s, loss=0.123, val_R2Score=0.856]
7. 常见错误与调试
7.1 错误1: 忘记调用 reset()
python
# ❌ 错误代码
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics)
# 忘记 reset()!
# 结果:下一个 epoch 的指标会继续累积,导致错误结果
症状:
- 指标值逐渐趋于稳定(因为累积了所有历史数据)
- 第一个 epoch 的指标值正常,后续 epoch 异常
修复:
python
# ✅ 正确代码
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics)
self.val_metrics.reset() # 必须重置!
7.2 错误2: 在 *_step 中调用 compute()
python
# ❌ 错误代码(性能差)
def validation_step(self, batch, batch_idx):
self.val_metrics.update(y_hat, y)
metrics = self.val_metrics.compute() # ❌ 每个 batch 都计算!
self.log_dict(metrics)
# 问题:
# - 每个 batch 都计算一次,重复计算浪费时间
# - 指标值会在 batch 之间波动(不是 epoch 级别的)
修复:
python
# ✅ 正确代码
def validation_step(self, batch, batch_idx):
self.val_metrics.update(y_hat, y) # 只更新
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute() # epoch 结束统一计算
self.log_dict(metrics)
self.val_metrics.reset()
7.3 错误3: 指标维度不匹配
python
# ❌ 常见维度问题
y_hat = model(x) # shape: (batch_size, 1)
y = batch[1] # shape: (batch_size,)
self.metrics.update(y_hat, y) # ❌ 维度不匹配!
症状:
RuntimeError: The size of tensor a (64) must match the size of tensor b (1) at non-singleton dimension 1
修复:
python
# ✅ 方法1: squeeze 预测值
y_hat = model(x).squeeze() # (batch_size, 1) → (batch_size,)
self.metrics.update(y_hat, y)
# ✅ 方法2: 在 loss 计算时统一处理
loss = F.mse_loss(y_hat, y.view_as(y_hat))
self.metrics.update(y_hat.squeeze(), y)
7.4 错误4: 命名冲突导致 Callback 找不到指标
python
# 问题:使用了斜杠前缀
self.val_metrics = metrics.clone(prefix='val/') # val/R2Score
# Callback 配置
EarlyStopping(monitor='val_loss') # ❌ 找不到!实际名称是 'val/loss'
修复:
python
# ✅ 方案1: 统一使用下划线
self.log('val_loss', loss) # val_loss
self.val_metrics = metrics.clone(prefix='val_') # val_R2Score
# ✅ 方案2: 如果使用斜杠,Callback 也要相应修改
self.log('val/loss', loss)
EarlyStopping(monitor='val/loss')
7.5 调试技巧
python
# 技巧1: 打印指标状态
def on_validation_epoch_end(self):
print(f"Metrics before compute: {self.val_metrics}")
metrics = self.val_metrics.compute()
print(f"Computed metrics: {metrics}")
self.val_metrics.reset()
# 技巧2: 检查内部状态
print(self.val_metrics['R2Score'].sum_squared_obs) # R2Score 的内部累积值
# 技巧3: 单独测试指标
r2 = R2Score()
for batch in dataloader:
r2.update(pred, target)
print(f"Batch R2: {r2.compute()}") # 临时查看,正式代码中删除
r2.reset()
8. 进阶主题
8.1 自定义指标
场景: 需要计算调整 R²(Adjusted R²)
python
from torchmetrics import Metric
import torch
class AdjustedR2Score(Metric):
"""
调整 R² 分数
公式:Adjusted R² = 1 - (1 - R²) * (n - 1) / (n - p - 1)
其中 n 是样本数,p 是特征数
"""
def __init__(self, n_features):
super().__init__()
self.n_features = n_features
# 定义状态变量
self.add_state("sum_squared_residuals",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("sum_squared_total",
default=torch.tensor(0.0),
dist_reduce_fx="sum")
self.add_state("num_samples",
default=torch.tensor(0),
dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""每个 batch 调用,累积状态"""
# 残差平方和
self.sum_squared_residuals += torch.sum((target - preds) ** 2)
# 总平方和
self.sum_squared_total += torch.sum((target - target.mean()) ** 2)
# 样本数
self.num_samples += target.numel()
def compute(self):
"""Epoch 结束时调用,计算最终值"""
# 计算 R²
r2 = 1 - (self.sum_squared_residuals / self.sum_squared_total)
# 计算调整 R²
n = self.num_samples
p = self.n_features
adjusted_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
return adjusted_r2
# 使用示例
metrics = MetricCollection({
'R2Score': R2Score(),
'AdjustedR2': AdjustedR2Score(n_features=10), # 假设 10 个特征
})
自定义指标的关键点:
- ✅ 继承
Metric类 - ✅ 使用
add_state()定义状态变量 - ✅ 实现
update()方法(累积) - ✅ 实现
compute()方法(计算) - ✅
dist_reduce_fx指定分布式训练的聚合方式
8.2 分布式训练中的指标同步
TorchMetrics 自动处理分布式训练,无需手动同步:
python
# ✅ 自动处理(无需额外代码)
# 在多GPU训练时:
# 1. 每个GPU独立累积状态(update)
# 2. compute() 时自动跨GPU同步
# 3. 返回全局聚合结果
metrics = MetricCollection({
'R2Score': R2Score(), # 自动支持分布式
'MSE': MeanSquaredError(),
})
# 在 Trainer 中启用多GPU
trainer = pl.Trainer(
accelerator='gpu',
devices=4, # 4个GPU
strategy='ddp' # 分布式数据并行
)
内部机制:
python
# TorchMetrics 内部的处理(简化版)
def compute(self):
# 步骤1: 收集所有GPU的状态
all_sum = torch.distributed.all_reduce(self.sum_squared_residuals)
all_count = torch.distributed.all_reduce(self.num_samples)
# 步骤2: 基于全局状态计算指标
return all_sum / all_count
8.3 Metric 的序列化与恢复
python
# 保存指标状态(用于断点续训)
checkpoint = {
'model_state': model.state_dict(),
'metrics_state': self.val_metrics.state_dict(), # 保存指标状态
}
torch.save(checkpoint, 'checkpoint.ckpt')
# 恢复
checkpoint = torch.load('checkpoint.ckpt')
model.load_state_dict(checkpoint['model_state'])
self.val_metrics.load_state_dict(checkpoint['metrics_state']) # 恢复指标状态
9. 完整示例
9.1 完整的 model_interface.py
python
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torchmetrics import (
MetricCollection,
R2Score,
MeanSquaredError,
MeanAbsoluteError,
MeanAbsolutePercentageError,
)
class MInterface(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.save_hyperparameters(config)
# 加载模型(省略具体实现)
self.model = self._build_model()
# ===== 配置指标 =====
self.configure_metrics()
def configure_metrics(self):
"""配置评估指标"""
task_type = self.config['model']['task_type']
if task_type == 'regression':
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
'RMSE': MeanSquaredError(squared=False),
'MAE': MeanAbsoluteError(),
'MAPE': MeanAbsolutePercentageError(),
})
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.train_metrics.update(y_hat.squeeze(), y.squeeze())
return loss
def on_train_epoch_end(self):
metrics = self.train_metrics.compute()
self.log_dict(metrics, on_epoch=True)
self.train_metrics.reset()
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
self.val_metrics.update(y_hat.squeeze(), y.squeeze())
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics, on_epoch=True, prog_bar=True)
self.val_metrics.reset()
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
self.log('test_loss', loss, on_epoch=True)
self.test_metrics.update(y_hat.squeeze(), y.squeeze())
def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
self.log_dict(metrics, on_epoch=True)
# 打印详细结果
self._print_test_results(metrics)
self.test_metrics.reset()
def _print_test_results(self, metrics):
"""打印格式化的测试结果"""
print("\n" + "="*70)
print(f"{'测试集评估结果':^70}")
print("="*70)
# 分类显示
gof_metrics = {k: v for k, v in metrics.items() if 'R2' in k}
error_metrics = {k: v for k, v in metrics.items() if 'R2' not in k}
if gof_metrics:
print(f"\n{'拟合优度指标':^70}")
print("-"*70)
for k, v in gof_metrics.items():
print(f" {k.replace('test_', ''):20s}: {v:>12.6f}")
if error_metrics:
print(f"\n{'误差指标':^70}")
print("-"*70)
for k, v in error_metrics.items():
print(f" {k.replace('test_', ''):20s}: {v:>12.6f}")
print("="*70 + "\n")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.config['model']['lr'])
def _build_model(self):
# 实际项目中的模型构建逻辑
pass
10. 总结与速查表
10.1 核心概念速查
| 概念 | 说明 | 使用时机 |
|---|---|---|
| update() | 累积中间状态 | 每个 batch 的 *_step 中 |
| compute() | 计算最终指标 | epoch 结束的 on_*_epoch_end 中 |
| reset() | 清空状态 | compute() 后立即调用 |
| MetricCollection | 批量管理指标 | 有多个指标时 |
| clone(prefix) | 为不同阶段创建副本 | train/val/test 分离 |
10.2 回归指标选择速查
python
# 通用回归(推荐)
{
'R2Score': R2Score(), # 拟合优度
'RMSE': MeanSquaredError(squared=False), # 主要误差
'MAE': MeanAbsoluteError() # 鲁棒性检查
}
# 对异常值敏感的场景(金融、医疗)
{
'R2Score': R2Score(),
'MAE': MeanAbsoluteError(), # 主要指标
'MAPE': MeanAbsolutePercentageError()
}
# 需要严格惩罚大误差
{
'R2Score': R2Score(),
'MSE': MeanSquaredError(), # 平方惩罚
'RMSE': MeanSquaredError(squared=False)
}
10.3 命名规范速查
python
# ✅ 推荐命名(下划线风格)
self.log('train_loss', loss)
self.log('val_loss', loss)
self.train_metrics = metrics.clone(prefix='train_') # train_R2Score
self.val_metrics = metrics.clone(prefix='val_') # val_R2Score
# Callback 配置
EarlyStopping(monitor='val_loss') # 匹配
ModelCheckpoint(monitor='val_R2Score') # 匹配
10.4 常见错误速查
| 错误 | 症状 | 解决方案 |
|---|---|---|
忘记 reset() |
指标值趋于稳定 | 在 on_*_epoch_end 中添加 reset() |
在 *_step 中 compute() |
性能差,指标波动 | 只在 on_*_epoch_end 中调用 |
| 维度不匹配 | RuntimeError | 使用 .squeeze() 统一维度 |
| 命名不一致 | Callback 找不到指标 | 统一使用下划线命名 |
10.5 最小可运行示例
python
# 复制此代码即可运行
import pytorch_lightning as pl
import torch
from torchmetrics import MetricCollection, R2Score, MeanSquaredError
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torch.nn.Linear(10, 1)
metrics = MetricCollection({
'R2': R2Score(),
'MSE': MeanSquaredError()
})
self.val_metrics = metrics.clone(prefix='val_')
def forward(self, x):
return self.model(x)
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x).squeeze()
self.val_metrics.update(y_hat, y)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics)
self.val_metrics.reset()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
11. 参考资源
官方文档
- TorchMetrics 官方文档: https://torchmetrics.readthedocs.io/
- PyTorch Lightning 官方文档: https://pytorch-lightning.readthedocs.io/
进阶阅读
- Metric API 设计原理
- 自定义分布式指标
- 高级回调函数集成
好问题!让我详细解释 self.log() 和 self.log_dict() 的区别:
self.log和self.log_dict的区别
核心区别
python
# self.log() - 记录单个标量值
self.log('val_loss', loss_value)
# self.log_dict() - 记录多个标量值(字典)
self.log_dict({'val_loss': loss_value, 'val_acc': acc_value})
1. 详细对比
| 特性 | self.log() | self.log_dict() |
|---|---|---|
| 输入类型 | 单个键值对 | 字典 |
| 适用场景 | 记录单个指标 | 批量记录多个指标 |
| 调用次数 | 每个指标调用一次 | 一次调用记录所有指标 |
| 代码简洁度 | 多个指标时冗长 | 多个指标时简洁 |
2. 使用示例对比
场景1: 记录单个 loss
python
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
# ✅ 方法1: 使用 self.log()
self.log('val_loss', loss, on_epoch=True)
# ✅ 方法2: 使用 self.log_dict()
self.log_dict({'val_loss': loss}, on_epoch=True)
# 结果:两种方法完全等价
场景2: 记录多个指标(重点⭐)
python
def on_validation_epoch_end(self):
# 假设有多个指标
metrics = {
'val_R2Score': 0.856,
'val_MSE': 0.012,
'val_RMSE': 0.111,
'val_MAE': 0.089,
}
# ❌ 方法1: 使用 self.log() - 繁琐
self.log('val_R2Score', metrics['val_R2Score'])
self.log('val_MSE', metrics['val_MSE'])
self.log('val_RMSE', metrics['val_RMSE'])
self.log('val_MAE', metrics['val_MAE'])
# ✅ 方法2: 使用 self.log_dict() - 推荐
self.log_dict(metrics)
# 一行代码完成!
3. 与 MetricCollection 配合
这就是为什么 MetricCollection 和 self.log_dict() 是完美搭配:
python
class MInterface(pl.LightningModule):
def __init__(self):
super().__init__()
# MetricCollection 返回字典
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
'MAE': MeanAbsoluteError(),
})
self.val_metrics = metrics.clone(prefix='val_')
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
# 更新所有指标
self.val_metrics.update(y_hat, y)
def on_validation_epoch_end(self):
# compute() 返回字典
metrics = self.val_metrics.compute()
# 返回格式: {'val_R2Score': 0.85, 'val_MSE': 0.12, 'val_MAE': 0.08}
# ✅ 完美配合!一行搞定
self.log_dict(metrics, on_epoch=True)
# ❌ 如果用 self.log() 需要循环
# for key, value in metrics.items():
# self.log(key, value, on_epoch=True)
self.val_metrics.reset()
4. 参数对比
两个方法支持相同的参数,但使用方式不同:
self.log() 的参数
python
self.log(
name='val_loss', # 指标名称
value=loss, # 指标值(标量)
prog_bar=True, # 是否显示在进度条
logger=True, # 是否记录到logger
on_step=False, # 是否记录每个step
on_epoch=True, # 是否记录每个epoch
reduce_fx='mean', # 聚合方式
sync_dist=False # 是否跨GPU同步
)
self.log_dict() 的参数
python
self.log_dict(
dictionary={ # 指标字典
'val_loss': loss,
'val_acc': acc
},
prog_bar=True, # 应用到所有指标
logger=True,
on_step=False,
on_epoch=True,
reduce_fx='mean',
sync_dist=False
)
关键区别:
self.log_dict()的参数会应用到字典中的所有指标- 如果需要为不同指标设置不同参数,必须分别调用
self.log()
5. 实际场景选择指南
使用 self.log() 的场景
python
# 场景1: 单个指标
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log('train_loss', loss, on_step=True, on_epoch=True)
return loss
# 场景2: 需要不同的显示设置
def validation_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
acc = self.compute_acc(batch)
# loss 显示在进度条
self.log('val_loss', loss, prog_bar=True)
# acc 不显示在进度条
self.log('val_acc', acc, prog_bar=False)
使用 self.log_dict() 的场景
python
# 场景1: MetricCollection(最常见)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
self.log_dict(metrics, on_epoch=True)
self.val_metrics.reset()
# 场景2: 多个相同配置的指标
def validation_step(self, batch, batch_idx):
results = {
'val_loss': loss,
'val_mse': mse,
'val_mae': mae
}
# 所有指标都显示在进度条
self.log_dict(results, prog_bar=True)
# 场景3: 手动构建指标字典
def test_step(self, batch, batch_idx):
metrics = {
'test_r2': self.compute_r2(),
'test_mse': self.compute_mse(),
'test_mae': self.compute_mae()
}
self.log_dict(metrics)
6. 混合使用示例
在实际项目中,通常会混合使用:
python
class MInterface(pl.LightningModule):
def __init__(self):
super().__init__()
metrics = MetricCollection({
'R2Score': R2Score(),
'MSE': MeanSquaredError(),
'MAE': MeanAbsoluteError(),
})
self.val_metrics = metrics.clone(prefix='val_')
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
# ✅ 单独记录 loss(需要在进度条显示)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
# 更新指标集合
self.val_metrics.update(y_hat, y)
def on_validation_epoch_end(self):
metrics = self.val_metrics.compute()
# ✅ 批量记录指标(不显示在进度条)
self.log_dict(metrics, on_epoch=True, prog_bar=False)
# 或者:选择性显示某个指标
self.log('val_R2Score', metrics['val_R2Score'], prog_bar=True)
self.val_metrics.reset()
7. 常见误区
误区1: log_dict 可以记录嵌套字典
python
# ❌ 错误:log_dict 不支持嵌套字典
metrics = {
'validation': {
'loss': 0.1,
'acc': 0.9
}
}
self.log_dict(metrics) # ❌ 报错!
# ✅ 正确:扁平化字典
metrics = {
'val_loss': 0.1,
'val_acc': 0.9
}
self.log_dict(metrics) # ✅ OK
误区2: log_dict 中的值必须是标量
python
# ❌ 错误:值必须是标量,不能是张量
metrics = {
'val_preds': torch.tensor([1, 2, 3]) # ❌ 不是标量
}
self.log_dict(metrics) # ❌ 报错!
# ✅ 正确:转换为标量
metrics = {
'val_mean_pred': torch.tensor([1, 2, 3]).mean() # ✅ 标量
}
self.log_dict(metrics) # ✅ OK
8. 性能对比
python
import time
# 性能测试:记录100个指标
metrics = {f'metric_{i}': i * 0.01 for i in range(100)}
# 方法1: 循环调用 self.log()
start = time.time()
for key, value in metrics.items():
self.log(key, value)
time1 = time.time() - start
# 方法2: 一次调用 self.log_dict()
start = time.time()
self.log_dict(metrics)
time2 = time.time() - start
# 结果:self.log_dict() 通常更快(内部优化)
print(f"self.log() 循环: {time1:.4f}s")
print(f"self.log_dict(): {time2:.4f}s")
总结速查表
| 特性 | self.log() | self.log_dict() |
|---|---|---|
| 输入 | name, value |
dictionary |
| 适用场景 | 1-2个指标 | 3+个指标 |
| 与MetricCollection | 需要循环 | ✅ 完美配合 |
| 进度条个性化 | ✅ 支持 | ❌ 统一设置 |
| 代码简洁度 | 多指标时差 | ✅ 简洁 |
| 性能 | 多次调用 | ✅ 一次调用 |
推荐做法:
python
# loss - 使用 self.log()(需要特殊配置)
self.log('val_loss', loss, prog_bar=True)
# 多个指标 - 使用 self.log_dict()(简洁)
self.log_dict(self.val_metrics.compute())