yolov8+convLSTM训练MOT16数据集

基于MOT16数据集做目标检测的预处理(类别合并与清理)-CSDN博客 修改后的标签进行训练,代码如下,效果待验证

python 复制代码
import os
import cv2
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict

# ===================== 全局配置 =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 640
SEQ_LEN = 5
NC = 3               # 有效类别数 0,1,2
NA = 3               # 每个网格3个锚框
NO = NC + 5          # 5(xywh+conf) + 类别
BATCH_SIZE = 8       # 修改为8
LR = 1e-3 * (BATCH_SIZE / 2)  # 自动线性缩放学习率
CONF_THR = 0.25
IOU_THR = 0.45

# YOLOv8n 锚框 + 步长
ANCHORS = torch.tensor([
    [[10, 13], [16, 30], [33, 23]],
    [[30, 61], [62, 45], [59, 119]],
    [[116, 90], [156, 198], [373, 326]]
], device=DEVICE)
STRIDES = torch.tensor([8, 16, 32], device=DEVICE)

# ===================== YOLOv8 基础模块 =====================
class Conv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, p if p else k//2, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
        super().__init__()
        c_ = int(c2 * e)
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C2f(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        super().__init__()
        self.c = int(c2 * e)
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g) for _ in range(n))

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

class SPPF(nn.Module):
    def __init__(self, c1, c2, k=5):
        super().__init__()
        c_ = c1 // 2
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_ * 4, c2, 1, 1)
        self.m = nn.MaxPool2d(k, stride=1, padding=k//2)

    def forward(self, x):
        x = self.cv1(x)
        y1 = self.m(x)
        y2 = self.m(y1)
        y3 = self.m(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], dim=1))

# ===================== ConvLSTM 模块 =====================
class ConvLSTMCell(nn.Module):
    def __init__(self, in_channels, hidden_channels, k=3, p=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels + hidden_channels, 4 * hidden_channels, k, 1, p)

    def forward(self, x, h, c):
        combine = torch.cat([x, h], dim=1)
        gates = self.conv(combine)
        i, f, o, g = torch.chunk(gates, 4, dim=1)
        c = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        h = torch.sigmoid(o) * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.cell = ConvLSTMCell(channels, channels)

    def forward(self, x_seq):
        B, T, C, H, W = x_seq.shape
        h = torch.zeros(B, C, H, W, device=x_seq.device)
        c = torch.zeros(B, C, H, W, device=x_seq.device)
        out = []
        for t in range(T):
            xt = x_seq[:, t]
            h, c = self.cell(xt, h, c)
            out.append(h)
        return torch.stack(out, dim=1)

# ===================== 主干网络 YOLOv8n + ConvLSTM =====================
class YOLOv8n_ConvLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.na = NA
        self.nc = NC
        self.no = NO

        # Backbone
        self.backbone = nn.Sequential(
            Conv(3, 16, 3, 2),
            Conv(16, 32, 3, 2),
            C2f(32, 32, n=1),
            Conv(32, 64, 3, 2),    # P3: 64, 80x80
            C2f(64, 64, n=2),
            Conv(64, 128, 3, 2),   # P4: 128,40x40
            C2f(128, 128, n=2),
            Conv(128, 256, 3, 2),  # P5:256,20x20
            C2f(256, 256, n=1),
            SPPF(256, 256)
        )

        # 三路 ConvLSTM (P3/P4/P5 之后)
        self.lstm_p3 = ConvLSTM(64)
        self.lstm_p4 = ConvLSTM(128)
        self.lstm_p5 = ConvLSTM(256)

        # Neck
        self.neck = nn.ModuleList([
            C2f(256 + 128, 128, n=1, shortcut=False),
            C2f(128 + 64, 64, n=1, shortcut=False),
            Conv(64, 64, 3, 2),
            C2f(64 + 128, 128, n=1, shortcut=False),
            Conv(128, 128, 3, 2),
            C2f(128 + 256, 256, n=1, shortcut=False),
        ])

        # Head
        self.head = nn.ModuleList([
            nn.Conv2d(64, self.na * self.no, 1),
            nn.Conv2d(128, self.na * self.no, 1),
            nn.Conv2d(256, self.na * self.no, 1)
        ])

    def extract_p3_p4_p5(self, x):
        x1 = self.backbone[:4](x)
        x2 = self.backbone[4:6](x1)
        x3 = self.backbone[6:](x2)
        return x1, x2, x3

    def forward(self, imgs_seq):
        B, T, _, _, _ = imgs_seq.shape
        p3_list, p4_list, p5_list = [], [], []

        for t in range(T):
            frame = imgs_seq[:, t]
            p3, p4, p5 = self.extract_p3_p4_p5(frame)
            p3_list.append(p3)
            p4_list.append(p4)
            p5_list.append(p5)

        # 时序增强
        p3_seq = torch.stack(p3_list, dim=1)
        p4_seq = torch.stack(p4_list, dim=1)
        p5_seq = torch.stack(p5_list, dim=1)

        p3 = self.lstm_p3(p3_seq)[:, -1]
        p4 = self.lstm_p4(p4_seq)[:, -1]
        p5 = self.lstm_p5(p5_seq)[:, -1]

        # Neck 前向
        x = self.neck[0](torch.cat([F.interpolate(p5, scale_factor=2), p4], dim=1))
        x = self.neck[1](torch.cat([F.interpolate(x, scale_factor=2), p3], dim=1))
        out1 = self.head[0](x)

        x = self.neck[2](x)
        temp = self.neck[0](torch.cat([F.interpolate(p5, scale_factor=2), p4], dim=1))
        x = self.neck[3](torch.cat([x, temp], dim=1))
        out2 = self.head[1](x)

        x = self.neck[4](x)
        x = self.neck[5](torch.cat([x, p5], dim=1))
        out3 = self.head[2](x)

        return [out1, out2, out3]

