samout 最新版本state 逐层控制加速收敛

代码

python 复制代码
import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads
        self.win = win
        self.hidden = hidden_dim
        self.mask = torch.triu(torch.ones([win, win])).to("cuda")
        self.layer_nor = torch.nn.LayerNorm(hidden_dim)

    def forward(self, input_data, state=None):
        # self.head.to("cuda")
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

        window = torch.ones([1, w]).to("cuda")

        out = self.head(input_data)

        out = out.unsqueeze(-1) @ window

        out = out.permute([0, 2, 1, 3])

        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to("cuda")
        for i in range(0, s, w):

            state.reshape([state.shape[0], -1])
            j = w + i
            one = out[:, :, i:j]
            _, _, r, c = one.shape
            if r != self.win:

                one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))

            else:
                one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))

            if i == 0:

                one = torch.concat([one, state @ window], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)


            else:

                state1, _ = torch.max(one, axis=2, keepdim=True)

                # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
                state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
                state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
                # state = state.reshape(state1.shape)

                one = torch.concat([one, state], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)

            one = state.reshape([b, k, h, w])

            state = state[..., -1:]
            if r != self.win:
                one = one[..., :r]

            one = one.permute([0, 3, 1, 2])
            one_list.append(one)

        out = torch.concat(one_list, 1)

        out = out.reshape([b, s, -1])

        return out, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)

        x = self.layer_norm(self.ffn(x1) + x)  # Feed-Forward with residual connection

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size)
        self.head_state = torch.nn.Linear(hidden_size, num_layers)

    def forward(self, x, state=None, seq_len=None):
        x = self.em(x)
        if x.shape[1] >= 1024:
            pos = self.pos(torch.range(0, x.shape[1] - 1).long() // 1024).unsqueeze(0)
            pos = self.pos(torch.range(0, x.shape[1] - 1).long() % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.range(0, x.shape[1] - 1).long().to("cuda")).unsqueeze(0)

        if state is None:
            state = [None] * len(self.decoder_layers)

        i = 0

        for decoder_layer in self.decoder_layers:
            x1, state[i] = decoder_layer(pos + x, state[i])
            x = x1 + x

            i += 1
        state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))
        return self.head(x), state, state_data


if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net(torch.randint(0, 200, [2, 3000]))

解析

python 复制代码
import torch
import numpy as np

这两行代码导入了PyTorch库和NumPy库,它们分别用于深度学习和数值计算。

python 复制代码
class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

这里定义了一个名为MaxState的PyTorch模块。它继承自torch.nn.Module,这是所有自定义模型的基类。

python 复制代码
    assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

这里使用assert语句确保hidden_dim能够被heads整除,这是多头注意力机制的一个要求。

python 复制代码
    self.head_size = hidden_dim // heads
    self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.head_num = heads
    self.win = win
    self.hidden = hidden_dim
    self.mask = torch.triu(torch.ones([win, win])).to("cuda")
    self.layer_nor = torch.nn.LayerNorm(hidden_dim)

这里初始化了一些类的属性,包括线性层、头数、窗口大小、隐藏层大小、上三角矩阵掩码以及层归一化。

python 复制代码
    def forward(self, input_data, state=None):

定义了forward方法,这是模型的前向传播过程。

python 复制代码
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

从输入数据中提取批次大小、序列长度、头数、头大小和窗口大小。

python 复制代码
        window = torch.ones([1, w]).to("cuda")

创建一个窗口大小的一维张量,并将其移动到GPU上。

python 复制代码
        out = self.head(input_data)

对输入数据进行线性变换。

python 复制代码
        out = out.unsqueeze(-1) @ window

将输出数据与窗口张量进行矩阵乘法。

python 复制代码
        out = out.permute([0, 2, 1, 3])

调整输出数据的维度顺序。

python 复制代码
        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to("cuda")

如果状态为空,则初始化状态张量,并将其移动到GPU上。

python 复制代码
        for i in range(0, s, w):
            # ... (省略中间代码)

对序列进行迭代,每次迭代处理一个窗口大小的数据。

python 复制代码
        return out, state

返回最终输出和状态。

接下来是FeedForwardDecoderLayerSamOut类的定义,它们的结构和MaxState类类似,都是自定义的PyTorch模块。

python 复制代码
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net(torch.randint(0, 200, [2, 3000]))

最后,这段代码实例化了SamOut类,并使用随机生成的输入数据进行了一次前向传播。

相关推荐
川石课堂软件测试10 分钟前
性能测试|docker容器下搭建JMeter+Grafana+Influxdb监控可视化平台
运维·javascript·深度学习·jmeter·docker·容器·grafana
985小水博一枚呀18 分钟前
【深度学习滑坡制图|论文解读3】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer
AltmanChan19 分钟前
大语言模型安全威胁
人工智能·安全·语言模型
985小水博一枚呀23 分钟前
【深度学习滑坡制图|论文解读2】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer·迁移学习
数据与后端架构提升之路32 分钟前
从神经元到神经网络:深度学习的进化之旅
人工智能·神经网络·学习
爱技术的小伙子38 分钟前
【ChatGPT】如何通过逐步提示提高ChatGPT的细节描写
人工智能·chatgpt
深度学习实战训练营2 小时前
基于CNN-RNN的影像报告生成
人工智能·深度学习
昨日之日20064 小时前
Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
人工智能·whisper·语音识别
浮生如梦_4 小时前
Halcon基于laws纹理特征的SVM分类
图像处理·人工智能·算法·支持向量机·计算机视觉·分类·视觉检测
深度学习lover4 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别