目录
- [第一篇:Kolmogorov-Arnold 表示定理](#第一篇:Kolmogorov-Arnold 表示定理)
- [第二篇:KAN 架构与 B 样条实现](#第二篇:KAN 架构与 B 样条实现)
- [第三篇:KAN 的前沿发展与应用](#第三篇:KAN 的前沿发展与应用)
- 参考文献
第一篇:Kolmogorov-Arnold 表示定理
1. 引言
自多层感知机(MLP)诞生以来,"固定激活函数 + 可学习线性权重"的范式统治了深度学习数十年。ReLU、GELU、SiLU 等激活函数被广泛使用,但它们的选择往往是经验性的。
KAN(Kolmogorov-Arnold Networks, Liu et al., 2024) 回到了一个根本性的数学问题:什么是神经网络表示函数的最优方式? 答案来自 1957 年的 Kolmogorov-Arnold 表示定理------任何多元连续函数都可以表示为一元函数的有限组合。KAN 将这一定理转化为实用的神经网络架构:
- 可学习的激活函数 :激活函数放在网络的边(而非节点)上,且是可学习的
- 节点仅做求和:每个节点计算输入的简单求和
- B 样条参数化:激活函数用 B 样条表示,实现灵活且可微的函数逼近
2. Kolmogorov-Arnold 表示定理
2.1 定理陈述
定理 (Kolmogorov, 1957; Arnold, 1958):对于任意连续函数 f : 0 , 1 n → R f: 0, 1^n \to \mathbb{R} f:0,1n→R,存在连续的一元函数 ϕ q , p : 0 , 1 → R \phi_{q,p}: 0, 1 \to \mathbb{R} ϕq,p:0,1→R 和 Φ q : 0 , 1 → R \Phi_q: 0, 1 \to \mathbb{R} Φq:0,1→R,使得:
f ( x 1 , x 2 , ... , x n ) = ∑ q = 0 2 n Φ q ( ∑ p = 1 n ϕ q , p ( x p ) ) f(x_1, x_2, \ldots, x_n) = \sum_{q=0}^{2n} \Phi_q \left( \sum_{p=1}^{n} \phi_{q,p}(x_p) \right) f(x1,x2,...,xn)=q=0∑2nΦq(p=1∑nϕq,p(xp))
其中 q q q 的范围是 0 , 1 , ... , 2 n 0, 1, \ldots, 2n 0,1,...,2n。
2.2 定理的深刻含义
含义 1:高维函数的分解
任何多元函数都可以分解为内层一元函数的求和 + 外层一元函数的组合 。这将 n n n 维问题分解为 O ( n ) O(n) O(n) 个一维问题。
含义 2:维度的诅咒被打破
表面上, n n n 维函数需要 O ( n n ) O(n^n) O(nn) 个参数来表示(网格方法)。Kolmogorov-Arnold 定理表明,只需要 O ( n ) O(n) O(n) 个一元函数即可,每个一元函数只需要 O ( 1 ) O(1) O(1) 个参数(在连续函数空间中)。
含义 3:激活函数应该放在边上
传统 MLP 将可学习参数放在线性层的权重上,激活函数是固定的。Kolmogorov-Arnold 定理暗示:可学习的应该是一元函数(激活函数),而非线性权重。
2.3 定理的构造性证明
Kolmogorov 的证明是构造性的,但构造的函数 ϕ q , p \phi_{q,p} ϕq,p 和 Φ q \Phi_q Φq 高度病态(非光滑、分形性质),无法直接用于实际计算。
KAN 的贡献:用 B 样条参数化的光滑函数替代病态函数,将定理转化为实用架构。
3. 从 MLP 到 KAN 的范式转变
3.1 MLP 的架构
标准 MLP 的第 l l l 层:
x ( l + 1 ) = σ ( W ( l ) x ( l ) + b ( l ) ) \mathbf{x}^{(l+1)} = \sigma(\mathbf{W}^{(l)} \mathbf{x}^{(l)} + \mathbf{b}^{(l)}) x(l+1)=σ(W(l)x(l)+b(l))
其中:
- W ( l ) ∈ R n l + 1 × n l \mathbf{W}^{(l)} \in \mathbb{R}^{n_{l+1} \times n_l} W(l)∈Rnl+1×nl:可学习权重矩阵
- σ \sigma σ:固定的激活函数(如 ReLU)
- 每个节点:线性组合 + 固定非线性
3.2 KAN 的架构
KAN 的第 l l l 层:
x j ( l + 1 ) = ∑ i = 1 n l ϕ l , j , i ( x i ( l ) ) x_j^{(l+1)} = \sum_{i=1}^{n_l} \phi_{l,j,i}(x_i^{(l)}) xj(l+1)=i=1∑nlϕl,j,i(xi(l))
其中:
- ϕ l , j , i : R → R \phi_{l,j,i}: \mathbb{R} \to \mathbb{R} ϕl,j,i:R→R:可学习的一元函数
- 每条边:独立的可学习激活函数
- 每个节点:仅做求和(无固定激活)
3.3 架构对比
| 特性 | MLP | KAN |
|---|---|---|
| 可学习参数位置 | 节点(权重矩阵) | 边(激活函数) |
| 激活函数 | 固定(ReLU, GELU) | 可学习(B 样条) |
| 节点操作 | 线性组合 + 固定非线性 | 仅求和 |
| 理论基础 | 通用逼近定理 | Kolmogorov-Arnold 定理 |
| 参数效率 | 低(需要更多层/宽) | 高(单层即可逼近复杂函数) |
3.4 为什么 KAN 更高效?
MLP 的问题:固定激活函数限制了每层的表达能力。为了逼近复杂函数,MLP 需要更多的层或更宽的层。
KAN 的优势:每条边上的可学习激活函数可以自适应地拟合数据中的非线性模式。单层 KAN 就能表示 MLP 需要多层才能表示的函数。
形式化:对于某些函数类,KAN 的逼近速率指数级快于 MLP。
4. 逼近论分析
4.1 逼近速率对比
定理 (Liu et al., 2024):对于某些光滑函数类 F \mathcal{F} F,KAN 的逼近误差为:
inf f θ ∈ KAN ∥ f − f θ ∥ ∞ = O ( n − k ) \inf_{f_\theta \in \text{KAN}} \|f - f_\theta\|_\infty = O(n^{-k}) fθ∈KANinf∥f−fθ∥∞=O(n−k)
其中 n n n 是参数量, k k k 是样条的阶数。而 MLP 的逼近误差为:
inf f θ ∈ MLP ∥ f − f θ ∥ ∞ = O ( n − 2 / d ) \inf_{f_\theta \in \text{MLP}} \|f - f_\theta\|_\infty = O(n^{-2/d}) fθ∈MLPinf∥f−fθ∥∞=O(n−2/d)
其中 d d d 是输入维度。当 d d d 很大时,KAN 的逼近速率远快于 MLP。
4.2 为什么 KAN 更适合科学计算?
科学计算中的函数通常具有组合结构(如物理定律的嵌套形式)。Kolmogorov-Arnold 定理天然捕捉了这种组合结构,而 MLP 需要从头学习。
例子 :考虑 f ( x , y ) = sin ( x 2 + y 2 ) f(x, y) = \sin(x^2 + y^2) f(x,y)=sin(x2+y2)。
- MLP:需要学习 sin \sin sin、 x 2 x^2 x2、 y 2 y^2 y2、加法的组合,需要多层
- KAN:自然地分解为 ϕ 1 ( x ) = x 2 \phi_1(x) = x^2 ϕ1(x)=x2、 ϕ 2 ( y ) = y 2 \phi_2(y) = y^2 ϕ2(y)=y2、 Φ ( z ) = sin ( z ) \Phi(z) = \sin(z) Φ(z)=sin(z)
5. Kolmogorov-Arnold 定理数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ Kolmogorov-Arnold 表示定理 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 定理陈述: ║
║ f(x₁,...,xₙ) = Σ_{q=0}^{2n} Φ_q(Σ_{p=1}^{n} φ_{q,p}(x_p)) ║
║ 任意连续多元函数 = 一元函数的有限组合 ║
║ ║
║ 2. 含义: ║
║ - 高维函数可分解为 O(n) 个一维问题 ║
║ - 打破维度诅咒: O(n) vs O(n^n) ║
║ - 可学习的应是一元函数 (激活函数), 而非线性权重 ║
║ ║
║ 3. MLP vs KAN: ║
║ MLP: x_{l+1} = σ(W·x_l + b) (固定 σ, 可学习 W) ║
║ KAN: x_j^{l+1} = Σ_i φ_{l,j,i}(x_i^l) (可学习 φ, 仅求和) ║
║ ║
║ 4. 逼近速率: ║
║ KAN: O(n^{-k}), k = 样条阶数 ║
║ MLP: O(n^{-2/d}), d = 输入维度 ║
║ KAN 在高维问题上指数级更快 ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第二篇:KAN 架构与 B 样条实现
1. 引言
KAN 的核心实现问题是如何参数化可学习的一元函数 ϕ ( x ) \phi(x) ϕ(x)。KAN 使用 B 样条(B-spline) 来实现这一点------B 样条既能灵活逼近任意函数,又具有良好的可微性和局部控制性。
2. B 样条基础
2.1 B 样条的定义
B 样条基函数 B i , k ( x ) B_{i,k}(x) Bi,k(x) 是定义在节点向量上的分段多项式。
给定节点向量 t = ( t 0 , t 1 , ... , t m ) \mathbf{t} = (t_0, t_1, \ldots, t_{m}) t=(t0,t1,...,tm), k k k 阶 B 样条基函数递归定义为:
零阶 ( k = 1 k = 1 k=1):
B i , 1 ( x ) = { 1 if t i ≤ x < t i + 1 0 otherwise B_{i,1}(x) = \begin{cases} 1 & \text{if } t_i \leq x < t_{i+1} \\ 0 & \text{otherwise} \end{cases} Bi,1(x)={10if ti≤x<ti+1otherwise
递归 ( k > 1 k > 1 k>1):
B i , k ( x ) = x − t i t i + k − 1 − t i B i , k − 1 ( x ) + t i + k − x t i + k − t i + 1 B i + 1 , k − 1 ( x ) B_{i,k}(x) = \frac{x - t_i}{t_{i+k-1} - t_i} B_{i,k-1}(x) + \frac{t_{i+k} - x}{t_{i+k} - t_{i+1}} B_{i+1,k-1}(x) Bi,k(x)=ti+k−1−tix−tiBi,k−1(x)+ti+k−ti+1ti+k−xBi+1,k−1(x)
其中约定 0 / 0 = 0 0/0 = 0 0/0=0。
2.2 B 样条的性质
| 性质 | 说明 |
|---|---|
| 局部支撑 | B i , k ( x ) B_{i,k}(x) Bi,k(x) 仅在 [ t i , t i + k ) [t_i, t_{i+k}) [ti,ti+k) 上非零 |
| 非负性 | B i , k ( x ) ≥ 0 B_{i,k}(x) \geq 0 Bi,k(x)≥0 |
| 归一性 | ∑ i B i , k ( x ) = 1 \sum_i B_{i,k}(x) = 1 ∑iBi,k(x)=1 |
| 可微性 | B i , k ∈ C k − 2 B_{i,k} \in C^{k-2} Bi,k∈Ck−2( k − 2 k-2 k−2 次连续可微) |
| 局部控制 | 移动一个节点只影响附近的基函数 |
2.3 B 样条曲线
B 样条曲线是基函数的加权组合:
ϕ ( x ) = ∑ i = 0 n c i B i , k ( x ) \phi(x) = \sum_{i=0}^{n} c_i B_{i,k}(x) ϕ(x)=i=0∑nciBi,k(x)
其中 c i c_i ci 是控制点 (可学习参数), n + 1 n+1 n+1 是控制点数量。
3. KAN 的激活函数参数化
3.1 B 样条激活函数
KAN 使用 B 样条参数化每条边上的激活函数:
ϕ l , j , i ( x ) = w b ⋅ b ( x ) + w s ⋅ ∑ m = 0 G + k − 1 c m B m , k ( x ) \phi_{l,j,i}(x) = w_b \cdot b(x) + w_s \cdot \sum_{m=0}^{G+k-1} c_m B_{m,k}(x) ϕl,j,i(x)=wb⋅b(x)+ws⋅m=0∑G+k−1cmBm,k(x)
其中:
- b ( x ) = silu ( x ) = x / ( 1 + e − x ) b(x) = \text{silu}(x) = x / (1 + e^{-x}) b(x)=silu(x)=x/(1+e−x):残差基函数(确保基本非线性)
- B m , k ( x ) B_{m,k}(x) Bm,k(x): k k k 阶 B 样条基函数
- c m c_m cm:可学习的控制点
- G G G:网格大小(grid size)
- w b , w s w_b, w_s wb,ws:可学习的缩放因子
3.2 网格自适应
初始网格 :在 − R , R -R, R −R,R 上均匀分布 G + 1 G + 1 G+1 个节点,其中 R R R 是输入范围的估计。
网格细化 :训练过程中,可以动态增加网格大小 G G G,以提高逼近精度。
网格更新公式:
t new = refine ( t old ) \mathbf{t}{\text{new}} = \text{refine}(\mathbf{t}{\text{old}}) tnew=refine(told)
新节点插入到旧节点的中点,控制点通过插值更新。
3.3 参数量分析
对于一个 KAN 层,连接 n l n_l nl 个输入节点和 n l + 1 n_{l+1} nl+1 个输出节点:
| 参数 | 数量 | 说明 |
|---|---|---|
| 控制点 c m c_m cm | n l × n l + 1 × ( G + k ) n_l \times n_{l+1} \times (G + k) nl×nl+1×(G+k) | 每条边 G + k G + k G+k 个 |
| 缩放因子 w b , w s w_b, w_s wb,ws | 2 × n l × n l + 1 2 \times n_l \times n_{l+1} 2×nl×nl+1 | 每条边 2 个 |
| 总计 | n l × n l + 1 × ( G + k + 2 ) n_l \times n_{l+1} \times (G + k + 2) nl×nl+1×(G+k+2) | --- |
对比 MLP 层: n l × n l + 1 n_l \times n_{l+1} nl×nl+1 个权重 + n l + 1 n_{l+1} nl+1 个偏置。
当 G + k > 1 G + k > 1 G+k>1 时,KAN 每层参数量多于 MLP,但逼近能力更强。
4. 完整可运行实现
4.1 B 样条实现
python
"""
KAN (Kolmogorov-Arnold Networks) --- 完整可运行实现
依赖: torch >= 2.0, numpy, matplotlib
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import List, Tuple, Optional
from dataclasses import dataclass
@dataclass
class KANConfig:
"""KAN 配置"""
layers: List[int] = None # 各层宽度, 如 [2, 5, 1]
grid_size: int = 5 # B 样条网格大小 G
spline_order: int = 3 # B 样条阶数 k
grid_range: Tuple[float, float] = (-2.0, 2.0) # 网格范围
base_activation: str = "silu" # 残差基函数
def __post_init__(self):
if self.layers is None:
self.layers = [2, 5, 1]
def compute_bspline_basis(
x: torch.Tensor, # (B,) 输入
grid: torch.Tensor, # (G + 2k + 1,) 节点向量
k: int, # 样条阶数
) -> torch.Tensor:
"""
计算 B 样条基函数值。
使用 Cox-de Boor 递归公式:
B_{i,1}(x) = 1 if t_i <= x < t_{i+1}, else 0
B_{i,k}(x) = ((x - t_i) / (t_{i+k-1} - t_i)) * B_{i,k-1}(x)
+ ((t_{i+k} - x) / (t_{i+k} - t_{i+1})) * B_{i+1,k-1}(x)
返回: (B, n_basis) 基函数值
"""
# 节点数
n_nodes = len(grid)
# 零阶基函数
# B_{i,1}(x) = 1 if t_i <= x < t_{i+1}
bases = []
for i in range(n_nodes - 1):
left = grid[i]
right = grid[i + 1]
basis = ((x >= left) & (x < right)).float()
bases.append(basis)
bases = torch.stack(bases, dim=-1) # (B, n_nodes - 1)
# 递归计算高阶基函数
for order in range(2, k + 1):
new_bases = []
n_cur = bases.shape[-1]
for i in range(n_cur - 1):
# 左项: (x - t_i) / (t_{i+order-1} - t_i) * B_{i,order-1}
left_node = grid[i]
left_node_next = grid[i + order - 1]
denom_left = left_node_next - left_node
if denom_left > 0:
left_term = ((x - left_node) / denom_left) * bases[:, i]
else:
left_term = torch.zeros_like(x)
# 右项: (t_{i+order} - x) / (t_{i+order} - t_{i+1}) * B_{i+1,order-1}
right_node = grid[i + order]
right_node_next = grid[i + 1]
denom_right = right_node - right_node_next
if denom_right > 0:
right_term = ((right_node - x) / denom_right) * bases[:, i + 1]
else:
right_term = torch.zeros_like(x)
new_bases.append(left_term + right_term)
bases = torch.stack(new_bases, dim=-1)
return bases # (B, n_basis)
4.2 KAN Layer 实现
python
class KANLayer(nn.Module):
"""KAN 层: 每条边是可学习的 B 样条激活函数"""
def __init__(
self,
in_features: int,
out_features: int,
grid_size: int = 5,
spline_order: int = 3,
grid_range: Tuple[float, float] = (-2.0, 2.0),
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# 构建节点向量
# 均匀网格 + 两端扩展 k 个节点
h = (grid_range[1] - grid_range[0]) / grid_size
grid = torch.arange(
grid_range[0] - spline_order * h,
grid_range[1] + (spline_order + 1) * h,
h
)
self.register_buffer("grid", grid)
n_basis = len(grid) - spline_order - 1
# 每条边的可学习控制点
# shape: (out_features, in_features, n_basis)
self.spline_weight = nn.Parameter(
torch.randn(out_features, in_features, n_basis) * 0.1
)
# 残差基函数的缩放因子
self.base_weight = nn.Parameter(
torch.randn(out_features, in_features) * 0.1
)
# B 样条的缩放因子
self.spline_scale = nn.Parameter(
torch.ones(out_features, in_features)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, in_features)
返回: (B, out_features)
每个输出节点 = 所有输入节点的激活函数之和
"""
B = x.shape[0]
# 计算每个输入的 B 样条基函数值
# x: (B, in_features) → 对每个特征分别计算
basis_values = [] # in_features 个 (B, n_basis)
for i in range(self.in_features):
basis_i = compute_bspline_basis(x[:, i], self.grid, self.spline_order)
basis_values.append(basis_i)
basis_values = torch.stack(basis_values, dim=1) # (B, in_features, n_basis)
# B 样条激活: sum(c_m * B_m(x))
# spline_weight: (out, in, n_basis)
# basis_values: (B, in, n_basis)
spline_activation = torch.einsum(
"oin,bin->boi", self.spline_weight, basis_values
) # (B, out_features, in_features)
# 残差基函数 (SiLU)
if self.base_weight.requires_grad:
base_activation = F.silu(x) # (B, in_features)
base_contribution = torch.einsum(
"oi,bi->boi", self.base_weight, base_activation
) # (B, out_features, in_features)
else:
base_contribution = 0
# 合并: phi(x) = w_b * silu(x) + w_s * spline(x)
activation = base_contribution + self.spline_scale.unsqueeze(0) * spline_activation
# 节点求和: output_j = sum_i phi_{j,i}(x_i)
output = activation.sum(dim=-1) # (B, out_features)
return output
4.3 完整 KAN 网络
python
class KAN(nn.Module):
"""Kolmogorov-Arnold Network"""
def __init__(self, config: KANConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList()
for i in range(len(config.layers) - 1):
self.layers.append(
KANLayer(
in_features=config.layers[i],
out_features=config.layers[i + 1],
grid_size=config.grid_size,
spline_order=config.spline_order,
grid_range=config.grid_range,
)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x)
return x
def get_activation_values(
self, layer_idx: int, edge_idx: Tuple[int, int], x: torch.Tensor
) -> torch.Tensor:
"""获取指定边的激活函数值 (用于可视化)"""
layer = self.layers[layer_idx]
i, j = edge_idx # 输入 i, 输出 j
basis = compute_bspline_basis(x, layer.grid, layer.spline_order)
spline_val = (layer.spline_weight[j, i] * basis).sum(dim=-1)
base_val = layer.base_weight[j, i] * F.silu(x)
return base_val + layer.spline_scale[j, i] * spline_val
4.4 实验代码
python
def experiment_kan_function_fitting():
"""KAN 函数拟合实验"""
torch.manual_seed(42)
device = torch.device("cpu")
# 目标函数: f(x, y) = sin(x^2 + y^2)
def target_fn(x):
return torch.sin(x[:, 0:1] ** 2 + x[:, 1:2] ** 2)
# 生成训练数据
n_train = 1000
x_train = torch.randn(n_train, 2, device=device) * 1.5
y_train = target_fn(x_train)
# 创建 KAN
config = KANConfig(layers=[2, 5, 1], grid_size=8, spline_order=3)
model = KAN(config).to(device)
# 训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
losses = []
for step in range(2000):
idx = torch.randint(0, n_train, (256,))
x_batch = x_train[idx]
y_batch = y_train[idx]
y_pred = model(x_batch)
loss = F.mse_loss(y_pred, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (step + 1) % 500 == 0:
print(f"Step {step+1} | Loss: {np.mean(losses[-100:]):.6f}")
# 测试
x_test = torch.randn(500, 2, device=device) * 1.5
y_test = target_fn(x_test)
with torch.no_grad():
y_pred = model(x_test)
test_mse = F.mse_loss(y_pred, y_test).item()
print(f"\n测试 MSE: {test_mse:.6f}")
return model, losses
def compare_kan_vs_mlp():
"""对比 KAN 和 MLP 的拟合能力"""
torch.manual_seed(42)
# 目标函数: f(x) = sin(3x) * exp(-x^2)
def target_fn(x):
return torch.sin(3 * x) * torch.exp(-x ** 2)
# 训练数据
x_train = torch.linspace(-3, 3, 200).unsqueeze(1)
y_train = target_fn(x_train)
# KAN
config = KANConfig(layers=[1, 5, 1], grid_size=10, spline_order=3)
kan = KAN(config)
optimizer_kan = torch.optim.Adam(kan.parameters(), lr=1e-2)
# MLP (类似参数量)
mlp = nn.Sequential(
nn.Linear(1, 20), nn.SiLU(),
nn.Linear(20, 20), nn.SiLU(),
nn.Linear(20, 1)
)
optimizer_mlp = torch.optim.Adam(mlp.parameters(), lr=1e-3)
# 训练
for step in range(3000):
# KAN
y_pred_kan = kan(x_train)
loss_kan = F.mse_loss(y_pred_kan, y_train)
optimizer_kan.zero_grad()
loss_kan.backward()
optimizer_kan.step()
# MLP
y_pred_mlp = mlp(x_train)
loss_mlp = F.mse_loss(y_pred_mlp, y_train)
optimizer_mlp.zero_grad()
loss_mlp.backward()
optimizer_mlp.step()
if (step + 1) % 1000 == 0:
print(f"Step {step+1} | KAN Loss: {loss_kan.item():.6f} | MLP Loss: {loss_mlp.item():.6f}")
# 参数量
kan_params = sum(p.numel() for p in kan.parameters())
mlp_params = sum(p.numel() for p in mlp.parameters())
print(f"\nKAN 参数量: {kan_params}")
print(f"MLP 参数量: {mlp_params}")
5. KAN 架构数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ KAN 架构与 B 样条 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. B 样条基函数 (Cox-de Boor 递归): ║
║ B_{i,1}(x) = 1 if t_i ≤ x < t_{i+1}, else 0 ║
║ B_{i,k}(x) = ((x-t_i)/(t_{i+k-1}-t_i))·B_{i,k-1}(x) ║
║ + ((t_{i+k}-x)/(t_{i+k}-t_{i+1}))·B_{i+1,k-1}(x) ║
║ ║
║ 2. KAN 激活函数: ║
║ φ(x) = w_b·silu(x) + w_s·Σ_m c_m·B_m(x) ║
║ w_b, w_s, c_m 均可学习 ║
║ ║
║ 3. KAN 层: ║
║ x_j^{l+1} = Σ_i φ_{l,j,i}(x_i^l) ║
║ 节点仅求和, 非线性来自边上的激活函数 ║
║ ║
║ 4. 参数量: ║
║ KAN 层: n_in × n_out × (G + k + 2) ║
║ MLP 层: n_in × n_out + n_out ║
║ KAN 每层参数更多, 但需要更少的层 ║
║ ║
║ 5. B 样条性质: ║
║ - 局部支撑: 移动一个节点只影响附近 ║
║ - 非负性: B_{i,k}(x) ≥ 0 ║
║ - 归一性: Σ_i B_{i,k}(x) = 1 ║
║ - 可微性: B_{i,k} ∈ C^{k-2} ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
第三篇:KAN 的前沿发展与应用
1. 引言
KAN 自 2024 年 5 月发表以来,迅速引发了大量 follow-up 工作,涵盖理论分析、架构改进和应用拓展。
2. KAN 的理论分析
2.1 KAN 的通用逼近定理
定理 (Liu et al., 2024):对于任意连续函数 f : 0 , 1 n → R f: 0, 1^n \to \mathbb{R} f:0,1n→R 和任意 ϵ > 0 \epsilon > 0 ϵ>0,存在一个 KAN 网络 f θ f_\theta fθ 使得:
∥ f − f θ ∥ ∞ < ϵ \|f - f_\theta\|_\infty < \epsilon ∥f−fθ∥∞<ϵ
这比 Kolmogorov-Arnold 定理更强------原始定理中的函数是病态的,而 KAN 使用光滑的 B 样条。
2.2 KAN 的逼近速率
定理 :对于 k k k 次连续可微的函数类 C k C^k Ck,使用 k k k 阶 B 样条的 KAN 的逼近误差为:
∥ f − f θ ∥ ∞ = O ( G − k ) \|f - f_\theta\|_\infty = O(G^{-k}) ∥f−fθ∥∞=O(G−k)
其中 G G G 是网格大小。增加网格大小可以指数级减小误差。
2.3 KAN 的可解释性
KAN 的一个重要优势是可解释性------每条边上的激活函数可以独立可视化,揭示数据中的函数关系。
符号回归:KAN 可以自动发现数学公式。通过分析学习到的激活函数形状,可以推断出底层的数学关系。
3. KAN 的架构变体
3.1 FourierKAN
用傅里叶基函数替代 B 样条:
ϕ ( x ) = ∑ m = 0 M ( a m cos ( m ω x ) + b m sin ( m ω x ) ) \phi(x) = \sum_{m=0}^{M} \left( a_m \cos(m \omega x) + b_m \sin(m \omega x) \right) ϕ(x)=m=0∑M(amcos(mωx)+bmsin(mωx))
优势:更适合周期性函数。
3.2 WaveletKAN
用小波基函数替代 B 样条:
ϕ ( x ) = ∑ m c m ψ ( x − t m s ) \phi(x) = \sum_{m} c_m \psi\left(\frac{x - t_m}{s}\right) ϕ(x)=m∑cmψ(sx−tm)
优势:多尺度分析能力,适合具有局部特征的函数。
3.3 ReLU-KAN
用分段线性函数(ReLU 的组合)替代 B 样条:
ϕ ( x ) = ∑ m c m ReLU ( x − t m ) \phi(x) = \sum_{m} c_m \text{ReLU}(x - t_m) ϕ(x)=m∑cmReLU(x−tm)
优势:计算效率更高,与现有硬件更兼容。
3.4 ChebyKAN
用切比雪夫多项式替代 B 样条:
ϕ ( x ) = ∑ m = 0 M c m T m ( x ) \phi(x) = \sum_{m=0}^{M} c_m T_m(x) ϕ(x)=m=0∑McmTm(x)
其中 T m T_m Tm 是 m m m 阶切比雪夫多项式。
优势 :在 − 1 , 1 -1, 1 −1,1 上具有最优的多项式逼近性质。
4. KAN 的应用场景
4.1 科学计算
KAN 在科学计算中表现出色,因为它天然捕捉函数的组合结构:
| 任务 | KAN 优势 |
|---|---|
| 物理定律发现 | 可解释的激活函数揭示物理关系 |
| PDE 求解 | 高效逼近高维函数 |
| 分子动力学 | 捕捉原子间相互作用的函数形式 |
4.2 符号回归
KAN 可以自动发现数学公式:
- 训练 KAN 拟合数据
- 分析激活函数形状
- 用符号表达式替换 B 样条
- 简化得到最终公式
4.3 机器学习
| 任务 | KAN 表现 |
|---|---|
| 函数拟合 | 参数效率更高 |
| 回归 | 更少的参数达到相同精度 |
| 分类 | 与 MLP 相当,但更可解释 |
5. KAN 与 MLP 的全面对比
| 维度 | MLP | KAN |
|---|---|---|
| 理论基础 | 通用逼近定理 | Kolmogorov-Arnold 定理 |
| 参数位置 | 节点(权重矩阵) | 边(激活函数) |
| 激活函数 | 固定 | 可学习 |
| 逼近速率 | O ( n − 2 / d ) O(n^{-2/d}) O(n−2/d) | O ( n − k ) O(n^{-k}) O(n−k) |
| 参数效率 | 低 | 高 |
| 可解释性 | 低 | 高(可视化激活函数) |
| 计算效率 | 高(矩阵乘法) | 中(B 样条计算) |
| 硬件支持 | 成熟 | 待优化 |
| 适用场景 | 通用 | 科学计算、符号回归 |
6. 前沿研究方向
6.1 KAN + Transformer
将 KAN 替换 Transformer 中的 MLP 层:
Attention → KAN → Attention → KAN → ⋯ \text{Attention} \to \text{KAN} \to \text{Attention} \to \text{KAN} \to \cdots Attention→KAN→Attention→KAN→⋯
初步实验表明,在某些任务上 KAN-Transformer 优于标准 Transformer。
6.2 KAN + GNN
将 KAN 应用于图神经网络的消息传递:
m i j = ϕ KAN ( x i , x j , e i j ) m_{ij} = \phi_{\text{KAN}}(x_i, x_j, e_{ij}) mij=ϕKAN(xi,xj,eij)
6.3 KAN 的理论深化
- 泛化理论:分析 KAN 的泛化误差界
- 优化理论:分析 KAN 训练的收敛性
- 信息论:分析 KAN 的信息压缩能力
7. 前沿发展数学公式总结
╔══════════════════════════════════════════════════════════════════════════════════════════╗
║ KAN 前沿发展 数学总结 ║
╠══════════════════════════════════════════════════════════════════════════════════════════╣
║ ║
║ 1. 通用逼近: ║
║ ∀f ∈ C([0,1]^n), ∀ε > 0, ∃KAN f_θ: ‖f - f_θ‖_∞ < ε ║
║ ║
║ 2. 逼近速率: ║
║ KAN (B 样条): O(G^{-k}), G=网格大小, k=样条阶数 ║
║ MLP: O(n^{-2/d}), d=输入维度 ║
║ ║
║ 3. KAN 变体: ║
║ FourierKAN: φ(x) = Σ (a_m·cos(mωx) + b_m·sin(mωx)) ║
║ ChebyKAN: φ(x) = Σ c_m·T_m(x) ║
║ ReLU-KAN: φ(x) = Σ c_m·ReLU(x - t_m) ║
║ ║
║ 4. 符号回归流程: ║
║ 训练 KAN → 分析激活函数 → 符号替换 → 简化公式 ║
║ ║
║ 5. KAN + Transformer: ║
║ Attention → KAN → Attention → KAN → ... ║
║ 初步实验优于标准 Transformer ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════════════════╝
参考文献
核心论文
- Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., Hou, T. Y., & Tegmark, M. (2024). KAN: Kolmogorov-Arnold Networks. arXiv:2404.19756.
- Kolmogorov, A. N. (1957). On the representation of continuous functions of many variables by superposition of continuous functions of one variable and addition. Doklady Akademii Nauk SSSR, 114, 953-956.
- Arnold, V. I. (1958). On the representation of continuous functions of three variables by superpositions of continuous functions of two variables. Doklady Akademii Nauk SSSR, 114, 679-681.
B 样条理论
- De Boor, C. (1978). A Practical Guide to Splines. Springer-Verlag.
- Schoenberg, I. J. (1946). Contributions to the problem of approximation of equidistant data by analytic functions. Quarterly of Applied Mathematics, 4, 45-99.
KAN 变体
- Xu, Z. (2024). FourierKAN-GCF: Fourier Kolmogorov-Arnold Network for Graph Collaborative Filtering. arXiv:2406.01006.
- Bodner, A. D., Tepsich, A. S., Spinosa, J. P., & Pinto, J. (2024). Wav-KAN: Wavelet Kolmogorov-Arnold Networks. arXiv:2405.12832.
逼近论
- Cybenko, G. (1989). Approximation by superpositions of a sigmoidal function. Mathematics of Control, Signals and Systems, 2(4), 303-314.
- Hornik, K. (1991). Approximation capabilities of multilayer feedforward networks. Neural Networks, 4(2), 251-257.