【机器学习案列】基于随机森林和xgboost的二手车价格回归预测

一、项目分析

1.1 项目任务

kaggle二手车价格回归预测项目,目的根据各种属性预测二手车的价格

1.2 评估准则

评估的标准是均方根误差:

1.3 数据介绍

数据连接https://www.kaggle.com/competitions/playground-series-s4e9/data?select=train.csv

其中:

  • id:唯一标识符(或编号)
  • brand:品牌
  • model:型号
  • model_year:车型年份
  • mileage(注意这里可能是拼写错误,应该是mileage而不是milage):里程数
  • fuel_type:燃油类型
  • engine:发动机
  • transmission:变速器
  • ext_col:车身颜色(外部)
  • int_col:内饰颜色(内部)
  • accident:事故记录
  • clean_title:清洁标题(通常指车辆是否有清晰的产权记录,无抵押、无重大事故等)
  • price:价格

二、读取数据

2.1 导入相应的库

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split, GridSearchCV
import xgboost as xgb

2.2 读取数据

python 复制代码
file_path = '/kaggle/input/playground-series-s4e9/train.csv'
df = pd.read_csv(file_path)

df.head()
df.shape()


三、Exploratory Data Analysis(EDA)

3.1 车型年份与价格的关系

python 复制代码
plt.figure(figsize=(10, 6))
sns.scatterplot(x='model_year', y='price', data=df)
plt.title('Model Year vs Price')
plt.xlabel('Model Year')
plt.ylabel('Price')
plt.show()

3.2 滞留量与价格的关系

python 复制代码
plt.figure(figsize=(10, 6))
sns.scatterplot(x='milage', y='price', data=df)
plt.title('Milage vs Price')
plt.xlabel('Milage')
plt.ylabel('Price')
plt.show()

3.3 热图检查数值特征之间的关系

python 复制代码
num_df = df.select_dtypes(include=['float64', 'int64'])
plt.figure(figsize=(12, 8))
corr_matrix = num_df.corr()
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", linewidths=0.5, annot_kws={"size": 10})
plt.title('Correlation Matrix', fontsize=16)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

3.4 按品牌统计图表

python 复制代码
plt.figure(figsize=(12, 6))
sns.countplot(data=df, x='brand', order=df['brand'].value_counts().index)
plt.title('Count of Cars by Brand', fontsize=16)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

3.5 箱线图

python 复制代码
plt.figure(figsize=(12, 6))
sns.boxplot(data=df, x='fuel_type', y='milage')
plt.title('Mileage by Fuel Type', fontsize=16)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

1.6 各品牌平均里程数

python 复制代码
plt.figure(figsize=(12, 6))
sns.barplot(data=df, x='brand', y='milage', estimator=np.mean, ci=None)
plt.title('Average Mileage by Brand', fontsize=16)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

四、 数据预测处理

4.1 检查每个特征是否具有不同的值

python 复制代码
for i in df.columns:
    if df[i].nunique()<2:
        print(f'{i} has only one unique value. ')

clean_title has only one unique value.

"Clean "功能只有一个唯一值,所以我们可以将其删除。

python 复制代码
df.drop(['id','clean_title'],axis=1,inplace=True)
df.shape

(188533, 11)

4.2 缺失值处理

python 复制代码
df.isnull().sum().sum()

7535

python 复制代码
df.dropna(inplace=True)
df.isnull().sum().sum()

0

没有缺失的值,所以我们可以继续了。

4.3

使用一热编码将分类变量转换为数值格式

python 复制代码
df = pd.get_dummies(df, columns=['brand', 'model', 'fuel_type', 'transmission', 'ext_col', 'int_col', 'accident','engine' ], drop_first=True)

五、数据预测

5.1 数据样本和标签分离

python 复制代码
X = df.drop('price', axis=1)
y = df['price']

5.2 切分数据集

python 复制代码
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

5.3 模型训练和评估

5.3.1 Xgboost回归模型

python 复制代码
xgb_model = xgb.XGBRegressor(
    n_estimators=100,      
    max_depth=5,           
    learning_rate=0.1,     
    subsample=0.8,        
    random_state=42        
)

xgb_model.fit(X_train, y_train)

y_pred_xgb = xgb_model.predict(X_test)

rmse_xgb = np.sqrt(mean_squared_error(y_test, y_pred_xgb))
print(f'XGBoost Root Mean Squared Error: {rmse_xgb}')

XGBoost Root Mean Squared Error: 67003.09126576487

5.3.2 Random Forest回归模型

python 复制代码
rf_model = RandomForestRegressor(
    n_estimators=100,     
    max_depth=10,         
    min_samples_split=2,
    min_samples_leaf=1,    
    random_state=42      
)

rf_model.fit(X_train, y_train)

y_pred_rf = rf_model.predict(X_test)

rmse_rf = np.sqrt(mean_squared_error(y_test, y_pred_rf))
print(f'Random Forest Root Mean Squared Error: {rmse_rf}')

Random Forest Root Mean Squared Error: 68418.85393408517

参考文献:

1 https://www.kaggle.com/code/muhammaadmuzammil008/eda-random-forest-xgboost

相关推荐
Coder_Boy_5 小时前
技术发展的核心规律是「加法打底,减法优化,重构平衡」
人工智能·spring boot·spring·重构
会飞的老朱7 小时前
医药集团数智化转型,智能综合管理平台激活集团管理新效能
大数据·人工智能·oa协同办公
聆风吟º8 小时前
CANN runtime 实战指南:异构计算场景中运行时组件的部署、调优与扩展技巧
人工智能·神经网络·cann·异构计算
Codebee10 小时前
能力中心 (Agent SkillCenter):开启AI技能管理新时代
人工智能
聆风吟º11 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
uesowys11 小时前
Apache Spark算法开发指导-One-vs-Rest classifier
人工智能·算法·spark
AI_567811 小时前
AWS EC2新手入门:6步带你从零启动实例
大数据·数据库·人工智能·机器学习·aws
User_芊芊君子11 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能12 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
qq_1601448712 小时前
亲测!2026年零基础学AI的入门干货,新手照做就能上手
人工智能