机器学习——网格搜索(GridSearchCV)超参数优化


网格搜索(Grid Search)详细教学

1. 什么是网格搜索?

在机器学习模型中,算法的**超参数(Hyperparameters)**对模型的表现起着决定性作用。比如:

  • KNN 的邻居数量 n_neighbors

  • SVM 的惩罚系数 C 和核函数参数 gamma

  • 随机森林的决策树数量 n_estimators

这些超参数不会在训练过程中自动学习得到,而是需要我们人为设定。网格搜索(Grid Search)是一种最常见的超参数优化方法:
它通过
遍历给定参数网格中的所有组合
,使用交叉验证来评估每组参数的效果,最终选出表现最优的一组。

通俗理解:

👉 网格搜索 = 穷举法找最佳参数。


2. 网格搜索的核心思想

  1. 定义参数范围(网格) :例如 C=[0.1, 1, 10]gamma=[0.01, 0.1, 1]

  2. 训练所有组合 :即 (C=0.1, gamma=0.01)(C=0.1, gamma=0.1)...直到 (C=10, gamma=1)

  3. 交叉验证评估:每组参数都会在 k 折交叉验证下计算平均性能指标(如准确率、F1 分数)。

  4. 选择最佳参数:选出指标最优的一组参数作为最终模型配置。


3. 为什么要用网格搜索?

  • 超参数选择自动化:不用凭感觉拍脑袋。

  • 保证找到最优解:只要网格覆盖范围足够大,就不会遗漏最佳参数组合。

  • 结合交叉验证:结果更加稳健,避免过拟合或欠拟合。

但缺点也明显:

  • 计算开销大:参数范围和组合越多,训练越耗时。

  • 不适合大规模搜索:参数维度高时可能出现"维度灾难"。


4. Scikit-Learn 中的网格搜索工具

sklearn.model_selection.GridSearchCV 是最常用的网格搜索实现。

4.1 函数原型

复制代码
GridSearchCV(
    estimator,          # 基础模型,如SVC()、RandomForestClassifier()
    param_grid,         # 参数字典或列表,定义搜索空间
    scoring=None,       # 评估指标(accuracy、f1、roc_auc等)
    n_jobs=None,        # 并行任务数,-1表示使用所有CPU
    cv=None,            # 交叉验证折数,如cv=5
    verbose=0,          # 日志等级,1=简单进度条,2=详细
    refit=True,         # 是否在找到最优参数后重新训练整个模型
    return_train_score=False  # 是否返回训练集得分
)

GridSearchCV 常用参数表:

分类 参数 类型 说明 常用取值
核心 estimator estimator 对象 基础模型,必须实现 fit / predict SVC()RandomForestClassifier()
param_grid dict / list 要搜索的参数空间,键=参数名,值=候选值列表 {'C':[0.1,1,10], 'gamma':[0.01,0.1,1]}
评估 scoring str / callable 模型评估指标 accuracyf1_macroroc_aucneg_mean_squared_error
cv int / 生成器 交叉验证方式 5(5折交叉验证)、KFold(10)
refit bool / str 用最佳参数在全训练集上重新训练 True(默认)、'f1_macro'(多指标时指定)
效率 n_jobs int 并行任务数,-1=使用所有CPU -14
pre_dispatch int / str 并行调度策略 '2*n_jobs'(默认)
日志 verbose int 输出日志等级 0=无输出,1=进度,2=详细
错误处理 error_score str / numeric 参数报错时的分数 np.nan(默认)、0
调试 return_train_score bool 是否返回训练集得分(用于过拟合分析) False(默认)、True

5. 网格搜索实战案例

5.1 示例数据集

以鸢尾花(Iris)分类为例,使用 SVM 模型。

复制代码
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义模型
svc = SVC()

5.2 设置参数网格

复制代码
param_grid = {
    'C': [0.1, 1, 10, 100],          # 惩罚系数
    'gamma': [1, 0.1, 0.01, 0.001],  # 核函数参数
    'kernel': ['rbf', 'linear']      # 核函数类型
}

5.3 执行网格搜索

复制代码
grid = GridSearchCV(
    estimator=svc,
    param_grid=param_grid,
    scoring='accuracy',
    cv=5,
    verbose=2,
    n_jobs=-1
)
grid.fit(X_train, y_train)

5.4 输出结果

复制代码
print("最佳参数:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("测试集准确率:", grid.best_estimator_.score(X_test, y_test))

结果示例


6. 网格搜索的可视化

我们可以把不同参数组合的表现绘制出来,直观查看最优解在哪个区域:

复制代码
import matplotlib.pyplot as plt

results = pd.DataFrame(grid.cv_results_)

# 只绘制 C 与 gamma 的得分热力图(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(
    index='param_gamma',
    columns='param_C',
    values='mean_test_score'
)

plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()

7. 网格搜索的进阶技巧

  1. 缩小搜索范围:先用较粗粒度搜索,再在最优附近细化搜索。

  2. 并行计算n_jobs=-1 可利用多核 CPU。

  3. 随机搜索(RandomizedSearchCV):当参数空间太大时,可考虑随机抽样搜索,更高效。

  4. 贝叶斯优化 :如 OptunaHyperopt,比网格搜索更智能。


8. 注意事项

  • 参数空间不要过大,否则计算量爆炸。

  • 交叉验证的折数 cv 不宜过大,通常 5 或 10。

  • 选择合适的评分指标 scoring,分类问题常用 accuracyf1_macro,回归问题用 neg_mean_squared_error 等。

  • 最终模型建议用 grid.best_estimator_,而不是手动再初始化。


9. 总结

  • **网格搜索(Grid Search)**是一种系统化的超参数优化方法,通过遍历参数网格+交叉验证,找到表现最优的参数组合。

  • sklearn 中,GridSearchCV 是核心工具。

  • 它简单易用,但计算成本高,不适合大规模问题。

  • 实际应用中常结合粗到细搜索、随机搜索、贝叶斯优化来提升效率。

相关推荐
渲吧云渲染17 小时前
SaaS模式重构工业软件竞争规则,助力中小企业快速实现数字化转型
大数据·人工智能·sass
算家云17 小时前
DeepSeek-OCR本地部署教程:DeepSeek突破性开创上下文光学压缩,10倍效率重构文本处理范式
人工智能·计算机视觉·算家云·模型部署教程·镜像社区·deepseek-ocr
AgeClub17 小时前
1.2亿老人需助听器:本土品牌如何以AI破局,重构巨头垄断市场?
人工智能
PPIO派欧云18 小时前
PPIO上线Qwen-VL-8B/30B、GLM-4.5-Air等多款中小尺寸模型
人工智能
chenchihwen19 小时前
AI代码开发宝库系列:FAISS向量数据库
数据库·人工智能·python·faiss·1024程序员节
张登杰踩20 小时前
工业产品表面缺陷检测方法综述:从传统视觉到深度学习
人工智能·深度学习
sponge'20 小时前
opencv学习笔记6:SVM分类器
人工智能·机器学习·支持向量机·1024程序员节
zandy101120 小时前
2025年AI IDE的深度评测与推荐:从单一功能效率转向生态壁垒
ide·人工智能
旋转小马20 小时前
XGBoost完整学习指南:从数据清洗到模型调参
机器学习·scikit-learn·xgboost·1024程序员节