神经网络基础-神经网络补充概念-49-adam优化算法

概念

Adam(Adaptive Moment Estimation)是一种优化算法,结合了动量梯度下降法和RMSProp的优点,用于在训练神经网络等深度学习模型时自适应地调整学习率。Adam算法在深度学习中广泛应用,通常能够加速收敛并提高模型性能。

Adam算法综合了动量(momentum)和均方梯度的移动平均(RMSProp)来更新模型参数。与传统的梯度下降法不同,Adam维护了一个每个参数的动量变量和均方梯度的移动平均变量,并在每个迭代步骤中使用这些变量来调整学习率。

步骤

1初始化参数:初始化模型的参数。

2初始化动量变量和均方梯度的移动平均:初始化动量变量为零向量,初始化均方梯度的移动平均为零向量。

3计算梯度:计算当前位置的梯度。

4更新动量变量:计算动量变量的移动平均。

python 复制代码
momentum = beta1 * momentum + (1 - beta1) * gradient

其中,beta1 是用于计算动量变量移动平均的超参数。

5更新均方梯度的移动平均:计算均方梯度的移动平均。

python 复制代码
moving_average = beta2 * moving_average + (1 - beta2) * gradient^2

其中,beta2 是用于计算均方梯度的移动平均的超参数

6修正偏差

对动量变量和均方梯度的移动平均进行偏差修正,以减轻初始迭代的影响。

python 复制代码
corrected_momentum = momentum / (1 - beta1^t)
corrected_moving_average = moving_average / (1 - beta2^t)

7更新参数

python 复制代码
parameter = parameter - learning_rate * corrected_momentum / (sqrt(corrected_moving_average) + epsilon)

其中,epsilon 是一个小的常数,防止分母为零。

8重复迭代:重复执行步骤 3 到 7,直到达到预定的迭代次数(epochs)或收敛条件。

代码实现

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

# 生成随机数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)

# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]

# 初始化参数
theta = np.random.randn(2, 1)

# 学习率
learning_rate = 0.1

# Adam参数
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
momentum = np.zeros_like(theta)
moving_average = np.zeros_like(theta)

# 迭代次数
n_iterations = 1000

# Adam优化
for iteration in range(n_iterations):
    gradients = 2 / 100 * X_b.T.dot(X_b.dot(theta) - y)
    momentum = beta1 * momentum + (1 - beta1) * gradients
    moving_average = beta2 * moving_average + (1 - beta2) * gradients**2
    corrected_momentum = momentum / (1 - beta1**(iteration+1))
    corrected_moving_average = moving_average / (1 - beta2**(iteration+1))
    theta = theta - learning_rate * corrected_momentum / (np.sqrt(corrected_moving_average) + epsilon)

# 绘制数据和拟合直线
plt.scatter(X, y)
plt.plot(X, X_b.dot(theta), color='red')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Linear Regression
相关推荐
KuaFuAI11 分钟前
百度“秒哒”能开始内测了?李彦宏:假!
人工智能·百度·aigc·码上飞·ai产品榜·一句话生成一个应用
羑悻的小杀马特15 分钟前
计算机视觉:撕裂时空的视觉算法革命狂潮
人工智能·算法·计算机视觉
gentle_ice15 分钟前
leetcode——搜索二维矩阵II(java)
java·算法·leetcode·矩阵
l1m0_16 分钟前
什么是波士顿矩阵,怎么制作?AI工具一键生成战略分析图!
人工智能·ai·信息可视化·矩阵·aigc·波士顿矩阵
Icomi_18 分钟前
【PyTorch】3.张量类型转换
c语言·c++·人工智能·pytorch·python·深度学习·神经网络
OTWOL27 分钟前
八种排序算法【C语言实现】
c语言·算法·排序算法
Doopny@35 分钟前
求阶乘(信息学奥赛一本通-2019)
算法
GISer Liu1 小时前
深入理解Transformer中的解码器原理(Decoder)与掩码机制
开发语言·人工智能·python·深度学习·机器学习·llm·transformer
金融OG1 小时前
6. 马科维茨资产组合模型+政策意图AI金融智能体(DeepSeek-V3)增强方案(理论+Python实战)
大数据·人工智能·python·算法·机器学习·数学建模·金融
PaLu-LI1 小时前
ORB-SLAM2源码学习:Initializer.cc(11): Initializer::ReconstructH用H矩阵恢复R, t和三维点
c++·人工智能·学习·ubuntu·计算机视觉·矩阵