可微分结构搜索, 可微分算子选择 —— 让程序“结构”也可学习 , 具体怎么实现结构的轮询穷举

在可微分结构搜索中,实现"结构的轮询穷举"并非指在训练过程中逐个尝试每一种结构(那会回到强化学习或进化算法的老路,效率极低),而是通过连续松弛技术,将离散的搜索空间转化为连续的优化空间。

简单来说,我们不再是"非此即彼"地选择算子,而是让程序同时包含所有可能的算子,通过权重参数来控制每个算子的"重要性"。训练过程中,这些权重会像普通神经网络的参数一样通过梯度下降来更新。

以下是具体的实现逻辑和步骤:

1. 核心思想:从离散到连续

在传统的穷举中,假设一条边上有 3 个候选算子(如 Conv3x3, Conv5x5, MaxPool),你只能选一个试。

可微分方法的实现方式是:

让这三个算子并行运行,然后将它们的输出结果加权求和。

输出=w1⋅Conv3x3(x)+w2⋅Conv5x5(x)+w3⋅MaxPool(x)输出=w1​⋅Conv3x3(x)+w2​⋅Conv5x5(x)+w3​⋅MaxPool(x)

这里的 w1,w2,w3w1​,w2​,w3​ 就是结构参数。如果训练到最后 w1w1​ 最大,我们就认为结构选择了 Conv3x3。

2. 具体实现步骤

第一步:定义超网络

这是物理实现的基础。你需要构建一个包含所有候选算子的"超级网络"。对于网络中的每一个节点连接(边),不再放置单一的操作,而是放置一个包含所有候选操作的混合模块。

代码逻辑示意:

复制代码
class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        # 将所有候选算子实例化并放入列表
        for primitive in PRIMITIVES:
            op = OPS[primitive](C, stride)
            self._ops.append(op)

    def forward(self, x, weights):
        # 核心逻辑:加权求和
        # weights 是一个向量,长度等于候选算子数量
        output = sum(w * op(x) for w, op in zip(weights, self._ops))
        return output
第二步:参数连续化与 Softmax

为了让"选择结构"这个动作可微分,结构参数 αα 必须通过 Softmax 归一化为概率分布(权重 ww)。

wi=exp⁡(αi)∑jexp⁡(αj)wi​=∑j​exp(αj​)exp(αi​)​

  • αα 是可学习的参数(类似于注意力机制中的 Key)。
  • 初始时,αα 全为0,意味着所有算子的权重相等(平均主义)。
  • 随着训练,梯度会更新 αα,使得表现好的算子权重变大。
第三步:双层优化

这是最关键的一步。我们有两类参数需要学习:

  1. 网络权重 WW:控制卷积核的大小、数值等(传统深度学习参数)。
  2. 结构权重 αα:控制选哪个算子(结构搜索参数)。

如果同时优化两者,会导致坍塌(网络倾向于选择参数少的、简单的算子,如 Skip Connection,因为它们容易训练)。DARTS 采用了交替训练的策略来模拟"轮询"和"筛选":

  1. 更新 WW(训练阶段): 固定 αα,在训练集上计算 Loss,更新网络权重 WW。此时,程序实际上是在尝试当前的混合结构。
  2. 更新 αα(验证阶段): 固定 WW,在验证集上计算 Loss,更新结构参数 αα。
    • 如果某个算子能让验证集 Loss 变小,它的 αα 梯度就会是负的,经过优化 αα 变大,该算子在下一次前向传播中的权重就会增加。

伪代码流程:

复制代码
# optimizer: 优化器
# train_data: 训练集
# val_data: 验证集

for epoch in range(epochs):
    # 1. 更新网络权重 W (结构参数 alpha 视为常数)
    # 这一步相当于让当前结构"学会"如何处理数据
    for step, (train_x, train_y) in enumerate(train_data):
        optimizer.zero_grad()
        # 前向传播时传入当前的 softmax(alpha)
        logits = model(train_x, softmax(alpha)) 
        loss = criterion(logits, train_y)
        loss.backward()
        optimizer.step() # 更新 W

    # 2. 更新结构参数 alpha (网络权重 W 视为常数)
    # 这一步相当于"评价"哪个结构更好
    for step, (val_x, val_y) in enumerate(val_data):
        architect_optimizer.zero_grad()
        # 在验证集上跑,寻找最佳结构
        logits = model(val_x, softmax(alpha))
        val_loss = criterion(logits, val_y)
        val_loss.backward()
        architect_optimizer.step() # 更新 alpha
第四步:离散化与穷举的终结

经过多轮训练后,每个位置上的 αα 向量会变得极度不平衡(例如,Conv3x3 的 αα 远大于其他)。此时,我们执行最后的"离散化"操作:

  1. 生成最终结构: 对于每个混合节点,保留 αα 值最大的那个算子。
  2. 丢弃其他算子: 此时,复杂的超网络变成了一个轻量级的单路径网络。
  3. 重训练: 这个被选中的网络结构会被重新从头训练,以达到最佳性能。

总结:这为什么叫"可微分"且实现了"穷举"?

  1. 为何像穷举?

    在每一次前向传播中,所有算子都参与了计算。这相当于程序不是在选 A 或 B,而是在计算"A 和 B 的某种混合"。通过权重的变化,程序在连续空间中"扫描"了所有可能的组合。每一个 αα 的更新,都是在微调结构,向着最优结构靠近。

  2. 为何是可微分?

    传统方法选结构是离散动作(选了就是选了,不可导),而现在选结构的动作变成了 Softmax 乘法,这是可导的。因此,我们可以用链式法则直接算出"选这个算子会对最终结果产生什么影响",并据此调整 αα。

这种方法的代价:

由于初始阶段所有算子都要算一遍,显存占用会非常大(显存占用 = 候选算子数量 ×× 单算子显存)。因此,后来出现了 ProxylessNAS 等改进方法,在训练时只采样一条路径进行更新,进一步解决了显存瓶颈。

相关推荐
惊鸿一博1 小时前
自动驾驶的 BEV 特征(Bird’s Eye View Feature)
人工智能·机器学习·自动驾驶
碳基硅坊2 小时前
Mac Studio M3 Ultra 运行大模型实测:Qwen3.6 vs 6款主流模型工具调用对比
人工智能·qwen·qwen3.6
TeDi TIVE8 小时前
开源模型应用落地-工具使用篇-Spring AI-高阶用法(九)
人工智能·spring·开源
MY_TEUCK8 小时前
Sealos 平台部署实战指南:结合 Cursor 与版本发布流程
java·人工智能·学习·aigc
三毛的二哥8 小时前
BEV:典型BEV算法总结
人工智能·算法·计算机视觉·3d
j_xxx404_9 小时前
大语言模型 (LLM) 零基础入门:核心原理、训练机制与能力全解
人工智能·ai·transformer
飞哥数智坊9 小时前
全新 SOLO 在日常办公中的实际体验
人工智能·solo
<-->9 小时前
Megatron(全称 Megatron-LM,由 NVIDIA 开发)和 DeepSpeed(由 Microsoft 开发)
人工智能·pytorch·python·深度学习·transformer
朝新_9 小时前
【Spring AI 】图像与语音模型实战
java·人工智能·spring
Yuanxl9039 小时前
神经网络-Sequential 应用与实战
人工智能·深度学习·神经网络