Simple-STNDT使用Transformer进行Spike信号的表征学习(二)模型结构

文章目录

    • [1. 位置编码](#1. 位置编码)
    • [1.2 EncoderLayer](#1.2 EncoderLayer)
    • [1.3 Encoder](#1.3 Encoder)
    • [1.4 STNDT](#1.4 STNDT)

1. 位置编码

model.py

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100

class PositionalEncoding(nn.Module):
    def __init__(self, trial_length, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(trial_length, d_model)
        position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

1.2 EncoderLayer

model.py

核心编码层,加入了将空间注意力编码

py 复制代码
class STNTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, d_model_s, num_heads=2,  dim_feedforward=128, dropout=0.1, 
                 activation='relu'):
        super().__init__(
            d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation
        )
        self.num_heads = num_heads
        self.num_input = d_model
        self.d_model_s = d_model_s      # d_model_s: 时间步数(例如 160), 用于空间自注意力
        self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)
        self.spatial_norm1 = nn.LayerNorm(d_model_s)
        self.ts_norm1 = nn.LayerNorm(d_model)
        self.ts_norm2 = nn.LayerNorm(d_model)
        self.ts_linear1 = nn.Linear(d_model, dim_feedforward)
        self.ts_linear2 = nn.Linear(dim_feedforward, d_model)
        self.ts_dropout1 = nn.Dropout(dropout)
        self.ts_dropout2 = nn.Dropout(dropout)
        self.ts_dropout3 = nn.Dropout(dropout)
    
    def attend(self, src, context_mask=None, **kwargs):
        attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    def spatial_attend(self, src, context_mask=None, **kwargs):
        r"""
        Attends over spatial dimension
        Args:
            src: spatiotemporal neural population input
            context_mask: spatial context mask
        Returns:
            spatiotemporal neural population activity transformed by spatial attention
        """
        attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    
    def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):
        # temporal
        residual = src
        src = self.norm1(src)
        t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = residual + self.dropout1(t_out)
        residual = src
        src = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = residual + self.dropout2(src2)

        # spatial
        spatial_src = self.spatial_norm1(spatial_src)
        spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)

        # spatio-temporal feature mixture
        ts_residual = src
        src = self.ts_norm1(src)
        ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)
        ts_out = ts_residual + self.ts_dropout1(ts_out)
        ts_residual = ts_out
        ts_out = self.ts_norm2(ts_out)
        ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))
        ts_out = ts_residual + self.ts_dropout3(ts_out)
        
        return ts_out

1.3 Encoder

model.py

py 复制代码
class STNTransformerEncoder(TransformerEncoder):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__(encoder_layer, num_layers, norm)
    
    def forward(self, src, spatial_src, mask=None, spatial_mask=None):
        for i, mod in enumerate(self.layers):
            if i == 0:
                src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)
            else:
                src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)
        if self.norm is not None:
            src = self.norm(src)
        return src

1.4 STNDT

model.py

py 复制代码
class SpatioTemporalNDT(nn.Module):
    def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, 
                 dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,
                 enc_heads=2,  enc_dff=128, enc_drop=0.1
                 ) -> None:
        super().__init__()

        self.src_mask = None
        self.num_input = num_neurons
        self.num_spatial_input = trial_length
        self.embedder = nn.Identity()
        self.spatial_embedder = nn.Identity()
        self.scale = math.sqrt(num_neurons)
        self.spatial_scale = math.sqrt(trial_length)
        self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)
        self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)
        
        self.projector = nn.Identity()
        self.spatial_projector = nn.Identity()
        self.n_views = 2
        self.temperature = temperature
        self.contrast_lambda = c_lambda
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.mse = nn.MSELoss(reduction='mean')

        encoder_layer =STNTransformerEncoderLayer(
            d_model=self.num_input,
            d_model_s=self.num_spatial_input, 
            num_heads=enc_heads,
            dim_feedforward=enc_dff,
            dropout=enc_drop
        )
        self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))

        self.rate_dropout = nn.Dropout(dropout)
        self.src_decoder = nn.Linear(num_neurons, self.num_input)
        self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)
    
    def _get_mask(self, src, do_convert=True):
        if self.src_mask is not None:
            return self.src_mask
        size = src.size(0)
        context_forward = 13
        context_backward = 79
        mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)
        back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)
        mask = mask & back_mask
        mask = mask.float()
        mask = binary_mask_to_attn_mask(mask)
        self.src_mask = mask
        return self.src_mask
    
    def forward(self, src: torch.Tensor, mask_labels: torch.Tensor):
        src = src.float()
        spatial_src = src.permute(2,0,1)
        spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scale
        spatial_src = self.spatial_pos_encoder(spatial_src)
        src = src.permute(1,0,2)
        src = self.embedder(src) * self.scale
        src = self.src_pos_encoder(src)
        src_mask = self._get_mask(src)
        spatial_src_mask = None
        encoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)
        encoder_output = self.rate_dropout(encoder_output)
        decoder_output = self.src_decoder(encoder_output)
        
        decoder_rates = decoder_output.permute(1, 0, 2)
        decoder_loss = self.classifier(decoder_rates, mask_labels)
        masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]
        masked_decoder_loss = masked_decoder_loss.mean()

        return masked_decoder_loss, decoder_rates


def binary_mask_to_attn_mask(x):
    return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391

相关推荐
数据智能老司机3 小时前
PyTorch 深度学习——使用神经网络来拟合数据
pytorch·深度学习
数据智能老司机3 小时前
PyTorch 深度学习——用于图像的扩散模型
pytorch·深度学习
数据智能老司机3 小时前
PyTorch 深度学习——Transformer 是如何工作的
pytorch·深度学习
数据智能老司机1 天前
PyTorch 深度学习——使用张量表示真实世界数据
pytorch·深度学习
数据智能老司机1 天前
PyTorch 深度学习——它始于一个张量
pytorch·深度学习
Narrastory3 天前
明日香 - Pytorch 快速入门保姆级教程(三)
pytorch·深度学习
Narrastory6 天前
明日香 - Pytorch 快速入门保姆级教程(一)
人工智能·pytorch·深度学习
Narrastory6 天前
明日香 - Pytorch 快速入门保姆级教程(二)
人工智能·pytorch·深度学习
warm3snow10 天前
AI 核心技能系列:12 篇文章带你系统掌握大模型岗位必备技能
ai·transformer·agent·skill·mcp·fine-tunning
西岸行者11 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习