【YOLOv11 小目标检测改进】NWD 归一化 Wasserstein 距离损失函数训练详解

摘要:本文介绍 NWD(Normalized Wasserstein Distance)损失函数在 YOLOv11 中的集成方法。NWD 将边界框建模为 2D 高斯分布,用 Wasserstein 距离替代传统 IoU 衡量检测框相似度,有效解决了 IoU 对小目标位置偏移过度敏感的问题。实验表明,NWD + CIoU 混合损失可在小目标检测场景获得稳定提升。


一、YOLOv11 改进目录

本文属于 YOLOv11 损失函数改进系列 之一:

序号 改进方向 内容 适用场景
1 检测头改进 添加 P2 层小目标检测头 超小目标(<16×16像素)
2 下采样改进 SPDConv 空间转深度卷积 小目标特征保留
3 损失函数改进 Focal Loss 替代 BCE 分类损失 类别不均衡
4 损失函数改进 NWD Loss 替代/混合 CIoU 小目标回归精度提升
5 训练策略 Multi-scale training / Mosaic 优化 小目标泛化性

二、NWD 原理介绍

2.1 为什么 IoU 对小目标不公平?

传统目标检测使用 CIoU(Complete IoU) 作为边界框回归损失。对于大目标,IoU 表现良好;但对于小目标,IoU 存在两个致命缺陷:

问题一:小偏移引起大落差

假设一个 6×6 像素的小目标,仅偏移 3 个像素:

python 复制代码
中心点偏移 3px:
├─ 大目标 (200×200): IoU 从 0.97 → 0.94  (下降 3%)
└─ 小目标 (6×6):    IoU 从 0.53 → 0.06  (下降 89%!)

同样的位置偏移,对小目标的 IoU 影响是大目标的 30 倍

问题二:无重叠区域时梯度为零

当预测框与真实框没有交集时,IoU = 0,梯度也为零,模型无法判断"应该往哪个方向移动"。

2.2 NWD 核心思想:把框看作高斯分布

NWD 的出发点很简单:不要用硬边界(重叠/不重叠),而是用概率分布来建模边界框

一个边界框 (cx, cy, w, h) 被建模为一个 二维高斯分布

python 复制代码
N(μ, Σ)

其中:
  均值 μ = [cx, cy]ᵀ           ← 框的中心
  协方差 Σ = [[(w/2)², 0],
              [0, (h/2)²]]     ← 框的尺寸决定分布"胖瘦"

直观理解:边界框不是"有或无"的硬区域,而是一个中心最热、边缘衰减的概率分布

2.3 数学公式

第一步:计算 2-Wasserstein 距离

对于两个高斯分布,2-Wasserstein 距离有优雅的闭式解:

python 复制代码
W₂²(N_a, N_b) = ||μ_a - μ_b||²₂ + ||Σ_a^{1/2} - Σ_b^{1/2}||²_F

展开得:
W₂² = (cx₁-cx₂)² + (cy₁-cy₂)² + (w₁/2-w₂/2)² + (h₁/2-h₂/2)²
    = ||(cx₁, cy₁, w₁/2, h₁/2) - (cx₂, cy₂, w₂/2, h₂/2)||²

Wasserstein 距离度量的是将一个分布"搬运"到另一个分布所需的最小能量。即使两个框不重叠,也能产生有意义的距离值。

第二步:归一化为相似度

Matlab 复制代码
NWD(N_a, N_b) = exp(-√W₂² / C)

其中 C 是与数据集相关的常数(典型值 ~12.8,反映数据集中物体的平均尺度)。

第三步:转换为损失函数

Matlab 复制代码
L_NWD = 1 - NWD      (NWD ∈ (0, 1] 为相似度)

混合损失(推荐):

Matlab 复制代码
L_box = (1 - α) × L_CIoU + α × L_NWD

其中 α = nwd_ratio,建议取 0.3~0.5。

