时间序列(Time-Series)FourierCorrelation.py代码解析

coding=utf-8

author=maziqing

email=maziqing.mzq@alibaba-inc.com

#这行导入了NumPy库,通常用于科学计算中的数组操作。

import numpy as np

#这行导入了PyTorch库,是一个常用于深度学习的库。

import torch

#这行导入了PyTorch中的nn模块,它包含了构建神经网络所需的类和方法

import torch.nn as nn

#这行定义了一个名为get_frequency_modes的函数,它接受序列长度seq_len,模式数modes,和模式选择方法mode_select_method作为参数。

def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):

"""

get modes on frequency domain:

'random' means sampling randomly;

'else' means sampling the lowest modes;

"""

#这行限制modes的数值不超过seq_len的一半。

modes = min(modes, seq_len // 2)

#这个条件语句根据mode_select_method参数选择不同的模式。如果选择了'random',它会随机选择模式;否则,它会选择最低的模式。

if mode_select_method == 'random':

index = list(range(0, seq_len // 2))

np.random.shuffle(index)

index = index[:modes]

else:

index = list(range(0, modes))

#这两行首先对索引进行排序,然后返回索引列表。

index.sort()

return index

########## fourier layer

#定义了一个基于PyTorch的FourierBlock类,用于傅里叶变换的深度学习模型的一部分。

class FourierBlock(nn.Module):

def init(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):

#在类的构造函数中,调用了父类nn.Module的构造函数,并且初始化了FourierBlock。

super(FourierBlock, self).init()

print('fourier enhanced block used!')

"""

1D Fourier block. It performs representation learning on frequency domain,

it does FFT, linear transform, and Inverse FFT.

"""

get modes on frequency domain

#在FourierBlock的实例中,调用get_frequency_modes函数来获取频率模式的索引。

self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)

print('modes={}, index={}'.format(modes, self.index))

#计算缩放因子来初始化网络权重。

self.scale = (1 / (in_channels * out_channels))

#这两行代码初始化了傅里叶层的两组参数weights1和weights2,它们用于在频率域进行线性变换。

self.weights1 = nn.Parameter(

self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))

self.weights2 = nn.Parameter(

self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))

Complex multiplication

#定义了一个复数乘法的函数,它接受一个特定的排列顺序order,输入x和权重weights

def compl_mul1d(self, order, x, weights):

#设置两个标志,用于跟踪输入x和权重weights是否为复数

x_flag = True

w_flag = True

##如果输入x不是复数,那么创建一个复数版本,其虚部为零

if not torch.is_complex(x):

x_flag = False

x = torch.complex(x, torch.zeros_like(x).to(x.device))

#如果权重weights不是复数,那么创建一个复数版本,其虚部为零。

if not torch.is_complex(weights):

w_flag = False

weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))

#这是复数乘法的实际执行,使用了torch.einsum来进行张量乘法和加法。

if x_flag or w_flag:

return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),

torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))

#如果输入和权重都不是复数,则仅进行实数部分的乘法。

else:

return torch.einsum(order, x.real, weights.real)

#定义了forward函数,这是模型在前向传播时调用的函数。它接受查询q、键k、值v和一个掩码mask作为输入。

def forward(self, q, k, v, mask):

size = [B, L, H, E]

#获取输入查询q的形状,包括批次大小B、序列长度L、头的数量H和嵌入维度E。

B, L, H, E = q.shape

#对查询张量进行排列,改变其维度顺序。

x = q.permute(0, 2, 3, 1)

Compute Fourier coefficients

#使用torch.fft.rfft函数对输入进行实数快速傅里叶变换(RFFT)。

x_ft = torch.fft.rfft(x, dim=-1)

Perform Fourier neural operations

#初始化一个用于存储傅里叶变换结果的零张量,其形状适配了RFFT的输出。

out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)

#遍历频率模式的索引。如果索引超出了傅里叶变换结果的范围,则跳过当前迭代。

for wi, i in enumerate(self.index):

if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:

continue

#对每个频率模式执行复数乘法,并将结果存储在out_ft张量中。

out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],

torch.complex(self.weights1, self.weights2)[:, :, :, wi])

Return to time domain

#使用torch.fft.irfft函数对傅里叶变换的结果进行逆变换,从频域回到时域。

x = torch.fft.irfft(out_ft, n=x.size(-1))

