CornerNet的续篇(数据处理与训练)

接续

继上一篇文章的模型复现完毕之后,就要完成对应的数据预处理以及训练,推理等步骤,整个模型才算完全复现完毕。

复制代码
CornerNet-Reproduction/
├── models/
│   ├── __init__.py
│   ├── backbone.py          # Hourglass网络
│   ├── corner_pool.py       # Corner Pooling实现
│   ├── cornernet.py         # 主网络
│   └── loss.py              # 损失函数
├── utils/
│   ├── __init__.py
│   ├── image.py             # 图像处理
│   ├── decode.py            # 推理解码
│   └── visualization.py     # 可视化工具
├── data/
│   ├── __init__.py
│   ├── coco.py              # COCO数据集
│   └── transforms.py        # 数据增强
├── configs/
│   └── cornernet_config.py  # 配置文件
├── train.py                 # 训练脚本
├── test.py                  # 测试脚本
├── demo.py                  # 演示脚本
└── requirements.txt         # 依赖包

数据增强

data/transform.py

(1)导入功能包,图像归一化的均值和标准差,基于COCO数据集

python 复制代码
import cv2
import numpy as np
import random
import torch


CORNERNET_MEAN = (0.40789654, 0.44719302, 0.47026115)  # RGB
CORNERNET_STD = (0.28863828, 0.27408164, 0.27809835)

实际上,V2版本transform已经自带了同步处理图像和标签的,自己不需要写负责的图像变换增强。这里为了锻炼自己对数据增强的理解,手搓了数据增强的代码。

python 复制代码
#组合变换类
class Compose:
    def __init__(self, transforms, strict=True):
        self.transforms = transforms
        self.strict = strict
    def __call__(self, image, target):
        for t in self.transforms:
            try:
                image, target = t(image, target)
            except Exception:
                if self.strict:
                    raise
        return image, target

随机水平翻转增强,标签同步水平翻转

python 复制代码
class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob
    def __call__(self, image, target):
        if random.random() >= self.prob:
            return image, target
        image = np.ascontiguousarray(np.fliplr(image))
        w = image.shape[1]
        boxes = target["boxes"].copy()
        boxes[:, [0, 2]] = w - boxes[:, [2, 0]]
        target["boxes"] = boxes
        return image, target

固定尺寸裁剪

python 复制代码
class RandomCropFixSize:
    """
    把图放进 size×size 画布中(不足则 padding),然后随机裁剪 size×size
    这比"先强行 resize 再 crop"更贴近 CornerNet/CenterNet 的做法(带 border 概念)
    """
    def __init__(self, size=511, pad_value=0):
        self.size = int(size)
        self.pad_value = int(pad_value)

    def __call__(self, image, target):
        h, w = image.shape[:2]
        size = self.size

        out = np.full((max(h, size), max(w, size), 3), self.pad_value, dtype=image.dtype)
        out[:h, :w] = image

        new_h, new_w = out.shape[:2]
        x0 = random.randint(0, new_w - size)
        y0 = random.randint(0, new_h - size)
        image = out[y0:y0 + size, x0:x0 + size]

        boxes = target["boxes"].copy()
        boxes[:, [0, 2]] -= x0
        boxes[:, [1, 3]] -= y0
        boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, size)
        boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, size)

        bw = boxes[:, 2] - boxes[:, 0]
        bh = boxes[:, 3] - boxes[:, 1]
        keep = (bw > 1) & (bh > 1)
        target["boxes"] = boxes[keep]
        target["labels"] = target["labels"][keep]
        return image, target

推理时候直接resize