2.4 直观对比

三、适用场景

3.1 推荐使用 NWD 的场景

场景 推荐度 说明
小/微小目标检测 ⭐⭐⭐⭐⭐ NWD 核心优势场景
无人机/卫星图像 ⭐⭐⭐⭐⭐ VisDrone、AI-TOD、DOTA 等数据集
工业缺陷检测 ⭐⭐⭐⭐ 裂纹、凹坑等小缺陷
交通标志/行人 ⭐⭐⭐⭐ KITTI、CityPersons 中的小目标
医学图像 ⭐⭐⭐⭐ 细胞检测、病灶检测
遥感目标检测 ⭐⭐⭐⭐ 舰船、飞机等小目标
一般目标检测(COCO) ⭐⭐⭐ 小目标有提升,中大目标持平
大目标为主的数据集 ⭐⭐ 优势不明显,纯 CIoU 可能更好

3.2 判断标准

如果数据集中 大量目标的面积 < 图像面积的 1%(例如在 640×640 图像中 < 32×32 像素),建议使用 NWD。

四、论文与代码链接

资源 链接
论文 A Normalized Gaussian Wasserstein Distance for Tiny Object Detection
官方代码 (MMDetection) github.com/jwwangchn/NWD
本实现 (YOLOv11) ultralytics/utils/metrics.pyultralytics/utils/loss.py
训练脚本 train_nwd.py

五、详细实现步骤

5.1 修改 ultralytics/utils/metrics.py

bbox_iou 函数后添加 wasserstein_loss 函数:

Matlab 复制代码
def wasserstein_loss(
    pred_boxes: torch.Tensor,
    target_boxes: torch.Tensor,
    xywh: bool = False,
    constant: float = 12.8,
    eps: float = 1e-7,
) -> torch.Tensor:
    """Compute Normalized Wasserstein Distance (NWD) loss for bounding boxes."""
    if xywh:
        cx1, cy1, w1, h1 = pred_boxes.chunk(4, -1)
        cx2, cy2, w2, h2 = target_boxes.chunk(4, -1)
    else:
        # xyxy → cx, cy, w, h
        x1, y1, x2, y2 = pred_boxes.chunk(4, -1)
        cx1, cy1 = (x1 + x2) / 2, (y1 + y2) / 2
        w1, h1 = (x2 - x1).clamp(eps), (y2 - y1).clamp(eps)
        x1, y1, x2, y2 = target_boxes.chunk(4, -1)
        cx2, cy2 = (x1 + x2) / 2, (y1 + y2) / 2
        w2, h2 = (x2 - x1).clamp(eps), (y2 - y1).clamp(eps)

    # 2-Wasserstein distance
    w_dist = (
        (cx1 - cx2).pow(2) + (cy1 - cy2).pow(2)
        + (w1 / 2 - w2 / 2).pow(2) + (h1 / 2 - h2 / 2).pow(2)
    )

    # Normalized Wasserstein Distance
    nwd = torch.exp(-torch.sqrt(w_dist + eps) / constant)

    return 1.0 - nwd

5.2 修改 ultralytics/utils/loss.py

步骤 2.1 --- 导入 wasserstein_loss

Matlab 复制代码
from .metrics import bbox_iou, probiou, wasserstein_loss  # 添加 wasserstein_loss

步骤 2.2 --- 修改 BboxLoss.__init__,增加 nwd_ratio 参数:

Matlab 复制代码
class BboxLoss(nn.Module):
    def __init__(self, reg_max: int = 16, nwd_ratio: float = 0.0, nwd_constant: float = 12.8):
        super().__init__()
        self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
        self.nwd_ratio = nwd_ratio
        self.nwd_constant = nwd_constant

步骤 2.3 --- 修改 BboxLoss.forward,混合 NWD 损失:

Matlab 复制代码
# 原有 CIoU 损失
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

