深入理解前向传播、反向传播和计算图

1. 什么是前向传播?

前向传播(Forward Propagation)是神经网络的推理过程。它将输入数据逐层传递,通过每一层的神经元计算,最终生成输出。

前向传播的公式

假设我们有一个简单的三层神经网络(输入层、一个隐藏层和输出层),网络的每一层计算如下:

z ( 1 ) = W ( 1 ) ⋅ X + b ( 1 ) z^{(1)} = W^{(1)} \cdot X + b^{(1)} z(1)=W(1)⋅X+b(1)
a ( 1 ) = σ ( z ( 1 ) ) a^{(1)} = \sigma(z^{(1)}) a(1)=σ(z(1))
z ( 2 ) = W ( 2 ) ⋅ a ( 1 ) + b ( 2 ) z^{(2)} = W^{(2)} \cdot a^{(1)} + b^{(2)} z(2)=W(2)⋅a(1)+b(2)
y ^ = a ( 2 ) = softmax ( z ( 2 ) ) \hat{y} = a^{(2)} = \text{softmax}(z^{(2)}) y^=a(2)=softmax(z(2))

其中,(W) 和 (b) 分别是权重矩阵和偏置,(\sigma) 是激活函数,(\hat{y}) 是网络的输出。

代码示例

我们用 Python 和 NumPy 来实现前向传播的过程:

python 复制代码
import numpy as np

# 输入数据和网络参数
X = np.array([[0.5, 1.2]])
W1 = np.array([[0.4, 0.3], [0.2, 0.7]])
b1 = np.array([[0.1, 0.2]])
W2 = np.array([[0.6], [0.8]])
b2 = np.array([[0.3]])

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 前向传播
z1 = np.dot(X, W1) + b1
a1 = sigmoid(z1)
z2 = np.dot(a1, W2) + b2
output = sigmoid(z2)

print("网络输出:", output)

2. 什么是反向传播?

反向传播(Backpropagation)是神经网络的训练过程,它通过计算损失函数的梯度来更新权重,从而最小化损失。

反向传播的原理

反向传播通过链式法则计算梯度。假设我们的损失函数是均方误差(MSE):

L = 1 2 ( y ^ − y ) 2 L = \frac{1}{2} (\hat{y} - y)^2 L=21(y^−y)2

对于每个权重 (W),梯度更新规则为:

∂ L ∂ W ( 2 ) = δ ( 2 ) ⋅ a ( 1 ) \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} \cdot a^{(1)} ∂W(2)∂L=δ(2)⋅a(1)

其中,(\delta^{(2)}) 是输出层的误差:

δ ( 2 ) = y ^ − y \delta^{(2)} = \hat{y} - y δ(2)=y^−y

代码示例

我们可以用 Python 代码实现一个简单的反向传播:

python 复制代码
# 假设真实标签
y = np.array([[1]])

# 计算损失
loss = 0.5 * (output - y) ** 2

# 反向传播
d_output = output - y
d_z2 = d_output * output * (1 - output)
d_W2 = np.dot(a1.T, d_z2)

d_a1 = np.dot(d_z2, W2.T)
d_z1 = d_a1 * a1 * (1 - a1)
d_W1 = np.dot(X.T, d_z1)

# 更新权重
learning_rate = 0.1
W2 -= learning_rate * d_W2
W1 -= learning_rate * d_W1

print("更新后的 W1:", W1)
print("更新后的 W2:", W2)

3. 理解计算图

计算图(Computational Graph)是表示神经网络中计算过程的图形结构。通过计算图,我们可以直观地理解前向传播和反向传播。

在计算图中,节点表示操作(如加法、乘法、激活函数)或变量(如输入、权重),边表示数据流动。通过前向传播,计算图中的数据从输入流向输出;通过反向传播,梯度从输出反向传播到输入。

4. 实战案例:手写数字识别

让我们用一个简单的手写数字识别的案例来巩固这些概念。我们将使用一个小型的神经网络,通过前向传播预测数字,通过反向传播调整权重。

python 复制代码
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# 加载数据
digits = load_digits()
X = digits.data / 16.0
y = digits.target.reshape(-1, 1)

# 独热编码标签
encoder = OneHotEncoder(sparse=False)
y_onehot = encoder.fit_transform(y)

# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.3)

# 初始化网络参数
input_size = X_train.shape[1]
hidden_size = 64
output_size = y_train.shape[1]

W1 = np.random.randn(input_size, hidden_size)
b1 = np.zeros((1, hidden_size))
W2 = np.random.randn(hidden_size, output_size)
b2 = np.zeros((1, output_size))

# 训练过程
epochs = 1000
learning_rate = 0.01

for epoch in range(epochs):
    # 前向传播
    z1 = np.dot(X_train, W1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, W2) + b2
    output = sigmoid(z2)
    
    # 计算损失
    loss = np.mean(0.5 * (output - y_train) ** 2)
    
    # 反向传播
    d_output = output - y_train
    d_z2 = d_output * output * (1 - output)
    d_W2 = np.dot(a1.T, d_z2)
    
    d_a1 = np.dot(d_z2, W2.T)
    d_z1 = d_a1 * a1 * (1 - a1)
    d_W1 = np.dot(X_train.T, d_z1)
    
    # 更新权重
    W2 -= learning_rate * d_W2
    W1 -= learning_rate * d_W1
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

通过上述代码,我们可以训练一个简单的神经网络来识别手写数字。在每个训练周期,网络通过前向传播计算输出,通过反向传播调整权重。

5. 总结

在本文中,我们深入探讨了前向传播、反向传播和计算图的概念,并通过代码示例和图示帮助理解这些复杂的过程。希望这些内容能帮助你更好地理解神经网络的工作原理。


通过这些详细的解释、代码示例和图示,你的读者应该能够深入理解前向传播、反向传播和计算图在神经网络中的作用。如果需要进一步调整内容,随时可以进行修改。

相关推荐
yannan2019031313 分钟前
【算法】(Python)动态规划
python·算法·动态规划
埃菲尔铁塔_CV算法15 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR15 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
MarkHD18 分钟前
第十一天 线性代数基础
线性代数·决策树·机器学习
打羽毛球吗️21 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
蒙娜丽宁23 分钟前
《Python OpenCV从菜鸟到高手》——零基础进阶,开启图像处理与计算机视觉的大门!
python·opencv·计算机视觉
光芒再现dev24 分钟前
已解决,部署GPTSoVITS报错‘AsyncRequest‘ object has no attribute ‘_json_response_data‘
运维·python·gpt·语言模型·自然语言处理
好喜欢吃红柚子38 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python42 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长