python 复制代码
class CenterPadResize:
    """
    验证/推理:按比例缩放,使最长边= size,然后居中 pad 到 size×size
    记录 meta:scale + border(用于还原到原图坐标)
    """
    def __init__(self, size=511, pad_value=0):
        self.size = int(size)
        self.pad_value = int(pad_value)

    def __call__(self, image, target):
        orig_h, orig_w = image.shape[:2]
        size = self.size

        scale = size / max(orig_h, orig_w)
        new_w = int(orig_w * scale + 0.5)
        new_h = int(orig_h * scale + 0.5)
        resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

        out = np.full((size, size, 3), self.pad_value, dtype=resized.dtype)
        x0 = (size - new_w) // 2
        y0 = (size - new_h) // 2
        out[y0:y0 + new_h, x0:x0 + new_w] = resized

        if len(target.get("boxes", [])) > 0:
            boxes = target["boxes"].copy() * scale
            boxes[:, [0, 2]] += x0
            boxes[:, [1, 3]] += y0
            boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, size)
            boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, size)
            target["boxes"] = boxes

        target.setdefault("meta", {})
        target["meta"].update({
            "orig_size": (int(orig_h), int(orig_w)),
            "input_size": int(size),
            "scale": float(scale),
            "border": (float(x0), float(y0)),
            "resized_size": (int(new_h), int(new_w)),
        })
        return out, target

标准化图像

python 复制代码
class Normalize:
    def __init__(self, mean=CORNERNET_MEAN, std=CORNERNET_STD):
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)

    def __call__(self, image, target):
        image = image.astype(np.float32) / 255.0
        image = (image - self.mean) / self.std
        return image, target

转换为Tensor

python 复制代码
class ToTensor:
    def __call__(self, image, target):
        x = torch.from_numpy(image).permute(2, 0, 1).contiguous()
        return x, target

然后我们将训练与验证分开即可,整套代码如下所示:

python 复制代码
import cv2
import numpy as np
import random
import torch


CORNERNET_MEAN = (0.40789654, 0.44719302, 0.47026115)  # RGB
CORNERNET_STD = (0.28863828, 0.27408164, 0.27809835)


class Compose:
    def __init__(self, transforms, strict=True):
        self.transforms = transforms
        self.strict = strict

    def __call__(self, image, target):
        for t in self.transforms:
            try:
                image, target = t(image, target)
            except Exception:
                if self.strict:
                    raise
        return image, target


class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() >= self.prob:
            return image, target
        image = np.ascontiguousarray(np.fliplr(image))
        w = image.shape[1]
        boxes = target["boxes"].copy()
        boxes[:, [0, 2]] = w - boxes[:, [2, 0]]
        target["boxes"] = boxes
        return image, target


class RandomScale:
    """
    CornerNet 常见:离散尺度集合随机选
    """
    def __init__(self, scales=(0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4)):
        self.scales = scales

    def __call__(self, image, target):
        s = random.choice(self.scales)
        h, w = image.shape[:2]
        new_w = max(1, int(w * s + 0.5))
        new_h = max(1, int(h * s + 0.5))
        image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
        if len(target["boxes"]) > 0:
            target["boxes"] = target["boxes"] * s
        return image, target


class RandomCropFixSize:
    """
    把图放进 size×size 画布中(不足则 padding),然后随机裁剪 size×size
    这比"先强行 resize 再 crop"更贴近 CornerNet/CenterNet 的做法(带 border 概念)
    """
    def __init__(self, size=511, pad_value=0):
        self.size = int(size)
        self.pad_value = int(pad_value)

    def __call__(self, image, target):
        h, w = image.shape[:2]
        size = self.size

        out = np.full((max(h, size), max(w, size), 3), self.pad_value, dtype=image.dtype)
        out[:h, :w] = image

        new_h, new_w = out.shape[:2]
        x0 = random.randint(0, new_w - size)
        y0 = random.randint(0, new_h - size)
        image = out[y0:y0 + size, x0:x0 + size]

        boxes = target["boxes"].copy()
        boxes[:, [0, 2]] -= x0
        boxes[:, [1, 3]] -= y0
        boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, size)
        boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, size)

        bw = boxes[:, 2] - boxes[:, 0]
        bh = boxes[:, 3] - boxes[:, 1]
        keep = (bw > 1) & (bh > 1)
        target["boxes"] = boxes[keep]
        target["labels"] = target["labels"][keep]
        return image, target


