可微分结构搜索, 可微分算子选择 —— 让程序“结构”也可学习 , 具体怎么实现结构的轮询穷举

在可微分结构搜索中,实现"结构的轮询穷举"并非指在训练过程中逐个尝试每一种结构(那会回到强化学习或进化算法的老路,效率极低),而是通过连续松弛技术,将离散的搜索空间转化为连续的优化空间。

简单来说,我们不再是"非此即彼"地选择算子,而是让程序同时包含所有可能的算子,通过权重参数来控制每个算子的"重要性"。训练过程中,这些权重会像普通神经网络的参数一样通过梯度下降来更新。

以下是具体的实现逻辑和步骤:

1. 核心思想:从离散到连续

在传统的穷举中,假设一条边上有 3 个候选算子(如 Conv3x3, Conv5x5, MaxPool),你只能选一个试。

可微分方法的实现方式是:

让这三个算子并行运行,然后将它们的输出结果加权求和。

输出=w1⋅Conv3x3(x)+w2⋅Conv5x5(x)+w3⋅MaxPool(x)输出=w1​⋅Conv3x3(x)+w2​⋅Conv5x5(x)+w3​⋅MaxPool(x)

这里的 w1,w2,w3w1​,w2​,w3​ 就是结构参数。如果训练到最后 w1w1​ 最大,我们就认为结构选择了 Conv3x3。

2. 具体实现步骤

第一步:定义超网络

这是物理实现的基础。你需要构建一个包含所有候选算子的"超级网络"。对于网络中的每一个节点连接(边),不再放置单一的操作,而是放置一个包含所有候选操作的混合模块。

代码逻辑示意:

复制代码
class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        # 将所有候选算子实例化并放入列表
        for primitive in PRIMITIVES:
            op = OPS[primitive](C, stride)
            self._ops.append(op)

    def forward(self, x, weights):
        # 核心逻辑:加权求和
        # weights 是一个向量,长度等于候选算子数量
        output = sum(w * op(x) for w, op in zip(weights, self._ops))
        return output
第二步:参数连续化与 Softmax

为了让"选择结构"这个动作可微分,结构参数 αα 必须通过 Softmax 归一化为概率分布(权重 ww)。

wi=exp⁡(αi)∑jexp⁡(αj)wi​=∑j​exp(αj​)exp(αi​)​

  • αα 是可学习的参数(类似于注意力机制中的 Key)。
  • 初始时,αα 全为0,意味着所有算子的权重相等(平均主义)。
  • 随着训练,梯度会更新 αα,使得表现好的算子权重变大。
第三步:双层优化

这是最关键的一步。我们有两类参数需要学习:

  1. 网络权重 WW:控制卷积核的大小、数值等(传统深度学习参数)。
  2. 结构权重 αα:控制选哪个算子(结构搜索参数)。

如果同时优化两者,会导致坍塌(网络倾向于选择参数少的、简单的算子,如 Skip Connection,因为它们容易训练)。DARTS 采用了交替训练的策略来模拟"轮询"和"筛选":

  1. 更新 WW(训练阶段): 固定 αα,在训练集上计算 Loss,更新网络权重 WW。此时,程序实际上是在尝试当前的混合结构。
  2. 更新 αα(验证阶段): 固定 WW,在验证集上计算 Loss,更新结构参数 αα。
    • 如果某个算子能让验证集 Loss 变小,它的 αα 梯度就会是负的,经过优化 αα 变大,该算子在下一次前向传播中的权重就会增加。

伪代码流程:

复制代码
# optimizer: 优化器
# train_data: 训练集
# val_data: 验证集

for epoch in range(epochs):
    # 1. 更新网络权重 W (结构参数 alpha 视为常数)
    # 这一步相当于让当前结构"学会"如何处理数据
    for step, (train_x, train_y) in enumerate(train_data):
        optimizer.zero_grad()
        # 前向传播时传入当前的 softmax(alpha)
        logits = model(train_x, softmax(alpha)) 
        loss = criterion(logits, train_y)
        loss.backward()
        optimizer.step() # 更新 W

    # 2. 更新结构参数 alpha (网络权重 W 视为常数)
    # 这一步相当于"评价"哪个结构更好
    for step, (val_x, val_y) in enumerate(val_data):
        architect_optimizer.zero_grad()
        # 在验证集上跑,寻找最佳结构
        logits = model(val_x, softmax(alpha))
        val_loss = criterion(logits, val_y)
        val_loss.backward()
        architect_optimizer.step() # 更新 alpha
第四步:离散化与穷举的终结

经过多轮训练后,每个位置上的 αα 向量会变得极度不平衡(例如,Conv3x3 的 αα 远大于其他)。此时,我们执行最后的"离散化"操作:

  1. 生成最终结构: 对于每个混合节点,保留 αα 值最大的那个算子。
  2. 丢弃其他算子: 此时,复杂的超网络变成了一个轻量级的单路径网络。
  3. 重训练: 这个被选中的网络结构会被重新从头训练,以达到最佳性能。

总结:这为什么叫"可微分"且实现了"穷举"?

  1. 为何像穷举?

    在每一次前向传播中,所有算子都参与了计算。这相当于程序不是在选 A 或 B,而是在计算"A 和 B 的某种混合"。通过权重的变化,程序在连续空间中"扫描"了所有可能的组合。每一个 αα 的更新,都是在微调结构,向着最优结构靠近。

  2. 为何是可微分?

    传统方法选结构是离散动作(选了就是选了,不可导),而现在选结构的动作变成了 Softmax 乘法,这是可导的。因此,我们可以用链式法则直接算出"选这个算子会对最终结果产生什么影响",并据此调整 αα。

这种方法的代价:

由于初始阶段所有算子都要算一遍,显存占用会非常大(显存占用 = 候选算子数量 ×× 单算子显存)。因此,后来出现了 ProxylessNAS 等改进方法,在训练时只采样一条路径进行更新,进一步解决了显存瓶颈。

相关推荐
智联视频超融合平台几秒前
AI赋能传统电厂:2025能源革命的智慧引擎
人工智能·能源
qcx238 分钟前
【系统学AI】23 AI 时代产品运营与获客全景:CRM SaaS 大变局 + 增长新范式(2026 调研报告)
人工智能·产品运营·产品设计·ai agent·ai native
叶修_A9 分钟前
【COZE-08】Prompt工程进阶 - 结构化输出与思维链
大数据·人工智能·prompt
John_ToDebug11 分钟前
开源与人性:DeepSeek 战略的底层逻辑
人工智能·经验分享·ai
IronMurphy13 分钟前
【算法五十五】240. 搜索二维矩阵 II
线性代数·矩阵
老吴胡编14 分钟前
eknife 2026.05.28 v0.0.5 更新 —— 支持 PDF 文档合并
人工智能·嵌入式硬件·个人开发
OCR_1337162127514 分钟前
技术实测|2026三款主流OCR横向对比:SDK15、PaddleOCR、GLM-OCR选型指南
大数据·人工智能
深蓝电商API15 分钟前
当爬虫遇见大模型:AI驱动的智能数据采集新范式
人工智能·爬虫
陈天伟教授16 分钟前
图解人工智能(37)人工智能应用-车牌识别
人工智能·深度学习
Agent手记16 分钟前
电商智能客服的退换货自动处理流程如何配置?——2026企业级Agent全链路实战指南
人工智能·ai