手撕ultralytics,换用Lightning训练yolo模型

YOLO 模型作为目标检测的一座高峰不必多说,快又好用。一般来说是用叫做 ultralytics 的 Python 库使用和训练 YOLO 模型。库写得非常好,能很简便地用一个函数启用模型训练。

python 复制代码
from ultralytics import YOLO

model = YOLO("yolo12l.pt")

results = model.train(
    data="/mnt/sda/data/20250312_SARDet100K/sar100k.yaml",
    epochs=100,
    imgsz=640,
)

但如果有更高的自定义需求,这种一键训练的方式就不够用了。如果能把训练代码写成以下标准的 PyTorch 训练形式,那添加自定义修改就方便多了。

python 复制代码
train_loader = DataLoader(train_dataset, batch_size=..., shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=...)

for epoch in range(...):
    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            val_loss = criterion(pred, y)

经过几周的鏖战,终于是把 ultralytics 手撕得差不多,摆脱了 model.train() 的束缚。现在能自由训练目标检测了。

总览

ultralytics 库的逻辑写得很紧凑,完全改写是相当困难的。比较现实的修改方法是借用和继承原库的一些库和方法,使用符合 ultralytics 的数据形式。

中途还遇到了个奇怪的问题。使用 torchvision 的数据增强方法会损坏 YOLO 预训练权重性能,必须用 ultralytics 的数据增强。即使是很小心地控制变量、只选择两者都有的数据增强方法,肉眼完全看不出图像和标注框差异,实验都只能得出一样的结果。那就这样吧。

Lightning 是一个辅助编写 PyTorch 训练代码的库,可以把像是训练循环封装成一个函数,不论是编写还是查阅都会轻松许多。即使没接触过 Lightning 也没关系,后文看函数名也能知道我写的啥逻辑。

本文尽可能简化代码逻辑,主要起示例作用。

数据准备

Dataset

需要构造出一个符合 ultralytics 吸怪的数据集。这个数据集需要是一个字典,包含这些键:

  • img,图片矩阵。用 Image.open() 读出来后除以 255 就能符合要求了
  • bboxes,标注框,以 xywh 形式存储的 List[List] 对象
  • cls,类别,纯数字
  • bbox_format,这个填 "xywh" 就行
  • normalized,填 True
  • ori_shape,原始图片大小
  • ratio_pad,不清楚,填 None 就可以

具体实现看代码。

  • __init__() ,写有数据增强逻辑
  • __len__(),让数据集能被获取长度
  • update_labels_info(),从 ultralytics 摘抄过来用于辅助生成 label 数据的函数
  • __getitem__() 进行实际的数据构造。重点看这个函数的代码
python 复制代码
from torch.utils.data import Dataset
from ultralytics.data.augment import (
    Compose,
    Format,
    LetterBox,
    RandomPerspective,
    RandomHSV,
    RandomFlip,
)
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.instance import Instances

class MyDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

        pre_transform = RandomPerspective(
            degrees=0.0,
            translate=0.0,
            scale=0.5,
            shear=0.0,
            perspective=0.0,
            pre_transform=LetterBox(new_shape=(512, 512), scaleup=False),
        )
        self.transforms = Compose(
            [
                pre_transform,
                RandomHSV(hgain=0.015, sgain=0.7, vgain=0.4),
                RandomFlip(direction="vertical", p=0.0),
                RandomFlip(direction="horizontal", p=0.5),
            ]
        )
        self.transforms.append(
            Format(
                bbox_format="xywh",
                normalize=True,
                return_mask=False,
                return_keypoint=False,
                return_obb=False,
                batch_idx=True,
                mask_ratio=4,
                mask_overlap=True,
                bgr=0.0,
            )
        )

    def __len__(self):
        return len(self.dataset)

    def update_labels_info(self, label: Dict) -> Dict:
        """
        Update label format for different tasks.

        Args:
            label (dict): Label dictionary containing bboxes, segments, keypoints, etc.

        Returns:
            (dict): Updated label dictionary with instances.

        Note:
            cls is not with bboxes now, classification and semantic segmentation need an independent cls label
            Can also support classification and semantic segmentation by adding or removing dict keys there.
        """
        bboxes = label.pop("bboxes")
        segments = label.pop("segments", [])
        keypoints = label.pop("keypoints", None)
        bbox_format = label.pop("bbox_format")
        normalized = label.pop("normalized")

        # NOTE: do NOT resample oriented boxes
        segment_resamples = 1000
        if len(segments) > 0:
            # make sure segments interpolate correctly if original length is greater than segment_resamples
            max_len = max(len(s) for s in segments)
            segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
            # list[np.array(segment_resamples, 2)] * num_samples
            segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
        else:
            segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)

        bboxes = bboxes if bboxes.size else np.zeros((0, 4), dtype=np.float32)
        label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)

        return label

    def __getitem__(self, idx):
        image_path, annotations = self.dataset[idx]

        with Image.open(image_path) as img:
            this_img = img.convert("RGB")

        original_size = this_img.size

        boxes = []
        classes = []
        for one_box in annotations:
            bbox = one_box["bbox"]
            category_id = one_box["category_id"]
            x, y, w, h = bbox
            boxes.append([x, y, w, h])
            classes.append([category_id])

        bboxes = np.array(boxes, dtype=np.float32)
        cls = np.array(classes, dtype=np.float32)

        label = {
            'img': np.array(this_img),
            'bboxes': bboxes,
            'cls': cls,
            'bbox_format': 'xywh',
            'normalized': True,
            'ori_shape': original_size,
            'ratio_pad': None,
        }
        label = self.update_labels_info(label)
        label = self.transforms(label)

        label["img"] = label["img"] / 255.0

        return label

