pytorch+vit基础结构

一、基础概念

1)nn.Module 是什么

在 PyTorch 里,所有神经网络模块几乎都继承自 nn.Module

比如:

  • 一个卷积层 nn.Conv2d(...)
  • 一个全连接层 nn.Linear(...)
  • 一个激活函数模块
  • 甚至整个大模型

本质上都可以看成"一个模块"。

你这里写的是:

python 复制代码
class RGBViTBranch(nn.Module):

意思就是:

我现在要自己定义一个神经网络模块,这个模块叫 RGBViTBranch

它以后就能像普通层一样使用,例如:

python 复制代码
branch = RGBViTBranch()
cls_token, patch_tokens = branch(x)

2)__init__ 是什么

python 复制代码
def __init__(self, ...):

这是 Python 类的初始化函数。

你创建对象时:

python 复制代码
branch = RGBViTBranch()

就会自动执行 __init__ 里的内容。

它主要负责:

  • 定义这个模块里有哪些层
  • 加载预训练模型
  • 设置参数
  • 做初始化工作

3)super().__init__() 是什么

这句很重要。因为你继承的是 nn.Module,所以你必须先调用父类 nn.Module 的初始化逻辑。

它会帮你建立很多内部机制,比如:

  • 记录有哪些参数
  • 支持 .parameters()
  • 支持 .to(device)
  • 支持 .train() / .eval()
  • 支持反向传播

你可以简单理解为:

"先把 PyTorch 模块该有的底层功能装好,再开始定义我自己的东西。"

4)forward() 是什么

在 PyTorch 里,forward() 就是:

这个模块真正怎么计算

比如:

python 复制代码
def forward(self, x):
    ...
    return output

当你写:

python 复制代码
y = model(x)

PyTorch 实际上会自动去调用:

python 复制代码
model.forward(x)

所以 forward 就是"前向传播"的定义。

5)nn.Identity() 是什么

python 复制代码
vit.heads = nn.Identity()

nn.Identity() 是 PyTorch 自带的一个模块,作用非常简单:

输入什么,就原样输出什么,啥也不做。

等价于:

python 复制代码
class Identity(nn.Module):
    def forward(self, x):
        return x

所以这里的意思是:原来 vit.heads 是分类头,负责把特征变成类别分数;

现在我们不想让它分类,只想拿特征,所以把它换成"空操作"。

二、ViT

1)CNN 是怎么看图像的

先对比一下 CNN。

CNN 会把图像当成一个二维网格:

  • 高 H
  • 宽 W
  • 通道 C

比如:[B, 3, 224, 224]

表示:

  • B:batch size,多少张图一起进来
  • 3:RGB 三个通道
  • 224,224:图像大小

CNN 会用卷积核在图像上滑动,逐步提特征。

2)ViT 的核心思想

ViT 不像 CNN 那样直接做卷积堆叠,而是先把图像切成小块,再把这些小块当成"token"送进 Transformer。

这和 NLP 很像:

  • NLP:一句话切成一个个单词 token
  • ViT:一张图切成一个个 patch token

3)224×224 的图像怎么切 patch

ViT-B/16 里的 16 表示 patch size = 16。

所以一张 224×224 的图,会被切成:

  • 横向:224 / 16 = 14
  • 纵向:224 / 16 = 14

总 patch 数:14 × 14 = 196

所以图像会变成 196 个 patch

4)每个 patch 怎么变成向量

每个 patch 原本大小是: 16 × 16 × 3

这是一小块图像数据。

ViT 会把每个 patch 映射成一个固定维度的向量,比如 768 维。

于是整张图就变成[B, 196, 768]

含义是:

  • B 张图
  • 每张图 196 个 patch
  • 每个 patch 是 768 维特征

5)为什么会多一个 cls token

一张 224×224 的图,被切成 16×16 的小块后,会得到 14×14=196 个 patch。

你可以把这 196 个 patch 想成:

  • patch1:看左上角那一小块
  • patch2:看旁边那一小块
  • ...
  • patch196:看右下角那一小块

