绘制神经网络的决策边界

python 复制代码
def plot_decision_boundary(model, X, y):
    # Set min and max values and give it some padding
    x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
    y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
    h = 0.01
    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # Predict the function value for the whole grid
    Z = model(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

这段代码 plot_decision_boundary 用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。

1. 设置绘图边界

python 复制代码
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
  • X[0, :]X[1, :] 分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X 的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。
  • min()max() 分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。

2. 生成网格点

python 复制代码
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  • h = 0.01:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。
  • np.meshgrid():生成二维的网格点坐标矩阵 xxyyxx 包含的是 x 1 x_1 x1 坐标,yy 包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。

3. 模型预测

python 复制代码
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
  • xx.ravel()yy.ravel():将二维的网格点展平成一维向量,以便传入模型进行预测。
  • np.c_:将展平的 xxyy 坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。
  • model(np.c_[xx.ravel(), yy.ravel()]):传入展平后的网格点,通过模型 model 对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。
  • Z.reshape(xx.shape):将预测结果 Z 重新整形为与网格 xx 相同的形状,以便用于绘图。

4. 绘制决策边界和数据点

python 复制代码
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  • plt.contourf():根据 Z 中的分类结果绘制决策边界。该函数将网格点的预测结果 Z 映射为不同颜色的区域(由 cmap 控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python 复制代码
plt.ylabel('x2')
plt.xlabel('x1')
  • 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python 复制代码
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
  • plt.scatter():绘制输入数据点,X[0, :]X[1, :] 分别是输入数据的两个特征,c=y 表示根据标签 y y y 为数据点着色。
  • cmap=plt.cm.Spectral:使用 Spectral 颜色映射来为不同类别的数据点着色。

总结

plot_decision_boundary 函数的工作流程是:

  1. 根据输入数据 X X X 设置绘制区域的边界。
  2. 使用 np.meshgrid 在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。
  3. 对这些网格点进行模型预测,生成每个点的分类结果。
  4. 使用 plt.contourf() 绘制决策边界,显示模型在不同区域的分类结果。
  5. 最后,使用 plt.scatter() 绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。

这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。

相关推荐
雪兽软件1 小时前
人工智能和大数据如何改变企业?
大数据·人工智能
UMS攸信技术3 小时前
汽车电子行业数字化转型的实践与探索——以盈趣汽车电子为例
人工智能·汽车
ws2019073 小时前
聚焦汽车智能化与电动化︱AUTO TECH 2025 华南展,以展带会,已全面启动,与您相约11月广州!
大数据·人工智能·汽车
堇舟4 小时前
斯皮尔曼相关(Spearman correlation)系数
人工智能·算法·机器学习
爱写代码的小朋友4 小时前
使用 OpenCV 进行人脸检测
人工智能·opencv·计算机视觉
Cici_ovo5 小时前
摄像头点击器常见问题——摄像头视窗打开慢
人工智能·单片机·嵌入式硬件·物联网·计算机视觉·硬件工程
QQ39575332375 小时前
中阳智能交易系统:创新金融科技赋能投资新时代
人工智能·金融
yyfhq5 小时前
dcgan
深度学习·机器学习·生成对抗网络
这个男人是小帅5 小时前
【图神经网络】 AM-GCN论文精讲(全网最细致篇)
人工智能·pytorch·深度学习·神经网络·分类
放松吃羊肉6 小时前
【约束优化】一次搞定拉格朗日,对偶问题,弱对偶定理,Slater条件和KKT条件
人工智能·机器学习·支持向量机·对偶问题·约束优化·拉格朗日·kkt