用Python手写一个能识花的感知器模型——Iris分类实战详解

描述

感知器是最简单的线性二分类模型。尽管它很基础,但在实际场景中仍有用处,尤其是在:

  1. 教学/入门:帮助学生理解线性分类、梯度思想与决策边界。
  2. 低算力嵌入式设备:当模型必须非常小且推断简单(例如设备上用一个线性规则快速判断两类状态)时,感知器可作为简单筛选器。
  3. 数据快速原型:在需要快速判断某两个易分离类别时(如质量控制中"合格/不合格"的某种材质特征),感知器能快速给出一个基线性能。
  4. 可解释性需求强的应用:感知器决策来自线性组合,容易解释权重对特征的影响。

在本文示例里,我们用萼片长度(sepal length)和花瓣长度(petal length)来区分 setosa 与 versicolor。这两维在前 100 个样本中本身就比较有区分力,所以非常适合作为教学与演示数据集。

题解答案

  1. 载入 sklearn.datasets.load_iris,取前 100 个样本(0--99),对应两类(0:setosa,1:versicolor)。

  2. 从每个样本提取第 1 列(萼片长度)与第 3 列(花瓣长度)作为特征矩阵 (X \in \mathbb{R}^{100 \times 2})。

  3. 将标签 y 中的 0 替换为 -1,1 替换为 +1,适应感知器的符号输出。

  4. 实现一个感知器类(从零写,包含 fitpredictnet_input),记录每一轮(epoch)错误分类的数量。

  5. 训练后:

    • 绘制样本散点图(两类不同标记)和决策边界(可视化)。
    • 绘制训练过程中的错误数随迭代变化图(用以查看是否收敛)。
  6. 给出测试/训练准确率与模型权重解读,并分析复杂度与空间占用。

题解代码

下面给出完整、清晰、注释充足的 Python 代码。运行前请确保已安装 numpy, scikit-learn, matplotlib

python 复制代码
# perceptron_iris_demo.py
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

class Perceptron:
    """
    简单的感知器实现(批量感知器,逐样本更新)。
    参数:
      eta: 学习率 (float)
      n_iter: 训练轮数 (int)
    属性:
      w_: 权重向量(包含偏置),形状 (n_features + 1,)
      errors_: 每个 epoch 的错误分类数列表
    """
    def __init__(self, eta=0.01, n_iter=10):
        self.eta = eta
        self.n_iter = n_iter

    def fit(self, X, y):
        # X: (n_samples, n_features), y: { -1, 1 }
        n_samples, n_features = X.shape
        # 初始化权重(一个额外的偏置权重)
        self.w_ = np.zeros(n_features + 1)
        self.errors_ = []

        for epoch in range(self.n_iter):
            errors = 0
            for xi, target in zip(X, y):
                update = self.eta * (target - self.predict(xi))
                # 如果 update != 0,说明预测有误(或不完全等于目标),按感知器规则更新
                if update != 0.0:
                    # w_j = w_j + update * x_j
                    self.w_[1:] += update * xi
                    # bias: w_0 = w_0 + update
                    self.w_[0] += update
                    errors += 1
            self.errors_.append(errors)
            # 可选:打印每轮的错误数,便于调试
            # print(f"Epoch {epoch+1}/{self.n_iter}, errors: {errors}")
        return self

    def net_input(self, X):
        # 线性组合:w_0 * 1 + sum(w_j * x_j)
        return np.dot(X, self.w_[1:]) + self.w_[0]

    def predict(self, X):
        # 对单个样本或批量进行预测
        # 返回 +1 或 -1
        net = self.net_input(X)
        # 当 X 是单样本时 net 是标量;当是数组时是数组
        return np.where(net >= 0.0, 1, -1)


