YoloV8改进策略:卷积篇Kan行天下之小波Kan

python 复制代码
'''
Based on https://github.com/zavareh1/Wav-KAN
This is a sample code for the simulations of the paper:
Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024)

https://arxiv.org/abs/2405.12832
and also available at:
https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325
We used efficient KAN notation and some part of the code:https://github.com/Blealtan/efficient-kan

'''
import math

import torch
import torch.nn as nn


class WaveletConvND(nn.Module):
    def __init__(self, conv_class, input_dim, output_dim, kernel_size,
                 padding=0, stride=1, dilation=1,
                 ndim: int = 2, wavelet_type='mexican_hat'):
        super(WaveletConvND, self).__init__()

        _shapes = (1, output_dim, input_dim) + tuple(1 for _ in range(ndim))

        self.scale = nn.Parameter(torch.ones(*_shapes))
        self.translation = nn.Parameter(torch.zeros(*_shapes))

        self.ndim = ndim
        self.wavelet_type = wavelet_type

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.wavelet_weights = nn.ModuleList([conv_class(input_dim,
                                                         1,
                                                         kernel_size,
                                                         stride,
                                                         padding,
                                                         dilation,
                                                         groups=1,
                                                         bias=False) for _ in range(output_dim)])

        self.wavelet_out = conv_class(output_dim, output_dim, 1, 1, 0, dilation, groups=1, bias=False)

        for conv_layer in self.wavelet_weights:
            nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')
        nn.init.kaiming_uniform_(self.wavelet_out.weight, nonlinearity='linear')

    @staticmethod
    def _forward_mexican_hat(x):
        term1 = ((x ** 2) - 1)
        term2 = torch.exp(-0.5 * x ** 2)
        wavelet = (2 / (math.sqrt(3) * math.pi ** 0.25)) * term1 * term2
        return wavelet

    @staticmethod
    def _forward_morlet(x):
        omega0 = 5.0  # Central frequency
        real = torch.cos(omega0 * x)
        envelope = torch.exp(-0.5 * x ** 2)
        wavelet = envelope * real
        return wavelet

    @staticmethod
    def _forward_dog(x):
        return -x * torch.exp(-0.5 * x ** 2)

    @staticmethod
    def _forward_meyer(x):
        v = torch.abs(x)
        pi = math.pi

        def meyer_aux(v):
            return torch.where(v <= 1 / 2, torch.ones_like(v),
                               torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1))))

        def nu(t):
            return t ** 4 * (35 - 84 * t + 70 * t ** 2 - 20 * t ** 3)

        # Meyer wavelet calculation using the auxiliary function
        wavelet = torch.sin(pi * v) * meyer_aux(v)
        return wavelet

    def _forward_shannon(self, x):
        pi = math.pi
        sinc = torch.sinc(x / pi)  # sinc(x) = sin(pi*x) / (pi*x)

        _shape = (1, 1, x.size(2)) + tuple(1 for _ in range(self.ndim))
        # Applying a Hamming window to limit the infinite support of the sinc function
        window = torch.hamming_window(x.size(2), periodic=False, dtype=x.dtype,
                                      device=x.device).view(*_shape)
        # Shannon wavelet is the product of the sinc function and the window
        wavelet = sinc * window
        return wavelet

    def forward(self, x):
        x_expanded = x.unsqueeze(1)

        x_scaled = (x_expanded - self.translation) / self.scale

        if self.wavelet_type == 'mexican_hat':
            wavelet = self._forward_mexican_hat(x_scaled)
        elif self.wavelet_type == 'morlet':
            wavelet = self._forward_morlet(x_scaled)
        elif self.wavelet_type == 'dog':
            wavelet = self._forward_dog(x_scaled)
        elif self.wavelet_type == 'meyer':
            wavelet = self._forward_meyer(x_scaled)
        elif self.wavelet_type == 'shannon':
            wavelet = self._forward_shannon(x_scaled)
        else:
            raise ValueError("Unsupported wavelet type")

        wavelet_x = torch.split(wavelet, 1, dim=1)
        output = []
        for group_ind, _x in enumerate(wavelet_x):
            y = self.wavelet_weights[group_ind](_x.squeeze(1))
            # output.append(y.clone())
            output.append(y)
        y = torch.cat(output, dim=1)
        y = self.wavelet_out(y)
        return y


