时间序列(Time-Series)FourierCorrelation.py代码解析

coding=utf-8

author=maziqing

email=maziqing.mzq@alibaba-inc.com

#这行导入了NumPy库,通常用于科学计算中的数组操作。

import numpy as np

#这行导入了PyTorch库,是一个常用于深度学习的库。

import torch

#这行导入了PyTorch中的nn模块,它包含了构建神经网络所需的类和方法

import torch.nn as nn

#这行定义了一个名为get_frequency_modes的函数,它接受序列长度seq_len,模式数modes,和模式选择方法mode_select_method作为参数。

def get_frequency_modes(seq_len, modes=64, mode_select_method='random'):

"""

get modes on frequency domain:

'random' means sampling randomly;

'else' means sampling the lowest modes;

"""

#这行限制modes的数值不超过seq_len的一半。

modes = min(modes, seq_len // 2)

#这个条件语句根据mode_select_method参数选择不同的模式。如果选择了'random',它会随机选择模式;否则,它会选择最低的模式。

if mode_select_method == 'random':

index = list(range(0, seq_len // 2))

np.random.shuffle(index)

index = index[:modes]

else:

index = list(range(0, modes))

#这两行首先对索引进行排序,然后返回索引列表。

index.sort()

return index

########## fourier layer

#定义了一个基于PyTorch的FourierBlock类,用于傅里叶变换的深度学习模型的一部分。

class FourierBlock(nn.Module):

def init(self, in_channels, out_channels, seq_len, modes=0, mode_select_method='random'):

#在类的构造函数中,调用了父类nn.Module的构造函数,并且初始化了FourierBlock。

super(FourierBlock, self).init()

print('fourier enhanced block used!')

"""

1D Fourier block. It performs representation learning on frequency domain,

it does FFT, linear transform, and Inverse FFT.

"""

get modes on frequency domain

#在FourierBlock的实例中,调用get_frequency_modes函数来获取频率模式的索引。

self.index = get_frequency_modes(seq_len, modes=modes, mode_select_method=mode_select_method)

print('modes={}, index={}'.format(modes, self.index))

#计算缩放因子来初始化网络权重。

self.scale = (1 / (in_channels * out_channels))

#这两行代码初始化了傅里叶层的两组参数weights1和weights2,它们用于在频率域进行线性变换。

self.weights1 = nn.Parameter(

self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))

self.weights2 = nn.Parameter(

self.scale * torch.rand(8, in_channels // 8, out_channels // 8, len(self.index), dtype=torch.float))

Complex multiplication

#定义了一个复数乘法的函数,它接受一个特定的排列顺序order,输入x和权重weights

def compl_mul1d(self, order, x, weights):

#设置两个标志,用于跟踪输入x和权重weights是否为复数

x_flag = True

w_flag = True

##如果输入x不是复数,那么创建一个复数版本,其虚部为零

if not torch.is_complex(x):

x_flag = False

x = torch.complex(x, torch.zeros_like(x).to(x.device))

#如果权重weights不是复数,那么创建一个复数版本,其虚部为零。

if not torch.is_complex(weights):

w_flag = False

weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))

#这是复数乘法的实际执行,使用了torch.einsum来进行张量乘法和加法。

if x_flag or w_flag:

return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),

torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))

#如果输入和权重都不是复数,则仅进行实数部分的乘法。

else:

return torch.einsum(order, x.real, weights.real)

#定义了forward函数,这是模型在前向传播时调用的函数。它接受查询q、键k、值v和一个掩码mask作为输入。

def forward(self, q, k, v, mask):

size = [B, L, H, E]

#获取输入查询q的形状,包括批次大小B、序列长度L、头的数量H和嵌入维度E。

B, L, H, E = q.shape

#对查询张量进行排列,改变其维度顺序。

x = q.permute(0, 2, 3, 1)

Compute Fourier coefficients