每个 patch 都只知道自己那一块局部信息。那我怎么得出"这整棵树是什么树种"这个总体结论?这时候就需要一个专门负责"汇总大家意见"的角色。这个角色就是:

cls token

它一开始其实什么都不知道 ,只是一个可学习的向量。

但经过每一层 self-attention,它都会去"听"别的 patch 在说什么,把信息一点点收集过来。

最后它就变成了:

"我综合了整张图所有局部信息后,形成的整图总结。"

所以它不是凭空来的,也不是人为写死的,而是在 Transformer 一层一层的信息交互中,被训练成一个'全局摘要'

它到底"看"到了什么

在 self-attention 里,每个 token 都可以和别的 token 交互。

所以 cls token 会去看:

  • patch1
  • patch2
  • patch3
  • ...
  • patch196

它会给不同 patch 不同权重。比如:

  • 叶子纹理相关 patch 权重大
  • 背景块权重小
  • 阴影块权重适中
  • 无信息块几乎忽略

这样经过很多层后,cls token 就逐渐从"空白总结员"变成"很会抓重点的全局特征"。


所以 ViT 会在 196 个 patch token 前面,额外拼一个特殊 token:cls token

这样 token 总数就变成:197 = 1 + 196

所以输入 Transformer 的形状变成:[B, 197, 768]

其中:

  • 第 0 个 token:cls token
  • 后面 196 个 token:patch token

最后输出时:

  • cls token 常用来表示整张图的全局特征
  • patch tokens 表示各局部 patch 的特征

6)位置编码 pos_embedding 是什么

Transformer 本身不懂顺序。

在 NLP 中,它不知道"我"在前还是"你"在前;

在图像里,它也不知道某个 patch 是左上角还是右下角。

所以要加一个 位置编码(position embedding),告诉模型:

这个 patch 在图上的哪个位置。

ViT 会给每个 token 加上位置向量。

7)Transformer Encoder 做了什么

ViT 中的 encoder 本质上是很多层 Transformer block 叠起来。

每层 block 主要做两件事:

  1. Self-Attention
    • 每个 token 都可以看别的 token
    • 建立全局关系
  2. MLP
    • 对每个 token 的特征做进一步变换

所以经过很多层后:

  • cls token 会综合所有 patch 的信息,变成全局图像表示
  • patch tokens 会变成更高级的局部语义特征

四、初始化代码详解

1)类定义

python 复制代码
class RGBViTBranch(nn.Module):

定义一个 PyTorch 模块,名字叫 RGBViTBranch。

2)初始化函数

python 复制代码
def __init__(
    self,
    fine_tune_last_n_blocks=4,
    freeze_patch_embed=True,
    freeze_pos_embed=True
):

这里定义了 3 个可调参数。

参数1:fine_tune_last_n_blocks=4

意思是:

ViT 有很多 Transformer block,我只放开最后 4 个 block 训练,前面的都冻结。

为什么这样做?

因为你通常是用预训练 ViT:

  • 前面的层学到的是比较通用的特征
  • 后面的层更适合根据你的任务微调

这样做有几个好处:

  • 防止小数据集过拟合
  • 训练更稳定
  • 显存和计算压力更小
  • 不容易把预训练知识毁掉

参数2:freeze_patch_embed=True

意思是:

patch embedding 层默认冻结,不训练。


参数3:freeze_pos_embed=True

意思是:

class token 和位置编码默认冻结,不训练。


3)调用父类初始化

python 复制代码
super().__init__()

这句前面讲了,是为了让这个类真正成为一个 PyTorch 模块。

4)加载预训练 ViT

python 复制代码
vit = load_vit_b16_backbone_safely(
    local_weight_path=LOCAL_VIT_WEIGHT_PATH,
    allow_auto_download=ALLOW_AUTO_DOWNLOAD
)