DataLoader

从 ultralytics 摘抄 collate_fn(),之后要传入到 DataLoader 代替默认 collator。

python 复制代码
def collate_fn(batch: List[Dict]) -> Dict:
    """
    Collate data samples into batches.

    Args:
        batch (List[dict]): List of dictionaries containing sample data.

    Returns:
        (dict): Collated batch with stacked tensors.
    """
    new_batch = {}
    batch = [dict(sorted(b.items())) for b in batch]  # make sure the keys are in the same order
    keys = batch[0].keys()
    values = list(zip(*[list(b.values()) for b in batch]))
    for i, k in enumerate(keys):
        value = values[i]
        if k in {"img", "text_feats"}:
            value = torch.stack(value, 0)
        elif k == "visuals":
            value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
        if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
            value = torch.cat(value, 0)
        new_batch[k] = value
    new_batch["batch_idx"] = list(new_batch["batch_idx"])
    for i in range(len(new_batch["batch_idx"])):
        new_batch["batch_idx"][i] += i  # add target image index for build_targets()
    new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
    
    return new_batch

实例化 dataloader。

python 复制代码
from torch.utils.data import DataLoader

train_dataset = MyDataset(train_dataset)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True,
)

模型定义

这一步会给原来的 YOLO 模型套层壳,方便后面使用。

  • __init__(),用比较别扭的方式初始化模型并加载预训练权重
  • forward(),输入图像进行正向传播。注意,在 train 状态下,会输出 loss_out;在 eval 状态下,会输出 (inference_out, loss_out)
  • get_loss(),输入 batch 数据和 loss_out,输出 loss
  • get_bboxes,输入 inference_out,输出 bboxes。会用 non_max_suppression 处理 bbox
python 复制代码
from types import SimpleNamespace
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import ops

class YOLOModule(DetectionModel):
    def __init__(self, num_class, channels, model="yolo11n.pt", pretrained=False):
        model = YOLO(model)
        cfg = model.yaml

        args = model.args
        args.update(
            {
                "box": 7.5,
                "cls": 0.5,
                "dfl": 1.5,
            }
        )
        self.args = SimpleNamespace(**args)
        self.overrides = args

        super().__init__(cfg, nc=num_class, ch=channels, verbose=False)
        if pretrained:
            self.load(model.model)

    def forward(self, x):
        preds = self.predict(x)
        return preds

    def get_loss(self, batch, preds):
        return self.loss(batch, preds)[0]

    def get_bboxes(self, preds):
        preds = ops.non_max_suppression(
            preds,
            conf_thres=0.25,
            iou_thres=0.7,
            max_det=300,
            return_idxs=False,
        )
        return preds

训练代码 / Lightning Module 定义

以下代码主要看 training_step()validation_step() 的逻辑,看是如何得到最终的 loss 的(Lightning 会帮忙调用 loss.backward() 等函数)。

python 复制代码
class LightningModel(BaseModule):

    def __init__(self, model):
        super().__init__()

        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch['img']
        batch_size = x.shape[0]
        loss_out = self(x)

        loss = self.model.get_loss(
            batch=batch,
            preds=loss_out,
        )

        box_loss, cls_loss, dfl_loss = loss / batch_size
        loss = box_loss + cls_loss + dfl_loss
        self.log('train/loss', loss, on_step=True, on_epoch=True, batch_size=batch_size, sync_dist=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch['img']
        batch_size = x.shape[0]
        inference_out, loss_out = self(x)

        loss = self.model.get_loss(
            batch=batch,
            preds=loss_out,
        )

        box_loss, cls_loss, dfl_loss = loss / batch_size
        loss = box_loss + cls_loss + dfl_loss
        self.log('val/loss', loss, on_step=False, on_epoch=True, batch_size=batch_size, sync_dist=True)

        return loss
相关推荐
地平线开发者24 分钟前
理想汽车智驾方案介绍专题 1 端到端+VLM 方案介绍
算法·自动驾驶
地平线开发者42 分钟前
征程 6 | UCP 任务优先级/抢占简介与实操
算法·自动驾驶
杰克尼1 小时前
912. 排序数组
算法
jndingxin2 小时前
OpenCV直线段检测算法类cv::line_descriptor::LSDDetector
人工智能·opencv·算法
秋说2 小时前
【PTA数据结构 | C语言版】阶乘的递归实现
c语言·数据结构·算法
小指纹3 小时前
巧用Bitset!优化dp
数据结构·c++·算法·代理模式·dp·bitset
爱Java&Java爱我4 小时前
数组:从键盘上输入10个数,合法值为1、2或3,不是这三个数则为非法数字,试编辑统计每个整数和非法数字的个数
java·开发语言·算法
是店小二呀5 小时前
【算法-BFS 解决最短路问题】探索BFS在图论中的应用:最短路径问题的高效解法
算法·图论·宽度优先
qq_513970445 小时前
力扣 hot100 Day46
算法·leetcode
满分观察网友z6 小时前
递归与迭代的优雅之舞:我在评论区功能中悟出的“树”之道(104. 二叉树的最大深度)
后端·算法