samout超级加速

python 复制代码
import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.h_linear=torch.nn.Parameter(torch.empty(5, 1))
        # torch.nn.init.xavier_uniform_(self.h_linear,0.5)
        self.layer_nor=torch.nn.LayerNorm(hidden_dim)

        self.head_num = heads

        self.hidden = hidden_dim

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size

        out = self.head0(input_data)

        out1 = self.head1(input_data)

        out2 = self.head2(input_data)

        #
        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        # out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])

        out = torch.cummax((out + out1) / h ** 0.5, 2)[0]

        out = out.permute([0, 2, 1, 3])
        out1 = out1.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])
        out1 = out1.reshape([b, s, -1])
        # out = self.layer_nor(out)

        out = out2 * out-out1

        return out, state


class KAttention(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(KAttention, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads

    def forward(self, x, state=None):
        b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
        q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        qk = (q @ k.permute([0, 1, 3, 2])) / d ** 0.5
        mask = torch.triu(torch.ones(s, s).to(device))
        qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
        qkv = torch.nn.functional.softmax(qk, -1) @ v
        #             v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
        qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
        #
        return qkv, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x




class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        # self.self_attention = KAttention(hidden_size, num_heads)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)

        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))

            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0):
        x0, _ = self.one_forward(x0, state=None)

        return x0, _

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #
相关推荐
云和数据.ChenGuang26 分钟前
Django 应用安装脚本 – 如何将应用添加到 INSTALLED_APPS 设置中 原创
数据库·django·sqlite
woshilys1 小时前
sql server 查询对象的修改时间
运维·数据库·sqlserver
Hacker_LaoYi1 小时前
SQL注入的那些面试题总结
数据库·sql
2401_857439692 小时前
SSM 架构下 Vue 电脑测评系统:为电脑性能评估赋能
开发语言·php
建投数据2 小时前
建投数据与腾讯云数据库TDSQL完成产品兼容性互认证
数据库·腾讯云
SoraLuna2 小时前
「Mac畅玩鸿蒙与硬件47」UI互动应用篇24 - 虚拟音乐控制台
开发语言·macos·ui·华为·harmonyos
xlsw_2 小时前
java全栈day20--Web后端实战(Mybatis基础2)
java·开发语言·mybatis
Hacker_LaoYi3 小时前
【渗透技术总结】SQL手工注入总结
数据库·sql
岁月变迁呀3 小时前
Redis梳理
数据库·redis·缓存
独行soc3 小时前
#渗透测试#漏洞挖掘#红蓝攻防#护网#sql注入介绍06-基于子查询的SQL注入(Subquery-Based SQL Injection)
数据库·sql·安全·web安全·漏洞挖掘·hw