Vlm-vit模型

Vision Transformer (ViT) 解析

一、核心思想:图像即序列

ViT 的本质是将图像视为一系列 "图像单词"(image tokens),完全沿用 NLP 中 Transformer 的架构来处理视觉数据:

  1. 将输入图像分割为固定大小的非重叠图像块 (patches)
  2. 将每个图像块线性投影为向量 (embedding)
  3. 添加位置编码以保留空间信息
  4. 输入 Transformer 编码器进行全局特征学习
  5. 用特殊的 [CLS] token 进行分类预测

这与 CNN 的局部卷积 + 池化的分层特征提取范式形成鲜明对比。

二、ViT 完整架构详解

1. 输入处理模块

图像分块 (Patch Embedding)

  • 输入:RGB 图像,尺寸为 H×W×3(高度 × 宽度 × 通道数)
  • 分块:使用大小为 P×P 的非重叠窗口分割图像,得到 N = (H/P) × (W/P) 个图像块
    • 标准设置:P=16,对于 224×224 图像,N=14×14=196 个块
  • 展平:每个图像块展平为长度为 P²×3 的向量(例如 16×16×3=768)
  • 线性投影:通过全连接层将展平向量映射到 D 维嵌入空间(通常 D=768)

数学表达

python 复制代码
x_patch = flatten(patch)  # 形状: (P²×3)
z_0^p = Linear(x_patch)   # 形状: (D), p=1..N

特殊 Token 与位置编码

  1. [CLS] Token :在图像块序列前添加一个可学习的嵌入向量 z_0^0,用于最终分类,它会聚合整个图像的全局信息
  2. 位置编码 (Positional Embedding) :由于 Transformer 是顺序无关的,必须添加位置信息以保留空间结构
    • 可学习位置编码:与图像块嵌入维度相同 (D),通过训练学习
    • 固定位置编码:如正弦余弦编码(原 Transformer 使用)
    • 最终输入序列:z_0 = [z_0^0; z_0^1; z_0^2; ...; z_0^N] + E_pos,形状为 (N+1)×D
2. Transformer 编码器模块

ViT 使用标准 Transformer 编码器,由L 个相同层堆叠而成,每一层包含两个子层:

|--------------|-------------------|---------------------------|
| 子层 | 功能 | 结构 |
| 多头自注意力 (MSA) | 捕捉所有图像块间的全局依赖关系 | 多个并行的自注意力头,输出拼接后线性变换 |
| 多层感知机 (MLP) | 对每个位置的特征独立进行非线性变换 | 两层全连接 + GELU 激活 + Dropout |

残差连接与层归一化

  • 每个子层都有残差连接:LayerNorm(x + Sublayer(x))
  • 这是 Transformer 训练稳定性的关键,允许模型深度堆叠

多头自注意力计算

python 复制代码
# 线性投影得到Q, K, V
Q = Linear(z_l-1)  # (N+1)×D
K = Linear(z_l-1)  # (N+1)×D
V = Linear(z_l-1)  # (N+1)×D

# 分割为M个头
Q = split_heads(Q, M)  # M×(N+1)×(D/M)
K = split_heads(K, M)  # M×(N+1)×(D/M)
V = split_heads(V, M)  # M×(N+1)×(D/M)

# 自注意力计算
attn_scores = softmax((Q @ K^T) / sqrt(D/M))  # M×(N+1)×(N+1)
attn_output = attn_scores @ V                 # M×(N+1)×(D/M)

# 拼接并线性变换
attn_output = concat_heads(attn_output)       # (N+1)×D
z_l^1 = Linear(attn_output)                   # (N+1)×D

MLP 计算

python 复制代码
z_l^2 = MLP(LayerNorm(z_l^1 + z_l-1))  # 残差连接+层归一化
z_l = LayerNorm(z_l^2 + z_l^1)         # 第二个残差连接+层归一化
3. 分类头模块
  • 取最后一层编码器输出中 **[CLS] token** 的表示:z_L^0
  • 通过 MLP Head(通常是两层全连接)映射到类别空间:
python 复制代码
y = MLP_Head(z_L^0)  # 形状: (num_classes)

三、ViT 变体与尺度

Google 提供了多种尺度的 ViT 模型,参数如下:

|-----------|--------|-----------|-----------|-----------|------|
| 模型 | 层数 (L) | 隐藏层维度 (D) | 注意力头数 (M) | 图像块大小 (P) | 参数数量 |
| ViT-Base | 12 | 768 | 12 | 16 | 86M |
| ViT-Large | 24 | 1024 | 16 | 16 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 16 | 632M |

训练策略

  1. 大规模预训练:在 JFT-300M 等超大数据集上预训练,再微调至下游任务
  2. 混合精度训练:使用 FP16 加速训练并减少内存占用
  3. 数据增强:大量使用 RandAugment 等增强技术提升泛化能力
相关推荐
沛沛老爹2 小时前
Web开发者转型AI:Agent Skills版本控制与管理实战——从Git到AI技能仓库
java·前端·人工智能·git·架构·rag
李莫若2 小时前
2026权威评测AI学术写作工具全面对比:AIPaperGPT以一站式服务与强保障体系成为全能冠军
人工智能
weixin_462446232 小时前
使用 Chainlit +langchain+ LangGraph + MCP + Ollama 构建可视化 AI 工具 Agent(完整实战)
人工智能·langchain·agent·ai聊天·mcp server
小郭团队2 小时前
1_5_五段式SVPWM (传统算法反正切+DPWM1)算法理论与 MATLAB 实现详解
人工智能·嵌入式硬件·算法·dsp开发
有Li2 小时前
DACG:用于放射学报告生成的双重注意力和上下文引导模型/文献速递-基于人工智能的医学影像技术
论文阅读·人工智能·文献·医学生
时间会给答案scidag2 小时前
Spring AI Alibaba 学习day01
人工智能·学习·spring
ghie90902 小时前
基于粒子滤波的多目标检测前跟踪(TBD)MATLAB实现
人工智能·目标检测·matlab
分布式存储与RustFS2 小时前
RustFS在AI场景下的实测:从GPU到存储的完整加速方案
开发语言·人工智能·rust·对象存储·企业存储·rustfs·minio国产化替代
Deepoch2 小时前
Deepoc具身模型开发板:半导体制造智能化的技术引擎
人工智能·开发板·半导体·具身模型·deepoc