⌈ 传知代码 ⌋ 挑战MLP,KAN网络解析使用

前情提要

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


💡本章重点

  • 挑战MLP,KAN网络解析使用

🍞一. 概述

自深度学习出现至今,所有的网络几乎都研自于MLP,虽然MLP的效果较为不错,MLP没有非线性的能力,用其训练人工智能需要大量的参数的堆叠和激活函数的作用,这使得模型有了过拟合的风险。本项目旨在通过新型网络KAN替代往常深度学习网络中MLP层,借助KAN网络其本身的非线性拟合能力和良好的可解释性,探究其在神经网络中替代MLP的可能,并研发新型的深度学习网络架构,以提高深度学习的性能和可解释性。


🍞二. 核心逻辑

KAN网络解析

KAN网络基于柯尔莫哥洛夫-阿诺德表示定理。证明了实分析中的如下表示定理:如果一个函数是多元连续函数,则该函数可以写成有限数量的单变量连续函数的两层嵌套叠加。其最简洁的数学表达式就是:

KAN网络通过将传统深度学习中普通的权重数据值替换为由多个基础函数所构成的样条函数,经过与输入数据的非线性组合,进而得出输出结果。KAN网络经过训练之后能够输出其所拟合的有多个基础函数所组合而成的复杂函数,有非常好的可解释性,加之其本身就是各种非线性函数的组合,从而使其拥有了非线性变换的能力,可以仅通过几百的参数量就达到MLP几万参数量的效果,所以一直被认为是MLP的替代品。以下是KAN的工作原理:

B-spline

原始的KAN使用B-spline(B样条)来构建。B-spline是基础样条(Basic Spline)的缩写。对于B-spline,函数在其定义域内、在结点(Knot)都具有相同的连续性。

样条是 KAN 学习机制的核心,它们取代了神经网络中通常使用的传统权重参数。

样条的灵活性使其能够通过调整其形状来适应性地建模数据中的复杂关系,从而最小化近似误差,增强了网络从高维数据集中学习细微模式的能力。

KAN 中样条的通用公式可以用 B-样条来表示:

这里,spline(x)表示样条函数,ci 是训练期间优化的系数,而 Bi (x)是定义在网格上的B-样条的基函数。网格点定义了每个基函数 Bi 活跃并显著影响形状和平滑度的区间。可以简单的视为影响网络准确性的超参数。更多的网格意味着更多的控制和更高的精度,同时也意味着需要学习更多的参数。

​训练模拟图如下:

与MLP的比较

MLP属于前馈神经网络,由输入层、一个或多个隐藏层以及输出层组成。每一层由多个神经元组成,每个神经元都与下一层的所有神经元相连接。MLP能够学习和模拟复杂的非线性函数,这使得它在各种机器学习任务中都有广泛的应用,如分类、回归、聚类等。但是MLP在实现复杂的函数逼近时需要大量的参数,这导致了过高的计算成本和存储需求。由于参数数量的增加,MLP容易在训练数据上过拟合,即在训练数据上表现良好,但在未见过的新数据上表现不佳。且MLP最令人所诟病的一点就是MLP的结构通常不具备良好的可解释性,这使得理解模型的决策过程变得困难,这在需要模型解释的应用场景中是一个重大缺陷。且随着时代的发展,MLP已经被研究的差不多了,较难有很好的创新了。

KAN网络相比较MLP能够使用更少的参数量实现更高的精度。KAN网络在结构上更具可解释性,这对于数学和物理研究中的辅助模型尤其重要。因此本项目旨在通过KAN网络完全替代MLP在传统神经网络模型中的决定性作用,研究出一种基于KAN网络的神经网络模型,可用于提升神经网络提取目标特征的能力。

设想存在一个深度为L的网络结构,每层包含相同数量的节点N,且该网络在G个区间上定义(这意味着有G+1个离散的网格点)。每个样条函数的阶数为k(一般情况下,k取值为3)。这样计算下来,整个网络总共包含了若干个参数。相较之下,一个同样深度为L且每层宽度为N的多层感知机(MLP)所需的参数数量则较少,这表明MLP在参数效率上似乎优于KAN(核自适应网络)。然而,值得庆幸的是,KAN通常可以在较小的N值下工作,这不仅减少了参数的数量,还提高了模型的泛化能力,并且有助于增强模型的可解释性。


🍞三.KAN和MLP的选择

