使用scikit-learn中的线性回归包对自定义数据集进行拟合

1. 导入必要的库

首先,需要导入所需的库,包括pandas用于数据处理,numpy用于数值计算,以及scikit-learn中的线性回归模型。

python 复制代码
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。将使用pandas读取这个文件。

python 复制代码
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有两列:'feature'为特征,'target'为标签
X = data[['feature']].values  # 特征需要是二维数组
y = data['target'].values     # 标签

3. 分割数据集

为了评估模型的性能,将数据集分为训练集和测试集。

python 复制代码
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

4. 创建并训练线性回归模型

使用scikit-learn中的LinearRegression类创建线性回归模型,并使用训练集进行训练。

python 复制代码
# 创建线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X_train, y_train)

5. 进行预测并评估模型

使用测试集进行预测,并评估模型的性能。将使用均方误差(MSE)和决定系数(R²)作为评估指标。

python 复制代码
# 进行预测
y_pred = model.predict(X_test)

# 计算均方误差和决定系数
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f'Mean Squared Error: {mse:.2f}')
print(f'R² Score: {r2:.2f}')

6. 可视化结果

为了更直观地了解模型的拟合效果,可以绘制散点图来显示真实值和预测值。

python 复制代码
# 可视化结果
plt.scatter(X_test, y_test, color='black', label='Actual data')
plt.plot(X_test, y_pred, color='blue', linewidth=3, label='Fitted line')

plt.xlabel('Feature')
plt.ylabel('Target')
plt.title('Linear Regression Fit')
plt.legend()
plt.show()
相关推荐
Dxy12393102161 天前
Python检查JSON格式错误的多种方法
前端·python·json
Lightning-py1 天前
ASCII,十进制,十六进制,八进制和二进制转换表
python
laplace01231 天前
deque+yield+next语法
人工智能·笔记·python·agent·rag
福大大架构师每日一题1 天前
2026-01-15:下一个特殊回文数。用go语言,给定一个整数 n,求出一个比 n 更大的最小整数,该整数需要满足两条规则: 1. 它的十进制表示从左到右与从右到左完全一致(即读起来是对称的)。 2
python·算法·golang
芝士爱知识a1 天前
[2026深度测评] AI期权交易平台推荐榜单:AlphaGBM领跑,量化交易的新范式
开发语言·数据结构·人工智能·python·alphagbm·ai期权工具
overmind1 天前
oeasy Python 113 内置函数sorted中使用 reverse和key
开发语言·python
AC赳赳老秦1 天前
2026主权AI趋势:DeepSeek搭建企业自有可控AI环境,保障数据安全实战
大数据·数据库·人工智能·python·科技·rabbitmq·deepseek
小小张说故事1 天前
OpenCV Python技术文档
python·opencv
PD我是你的真爱粉1 天前
Redis持久化、内存管理、慢查询与发布订阅
redis·python·mybatis
查无此人byebye1 天前
实战DDPM扩散模型:MNIST手写数字生成+FID分数计算(完整可运行版)
人工智能·pytorch·python·深度学习·音视频