人工智能【第41篇】神经架构搜索NAS入门:自动设计神经网络

作者的话 :在前面的文章中,我们学习了各种神经网络架构------CNN、RNN、Transformer等。但这些架构都是人类专家手工设计 的,需要大量试错和经验。能否让AI自动设计神经网络 ?**神经架构搜索(Neural Architecture Search, NAS)**就是要回答这个问题------通过自动化方法发现超越人类设计的网络架构。从AlphaGo的启示到EfficientNet的成功,NAS正在改变深度学习的发展方式。本文将带你深入理解NAS的原理与实战!


一、为什么需要神经架构搜索?

1.1 手工设计网络的问题

传统流程的问题

复制代码
手工设计神经网络:
  1. 确定基本架构(ResNet? DenseNet?)
  2. 选择层数(18? 50? 101?)
  3. 调整通道数(64? 128? 256?)
  4. 设计连接方式(跳跃连接?密集连接?)
  5. 实验验证(训练7天...效果不好...)
  6. 回到步骤2,重复...

时间成本:数周到数月
计算成本:数千GPU小时
人力成本:资深专家经验

1.2 手工设计的局限

局限 说明 例子
经验依赖 需要深厚专业知识 为什么ResNet用3×3卷积?
试错成本高 每个设计都要完整训练 训练一次数天
局部最优 人倾向于熟悉的架构 都是ResNet变种
扩展性差 新任务需要重新设计 ImageNet到移动端

1.3 NAS的成功案例

模型 方法 成果
NASNet RL + CIFAR10 超越人类设计
AmoebaNet 进化算法 ImageNet SOTA
EfficientNet 复合缩放 SOTA + 高效
MobileNetV3 NAS + 人工调优 移动端最优

二、搜索空间

2.1 搜索空间的分类

宏搜索(Macro Search)

  • 搜索整个网络结构
  • 每一层都可以不同
  • 灵活性高,搜索空间大

微搜索(Micro Search)

  • 搜索基本单元(cell/block)

  • 堆叠相同单元构建网络

  • 搜索空间小,迁移性好

    宏搜索 vs 微搜索

    宏搜索:
    Layer 1: Conv3x3, 32 filters
    Layer 2: Conv5x5, 64 filters
    Layer 3: MaxPool 3x3
    ...(每层都不同)

    微搜索(NASNet):
    Normal Cell: [搜索得到的基本单元]
    Reduction Cell: [搜索得到的基本单元]
    网络 = Normal × N + Reduction + Normal × N + ...

2.2 DARTS搜索空间

复制代码
# 有向无环图(DAG)表示
# 节点 = 特征图
# 边 = 操作

可选操作:
  - none: 不连接
  - skip_connect: 跳跃连接
  - conv_3x3: 3x3卷积
  - conv_5x5: 5x5卷积
  - dil_conv_3x3: 3x3空洞卷积
  - sep_conv_3x3: 3x3可分离卷积
  - avg_pool_3x3: 3x3平均池化
  - max_pool_3x3: 3x3最大池化

三、搜索策略

3.1 基于强化学习的NAS

Zoph & Le, 2017(开创性工作)

复制代码
控制器(RNN)→ 生成架构描述 → 训练子网络 → 得到准确率 → 奖励RNN
         ↑__________________________________________________|

核心思想

  • 使用RNN作为控制器
  • 逐层生成网络架构
  • 子网络准确率作为奖励
  • 使用策略梯度更新控制器

3.2 基于进化算法的NAS

复制代码
进化算法流程:
  1. 初始化:随机生成P个架构(种群)
  2. 评估:训练并评估每个架构的适应度(准确率)
  3. 选择:选择适应度高的架构
  4. 变异:对选中架构进行变异(改层、改连接)
  5. 重复2-4,直到找到最优架构

3.3 基于梯度的NAS(DARTS)

DARTS核心思想

将离散架构选择松弛为连续权重,使用梯度下降优化。

复制代码
传统NAS:硬选择一条边
  edge_op = select_one_from([op1, op2, op3, op4])

DARTS:软加权所有边
  edge_output = softmax(α1) * op1(x) + softmax(α2) * op2(x) + ...
  # α是可学习的架构参数

双层优化问题

复制代码
min_α L_val(w*(α), α)
s.t. w*(α) = argmin_w L_train(w, α)

其中:
- α:架构参数(决定选择哪些操作)
- w:网络权重
- 内层:固定架构,优化权重
- 外层:固定权重,优化架构

四、高效NAS方法

4.1 ENAS:权重共享

核心思想

所有子网络共享同一套权重,避免重复训练。

复制代码
ENAS关键创新:
  1. 使用一个大的计算图包含所有可能操作
  2. 每个子网络是这个图的一个子图
  3. 所有子网络共享参数
  4. 训练一个子网络 = 更新共享参数

4.2 ProxylessNAS:直接搜索目标硬件