KAN(核自适应网络)的主要局限性在于其训练过程较为缓慢,相较于拥有相同参数数量的MLP(多层感知机),KAN的训练速度往往要慢上10倍。尽管如此,这种训练速度的劣势更多地被视为一个可以通过工程技术解决的挑战,而非KAN本身的根本性缺陷。在许多情况下,KAN的表现有望与MLP持平甚至超越,这赋予了探索KAN应用的价值。以下所示的决策树可以辅助判断在何种情况下应当考虑使用KAN。简而言之,如果解释性和/或模型的准确性对您来说至关重要,且训练速度的缓慢不是您的主要顾虑,那么我们建议您尝试使用KAN。


🍞四.AlexNet+KAN

模型原理图:

定义模型(重点)

python 复制代码
class AlexNet_kan(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet_kan, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(3,96,kernel_size=11,stride=4,padding=0),#[96,54,54]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2) #[96,26,26]
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(96,256,kernel_size=5,padding=2),#[256,26,26] 
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2) #[256,12,12]
        )
        self.conv3=nn.Sequential(
            nn.Conv2d(256,384,kernel_size=3,padding=1),#[384,12,12]
            nn.ReLU(inplace=True),
            nn.Conv2d(384,384,kernel_size=3,padding=1),#[384,12,12]
            nn.ReLU(inplace=True),
            nn.Conv2d(384,256,kernel_size=3,padding=1),#[256,12,12]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2) #[256,5,5]
        )
        self.fc=nn.Sequential(
            KAN([256*5*5,4096]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            KAN([4096,4096]),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            KAN([4096,num_classes])
        )

    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=x.view(x.size(0),-1)
        x=self.fc(x)
        return x

前向加噪

python 复制代码
import torch
from config import *

# 前向diffusion计算参数
betas=torch.linspace(0.0001,0.02,T) # (T,)
alphas=1-betas  # (T,)
alphas_cumprod=torch.cumprod(alphas,dim=-1) # alpha_t累乘 (T,)    [a1,a2,a3,....] ->  [a1,a1*a2,a1*a2*a3,.....]
alphas_cumprod_prev=torch.cat((torch.tensor([1.0]),alphas_cumprod[:-1]),dim=-1) # alpha_t-1累乘 (T,),  [1,a1,a1*a2,a1*a2*a3,.....]
variance=(1-alphas)*(1-alphas_cumprod_prev)/(1-alphas_cumprod)  # denoise用的方差   (T,)

# 执行前向加噪
def forward_add_noise(x,t): # batch_x: (batch,channel,height,width), batch_t: (batch_size,)
    noise=torch.randn_like(x)   # 为每张图片生成第t步的高斯噪音   (batch,channel,height,width)
    batch_alphas_cumprod=alphas_cumprod[t].view(x.size(0),1,1,1) 
    x=torch.sqrt(batch_alphas_cumprod)*x+torch.sqrt(1-batch_alphas_cumprod)*noise # 基于公式直接生成第t步加噪后图片
    return x,noise

模型效果对比

VisionTransformer+KAN

在本次实验中使用了VisionTransformer和KAN结合的VisionTransformer_kan模型对比原模型进行训练比较。


🫓总结

综上,我们基本了解了**"一项全新的技术啦"** :lollipop: ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读:satisfied:

后续还会继续更新:heartbeat:,欢迎持续关注:pushpin:哟~

:dizzy:如果有错误❌,欢迎指正呀:dizzy:

:sparkles:如果觉得收获满满,可以点点赞👍支持一下哟~:sparkles:

【传知科技 -- 了解更多新知识】

相关推荐
XH华2 小时前
初识C语言之二维数组(下)
c语言·算法
南宫生2 小时前
力扣-图论-17【算法学习day.67】
java·学习·算法·leetcode·图论
不想当程序猿_3 小时前
【蓝桥杯每日一题】求和——前缀和
算法·前缀和·蓝桥杯
落魄君子3 小时前
GA-BP分类-遗传算法(Genetic Algorithm)和反向传播算法(Backpropagation)
算法·分类·数据挖掘
菜鸡中的奋斗鸡→挣扎鸡3 小时前
滑动窗口 + 算法复习
数据结构·算法
Lenyiin3 小时前
第146场双周赛:统计符合条件长度为3的子数组数目、统计异或值为给定值的路径数目、判断网格图能否被切割成块、唯一中间众数子序列 Ⅰ
c++·算法·leetcode·周赛·lenyiin
郭wes代码3 小时前
Cmd命令大全(万字详细版)
python·算法·小程序
scan7243 小时前
LILAC采样算法
人工智能·算法·机器学习
菌菌的快乐生活4 小时前
理解支持向量机
算法·机器学习·支持向量机
大山同学4 小时前
第三章线性判别函数(二)
线性代数·算法·机器学习