osnet模型和yolo模型的微调:冻结训练

1.osnet

1.1 源码

model_zoo:https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html

osnet源码:https://github.com/KaiyangZhou/deep-person-reid

1.2 分为三步冻结部分层进行微调:

1.stage1:使用torchreid的默认配置,冻结全部网络,只解冻训练classifier + BNNeck。

2.stage2:使用torchreid的默认配置,冻结全部网络,只解冻 classifier + BNNeck,解冻高层 backbone(conv5)

3.stage3:使用torchreid的默认配置,解冻所有层

1.3 自动化三阶段OSNet模型训练脚本 train_osnet_3stage_auto.py

将以下代码放至./model/deep-person-reid下的train_osnet_3stage_auto.py

复制代码
"""自动化三阶段OSNet模型训练脚本 train_osnet_3stage_auto.py
此脚本会自动依次执行Stage 1 -> Stage 2 -> Stage 3的训练,无需手动干预
"""

import torch
import torchreid
from torchreid.utils import set_random_seed
import os
import glob

# ===============================
# 0. 基础配置
# ===============================
set_random_seed(42)

# 数据集配置
DATASET_NAME = 'ten_object'  # 数据集 pubfigface
DATASET_ROOT = '/ReID/datasets/'  # 数据集根路径,数据集路径在以下注册,所以这里只需要根路径,别搞错了
BASE_SAVE_DIR = './log/osnet_ten_object_auto1'  # 基础保存目录

MODEL_NAME = 'osnet_x1_0' # 选用预训练模型,一般不用改
IMG_HEIGHT = 256
IMG_WIDTH = 128
BATCH_SIZE = 64

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# 初始化数据管理器
print("正在初始化数据管理器...")
datamanager = torchreid.data.ImageDataManager(
    root=DATASET_ROOT,
    sources=DATASET_NAME,
    targets=DATASET_NAME,
    height=IMG_HEIGHT,
    width=IMG_WIDTH,
    batch_size_train=BATCH_SIZE,
    batch_size_test=BATCH_SIZE,
    transforms=['random_flip', 'random_crop'],
    num_instances=4
)

def find_latest_model(stage_num, base_dir):
    """查找指定阶段的最新模型文件"""
    # 实际保存路径就是 base_dir/model/,不需要再加 stage 后缀
    model_dir = os.path.join(base_dir, "model")
    pattern = os.path.join(model_dir, "model.pth.tar-*")
    model_files = glob.glob(pattern)
    
    if not model_files:
        print(f"在 {model_dir} 中未找到模型文件")
        return None
    
    # 按epoch数字排序,返回最大epoch的模型
    def extract_epoch(path):
        try:
            # 提取文件名中的epoch数字
            basename = os.path.basename(path)
            # 从文件名中提取epoch数字,例如从 "model.pth.tar-40" 提取 40
            parts = basename.split('-')
            if len(parts) > 0:
                epoch_part = parts[-1].replace('.pth.tar', '')
                epoch = int(epoch_part)
                return epoch
        except:
            return 0
    
    latest_model = max(model_files, key=extract_epoch)
    return latest_model

def get_stage_config(stage_num, prev_model_path=None):
    """获取指定阶段的配置"""
    epochs_map = {1: 40, 2: 80, 3: 40}
    
    config = {
        'save_dir': f'{BASE_SAVE_DIR}_stage{stage_num}',
        'epochs': epochs_map.get(stage_num, 40),
        'lr_backbone': 0.0,
        'lr_head': 1e-3
    }
    
    if stage_num == 1:
        config['lr_backbone'] = 0.0
        config['lr_head'] = 1e-3
    elif stage_num == 2:
        config['lr_backbone'] = 1e-5
        config['lr_head'] = 1e-4
    elif stage_num == 3:
        config['lr_backbone'] = 3e-4
        config['lr_head'] = 3e-4
    
    return config