def main():
    # 1. 加载数据并预处理
    iris = load_iris()
    data = iris.data
    target = iris.target

    # 只取前100个样本(setosa 和 versicolor),并取第1列和第3列作为特征(索引0和2)
    X = data[0:100, [0, 2]]  # (100, 2)
    y = target[0:100]        # 0或1

    # 将 0 -> -1,1 -> +1
    y = np.where(y == 0, -1, 1)

    # 2. 数据可视化(散点图)
    index_0 = np.where(y == -1)
    index_1 = np.where(y == 1)

    plt.figure(figsize=(6, 4))
    plt.scatter(X[index_0, 0], X[index_0, 1], marker='x', label='setosa (-1)')
    plt.scatter(X[index_1, 0], X[index_1, 1], marker='o', label='versicolor (+1)')
    plt.xlabel('萼片长度 (sepal length)')
    plt.ylabel('花瓣长度 (petal length)')
    plt.legend(loc='lower right')
    plt.title('Iris 子集(前100个样本)- 特征散点图')
    plt.show()

    # 3. 训练感知器并记录错误数
    ppn = Perceptron(eta=0.1, n_iter=10)
    ppn.fit(X, y)

    # 4. 绘制每轮的错误数量(学习曲线)
    plt.figure(figsize=(6, 4))
    plt.plot(range(1, len(ppn.errors_) + 1), ppn.errors_, marker='o')
    plt.xlabel('训练轮数 (epoch)')
    plt.ylabel('错误分类数')
    plt.title('感知器训练过程 - 错误数随迭代的变化')
    plt.grid(True)
    plt.show()

    # 5. 查看训练结果:权重和训练集准确率
    print("训练得到的权重(包含偏置w_0):", ppn.w_)
    y_pred = ppn.predict(X)
    accuracy = np.mean(y_pred == y)
    print(f"训练集准确率: {accuracy * 100:.2f}%")

    # 6. 可选:绘制决策边界(二维特征空间)
    # 创建网格来绘制决策边界
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                         np.linspace(y_min, y_max, 200))
    grid = np.c_[xx.ravel(), yy.ravel()]
    Z = ppn.predict(grid)
    Z = Z.reshape(xx.shape)

    plt.figure(figsize=(6, 4))
    plt.contourf(xx, yy, Z, alpha=0.2)
    plt.scatter(X[index_0, 0], X[index_0, 1], marker='x', label='setosa (-1)')
    plt.scatter(X[index_1, 0], X[index_1, 1], marker='o', label='versicolor (+1)')
    plt.xlabel('萼片长度 (sepal length)')
    plt.ylabel('花瓣长度 (petal length)')
    plt.legend(loc='lower right')
    plt.title('感知器决策边界(灰色区域表示分类)')
    plt.show()

if __name__ == "__main__":
    main()

题解代码分析

数据加载与预处理

python 复制代码
iris = load_iris()
data = iris.data
target = iris.target
X = data[0:100, [0, 2]]
y = target[0:100]
y = np.where(y == 0, -1, 1)
  • load_iris() 返回一个包含 data(150×4)和 target (150,) 的字典样对象。
  • 只取前 100 个样本 (0..99),这两个类别是线性可分性较好的组合:setosa 与 versicolor。
  • 选择第 1 列(萼片长度,index=0)与第 3 列(花瓣长度,index=2)作为演示用的二维特征,便于可视化。
  • 将标签从 {0,1} 映射到 {-1, +1},因为经典感知器输出是符号函数(sign)。

感知器类(Perceptron)

python 复制代码
class Perceptron:
    def __init__(self, eta=0.01, n_iter=10):
        ...
    def fit(self, X, y):
        ...
    def net_input(self, X):
        ...
    def predict(self, X):
        ...
  • w_ 含偏置,长度为 n_features + 1w_[0] 是偏置(bias/intercept)。
  • fit 中采用逐样本(online)更新规则:w <- w + eta * (target - predict(x)) * x(对偏置 x0=1 用同样的更新)。
  • 这里 predict 使用 np.where(net >= 0.0, 1, -1),决定阈值为 0。
  • errors_ 保存每轮的错误数,用来绘制学习曲线与判断是否收敛。

可视化:散点图、错误曲线、决策边界

  • 散点图:直观查看两类在所选二维特征空间中的分布。
  • 错误曲线:如果感知器能线性分离数据,错误数通常会下降到 0 并维持;若不能完全线性分离,错误数会在某个水平徘徊。
  • 决策边界:我们用网格对整个特征空间进行预测并绘制等高面(实质是分类分隔面的可视化),这样就能看到模型如何把空间切分为两类。

示例测试及结果

运行方式 :将上面的代码保存为 perceptron_iris_demo.py,在支持的 Python 环境下运行:

bash 复制代码
python perceptron_iris_demo.py

