代码
python
import torch
import numpy as np
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.win = win
self.hidden = hidden_dim
self.mask = torch.triu(torch.ones([win, win])).to("cuda")
self.layer_nor = torch.nn.LayerNorm(hidden_dim)
def forward(self, input_data, state=None):
# self.head.to("cuda")
b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
window = torch.ones([1, w]).to("cuda")
out = self.head(input_data)
out = out.unsqueeze(-1) @ window
out = out.permute([0, 2, 1, 3])
one_list = []
if state is None:
state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
state = state.to("cuda")
for i in range(0, s, w):
state.reshape([state.shape[0], -1])
j = w + i
one = out[:, :, i:j]
_, _, r, c = one.shape
if r != self.win:
one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))
else:
one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))
if i == 0:
one = torch.concat([one, state @ window], axis=2)
state, _ = torch.max(one, axis=2, keepdim=True)
else:
state1, _ = torch.max(one, axis=2, keepdim=True)
# state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
# state = state.reshape(state1.shape)
one = torch.concat([one, state], axis=2)
state, _ = torch.max(one, axis=2, keepdim=True)
one = state.reshape([b, k, h, w])
state = state[..., -1:]
if r != self.win:
one = one[..., :r]
one = one.permute([0, 3, 1, 2])
one_list.append(one)
out = torch.concat(one_list, 1)
out = out.reshape([b, s, -1])
return out, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
x = x1 * x2
x = self.ffn2(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
self.self_attention = MaxState(hidden_size, num_heads, 8)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def forward(self, x, state=None, seq_len=None):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.ffn(x1) + x) # Feed-Forward with residual connection
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.pos = torch.nn.Embedding(1024, hidden_size)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size)
self.head_state = torch.nn.Linear(hidden_size, num_layers)
def forward(self, x, state=None, seq_len=None):
x = self.em(x)
if x.shape[1] >= 1024:
pos = self.pos(torch.range(0, x.shape[1] - 1).long() // 1024).unsqueeze(0)
pos = self.pos(torch.range(0, x.shape[1] - 1).long() % 1024).unsqueeze(0) + pos
else:
pos = self.pos(torch.range(0, x.shape[1] - 1).long().to("cuda")).unsqueeze(0)
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for decoder_layer in self.decoder_layers:
x1, state[i] = decoder_layer(pos + x, state[i])
x = x1 + x
i += 1
state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))
return self.head(x), state, state_data
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net(torch.randint(0, 200, [2, 3000]))
解析
python
import torch
import numpy as np
这两行代码导入了PyTorch库和NumPy库,它们分别用于深度学习和数值计算。
python
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
这里定义了一个名为MaxState
的PyTorch模块。它继承自torch.nn.Module
,这是所有自定义模型的基类。
python
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
这里使用assert
语句确保hidden_dim
能够被heads
整除,这是多头注意力机制的一个要求。
python
self.head_size = hidden_dim // heads
self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.win = win
self.hidden = hidden_dim
self.mask = torch.triu(torch.ones([win, win])).to("cuda")
self.layer_nor = torch.nn.LayerNorm(hidden_dim)
这里初始化了一些类的属性,包括线性层、头数、窗口大小、隐藏层大小、上三角矩阵掩码以及层归一化。
python
def forward(self, input_data, state=None):
定义了forward
方法,这是模型的前向传播过程。
python
b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win
从输入数据中提取批次大小、序列长度、头数、头大小和窗口大小。
python
window = torch.ones([1, w]).to("cuda")
创建一个窗口大小的一维张量,并将其移动到GPU上。
python
out = self.head(input_data)
对输入数据进行线性变换。
python
out = out.unsqueeze(-1) @ window
将输出数据与窗口张量进行矩阵乘法。
python
out = out.permute([0, 2, 1, 3])
调整输出数据的维度顺序。
python
one_list = []
if state is None:
state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
state = state.to("cuda")
如果状态为空,则初始化状态张量,并将其移动到GPU上。
python
for i in range(0, s, w):
# ... (省略中间代码)
对序列进行迭代,每次迭代处理一个窗口大小的数据。
python
return out, state
返回最终输出和状态。
接下来是FeedForward
、DecoderLayer
和SamOut
类的定义,它们的结构和MaxState
类类似,都是自定义的PyTorch模块。
python
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net(torch.randint(0, 200, [2, 3000]))
最后,这段代码实例化了SamOut
类,并使用随机生成的输入数据进行了一次前向传播。