详解Visual Transformer (ViT)网络模型

1. 简介

ViT是2020年Google团队提出的将Transformer应用在图像分类的模型。

ViT原论文中最核心的结论是,当拥有足够多的数据进行预训练的时候,ViT的表现就会超过CNN,突破transformer缺少归纳偏置的限制,可以在下游任务中获得较好的迁移效果。

但是当训练数据集不够大的时候,ViT的表现通常比同等大小的ResNets要差一些,因为Transformer和CNN相比缺少归纳偏置(inductive bias),即一种先验知识,提前做好的假设。

CNN具有两种归纳偏置,一种是局部性 (locality/two-dimensional neighborhood structure),即图片上相邻的区域具有相似的特征;一种是平移不变形 (translation equivariance), <math xmlns="http://www.w3.org/1998/Math/MathML"> f ( g ( x ) ) = g ( f ( x ) ) f(g(x))=g(f(x)) </math>f(g(x))=g(f(x)),其中g代表卷积操作,f代表平移操作。

当CNN具有以上两种归纳偏置,就有了很多先验信息,需要相对少的数据就可以学习一个比较好的模型。

2. ViT模型架构

ViT的工作流程,如下:

  • 将一张图片分成patches
  • 将patches铺平
  • 将铺平后的patches的线性映射到更低维的空间
  • 添加位置embedding编码信息
  • 将图像序列数据送入标准Transformer encoder中去
  • 在较大的数据集上预训练
  • 在下游数据集上微调用于图像分类

模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder
  • MLP Head(最终用于分类的层结构)

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵 [num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。

对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B /16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention 、Dropout和MLP Block几部分组成。

MLP Head

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

3. ViT工作原理

4. 模型搭建参数

论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数

  • Layers就是Transformer Encoder中重复堆叠Encoder Block的次数 L。
  • Hidden Size就是对应通过Embedding层(Patch Embedding + Class Embedding + Position Embedding)后每个token的dim(序列向量的长度)
  • MLP Size是Transformer Encoder中MLP Block第一个全连接的节点个数(是token长度的4倍)
  • Heads代表Transformer中Multi-Head Attention的heads数。

5. ViT进行迁移学习

一、下载源码和预训练模型

  1. 官方源码
  1. 预训练模型下载
python 复制代码
# 通过timm下载(最简单)
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 通过HuggingFace下载
from transformers import ViTModel
model = ViTModel.from_pretrained('google/vit-base-patch16-224')

二、训练前的修改步骤

1. 修改分类头

python 复制代码
import torch.nn as nn
import timm

# 加载预训练模型
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 获取原始分类头特征维度
num_features = model.head.in_features

# 替换为自己的分类头(假设你的数据集有10类)
model.head = nn.Linear(num_features, 10)

# 或者更复杂的分类头
model.head = nn.Sequential(
    nn.Linear(num_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, 10)
)

2. 调整输入尺寸

python 复制代码
# 如果图像尺寸不是224x224,可以选择:
# 方案1:resize图像到224x224
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 方案2:使用支持其他尺寸的ViT变体
model = timm.create_model('vit_base_patch16_384', pretrained=True)  # 384x384

3. 配置训练策略

python 复制代码
# 冻结部分层(可选)
# 冻结所有层
for param in model.parameters():
    param.requires_grad = False
    
# 只解冻分类头
for param in model.head.parameters():
    param.requires_grad = True

# 或者解冻最后几层
for name, param in model.named_parameters():
    if 'blocks.11' in name or 'head' in name:  # 解冻最后一个block和分类头
        param.requires_grad = True

三、完整的迁移学习示例

python 复制代码
import torch
import torch.nn as nn
import timm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. 数据准备
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载你的数据集
train_dataset = datasets.ImageFolder('path/to/train', transform=transform_train)
test_dataset = datasets.ImageFolder('path/to/test', transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 2. 加载预训练模型并修改
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 获取类别数
num_classes = len(train_dataset.classes)

# 替换分类头
model.head = nn.Linear(model.head.in_features, num_classes)

# 3. 训练配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 不同层使用不同学习率
optimizer = torch.optim.Adam([
    {'params': model.patch_embed.parameters(), 'lr': 1e-5},
    {'params': model.blocks.parameters(), 'lr': 1e-5},
    {'params': model.head.parameters(), 'lr': 1e-4}
], weight_decay=1e-4)

criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# 4. 训练循环
for epoch in range(50):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    scheduler.step()
    
    # 验证
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Epoch {epoch}: Accuracy: {100 * correct / total:.2f}%')
  1. 学习率:通常比从头训练小10-100倍(1e-4到1e-5)
  2. Batch Size:根据GPU内存调整,ViT通常需要较小的batch size(16-64)
  3. Epochs:迁移学习通常20-50个epochs就够了
  4. 数据增强:对特定场景很重要,可以提高泛化能力

详解VIT(Vision Transformer)模型原理, 代码级讲解

ViT(Visual Transformer)最通俗易懂的讲解(有代码)

【Transformer系列】深入浅出理解ViT(Vision Transformer)模型

相关推荐
CoovallyAIHub2 天前
仿生学突破:SILD模型如何让无人机在电力线迷宫中发现“隐形威胁”
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
从春晚机器人到零样本革命:YOLO26-Pose姿态估计实战指南
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
Le-DETR:省80%预训练数据,这个实时检测Transformer刷新SOTA|Georgia Tech & 北交大
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
强化学习凭什么比监督学习更聪明?RL的“聪明”并非来自算法,而是因为它学会了“挑食”
深度学习·算法·计算机视觉
CoovallyAIHub2 天前
YOLO-IOD深度解析:打破实时增量目标检测的三重知识冲突
深度学习·算法·计算机视觉
用户1474853079743 天前
AI-动手深度学习环境搭建-d2l
深度学习
OpenBayes贝式计算3 天前
解决视频模型痛点,TurboDiffusion 高效视频扩散生成系统;Google Streetview 涵盖多个国家的街景图像数据集
人工智能·深度学习·机器学习
OpenBayes贝式计算3 天前
OCR教程汇总丨DeepSeek/百度飞桨/华中科大等开源创新技术,实现OCR高精度、本地化部署
人工智能·深度学习·机器学习
在人间耕耘4 天前
HarmonyOS Vision Kit 视觉AI实战:把官方 Demo 改造成一套能长期复用的组件库
人工智能·深度学习·harmonyos