def build_and_load_model(stage_num, prev_model_path=None):
    """构建模型并加载相应阶段的权重"""
    print(f"构建阶段 {stage_num} 的模型...")
    
    model = torchreid.models.build_model(
        name=MODEL_NAME,
        num_classes=datamanager.num_train_pids,
        loss='softmax',
        pretrained=False  # 关闭自动下载
    )
    model = model.to(DEVICE)

    # 确定要加载的权重
    if stage_num == 1:
        # 第一阶段:加载ImageNet预训练权重
        local_weights = '../osnet_x1_0_imagenet.pth'
        print(f"加载基础预训练权重: {local_weights}")
    else:
        # 后续阶段:加载前一阶段的模型权重
        if prev_model_path and os.path.exists(prev_model_path):
            local_weights = prev_model_path
            print(f"加载前一阶段模型权重: {local_weights}")
        else:
            raise FileNotFoundError(f"找不到第 {stage_num-1} 阶段的模型文件: {prev_model_path}")
    
    # 加载权重
    state_dict = torch.load(local_weights, map_location=DEVICE, weights_only=False)
    
    # 兼容 torchreid / checkpoint 两种格式
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    
    # 删除 classifier 权重(对于第1阶段),避免 size mismatch
    if stage_num == 1:
        for key in list(state_dict.keys()):
            if key.startswith('classifier'):
                del state_dict[key]
        print("移除了预训练模型中的分类头权重")
    
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    print('加载模型权重成功')
    print('缺失的键:', missing)
    print('意外的键:', unexpected)
    
    return model

def setup_training_parameters(model, stage_num):
    """设置训练参数(冻结/解冻策略)"""
    print(f"设置阶段 {stage_num} 的训练参数...")
    
    # 先全部冻结
    for param in model.parameters():
        param.requires_grad = False

    # Stage 1: 只训练 classifier + BNNeck
    if stage_num == 1:
        for name, param in model.named_parameters():
            if 'classifier' in name or 'bn' in name:
                param.requires_grad = True

        # 必须重新初始化分类头
        model.classifier.reset_parameters()
        lr_backbone = 0.0
        lr_head = 1e-3

    # Stage 2: 解冻 classifier + BNNeck + 高层 backbone(conv5)
    elif stage_num == 2:
        for name, param in model.named_parameters():
            if 'classifier' in name or 'bn' in name:
                param.requires_grad = True

        # 解冻高层 backbone(conv5)
        for name, param in model.named_parameters():
            if 'conv5' in name:
                param.requires_grad = True

        lr_backbone = 1e-5
        lr_head = 1e-4

    # Stage 3: 全模型训练
    elif stage_num == 3:
        for param in model.parameters():
            param.requires_grad = True

        lr_backbone = 3e-4
        lr_head = 3e-4

    else:
        raise ValueError("STAGE must be 1, 2, or 3")

    # 设置优化器参数
    params = list(model.named_parameters())
    backbone_params = [p for n, p in params if 'classifier' not in n]
    head_params = [p for n, p in params if 'classifier' in n]

    optimizer = torch.optim.Adam(
        [
            {'params': backbone_params, 'lr': lr_backbone},
            {'params': head_params, 'lr': lr_head},
        ],
        weight_decay=5e-4
    )

    # 学习率调度
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=20, gamma=0.1
    )

    return optimizer, scheduler

def train_stage(stage_num, prev_model_path=None):
    """训练单个阶段"""
    print(f"\n{'='*50}")
    print(f"开始第 {stage_num} 阶段训练")
    print(f"{'='*50}")
    
    # 构建模型并加载权重
    model = build_and_load_model(stage_num, prev_model_path)
    
    # 设置训练参数
    optimizer, scheduler = setup_training_parameters(model, stage_num)
    
    # 获取阶段配置
    config = get_stage_config(stage_num)
    
    # 创建Trainer
    engine = torchreid.engine.ImageSoftmaxEngine(
        datamanager,
        model,
        optimizer=optimizer,
        scheduler=scheduler,
        label_smooth=True
    )

    # 开始训练
    print(f"开始训练,保存到: {config['save_dir']}")
    engine.run(
        save_dir=config['save_dir'],
        max_epoch=config['epochs'],
        eval_freq=5,
        print_freq=20
    )
    
    # 等待一段时间让文件系统完成写入
    import time
    print("等待模型文件写入完成...")
    time.sleep(5)  # 等待5秒让文件系统完成写入
    
    # 查找并返回此阶段训练完成的模型路径
    latest_model = find_latest_model(stage_num, config['save_dir'])
    if latest_model:
        print(f"阶段 {stage_num} 训练完成,最新模型: {latest_model}")
    else:
        print(f"警告: 未找到阶段 {stage_num} 的模型文件")
        # 再次尝试查找,以防文件还在写入过程中
        time.sleep(10)
        latest_model = find_latest_model(stage_num, config['save_dir'])
        if latest_model:
            print(f"阶段 {stage_num} 训练完成,最新模型: {latest_model}")
        else:
            print(f"再次检查后仍未找到阶段 {stage_num} 的模型文件")
        
    return latest_model

