YoLo World代码块解读

MaxSigmoidAttnBlock

分别处理图像与文本特征,计算这两者的相关性,得到整个句子所有word中最大的相关性数值作为attention作用于图像特征中。

python 复制代码
class MaxSigmoidAttnBlock(nn.Module):
    """Max Sigmoid attention block."""

    def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
        """Initializes MaxSigmoidAttnBlock with specified arguments."""
        super().__init__()
        self.nh = nh
        self.hc = c2 // nh
        self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
        self.gl = nn.Linear(gc, ec)
        self.bias = nn.Parameter(torch.zeros(nh))
        self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
        self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0

    def forward(self, x, guide):
        """Forward process."""
        bs, _, h, w = x.shape

        guide = self.gl(guide)
        guide = guide.view(bs, -1, self.nh, self.hc)
        embed = self.ec(x) if self.ec is not None else x
        embed = embed.view(bs, self.nh, self.hc, h, w)

        aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide) #文本与图像的交叉注意力
        aw = aw.max(dim=-1)[0] #整个句子最大相关性数值作为attention作用于图像特征中。
        aw = aw / (self.hc**0.5)
        aw = aw + self.bias[None, :, None, None]
        aw = aw.sigmoid() * self.scale

        x = self.proj_conv(x)
        x = x.view(bs, self.nh, -1, h, w)
        x = x * aw.unsqueeze(2)
        return x.view(bs, -1, h, w)

WorldDetect

基本和Detect一样,主要是分类分支需要通过BNContrastiveHead将图像特征于文字特征计算相关性得到每个grid中所有类别的置信度。

python 复制代码
class WorldDetect(Detect):
    def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
        """Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
        super().__init__(nc, ch)
        c3 = max(ch[0], min(self.nc, 100))
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
        self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)

    def forward(self, x, text):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        for i in range(self.nl):
            x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
            #cv4可以获得类别编码,[512,80,80]*[4,512]=[4,80,80]
相关推荐
喵叔哟5 小时前
01-YOLO最新版到底新在哪
yolo
无人装备硬件开发爱好者5 小时前
RV1126B 边缘端 AI 实战:YOLOv8+DNTR 微小目标跟踪监测全栈实现 1
人工智能·yolo·目标跟踪
2501_941322035 小时前
基于YOLOv8的汽车车损检测与评估系统_16种损伤类型识别
yolo·汽车
LASDAaaa12316 小时前
电力巡检实战:基于YOLOv8-SEG-P6的输电线路鸟类检测与识别技术详解
yolo
Piar1231sdafa6 小时前
YOLOv5-AIFI改进_爆炸物检测与识别系统_实现与应用
yolo
zy_destiny7 小时前
【工业场景】用YOLOv26实现4种输电线隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·输电线隐患识别
雍凉明月夜7 小时前
深度学习之目标检测yolo算法Ⅴ-YOLOv8
深度学习·yolo·目标检测
智驱力人工智能7 小时前
货车违规变道检测 高速公路安全治理的工程实践 货车变道检测 高速公路货车违规变道抓拍系统 城市快速路货车压实线识别方案
人工智能·opencv·算法·安全·yolo·目标检测·边缘计算
2501_941652777 小时前
改进YOLOv5-BiFPN-SDI实现牙齿龋齿检测与分类_深度学习_计算机视觉_原创
深度学习·yolo·分类
zy_destiny8 小时前
【工业场景】用YOLOv26实现8种道路隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·目标跟踪