3a 感知机训练过程示例(手算拆解,代码实现)

感知机是神经网络的基础,核心是"错误驱动的参数更新",,这里用「实现或门」这个场景,手算拆解训练过程,理解感知机"怎么学、怎么调整参数"。

一、前置准备(明确训练任务:实现或门)

训练目标:用感知机实现或门逻辑,先记住或门的核心规则,再确定训练相关的所有参数(感知机训练的核心就是调整这些参数)。

1. 或门逻辑规则

输入2个特征(x_1、x_2)(仅取0或1),输出 (y)(仅取0或1),规则如下:

  • 输入(0,0) → 输出0(全0才为0)

  • 输入(0,1) → 输出1(有一个为1即为1)

  • 输入(1,0) → 输出1(有一个为1即为1)

  • 输入(1,1) → 输出1(有一个为1即为1)

2. 感知机核心参数与公式

感知机训练只关注3个参数:权重(w_1、w_2)(对应两个输入的重要性)、偏置(b)(调整激活阈值),以及2个核心公式(预测公式+更新公式)。

  • 初始化参数(所有感知机默认起点):(w_1=0、w_2=0、b=0)

  • 学习率(eta=0.1)(每次参数更新的步长,越小越稳定,这里取0.1方便手算)

  • 预测公式(算输出):先算净输入 (z = w_1x_1 + w_2x_2 + b),再判断: 若 (z ≥ 0) → 预测值(y_{pred}=1);若 (z < 0) → 预测值(y_{pred}=0)

  • 更新公式(错了才更,对了不更):

    • 权重更新:(w_i = w_i + eta × (y_{true} - y_{pred}) × x_i)((i=1、2),对应两个权重)

    • 偏置更新:(b = b + eta × (y_{true} - y_{pred}))

    • 核心逻辑:(y_{true} - y_{pred}) 是"误差",误差不为0(预测错)才调整,误差为0(预测对)直接跳过。

3. 训练终止条件

遍历所有4组训练数据,所有数据的预测值都等于真实值,训练结束(否则反复遍历,直到收敛)。

二、手算:感知机训练过程

感知机训练是"反复遍历所有训练数据"的过程,我们这里只展示关键2轮(大部分情况2轮就能收敛到正确参数),每一步都手算,标注清晰,可直接对照验算。

第1轮训练(第一次遍历4组数据)

初始参数:(w_1=0、w_2=0、b=0),逐组计算、判断、更新。

第1组数据:输入(0,0),真实值(y_{true}=0)

  1. 计算净输入:(z = 0×0 + 0×0 + 0 = 0)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值0 ≠ 预测值1(预测错),需要更新参数

  4. 计算误差:(y_{true} - y_{pred} = 0 - 1 = -1)

  5. 参数更新: (w_1 = 0 + 0.1×(-1)×0 = 0)((x_1=0),权重无变化) (w_2 = 0 + 0.1×(-1)×0 = 0)((x_2=0),权重无变化) (b = 0 + 0.1×(-1) = -0.1)

  6. 更新后参数:(w_1=0、w_2=0、b=-0.1)

第2组数据:输入(0,1),真实值(y_{true}=1)

  1. 计算净输入:(z = 0×0 + 0×1 + (-0.1) = -0.1)

  2. 预测值:(z < 0 → y_{pred}=0)

  3. 判断:真实值1 ≠ 预测值0(预测错),需要更新参数

  4. 计算误差:(y_{true} - y_{pred} = 1 - 0 = 1)

  5. 参数更新: (w_1 = 0 + 0.1×1×0 = 0)((x_1=0),权重无变化) (w_2 = 0 + 0.1×1×1 = 0.1)((x_2=1),权重增加) (b = -0.1 + 0.1×1 = 0)

  6. 更新后参数:(w_1=0、w_2=0.1、b=0)

