绘制神经网络的决策边界

python 复制代码
def plot_decision_boundary(model, X, y):
    # Set min and max values and give it some padding
    x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
    y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
    h = 0.01
    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # Predict the function value for the whole grid
    Z = model(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

这段代码 plot_decision_boundary 用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。

1. 设置绘图边界

python 复制代码
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
  • X[0, :]X[1, :] 分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X 的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。
  • min()max() 分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。

2. 生成网格点

python 复制代码
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  • h = 0.01:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。
  • np.meshgrid():生成二维的网格点坐标矩阵 xxyyxx 包含的是 x 1 x_1 x1 坐标,yy 包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。

3. 模型预测

python 复制代码
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
  • xx.ravel()yy.ravel():将二维的网格点展平成一维向量,以便传入模型进行预测。
  • np.c_:将展平的 xxyy 坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。
  • model(np.c_[xx.ravel(), yy.ravel()]):传入展平后的网格点,通过模型 model 对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。
  • Z.reshape(xx.shape):将预测结果 Z 重新整形为与网格 xx 相同的形状,以便用于绘图。

4. 绘制决策边界和数据点

python 复制代码
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  • plt.contourf():根据 Z 中的分类结果绘制决策边界。该函数将网格点的预测结果 Z 映射为不同颜色的区域(由 cmap 控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python 复制代码
plt.ylabel('x2')
plt.xlabel('x1')
  • 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python 复制代码
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
  • plt.scatter():绘制输入数据点,X[0, :]X[1, :] 分别是输入数据的两个特征,c=y 表示根据标签 y y y 为数据点着色。
  • cmap=plt.cm.Spectral:使用 Spectral 颜色映射来为不同类别的数据点着色。

总结

plot_decision_boundary 函数的工作流程是:

  1. 根据输入数据 X X X 设置绘制区域的边界。
  2. 使用 np.meshgrid 在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。
  3. 对这些网格点进行模型预测,生成每个点的分类结果。
  4. 使用 plt.contourf() 绘制决策边界,显示模型在不同区域的分类结果。
  5. 最后,使用 plt.scatter() 绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。

这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。

相关推荐
LO嘉嘉VE13 分钟前
学习笔记二十一:深度学习
笔记·深度学习·学习
vvoennvv1 小时前
【Python TensorFlow】 TCN-GRU时间序列卷积门控循环神经网络时序预测算法(附代码)
python·rnn·神经网络·机器学习·gru·tensorflow·tcn
YJlio1 小时前
[编程达人挑战赛] 用 PowerShell 写了一个“电脑一键初始化脚本”:从混乱到可复制的开发环境
数据库·人工智能·电脑
玦尘、1 小时前
《统计学习方法》第4章——朴素贝叶斯法【学习笔记】
笔记·机器学习
RoboWizard1 小时前
PCIe 5.0 SSD有无独立缓存对性能影响大吗?Kingston FURY Renegade G5!
人工智能·缓存·电脑·金士顿
霍格沃兹测试开发学社-小明2 小时前
测试左移2.0:在开发周期前端筑起质量防线
前端·javascript·网络·人工智能·测试工具·easyui
懒麻蛇2 小时前
从矩阵相关到矩阵回归:曼特尔检验与 MRQAP
人工智能·线性代数·矩阵·数据挖掘·回归
xwill*2 小时前
RDT-1B: A DIFFUSION FOUNDATION MODEL FOR BIMANUAL MANIPULATION
人工智能·pytorch·python·深度学习
网安INF2 小时前
机器学习入门:深入理解线性回归
人工智能·机器学习·线性回归
陈奕昆2 小时前
n8n实战营Day2课时2:Loop+Merge节点进阶·Excel批量校验实操
人工智能·python·excel·n8n