BP 神经网络学习 MATLAB 函数详解及应用

BP 神经网络学习 MATLAB 函数详解及应用

一、引言

BP(Back Propagation)神经网络是一种广泛应用于模式识别、数据预测、函数逼近等众多领域的人工神经网络。MATLAB 作为一款功能强大的科学计算软件,为 BP 神经网络的实现提供了丰富而便捷的函数。本文将深入探讨 BP 神经网络学习算法在 MATLAB 中的相关函数,并通过大量的应用实例展示其在实际问题中的应用。

二、BP 神经网络概述

(一)BP 神经网络的结构

BP 神经网络通常由输入层、一个或多个隐藏层和输出层组成。输入层接收外部输入数据,每个神经元对应一个输入特征。隐藏层对输入数据进行非线性变换和特征提取,输出层输出网络的预测结果。神经元之间通过合适的权重连接,信息在网络中从前向后传播,误差则从输出层反向传播来调整权重。

(二)BP 神经网络的学习过程

  1. 前向传播
    输入数据通过网络的各层,经过加权求和和激活函数处理,最终得到输出层的输出。对于输入层神经元 i i i,其输出 x i x_i xi 就是输入数据的第 i i i 个特征值。对于隐藏层和输出层的神经元 j j j,其输入 n e t j = ∑ i w i j x i + b j net_j=\sum_{i}w_{ij}x_i + b_j netj=∑iwijxi+bj(其中 w i j w_{ij} wij 是连接上一层第 i i i 个神经元和当前层第 j j j 个神经元的权重, b j b_j bj 是当前层第 j j j 个神经元的偏置),输出 y j = f ( n e t j ) y_j = f(net_j) yj=f(netj)( f f f 为激活函数)。
  2. 反向传播
    计算输出层与真实输出之间的误差,然后将误差从输出层反向传播到隐藏层,根据误差调整网络的权重和偏置。误差的反向传播是基于链式法则计算梯度来实现的,通过不断迭代调整权重和偏置,使网络的输出尽可能接近真实输出。

三、MATLAB 中 BP 神经网络的相关函数

(一)创建神经网络对象 - newff 函数

  1. 基本语法
    net = newff(P,T,S,TF,BTF,BLF,PF,IPF,OPF,DDF)

    其中:

    • P:输入数据的最小最大值矩阵,一般通过 minmax 函数获取,格式为 [min(P) ; max(P)]
    • T:目标输出数据,用于训练网络。
    • S:一个向量,表示各层神经元的数量,例如 [n1,n2,...,nm],其中 n 1 n1 n1 是输入层神经元数量, n m nm nm 是输出层神经元数量,中间是隐藏层神经元数量。
    • TF:每层的激活函数,可以是字符串或单元数组,如 {'tansig','logsig','purelin'}
    • BTF:训练函数,例如 'trainlm'(Levenberg - Marquardt 算法)、'traingd'(梯度下降算法)等。
    • BLF:性能函数,用于评估网络的性能,如 'mse'(均方误差)。
    • PF:绘图函数,用于在训练过程中绘制相关信息。
    • IPF:输入处理函数。
    • OPF:输出处理函数。
    • DDF:权重和偏置的导数函数。
  2. 示例代码

    以下是创建一个简单的 BP 神经网络的示例,用于拟合一个非线性函数 y = 2 x 2 + 3 x + 1 y = 2x^2 + 3x + 1 y=2x2+3x+1,这里假设输入数据范围在 [ 0 , 5 ] [0, 5] [0,5],生成一些训练数据:

matlab 复制代码
% 生成训练数据
x = 0:0.1:5;
y = 2*x.^2 + 3*x + 1;