# ===================== CIoU 计算 =====================
def bbox_iou(box1, box2, xywh=True, CIoU=True, eps=1e-7):
    if xywh:
        b1_x1 = box1[...,0] - box1[...,2]/2
        b1_x2 = box1[...,0] + box1[...,2]/2
        b1_y1 = box1[...,1] - box1[...,3]/2
        b1_y2 = box1[...,1] + box1[...,3]/2

        b2_x1 = box2[...,0] - box2[...,2]/2
        b2_x2 = box2[...,0] + box2[...,2]/2
        b2_y1 = box2[...,1] - box2[...,3]/2
        b2_y2 = box2[...,1] + box2[...,3]/2
    else:
        b1_x1,b1_y1,b1_x2,b1_y2 = box1.chunk(4,-1)
        b2_x1,b2_y1,b2_x2,b2_y2 = box2.chunk(4,-1)

    w1 = b1_x2 - b1_x1
    h1 = b1_y2 - b1_y1
    w2 = b2_x2 - b2_x1
    h2 = b2_y2 - b2_y1

    inter = (torch.min(b1_x2,b2_x2) - torch.max(b1_x1,b2_x1)).clamp(0) * \
            (torch.min(b1_y2,b2_y2) - torch.max(b1_y1,b2_y1)).clamp(0)
    union = w1*h1 + w2*h2 - inter + eps
    iou = inter / union

    if CIoU:
        cw = torch.max(b1_x2,b2_x2) - torch.min(b1_x1,b2_x1)
        ch = torch.max(b1_y2,b2_y2) - torch.min(b1_y1,b2_y1)
        c2 = cw**2 + ch**2 + eps
        rho2 = ((b1_x1+b1_x2 - b2_x1-b2_x2)**2 + (b1_y1+b1_y2 - b2_y1-b2_y2)**2) / 4

        v = (4 / math.pi**2) * torch.pow(torch.atan(w2/(h2+eps)) - torch.atan(w1/(h1+eps)), 2)
        with torch.no_grad():
            alpha = v / (v - iou + 1 + eps)
        return iou - (rho2/c2 + v*alpha)
    return iou

# ===================== YOLO 损失函数(含类别越界保护) =====================
class YOLOLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.lambda_box = 0.05
        self.lambda_obj = 1.0
        self.lambda_cls = 0.5
        self.anchors = ANCHORS
        self.strides = STRIDES
        self.na = NA
        self.nc = NC
        self.no = NO

    def forward(self, preds, targets):
        device = preds[0].device
        loss_box = loss_obj = loss_cls = 0.0

        for i, pred in enumerate(preds):
            B, C, H, W = pred.shape
            pred = pred.view(B, self.na, self.no, H, W).permute(0,3,4,1,2)

            # 解码
            xy = torch.sigmoid(pred[..., :2]) * 2.0 - 0.5
            wh = (torch.sigmoid(pred[..., 2:4]) * 2.0) ** 2 * self.anchors[i]
            pred_box = torch.cat([xy, wh], dim=-1)
            pred_conf = torch.sigmoid(pred[..., 4])
            pred_cls = torch.sigmoid(pred[..., 5:])

            # 标签初始化
            t_obj = torch.zeros((B, H, W, self.na), device=device)
            t_cls = torch.zeros((B, H, W, self.na, self.nc), device=device)
            t_box = torch.zeros((B, H, W, self.na, 4), device=device)

            # 标签分配 + 类别截断保护
            for b in range(B):
                for box in targets[b]:
                    if box.sum() < 1e-6:
                        continue
                    cx, cy, cw, ch, cls_id = box[:5]
                    cls_id = int(torch.clamp(cls_id, 0, self.nc - 1))
                    gx = int(cx * W)
                    gy = int(cy * H)
                    if 0 <= gx < W and 0 <= gy < H:
                        t_obj[b, gy, gx, :] = 1.0
                        t_cls[b, gy, gx, :, cls_id] = 1.0
                        t_box[b, gy, gx, :] = torch.tensor([cx, cy, cw, ch], device=device)

            # 损失计算
            iou = bbox_iou(pred_box, t_box, CIoU=True)
            loss_box += ((1.0 - iou) * t_obj).sum() / (t_obj.sum() + 1e-6)
            loss_obj += F.binary_cross_entropy(pred_conf, t_obj, reduction='sum') / (B * H * W)
            loss_cls += F.binary_cross_entropy(pred_cls, t_cls, reduction='sum') / (B * H * W)

        loss_box *= self.lambda_box
        loss_obj *= self.lambda_obj
        loss_cls *= self.lambda_cls
        total_loss = loss_box + loss_obj + loss_cls
        return total_loss, loss_box, loss_obj, loss_cls