预期输出与解释

  1. 首先弹出一个散点图窗口:可以看到 setosa(x 标记)与 versicolor(o 标记)大体上能用一条线分开(setosa 的花瓣长度普遍更短)。

  2. 然后弹出训练错误数随 epoch 变化的折线图。对于前100个样本(两个类别),在合理的学习率和轮数下,通常会在若干 epoch 后收敛到 0 错误或极低错误数(因为这两个类别在这两个维度上几乎线性可分)。

  3. 控制台会输出类似:

    训练得到的权重(包含偏置w_0): [ -3.2 0.8 1.5 ]
    训练集准确率: 100.00%

(这里的数字会因随机初始化/实现差异与超参不同而不同,上面只是示意)。

  1. 最后会弹出一个决策边界图,可以直观看到分类线把平面分为两块,与散点分布吻合。

示例说明:如果训练集准确率接近 100%,说明感知器在这两个类别与选定特征下很好地学习到了线性分隔边界。若准确率显著低于 100%,可能原因包括学习率过小/过大、epoch 太少或数据并非线性可分(在选取不同特征时尤为常见)。

时间复杂度

设 (n) 为样本数,(d) 为特征维度,(T) 为训练轮数(epochs)。

  • 训练时间复杂度:每个 epoch 中对每个样本做一次预测(内积 (O(d)))并在错误时更新权重(更新也是 (O(d)))。总体为 (O(T \cdot n \cdot d))。

    • 对于本文:(n=100, d=2),所以开销极小,适合交互式演示或嵌入式场景。
  • 预测时间复杂度:对单样本预测是一次内积,复杂度 (O(d))。批量预测 (O(n \cdot d))。

空间复杂度

  • 权重向量存储占用 (O(d))(包含一个偏置)。
  • 不考虑数据本身(若需要在内存中保存数据,则为 (O(n \cdot d)))。
  • 训练过程额外只保存 errors_ 长度为 (T) 的列表,空间开销 (O(T))(通常远小于 (n\cdot d))。

因此总体空间复杂度主要被数据存储支配:(O(n \cdot d))。

小结

  • 本文以通俗且实操的方式展示了如何用 Iris 数据集的前 100 个样本(setosa 与 versicolor)训练一个从零实现的感知器,过程包含数据提取、标签映射、训练、学习曲线和决策边界可视化。
  • 感知器适合线性可分的二分类问题:若数据线性可分,经过若干迭代错误数会降到 0;否则会震荡或停留在某个非零错误数。
  • 实际应用场景很多:教学、轻量级终端判别、快速原型等。若需要更高性能或非线性判别,应该考虑使用支持向量机(SVM)、逻辑回归或神经网络等更强的模型。
  • 最后,代码写得尽量清晰,适合直接运行并修改超参数(学习率 eta、迭代次数 n_iter)或特征选择来观察不同设置的效果。
相关推荐
少林and叔叔3 小时前
基于yolov5.7.0的人工智能算法的下载、开发环境搭建(pycharm)与运行测试
人工智能·pytorch·python·yolo·目标检测·pycharm
心.c3 小时前
深拷贝浅拷贝
开发语言·前端·javascript·ecmascript
曦樂~3 小时前
【Qt】启动新窗口--C/S传输信息
开发语言·qt
源代码•宸3 小时前
Qt6 学习——一个Qt桌面应用程序
开发语言·c++·经验分享·qt·学习·软件构建·windeployqt
寻找华年的锦瑟3 小时前
Qt-UDP
开发语言·qt·udp
橘颂TA3 小时前
【QSS】软件界面的美工操作——Qt 界面优化
开发语言·qt·c/c++·界面设计
合作小小程序员小小店4 小时前
旧版本附近停车场推荐系统demo,基于python+flask+协同推荐(基于用户信息推荐),开发语言python,数据库mysql,
人工智能·python·flask·sklearn·推荐算法
动能小子ohhh4 小时前
Langchain从零开始到应用落地案例[AI智能助手]【3】---使用Paddle-OCR识别优化可识别图片进行解析回答
人工智能·python·pycharm·langchain·ocr·paddle·1024程序员节
Evand J4 小时前
【MATLAB例程】二维环境定位,GDOP和CRLB的计算,锚点数=4的情况(附代码下载链接)
开发语言·matlab·定位·toa·crlb·gdop