【胶囊网络 - 简明教程】02-1 胶囊网络 - 整体架构设计

【胶囊网络 - 简明教程】 02-1 整体架构设计

目录

  • [1. 架构总览](#1. 架构总览)
  • [2. 编码器设计](#2. 编码器设计)
  • [3. 解码器设计](#3. 解码器设计)
  • [4. 数据流分析](#4. 数据流分析)
  • [5. 设计哲学](#5. 设计哲学)

1. 架构总览

1.1 完整 Capsule Network 结构

基于本项目(Capsule_Network.ipynb)实现的架构,完整的胶囊网络由以下核心组件构成:

架构可视化

图 1:完整的 Capsule Network 架构(来源:原论文)

这张图展示了 CapsNet 的完整数据流:

  • 左侧:编码器部分,从输入图像到 DigitCaps 的层次化特征提取
  • 右侧:解码器部分,从胶囊表示重建原始图像
  • 中间:动态路由机制,连接 PrimaryCaps 和 DigitCaps

解码器 Decoder
编码器 Encoder
输入层
输入图像

28×28×1 灰度
Conv1 卷积层

256@9×9 stride=1

输出:20×20×256
PrimaryCaps

8 个胶囊

输出:1152×8
DigitCaps

10 个数字胶囊

输出:16×10
FC1: 160→512
FC2: 512→1024
FC3: 1024→784

重建图像

1.2 核心组件功能

组件 功能 输入维度 输出维度
Conv1 提取基础特征 28×28×1 20×20×256
PrimaryCaps 形成初级胶囊 20×20×256 1152×8
DigitCaps 高级语义表示 1152×8 16×10
Decoder 重建验证 160 784

2. 编码器设计

2.1 第一层:卷积层(ConvLayer)

设计目标

卷积层负责从原始像素中提取低级特征,如边缘、角点、纹理等。

关键参数
python 复制代码
class ConvLayer(nn.Module):
    def __init__(self, in_channels=1, out_channels=256):
        self.conv = nn.Conv2d(
            in_channels=1,      # 输入:灰度图像
            out_channels=256,   # 输出:256 个特征图
            kernel_size=9,      # 9×9 卷积核
            stride=1,           # 步长为 1
            padding=0           # 无填充
        )
维度计算详解
复制代码
输入尺寸:28×28×1

卷积操作:
- 卷积核大小:9×9
- 步长:1
- 填充:0

输出尺寸计算公式:
output_size = (input_size - kernel_size + 2*padding) / stride + 1
            = (28 - 9 + 0) / 1 + 1
            = 20

输出维度:20×20×256
可视化理解

输出特征图
卷积操作
输入
输入图像

28×28×1
9×9 卷积核

滑动扫描
特征图 1

垂直边缘
特征图 2

水平边缘
特征图 3

45°斜边
...
特征图 256

复杂纹理

激活函数选择
python 复制代码
features = F.relu(self.conv(x))

为什么使用 ReLU?

  1. 非线性:引入非线性,增强表达能力
  2. 稀疏性:负值区域输出为 0,产生稀疏激活
  3. 计算高效:简单的阈值操作
  4. 梯度稳定:正区域梯度恒为 1,缓解梯度消失

2.2 第二层:主胶囊层(PrimaryCaps)

设计目标

PrimaryCaps 层是从标量到向量的转折点,将卷积特征重新组织为胶囊向量。

关键结构
python 复制代码
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32):
        # 创建 8 个并行的卷积层
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=256, out_channels=32,
                      kernel_size=9, stride=2, padding=0)
            for _ in range(num_capsules)])
为什么是 8 个胶囊?
复制代码
设计考量:
✓ 足够多的胶囊可以编码丰富的特征组合
✓ 每个胶囊学习不同的特征子空间
✓ 8 是 2 的幂次,便于 GPU 并行计算
✓ 实验验证:8 个胶囊在 MNIST 上表现良好
维度变换过程

图 2:编码器架构示意图(来源:原论文)

这张图清晰地展示了编码器的三层结构:

  • Conv1:256 个 9×9 卷积核,提取基础特征
  • PrimaryCaps:8 个胶囊,每个输出 1152 维向量
  • DigitCaps:10 个数字胶囊,每个输出 16 维向量

步骤 4 Squash
步骤 3 堆叠
步骤 2 重塑
步骤 1 卷积
32@9×9 stride=2
展平
8 个胶囊
squash 函数
20×20×256
6×6×32
1152×1
1152×8
1152×8

归一化向量

Squash 函数实现
python 复制代码
def squash(tensor, dim=-1):
    # 计算向量模长的平方
    squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
    
    # 计算缩放因子:scale = ||v||² / (1 + ||v||²)
    scale = squared_norm / (1.0 + squared_norm)
    
    # 归一化向量:v_normalized = v / ||v||
    normalized = tensor / torch.sqrt(squared_norm + 1e-9)
    
    # 应用缩放:v_output = scale * v_normalized
    return scale * normalized
Squash 函数的几何意义

很小≈0
很大
输入向量 v
计算模长||v||
输出接近 0

特征不存在
输出接近 1

特征存在
方向保持不变
模长在 0-1 区间


2.3 第三层:数字胶囊层(DigitCaps)

设计目标

DigitCaps 层是编码器的最高层,每个胶囊对应一个数字类别(0-9)。

关键特性
python 复制代码
class DigitCaps(nn.Module):
    def __init__(self, in_caps=8, in_dim=1152, out_caps=10, out_dim=16):
        # 可学习的路由权重矩阵
        self.weight = nn.Parameter(torch.randn(1, out_caps, in_caps, out_dim, in_dim))
核心参数
参数 含义
in_caps 8 输入胶囊数量(PrimaryCaps 的输出)
in_dim 1152 每个输入胶囊的维度
out_caps 10 输出胶囊数量(对应 0-9)
out_dim 16 每个输出胶囊的维度
路由机制概览
复制代码
PrimaryCaps (8 个胶囊)
    ↓
动态路由 (3 次迭代)
    ↓
DigitCaps (10 个胶囊)

每个 DigitCaps 接收来自所有 PrimaryCaps 的加权输入
权重通过动态路由算法迭代更新
16 维向量的意义
复制代码
DigitCaps 输出的 16 维向量编码了:
- v[0:2]: 位置信息(x, y 坐标)
- v[2:4]: 尺寸信息(宽,高)
- v[4:6]: 旋转角度
- v[6:8]: 笔画粗细
- v[8:15]: 其他变形参数

向量的模长表示"是该数字"的概率

3. 解码器设计

3.1 解码器的作用

解码器有两个核心功能:

  1. 正则化编码器:通过重建任务迫使编码器保留更多信息
  2. 可视化验证:通过重建质量评估胶囊的表示能力

3.2 网络结构

图 3:解码器架构示意图(来源:原论文)

解码器的工作流程:

  1. 掩码选择:选择模长最大的胶囊(预测类别)
  2. 展平:将 10×16 向量展平为 160 维
  3. 全连接层:三层全连接逐步恢复维度
  4. 重建输出:生成 28×28 的重建图像
python 复制代码
class Decoder(nn.Module):
    def __init__(self):
        self.fc1 = nn.Linear(160, 512)    # 160 = 10 个胶囊 × 16 维
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 784)   # 784 = 28×28