def main():
    """主函数:执行三阶段训练"""
    print("开始自动化三阶段训练流程")
    print("="*60)
    
    prev_model_path = None
    
    # 依次执行三个阶段
    for stage in range(1, 4):
        prev_model_path = train_stage(stage, prev_model_path)
        
        if not prev_model_path:
            print(f"错误: 阶段 {stage} 训练失败,无法继续后续阶段")
            return
    
    print("\n" + "="*60)
    print("所有三个阶段训练完成!")
    print(f"最终模型保存在: {BASE_SAVE_DIR}_stage3")
    print("训练流程结束")

if __name__ == "__main__":
    main()

1.4 如若自制数据集,格式如下

命名规则:

如:0001_c001_00016450_0.jpg,第一个下划线前是id号0001,第二个下划线前是相机号c0001,然后后面是图片名字,最后是扩展名.jpg

ID数应大于等于50,否则效果可能会下降

格式要求:

每个目标至少需要4张图片,4张图片需要两个不同的相机号

train中放训练集,包含cam1和cam2的图片

query是目标集,用来验证用的,gallery是搜索集,也就是使用query中的图片在gallery中进行搜索

注意!!!:query中的id和gallery中的id必须一样(不然搜不到),cam号必须不一样(算法要求),且必须与train中的id不一样(不然即参与训练又参与验证会导致map升高)

复制代码
# 示例数据集格式:
aircraft_reid/
├── train/
│   ├── 0001_c001_00016450.jpg
│   ├── 0001_c002_00016915.jpg
│   ├── 0001_c002_00014680.jpg
|
├── query/
│   ├── 0002_c001_00030600.jpg
│       
└── gallery/
    ├── 0002_c002_00030600.jpg
    │   
    ├── 0005_c002_00075750.jpg

1.5 数据集划分

可使用如下脚本进行yolo格式的数据集的划分,划分前数据集格式和分类数据集一致,每个文件夹为一类物体的所有图片,有若干文件夹

更改其中的DATA_YAML,OUTPUT_DIR,PREFIX

复制代码
"""
根据YOLO格式的标签,从图片中裁剪出目标,并按照指定前缀和类别名保存到对应文件夹中。
"""
import os
import yaml
import cv2

# ===================== 配置 =====================

DATA_YAML = "classes3.yaml" # 数据集的yaml文件路径,yaml文件为yolo格式如下
"""
path: /data/VOCdevkit/ # 数据集目录
train:
  - train/images
val:
  - valid/images
test:
  - test/images
# Classes
names:
  0: class0
  1: class1
  2: class2

"""
OUTPUT_DIR = "classes3_output"  # 输出目录
PREFIX = "classes3"  # 你指定的前缀,例如 classes3

IMG_EXTS = [".jpg", ".jpeg", ".png", ".bmp"]

# =================================================


def load_yaml(yaml_path):
    with open(yaml_path, 'r', encoding='utf-8') as f:
        data = yaml.safe_load(f)
    return data


def is_image(file):
    return os.path.splitext(file)[1].lower() in IMG_EXTS


def find_image_path(img_dir, base_name):
    for ext in IMG_EXTS:
        img_path = os.path.join(img_dir, base_name + ext)
        if os.path.exists(img_path):
            return img_path
    return None


def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def yolo_to_xyxy(img_w, img_h, x, y, w, h):
    """
    YOLO格式 -> 像素坐标
    """
    x1 = int((x - w / 2) * img_w)
    y1 = int((y - h / 2) * img_h)
    x2 = int((x + w / 2) * img_w)
    y2 = int((y + h / 2) * img_h)

    # 边界裁剪
    x1 = max(0, x1)
    y1 = max(0, y1)
    x2 = min(img_w - 1, x2)
    y2 = min(img_h - 1, y2)

    return x1, y1, x2, y2


def get_all_image_label_pairs(root_path):
    """
    通用版本:
    支持任意结构,只要满足 images / labels 对应关系
    """

    pairs = []

    for dirpath, _, filenames in os.walk(root_path):

        # 关键:路径中包含 labels
        if "labels" not in dirpath:
            continue

        # 推导对应的 images 路径
        img_dir = dirpath.replace(os.sep + "labels", os.sep + "images")

        if not os.path.exists(img_dir):
            # print(f"⚠️ 没有对应 images: {img_dir}")
            continue

        for file in filenames:
            if not file.endswith(".txt"):
                continue

            base = os.path.splitext(file)[0]
            label_path = os.path.join(dirpath, file)
            img_path = find_image_path(img_dir, base)

            if img_path:
                pairs.append((img_path, label_path))
            else:
                print(f"⚠️ 图片不存在: {base}")

    return pairs