class CenterPadResize:
    """
    验证/推理:按比例缩放,使最长边= size,然后居中 pad 到 size×size
    记录 meta:scale + border(用于还原到原图坐标)
    """
    def __init__(self, size=511, pad_value=0):
        self.size = int(size)
        self.pad_value = int(pad_value)

    def __call__(self, image, target):
        orig_h, orig_w = image.shape[:2]
        size = self.size

        scale = size / max(orig_h, orig_w)
        new_w = int(orig_w * scale + 0.5)
        new_h = int(orig_h * scale + 0.5)
        resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

        out = np.full((size, size, 3), self.pad_value, dtype=resized.dtype)
        x0 = (size - new_w) // 2
        y0 = (size - new_h) // 2
        out[y0:y0 + new_h, x0:x0 + new_w] = resized

        if len(target.get("boxes", [])) > 0:
            boxes = target["boxes"].copy() * scale
            boxes[:, [0, 2]] += x0
            boxes[:, [1, 3]] += y0
            boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], 0, size)
            boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], 0, size)
            target["boxes"] = boxes

        target.setdefault("meta", {})
        target["meta"].update({
            "orig_size": (int(orig_h), int(orig_w)),
            "input_size": int(size),
            "scale": float(scale),
            "border": (float(x0), float(y0)),
            "resized_size": (int(new_h), int(new_w)),
        })
        return out, target


class Normalize:
    def __init__(self, mean=CORNERNET_MEAN, std=CORNERNET_STD):
        self.mean = np.array(mean, dtype=np.float32)
        self.std = np.array(std, dtype=np.float32)

    def __call__(self, image, target):
        image = image.astype(np.float32) / 255.0
        image = (image - self.mean) / self.std
        return image, target


class ToTensor:
    def __call__(self, image, target):
        x = torch.from_numpy(image).permute(2, 0, 1).contiguous()
        return x, target


def get_train_transforms(input_size=511):
    return Compose([
        RandomScale(),
        RandomHorizontalFlip(0.5),
        RandomCropFixSize(size=input_size, pad_value=0),
        Normalize(),
        ToTensor(),
    ], strict=True)


def get_val_transforms(input_size=511):
    return Compose([
        CenterPadResize(size=input_size, pad_value=0),
        Normalize(),
        ToTensor(),
    ], strict=True)

数据集加载

在数据增强之前,我们还需要对数据集进行读取,主要是图像和标签的同步处理。

高斯核的生成,在实际的角点检测过程中,如果只把角点一个点当做正样本的话,条件太过于严格,实际难以维持训练稳定。为此,我们需要根据目标的大小动态的设置高斯核,只是力角点远的地方数值小,让其平滑过渡。

python 复制代码
import os
import cv2
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from pycocotools.coco import COCO


def gaussian_radius(det_size, min_overlap=0.7):
    height, width = det_size
    height = float(height)
    width = float(width)
    if height <= 0 or width <= 0:
        return 0.0

    a1 = 1
    b1 = height + width
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = max(0.0, b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + math.sqrt(sq1)) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = max(0.0, b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + math.sqrt(sq2)) / 2

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = max(0.0, b3 ** 2 - 4 * a3 * c3)
    r3 = (b3 + math.sqrt(sq3)) / 2

    return min(r1, r2, r3)


def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.0) / 2.0 for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h


def draw_gaussian(heatmap, center, radius):
    radius = int(radius)
    x, y = int(center[0]), int(center[1])
    height, width = heatmap.shape[:2]

    if radius <= 0:
        if 0 <= x < width and 0 <= y < height:
            heatmap[y, x] = 1.0
        return heatmap

    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)

    left = min(x, radius)
    right = min(width - x, radius + 1)
    top = min(y, radius)
    bottom = min(height - y, radius + 1)

    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if masked_gaussian.shape == masked_heatmap.shape:
        np.maximum(masked_heatmap, masked_gaussian, out=masked_heatmap)
        masked_heatmap[top, left] = 1.0
    return heatmap

