机器学习——网格搜索(GridSearchCV)超参数优化


网格搜索(Grid Search)详细教学

1. 什么是网格搜索?

在机器学习模型中,算法的**超参数(Hyperparameters)**对模型的表现起着决定性作用。比如:

  • KNN 的邻居数量 n_neighbors

  • SVM 的惩罚系数 C 和核函数参数 gamma

  • 随机森林的决策树数量 n_estimators

这些超参数不会在训练过程中自动学习得到,而是需要我们人为设定。网格搜索(Grid Search)是一种最常见的超参数优化方法:
它通过
遍历给定参数网格中的所有组合
,使用交叉验证来评估每组参数的效果,最终选出表现最优的一组。

通俗理解:

👉 网格搜索 = 穷举法找最佳参数。


2. 网格搜索的核心思想

  1. 定义参数范围(网格) :例如 C=[0.1, 1, 10]gamma=[0.01, 0.1, 1]

  2. 训练所有组合 :即 (C=0.1, gamma=0.01)(C=0.1, gamma=0.1)...直到 (C=10, gamma=1)

  3. 交叉验证评估:每组参数都会在 k 折交叉验证下计算平均性能指标(如准确率、F1 分数)。

  4. 选择最佳参数:选出指标最优的一组参数作为最终模型配置。


3. 为什么要用网格搜索?

  • 超参数选择自动化:不用凭感觉拍脑袋。

  • 保证找到最优解:只要网格覆盖范围足够大,就不会遗漏最佳参数组合。

  • 结合交叉验证:结果更加稳健,避免过拟合或欠拟合。

但缺点也明显:

  • 计算开销大:参数范围和组合越多,训练越耗时。

  • 不适合大规模搜索:参数维度高时可能出现"维度灾难"。


4. Scikit-Learn 中的网格搜索工具

sklearn.model_selection.GridSearchCV 是最常用的网格搜索实现。

4.1 函数原型

复制代码
GridSearchCV(
    estimator,          # 基础模型,如SVC()、RandomForestClassifier()
    param_grid,         # 参数字典或列表,定义搜索空间
    scoring=None,       # 评估指标(accuracy、f1、roc_auc等)
    n_jobs=None,        # 并行任务数,-1表示使用所有CPU
    cv=None,            # 交叉验证折数,如cv=5
    verbose=0,          # 日志等级,1=简单进度条,2=详细
    refit=True,         # 是否在找到最优参数后重新训练整个模型
    return_train_score=False  # 是否返回训练集得分
)

GridSearchCV 常用参数表:

分类 参数 类型 说明 常用取值
核心 estimator estimator 对象 基础模型,必须实现 fit / predict SVC()RandomForestClassifier()
param_grid dict / list 要搜索的参数空间,键=参数名,值=候选值列表 {'C':[0.1,1,10], 'gamma':[0.01,0.1,1]}
评估 scoring str / callable 模型评估指标 accuracyf1_macroroc_aucneg_mean_squared_error
cv int / 生成器 交叉验证方式 5(5折交叉验证)、KFold(10)
refit bool / str 用最佳参数在全训练集上重新训练 True(默认)、'f1_macro'(多指标时指定)
效率 n_jobs int 并行任务数,-1=使用所有CPU -14
pre_dispatch int / str 并行调度策略 '2*n_jobs'(默认)
日志 verbose int 输出日志等级 0=无输出,1=进度,2=详细
错误处理 error_score str / numeric 参数报错时的分数 np.nan(默认)、0
调试 return_train_score bool 是否返回训练集得分(用于过拟合分析) False(默认)、True

5. 网格搜索实战案例

5.1 示例数据集

以鸢尾花(Iris)分类为例,使用 SVM 模型。

复制代码
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, train_test_split

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 定义模型
svc = SVC()

5.2 设置参数网格

复制代码
param_grid = {
    'C': [0.1, 1, 10, 100],          # 惩罚系数
    'gamma': [1, 0.1, 0.01, 0.001],  # 核函数参数
    'kernel': ['rbf', 'linear']      # 核函数类型
}

5.3 执行网格搜索