def main():
    data = load_yaml(DATA_YAML)

    root_path = data["path"]
    names = data["names"]  # {0: 'M1A1', 1: 'T72', ...}

    print(f"数据集路径: {root_path}")
    print(f"类别: {names}")

    pairs = get_all_image_label_pairs(root_path)
    print(f"共找到 {len(pairs)} 对图片标签")

    # 每个类别计数
    counters = {int(k): 0 for k in names.keys()}

    for img_path, label_path in pairs:
        img = cv2.imread(img_path)
        if img is None:
            continue

        h, w = img.shape[:2]

        with open(label_path, "r") as f:
            lines = f.readlines()

        if len(lines) == 0:
            continue

        for line in lines:
            parts = line.strip().split()
            if len(parts) != 5:
                continue

            cls_id = int(parts[0])
            x, y, bw, bh = map(float, parts[1:])

            # ===== 新增:类别过滤 =====
            if cls_id not in names:
                print(f"⚠️ 跳过未知类别 {cls_id} -> {label_path}")
                continue

            x1, y1, x2, y2 = yolo_to_xyxy(w, h, x, y, bw, bh)

            crop = img[y1:y2, x1:x2]
            if crop.size == 0:
                continue

            cls_name = names[cls_id]

            # 输出文件夹
            out_dir = os.path.join(OUTPUT_DIR, f"{PREFIX}_{cls_name}")
            ensure_dir(out_dir)

            # 文件名
            idx = counters[cls_id]
            save_path = os.path.join(out_dir,
                                     f"{PREFIX}_{cls_name}_{idx:06d}.jpg")

            cv2.imwrite(save_path, crop)
            counters[cls_id] += 1

    print("处理完成!")


if __name__ == "__main__":
    main()

1.6 制作osnet数据集

在1.4生成的众多文件夹内同目录下创建以下脚本,制作osnet数据集,更改其中的SRC_DIR,OUT_DIR

复制代码
# 所有种类是同种类时
import os
import shutil
import random

# ===================== 配置 =====================
SRC_DIR = "classes3_output"
OUT_DIR = "reid_dataset_classes3"

TRAIN_RATIO = 0.7   # 70% 类别用于训练
random.seed(0)
# ===============================================

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def main():
    classes = [d for d in os.listdir(SRC_DIR) if os.path.isdir(os.path.join(SRC_DIR, d))]
    classes.sort()
    print(f"总类别数: {len(classes)}")

    random.shuffle(classes)
    split_idx = int(len(classes) * TRAIN_RATIO)
    train_classes = classes[:split_idx]
    test_classes = classes[split_idx:]

    # 创建输出目录
    train_dir = os.path.join(OUT_DIR, "train")
    query_dir = os.path.join(OUT_DIR, "query")
    gallery_dir = os.path.join(OUT_DIR, "gallery")
    ensure_dir(train_dir)
    ensure_dir(query_dir)
    ensure_dir(gallery_dir)

    # ===== 处理 train 类 =====
    train_id = 1
    for cls in train_classes:
        cls_path = os.path.join(SRC_DIR, cls)
        imgs = [f for f in os.listdir(cls_path) if f.endswith(".jpg")]
        for i, img_name in enumerate(imgs):
            cam_id = 1 if i % 2 == 0 else 2
            new_name = f"{train_id:04d}_c{cam_id:03d}_{i:08d}.jpg"
            shutil.copy(os.path.join(cls_path, img_name), os.path.join(train_dir, new_name))
        train_id += 1

    # ===== 处理 query/gallery 类 =====
    test_id = 1001
    for cls in test_classes:
        cls_path = os.path.join(SRC_DIR, cls)
        imgs = [f for f in os.listdir(cls_path) if f.endswith(".jpg")]
        random.shuffle(imgs)

        # 至少两张用于 query/gallery
        query_img = imgs[0]
        gallery_img = imgs[1]

        shutil.copy(os.path.join(cls_path, query_img), os.path.join(query_dir, f"{test_id:04d}_c001_00000000.jpg"))
        shutil.copy(os.path.join(cls_path, gallery_img), os.path.join(gallery_dir, f"{test_id:04d}_c002_00000001.jpg"))

        # 剩余放 gallery
        for i, img_name in enumerate(imgs[2:]):
            new_name = f"{test_id:04d}_c002_{i+2:08d}.jpg"
            shutil.copy(os.path.join(cls_path, img_name), os.path.join(gallery_dir, new_name))

        test_id += 1

    print("✅ 方案一数据集划分完成!")