#使用torch.fft.rfft函数对输入进行实数快速傅里叶变换(RFFT)。

x_ft = torch.fft.rfft(x, dim=-1)

Perform Fourier neural operations

#初始化一个用于存储傅里叶变换结果的零张量,其形状适配了RFFT的输出。

out_ft = torch.zeros(B, H, E, L // 2 + 1, device=x.device, dtype=torch.cfloat)

#遍历频率模式的索引。如果索引超出了傅里叶变换结果的范围,则跳过当前迭代。

for wi, i in enumerate(self.index):

if i >= x_ft.shape[3] or wi >= out_ft.shape[3]:

continue

#对每个频率模式执行复数乘法,并将结果存储在out_ft张量中。

out_ft[:, :, :, wi] = self.compl_mul1d("bhi,hio->bho", x_ft[:, :, :, i],

torch.complex(self.weights1, self.weights2)[:, :, :, wi])

Return to time domain

#使用torch.fft.irfft函数对傅里叶变换的结果进行逆变换,从频域回到时域。

x = torch.fft.irfft(out_ft, n=x.size(-1))

#返回处理后的结果x和一个占位符None,因为通常注意力机制会返回一个额外的输出,例如注意力权重,但在这里并未使用。

return (x, None)

########## Fourier Cross Former

#定义了一个新的神经网络模块类

class FourierCrossAttention(nn.Module):

#这是类的构造函数,用于初始化参数。in_channels和out_channels分别表示输入和输出的通道数。seq_len_q和seq_len_kv分别是查询(Query)和键/值(Key/Value)的序列长度。modes是频率域关注的模式数,mode_select_method是选择这些模式的方法,默认为'random'。activation是用于注意力权重的激活函数,默认为'tanh'。policy和num_heads是模型的超参数,其中num_heads表示多头注意力的头数。

def init(self, in_channels, out_channels, seq_len_q, seq_len_kv, modes=64, mode_select_method='random',

activation='tanh', policy=0, num_heads=8):

#调用父类nn.Module的构造函数进行初始化。

super(FourierCrossAttention, self).init()

print(' fourier enhanced cross attention used!')

"""

1D Fourier Cross Attention layer. It does FFT, linear transform, attention mechanism and Inverse FFT.

"""

#保存输入输出通道数和激活函数,以便后续使用。

self.activation = activation

self.in_channels = in_channels

self.out_channels = out_channels

get modes for queries and keys (& values) on frequency domain

#通过调用get_frequency_modes函数来确定在频率域中关注哪些频率分量,对于查询和键/值分别存储索引。

self.index_q = get_frequency_modes(seq_len_q, modes=modes, mode_select_method=mode_select_method)

self.index_kv = get_frequency_modes(seq_len_kv, modes=modes, mode_select_method=mode_select_method)

print('modes_q={}, index_q={}'.format(len(self.index_q), self.index_q))

print('modes_kv={}, index_kv={}'.format(len(self.index_kv), self.index_kv))

#计算缩放因子,用于权重初始化

self.scale = (1 / (in_channels * out_channels))

self.weights1 = nn.Parameter(

self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))