这不是 PyTorch 自带函数,自己写的函数。它的作用是:

  • 创建一个 vit_b_16 模型
  • 尝试从本地加载权重
  • 本地没有再自动下载
  • 最后返回这个 ViT 模型

5)去掉原分类头

python 复制代码
vit.heads = nn.Identity()

原来的 ViT 末尾有分类头,比如把 768 维特征变成 1000 类 logits。

但你这里不是拿它直接分类,所以把分类头删除,换成"原样输出"。

所以现在这个 ViT 相当于只保留"骨干特征提取部分"。


6)保存为成员变量

python 复制代码
self.backbone = vit

把这个 ViT 保存到当前模块里,后面 forward() 和其他函数里都能用。

7)冻结部分层

python 复制代码
self.freeze_layers(
    fine_tune_last_n_blocks=fine_tune_last_n_blocks,
    freeze_patch_embed=freeze_patch_embed,
    freeze_pos_embed=freeze_pos_embed
)

这一步就是:

按你的策略,决定哪些层参与训练,哪些层不参与训练。

五、freeze_layers() 详细拆解

冻结到底是什么含义

冻结不是"这一层不参与前向传播",而是:

前向照常算,反向时不更新参数。

也就是说:

  • 冻结层仍然参与特征提取
  • 只是参数不变

1)先全部冻结

python 复制代码
for p in self.backbone.parameters():
    p.requires_grad = False

self.backbone.parameters() 会拿到这个 ViT 模型的所有参数。

p.requires_grad = False 的意思是:

这个参数不参与梯度计算,不更新。

也就是冻结。

所以这一步相当于:

先一刀切,全部锁死。

2)放开 encoder 最后的 layer norm

python 复制代码
for p in self.backbone.encoder.ln.parameters():
    p.requires_grad = True

这里把 encoder 最后的 layer norm 放开。为什么?

因为最后的归一化层离输出最近,适当让它可训练,有时会帮助特征适配你的任务。

你可以先理解为:

给输出端留一点可调空间。

3)放开最后 n 个 Transformer block

python 复制代码
total_blocks = len(self.backbone.encoder.layers)
start_idx = max(0, total_blocks - fine_tune_last_n_blocks)
for i in range(start_idx, total_blocks):
    for p in self.backbone.encoder.layers[i].parameters():
        p.requires_grad = True

这里是关键。


total_blocks

ViT encoder 里有多层 Transformer block。

比如 ViT-B/16 通常有 12 层 block。

所以:total_blocks = len(self.backbone.encoder.layers)

就是得到 block 总数,比如 12。

start_idx

如果你要放开最后 4 层,那么起始位置就是:12 - 4 = 8

所以开放:第 8 层、第 9 层、第 10 层、第 11 层

这一段整体意思

只训练 ViT 最后几层,前面大部分层保持预训练状态不动。

这叫 partial fine-tuning(部分微调)

4)决定是否放开 patch embedding

python 复制代码
if not freeze_patch_embed:
    for p in self.backbone.conv_proj.parameters():
        p.requires_grad = True

这里的 conv_proj 就是 patch embedding 层。

在 torchvision 的 ViT 里,它通常是一个卷积层,用来做 patchify + 投影。


为什么卷积能做 patchify?

比如 kernel size = 16,stride = 16:

  • 每次正好取一个 16×16 patch
  • 不重叠地扫完整张图
  • 每个 patch 被投影成 embedding

所以它本质上就在做:

"把图像切块,并把每个块变成 token 向量"

这段的意思:

如果 freeze_patch_embed=False,就允许 patch embedding 层训练;否则保持冻结。

5)决定是否冻结 class token 和位置编码

python 复制代码
if not freeze_pos_embed:
    self.backbone.class_token.requires_grad = True
    self.backbone.encoder.pos_embedding.requires_grad = True
else:
    self.backbone.class_token.requires_grad = False
    self.backbone.encoder.pos_embedding.requires_grad = False
  • class_token

这是那个额外加在最前面的全局 token。它本身也是一个可学习参数。

  • pos_embedding