if __name__ == "__main__":
    main()

1.7 注册数据集:

osnet不支持像yolo那样以文件夹名为数据集,需注册,

deep-person-reid/torchreid/data/datasets/下新建如classes3.py文件,内容如下:

更改其中的:self.train_dirself.query_dirself.gallery_dir路径

更改其中的:dataset_dir

更改其中的:class classes3类名和对应的super(classes3, self).__init__(train, query, gallery, **kwargs)

复制代码
# classes3.py
from __future__ import absolute_import
import os.path as osp
import glob
import re

from ..dataset import ImageDataset


class classes3(ImageDataset):
    dataset_dir = 'reid_dataset_classes3' # osnet的数据集路径

    def __init__(self, root='', **kwargs):
        self.root = osp.abspath(osp.expanduser(root))
        self.dataset_dir = osp.join(self.root, self.dataset_dir)

        self.train_dir = osp.join(self.dataset_dir, 'train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'gallery')

        required_files = [
            self.dataset_dir,
            self.train_dir,
            self.query_dir,
            self.gallery_dir
        ]
        self.check_before_run(required_files)

        train = self.process_dir(self.train_dir, relabel=True)
        query = self.process_dir(self.query_dir, relabel=False)
        gallery = self.process_dir(self.gallery_dir, relabel=False)

        super(classes3, self).__init__(train, query, gallery, **kwargs)

    def process_dir(self, dir_path, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d+)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid < 0:
                continue
            pid_container.add(pid)

        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        data = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid < 0:
                continue
            camid -= 1  # camid 从 0 开始
            if relabel:
                pid = pid2label[pid]
            data.append((img_path, pid, camid))

        return data

在 datasets/init.py 中注册

文件路径:model/deep-person-reid/torchreid/data/datasets/__init__.py添加:

复制代码
from .image.classes3 import classes3

添加到 factory:找到:__image_datasets = {.....},在其中添加如下

复制代码
__image_datasets = {
    ...
    'classes3': classes3,
}

1.8 更改训练脚本:train_osnet_3stage_auto.py

完成后训练脚本的路径改为

复制代码
DATASET_NAME = 'reid_dataset_classes3'

DATASET_ROOT改为classes3数据集的根路径

2.yolo的冻结训练

只需加一个参数freeze

freeze 值 含义
22 只训练 Detect Head(极端小数据)
10 Backbone 全冻结(推荐起点)
15 Backbone + 部分 Neck
0 不冻结

代码

复制代码
from ultralytics import YOLO
import os
os.environ["YOLO_DISABLE_AMP_CHECK"] = "1"

model = YOLO("./yolov8s.pt")
# 加载模型
results = model.train(data="truck.yaml", epochs=50, workers=16,imgsz=640, batch=64, device='0,1',resume=False, freeze=22, patience=20)# cache=True

# python -m torch.distributed.run --nproc_per_node=2 ./train_detect.py
相关推荐
向哆哆7 小时前
人脸眼部特征检测数据集(千张图片已划分、已标注)适用于YOLO系列深度学习分类检测任务
深度学习·yolo·分类
Dev7z8 小时前
基于YOLOv8面向家居场景的火焰烟雾图像识别系统
人工智能·yolo
童话名剑9 小时前
YOLO v6(学习笔记)
yolo·目标检测·yolov6
前网易架构师-高司机17 小时前
带标注的瓶盖识别数据集,识别率99.5%,可识别瓶盖,支持yolo,coco json,pascal voc xml格式
人工智能·yolo·数据集·瓶盖
一勺汤19 小时前
YOLO26 改进、魔改| 部分通道注意力模块PAT,以轻量化并行结构融合局部卷积与增强型通道注意力,提升小目标、遮挡目标的检测效果。
yolo·注意力机制·轻量化·小目标·yolo26·yolo26改进·复杂场景
fl1768311 天前
智慧工业玻璃瓶容器缺陷检测数据集VOC+YOLO格式2149张28类别
yolo
_假正经1 天前
YOLOV8/11分割与分类输出参数说明
人工智能·yolo·分类
JicasdC123asd2 天前
CGNet上下文引导网络改进YOLOv26下采样特征保留能力
网络·yolo
Coovally AI模型快速验证2 天前
检测+跟踪一体化!4.39M参数、8.3W功耗,轻量化模型让无人机在露天矿实时巡检
算法·yolo·无人机·智能巡检·智慧矿山