特征工程四-2:使用GridSearchCV 进行超参数网格搜索(Hyperparameter Tuning)的用途

1. GridSearchCV 的作用

GridSearchCV(网格搜索交叉验证)用于:

  • 自动搜索 给定参数范围内的最佳超参数组合。
  • 交叉验证评估 每个参数组合的性能,避免过拟合。
  • 返回最佳模型,可直接用于预测或分析。

2. 代码逐行解析

(1) 创建 GridSearchCV 对象
复制代码
grid = GridSearchCV(
    model,       # 要优化的模型(如 RandomForest、SVM 等)
    params,      # 待搜索的参数网格(字典或列表格式)
    error_score=0.  # 如果某组参数拟合报错,将该组合得分设为 0
)
  • model :已经定义的模型实例(如 model = RandomForestClassifier())。

  • params:参数网格,格式示例:

    复制代码
    params = {
        'n_estimators': [50, 100],  # 决策树数量
        'max_depth': [5, 10]        # 树的最大深度
    }
  • error_score=0.

    当某组参数导致模型拟合失败(如不兼容参数)时,将该参数组合的验证得分设为 0,避免程序中断。

(2) 执行网格搜索
复制代码
grid.fit(X, y)  # 用数据 X 和标签 y 拟合模型
  • params 中的所有参数组合进行尝试,并通过交叉验证(默认 5 折)评估性能。
  • 最终确定 最佳参数组合,并重新训练模型(用最佳参数在整个数据集上训练)。

3. 关键输出

完成 fit 后,可通过以下属性获取结果:

  • 最佳参数

    复制代码
    print(grid.best_params_)
    # 输出示例:{'max_depth': 10, 'n_estimators': 100}
  • 最佳模型的交叉验证得分

    复制代码
    print(grid.best_score_)
  • 最佳模型实例(可直接用于预测):

    复制代码
    best_model = grid.best_estimator_
    best_model.predict(X_test)

4. 注意事项

  1. 参数网格设计

    • 范围过大可能导致计算耗时,建议先用粗网格筛选,再细化。

    • 示例:

      复制代码
      params = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}  # SVM 参数
  2. 交叉验证控制

    • 可通过 cv 参数调整折数(如 cv=10)。
    • 使用 scoring 指定评估指标(如 scoring='accuracy')。
  3. 替代方案

    • 如果参数空间较大,可用 RandomizedSearchCV(随机搜索,更快)。

5. 完整示例

复制代码
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

# 定义模型和参数网格

# 实例化机器学习模型
rf = RandomForestClassifier()
lr = LogisticRegression()
knn = KNeighborsClassifier()
dt = DecisionTreeClassifier()

# 逻辑回归
lr_params/ = {'C': [1e-1, 1e0, 1e1, 1e2], 'penalty': ['l1', 'l2']}
# KNN
knn_params = {'n_neighbors': [1, 3, 5, 7]}
# 决策树
dt_params = {'max_depth': [None, 1, 3, 5, 7]}
# 随机森林
rf_params = {'n_estimators': [10, 50, 100], 'max_depth': [None, 1, 3, 5, 7]}
# model=rf/lr/knn/dt,params=lr_params/knn_params/dt_params/rf_params

# 网格搜索
grid = GridSearchCV(model, params, error_score=0.)
grid.fit(X_train, y_train)

# 输出最佳参数
print("Best parameters:", grid.best_params_)

总结

  • 用途:自动化超参数优化,提升模型性能。
  • 核心参数model(模型)、params(参数网格)、error_score(容错处理)。
  • 输出 :通过 best_params_best_score_ 等获取最佳结果。

适用于任何 Scikit-learn 兼容模型(分类、回归等)。

相关推荐
Caaacy_YU25 分钟前
多模态大模型研究每日简报【2025-08-21】
论文阅读·人工智能·机器学习·计算机视觉
画中有画30 分钟前
使用AI来实现拼多多自动化运营脚本
运维·人工智能·自动化·ai编程·rpa·自动化脚本
钮钴禄·爱因斯晨32 分钟前
AIGC浪潮下,风靡全球的Mcp到底是什么?一文讲懂,技术小白都知道!!
开发语言·人工智能·深度学习·神经网络·生成对抗网络·aigc
大模型真好玩33 分钟前
深入浅出LangChain AI Agent智能体开发教程(九)—LangChain从0到1搭建知识库
人工智能·python·mcp
xcLeigh40 分钟前
文心一言4.5开源模型实战:ERNIE-4.5-0.3B轻量化部署与效能突破
人工智能·开源·大模型·文心一言·ernie·轻量化部署
居7然2 小时前
解锁工业级Prompt设计,打造高准确率AI应用
人工智能·prompt·提示词
星期天要睡觉2 小时前
机器学习——网格搜索(GridSearchCV)超参数优化
人工智能·机器学习
元宇宙时间5 小时前
RWA加密金融高峰论坛&星链品牌全球发布 —— 稳定币与Web3的香港新篇章
人工智能·web3·区块链
天涯海风8 小时前
检索增强生成(RAG) 缓存增强生成(CAG) 生成中检索(RICHES) 知识库增强语言模型(KBLAM)
人工智能·缓存·语言模型
lxmyzzs9 小时前
基于深度学习CenterPoint的3D目标检测部署实战
人工智能·深度学习·目标检测·自动驾驶·ros·激光雷达·3d目标检测