RLS(递归最小二乘)算法详解

RLS(递归最小二乘)算法详解

RLS是在线学习 的"王者算法",能实时更新模型,等价于增量伪逆,但数值更稳定。以下从数学原理到代码实现完整讲解。


一、核心思想:从批量到递归

批量最小二乘(使用场景)

给定 nnn 个地点的模式矩阵 P ∈ ℝ^(n×32) 和标签 Y=InY = I_nY=In(单位矩阵),求解:
W∗=argmin∣∣PW−Y∣∣2解:W=P+T(伪逆) W* = argmin ||P W - Y||² 解:W = P⁺ᵀ (伪逆) W∗=argmin∣∣PW−Y∣∣2解:W=P+T(伪逆)
问题 :每次新增地点,需重新计算整个伪逆O(n3)O(n³)O(n3)。

RLS的递归思想

已知 nnn 个样本的最优解 WnW_nWn,当第 n+1n+1n+1 个样本 (pnew,ynew)(p_new, y_new)(pnew,ynew) 到来时:
Wn+1=Wn+ΔW W_{n+1} = W_n + ΔW Wn+1=Wn+ΔW

只基于新样本历史信息 (如协方差矩阵),无需重新计算所有历史样本


二、数学推导(最小二乘的递归形式)

目标函数

在时刻 nnn,最小化累计平方误差:
J(W)=Σi=1n∣∣piW−yi∣∣2+λ∣∣W∣∣2 J(W) = Σ_{i=1}^n ||p_i W - y_i||² + λ ||W||² J(W)=Σi=1n∣∣piW−yi∣∣2+λ∣∣W∣∣2

其中 λ 是正则化系数(防止过拟合)。

关键:协方差矩阵逆的递推

定义 Sn=(PnTPn+λI)−1S_n = (P_nᵀ P_n + λI)^{-1}Sn=(PnTPn+λI)−1(32×32矩阵),则RLS更新包含两个步骤:

步骤1:计算增益向量

kn=Sn−1pnewT/(1+pnewSn−1pnewT)[32×1] k_n = S_{n-1} p_{new}ᵀ / (1 + p_{new} S_{n-1} p_{new}ᵀ) [32×1] kn=Sn−1pnewT/(1+pnewSn−1pnewT)[32×1]
直观理解 :knk_nkn 是"新信息的重要性权重"。如果 pnewp_{new}pnew 与历史数据相似,增益小;如果完全新颖,增益大。

步骤2:更新伪逆矩阵

Wn=Wn−1+kn(ynew−pnewWn−1) W_n = W_{n-1} + k_n (y_{new} - p_{new} W_{n-1}) Wn=Wn−1+kn(ynew−pnewWn−1)

直观理解:新权重 = 旧权重 + 增益 × 预测误差。

步骤3:更新协方差逆矩阵

Sn=Sn−1−knpnewSn−1 S_n = S_{n-1} - k_n p_{new} S_{n-1} Sn=Sn−1−knpnewSn−1
直观理解:新协方差 = 旧协方差 - 已学习的信息。


三、RLS算法实现(Python代码)

python 复制代码
class RLSAssociativeMemory:
    def __init__(self, feature_dim=32, lambda_reg=1.0):
        self.d = feature_dim
        self.n = 0
        
        # 协方差矩阵逆(S),初始化为正则化单位矩阵
        self.S = np.eye(feature_dim) / lambda_reg  # [32, 32]
        
        # 伪逆矩阵W(增量构建)
        self.W = None  # 初始为空
        
        # 正则化参数(遗忘因子)
        self.lambda_reg = lambda_reg
    
    def add_place(self, p_new, place_id):
        """
        增量添加新地点
        p_new: 地点特征向量 [32] 或 [1, 32]
        place_id: 地点索引(0到n-1)
        """
        # 1. 预处理输入
        p_new = np.array(p_new).reshape(1, self.d)  # [1, 32]
        y_new = np.zeros((1, self.n + 1))  # 期望输出(单位矩阵的新列)
        y_new[0, self.n] = 1.0  # 新地点对应位置为1
        
        # 2. 计算增益向量 k_n
        # k = S @ p.T / (1 + p @ S @ p.T)
        k = self.S @ p_new.T  # [32, 1]
        denom = 1.0 + p_new @ k  # 标量
        k = k / denom  # 最终增益
        
        # 3. 更新伪逆矩阵W
        if self.W is None:
            # 第一个地点
            self.W = k  # [32, 1]
        else:
            # 预测误差 (y_new - p_new @ W_old)
            # y_new是[1, n+1],需要扩展W_old
            pred = p_new @ self.W  # [1, n]
            error = np.zeros((1, self.n + 1))
            error[0, :self.n] = -pred  # 旧部分误差
            error[0, self.n] = 1.0    # 新部分误差
            
            # 更新: W = [W_old, 0] + k @ error
            W_extended = np.hstack([self.W, np.zeros((self.d, 1))])
            self.W = W_extended + k @ error  # [32, n+1]
        
        # 4. 更新协方差逆矩阵 S_n
        self.S = self.S - k @ p_new @ self.S
        
        # 5. 记录地点
        self.n += 1
        
        return self.W
    
    def query(self, query_vec, threshold=0.8):
        """
        查询地点索引
        query_vec: [32]
        """
        query_vec = query_vec.reshape(1, self.d)
        scores = query_vec @ self.W  # [1, n]
        
        pred_idx = np.argmax(scores)
        confidence = scores[0, pred_idx]
        
        if confidence < threshold:
            return None, confidence  # 未知地点
        
        return pred_idx, confidence

