在可微分结构搜索中,实现"结构的轮询穷举"并非指在训练过程中逐个尝试每一种结构(那会回到强化学习或进化算法的老路,效率极低),而是通过连续松弛技术,将离散的搜索空间转化为连续的优化空间。
简单来说,我们不再是"非此即彼"地选择算子,而是让程序同时包含所有可能的算子,通过权重参数来控制每个算子的"重要性"。训练过程中,这些权重会像普通神经网络的参数一样通过梯度下降来更新。
以下是具体的实现逻辑和步骤:
1. 核心思想:从离散到连续
在传统的穷举中,假设一条边上有 3 个候选算子(如 Conv3x3, Conv5x5, MaxPool),你只能选一个试。
可微分方法的实现方式是:
让这三个算子并行运行,然后将它们的输出结果加权求和。
输出=w1⋅Conv3x3(x)+w2⋅Conv5x5(x)+w3⋅MaxPool(x)输出=w1⋅Conv3x3(x)+w2⋅Conv5x5(x)+w3⋅MaxPool(x)
这里的 w1,w2,w3w1,w2,w3 就是结构参数。如果训练到最后 w1w1 最大,我们就认为结构选择了 Conv3x3。
2. 具体实现步骤
第一步:定义超网络
这是物理实现的基础。你需要构建一个包含所有候选算子的"超级网络"。对于网络中的每一个节点连接(边),不再放置单一的操作,而是放置一个包含所有候选操作的混合模块。
代码逻辑示意:
class MixedOp(nn.Module):
def __init__(self, C, stride):
super(MixedOp, self).__init__()
self._ops = nn.ModuleList()
# 将所有候选算子实例化并放入列表
for primitive in PRIMITIVES:
op = OPS[primitive](C, stride)
self._ops.append(op)
def forward(self, x, weights):
# 核心逻辑:加权求和
# weights 是一个向量,长度等于候选算子数量
output = sum(w * op(x) for w, op in zip(weights, self._ops))
return output
第二步:参数连续化与 Softmax
为了让"选择结构"这个动作可微分,结构参数 αα 必须通过 Softmax 归一化为概率分布(权重 ww)。
wi=exp(αi)∑jexp(αj)wi=∑jexp(αj)exp(αi)
- αα 是可学习的参数(类似于注意力机制中的 Key)。
- 初始时,αα 全为0,意味着所有算子的权重相等(平均主义)。
- 随着训练,梯度会更新 αα,使得表现好的算子权重变大。
第三步:双层优化
这是最关键的一步。我们有两类参数需要学习:
- 网络权重 WW:控制卷积核的大小、数值等(传统深度学习参数)。
- 结构权重 αα:控制选哪个算子(结构搜索参数)。
如果同时优化两者,会导致坍塌(网络倾向于选择参数少的、简单的算子,如 Skip Connection,因为它们容易训练)。DARTS 采用了交替训练的策略来模拟"轮询"和"筛选":
- 更新 WW(训练阶段): 固定 αα,在训练集上计算 Loss,更新网络权重 WW。此时,程序实际上是在尝试当前的混合结构。
- 更新 αα(验证阶段): 固定 WW,在验证集上计算 Loss,更新结构参数 αα。
- 如果某个算子能让验证集 Loss 变小,它的 αα 梯度就会是负的,经过优化 αα 变大,该算子在下一次前向传播中的权重就会增加。
伪代码流程:
# optimizer: 优化器
# train_data: 训练集
# val_data: 验证集
for epoch in range(epochs):
# 1. 更新网络权重 W (结构参数 alpha 视为常数)
# 这一步相当于让当前结构"学会"如何处理数据
for step, (train_x, train_y) in enumerate(train_data):
optimizer.zero_grad()
# 前向传播时传入当前的 softmax(alpha)
logits = model(train_x, softmax(alpha))
loss = criterion(logits, train_y)
loss.backward()
optimizer.step() # 更新 W
# 2. 更新结构参数 alpha (网络权重 W 视为常数)
# 这一步相当于"评价"哪个结构更好
for step, (val_x, val_y) in enumerate(val_data):
architect_optimizer.zero_grad()
# 在验证集上跑,寻找最佳结构
logits = model(val_x, softmax(alpha))
val_loss = criterion(logits, val_y)
val_loss.backward()
architect_optimizer.step() # 更新 alpha
第四步:离散化与穷举的终结
经过多轮训练后,每个位置上的 αα 向量会变得极度不平衡(例如,Conv3x3 的 αα 远大于其他)。此时,我们执行最后的"离散化"操作:
- 生成最终结构: 对于每个混合节点,保留 αα 值最大的那个算子。
- 丢弃其他算子: 此时,复杂的超网络变成了一个轻量级的单路径网络。
- 重训练: 这个被选中的网络结构会被重新从头训练,以达到最佳性能。
总结:这为什么叫"可微分"且实现了"穷举"?
-
为何像穷举?
在每一次前向传播中,所有算子都参与了计算。这相当于程序不是在选 A 或 B,而是在计算"A 和 B 的某种混合"。通过权重的变化,程序在连续空间中"扫描"了所有可能的组合。每一个 αα 的更新,都是在微调结构,向着最优结构靠近。
-
为何是可微分?
传统方法选结构是离散动作(选了就是选了,不可导),而现在选结构的动作变成了 Softmax 乘法,这是可导的。因此,我们可以用链式法则直接算出"选这个算子会对最终结果产生什么影响",并据此调整 αα。
这种方法的代价:
由于初始阶段所有算子都要算一遍,显存占用会非常大(显存占用 = 候选算子数量 ×× 单算子显存)。因此,后来出现了 ProxylessNAS 等改进方法,在训练时只采样一条路径进行更新,进一步解决了显存瓶颈。