由于都输数据变换等操作,其全部代码如下:

python 复制代码
import os
import cv2
import math
import numpy as np
import torch
from torch.utils.data import Dataset
from pycocotools.coco import COCO


def gaussian_radius(det_size, min_overlap=0.7):
    height, width = det_size
    height = float(height)
    width = float(width)
    if height <= 0 or width <= 0:
        return 0.0

    a1 = 1
    b1 = height + width
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = max(0.0, b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + math.sqrt(sq1)) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = max(0.0, b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + math.sqrt(sq2)) / 2

    a3 = 4 * min_overlap
    b3 = -2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = max(0.0, b3 ** 2 - 4 * a3 * c3)
    r3 = (b3 + math.sqrt(sq3)) / 2

    return min(r1, r2, r3)


def gaussian2D(shape, sigma=1):
    m, n = [(ss - 1.0) / 2.0 for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]
    h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h


def draw_gaussian(heatmap, center, radius):
    radius = int(radius)
    x, y = int(center[0]), int(center[1])
    height, width = heatmap.shape[:2]

    if radius <= 0:
        if 0 <= x < width and 0 <= y < height:
            heatmap[y, x] = 1.0
        return heatmap

    diameter = 2 * radius + 1
    gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)

    left = min(x, radius)
    right = min(width - x, radius + 1)
    top = min(y, radius)
    bottom = min(height - y, radius + 1)

    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if masked_gaussian.shape == masked_heatmap.shape:
        np.maximum(masked_heatmap, masked_gaussian, out=masked_heatmap)
        masked_heatmap[top, left] = 1.0
    return heatmap


