PyTorch Lightning 中 TorchMetrics

PyTorch Lightning 中 TorchMetrics


目录

  • [1. 背景与动机](#1. 背景与动机)
  • [2. 核心概念](#2. 核心概念)
  • [3. 基础用法](#3. 基础用法)
  • [4. 回归任务指标详解](#4. 回归任务指标详解)
  • [5. MetricCollection 管理多指标](#5. MetricCollection 管理多指标)
  • [6. 在 Lightning 中的最佳实践](#6. 在 Lightning 中的最佳实践)
  • [7. 常见错误与调试](#7. 常见错误与调试)
  • [8. 进阶主题](#8. 进阶主题)
  • [9. 完整示例](#9. 完整示例)
  • [10. 总结与速查表](#10. 总结与速查表)

1. 背景与动机

1.1 为什么需要 TorchMetrics?

传统做法的问题:

python 复制代码
# ❌ 手动计算指标的问题
def compute_accuracy(preds, targets):
    correct = (preds == targets).sum()
    total = len(targets)
    return correct / total

# 问题1: 每个batch都要重新计算,无法跨batch累积
# 问题2: 分布式训练时需要手动同步
# 问题3: 代码重复,容易出错
# 问题4: 无法处理边界情况(如除零错误)

TorchMetrics 的优势:

特性 手动实现 TorchMetrics
状态管理 需自己维护累积变量 ✅ 自动管理状态
分布式训练 需手动同步 ✅ 自动跨GPU同步
数值稳定性 容易溢出/精度问题 ✅ 优化过的计算
代码复用 每个项目重写 ✅ 标准化API
Lightning集成 需额外适配 ✅ 原生支持

2. 核心概念

2.1 Metric 的三阶段生命周期

TorchMetrics 的核心设计理念是状态累积 → 最终计算

复制代码
┌─────────────────────────────────────────────────────┐
│                   Metric 生命周期                      │
├─────────────────────────────────────────────────────┤
│                                                     │
│  1. update(preds, targets)  ← 每个 batch 调用      │
│     │                                               │
│     ├─ 累积中间状态(sum, count, etc.)             │
│     └─ 不进行最终计算(避免重复计算)                │
│                                                     │
│  2. compute()               ← epoch 结束时调用      │
│     │                                               │
│     ├─ 基于累积状态计算最终结果                      │
│     └─ 返回标量指标值                               │
│                                                     │
│  3. reset()                 ← 开始新 epoch 前调用   │
│     │                                               │
│     └─ 清空所有累积状态                             │
│                                                     │
└─────────────────────────────────────────────────────┘

示例说明:

python 复制代码
from torchmetrics import Accuracy
import torch

# 初始化指标
acc = Accuracy(task='multiclass', num_classes=3)

# === Epoch 1, Batch 1 ===
preds1 = torch.tensor([0, 1, 2, 1])
targets1 = torch.tensor([0, 1, 2, 2])
acc.update(preds1, targets1)  # 内部累积:3 correct, 4 total

# === Epoch 1, Batch 2 ===
preds2 = torch.tensor([0, 0, 1])
targets2 = torch.tensor([0, 0, 1])
acc.update(preds2, targets2)  # 继续累积:6 correct, 7 total

# === Epoch 1 结束 ===
final_acc = acc.compute()  # 计算:6/7 = 0.8571
print(f"Accuracy: {final_acc}")  # 0.8571

# === 准备 Epoch 2 ===
acc.reset()  # 清空状态,重新开始

2.2 Metric 的内部状态

每个 Metric 内部维护的状态示例:

python 复制代码
# 以 MeanSquaredError 为例
class MeanSquaredError:
    def __init__(self):
        self.sum_squared_error = 0.0  # 累积的平方误差和
        self.total_samples = 0         # 累积的样本数
    
    def update(self, preds, target):
        self.sum_squared_error += ((preds - target) ** 2).sum()
        self.total_samples += target.numel()
    
    def compute(self):
        return self.sum_squared_error / self.total_samples
    
    def reset(self):
        self.sum_squared_error = 0.0
        self.total_samples = 0

关键理解:

  • update() 只累积,不计算最终值(高效)
  • compute() 才进行除法等操作(避免重复计算)
  • reset() 清空状态,准备下一轮

3. 基础用法

3.1 单个指标的完整流程

python 复制代码
import torch
from torchmetrics import MeanSquaredError

# ===== 步骤1: 初始化指标 =====
mse = MeanSquaredError()

# ===== 步骤2: 训练循环中更新 =====
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # 前向传播
        preds = model(data)
        
        # 更新指标(不计算最终值)
        mse.update(preds, target)
    
    # ===== 步骤3: Epoch结束时计算 =====
    epoch_mse = mse.compute()
    print(f"Epoch {epoch}, MSE: {epoch_mse:.4f}")
    
    # ===== 步骤4: 重置状态 =====
    mse.reset()

3.2 常用指标快速索引

分类任务
python 复制代码
from torchmetrics import (
    Accuracy,           # 准确率
    Precision,          # 精确率
    Recall,             # 召回率
    F1Score,            # F1分数
    ConfusionMatrix,    # 混淆矩阵
    AUROC,              # ROC曲线下面积
)

# 二分类
acc = Accuracy(task='binary')
f1 = F1Score(task='binary')

# 多分类
acc = Accuracy(task='multiclass', num_classes=10)
f1 = F1Score(task='multiclass', num_classes=10, average='macro')

# 多标签
acc = Accuracy(task='multilabel', num_labels=5)
回归任务(重点⭐)
python 复制代码
from torchmetrics import (
    MeanSquaredError,              # MSE - 均方误差
    MeanAbsoluteError,             # MAE - 平均绝对误差
    R2Score,                       # R² - 决定系数(拟合优度)
    MeanAbsolutePercentageError,   # MAPE - 平均绝对百分比误差
    ExplainedVariance,             # 解释方差
    PearsonCorrCoef,               # 皮尔逊相关系数
)

# 基本用法
mse = MeanSquaredError()
mae = MeanAbsoluteError()
r2 = R2Score()

# RMSE(均方根误差)
rmse = MeanSquaredError(squared=False)  # 注意参数!

4. 回归任务指标详解

4.1 核心拟合优度指标

📊 R² Score (决定系数)

含义: 模型解释目标变量方差的比例

公式:

复制代码
R² = 1 - (SS_res / SS_tot)

其中:
SS_res = Σ(y_true - y_pred)²  # 残差平方和
SS_tot = Σ(y_true - y_mean)²  # 总平方和

取值范围:

  • R² = 1:完美预测
  • R² = 0:模型性能 = 预测均值
  • R² < 0:模型比预测均值还差(过拟合或严重错误)

代码示例:

python 复制代码
from torchmetrics import R2Score
import torch

r2 = R2Score()

# 示例数据
y_true = torch.tensor([3.0, -0.5, 2.0, 7.0])
y_pred = torch.tensor([2.5, 0.0, 2.0, 8.0])

r2.update(y_pred, y_true)
score = r2.compute()
print(f"R² Score: {score:.4f}")  # 0.9486

# 解读:模型解释了 94.86% 的方差

使用建议:

  • ✅ 最常用的回归评估指标
  • ✅ 适合比较不同模型
  • ⚠️ 对异常值敏感
  • ⚠️ 样本量小时可能不稳定

📏 MSE & RMSE

MSE (Mean Squared Error):

复制代码
MSE = (1/n) Σ(y_true - y_pred)²

RMSE (Root Mean Squared Error):

复制代码
RMSE = √MSE

对比:

指标 单位 解释性 对异常值
MSE 平方单位 较差 非常敏感(平方放大误差)
RMSE 原始单位 敏感

代码示例:

python 复制代码
from torchmetrics import MeanSquaredError

# MSE
mse = MeanSquaredError()
mse.update(y_pred, y_true)
print(f"MSE: {mse.compute():.4f}")

# RMSE(注意 squared=False)
rmse = MeanSquaredError(squared=False)
rmse.update(y_pred, y_true)
print(f"RMSE: {rmse.compute():.4f}")

📐 MAE & MAPE

MAE (Mean Absolute Error):

复制代码
MAE = (1/n) Σ|y_true - y_pred|

特点:

  • ✅ 对异常值不敏感(相比MSE)
  • ✅ 单位与原始数据一致
  • ✅ 易于解释

MAPE (Mean Absolute Percentage Error):

复制代码
MAPE = (100/n) Σ|(y_true - y_pred) / y_true|

特点:

  • ✅ 输出百分比,便于理解
  • ⚠️ 当 y_true = 0 时会除零
  • ⚠️ 对小值误差敏感

代码示例:

python 复制代码
from torchmetrics import MeanAbsoluteError, MeanAbsolutePercentageError

mae = MeanAbsoluteError()
mape = MeanAbsolutePercentageError()

# 假设房价预测(单位:万元)
y_true = torch.tensor([100.0, 150.0, 200.0, 250.0])
y_pred = torch.tensor([95.0, 160.0, 195.0, 260.0])

mae.update(y_pred, y_true)
mape.update(y_pred, y_true)

print(f"MAE: {mae.compute():.2f} 万元")    # 7.50 万元
print(f"MAPE: {mape.compute():.2f}%")      # 4.25%

4.2 指标选择指南

不同场景的推荐:

python 复制代码
# 场景1: 通用回归任务(推荐组合)
metrics = {
    'R2Score': R2Score(),        # 拟合优度
    'RMSE': MeanSquaredError(squared=False),  # 误差大小
    'MAE': MeanAbsoluteError()   # 鲁棒性评估
}

# 场景2: 对异常值敏感的任务(如金融预测)
metrics = {
    'R2Score': R2Score(),
    'MAE': MeanAbsoluteError(),  # 主要指标
    'MAPE': MeanAbsolutePercentageError()
}

# 场景3: 需要严格惩罚大误差的任务
metrics = {
    'R2Score': R2Score(),
    'MSE': MeanSquaredError(),   # 平方惩罚
    'RMSE': MeanSquaredError(squared=False)
}

5. MetricCollection 管理多指标

5.1 为什么使用 MetricCollection?

不使用 MetricCollection(繁琐):

python 复制代码
# ❌ 需要手动管理每个指标
class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # 为每个阶段创建指标
        self.train_r2 = R2Score()
        self.train_mse = MeanSquaredError()
        self.train_mae = MeanAbsoluteError()
        
        self.val_r2 = R2Score()
        self.val_mse = MeanSquaredError()
        self.val_mae = MeanAbsoluteError()
        
        self.test_r2 = R2Score()
        self.test_mse = MeanSquaredError()
        self.test_mae = MeanAbsoluteError()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        # 需要手动更新每个指标
        self.val_r2.update(y_hat, y)
        self.val_mse.update(y_hat, y)
        self.val_mae.update(y_hat, y)
    
    def on_validation_epoch_end(self):
        # 手动计算、记录、重置
        self.log('val_r2', self.val_r2.compute())
        self.log('val_mse', self.val_mse.compute())
        self.log('val_mae', self.val_mae.compute())
        
        self.val_r2.reset()
        self.val_mse.reset()
        self.val_mae.reset()

使用 MetricCollection(优雅✨):

python 复制代码
from torchmetrics import MetricCollection

class MyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        # ✅ 一次定义所有指标
        metrics = MetricCollection({
            'R2Score': R2Score(),
            'MSE': MeanSquaredError(),
            'RMSE': MeanSquaredError(squared=False),
            'MAE': MeanAbsoluteError(),
            'MAPE': MeanAbsolutePercentageError(),
        })
        
        # ✅ 自动克隆到不同阶段,添加前缀
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        # ✅ 一次更新所有指标
        self.val_metrics.update(y_hat, y)
    
    def on_validation_epoch_end(self):
        # ✅ 一次计算、记录、重置所有指标
        metrics = self.val_metrics.compute()
        self.log_dict(metrics)
        self.val_metrics.reset()

5.2 MetricCollection 核心特性

特性1: 自动前缀管理
python 复制代码
metrics = MetricCollection({
    'R2Score': R2Score(),
    'MSE': MeanSquaredError(),
})

# 为不同阶段添加前缀
train_metrics = metrics.clone(prefix='train_')
val_metrics = metrics.clone(prefix='val_')

# 结果:
# train_metrics.compute() → {'train_R2Score': 0.85, 'train_MSE': 0.12}
# val_metrics.compute()   → {'val_R2Score': 0.82, 'val_MSE': 0.15}
特性2: 批量操作
python 复制代码
# 批量更新
metrics.update(preds, targets)  # 等价于对每个指标调用 update()

# 批量计算
results = metrics.compute()  # 返回字典

# 批量重置
metrics.reset()
特性3: 动态添加/删除指标
python 复制代码
# 初始化
metrics = MetricCollection({
    'R2Score': R2Score(),
})

# 运行时添加
metrics.add_metrics({'MSE': MeanSquaredError()})

# 删除指标
del metrics['R2Score']

5.3 完整示例:回归任务最佳实践

python 复制代码
import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics import MetricCollection, R2Score, MeanSquaredError, MeanAbsoluteError

class RegressionModel(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.save_hyperparameters()
        
        # 模型结构
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        # ===== 配置指标集合 =====
        metrics = MetricCollection({
            'R2Score': R2Score(),
            'MSE': MeanSquaredError(),
            'RMSE': MeanSquaredError(squared=False),
            'MAE': MeanAbsoluteError(),
        })
        
        # 为不同阶段创建独立实例
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        
        # 记录 loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        # 更新指标(不计算)
        self.train_metrics.update(y_hat, y)
        
        return loss
    
    def on_train_epoch_end(self):
        """训练 epoch 结束时统一处理指标"""
        # 计算所有指标
        metrics = self.train_metrics.compute()
        
        # 批量记录到 logger
        self.log_dict(metrics, on_epoch=True)
        
        # 重置状态
        self.train_metrics.reset()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.val_metrics.update(y_hat, y)
        
        return loss
    
    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()
        self.log_dict(metrics, on_epoch=True, prog_bar=True)
        self.val_metrics.reset()
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_epoch=True)
        self.test_metrics.update(y_hat, y)
        
        return loss
    
    def on_test_epoch_end(self):
        """测试结束,打印详细结果"""
        metrics = self.test_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        self.test_metrics.reset()
        
        # 打印格式化的测试结果
        print("\n" + "="*60)
        print("测试集评估结果")
        print("="*60)
        for key, value in metrics.items():
            clean_key = key.replace('test_', '')
            print(f"  {clean_key:15s}: {value:.6f}")
        print("="*60)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

输出示例:

复制代码
============================================================
测试集评估结果
============================================================
  R2Score        : 0.856234
  MSE            : 0.012345
  RMSE           : 0.111111
  MAE            : 0.089012
============================================================

其中根据pytorch-lightning中torchmetrics中的记载https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metriccollection Similarly it can also reduce the amount of code required to log multiple metrics inside your LightningModule. In most cases we just have to replace self.log with self.log_dict. 后续也在第二个章节self.log和self.log_dict的区别记载了这两个的区别

6. 在 Lightning 中的最佳实践

6.1 完整的生命周期管理

python 复制代码
class MyLightningModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.val_metrics = MetricCollection({...})
    
    # ✅ 步骤1: 在 *_step 中只更新(update)
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.val_metrics.update(y_hat, y)  # 仅更新,不计算
    
    # ✅ 步骤2: 在 on_*_epoch_end 中计算(compute)
    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()  # 统一计算
        self.log_dict(metrics, on_epoch=True)
        self.val_metrics.reset()  # 重置状态

为什么这样设计?

复制代码
每个 batch:  update()  ← 轻量级操作,只累积状态
每个 batch:  update()
每个 batch:  update()
            ...
Epoch 结束:  compute() ← 一次性计算,避免重复开销
            reset()   ← 清空状态

6.2 与 EarlyStopping 和 ModelCheckpoint 配合

python 复制代码
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

# ===== 在 MInterface 中 =====
class MInterface(pl.LightningModule):
    def __init__(self):
        # 使用下划线前缀(与 loss 命名一致)
        self.val_metrics = metrics.clone(prefix='val_')
    
    def validation_step(self, batch, batch_idx):
        # 确保 val_loss 使用原命名
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        
        # 指标会自动命名为 val_R2Score, val_MSE 等
        self.val_metrics.update(y_hat, y)

# ===== 在 main.py 中配置回调 =====
callbacks = [
    # 基于 loss 的早停
    EarlyStopping(
        monitor='val_loss',  # 监控 loss(命名一致)
        patience=10,
        mode='min'
    ),
    
    # 也可以基于 R² 保存最佳模型
    ModelCheckpoint(
        monitor='val_R2Score',  # 监控 R²
        mode='max',             # R² 越大越好
        filename='best-{epoch}-{val_R2Score:.4f}'
    )
]

命名规则总结:

指标类型 命名格式 示例
手动记录的 loss 阶段_loss train_loss, val_loss
MetricCollection 指标 prefix + 指标名 val_R2Score, val_MSE

6.3 进度条显示配置

python 复制代码
def validation_step(self, batch, batch_idx):
    # 只在进度条显示最重要的指标
    self.log('val_loss', loss, prog_bar=True)  # ✅ 显示
    self.val_metrics.update(y_hat, y)

def on_validation_epoch_end(self):
    metrics = self.val_metrics.compute()
    
    # 选择性显示在进度条
    self.log('val_R2Score', metrics['val_R2Score'], prog_bar=True)  # ✅ 显示
    
    # 其他指标只记录到 logger,不显示在进度条
    for key, value in metrics.items():
        if key != 'val_R2Score':
            self.log(key, value, prog_bar=False)  # ❌ 不显示
    
    self.val_metrics.reset()

效果:

复制代码
Epoch 10: 100%|██████| 50/50 [00:05<00:00,  9.2it/s, loss=0.123, val_R2Score=0.856]

7. 常见错误与调试

7.1 错误1: 忘记调用 reset()

python 复制代码
# ❌ 错误代码
def on_validation_epoch_end(self):
    metrics = self.val_metrics.compute()
    self.log_dict(metrics)
    # 忘记 reset()!

# 结果:下一个 epoch 的指标会继续累积,导致错误结果

症状:

  • 指标值逐渐趋于稳定(因为累积了所有历史数据)
  • 第一个 epoch 的指标值正常,后续 epoch 异常

修复:

python 复制代码
# ✅ 正确代码
def on_validation_epoch_end(self):
    metrics = self.val_metrics.compute()
    self.log_dict(metrics)
    self.val_metrics.reset()  # 必须重置!

7.2 错误2: 在 *_step 中调用 compute()

python 复制代码
# ❌ 错误代码(性能差)
def validation_step(self, batch, batch_idx):
    self.val_metrics.update(y_hat, y)
    metrics = self.val_metrics.compute()  # ❌ 每个 batch 都计算!
    self.log_dict(metrics)

# 问题:
# - 每个 batch 都计算一次,重复计算浪费时间
# - 指标值会在 batch 之间波动(不是 epoch 级别的)

修复:

python 复制代码
# ✅ 正确代码
def validation_step(self, batch, batch_idx):
    self.val_metrics.update(y_hat, y)  # 只更新

def on_validation_epoch_end(self):
    metrics = self.val_metrics.compute()  # epoch 结束统一计算
    self.log_dict(metrics)
    self.val_metrics.reset()

7.3 错误3: 指标维度不匹配

python 复制代码
# ❌ 常见维度问题
y_hat = model(x)  # shape: (batch_size, 1)
y = batch[1]      # shape: (batch_size,)

self.metrics.update(y_hat, y)  # ❌ 维度不匹配!

症状:

复制代码
RuntimeError: The size of tensor a (64) must match the size of tensor b (1) at non-singleton dimension 1

修复:

python 复制代码
# ✅ 方法1: squeeze 预测值
y_hat = model(x).squeeze()  # (batch_size, 1) → (batch_size,)
self.metrics.update(y_hat, y)

# ✅ 方法2: 在 loss 计算时统一处理
loss = F.mse_loss(y_hat, y.view_as(y_hat))
self.metrics.update(y_hat.squeeze(), y)

7.4 错误4: 命名冲突导致 Callback 找不到指标

python 复制代码
# 问题:使用了斜杠前缀
self.val_metrics = metrics.clone(prefix='val/')  # val/R2Score

# Callback 配置
EarlyStopping(monitor='val_loss')  # ❌ 找不到!实际名称是 'val/loss'

修复:

python 复制代码
# ✅ 方案1: 统一使用下划线
self.log('val_loss', loss)  # val_loss
self.val_metrics = metrics.clone(prefix='val_')  # val_R2Score

# ✅ 方案2: 如果使用斜杠,Callback 也要相应修改
self.log('val/loss', loss)
EarlyStopping(monitor='val/loss')

7.5 调试技巧

python 复制代码
# 技巧1: 打印指标状态
def on_validation_epoch_end(self):
    print(f"Metrics before compute: {self.val_metrics}")
    metrics = self.val_metrics.compute()
    print(f"Computed metrics: {metrics}")
    self.val_metrics.reset()

# 技巧2: 检查内部状态
print(self.val_metrics['R2Score'].sum_squared_obs)  # R2Score 的内部累积值

# 技巧3: 单独测试指标
r2 = R2Score()
for batch in dataloader:
    r2.update(pred, target)
    print(f"Batch R2: {r2.compute()}")  # 临时查看,正式代码中删除
    r2.reset()

8. 进阶主题

8.1 自定义指标

场景: 需要计算调整 R²(Adjusted R²)

python 复制代码
from torchmetrics import Metric
import torch

class AdjustedR2Score(Metric):
    """
    调整 R² 分数
    
    公式:Adjusted R² = 1 - (1 - R²) * (n - 1) / (n - p - 1)
    其中 n 是样本数,p 是特征数
    """
    def __init__(self, n_features):
        super().__init__()
        self.n_features = n_features
        
        # 定义状态变量
        self.add_state("sum_squared_residuals", 
                       default=torch.tensor(0.0), 
                       dist_reduce_fx="sum")
        self.add_state("sum_squared_total", 
                       default=torch.tensor(0.0), 
                       dist_reduce_fx="sum")
        self.add_state("num_samples", 
                       default=torch.tensor(0), 
                       dist_reduce_fx="sum")
    
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """每个 batch 调用,累积状态"""
        # 残差平方和
        self.sum_squared_residuals += torch.sum((target - preds) ** 2)
        
        # 总平方和
        self.sum_squared_total += torch.sum((target - target.mean()) ** 2)
        
        # 样本数
        self.num_samples += target.numel()
    
    def compute(self):
        """Epoch 结束时调用,计算最终值"""
        # 计算 R²
        r2 = 1 - (self.sum_squared_residuals / self.sum_squared_total)
        
        # 计算调整 R²
        n = self.num_samples
        p = self.n_features
        adjusted_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
        
        return adjusted_r2

# 使用示例
metrics = MetricCollection({
    'R2Score': R2Score(),
    'AdjustedR2': AdjustedR2Score(n_features=10),  # 假设 10 个特征
})

自定义指标的关键点:

  1. ✅ 继承 Metric
  2. ✅ 使用 add_state() 定义状态变量
  3. ✅ 实现 update() 方法(累积)
  4. ✅ 实现 compute() 方法(计算)
  5. dist_reduce_fx 指定分布式训练的聚合方式

8.2 分布式训练中的指标同步

TorchMetrics 自动处理分布式训练,无需手动同步:

python 复制代码
# ✅ 自动处理(无需额外代码)
# 在多GPU训练时:
# 1. 每个GPU独立累积状态(update)
# 2. compute() 时自动跨GPU同步
# 3. 返回全局聚合结果

metrics = MetricCollection({
    'R2Score': R2Score(),  # 自动支持分布式
    'MSE': MeanSquaredError(),
})

# 在 Trainer 中启用多GPU
trainer = pl.Trainer(
    accelerator='gpu',
    devices=4,  # 4个GPU
    strategy='ddp'  # 分布式数据并行
)

内部机制:

python 复制代码
# TorchMetrics 内部的处理(简化版)
def compute(self):
    # 步骤1: 收集所有GPU的状态
    all_sum = torch.distributed.all_reduce(self.sum_squared_residuals)
    all_count = torch.distributed.all_reduce(self.num_samples)
    
    # 步骤2: 基于全局状态计算指标
    return all_sum / all_count

8.3 Metric 的序列化与恢复

python 复制代码
# 保存指标状态(用于断点续训)
checkpoint = {
    'model_state': model.state_dict(),
    'metrics_state': self.val_metrics.state_dict(),  # 保存指标状态
}
torch.save(checkpoint, 'checkpoint.ckpt')

# 恢复
checkpoint = torch.load('checkpoint.ckpt')
model.load_state_dict(checkpoint['model_state'])
self.val_metrics.load_state_dict(checkpoint['metrics_state'])  # 恢复指标状态

9. 完整示例

9.1 完整的 model_interface.py

python 复制代码
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torchmetrics import (
    MetricCollection,
    R2Score,
    MeanSquaredError,
    MeanAbsoluteError,
    MeanAbsolutePercentageError,
)

class MInterface(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters(config)
        
        # 加载模型(省略具体实现)
        self.model = self._build_model()
        
        # ===== 配置指标 =====
        self.configure_metrics()
    
    def configure_metrics(self):
        """配置评估指标"""
        task_type = self.config['model']['task_type']
        
        if task_type == 'regression':
            metrics = MetricCollection({
                'R2Score': R2Score(),
                'MSE': MeanSquaredError(),
                'RMSE': MeanSquaredError(squared=False),
                'MAE': MeanAbsoluteError(),
                'MAPE': MeanAbsolutePercentageError(),
            })
            
            self.train_metrics = metrics.clone(prefix='train_')
            self.val_metrics = metrics.clone(prefix='val_')
            self.test_metrics = metrics.clone(prefix='test_')
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.train_metrics.update(y_hat.squeeze(), y.squeeze())
        
        return loss
    
    def on_train_epoch_end(self):
        metrics = self.train_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        self.train_metrics.reset()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.val_metrics.update(y_hat.squeeze(), y.squeeze())
    
    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()
        self.log_dict(metrics, on_epoch=True, prog_bar=True)
        self.val_metrics.reset()
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log('test_loss', loss, on_epoch=True)
        self.test_metrics.update(y_hat.squeeze(), y.squeeze())
    
    def on_test_epoch_end(self):
        metrics = self.test_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        
        # 打印详细结果
        self._print_test_results(metrics)
        
        self.test_metrics.reset()
    
    def _print_test_results(self, metrics):
        """打印格式化的测试结果"""
        print("\n" + "="*70)
        print(f"{'测试集评估结果':^70}")
        print("="*70)
        
        # 分类显示
        gof_metrics = {k: v for k, v in metrics.items() if 'R2' in k}
        error_metrics = {k: v for k, v in metrics.items() if 'R2' not in k}
        
        if gof_metrics:
            print(f"\n{'拟合优度指标':^70}")
            print("-"*70)
            for k, v in gof_metrics.items():
                print(f"  {k.replace('test_', ''):20s}: {v:>12.6f}")
        
        if error_metrics:
            print(f"\n{'误差指标':^70}")
            print("-"*70)
            for k, v in error_metrics.items():
                print(f"  {k.replace('test_', ''):20s}: {v:>12.6f}")
        
        print("="*70 + "\n")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.config['model']['lr'])
    
    def _build_model(self):
        # 实际项目中的模型构建逻辑
        pass

10. 总结与速查表

10.1 核心概念速查

概念 说明 使用时机
update() 累积中间状态 每个 batch 的 *_step
compute() 计算最终指标 epoch 结束的 on_*_epoch_end
reset() 清空状态 compute() 后立即调用
MetricCollection 批量管理指标 有多个指标时
clone(prefix) 为不同阶段创建副本 train/val/test 分离

10.2 回归指标选择速查

python 复制代码
# 通用回归(推荐)
{
    'R2Score': R2Score(),                           # 拟合优度
    'RMSE': MeanSquaredError(squared=False),        # 主要误差
    'MAE': MeanAbsoluteError()                      # 鲁棒性检查
}

# 对异常值敏感的场景(金融、医疗)
{
    'R2Score': R2Score(),
    'MAE': MeanAbsoluteError(),                     # 主要指标
    'MAPE': MeanAbsolutePercentageError()
}

# 需要严格惩罚大误差
{
    'R2Score': R2Score(),
    'MSE': MeanSquaredError(),                      # 平方惩罚
    'RMSE': MeanSquaredError(squared=False)
}

10.3 命名规范速查

python 复制代码
# ✅ 推荐命名(下划线风格)
self.log('train_loss', loss)
self.log('val_loss', loss)
self.train_metrics = metrics.clone(prefix='train_')  # train_R2Score
self.val_metrics = metrics.clone(prefix='val_')      # val_R2Score

# Callback 配置
EarlyStopping(monitor='val_loss')      # 匹配
ModelCheckpoint(monitor='val_R2Score') # 匹配

10.4 常见错误速查

错误 症状 解决方案
忘记 reset() 指标值趋于稳定 on_*_epoch_end 中添加 reset()
*_stepcompute() 性能差,指标波动 只在 on_*_epoch_end 中调用
维度不匹配 RuntimeError 使用 .squeeze() 统一维度
命名不一致 Callback 找不到指标 统一使用下划线命名

10.5 最小可运行示例

python 复制代码
# 复制此代码即可运行
import pytorch_lightning as pl
import torch
from torchmetrics import MetricCollection, R2Score, MeanSquaredError

class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(10, 1)
        
        metrics = MetricCollection({
            'R2': R2Score(),
            'MSE': MeanSquaredError()
        })
        self.val_metrics = metrics.clone(prefix='val_')
    
    def forward(self, x):
        return self.model(x)
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x).squeeze()
        self.val_metrics.update(y_hat, y)
    
    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()
        self.log_dict(metrics)
        self.val_metrics.reset()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

11. 参考资源

官方文档

进阶阅读

  • Metric API 设计原理
  • 自定义分布式指标
  • 高级回调函数集成

好问题!让我详细解释 self.log()self.log_dict() 的区别:

self.log和self.log_dict的区别

核心区别

python 复制代码
# self.log() - 记录单个标量值
self.log('val_loss', loss_value)

# self.log_dict() - 记录多个标量值(字典)
self.log_dict({'val_loss': loss_value, 'val_acc': acc_value})

1. 详细对比

特性 self.log() self.log_dict()
输入类型 单个键值对 字典
适用场景 记录单个指标 批量记录多个指标
调用次数 每个指标调用一次 一次调用记录所有指标
代码简洁度 多个指标时冗长 多个指标时简洁

2. 使用示例对比

场景1: 记录单个 loss

python 复制代码
def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.mse_loss(y_hat, y)
    
    # ✅ 方法1: 使用 self.log()
    self.log('val_loss', loss, on_epoch=True)
    
    # ✅ 方法2: 使用 self.log_dict()
    self.log_dict({'val_loss': loss}, on_epoch=True)
    
    # 结果:两种方法完全等价

场景2: 记录多个指标(重点⭐)

python 复制代码
def on_validation_epoch_end(self):
    # 假设有多个指标
    metrics = {
        'val_R2Score': 0.856,
        'val_MSE': 0.012,
        'val_RMSE': 0.111,
        'val_MAE': 0.089,
    }
    
    # ❌ 方法1: 使用 self.log() - 繁琐
    self.log('val_R2Score', metrics['val_R2Score'])
    self.log('val_MSE', metrics['val_MSE'])
    self.log('val_RMSE', metrics['val_RMSE'])
    self.log('val_MAE', metrics['val_MAE'])
    
    # ✅ 方法2: 使用 self.log_dict() - 推荐
    self.log_dict(metrics)
    
    # 一行代码完成!

3. 与 MetricCollection 配合

这就是为什么 MetricCollectionself.log_dict() 是完美搭配:

python 复制代码
class MInterface(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        # MetricCollection 返回字典
        metrics = MetricCollection({
            'R2Score': R2Score(),
            'MSE': MeanSquaredError(),
            'MAE': MeanAbsoluteError(),
        })
        self.val_metrics = metrics.clone(prefix='val_')
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        # 更新所有指标
        self.val_metrics.update(y_hat, y)
    
    def on_validation_epoch_end(self):
        # compute() 返回字典
        metrics = self.val_metrics.compute()
        # 返回格式: {'val_R2Score': 0.85, 'val_MSE': 0.12, 'val_MAE': 0.08}
        
        # ✅ 完美配合!一行搞定
        self.log_dict(metrics, on_epoch=True)
        
        # ❌ 如果用 self.log() 需要循环
        # for key, value in metrics.items():
        #     self.log(key, value, on_epoch=True)
        
        self.val_metrics.reset()

4. 参数对比

两个方法支持相同的参数,但使用方式不同:

self.log() 的参数

python 复制代码
self.log(
    name='val_loss',           # 指标名称
    value=loss,                # 指标值(标量)
    prog_bar=True,             # 是否显示在进度条
    logger=True,               # 是否记录到logger
    on_step=False,             # 是否记录每个step
    on_epoch=True,             # 是否记录每个epoch
    reduce_fx='mean',          # 聚合方式
    sync_dist=False            # 是否跨GPU同步
)

self.log_dict() 的参数

python 复制代码
self.log_dict(
    dictionary={               # 指标字典
        'val_loss': loss,
        'val_acc': acc
    },
    prog_bar=True,             # 应用到所有指标
    logger=True,
    on_step=False,
    on_epoch=True,
    reduce_fx='mean',
    sync_dist=False
)

关键区别:

  • self.log_dict() 的参数会应用到字典中的所有指标
  • 如果需要为不同指标设置不同参数,必须分别调用 self.log()

5. 实际场景选择指南

使用 self.log() 的场景

python 复制代码
# 场景1: 单个指标
def training_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    self.log('train_loss', loss, on_step=True, on_epoch=True)
    return loss

# 场景2: 需要不同的显示设置
def validation_step(self, batch, batch_idx):
    loss = self.compute_loss(batch)
    acc = self.compute_acc(batch)
    
    # loss 显示在进度条
    self.log('val_loss', loss, prog_bar=True)
    
    # acc 不显示在进度条
    self.log('val_acc', acc, prog_bar=False)

使用 self.log_dict() 的场景

python 复制代码
# 场景1: MetricCollection(最常见)
def on_validation_epoch_end(self):
    metrics = self.val_metrics.compute()
    self.log_dict(metrics, on_epoch=True)
    self.val_metrics.reset()

# 场景2: 多个相同配置的指标
def validation_step(self, batch, batch_idx):
    results = {
        'val_loss': loss,
        'val_mse': mse,
        'val_mae': mae
    }
    # 所有指标都显示在进度条
    self.log_dict(results, prog_bar=True)

# 场景3: 手动构建指标字典
def test_step(self, batch, batch_idx):
    metrics = {
        'test_r2': self.compute_r2(),
        'test_mse': self.compute_mse(),
        'test_mae': self.compute_mae()
    }
    self.log_dict(metrics)

6. 混合使用示例

在实际项目中,通常会混合使用:

python 复制代码
class MInterface(pl.LightningModule):
    def __init__(self):
        super().__init__()
        metrics = MetricCollection({
            'R2Score': R2Score(),
            'MSE': MeanSquaredError(),
            'MAE': MeanAbsoluteError(),
        })
        self.val_metrics = metrics.clone(prefix='val_')
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        
        # ✅ 单独记录 loss(需要在进度条显示)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        
        # 更新指标集合
        self.val_metrics.update(y_hat, y)
    
    def on_validation_epoch_end(self):
        metrics = self.val_metrics.compute()
        
        # ✅ 批量记录指标(不显示在进度条)
        self.log_dict(metrics, on_epoch=True, prog_bar=False)
        
        # 或者:选择性显示某个指标
        self.log('val_R2Score', metrics['val_R2Score'], prog_bar=True)
        
        self.val_metrics.reset()

7. 常见误区

误区1: log_dict 可以记录嵌套字典

python 复制代码
# ❌ 错误:log_dict 不支持嵌套字典
metrics = {
    'validation': {
        'loss': 0.1,
        'acc': 0.9
    }
}
self.log_dict(metrics)  # ❌ 报错!

# ✅ 正确:扁平化字典
metrics = {
    'val_loss': 0.1,
    'val_acc': 0.9
}
self.log_dict(metrics)  # ✅ OK

误区2: log_dict 中的值必须是标量

python 复制代码
# ❌ 错误:值必须是标量,不能是张量
metrics = {
    'val_preds': torch.tensor([1, 2, 3])  # ❌ 不是标量
}
self.log_dict(metrics)  # ❌ 报错!

# ✅ 正确:转换为标量
metrics = {
    'val_mean_pred': torch.tensor([1, 2, 3]).mean()  # ✅ 标量
}
self.log_dict(metrics)  # ✅ OK

8. 性能对比

python 复制代码
import time

# 性能测试:记录100个指标
metrics = {f'metric_{i}': i * 0.01 for i in range(100)}

# 方法1: 循环调用 self.log()
start = time.time()
for key, value in metrics.items():
    self.log(key, value)
time1 = time.time() - start

# 方法2: 一次调用 self.log_dict()
start = time.time()
self.log_dict(metrics)
time2 = time.time() - start

# 结果:self.log_dict() 通常更快(内部优化)
print(f"self.log() 循环: {time1:.4f}s")
print(f"self.log_dict(): {time2:.4f}s")

总结速查表

特性 self.log() self.log_dict()
输入 name, value dictionary
适用场景 1-2个指标 3+个指标
与MetricCollection 需要循环 ✅ 完美配合
进度条个性化 ✅ 支持 ❌ 统一设置
代码简洁度 多指标时差 ✅ 简洁
性能 多次调用 ✅ 一次调用

推荐做法:

python 复制代码
# loss - 使用 self.log()(需要特殊配置)
self.log('val_loss', loss, prog_bar=True)

# 多个指标 - 使用 self.log_dict()(简洁)
self.log_dict(self.val_metrics.compute())
相关推荐
苛子2 小时前
谷云科技发布API × AI 战略是什么?
大数据·人工智能
GEO AI搜索优化助手2 小时前
数据共振:GEO与SEO的算法协同与智能决策系统
人工智能·算法·搜索引擎·生成式引擎优化·ai优化·geo搜索优化
3824278272 小时前
python:selenium,CSS位置偏移反爬案例
css·python·selenium
我可以将你更新哟2 小时前
【PyQT-4】QListWidget列表控件、QComboBox下拉列表控件、QTableWidget表格控件
开发语言·python·pyqt
七夜zippoe2 小时前
Python上下文管理器与with语句深度应用:从入门到企业级实战
python·异常处理·with·contextlib·exitstack
TheSumSt2 小时前
Python丨课程笔记Part1:Python基础入门部分
开发语言·笔记·python·学习方法
张彦峰ZYF2 小时前
持续改进 RAG 应用效果:从“能用”到“好用”的系统化方法
人工智能·rag·rag进阶
yumgpkpm2 小时前
Cloudera CDP 7.3(国产CMP 鲲鹏版)平台与银行五大平台的技术对接方案
大数据·人工智能·hive·zookeeper·flink·kafka·cloudera
亚里仕多德2 小时前
启航-泽木鸟家居:打造未来之家
大数据·人工智能