基于MLP回归的鸢尾花花瓣长度预测

基于MLP回归的鸢尾花花瓣长度预测

基于MLP回归的鸢尾花花瓣长度预测

1.作者介绍

郝梦月,女,西安工程大学电子信息学院,2024级研究生

研究方向:模式识别与智能系统

电子邮件:479997163@qq.com

王晓睿,男,西安工程大学电子信息学院,2024级研究生,张宏伟人工智能课题组

研究方向:智能视觉检测与工业自动化技术

电子邮件:3234002295@qq.com

2.MLP算法原理

2.1 MLP简介

多层感知器(MLP)是为了创建决策边界,把多个感知器合并成为一个更大的网络。MLP一般至少由三层组成,其中第一层为数据集的每个输入特征,都有一个节点,最后一层有每个类标签的结点。它的显著特性是:如果网络足够大,那么他们可以表示任意的数学函数;且非线性映射能力和适应性强,能处理各种复杂的模式识别问题。

MLP神经网络属于前馈神经网络的一种。在网络训练过程中,需要通过反向传播算法计算梯度,将误差从输出层反向传播回输入层,用于更新网络参数。

2.2 MLP组成

(1)输入层:用于接收输入数据并将其传递到隐藏层。输入层中的神经元数量等于输入特征的数量。

(2)隐藏层:由一层或多层神经元组成,用于执行计算并转换输入数据。可以通过调整每层中的隐藏层和神经元的数量,以优化网络性能。

(3)激活函数:对隐藏层中每个神经元的输出应用非线性变换。常见的激活函数包括 Sigmoid、ReLU、tanh 等。

(4)输出层:网络的最终输出,例如分类标签或回归目标。输出层中的神经元数量取决于具体的数据,例如分类问题中的类别数量。

(5)权重和偏差:可调节参数,决定相邻层神经元之间的连接强度以及每个神经元的偏差。这些参数在训练过程中学习,以尽量减少网络预测与实际目标值之间的差异。

(6)损失函数:衡量网络预测与实际目标值之间的差异。MLP 的常见损失函数包括回归任务的均方误差和分类任务的交叉熵。

3.基于MLP回归的鸢尾花花瓣长度预测

3.1 数据集介绍

(1)scikit-learn(sklearn)库中常用的内置数据集。

(2)鸢尾花数据集(Iris Dataset)包含150个样本,将其分为3类,每类50个样本,每个样本有4个特征,分别是:花萼的长度和宽度、花瓣的长度和宽度,单位为厘米。

3.2 实验步骤

(1)导入必要的库。

(2)加载数据集并对其进行测试集和训练集的划分。本实验将数据集的20%划分为测试集,80%作为训练集。设置随机种子为30,确保每次运行代码时划分的结果一致。

(3)将数据进行标准化处理,消除量纲之间的差异,使得模型训练更加稳定和高效。

(4)调整参数,最终设置了三个隐藏层,每层10个神经元。激活函数采用relu(线性整流函数),优化算法采用adam(适应性矩估计),最大迭代次数设置为1000次。

(5)训练模型并计算相关参数MSE和R2。

(6)可视化实验数据。

3.3 实验代码

python 复制代码
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data  # 特征矩阵 (150, 4)
y = iris.data[:, 2]  # 花瓣长度作为目标变量
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=30)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 创建MLP回归模型
mlp = MLPRegressor(hidden_layer_sizes=(10, 10, 10),  # 三个隐藏层,每层10个神经元
                   activation='relu',          # 激活函数
                   solver='adam',              # 优化器
                   max_iter=1000,             # 最大迭代次数
                   random_state=30)

# 训练模型
mlp.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = mlp.predict(X_test)
# 计算均方误差和R²分数
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse:.4f}")
print(f"R² Score: {r2:.4f}")

# 绘制真实值与预测值的对比图
plt.scatter(y_test, y_pred, color='blue')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('True Petal Length')
plt.ylabel('Predicted Petal Length')
plt.title('True vs Predicted Petal Length')
plt.show()

3.4 实验结果

MSE越小且R2接近1,说明模型拟合良好。其中,MSE衡量预测误差;R2衡量模型的解释能力。

横轴表示真实的花瓣长度,纵轴表示模型预测的花瓣长度。如果预测完全准确,所有数据点应落在对角线上。本模型所有预测的数据点紧密围绕在对角线周围,说明该模型性能较好。

相关推荐
那就摆吧12 分钟前
U-Net vs. 传统CNN:为什么医学图像分割需要跳过连接?
人工智能·神经网络·cnn·u-net·医学图像
深度学习实战训练营23 分钟前
中英混合的语音识别XPhoneBERT 监督的音频到音素的编码器结合 f0 特征LID
人工智能·音视频·语音识别
WADesk---瓜子31 分钟前
用 AI 自动生成口型同步视频,短视频内容也能一人完成
人工智能·音视频·语音识别·流量运营·用户运营
星环科技TDH社区版39 分钟前
AI Agent 的 10 种应用场景:物联网、RAG 与灾难响应
人工智能·物联网
时序之心1 小时前
ICML 2025 | 深度剖析时序 Transformer:为何有效,瓶颈何在?
人工智能·深度学习·transformer
希艾席帝恩1 小时前
拥抱智慧物流时代:数字孪生技术的应用与前景
大数据·人工智能·低代码·数字化转型·业务系统
Bar_artist1 小时前
离线智能破局,架构创新突围:RockAI与中国AI的“另一条车道”
大数据·人工智能
双向331 小时前
高性能MCP服务器架构设计:并发、缓存与监控
人工智能
weixin_464078071 小时前
机器学习sklearn:处理缺失值
人工智能·机器学习·sklearn
2202_756749691 小时前
04 基于sklearn的机械学习-梯度下降(上)
人工智能·算法·机器学习