一、基础概念
1)nn.Module 是什么
在 PyTorch 里,所有神经网络模块几乎都继承自 nn.Module。
比如:
- 一个卷积层
nn.Conv2d(...) - 一个全连接层
nn.Linear(...) - 一个激活函数模块
- 甚至整个大模型
本质上都可以看成"一个模块"。
你这里写的是:
python
class RGBViTBranch(nn.Module):
意思就是:
我现在要自己定义一个神经网络模块,这个模块叫
RGBViTBranch。
它以后就能像普通层一样使用,例如:
python
branch = RGBViTBranch()
cls_token, patch_tokens = branch(x)
2)__init__ 是什么
python
def __init__(self, ...):
这是 Python 类的初始化函数。
你创建对象时:
python
branch = RGBViTBranch()
就会自动执行 __init__ 里的内容。
它主要负责:
- 定义这个模块里有哪些层
- 加载预训练模型
- 设置参数
- 做初始化工作
3)super().__init__() 是什么
这句很重要。因为你继承的是 nn.Module,所以你必须先调用父类 nn.Module 的初始化逻辑。
它会帮你建立很多内部机制,比如:
- 记录有哪些参数
- 支持
.parameters() - 支持
.to(device) - 支持
.train()/.eval() - 支持反向传播
你可以简单理解为:
"先把 PyTorch 模块该有的底层功能装好,再开始定义我自己的东西。"
4)forward() 是什么
在 PyTorch 里,forward() 就是:
这个模块真正怎么计算
比如:
python
def forward(self, x):
...
return output
当你写:
python
y = model(x)
PyTorch 实际上会自动去调用:
python
model.forward(x)
所以 forward 就是"前向传播"的定义。
5)nn.Identity() 是什么
python
vit.heads = nn.Identity()
nn.Identity() 是 PyTorch 自带的一个模块,作用非常简单:
输入什么,就原样输出什么,啥也不做。
等价于:
python
class Identity(nn.Module):
def forward(self, x):
return x
所以这里的意思是:原来 vit.heads 是分类头,负责把特征变成类别分数;
现在我们不想让它分类,只想拿特征,所以把它换成"空操作"。
二、ViT
1)CNN 是怎么看图像的
先对比一下 CNN。
CNN 会把图像当成一个二维网格:
- 高 H
- 宽 W
- 通道 C
比如:[B, 3, 224, 224]
表示:
B:batch size,多少张图一起进来3:RGB 三个通道224,224:图像大小
CNN 会用卷积核在图像上滑动,逐步提特征。
2)ViT 的核心思想
ViT 不像 CNN 那样直接做卷积堆叠,而是先把图像切成小块,再把这些小块当成"token"送进 Transformer。
这和 NLP 很像:
- NLP:一句话切成一个个单词 token
- ViT:一张图切成一个个 patch token
3)224×224 的图像怎么切 patch
ViT-B/16 里的 16 表示 patch size = 16。
所以一张 224×224 的图,会被切成:
- 横向:224 / 16 = 14
- 纵向:224 / 16 = 14
总 patch 数:14 × 14 = 196
所以图像会变成 196 个 patch。
4)每个 patch 怎么变成向量
每个 patch 原本大小是: 16 × 16 × 3
这是一小块图像数据。
ViT 会把每个 patch 映射成一个固定维度的向量,比如 768 维。
于是整张图就变成[B, 196, 768]
含义是:
- B 张图
- 每张图 196 个 patch
- 每个 patch 是 768 维特征
5)为什么会多一个 cls token
一张 224×224 的图,被切成 16×16 的小块后,会得到 14×14=196 个 patch。
你可以把这 196 个 patch 想成:
- patch1:看左上角那一小块
- patch2:看旁边那一小块
- ...
- patch196:看右下角那一小块
每个 patch 都只知道自己那一块局部信息。那我怎么得出"这整棵树是什么树种"这个总体结论?这时候就需要一个专门负责"汇总大家意见"的角色。这个角色就是:
cls token
它一开始其实什么都不知道 ,只是一个可学习的向量。
但经过每一层 self-attention,它都会去"听"别的 patch 在说什么,把信息一点点收集过来。
最后它就变成了:
"我综合了整张图所有局部信息后,形成的整图总结。"
所以它不是凭空来的,也不是人为写死的,而是在 Transformer 一层一层的信息交互中,被训练成一个'全局摘要'。
它到底"看"到了什么
在 self-attention 里,每个 token 都可以和别的 token 交互。
所以 cls token 会去看:
- patch1
- patch2
- patch3
- ...
- patch196
它会给不同 patch 不同权重。比如:
- 叶子纹理相关 patch 权重大
- 背景块权重小
- 阴影块权重适中
- 无信息块几乎忽略
这样经过很多层后,cls token 就逐渐从"空白总结员"变成"很会抓重点的全局特征"。
所以 ViT 会在 196 个 patch token 前面,额外拼一个特殊 token:cls token
这样 token 总数就变成:197 = 1 + 196
所以输入 Transformer 的形状变成:[B, 197, 768]
其中:
- 第 0 个 token:
cls token - 后面 196 个 token:patch token
最后输出时:
cls token常用来表示整张图的全局特征patch tokens表示各局部 patch 的特征
6)位置编码 pos_embedding 是什么
Transformer 本身不懂顺序。
在 NLP 中,它不知道"我"在前还是"你"在前;
在图像里,它也不知道某个 patch 是左上角还是右下角。
所以要加一个 位置编码(position embedding),告诉模型:
这个 patch 在图上的哪个位置。
ViT 会给每个 token 加上位置向量。
7)Transformer Encoder 做了什么
ViT 中的 encoder 本质上是很多层 Transformer block 叠起来。
每层 block 主要做两件事:
- Self-Attention
- 每个 token 都可以看别的 token
- 建立全局关系
- MLP
- 对每个 token 的特征做进一步变换
所以经过很多层后:
cls token会综合所有 patch 的信息,变成全局图像表示patch tokens会变成更高级的局部语义特征
四、初始化代码详解
1)类定义
python
class RGBViTBranch(nn.Module):
定义一个 PyTorch 模块,名字叫 RGBViTBranch。
2)初始化函数
python
def __init__(
self,
fine_tune_last_n_blocks=4,
freeze_patch_embed=True,
freeze_pos_embed=True
):
这里定义了 3 个可调参数。
参数1:fine_tune_last_n_blocks=4
意思是:
ViT 有很多 Transformer block,我只放开最后 4 个 block 训练,前面的都冻结。
为什么这样做?
因为你通常是用预训练 ViT:
- 前面的层学到的是比较通用的特征
- 后面的层更适合根据你的任务微调
这样做有几个好处:
- 防止小数据集过拟合
- 训练更稳定
- 显存和计算压力更小
- 不容易把预训练知识毁掉
参数2:freeze_patch_embed=True
意思是:
patch embedding 层默认冻结,不训练。
参数3:freeze_pos_embed=True
意思是:
class token 和位置编码默认冻结,不训练。
3)调用父类初始化
python
super().__init__()
这句前面讲了,是为了让这个类真正成为一个 PyTorch 模块。
4)加载预训练 ViT
python
vit = load_vit_b16_backbone_safely(
local_weight_path=LOCAL_VIT_WEIGHT_PATH,
allow_auto_download=ALLOW_AUTO_DOWNLOAD
)
这不是 PyTorch 自带函数,自己写的函数。它的作用是:
- 创建一个
vit_b_16模型 - 尝试从本地加载权重
- 本地没有再自动下载
- 最后返回这个 ViT 模型
5)去掉原分类头
python
vit.heads = nn.Identity()
原来的 ViT 末尾有分类头,比如把 768 维特征变成 1000 类 logits。
但你这里不是拿它直接分类,所以把分类头删除,换成"原样输出"。
所以现在这个 ViT 相当于只保留"骨干特征提取部分"。
6)保存为成员变量
python
self.backbone = vit
把这个 ViT 保存到当前模块里,后面 forward() 和其他函数里都能用。
7)冻结部分层
python
self.freeze_layers(
fine_tune_last_n_blocks=fine_tune_last_n_blocks,
freeze_patch_embed=freeze_patch_embed,
freeze_pos_embed=freeze_pos_embed
)
这一步就是:
按你的策略,决定哪些层参与训练,哪些层不参与训练。
五、freeze_layers() 详细拆解
冻结到底是什么含义
冻结不是"这一层不参与前向传播",而是:
前向照常算,反向时不更新参数。
也就是说:
- 冻结层仍然参与特征提取
- 只是参数不变
1)先全部冻结
python
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone.parameters() 会拿到这个 ViT 模型的所有参数。
p.requires_grad = False 的意思是:
这个参数不参与梯度计算,不更新。
也就是冻结。
所以这一步相当于:
先一刀切,全部锁死。
2)放开 encoder 最后的 layer norm
python
for p in self.backbone.encoder.ln.parameters():
p.requires_grad = True
这里把 encoder 最后的 layer norm 放开。为什么?
因为最后的归一化层离输出最近,适当让它可训练,有时会帮助特征适配你的任务。
你可以先理解为:
给输出端留一点可调空间。
3)放开最后 n 个 Transformer block
python
total_blocks = len(self.backbone.encoder.layers)
start_idx = max(0, total_blocks - fine_tune_last_n_blocks)
for i in range(start_idx, total_blocks):
for p in self.backbone.encoder.layers[i].parameters():
p.requires_grad = True
这里是关键。
total_blocks
ViT encoder 里有多层 Transformer block。
比如 ViT-B/16 通常有 12 层 block。
所以:total_blocks = len(self.backbone.encoder.layers)
就是得到 block 总数,比如 12。
start_idx
如果你要放开最后 4 层,那么起始位置就是:12 - 4 = 8
所以开放:第 8 层、第 9 层、第 10 层、第 11 层
这一段整体意思
只训练 ViT 最后几层,前面大部分层保持预训练状态不动。
这叫 partial fine-tuning(部分微调)。
4)决定是否放开 patch embedding
python
if not freeze_patch_embed:
for p in self.backbone.conv_proj.parameters():
p.requires_grad = True
这里的 conv_proj 就是 patch embedding 层。
在 torchvision 的 ViT 里,它通常是一个卷积层,用来做 patchify + 投影。
为什么卷积能做 patchify?
比如 kernel size = 16,stride = 16:
- 每次正好取一个 16×16 patch
- 不重叠地扫完整张图
- 每个 patch 被投影成 embedding
所以它本质上就在做:
"把图像切块,并把每个块变成 token 向量"
这段的意思:
如果 freeze_patch_embed=False,就允许 patch embedding 层训练;否则保持冻结。
5)决定是否冻结 class token 和位置编码
python
if not freeze_pos_embed:
self.backbone.class_token.requires_grad = True
self.backbone.encoder.pos_embedding.requires_grad = True
else:
self.backbone.class_token.requires_grad = False
self.backbone.encoder.pos_embedding.requires_grad = False
class_token
这是那个额外加在最前面的全局 token。它本身也是一个可学习参数。
pos_embedding
这是位置编码,也是可学习参数。
为什么有时冻结,有时不冻结?
因为这两个参数和模型原始输入分布绑定较强。
如果你的输入尺寸和预训练时相同(224×224),而且数据量又不大,冻结往往更稳。
如果你的任务和原任务差异很大,或者你想让模型更充分适配,也可以放开。
六、forward() 详解
输入说明
输入 x 是一批 RGB 图像:
- B:batch size
- 3:RGB 通道
- 224×224:图像尺寸
1)取 batch size
python
n = x.shape[0]
x.shape[0] 就是 batch size,也就是 B。
2)图像转成 patch tokens
python
x = self.backbone._process_input(x) # [B, 196, 768]
这是 torchvision ViT 里的内部函数。_process_input(x) = "把图像变成 patch token 序列"
它做的事大致包括:
- 把图像切成 patch
- 做 patch embedding
- 调整维度格式
结果输出:[B, 196, 768]
也就是:
- 每张图被切成 196 个 patch
- 每个 patch 变成 768 维特征
3)扩展 cls_token
python
cls_token = self.backbone.class_token.expand(n, -1, -1) # [B,1,768]
原始的 class_token 通常是一个可学习参数,形状大概是:[1, 1, 768]
它表示"一个模板 cls token"。
但现在一个 batch 有 n 张图,所以要给每张图都配一个 cls token。
于是用 expand 扩成:[B, 1, 768]
expand(n, -1, -1) 怎么理解
- 第一个维度变成 n
- 后面两个维度保持不变
如果原来是 [1,1,768],扩展后就是 [B,1,768]
4)把 cls token 拼到 patch tokens 前面
python
x = torch.cat([cls_token, x], dim=1) # [B,197,768]
torch.cat 是 PyTorch 里的拼接函数。
这里在第 1 维拼接,也就是 token 维。
原来:
cls_token:[B,1,768]x:[B,196,768]
拼接后:[B,197,768]
含义:
- 第 0 个 token 是
cls token - 后面 196 个是 patch token
5)送入 Transformer Encoder
python
x = self.backbone.encoder(x)
这里的 encoder 内部已经包括了:
- 加位置编码
pos_embedding - dropout
- 多层 Transformer block
- 每层 block 里做 self-attention
- 再做 MLP
- 最后的 layer norm
所以经过这一步后:x.shape = [B,197,768]
形状没变,但每个 token 都变成了"融合上下文后的高级特征"。
self attention具体解析如下连接:
6)把输出拆成 cls 和 patch
python
cls_token = x[:, :1, :] # [B,1,768]
patch_tokens = x[:, 1:, :] # [B,196,768]
这里是张量切片。
x[:, :1, :]
意思是:
- 第 0 维
::所有 batch - 第 1 维
:1:取第一个 token - 第 2 维
::取全部特征维度
所以得到:[B,1,768]。这就是最终的 cls token。
x[:, 1:, :]
意思是:
- 所有 batch
- 从第 1 个 token 到最后
- 所有特征维度
所以得到:[B,196,768]。这就是所有 patch token。
7)返回结果
python
return cls_token, patch_tokens
所以这个分支最后输出的是两种特征:
cls_token:全局图像表示patch_tokens:局部 patch 表示
七、极简版结构图
输入图像 [B,3,224,224]
↓
切成 196 个 patch
↓
每个 patch → 768维
↓
得到 [B,196,768]
↓
前面拼一个 cls token
↓
得到 [B,197,768]
↓
加位置编码,进 Transformer Encoder
↓
输出 [B,197,768]
↓
拆成:
cls_token = [B,1,768]
patch_tokens= [B,196,768]
代码
python
class RGBViTBranch(nn.Module):
def __init__(
self,
fine_tune_last_n_blocks=4, #只放开最后 4 个 block 参与训练
freeze_patch_embed=True, # 是否冻结 patch embedding 层。
freeze_pos_embed=True # 是否冻结class token, pos embedding 层。
):
super().__init__() # 初始化 nn.Module 的内部机制
vit = load_vit_b16_backbone_safely(
local_weight_path=LOCAL_VIT_WEIGHT_PATH, # 优先从本地加载 vit_b_16 权重
allow_auto_download=ALLOW_AUTO_DOWNLOAD # 允许自动下载
)
# 去掉原分类头
vit.heads = nn.Identity() #不是直接拿它分类,而是要提取特征,所以把原分类头替换成:
self.backbone = vit
self.freeze_layers(
fine_tune_last_n_blocks=fine_tune_last_n_blocks,
freeze_patch_embed=freeze_patch_embed,
freeze_pos_embed=freeze_pos_embed
)
def freeze_layers(self, fine_tune_last_n_blocks=4, freeze_patch_embed=True, freeze_pos_embed=True):
# 先全部冻结
for p in self.backbone.parameters():
p.requires_grad = False
# 放开 encoder 最后的 ln
for p in self.backbone.encoder.ln.parameters():
p.requires_grad = True
# 放开最后 n 个 Transformer block
total_blocks = len(self.backbone.encoder.layers)
start_idx = max(0, total_blocks - fine_tune_last_n_blocks)
for i in range(start_idx, total_blocks):
for p in self.backbone.encoder.layers[i].parameters():
p.requires_grad = True
# patch embedding
if not freeze_patch_embed:
for p in self.backbone.conv_proj.parameters():
p.requires_grad = True
# pos embedding / class token
if not freeze_pos_embed:
self.backbone.class_token.requires_grad = True
self.backbone.encoder.pos_embedding.requires_grad = True
else:
self.backbone.class_token.requires_grad = False
self.backbone.encoder.pos_embedding.requires_grad = False
def forward(self, x):
"""
x: [B, 3, 224, 224]
返回:
cls_token: [B, 1, 768]
patch_tokens:[B, 196, 768]
"""
n = x.shape[0]
# torchvision vit 内部 patchify
x = self.backbone._process_input(x) # [B, 196, 768]
cls_token = self.backbone.class_token.expand(n, -1, -1) # [B,1,768]
x = torch.cat([cls_token, x], dim=1) # [B,197,768]
x = self.backbone.encoder(x) # encoder 内部已包含 pos_embedding + dropout + blocks + ln
cls_token = x[:, :1, :] # [B,1,768]
patch_tokens = x[:, 1:, :] # [B,196,768]
return cls_token, patch_tokens