# ===================== 数据集 =====================
class MOTVideoDataset(Dataset):
    def __init__(self, root="MOT16/train", seq_len=SEQ_LEN):
        self.root = root
        self.seq_len = seq_len
        self.data, self.labels = self._load()

    def _load(self):
        samples = []
        label_cache = []
        seqs = [d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
        for seq in seqs[:1]:
            img_dir = os.path.join(self.root, seq, "img1")
            gt_path = os.path.join(self.root, seq, "gt", "new_gt.txt")
            if not os.path.exists(gt_path):
                continue
            gts = np.loadtxt(gt_path, delimiter=",")
            frame_dict = defaultdict(list)
            for row in gts:
                fid = int(row[0])
                x = row[2] / IMG_SIZE
                y = row[3] / IMG_SIZE
                w = row[4] / IMG_SIZE
                h = row[5] / IMG_SIZE
                cls = int(row[7])
                # 若你的标签类别从 1 开始,取消下面这行注释
                # cls = cls - 1
                frame_dict[fid].append([x + w/2, y + h/2, w, h, cls])

            fids = sorted(frame_dict.keys())
            for i in range(len(fids) - self.seq_len):
                clip = []
                for j in range(self.seq_len):
                    f = fids[i + j]
                    clip.append(os.path.join(img_dir, f"{f:06d}.jpg"))
                samples.append(clip)
                label_cache.append(np.array(frame_dict[fids[i + self.seq_len - 1]]))
        return samples, label_cache

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        paths = self.data[idx]
        img_seq = []
        for p in paths:
            img = cv2.imread(p)
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            img = img.transpose(2, 0, 1) / 255.0
            img_seq.append(img)
        img_seq = np.array(img_seq, dtype=np.float32)

        label = self.labels[idx]
        max_box = 20
        pad = np.zeros((max_box, 5), dtype=np.float32)
        n = min(len(label), max_box)
        pad[:n] = label[:n]
        return torch.from_numpy(img_seq), torch.from_numpy(pad)

# ===================== 训练入口 =====================
if __name__ == "__main__":
    # 开启卷积加速
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    model = YOLOv8n_ConvLSTM().to(DEVICE)
    criterion = YOLOLoss().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)

    dataset = MOTVideoDataset()
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    print(f"Start Training | Seq_len:{SEQ_LEN}, BatchSize:{BATCH_SIZE}, LR:{LR:.6f}, Classes:{NC}")
    for epoch in range(30):
        model.train()
        total_loss = 0.0
        for imgs, labels in dataloader:
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            preds = model(imgs)
            loss, l_box, l_obj, l_cls = criterion(preds, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch:02d} | Avg Loss: {avg_loss:.4f}")

    torch.save(model.state_dict(), "yolov8_convlstm_bs8_final.pth")
    print("Training Done! Model saved -> yolov8_convlstm_bs8_final.pth")
相关推荐
CV-deeplearning18 小时前
YOLO26 正式发布!6 大任务一战封神,n 模型 mAP 40.9 跑 1.7ms,从检测到分割到姿态一条龙
yolo·目标检测·计算机视觉·ultralytics·yolo26
stsdddd19 小时前
YOLO系列目标检测数据集大全【第十五期】
yolo·目标检测·目标跟踪
stsdddd1 天前
YOLO系列目标检测数据集大全【第十六期】
yolo·目标检测·目标跟踪
hans汉斯1 天前
【人工智能与机器人研究】基于分层控制的多智能体编队协同控制
网络·人工智能·学习·yolo·机器人
动物园猫1 天前
无人机植物病害目标检测数据集分享(适用于YOLO系列深度学习分类检测任务)
yolo·目标检测·无人机
YOLO数据集集合1 天前
无人机航拍光伏板状态识别数据集 | 太阳能板异常检测、智能巡检、深度学习模型训练素材第10340期
人工智能·深度学习·yolo·目标检测·无人机
探物 AI2 天前
把 MambaOut 塞进 YOLOv11:会有什么样的反应
python·yolo·计算机视觉
快乐得小萝卜2 天前
部署:YOLO V11 TensorRT 推理&前后处理
yolo
断眉的派大星2 天前
YOLO26 完整学习笔记:从 Anchor-Free、TAL、STAL 到端到端无 NMS 部署
人工智能·笔记·学习·yolo·目标检测·计算机视觉·目标跟踪