绘制神经网络的决策边界

python 复制代码
def plot_decision_boundary(model, X, y):
    # Set min and max values and give it some padding
    x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
    y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
    h = 0.01
    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # Predict the function value for the whole grid
    Z = model(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

这段代码 plot_decision_boundary 用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。

1. 设置绘图边界

python 复制代码
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
  • X[0, :]X[1, :] 分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X 的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。
  • min()max() 分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。

2. 生成网格点

python 复制代码
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  • h = 0.01:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。
  • np.meshgrid():生成二维的网格点坐标矩阵 xxyyxx 包含的是 x 1 x_1 x1 坐标,yy 包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。

3. 模型预测

python 复制代码
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
  • xx.ravel()yy.ravel():将二维的网格点展平成一维向量,以便传入模型进行预测。
  • np.c_:将展平的 xxyy 坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。
  • model(np.c_[xx.ravel(), yy.ravel()]):传入展平后的网格点,通过模型 model 对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。
  • Z.reshape(xx.shape):将预测结果 Z 重新整形为与网格 xx 相同的形状,以便用于绘图。

4. 绘制决策边界和数据点

python 复制代码
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  • plt.contourf():根据 Z 中的分类结果绘制决策边界。该函数将网格点的预测结果 Z 映射为不同颜色的区域(由 cmap 控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python 复制代码
plt.ylabel('x2')
plt.xlabel('x1')
  • 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python 复制代码
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
  • plt.scatter():绘制输入数据点,X[0, :]X[1, :] 分别是输入数据的两个特征,c=y 表示根据标签 y y y 为数据点着色。
  • cmap=plt.cm.Spectral:使用 Spectral 颜色映射来为不同类别的数据点着色。

总结

plot_decision_boundary 函数的工作流程是:

  1. 根据输入数据 X X X 设置绘制区域的边界。
  2. 使用 np.meshgrid 在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。
  3. 对这些网格点进行模型预测,生成每个点的分类结果。
  4. 使用 plt.contourf() 绘制决策边界,显示模型在不同区域的分类结果。
  5. 最后,使用 plt.scatter() 绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。

这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。

相关推荐
水如烟5 小时前
孤能子视角:“组织行为学–组织文化“
人工智能
大山同学5 小时前
图片补全-Context Encoder
人工智能·机器学习·计算机视觉
薛定谔的猫19825 小时前
十七、用 GPT2 中文对联模型实现经典上联自动对下联:
人工智能·深度学习·gpt2·大模型 训练 调优
壮Sir不壮5 小时前
2026年奇点:Clawdbot引爆个人AI代理
人工智能·ai·大模型·claude·clawdbot·moltbot·openclaw
PaperRed ai写作降重助手6 小时前
高性价比 AI 论文写作软件推荐:2026 年预算友好型
人工智能·aigc·论文·写作·ai写作·智能降重
玉梅小洋6 小时前
Claude Code 从入门到精通(七):Sub Agent 与 Skill 终极PK
人工智能·ai·大模型·ai编程·claude·ai工具
-嘟囔着拯救世界-6 小时前
【保姆级教程】Win11 下从零部署 Claude Code:本地环境配置 + VSCode 可视化界面全流程指南
人工智能·vscode·ai·编辑器·html5·ai编程·claude code
正见TrueView6 小时前
程一笑的价值选择:AI金玉其外,“收割”老人败絮其中
人工智能
Imm7776 小时前
中国知名的车膜品牌推荐几家
人工智能·python
风静如云6 小时前
Claude Code:进入dash模式
人工智能