YOLO中task.py改复杂的模块

比如改一个多输入模块,我们需要记录输入1的通道,输入2的通道,Conv_reduce的输入通道

YOLO中这个模块接受层1和层2的作为输入,那么层1和层2的输出通道肯定是知道的,所以现在只需要在yaml里面标记整个模块的输出通道即可。

python 复制代码
class AF(nn.Module):
    def __init__(self,c1,c2,dim1,dim2):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_atten = nn.Sequential(
            nn.Conv2d(c1, c1,1),
            nn.Sigmoid()
        )
        self.conv_redu = nn.Conv2d(c1, c2, kernel_size=1, bias=False)

        self.conv1 = nn.Conv2d(dim1, 1, 1, 1)
        self.conv2 = nn.Conv2d(dim2, 1, 1, 1)
        self.nonlin = nn.Sigmoid()

    def forward(self, x):
        output =  torch.cat(x,1)
        att = self.conv_atten(self.avg_pool(output))
        #print(att.shape)
        output = output * att
        output = self.conv_redu(output)
        #print(output.shape)
        att = self.conv1(x[0]) + self.conv2(x[1])
        att = self.nonlin(att)
        #print(att.shape)
        output = output * att
        return output
html 复制代码
  - [[-1, 6], 1, AF, [32]] # cat backbone P4

例如这条yaml,接受第6层和上一层的输入,输出通道数为32。这里参数为什么是一个?因为这里只需要给出输出通道数即可,其他参数可以再网络的记录中得到。

python 复制代码
        elif m is AF:
            c1 = sum(ch[x] for x in f)
            c3 = ch[f[0]]
            c4 = ch[f[1]]
            c2 = args[0]
            args = [c1,c2,c3,c4]
            print(args)

f是一个表表示来自那一层,这里的f里面就保存的内容相当于【-1,6】的索引,ch是每一层的输出通道数,ch[层索引]不就得到某层的输出通道了。这里随便借助一个中间变量,c1,c2,c3,c4,记录参数后,合成列表【c1,c2,c3,c4】

python 复制代码
torch.nn.Sequential(*(m(*args))

m相当于类名称,加入类名为AF,不就相当于AF(c1,c2,c3,c4)吗

相关推荐
音沐mu.13 小时前
【69】果蔬新鲜度数据集(有v5/v8模型)/YOLO果蔬新鲜度检测
yolo·目标检测·数据集·果蔬新鲜度数据集·果蔬新鲜度检测
Muyuan199813 小时前
26.Paper RAG Agent 展示面收口:截图与项目表达更新记录
人工智能·python·django·fastapi
AI技术增长13 小时前
Pytorch图像去噪实战(十四):条件扩散模型图像去噪,让Diffusion根据带噪图恢复干净图
人工智能·pytorch·python
li星野13 小时前
FastAPI 项目加入 WebSocket 支持
python·websocket·fastapi
tangweiguo0305198713 小时前
LangGraph 入门:多智能体工作流实战(阿里云百炼)
人工智能·python·langchain
帅次13 小时前
Android AI 面试速刷版
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理·数据分析
生物信息与育种13 小时前
全基因组重测序及群体遗传与进化分析技术服务指南
人工智能·深度学习·算法·数据分析·r语言
MediaTea13 小时前
Scikit-learn:preprocessing 模块
人工智能·深度学习·机器学习·计算机视觉·scikit-learn
Ares-Wang13 小时前
Flask》》Flask-Caching缓存插件
python·缓存·flask
明如正午13 小时前
转换pdf文件为md文件【markitdown+pdf4llm】
python·pdf·markitdown·pdf4llm