基于神经网络的弹弹堂类游戏弹道快速预测

目录

[一、 目的... 1](#一、 目的... 1)

[1.1 输入与输出.... 1](#1.1 输入与输出.... 1)

[1.2 隐网络架构设计.... 1](#1.2 隐网络架构设计.... 1)

[1.3 激活函数与损失函数.... 1](#1.3 激活函数与损失函数.... 1)

[二、 训练... 2](#二、 训练... 2)

[2.1 数据加载与预处理.... 2](#2.1 数据加载与预处理.... 2)

[2.2 训练过程.... 2](#2.2 训练过程.... 2)

[2.3 训练参数与设置.... 2](#2.3 训练参数与设置.... 2)

[三、 测试与分析... 2](#三、 测试与分析... 2)

[3.1 性能对比.... 2](#3.1 性能对比.... 2)

[3.2 训练过程差异.... 3](#3.2 训练过程差异.... 3)

[四、 训练过程中的损失变化... 3](#四、 训练过程中的损失变化... 3)

[五、 代码... 6](#五、 代码... 6)


一、目的

在机器学习中,神经网络是解决回归和分类问题的强大工具。本文通过对比全连接神经网络(SimpleNN)在不同激活函数下的表现,探索不同激活函数对模型训练过程和最终性能的影响。本实验通过使用PyTorch框架,首先使用ReLU激活函数,之后将激活函数切换为tanh,分析这两种激活函数在回归问题中的差异。

1.1 输入与输出

本实验中的神经网络模型输入的是来自MATLAB文件(data.mat)的数据集,其中包括4个输入特征和1个输出标签。数据通过标准化处理后输入神经网络,网络模型通过学习特征和标签之间的关系来预测输出。最终网络输出为一个连续值,即回归问题中的预测值。

1.2 隐网络架构设计

SimpleNN模型:

本实验使用了一个简单的前馈神经网络模型,包含一个输入层、一个隐藏层和一个输出层。输入层的节点数与特征数量相同,输出层的节点数与标签数量相同。隐藏层的节点数设置为10。激活函数用于隐藏层的神经元,以增加模型的非线性表达能力。

在此实验中,我们首先使用了ReLU激活函数进行训练,然后将激活函数替换为tanh进行对比分析。

1.3 激活函数与损失函数

  1. 激活函数选择
  • ReLU(Rectified Linear Unit):

    是一种常用的激活函数,其输出为正输入或零。ReLU有助于缓解梯度消失问题,并加速神经网络的训练。

  • tanh(双曲正切函数):

    是一种平滑的非线性激活函数,其输出范围为-1到1。与ReLU相比,tanh的输出范围较小,并且存在梯度消失的风险,但它能够处理负值输入,适用于某些回归任务。

  1. 损失函数选择
    本实验使用均方误差(MSE)作为损失函数,用于回归任务中度量模型预测与真实输出之间的差异。

二、训练

2.1 数据加载与预处理

数据集来自MATLAB的.mat文件。输入特征(4个)和输出标签(1个)首先被提取,并通过MinMaxScaler进行归一化处理。数据集被随机分割为训练集和测试集,其中50个样本用于测试,剩余的用于训练。

2.2 训练过程

网络通过3000次迭代进行训练。在每一次迭代中,模型使用训练数据进行前向传播,计算预测结果与真实标签之间的损失。然后进行反向传播,更新网络的参数。训练的停止条件为损失低于设定阈值(1e-14)。

2.3 训练参数与设置

训练过程中使用的主要参数如下:

  • 学习率: 0.001
  • 训练轮次: 最大3000次,或提前停止
  • 损失函数: 均方误差(MSE)损失函数
  • 优化器: Adam

三、测试与分析

3.1 性能对比

  • 使用ReLU激活函数时:

    在训练过程中,模型的损失函数逐渐下降,表现出良好的学习效果。最终损失值趋近于0,表明网络能够较好地拟合训练数据。测试时,模型能够有效地预测测试集的数据,偏差较小。

  • 使用tanh激活函数时:

    与ReLU相比,使用tanh激活函数时,损失下降的速度较慢,且网络训练的初期出现较大的波动。这可能与tanh的输出范围(-1到1)有关,导致梯度消失问题,尤其是在多层网络中。

3.2 训练过程差异

  1. 收敛速度
  • ReLU: 在训练初期收敛较快,且表现出较好的梯度更新能力。在训练过程中,模型的准确性和损失函数下降速度较为平稳。

  • tanh: 收敛速度较慢,且在训练初期存在较大的梯度波动。由于其在负输入下的饱和特性,可能导致梯度更新较慢,尤其是在深层网络中。

  1. 偏差分析
  • 使用ReLU时: 偏差较小,模型预测与实际值之间的差异较少,说明模型具有较好的预测能力。

  • 使用tanh时: 偏差稍大,尤其是在某些测试样本上。虽然损失函数已经较低,但由于tanh的输出范围限制,模型在某些输入上可能无法达到完全准确的预测。


四、训练过程中的损失变化

图 1: ReLU训练损失曲线
图 2: ReLU测试数据集结果图
图 3: tanh训练损失曲线
图 4: tanh测试数据集结果图

|----------------------------------------------------------------------------|
| |
| 图****1 relu 训练损失曲线 |

|----------------------------------------------------------------------------|
| |
| 图****2 relu 测试数据集结果图 |

|----------------------------------------------------------------------------|
| |
| 图****3 tanh 训练损失曲线 |

|----------------------------------------------------------------------------|
| |
| 图****4 tanh 测试数据集结果图 |

  • 代码

|--------------------------------------------------------------------------------------------------------------------------|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from sklearn.preprocessing import MinMaxScaler |
| import matplotlib.pyplot as plt |
| import scipy.io |
| # 加载 .mat 文件(替换为实际的路径) |
| data = scipy.io.loadmat( 'E: \\ Learn Project \\ matlab_pjt \\ data.mat' ) |
| # 获取数据 |
| data = data[ 'data' ] |
| # 输入和输出数据 |
| inputs = data[:, : 4 ] # 输入特征 |
| outputs = data[:, 4 :] # 输出标签 |
| # 随机分割数据为训练集和测试集 |
| test_size = 50 # 测试集大小 |
| indices = np.random.permutation( len (inputs)) |
| train_indices = indices[test_size:] |
| test_indices = indices[:test_size] |
| input_train = inputs[train_indices] |
| output_train = outputs[train_indices] |
| input_test = inputs[test_indices] |
| output_test = outputs[test_indices] |
| # 数据归一化 |
| scaler_input = MinMaxScaler() |
| scaler_output = MinMaxScaler() |
| input_train_scaled = scaler_input.fit_transform(input_train) |
| output_train_scaled = scaler_output.fit_transform(output_train) |
| input_test_scaled = scaler_input.transform(input_test) |
| output_test_scaled = scaler_output.transform(output_test) |
| # 转换为 PyTorch 张量 |
| X_train_tensor = torch.tensor(input_train_scaled, dtype =torch.float32) |
| y_train_tensor = torch.tensor(output_train_scaled, dtype =torch.float32) |
| X_test_tensor = torch.tensor(input_test_scaled, dtype =torch.float32) |
| y_test_tensor = torch.tensor(output_test_scaled, dtype =torch.float32) |
| # 定义简单的神经网络 |
| class SimpleNN(nn.Module): |
| def init ( self , input_size, hidden_size, output_size): |
| super (SimpleNN, self ). init () |
| self .layer1 = nn.Linear(input_size, hidden_size) |
| self .layer2 = nn.Linear(hidden_size, output_size) |
| def forward ( self , x): |
| x = torch.relu( self .layer1(x)) # 激活函数 ReLU |
| x = self .layer2(x) # 输出层 |
| return x |
| # 网络参数 |
| input_size = input_train.shape[ 1 ] |
| hidden_size = 10 |
| output_size = output_train.shape[ 1 ] |
| # 创建模型 |
| model = SimpleNN(input_size, hidden_size, output_size) |
| # 损失函数和优化器 |
| criterion = nn.MSELoss() # 均方误差损失 |
| optimizer = optim.Adam(model.parameters(), lr = 0.001 ) |
| # 训练模型 |
| epochs = 3000 |
| loss_history = [] # 保存损失变化 |
| for epoch in range (epochs): |
| optimizer.zero_grad() # 清空梯度 |
| output = model(X_train_tensor) # 前向传播 |
| loss = criterion(output, y_train_tensor) # 计算损失 |
| loss.backward() # 反向传播 |
| optimizer.step() # 更新参数 |
| # 记录损失 |
| loss_history.append(loss.item()) |
| # 停止条件 |
| if loss.item() < 1e-14 : |
| print ( f" 训练提前停止,当前迭代: { epoch } " ) |
| break |
| # 绘制训练损失图 |
| plt.plot(loss_history) |
| plt.xlabel( 'Epoch' ) |
| plt.ylabel( 'Loss (MSE)' ) |
| plt.title( 'Training Loss History' ) |
| plt.show() |
| # 测试模型 |
| with torch.no_grad(): |
| model.eval() # 设置模型为评估模式 |
| y_test_pred_scaled = model(X_test_tensor) # 预测 |
| y_test_pred = scaler_output.inverse_transform(y_test_pred_scaled.numpy()) # 反归一化 |
| # 计算每个样本的偏差 |
| deviation = np.sqrt(np.sum((output_test - y_test_pred) ** 2 , axis = 1 )) # 欧几里得距离 |
| # 绘制偏差图 |
| plt.plot(deviation, marker = 'o' , color = 'red' ) |
| plt.xlabel( 'Sample Index' ) |
| plt.ylabel( 'Deviation' ) |
| plt.title( 'Test Deviation' ) |
| plt.show() |

相关推荐
Tttian6221 小时前
Python办公自动化(3)对Excel的操作
开发语言·python·excel
xyliiiiiL1 小时前
ZGC初步了解
java·jvm·算法
爱的叹息1 小时前
RedisTemplate 的 6 个可配置序列化器属性对比
算法·哈希算法
独好紫罗兰2 小时前
洛谷题单2-P5713 【深基3.例5】洛谷团队系统-python-流程图重构
开发语言·python·算法
每次的天空3 小时前
Android学习总结之算法篇四(字符串)
android·学习·算法
闪电麦坤953 小时前
C#:base 关键字
开发语言·c#
Mason Lin3 小时前
2025年3月29日(matlab -ss -lti)
开发语言·matlab
请来次降维打击!!!3 小时前
优选算法系列(5.位运算)
java·前端·c++·算法
qystca3 小时前
蓝桥云客 刷题统计
算法·模拟
别NULL3 小时前
机试题——统计最少媒体包发送源个数
c++·算法·媒体