class COCODatasetCornerNet(Dataset):
    """
    训练/验证都使用 COCO instances_*.json(这点对"严格对齐"很关键)
    需要 data_dir:
      - annotations/instances_train2017.json
      - train2017/*.jpg
    """
    def __init__(
        self,
        data_dir,
        split="train2017",
        transforms=None,
        input_size=511,
        output_size=128,
        down_ratio=4, #下采样率
        max_objs=128, #单张图片最大目标数
    ):
        self.data_dir = data_dir
        self.split = split
        self.transforms = transforms  #图像变换管道
        self.input_size = int(input_size)
        self.output_size = int(output_size)
        self.down_ratio = int(down_ratio)
        self.max_objs = int(max_objs)

        ann_file = os.path.join(data_dir, "annotations", f"instances_{split}.json")  #加载数据集标注文件
        if not os.path.exists(ann_file):
            raise FileNotFoundError(f"COCO annotation not found: {ann_file}")

        self.coco = COCO(ann_file)   # 使用 pycocotools 解析
        self.img_ids = sorted(self.coco.getImgIds())   #获取图像的全部ID,如COCO128的话,就是0~127

        self.cat_ids = sorted(self.coco.getCatIds())   #获取COCO的所有类别ID并进行排序,COCO128是1-90,中间有些没有,一共是80个类别
        self.cat_to_idx = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}  #将不连续的cat_ids隐射为连续的,建立隐射字典cat_to_idx = {1: 0, 2: 1, ..., 90: 79}
        self.idx_to_cat = {i: cat_id for cat_id, i in self.cat_to_idx.items()}  # 反向隐射idx_to_cat = {0: 1, 1: 2, ..., 79: 90}
        self.num_classes = len(self.cat_ids)

        # 过滤无标注图(训练更稳定;验证你可选择不过滤)
        valid = []
        for img_id in self.img_ids:  #循环所有的图片ID
            ann_ids = self.coco.getAnnIds(imgIds=img_id)   #获取标注的ID,即这张图片所有的标注ID
            if len(ann_ids) > 0:   #如果不为0,则表示这个图像有过标注,即有目标的
                valid.append(img_id)  #添加图片,避免无标注的图片,这样子就可以过滤没有标注的图片了,因为那个ID被过滤掉了
        self.img_ids = valid

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]  #图像ID
        img_info = self.coco.loadImgs(img_id)[0]  #获取图像信息
        img_path = os.path.join(self.data_dir, self.split, img_info["file_name"])  #加载图像的全部路径
        image = cv2.imread(img_path)  #读取图像
        if image is None:
            raise FileNotFoundError(f"Failed to read image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  #通道转换,BGR转RGB

        ann_ids = self.coco.getAnnIds(imgIds=img_id)  #根据图像ID获取图像标注ID
        anns = self.coco.loadAnns(ann_ids)   #根据标注ID加载所有标注

        h_img, w_img = image.shape[:2]    #图像的高和宽
        boxes = []
        labels = []
        for ann in anns:
            if ann.get("iscrowd", 0) == 1:   #忽略群体标注
                continue
            x, y, w, h = ann["bbox"]   # 提取边界框并转换格式
            if w < 1 or h < 1:
                continue

            # 转换 [x,y,w,h] 为 [x1,y1,x2,y2]
            x1 = max(0.0, float(x))
            y1 = max(0.0, float(y))
            x2 = min(float(w_img), float(x + w))
            y2 = min(float(h_img), float(y + h))
            # 过滤无效框
            if x2 <= x1 or y2 <= y1:
                continue
            #有效框和有效的标签全部存起来
            boxes.append([x1, y1, x2, y2])
            labels.append(self.cat_to_idx[int(ann["category_id"])])

        boxes = np.asarray(boxes, dtype=np.float32)  #转 boxes 为 np array (float32)
        labels = np.asarray(labels, dtype=np.int64)

        target = {"boxes": boxes, "labels": labels, "image_id": int(img_id)}  #创建 target dict
        if self.transforms is not None:
            image, target = self.transforms(image, target)  ## 如缩放、翻转等

        # 输出 target(严格版)
        gt = self._gen_targets(target)   # 调用内部方法生成最终标签
        return image, gt   # 返回图片和真实的标签

# CornerNet的精华和核心所在

    def _gen_targets(self, target):
        boxes = target["boxes"]
        labels = target["labels"]

        H = self.output_size
        W = self.output_size

        # 为每个类别创建热力图
        tl_hm = np.zeros((self.num_classes, H, W), dtype=np.float32)
        br_hm = np.zeros((self.num_classes, H, W), dtype=np.float32)

        # tl 和br的索引,用于后面引导计算
        tl_inds = np.zeros((self.max_objs,), dtype=np.int64)
        br_inds = np.zeros((self.max_objs,), dtype=np.int64)

        # 偏移
        tl_regs = np.zeros((self.max_objs, 2), dtype=np.float32)
        br_regs = np.zeros((self.max_objs, 2), dtype=np.float32)
        # 掩膜 1或者0, # 初始化掩码(标记哪些位置是有效目标)
        reg_mask = np.zeros((self.max_objs,), dtype=np.float32)
        #上面所述,后续根据图片进行填充
        num = min(len(boxes), self.max_objs)   #物体数,限 max_objs,取实际与最小的为主
        for i in range(num):
            # 当前目标的类别
            cls = int(labels[i])
            x1, y1, x2, y2 = [float(v) for v in boxes[i]]  #提取坐标
            if x2 <= x1 + 1 or y2 <= y1 + 1:   #坐标如果太小,直接跳过
                continue
            # CornerNet 常用:BR 使用 (x2-1, y2-1)
            br_x_pix = x2 - 1.0
            br_y_pix = y2 - 1.0
            # 将原图坐标转换到特征图坐标(下采样),这里结果都是浮点数
            tl_x = x1 / self.down_ratio
            tl_y = y1 / self.down_ratio
            br_x = br_x_pix / self.down_ratio
            br_y = br_y_pix / self.down_ratio
            ## 防止坐标越界,clip():将坐标限制在[0, 特征图尺寸-1)范围内,eps:小的epsilon,避免坐标正好等于尺寸值
            eps = 1e-4
            tl_x = float(np.clip(tl_x, 0.0, W - 1 - eps))
            tl_y = float(np.clip(tl_y, 0.0, H - 1 - eps))
            br_x = float(np.clip(br_x, 0.0, W - 1 - eps))
            br_y = float(np.clip(br_y, 0.0, H - 1 - eps))

            # 取整数部分(特征图上的像素位置)
            tl_xi, tl_yi = int(tl_x), int(tl_y)
            br_xi, br_yi = int(br_x), int(br_y)

            # 计算边界框在特征图上的大小
            box_w = (x2 - x1) / self.down_ratio
            box_h = (y2 - y1) / self.down_ratio

            ## 计算高斯半径
            radius = gaussian_radius((math.ceil(box_h), math.ceil(box_w)))
            radius = int(max(0, radius))
            # 在热力图上绘制高斯分布
            draw_gaussian(tl_hm[cls], (tl_xi, tl_yi), radius)
            draw_gaussian(br_hm[cls], (br_xi, br_yi), radius)

            # 记录角点在特征图上的扁平索引,将二维坐标(y,x)转换为一维索引,例如:128×128特征图上,(y=12, x=25) → 12×128+25=1561
            tl_inds[i] = tl_yi * W + tl_xi
            br_inds[i] = br_yi * W + br_xi

            ## 计算偏移量(浮点坐标 - 整数坐标)
            tl_regs[i, 0] = tl_x - tl_xi
            tl_regs[i, 1] = tl_y - tl_yi
            br_regs[i, 0] = br_x - br_xi
            br_regs[i, 1] = br_y - br_yi

            ## 标记这个位置是有效目标
            reg_mask[i] = 1.0

        return {
            "tl_heatmaps": torch.from_numpy(tl_hm),
            "br_heatmaps": torch.from_numpy(br_hm),
            "tl_inds": torch.from_numpy(tl_inds),
            "br_inds": torch.from_numpy(br_inds),
            "tl_regs": torch.from_numpy(tl_regs),
            "br_regs": torch.from_numpy(br_regs),
            "reg_mask": torch.from_numpy(reg_mask),
            "image_id": torch.tensor(int(target.get("image_id", -1)), dtype=torch.int64),
            "meta": target.get("meta", {}),
        }


# 批处理函数,用于将每批次打包输入,增强训练节奏。
def collate_fn(batch):
    #batch的情况如下所示:batch = [(image1, target1), (image2, target2), (image3, target3)]
    images = torch.stack([b[0] for b in batch], dim=0)  #把批次的图片都堆叠起来,所有 image 形状相同(transforms 确保),否则 stack 会失败。
    keys = ["tl_heatmaps", "br_heatmaps", "tl_inds", "br_inds", "tl_regs", "br_regs", "reg_mask", "image_id"]
    targets = {k: torch.stack([b[1][k] for b in batch], dim=0) for k in keys}
    targets["meta"] = [b[1].get("meta", {}) for b in batch]  # list,不 stack
    return images, targets

热力图可视化:


相关推荐
万粉变现经纪人2 小时前
如何解决 pip install mysqlclient 报错 ‘mysql_config’ not found 问题
数据库·python·mysql·pycharm·bug·pandas·pip
你怎么知道我是队长2 小时前
C语言---预处理器
c语言·开发语言·chrome
海棠AI实验室2 小时前
第五章 配置管理:用 YAML/ENV 让项目可迁移
python·yaml
love_summer2 小时前
流程控制进阶:从闰年判断到猜数游戏的逻辑复盘与代码实现
python
JAVA+C语言2 小时前
Java ThreadLocal 的原理
java·开发语言·python
小二·2 小时前
Python Web 开发进阶实战:全链路测试体系 —— Pytest + Playwright + Vitest 构建高可靠交付流水线
前端·python·pytest
皇族崛起2 小时前
【视觉多模态】基于视觉AI的人物轨迹生成方案
人工智能·python·计算机视觉·图文多模态·视觉多模态
HealthScience2 小时前
常见的微调的方式有哪些?(Lora...)
vscode·python
nimadan122 小时前
**免费有声书配音软件2025推荐,高拟真度AI配音与多场景
人工智能·python