完整的 YOLO26 自定义模块注册 & 训练步骤

1. 定义自定义模块文件(基础)
  • ultralytics/nn/modules/ 下创建模块文件(如 SEModule.py),定义模块类(如 SEBlock):
  • ✅ 关键:__init__ 方法要适配 YOLO 的参数传递逻辑 (需接收 c1, c2, *args,其中 c1 是上一层通道,c2 是当前通道);
  • ✅ 避坑:模块的 forward 方法输出张量的尺寸 / 通道需和输入一致(SEBlock 是注意力模块,满足这一点)。
2. 导入模块到 tasks.py(注册核心)
  • ultralytics/nn/tasks.py 开头添加导入语句:
python 复制代码
from ultralytics.nn.modules.SEModule import SEBlock  # 路径要和文件名称严格一致
  • tasks.pybase_modules = frozenset({...}) 中添加 SEBlock

✅ 关键:只有加入 base_modules,YOLO 的 parse_model 函数才会识别该模块并自动处理参数;

✅ 避坑:添加时注意逗号分隔 ,避免语法错误(如 Conv, SEBlock, C2f)。

3. 配置 YAML 文件(维度匹配核心)
  • ultralytics/cfg/models/26/ 下创建 yolo26-SE.yaml

✅ 关键 1:新增模块的参数要和模块 __init__ 匹配(如 SEBlock[128, 16]);

✅ 关键 2:修正 Concat 层索引 ------ 新增模块会导致 backbone 层索引后移,需确保 head 部分的 Concat 只拼接同尺寸特征层

✅ 避坑:标注每一层的尺寸 / 通道,避免拼接时出现「尺寸不匹配」错误。

4. 训练脚本调用(最终执行)
  • train.py 中加载自定义 YAML 并训练:
python 复制代码
from ultralytics import YOLO
model = YOLO("ultralytics/cfg/models/26/yolo26-SE.yaml")  # 路径要完整
model.train(data="coco8.yaml", epochs=10, batch=8)  # 补充训练参数

核心避坑点(我踩过的关键坑)

  1. 参数不匹配 :YOLO 的 parse_model 会给 base_modules 内的模块自动追加 c1 参数,因此自定义模块的 __init__ 需接收 c1, c2, *args(而非仅 c1, r);
  2. 索引错乱:新增模块会导致 backbone 层索引后移,必须修正 head 中 Concat 层的引用索引,确保拼接同尺寸特征;
  3. 语法错误 :修改 tasks.py 时,注意标点(逗号、冒号)和缩进,新增模块后建议用 python -m py_compile tasks.py 检查语法。

实验:

第一步,编辑模块ECA.py

/ultralytics/nn/ECA.py

python 复制代码
import torch
import torch.nn as nn
# 1. 修正:从正确路径导入C3k2和register_module
from ultralytics.nn.modules.block import C3k2  # C3k2在block.py中,不是tasks.py
# from ultralytics.nn.tasks import register_module  # 正确的register_module路径
# 或备选路径:from ultralytics.utils.torch_utils import register_module

# 2. 你的ECA模块(保留不变,正确)
# 新增:ECA模块
class ECAModule(nn.Module):
    def __init__(self, channel, k_size=3):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

    # 对池化结果进行一系列变换:
    # 1. 移除最后一个维度
    # 2. 转置最后两个维度
    # 3. 应用卷积层
    # 4. 再次转置最后两个维度
    # 5. 添加最后一个维度
# 新增:C3k2_ECA
    # 使用sigmoid函数处理结果,得到权重
class C3k2_ECA(C3k2):
    # 将权重扩展到与输入x相同的形状,并与输入x逐元素相乘
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=2, eca_k=3):
        super().__init__(c1, c2, n, shortcut, g, e, k)
        c_ = int(c2 * e)
        self.eca = ECAModule(channel=c_, k_size=eca_k)

    def forward(self, x):
        out = super().forward(x)
        out = self.eca(out)
        return out