3.3 维度变换

复制代码
输入:10 个 16 维向量 → 展平为 160 维

FC1: 160 → 512
ReLU 激活

FC2: 512 → 1024
ReLU 激活

FC3: 1024 → 784
Sigmoid 激活(输出像素值在 [0, 1])

输出:784 维 → 重塑为 28×28 图像

3.4 掩码机制

训练时掩码
python 复制代码
# 使用真实标签选择对应的胶囊
mask = torch.zeros_like(capsules_output)
mask[range(batch_size), labels] = 1
masked_output = capsules_output * mask
推理时掩码
python 复制代码
# 选择模长最大的胶囊
lengths = torch.norm(capsules_output, dim=-1)
max_length_idx = torch.argmax(lengths, dim=-1)
mask = torch.zeros_like(capsules_output)
mask[range(batch_size), max_length_idx] = 1
masked_output = capsules_output * mask

3.5 重建损失的作用

复制代码
总损失 = 分类损失 + 0.0005 × 重建损失

重建损失的意义:
✓ 迫使胶囊编码更多视觉信息
✓ 防止过拟合
✓ 提高泛化能力
✓ 提供可解释性

4. 数据流分析

4.1 前向传播完整流程

复制代码
Step 1: 输入图像
形状:[batch_size, 1, 28, 28]
含义:一批 28×28 灰度图像

    ↓

Step 2: ConvLayer
操作:Conv2d(1→256, 9×9) + ReLU
形状:[batch_size, 256, 20, 20]
含义:256 个 20×20 特征图

    ↓

