KAN 网络深度解析

目录

  • [第一篇:Kolmogorov-Arnold 表示定理](#第一篇:Kolmogorov-Arnold 表示定理)
  • [第二篇:KAN 架构与 B 样条实现](#第二篇:KAN 架构与 B 样条实现)
  • [第三篇:KAN 的前沿发展与应用](#第三篇:KAN 的前沿发展与应用)
  • 参考文献

第一篇:Kolmogorov-Arnold 表示定理

1. 引言

自多层感知机(MLP)诞生以来,"固定激活函数 + 可学习线性权重"的范式统治了深度学习数十年。ReLU、GELU、SiLU 等激活函数被广泛使用,但它们的选择往往是经验性的。

KAN(Kolmogorov-Arnold Networks, Liu et al., 2024) 回到了一个根本性的数学问题:什么是神经网络表示函数的最优方式? 答案来自 1957 年的 Kolmogorov-Arnold 表示定理------任何多元连续函数都可以表示为一元函数的有限组合。KAN 将这一定理转化为实用的神经网络架构:

  1. 可学习的激活函数 :激活函数放在网络的(而非节点)上,且是可学习的
  2. 节点仅做求和:每个节点计算输入的简单求和
  3. B 样条参数化:激活函数用 B 样条表示,实现灵活且可微的函数逼近

2. Kolmogorov-Arnold 表示定理

2.1 定理陈述

定理 (Kolmogorov, 1957; Arnold, 1958):对于任意连续函数 f : 0 , 1 n → R f: 0, 1^n \to \mathbb{R} f:0,1n→R,存在连续的一元函数 ϕ q , p : 0 , 1 → R \phi_{q,p}: 0, 1 \to \mathbb{R} ϕq,p:0,1→R 和 Φ q : 0 , 1 → R \Phi_q: 0, 1 \to \mathbb{R} Φq:0,1→R,使得:

f ( x 1 , x 2 , ... , x n ) = ∑ q = 0 2 n Φ q ( ∑ p = 1 n ϕ q , p ( x p ) ) f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q \left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right) f(x1,x2,...,xn)=q=0∑2nΦq(p=1∑nϕq,p(xp))

其中 q q q 的范围是 0 , 1 , ... , 2 n 0, 1, \ldots, 2n 0,1,...,2n。

2.2 定理的深刻含义

含义 1:高维函数的分解

任何多元函数都可以分解为内层一元函数的求和 + 外层一元函数的组合 。这将 n n n 维问题分解为 O ( n ) O(n) O(n) 个一维问题。

含义 2:维度的诅咒被打破

表面上, n n n 维函数需要 O ( n n ) O(n^n) O(nn) 个参数来表示(网格方法)。Kolmogorov-Arnold 定理表明,只需要 O ( n ) O(n) O(n) 个一元函数即可,每个一元函数只需要 O ( 1 ) O(1) O(1) 个参数(在连续函数空间中)。

含义 3:激活函数应该放在边上

传统 MLP 将可学习参数放在线性层的权重上,激活函数是固定的。Kolmogorov-Arnold 定理暗示:可学习的应该是一元函数(激活函数),而非线性权重

2.3 定理的构造性证明

Kolmogorov 的证明是构造性的,但构造的函数 ϕ q , p \phi_{q,p} ϕq,p 和 Φ q \Phi_q Φq 高度病态(非光滑、分形性质),无法直接用于实际计算。

KAN 的贡献:用 B 样条参数化的光滑函数替代病态函数,将定理转化为实用架构。


3. 从 MLP 到 KAN 的范式转变

3.1 MLP 的架构

标准 MLP 的第 l l l 层:

x ( l + 1 ) = σ ( W ( l ) x ( l ) + b ( l ) ) \mathbf{x}^{(l+1)} = \sigma(\mathbf{W}^{(l)} \mathbf{x}^{(l)} + \mathbf{b}^{(l)}) x(l+1)=σ(W(l)x(l)+b(l))

其中:

  • W ( l ) ∈ R n l + 1 × n l \mathbf{W}^{(l)} \in \mathbb{R}^{n_{l+1} \times n_l} W(l)∈Rnl+1×nl:可学习权重矩阵
  • σ \sigma σ:固定的激活函数(如 ReLU)
  • 每个节点:线性组合 + 固定非线性

3.2 KAN 的架构

KAN 的第 l l l 层:

x j ( l + 1 ) = ∑ i = 1 n l ϕ l , j , i ( x i ( l ) ) x_j^{(l+1)} = \sum_{i=1}^{n_l} \phi_{l,j,i}(x_i^{(l)}) xj(l+1)=i=1∑nlϕl,j,i(xi(l))

其中:

  • ϕ l , j , i : R → R \phi_{l,j,i}: \mathbb{R} \to \mathbb{R} ϕl,j,i:R→R:可学习的一元函数
  • 每条边:独立的可学习激活函数
  • 每个节点:仅做求和(无固定激活)

3.3 架构对比

特性 MLP KAN
可学习参数位置 节点(权重矩阵) 边(激活函数)
激活函数 固定(ReLU, GELU) 可学习(B 样条)
节点操作 线性组合 + 固定非线性 仅求和
理论基础 通用逼近定理 Kolmogorov-Arnold 定理
参数效率 低(需要更多层/宽) 高(单层即可逼近复杂函数)

3.4 为什么 KAN 更高效?

MLP 的问题:固定激活函数限制了每层的表达能力。为了逼近复杂函数,MLP 需要更多的层或更宽的层。

KAN 的优势:每条边上的可学习激活函数可以自适应地拟合数据中的非线性模式。单层 KAN 就能表示 MLP 需要多层才能表示的函数。

形式化:对于某些函数类,KAN 的逼近速率指数级快于 MLP。


4. 逼近论分析

4.1 逼近速率对比

定理 (Liu et al., 2024):对于某些光滑函数类 F \mathcal{F} F,KAN 的逼近误差为:

inf ⁡ f θ ∈ KAN ∥ f − f θ ∥ ∞ = O ( n − k ) \inf_{f_\theta \in \text{KAN}} \|f - f_\theta\|_\infty = O(n^{-k}) fθ∈KANinf∥f−fθ∥∞=O(n−k)

其中 n n n 是参数量, k k k 是样条的阶数。而 MLP 的逼近误差为:

inf ⁡ f θ ∈ MLP ∥ f − f θ ∥ ∞ = O ( n − 2 / d ) \inf_{f_\theta \in \text{MLP}} \|f - f_\theta\|_\infty = O(n^{-2/d}) fθ∈MLPinf∥f−fθ∥∞=O(n−2/d)

其中 d d d 是输入维度。当 d d d 很大时,KAN 的逼近速率远快于 MLP。

4.2 为什么 KAN 更适合科学计算?

科学计算中的函数通常具有组合结构(如物理定律的嵌套形式)。Kolmogorov-Arnold 定理天然捕捉了这种组合结构,而 MLP 需要从头学习。

例子 :考虑 f ( x , y ) = sin ⁡ ( x 2 + y 2 ) f(x, y) = \sin(x^2 + y^2) f(x,y)=sin(x2+y2)。

  • MLP:需要学习 sin ⁡ \sin sin、 x 2 x^2 x2、 y 2 y^2 y2、加法的组合,需要多层
  • KAN:自然地分解为 ϕ 1 ( x ) = x 2 \phi_1(x) = x^2 ϕ1(x)=x2、 ϕ 2 ( y ) = y 2 \phi_2(y) = y^2 ϕ2(y)=y2、 Φ ( z ) = sin ⁡ ( z ) \Phi(z) = \sin(z) Φ(z)=sin(z)

5. Kolmogorov-Arnold 定理数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    Kolmogorov-Arnold 表示定理 数学总结                                     ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 定理陈述:                                                                           ║
║     f(x₁,...,xₙ) = Σ_{q=0}^{2n} Φ_q(Σ_{p=1}^{n} φ_{q,p}(x_p))                       ║
║     任意连续多元函数 = 一元函数的有限组合                                                ║
║                                                                                        ║
║  2. 含义:                                                                               ║
║     - 高维函数可分解为 O(n) 个一维问题                                                   ║
║     - 打破维度诅咒: O(n) vs O(n^n)                                                      ║
║     - 可学习的应是一元函数 (激活函数), 而非线性权重                                       ║
║                                                                                        ║
║  3. MLP vs KAN:                                                                         ║
║     MLP: x_{l+1} = σ(W·x_l + b)    (固定 σ, 可学习 W)                                  ║
║     KAN: x_j^{l+1} = Σ_i φ_{l,j,i}(x_i^l)    (可学习 φ, 仅求和)                       ║
║                                                                                        ║
║  4. 逼近速率:                                                                           ║
║     KAN: O(n^{-k}),  k = 样条阶数                                                      ║
║     MLP: O(n^{-2/d}),  d = 输入维度                                                     ║
║     KAN 在高维问题上指数级更快                                                          ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第二篇:KAN 架构与 B 样条实现

1. 引言

KAN 的核心实现问题是如何参数化可学习的一元函数 ϕ ( x ) \phi(x) ϕ(x)。KAN 使用 B 样条(B-spline) 来实现这一点------B 样条既能灵活逼近任意函数,又具有良好的可微性和局部控制性。


2. B 样条基础

2.1 B 样条的定义

B 样条基函数 B i , k ( x ) B_{i,k}(x) Bi,k(x) 是定义在节点向量上的分段多项式。

给定节点向量 t = ( t 0 , t 1 , ... , t m ) \mathbf{t} = (t_0, t_1, \ldots, t_{m}) t=(t0,t1,...,tm), k k k 阶 B 样条基函数递归定义为:

零阶 ( k = 1 k = 1 k=1):

B i , 1 ( x ) = { 1 if t i ≤ x < t i + 1 0 otherwise B_{i,1}(x) = \begin{cases} 1 & \text{if } t_i \leq x < t_{i+1} \\ 0 & \text{otherwise} \end{cases} Bi,1(x)={10if ti≤x<ti+1otherwise

递归 ( k > 1 k > 1 k>1):

B i , k ( x ) = x − t i t i + k − 1 − t i B i , k − 1 ( x ) + t i + k − x t i + k − t i + 1 B i + 1 , k − 1 ( x ) B_{i,k}(x) = \frac{x - t_i}{t_{i+k-1} - t_i} B_{i,k-1}(x) + \frac{t_{i+k} - x}{t_{i+k} - t_{i+1}} B_{i+1,k-1}(x) Bi,k(x)=ti+k−1−tix−tiBi,k−1(x)+ti+k−ti+1ti+k−xBi+1,k−1(x)

其中约定 0 / 0 = 0 0/0 = 0 0/0=0。

2.2 B 样条的性质

性质 说明
局部支撑 B i , k ( x ) B_{i,k}(x) Bi,k(x) 仅在 [ t i , t i + k ) [t_i, t_{i+k}) [ti,ti+k) 上非零
非负性 B i , k ( x ) ≥ 0 B_{i,k}(x) \geq 0 Bi,k(x)≥0
归一性 ∑ i B i , k ( x ) = 1 \sum_i B_{i,k}(x) = 1 ∑iBi,k(x)=1
可微性 B i , k ∈ C k − 2 B_{i,k} \in C^{k-2} Bi,k∈Ck−2( k − 2 k-2 k−2 次连续可微)
局部控制 移动一个节点只影响附近的基函数

2.3 B 样条曲线

B 样条曲线是基函数的加权组合:

ϕ ( x ) = ∑ i = 0 n c i B i , k ( x ) \phi(x) = \sum_{i=0}^{n} c_i B_{i,k}(x) ϕ(x)=i=0∑nciBi,k(x)

其中 c i c_i ci 是控制点 (可学习参数), n + 1 n+1 n+1 是控制点数量。


3. KAN 的激活函数参数化

3.1 B 样条激活函数

KAN 使用 B 样条参数化每条边上的激活函数:

ϕ l , j , i ( x ) = w b ⋅ b ( x ) + w s ⋅ ∑ m = 0 G + k − 1 c m B m , k ( x ) \phi_{l,j,i}(x) = w_b \cdot b(x) + w_s \cdot \sum_{m=0}^{G+k-1} c_m B_{m,k}(x) ϕl,j,i(x)=wb⋅b(x)+ws⋅m=0∑G+k−1cmBm,k(x)

其中:

  • b ( x ) = silu ( x ) = x / ( 1 + e − x ) b(x) = \text{silu}(x) = x / (1 + e^{-x}) b(x)=silu(x)=x/(1+e−x):残差基函数(确保基本非线性)
  • B m , k ( x ) B_{m,k}(x) Bm,k(x): k k k 阶 B 样条基函数
  • c m c_m cm:可学习的控制点
  • G G G:网格大小(grid size)
  • w b , w s w_b, w_s wb,ws:可学习的缩放因子

3.2 网格自适应

初始网格 :在 − R , R -R, R −R,R 上均匀分布 G + 1 G + 1 G+1 个节点,其中 R R R 是输入范围的估计。

网格细化 :训练过程中,可以动态增加网格大小 G G G,以提高逼近精度。

网格更新公式

t new = refine ( t old ) \mathbf{t}{\text{new}} = \text{refine}(\mathbf{t}{\text{old}}) tnew=refine(told)

新节点插入到旧节点的中点,控制点通过插值更新。

3.3 参数量分析

对于一个 KAN 层,连接 n l n_l nl 个输入节点和 n l + 1 n_{l+1} nl+1 个输出节点:

参数 数量 说明
控制点 c m c_m cm n l × n l + 1 × ( G + k ) n_l \times n_{l+1} \times (G + k) nl×nl+1×(G+k) 每条边 G + k G + k G+k 个
缩放因子 w b , w s w_b, w_s wb,ws 2 × n l × n l + 1 2 \times n_l \times n_{l+1} 2×nl×nl+1 每条边 2 个
总计 n l × n l + 1 × ( G + k + 2 ) n_l \times n_{l+1} \times (G + k + 2) nl×nl+1×(G+k+2) ---

对比 MLP 层: n l × n l + 1 n_l \times n_{l+1} nl×nl+1 个权重 + n l + 1 n_{l+1} nl+1 个偏置。

当 G + k > 1 G + k > 1 G+k>1 时,KAN 每层参数量多于 MLP,但逼近能力更强。


4. 完整可运行实现

4.1 B 样条实现

python 复制代码
"""
KAN (Kolmogorov-Arnold Networks) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import List, Tuple, Optional
from dataclasses import dataclass


@dataclass
class KANConfig:
    """KAN 配置"""
    layers: List[int] = None       # 各层宽度, 如 [2, 5, 1]
    grid_size: int = 5             # B 样条网格大小 G
    spline_order: int = 3          # B 样条阶数 k
    grid_range: Tuple[float, float] = (-2.0, 2.0)  # 网格范围
    base_activation: str = "silu"  # 残差基函数

    def __post_init__(self):
        if self.layers is None:
            self.layers = [2, 5, 1]


def compute_bspline_basis(
    x: torch.Tensor,       # (B,) 输入
    grid: torch.Tensor,     # (G + 2k + 1,) 节点向量
    k: int,                 # 样条阶数
) -> torch.Tensor:
    """
    计算 B 样条基函数值。

    使用 Cox-de Boor 递归公式:
    B_{i,1}(x) = 1 if t_i <= x < t_{i+1}, else 0
    B_{i,k}(x) = ((x - t_i) / (t_{i+k-1} - t_i)) * B_{i,k-1}(x)
               + ((t_{i+k} - x) / (t_{i+k} - t_{i+1})) * B_{i+1,k-1}(x)

    返回: (B, n_basis) 基函数值
    """
    # 节点数
    n_nodes = len(grid)

    # 零阶基函数
    # B_{i,1}(x) = 1 if t_i <= x < t_{i+1}
    bases = []
    for i in range(n_nodes - 1):
        left = grid[i]
        right = grid[i + 1]
        basis = ((x >= left) & (x < right)).float()
        bases.append(basis)

    bases = torch.stack(bases, dim=-1)  # (B, n_nodes - 1)

    # 递归计算高阶基函数
    for order in range(2, k + 1):
        new_bases = []
        n_cur = bases.shape[-1]

        for i in range(n_cur - 1):
            # 左项: (x - t_i) / (t_{i+order-1} - t_i) * B_{i,order-1}
            left_node = grid[i]
            left_node_next = grid[i + order - 1]
            denom_left = left_node_next - left_node
            if denom_left > 0:
                left_term = ((x - left_node) / denom_left) * bases[:, i]
            else:
                left_term = torch.zeros_like(x)

            # 右项: (t_{i+order} - x) / (t_{i+order} - t_{i+1}) * B_{i+1,order-1}
            right_node = grid[i + order]
            right_node_next = grid[i + 1]
            denom_right = right_node - right_node_next
            if denom_right > 0:
                right_term = ((right_node - x) / denom_right) * bases[:, i + 1]
            else:
                right_term = torch.zeros_like(x)

            new_bases.append(left_term + right_term)

        bases = torch.stack(new_bases, dim=-1)

    return bases  # (B, n_basis)

4.2 KAN Layer 实现

python 复制代码
class KANLayer(nn.Module):
    """KAN 层: 每条边是可学习的 B 样条激活函数"""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        grid_size: int = 5,
        spline_order: int = 3,
        grid_range: Tuple[float, float] = (-2.0, 2.0),
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        # 构建节点向量
        # 均匀网格 + 两端扩展 k 个节点
        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = torch.arange(
            grid_range[0] - spline_order * h,
            grid_range[1] + (spline_order + 1) * h,
            h
        )
        self.register_buffer("grid", grid)

        n_basis = len(grid) - spline_order - 1

        # 每条边的可学习控制点
        # shape: (out_features, in_features, n_basis)
        self.spline_weight = nn.Parameter(
            torch.randn(out_features, in_features, n_basis) * 0.1
        )

        # 残差基函数的缩放因子
        self.base_weight = nn.Parameter(
            torch.randn(out_features, in_features) * 0.1
        )

        # B 样条的缩放因子
        self.spline_scale = nn.Parameter(
            torch.ones(out_features, in_features)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, in_features)
        返回: (B, out_features)

        每个输出节点 = 所有输入节点的激活函数之和
        """
        B = x.shape[0]

        # 计算每个输入的 B 样条基函数值
        # x: (B, in_features) → 对每个特征分别计算
        basis_values = []  # in_features 个 (B, n_basis)
        for i in range(self.in_features):
            basis_i = compute_bspline_basis(x[:, i], self.grid, self.spline_order)
            basis_values.append(basis_i)

        basis_values = torch.stack(basis_values, dim=1)  # (B, in_features, n_basis)

        # B 样条激活: sum(c_m * B_m(x))
        # spline_weight: (out, in, n_basis)
        # basis_values: (B, in, n_basis)
        spline_activation = torch.einsum(
            "oin,bin->boi", self.spline_weight, basis_values
        )  # (B, out_features, in_features)

        # 残差基函数 (SiLU)
        if self.base_weight.requires_grad:
            base_activation = F.silu(x)  # (B, in_features)
            base_contribution = torch.einsum(
                "oi,bi->boi", self.base_weight, base_activation
            )  # (B, out_features, in_features)
        else:
            base_contribution = 0

        # 合并: phi(x) = w_b * silu(x) + w_s * spline(x)
        activation = base_contribution + self.spline_scale.unsqueeze(0) * spline_activation

        # 节点求和: output_j = sum_i phi_{j,i}(x_i)
        output = activation.sum(dim=-1)  # (B, out_features)

        return output

4.3 完整 KAN 网络

python 复制代码
class KAN(nn.Module):
    """Kolmogorov-Arnold Network"""

    def __init__(self, config: KANConfig):
        super().__init__()
        self.config = config

        self.layers = nn.ModuleList()
        for i in range(len(config.layers) - 1):
            self.layers.append(
                KANLayer(
                    in_features=config.layers[i],
                    out_features=config.layers[i + 1],
                    grid_size=config.grid_size,
                    spline_order=config.spline_order,
                    grid_range=config.grid_range,
                )
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

    def get_activation_values(
        self, layer_idx: int, edge_idx: Tuple[int, int], x: torch.Tensor
    ) -> torch.Tensor:
        """获取指定边的激活函数值 (用于可视化)"""
        layer = self.layers[layer_idx]
        i, j = edge_idx  # 输入 i, 输出 j

        basis = compute_bspline_basis(x, layer.grid, layer.spline_order)
        spline_val = (layer.spline_weight[j, i] * basis).sum(dim=-1)
        base_val = layer.base_weight[j, i] * F.silu(x)

        return base_val + layer.spline_scale[j, i] * spline_val

4.4 实验代码

python 复制代码
def experiment_kan_function_fitting():
    """KAN 函数拟合实验"""
    torch.manual_seed(42)
    device = torch.device("cpu")

    # 目标函数: f(x, y) = sin(x^2 + y^2)
    def target_fn(x):
        return torch.sin(x[:, 0:1] ** 2 + x[:, 1:2] ** 2)

    # 生成训练数据
    n_train = 1000
    x_train = torch.randn(n_train, 2, device=device) * 1.5
    y_train = target_fn(x_train)

    # 创建 KAN
    config = KANConfig(layers=[2, 5, 1], grid_size=8, spline_order=3)
    model = KAN(config).to(device)

    # 训练
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    losses = []

    for step in range(2000):
        idx = torch.randint(0, n_train, (256,))
        x_batch = x_train[idx]
        y_batch = y_train[idx]

        y_pred = model(x_batch)
        loss = F.mse_loss(y_pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if (step + 1) % 500 == 0:
            print(f"Step {step+1} | Loss: {np.mean(losses[-100:]):.6f}")

    # 测试
    x_test = torch.randn(500, 2, device=device) * 1.5
    y_test = target_fn(x_test)
    with torch.no_grad():
        y_pred = model(x_test)
    test_mse = F.mse_loss(y_pred, y_test).item()
    print(f"\n测试 MSE: {test_mse:.6f}")

    return model, losses


def compare_kan_vs_mlp():
    """对比 KAN 和 MLP 的拟合能力"""
    torch.manual_seed(42)

    # 目标函数: f(x) = sin(3x) * exp(-x^2)
    def target_fn(x):
        return torch.sin(3 * x) * torch.exp(-x ** 2)

    # 训练数据
    x_train = torch.linspace(-3, 3, 200).unsqueeze(1)
    y_train = target_fn(x_train)

    # KAN
    config = KANConfig(layers=[1, 5, 1], grid_size=10, spline_order=3)
    kan = KAN(config)
    optimizer_kan = torch.optim.Adam(kan.parameters(), lr=1e-2)

    # MLP (类似参数量)
    mlp = nn.Sequential(
        nn.Linear(1, 20), nn.SiLU(),
        nn.Linear(20, 20), nn.SiLU(),
        nn.Linear(20, 1)
    )
    optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-3)

    # 训练
    for step in range(3000):
        # KAN
        y_pred_kan = kan(x_train)
        loss_kan = F.mse_loss(y_pred_kan, y_train)
        optimizer_kan.zero_grad()
        loss_kan.backward()
        optimizer_kan.step()

        # MLP
        y_pred_mlp = mlp(x_train)
        loss_mlp = F.mse_loss(y_pred_mlp, y_train)
        optimizer_mlp.zero_grad()
        loss_mlp.backward()
        optimizer_mlp.step()

        if (step + 1) % 1000 == 0:
            print(f"Step {step+1} | KAN Loss: {loss_kan.item():.6f} | MLP Loss: {loss_mlp.item():.6f}")

    # 参数量
    kan_params = sum(p.numel() for p in kan.parameters())
    mlp_params = sum(p.numel() for p in mlp.parameters())
    print(f"\nKAN 参数量: {kan_params}")
    print(f"MLP 参数量: {mlp_params}")

5. KAN 架构数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    KAN 架构与 B 样条 数学总结                                              ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. B 样条基函数 (Cox-de Boor 递归):                                                    ║
║     B_{i,1}(x) = 1 if t_i ≤ x < t_{i+1}, else 0                                       ║
║     B_{i,k}(x) = ((x-t_i)/(t_{i+k-1}-t_i))·B_{i,k-1}(x)                              ║
║                  + ((t_{i+k}-x)/(t_{i+k}-t_{i+1}))·B_{i+1,k-1}(x)                     ║
║                                                                                        ║
║  2. KAN 激活函数:                                                                       ║
║     φ(x) = w_b·silu(x) + w_s·Σ_m c_m·B_m(x)                                          ║
║     w_b, w_s, c_m 均可学习                                                             ║
║                                                                                        ║
║  3. KAN 层:                                                                             ║
║     x_j^{l+1} = Σ_i φ_{l,j,i}(x_i^l)                                                  ║
║     节点仅求和, 非线性来自边上的激活函数                                                 ║
║                                                                                        ║
║  4. 参数量:                                                                             ║
║     KAN 层: n_in × n_out × (G + k + 2)                                                 ║
║     MLP 层: n_in × n_out + n_out                                                       ║
║     KAN 每层参数更多, 但需要更少的层                                                    ║
║                                                                                        ║
║  5. B 样条性质:                                                                         ║
║     - 局部支撑: 移动一个节点只影响附近                                                  ║
║     - 非负性: B_{i,k}(x) ≥ 0                                                           ║
║     - 归一性: Σ_i B_{i,k}(x) = 1                                                       ║
║     - 可微性: B_{i,k} ∈ C^{k-2}                                                        ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

第三篇:KAN 的前沿发展与应用

1. 引言

KAN 自 2024 年 5 月发表以来,迅速引发了大量 follow-up 工作,涵盖理论分析、架构改进和应用拓展。


2. KAN 的理论分析

2.1 KAN 的通用逼近定理

定理 (Liu et al., 2024):对于任意连续函数 f : 0 , 1 n → R f: 0, 1^n \to \mathbb{R} f:0,1n→R 和任意 ϵ > 0 \epsilon > 0 ϵ>0,存在一个 KAN 网络 f θ f_\theta fθ 使得:

∥ f − f θ ∥ ∞ < ϵ \|f - f_\theta\|_\infty < \epsilon ∥f−fθ∥∞<ϵ

这比 Kolmogorov-Arnold 定理更强------原始定理中的函数是病态的,而 KAN 使用光滑的 B 样条。

2.2 KAN 的逼近速率

定理 :对于 k k k 次连续可微的函数类 C k C^k Ck,使用 k k k 阶 B 样条的 KAN 的逼近误差为:

∥ f − f θ ∥ ∞ = O ( G − k ) \|f - f_\theta\|_\infty = O(G^{-k}) ∥f−fθ∥∞=O(G−k)

其中 G G G 是网格大小。增加网格大小可以指数级减小误差。

2.3 KAN 的可解释性

KAN 的一个重要优势是可解释性------每条边上的激活函数可以独立可视化,揭示数据中的函数关系。

符号回归:KAN 可以自动发现数学公式。通过分析学习到的激活函数形状,可以推断出底层的数学关系。


3. KAN 的架构变体

3.1 FourierKAN

用傅里叶基函数替代 B 样条:

ϕ ( x ) = ∑ m = 0 M ( a m cos ⁡ ( m ω x ) + b m sin ⁡ ( m ω x ) ) \phi(x) = \sum_{m=0}^{M} \left( a_m \cos(m \omega x) + b_m \sin(m \omega x) \right) ϕ(x)=m=0∑M(amcos(mωx)+bmsin(mωx))

优势:更适合周期性函数。

3.2 WaveletKAN

用小波基函数替代 B 样条:

ϕ ( x ) = ∑ m c m ψ ( x − t m s ) \phi(x) = \sum_{m} c_m \psi\left(\frac{x - t_m}{s}\right) ϕ(x)=m∑cmψ(sx−tm)

优势:多尺度分析能力,适合具有局部特征的函数。

3.3 ReLU-KAN

用分段线性函数(ReLU 的组合)替代 B 样条:

ϕ ( x ) = ∑ m c m ReLU ( x − t m ) \phi(x) = \sum_{m} c_m \text{ReLU}(x - t_m) ϕ(x)=m∑cmReLU(x−tm)

优势:计算效率更高,与现有硬件更兼容。

3.4 ChebyKAN

用切比雪夫多项式替代 B 样条:

ϕ ( x ) = ∑ m = 0 M c m T m ( x ) \phi(x) = \sum_{m=0}^{M} c_m T_m(x) ϕ(x)=m=0∑McmTm(x)

其中 T m T_m Tm 是 m m m 阶切比雪夫多项式。

优势 :在 − 1 , 1 -1, 1 −1,1 上具有最优的多项式逼近性质。


4. KAN 的应用场景

4.1 科学计算

KAN 在科学计算中表现出色,因为它天然捕捉函数的组合结构:

任务 KAN 优势
物理定律发现 可解释的激活函数揭示物理关系
PDE 求解 高效逼近高维函数
分子动力学 捕捉原子间相互作用的函数形式

4.2 符号回归

KAN 可以自动发现数学公式:

  1. 训练 KAN 拟合数据
  2. 分析激活函数形状
  3. 用符号表达式替换 B 样条
  4. 简化得到最终公式

4.3 机器学习

任务 KAN 表现
函数拟合 参数效率更高
回归 更少的参数达到相同精度
分类 与 MLP 相当,但更可解释

5. KAN 与 MLP 的全面对比

维度 MLP KAN
理论基础 通用逼近定理 Kolmogorov-Arnold 定理
参数位置 节点(权重矩阵) 边(激活函数)
激活函数 固定 可学习
逼近速率 O ( n − 2 / d ) O(n^{-2/d}) O(n−2/d) O ( n − k ) O(n^{-k}) O(n−k)
参数效率
可解释性 高(可视化激活函数)
计算效率 高(矩阵乘法) 中(B 样条计算)
硬件支持 成熟 待优化
适用场景 通用 科学计算、符号回归

6. 前沿研究方向

6.1 KAN + Transformer

将 KAN 替换 Transformer 中的 MLP 层:

Attention → KAN → Attention → KAN → ⋯ \text{Attention} \to \text{KAN} \to \text{Attention} \to \text{KAN} \to \cdots Attention→KAN→Attention→KAN→⋯

初步实验表明,在某些任务上 KAN-Transformer 优于标准 Transformer。

6.2 KAN + GNN

将 KAN 应用于图神经网络的消息传递:

m i j = ϕ KAN ( x i , x j , e i j ) m_{ij} = \phi_{\text{KAN}}(x_i, x_j, e_{ij}) mij=ϕKAN(xi,xj,eij)

6.3 KAN 的理论深化

  1. 泛化理论:分析 KAN 的泛化误差界
  2. 优化理论:分析 KAN 训练的收敛性
  3. 信息论:分析 KAN 的信息压缩能力

7. 前沿发展数学公式总结

复制代码
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║                    KAN 前沿发展 数学总结                                                   ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║                                                                                        ║
║  1. 通用逼近:                                                                           ║
║     ∀f ∈ C([0,1]^n), ∀ε > 0, ∃KAN f_θ: ‖f - f_θ‖_∞ < ε                               ║
║                                                                                        ║
║  2. 逼近速率:                                                                           ║
║     KAN (B 样条): O(G^{-k}),  G=网格大小, k=样条阶数                                    ║
║     MLP: O(n^{-2/d}),  d=输入维度                                                       ║
║                                                                                        ║
║  3. KAN 变体:                                                                           ║
║     FourierKAN:  φ(x) = Σ (a_m·cos(mωx) + b_m·sin(mωx))                               ║
║     ChebyKAN:    φ(x) = Σ c_m·T_m(x)                                                   ║
║     ReLU-KAN:    φ(x) = Σ c_m·ReLU(x - t_m)                                            ║
║                                                                                        ║
║  4. 符号回归流程:                                                                       ║
║     训练 KAN → 分析激活函数 → 符号替换 → 简化公式                                        ║
║                                                                                        ║
║  5. KAN + Transformer:                                                                  ║
║     Attention → KAN → Attention → KAN → ...                                             ║
║     初步实验优于标准 Transformer                                                         ║
║                                                                                        ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝

参考文献

核心论文

  1. Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., Hou, T. Y., & Tegmark, M. (2024). KAN: Kolmogorov-Arnold Networks. arXiv:2404.19756.
  2. Kolmogorov, A. N. (1957). On the representation of continuous functions of many variables by superposition of continuous functions of one variable and addition. Doklady Akademii Nauk SSSR, 114, 953-956.
  3. Arnold, V. I. (1958). On the representation of continuous functions of three variables by superpositions of continuous functions of two variables. Doklady Akademii Nauk SSSR, 114, 679-681.

B 样条理论

  1. De Boor, C. (1978). A Practical Guide to Splines. Springer-Verlag.
  2. Schoenberg, I. J. (1946). Contributions to the problem of approximation of equidistant data by analytic functions. Quarterly of Applied Mathematics, 4, 45-99.

KAN 变体

  1. Xu, Z. (2024). FourierKAN-GCF: Fourier Kolmogorov-Arnold Network for Graph Collaborative Filtering. arXiv:2406.01006.
  2. Bodner, A. D., Tepsich, A. S., Spinosa, J. P., & Pinto, J. (2024). Wav-KAN: Wavelet Kolmogorov-Arnold Networks. arXiv:2405.12832.

逼近论

  1. Cybenko, G. (1989). Approximation by superpositions of a sigmoidal function. Mathematics of Control, Signals and Systems, 2(4), 303-314.
  2. Hornik, K. (1991). Approximation capabilities of multilayer feedforward networks. Neural Networks, 4(2), 251-257.
相关推荐
j7~1 小时前
【算法】专题一:双指针之移动零,复写零,快乐数
数据结构·c++·算法·双指针·快乐数·移动零·复写零
阿里matlab建模师2 小时前
【机场停机位分配】matlab实现基于遗传算法的机场停机位分配优化研究
开发语言·算法·数学建模·matlab·全国大学生数学建模竞赛
小雨下雨的雨7 小时前
井字棋AI机器人实现详解 - Minimax算法实战-鸿蒙PC Electron框架完成
前端·人工智能·算法·华为·electron·鸿蒙
xieliyu.10 小时前
Java算法精讲:双指针(三)
java·开发语言·算法
一条小锦吕*10 小时前
基于Spring Boot + 数据可视化 + 协同过滤算法的推荐系统设计与实现(源码+论文+部署全讲解)
spring boot·算法·信息可视化
如竟没有火炬12 小时前
最大矩阵——单调栈
数据结构·python·线性代数·算法·leetcode·矩阵
8Qi812 小时前
LeetCode 1143 & 718:最长公共子序列 / 最长重复子数组
算法·leetcode·职场和发展·动态规划
绿算技术13 小时前
万卡推理集群存储选型分析:从核心架构到应用视角
大数据·科技·算法·架构