【自然语言处理 NLP】前沿架构与多模态 状态空间模型(Mamba/SSM)深度实现

目录

[​编辑6 非Transformer架构](#编辑6 非Transformer架构)

[6.1 状态空间模型(Mamba/SSM)深度实现](#6.1 状态空间模型(Mamba/SSM)深度实现)

[6.1.1 连续状态空间模型(S4)的离散化](#6.1.1 连续状态空间模型(S4)的离散化)

[2 结构化伪代码](#2 结构化伪代码)

[6.1.1 结构化伪代码实现](#6.1.1 结构化伪代码实现)

[3 代码实现](#3 代码实现)

脚本1:HiPPO初始化与可视化

脚本2:S4离散化实现

脚本3:卷积模式(FFT加速)

脚本4:循环模式(自回归生成)

脚本5:双模式切换系统


6 非Transformer架构

6.1 状态空间模型(Mamba/SSM)深度实现
6.1.1 连续状态空间模型(S4)的离散化

1.1 连续状态空间模型基础

1.1.1 状态空间表示

连续时间动态系统通过状态空间方程描述输入序列到输出序列的映射。设输入信号为 u(t) \\in \\mathbb{R},隐藏状态为 x(t) \\in \\mathbb{R}\^N,输出为 y(t) \\in \\mathbb{R},系统演化遵循线性常微分方程组:

\\dot{x}(t) = Ax(t) + Bu(t)

y(t) = Cx(t) + Du(t)

其中 A \\in \\mathbb{R}\^{N \\times N} 为状态矩阵,B \\in \\mathbb{R}\^{N \\times 1} 为输入矩阵,C \\in \\mathbb{R}\^{1 \\times N} 为输出矩阵,D \\in \\mathbb{R} 为直馈系数。该表征将序列建模转化为连续时间信号处理,通过适当选择 A 矩阵结构,可实现对长程依赖的有效捕获。

1.1.2 HiPPO初始化理论

HiPPO(High-order Polynomial Projection Operator)理论为状态空间模型提供了一种函数逼近框架,使得状态向量 x(t) 能够压缩历史输入信息。该理论基于在滑动时间窗口上对输入信号进行正交多项式投影,通过最小化重构误差实现最优记忆。

针对尺度不变记忆机制,HiPPO-LegS(Legendre Scale-invariant)采用Legendre多项式基函数,其状态转移遵循特定微分代数结构。定义度量函数为 \\mu(t) = \\frac{1}{t},正交基为 scaled Legendre polynomials,则HiPPO矩阵具备解析形式:

A_{nk} = -\\begin{cases} (2n+1)\^{1/2}(2k+1)\^{1/2} \& \\text{if } n \> k \\\\ n+1 \& \\text{if } n = k \\\\ 0 \& \\text{if } n \< k \\end{cases}

B_n = (2n+1)\^{1/2}

该结构确保状态矩阵 A 具备特定负定性,使得历史输入以多项式衰减方式被记忆,而非指数遗忘。HiPPO初始化将上述理论矩阵作为SSM参数初始值,赋予模型对长序列的归纳偏置。

1.2 S4离散化框架

1.2.1 离散化方法论

将连续SSM转换为离散形式需采用数值积分策略,常见方法包括双线性变换(Bilinear Transform)与零阶保持(Zero-Order Hold, ZOH)。设离散化步长为 \\Delta,则连续到离散的映射关系如下。

双线性变换基于梯形积分规则,将连续系统矩阵映射为离散等价形式:

\\bar{A} = (I - \\Delta/2 \\cdot A)\^{-1}(I + \\Delta/2 \\cdot A)

\\bar{B} = (I - \\Delta/2 \\cdot A)\^{-1}\\Delta B

零阶保持假设输入在采样间隔内保持常数,提供更精确的离散化:

\\bar{A} = e\^{\\Delta A}

\\bar{B} = (\\Delta A)\^{-1}(e\^{\\Delta A} - I)\\Delta B

离散化后的状态空间方程转换为递推形式:

x_k = \\bar{A}x_{k-1} + \\bar{B}u_k

y_k = \\bar{C}x_k + \\bar{D}u_k

其中 \\bar{C} = C\\bar{D} = D

1.2.2 双模式架构设计

S4模型的核心创新在于其双重计算模式:卷积模式适用于并行训练,循环模式支持自回归生成。两种模式共享同一组离散化参数,但计算路径截然不同。

卷积模式利用线性时不变系统的卷积特性,将输出表示为输入与系统脉冲响应的卷积。脉冲响应 \\bar{K} 由状态矩阵完全确定:

\\bar{K} = (C\\bar{A}\^0\\bar{B}, C\\bar{A}\^1\\bar{B}, \\dots, C\\bar{A}\^{L-1}\\bar{B})

该形式允许通过快速傅里叶变换(FFT)在 O(L \\log L) 复杂度内完成序列到序列的映射,其中 L 为序列长度。

循环模式直接执行状态递推,适用于流式处理与自回归采样。通过逐时间步更新隐藏状态,实现 O(N) 每步计算复杂度,其中 N 为状态维度。

1.3 计算模式实现

1.3.1 卷积模式与FFT加速

卷积模式的核心在于计算全局卷积核 \\bar{K} \\in \\mathbb{R}\^L。利用卷积定理,时域卷积等价于频域逐点乘积:

y = u \* \\bar{K} = \\mathcal{F}\^{-1}(\\mathcal{F}(u) \\odot \\mathcal{F}(\\bar{K}_{\\text{padded}}))

其中 \\mathcal{F} 表示离散傅里叶变换,\\odot 为逐元素乘法。为实现该计算,需构造长度为 2L-1 的零填充卷积核,通过FFT计算后截取前 L 个有效输出。

针对长序列,采用重叠相加(Overlap-Add)或重叠保留(Overlap-Save)算法将全局卷积分解为分块处理,降低内存占用并保持计算效率。

1.3.2 循环模式与自回归生成

循环模式通过迭代执行状态更新方程实现序列生成。对于每一步 k,系统接收当前输入 u_k,更新内部状态,并产生输出 y_k

x_k \\leftarrow \\bar{A}x_{k-1} + \\bar{B}u_k

y_k \\leftarrow Cx_k + Du_k

该模式的关键优化在于状态矩阵 \\bar{A} 的结构性分解。S4采用对角化 plus 低秩(Diagonal Plus Low-Rank, DPLR)或纯对角(Diagonal)结构,使得矩阵幂运算可通过逐元素指数运算实现:

\\bar{A} = \\Lambda = \\text{diag}(\\lambda_1, \\lambda_2, \\dots, \\lambda_N)

\\bar{A}\^k B = \\Lambda\^k B = (\\lambda_1\^k B_1, \\lambda_2\^k B_2, \\dots, \\lambda_N\^k B_N)\^\\top

此结构将矩阵向量乘法复杂度降至 O(N),并实现稳定的梯度传播。

1.3.3 模式切换机制

双模式切换机制根据计算上下文动态选择执行路径。训练阶段采用卷积模式以最大化并行度与GPU利用率;推理阶段切换至循环模式以支持流式处理。

切换控制器维护离散化参数集合 \\theta = \\{\\bar{A}, \\bar{B}, C, D\\},根据模式标志位 m \\in \\{\\text{conv, recurrent}\\} 路由计算图。关键实现细节包括:

  • 状态初始化管理:卷积模式下隐藏状态隐式存在,循环模式需显式维护状态缓存。切换时执行状态同步,将卷积模式的累积效应转换为循环模式的初始状态 x_0

  • 核函数预计算:卷积模式离线计算全局卷积核 \\bar{K},存储于内存以避免重复计算;循环模式在线计算单步转移。

  • 梯度流控制:训练时通过卷积模式计算整体梯度,确保时间并行性;循环模式仅用于推理或梯度检查点。


2 结构化伪代码

6.1.1 结构化伪代码实现

算法 1:HiPPO-LegS 初始化

该算法用于初始化状态矩阵 A 与输入矩阵 B,赋予模型尺度不变的记忆偏置。

  • 输入:状态维度 N

  • 过程

    1. 初始化 A \\in \\mathbb{R}\^{N \\times N} 为全零矩阵,B \\in \\mathbb{R}\^N 为全零向量。

    2. 对于 n0N-1

      • 设置 B\[n\] = (2n+1)\^{1/2}

      • 对于 k0n-1

        • 设置 A\[n,k\] = - (2n+1)\^{1/2}(2k+1)\^{1/2}
      • 设置 A\[n,n\] = -(n+1)

  • 输出:矩阵 A, B


算法 2:连续到离散转换(ZOH 方法)

利用零阶保持(Zero-Order Hold)将连续系统参数转换为离散时间步参数。

  • 输入:连续矩阵 A, B,步长 \\Delta

  • 过程

    1. I \\leftarrow \\text{identity}(\\text{dim}(A))

    2. \\bar{A} \\leftarrow \\exp(\\Delta \\cdot A)

    3. A_{\\text{inv}} \\leftarrow \\text{pseudoinverse}(A)

    4. \\bar{B} \\leftarrow A_{\\text{inv}} \\cdot (\\bar{A} - I) \\cdot \\Delta \\cdot B

  • 输出:离散矩阵 \\bar{A}, \\bar{B}


算法 3:卷积模式前向传播(FFT 加速)

适用于训练阶段,通过 FFT 将状态转换变为全局卷积操作。

  • 输入:输入序列 u,离散参数 \\{\\bar{A}, \\bar{B}, C, D\\}

  • 过程

    1. L \\leftarrow \\text{length}(u)

    2. 构造卷积核 \\bar{K}:对于 k \\in \[0, L-1\]\\bar{K}\[k\] = C \\cdot \\bar{A}\^k \\cdot \\bar{B}

    3. u\\bar{K} 进行零填充(Padding)至长度 2L-1

    4. U_{\\text{freq}} \\leftarrow \\text{FFT}(u_{\\text{padded}})

    5. K_{\\text{freq}} \\leftarrow \\text{FFT}(\\bar{K}_{\\text{padded}})

    6. Y_{\\text{freq}} \\leftarrow U_{\\text{freq}} \\odot K_{\\text{freq}} (逐元素乘法)

    7. y \\leftarrow \\text{IFFT}(Y_{\\text{freq}})\[0:L-1\] + D \\cdot u

  • 输出:响应序列 y


算法 4:循环模式前向传播(自回归生成)

适用于推理阶段,通过隐藏状态逐步递推。

  • 输入:输入序列 u,初始状态 x_0,离散参数 \\{\\bar{A}, \\bar{B}, C, D\\}

  • 过程

    1. x \\leftarrow x_0

    2. 对于每个时间步 k0L-1

      • 更新状态:x \\leftarrow \\bar{A} \\cdot x + \\bar{B} \\cdot u\[k\]

      • 计算输出:y\[k\] \\leftarrow C \\cdot x + D \\cdot u\[k\]

  • 输出:序列 y,最终状态 x


算法 5:双模式切换控制器

根据上下文(训练或推理)自动选择计算路径。

  • 输入:输入序列 u,模式 \\text{mode},参数 \\theta = \\{\\bar{A}, \\bar{B}, C, D\\}

  • 过程

    1. 初始化状态缓存 x_{\\text{cache}} 为零向量。

    2. \\text{mode} = \\text{'convolution'}

      • y \\leftarrow \\text{Conv\\_Mode\\_Forward}(u, \\theta)

      • x_{\\text{cache}} \\leftarrow \\text{Extract\\_Final\\_State}(y, \\theta)

    3. \\text{mode} = \\text{'recurrent'}

      • y, x_{\\text{cache}} \\leftarrow \\text{Recurrent\\_Mode\\_Forward}(u, x_{\\text{cache}}, \\theta)
  • 输出:输出序列 y,更新后的状态缓存 x_{\\text{cache}}

3 代码实现

脚本1:HiPPO初始化与可视化

该脚本实现HiPPO-LegS矩阵的构造,可视化状态矩阵的谱特性与记忆核函数。运行方式:python hippo_initializer.py

Python

复制代码
"""
脚本1:HiPPO初始化与可视化
内容:实现HiPPO-LegS矩阵构造,分析其数学特性与可视化
使用方式:直接运行 python hippo_initializer.py
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm, eigvals
import seaborn as sns

class HiPPOLegS:
    """
    HiPPO-LegS初始化实现
    基于Legendre多项式的尺度不变记忆机制
    """
    
    def __init__(self, N):
        """
        初始化HiPPO-LegS矩阵
        
        Args:
            N: 状态维度(Legendre多项式阶数)
        """
        self.N = N
        self.A, self.B = self._build_hippo_matrices()
        
    def _build_hippo_matrices(self):
        """
        构造HiPPO-LegS矩阵A和B
        
        基于Legendre多项式的正交投影:
        - A矩阵为严格下三角加负对角
        - B向量为Legendre多项式归一化系数
        """
        A = np.zeros((self.N, self.N))
        B = np.zeros(self.N)
        
        for n in range(self.N):
            # B[n] = sqrt(2n+1)
            B[n] = np.sqrt(2 * n + 1)
            
            for k in range(n):
                # A[n,k] = -sqrt((2n+1)(2k+1)) for k < n
                A[n, k] = -np.sqrt((2 * n + 1) * (2 * k + 1))
            
            # A[n,n] = -(n+1)
            A[n, n] = -(n + 1)
            
        return A, B
    
    def measure_function(self, t):
        """
        HiPPO-LegS的度量函数 mu(t) = 1/t
        定义在正实数上的概率密度
        """
        return 1.0 / t if t > 0 else 0.0
    
    def visualize_matrices(self):
        """
        可视化HiPPO矩阵结构与特征值分布
        """
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # 1. 热力图展示A矩阵结构
        sns.heatmap(self.A, ax=axes[0,0], cmap='RdBu_r', center=0, 
                   square=True, cbar_kws={'label': 'Value'})
        axes[0,0].set_title('HiPPO-LegS State Matrix $A$')
        axes[0,0].set_xlabel('State Dimension $k$')
        axes[0,0].set_ylabel('State Dimension $n$')
        
        # 2. B向量可视化
        axes[0,1].bar(range(self.N), self.B, color='steelblue', alpha=0.7)
        axes[0,1].set_title('Input Vector $B$ (Legendre Coefficients)')
        axes[0,1].set_xlabel('Index $n$')
        axes[0,1].set_ylabel('$B_n = \sqrt{2n+1}$')
        axes[0,1].grid(True, alpha=0.3)
        
        # 3. 特征值分布(稳定性分析)
        eigenvalues = eigvals(self.A)
        axes[1,0].scatter(np.real(eigenvalues), np.imag(eigenvalues), 
                         c='red', s=100, alpha=0.6, edgecolors='black')
        axes[1,0].axvline(x=0, color='k', linestyle='--', alpha=0.3)
        axes[1,0].set_title('Eigenvalue Distribution of $A$ (Stability)')
        axes[1,0].set_xlabel('Real Part')
        axes[1,0].set_ylabel('Imaginary Part')
        axes[1,0].grid(True, alpha=0.3)
        
        # 4. 记忆核函数(与历史输入的交互)
        t_vals = np.linspace(0.1, 10, 100)
        # 近似记忆核:基于连续系统的脉冲响应
        kernel_approx = np.array([np.sum([self.B[i] * np.exp(self.A[i,i] * t) 
                                         for i in range(min(5, self.N))]) 
                                 for t in t_vals])
        axes[1,1].plot(t_vals, kernel_approx, 'b-', linewidth=2, label='Memory Kernel')
        axes[1,1].set_title('Approximate Memory Decay Kernel')
        axes[1,1].set_xlabel('Time $t$')
        axes[1,1].set_ylabel('Kernel Amplitude')
        axes[1,1].set_yscale('log')
        axes[1,1].grid(True, alpha=0.3)
        axes[1,1].legend()
        
        plt.tight_layout()
        plt.savefig('hippo_visualization.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"HiPPO-LegS矩阵构造完成:维度N={self.N}")
        print(f"A矩阵条件数: {np.linalg.cond(self.A):.2f}")
        print(f"特征值实部范围: [{np.min(np.real(eigenvalues)):.4f}, {np.max(np.real(eigenvalues)):.4f}]")
        print("可视化结果已保存至 hippo_visualization.png")

if __name__ == "__main__":
    # 构造16维HiPPO-LegS系统
    hippo = HiPPOLegS(N=16)
    hippo.visualize_matrices()
    
    # 输出前5x5子矩阵供检查
    print("\nA矩阵前5x5子块:")
    print(np.round(hippo.A[:5, :5], 4))
    print("\nB向量前5个元素:")
    print(np.round(hippo.B[:5], 4))

脚本2:S4离散化实现

该脚本实现连续SSM到离散RNN/CNN的双模式离散化,支持Bilinear与ZOH方法。运行方式:python s4_discretization.py

复制代码
"""
脚本2:S4离散化实现
内容:连续状态空间模型的离散化(ZOH与Bilinear方法)
使用方式:直接运行 python s4_discretization.py
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm, inv, solve
from scipy.integrate import odeint

class S4Discretization:
    """
    S4离散化模块
    实现连续SSM到离散RNN/CNN的转换
    """
    
    def __init__(self, A, B, C, D=0.0):
        """
        初始化连续SSM参数
        
        Args:
            A: 连续状态矩阵 [N, N]
            B: 连续输入矩阵 [N, 1] 或 [N]
            C: 连续输出矩阵 [1, N] 或 [N]
            D: 直馈系数(标量)
        """
        self.A = np.array(A, dtype=np.float64)
        self.B = np.array(B, dtype=np.float64).reshape(-1, 1)
        self.C = np.array(C, dtype=np.float64).reshape(1, -1)
        self.D = float(D)
        self.N = self.A.shape[0]
        
    def zoh_discretization(self, delta):
        """
        零阶保持(ZOH)离散化
        
        数学原理:
        \bar{A} = exp(delta * A)
        \bar{B} = A^{-1} * (exp(delta * A) - I) * delta * B
        
        Args:
            delta: 离散化步长(采样间隔)
            
        Returns:
            A_bar, B_bar: 离散化后的系统矩阵
        """
        # 计算矩阵指数:\bar{A} = exp(\Delta A)
        A_bar = expm(delta * self.A)
        
        # 计算 \bar{B} = A^{-1} (exp(\Delta A) - I) \Delta B
        I = np.eye(self.N)
        
        # 处理A的逆(使用伪逆确保数值稳定性)
        A_inv = np.linalg.pinv(self.A)
        
        # 计算 (exp(\Delta A) - I)
        exp_diff = A_bar - I
        
        # \bar{B} = A^{-1} (exp(\Delta A) - I) \Delta B
        B_bar = A_inv @ exp_diff @ (delta * self.B)
        
        return A_bar, B_bar.flatten()
    
    def bilinear_discretization(self, delta):
        """
        双线性变换(Tustin方法/梯形积分)
        
        数学原理:
        \bar{A} = (I - delta/2 * A)^{-1} (I + delta/2 * A)
        \bar{B} = (I - delta/2 * A)^{-1} delta * B
        
        Args:
            delta: 离散化步长
            
        Returns:
            A_bar, B_bar: 离散化后的系统矩阵
        """
        I = np.eye(self.N)
        half_delta = delta / 2.0
        
        # 计算 (I - \Delta/2 * A)^{-1}
        left_term = inv(I - half_delta * self.A)
        
        # \bar{A} = (I - \Delta/2 A)^{-1} (I + \Delta/2 A)
        A_bar = left_term @ (I + half_delta * self.A)
        
        # \bar{B} = (I - \Delta/2 A)^{-1} \Delta B
        B_bar = left_term @ (delta * self.B)
        
        return A_bar, B_bar.flatten()
    
    def compare_methods(self, delta_range):
        """
        比较不同离散化方法对系统动态的影响
        
        Args:
            delta_range: 步长范围数组
        """
        stability_zoh = []
        stability_bilinear = []
        
        for delta in delta_range:
            A_zoh, _ = self.zoh_discretization(delta)
            A_bil, _ = self.bilinear_discretization(delta)
            
            # 计算特征值模长(稳定性判据:|lambda| < 1)
            eig_zoh = np.max(np.abs(np.linalg.eigvals(A_zoh)))
            eig_bil = np.max(np.abs(np.linalg.eigvals(A_bil)))
            
            stability_zoh.append(eig_zoh)
            stability_bilinear.append(eig_bil)
        
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # 1. 稳定性对比
        axes[0].plot(delta_range, stability_zoh, 'b-o', label='ZOH', markersize=6)
        axes[0].plot(delta_range, stability_bilinear, 'r-s', label='Bilinear', markersize=6)
        axes[0].axhline(y=1.0, color='k', linestyle='--', label='Stability Boundary')
        axes[0].set_xlabel('Discretization Step $\Delta$')
        axes[0].set_ylabel('Max Eigenvalue Magnitude $|\lambda|_{\max}$')
        axes[0].set_title('Stability Comparison: ZOH vs Bilinear')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        axes[0].set_yscale('log')
        
        # 2. 步长delta=0.1时的矩阵结构对比
        delta_test = 0.1
        A_zoh, B_zoh = self.zoh_discretization(delta_test)
        A_bil, B_bil = self.bilinear_discretization(delta_test)
        
        im1 = axes[1].imshow(np.abs(A_zoh - A_bil), cmap='hot', aspect='auto')
        axes[1].set_title(f'Absolute Difference: $|A_{{ZOH}} - A_{{Bilinear}}|$ ($\Delta={delta_test}$)')
        axes[1].set_xlabel('State Dimension')
        axes[1].set_ylabel('State Dimension')
        plt.colorbar(im1, ax=axes[1])
        
        plt.tight_layout()
        plt.savefig('discretization_comparison.png', dpi=300)
        plt.show()
        
        return np.array(stability_zoh), np.array(stability_bilinear)

    def visualize_state_transition(self, delta=0.1, steps=50):
        """
        可视化离散化后的状态转移动态
        """
        A_bar, B_bar = self.zoh_discretization(delta)
        
        # 模拟单位脉冲输入下的状态演化
        x = np.zeros((steps, self.N))
        x[0] = B_bar  # 初始脉冲
        
        for k in range(1, steps):
            x[k] = A_bar @ x[k-1]
        
        fig, axes = plt.subplots(2, 1, figsize=(12, 8))
        
        # 1. 状态轨迹热力图
        im = axes[0].imshow(x.T, aspect='auto', cmap='viridis', origin='lower')
        axes[0].set_title(f'State Evolution under Impulse Input (ZOH, $\Delta={delta}$)')
        axes[0].set_xlabel('Time Step $k$')
        axes[0].set_ylabel('State Dimension $n$')
        plt.colorbar(im, ax=axes[0])
        
        # 2. 各维度衰减曲线
        for i in range(min(5, self.N)):
            axes[1].plot(x[:, i], label=f'State {i}', linewidth=2)
        axes[1].set_title('Individual State Trajectories')
        axes[1].set_xlabel('Time Step $k$')
        axes[1].set_ylabel('State Value')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        axes[1].set_yscale('symlog')
        
        plt.tight_layout()
        plt.savefig('state_transition.png', dpi=300)
        plt.show()

if __name__ == "__main__":
    # 构造测试用的HiPPO-like矩阵(简化版)
    N = 8
    A_cont = -np.diag(np.arange(1, N+1)) + np.diag(np.ones(N-1), k=1) * 0.5
    B_cont = np.ones(N)
    C_cont = np.ones(N) / N
    
    s4 = S4Discretization(A_cont, B_cont, C_cont, D=0.0)
    
    # 对比不同离散化方法
    deltas = np.logspace(-3, 0, 50)
    stab_zoh, stab_bil = s4.compare_methods(deltas)
    
    # 可视化状态转移
    s4.visualize_state_transition(delta=0.05, steps=100)
    
    # 数值验证:检查离散化后系统的稳定性
    delta_stable = 0.01
    A_z, B_z = s4.zoh_discretization(delta_stable)
    eigenvals = np.linalg.eigvals(A_z)
    print(f"\nZOH离散化验证 ($\Delta={delta_stable}$):")
    print(f"特征值最大模长: {np.max(np.abs(eigenvals)):.6f} (<1.0 稳定)")
    print(f"离散化A矩阵条件数: {np.linalg.cond(A_z):.2f}")

脚本3:卷积模式(FFT加速)

该脚本实现S4的卷积模式,利用FFT计算全局卷积核,支持长序列处理。运行方式:python convolution_mode.py

复制代码
"""
脚本3:卷积模式(FFT加速)
内容:实现S4卷积模式,基于FFT的全局卷积计算
使用方式:直接运行 python convolution_mode.py
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm
import time

class S4ConvolutionMode:
    """
    S4卷积模式实现
    利用卷积定理通过FFT加速全局序列建模
    """
    
    def __init__(self, A, B, C, D=0.0, Delta=0.001):
        """
        初始化卷积模式参数
        
        Args:
            A, B, C, D: 连续SSM参数
            Delta: 离散化步长
        """
        self.N = A.shape[0]
        self.D = D
        
        # 执行ZOH离散化
        self.A_bar = expm(Delta * A)
        
        # 计算离散化B(确保数值稳定性)
        I = np.eye(self.N)
        A_inv = np.linalg.pinv(A)
        self.B_bar = (A_inv @ (self.A_bar - I) @ (Delta * B.reshape(-1, 1))).flatten()
        self.C = C.reshape(1, -1)
        
    def compute_kernel(self, L):
        """
        计算全局卷积核 \bar{K} = (C\bar{A}^0\bar{B}, C\bar{A}^1\bar{B}, ..., C\bar{A}^{L-1}\bar{B})
        
        利用对角化加速计算(假设A可对角化或采用DPLR结构)
        
        Args:
            L: 序列长度(卷积核长度)
            
        Returns:
            K: 卷积核向量 [L]
        """
        K = np.zeros(L)
        
        # 方法1:直接计算(适用于小L)
        # 对于大型系统,应使用谱分解或DPLR结构
        current_power = np.eye(self.N)
        Ab = self.A_bar
        
        # 利用矩阵幂运算优化:计算 C A^k B
        for k in range(L):
            K[k] = self.C @ current_power @ self.B_bar
            current_power = current_power @ Ab
        
        return K
    
    def fft_convolution(self, u):
        """
        基于FFT的快速卷积实现
        
        数学原理:
        y = K * u (卷积) = IFFT(FFT(K_padded) * FFT(u_padded))
        
        Args:
            u: 输入序列 [L] 或 [batch, L]
            
        Returns:
            y: 输出序列(与u同shape)
        """
        # 处理输入维度
        single_input = False
        if u.ndim == 1:
            u = u.reshape(1, -1)
            single_input = True
        
        batch_size, L = u.shape
        
        # 计算卷积核
        K = self.compute_kernel(L)
        
        # FFT卷积需要长度至少 2L-1
        L_fft = 2 * L - 1
        
        # 零填充至FFT长度
        K_padded = np.zeros(L_fft)
        K_padded[:L] = K
        
        u_padded = np.zeros((batch_size, L_fft))
        u_padded[:, :L] = u
        
        # FFT变换
        K_fft = np.fft.rfft(K_padded)
        u_fft = np.fft.rfft(u_padded, axis=1)
        
        # 频域乘积
        y_fft = u_fft * K_fft[np.newaxis, :]
        
        # IFFT逆变换
        y_full = np.fft.irfft(y_fft, n=L_fft, axis=1)
        
        # 取前L个有效值(因果卷积)
        y = y_full[:, :L]
        
        # 添加D项(逐元素乘积)
        y = y + self.D * u
        
        if single_input:
            y = y.flatten()
            
        return y
    
    def naive_convolution(self, u):
        """
        朴素卷积实现(用于验证FFT正确性)
        """
        L = len(u) if u.ndim == 1 else u.shape[1]
        K = self.compute_kernel(L)
        
        if u.ndim == 1:
            y = np.convolve(K, u, mode='full')[:L]
            y = y + self.D * u
        else:
            y = np.array([np.convolve(K, u[i], mode='full')[:L] for i in range(u.shape[0])])
            y = y + self.D * u
            
        return y
    
    def benchmark_speed(self, lengths=[128, 256, 512, 1024, 2048, 4096, 8192]):
        """
        对比FFT卷积与朴素卷积的计算效率
        """
        fft_times = []
        naive_times = []
        
        for L in lengths:
            u = np.random.randn(L)
            
            # FFT方法计时
            start = time.time()
            for _ in range(100):
                _ = self.fft_convolution(u)
            fft_times.append((time.time() - start) / 100)
            
            # 朴素方法计时(仅对较短序列)
            if L <= 2048:
                start = time.time()
                for _ in range(100):
                    _ = self.naive_convolution(u)
                naive_times.append((time.time() - start) / 100)
            else:
                naive_times.append(np.nan)
        
        # 可视化
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # 1. 时间对比
        axes[0].plot(lengths, fft_times, 'b-o', label='FFT Convolution', markersize=8, linewidth=2)
        valid_naive = [i for i, t in enumerate(naive_times) if not np.isnan(t)]
        if valid_naive:
            axes[0].plot([lengths[i] for i in valid_naive], 
                        [naive_times[i] for i in valid_naive], 
                        'r-s', label='Naive Convolution', markersize=8, linewidth=2)
        axes[0].set_xlabel('Sequence Length $L$')
        axes[0].set_ylabel('Time (seconds)')
        axes[0].set_title('Computational Efficiency: FFT vs Naive')
        axes[0].set_xscale('log', base=2)
        axes[0].set_yscale('log')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 2. 复杂度分析
        theoretical_fft = [L * np.log2(L) for L in lengths]
        theoretical_naive = [L**2 for L in lengths]
        
        axes[1].plot(lengths, theoretical_fft, 'b--', label='$O(L \log L)$', linewidth=2)
        axes[1].plot(lengths, theoretical_naive, 'r--', label='$O(L^2)$', linewidth=2)
        axes[1].set_xlabel('Sequence Length $L$')
        axes[1].set_ylabel('Theoretical Operations')
        axes[1].set_title('Computational Complexity')
        axes[1].set_xscale('log', base=2)
        axes[1].set_yscale('log')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('convolution_benchmark.png', dpi=300)
        plt.show()
        
        return np.array(fft_times), np.array(naive_times)
    
    def visualize_frequency_response(self, L=1024):
        """
        可视化卷积核的频域响应特性
        """
        K = self.compute_kernel(L)
        
        # 计算频率响应
        K_fft = np.fft.fft(K)
        freqs = np.fft.fftfreq(L, d=1.0)
        
        # 排序频率
        idx = np.argsort(freqs)
        freqs_sorted = freqs[idx]
        magnitude = np.abs(K_fft[idx])
        phase = np.angle(K_fft[idx])
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # 1. 时域卷积核
        axes[0,0].plot(K[:100], 'b-', linewidth=2)  # 显示前100个点
        axes[0,0].set_title('Impulse Response $K$ (First 100 steps)')
        axes[0,0].set_xlabel('Time Step $k$')
        axes[0,0].set_ylabel('Amplitude')
        axes[0,0].grid(True, alpha=0.3)
        
        # 2. 频域幅度响应
        axes[0,1].plot(freqs_sorted, 20*np.log10(magnitude + 1e-10), 'r-', linewidth=2)
        axes[0,1].set_title('Frequency Response (Magnitude)')
        axes[0,1].set_xlabel('Frequency')
        axes[0,1].set_ylabel('Magnitude (dB)')
        axes[0,1].grid(True, alpha=0.3)
        
        # 3. 相位响应
        axes[1,0].plot(freqs_sorted, phase, 'g-', linewidth=2)
        axes[1,0].set_title('Phase Response')
        axes[1,0].set_xlabel('Frequency')
        axes[1,0].set_ylabel('Phase (radians)')
        axes[1,0].grid(True, alpha=0.3)
        
        # 4. 零极点图(近似)
        # 计算系统函数H(z)的零点(通过卷积核的DFT近似)
        poles = np.roots(K[:64]) if len(K) >= 64 else np.roots(K)
        axes[1,1].scatter(np.real(poles), np.imag(poles), c='blue', s=50, alpha=0.6)
        axes[1,1].add_patch(plt.Circle((0,0), 1, fill=False, color='red', linestyle='--'))
        axes[1,1].set_title('Pole-Zero Plot (Approximation)')
        axes[1,1].set_xlabel('Real')
        axes[1,1].set_ylabel('Imaginary')
        axes[1,1].grid(True, alpha=0.3)
        axes[1,1].axis('equal')
        
        plt.tight_layout()
        plt.savefig('frequency_response.png', dpi=300)
        plt.show()

if __name__ == "__main__":
    # 构造测试系统(使用简化的HiPPO-like参数)
    N = 16
    A = -np.diag(np.linspace(1, N, N)) + np.diag(np.ones(N-1), k=-1) * 0.5
    B = np.ones(N) / np.sqrt(N)
    C = np.ones(N) / np.sqrt(N)
    
    s4_conv = S4ConvolutionMode(A, B, C, D=0.01, Delta=0.001)
    
    # 生成测试信号:正弦波叠加噪声
    L = 1024
    t = np.linspace(0, 4*np.pi, L)
    u_test = np.sin(2*np.pi*t) + 0.5*np.sin(8*np.pi*t) + 0.1*np.random.randn(L)
    
    # 执行卷积
    y_fft = s4_conv.fft_convolution(u_test)
    y_naive = s4_conv.naive_convolution(u_test)
    
    # 验证数值等价性
    error = np.max(np.abs(y_fft - y_naive))
    print(f"FFT与朴素卷积数值误差: {error:.2e}")
    
    # 可视化处理效果
    fig, axes = plt.subplots(2, 1, figsize=(12, 6))
    axes[0].plot(u_test, 'b-', alpha=0.7, label='Input Signal')
    axes[0].set_title('Input Signal (Multi-frequency Sine + Noise)')
    axes[0].set_xlabel('Time Step')
    axes[0].set_ylabel('Amplitude')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(y_fft, 'r-', linewidth=2, label='SSM Output (Filtered)')
    axes[1].set_title('Output Signal (Low-pass Characteristic)')
    axes[1].set_xlabel('Time Step')
    axes[1].set_ylabel('Amplitude')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('convolution_demo.png', dpi=300)
    plt.show()
    
    # 性能基准测试
    fft_times, naive_times = s4_conv.benchmark_speed()
    
    # 频率响应分析
    s4_conv.visualize_frequency_response()
    
    print("\n卷积模式验证完成。所有可视化结果已保存。")

脚本4:循环模式(自回归生成)

该脚本实现S4的循环模式,支持逐步状态更新与自回归生成。运行方式:python recurrent_mode.py

复制代码
"""
脚本4:循环模式(自回归生成)
内容:实现S4循环模式,支持逐步状态更新与自回归序列生成
使用方式:直接运行 python recurrent_mode.py
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm

class S4RecurrentMode:
    """
    S4循环模式实现
    适用于自回归生成与流式处理
    """
    
    def __init__(self, A, B, C, D=0.0, Delta=0.001):
        """
        初始化循环模式参数
        
        Args:
            A, B, C, D: 连续SSM参数
            Delta: 离散化步长
        """
        self.N = A.shape[0]
        self.D = D
        
        # 离散化(ZOH)
        self.A_bar = expm(Delta * A)
        
        I = np.eye(self.N)
        A_inv = np.linalg.pinv(A)
        self.B_bar = (A_inv @ (self.A_bar - I) @ (Delta * B.reshape(-1, 1))).flatten()
        self.C = C.reshape(-1)
        
        # 状态缓存(用于流式处理)
        self.state_cache = None
        self.current_step = 0
        
    def reset_state(self, batch_size=1):
        """
        重置/初始化隐藏状态
        """
        self.state_cache = np.zeros((batch_size, self.N))
        self.current_step = 0
        
    def step(self, u_t):
        """
        单步前向传播(循环模式核心)
        
        数学实现:
        x_t = A_bar * x_{t-1} + B_bar * u_t
        y_t = C * x_t + D * u_t
        
        Args:
            u_t: 当前时间步输入 [batch_size] 或 标量
            
        Returns:
            y_t: 当前时间步输出
            x_t: 当前状态(可选返回)
        """
        if self.state_cache is None:
            self.reset_state(batch_size=1 if np.isscalar(u_t) else len(u_t))
        
        # 确保输入维度正确
        if np.isscalar(u_t):
            u_t = np.array([u_t])
        
        u_t = u_t.reshape(-1, 1)  # [batch, 1]
        
        # 状态更新:x_t = A_bar @ x_{t-1} + B_bar * u_t
        # 利用广播机制处理batch
        self.state_cache = (self.state_cache @ self.A_bar.T + 
                           u_t @ self.B_bar.reshape(1, -1))
        
        # 输出计算:y_t = C @ x_t + D * u_t
        y_t = self.state_cache @ self.C + self.D * u_t.flatten()
        
        self.current_step += 1
        
        return y_t[0] if len(y_t) == 1 else y_t
    
    def generate_sequence(self, u_sequence):
        """
        处理完整输入序列(循环展开)
        
        Args:
            u_sequence: 输入序列 [L] 或 [batch, L]
            
        Returns:
            y_sequence: 输出序列(与输入同shape)
            final_state: 最终状态
        """
        if u_sequence.ndim == 1:
            u_sequence = u_sequence.reshape(1, -1)
            single_batch = True
        else:
            single_batch = False
            
        batch_size, L = u_sequence.shape
        
        # 重置状态
        self.reset_state(batch_size)
        
        y_sequence = np.zeros((batch_size, L))
        
        # 逐步执行
        for t in range(L):
            y_t = self.step(u_sequence[:, t])
            y_sequence[:, t] = y_t if np.ndim(y_t) > 0 else [y_t]
        
        if single_batch:
            y_sequence = y_sequence.flatten()
            
        return y_sequence, self.state_cache.copy()
    
    def autoregressive_generate(self, initial_input, steps, feedback_scale=0.1):
        """
        自回归生成:使用模型输出作为下一步输入
        
        Args:
            initial_input: 初始输入值或序列
            steps: 生成步数
            feedback_scale: 反馈系数(控制自回归强度)
            
        Returns:
            generated_sequence: 生成的序列
        """
        self.reset_state(batch_size=1)
        
        generated = []
        current_u = initial_input if np.isscalar(initial_input) else initial_input[-1]
        
        # 若有初始序列,先预热状态
        if not np.isscalar(initial_input) and len(initial_input) > 1:
            for u in initial_input[:-1]:
                _ = self.step(u)
        
        # 自回归生成
        for _ in range(steps):
            y_t = self.step(current_u)
            generated.append(y_t)
            
            # 输出反馈作为下一步输入(可添加噪声增加多样性)
            current_u = y_t * feedback_scale + 0.01 * np.random.randn()
        
        return np.array(generated)
    
    def compare_with_convolution(self, u, conv_model):
        """
        验证循环模式与卷积模式数值等价性
        """
        # 循环模式结果
        y_rec, _ = self.generate_sequence(u)
        
        # 卷积模式结果(假设conv_model提供fft_convolution方法)
        y_conv = conv_model.fft_convolution(u)
        
        error = np.max(np.abs(y_rec - y_conv))
        
        # 可视化对比
        fig, axes = plt.subplots(2, 1, figsize=(12, 8))
        
        L = len(u)
        t = np.arange(L)
        
        axes[0].plot(t, y_rec, 'b-', label='Recurrent Mode', linewidth=2, alpha=0.8)
        axes[0].plot(t, y_conv, 'r--', label='Convolution Mode', linewidth=2, alpha=0.8)
        axes[0].set_title(f'Mode Comparison (Max Error: {error:.2e})')
        axes[0].set_xlabel('Time Step')
        axes[0].set_ylabel('Output')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 误差曲线
        axes[1].semilogy(t, np.abs(y_rec - y_conv) + 1e-10, 'g-', linewidth=2)
        axes[1].set_title('Absolute Difference (Log Scale)')
        axes[1].set_xlabel('Time Step')
        axes[1].set_ylabel('$|y_{rec} - y_{conv}|$')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('mode_comparison.png', dpi=300)
        plt.show()
        
        return error
    
    def visualize_hidden_dynamics(self, u):
        """
        可视化隐藏状态的动态演化
        """
        batch_size = 1 if u.ndim == 1 else u.shape[0]
        L = len(u) if u.ndim == 1 else u.shape[1]
        
        self.reset_state(batch_size)
        
        # 记录状态历史
        state_history = np.zeros((L, self.N))
        
        for t in range(L):
            if u.ndim == 1:
                _ = self.step(u[t])
            else:
                _ = self.step(u[0, t])  # 取第一个batch
            state_history[t] = self.state_cache[0]
        
        fig, axes = plt.subplots(2, 1, figsize=(12, 8))
        
        # 1. 状态轨迹热力图
        im = axes[0].imshow(state_history.T, aspect='auto', cmap='coolwarm', 
                           origin='lower', interpolation='nearest')
        axes[0].set_title('Hidden State Evolution Over Time')
        axes[0].set_xlabel('Time Step $t$')
        axes[0].set_ylabel('State Dimension $n$')
        plt.colorbar(im, ax=axes[0])
        
        # 2. 选定维度的轨迹
        dims_to_plot = min(4, self.N)
        for i in range(dims_to_plot):
            axes[1].plot(state_history[:, i], label=f'State dim {i}', linewidth=2)
        axes[1].set_title('Individual State Trajectories')
        axes[1].set_xlabel('Time Step $t$')
        axes[1].set_ylabel('State Value')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('hidden_dynamics.png', dpi=300)
        plt.show()
        
        return state_history

if __name__ == "__main__":
    # 构造测试系统
    N = 16
    # 构造对角化友好的矩阵(确保稳定性)
    eigenvals = -np.linspace(0.1, 2.0, N) + 1j * np.linspace(-1, 1, N) * 0.5
    A = np.diag(eigenvals).real  # 简化:使用实对角阵
    B = np.ones(N) / np.sqrt(N)
    C = np.ones(N) / np.sqrt(N)
    
    s4_rec = S4RecurrentMode(A, B, C, D=0.01, Delta=0.01)
    
    # 测试1:序列处理
    L = 200
    t = np.linspace(0, 10, L)
    u_input = np.sin(t) * np.exp(-0.1 * t)  # 衰减正弦
    
    y_output, final_state = s4_rec.generate_sequence(u_input)
    
    fig, axes = plt.subplots(2, 1, figsize=(12, 6))
    axes[0].plot(t, u_input, 'b-', label='Input', linewidth=2)
    axes[0].set_title('Input Signal')
    axes[0].set_xlabel('Time')
    axes[0].set_ylabel('Amplitude')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    
    axes[1].plot(t, y_output, 'r-', label='Output', linewidth=2)
    axes[1].set_title('Recurrent Mode Output')
    axes[1].set_xlabel('Time Step')
    axes[1].set_ylabel('Amplitude')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    
    plt.tight_layout()
    plt.savefig('recurrent_demo.png', dpi=300)
    plt.show()
    
    # 测试2:自回归生成
    print("\n执行自回归生成...")
    init_seq = u_input[:20]
    generated = s4_rec.autoregressive_generate(init_seq, steps=100, feedback_scale=0.5)
    
    plt.figure(figsize=(10, 4))
    plt.plot(range(len(init_seq)), init_seq, 'b-o', label='Initial Context', markersize=6)
    plt.plot(range(len(init_seq), len(init_seq)+len(generated)), generated, 'r-s', 
             label='Autoregressive Gen', markersize=6)
    plt.axvline(x=len(init_seq)-1, color='k', linestyle='--', alpha=0.5, label='Generation Start')
    plt.title('Autoregressive Generation Demo')
    plt.xlabel('Time Step')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('autoregressive_gen.png', dpi=300)
    plt.show()
    
    # 测试3:隐藏状态动态可视化
    state_hist = s4_rec.visualize_hidden_dynamics(u_input)
    
    print(f"\n循环模式验证完成。")
    print(f"最终状态范数: {np.linalg.norm(final_state):.4f}")
    print(f"自回归序列均值: {np.mean(generated):.4f}, 方差: {np.var(generated):.4f}")

脚本5:双模式切换系统

该脚本实现完整的双模式切换控制器,集成卷积与循环模式,支持训练-推理工作流。运行方式:python dual_mode_system.py

复制代码
"""
脚本5:双模式切换系统
内容:实现S4双模式切换控制器,集成卷积模式(训练)与循环模式(推理)
使用方式:直接运行 python dual_mode_system.py
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm
import time

# 假设前序脚本已导入,或在此处重新定义核心类(为独立性考虑,包含必要代码)

class S4DualModeSystem:
    """
    S4双模式切换系统
    集成卷积模式(并行训练)与循环模式(流式推理)
    """
    
    def __init__(self, N, Delta=0.001, D=0.0, mode='conv'):
        """
        初始化双模式系统
        
        Args:
            N: 状态维度
            Delta: 离散化步长
            D: 直馈系数
            mode: 初始模式 'conv' 或 'recurrent'
        """
        self.N = N
        self.Delta = Delta
        self.D = D
        self.mode = mode
        
        # 初始化HiPPO-like参数(简化版)
        self._init_parameters()
        
        # 离散化参数(预计算)
        self._discretize()
        
        # 状态管理
        self.state_cache = np.zeros(N)
        self.is_state_dirty = False  # 标记状态是否需要同步
        
    def _init_parameters(self):
        """
        初始化连续SSM参数(使用简化的HiPPO结构)
        """
        # 构造稳定的连续系统矩阵
        # A = -diag(1,2,...,N) + 次对角线耦合
        self.A = -np.diag(np.linspace(1, self.N, self.N))
        if self.N > 1:
            coupling = np.diag(np.ones(self.N-1) * 0.5, k=-1)
            self.A += coupling
        
        self.B = np.ones(self.N) / np.sqrt(self.N)
        self.C = np.ones(self.N) / np.sqrt(self.N)
        
    def _discretize(self):
        """
        预计算离散化参数(ZOH方法)
        """
        # A_bar = exp(Delta * A)
        self.A_bar = expm(self.Delta * self.A)
        
        # B_bar = A^{-1} (exp(Delta A) - I) Delta B
        I = np.eye(self.N)
        A_inv = np.linalg.pinv(self.A)
        self.B_bar = (A_inv @ (self.A_bar - I) @ (self.Delta * self.B.reshape(-1, 1))).flatten()
        
        # C和D保持不变
        self.C_row = self.C.reshape(1, -1)
        
    def _compute_convolution_kernel(self, L):
        """
        计算卷积核(用于卷积模式)
        """
        K = np.zeros(L)
        power = np.eye(self.N)
        
        for k in range(L):
            K[k] = self.C_row @ power @ self.B_bar
            power = power @ self.A_bar
            
        return K
    
    def switch_mode(self, new_mode, input_sequence=None):
        """
        模式切换控制器
        
        关键逻辑:
        - conv -> recurrent: 需要计算最终状态作为初始状态
        - recurrent -> conv: 清除状态缓存
        
        Args:
            new_mode: 目标模式 'conv' 或 'recurrent'
            input_sequence: 若从conv切到recurrent,需要历史序列计算初始状态
        """
        if new_mode == self.mode:
            return
        
        if self.mode == 'conv' and new_mode == 'recurrent':
            # 卷积到循环:需要从卷积历史提取最终状态
            if input_sequence is not None:
                # 通过循环前向计算最终状态
                self.state_cache = self._compute_final_state(input_sequence)
            self.is_state_dirty = False
            
        elif self.mode == 'recurrent' and new_mode == 'conv':
            # 循环到卷积:状态被丢弃,下次卷积重新计算
            self.state_cache = np.zeros(self.N)
            self.is_state_dirty = True
            
        self.mode = new_mode
        print(f"模式切换完成: {self.mode}")
        
    def _compute_final_state(self, u):
        """
        通过循环展开计算序列的最终状态(用于模式切换时状态同步)
        """
        x = np.zeros(self.N)
        for u_t in u:
            x = self.A_bar @ x + self.B_bar * u_t
        return x
    
    def forward(self, u):
        """
        统一前向接口,根据当前模式自动路由
        
        Args:
            u: 输入序列 [L] 或 单步输入(标量)
            
        Returns:
            输出序列或单步输出
        """
        if self.mode == 'conv':
            return self._conv_forward(u)
        else:
            return self._recurrent_forward(u)
    
    def _conv_forward(self, u):
        """
        卷积模式前向传播(FFT加速)
        """
        L = len(u) if np.ndim(u) == 1 else u.shape[1]
        
        # 计算卷积核
        K = self._compute_convolution_kernel(L)
        
        # FFT卷积
        L_fft = 2 * L - 1
        K_padded = np.zeros(L_fft)
        K_padded[:L] = K
        
        u_padded = np.zeros(L_fft)
        u_padded[:L] = u
        
        K_fft = np.fft.rfft(K_padded)
        u_fft = np.fft.rfft(u_padded)
        
        y_fft = u_fft * K_fft
        y_full = np.fft.irfft(y_fft, n=L_fft)
        y = y_full[:L] + self.D * u
        
        # 更新状态缓存(供后续切换到循环模式使用)
        self.state_cache = self._compute_final_state(u)
        
        return y
    
    def _recurrent_forward(self, u):
        """
        循环模式前向传播(支持单步或序列)
        """
        if np.isscalar(u) or (isinstance(u, np.ndarray) and u.ndim == 0):
            # 单步模式
            self.state_cache = self.A_bar @ self.state_cache + self.B_bar * u
            y = self.C_row @ self.state_cache + self.D * u
            return y.item() if y.size == 1 else y
            
        elif isinstance(u, np.ndarray) and u.ndim == 1:
            # 序列模式(循环展开)
            L = len(u)
            y_seq = np.zeros(L)
            
            for t in range(L):
                self.state_cache = self.A_bar @ self.state_cache + self.B_bar * u[t]
                y_seq[t] = self.C_row @ self.state_cache + self.D * u[t]
                
            return y_seq
        
        else:
            raise ValueError("Input must be scalar or 1D array")
    
    def train_inference_pipeline(self, train_data, test_initial):
        """
        完整的训练-推理流程演示
        
        Args:
            train_data: 训练序列(长序列,使用卷积模式)
            test_initial: 推理初始上下文(短序列,使用循环模式)
        """
        print("=== 阶段1: 训练模式(卷积) ===")
        self.switch_mode('conv')
        
        # 训练前向(并行处理长序列)
        start_time = time.time()
        train_output = self.forward(train_data)
        train_time = time.time() - start_time
        
        print(f"卷积模式处理长度 {len(train_data)} 序列耗时: {train_time:.4f}s")
        
        # 模拟训练后的状态同步
        print("\n=== 阶段2: 模式切换 ===")
        self.switch_mode('recurrent', input_sequence=train_data)
        
        print("\n=== 阶段3: 推理模式(循环/自回归) ===")
        # 预热:处理初始上下文
        if len(test_initial) > 0:
            _ = self.forward(test_initial)
        
        # 自回归生成
        generated = []
        current = test_initial[-1] if len(test_initial) > 0 else 0.0
        
        gen_steps = 50
        start_time = time.time()
        for _ in range(gen_steps):
            out = self.forward(current)
            generated.append(out)
            current = out  # 自回归反馈
        gen_time = time.time() - start_time
        
        print(f"循环模式生成 {gen_steps} 步耗时: {gen_time:.4f}s")
        
        return train_output, np.array(generated)
    
    def visualize_dual_mode(self, sequence):
        """
        可视化同一序列在两种模式下的处理结果,验证数值一致性
        """
        # 卷积模式
        self.switch_mode('conv')
        y_conv = self.forward(sequence)
        
        # 循环模式(重置状态后处理相同序列)
        self.switch_mode('recurrent')
        self.state_cache = np.zeros(self.N)  # 强制重置
        y_rec = self.forward(sequence)
        
        # 绘图对比
        fig, axes = plt.subplots(3, 1, figsize=(12, 10))
        
        t = np.arange(len(sequence))
        
        # 1. 输入序列
        axes[0].plot(t, sequence, 'k-', linewidth=1.5, label='Input')
        axes[0].set_title('Input Sequence')
        axes[0].set_xlabel('Time')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # 2. 双模式输出对比
        axes[1].plot(t, y_conv, 'b-', label='Convolution Mode', linewidth=2, alpha=0.8)
        axes[1].plot(t, y_rec, 'r--', label='Recurrent Mode', linewidth=2, alpha=0.8)
        axes[1].set_title('Dual-Mode Output Comparison')
        axes[1].set_xlabel('Time Step')
        axes[1].set_ylabel('Output')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        # 3. 数值误差
        error = np.abs(y_conv - y_rec)
        axes[2].semilogy(t, error + 1e-15, 'g-', linewidth=2)
        axes[2].set_title(f'Numerical Error (Max: {np.max(error):.2e})')
        axes[2].set_xlabel('Time Step')
        axes[2].set_ylabel('Absolute Error')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('dual_mode_comparison.png', dpi=300)
        plt.show()
        
        return y_conv, y_rec, error

if __name__ == "__main__":
    # 初始化双模式系统
    N = 32
    Delta = 0.01
    system = S4DualModeSystem(N=N, Delta=Delta, D=0.01, mode='conv')
    
    # 生成测试数据
    np.random.seed(42)
    L_train = 1024
    L_context = 50
    
    # 训练数据:随机游走信号
    train_seq = np.cumsum(np.random.randn(L_train) * 0.1)
    
    # 测试上下文:正弦波
    context_seq = np.sin(np.linspace(0, 4*np.pi, L_context))
    
    # 执行完整流程
    print("执行训练-推理管道...")
    train_out, gen_out = system.train_inference_pipeline(train_seq, context_seq)
    
    # 可视化训练结果
    fig, axes = plt.subplots(2, 1, figsize=(12, 6))
    
    axes[0].plot(train_seq, 'b-', alpha=0.7, label='Train Input')
    axes[0].plot(train_out, 'r-', alpha=0.7, label='Train Output')
    axes[0].set_title('Training Phase (Convolution Mode)')
    axes[0].set_xlabel('Time Step')
    axes[0].set_ylabel('Value')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # 绘制生成结果
    gen_t = np.arange(len(gen_out))
    axes[1].plot(gen_t, gen_out, 'g-o', markersize=4, linewidth=2, label='Generated')
    axes[1].axvline(x=0, color='k', linestyle='--', alpha=0.5, label='Generation Start')
    axes[1].set_title('Inference Phase (Recurrent Mode - Autoregressive)')
    axes[1].set_xlabel('Generation Step')
    axes[1].set_ylabel('Value')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('train_inference_pipeline.png', dpi=300)
    plt.show()
    
    # 验证双模式数值一致性
    print("\n验证双模式数值一致性...")
    test_seq = np.sin(np.linspace(0, 10*np.pi, 200)) * np.exp(-0.01*np.arange(200))
    y_c, y_r, err = system.visualize_dual_mode(test_seq)
    
    print(f"\n双模式系统验证完成。")
    print(f"最大数值误差: {np.max(err):.2e} (应接近机器精度)")
    print(f"生成序列统计: 均值={np.mean(gen_out):.4f}, 标准差={np.std(gen_out):.4f}")
相关推荐
Westward-sun.2 小时前
OpenCV 实战:SIFT 指纹特征匹配与可视化(补充版)
人工智能·opencv·计算机视觉
财经资讯数据_灵砚智能2 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(日间)2026年4月7日
大数据·人工智能·python·信息可视化·语言模型·自然语言处理·ai编程
凌峰的博客2 小时前
基于注意力流的鲁棒信息隐写方法:从扩散隐写到Attention Flow的新探索
人工智能
初心未改HD2 小时前
从Java转行大模型应用,扣子工作流学习
人工智能
Gary jie2 小时前
AI上下文管理与记忆架构详解
人工智能·机器学习·架构·openclaw
大树882 小时前
【无标题】
大数据·运维·服务器·人工智能
我材不敲代码2 小时前
基于dlib+OpenCV的人脸疲劳检测 + 年龄性别识别实战
人工智能·opencv·计算机视觉
victory04312 小时前
2026年4月7日nanoGPT训练记录
人工智能
人工智能AI技术2 小时前
AI Agent 的 Harness 机制学习思考
人工智能