YoLo World代码块解读

MaxSigmoidAttnBlock

分别处理图像与文本特征,计算这两者的相关性,得到整个句子所有word中最大的相关性数值作为attention作用于图像特征中。

python 复制代码
class MaxSigmoidAttnBlock(nn.Module):
    """Max Sigmoid attention block."""

    def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
        """Initializes MaxSigmoidAttnBlock with specified arguments."""
        super().__init__()
        self.nh = nh
        self.hc = c2 // nh
        self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
        self.gl = nn.Linear(gc, ec)
        self.bias = nn.Parameter(torch.zeros(nh))
        self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
        self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0

    def forward(self, x, guide):
        """Forward process."""
        bs, _, h, w = x.shape

        guide = self.gl(guide)
        guide = guide.view(bs, -1, self.nh, self.hc)
        embed = self.ec(x) if self.ec is not None else x
        embed = embed.view(bs, self.nh, self.hc, h, w)

        aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide) #文本与图像的交叉注意力
        aw = aw.max(dim=-1)[0] #整个句子最大相关性数值作为attention作用于图像特征中。
        aw = aw / (self.hc**0.5)
        aw = aw + self.bias[None, :, None, None]
        aw = aw.sigmoid() * self.scale

        x = self.proj_conv(x)
        x = x.view(bs, self.nh, -1, h, w)
        x = x * aw.unsqueeze(2)
        return x.view(bs, -1, h, w)

WorldDetect

基本和Detect一样,主要是分类分支需要通过BNContrastiveHead将图像特征于文字特征计算相关性得到每个grid中所有类别的置信度。

python 复制代码
class WorldDetect(Detect):
    def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
        """Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
        super().__init__(nc, ch)
        c3 = max(ch[0], min(self.nc, 100))
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
        self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)

    def forward(self, x, text):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
            #cv4可以获得类别编码,[512,80,80]*[4,512]=[4,80,80]
相关推荐
想进部的张同学10 小时前
hilinux-3599---设备学习---以及部署yolo
学习·yolo·海思
Java后端的Ai之路14 小时前
【人工智能领域】-YOLO目标检测算法全解析(含大白话解释)
人工智能·yolo·目标检测·cnn
Coovally AI模型快速验证16 小时前
超越Sora的开源思路:如何用预训练组件高效训练你的视频扩散模型?(附训练代码)
人工智能·算法·yolo·计算机视觉·音视频·无人机
FL162386312917 小时前
监控视角工地建筑施工工程车辆检测数据集VOC+YOLO格式8345张10类别
yolo
duyinbi751718 小时前
【计算机视觉实践】:基于YOLOv8-BIMAFPN的海洋漏油事件检测与分类系统实现_2
yolo·计算机视觉·分类
头发够用的程序员19 小时前
Ultralytics 代码库深度解读【六】:数据加载机制深度解析
人工智能·pytorch·python·深度学习·yolo·边缘计算·模型部署
X_Cosmic19 小时前
从零开始:YOLO11 训练 DOTA OBB 遥感数据旋转框目标检测
python·yolo·目标检测
棒棒的皮皮20 小时前
【深度学习】YOLO模型精度优化全攻略
人工智能·深度学习·yolo·计算机视觉
棒棒的皮皮21 小时前
【深度学习】YOLO实战之模型训练
人工智能·深度学习·yolo·计算机视觉
数据光子21 小时前
【YOLO数据集】遛狗未牵绳目标检测
人工智能·python·yolo·目标检测·计算机视觉