% 数据预处理
P = [x'];
T = [y'];
net = newff(minmax(P),T,[5,1],{'tansig','purelin'});

在这个示例中,输入层有 1 个神经元(因为输入数据是一维的),隐藏层有 5 个神经元,输出层有 1 个神经元,隐藏层激活函数使用 tansig,输出层激活函数使用 purelin

(二)训练神经网络 - train 函数

  1. 基本语法
    [net,tr] = train(net,P,T)

    其中 net 是要训练的神经网络对象,P 是输入数据,T 是目标数据,tr 是训练记录,包含训练过程中的步数、性能等信息。

  2. 示例代码

    继续上面的示例,使用梯度下降算法(traingd)训练网络:

matlab 复制代码
net.trainFcn = 'traingd';
net.trainParam.epochs = 100; % 训练轮数
net.trainParam.lr = 0.1; % 学习率

[net,tr] = train(net,P,T);

这里设置了训练轮数为 100,学习率为 0.1。训练完成后,可以查看训练记录 tr 的相关信息,如训练过程中的误差变化。

(三)模拟网络输出 - sim 函数

  1. 基本语法
    Y = sim(net,P)

    用于使用训练好的网络 net 对输入数据 P 进行模拟输出。

  2. 示例代码

    在训练完上述网络后,可以使用新的数据进行预测:

matlab 复制代码
new_x = 0.2:0.1:5.2;
new_P = [new_x'];
output = sim(net,new_P);
plot(new_x,output,'r',x,y,'o'); % 绘制预测结果和原始数据
legend('预测结果','原始数据');

这将绘制出网络对新数据的预测结果,并与原始数据进行对比。

四、不同应用场景下的 BP 神经网络实例

(一)函数逼近

  1. 复杂函数逼近
    考虑逼近函数 f ( x ) = sin ⁡ ( x ) + 0.5 x 2 − 0.2 x f(x)=\sin(x) + 0.5x^2 - 0.2x f(x)=sin(x)+0.5x2−0.2x,在区间 [ − 5 , 5 ] [-5,5] [−5,5] 上生成训练数据:
matlab 复制代码
x_data = -5:0.2:5;
y_data = sin(x_data) + 0.5*x_data.^2 - 0.2*x_data;
P = [x_data'];
T = [y_data'];

net = newff(minmax(P),T,[10,1],{'tansig','purelin'});
net.trainFcn = 'trainlm';
net.trainParam.epochs = 200;
net.trainParam.goal = 0.001;

[net,tr] = train(net,P,T);

new_x_data = -5.2:0.2:5.2;
new_P = [new_x_data'];
approx_output = sim(net,new_P);
plot(x_data,y_data,'b',new_x_data,approx_output,'r');
legend('原始函数','逼近结果');

在这个例子中,通过调整网络结构和训练参数,使用 Levenberg - Marquardt 算法训练网络来逼近复杂函数,最后绘制出原始函数和逼近结果的对比图。

(二)模式识别 - 手写数字识别(简化示例)

  1. 数据准备
    假设已经有一个简单的手写数字图像数据集,图像大小为 8 × 8 8\times8 8×8,将其转换为向量形式作为输入。每个数字类别对应一个目标输出向量(例如,数字 0 可以用 [ 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] [1,0,0,0,0,0,0,0,0,0] [1,0,0,0,0,0,0,0,0,0] 表示)。
matlab 复制代码
% 假设已经加载图像数据和标签,存储在 image_data 和 labels 中
image_data = reshape(image_data, [], size(image_data,3)); % 转换为向量
input_data = double(image_data) / 255; % 归一化

% 目标输出数据处理
num_classes = 10; % 数字 0 - 9
target_data = zeros(size(image_data,3), num_classes);
for i = 1:size(image_data,3)
    target_data(i,labels(i)+1) = 1;
end

net = newff(minmax(input_data),target_data,[64,20,10],{'logsig','logsig','softmax'});
net.trainFcn = 'traingda'; % 自适应学习率算法
net.trainParam.epochs = 100;
net.trainParam.lr = 0.01;

[net,tr] = train(net,input_data,target_data);

% 在测试集上测试(假设已经有测试数据 test_image_data 和 test_labels)
test_image_data = reshape(test_image_data, [], size(test_image_data,3));
test_input_data = double(test_image_data) / 255;
outputs = sim(net,test_input_data);
[~, predicted_labels] = max(outputs,[],2);
accuracy = sum(predicted_labels == test_labels) / length(test_labels);
disp(['手写数字识别准确率: ', num2str(accuracy * 100), '%']);

这个示例展示了如何使用 BP 神经网络进行手写数字识别,包括数据预处理、网络创建、训练和测试过程,使用自适应学习率算法来提高训练效果。

(三)时间序列预测 - 股票价格预测

  1. 数据准备与处理
    假设已经有某股票的历史价格数据 stock_prices,这里使用过去 10 天的价格来预测第 11 天的价格。
matlab 复制代码
input_data = [];
target_data = [];
window_size = 10;
for i = 1:length(stock_prices) - window_size
    input_data = [input_data; stock_prices(i:i + window_size - 1)'];
    target_data = [target_data; stock_prices(i + window_size)];
end

net = newff(minmax(input_data),target_data,[10,5,1],{'tansig','tansig','purelin'});
net.trainFcn = 'trainlm';
net.trainParam.epochs = 50;
net.trainParam.goal = 0.01;

[net,tr] = train(net,input_data,target_data);

% 预测未来价格
last_10_prices = stock_prices(end - 9:end);
next_price = sim(net,[last_10_prices']);
disp(['预测的下一天股票价格: ', num2str(next_price)]);

在这个股票价格预测的例子中,构建了合适的输入和目标数据,创建并训练 BP 神经网络,最后使用训练好的网络对下一天的股票价格进行预测。

五、BP 神经网络训练参数的选择与优化

(一)学习率的选择

学习率决定了每次权重更新的步长。如果学习率过大,可能导致网络无法收敛甚至发散;如果学习率过小,训练过程会非常缓慢。可以通过试验不同的值来选择合适的学习率,例如在上述的例子中,可以尝试从 0.001 0.001 0.001 到 0.5 0.5 0.5 之间的值,观察训练过程中的误差变化和收敛速度。

(二)训练轮数的确定

训练轮数(epochs)是训练过程的迭代次数。过少的轮数可能导致网络未充分训练,而过多的轮数可能导致过拟合。可以通过观察训练过程中的性能指标(如均方误差)在验证集上的变化来确定合适的训练轮数。如果在验证集上误差开始上升,说明可能出现了过拟合,此时的训练轮数可能已经过多。

(三)隐藏层神经元数量的调整

隐藏层神经元数量对网络的性能有很大影响。过少的神经元可能无法学习到数据中的复杂模式,而过多的神经元可能导致过拟合。可以通过逐步增加或减少神经元数量,观察网络在训练集和验证集上的性能来调整。

六、总结

本文详细介绍了 BP 神经网络学习算法在 MATLAB 中的主要函数,包括 newfftrainsim 函数的用法,并通过函数逼近、手写数字识别和股票价格预测等多个应用实例展示了如何使用这些函数解决实际问题。同时,还讨论了训练参数的选择和优化方法。BP 神经网络在 MATLAB 中的实现为解决各种复杂的机器学习和数据处理问题提供了有力的工具,通过合理选择网络结构和训练参数,可以在不同领域取得较好的应用效果。在实际应用中,需要根据具体问题的特点不断尝试和改进,以提高网络的性能和泛化能力。

相关推荐
飞Link2 分钟前
智能体时代的“紧箍咒”:深度解析 Agent 治理架构与 AI 杀伤开关
人工智能·架构
飞Link8 分钟前
2000 亿砸向算力:字节跳动 AI 基建跨越,后端与运维的“万亿 Token”生死战
运维·人工智能
zhangfeng113321 分钟前
小龙虾 wordbuddy 安装浏览器控制器 agent-browser npm install -g agent-browse
前端·人工智能·npm·node.js
阿里云大数据AI技术21 分钟前
一条 SQL 生成广告:Hologres 如何实现素材生成到投放分析一体化
人工智能·sql
liudanzhengxi29 分钟前
GitSubmodule避坑全攻略
人工智能·新人首发
用户4252108006031 分钟前
Claude Code Linux 服务器部署与配置
人工智能
OJAC11134 分钟前
学过Python却不敢投AI岗,他最后拿下12K offer
人工智能
Bigger34 分钟前
因为看不懂小棉袄的画,我写了个 AI 程序帮我“翻译”她的世界
前端·人工智能·ai编程
CeshirenTester37 分钟前
LangChain的工具调用 vs 原生Skill API:性能差在哪儿?
java·人工智能·langchain
爱问的艾文1 小时前
八周带你手搓AI应用-第二周-让AI更像人-第1天-流式输出改造
人工智能