深入理解前向传播、反向传播和计算图

1. 什么是前向传播?

前向传播(Forward Propagation)是神经网络的推理过程。它将输入数据逐层传递,通过每一层的神经元计算,最终生成输出。

前向传播的公式

假设我们有一个简单的三层神经网络(输入层、一个隐藏层和输出层),网络的每一层计算如下:

z ( 1 ) = W ( 1 ) ⋅ X + b ( 1 ) z^{(1)} = W^{(1)} \cdot X + b^{(1)} z(1)=W(1)⋅X+b(1)
a ( 1 ) = σ ( z ( 1 ) ) a^{(1)} = \sigma(z^{(1)}) a(1)=σ(z(1))
z ( 2 ) = W ( 2 ) ⋅ a ( 1 ) + b ( 2 ) z^{(2)} = W^{(2)} \cdot a^{(1)} + b^{(2)} z(2)=W(2)⋅a(1)+b(2)
y ^ = a ( 2 ) = softmax ( z ( 2 ) ) \hat{y} = a^{(2)} = \text{softmax}(z^{(2)}) y^=a(2)=softmax(z(2))

其中,(W) 和 (b) 分别是权重矩阵和偏置,(\sigma) 是激活函数,(\hat{y}) 是网络的输出。

代码示例

我们用 Python 和 NumPy 来实现前向传播的过程:

python 复制代码
import numpy as np

# 输入数据和网络参数
X = np.array([[0.5, 1.2]])
W1 = np.array([[0.4, 0.3], [0.2, 0.7]])
b1 = np.array([[0.1, 0.2]])
W2 = np.array([[0.6], [0.8]])
b2 = np.array([[0.3]])

# 激活函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 前向传播
z1 = np.dot(X, W1) + b1
a1 = sigmoid(z1)
z2 = np.dot(a1, W2) + b2
output = sigmoid(z2)

print("网络输出:", output)

2. 什么是反向传播?

反向传播(Backpropagation)是神经网络的训练过程,它通过计算损失函数的梯度来更新权重,从而最小化损失。

反向传播的原理

反向传播通过链式法则计算梯度。假设我们的损失函数是均方误差(MSE):

L = 1 2 ( y ^ − y ) 2 L = \frac{1}{2} (\hat{y} - y)^2 L=21(y^−y)2

对于每个权重 (W),梯度更新规则为:

∂ L ∂ W ( 2 ) = δ ( 2 ) ⋅ a ( 1 ) \frac{\partial L}{\partial W^{(2)}} = \delta^{(2)} \cdot a^{(1)} ∂W(2)∂L=δ(2)⋅a(1)

其中,(\delta^{(2)}) 是输出层的误差:

δ ( 2 ) = y ^ − y \delta^{(2)} = \hat{y} - y δ(2)=y^−y

代码示例

我们可以用 Python 代码实现一个简单的反向传播:

python 复制代码
# 假设真实标签
y = np.array([[1]])

# 计算损失
loss = 0.5 * (output - y) ** 2

# 反向传播
d_output = output - y
d_z2 = d_output * output * (1 - output)
d_W2 = np.dot(a1.T, d_z2)

d_a1 = np.dot(d_z2, W2.T)
d_z1 = d_a1 * a1 * (1 - a1)
d_W1 = np.dot(X.T, d_z1)

# 更新权重
learning_rate = 0.1
W2 -= learning_rate * d_W2
W1 -= learning_rate * d_W1

print("更新后的 W1:", W1)
print("更新后的 W2:", W2)

3. 理解计算图

计算图(Computational Graph)是表示神经网络中计算过程的图形结构。通过计算图,我们可以直观地理解前向传播和反向传播。

在计算图中,节点表示操作(如加法、乘法、激活函数)或变量(如输入、权重),边表示数据流动。通过前向传播,计算图中的数据从输入流向输出;通过反向传播,梯度从输出反向传播到输入。

4. 实战案例:手写数字识别