# 新增:NWD 损失混合
if self.nwd_ratio > 0:
    nwd_loss = wasserstein_loss(
        pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, constant=self.nwd_constant
    )
    loss_iou = (1.0 - self.nwd_ratio) * loss_iou + \
               self.nwd_ratio * ((nwd_loss * weight.squeeze(-1)).sum() / target_scores_sum)

步骤 2.4 --- 修改 v8DetectionLoss.__init__,传递 nwd_ratio

Matlab 复制代码
self.bbox_loss = BboxLoss(m.reg_max, nwd_ratio=getattr(h, "nwd_ratio", 0.0)).to(device)

5.3 修改 ultralytics/cfg/default.yaml

添加新的超参数:

Matlab 复制代码
nwd_ratio: 0.0   # (float) NWD loss mixing ratio. 0.0 = pure CIoU, 1.0 = pure NWD

5.4 训练

Matlab 复制代码
# 基线训练(纯 CIoU)
yolo detect train data=VOC_YOLO/data.yaml model=yolo11n.yaml nwd_ratio=0.0

# NWD 混合训练(推荐 nwd_ratio=0.3~0.5)
yolo detect train data=VOC_YOLO/data.yaml model=yolo11n.yaml nwd_ratio=0.5

# 或使用训练脚本
python train_nwd.py

5.5 结果对比

训练完成后,对比两个实验的关键指标:

实验 nwd_ratio 说明
baseline 0.0 纯 CIoU(对照)
nwd_exp 0.5 CIoU + NWD 混合

关注指标

  • mAP@50mAP@50-95:整体精度
  • AP_small:小目标精度的提升
  • • 小目标类别(如 Dent、Crack)的逐类 AP

六、超参数调优建议

参数 说明 推荐值
nwd_ratio NWD 混合比例 小目标多 → 0.50.7;均衡 → 0.30.5
nwd_constant (C) 归一化常数 默认 12.8(AI-TOD);可根据数据集物体平均尺度调整
box 框损失权重 使用 NWD 时可能需要微调,建议从默认 7.5 开始

进阶建议

    1. NWD + P2 检测头:配合 P2 层使用,P2 层输出 160×160 特征图,对小目标更敏感,与 NWD 形成互补
    1. NWD + Focal Loss:分类用 Focal Loss 处理类别不均衡,回归用 NWD 提升小目标精度
    1. NWD 替换 Label Assignment:除了作为回归损失,NWD 还可以替代 IoU 用于标签分配(TaskAlignedAssigner),实现更合理的正负样本匹配

参考资料

相关推荐
狗哥哥1 天前
知乎回答二次创作转AI 漫画/视频思路分享
人工智能
极速蜗牛1 天前
我在 Taro 小程序项目里实践的 API First + AI 编程方式
前端·人工智能·后端
桜吹雪1 天前
所有智能体架构(3):Planning(计划任务)
javascript·人工智能·langchain
武子康1 天前
调查研究-176 taste-skill:AI 编程时代,前端开发最缺的不是代码,而是品味
人工智能·openai·claude
码语智行1 天前
工具调用MCP_Server 开发梳理
人工智能
lili00121 天前
2026 企业 AI 选型新范式:OpenRouter Fusion 证明多模型融合性价比远超单模型,企业该如何重构技术栈? - 微元算力(weytoken)
java·人工智能·python·重构·ai编程
shushangyun_1 天前
汽车服务行业B2B平台+AI解决方案哪家专业:2026年最新测评
java·运维·网络·数据库·人工智能·汽车
A.说学逗唱的Coke1 天前
【大模型专题】Spring AI Alibaba × Skill 整合实战:让 AI 真正“会干活
java·人工智能·spring
米小虾1 天前
AI Agent 记忆系统:从对话记录到认知架构
人工智能·agent
-山中问答-1 天前
【智能体工具使用实战08】实战项目:代码仓库健康度分析Agent
人工智能·智能体·工具调用·工程实战