【YOLO】yolov5的训练策略2 -- 自动锚框autoanchor

目录

一、默认的锚定框

在yolov5x.yaml中查看锚定框的数据,如下:

yaml 复制代码
# yolov5中预先设定了锚定框,这些锚框是针对coco数据集的,其他目标检测也适用。
# 这些框针对的图片大小是640x640,是默认的anchor大小。
# 需要注意的是在目标检测任务中,一般使用大特征图上去检测小目标,因为大特征图含有更多小目标信息,
# 因此大特征图上的anchor数值通常设置为小数值,小特征图检测大目标,因此小特征图上anchor数值设置较大。
anchors:
  - [10,13, 16,30, 33,23]       # P3/8    最大特征图上的锚框
  - [30,61, 62,45, 59,119]      # P4/16   中等特征图上的锚框
  - [116,90, 156,198, 373,326]  # P5/32   最小特征图上的锚框

二、自动锚定框

在yolov5 中训练开始前,计算数据集标注信息针对默认锚定框的最佳召回率,当最佳召回率大于等于0.98时,则不需要更新锚定框;如果最佳召回率小于0.98,则需要重新计算数据集的锚定框,如果计算处理更好则更新原理的anchors。

代码在 yolov5/utils/autoanchor.py。

python 复制代码
def check_anchors(dataset, detect, thr=4.0, imgsz=640):
    """
    # Check anchor fit to data, recompute if necessary
    检查 anchor你和数据情况, 如果不合理重新计算anchors
    输入:
        dataset - 数据集对象
        detect - 模型的最后一层,检测头
        thr - 阈值?
        imgsz - 图片大小
    """
    m = detect
    # 例如: shape = 640 * [1280  960] / 1280
    shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
    # augment scale 随机缩放系数
    scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1))
    # 计算框的wh
    wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float()
	
	# 计算指标 bpr,aat
    def metric(k):  # compute metric
        r = wh[:, None] / k[None]
        x = torch.min(r, 1 / r).min(2)[0]  # ratio metric
        best = x.max(1)[0]  # best_x
        aat = (x > 1 / thr).float().sum(1).mean()   # anchors above threshold
        bpr = (best > 1 / thr).float().mean()       # best possible recall
        return bpr, aat

    # 计算BPR(最好的召回率)
    stride = m.stride.to(m.anchors.device).view(-1, 1, 1)   # model strides
    anchors = m.anchors.clone() * stride                    # current anchors
    bpr, aat = metric(anchors.cpu().view(-1, 2))

    s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
    # 如果召回率>0.98,当前的anchors是正常的
    # 否则就修正anchors
    if bpr > 0.98:
        logging.info(f'{s}Current anchors are a good fit to dataset ✅')
    else:
        logging.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
        na = m.anchors.numel() // 2  # number of anchors
        # 基于kmean方法重新计算 anchors
        anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
        new_bpr = metric(anchors)[0]
        # 如果新的anchors更好就替换
        if new_bpr > bpr:  # replace anchors
            anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
            m.anchors[:] = anchors.clone().view_as(m.anchors)
            check_anchor_order(m)  # must be in pixel-space (not grid-space)
            m.anchors /= stride
            s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
        else:
            s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
        logging.info(s)

三、基于kmean计算训练数据集新的anchors

代码在 yolov5/utils/autoanchor.py。