#返回处理后的结果x和一个占位符None,因为通常注意力机制会返回一个额外的输出,例如注意力权重,但在这里并未使用。

return (x, None)

########## Fourier Cross Former

#定义了一个新的神经网络模块类

class FourierCrossAttention(nn.Module):

#这是类的构造函数,用于初始化参数。in_channels和out_channels分别表示输入和输出的通道数。seq_len_q和seq_len_kv分别是查询(Query)和键/值(Key/Value)的序列长度。modes是频率域关注的模式数,mode_select_method是选择这些模式的方法,默认为'random'。activation是用于注意力权重的激活函数,默认为'tanh'。policy和num_heads是模型的超参数,其中num_heads表示多头注意力的头数。

def init(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',

activation='tanh', policy=0, num_heads=8):

#调用父类nn.Module的构造函数进行初始化。

super(FourierCrossAttention, self).init()

print(' fourier enhanced cross attention used!')

"""

1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.

"""

#保存输入输出通道数和激活函数,以便后续使用。

self.activation = activation

self.in_channels = in_channels

self.out_channels = out_channels

get modes for queries and keys (& values) on frequency domain

#通过调用get_frequency_modes函数来确定在频率域中关注哪些频率分量,对于查询和键/值分别存储索引。

self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)

self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)

print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))

print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))

#计算缩放因子,用于权重初始化

self.scale = (1 / (in_channels * out_channels))

self.weights1 = nn.Parameter(

self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))

self.weights2 = nn.Parameter(

self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))

Complex multiplication

#初始化两个可训练参数self.weights1和self.weights2,它们将在傅立叶交叉注意力机制中使用。

#定义了一个复杂数乘法函数compl_mul1d,用于处理复数张量的元素级乘法。

def compl_mul1d(self, order, x, weights):

x_flag = True

w_flag = True

if not torch.is_complex(x):

x_flag = False

x = torch.complex(x, torch.zeros_like(x).to(x.device))

if not torch.is_complex(weights):

w_flag = False

weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))

if x_flag or w_flag:

return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),

torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))

else:

return torch.einsum(order, x.real, weights.real)

#定义了前向传播函数forward,这是数据流经网络模块时的主要入口点。

def forward(self, q, k, v, mask):

size = [B, L, H, E]

B, L, H, E = q.shape

xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]

xk = k.permute(0, 2, 3, 1)

xv = v.permute(0, 2, 3, 1)

Compute Fourier coefficients

xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)

xq_ft = torch.fft.rfft(xq, dim=-1)

for i, j in enumerate(self.index_q):

if j >= xq_ft.shape[3]:

continue

xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]

xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)

xk_ft = torch.fft.rfft(xk, dim=-1)

for i, j in enumerate(self.index_kv):

if j >= xk_ft.shape[3]:

continue

xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]

perform attention mechanism on frequency domain

xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))

if self.activation == 'tanh':

xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())

elif self.activation == 'softmax':

xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)

xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))

else:

raise Exception('{} actiation function is not implemented'.format(self.activation))

xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)

xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))

out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)

for i, j in enumerate(self.index_q):

if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:

continue

out_ft[:, :, :, j] = xqkvw[:, :, :, i]

Return to time domain

out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))

return (out, None)

相关推荐
天天向上杰1 分钟前
通义灵码AI程序员
人工智能·aigc·ai编程
sendnews12 分钟前
AI赋能教育,小猿搜题系列产品携手DeepSeek打造个性化学习新体验
人工智能
查理零世14 分钟前
【蓝桥杯集训·每日一题2025】 AcWing 6134. 哞叫时间II python
python·算法·蓝桥杯
悠然的笔记本15 分钟前
机器学习,我们主要学习什么?
机器学习
紫雾凌寒24 分钟前
解锁机器学习核心算法|神经网络:AI 领域的 “超级引擎”
人工智能·python·神经网络·算法·机器学习·卷积神经网络
WBingJ34 分钟前
2月17日深度学习日记
人工智能
zhengyawen66635 分钟前
深度学习之图像分类(一)
人工智能·深度学习·分类
sun lover36 分钟前
conda简单命令
python·conda
莫莫莫i39 分钟前
拆解微软CEO纳德拉战略蓝图:AI、量子计算、游戏革命如何改写未来规则!
人工智能·微软·量子计算
C#Thread42 分钟前
机器视觉--图像的运算(加法)
图像处理·人工智能·计算机视觉