Step 3: PrimaryCaps
操作:8 个 Conv2d(256→32, 9×9, stride=2) + Squash
形状:[batch_size, 1152, 8]
含义:8 个 1152 维胶囊向量

    ↓

Step 4: DigitCaps
操作:动态路由(3 次迭代)+ Squash
形状:[batch_size, 10, 16]
含义:10 个 16 维数字胶囊

    ↓

Step 5: 分类输出
操作:计算 10 个胶囊的模长
形状:[batch_size, 10]
含义:每个数字的概率

    ↓

Step 6: 解码器输入
操作:掩码选择 + 展平
形状:[batch_size, 160]
含义:选中的胶囊向量

    ↓

Step 7: Decoder
操作:FC1(512) → FC2(1024) → FC3(784)
形状:[batch_size, 784]
含义:重建的图像像素

    ↓

Step 8: 重建图像
形状:[batch_size, 1, 28, 28]
含义:重建的 28×28 图像

4.2 梯度反向传播

复制代码
损失函数
    ↓
分类损失(Margin Loss)
重建损失(MSE Loss)
    ↓
梯度流向 Decoder
    ↓
梯度流向 DigitCaps
    ↓
梯度流向 PrimaryCaps
    ↓
梯度流向 ConvLayer

5. 设计哲学

5.1 层次化表示

复制代码
设计原则:从具体到抽象

像素级 (Pixel Level)
    ↓ 28×28 原始像素
特征级 (Feature Level)
    ↓ 256 个卷积特征
部件级 (Part Level)
    ↓ 8 个主胶囊
对象级 (Object Level)
    ↓ 10 个数字胶囊

5.2 信息瓶颈

复制代码
信息流:
28×28 = 784 像素
    ↓ 压缩
20×20×256 = 102,400 特征
    ↓ 选择
1152×8 = 9,216 维
    ↓ 抽象
16×10 = 160 维
    ↓ 重建
784 像素

瓶颈处的 160 维编码了最本质的特征

5.3 端到端训练

复制代码
优势:
✓ 所有参数联合优化
✓ 特征表示自学习
✓ 无需人工设计特征
✓ 适应数据分布

5.4 可解释性设计

复制代码
可解释性体现:
1. 胶囊向量模长 → 存在概率
2. 胶囊向量方向 → 姿态参数
3. 解码器重建 → 可视化验证
4. 动态路由权重 → 注意力分布

6. 总结

架构设计要点

层级 核心创新 技术要点
ConvLayer 特征提取 大卷积核、多通道
PrimaryCaps 标量→向量 并行卷积、Squash
DigitCaps 动态路由 迭代路由、类别表示
Decoder 重建验证 全连接、掩码机制

设计亮点

层次化编码 :从像素到语义的渐进抽象

向量表示 :同时编码存在性和姿态

动态路由 :自适应的特征组合

重建监督:双重任务提升泛化

相关推荐
小雨中_1 小时前
2.6 时序差分方法(Temporal Difference, TD)
人工智能·python·深度学习·机器学习·自然语言处理
落羽的落羽2 小时前
【Linux系统】磁盘ext文件系统与软硬链接
linux·运维·服务器·数据库·c++·人工智能·机器学习
民乐团扒谱机2 小时前
【硬科普】位置与动量为什么是傅里叶变换对?从正则对易关系到时空弯曲,一次讲透
人工智能·线性代数·正则·量子力学·傅里叶变换·对易算符
AI实战架构笔记2 小时前
大数据预测分析在房地产行业的市场动态监测
大数据·ai
陈天伟教授2 小时前
人工智能应用- 推荐算法:05.推荐算法的社会争议
算法·机器学习·推荐算法
七夜zippoe2 小时前
图神经网络实战:从社交网络到推荐系统的工业级应用
网络·人工智能·pytorch·python·神经网络·cora
啊阿狸不会拉杆2 小时前
《计算机视觉:模型、学习和推理》第 1 章 - 绪论
人工智能·python·学习·算法·机器学习·计算机视觉·模型
X54先生(人文科技)2 小时前
叙事响应:《当预言泛起涟漪——碳硅智能时代的叙事开篇》
人工智能·ai编程·ai写作
硅谷秋水2 小时前
具身智能中的生成多智体协作:系统性综述
人工智能·深度学习·机器学习·语言模型·机器人