# 4. 核心修正:手动注册模块(适配所有Ultralytics版本)
# 方式1:添加到block模块字典(优先)
# import ultralytics.nn.modules.block
# ultralytics.nn.modules.block.C3k2_ECA = C3k2_ECA  # 把C3k2_ECA加入block模块

# # 方式2:添加到tasks的模型字典(兜底,二选一即可)
# import ultralytics.nn.tasks
# ultralytics.nn.tasks.MODELS["C3k2_ECA"] = C3k2_ECA
# 测试:验证模块是否注册成功(可选,运行无报错则说明成功)
if __name__ == "__main__":
    # 创建C3k2_ECA实例,测试前向传播
    model = C3k2_ECA(c1=128, c2=256, n=2, e=0.5)
    x = torch.randn(1, 128, 80, 80)  # 模拟P2层特征
    out = model(x)
    print(f"C3k2_ECA输出形状:{out.shape}")  # 应输出 torch.Size([1, 256, 80, 80])
    print("C3k2_ECA模块注册+运行成功!")

第二步,在ultralytics/nn/tasks.py中添加模块代码

python 复制代码
from ultralytics.nn.ECA import C3k2_ECA

第三步,在ultralytics/nn/tasks.py中的 base_modules = frozenset( 中添加

python 复制代码
     A2C2f,
     C3k2_ECA, #   新增
  }

第四步,在ultralytics/cfg/models 中增加一个yaml文件

ultralytics/cfg/models/yolo26-ECA.yaml

python 复制代码
# Parameters
nc: 80 # number of classes
end2end: True # whether to use end-to-end mode
reg_max: 1 # DFL bins
scales: # model compound scaling constants, i.e. 'model=yolo26n.yaml' will call yolo26.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 260 layers, 2,572,280 parameters, 2,572,280 gradients, 6.1 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 260 layers, 10,009,784 parameters, 10,009,784 gradients, 22.8 GFLOPs
  m: [0.50, 1.00, 512] # summary: 280 layers, 21,896,248 parameters, 21,896,248 gradients, 75.4 GFLOPs
  l: [1.00, 1.00, 512] # summary: 392 layers, 26,299,704 parameters, 26,299,704 gradients, 93.8 GFLOPs
  x: [1.00, 1.50, 512] # summary: 392 layers, 58,993,368 parameters, 58,993,368 gradients, 209.5 GFLOPs

# YOLO26n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2_ECA, [256, False, 0.25]]  #这里把C3K2改成C3K2_ECA
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5, 3, True]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO26n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, True]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, True]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, True]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 1, C3k2, [1024, True, 0.5, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

第五步 修改train.py

python 复制代码
from ultralytics import YOLO

model = YOLO("ultralytics/cfg/models/26/yolo26-ECA.yaml")  # build a new model from scratch

 
# Train the model
results = model.train(data="coco128.yaml", epochs=3000, imgsz=640)
相关推荐
Sylvia33.2 小时前
火星数据:棒球数据API
java·前端·人工智能
nihao5612 小时前
OpenClaw 保姆级安装部署教程
人工智能
X54先生(人文科技)2 小时前
碳硅协同开发篇-ELR诞生记章
人工智能·ai编程·ai写作·程序员创富
小王毕业啦2 小时前
2010-2024年 上市公司-突破性创新和渐进性创新(数据+代码+文献)
大数据·人工智能·数据挖掘·数据分析·数据统计·社科数据·经管数据
美酒没故事°2 小时前
手摸手在扣子平台搭建周报智能体[特殊字符]
人工智能·ai
若谷老师3 小时前
21.WSL中部署gnina分子对接程序ds
linux·人工智能·ubuntu·卷积神经网络·gnina·smina
诗词在线3 小时前
孟浩然诗作数字化深度实战:诗词在线的意象挖掘、检索优化与多场景部署
大数据·人工智能·算法
冬奇Lab3 小时前
一天一个开源项目(第23篇):PageLM - 开源 AI 教育平台,把学习材料变成互动资源
人工智能·开源