python
def plot_decision_boundary(model, X, y):
# Set min and max values and give it some padding
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
h = 0.01
# Generate a grid of points with distance h between them
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
# Predict the function value for the whole grid
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Plot the contour and training examples
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.ylabel('x2')
plt.xlabel('x1')
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
这段代码 plot_decision_boundary
用于绘制神经网络的决策边界。它会根据模型的预测结果,直观地展示出模型对不同输入数据的分类结果,以及模型的决策边界。下面,我们逐步解释每行代码的含义。
1. 设置绘图边界
python
x_min, x_max = X[0, :].min() - 1, X[0, :].max() + 1
y_min, y_max = X[1, :].min() - 1, X[1, :].max() + 1
X[0, :]
和X[1, :]
分别表示输入数据的第 1 个特征(通常是 x 1 x_1 x1)和第 2 个特征(通常是 x 2 x_2 x2)。X
的形状是 2 × m 2 \times m 2×m,其中 m m m 是样本数。min()
和max()
分别找到数据在 x 1 x_1 x1 和 x 2 x_2 x2 方向上的最小值和最大值,然后各加减 1,以增加一点边界范围,保证图形不会紧贴边界。
2. 生成网格点
python
h = 0.01
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
h = 0.01
:定义网格点的步长,表示网格点之间的距离。步长越小,网格越密,决策边界也会越平滑。np.meshgrid()
:生成二维的网格点坐标矩阵xx
和yy
,xx
包含的是 x 1 x_1 x1 坐标,yy
包含的是 x 2 x_2 x2 坐标。它们的大小分别为 N × M N \times M N×M,即整个边界区域内所有可能的坐标点。
3. 模型预测
python
Z = model(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
xx.ravel()
和yy.ravel()
:将二维的网格点展平成一维向量,以便传入模型进行预测。np.c_
:将展平的xx
和yy
坐标点拼接在一起,形成一个 P × 2 P \times 2 P×2 的矩阵,其中 P P P 是网格点的总数量,每一行是一个输入点的坐标。model(np.c_[xx.ravel(), yy.ravel()])
:传入展平后的网格点,通过模型model
对这些网格点进行预测,返回预测的分类结果(通常是 0 或 1)。Z.reshape(xx.shape)
:将预测结果Z
重新整形为与网格xx
相同的形状,以便用于绘图。
4. 绘制决策边界和数据点
python
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
plt.contourf()
:根据Z
中的分类结果绘制决策边界。该函数将网格点的预测结果Z
映射为不同颜色的区域(由cmap
控制),形成决策边界。即在分类任务中,不同区域会被不同的颜色填充。
python
plt.ylabel('x2')
plt.xlabel('x1')
- 设置横轴和纵轴的标签为 x 1 x_1 x1 和 x 2 x_2 x2,表示两个输入特征。
python
plt.scatter(X[0, :], X[1, :], c=y, cmap=plt.cm.Spectral)
plt.scatter()
:绘制输入数据点,X[0, :]
和X[1, :]
分别是输入数据的两个特征,c=y
表示根据标签 y y y 为数据点着色。cmap=plt.cm.Spectral
:使用Spectral
颜色映射来为不同类别的数据点着色。
总结
plot_decision_boundary
函数的工作流程是:
- 根据输入数据 X X X 设置绘制区域的边界。
- 使用
np.meshgrid
在边界范围内生成一个密集的网格点矩阵,覆盖整个决策空间。 - 对这些网格点进行模型预测,生成每个点的分类结果。
- 使用
plt.contourf()
绘制决策边界,显示模型在不同区域的分类结果。 - 最后,使用
plt.scatter()
绘制实际的训练数据点,便于直观地查看模型的分类效果与实际数据的分布情况。
这样,决策边界和训练数据会同时显示在一张图上,便于观察模型的分类效果。