让我们用一个简单的手写数字识别的案例来巩固这些概念。我们将使用一个小型的神经网络,通过前向传播预测数字,通过反向传播调整权重。

python 复制代码
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# 加载数据
digits = load_digits()
X = digits.data / 16.0
y = digits.target.reshape(-1, 1)

# 独热编码标签
encoder = OneHotEncoder(sparse=False)
y_onehot = encoder.fit_transform(y)

# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y_onehot, test_size=0.3)

# 初始化网络参数
input_size = X_train.shape[1]
hidden_size = 64
output_size = y_train.shape[1]

W1 = np.random.randn(input_size, hidden_size)
b1 = np.zeros((1, hidden_size))
W2 = np.random.randn(hidden_size, output_size)
b2 = np.zeros((1, output_size))

# 训练过程
epochs = 1000
learning_rate = 0.01

for epoch in range(epochs):
    # 前向传播
    z1 = np.dot(X_train, W1) + b1
    a1 = sigmoid(z1)
    z2 = np.dot(a1, W2) + b2
    output = sigmoid(z2)
    
    # 计算损失
    loss = np.mean(0.5 * (output - y_train) ** 2)
    
    # 反向传播
    d_output = output - y_train
    d_z2 = d_output * output * (1 - output)
    d_W2 = np.dot(a1.T, d_z2)
    
    d_a1 = np.dot(d_z2, W2.T)
    d_z1 = d_a1 * a1 * (1 - a1)
    d_W1 = np.dot(X_train.T, d_z1)
    
    # 更新权重
    W2 -= learning_rate * d_W2
    W1 -= learning_rate * d_W1
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

通过上述代码,我们可以训练一个简单的神经网络来识别手写数字。在每个训练周期,网络通过前向传播计算输出,通过反向传播调整权重。

5. 总结

在本文中,我们深入探讨了前向传播、反向传播和计算图的概念,并通过代码示例和图示帮助理解这些复杂的过程。希望这些内容能帮助你更好地理解神经网络的工作原理。


通过这些详细的解释、代码示例和图示,你的读者应该能够深入理解前向传播、反向传播和计算图在神经网络中的作用。如果需要进一步调整内容,随时可以进行修改。

相关推荐
CountingStars6194 分钟前
目标检测常用评估指标(metrics)
人工智能·目标检测·目标跟踪
tangjunjun-owen12 分钟前
第四节:GLM-4v-9b模型的tokenizer源码解读
人工智能·glm-4v-9b·多模态大模型教程
冰蓝蓝17 分钟前
深度学习中的注意力机制:解锁智能模型的新视角
人工智能·深度学习
橙子小哥的代码世界25 分钟前
【计算机视觉基础CV-图像分类】01- 从历史源头到深度时代:一文读懂计算机视觉的进化脉络、核心任务与产业蓝图
人工智能·计算机视觉
黄公子学安全35 分钟前
Java的基础概念(一)
java·开发语言·python
新加坡内哥谈技术1 小时前
苏黎世联邦理工学院与加州大学伯克利分校推出MaxInfoRL:平衡内在与外在探索的全新强化学习框架
大数据·人工智能·语言模型
程序员一诺1 小时前
【Python使用】嘿马python高级进阶全体系教程第10篇:静态Web服务器-返回固定页面数据,1. 开发自己的静态Web服务器【附代码文档】
后端·python
小木_.1 小时前
【Python 图片下载器】一款专门为爬虫制作的图片下载器,多线程下载,速度快,支持续传/图片缩放/图片压缩/图片转换
爬虫·python·学习·分享·批量下载·图片下载器
fanstuck2 小时前
Prompt提示工程上手指南(七)Prompt编写实战-基于智能客服问答系统下的Prompt编写
人工智能·数据挖掘·openai
lovelin+v175030409662 小时前
安全性升级:API接口在零信任架构下的安全防护策略
大数据·数据库·人工智能·爬虫·数据分析