从零搭建CBAM、SENet、STN、transformer、mobile_vit、simple_vit、vit模型(Pytorch代码示例)

从零搭建深度学习模型(Pytorch代码示例二)

    • CBAM
      • [CBAM 的关键特点](#CBAM 的关键特点)
      • [CBAM 的基本结构](#CBAM 的基本结构)
    • SENet
      • [SENet 的关键特点](#SENet 的关键特点)
      • [SENet 的基本结构](#SENet 的基本结构)
    • STN
      • [STN 的关键特点](#STN 的关键特点)
      • [STN 的基本结构](#STN 的基本结构)
    • transformer
      • [Transformer 的关键特点](#Transformer 的关键特点)
      • [Transformer 的基本结构](#Transformer 的基本结构)
    • mobile_vit
      • [MobileViT 的关键特点](#MobileViT 的关键特点)
      • [MobileViT 的基本结构](#MobileViT 的基本结构)
      • [MobileViT Block 的详细结构](#MobileViT Block 的详细结构)
    • simple_vit
      • [SimpleViT 的关键特点](#SimpleViT 的关键特点)
      • [SimpleViT 的基本结构](#SimpleViT 的基本结构)
      • [Patch Embedding](#Patch Embedding)
      • [Positional Encoding](#Positional Encoding)
      • [Transformer Encoder](#Transformer Encoder)
      • [Classification Head](#Classification Head)
    • vit

CBAM

CBAM(Convolutional Block Attention Module)是一种注意力机制,可以在现有的卷积神经网络(CNN)中插入,以增强模型对重要特征的关注。CBAM 通过同时考虑通道维度和空间维度的注意力,提高了模型的表征能力和性能。以下是 CBAM 的一些关键特点和实现细节:

CBAM 的关键特点

  • 通道注意力(Channel Attention):通过计算每个通道的重要性权重,增强或抑制特定通道的特征。
  • 空间注意力(Spatial Attention):通过计算每个位置的重要性权重,增强或抑制特定区域的特征。
  • 轻量级:CBAM 可以轻松地插入到现有的 CNN 架构中,而不会显著增加计算复杂度。

CBAM 的基本结构

CBAM 包括两个主要模块:通道注意力模块和空间注意力模块。

  1. 通道注意力模块(Channel Attention Module)
    最大池化:对输入特征图进行最大池化操作。
    平均池化:对输入特征图进行平均池化操作。
    共享多层感知机(MLP):通过两个全连接层(FC)来学习通道的重要性权重。
    Sigmoid 激活:将 MLP 的输出通过 Sigmoid 函数转换为权重。
  2. 空间注意力模块(Spatial Attention Module)
    最大池化:对输入特征图进行最大池化操作。
    平均池化:对输入特征图进行平均池化操作。
    卷积层:通过一个 7x7 的卷积层来学习空间的重要性权重。
    Sigmoid 激活:将卷积层的输出通过 Sigmoid 函数转换为权重。
python 复制代码
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

## 通道注意力模块
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
                               nn.ReLU(),
                               nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

## 空间注意力模块
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

## ResNet18与ResNet34使用的残差模块
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.ca = ChannelAttention(planes) # 通道注意力
        self.sa = SpatialAttention() # 空间注意力

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

## ResNet50,ResNet101与ResNet152使用的残差模块
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)

        self.ca = ChannelAttention(planes * 4)
        self.sa = SpatialAttention()

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.ca(out) * out
        out = self.sa(out) * out

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

## 不同残差网络
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

if __name__ == '__main__':
    resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
    resnet34 = ResNet(BasicBlock, [3, 4, 6, 3])
    resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
    resnet101 = ResNet(Bottleneck, [3, 4, 23, 3])
    resnet152 = ResNet(Bottleneck, [3, 8, 36, 3])

    x = torch.randn(1, 3, 224, 224)
    output = resnet18(x)
    torch.onnx.export(resnet18,x,'resnet18_cbam.onnx')

    # 使用预训练权重
    pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
    new_state_dict = resnet18.state_dict()
    new_state_dict.update(pretrained_state_dict)
    resnet18.load_state_dict(new_state_dict)

SENet

SENet(Squeeze-and-Excitation Network)是一种通过引入注意力机制来增强卷积神经网络(CNN)性能的方法。SENet 通过动态地重新校准特征通道的重要性,使得模型能够更好地关注重要的特征,从而提高模型的表征能力和泛化能力。以下是 SENet 的一些关键特点和实现细节:

SENet 的关键特点

  • Squeeze 操作:通过全局平均池化(Global Average Pooling)将每个特征通道压缩成一个全局描述符,捕获每个通道的全局信息。
  • Excitation 操作:通过两个全连接层(FC)学习每个通道的重要性权重,这些权重通过 Sigmoid 激活函数转换为 [0, 1] 范围内的值。
  • Scale 操作:将学习到的权重与原始特征图相乘,增强或抑制特定通道的特征。

SENet 的基本结构

SENet 通过在传统的卷积块中插入 SE 模块来实现注意力机制。SE 模块包括以下步骤:

  • Squeeze 操作:对输入特征图进行全局平均池化,得到每个通道的全局描述符。
  • Excitation 操作:通过两个全连接层学习每个通道的重要性权重。
  • Scale 操作:将学习到的权重与原始特征图相乘,得到增强后的特征图。
python 复制代码
#coding:utf8
import torch
import torch.nn as nn

## SE模块搭建
class SELayer(nn.Module):
    def __init__(self, channel,reduction = 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_features=channel, out_features=channel // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_features=channel // reduction, out_features=channel)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        identity = x
        n,c,_,_ = x.size() # [N,C,H,W]
        x = self.avg_pool(x)
        x = x.view(n,c)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        x = x.view(n,c,1,1)
        x = identity * x

        return x

## 残差模块
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, outchannels):
        super(ResidualBlock, self).__init__()

        self.channel_equal_flag = True
        if in_channels == outchannels:
            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=outchannels, kernel_size=3, padding=1, stride=1, bias=False)
        else:
            self.conv1x1 = nn.Conv2d(in_channels=in_channels, out_channels=outchannels, kernel_size=1, stride=2, bias=False)
            self.bn1x1 = nn.BatchNorm2d(num_features=outchannels)
            self.channel_equal_flag = False

            self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=outchannels,kernel_size=3,padding=1, stride=2, bias=False)

        self.bn1 = nn.BatchNorm2d(num_features=outchannels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(in_channels=outchannels, out_channels=outchannels, kernel_size=3,padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=outchannels)

        self.selayer = SELayer(channel=outchannels)

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        if self.channel_equal_flag == True:
            pass
        else:
            identity = self.conv1x1(identity)
            identity = self.bn1x1(identity)
            identity = self.relu(identity)

        x = self.selayer(x) ## 即插即用模块

        out = identity + x
        return out

## SENet-18
class Model(nn.Module):
    def __init__(self, num_classes=1000):
        super(Model, self).__init__()
        # conv1
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64,kernel_size=7,stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=64)
        self.relu = nn.ReLU(inplace=True)

        # conv2_x
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2,padding=1)
        self.conv2_1 = ResidualBlock(in_channels=64, outchannels=64)
        self.conv2_2 = ResidualBlock(in_channels=64, outchannels=64)

        # conv3_x
        self.conv3_1 = ResidualBlock(in_channels=64, outchannels=128)
        self.conv3_2 = ResidualBlock(in_channels=128, outchannels=128)

        # conv4_x
        self.conv4_1 = ResidualBlock(in_channels=128, outchannels=256)
        self.conv4_2 = ResidualBlock(in_channels=256, outchannels=256)

        # conv5_x
        self.conv5_1 = ResidualBlock(in_channels=256, outchannels=512)
        self.conv5_2 = ResidualBlock(in_channels=512, outchannels=512)

        # avg_pool
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1,1)) # [N, C, H, W] = [N, 512, 1, 1]

        # fc
        self.fc = nn.Linear(in_features=512, out_features=num_classes)  # [N, num_classes]

    def forward(self, x):
        # conv1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # conv2_x
        x = self.maxpool(x)
        x = self.conv2_1(x)
        x = self.conv2_2(x)

        # conv3_x
        x = self.conv3_1(x)
        x = self.conv3_2(x)

        # conv4_x
        x = self.conv4_1(x)
        x = self.conv4_2(x)

        # conv5_x
        x = self.conv5_1(x)
        x = self.conv5_2(x)

        # avgpool + fc
        x = self.avg_pool(x)

        x = torch.flatten(x, 1)

        x = self.fc(x)

        return x


if __name__ == '__main__':
    # 定义模型
    model = Model(num_classes=10)

    # 定义输入 [N, C, H, W]
    input = torch.ones([10, 3, 224, 224])
    output = model(input)
    torch.onnx.export(model,input,'senet.onnx')

STN

STN(Spatial Transformer Networks)是一种用于在卷积神经网络(CNN)中进行空间变换的技术。STN 可以动态地对输入图像进行空间变换,如平移、旋转、缩放等,从而提高模型的鲁棒性和泛化能力。STN 通过引入一个可学习的空间变换模块,使得模型能够自适应地调整输入图像的位置和姿态,从而更好地捕捉特征。

STN 的关键特点

  • 局部化网络(Localization Network):用于预测变换参数的网络,通常是一个小的 CNN。
  • 网格生成器(Grid Generator):根据预测的变换参数生成采样网格。
  • 采样器(Sampler):根据生成的采样网格对输入图像进行采样,生成变换后的图像。

STN 的基本结构

STN 的基本结构包括三个主要部分:

  • 局部化网络:接收输入图像并输出变换参数。
  • 网格生成器:根据变换参数生成采样网格。
  • 采样器:根据采样网格对输入图像进行采样,生成变换后的图像。
python 复制代码
import torch
from torch import nn
from torch.nn import functional as F

class STN(nn.Module):
    def __init__(self, c,h,w,mode='stn'):
        assert mode in ['stn', 'cnn']

        super(STN, self).__init__()
        self.mode = mode
        self.local_net = LocalNetwork(c,h,w)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Linear(in_features=16*8*8, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=10)
        )

    def forward(self, img):
        '''
        :param img: (b, c, h, w)
        :return: (b, c, h, w), (b,)
        '''
        batch_size,c,h,w = img.shape
        img = self.local_net(img)

        conv_output = self.conv(img).view(batch_size, -1)
        predict = self.fc(conv_output)
        return img, predict


class LocalNetwork(nn.Module):
    def __init__(self,c,h,w):
        super(LocalNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features=c*h*w,
                      out_features=20),
            nn.Tanh(),
            nn.Linear(in_features=20, out_features=6),
            nn.Tanh(),
        )

    def forward(self, img):
        '''
        :param img: (b, c, h, w)
        :return: (b, c, h, w)
        '''
        batch_size,c,w,h = img.shape

        theta = self.fc(img.view(batch_size, -1)).view(batch_size, 2, 3)

        ## 仿射变换采样函数
        grid = F.affine_grid(theta, torch.Size((batch_size,c,h,w)))
        img_transform = F.grid_sample(img, grid)

        return img_transform


if __name__ == '__main__':
    net = STN(3, 32, 32)
    x = torch.randn(1, 3, 32, 32)

    feature,predict = net(x)

    print(feature.shape)

transformer

Transformer 是一种基于自注意力机制(Self-Attention Mechanism)的深度学习模型,最初由 Vaswani 等人在 2017 年的论文《Attention is All You Need》中提出。Transformer 模型在自然语言处理(NLP)任务中取得了显著的成功,尤其是在机器翻译、文本生成、问答系统等领域。以下是 Transformer 的一些关键特点和实现细节:

Transformer 的关键特点

  • 自注意力机制(Self-Attention Mechanism):允许模型在处理序列数据时,关注序列中的不同位置,从而捕捉长距离依赖关系。
  • 前馈神经网络(Feed-Forward Neural Network):每个位置的特征经过相同的前馈神经网络进行处理。
  • 位置编码(Positional Encoding):由于 Transformer 模型本身不包含序列顺序信息,因此需要添加位置编码来保留序列的顺序信息。
  • 多头注意力机制(Multi-Head Attention):通过多个不同的注意力头来捕捉不同类型的依赖关系,提高模型的表达能力。

Transformer 的基本结构

Transformer 模型主要由编码器(Encoder)和解码器(Decoder)组成。每个编码器和解码器都包含多个相同的层。

  1. 编码器(Encoder)
    多头自注意力机制(Multi-Head Self-Attention):对输入序列进行自注意力计算。
    前馈神经网络(Feed-Forward Neural Network):对每个位置的特征进行非线性变换。
    残差连接(Residual Connections):在每个子层之后添加残差连接,以缓解梯度消失问题。
    层归一化(Layer Normalization):在每个子层之后进行层归一化,以稳定训练过程。
  2. 解码器(Decoder)
    多头自注意力机制(Multi-Head Self-Attention):对目标序列进行自注意力计算。
    多头编码器-解码器注意力机制(Multi-Head Encoder-Decoder Attention):对编码器的输出和目标序列进行交叉注意力计算。
    前馈神经网络(Feed-Forward Neural Network):对每个位置的特征进行非线性变换。
    残差连接(Residual Connections):在每个子层之后添加残差连接。
    层归一化(Layer Normalization):在每个子层之后进行层归一化。
python 复制代码
# coding:utf8
# 第一步:导入需要的库
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy

import matplotlib.pyplot as plt

# 第二步:定义Transformer类,标准的Encoder-Decoder架构
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(Transformer, self).__init__()
        # encoder和decoder都是构造的时候传入的,这样会非常灵活
        self.encoder = encoder
        self.decoder = decoder
        # 输入和输出的embedding
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        # Decoder部分最后的Linear+softmax
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        # 接收并处理屏蔽src和目标序列
        # 首先调用encode方法对输入进行编码,然后调用decode方法进行解码
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        # 传入的参数包括src的embedding和src_mask
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        # 传入的参数包括目标的embedding,Encoder的输出memory,及两种掩码
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

# 第三步:创建Generator类,最终的输出层,全连接(linear)+ softmax,根据Decoder的隐状态输出一个词
class Generator(nn.Module):
    """d_model是Decoder输出的大小,vocab是词典大小"""

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    # 全连接再加上一个softmax
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

# 第四步:创建LayerNorm类,SublayerConnection类,Feedforward类
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        # pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor。故需对初始化后参数(Tensor型)进行类型转换。
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

# 不管是Self-Attention还是全连接层,都首先是LayerNorm,然后是Self-Attention/Dense,然后是Dropout,最后是残差连接。这里把它封装成SublayerConnection
class SublayerConnection(nn.Module):
    """
    LayerNorm + sublayer(Self-Attenion/Dense) + dropout + 残差连接
    为了简单,把LayerNorm放到了前面,这和原始论文稍有不同,原始论文LayerNorm在最后
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # 将残差连接应用于具有相同大小的任何子层
        return x + self.dropout(sublayer(self.norm(x)))

# Feedforward层
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

# 第五步:构建HeadedAttention,Scaled Dot Product Attention
def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1) # 特征维度
    # 矩阵(-1,-1,n,d_k)与矩阵(-1,-1,d_k,n)相乘,得到大小为(-1,-1,n,n)的矩阵,n为输入词的长度
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None: # 掩码为0的地方
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn # 矩阵(-1,-1,n,n)与矩阵(-1,-1,n,d_k)相乘,得到大小为(-1,-1,n,d_k)的矩阵

# 计算MultiHeadedAttention,传入head个数及所有head拼接后的model维度
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # 这里假设d_v=d_k
        self.d_k = d_model // h  # 计算每一个头的输入维度,如512/8=64
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)  ## 定义线性变换矩阵wq,wk,wv
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            # 相同的mask适应所有的head.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) 首先使用线性变换,然后把d_model分配给h个Head,每个head为d_k=d_model/h,矩阵格式(nbatches, self.h, n, self.d_k)
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) 使用attention函数计算scaled-Dot-product-attention
        # Q=矩阵(nbatches, self.h, n, self.d_k),K=矩阵(nbatches, self.h, self.d_k,n)相乘,
        # A=矩阵(nbatches, self.h, n, n),V=矩阵(nbatches, self.h,n, self.d_k),B=矩阵(nbatches, self.h, n, self.d_k)
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)

        # 3) 实现Multi-head attention,用view函数把8个head的64维向量拼接成一个512的向量。
        # 然后再使用一个线性变换(512,512),shape不变.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

# 第六步:创建Encoder,Encoder是N个EncoderLayer的堆积而成
def clones(module, N):
    "克隆N个完全相同的SubLayer,使用了copy.deepcopy"
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        # clone N个layer
        self.layers = clones(layer, N)
        # 再加一个LayerNorm层
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)  # N个EncoderLayer处理完成之后还需要一个LayerNorm

# 创建EncoderLayer,由self-attn and feed forward构成
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        # 第一个SublayerConnection是multi attention模块,第一个SublayerConnection是feed forward模块
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

# 第七步:定义Decoder,构建N个完全相同的Decoder层
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

# 创建DecoderLayer,由self-attn, src-attn和feed forward构成
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

# 构建上三角掩膜矩阵
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

#----------绘制掩膜矩阵图------------#
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(6)[0])
plt.savefig('mask.png')

# 第八步:输入数据
# 词嵌入
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model) #vocab为词表大小
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # 在对数空间中计算,保证数值稳定性
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].clone().detach()
        return self.dropout(x)

#----------绘制位置编码图------------#
## 语句长度为100,假设d_model=20,
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward(torch.zeros(1, 100, 20))  

plt.plot(np.arange(100), y[0, :, :].data.numpy())
plt.legend(["dim %d"%p for p in [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19]])

plt.savefig('positioncode.png')

# 第九步:构建完整网络
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = Transformer(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn),
                             c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab))
    return model

# 测试一个简单模型,输入、目标语句长度分别为10,Encoder、Decoder各2层。
if __name__ == '__main__':
    model = make_model(11, 11, N=2)
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) # 本质上是一个索引向量
    src_mask = torch.ones(1, 1, 10)
    torch.onnx._export(model, (src, src, src_mask, src_mask), 'transformer.onnx')
    print(model)

mobile_vit

MobileViT 是一种轻量级的视觉 Transformer 模型,旨在在移动设备上高效运行。它结合了卷积神经网络(CNN)和 Transformer 的优点,通过引入一种新的模块------MobileViT Block,实现了在保持高性能的同时降低计算复杂度。MobileViT 在图像分类、目标检测和语义分割等任务中表现出色,特别适合资源受限的设备。

MobileViT 的关键特点

  • MobileViT Block:结合了卷积和 Transformer 的优点,通过局部和全局信息的融合来提高模型的表达能力。
  • 轻量级设计:通过使用高效的卷积操作和 Transformer 结构,使得模型在保持高性能的同时具有较低的计算复杂度。
  • 多尺度特征提取:通过多尺度的特征提取,提高了模型对不同尺度特征的捕捉能力。

MobileViT 的基本结构

MobileViT 主要由以下几个部分组成:

  • Stem:初始的卷积层,用于提取低级别的特征。
  • MobileViT Blocks:核心模块,结合了卷积和 Transformer 的优点。
  • Global Pooling:全局池化层,用于将特征图降维。
  • Classification Head:最终的分类头,用于输出分类结果。

MobileViT Block 的详细结构

MobileViT Block 是 MobileViT 的核心模块,其结构如下:

  • Local Representation:通过卷积操作提取局部特征。
  • Global Representation:通过 Transformer 提取全局特征。
  • Fusion:将局部特征和全局特征进行融合。

具体步骤如下:

局部特征提取:

  • 使用卷积层提取局部特征。

全局特征提取:

  • 将局部特征展平并重塑为序列形式。
  • 使用 Transformer 进行全局特征提取。

特征融合:

  • 将全局特征重塑回特征图形式。
  • 与局部特征进行融合。
python 复制代码
import torch
import torch.nn as nn

from einops import rearrange
from einops.layers.torch import Reduce

# helpers
def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

def conv_nxn_bn(inp, oup, kernel_size=3, stride=1):
    return nn.Sequential(
        nn.Conv2d(inp, oup, kernel_size, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.SiLU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(
            t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

# MobileNetV2 Block
class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expansion == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride,
                          1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        out = self.conv(x)
        if self.use_res_connect:
            out = out + x
        return out

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x):
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)

        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d',
                      ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)',
                      h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)

        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x

# MobileViT模型定义
class MobileViT(nn.Module):
    def __init__(
        self,
        image_size,
        dims,
        channels,
        num_classes,
        expansion=4,
        kernel_size=3,
        patch_size=(2, 2),
        depths=(2, 4, 3)
    ):
        super().__init__()
        assert len(dims) == 3, 'dims must be a tuple of 3'
        assert len(depths) == 3, 'depths must be a tuple of 3'

        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        init_dim, *_, last_dim = channels

        self.conv1 = conv_nxn_bn(3, init_dim, stride=2)

        self.stem = nn.ModuleList([])
        self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))

        self.trunk = nn.ModuleList([])
        self.trunk.append(nn.ModuleList([
            MV2Block(channels[3], channels[4], 2, expansion),
            MobileViTBlock(dims[0], depths[0], channels[5],
                           kernel_size, patch_size, int(dims[0] * 2))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[5], channels[6], 2, expansion),
            MobileViTBlock(dims[1], depths[1], channels[7],
                           kernel_size, patch_size, int(dims[1] * 4))
        ]))

        self.trunk.append(nn.ModuleList([
            MV2Block(channels[7], channels[8], 2, expansion),
            MobileViTBlock(dims[2], depths[2], channels[9],
                           kernel_size, patch_size, int(dims[2] * 4))
        ]))

        self.to_logits = nn.Sequential(
            conv_1x1_bn(channels[-2], last_dim),
            Reduce('b c h w -> b c', 'mean'),
            nn.Linear(channels[-1], num_classes, bias=False)
        )

    def forward(self, x):
        x = self.conv1(x)

        for conv in self.stem:
            x = conv(x)

        for conv, attn in self.trunk:
            x = conv(x)
            x = attn(x)

        return self.to_logits(x)

if __name__ == '__main__':
    mbvit_xs = MobileViT(
        image_size=(256, 256),
        dims=[96, 120, 144],
        channels=[16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
        num_classes=1000
    )

    img = torch.randn(1, 3, 256, 256)
    pred = mbvit_xs(img)
    print(pred.shape)
    torch.onnx.export(mbvit_xs, img, 'mbvit_xs.onnx')

simple_vit

SimpleViT 是一个简化版的 Vision Transformer (ViT) 模型,旨在降低原始 ViT 的复杂性和计算成本,同时保持一定的性能。ViT 模型的核心思想是将图像视为一系列的 patch,然后通过 Transformer 架构进行处理。SimpleViT 通常会减少层数、隐藏维度或其他参数,以适应更小的计算资源或更快的推理速度。

SimpleViT 的关键特点

  • 简化架构:减少了原始 ViT 的层数和隐藏维度,降低了模型的复杂度。
  • 高效性:适合在资源受限的环境中运行,如移动设备或嵌入式系统。
  • 易于实现:代码相对简单,便于理解和修改。

SimpleViT 的基本结构

SimpleViT 主要由以下几个部分组成:

  • Patch Embedding:将图像分割成一系列 patch,并将每个 patch 转换为嵌入向量。
  • Positional Encoding:为每个 patch 添加位置信息,以便模型能够理解 patch 的顺序。
  • Transformer Encoder:通过多个 Transformer 编码器层对 patch 嵌入进行处理。
  • Classification Head:最终的分类头,用于输出分类结果。

Patch Embedding

Patch Embedding 将图像分割成一系列固定大小的 patch,并将每个 patch 转换为一个嵌入向量。具体步骤如下:

  • 分割图像:将图像分割成一系列固定大小的 patch。
  • 线性变换:将每个 patch 展平并通过线性变换转换为嵌入向量。
  • 添加类别标记:在嵌入向量序列的开头添加一个类别标记(cls token),用于最终的分类任务。

Positional Encoding

由于 Transformer 模型本身不包含序列顺序信息,因此需要添加位置编码来保留 patch 的顺序信息。位置编码可以通过正弦和余弦函数生成。

Transformer Encoder

Transformer Encoder 包含多个相同的编码器层,每个编码器层包含多头自注意力机制和前馈神经网络。具体步骤如下:

  • 多头自注意力机制:对 patch 嵌入进行自注意力计算。
  • 前馈神经网络:对每个位置的特征进行非线性变换。
  • 残差连接:在每个子层之后添加残差连接,以缓解梯度消失问题。
  • 层归一化:在每个子层之后进行层归一化,以稳定训练过程。

Classification Head

Classification Head 通过一个线性变换将最终的类别标记嵌入转换为分类结果。具体步骤如下:

  • 获取类别标记嵌入:从 Transformer Encoder 的输出中获取类别标记嵌入。
  • 线性变换:将类别标记嵌入通过线性变换转换为分类结果。
python 复制代码
import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

## 位置编码向量
def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
    _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
    # 假设h=8,w=8,dim=1024
    y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
    assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
    omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) # shape=1024/4=256
    omega = 1. / (temperature ** omega)

    print('y.shape',y.shape) # [8,8]
    print('x.shape',x.shape) # [8,8]
    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :] 
    print('omega.shape',omega.shape) # [256]
    print('y.shape',y.shape) # [64, 256]
    print('x.shape',x.shape) # [64, 256]

    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
    return pe.type(dtype)

## 前向计算模块定义,包括两个全连接层
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

## 注意力模块定义
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

## Transformer模型定义
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

## SimpleViT模型定义
class SimpleViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b h w (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.to_latent = nn.Identity()
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        *_, h, w, dtype = *img.shape, img.dtype

        x = self.to_patch_embedding(img)
        pe = posemb_sincos_2d(x)

        print('x shape',x.shape) # [1,8,8,1024]
        print('pe shape',pe.shape) # [64,1024]

        x = rearrange(x, 'b ... d -> b (...) d') + pe

        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

model = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024, ## token维度
    depth = 6, ## 模块数量
    heads = 16, ## 头的数量
    mlp_dim = 2048 ## mlp隐藏层维度
)

if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)
    preds = model(img)
    print(preds.shape)

    torch.onnx.export(model, img, 'Simple_ViT.onnx')

vit

python 复制代码
#coding:utf8
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

## 预标准化方法
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

## 前向计算模块定义,包括两个全连接层
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

## 注意力模块定义
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads ## 8*64=512
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5 ## 1/sqrt(64)=1/8

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) ## 默认dim=1024,inner_dim * 3 = 512*3

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1) ## 从输入x,生成q,k,v,每一个维度 = inner_dim = dim_head * heads = 512
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale ## q*k/sqrt(d)
        attn = self.attend(dots) ## 得到softmax(q*k/sqrt(d))
        attn = self.dropout(attn)

        out = torch.matmul(attn, v) ## 得到softmax(q*k/sqrt(d))*v
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

## Transformer模型定义
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x ## 自注意力模块
            x = ff(x) + x ## feedforward模块
        return x

## ViT模型定义
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size) ## 图像尺寸
        patch_height, patch_width = pair(patch_size) ## 裁剪子图尺寸

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width) ## 子图数量
        patch_dim = channels * patch_height * patch_width ## 展平后子图维度
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        ## 把图片切分为patch,然后拉成序列,假设输入图片大小是256x256(b,3,256,256),打算分成64个patch,每个patch是32x32像素,则rearrange操作是先变成(b,3,8x32,8x32),最后变成(b,8x8,32x32x3)即(b,64,3072)
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim), ## 把子图维度映射到特定维度dim,比如32*32*3 -> 1024
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) ## num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) ## 额外的分类token
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        ## 在编码器后接fc分类器head即可
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        print('x shape', x.shape)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1) ## 在输入token数量维度进行拼接,## 额外追加token,变成b,65,1024
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] ## 取第0个token的特征,或者所有特征的平均值

        x = self.to_latent(x)
        return self.mlp_head(x)


model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024, ## token维度
    depth = 6, ## 模块数量
    heads = 16, ## 头的数量
    mlp_dim = 2048 ## mlp隐藏层维度
)

if __name__ == '__main__':
    img = torch.randn(1, 3, 256, 256)
    preds = model(img)
    print(preds.shape)

    torch.onnx.export(model, img, 'ViT.onnx')
相关推荐
凳子花❀1 小时前
CNN、RNN、LSTM和Transformer之间的区别和联系
rnn·yolo·cnn·lstm·transformer
goomind1 小时前
深度学习实战智能交通计数
深度学习·yolo·计算机视觉·目标跟踪·bytetrack·deepsort·撞线计数
小陈phd1 小时前
深度学习之超分辨率算法——FRCNN
python·深度学习·神经网络
lly_csdn1232 小时前
【图像配准】方法总结
人工智能·python·深度学习·计算机视觉·图像配准
CSBLOG2 小时前
Day27 - 大模型微调,LLaMA搭建
人工智能·深度学习·llama
小嗷犬2 小时前
【论文笔记】CLIP-guided Prototype Modulating for Few-shot Action Recognition
论文阅读·人工智能·深度学习·神经网络·多模态
pk_xz1234564 小时前
pytorch 实现了使用预训练的VGG16模型对猫和狗的图像进行分类任务
pytorch·深度学习·分类
机器视觉知识推荐、就业指导5 小时前
深度学习OCR与传统OCR对比实验:图像数据集联系博主获取
人工智能·深度学习·ocr
Tiandaren6 小时前
从Python到C++的转变之路——如何高效复现C++开源项目 || Windows || Visual Studio || 持续更新
开发语言·c++·图像处理·人工智能·python·深度学习·开源