作业二.自定义数据集使用scikit-learn中的包实现线性回归方法对其进行拟合

from sklearn.linear_model import LinearRegression

from sklearn.model_selection import train_test_split

from sklearn.metrics import mean_squared_error

import numpy as np

import matplotlib.pyplot as plt

np.random.seed(0)

加载自定义数据集

X = 2 * np.random.rand(100, 1)

y = 4 + 3 * X + np.random.randn(100, 1)

将数据集划分为训练集和测试集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

创建线性回归模型对象并拟合训练数据

model = LinearRegression()

model.fit(X_train, y_train)

使用训练好的模型对测试集进行预测

y_pred = model.predict(X_test)

计算预测误差

mse = mean_squared_error(y_test, y_pred)

print("均方误差:", mse)

plt.scatter(X_test, y_test, color='blue')

plt.plot(X_test, y_pred, color='red')

plt.show()

相关推荐
Cricyta Sevina几秒前
Java Collection 集合进阶知识笔记
java·笔记·python·collection集合
胡萝卜3.03 分钟前
深入C++可调用对象:从function包装到bind参数适配的技术实现
开发语言·c++·人工智能·机器学习·bind·function·包装器
Echo_NGC22373 分钟前
【KL 散度】深入理解 Kullback-Leibler Divergence:AI 如何衡量“像不像”的问题
人工智能·算法·机器学习·散度·kl
小a杰.4 分钟前
Flutter 设计系统构建指南
开发语言·javascript·ecmascript
BD_Marathon10 分钟前
【JavaWeb】Servlet_url-pattern的一些特殊写法问题
java·开发语言·servlet
零度@22 分钟前
Java中Map的多种用法
java·前端·python
中文很快乐23 分钟前
java开发--开发工具全面介绍--新手养成记
java·开发语言·java开发·开发工具介绍·idea开发工具
IMPYLH25 分钟前
Lua 的 Coroutine(协程)模块
开发语言·笔记·后端·中间件·游戏引擎·lua
看见繁华26 分钟前
C++ 高级
开发语言·c++
550A30 分钟前
如何修改kagglehub的数据集默认下载路径
python