Simple-STNDT使用Transformer进行Spike信号的表征学习(二)模型结构

文章目录

    • [1. 位置编码](#1. 位置编码)
    • [1.2 EncoderLayer](#1.2 EncoderLayer)
    • [1.3 Encoder](#1.3 Encoder)
    • [1.4 STNDT](#1.4 STNDT)

1. 位置编码

model.py

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100

class PositionalEncoding(nn.Module):
    def __init__(self, trial_length, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(trial_length, d_model)
        position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

1.2 EncoderLayer

model.py

核心编码层,加入了将空间注意力编码

py 复制代码
class STNTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, d_model_s, num_heads=2,  dim_feedforward=128, dropout=0.1, 
                 activation='relu'):
        super().__init__(
            d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation
        )
        self.num_heads = num_heads
        self.num_input = d_model
        self.d_model_s = d_model_s      # d_model_s: 时间步数(例如 160), 用于空间自注意力
        self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)
        self.spatial_norm1 = nn.LayerNorm(d_model_s)
        self.ts_norm1 = nn.LayerNorm(d_model)
        self.ts_norm2 = nn.LayerNorm(d_model)
        self.ts_linear1 = nn.Linear(d_model, dim_feedforward)
        self.ts_linear2 = nn.Linear(dim_feedforward, d_model)
        self.ts_dropout1 = nn.Dropout(dropout)
        self.ts_dropout2 = nn.Dropout(dropout)
        self.ts_dropout3 = nn.Dropout(dropout)
    
    def attend(self, src, context_mask=None, **kwargs):
        attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    def spatial_attend(self, src, context_mask=None, **kwargs):
        r"""
        Attends over spatial dimension
        Args:
            src: spatiotemporal neural population input
            context_mask: spatial context mask
        Returns:
            spatiotemporal neural population activity transformed by spatial attention
        """
        attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    
    def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):
        # temporal
        residual = src
        src = self.norm1(src)
        t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = residual + self.dropout1(t_out)
        residual = src
        src = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = residual + self.dropout2(src2)

        # spatial
        spatial_src = self.spatial_norm1(spatial_src)
        spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)

        # spatio-temporal feature mixture
        ts_residual = src
        src = self.ts_norm1(src)
        ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)
        ts_out = ts_residual + self.ts_dropout1(ts_out)
        ts_residual = ts_out
        ts_out = self.ts_norm2(ts_out)
        ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))
        ts_out = ts_residual + self.ts_dropout3(ts_out)
        
        return ts_out

1.3 Encoder

model.py

py 复制代码
class STNTransformerEncoder(TransformerEncoder):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__(encoder_layer, num_layers, norm)
    
    def forward(self, src, spatial_src, mask=None, spatial_mask=None):
        for i, mod in enumerate(self.layers):
            if i == 0:
                src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)
            else:
                src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)
        if self.norm is not None:
            src = self.norm(src)
        return src

1.4 STNDT

model.py

py 复制代码
class SpatioTemporalNDT(nn.Module):
    def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, 
                 dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,
                 enc_heads=2,  enc_dff=128, enc_drop=0.1
                 ) -> None:
        super().__init__()

        self.src_mask = None
        self.num_input = num_neurons
        self.num_spatial_input = trial_length
        self.embedder = nn.Identity()
        self.spatial_embedder = nn.Identity()
        self.scale = math.sqrt(num_neurons)
        self.spatial_scale = math.sqrt(trial_length)
        self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)
        self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)
        
        self.projector = nn.Identity()
        self.spatial_projector = nn.Identity()
        self.n_views = 2
        self.temperature = temperature
        self.contrast_lambda = c_lambda
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.mse = nn.MSELoss(reduction='mean')

        encoder_layer =STNTransformerEncoderLayer(
            d_model=self.num_input,
            d_model_s=self.num_spatial_input, 
            num_heads=enc_heads,
            dim_feedforward=enc_dff,
            dropout=enc_drop
        )
        self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))

        self.rate_dropout = nn.Dropout(dropout)
        self.src_decoder = nn.Linear(num_neurons, self.num_input)
        self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)
    
    def _get_mask(self, src, do_convert=True):
        if self.src_mask is not None:
            return self.src_mask
        size = src.size(0)
        context_forward = 13
        context_backward = 79
        mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)
        back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)
        mask = mask & back_mask
        mask = mask.float()
        mask = binary_mask_to_attn_mask(mask)
        self.src_mask = mask
        return self.src_mask
    
    def forward(self, src: torch.Tensor, mask_labels: torch.Tensor):
        src = src.float()
        spatial_src = src.permute(2,0,1)
        spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scale
        spatial_src = self.spatial_pos_encoder(spatial_src)
        src = src.permute(1,0,2)
        src = self.embedder(src) * self.scale
        src = self.src_pos_encoder(src)
        src_mask = self._get_mask(src)
        spatial_src_mask = None
        encoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)
        encoder_output = self.rate_dropout(encoder_output)
        decoder_output = self.src_decoder(encoder_output)
        
        decoder_rates = decoder_output.permute(1, 0, 2)
        decoder_loss = self.classifier(decoder_rates, mask_labels)
        masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]
        masked_decoder_loss = masked_decoder_loss.mean()

        return masked_decoder_loss, decoder_rates


def binary_mask_to_attn_mask(x):
    return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391

相关推荐
朝九晚五ฺ5 小时前
【Linux探索学习】第十四弹——进程优先级:深入理解操作系统中的进程优先级
linux·运维·学习
猫爪笔记7 小时前
前端:HTML (学习笔记)【1】
前端·笔记·学习·html
pq113_67 小时前
ftdi_sio应用学习笔记 3 - GPIO
笔记·学习·ftdi_sio
澄澈i7 小时前
设计模式学习[8]---原型模式
学习·设计模式·原型模式
爱米的前端小笔记8 小时前
前端八股自学笔记分享—页面布局(二)
前端·笔记·学习·面试·求职招聘
alikami8 小时前
【前端】前端学习
学习
一只小菜鸡..8 小时前
241118学习日志——[CSDIY] [ByteDance] 后端训练营 [06]
学习
Hacker_Oldv10 小时前
网络安全的学习路线
学习·安全·web安全
蒟蒻的贤10 小时前
vue学习11.21
javascript·vue.js·学习
高 朗10 小时前
【GO基础学习】基础语法(2)切片slice
开发语言·学习·golang·slice