这是位置编码,也是可学习参数。

为什么有时冻结,有时不冻结?

因为这两个参数和模型原始输入分布绑定较强。

如果你的输入尺寸和预训练时相同(224×224),而且数据量又不大,冻结往往更稳。

如果你的任务和原任务差异很大,或者你想让模型更充分适配,也可以放开。

六、forward() 详解


输入说明

输入 x 是一批 RGB 图像:

  • B:batch size
  • 3:RGB 通道
  • 224×224:图像尺寸

1)取 batch size

python 复制代码
n = x.shape[0]

x.shape[0] 就是 batch size,也就是 B。

2)图像转成 patch tokens

python 复制代码
x = self.backbone._process_input(x)   # [B, 196, 768]

这是 torchvision ViT 里的内部函数。_process_input(x) = "把图像变成 patch token 序列"

它做的事大致包括:

  1. 把图像切成 patch
  2. 做 patch embedding
  3. 调整维度格式

结果输出:[B, 196, 768]

也就是:

  • 每张图被切成 196 个 patch
  • 每个 patch 变成 768 维特征

3)扩展 cls_token

python 复制代码
cls_token = self.backbone.class_token.expand(n, -1, -1)  # [B,1,768]

原始的 class_token 通常是一个可学习参数,形状大概是:[1, 1, 768]

它表示"一个模板 cls token"。

但现在一个 batch 有 n 张图,所以要给每张图都配一个 cls token。

于是用 expand 扩成:[B, 1, 768]

expand(n, -1, -1) 怎么理解

  • 第一个维度变成 n
  • 后面两个维度保持不变

如果原来是 [1,1,768],扩展后就是 [B,1,768]

4)把 cls token 拼到 patch tokens 前面

python 复制代码
x = torch.cat([cls_token, x], dim=1)  # [B,197,768]

torch.cat 是 PyTorch 里的拼接函数。

这里在第 1 维拼接,也就是 token 维。

原来:

  • cls_token: [B,1,768]
  • x: [B,196,768]

拼接后:[B,197,768]

含义:

  • 第 0 个 token 是 cls token
  • 后面 196 个是 patch token

5)送入 Transformer Encoder

python 复制代码
x = self.backbone.encoder(x)

这里的 encoder 内部已经包括了:

  • 加位置编码 pos_embedding
  • dropout
  • 多层 Transformer block
  • 每层 block 里做 self-attention
  • 再做 MLP
  • 最后的 layer norm

所以经过这一步后:x.shape = [B,197,768]

形状没变,但每个 token 都变成了"融合上下文后的高级特征"。

self attention具体解析如下连接:

VIT 中 Transformer的一层attention-CSDN博客

6)把输出拆成 cls 和 patch

python 复制代码
cls_token = x[:, :1, :]        # [B,1,768]
patch_tokens = x[:, 1:, :]     # [B,196,768]

这里是张量切片。

x[:, :1, :]

意思是:

  • 第 0 维 ::所有 batch
  • 第 1 维 :1:取第一个 token
  • 第 2 维 ::取全部特征维度

所以得到:[B,1,768]。这就是最终的 cls token


x[:, 1:, :]

意思是:

  • 所有 batch
  • 从第 1 个 token 到最后
  • 所有特征维度

所以得到:[B,196,768]。这就是所有 patch token。

7)返回结果

python 复制代码
return cls_token, patch_tokens

所以这个分支最后输出的是两种特征:

  • cls_token:全局图像表示
  • patch_tokens:局部 patch 表示

七、极简版结构图

输入图像 [B,3,224,224]

切成 196 个 patch

每个 patch → 768维

得到 [B,196,768]

前面拼一个 cls token

得到 [B,197,768]

加位置编码,进 Transformer Encoder

输出 [B,197,768]

拆成:

cls_token = [B,1,768]

patch_tokens= [B,196,768]

代码

