绘制神经网络的决策边界

python 复制代码
def plot_decision_boundary(model, X, y):
    # Set min and max values and give it some padding
    x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
    y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
    h = 0.01
    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # Predict the function value for the whole grid
    Z = model(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

这段代码 plot_decision_boundary 用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。

1. 设置绘图边界

python 复制代码
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
  • X[0, :]X[1, :] 分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X 的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。
  • min()max() 分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。

2. 生成网格点

python 复制代码
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  • h = 0.01:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。
  • np.meshgrid():生成二维的网格点坐标矩阵 xxyyxx 包含的是 x 1 x_1 x1 坐标,yy 包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。

3. 模型预测

python 复制代码
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
  • xx.ravel()yy.ravel():将二维的网格点展平成一维向量,以便传入模型进行预测。
  • np.c_:将展平的 xxyy 坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。
  • model(np.c_[xx.ravel(), yy.ravel()]):传入展平后的网格点,通过模型 model 对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。
  • Z.reshape(xx.shape):将预测结果 Z 重新整形为与网格 xx 相同的形状,以便用于绘图。

4. 绘制决策边界和数据点

python 复制代码
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  • plt.contourf():根据 Z 中的分类结果绘制决策边界。该函数将网格点的预测结果 Z 映射为不同颜色的区域(由 cmap 控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python 复制代码
plt.ylabel('x2')
plt.xlabel('x1')
  • 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python 复制代码
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
  • plt.scatter():绘制输入数据点,X[0, :]X[1, :] 分别是输入数据的两个特征,c=y 表示根据标签 y y y 为数据点着色。
  • cmap=plt.cm.Spectral:使用 Spectral 颜色映射来为不同类别的数据点着色。

总结

plot_decision_boundary 函数的工作流程是:

  1. 根据输入数据 X X X 设置绘制区域的边界。
  2. 使用 np.meshgrid 在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。
  3. 对这些网格点进行模型预测,生成每个点的分类结果。
  4. 使用 plt.contourf() 绘制决策边界,显示模型在不同区域的分类结果。
  5. 最后,使用 plt.scatter() 绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。

这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。

相关推荐
CareyWYR18 分钟前
每周AI论文速递(2506202-250606)
人工智能
YYXZZ。。22 分钟前
PyTorch——优化器(9)
pytorch·深度学习·计算机视觉
点云SLAM22 分钟前
PyTorch 中contiguous函数使用详解和代码演示
人工智能·pytorch·python·3d深度学习·contiguous函数·张量内存布局优化·张量操作
小天才才34 分钟前
【自然语言处理】大模型时代的数据标注(主动学习)
人工智能·机器学习·语言模型·自然语言处理
音程36 分钟前
预训练语言模型T5-11B的简要介绍
人工智能·语言模型·自然语言处理
苏苏susuus1 小时前
机器学习:集成学习概念和分类、随机森林、Adaboost、GBDT
机器学习·分类·集成学习
人肉推土机1 小时前
AI Agent 架构设计:ReAct 与 Self-Ask 模式对比与分析
人工智能·大模型·llm·agent
新知图书1 小时前
OpenCV为图像添加边框
人工智能·opencv·计算机视觉
大模型真好玩1 小时前
可视化神器WandB,大模型训练的必备工具!
人工智能·python·mcp
databook1 小时前
当机器学习遇见压缩感知:用少量数据重建完整世界
python·机器学习·scikit-learn