复制代码
grid = GridSearchCV(
    estimator=svc,
    param_grid=param_grid,
    scoring='accuracy',
    cv=5,
    verbose=2,
    n_jobs=-1
)
grid.fit(X_train, y_train)

5.4 输出结果

复制代码
print("最佳参数:", grid.best_params_)
print("最佳得分:", grid.best_score_)
print("测试集准确率:", grid.best_estimator_.score(X_test, y_test))

结果示例


6. 网格搜索的可视化

我们可以把不同参数组合的表现绘制出来,直观查看最优解在哪个区域:

复制代码
import matplotlib.pyplot as plt

results = pd.DataFrame(grid.cv_results_)

# 只绘制 C 与 gamma 的得分热力图(kernel=rbf)
scores = results[results.param_kernel == 'rbf'].pivot(
    index='param_gamma',
    columns='param_C',
    values='mean_test_score'
)

plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot)
plt.xlabel('C')
plt.ylabel('gamma')
plt.colorbar()
plt.xticks(np.arange(len(scores.columns)), scores.columns)
plt.yticks(np.arange(len(scores.index)), scores.index)
plt.title('Grid Search Accuracy Heatmap')
plt.show()

7. 网格搜索的进阶技巧

  1. 缩小搜索范围:先用较粗粒度搜索,再在最优附近细化搜索。

  2. 并行计算n_jobs=-1 可利用多核 CPU。

  3. 随机搜索(RandomizedSearchCV):当参数空间太大时,可考虑随机抽样搜索,更高效。

  4. 贝叶斯优化 :如 OptunaHyperopt,比网格搜索更智能。


8. 注意事项

  • 参数空间不要过大,否则计算量爆炸。

  • 交叉验证的折数 cv 不宜过大,通常 5 或 10。

  • 选择合适的评分指标 scoring,分类问题常用 accuracyf1_macro,回归问题用 neg_mean_squared_error 等。

  • 最终模型建议用 grid.best_estimator_,而不是手动再初始化。


9. 总结

  • **网格搜索(Grid Search)**是一种系统化的超参数优化方法,通过遍历参数网格+交叉验证,找到表现最优的参数组合。

  • sklearn 中,GridSearchCV 是核心工具。

  • 它简单易用,但计算成本高,不适合大规模问题。

  • 实际应用中常结合粗到细搜索、随机搜索、贝叶斯优化来提升效率。

相关推荐
新手村领路人几秒前
飞桨paddlepaddle旧版本2.4.2安装
人工智能·paddlepaddle
Elastic 中国社区官方博客6 分钟前
带地图的 RAG:多模态 + 地理空间 在 Elasticsearch 中
大数据·人工智能·elasticsearch·搜索引擎·ai·语言模型·全文检索
云卓SKYDROID7 分钟前
无人机云台电压类型及测量方法
人工智能·目标跟踪·无人机·高科技·航线系统
云雾J视界22 分钟前
AI时代技术面试重构:谷歌如何用Vibe Coding与抗作弊革命重塑招聘
人工智能·google·面试·重构·谷歌·ai工具·技术面试
BFT白芙堂22 分钟前
GRASP 实验室研究 论文解读 | 机器人交互:基于神经网络引导变分推理的快速失配估计
人工智能·神经网络·机器学习·mvc·人机交互·科研教育机器人·具身智能平台
深蓝学院24 分钟前
智源研究院新研究:突破物理世界智能边界的RoboBrain 2.0,将重构具身AI能力天花板
人工智能·重构
做萤石二次开发的哈哈26 分钟前
萤石安全生产监管解决方案:构建企业安全智能化防护网
大数据·人工智能
万米商云28 分钟前
碎片化采购是座金矿:数字化正重构电子元器件分销的价值链
大数据·人工智能·电子元器件·供应链采购
GoldenSpider.AI29 分钟前
马斯克访谈深度解读:机器人、AI芯片与人类文明的未来
人工智能·机器人·starlink·spacex·tesla·elon musk·optimus
伊莲娜生活30 分钟前
大健康时代下的平台电商:VTN平台以科研创新重构健康美丽消费生态
人工智能·物联网·重构