yolov11剪枝、蒸馏、加注意力

这段代码是一个用于深度学习模型训练的Python脚本,特别是针对目标检测任务,使用了YOLO(You Only Look Once)算法。代码中包含了多个步骤,每个步骤都是模型训练过程中的一个阶段。以下是对代码的详细解释:

完整版代码在GitHub上:yolov11剪枝蒸馏

1. **导入必要的库和模块**:

  • `from ultralytics import YOLO`:导入了ultralytics提供的YOLO模型库。

  • `import os`:导入操作系统接口模块,用于文件和目录操作。

  • `from utils.yolo.attention import add_attention`:导入一个自定义模块,用于给模型添加注意力机制。

2. **设置环境变量和路径**:

  • `os.environ["CUDA_VISIBLE_DEVICES"]="0,1"`:这行代码被注释掉了,它的作用是设置CUDA环境变量,指定使用哪几个GPU设备。在这里指定了0号和1号GPU。

  • `root = os.getcwd()`:获取当前工作目录的路径。

  • `name_yaml`、`name_pretrain`等变量定义了配置文件和预训练模型文件的路径。

3. **定义训练步骤**:

  • `step1_train()`:加载预训练模型并开始训练。

  • `step2_Constraint_train()`:在约束条件下进行训练,例如可能涉及到正则化或其他约束条件。

  • `step3_pruning()`:使用自定义的`do_pruning`函数对模型进行剪枝,以减少模型的复杂度。

  • `step4_finetune()`:微调剪枝后的模型。

  • `step5_distillation()`:使用知识蒸馏技术,将一个训练好的大模型(教师模型)的知识传递给一个较小的模型(学生模型)。

4. **训练函数参数解释**:

  • `data`:指定数据配置文件的路径。

  • `device`:指定训练使用的设备,如GPU。

  • `imgsz`:指定输入图像的大小。

  • `epochs`:指定训练的轮数。

  • `batch`:指定每批训练的样本数量。

  • `workers`:指定用于数据加载的工作线程数量。

  • `save_period`:指定保存模型的周期。

  • `name`:指定模型保存的路径。

  • `amp`:指定是否使用自动混合精度训练。

  • `Distillation`:指定知识蒸馏的教师模型。

  • `loss_type`:指定损失函数的类型。

  • `layers`:指定进行蒸馏的层。

python 复制代码
from ultralytics import YOLO
import os
from utils.yolo.attention import add_attention
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

root = os.getcwd()
## 配置文件路径
name_yaml = os.path.join(root, "data.yaml")
name_pretrain = os.path.join(root, "runs/segment/ori/weights/best.pt")
## 原始训练路径
# path_train = os.path.join(root, "runs/detect/VOC")
name_train = "runs/segment/ori/weights/best.pt"
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/segment/Constraint")
name_prune_before = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after = os.path.join(path_constraint_train, "weights/prune.pt")
## 微调路径
path_fineturn = os.path.join(root, "runs/detect/VOC_finetune")

def step1_train():
    model = YOLO(name_pretrain)
    model.train(data=name_yaml, device="0", imgsz=720, epochs=50, batch=2, workers=0, save_period=1)  # train the model


## 2024.3.4添加【amp=False】
def step2_Constraint_train():
    model = YOLO(name_train)
    model.train(data=name_yaml, device="0", imgsz=640, epochs=50, batch=2, amp=False, workers=0, save_period=1,
                name=path_constraint_train)  # train the model


def step3_pruning():
    from utils.yolo.LL_pruning import do_pruning
    do_pruning(name_prune_before, name_prune_after)


def step4_finetune():
    model = YOLO(name_prune_after)  # load a pretrained model (recommended for training)
    for param in model.parameters():
        param.requires_grad = True
    model.train(data=name_yaml, device="0", imgsz=640, epochs=200, batch=2, workers=0, name=path_fineturn)  # train the model

def step5_distillation():
    layers = ["6", "8", "13", "16", "19", "22"]
    model_t = YOLO('runs/segment/ori/weights/best.pt')  # the teacher model
    model_s = YOLO('runs/segment/Constraint/weights/prune.pt')  # the student model
    model_s = add_attention(model_s)
    """
    Attributes:
        Distillation: the distillation model
    """
    model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd',layers=layers, amp=False, imgsz=1280, epochs=300,
                  batch=2, device=0, workers=0, lr0=0.001)


if __name__ == '__main__':
    # step1_train()
    # step2_Constraint_train()
    # step3_pruning()
    # step4_finetune()
    step5_distillation()
相关推荐
IT古董8 分钟前
【深度学习】常见模型-Transformer模型
人工智能·深度学习·transformer
沐雪架构师1 小时前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)2 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
kakaZhui2 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
利刃大大2 小时前
【回溯+剪枝】找出所有子集的异或总和再求和 && 全排列Ⅱ
c++·算法·深度优先·剪枝
struggle20253 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥3 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
云空4 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代4 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt
山晨啊85 小时前
2025年美赛B题-结合Logistic阻滞增长模型和SIR传染病模型研究旅游可持续性-成品论文
人工智能·机器学习