# 测试
memory = RLSAssociativeMemory(feature_dim=32)

# 添加50个地点(流式)
patterns = np.random.randn(50, 32)
for i in range(50):
    memory.add_place(patterns[i], place_id=i)
    
    # 每10个地点验证一次
    if i % 10 == 0 and i > 0:
        idx, conf = memory.query(patterns[i])
        print(f"添加地点{i}后,查询索引: {idx}, 置信度: {conf:.6f}")

四、RLS与增量伪逆的关系

RLS的 SnS_nSn 就是 PnTPnP_nᵀ P_nPnTPn 的伪逆,而 WnW_nWn 就是 Pn+TP_n⁺ᵀPn+T。

数学等价
Wn=SnPnT W_n = S_n P_nᵀ Wn=SnPnT

当新增样本 pnewp_{new}pnew 时,RLS的递归更新 精确等价于 重新计算 Pn+1+TP_{n+1}⁺ᵀPn+1+T,但复杂度从 O(n³) 降到 O(d²)。


五、RLS的关键优势

1. 数值稳定性

伪逆的直接计算在样本接近线性相关时数值不稳定。RLS通过协方差矩阵逆的递推增益调节,避免了矩阵求逆的病态问题。

2. 计算效率

  • 添加地点:O(d²) = O(32²) = 恒定时间
  • 重新计算伪逆:O(n³) = O(50³) = 125,000次运算
  • 查询:两者都是O(n·d)

3. 遗忘旧数据(可选项)

引入遗忘因子 λ(0 < λ ≤ 1),让旧地点影响衰减:
Sn=(Sn−1−knpnewSn−1)/λ S_n = (S_{n-1} - k_n p_new S_{n-1}) / λ Sn=(Sn−1−knpnewSn−1)/λ

适合动态环境(旧地点可能被重新探索并更新)。


六、RLS在眼动场景中的具体应用

python 复制代码
class BrainRLSMemory:
    def __init__(self, feature_dim=32):
        self.rls = RLSAssociativeMemory(feature_dim)
        self.place_sequences = {}  # 地点ID → 眼动序列
    
    def observe_and_learn(self, flc_vector, gaze_position):
        """
        观察新地点并学习
        """
        place_id = self.rls.n  # 新地点ID
        
        # 1. 增量学习地点特征
        self.rls.add_place(flc_vector, place_id)
        
        # 2. 存储眼动序列(用于预测下一注视点)
        self.place_sequences[place_id] = {
            'sequence': [gaze_position],
            'confidence': 0.0
        }
        
        return place_id
    
    def recognize_and_predict_saccade(self, current_flc):
        """
        识别当前地点并预测下一眼动
        """
        # 1. 查询地点
        place_id, confidence = self.rls.query(current_flc)
        
        if place_id is not None:
            # 2. 从历史序列预测下一注视点
            history = self.place_sequences[place_id]['sequence']
            if len(history) > 1:
                # 用最后两个位置计算眼动向量
                last_gaze = history[-1]
                prev_gaze = history[-2]
                saccade_vector = last_gaze - prev_gaze
                
                return {
                    'place_id': place_id,
                    'saccade': saccade_vector,
                    'confidence': confidence
                }
        
        # 未知地点或数据不足
        return self.exploratory_saccade()
    
    def exploratory_saccade(self):
        """随机探索"""
        return {
            'place_id': None,
            'saccade': np.random.randn(2) * 0.1,  # 小范围随机
            'confidence': 0.0
        }

七、RLS的局限性

  1. 内存占用 :需维护 S 矩阵(32×32),空间复杂度 O(d²)
  2. 初始正则化λ 选择敏感,太大收敛慢,太小不稳定
  3. 无删除操作:无法主动遗忘特定地点(需用遗忘因子间接实现)

八、最终建议

在你的场景中(50个地点,32维特征):

  • 离线训练 :直接用 torch.linalg.pinv()(0.12ms,代码简单)
  • 在线学习:使用RLS(0.03ms/次,适合机器人持续探索)

RLS的核心价值:将批量伪逆的O(n³)问题转化为O(d²)的在线更新,完美契合生物"边探索边学习"的特性。

相关推荐
南方的狮子先生2 小时前
【C++】C++文件读写
java·开发语言·数据结构·c++·算法·1024程序员节
阿里云云原生2 小时前
阿里云 FunctionAI 技术详解:基于 Serverless 的企业级 AI 原生应用基础设施构建
人工智能·阿里云·serverless
感智教育2 小时前
2025 年世界职业院校技能大赛汽车制造与维修赛道备赛方案
人工智能·汽车·制造
Alex艾力的IT数字空间2 小时前
完整事务性能瓶颈分析案例:支付系统事务雪崩优化
开发语言·数据结构·数据库·分布式·算法·中间件·php
8Qi82 小时前
Stable Diffusion详解
人工智能·深度学习·stable diffusion·图像生成
激动的小非2 小时前
电商数据分析报告
大数据·人工智能·数据分析
玖剹2 小时前
二叉树递归题目(一)
c语言·c++·算法·leetcode
ChoSeitaku2 小时前
线代强化NO6|矩阵|例题|小结
算法·机器学习·矩阵
carver w2 小时前
transformer 手写数字识别
人工智能·深度学习·transformer