samout 最新版本state 逐层控制加速收敛

代码

python 复制代码
import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads
        self.win = win
        self.hidden = hidden_dim
        self.mask = torch.triu(torch.ones([win, win])).to("cuda")
        self.layer_nor = torch.nn.LayerNorm(hidden_dim)

    def forward(self, input_data, state=None):
        # self.head.to("cuda")
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

        window = torch.ones([1, w]).to("cuda")

        out = self.head(input_data)

        out = out.unsqueeze(-1) @ window

        out = out.permute([0, 2, 1, 3])

        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to("cuda")
        for i in range(0, s, w):

            state.reshape([state.shape[0], -1])
            j = w + i
            one = out[:, :, i:j]
            _, _, r, c = one.shape
            if r != self.win:

                one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))

            else:
                one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))

            if i == 0:

                one = torch.concat([one, state @ window], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)


            else:

                state1, _ = torch.max(one, axis=2, keepdim=True)

                # state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))
                state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))
                state = state1.permute([0, 2, 1]).unsqueeze(-2) + state
                # state = state.reshape(state1.shape)

                one = torch.concat([one, state], axis=2)
                state, _ = torch.max(one, axis=2, keepdim=True)

            one = state.reshape([b, k, h, w])

            state = state[..., -1:]
            if r != self.win:
                one = one[..., :r]

            one = one.permute([0, 3, 1, 2])
            one_list.append(one)

        out = torch.concat(one_list, 1)

        out = out.reshape([b, s, -1])

        return out, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)

        x = self.layer_norm(self.ffn(x1) + x)  # Feed-Forward with residual connection

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size)
        self.head_state = torch.nn.Linear(hidden_size, num_layers)

    def forward(self, x, state=None, seq_len=None):
        x = self.em(x)
        if x.shape[1] >= 1024:
            pos = self.pos(torch.range(0, x.shape[1] - 1).long() // 1024).unsqueeze(0)
            pos = self.pos(torch.range(0, x.shape[1] - 1).long() % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.range(0, x.shape[1] - 1).long().to("cuda")).unsqueeze(0)

        if state is None:
            state = [None] * len(self.decoder_layers)

        i = 0

        for decoder_layer in self.decoder_layers:
            x1, state[i] = decoder_layer(pos + x, state[i])
            x = x1 + x

            i += 1
        state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))
        return self.head(x), state, state_data


if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net(torch.randint(0, 200, [2, 3000]))

解析

python 复制代码
import torch
import numpy as np

这两行代码导入了PyTorch库和NumPy库,它们分别用于深度学习和数值计算。

python 复制代码
class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

这里定义了一个名为MaxState的PyTorch模块。它继承自torch.nn.Module,这是所有自定义模型的基类。

python 复制代码
    assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

这里使用assert语句确保hidden_dim能够被heads整除,这是多头注意力机制的一个要求。

python 复制代码
    self.head_size = hidden_dim // heads
    self.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
    self.head_num = heads
    self.win = win
    self.hidden = hidden_dim
    self.mask = torch.triu(torch.ones([win, win])).to("cuda")
    self.layer_nor = torch.nn.LayerNorm(hidden_dim)

这里初始化了一些类的属性,包括线性层、头数、窗口大小、隐藏层大小、上三角矩阵掩码以及层归一化。

python 复制代码
    def forward(self, input_data, state=None):

定义了forward方法,这是模型的前向传播过程。

python 复制代码
        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

从输入数据中提取批次大小、序列长度、头数、头大小和窗口大小。

python 复制代码
        window = torch.ones([1, w]).to("cuda")

创建一个窗口大小的一维张量,并将其移动到GPU上。

python 复制代码
        out = self.head(input_data)

对输入数据进行线性变换。

python 复制代码
        out = out.unsqueeze(-1) @ window

将输出数据与窗口张量进行矩阵乘法。

python 复制代码
        out = out.permute([0, 2, 1, 3])

调整输出数据的维度顺序。

python 复制代码
        one_list = []
        if state is None:
            state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")
            state = state.to("cuda")

如果状态为空,则初始化状态张量,并将其移动到GPU上。

python 复制代码
        for i in range(0, s, w):
            # ... (省略中间代码)

对序列进行迭代,每次迭代处理一个窗口大小的数据。

python 复制代码
        return out, state

返回最终输出和状态。

接下来是FeedForwardDecoderLayerSamOut类的定义,它们的结构和MaxState类类似,都是自定义的PyTorch模块。

python 复制代码
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net(torch.randint(0, 200, [2, 3000]))

最后,这段代码实例化了SamOut类,并使用随机生成的输入数据进行了一次前向传播。

相关推荐
weixin_437497771 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端1 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat2 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技2 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪2 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子2 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z2 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人2 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程
橙汁味的风3 小时前
1隐马尔科夫模型HMM与条件随机场CRF
人工智能·深度学习·机器学习
itwangyang5203 小时前
AIDD-人工智能药物设计-AI 制药编码之战:预测癌症反应,选对方法是关键
人工智能