前言
最近多模态模型特别火,模型也越来越小,MiniCPM-2.6只有8B,里面采用的图片编码器是SigLipViT模型,一起从头学习ViT和Transformer!本文记录一下学习过程,所以是自上而下的写,从ViT拆到Transformer。
用Transformer来做图像分类?!
- Vision Transformer (ViT)出自ICLR 2021的论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,使用之前做文本任务的Transformer来做图片分类任务
- ViT模型的构成主要包含图像切片、图像映射、Transformer模块和分类头
ViT整体工作流程
假设输入图片尺寸image_size是 224 × 224 224 \times 224 224×224,子图大小(patch_size)为16,图片编码维度(hidden_dim)为768,
当1张224*224的图输入ViT后(批大小batch_size=1)会经历:
- 图片切片 -- 图片首先被分割为 16 × 16 16 \times 16 16×16大小的子图,总共 ( 224 / / 16 ) × ( 224 / / 16 ) = 14 × 14 = 196 (224//16) \times (224//16)=14 \times 14=196 (224//16)×(224//16)=14×14=196个
- 图片映射 -- 子图被分别送到Linear Projection这个模块进行映射,得到大小为[1,768,196]的向量
- 变换一下维度便于输入Transformer,所有子图拼成的图片隐向量维度为[1,196,768];
- 分类token -- 在输入Transformer前,为了与bert架构统一,也使用一个类似[CLS]的标记,在图片隐向量前面插入一个class_token,最终输入Transformer的向量大小为[1,197,768]
- 位置编码 -- 随机初始化的pos_embedding大小也是[1,197,768],加到图片向量上
- 输入Transformer,编码器输入输出维度一致,输出的维度是[1,197,768]
- 输出分类结果 -- 取class_token对应的输出向量输入分类头
*需要注意的是:分类任务不一定要取class_token对应的向量,也可在最后一个Transformer块的输出接一个global average pooling层再接MLP分类层,特定学习率参数情况下效果类似;ViT是为了和bert架构统一所以加入了class_token
ViT源代码拆解
1. VisionTransformer类的forward()
在torchvision代码中可以找到ViT的torch官方实现
python
def forward(self, x: torch.Tensor):
# 图片切片、图片编码并把图片向量调整为transformer能接受的维度
x = self._process_input(x)
n = x.shape[0] # n是batch size
# 给这个batch的n个图片向量最前面,都加入一个class_token,类似[CLS]
batch_class_token = self.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
# 图片向量用Transformer的block进行处理
x = self.encoder(x)
# 取class_token对应的向量,x是[1,197,768],x[:,0]表示x[:,0,:]
x = x[:, 0]
# 输入分类头进行分类任务
x = self.heads(x)
return x
2.图片切片与编码------VisionTransformer类的_process_input()
- ViT框架图里面的Linear Projection模块实际上是用一个nn.Con2d隐式实现的
- nn.Con2d起到的作用和单独把一个个子图放到Linear层编码是一样的
- 所以实际上图片编码后的维度为[1,768,14,14]--> [1,196,768]
python
........
self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def _process_input(self, x: torch.Tensor) -> torch.Tensor: n, c, h, w = x.shape
# 图片维度为(n, c, h, w),n是batchsize,c是图像通道数一般为3,h/w是图像高宽
p = self.patch_size # 图片切片大小,例如为16,子图大小为patch_size*patch_size
n_h = h // p # 图片切片,高度维度切的片数
n_w = w // p # 图片切片,高度维度切的片数
# (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
x = self.conv_proj(x)
# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)),进行展平操作
x = x.reshape(n, self.hidden_dim, n_h * n_w)
# Transformer期望的输入维度是(N,S,E),N是batchsize,S是序列长度,E是文本编码隐向量维度
# 所以把维度变换一下,permute(0,2,1)表示把第0维放最前面,第2维放中间,第1维放后面
x = x.permute(0, 2, 1) # 得到(n, (n_h * n_w), hidden_dim), n_h * n_w是子图数,类似文本序列长度
return x
其中,对于卷积操作而言
python
self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
- 默认hidden_dim=768,patch_size=16,卷积核个数也就是输出的特征图通道数为768
- 卷积核大小为16,步长也是16,可以保证卷积扫描的时候每次正好对一个子图做运算,子图互相之间不重叠,一个卷积核卷积运算的次数为(224//16)* (224//16)正好是14*14,每个运算值对应一个子图
- 有768个卷积核,所以输出的大小为(n,768,14,14),对RGB图像而言卷积核也是个[3,16,16]的矩阵
- RGB图像的卷积如下,RGB分别计算后相加,这只是1/768个卷积核的计算结果,所有结果拼接为矩阵
3. Transformer的Encoder
3.1 Encoder的forward()
ViT使用的是Encoder for sequence to sequence translation
python
......
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
3.2 Transformer的Encoder Block
主要包含self_attention结构,在self-attention中每个patch和patch之间计算相似度,学习patch间的关系
python
......
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim) # 层归一化,是对单个样本在其特征维度(最后一个维度)上进行的归一化
self.self_attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(
input.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}"
)
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input # 残差连接
y = self.ln_2(x)
y = self.mlp(y)
return x + y
3.3 Transformer的Encoder Block的MultiheadAttention
关于代码:
-
多头注意力模块nn.MultiHeadAttention,forward方法在torch.nn.functional中,在这里之前ViT代码中已经统一把向量变换为(L,N,E)的形状
-
q(L,N,E) k(S,N,E) v(S,N,E) output(L,N,E)
-
L is the target
length, S is the sequence length, H is the number of attention heads,
N is the batch size, and E is the embedding dimension
-
nn.MultiHeadAttention的attention有一版注释的代码也在源文件中,搜索" multihead attention"往下翻
-
*torchtext.nn.modules.multiheadattention的多头注意力模块代码更简洁一些
下面是torchtext.nn.modules.multiheadattention的多头注意力模块代码:
python
# 假设这是在一个类的方法中定义的
......
if self.batch_first: # 如果是batch_first的先从(N, L, E)变为(L, N, E)形式
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)
# 获取维度信息
tgt_len, src_len, bsz, embed_dim = (
query.size(-3),
key.size(-3),
query.size(-2),
query.size(-1)
)
# 分别乘qkv矩阵得到qkv
q, k, v = self.in_proj_container(query, key, value)
# 确保query的embed_dim可以被head数整除
assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads"
head_dim = q.size(-1) // self.nhead
q = q.reshape(tgt_len, bsz * self.nhead, head_dim)
# 确保key的embed_dim可以被head数整除
assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads"
head_dim = k.size(-1) // self.nhead
k = k.reshape(src_len, bsz * self.nhead, head_dim)
# 确保value的embed_dim可以被head数整除
assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads"
head_dim = v.size(-1) // self.nhead
v = v.reshape(src_len, bsz * self.nhead, head_dim)
# 计算注意力输出和权重
attn_output, attn_output_weights = self.attention_layer(
q, k, v,
attn_mask=attn_mask,
bias_k=bias_k,
bias_v=bias_v
)
# 将输出重新调整为原始形状
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)
# 如果是batch_first从(L, N, E)变回去(N, L, E),编码器输入输出形状保持一致
if self.batch_first:
attn_output = attn_output.transpose(-3, -2)
return attn_output, attn_output_weights
3.4 Transformer的Encoder Block的ScaledDotProduct
- torchtext.nn.modules.multiheadattention的self.attention_layer是ScaledDotProduct
- query: (L, N * H, E / H) , key: (S, N * H, E / H),self-attantion中L=E
- 计算注意力权重 : matmul(query,key)
- 权重归一化:对 attn_output_weights 进行 softmax 归一化时,希望确保每个查询位置(L)对所有键位置(S)的注意力权重之和为 1。因此,我们需要沿着最后一个维度 S 进行 softmax 归一化,即 dim=-1
- 加权求和:matmul(att_output_weights, value)
python
# Scale query
# 变成(N*H,L,E/H)
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
query = query * (float(head_dim) ** -0.5)
# Dot product of q, k
#(N*H,L,E/H) × (N*H, E/H, S),matmul计算最后2维,也就是[N*H,:,:]×[N*H,:,:],得到[N*H,L,S]
attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) # (N*H, L, S)
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_output_weights, value) # (N*H, L, E/H)
self-attention的直观解释-b站视频
Attention的解释有一个b站上搬运的视频非常直观,attention可以关注到全局上信息的关联,卷积只能关注到局部的信息
-
假设图片有4个像素,RGB三个通道,表示起来x是[4,3]的矩阵
-
如果隐空间维度hidden_dim=2,输入x乘以[3,2]的Wq/Wk/Wv矩阵可以得到[4,2]的Q/K/V向量
-
计算相似性度量 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 Q∙KT
-
注意到每次是向量的点积运算,例如Q的第4行表示q4,K的第4列表示k4,计算的实际上是向量相似度,得到的 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 Q∙KT是每个像素间的相似度矩阵
-
在self-attention的计算中涉及到除以√𝑑放缩,否则维度越大雅可比矩阵接近零矩阵梯度消失,详细原理可以在文末知乎专栏中找到
-
计算softmax进行归一化,因为每一行是一个像素和其它像素的相似度,所以预期是每行概率值相加为1,对列做softmax:
-
最后乘以V矩阵完成注意力计算:
-
左边的相似度矩阵可以理解为权重,乘以V矩阵类似加权平均
-
例如0.23表示第1个像素关注第1个像素的程度,0.33表示第1个像素关注第2个像素的程度
参考链接
- b站attention视频讲解:https://www.bilibili.com/video/BV1Ke411X7t7
- 知乎解释为什么attention需要除以√𝑑放缩:https://zhuanlan.zhihu.com/p/503321685