loss.py
utils\loss.py
目录
[2.def smooth_BCE(eps=0.1):](#2.def smooth_BCE(eps=0.1):)
[3.class BCEBlurWithLogitsLoss(nn.Module):](#3.class BCEBlurWithLogitsLoss(nn.Module):)
[4.class FocalLoss(nn.Module):](#4.class FocalLoss(nn.Module):)
[5.class QFocalLoss(nn.Module):](#5.class QFocalLoss(nn.Module):)
[6.class ComputeLoss:](#6.class ComputeLoss:)
[7.class ComputeLoss_NEW:](#7.class ComputeLoss_NEW:)
1.所需的库和模块
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.metrics import bbox_iou
from utils.torch_utils import de_parallel
2.def smooth_BCE(eps=0.1):
python
# 这段代码定义了一个名为 smooth_BCE 的函数,它用于计算平滑的二元交叉熵(Binary Cross-Entropy, BCE)损失函数的两个参数。在机器学习中,二元交叉熵损失函数常用于二分类问题,它衡量的是模型预测的概率分布与真实标签之间的差异。
# 这行代码定义了一个名为 smooth_BCE 的函数,它接受一个参数。
# eps :默认值为 0.1 。 eps 是一个很小的正数,用于平滑损失函数,防止在计算过程中出现数值不稳定的问题,比如除以零。
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
# return positive, negative label smoothing BCE targets
# 在原始的二元交叉熵损失函数中,正例的标签是 1 ,负例的标签是 0 。但是,为了避免在概率接近 0 或 1 时的数值不稳定问题,通常会对标签进行平滑处理,使得标签值介于 0 和 1 之间。这样做可以提高模型的泛化能力,避免在训练过程中出现过拟合。
# 1.0 - 0.5 * eps :这是平滑后的正例标签值。由于 eps 是一个很小的正数,所以这个值会非常接近 1 ,但不是完全等于 1 。
# 0.5 * eps :这是平滑后的负例标签值。同样,由于 eps 很小,这个值也会非常接近 0 ,但不是完全等于 0 。
return 1.0 - 0.5 * eps, 0.5 * eps
# smooth_BCE 函数是一个辅助函数,用于生成平滑二元交叉熵损失函数的两个参数。通过引入一个小的正数 eps ,可以避免在预测值接近 0 或 1 时梯度消失的问题,从而提高模型的训练稳定性。函数返回的两个值分别对应于平滑后的损失函数中的两个关键点,即当预测值为 0 和 1 时的损失值。
3.class BCEBlurWithLogitsLoss(nn.Module):
python
# 这段代码定义了一个名为 BCEBlurWithLogitsLoss 的 PyTorch 自定义损失函数类。这个类继承自 nn.Module ,是 PyTorch 中所有神经网络模块的基类。这个损失函数是 BCEWithLogitsLoss 的一个变种,它通过引入一个参数 alpha 来减少标签缺失的影响。
class BCEBlurWithLogitsLoss(nn.Module):
# BCEwithLogitLoss() with reduced missing label effects. BCEwithLogitLoss() 减少了缺失标签的影响。
# 这段代码是 BCEBlurWithLogitsLoss 类的构造函数( __init__ 方法),它负责初始化类的实例。
# 定义了类的构造函数,它接受一个参数。
# 1.alpha :该参数有一个默认值 0.05 ,这个参数将在类的实例化时被传递,并且用于后续的损失计算中,以调整损失值,减少标签缺失的影响。
def __init__(self, alpha=0.05):
# 调用了父类 nn.Module 的构造函数。在 Python 中,当创建一个类的子类时,通常需要调用父类的构造函数以确保父类被正确初始化。 super() 函数用于调用父类的方法。
super().__init__()
# 创建了一个 BCEWithLogitsLoss 损失函数的实例,并将其赋值给 self.loss_fcn 属性。 BCEWithLogitsLoss 是 PyTorch 中的一个损失函数,它结合了二元交叉熵损失和 Sigmoid 激活函数,适用于二元分类问题。参数 reduction='none' 指定了损失函数的约简(reduction)方式,这里设置为 'none' 表示不对损失进行任何约简,即不对损失值进行求和或平均,而是保留每个样本的损失值。
self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
# 将传入的 alpha 参数值保存到类的 self.alpha 属性中。 alpha 将在类的 forward 方法中使用,用于计算调整因子,以减少标签缺失的影响。
self.alpha = alpha
# 这个构造函数初始化了 BCEBlurWithLogitsLoss 类的两个主要属性。一个 BCEWithLogitsLoss 损失函数实例和一个 alpha 参数。这些属性将在类的其他方法中使用,特别是在 forward 方法中,用于计算和返回调整后的损失值。
# 这段代码是 BCEBlurWithLogitsLoss 类的 forward 方法,它是 PyTorch 模型中进行前向传播计算的核心方法。 forward 方法定义了如何根据输入的预测值 pred 和真实标签 true 来计算损失。
# 定义了 forward 方法,它接受两个参数。
# 1.pred :模型的预测输出,即 logits。
# 2.true :对应的真实标签。
def forward(self, pred, true):
# 使用在 __init__ 方法中初始化的 BCEWithLogitsLoss 损失函数计算 预测值 和 真实标签 之间的原始损失。
loss = self.loss_fcn(pred, true)
# 将预测的 logits 值通过 Sigmoid 函数转换为概率值。Sigmoid 函数的输出范围在 0 到 1 之间,表示预测为正类的概率。
pred = torch.sigmoid(pred) # prob from logits
# 计算 预测概率 和 真实标签 之间的差异。这个差异 dx 用于后续的调整因子计算,目的是减少只有标签缺失时的影响。
dx = pred - true # reduce only missing label effects
# 这是一行被注释掉的代码,如果取消注释,将会计算 预测概率 和 真实标签 之间的绝对差异,这样可以同时减少 标签缺失 和 错误标签 的影响。
# dx = (pred - true).abs() # reduce missing label and false label effects
# 计算一个调整因子 alpha_factor 。这个因子基于 dx 和类属性 self.alpha 计算得出。这里的公式是一个指数衰减函数,用于减少 dx 接近 1 时的损失影响,即减少当预测概率接近 1 但真实标签为 0 时(或反之)的损失惩罚。
alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
# 将原始损失 loss 乘以调整因子 alpha_factor ,以此来调整每个样本的损失值。
loss *= alpha_factor
# 返回所有样本损失值的平均值作为最终的损失输出。
return loss.mean()
# 这个方法的目的是减少在训练过程中由于标签缺失导致的影响,通过引入 alpha 参数和计算 alpha_factor 来调整损失值,使得模型在面对不准确的或缺失的标签时更加鲁棒。这种方法特别适用于那些标签可能不完整或不准确的数据集。
# 这个自定义损失函数的目的是在训练过程中减少由于标签缺失造成的影响,通过引入一个可调参数 alpha 来实现这一点。这种类型的损失函数可能在处理不平衡数据集或者标签不完整的情况时特别有用。
4.class FocalLoss(nn.Module):
python
# 这段代码定义了一个名为 FocalLoss 的 PyTorch 自定义损失函数类。这个类继承自 nn.Module ,并且封装了焦点损失(Focal Loss),它是一种用于解决类别不平衡问题的损失函数。焦点损失通过增加难以分类样本的权重和减少容易分类样本的权重来实现这一点。
# 定义了一个名为 FocalLoss 的类,它继承自 nn.Module 。
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
# 这是类的构造函数,它接受三个参数。
# 1.loss_fcn :一个损失函数,例如 nn.BCEWithLogitsLoss。
# 2.gamma :调节 易分类 和 难分类 样本权重的平衡参数,默认值为 1.5。
# 3.alpha :平衡 正负样本 权重的参数,默认值为 0.25 。
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
# 调用父类的构造函数。
super().__init__()
# 将传入的损失函数赋值给 self.loss_fcn 属性。
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
# 分别将 gamma 和 alpha 参数保存为类的属性。
self.gamma = gamma
self.alpha = alpha
# 保存原始损失函数的约简方式。
self.reduction = loss_fcn.reduction
# 设置 self.loss_fcn 的约简方式为 'none' ,这是必要的,因为焦点损失需要对每个元素单独应用。
self.loss_fcn.reduction = 'none' # required to apply FL to each element
# 这段代码是 FocalLoss 类的 forward 方法,它实现了焦点损失(Focal Loss)的计算逻辑。焦点损失是一种专门设计来解决类别不平衡问题,特别是在目标检测和分类任务中常用的损失函数。
# 定义了 forward 方法,它接受两个参数。
# 1.pred :模型的预测输出,即 logits 。
# 2.true :对应的真实标签。
def forward(self, pred, true):
# 使用 self.loss_fcn (在构造函数中传入的损失函数)计算预测值和真实标签之间的原始损失。
loss = self.loss_fcn(pred, true)
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
# 将预测的 logits 值通过 Sigmoid 函数转换为概率值。
pred_prob = torch.sigmoid(pred) # prob from logits
# 计算 p_t ,这是一个介于 0 和 1 之间的值,表示每个样本的预测概率,其中 true 为 1 表示正样本, true 为 0 表示负样本。
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
# 计算 alpha_factor ,这是一个权重因子,用于平衡正负样本的权重。对于正样本,权重为 self.alpha ;对于负样本,权重为 1 - self.alpha 。
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
# 计算 modulating_factor ,这是一个调节因子,用于增加难分类样本的权重和减少易分类样本的权重。 self.gamma 是一个大于 0 的值,通常用来控制调节因子的强度。
modulating_factor = (1.0 - p_t) ** self.gamma
# 将原始损失乘以 alpha_factor 和 modulating_factor ,以应用焦点损失。这种乘法操作使得模型在训练时更加关注那些难以分类的样本(即预测概率接近于 0.5 的样本)。
loss *= alpha_factor * modulating_factor
# 根据 self.reduction 的值,返回损失的平均值、总和或不进行约简的损失值。这是在构造函数中设置的,可以是 'mean' 、 'sum' 或 'none' 。
if self.reduction == 'mean':
# 如果 self.reduction 设置为 'mean' ,则返回所有样本损失值的平均值。
return loss.mean()
elif self.reduction == 'sum':
# 如果 self.reduction 设置为 'sum' ,则返回所有样本损失值的总和。
return loss.sum()
else: # 'none'
# 如果 self.reduction 设置为 'none' ,则返回不进行约简的损失值。
return loss
# 这个方法的目的是通过对损失函数进行加权,使得模型在训练过程中更加关注那些难以分类的样本,从而提高模型对少数类别的识别能力。这种方法特别适用于那些类别分布不均匀的数据集。
# 这个 FocalLoss 类可以与任何 PyTorch 损失函数一起使用,通过调整 gamma 和 alpha 参数,可以在训练过程中对难分类样本给予更多的关注,从而提高模型对少数类别的识别能力。
5.class QFocalLoss(nn.Module):
python
# 这段代码定义了一个名为 QFocalLoss 的 PyTorch 自定义损失函数类,它是焦点损失(Focal Loss)的一个变种,专门用于处理类别不平衡问题,并且在计算调节因子时考虑了预测概率与真实标签之间的差异。
# 定义了一个名为 QFocalLoss 的类,它继承自 nn.Module 。
class QFocalLoss(nn.Module):
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
# 这是类的构造函数,它接受三个参数。
# 1.loss_fcn :一个损失函数,例如 nn.BCEWithLogitsLoss。
# 2.gamma :调节易分类和难分类样本权重的平衡参数,默认值为 1.5。
# 3.alpha :平衡正负样本权重的参数,默认值为 0.25 。
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
# 调用父类的构造函数。
super().__init__()
# 将传入的损失函数赋值给 self.loss_fcn 属性。
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
# 分别将 gamma 和 alpha 参数保存为类的属性。
self.gamma = gamma
self.alpha = alpha
# 保存原始损失函数的约简方式。
self.reduction = loss_fcn.reduction
# 设置 self.loss_fcn 的约简方式为 'none' ,这是必要的,因为焦点损失需要对每个元素单独应用。
self.loss_fcn.reduction = 'none' # required to apply FL to each element
# 这段代码是 QFocalLoss 类的 forward 方法,它实现了质量焦点损失(Quality Focal Loss)的计算逻辑。质量焦点损失是焦点损失的一个变种,它通过考虑 预测概率 与 真实标签 之间的差异来调整损失值,从而更加关注那些难以分类的样本。
# 定义了 forward 方法,它接受两个参数
# 1.pred :模型的预测输出,即 logits 。
# 2.true :对应的真实标签。
def forward(self, pred, true):
# 使用 self.loss_fcn (在构造函数中传入的损失函数)计算预测值和真实标签之间的原始损失。
loss = self.loss_fcn(pred, true)
# 将预测的 logits 值通过 Sigmoid 函数转换为概率值。
pred_prob = torch.sigmoid(pred) # prob from logits
# 计算 alpha_factor ,这是一个权重因子,用于平衡正负样本的权重。对于正样本,权重为 self.alpha ;对于负样本,权重为 1 - self.alpha 。
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
# 计算 modulating_factor ,这是一个调节因子,用于增加 预测概率 与 真实标签 差异较大的样本的权重。这里使用绝对值来计算差异,并将其提升到 self.gamma 次幂。
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
# 将 原始损失 乘以 alpha_factor 和 modulating_factor ,以应用质量焦点损失。这种乘法操作使得模型在训练时更加关注那些 预测概率 与 真实标签 差异较大的样本。
loss *= alpha_factor * modulating_factor
# 根据 self.reduction 的值,返回损失的平均值、总和或不进行约简的损失值。这是在构造函数中设置的,可以是 'mean' 、 'sum' 或 'none' 。
if self.reduction == 'mean':
# 如果 self.reduction 设置为 'mean' ,则返回所有样本损失值的平均值。
return loss.mean()
elif self.reduction == 'sum':
# 如果 self.reduction 设置为 'sum' ,则返回所有样本损失值的总和。
return loss.sum()
else: # 'none'
# 如果 self.reduction 设置为 'none' ,则返回不进行约简的损失值。
return loss
# 这个方法的目的是通过对损失函数进行加权,使得模型在训练过程中更加关注那些难以分类的样本,从而提高模型对少数类别的识别能力。这种方法特别适用于那些类别分布不均匀的数据集。通过引入 预测概率 与 真实标签 之间的差异作为调节因子,质量焦点损失进一步增强了模型对难以分类样本的关注,提高了模型的分类性能。
# QFocalLoss 类通过引入预测概率与真实标签之间的差异作为调节因子,进一步增强了焦点损失的功能,使得模型在训练过程中更加关注那些难以分类且预测概率与真实标签差异较大的样本,从而提高模型的分类性能。
6.class ComputeLoss:
python
# 这段代码定义了一个名为 ComputeLoss 的类,它用于计算目标检测模型的损失。这个类结合了多种损失函数,包括边界框回归损失、目标性损失(objectness loss)和分类损失。
# 定义了一个名为 ComputeLoss 的新类。
class ComputeLoss:
# 这是一个类属性,用于标记是否需要根据目标的 IoU(交并比)对目标性损失进行排序。默认值为 False 。
sort_obj_iou = False
# Compute losses
# 这段代码是 ComputeLoss 类的构造函数 __init__ ,它初始化了用于计算目标检测模型损失的各种组件和参数。
# ComputeLoss 类的构造函数,它接受两个参数。
# 1.model :一个目标检测模型。
# 2.autobalance :一个布尔值,用于自动平衡不同层的损失,默认为 False 。
def __init__(self, model, autobalance=False):
# 通过访问模型的参数来确定模型运行的设备(CPU或GPU),并将其保存在 device 变量中。
device = next(model.parameters()).device # get model device
# 获取模型的超参数,这些超参数通常包含了训练过程中需要的各种设置,如学习率、优化器设置等。
h = model.hyp # hyperparameters
# Define criteria
# 定义了一个用于 分类损失 的二元交叉熵损失函数 BCEcls ,其中 pos_weight 参数用于平衡 正负样本 ,其值从超参数 h 中的 cls_pw 获取。
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
# 定义了一个用于 目标性损失 的二元交叉熵损失函数 BCEobj ,其中 pos_weight 参数用于平衡 正负样本 ,其值从超参数 h 中的 obj_pw 获取。
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
# 使用 smooth_BCE 函数创建标签平滑的正负标签目标, eps 参数从超参数 h 中的 label_smoothing 获取,用于减少模型对于类别的置信度,防止过拟合。
# def smooth_BCE(eps=0.1): -> 用于生成平滑二元交叉熵损失函数的两个参数。通过引入一个小的正数 eps ,可以避免在预测值接近 0 或 1 时梯度消失的问题,从而提高模型的训练稳定性。 -> return 1.0 - 0.5 * eps, 0.5 * eps
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
# Focal loss
# 从超参数 h 中获取焦点损失的 gamma 参数,用于调节焦点损失的强度。
g = h['fl_gamma'] # focal loss gamma
# 如果 gamma 大于0,则应用焦点损失。
if g > 0:
# 对分类损失和目标性损失应用焦点损失,以提高模型对于难分类样本的关注。
# class FocalLoss(nn.Module):
# -> 用于解决类别不平衡问题的损失函数。焦点损失通过增加难以分类样本的权重和减少容易分类样本的权重来实现这一点。
# -> def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
# 通过 de_parallel 函数获取模型的最后一层(检测模块),这通常是目标检测模型中负责生成最终预测结果的部分。
# def de_parallel(model):
# -> 将一个可能处于并行状态(例如使用 PyTorch 的 DataParallel 或 DistributedDataParallel 包装过的模型)转换回单个 GPU 或 CPU 上的模型。如果 model 是并行模型, 将返回原始的、未并行化的模型 model.module 。如果 model 不是并行模型,直接返回 model 。
# -> return model.module if is_parallel(model) else model
m = de_parallel(model).model[-1] # Detect() module
# 根据模型的层数 m.nl 设置不同层的损失权重,如果没有匹配的层数,则使用默认的权重列表。
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
# 如果启用自动平衡,则找到模型中stride为16的层的索引,否则设置为0。
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
# 保存 分类损失函数 、 目标性损失函数 、 IoU比例因子 、 超参数 和 自动平衡标志 。
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
# 保存模型的 类别数量 。
self.nc = m.nc # number of classes
# 保存模型的 层数 。
self.nl = m.nl # number of layers
# 保存模型使用的 锚点 ( anchor )。
self.anchors = m.anchors
# 保存模型运行的 设备 。
self.device = device
# 这个构造函数初始化了 ComputeLoss 类的所有必要组件,以便在训练过程中计算和返回目标检测模型的损失。
# 这段代码是 ComputeLoss 类的 __call__ 方法,它是类的主方法,用于计算给定预测 p 和目标 targets 的损失。
# 定义 __call__ 方法,它接受两个参数。
# 1.p :预测结果。
# 2.targets :目标。
def __call__(self, p, targets): # predictions, targets
# 获取 批量大小 ,即当前批次中的图像数量。
bs = p[0].shape[0] # batch size
# 初始化一个包含三个元素的张量,用于存储边 界框损失 、 目标性损失 和 分类损失 ,并将它们初始化为0。
loss = torch.zeros(3, device=self.device) # [box, obj, cls] losses
# 调用 build_targets 方法构建目标,这将返回 分类目标 tcls 、 边界框目标 tbox 和 索引 indices 。
tcls, tbox, indices = self.build_targets(p, targets) # targets
# 这段代码是 ComputeLoss 类的 __call__ 方法中的一部分,它负责计算目标检测模型中每一层的损失。
# Losses
# 开始一个循环,遍历模型输出的所有层的预测结果 p 。 i 是层的索引, pi 是第 i 层的预测结果。
for i, pi in enumerate(p): # layer index, layer predictions
# 将 indices 列表中第 i 个元素(即第 i 层的索引信息)解包到三个变量中。
# b :图像索引,表示哪些图像包含目标。
# gj :网格的 y 坐标索引,表示锚点在特征图的 y 轴位置。
# gi :网格的 x 坐标索引,表示锚点在特征图的 x 轴位置。
b, gj, gi = indices[i] # image, anchor, gridy, gridx
# 初始化一个形状为 (pi.shape[0], pi.shape[2], pi.shape[3]) 的张量 tobj ,用于存储目标性损失的目标值(即 IoU 比率),并将其数据类型和设备设置与 pi 相同。
tobj = torch.zeros((pi.shape[0], pi.shape[2], pi.shape[3]), dtype=pi.dtype, device=self.device) # tgt obj
# 获取当前层中 标签的数量 ,即当前批次中目标的数量。
n_labels = b.shape[0] # number of labels
# 如果 n_labels 大于0,即如果当前层有目标,则继续计算损失。
if n_labels:
# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
# 将第 i 层的预测结果 pi 分割成四个部分 :边界框的 x 和 y 坐标 pxy 、宽度和高度 pwh 、目标性分数(通常用于表示是否为目标的置信度)和分类概率 pcls 。
pxy, pwh, _, pcls = pi[b, :, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
# Regression
# pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
# pwh = (0.0 + (pwh - 1.09861).sigmoid() * 4) * anchors[i]
# pwh = (0.33333 + (pwh - 1.09861).sigmoid() * 2.66667) * anchors[i]
# pwh = (0.25 + (pwh - 1.38629).sigmoid() * 3.75) * anchors[i]
# pwh = (0.20 + (pwh - 1.60944).sigmoid() * 4.8) * anchors[i]
# pwh = (0.16667 + (pwh - 1.79175).sigmoid() * 5.83333) * anchors[i]
# 对边界框的 x 和 y 坐标进行 Sigmoid 激活函数处理,并进行缩放和偏移,以将预测的坐标转换到 原始图像空间 。
pxy = pxy.sigmoid() * 1.6 - 0.3
# 对边界框的宽度和高度进行 Sigmoid 激活函数处理,并进行缩放,然后乘以对应的锚点尺寸 self.anchors[i] ,以将预测的尺寸转换到 原始图像空间 。
pwh = (0.2 + pwh.sigmoid() * 4.8) * self.anchors[i]
# 将边界框的坐标和尺寸连接起来,形成预测的边界框 pbox 。
pbox = torch.cat((pxy, pwh), 1) # predicted box
# 计算 预测的边界框 pbox 和 目标边界框 tbox[i] 之间的交并比(IoU),使用 CIoU(Complete IoU)公式,并去除多余的维度。
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
# 计算 边界框损失 ,即 1 减去 IoU 的平均值,并累加到总损失的边界框损失部分 loss[0] 。
loss[0] += (1.0 - iou).mean() # box loss
# 这部分代码主要负责计算目标检测模型中每一层的边界框损失,这是目标检测任务中非常重要的一部分,因为它直接影响模型对目标位置的预测准确性。通过优化这个损失,模型可以学习更准确地定位目标。
# 这段代码处理的是目标检测任务中的目标性损失(objectness loss),即预测目标的存在与否。
# Objectness
# iou.detach() 将IoU值从计算图中分离出来,这样在后续的操作中不会对计算图产生影响。 .clamp(0) 将IoU值限制在0到1的范围内,确保所有的IoU值都是非负的。 .type(tobj.dtype) 将IoU值的数据类型转换为与目标性损失张量 tobj 相同的数据类型,以保证后续操作的数据类型一致。
iou = iou.detach().clamp(0).type(tobj.dtype)
# 一个条件判断,如果设置了 self.sort_obj_iou 为 True ,则按照IoU值对目标性损失进行排序。
if self.sort_obj_iou:
# 如果需要排序,则使用 argsort() 方法对IoU值进行排序,并获取排序后的索引 j 。
j = iou.argsort()
# 使用排序后的索引 j ,对 图像索引 b 、 网格的y坐标索引 gj 、 网格的x坐标索引 gi 以及 IoU值本身 进行排序。
b, gj, gi, iou = b[j], gj[j], gi[j], iou[j]
# 一个条件判断,用于处理 self.gr (目标性损失的权重因子)小于1的情况。
if self.gr < 1:
# 如果 self.gr 小于1,则根据 self.gr 调整IoU值,以减少高IoU值的影响,增加低IoU值的影响。这是一种平衡目标性损失的技术,使得模型不会对高置信度的预测过于自信。
iou = (1.0 - self.gr) + self.gr * iou
# 将 调整后的IoU值 赋给 目标性损失张量 tobj 对应的位置。这样,每个预测的目标性损失就与其对应的IoU值相关联,用于后续的目标性损失计算。
tobj[b, gj, gi] = iou # iou ratio
# 这段代码的目的是计算和调整目标性损失的目标值,即每个预测框与真实框的IoU比率。通过这种方式,模型可以学习更好地区分目标和背景,提高目标检测的准确性。
# 这段代码处理的是目标检测任务中的分类损失(classification loss),即预测目标的类别。
# Classification
# 一个条件判断,如果类别数量 self.nc 大于1,即如果任务是多分类问题,则计算 分类损失 。
if self.nc > 1: # cls loss (only if multiple classes)
# 使用 torch.full_like 函数创建一个与预测的分类概率 pcls 形状相同的 目标张量 t ,并将其填充为负样本标签 self.cn 。这个张量将用于存储分类损失的目标值。
t = torch.full_like(pcls, self.cn, device=self.device) # targets
# 对于每个标签,将目标张量 t 中对应位置的值设置为正样本标签 self.cp 。这里, range(n_labels) 生成一个从0到 n_labels-1 的序列, tcls[i] 是第 i 层的分类目标。
t[range(n_labels), tcls[i]] = self.cp
# 计算预测的 分类概率 pcls 和 目标张量 t 之间的二元交叉熵损失(BCE loss),并将其累加到总损失的 分类损失 部分 loss[2] 。
loss[2] += self.BCEcls(pcls, t) # cls loss
# 这段代码的目的是计算分类损失,使得模型可以学习正确地预测目标的类别。通过优化这个损失,模型可以提高其对不同类别的识别能力,从而提高目标检测的准确性。
# 这段代码处理的是目标检测任务中的目标性损失(objectness loss),即预测目标的存在与否。
# 计算 目标性损失。 self.BCEobj 是一个二元交叉熵损失函数,用于评估预测的目标性分数(通常是一个表示目标存在概率的分数)与真实目标性标签之间的差异。 pi[:, 4] 提取第 i 层预测中的目标性分数,而 tobj 是对应的目标性标签(IoU比率)。
obji = self.BCEobj(pi[:, 4], tobj)
# 将计算出的 目标性损失 obji 乘以平衡因子 self.balance[i] ,并将结果累加到总损失的第二项 loss[1] 中。平衡因子用于调整不同层或不同目标性分数的损失贡献,以解决类别不平衡问题。
loss[1] += obji * self.balance[i] # obj loss
# 一个条件判断,如果设置了 self.autobalance 为 True ,则启用自动平衡损失权重的功能。
if self.autobalance:
# 如果启用自动平衡,则根据当前的目标性损失 obji 动态调整第 i 层的平衡因子。 obji.detach().item() 获取当前目标性损失的数值,并将其从计算图中分离出来,以避免影响梯度计算。这个操作通过减少损失较小的层的权重和增加损失较大的层的权重,动态调整不同层的损失贡献,以优化训练过程。
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
# 这段代码的目的是计算目标性损失,并根据损失值动态调整平衡因子,以解决训练过程中的类别不平衡问题,并优化模型对目标存在概率的预测。通过这种方式,模型可以更加关注难以分类的目标,提高目标检测的准确性。
# 这段代码是 ComputeLoss 类的 __call__ 方法的最后部分,它负责完成损失的计算,包括自动平衡损失权重、应用超参数中的损失权重,以及返回最终的损失值。
# 这是一个条件判断,检查是否启用了自动平衡损失权重的功能。
if self.autobalance:
# 如果启用了自动平衡,那么对每个损失层的平衡因子进行归一化处理。这里, self.ssi 是stride为16的层的索引, self.balance 是一个包含每层损失权重的列表。通过将每个层的权重除以stride为16层的权重,实现自动平衡。这样做的目的是确保不同层的损失权重相对稳定,不会因为某些层的损失过大而导致训练不稳定。
self.balance = [x / self.balance[self.ssi] for x in self.balance]
# 将 边界框损失 ( loss[0] )乘以超参数中定义的 边界框损失权重 ( self.hyp['box'] )。这样可以调整边界框损失对总损失的贡献程度。
loss[0] *= self.hyp['box']
# 将 目标性损失 ( loss[1] )乘以超参数中定义的 目标性损失权重 ( self.hyp['obj'] )。这样可以调整目标性损失对总损失的贡献程度。
loss[1] *= self.hyp['obj']
# 将 分类损失 ( loss[2] )乘以超参数中定义的 分类损失权重 ( self.hyp['cls'] )。这样可以调整分类损失对总损失的贡献程度。
loss[2] *= self.hyp['cls']
# 计算 总损失 ( loss.sum() ),即将边界框损失、目标性损失和分类损失相加,然后乘以 批量大小 ( bs ),得到最终的损失值。 loss.detach() 是从计算图中分离出来的损失张量的副本,这样在返回损失值时不会影响梯度计算。这个方法返回两个值 : 最终的总损失 和 分离出来的损失张量 。
return loss.sum() * bs, loss.detach() # [box, obj, cls] losses
# 这段代码的目的是完成损失的计算和调整,确保不同部分的损失在总损失中的贡献是平衡的,并且返回最终的损失值以供优化器使用。通过这种方式,模型可以在训练过程中同时关注边界框预测的准确性、目标性分数的准确性和分类的准确性。
# 这个方法综合了边界框回归损失、目标性损失和分类损失的计算,并根据模型的超参数和平衡因子进行了调整。最终返回的是总损失和损失张量的副本,以便在训练过程中使用。
# 在目标检测任务中,损失函数是衡量模型预测与真实标注之间差异的关键部分,它指导模型参数的更新。目标检测的损失函数通常由三个主要部分组成 :
# 边界框损失 ( Localization Loss ) 、 分类损失 ( Classification Loss )和 目标性损失 ( Confidence Loss )。
# 下面分别解释这三部分的定义、计算方式及其作用 :
# 边界框损失 ( Localization Loss )
# 定义 :边界框损失衡量的是模型预测的边界框与真实边界框之间的差异。
# 计算 :常用的边界框损失包括平滑L1损失(Smooth L1 loss)和L2损失(Mean Squared Error)。平滑L1损失在误差较小时使用L2损失,在误差较大时使用L1损失,以减小异常大的预测框对训练的影响。
# 作用 :边界框损失确保模型能够准确地定位目标对象的位置。
# 分类损失 ( Classification Loss )
# 定义 :分类损失衡量的是模型对每个框的分类结果与真实标签之间的差异。
# 计算 :常用的分类损失是交叉熵损失(Cross-Entropy Loss),它衡量预测的类别概率分布与真实类别标签之间的差异。对于多类分类问题,交叉熵损失可以表示为 :L = -∑y_i·log(p_i) 。其中 p_i 是模型预测的类别概率, y_i 是真实类别标签(通常为one-hot编码)。
# 作用 :分类损失帮助模型正确地识别目标的类别。
# 目标性损失 ( Confidence Loss )
# 定义 :目标性损失衡量的是模型对预测框的置信度(即是否包含目标物体)与真实标签之间的差异。
# 计算 :通常采用二元交叉熵损失来评估一个框是否包含目标。对于每个检测框,模型会预测一个 objectness 分数,表示该框是否包含目标。二元交叉熵损失可以表示为:BCE(y, p) = -y·log(p) - (1-y)·log(1-p) 。其中 y 是目标类别标签(0或1), p 是框包含目标的概率。
# 作用 :目标性损失使模型能够区分哪些区域包含目标对象,哪些不包含。
# 综合损失函数
# 在目标检测模型中,最终的损失函数通常是多种损失的加权和 : L = λ_1·Localization_Loss + λ_2·Classification_Loss + λ_3·Confidence_Loss 。其中 λ_1 、 λ_2 、λ_3 是损失函数各部分的权重,通常需要根据具体任务进行调节。
# 这些损失函数共同作用,指导模型在目标检测任务中取得更好的性能,包括更准确的目标定位和分类。通过优化这些损失函数,模型可以更好地解决检测任务中的挑战,比如类别不平衡、正负样本不平衡、小目标检测困难等问题。
# 这段代码是 ComputeLoss 类中的 build_targets 方法,它负责根据模型的预测和真实目标构建训练目标。
# 定义 build_targets 方法,接受两个参数。
# 1.p :这是一个列表或元组,包含了模型每一层的预测输出。每一部分 pi 通常包含了该层网络对于 边界框 、 目标性分数 和 类别概率 的预测。
# 2.targets :这是一个包含了真实标注信息的张量,其形状通常是 [N, 7] ,其中 N 是标注框的数量。每一行代表一个标注框,包含了以下信息 : image 标注框所属的图像索引。 class 标注框中目标的类别。 x 、 y 、 w 、 h :标注框的中心坐标 (x, y) 、其宽度 w 和高度 h 和 锚点索引 。
def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h) 为 compute_loss() 构建目标,输入目标(图像、类、x、y、w、h)。
# 这段代码是 build_targets 方法的开始部分,它负责初始化一些变量,这些变量将被用于后续的目标构建过程。
# 获取 targets 张量的第一个维度的大小,即真实目标(标注框)的数量。这里 nt 代表 "number of targets"。
nt = targets.shape[0] # number of anchors, targets
# 初始化三个空列表,它们将被用来存储 分类目标 tcls 、 边界框目标 tbox 和 索引 indices 。这些目标和索引将用于损失函数的计算。
tcls, tbox, indices = [], [], []
# 创建一个长度为6、所有元素为1的张量 gain ,并确保它位于正确的设备上(CPU或GPU)。这个张量将被用于将目标从原始图像空间归一化到网格空间。
gain = torch.ones(6, device=self.device) # normalized to gridspace gain
# 设置一个名为 g 的变量,其值为0.3。这个值将被用作偏置值,用于确定 锚点的偏移量 。
g = 0.3 # bias
# 创建一个包含 偏移量坐标 的张量 off 。这些偏移量将被用于调整锚点的位置,以考虑不同锚点的中心点可能不在网格的精确中心。
# 确保 off 张量位于正确的设备上(CPU或GPU)。将 off 张量中的每个元素乘以偏置值 g ,以调整偏移量的大小。
off = torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device=self.device).float() * g # offsets
# off 张量中的元素表示 锚点中心 相对于 网格单元中心 的 可能偏移 。例如, [0, 0] 表示没有偏移, [1, 0] 表示向右偏移一个网格单元, [0, 1] 表示向下偏移一个网格单元,依此类推。这些偏移量有助于在构建目标时更精细地调整锚点的位置。
# 这段代码是 build_targets 方法中的一部分,它负责为模型的每一层构建目标。
# 开始一个循环,遍历模型输出的所有层。 self.nl 是模型的层数。
for i in range(self.nl):
# 获取第 i 层预测的形状。
shape = p[i].shape
# 更新 gain 张量的第2到第6个元素,以适应当前层的分辨率。 shape[3] 和 shape[2] 分别是第 i 层预测的 宽度 和 高度 。
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
# Match targets to anchors
# 这行代码将 真实目标 targets 按照 gain 进行缩放,以将目标从 原始图像空间 转换到 网格空间 。
t = targets * gain # shape(3,n,7)
# 一个条件判断,如果存在真实目标(即 nt 大于0),则进行匹配和过滤。
if nt:
# Matches
# 计算目标宽度和高度与第 i 层锚点的比率。
r = t[..., 4:6] / self.anchors[i] # wh ratio
# 找到与锚点匹配的目标。它比较 每个目标 与 每个锚点 的宽度和高度比率的最大值,如果这个最大值小于超参数 self.hyp['anchor_t'] ,则认为目标与锚点匹配。
j = torch.max(r, 1 / r).max(1)[0] < self.hyp['anchor_t'] # compare
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
# 过滤出 匹配的目标 。
t = t[j] # filter
# 这段代码是 build_targets 方法中处理偏移(offsets)的部分,它负责调整锚点(anchors)的位置以更好地匹配目标的真实位置。
# Offsets
# 提取目标张量 t 中的 网格坐标 (中心点的 x 和 y 坐标)。
gxy = t[:, 2:4] # grid xy
# 计算网格坐标的逆 gxi 主要是为了确定目标在特征图网格中的位置,并且用于计算偏移量,以便更精确地对齐锚点(anchors)和目标。以下是 gxi 的作用和为何要计算它的原因 :
# 确定目标的精确位置 : gxy 提供了目标在特征图上的网格坐标,但有时候需要知道目标相对于网格单元边缘的位置信息,这正是 gxi 提供的。 gxi 是通过从特征图的尺寸中减去 gxy 得到的,它表示目标中心到最近网格边缘的距离。
# 计算偏移量 :在目标检测中,锚点(anchors)通常预先定义在特征图的网格上。通过计算 gxi ,可以确定目标中心相对于锚点中心的位置偏移,这对于调整锚点以更好地匹配目标边界框是必要的。
# 处理边界情况 :当目标恰好位于网格单元的边界上时, gxi 可以帮助决定是使用当前网格单元还是相邻的网格单元来预测目标。这有助于处理那些位于网格边缘的目标,提高检测的准确性。
# 增强模型的鲁棒性 :通过考虑 gxi ,模型可以更好地处理那些位置不完全对齐到网格中心的目标。这增强了模型对于目标位置变化的鲁棒性。
# 优化锚点匹配 :在目标检测中,选择最佳匹配的锚点对于性能至关重要。 gxi 提供的额外位置信息可以帮助模型更精确地选择或调整锚点,以适应目标的实际位置。
# 总之,计算网格坐标的逆 gxi 是为了提供更精细的目标位置信息,这对于精确地定位目标、优化锚点匹配策略以及提高目标检测模型的整体性能都是非常重要的。
# 计算 网格坐标的逆 ,即从特征图的尺寸中减去网格坐标,得到另一个表示位置的坐标。
# 这行代码是在 build_targets 方法中计算网格坐标的逆。
# gain[[2, 3]] : gain 是一个长度为6的张量,其中包含了归一化到网格空间的增益值。 gain[[2, 3]] 提取了 gain 张量的第2和第3个元素,即特征图的宽度和高度。
# gxy : gxy 是目标的网格坐标,即目标在特征图中的中心点坐标(x, y)。
# gxi = gain[[2, 3]] - gxy :计算网格坐标的逆,即从 特征图的尺寸中 减去 网格坐标 ,得到另一个表示位置的坐标。这个逆坐标可以用于计算目标在特征图中的相对位置,有助于后续的偏移量计算。
# 通过计算网格坐标的逆,可以更准确地确定目标在特征图中的位置,从而为构建目标检测任务中的边界框目标和分类目标提供必要的信息。
# gxi 是通过从特征图的尺寸中减去 gxy 得到的,它表示 目标中心 到 最近网格边缘 的 距离 。
gxi = gain[[2, 3]] - gxy # inverse
# 对于每个目标,检查网格坐标 gxy 是否接近于整数(即是否有小数部分小于偏置值 g ),并且坐标值大于1。这用于确定哪些目标需要在 x 和 y 方向上进行偏移。
# 这行代码是在 build_targets 方法中用于确定目标的网格坐标是否需要偏移的一部分。
# gxy % 1 < g : gxy 是一个张量,包含了目标在特征图上的网格坐标(x, y)。 gxy % 1 计算每个坐标的小数部分。 gxy % 1 < g 检查每个坐标的小数部分是否小于一个给定的阈值 g (在这里是0.3)。如果一个坐标的小数部分小于这个阈值,这意味着该坐标更接近于当前的网格单元而不是下一个网格单元。
# (gxy > 1) :这个表达式检查 gxy 中的每个坐标值是否大于1。由于 gxy 表示的是归一化到 [0, 1] 区间内的坐标,这个条件检查坐标是否在网格单元的边界之外。
# (gxy % 1 < g) & (gxy > 1) :这个表达式结合了上述两个条件,使用逻辑与操作符 & 。它将返回一个布尔张量,其中 True 表示对应的坐标既满足小数部分小于 g 又大于1,这通常意味着坐标位于网格单元的边界上或者稍微超出。
# ((gxy % 1 < g) & (gxy > 1)).T : .T 是转置操作,它将布尔张量从形状 [N, 2] (其中 N 是目标的数量,2 表示 x 和 y 坐标)转置成 [2, N] ,使得第一个维度包含所有目标的 x 坐标条件,第二个维度包含所有目标的 y 坐标条件。
# j, k = ... :这里将转置后的布尔张量分解为两个独立的张量 j 和 k ,其中 j 包含所有目标的 x 坐标条件, k 包含所有目标的 y 坐标条件。
# 这两个张量 j 和 k 将被用于后续的索引操作,以选择那些需要根据网格坐标偏移的目标。这种偏移有助于确保锚点更精确地与目标的中心对齐,特别是在目标跨越多个网格单元时。
j, k = ((gxy % 1 < g) & (gxy > 1)).T
# 类似于上一步,但是针对 gxi 坐标。
# 这行代码与之前解释的 j, k = ((gxy % 1 < g) & (gxy > 1)).T 类似,但它用于计算 网格坐标的逆 gxi 的偏移条件。
# gxi % 1 < g : gxi 是一个张量,包含了目标在特征图上的网格坐标的逆。 gxi % 1 计算每个坐标的小数部分。 gxi % 1 < g 检查每个坐标的小数部分是否小于一个给定的阈值 g (在这里是0.3)。如果一个坐标的小数部分小于这个阈值,这意味着该坐标更接近于当前的网格单元而不是下一个网格单元。
# (gxi > 1) :这个表达式检查 gxi 中的每个坐标值是否大于1。由于 gxi 表示的是归一化到 [0, 1] 区间内的坐标,这个条件检查坐标是否在网格单元的边界之外。
# (gxi % 1 < g) & (gxi > 1) :这个表达式结合了上述两个条件,使用逻辑与操作符 & 。它将返回一个布尔张量,其中 True 表示对应的坐标既满足小数部分小于 g 又大于1,这通常意味着坐标位于网格单元的边界上或者稍微超出。
# ((gxi % 1 < g) & (gxi > 1)).T : .T 是转置操作,它将布尔张量从形状 [N, 2] (其中 N 是目标的数量,2 表示 x 和 y 坐标)转置成 [2, N] ,使得第一个维度包含所有目标的 x 坐标条件,第二个维度包含所有目标的 y 坐标条件。
# l, m = ... :这里将转置后的布尔张量分解为两个独立的张量 l 和 m ,其中 l 包含所有目标的 x 坐标条件, m 包含所有目标的 y 坐标条件。
# 这两个张量 l 和 m 将被用于后续的索引操作,以选择那些需要根据网格坐标的逆偏移的目标。这种偏移有助于确保锚点更精确地与目标的中心对齐,特别是在目标跨越多个网格单元时。
l, m = ((gxi % 1 < g) & (gxi > 1)).T
# 将上述得到的索引堆叠起来,形成一个包含多个索引的张量 j ,这些索引将用于选择 需要偏移的目标 。
# 这行代码是用于将多个张量合并成一个单一的张量,通常用于索引操作。
# torch.ones_like(j) :创建一个与 j 形状相同、所有元素为1的张量。这个张量将作为索引的第一个元素,通常用于表示批次维度或类似的用途。
# torch.stack((...)) : torch.stack 函数接受一个张量序列,并沿着一个新的维度将它们堆叠起来。与 torch.cat 不同, torch.stack 会增加一个新的维度,而 torch.cat 会在现有的维度上进行连接。
# torch.stack((torch.ones_like(j), j, k, l, m)) :这个函数调用将五个张量 ( torch.ones_like(j) , j , k , l , m )堆叠起来,形成一个新张量。这些张量包含了用于索引和选择特定目标或锚点的索引信息。
# j = ... : 将 torch.stack 的结果赋值给变量 j 。这个新的 j 张量现在包含了所有需要的索引信息,可以用于后续的索引操作。
# 这种索引信息通常用于从预测张量中选择特定的目标或锚点,或者在构建目标时选择特定的位置。例如,在目标检测中,可能需要根据这些索引来选择与真实目标最匹配的锚点,或者在特征图上定位目标的位置。通过将这些索引信息合并到一个张量中,可以简化后续的操作,使其更加高效和易于管理。
# j 、 k 、 l 和 m 都是形状为 [N] 的张量,其中 N 是目标的数量。 torch.ones_like(j) 也是一个形状为 [N] 的张量。
# 当使用 torch.stack 将这些张量堆叠起来时,会增加一个新的维度。因此, torch.stack((torch.ones_like(j), j, k, l, m)) 得到的新的 j 的形状将是 [5, N] 。
j = torch.stack((torch.ones_like(j), j, k, l, m))
# 将目标张量 t 重复5次,然后根据索引 j 选择出需要偏移的目标。
# 这行代码执行了两个操作 :重复(repeating)和索引(indexing)。
# t.repeat((5, 1, 1)) :
# t 是一个包含了目标信息的张量,其形状通常是 [N, 6] ,其中 N 是目标的数量,6 表示每个目标的信息(例如,包含 图像索引、类别、x、y、宽度、高度等)。
# repeat((5, 1, 1)) 将张量 t 沿着第一个维度(即目标的数量)重复5次。这样做的目的是为每个目标创建多个副本,以便为每个锚点配置生成多个目标信息。重复后,张量的形状变为 [5N, 6] 。
# t[j] :
# j 是一个索引张量,它包含了用于从重复后的张量 t 中选择特定行的索引。
# t[j] 使用 j 中的索引来选择 t 的特定行,这通常是为了选择那些与锚点匹配的目标。
# 综合来看, t = t.repeat((5, 1, 1))[j] 这行代码的作用是 :首先,将每个目标的信息重复5次,为每个目标创建多个副本。然后,使用索引 j 从这些重复的目标中选择特定的行。
# 这种操作在处理目标检测任务时很常见,特别是在锚点匹配过程中。每个锚点可能与多个目标相关联,通过重复目标信息并选择匹配的锚点,可以为每个锚点构建相应的目标信息,进而用于损失函数的计算。这样做可以提高目标检测模型的灵活性和准确性,因为它允许模型为每个锚点配置学习特定的目标信息。
t = t.repeat((5, 1, 1))[j]
# 创建一个与 gxy 形状相同的零张量,并将其与偏移张量 off 相加,然后根据索引 j 选择出对应的偏移值。
# 这行代码计算了用于调整锚点位置的偏移量。
# torch.zeros_like(gxy) :创建一个与 gxy 形状相同、所有元素为0的张量。 gxy 是一个包含了目标在特征图上的网格坐标(x, y)的张量。
# torch.zeros_like(gxy)[None] :在 torch.zeros_like(gxy) 的基础上增加一个新的维度,使其形状从 [N, 2] 变为 [1, N, 2] ,其中 N 是目标的数量。
# off[:, None] : off 是一个包含了偏移量的张量,其形状为 [5, 2] 。通过在 off 的第二个维度上增加一个新的维度,将其形状变为 [5, 1, 2] 。
# torch.zeros_like(gxy)[None] + off[:, None] :将形状为 [1, N, 2] 的零张量与形状为 [5, 1, 2] 的偏移量张量相加。由于它们的形状在广播(broadcasting)规则下是兼容的,相加后得到的张量形状为 [5, N, 2] 。这个张量包含了每个目标的5个偏移量。
# (torch.zeros_like(gxy)[None] + off[:, None])[j] :使用索引张量 j 从形状为 [5, N, 2] 的偏移量张量中选择特定的偏移量。 j 是一个包含了用于选择特定目标和偏移量的索引的张量。
# 综合来看, offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] 这行代码的作用是 :首先,为每个目标创建5个偏移量。 然后,使用索引 j 选择与特定目标和锚点匹配的偏移量。
# 这种操作在处理目标检测任务时很常见,特别是在锚点匹配过程中。通过为每个目标创建多个偏移量并选择匹配的偏移量,可以为每个锚点配置学习特定的偏移信息,进而用于调整锚点的位置,提高目标检测模型的定位精度。
# j 是一个形状为 [5, N] 的张量,它包含了用于选择特定目标和偏移量的索引。
# 当使用 j 作为索引时,实际上是在对每个 N 个目标的每个 5 个偏移量进行索引。因此, j 中的每个值 j[i, n] 都是从 [5, N, 2] 张量中选择第 i 个偏移量对应的目标 n 。
# 对于每个目标,都选择了 5 个偏移量,所以最终 offsets 的形状是 [5,N, 2] 。这里的 [5,N] 表示对每个目标都有 5 个偏移量的选择,因此总的选择数量是 5 倍于目标数量 N 。最终 offsets 的形状是 [5,N, 2] 。
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
# 如果没有目标( nt 等于0),则只使用第一个目标,并将偏移量设置为0。
else:
t = targets[0]
offsets = 0
# 这段代码的目的是为每个目标计算偏移量,以便在构建目标时调整锚点的位置。偏移量是基于目标的网格坐标和特征图的尺寸计算得出的,这样可以确保锚点更精确地与目标的真实边界框对齐。这对于提高目标检测模型的定位精度是非常重要的。
# 这段代码的目的是为模型的每一层构建目标,包括 分类目标 、 边界框目标 和 对应的索引 。这些目标将用于计算损失函数,指导模型的训练过程。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 这段代码是 build_targets 方法中定义目标检测任务中所需的各种索引和值的部分。
# Define
# 将 目标张量 t 按照第二个维度(dim=1)分割成三个部分。 t 张量包含了每个目标的完整信息,分割后得到 : bc 包含图像索引和类别的信息。 gxy 包含目标在网格中的中心坐标(x, y)。 gwh 包含目标的宽度和高度(w, h)。
bc, gxy, gwh = t.chunk(3, 1) # (image, class), grid xy, grid wh
# 将 bc 张量转置( .T ),并将数据类型转换为长整型( long() ),得到两个张量 : b 图像索引,表示每个目标属于哪张图片。 c 类别索引,表示每个目标的类别。
b, c = bc.long().T # image, class
# 计算 调整偏移后 的 目标网格坐标 。 gxy 是目标的原始网格坐标, offsets 是之前计算的偏移量。相减后得到调整后的网格坐标,并转换为长整型。
# 这行代码是用于计算调整偏移后的网格坐标。
# gxy - offsets : gxy 是一个包含了目标在特征图上的网格坐标(x, y)的张量。 offsets 是一个包含了为每个目标计算的偏移量的张量。 gxy - offsets 从 gxy 中减去 offsets ,得到调整偏移后的网格坐标。
# .long() :将调整偏移后的网格坐标转换为长整型(long)。这通常是因为网格坐标需要作为索引使用,而索引必须是整数。
# gij = ... :将计算得到的调整偏移后的网格坐标赋值给变量 gij 。这个新的 gij 张量包含了每个目标在特征图上的精确网格位置,可以用于后续的索引操作。
# 通过计算调整偏移后的网格坐标,可以更精确地确定目标在特征图上的位置,这对于构建目标检测任务中的边界框目标和分类目标是必要的。这种调整有助于提高目标检测模型的定位精度,特别是在目标跨越多个网格单元时。
# 审视 gij = (gxy - offsets).long() 这行代码中涉及的张量的形状 :
# gxy : gxy 是从目标张量 t 中提取的网格坐标(x, y),其形状为 [N, 2] ,其中 N 是目标的数量。
# offsets : offsets 是通过 (torch.zeros_like(gxy)[None] + off[:, None])[j] 计算得到的偏移量张量。根据之前的解释, torch.zeros_like(gxy)[None] 的形状为 [1, N, 2] , off[:, None] 的形状为 [5, 1, 2] 。相加后得到 [5, N, 2] 的张量,再通过索引 j (形状为 [5, N] )选择特定的偏移量,最终 offsets 的形状为 [5,N, 2] 。
# gij : gij 是通过 gxy - offsets 计算得到的调整偏移后的网格坐标。由于 gxy 的形状是 [N, 2] ,而 offsets 的形状是 [5,N, 2] ,需要确保这两个张量在相减时是兼容的。这通常意味着 gxy 需要被重复或扩展以匹配 offsets 的形状。在实际操作中, gxy 通常会被重复 5 次,以形成形状为 [5,N, 2] 的张量,然后与 offsets 相减。因此, gij 的形状也是 [5,N, 2] 。
# 综上所述, gxy 的形状是 [N, 2] , offsets 的形状是 [5,N, 2] , gij 的形状也是 [5,N, 2] 。这里 N 是目标的数量,而 [5,N] 表示对每个目标都有 5 个偏移量的选择。
gij = (gxy - offsets).long()
# 将调整后的网格坐标张量 gij 转置( .T ),得到两个张量 : gi 网格的 x 坐标索引。 gj 网格的 y 坐标索引。
# gij 的形状是 [5, N, 2] ,那么在代码 gi, gj = gij.T 中 :
# gij.T :这是 gij 张量的转置操作,它将交换 gij 的最后一个维度(2)和第一个维度(5)。因此, gij.T 的形状从 [5, N, 2] 变为 [2, N, 5] 。
# gi, gj = gij.T :转置后的张量 gij.T 被分解为两个张量 gi 和 gj 。由于 gij.T 的形状是 [2, N, 5] , gi 和 gj 的形状将是 [N, 5] 。
# 因此, gi 和 gj 的形状是 [N, 5] 。这里 gi 和 gj 分别包含了转置后张量中每个目标的 x 和 y 坐标。
gi, gj = gij.T # grid indices
# 这些索引和值是构建目标检测损失函数时必需的,因为它们指定了每个目标在特征图中的确切位置,以及它们的类别。这些信息将用于创建用于训练的目标张量,包括 边界框目标 tbox 和 分类目标 tcls ,以及用于索引预测张量的 indices 。通过这种方式,模型可以学习如何准确地预测边界框的位置和目标的类别。
# 这段代码是 build_targets 方法中的最后部分,它负责将构建好的目标检测所需的数据添加到相应的列表中,并返回这些列表。
# Append
# 将 图像索引 b 、 调整后的网格 y 坐标索引 gj 和 网格 x 坐标索引 gi 添加到 indices 列表中。 clamp_ 方法用于确保索引值在有效范围内,即它们不会小于 0 或大于特征图的高度和宽度减去 1。这里 shape[2] 和 shape[3] 分别是特征图的高度和宽度。
indices.append((b, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, grid_y, grid_x indices
# 将调整后的 边界框目标 添加到 tbox 列表中。 gxy - gij 计算得到边界框的中心坐标调整值, gwh 是边界框的宽度和高度。 torch.cat 方法将这些值连接起来形成每个边界框的目标值,并且沿着第一个维度(dim=1)进行拼接。
# 分析 tbox.append(torch.cat((gxy - gij, gwh), 1)) 这行代码中涉及的张量的形状 :
# gxy - gij : gxy 是从目标张量 t 中提取的网格坐标(x, y),其形状为 [N, 2] ,其中 N 是目标的数量。 gij 是调整偏移后的网格坐标,其形状为 [5,N, 2] 。 为了使 gxy 和 gij 的形状兼容, gxy 可以看作是 [1, N, 2],即在第一个维度上有一个隐含的1,通常会被重复 5 次,以形成形状为 [5,N, 2] 的张量。 因此, gxy - gij 的结果是一个形状为 [5,N, 2] 的张量,表示每个目标的偏移后的中心坐标。
# gwh : gwh 是从目标张量 t 中提取的宽度和高度,其形状为 [N, 2] 。 为了使 gwh 与 gxy - gij 的形状兼容, gwh 可以看作是 [1, N, 2],即在第一个维度上有一个隐含的1,通常会被重复 5 次,以形成形状为 [5,N, 2] 的张量。
# torch.cat((gxy - gij, gwh), 1) : torch.cat 函数将两个张量沿着指定的维度连接起来。在这里,将 gxy - gij 和 gwh 沿着第二个维度(dim=1)连接起来。 由于 gxy - gij 和 gwh 都是形状为 [5,N, 2] 的张量,连接后的结果是一个形状为 [5,N, 4] 的张量,表示每个目标的偏移后的中心坐标和宽度高度。
# tbox.append(...) : tbox 是一个列表,用于存储每个层的边界框目标。 在这行代码中,将形状为 [5,N, 4] 的张量添加到 tbox 列表中。
# 综上所述, gxy - gij 和 gwh 的形状都是 [5,N, 2] ,而 torch.cat((gxy - gij, gwh), 1) 的结果形状是 [5,N, 4] 。这里 N 是目标的数量,而 [5,N] 表示对每个目标都有 5 个偏移量的选择。
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
# 将类别目标 c 添加到 tcls 列表中。 c 包含了每个目标的类别信息。
tcls.append(c) # class
# 在循环结束后,方法返回三个列表 : tcls (包含每个目标的类别信息), tbox (包含每个目标的边界框信息),以及 indices (包含每个目标在特征图中的位置索引)。
return tcls, tbox, indices
# 这些返回的列表将被用于损失函数的计算,其中 tcls 和 tbox 分别作为分类损失和边界框损失的目标值, indices 用于从模型的预测中索引出对应目标的预测值,以便与目标值进行比较并计算损失。这个过程对于训练目标检测模型至关重要,因为它确保了模型能够学习如何准确地预测目标的类别和位置。
# 这个方法的目的是为模型的每层预测构建训练目标,包括分类目标、边界框目标和对应的索引。这些目标将用于计算损失函数,指导模型的训练过程。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 这个 ComputeLoss 类是一个复杂的损失计算工具,它集成了多种损失函数和自动平衡机制,用于提高目标检测模型的性能。通过调整超参数和损失权重,可以优化模型以适应不同的数据集和任务。
7.class ComputeLoss_NEW:
python
# 这段代码定义了一个名为 ComputeLoss_NEW 的类,它是用于计算目标检测模型损失的类。这个类类似于 ComputeLoss 类的基本结构,并添加了一些新的功能和细节。
# 定义了一个名为 ComputeLoss_NEW 的新类,用于计算目标检测模型的损失。
class ComputeLoss_NEW:
# 这是一个类属性,用于标记是否需要根据目标的 IoU(交并比)对目标性损失进行排序。默认值为 False 。
sort_obj_iou = False
# Compute losses
# 这段代码是 ComputeLoss_NEW 类的构造函数 __init__ ,它初始化类实例并设置计算目标检测模型损失所需的参数和属性。
# 这是类的构造函数,它接受两个参数。
# 1.model :目标检测模型实例。
# 2.autobalance :一个布尔值,用于自动平衡不同层的损失,默认为 False 。
def __init__(self, model, autobalance=False):
# 获取模型参数所在的设备(CPU或GPU),以便在后续计算中使用相同的设备。
device = next(model.parameters()).device # get model device
# 获取模型的超参数,这些参数通常包含训练过程中需要的各种设置,如学习率、优化器设置等。
h = model.hyp # hyperparameters
# Define criteria
# 初始化一个用于 分类损失 的二元交叉熵损失函数,带有正样本权重。 h['cls_pw'] 是分类损失中正样本的权重。
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
# 初始化一个用于 目标性损失 的二元交叉熵损失函数,带有正样本权重。 h['obj_pw'] 是目标性损失中正样本的权重。
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
# 使用 smooth_BCE 函数创建标签平滑的正负标签目标, eps 参数控制标签平滑的程度。
self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
# Focal loss
# 从超参数中获取焦点损失的 gamma 参数,用于调节焦点损失的强度。
g = h['fl_gamma'] # focal loss gamma
# 如果 gamma 大于0,则启用焦点损失。
if g > 0:
# 对分类损失和目标性损失应用焦点损失,以提高模型对于难分类样本的关注。
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
# 获取模型的最后一层(检测模块)。
m = de_parallel(model).model[-1] # Detect() module
# 设置不同层的损失权重, m.nl 是模型的层数。
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
# 如果启用自动平衡,则找到stride为16的层的索引。
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
# 保存 分类损失函数 、 目标性损失函数 、 IoU比例因子 、 超参数 和 自动平衡标志 。
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
# 保存模型的类别数量。
self.nc = m.nc # number of classes
# 保存模型的层数。
self.nl = m.nl # number of layers
# 保存模型使用的锚点。
self.anchors = m.anchors
# 保存模型参数所在的设备。
self.device = device
# 初始化一个基础的二元交叉熵损失函数,用于计算损失,设置 reduction='none' 以保留每个样本的损失值。
self.BCE_base = nn.BCEWithLogitsLoss(reduction='none')
# 构造函数 __init__ 初始化了 ComputeLoss_NEW 类的实例,设置了计算损失所需的各种参数和属性,包括损失函数、标签平滑、焦点损失、层权重平衡等。这些设置将用于后续的损失计算,以指导模型的训练过程。
# 这段代码是 ComputeLoss_NEW 类的 __call__ 方法,它负责计算目标检测模型的损失。
# 这是一个特殊方法,允许类的实例像函数一样被调用。它接受两个参数。
# 1.p :模型的预测结果。
# 2.targets :真实目标标签。
def __call__(self, p, targets): # predictions, targets
# 这段代码是 ComputeLoss_NEW 类的 __call__ 方法的一部分,它负责初始化损失计算过程。
# 调用 build_targets 方法,传入模型的 预测结果 p 和 真实目标 targets ,以构建训练目标。
# tcls 是 分类目标 ,包含了每个锚点对应的类别标签。
# tbox 是 边界框目标 ,包含了每个锚点对应的目标边界框。
# indices 是 索引 ,包含了每个锚点对应的图像、网格位置等信息。
tcls, tbox, indices = self.build_targets(p, targets) # targets
# 获取 批量大小 ,即当前批次中的图像数量。 p[0] 是模型输出的第一层预测结果, .shape[0] 获取该张量的第一个维度大小,即批量大小。
bs = p[0].shape[0] # batch size
# 获取 真实目标标签的数量 。 targets 是一个包含了所有真实目标信息的张量, .shape[0] 获取该张量的第一个维度大小,即标签数量。
n_labels = targets.shape[0] # number of labels
# 初始化一个包含三个元素的张量,用于存储三种损失 : 边界框损失 、 目标性损失 和 分类损失 。这些损失将被累积并用于后续的反向传播。 torch.zeros(3, device=self.device) 创建一个形状为 [3] 的张量,所有元素初始化为0,且位于与模型参数相同的设备上。
loss = torch.zeros(3, device=self.device) # [box, obj, cls] losses
# 这段代码的目的是为损失计算做准备,通过构建目标和初始化损失张量,为后续的损失计算奠定基础。
# 这段代码是 ComputeLoss_NEW 类的 __call__ 方法的一部分,它负责计算每个预测层的损失,并将结果存储在 all_loss 列表中。
# Compute all losses
# 初始化一个空列表 all_loss ,用于存储每一层的损失计算结果。
all_loss = []
# 遍历模型输出的每层预测结果 p 。 i 是层的索引, pi 是第 i 层的预测结果。
for i, pi in enumerate(p): # layer index, layer predictions
# 从 indices 列表中获取第 i 层的索引信息,包括 图像索引 b 、 锚点索引 gj 、 网格的 y 坐标 gj 和 网格的 x 坐标 gi 。
b, gj, gi = indices[i] # image, anchor, gridy, gridx
# 如果存在真实目标标签( n_labels 大于0),则继续计算损失。
if n_labels:
# 将第 i 层的 预测结果 pi 分割成四个部分 : 边界框的 x 和 y 坐标 pxy 、 宽度和高度 pwh 、 目标性分数 pobj 和 分类概率 pcls 。
pxy, pwh, pobj, pcls = pi[b, :, gj, gi].split((2, 2, 1, self.nc), 2) # target-subset of predictions
# Regression
# 计算 预测的边界框 pbox 。首先对 pxy 和 pwh 应用 Sigmoid 函数,然后根据模型的锚点 self.anchors[i] 调整尺寸。
pbox = torch.cat((pxy.sigmoid() * 1.6 - 0.3, (0.2 + pwh.sigmoid() * 4.8) * self.anchors[i]), 2)
# 计算 预测边界框 pbox 和 目标边界框 tbox[i] 之间的交并比(IoU),使用 CIoU(Complete IoU)公式。
iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(predicted_box, target_box)
# 将 IoU 值转换为 目标性损失 的目标值 obj_target ,确保值在 [0, 1] 范围内。
obj_target = iou.detach().clamp(0).type(pi.dtype) # objectness targets
# 将计算得到的损失值添加到 all_loss 列表中。这些损失值包括 :
# 边界框损失 : (1.0 - iou) * self.hyp['box'] 。
# 目标性损失 : self.BCE_base(pobj.squeeze(), torch.ones_like(obj_target)) * self.hyp['obj'] 。
# 分类损失 : self.BCE_base(pcls, F.one_hot(tcls[i], self.nc).float()).mean(2) * self.hyp['cls'] 。
# 目标性损失的目标值: obj_target 。
# 有效性标志 : tbox[i][..., 2] > 0.0 ,表示目标框的宽度是否大于0,用于过滤无效目标。
all_loss.append([(1.0 - iou) * self.hyp['box'],
self.BCE_base(pobj.squeeze(), torch.ones_like(obj_target)) * self.hyp['obj'],
self.BCE_base(pcls, F.one_hot(tcls[i], self.nc).float()).mean(2) * self.hyp['cls'],
obj_target,
tbox[i][..., 2] > 0.0]) # valid
# 这段代码的目的是为模型的每一层预测结果计算损失,并存储这些损失值,以便后续的损失聚合和反向传播。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 这段代码是 ComputeLoss_NEW 类的 __call__ 方法的一部分,它负责从所有计算的损失中选择每个标签的最低损失,以用于最终的损失计算。
# Lowest 3 losses per label
# 设置每个标签的最大匹配数量为4,这意味着对于每个标签,将选择损失最低的4个预测结果。
n_assign = 4 # top n matches
# all_loss 列表包含了所有层的损失, zip(*all_loss) 将这些损失按维度组合起来。 torch.cat(x, 1) 将这些损失沿着第二个维度(dim=1)连接起来,形成一个大的损失张量。
cat_loss = [torch.cat(x, 1) for x in zip(*all_loss)]
# 创建一个与 cat_loss[0] 形状相同、所有元素为0的布尔张量 ij ,用于标记每个标签的 最低损失 。
ij = torch.zeros_like(cat_loss[0]).bool() # top 3 mask
# 计算总损失,这里是将 边界框损失 cat_loss[0] 和 分类损失 cat_loss[2] 相加。
sum_loss = cat_loss[0] + cat_loss[2]
# torch.argsort(sum_loss, dim=1) 对总损失 sum_loss 按列(dim=1)进行排序,返回索引。 .T 转置索引,使其按行排列。 [:n_assign] 选择每个标签损失最低的前4个预测结果的索引。
for col in torch.argsort(sum_loss, dim=1).T[:n_assign]:
# ij[range(n_labels), col] = True
# 使用索引 col 更新 ij 张量,标记每个标签的最低损失。 cat_loss[4] 包含了有效性标志,用于确定哪些目标框是有效的。
ij[range(n_labels), col] = cat_loss[4][range(n_labels), col]
# 计算边界框损失,选择 ij 标记为True的损失,计算它们的平均值,并乘以层数 self.nl 。
loss[0] = cat_loss[0][ij].mean() * self.nl # box loss
# 计算分类损失,选择 ij 标记为True的损失,计算它们的平均值,并乘以层数 self.nl 。
loss[2] = cat_loss[2][ij].mean() * self.nl # cls loss
# 这段代码的目的是为每个标签选择损失最低的预测结果,并计算这些结果的平均损失,以用于模型的训练。这种方法有助于模型关注那些最有可能正确的预测,从而提高模型的性能。
# 这段代码是 ComputeLoss_NEW 类的 __call__ 方法的一部分,它负责计算目标性损失(objectness loss)。
# Obj loss
# 遍历每一层的预测结果 p 和对应的布尔张量 ij 的分块。 ij.chunk(self.nl, 1) 将 ij 张量分成 self.nl (层数)个部分,每个部分对应一层的损失。 enumerate 提供层的索引 i 。
for i, (h, pi) in enumerate(zip(ij.chunk(self.nl, 1), p)): # layer index, layer predictions
# 从 indices 列表中获取第 i 层的索引信息,包括 图像索引 b 、 锚点索引 gj 、 网格的 y 坐标 gj 和 网格的 x 坐标 gi 。
b, gj, gi = indices[i] # image, anchor, gridy, gridx
# 初始化一个形状为 (pi.shape[0], pi.shape[2], pi.shape[3]) 的张量 tobj ,用于存储目标性损失的目标值。这个张量被初始化为0,并且具有与 pi 相同的数据类型和设备。
tobj = torch.zeros((pi.shape[0], pi.shape[2], pi.shape[3]), dtype=pi.dtype, device=self.device) # obj
# 如果存在真实目标标签( n_labels 大于0),则继续计算目标性损失。
if n_labels: # if any labels
# 将 all_loss 列表中存储的目标性损失的目标值赋给 tobj 张量。这里 all_loss[i][3] 是第 i 层的目标性损失目标值, h 是 ij 张量的布尔索引。
tobj[b[h], gj[h], gi[h]] = all_loss[i][3][h]
# 计算目标性损失,使用二元交叉熵损失函数 self.BCEobj 。 pi[:, 4] 是模型预测的目标性分数。计算得到的损失乘以平衡因子 self.balance[i] 和超参数中的权重 self.hyp['obj'] ,然后累加到 loss[1] 。
loss[1] += self.BCEobj(pi[:, 4], tobj) * (self.balance[i] * self.hyp['obj'])
# 返回总损失( loss.sum() * bs ),即所有损失的总和乘以批量大小 bs ,以及损失张量的副本 loss.detach() 。
return loss.sum() * bs, loss.detach() # [box, obj, cls] losses
# 这段代码的目的是计算每一层的目标性损失,并将其累加到总损失中。通过这种方式,模型可以学习预测每个锚点是否包含目标对象。这种方法有助于模型在训练过程中关注那些最有可能正确的预测,从而提高模型的性能。
# 这个方法综合了边界框回归损失、目标性损失和分类损失的计算,并根据模型的超参数和平衡因子进行了调整。最终返回的是总损失和损失张量的副本,以便在训练过程中使用。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 这段代码定义了 ComputeLoss_NEW 类的 build_targets 方法,其目的是为模型的损失计算构建目标。
# 这是 build_targets 方法的定义,它接受两个参数。
# 1.p :模型的预测。
# 2.targets :真实目标。
def build_targets(self, p, targets):
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
# 这段代码是 ComputeLoss_NEW 类的 build_targets 方法的一部分,它负责初始化构建目标的过程。
# 获取传入的 targets 张量的第一个维度的大小,即真实目标(或锚点)的数量。这里 nt 代表 "number of targets"。
nt = targets.shape[0] # number of anchors, targets
# 初始化三个空列表,分别用于存储 分类目标 tcls 、 边界框目标 tbox 和 索引 indices 。这些列表将被填充用于后续损失计算的目标值。
tcls, tbox, indices = [], [], []
# 创建一个长度为6、所有元素为1的张量 gain ,并确保它位于模型参数相同的设备上(CPU或GPU)。这个张量将用于将目标从原始图像空间归一化到特征图空间(grid space)。
gain = torch.ones(6, device=self.device) # normalized to gridspace gain
# 设置一个名为 g 的变量,其值为0.3。这个值将被用作确定网格偏移的阈值。
g = 0.3 # bias
# 创建一个包含偏移量坐标的张量 off 。这些偏移量将被用于调整锚点的位置,以更好地匹配目标的真实位置。
# torch.tensor([...], device=self.device) 创建一个包含特定偏移模式的张量,并确保它位于模型参数相同的设备上。 .float() 将张量的数据类型转换为浮点数,以便于后续的数学运算。
# off 张量是通过 torch.tensor 创建的,其中包含一个列表的列表,每个内部列表包含两个元素。这个列表的列表定义了五个偏移量,每个偏移量有两个坐标(x 和 y)。因此, off 张量的形状是 [5, 2] 。
off = torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
],
device=self.device).float() # offsets
# off 张量中的元素表示锚点中心相对于网格单元中心的可能偏移。例如, [0, 0] 表示没有偏移, [1, 0] 表示向右偏移一个网格单元, [0, 1] 表示向下偏移一个网格单元,依此类推。这些偏移量有助于在构建目标时更精细地调整锚点的位置。
# 这段代码是 ComputeLoss_NEW 类的 build_targets 方法的一部分,它负责为每个预测层匹配锚点和目标,并计算相关的增益值。
# 遍历模型输出的每一层预测结果, self.nl 是模型中预测层的数量。
for i in range(self.nl):
# 获取第 i 层预测结果 p[i] 的形状,这通常是一个四维张量,形状为 [batch_size, num_anchors, grid_height, grid_width] 。
shape = p[i].shape
# 更新 gain 张量中的值,使其包含当前层特征图的宽度和高度。 shape[3] 和 shape[2] 分别是特征图的高度和宽度,这些值用于将目标从图像空间归一化到特征图空间。
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
# Match targets to anchors
# 将所有目标 targets 按 gain 进行缩放,以匹配特征图的尺寸。 targets 的形状是 [3, n, 7] ,其中 n 是目标的数量,7 表示目标的七个属性(包括 图像索引 、 类别 、 x 、 y 、 宽度 、 高度 和 锚点索引 )。
t = targets * gain # shape(3,n,7)
# 如果存在目标( nt 大于0),则执行锚点匹配。
if nt:
# # Matches
# 计算目标宽度和高度与第 i 层锚点的比率。 t[..., 4:6] 提取目标的宽度和高度, self.anchors[i] 是第 i 层的锚点尺寸。
r = t[..., 4:6] / self.anchors[i] # wh ratio
# 找到与锚点匹配的目标。 torch.max(r, 1 / r) 计算目标宽度和高度与锚点尺寸的最大比率, max(1)[0] 取这个最大比率的最大值。如果这个值小于超参数 self.hyp['anchor_t'] ,则认为目标与锚点匹配。
a = torch.max(r, 1 / r).max(1)[0] < self.hyp['anchor_t'] # compare
# 这是一个被注释掉的代码行,如果取消注释,它将使用宽度和高度的交并比(IoU)来确定目标和锚点之间的匹配程度。
# a = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
# 这也是一个被注释掉的代码行,如果取消注释,它将过滤出与锚点匹配的目标。
# t = t[a] # filter
# 这段代码的目的是为每个预测层匹配锚点和目标,并计算相关的增益值,以便后续构建目标和计算损失。通过这种方式,模型可以学习预测正确的边界框和类别。
# 这段代码是 ComputeLoss_NEW 类的 build_targets 方法的一部分,它负责计算目标的网格坐标偏移量,并处理不匹配的目标。
# Offsets
# 提取目标张量 t 中的网格坐标(x, y),这些坐标表示目标中心在特征图上的位置。
gxy = t[:, 2:4] # grid xy
# 计算网格坐标的逆,即从特征图的尺寸中减去网格坐标,得到另一个表示位置的坐标。
gxi = gain[[2, 3]] - gxy # inverse
# 计算需要在 x 和 y 方向上偏移的网格坐标。 gxy % 1 得到网格坐标的小数部分, gxy % 1 < g 检查小数部分是否小于偏置值 g , gxy > 1 检查网格坐标是否大于1(即是否接近下一个网格单元)。
j, k = ((gxy % 1 < g) & (gxy > 1)).T
# 类似于上一步,但是针对 gxi 坐标。
l, m = ((gxi % 1 < g) & (gxi > 1)).T
# 将上述得到的索引堆叠起来,并与匹配的目标 a 进行逻辑与操作,以确保只有匹配的目标才会被考虑。
j = torch.stack((torch.ones_like(j), j, k, l, m)) & a
# 将目标张量 t 重复5次,为每个可能的偏移量创建副本。
t = t.repeat((5, 1, 1))
# 创建一个与 gxy 形状相同的零张量,并将其与偏移张量 off 相加,得到偏移量。
# 在表达式 offsets = torch.zeros_like(gxy)[None] + off[:, None] 中,首先需要了解 gxy 和 off 的形状 : gxy 是一个形状为 [N, 2] 的张量,其中 N 是目标的数量。 off 是一个形状为 [5, 2] 的张量,包含了五个偏移量。
# 接下来,分析表达式的每个部分 :
# torch.zeros_like(gxy) :创建一个与 gxy 形状相同、所有元素为0的张量,其形状为 [N, 2] 。
# torch.zeros_like(gxy)[None] :在 torch.zeros_like(gxy) 的基础上增加一个新的维度,使其形状变为 [1, N, 2] 。
# off[:, None] :在 off 的第二个维度上增加一个新的维度,使其形状变为 [5, 1, 2] 。
# torch.zeros_like(gxy)[None] + off[:, None] :将形状为 [1, N, 2] 的零张量与形状为 [5, 1, 2] 的偏移量张量相加。由于它们的形状在广播(broadcasting)规则下是兼容的,相加后得到的张量形状为 [5, N, 2] 。
# 因此, offsets 的形状是 [5, N, 2] 。
offsets = torch.zeros_like(gxy)[None] + off[:, None] # offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
# 将不匹配的目标(即 ~j 为 True 的目标)的宽度和高度设置为0,这样在计算损失时这些目标会被忽略。
t[..., 4:6][~j] = 0.0 # move unsuitable targets far away 将不合适的目标移走 t = t.repeat((5, 1, 1))[j]
# 如果没有目标( nt 等于0),则只使用第一个目标,并将偏移量设置为0。
else:
t = targets[0]
offsets = 0
# 这段代码的目的是为每个目标计算可能的偏移量,并处理不匹配的目标。通过这种方式,模型可以更精确地定位目标,并且减少不相关目标对损失计算的影响。
# 这段代码是 ComputeLoss_NEW 类的 build_targets 方法的一部分,它负责从匹配的目标中提取和处理信息,以构建用于损失计算的目标。
# Define
# t 是一个包含了匹配目标信息的张量,形状为 [N, 7] ,其中 N 是匹配目标的数量,7 表示目标的属性(图像索引、类别、x、y、宽度、高度和 锚点索引 )。
# chunk(3, 2) 将 t 张量沿着第二个维度(dim=2)分割成3部分,每部分包含2个元素。
# bc 包含图像索引和类别信息, gxy 包含目标的网格坐标(x, y), gwh 包含目标的宽度和高度。
bc, gxy, gwh = t.chunk(3, 2) # (image, class), grid xy, grid wh
# bc.long() 将 bc 张量的数据类型转换为长整型。 transpose(0, 2) 将 bc 张量的第一个维度和第三个维度交换,使得图像索引和类别信息分别位于不同的维度。 contiguous() 确保张量在内存中是连续存储的,这对于某些 PyTorch 操作是必要的。
# 这行代码执行了以下几个操作,用于处理和转换包含 图像索引 和 类别信息 的张量 bc 。
# bc.long() :将 bc 张量的数据类型转换为长整型(long)。这是因为图像索引和类别标签通常是整数,长整型可以确保在后续操作中保持数据的完整性。
# .transpose(0, 2) : transpose 函数用于交换张量的维度。在这里, .transpose(0, 2) 交换了 bc 张量的第一个维度(通常是批次或目标维度)和第三个维度(类别信息维度)。这样做的目的是为了将类别信息和图像索引分别放在不同的维度上,便于后续处理。
# .contiguous() : contiguous 方法用于确保张量在内存中是连续存储的。在 PyTorch 中,某些操作(如 transpose )可能会导致张量在内存中不连续,这可能会影响后续操作的性能。调用 .contiguous() 可以确保张量在内存中是连续的,从而提高效率。
# 综上所述,这行代码的目的是将包含 图像索引 和 类别信息 的张量 bc 转换为长整型,交换其维度,确保其在内存中连续存储,以便后续操作可以高效地进行。最终, b 和 c 分别包含了 图像索引 和 类别标签 ,它们可以用于构建目标检测任务中的目标类别和边界框目标。
# bc 的形状将是 [N, 2] ,其中第一列是图像索引,第二列是类别。其中 N 是目标的数量。
b, c = bc.long().transpose(0, 2).contiguous() # image, class ❌⚠️ bc 只有两个维度,那么代码中的 .transpose(0, 2) 应该被替换为 .transpose(0, 1) ,这样才能正确地交换第一个和第二个维度。
# gxy - offsets 计算调整偏移后的目标网格坐标。 long() 将结果转换为长整型,因为网格坐标索引需要是整数。
# 在表达式 gij = (gxy - offsets).long() 中,首先需要了解 gxy 和 offsets 的形状 : gxy 是一个形状为 [N, 2] 的张量,其中 N 是目标的数量。 offsets 是一个形状为 [5, N, 2] 的张量,包含了每个目标的五个偏移量。
# 接下来,分析表达式的每个部分 :
# gxy - offsets :由于 gxy 的形状是 [N, 2] ,而 offsets 的形状是 [5, N, 2] ,需要确保这两个张量在相减时是兼容的。这通常意味着 gxy 需要被重复或扩展以匹配 offsets 的形状。在实际操作中, gxy 可以看作是 [1, N, 2],即在第一个维度上有一个隐含的1,通常会被重复 5 次,以形成形状为 [5,N, 2] 的张量,然后与 offsets 相减。
# 因此, gxy - offsets 的结果是一个形状为 [5,N, 2] 的张量,表示每个目标的偏移后的网格坐标。
# .long() :将相减后的结果转换为长整型(long),因为网格坐标需要作为索引使用,所以其最终形状为 [5,N, 2] 。
# 因此, gxy 的形状是 [N, 2] , offsets 的形状是 [5, N, 2] , gij 的形状是 [5,N, 2] 。这里 N 是目标的数量,而 [5,N] 表示对每个目标都有5个偏移量。
gij = (gxy - offsets).long()
# transpose(0, 2) 将 gij 张量的第一个维度和第三个维度交换,使得网格的 y 坐标和 x 坐标分别位于不同的维度。 contiguous() 确保张量在内存中是连续存储的。
# .transpose(0, 2) :这个操作将 gij 张量的第一个维度(5)和第三个维度(2)交换。因此, gij 的形状从 [5, N, 2] 变为 [2, N, 5] 。
# .contiguous() :这个操作确保张量在内存中是连续存储的,这对于某些 PyTorch 操作是必要的。
# gi, gj = ... : 交换后的张量被分解为两个张量 gi 和 gj ,它们分别包含网格的 x 坐标和 y 坐标。由于交换后的张量形状为 [2, N, 5] , gi 和 gj 的形状将是 [N, 5] 。
# 因此, gi 和 gj 的形状是 [N, 5] 。
gi, gj = gij.transpose(0, 2).contiguous() # grid indices
# 这段代码的目的是将匹配的目标信息分解成可用于构建损失目标的组件,包括图像索引、类别、调整偏移后的网格坐标等。这些信息将用于计算每个锚点的目标类别、边界框和网格索引,进而用于损失函数的计算。通过这种方式,模型可以学习预测正确的边界框和类别。
# 这段代码是 ComputeLoss_NEW 类的 build_targets 方法的一部分,它负责将处理后的目标信息添加到相应的列表中,以便后续用于损失计算。
# Append
# b 是图像索引, gj 和 gi 分别是网格的 y 坐标和 x 坐标索引。 clamp_(0, shape[2] - 1) 和 clamp_(0, shape[3] - 1) 确保网格索引在特征图的边界内,防止索引超出范围。 这个元组 (b, gj, gi) 包含了每个目标的图像索引和网格位置,被添加到 indices 列表中,用于后续的损失计算。
indices.append((b, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, grid_y, grid_x indices
# gxy - gij 计算调整偏移后的目标中心坐标。 torch.cat((gxy - gij, gwh), 2) 将调整后的中心坐标和宽度高度连接起来,形成完整的目标边界框。
# permute(1, 0, 2) 调整张量维度的顺序,使其形状为 (num_targets, 4) ,其中 4 表示边界框的四个坐标值(x, y, w, h)。 contiguous() 确保张量在内存中是连续存储的,这对于某些 PyTorch 操作是必要的。 这个处理后的边界框张量被添加到 tbox 列表中,用于后续的损失计算。
# gxy 形状是 [N, 2] , gij 形状是 [5,N, 2], gxy - gij 的操作过程如下 :
# 将 gxy 的形状从 [N, 2] 扩展为 [1, N, 2]。 将扩展后的 gxy 与 gij 进行逐元素减法。 因此,gxy - gij 的结果张量的形状与 gij 的形状相同,即 [5, N, 2] 。
# 在PyTorch中, torch.cat 函数用于将两个或多个张量沿着指定的维度拼接在一起。在该行代码中,希望将形状为 [5, N, 2] 的张量 gxy - gij 和形状为 [N, 2] 的张量 gwh 沿着第三个维度(即维度2)拼接。为了使用 torch.cat ,需要确保除了拼接的维度外,其他维度的大小必须相同或可广播。
# 在该行代码中,张量 gxy - gij 的形状是 [5, N, 2],而张量 gwh 的形状是 [N, 2]。为了使它们在除了拼接维度外的其他维度上兼容,可以将张量 gwh 的形状扩展为 [1, N, 2],这样它就可以与张量 gxy - gij 沿着第三个维度拼接了。
# 拼接后的张量形状计算如下 :
# 沿着第一个维度(维度0),大小为5(来自张量 gxy - gij )。
# 沿着第二个维度(维度1),大小为N(来自张量 gxy - gij 和 gwh )。
# 沿着第三个维度(维度2),大小为2(来自张量 gxy - gij )+ 2(来自张量 gwh )= 4。
# 因此, torch.cat((gxy - gij, gwh), 2) 的结果张量的形状是 [5, N, 4] 。
# torch.cat((gxy - gij, gwh), 2) 的结果张量的形状为 [5, N, 4] ,那么在表达式 tbox.append(torch.cat((gxy - gij, gwh), 2).permute(1, 0, 2).contiguous()) 中 :
# .permute(1, 0, 2) :这个操作将拼接后的张量的维度重新排列。原来的形状是 [5, N, 4] ,经过 .permute(1, 0, 2) 后,形状变为 [N, 5, 4] 。
# .contiguous() :这个操作确保张量在内存中是连续存储的,这对于某些 PyTorch 操作是必要的。
# 因此,最终 tbox 列表中添加的张量的形状是 [N, 5, 4] 。这里 N 表示目标的数量, 5 表示每个目标的五个偏移量, 4 表示每个目标的四个边界框坐标 (x,y,w,h) 。
tbox.append(torch.cat((gxy - gij, gwh), 2).permute(1, 0, 2).contiguous()) # box
# c 是目标的类别索引。 这个类别索引被添加到 tcls 列表中,用于后续的分类损失计算。
tcls.append(c) # class
# 这段代码的目的是将每个目标的图像索引、网格位置、边界框和类别信息收集起来,以便在损失计算中使用。这些信息对于目标检测模型的训练至关重要,因为它们定义了模型需要预测的内容。通过这种方式,模型可以学习如何准确地预测目标的位置和类别。
# 这段代码包含了一个被注释掉的部分,它原本的目的是检查构建的目标边界框是否唯一,即没有重复的目标。
# # Unique
# b.view(-1, 1) 将 图像索引 b 转换为列向量。 tbox[i].view(-1, 4) 将第 i 层的边界框 tbox[i] 转换为形状为 [-1, 4] 的张量,其中 -1 表示自动计算的维度, 4 表示边界框的四个坐标值(x, y, w, h)。 torch.cat(..., 1) 将图像索引和边界框沿着第二个维度(dim=1)连接起来,形成一个新张量。 .shape[0] 获取连接后张量的第一个维度的大小,即总的组合数量。
# n1 = torch.cat((b.view(-1, 1), tbox[i].view(-1, 4)), 1).shape[0]
# tbox[i].view(-1, 4) 同上,将边界框转换为形状为 [-1, 4] 的张量。 .unique(dim=0) 返回边界框张量中唯一的行,即唯一的边界框。 .shape[0] 获取唯一边界框张量的第一个维度的大小,即唯一的边界框数量。
# n2 = tbox[i].view(-1, 4).unique(dim=0).shape[0]
# 打印出总的组合数量 n1 和唯一边界框数量 n2 ,以及它们之间的差异 n1-n2 ,这可以帮助识别是否有重复的目标边界框。
# print(f'targets-unique {n1}-{n2} diff={n1-n2}')
# 这段代码的目的是确保构建的目标边界框是唯一的,没有重复。在目标检测任务中,确保目标的唯一性是很重要的,因为重复的目标可能会导致模型训练过程中的误差。通过检查和打印出重复的目标数量,可以帮助调试和优化目标构建过程。
# 这是方法的返回语句,它返回三个列表。 tcls (包含每个目标的类别信息) , tbox (包含每个目标的边界框信息) ,以及 indices (包含每个目标的图像索引和网格位置信息) 。
return tcls, tbox, indices
# 这个方法的目的是为模型的每一层预测构建目标,包括分类目标、边界框目标和对应的索引。这些目标将用于计算损失函数,指导模型的训练过程。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 这个类通过构建目标和计算损失,为目标检测模型的训练提供了必要的工具。通过这种方式,模型可以学习预测正确的边界框、目标类别和目标性分数。
# 在pytorch中,a形状为[N,2]、b形状为[5,N, 2],a-b的结果张量的形状。
# 在PyTorch中,进行张量减法操作时,需要遵循广播(broadcasting)规则。广播规则允许在某些维度上扩展较小张量的形状,以便与较大张量进行逐元素操作。给定的张量形状为 :
# a : [N, 2] 。
# b : [5, N, 2] 。
# 希望计算 a - b。根据广播规则,可以将a视为在第一个维度上具有大小1的张量,即 [1, N, 2]。这样,a和b的形状就兼容了,可以进行逐元素减法。
# 具体来说,a - b 的操作过程如下 :
# 将a的形状从 [N, 2] 扩展为 [1, N, 2]。
# 将扩展后的a与b进行逐元素减法。
# 因此,a - b 的结果张量的形状与b的形状相同,即 [5, N, 2]。
# 所以,a - b 的结果张量的形状是 [5, N, 2]。
# 在pytorch中,a形状为[5, N, 2]、b形状为[N, 2],torch.cat((a, b), 2)的结果张量的形状。
# 在PyTorch中, torch.cat 函数用于将两个或多个张量沿着指定的维度拼接在一起。在本例中,希望将形状为 [5, N, 2] 的张量a和形状为 [N, 2] 的张量b沿着第三个维度(即维度2)拼接。
# 为了使用 torch.cat ,需要确保除了拼接的维度外,其他维度的大小必须相同或可广播。在这个例子中,张量a的形状是 [5, N, 2],而张量b的形状是 [N, 2]。为了使它们在除了拼接维度外的其他维度上兼容,可以将张量b的形状扩展为 [1, N, 2],这样它就可以与张量a沿着第三个维度拼接了。
# 拼接后的张量形状计算如下 :
# 沿着第一个维度(维度0),大小为5(来自张量a)。
# 沿着第二个维度(维度1),大小为N(来自张量a和b)。
# 沿着第三个维度(维度2),大小为2(来自张量a)+ 2(来自张量b)= 4。
# 因此, torch.cat((a, b), 2) 的结果张量的形状是 [5, N, 4]。
# 在pytorch中,a形状为[N,2]、b形状为[5,N, 2],a*b的结果张量的形状。
# 在PyTorch中,进行张量乘法操作时,同样需要遵循广播(broadcasting)规则。广播规则允许在某些维度上扩展较小张量的形状,以便与较大张量进行逐元素操作。给定的张量形状为 :
# a :[N, 2] 。
# b :[5, N, 2]。
# 希望计算 a * b。根据广播规则,可以将a视为在第一个维度上具有大小1的张量,即 [1, N, 2]。这样,a和b的形状就兼容了,可以进行逐元素乘法。
# 具体来说,a * b 的操作过程如下 :
# 将a的形状从 [N, 2] 扩展为 [1, N, 2]。
# 将扩展后的a与b进行逐元素乘法。
# 因此,a * b 的结果张量的形状与b的形状相同,即 [5, N, 2]。
# 所以,a * b 的结果张量的形状是 [5, N, 2]。
# 在pytorch中,a形状为[N,2]、b形状为[5,N, 2],a@b的结果张量的形状。
# 在PyTorch中,使用 @ 符号表示矩阵乘法,即 torch.matmul 函数。矩阵乘法要求两个张量的形状必须满足特定的条件:第一个张量的最后一个维度的大小必须等于第二个张量的倒数第二个维度的大小。
# 给定的张量形状为 :
# a :[N, 2] 。
# b :[5, N, 2]。
# 希望计算 a @ b。根据矩阵乘法的规则,a的最后一个维度(大小为2)必须与b的倒数第二个维度(大小为N)相匹配。但是,在这个例子中,a的最后一个维度是2,而b的倒数第二个维度是N,所以这两个张量不能直接进行矩阵乘法。
# 因此,a @ b 的操作是不合法的,会抛出一个错误。所以,a @ b 的结果张量的形状是 不合法 。
# 在PyTorch中,两个张量进行矩阵乘法(使用 torch.matmul 或者 @ 操作符)时,必须满足以下条件 :
# 维度匹配 :第一个张量的最后一个维度的大小必须与第二个张量的倒数第二个维度的大小相匹配。这是矩阵乘法的基本规则,即第一个矩阵的列数必须等于第二个矩阵的行数。
# 广播兼容 :除了进行乘法的维度外,其他维度的大小必须完全相同,或者其中一个张量在该维度上的大小为1,这样就可以进行广播。
# 如果这些条件不满足,PyTorch会抛出一个错误,因为无法执行矩阵乘法。例如,如果张量a的形状是[m, n],张量b的形状是[p, q],那么为了能够进行矩阵乘法,必须有n == p。
# 总结来说,两个张量矩阵相乘必须要满足的"跳进"(条件)是 :
# 第一个张量的列数必须等于第二个张量的行数。
# 其他维度必须兼容,要么完全相同,要么其中一个维度的大小为1,以满足广播规则。