一、 设计流程与核心哲学
-
从简单开始
- 不要一上来就上ResNet、Transformer。先建立一个简单的基准模型,比如只有一两层的全连接网络或小型CNN。
- 目的:验证你的数据管道是否正确,确保模型能够学习(哪怕只是轻微过拟合),并建立一个性能底线。如果简单模型都学不好,复杂模型大概率也学不好。
-
优先搞定数据和损失函数
- 数据是天花板:模型性能的上限由你的数据质量决定。花大量时间在数据清洗、增强和预处理上,回报率最高。
- 损失函数是导航:你的损失函数必须精确地定义你希望模型"优化什么"。分类任务用交叉熵,回归任务用MSE/MAE,生成任务可能用对抗损失等。选对损失函数是关键第一步。
-
过度拟合一个小数据集
- 在正式训练前,找一个极小的数据子集(比如每个类别几张图片),让模型去训练,并确保它能达到接近100%的训练准确率。
- 目的:这被称为"合理性检查"。如果模型在小数据上都无法过拟合,说明模型架构能力不足、存在bug或优化器设置有问题。
-
迭代式开发与评估
- 遵循一个循环:构建模型 -> 训练 -> 分析误差 -> 提出假设 -> 修改。
- 在验证集/测试集上分析模型在哪里出错,能为你提供最直接的改进方向。是欠拟合还是过拟合?是对某些类别识别不好?还是对图像旋转敏感?
二、 架构选择技巧
-
遵循经过验证的范式
- 计算机视觉 :从 CNN 开始。优先考虑使用现代架构如 ResNet , EfficientNet , MobileNet(作为backbone),它们内置了残差连接、通道注意力等高效机制。
- 自然语言处理/序列建模 :从 RNN (LSTM/GRU) 或 Transformer 开始。对于大多数任务,Transformer(尤其是预训练模型如BERT, GPT)已成为主流。
- 图数据 :使用 图神经网络。
-
善用"现代"构建模块
- 残差连接:几乎是深层网络的必需品,能有效解决梯度消失/爆炸问题,让网络更容易训练得更深。
- 批量归一化:加速训练、提高稳定性、降低对初始化的敏感度。通常放在卷积/全连接层之后,激活函数之前。
- 注意力机制:让模型学会"关注"重要的部分。从Transformer中的自注意力到CNN中的SE模块,都非常有效。
- Dropout:防止过拟合的有效正则化手段,在全连接层后使用效果更明显。
-
选择正确的激活函数
- 默认推荐 :ReLU 及其变体(如 Leaky ReLU , PReLU)。它们解决了Sigmoid/Tanh的梯度消失问题。
- 输出层:二分类用Sigmoid,多分类用Softmax,回归用线性激活。
三、 训练与调参技巧
-
优化器选择
- Adam/AdamW:通常是默认的、效果不错的起点,对学习率不那么敏感。
- SGD with Momentum :在精心调参(特别是学习率调度)后,往往能达到比Adam更好的最终性能,但可能需要更多技巧。AdamW 解决了Adam的权重衰减问题,是目前更推荐的选择。
-
学习率是关键
- 学习率调度:使用动态学习率。常见策略有:步长衰减、余弦退火、预热。
- 学习率预热:在训练开始时使用一个极小的学习率,逐步增大到初始学习率,有助于稳定训练。
- 一周期策略:一种有效的方法,先增大学习率再减小。
- 找不到合适的学习率? 进行学习率搜索,绘制学习率与损失的关系图,选择一个损失下降最快的点。
-
正则化以防止过拟合
- 数据增强:最有效的正则化方法!通过对训练数据进行随机变换(旋转、裁剪、颜色抖动等)来增加数据的多样性和数量。
- 权重衰减:即L2正则化,给损失函数加上权重的平方和,惩罚过大权重。
- 早停:在验证集性能不再提升时停止训练。
- Dropout:如上所述。
四、 高级策略与调试
-
利用预训练模型
- 在你有中等规模的数据集时,迁移学习 是王道。使用在ImageNet、WikiText等大型数据集上预训练好的模型,在你的任务上进行微调,能极大加快收敛速度并提升性能。
-
自动化超参数搜索
- 当手动调参遇到瓶颈时,可以使用自动化工具,如网格搜索、随机搜索、贝叶斯优化等。随机搜索 通常比网格搜索更高效。
-
可视化与监控
- 监控损失和准确率曲线:关注训练集和验证集的差距,判断过拟合/欠拟合。
- 可视化激活和权重:看看网络到底学到了什么。
- 使用梯度裁剪:如果训练中出现梯度爆炸(损失突然变成NaN),梯度裁剪可以稳定训练。
总结:一个简洁的清单
当你开始一个新项目时,可以遵循这个清单:
- 数据:清洗、增强、标准化。
- 模型:从一个简单模型开始,快速验证。
- 损失与优化:选择适合任务的损失函数,用AdamW作为优化器。
- 学习率:使用一个带预热的调度器。
- 训练与监控:在小数据集上过拟合,然后在全数据集上训练,密切监控训练/验证曲线。
- 正则化:如果出现过拟合,增加数据增强、Dropout或权重衰减。
- 迭代:分析错误,提出假设,升级模型架构(如使用ResNet),并重复过程。
- 最终提升:尝试模型集成、测试时增强等。