第3组数据:输入(1,0),真实值(y_{true}=1)

  1. 计算净输入:(z = 0×1 + 0.1×0 + 0 = 0)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值1 = 预测值1(预测对),不更新参数

  4. 参数不变:(w_1=0、w_2=0.1、b=0)

第4组数据:输入(1,1),真实值(y_{true}=1)

  1. 计算净输入:(z = 0×1 + 0.1×1 + 0 = 0.1)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值1 = 预测值1(预测对),不更新参数

  4. 参数不变:(w_1=0、w_2=0.1、b=0)

✅ 第1轮结束:仍有数据预测错误(第1组),进入第2轮训练。

第2轮训练(第二次遍历4组数据)

当前参数:(w_1=0、w_2=0.1、b=0),继续逐组计算。

第1组数据:输入(0,0),真实值(y_{true}=0)

  1. 计算净输入:(z = 0×0 + 0.1×0 + 0 = 0)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值0 ≠ 预测值1(预测错),需要更新参数

  4. 计算误差:(y_{true} - y_{pred} = 0 - 1 = -1)

  5. 参数更新: (w_1 = 0 + 0.1×(-1)×0 = 0) (w_2 = 0.1 + 0.1×(-1)×0 = 0.1) (b = 0 + 0.1×(-1) = -0.1)

  6. 更新后参数:(w_1=0、w_2=0.1、b=-0.1)

第2组数据:输入(0,1),真实值(y_{true}=1)

  1. 计算净输入:(z = 0×0 + 0.1×1 + (-0.1) = 0)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值1 = 预测值1(预测对),不更新参数

  4. 参数不变:(w_1=0、w_2=0.1、b=-0.1)

第3组数据:输入(1,0),真实值(y_{true}=1)

  1. 计算净输入:(z = 0×1 + 0.1×0 + (-0.1) = -0.1)

  2. 预测值:(z < 0 → y_{pred}=0)

  3. 判断:真实值1 ≠ 预测值0(预测错),需要更新参数

  4. 计算误差:(y_{true} - y_{pred} = 1 - 0 = 1)

  5. 参数更新: (w_1 = 0 + 0.1×1×1 = 0.1)((x_1=1),权重增加) (w_2 = 0.1 + 0.1×1×0 = 0.1) (b = -0.1 + 0.1×1 = 0)

  6. 更新后参数:(w_1=0.1、w_2=0.1、b=0)

第4组数据:输入(1,1),真实值(y_{true}=1)

  1. 计算净输入:(z = 0.1×1 + 0.1×1 + 0 = 0.2)

  2. 预测值:(z ≥ 0 → y_{pred}=1)

  3. 判断:真实值1 = 预测值1(预测对),不更新参数

  4. 参数不变:(w_1=0.1、w_2=0.1、b=0)

✅ 第2轮结束:所有4组数据预测正确,训练终止!

三、训练结果验证

最终收敛参数:(w_1=0.1、w_2=0.1、b=0),用这组参数验证所有训练数据,确认完全符合或门逻辑。

  • 输入(0,0):(z=0.1×0 + 0.1×0 + 0 = 0) → 自定义(z=0)时输出0(符合或门)

  • 输入(0,1):(z=0.1×0 + 0.1×1 + 0 = 0.1 ≥ 0) → 输出1(符合)

  • 输入(1,0):(z=0.1×1 + 0.1×0 + 0 = 0.1 ≥ 0) → 输出1(符合)

  • 输入(1,1):(z=0.1×1 + 0.1×1 + 0 = 0.2 ≥ 0) → 输出1(符合)

四、总结

感知机训练 = 反复试错 + 小步调整:从全0参数开始,每次遍历数据,预测对了就跳过,预测错了就根据误差,调整权重和偏置(步长由学习率控制),直到所有数据预测正确,本质是找到能分隔两类数据的线性决策边界(或门的决策边界就是 (0.1x_1 + 0.1x_2 = 0))。

