神经网络中的Adam

Adam(Adaptive Moment Estimation)是一种广泛使用的优化算法,结合了RMSprop和动量(Momentum)的优点。它通过计算梯度的一阶矩估计(mean)和二阶矩估计(uncentered variance),为每个参数提供自适应学习率。Adam由Diederik P. Kingma和Jimmy Ba在2014年的论文《Adam: A Method for Stochastic Optimization》中提出。

Adam的核心思想

Adam的主要特点是:

  • **自适应学习率**:根据参数的梯度一阶矩(均值)和二阶矩(方差)自动调整学习率。

  • **动量**:引入类似于传统动量的概念来加速SGD在相关方向上的进展,并抑制震荡。

  • **偏置校正**:对初期的矩估计进行偏差校正,以应对开始阶段估计不准确的问题。

更新规则

对于时间步 \( t \),某个参数 \( w \) 的更新过程如下:

  1. **计算梯度**:

\[ g_t = \nabla_{w} J(w_{t-1}) \]

这里,\( g_t \) 是损失函数 \( J \) 对参数 \( w \) 在时间步 \( t \) 的梯度。

  1. **计算一阶矩估计(均值)**:

\[ m_t = \beta_1 m_{t-1} + (1 - \beta_1)g_t \]

  1. **计算二阶矩估计(未中心化的方差)**:

\[ v_t = \beta_2 v_{t-1} + (1 - \beta_2)g_t^2 \]

其中,\( \beta_1 \) 和 \( \beta_2 \) 分别是用于控制一阶矩和二阶矩估计的指数衰减率,默认情况下分别设置为 0.9 和 0.999。

  1. **偏差校正**:

\[ \hat{m}_t = \frac{m_t}{1-\beta_1^t} \]

\[ \hat{v}_t = \frac{v_t}{1-\beta_2^t} \]

这一步是为了修正初始时刻的偏差,因为在训练初期,\( m_t \) 和 \( v_t \) 可能会偏向零。

  1. **参数更新**:

\[ w_t = w_{t-1} - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t \]

其中,\( \eta \) 是学习率,\( \epsilon \) 是一个很小的常数(例如 \( 10^{-8} \)),用于确保数值稳定性,避免除以零的情况。

特点与优势

  • **高效性**:Adam通常比其他自适应学习率方法如Adagrad或RMSprop更快收敛。

  • **适用于非平稳目标**:由于其使用了移动平均,因此更适合处理随着时间变化的目标函数。

  • **不需要手动调节学习率**:相比标准SGD,Adam减少了对超参数(特别是学习率)精细调节的需求。

实践中的应用

Adam因其良好的性能和易用性,在深度学习领域得到了广泛应用。无论是图像识别、自然语言处理还是强化学习等领域,Adam都是首选的优化器之一。下面是一个使用TensorFlow/Keras实现Adam的例子:

```python

import tensorflow as tf

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import Dense

创建模型

model = Sequential([Dense(1, input_shape=(8,))])

使用Adam优化器

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

编译模型

model.compile(optimizer=optimizer, loss='mse')

假设我们有一些数据x_train和y_train

model.fit(x_train, y_train, epochs=10)

```

在这个例子中,`learning_rate` 参数可以调整以适应特定任务的需求,但默认值通常已经足够有效。此外,Keras的Adam优化器还允许进一步定制化,比如调整 \( \beta_1 \) 和 \( \beta_2 \) 的值等。

相关推荐
小明_GLC11 分钟前
Falcon-TST: A Large-Scale Time Series Foundation Model
论文阅读·人工智能·深度学习·transformer
Python_Study202511 分钟前
制造业数据采集系统选型指南:从技术挑战到架构实践
大数据·网络·数据结构·人工智能·架构
一只大侠的侠15 分钟前
【工业AI热榜】LSTM+GRU融合实战:设备故障预测准确率99.3%,附开源数据集与完整代码
人工智能·gru·lstm
weisian15123 分钟前
入门篇--知名企业-26-华为-2--华为VS阿里:两种科技路径的较量与共生
人工智能·科技·华为·阿里
棒棒的皮皮29 分钟前
【深度学习】YOLO模型精度优化 Checklist
人工智能·深度学习·yolo·计算机视觉
微尘hjx29 分钟前
【数据集 01】家庭室内烟火数据集(按比例划分训练、验证、测试)包含训练好的yolo11/yolov8模型
深度学习·yolov8·yolo11·训练模型·烟火数据集·家庭火灾数据集·火灾数据集
高洁0136 分钟前
CLIP 的双编码器架构是如何优化图文关联的?(2)
python·深度学习·机器学习·知识图谱
线束线缆组件品替网36 分钟前
Bulgin 防水圆形线缆在严苛环境中的工程实践
人工智能·数码相机·自动化·软件工程·智能电视
Cherry的跨界思维43 分钟前
【AI测试全栈:Vue核心】22、从零到一:Vue3+ECharts构建企业级AI测试可视化仪表盘项目实战
vue.js·人工智能·echarts·vue3·ai全栈·测试全栈·ai测试全栈
冬奇Lab44 分钟前
【Cursor进阶实战·07】OpenSpec实战:告别“凭感觉“,用规格驱动AI编程
人工智能·ai编程