class WaveletConvNDFastPlusOne(WaveletConvND):
    def __init__(self, conv_class, conv_class_d_plus_one, input_dim, output_dim, kernel_size,
                 padding=0, stride=1, dilation=1,
                 ndim: int = 2, wavelet_type='mexican_hat'):
        super(WaveletConvND, self).__init__()

        assert ndim < 3, "fast_plus_one version suppoerts only 1D and 2D convs"

        _shapes = (1, output_dim, input_dim) + tuple(1 for _ in range(ndim))

        self.scale = nn.Parameter(torch.ones(*_shapes))
        self.translation = nn.Parameter(torch.zeros(*_shapes))

        self.ndim = ndim
        self.wavelet_type = wavelet_type

        self.input_dim = input_dim
        self.output_dim = output_dim

        kernel_size_plus = (input_dim,) + kernel_size if isinstance(kernel_size, tuple) else (input_dim,) + (
        kernel_size,) * ndim
        stride_plus = (1,) + stride if isinstance(stride, tuple) else (1,) + (stride,) * ndim
        padding_plus = (0,) + padding if isinstance(padding, tuple) else (0,) + (padding,) * ndim
        dilation_plus = (1,) + dilation if isinstance(dilation, tuple) else (1,) + (dilation,) * ndim

        self.wavelet_weights = conv_class_d_plus_one(output_dim,
                                                     output_dim,
                                                     kernel_size_plus,
                                                     stride_plus,
                                                     padding_plus,
                                                     dilation_plus,
                                                     groups=output_dim,
                                                     bias=False)

        self.wavelet_out = conv_class(output_dim, output_dim, 1, 1, 0, dilation, groups=1, bias=False)

        nn.init.kaiming_uniform_(self.wavelet_weights.weight, nonlinearity='linear')
        nn.init.kaiming_uniform_(self.wavelet_out.weight, nonlinearity='linear')

    def forward(self, x):
        x_expanded = x.unsqueeze(1)

        x_scaled = (x_expanded - self.translation) / self.scale

        if self.wavelet_type == 'mexican_hat':
            wavelet = self._forward_mexican_hat(x_scaled)
        elif self.wavelet_type == 'morlet':
            wavelet = self._forward_morlet(x_scaled)
        elif self.wavelet_type == 'dog':
            wavelet = self._forward_dog(x_scaled)
        elif self.wavelet_type == 'meyer':
            wavelet = self._forward_meyer(x_scaled)
        elif self.wavelet_type == 'shannon':
            wavelet = self._forward_shannon(x_scaled)
        else:
            raise ValueError("Unsupported wavelet type")
        # wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
        # wavelet_output = wavelet_weighted.sum(dim=2)

        y = self.wavelet_weights(wavelet).squeeze(2)
        y = self.wavelet_out(y)
        return y


class WaveletConvNDFast(WaveletConvND):
    def __init__(self, conv_class, input_dim, output_dim, kernel_size,
                 padding=0, stride=1, dilation=1,
                 ndim: int = 2, wavelet_type='mexican_hat'):
        super(WaveletConvND, self).__init__()

        _shapes = (1, output_dim, input_dim) + tuple(1 for _ in range(ndim))

        self.scale = nn.Parameter(torch.ones(*_shapes))
        self.translation = nn.Parameter(torch.zeros(*_shapes))

        self.ndim = ndim
        self.wavelet_type = wavelet_type

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.wavelet_weights = conv_class(output_dim * input_dim,
                                          output_dim,
                                          kernel_size,
                                          stride,
                                          padding,
                                          dilation,
                                          groups=output_dim,
                                          bias=False)

        self.wavelet_out = conv_class(output_dim, output_dim, 1, 1, 0, dilation, groups=1, bias=False)

        nn.init.kaiming_uniform_(self.wavelet_weights.weight, nonlinearity='linear')
        nn.init.kaiming_uniform_(self.wavelet_out.weight, nonlinearity='linear')

    def forward(self, x):
        x_expanded = x.unsqueeze(1)

        x_scaled = (x_expanded - self.translation) / self.scale

        if self.wavelet_type == 'mexican_hat':
            wavelet = self._forward_mexican_hat(x_scaled)
        elif self.wavelet_type == 'morlet':
            wavelet = self._forward_morlet(x_scaled)
        elif self.wavelet_type == 'dog':
            wavelet = self._forward_dog(x_scaled)
        elif self.wavelet_type == 'meyer':
            wavelet = self._forward_meyer(x_scaled)
        elif self.wavelet_type == 'shannon':
            wavelet = self._forward_shannon(x_scaled)
        else:
            raise ValueError("Unsupported wavelet type")
        # wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)
        # wavelet_output = wavelet_weighted.sum(dim=2)

        y = self.wavelet_weights(wavelet.flatten(1, 2))
        y = self.wavelet_out(y)
        return y