python 复制代码
class RGBViTBranch(nn.Module):
    def __init__(
        self,
        fine_tune_last_n_blocks=4, #只放开最后 4 个 block 参与训练
        freeze_patch_embed=True, # 是否冻结 patch embedding 层。
        freeze_pos_embed=True # 是否冻结class token, pos embedding 层。
    ):
        super().__init__() # 初始化 nn.Module 的内部机制

        vit = load_vit_b16_backbone_safely( 
            local_weight_path=LOCAL_VIT_WEIGHT_PATH, # 优先从本地加载 vit_b_16 权重
            allow_auto_download=ALLOW_AUTO_DOWNLOAD # 允许自动下载
        )

        # 去掉原分类头
        vit.heads = nn.Identity()  #不是直接拿它分类,而是要提取特征,所以把原分类头替换成:
        self.backbone = vit

        self.freeze_layers(
            fine_tune_last_n_blocks=fine_tune_last_n_blocks,
            freeze_patch_embed=freeze_patch_embed,
            freeze_pos_embed=freeze_pos_embed
        )

    def freeze_layers(self, fine_tune_last_n_blocks=4, freeze_patch_embed=True, freeze_pos_embed=True):
        # 先全部冻结
        for p in self.backbone.parameters():
            p.requires_grad = False

        # 放开 encoder 最后的 ln
        for p in self.backbone.encoder.ln.parameters():
            p.requires_grad = True

        # 放开最后 n 个 Transformer block
        total_blocks = len(self.backbone.encoder.layers)
        start_idx = max(0, total_blocks - fine_tune_last_n_blocks)
        for i in range(start_idx, total_blocks):
            for p in self.backbone.encoder.layers[i].parameters():
                p.requires_grad = True

        # patch embedding
        if not freeze_patch_embed:
            for p in self.backbone.conv_proj.parameters():
                p.requires_grad = True

        # pos embedding / class token
        if not freeze_pos_embed:
            self.backbone.class_token.requires_grad = True
            self.backbone.encoder.pos_embedding.requires_grad = True
        else:
            self.backbone.class_token.requires_grad = False
            self.backbone.encoder.pos_embedding.requires_grad = False

    def forward(self, x):
        """
        x: [B, 3, 224, 224]
        返回:
            cls_token:   [B, 1, 768]
            patch_tokens:[B, 196, 768]
        """
        n = x.shape[0]

        # torchvision vit 内部 patchify
        x = self.backbone._process_input(x)   # [B, 196, 768]

        cls_token = self.backbone.class_token.expand(n, -1, -1)  # [B,1,768]
        x = torch.cat([cls_token, x], dim=1)                     # [B,197,768]

        x = self.backbone.encoder(x)   # encoder 内部已包含 pos_embedding + dropout + blocks + ln

        cls_token = x[:, :1, :]        # [B,1,768]
        patch_tokens = x[:, 1:, :]     # [B,196,768]

        return cls_token, patch_tokens
相关推荐
InfinteJustice2 小时前
CSS如何创建响应式导航栏菜单_结合Flexbox与媒体查询
jvm·数据库·python
nervermore9902 小时前
人工智能学习专栏
人工智能
粉嘟小飞妹儿2 小时前
Python环境PyTorch分布式训练初始化失败_检查MASTER_ADDR与端口
jvm·数据库·python
人工智能AI技术2 小时前
预训练与微调:大模型基础工作模式解析
人工智能
粉嘟小飞妹儿2 小时前
PHP怎么使用Eloquent Attribute Synthesis属性合成_Laravel多源数据融合【指南】
jvm·数据库·python
字节跳动的猫2 小时前
2026 四款 AI:开发场景适配全面解析
前端·人工智能·开源
m0_640309302 小时前
用Symfony构建AI驱动的Web应用实战
jvm·数据库·python
Warren982 小时前
Windows本地部署n8n完整教程(基于Docker,新手友好)
运维·windows·python·测试工具·docker·容器·可用性测试
kisdiem2 小时前
DeepSeek-OCR 2:给人工智能更像人类的眼睛
人工智能·深度学习·计算机视觉