补充说明:

  • 感知机只能处理线性可分数据(或门、与门都是线性可分,异或门不行)。

  • 学习率的选择:太大容易震荡不收敛,太小训练速度慢,这里取0.1是为了手算方便,实际代码中可调整为0.01、0.1、0.5等。

  • 参数不唯一:感知机收敛后的参数不是固定的(比如(w_1=0.2、w_2=0.2、b=0)也能实现或门),只要能满足所有数据的预测逻辑即可。

五、补充代码

python 复制代码
import numpy as np

# 定义感知机类(简化版,和手算逻辑一致)
class Perceptron:
    def __init__(self, lr=0.1, epochs=100):
        self.lr = lr  # 学习率
        self.epochs = epochs  # 最大迭代次数
        self.w = None  # 权重
        self.b = 0     # 偏置

    # 训练方法(对应手算的遍历过程)
    def fit(self, X, y):
        n_samples, n_features = X.shape
        self.w = np.zeros(n_features)  # 初始化权重为0
        for epoch in range(self.epochs):
            correct = 0  # 记录预测正确的数量
            for idx, x_i in enumerate(X):
                # 1. 计算净输入和预测值(手算步骤1-2)
                linear_out = np.dot(x_i, self.w) + self.b
                y_pred = 1 if linear_out >= 0 else 0
                # 2. 判断是否正确,正确则计数,错误则更新参数(手算步骤3-5)
                if y_pred == y[idx]:
                    correct += 1
                else:
                    update = self.lr * (y[idx] - y_pred)
                    self.w += update * x_i
                    self.b += update
            # 3. 所有数据预测正确,终止训练(手算终止条件)
            if correct == n_samples:
                print(f"训练提前终止,迭代次数:{epoch+1}")
                break
        return self

    # 预测方法
    def predict(self, X):
        linear_out = np.dot(X, self.w) + self.b
        return np.where(linear_out >= 0, 1, 0)

# 1. 准备或门训练数据(和手算一致)
X = np.array([[0,0], [0,1], [1,0], [1,1]])  # 输入数据
y = np.array([0, 1, 1, 1])                  # 真实标签

# 2. 初始化感知机,训练(学习率0.1,和手算一致)
perceptron = Perceptron(lr=0.1, epochs=100)
perceptron.fit(X, y)

# 3. 输出最终参数(和手算结果对比)
print("最终权重w1, w2:", perceptron.w)
print("最终偏置b:", perceptron.b)

# 4. 验证预测结果
print("预测结果:", perceptron.predict(X))  
相关推荐
zy_destiny2 小时前
【工业场景】用YOLOv26实现4种输电线隐患检测
人工智能·深度学习·算法·yolo·机器学习·计算机视觉·输电线隐患识别
放氮气的蜗牛2 小时前
从头开始学习AI:第五章 - 多分类与正则化技术
人工智能·学习·分类
Black蜡笔小新2 小时前
终结“监控盲区”:EasyGBS视频质量诊断技术多场景应用设计
人工智能·音视频·视频质量诊断
聊聊科技2 小时前
打破固化编曲思维,AI编曲软件为原创音乐人注入制作歌曲伴奏新创意
人工智能
智驱力人工智能2 小时前
货车违规变道检测 高速公路安全治理的工程实践 货车变道检测 高速公路货车违规变道抓拍系统 城市快速路货车压实线识别方案
人工智能·opencv·算法·安全·yolo·目标检测·边缘计算
乾元2 小时前
实战案例:解析某次真实的“AI vs. AI”攻防演练
运维·人工智能·安全·web安全·机器学习·架构
罗湖老棍子2 小时前
【例9.18】合并石子(信息学奥赛一本通- P1274)从暴搜到区间 DP:石子合并的四种写法
算法·动态规划·区间dp·区间动态规划
AiTop1002 小时前
智谱开源GLM-OCR:0.9B小模型在复杂文档处理登顶SOTA
人工智能·ai·aigc
晓晓不觉早2 小时前
OpenAI Codex App的推出:多代理工作流的新时代
人工智能·gpt