ViT模型技术学习

前言

最近多模态模型特别火,模型也越来越小,MiniCPM-2.6只有8B,里面采用的图片编码器是SigLipViT模型,一起从头学习ViT和Transformer!本文记录一下学习过程,所以是自上而下的写,从ViT拆到Transformer。

用Transformer来做图像分类?!

  1. Vision Transformer (ViT)出自ICLR 2021的论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,使用之前做文本任务的Transformer来做图片分类任务
  2. ViT模型的构成主要包含图像切片、图像映射、Transformer模块和分类头

ViT整体工作流程

假设输入图片尺寸image_size是 224 × 224 224 \times 224 224×224,子图大小(patch_size)为16,图片编码维度(hidden_dim)为768,

当1张224*224的图输入ViT后(批大小batch_size=1)会经历:

  1. 图片切片 -- 图片首先被分割为 16 × 16 16 \times 16 16×16大小的子图,总共 ( 224 / / 16 ) × ( 224 / / 16 ) = 14 × 14 = 196 (224//16) \times (224//16)=14 \times 14=196 (224//16)×(224//16)=14×14=196个
  2. 图片映射 -- 子图被分别送到Linear Projection这个模块进行映射,得到大小为[1,768,196]的向量
  3. 变换一下维度便于输入Transformer,所有子图拼成的图片隐向量维度为[1,196,768];
  4. 分类token -- 在输入Transformer前,为了与bert架构统一,也使用一个类似[CLS]的标记,在图片隐向量前面插入一个class_token,最终输入Transformer的向量大小为[1,197,768]
  5. 位置编码 -- 随机初始化的pos_embedding大小也是[1,197,768],加到图片向量上
  6. 输入Transformer,编码器输入输出维度一致,输出的维度是[1,197,768]
  7. 输出分类结果 -- 取class_token对应的输出向量输入分类头

*需要注意的是:分类任务不一定要取class_token对应的向量,也可在最后一个Transformer块的输出接一个global average pooling层再接MLP分类层,特定学习率参数情况下效果类似;ViT是为了和bert架构统一所以加入了class_token

ViT源代码拆解

1. VisionTransformer类的forward()

在torchvision代码中可以找到ViT的torch官方实现

python 复制代码
def forward(self, x: torch.Tensor):
    # 图片切片、图片编码并把图片向量调整为transformer能接受的维度
    x = self._process_input(x)
    n = x.shape[0] # n是batch size
    # 给这个batch的n个图片向量最前面,都加入一个class_token,类似[CLS]
    batch_class_token = self.class_token.expand(n, -1, -1)
    x = torch.cat([batch_class_token, x], dim=1)
                   
    # 图片向量用Transformer的block进行处理
    x = self.encoder(x)
    # 取class_token对应的向量,x是[1,197,768],x[:,0]表示x[:,0,:]
    x = x[:, 0]
    # 输入分类头进行分类任务
    x = self.heads(x)
    return x

2.图片切片与编码------VisionTransformer类的_process_input()

  • ViT框架图里面的Linear Projection模块实际上是用一个nn.Con2d隐式实现的
  • nn.Con2d起到的作用和单独把一个个子图放到Linear层编码是一样的
  • 所以实际上图片编码后的维度为[1,768,14,14]--> [1,196,768]
python 复制代码
    ........
    self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def _process_input(self, x: torch.Tensor) -> torch.Tensor:    n, c, h, w = x.shape 
    # 图片维度为(n, c, h, w),n是batchsize,c是图像通道数一般为3,h/w是图像高宽
    p = self.patch_size  # 图片切片大小,例如为16,子图大小为patch_size*patch_size 
    n_h = h // p         # 图片切片,高度维度切的片数
    n_w = w // p         # 图片切片,高度维度切的片数
    # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
    x = self.conv_proj(x)
    # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)),进行展平操作
    x = x.reshape(n, self.hidden_dim, n_h * n_w)    
    # Transformer期望的输入维度是(N,S,E),N是batchsize,S是序列长度,E是文本编码隐向量维度
    # 所以把维度变换一下,permute(0,2,1)表示把第0维放最前面,第2维放中间,第1维放后面
    x = x.permute(0, 2, 1) # 得到(n, (n_h * n_w), hidden_dim), n_h * n_w是子图数,类似文本序列长度
    return x

其中,对于卷积操作而言

python 复制代码
self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
  • 默认hidden_dim=768,patch_size=16,卷积核个数也就是输出的特征图通道数为768
  • 卷积核大小为16,步长也是16,可以保证卷积扫描的时候每次正好对一个子图做运算,子图互相之间不重叠,一个卷积核卷积运算的次数为(224//16)* (224//16)正好是14*14,每个运算值对应一个子图
  • 有768个卷积核,所以输出的大小为(n,768,14,14),对RGB图像而言卷积核也是个[3,16,16]的矩阵
  • RGB图像的卷积如下,RGB分别计算后相加,这只是1/768个卷积核的计算结果,所有结果拼接为矩阵

3. Transformer的Encoder

3.1 Encoder的forward()

ViT使用的是Encoder for sequence to sequence translation

python 复制代码
    ......
    super().__init__()
    # Note that batch_size is on the first dim because
    # we have batch_first=True in nn.MultiAttention() by default
    self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
    self.dropout = nn.Dropout(dropout)
    layers: OrderedDict[str, nn.Module] = OrderedDict()
    for i in range(num_layers):
        layers[f"encoder_layer_{i}"] = EncoderBlock(
        	num_heads, 
        	hidden_dim, 
        	mlp_dim, 
        	dropout,
        	attention_dropout,  
        	norm_layer,)
    self.layers = nn.Sequential(layers)
    self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
    torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
    input = input + self.pos_embedding
    return self.ln(self.layers(self.dropout(input)))

3.2 Transformer的Encoder Block

主要包含self_attention结构,在self-attention中每个patch和patch之间计算相似度,学习patch间的关系

python 复制代码
        ......
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)  # 层归一化,是对单个样本在其特征维度(最后一个维度)上进行的归一化
        self.self_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(
            input.dim() == 3,
            f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}"
        )
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input  # 残差连接

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