class WavKANConvNDLayer(nn.Module):
    def __init__(self, conv_class, conv_class_plus1, norm_class, input_dim, output_dim, kernel_size,
                 groups=1, padding=0, stride=1, dilation=1, wav_version: str = 'base',
                 ndim: int = 2, dropout=0.0, wavelet_type='mexican_hat', **norm_kwargs):
        super(WavKANConvNDLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.ndim = ndim
        self.norm_kwargs = norm_kwargs
        assert wavelet_type in ['mexican_hat', 'morlet', 'dog', 'meyer', 'shannon'], \
            ValueError(f"Unsupported wavelet type: {wavelet_type}")
        self.wavelet_type = wavelet_type

        self.dropout = None
        if dropout > 0:
            if ndim == 1:
                self.dropout = nn.Dropout1d(p=dropout)
            if ndim == 2:
                self.dropout = nn.Dropout2d(p=dropout)
            if ndim == 3:
                self.dropout = nn.Dropout3d(p=dropout)
        if groups <= 0:
            raise ValueError('groups must be a positive integer')
        if input_dim % groups != 0:
            raise ValueError('input_dim must be divisible by groups')
        if output_dim % groups != 0:
            raise ValueError('output_dim must be divisible by groups')

        self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
                                                   output_dim // groups,
                                                   kernel_size,
                                                   stride,
                                                   padding,
                                                   dilation,
                                                   groups=1,
                                                   bias=False) for _ in range(groups)])
        if wav_version == 'base':
            self.wavelet_conv = nn.ModuleList(
                [
                    WaveletConvND(
                        conv_class,
                        input_dim // groups,
                        output_dim // groups,
                        kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        ndim=ndim, wavelet_type=wavelet_type
                    ) for _ in range(groups)
                ]
            )
        elif wav_version == 'fast':
            self.wavelet_conv = nn.ModuleList(
                [
                    WaveletConvNDFast(
                        conv_class,
                        input_dim // groups,
                        output_dim // groups,
                        kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        ndim=ndim, wavelet_type=wavelet_type
                    ) for _ in range(groups)
                ]
            )
        elif wav_version == 'fast_plus_one':

            self.wavelet_conv = nn.ModuleList(
                [
                    WaveletConvNDFastPlusOne(
                        conv_class, conv_class_plus1,
                        input_dim // groups,
                        output_dim // groups,
                        kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=dilation,
                        ndim=ndim, wavelet_type=wavelet_type
                    ) for _ in range(groups)
                ]
            )

        self.layer_norm = nn.ModuleList([norm_class(output_dim // groups, **norm_kwargs) for _ in range(groups)])

        self.base_activation = nn.SiLU()

    def forward_wavkan(self, x, group_ind):
        # You may like test the cases like Spl-KAN
        x=self.base_activation(x)
        base_output = self.base_conv[group_ind](x)

        if self.dropout is not None:
            x = self.dropout(x)

        wavelet_output = self.wavelet_conv[group_ind](x)

        combined_output = wavelet_output + base_output

        # Apply batch normalization
        return self.layer_norm[group_ind](combined_output)

    def forward(self, x):
        split_x = torch.split(x, self.inputdim // self.groups, dim=1)
        output = []

        for group_ind, _x in enumerate(split_x):
            y = self.forward_wavkan(split_x[group_ind].clone(), group_ind)
            output.append(y.clone())
        y = torch.cat(output, dim=1)
        return y


class WavKANConv3DLayer(WavKANConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, groups=1, padding=0, stride=1, dilation=1,
                 dropout=0.0, wavelet_type='mexican_hat', norm_layer=nn.BatchNorm3d,
                 wav_version: str = 'fast', **norm_kwargs):
        super(WavKANConv3DLayer, self).__init__(nn.Conv3d, None, norm_layer, input_dim, output_dim, kernel_size,
                                                groups=groups, padding=padding, stride=stride, dilation=dilation,
                                                ndim=3, dropout=dropout, wavelet_type=wavelet_type,
                                                wav_version=wav_version, **norm_kwargs)


class WavKANConv2DLayer(WavKANConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, groups=1, padding=0, stride=1, dilation=1,
                 dropout=0.0, wavelet_type='mexican_hat', norm_layer=nn.BatchNorm2d,
                 wav_version: str = 'fast_plus_one', **norm_kwargs):
        super(WavKANConv2DLayer, self).__init__(nn.Conv2d, nn.Conv3d, norm_layer, input_dim, output_dim, kernel_size,
                                                groups=groups, padding=padding, stride=stride, dilation=dilation,
                                                ndim=2, dropout=dropout, wavelet_type=wavelet_type,
                                                wav_version=wav_version, **norm_kwargs)


class WavKANConv1DLayer(WavKANConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, groups=1, padding=0, stride=1, dilation=1,
                 dropout=0.0, wavelet_type='mexican_hat', norm_layer=nn.BatchNorm1d,
                 wav_version: str = 'fast', **norm_kwargs):
        super(WavKANConv1DLayer, self).__init__(nn.Conv1d, nn.Conv2d, norm_layer, input_dim, output_dim, kernel_size,
                                                groups=groups, padding=padding, stride=stride, dilation=dilation,
                                                ndim=1, dropout=dropout, wavelet_type=wavelet_type,
                                                wav_version=wav_version, **norm_kwargs)

运行结果

python 复制代码
YOLOv8l summary: 658 layers, 46147104 parameters, 0 gradients, 164.1 GFLOPs
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 15/15 [00:01<00:00,  9.66it/s]
                   all        230       1412       0.97      0.969      0.987      0.748
                   c17         40        131      0.998      0.992      0.995      0.823
                    c5         19         68          1      0.984      0.994      0.832
            helicopter         13         43      0.977      0.981      0.987      0.571
                  c130         20         85      0.989      0.988      0.994      0.659
                   f16         11         57          1       0.92      0.968      0.669
                    b2          2          2      0.912          1      0.995      0.823
                 other         13         86       0.99      0.907      0.971      0.538
                   b52         21         70      0.981      0.971      0.987      0.842
                  kc10         12         62      0.996      0.968      0.988      0.858
               command         12         40      0.994          1      0.995      0.824
                   f15         21        123      0.962      0.992      0.994      0.674
                 kc135         24         91      0.975      0.989      0.981      0.701
                   a10          4         27          1      0.556      0.874       0.42
                    b1          5         20          1       0.97      0.995      0.709
                   aew          4         25      0.952          1      0.995      0.789
                   f22          3         17      0.985          1      0.995      0.751
                    p3          6        105          1      0.975      0.995        0.8
                    p8          1          1      0.859          1      0.995      0.895
                   f35          5         32      0.977      0.969      0.993      0.584
                   f18         13        125      0.976      0.992      0.986      0.824
                   v22          5         41      0.981          1      0.995       0.69
                 su-27          5         31      0.986          1      0.995      0.847
                 il-38         10         27      0.959          1      0.995      0.819
                tu-134          1          1      0.872          1      0.995      0.895
                 su-33          1          2      0.958          1      0.995      0.796
                 an-70          1          2       0.91          1      0.995      0.728
                 tu-22          8         98      0.998          1      0.995      0.831
相关推荐
飞Link9 分钟前
LangChain Core 架构深度剖析与 LCEL 高阶实战
人工智能·架构·langchain
liangdabiao12 分钟前
Seedance 2.0 Skill 一键写好剧本上线了coze的技能商店了,免费
人工智能
喵飞云智AI研发社21 分钟前
本土AI企业发力 喵飞科技AIGC开年分享会助力天津数字化转型
人工智能·科技·aigc
于过23 分钟前
AgentMiddleware is All You Need
人工智能·langchain·llm
LLM精进之路24 分钟前
频域+特征融合:深度学习的黄金组合,顶会顶刊的快速通道
人工智能·计算机视觉·目标跟踪
大橙子打游戏29 分钟前
我做了一个 A2A 协议的 Postman —— A2A-Forge 开源了
人工智能
L-影34 分钟前
Agent中的ReAct:类型、作用与避坑指南(下篇)
人工智能·ai·react
itwangyang52034 分钟前
AIDD-人工智能药物发现与设计-利用深度学习从头设计药物,实现逆转疾病相关转录表型
人工智能·深度学习
陈天伟教授35 分钟前
人工智能应用- 预测新冠病毒传染性:07. 预测不同类型病毒的传播能力
人工智能·神经网络·安全·cnn·xss
初学者↑1 小时前
知识图谱可视化系统源码分享
人工智能·知识图谱