核心思想

直接在目标硬件上搜索,使用路径二值化减少内存占用。

复制代码
ProxylessNAS特点:
  - 训练时软选择所有路径
  - 推理时硬选择最优路径
  - 直接针对目标硬件优化延迟

4.3 Once-for-All:训练一个超级网络

复制代码
Once-for-All流程:
  1. 构建最大网络(最多层、最大宽度、所有操作)
  2. 使用渐进收缩训练:
     - 先训练最大网络
     - 逐步引入更小的子网络
     - 所有子网络共享权重
  3. 搜索:从训练好的网络中采样评估,无需重新训练
  4. 部署:直接提取子网络

五、性能评估加速

方法 描述 时间 准确性
完整训练 从头训练到收敛 数天 最准确
早停 训练少量epoch 数小时 较准确
代理任务 小数据集/浅网络 分钟 有相关性
权重共享 ENAS/Once-for-All 需要精心设计

六、实战项目:DARTS实现CIFAR-10搜索

6.1 核心代码实现

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

# 操作定义
OPS = {
    'none': lambda C, stride: Zero(stride),
    'skip_connect': lambda C, stride: Identity() if stride == 1 else nn.Conv2d(C, C, 1, stride, 0, bias=False),
    'sep_conv_3x3': lambda C, stride: SepConv(C, C, 3, stride, 1),
    'sep_conv_5x5': lambda C, stride: SepConv(C, C, 5, stride, 2),
    'dil_conv_3x3': lambda C, stride: DilConv(C, C, 3, stride, 2, 2),
    'avg_pool_3x3': lambda C, stride: nn.AvgPool2d(3, stride, 1),
    'max_pool_3x3': lambda C, stride: nn.MaxPool2d(3, stride, 1),
    'conv_3x3': lambda C, stride: ReLUConvBN(C, C, 3, stride, 1),
}

class MixedOp(nn.Module):
    """混合操作:加权所有候选操作"""
    def __init__(self, C, stride):
        super().__init__()
        self.ops = nn.ModuleList([op(C, stride) for op in OPS.values()])
    
    def forward(self, x, weights):
        """weights: 架构参数(softmax后)"""
        return sum(w * op(x) for w, op in zip(weights, self.ops))

6.2 DARTS单元实现