python 复制代码
def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
    """ Creates kmeans-evolved anchors from training dataset

        Arguments:
            dataset: path to data.yaml, or a loaded dataset
            n: number of anchors
            img_size: image size used for training
            thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
            gen: generations to evolve anchors using genetic algorithm
            verbose: print all results

        Return:
            k: kmeans evolved anchors

        Usage:
            from utils.autoanchor import *; _ = kmean_anchors()
    """
    from scipy.cluster.vq import kmeans

    npr = np.random
    thr = 1 / thr

    def metric(k, wh):  # compute metrics
        r = wh[:, None] / k[None]
        x = torch.min(r, 1 / r).min(2)[0]  # ratio metric
        # x = wh_iou(wh, torch.tensor(k))  # iou metric
        return x, x.max(1)[0]  # x, best_x

    def anchor_fitness(k):  # mutation fitness
        _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
        return (best * (best > thr).float()).mean()  # fitness

    def print_results(k, verbose=True):
        k = k[np.argsort(k.prod(1))]  # sort small to large
        x, best = metric(k, wh0)
        bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n  # best possible recall, anch > thr
        s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
            f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
            f'past_thr={x[x > thr].mean():.3f}-mean: '
        for x in k:
            s += '%i,%i, ' % (round(x[0]), round(x[1]))
        if verbose:
            logging.info(s[:-2])
        return k

    if isinstance(dataset, str):  # *.yaml file
        with open(dataset, errors='ignore') as f:
            data_dict = yaml.safe_load(f)  # model dict
        from utils.dataloaders import LoadImagesAndLabels
        dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)

    # Get label wh
    shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
    wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])  # wh

    # Filter
    i = (wh0 < 3.0).any(1).sum()
    if i:
        logging.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size')
    wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32)  # filter > 2 pixels
    # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1)  # multiply by random scale 0-1

    # Kmeans init
    try:
        logging.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
        assert n <= len(wh)  # apply overdetermined constraint
        s = wh.std(0)  # sigmas for whitening
        k = kmeans(wh / s, n, iter=30)[0] * s  # points
        assert n == len(k)  # kmeans may return fewer points than requested if wh is insufficient or too similar
    except Exception:
        logging.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
        k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size  # random init
    wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
    k = print_results(k, verbose=False)

    # Plot
    # k, d = [None] * 20, [None] * 20
    # for i in tqdm(range(1, 21)):
    #     k[i-1], d[i-1] = kmeans(wh / s, i)  # points, mean distance
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
    # ax = ax.ravel()
    # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))  # plot wh
    # ax[0].hist(wh[wh[:, 0]<100, 0],400)
    # ax[1].hist(wh[wh[:, 1]<100, 1],400)
    # fig.savefig('wh.png', dpi=200)

    # Evolve
    f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma
    pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT)  # progress bar
    for _ in pbar:
        v = np.ones(sh)
        while (v == 1).all():  # mutate until a change occurs (prevent duplicates)
            v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
        kg = (k.copy() * v).clip(min=2.0)
        fg = anchor_fitness(kg)
        if fg > f:
            f, k = fg, kg.copy()
            pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
            if verbose:
                print_results(k, verbose)

    return print_results(k).astype(np.float32)

点赞、关注、收藏、评论

相关推荐
千天夜1 小时前
多源多点路径规划:基于启发式动态生成树算法的实现
算法·机器学习·动态规划
Anlici2 小时前
模型训练与数据分析
人工智能·机器学习
三万棵雪松6 小时前
1.系统学习-线性回归
算法·机器学习·回归·线性回归·监督学习
Easy数模6 小时前
基于LR/GNB/SVM/KNN/DT算法的鸢尾花分类和K-Means算法的聚类分析
算法·机器学习·支持向量机·分类·聚类
Zhi.C.Yue6 小时前
SVM理论推导
人工智能·机器学习·支持向量机
愚者大大6 小时前
线性分类器(KNN,SVM损失,交叉熵损失,softmax)
人工智能·机器学习·支持向量机
静静AI学堂7 小时前
Yolo11改策略:卷积改进|SAC,提升模型对小目标和遮挡目标的检测性能|即插即用
人工智能·深度学习·目标跟踪
dundunmm8 小时前
机器学习之PCA降维
机器学习·信息可视化·数据挖掘·数据分析
千天夜8 小时前
深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解
网络·人工智能·深度学习·神经网络·yolo·机器学习
一勺汤8 小时前
YOLOv8模型改进 第二十五讲 添加基于卷积调制(Convolution based Attention) 替换自注意力机制
深度学习·yolo·计算机视觉·模块·yolov8·yolov8改进·魔改