self.weights2 = nn.Parameter(

self.scale * torch.rand(num_heads, in_channels // num_heads, out_channels // num_heads, len(self.index_q), dtype=torch.float))

Complex multiplication

#初始化两个可训练参数self.weights1和self.weights2,它们将在傅立叶交叉注意力机制中使用。

#定义了一个复杂数乘法函数compl_mul1d,用于处理复数张量的元素级乘法。

def compl_mul1d(self, order, x, weights):

x_flag = True

w_flag = True

if not torch.is_complex(x):

x_flag = False

x = torch.complex(x, torch.zeros_like(x).to(x.device))

if not torch.is_complex(weights):

w_flag = False

weights = torch.complex(weights, torch.zeros_like(weights).to(weights.device))

if x_flag or w_flag:

return torch.complex(torch.einsum(order, x.real, weights.real) - torch.einsum(order, x.imag, weights.imag),

torch.einsum(order, x.real, weights.imag) + torch.einsum(order, x.imag, weights.real))

else:

return torch.einsum(order, x.real, weights.real)

#定义了前向传播函数forward,这是数据流经网络模块时的主要入口点。

def forward(self, q, k, v, mask):

size = [B, L, H, E]

B, L, H, E = q.shape

xq = q.permute(0, 2, 3, 1) # size = [B, H, E, L]

xk = k.permute(0, 2, 3, 1)

xv = v.permute(0, 2, 3, 1)

Compute Fourier coefficients

xq_ft_ = torch.zeros(B, H, E, len(self.index_q), device=xq.device, dtype=torch.cfloat)

xq_ft = torch.fft.rfft(xq, dim=-1)

for i, j in enumerate(self.index_q):

if j >= xq_ft.shape[3]:

continue

xq_ft_[:, :, :, i] = xq_ft[:, :, :, j]

xk_ft_ = torch.zeros(B, H, E, len(self.index_kv), device=xq.device, dtype=torch.cfloat)

xk_ft = torch.fft.rfft(xk, dim=-1)

for i, j in enumerate(self.index_kv):

if j >= xk_ft.shape[3]:

continue

xk_ft_[:, :, :, i] = xk_ft[:, :, :, j]

perform attention mechanism on frequency domain

xqk_ft = (self.compl_mul1d("bhex,bhey->bhxy", xq_ft_, xk_ft_))

if self.activation == 'tanh':

xqk_ft = torch.complex(xqk_ft.real.tanh(), xqk_ft.imag.tanh())

elif self.activation == 'softmax':

xqk_ft = torch.softmax(abs(xqk_ft), dim=-1)

xqk_ft = torch.complex(xqk_ft, torch.zeros_like(xqk_ft))

else:

raise Exception('{} actiation function is not implemented'.format(self.activation))

xqkv_ft = self.compl_mul1d("bhxy,bhey->bhex", xqk_ft, xk_ft_)

xqkvw = self.compl_mul1d("bhex,heox->bhox", xqkv_ft, torch.complex(self.weights1, self.weights2))

out_ft = torch.zeros(B, H, E, L // 2 + 1, device=xq.device, dtype=torch.cfloat)

for i, j in enumerate(self.index_q):

if i >= xqkvw.shape[3] or j >= out_ft.shape[3]:

continue

out_ft[:, :, :, j] = xqkvw[:, :, :, i]

Return to time domain

out = torch.fft.irfft(out_ft / self.in_channels / self.out_channels, n=xq.size(-1))

return (out, None)

相关推荐
XinZong6 分钟前
【OpenAI】获取OpenAI API Key的多种方式全攻略:从入门到精通,再到详解教程!
人工智能
没有余地 EliasJie8 分钟前
深度学习图像视觉 RKNN Toolkit2 部署 RK3588S边缘端 过程全记录
人工智能·嵌入式硬件·深度学习
亚图跨际15 分钟前
Python和R基因组及蛋白质组学和代谢组学
python·r语言·生物医学
fanyamin22 分钟前
编程语言的局限
开发语言·python
梦醒沉醉24 分钟前
神经网络的正则化(二)
深度学习·神经网络
努力更新中1 小时前
Python浪漫之随机绘制不同颜色的气球
开发语言·python
HelpLook HelpLook1 小时前
高新技术行业中的知识管理:关键性、挑战、策略及工具应用
人工智能·科技·aigc·客服·知识库搭建
__lost1 小时前
Python 将彩色视频转换为黑白视频(MP4-格式可选)
python·opencv·音视频
青松@FasterAI1 小时前
【RAG 项目实战 05】重构:封装代码
人工智能·深度学习·自然语言处理·nlp
chnyi6_ya1 小时前
论文笔记:Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
论文阅读·人工智能·自然语言处理