复制代码
class Cell(nn.Module):
    """DARTS搜索单元"""
    def __init__(self, n_nodes, C_prev_prev, C_prev, C, reduction=False):
        super().__init__()
        self.n_nodes = n_nodes
        self.reduction = reduction
        
        # 预处理输入
        self.preprocess0 = nn.Sequential(
            nn.Conv2d(C_prev_prev, C, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C)
        )
        self.preprocess1 = nn.Sequential(
            nn.Conv2d(C_prev, C, 1, 1, 0, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # 构建所有边
        self.edges = nn.ModuleDict()
        for i in range(n_nodes):
            for j in range(2 + i):
                stride = 2 if reduction and j < 2 else 1
                edge_key = f"{j}->{i+2}"
                self.edges[edge_key] = MixedOp(C, stride)
    
    def forward(self, s0, s1, weights):
        s0 = self.preprocess0(s0)
        s1 = self.preprocess1(s1)
        
        states = [s0, s1]
        offset = 0
        
        for i in range(self.n_nodes):
            node_inputs = []
            for j in range(2 + i):
                edge_key = f"{j}->{i+2}"
                op = self.edges[edge_key]
                w = weights[offset + j]
                node_inputs.append(op(states[j], w))
            
            s = sum(node_inputs)
            states.append(s)
            offset += 2 + i
        
        return torch.cat(states[-self.n_nodes:], dim=1)

6.3 网络定义

复制代码
class Network(nn.Module):
    """DARTS搜索网络"""
    def __init__(self, C=16, n_classes=10, n_layers=8, n_nodes=4):
        super().__init__()
        
        # 初始卷积
        self.stem = nn.Sequential(
            nn.Conv2d(3, C, 3, 1, 1, bias=False),
            nn.BatchNorm2d(C)
        )
        
        # 构建层
        self.cells = nn.ModuleList()
        C_prev_prev, C_prev, C_curr = C, C, C
        reduction_layers = [n_layers // 3, 2 * n_layers // 3]
        
        for i in range(n_layers):
            reduction = i in reduction_layers
            if reduction:
                C_curr *= 2
            
            cell = Cell(n_nodes, C_prev_prev, C_prev, C_curr, reduction)
            self.cells.append(cell)
            C_prev_prev = C_prev
            C_prev = n_nodes * C_curr
        
        # 全局平均池化和分类器
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(C_prev, n_classes)
        
        # 架构参数
        n_edges = sum(2 + i for i in range(n_nodes))
        self.alphas_normal = nn.Parameter(torch.zeros(n_edges, len(OPS)))
        self.alphas_reduce = nn.Parameter(torch.zeros(n_edges, len(OPS)))
    
    def forward(self, x):
        s0 = s1 = self.stem(x)
        
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                weights = F.softmax(self.alphas_normal, dim=-1)
            
            s0, s1 = s1, cell(s0, s1, weights)
        
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits

6.4 训练代码

复制代码
class DARTSTrainer:
    """DARTS训练器"""
    def __init__(self, model, train_loader, val_loader):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # 优化器分离
        self.w_optimizer = torch.optim.SGD(
            model.weights(), lr=0.025, momentum=0.9, weight_decay=3e-4
        )
        self.alpha_optimizer = torch.optim.Adam(
            model.alphas(), lr=3e-4, betas=(0.5, 0.999), weight_decay=1e-3
        )
    
    def train(self, epochs=50):
        for epoch in range(epochs):
            # 阶段1:训练网络权重(内层优化)
            self._train_weights()
            
            # 阶段2:训练架构参数(外层优化)
            self._train_alphas()
            
            if epoch % 10 == 0:
                arch = self.model.discretize()
                acc = self.evaluate(arch)
                print(f"Epoch {epoch}, Arch: {arch}, Acc: {acc:.4f}")
    
    def _train_weights(self):
        """训练网络权重"""
        self.model.train()
        for x, y in self.train_loader:
            x, y = x.cuda(), y.cuda()
            
            self.w_optimizer.zero_grad()
            output = self.model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            self.w_optimizer.step()
    
    def _train_alphas(self):
        """训练架构参数"""
        self.model.train()
        for x, y in self.val_loader:
            x, y = x.cuda(), y.cuda()
            
            self.alpha_optimizer.zero_grad()
            output = self.model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            self.alpha_optimizer.step()
            break  # 每个epoch只更新一次架构参数

七、方法对比总结

方法 搜索方式 评估方式 优点 缺点
NASNet RL 完整训练 效果好 2000 GPU天
ENAS RL 权重共享 1 GPU天 可能偏置
DARTS 梯度 完整训练 可微分 内存大
ProxylessNAS 梯度 路径二值化 直接硬件 实现复杂
Once-for-All 预训练 权重共享 快速部署 训练成本高

八、总结

8.1 NAS的核心要点

  1. 核心问题:手工设计网络耗时且依赖经验,自动发现超越人类设计的架构
  2. 三要素:搜索空间、搜索策略、性能评估
  3. 主要方法
    • 强化学习:NASNet,RNN控制器+策略梯度
    • 进化算法:AmoebaNet,变异+选择
    • 梯度优化:DARTS,连续松弛+双层优化
    • 高效方法:ENAS、ProxylessNAS、Once-for-All

8.2 学习路径总结

复制代码
深度学习架构演进:
  ├── 手工设计(AlexNet → VGG → ResNet → DenseNet)
  ├── 手工设计 + 缩放(EfficientNet的复合缩放)
  └── 自动搜索(NAS)

NAS方法演进:
  ├── 基于强化学习(NASNet)
  ├── 基于进化算法(AmoebaNet)
  ├── 基于梯度(DARTS)
  ├── 权重共享(ENAS)
  ├── 直接硬件搜索(ProxylessNAS)
  └── 超级网络(Once-for-All)

下一篇预告:【第42篇】AutoML入门:自动化机器学习全流程

我们将探讨如何让AI自动完成从数据预处理到模型部署的整个机器学习流程,实现真正的"零代码AI开发"!


本文为系列第41篇,详细介绍了神经架构搜索的原理与实战。有任何问题欢迎在评论区交流!

标签:神经架构搜索、NAS、DARTS、AutoML、网络架构设计、深度学习自动化

相关推荐
威联通网络存储1 天前
QNAP全闪存架构:化解制造车间AOI数据I/O瓶颈
nas
威联通安全存储1 天前
资产治理:QNAP 存算融合架构理顺工程机械装配车间异构图纸流转
nas
威联通安全存储1 天前
QNAP存算一体架构:筑牢制造MES数据容灾防线
nas
nebula-AI2 天前
人工智能导论:模型与算法(未来发展与趋势)
人工智能·神经网络·算法·机器学习·量子计算·automl·类脑计算
威联通网络存储4 天前
制造协同:QNAP 软硬件架构化解汽车冲压车间大文件传输难题
nas
威联通安全存储4 天前
工业数据湖:QNAP 存算架构治理智能制造车间非结构化资产
nas
威联通安全存储9 天前
QNAP 存算融合架构:构建智造工厂 IIoT 数据湖
nas
威联通安全存储10 天前
威联通 AI 视觉边缘检测:半导体缺陷识别中的 GPU 直通与全闪存 I/O 协同
nas
晨陌y12 天前
EinVault宠物健康追踪器:NAS部署实录,随时记录疫苗体重和日常
宠物·nas