绘制神经网络的决策边界

python 复制代码
def plot_decision_boundary(model, X, y):
    # Set min and max values and give it some padding
    x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
    y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
    h = 0.01
    # Generate a grid of points with distance h between them
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # Predict the function value for the whole grid
    Z = model(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    # Plot the contour and training examples
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.ylabel('x2')
    plt.xlabel('x1')
    plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)

这段代码 plot_decision_boundary 用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。

1. 设置绘图边界

python 复制代码
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
  • X[0, :]X[1, :] 分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X 的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。
  • min()max() 分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。

2. 生成网格点

python 复制代码
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  • h = 0.01:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。
  • np.meshgrid():生成二维的网格点坐标矩阵 xxyyxx 包含的是 x 1 x_1 x1 坐标,yy 包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。

3. 模型预测

python 复制代码
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
  • xx.ravel()yy.ravel():将二维的网格点展平成一维向量,以便传入模型进行预测。
  • np.c_:将展平的 xxyy 坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。
  • model(np.c_[xx.ravel(), yy.ravel()]):传入展平后的网格点,通过模型 model 对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。
  • Z.reshape(xx.shape):将预测结果 Z 重新整形为与网格 xx 相同的形状,以便用于绘图。

4. 绘制决策边界和数据点

python 复制代码
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  • plt.contourf():根据 Z 中的分类结果绘制决策边界。该函数将网格点的预测结果 Z 映射为不同颜色的区域(由 cmap 控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python 复制代码
plt.ylabel('x2')
plt.xlabel('x1')
  • 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python 复制代码
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
  • plt.scatter():绘制输入数据点,X[0, :]X[1, :] 分别是输入数据的两个特征,c=y 表示根据标签 y y y 为数据点着色。
  • cmap=plt.cm.Spectral:使用 Spectral 颜色映射来为不同类别的数据点着色。

总结

plot_decision_boundary 函数的工作流程是:

  1. 根据输入数据 X X X 设置绘制区域的边界。
  2. 使用 np.meshgrid 在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。
  3. 对这些网格点进行模型预测,生成每个点的分类结果。
  4. 使用 plt.contourf() 绘制决策边界,显示模型在不同区域的分类结果。
  5. 最后,使用 plt.scatter() 绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。

这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。

相关推荐
Chaos_Wang_17 分钟前
NLP高频面试题(三十)——LLama系列模型介绍,包括LLama LLama2和LLama3
人工智能·自然语言处理·llama
databook20 分钟前
线性判别分析(LDA):降维与分类的完美结合
python·机器学习·scikit-learn
新智元22 分钟前
美国 CS 专业卷上天,满分学霸惨遭藤校全拒!父亲大受震撼引爆热议
人工智能·openai
新智元24 分钟前
美国奥数题撕碎 AI 数学神话,顶级模型现场翻车!最高得分 5%,DeepSeek 唯一逆袭
人工智能·openai
Baihai_IDP34 分钟前
「DeepSeek-V3 技术解析」:无辅助损失函数的负载均衡
人工智能·llm·deepseek
硅谷秋水41 分钟前
大语言模型智体的综述:方法论、应用和挑战(下)
人工智能·深度学习·机器学习·语言模型·自然语言处理
TGITCIC1 小时前
BERT与Transformer到底选哪个-下部
人工智能·gpt·大模型·aigc·bert·transformer
Lx3521 小时前
AutoML逆袭:普通开发者如何玩转大模型调参
人工智能
IT古董1 小时前
【漫话机器学习系列】185.神经网络参数的标准初始化(Normalized Initialization of Neural Network Parameter
人工智能
嘻嘻哈哈开森1 小时前
Java开发工程师转AI工程师
人工智能·后端