3.3 Transformer的Encoder Block的MultiheadAttention

关于代码:

  • 多头注意力模块nn.MultiHeadAttention,forward方法在torch.nn.functional中,在这里之前ViT代码中已经统一把向量变换为(L,N,E)的形状

  • q(L,N,E) k(S,N,E) v(S,N,E) output(L,N,E)

  • L is the target

    length, S is the sequence length, H is the number of attention heads,

    N is the batch size, and E is the embedding dimension

  • nn.MultiHeadAttention的attention有一版注释的代码也在源文件中,搜索" multihead attention"往下翻

  • *torchtext.nn.modules.multiheadattention的多头注意力模块代码更简洁一些

下面是torchtext.nn.modules.multiheadattention的多头注意力模块代码:

python 复制代码
# 假设这是在一个类的方法中定义的
    ......
    if self.batch_first:  # 如果是batch_first的先从(N, L, E)变为(L, N, E)形式
        query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

    # 获取维度信息
    tgt_len, src_len, bsz, embed_dim = (
        query.size(-3),
        key.size(-3),
        query.size(-2),
        query.size(-1)
    )

    # 分别乘qkv矩阵得到qkv
    q, k, v = self.in_proj_container(query, key, value)

    # 确保query的embed_dim可以被head数整除
    assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads"
    head_dim = q.size(-1) // self.nhead
    q = q.reshape(tgt_len, bsz * self.nhead, head_dim)

    # 确保key的embed_dim可以被head数整除
    assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads"
    head_dim = k.size(-1) // self.nhead
    k = k.reshape(src_len, bsz * self.nhead, head_dim)

    # 确保value的embed_dim可以被head数整除
    assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads"
    head_dim = v.size(-1) // self.nhead
    v = v.reshape(src_len, bsz * self.nhead, head_dim)

    # 计算注意力输出和权重
    attn_output, attn_output_weights = self.attention_layer(
        q, k, v,
        attn_mask=attn_mask,
        bias_k=bias_k,
        bias_v=bias_v
    )

    # 将输出重新调整为原始形状
    attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
    attn_output = self.out_proj(attn_output)

    # 如果是batch_first从(L, N, E)变回去(N, L, E),编码器输入输出形状保持一致
    if self.batch_first:
        attn_output = attn_output.transpose(-3, -2)

    return attn_output, attn_output_weights

3.4 Transformer的Encoder Block的ScaledDotProduct

  • torchtext.nn.modules.multiheadattention的self.attention_layer是ScaledDotProduct
  • query: (L, N * H, E / H) , key: (S, N * H, E / H),self-attantion中L=E
  • 计算注意力权重 : matmul(query,key)
  • 权重归一化:对 attn_output_weights 进行 softmax 归一化时,希望确保每个查询位置(L)对所有键位置(S)的注意力权重之和为 1。因此,我们需要沿着最后一个维度 S 进行 softmax 归一化,即 dim=-1
  • 加权求和:matmul(att_output_weights, value)
python 复制代码
# Scale query
# 变成(N*H,L,E/H)
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
query = query * (float(head_dim) ** -0.5)
# Dot product of q, k
#(N*H,L,E/H) ×  (N*H, E/H, S),matmul计算最后2维,也就是[N*H,:,:]×[N*H,:,:],得到[N*H,L,S]
attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) # (N*H, L, S)
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_output_weights, value) # (N*H, L, E/H)

self-attention的直观解释-b站视频

Attention的解释有一个b站上搬运的视频非常直观,attention可以关注到全局上信息的关联,卷积只能关注到局部的信息

  • 假设图片有4个像素,RGB三个通道,表示起来x是[4,3]的矩阵

  • 如果隐空间维度hidden_dim=2,输入x乘以[3,2]的Wq/Wk/Wv矩阵可以得到[4,2]的Q/K/V向量

  • 计算相似性度量 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 Q∙KT

  • 注意到每次是向量的点积运算,例如Q的第4行表示q4,K的第4列表示k4,计算的实际上是向量相似度,得到的 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 Q∙KT是每个像素间的相似度矩阵

  • 在self-attention的计算中涉及到除以√𝑑放缩,否则维度越大雅可比矩阵接近零矩阵梯度消失,详细原理可以在文末知乎专栏中找到

  • 计算softmax进行归一化,因为每一行是一个像素和其它像素的相似度,所以预期是每行概率值相加为1,对列做softmax:

  • 最后乘以V矩阵完成注意力计算:

  • 左边的相似度矩阵可以理解为权重,乘以V矩阵类似加权平均

  • 例如0.23表示第1个像素关注第1个像素的程度,0.33表示第1个像素关注第2个像素的程度

参考链接

  1. b站attention视频讲解:https://www.bilibili.com/video/BV1Ke411X7t7
  2. 知乎解释为什么attention需要除以√𝑑放缩:https://zhuanlan.zhihu.com/p/503321685
相关推荐
AngelPP3 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年4 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼4 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS4 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
warm3snow4 小时前
Claude Code 黑客马拉松:5 个获奖项目,没有一个是"纯码农"做的
ai·大模型·llm·agent·skill·mcp
天翼云开发者社区5 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈5 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang6 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx
shengjk17 小时前
NanoClaw 深度剖析:一个"AI 原生"架构的个人助手是如何运转的?
人工智能