基于MOT16数据集做目标检测的预处理(类别合并与清理)-CSDN博客 修改后的标签进行训练,代码如下,效果待验证
python
import os
import cv2
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
# ===================== 全局配置 =====================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 640
SEQ_LEN = 5
NC = 3 # 有效类别数 0,1,2
NA = 3 # 每个网格3个锚框
NO = NC + 5 # 5(xywh+conf) + 类别
BATCH_SIZE = 8 # 修改为8
LR = 1e-3 * (BATCH_SIZE / 2) # 自动线性缩放学习率
CONF_THR = 0.25
IOU_THR = 0.45
# YOLOv8n 锚框 + 步长
ANCHORS = torch.tensor([
[[10, 13], [16, 30], [33, 23]],
[[30, 61], [62, 45], [59, 119]],
[[116, 90], [156, 198], [373, 326]]
], device=DEVICE)
STRIDES = torch.tensor([8, 16, 32], device=DEVICE)
# ===================== YOLOv8 基础模块 =====================
class Conv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, p if p else k//2, groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C2f(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__()
self.c = int(c2 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g) for _ in range(n))
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class SPPF(nn.Module):
def __init__(self, c1, c2, k=5):
super().__init__()
c_ = c1 // 2
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(k, stride=1, padding=k//2)
def forward(self, x):
x = self.cv1(x)
y1 = self.m(x)
y2 = self.m(y1)
y3 = self.m(y2)
return self.cv2(torch.cat([x, y1, y2, y3], dim=1))
# ===================== ConvLSTM 模块 =====================
class ConvLSTMCell(nn.Module):
def __init__(self, in_channels, hidden_channels, k=3, p=1):
super().__init__()
self.conv = nn.Conv2d(in_channels + hidden_channels, 4 * hidden_channels, k, 1, p)
def forward(self, x, h, c):
combine = torch.cat([x, h], dim=1)
gates = self.conv(combine)
i, f, o, g = torch.chunk(gates, 4, dim=1)
c = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
h = torch.sigmoid(o) * torch.tanh(c)
return h, c
class ConvLSTM(nn.Module):
def __init__(self, channels):
super().__init__()
self.cell = ConvLSTMCell(channels, channels)
def forward(self, x_seq):
B, T, C, H, W = x_seq.shape
h = torch.zeros(B, C, H, W, device=x_seq.device)
c = torch.zeros(B, C, H, W, device=x_seq.device)
out = []
for t in range(T):
xt = x_seq[:, t]
h, c = self.cell(xt, h, c)
out.append(h)
return torch.stack(out, dim=1)
# ===================== 主干网络 YOLOv8n + ConvLSTM =====================
class YOLOv8n_ConvLSTM(nn.Module):
def __init__(self):
super().__init__()
self.na = NA
self.nc = NC
self.no = NO
# Backbone
self.backbone = nn.Sequential(
Conv(3, 16, 3, 2),
Conv(16, 32, 3, 2),
C2f(32, 32, n=1),
Conv(32, 64, 3, 2), # P3: 64, 80x80
C2f(64, 64, n=2),
Conv(64, 128, 3, 2), # P4: 128,40x40
C2f(128, 128, n=2),
Conv(128, 256, 3, 2), # P5:256,20x20
C2f(256, 256, n=1),
SPPF(256, 256)
)
# 三路 ConvLSTM (P3/P4/P5 之后)
self.lstm_p3 = ConvLSTM(64)
self.lstm_p4 = ConvLSTM(128)
self.lstm_p5 = ConvLSTM(256)
# Neck
self.neck = nn.ModuleList([
C2f(256 + 128, 128, n=1, shortcut=False),
C2f(128 + 64, 64, n=1, shortcut=False),
Conv(64, 64, 3, 2),
C2f(64 + 128, 128, n=1, shortcut=False),
Conv(128, 128, 3, 2),
C2f(128 + 256, 256, n=1, shortcut=False),
])
# Head
self.head = nn.ModuleList([
nn.Conv2d(64, self.na * self.no, 1),
nn.Conv2d(128, self.na * self.no, 1),
nn.Conv2d(256, self.na * self.no, 1)
])
def extract_p3_p4_p5(self, x):
x1 = self.backbone[:4](x)
x2 = self.backbone[4:6](x1)
x3 = self.backbone[6:](x2)
return x1, x2, x3
def forward(self, imgs_seq):
B, T, _, _, _ = imgs_seq.shape
p3_list, p4_list, p5_list = [], [], []
for t in range(T):
frame = imgs_seq[:, t]
p3, p4, p5 = self.extract_p3_p4_p5(frame)
p3_list.append(p3)
p4_list.append(p4)
p5_list.append(p5)
# 时序增强
p3_seq = torch.stack(p3_list, dim=1)
p4_seq = torch.stack(p4_list, dim=1)
p5_seq = torch.stack(p5_list, dim=1)
p3 = self.lstm_p3(p3_seq)[:, -1]
p4 = self.lstm_p4(p4_seq)[:, -1]
p5 = self.lstm_p5(p5_seq)[:, -1]
# Neck 前向
x = self.neck[0](torch.cat([F.interpolate(p5, scale_factor=2), p4], dim=1))
x = self.neck[1](torch.cat([F.interpolate(x, scale_factor=2), p3], dim=1))
out1 = self.head[0](x)
x = self.neck[2](x)
temp = self.neck[0](torch.cat([F.interpolate(p5, scale_factor=2), p4], dim=1))
x = self.neck[3](torch.cat([x, temp], dim=1))
out2 = self.head[1](x)
x = self.neck[4](x)
x = self.neck[5](torch.cat([x, p5], dim=1))
out3 = self.head[2](x)
return [out1, out2, out3]
# ===================== CIoU 计算 =====================
def bbox_iou(box1, box2, xywh=True, CIoU=True, eps=1e-7):
if xywh:
b1_x1 = box1[...,0] - box1[...,2]/2
b1_x2 = box1[...,0] + box1[...,2]/2
b1_y1 = box1[...,1] - box1[...,3]/2
b1_y2 = box1[...,1] + box1[...,3]/2
b2_x1 = box2[...,0] - box2[...,2]/2
b2_x2 = box2[...,0] + box2[...,2]/2
b2_y1 = box2[...,1] - box2[...,3]/2
b2_y2 = box2[...,1] + box2[...,3]/2
else:
b1_x1,b1_y1,b1_x2,b1_y2 = box1.chunk(4,-1)
b2_x1,b2_y1,b2_x2,b2_y2 = box2.chunk(4,-1)
w1 = b1_x2 - b1_x1
h1 = b1_y2 - b1_y1
w2 = b2_x2 - b2_x1
h2 = b2_y2 - b2_y1
inter = (torch.min(b1_x2,b2_x2) - torch.max(b1_x1,b2_x1)).clamp(0) * \
(torch.min(b1_y2,b2_y2) - torch.max(b1_y1,b2_y1)).clamp(0)
union = w1*h1 + w2*h2 - inter + eps
iou = inter / union
if CIoU:
cw = torch.max(b1_x2,b2_x2) - torch.min(b1_x1,b2_x1)
ch = torch.max(b1_y2,b2_y2) - torch.min(b1_y1,b2_y1)
c2 = cw**2 + ch**2 + eps
rho2 = ((b1_x1+b1_x2 - b2_x1-b2_x2)**2 + (b1_y1+b1_y2 - b2_y1-b2_y2)**2) / 4
v = (4 / math.pi**2) * torch.pow(torch.atan(w2/(h2+eps)) - torch.atan(w1/(h1+eps)), 2)
with torch.no_grad():
alpha = v / (v - iou + 1 + eps)
return iou - (rho2/c2 + v*alpha)
return iou
# ===================== YOLO 损失函数(含类别越界保护) =====================
class YOLOLoss(nn.Module):
def __init__(self):
super().__init__()
self.lambda_box = 0.05
self.lambda_obj = 1.0
self.lambda_cls = 0.5
self.anchors = ANCHORS
self.strides = STRIDES
self.na = NA
self.nc = NC
self.no = NO
def forward(self, preds, targets):
device = preds[0].device
loss_box = loss_obj = loss_cls = 0.0
for i, pred in enumerate(preds):
B, C, H, W = pred.shape
pred = pred.view(B, self.na, self.no, H, W).permute(0,3,4,1,2)
# 解码
xy = torch.sigmoid(pred[..., :2]) * 2.0 - 0.5
wh = (torch.sigmoid(pred[..., 2:4]) * 2.0) ** 2 * self.anchors[i]
pred_box = torch.cat([xy, wh], dim=-1)
pred_conf = torch.sigmoid(pred[..., 4])
pred_cls = torch.sigmoid(pred[..., 5:])
# 标签初始化
t_obj = torch.zeros((B, H, W, self.na), device=device)
t_cls = torch.zeros((B, H, W, self.na, self.nc), device=device)
t_box = torch.zeros((B, H, W, self.na, 4), device=device)
# 标签分配 + 类别截断保护
for b in range(B):
for box in targets[b]:
if box.sum() < 1e-6:
continue
cx, cy, cw, ch, cls_id = box[:5]
cls_id = int(torch.clamp(cls_id, 0, self.nc - 1))
gx = int(cx * W)
gy = int(cy * H)
if 0 <= gx < W and 0 <= gy < H:
t_obj[b, gy, gx, :] = 1.0
t_cls[b, gy, gx, :, cls_id] = 1.0
t_box[b, gy, gx, :] = torch.tensor([cx, cy, cw, ch], device=device)
# 损失计算
iou = bbox_iou(pred_box, t_box, CIoU=True)
loss_box += ((1.0 - iou) * t_obj).sum() / (t_obj.sum() + 1e-6)
loss_obj += F.binary_cross_entropy(pred_conf, t_obj, reduction='sum') / (B * H * W)
loss_cls += F.binary_cross_entropy(pred_cls, t_cls, reduction='sum') / (B * H * W)
loss_box *= self.lambda_box
loss_obj *= self.lambda_obj
loss_cls *= self.lambda_cls
total_loss = loss_box + loss_obj + loss_cls
return total_loss, loss_box, loss_obj, loss_cls
# ===================== 数据集 =====================
class MOTVideoDataset(Dataset):
def __init__(self, root="MOT16/train", seq_len=SEQ_LEN):
self.root = root
self.seq_len = seq_len
self.data, self.labels = self._load()
def _load(self):
samples = []
label_cache = []
seqs = [d for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
for seq in seqs[:1]:
img_dir = os.path.join(self.root, seq, "img1")
gt_path = os.path.join(self.root, seq, "gt", "new_gt.txt")
if not os.path.exists(gt_path):
continue
gts = np.loadtxt(gt_path, delimiter=",")
frame_dict = defaultdict(list)
for row in gts:
fid = int(row[0])
x = row[2] / IMG_SIZE
y = row[3] / IMG_SIZE
w = row[4] / IMG_SIZE
h = row[5] / IMG_SIZE
cls = int(row[7])
# 若你的标签类别从 1 开始,取消下面这行注释
# cls = cls - 1
frame_dict[fid].append([x + w/2, y + h/2, w, h, cls])
fids = sorted(frame_dict.keys())
for i in range(len(fids) - self.seq_len):
clip = []
for j in range(self.seq_len):
f = fids[i + j]
clip.append(os.path.join(img_dir, f"{f:06d}.jpg"))
samples.append(clip)
label_cache.append(np.array(frame_dict[fids[i + self.seq_len - 1]]))
return samples, label_cache
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
paths = self.data[idx]
img_seq = []
for p in paths:
img = cv2.imread(p)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
img = img.transpose(2, 0, 1) / 255.0
img_seq.append(img)
img_seq = np.array(img_seq, dtype=np.float32)
label = self.labels[idx]
max_box = 20
pad = np.zeros((max_box, 5), dtype=np.float32)
n = min(len(label), max_box)
pad[:n] = label[:n]
return torch.from_numpy(img_seq), torch.from_numpy(pad)
# ===================== 训练入口 =====================
if __name__ == "__main__":
# 开启卷积加速
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
model = YOLOv8n_ConvLSTM().to(DEVICE)
criterion = YOLOLoss().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
dataset = MOTVideoDataset()
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0
)
print(f"Start Training | Seq_len:{SEQ_LEN}, BatchSize:{BATCH_SIZE}, LR:{LR:.6f}, Classes:{NC}")
for epoch in range(30):
model.train()
total_loss = 0.0
for imgs, labels in dataloader:
imgs = imgs.to(DEVICE)
labels = labels.to(DEVICE)
optimizer.zero_grad()
preds = model(imgs)
loss, l_box, l_obj, l_cls = criterion(preds, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch:02d} | Avg Loss: {avg_loss:.4f}")
torch.save(model.state_dict(), "yolov8_convlstm_bs8_final.pth")
print("Training Done! Model saved -> yolov8_convlstm_bs8_final.pth")