Simple-STNDT使用Transformer进行Spike信号的表征学习(二)模型结构

文章目录

    • [1. 位置编码](#1. 位置编码)
    • [1.2 EncoderLayer](#1.2 EncoderLayer)
    • [1.3 Encoder](#1.3 Encoder)
    • [1.4 STNDT](#1.4 STNDT)

1. 位置编码

model.py

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100

class PositionalEncoding(nn.Module):
    def __init__(self, trial_length, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(trial_length, d_model)
        position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

1.2 EncoderLayer

model.py

核心编码层,加入了将空间注意力编码

py 复制代码
class STNTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, d_model_s, num_heads=2,  dim_feedforward=128, dropout=0.1, 
                 activation='relu'):
        super().__init__(
            d_model,
            nhead=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation
        )
        self.num_heads = num_heads
        self.num_input = d_model
        self.d_model_s = d_model_s      # d_model_s: 时间步数(例如 160), 用于空间自注意力
        self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)
        self.spatial_norm1 = nn.LayerNorm(d_model_s)
        self.ts_norm1 = nn.LayerNorm(d_model)
        self.ts_norm2 = nn.LayerNorm(d_model)
        self.ts_linear1 = nn.Linear(d_model, dim_feedforward)
        self.ts_linear2 = nn.Linear(dim_feedforward, d_model)
        self.ts_dropout1 = nn.Dropout(dropout)
        self.ts_dropout2 = nn.Dropout(dropout)
        self.ts_dropout3 = nn.Dropout(dropout)
    
    def attend(self, src, context_mask=None, **kwargs):
        attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    def spatial_attend(self, src, context_mask=None, **kwargs):
        r"""
        Attends over spatial dimension
        Args:
            src: spatiotemporal neural population input
            context_mask: spatial context mask
        Returns:
            spatiotemporal neural population activity transformed by spatial attention
        """
        attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)
        return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
    
    def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):
        # temporal
        residual = src
        src = self.norm1(src)
        t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = residual + self.dropout1(t_out)
        residual = src
        src = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = residual + self.dropout2(src2)

        # spatial
        spatial_src = self.spatial_norm1(spatial_src)
        spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)

        # spatio-temporal feature mixture
        ts_residual = src
        src = self.ts_norm1(src)
        ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)
        ts_out = ts_residual + self.ts_dropout1(ts_out)
        ts_residual = ts_out
        ts_out = self.ts_norm2(ts_out)
        ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))
        ts_out = ts_residual + self.ts_dropout3(ts_out)
        
        return ts_out

1.3 Encoder

model.py

py 复制代码
class STNTransformerEncoder(TransformerEncoder):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__(encoder_layer, num_layers, norm)
    
    def forward(self, src, spatial_src, mask=None, spatial_mask=None):
        for i, mod in enumerate(self.layers):
            if i == 0:
                src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)
            else:
                src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)
        if self.norm is not None:
            src = self.norm(src)
        return src

1.4 STNDT

model.py

py 复制代码
class SpatioTemporalNDT(nn.Module):
    def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, 
                 dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,
                 enc_heads=2,  enc_dff=128, enc_drop=0.1
                 ) -> None:
        super().__init__()

        self.src_mask = None
        self.num_input = num_neurons
        self.num_spatial_input = trial_length
        self.embedder = nn.Identity()
        self.spatial_embedder = nn.Identity()
        self.scale = math.sqrt(num_neurons)
        self.spatial_scale = math.sqrt(trial_length)
        self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)
        self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)
        
        self.projector = nn.Identity()
        self.spatial_projector = nn.Identity()
        self.n_views = 2
        self.temperature = temperature
        self.contrast_lambda = c_lambda
        self.cel = nn.CrossEntropyLoss(reduction='none')
        self.mse = nn.MSELoss(reduction='mean')

        encoder_layer =STNTransformerEncoderLayer(
            d_model=self.num_input,
            d_model_s=self.num_spatial_input, 
            num_heads=enc_heads,
            dim_feedforward=enc_dff,
            dropout=enc_drop
        )
        self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))

        self.rate_dropout = nn.Dropout(dropout)
        self.src_decoder = nn.Linear(num_neurons, self.num_input)
        self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)
    
    def _get_mask(self, src, do_convert=True):
        if self.src_mask is not None:
            return self.src_mask
        size = src.size(0)
        context_forward = 13
        context_backward = 79
        mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)
        back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)
        mask = mask & back_mask
        mask = mask.float()
        mask = binary_mask_to_attn_mask(mask)
        self.src_mask = mask
        return self.src_mask
    
    def forward(self, src: torch.Tensor, mask_labels: torch.Tensor):
        src = src.float()
        spatial_src = src.permute(2,0,1)
        spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scale
        spatial_src = self.spatial_pos_encoder(spatial_src)
        src = src.permute(1,0,2)
        src = self.embedder(src) * self.scale
        src = self.src_pos_encoder(src)
        src_mask = self._get_mask(src)
        spatial_src_mask = None
        encoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)
        encoder_output = self.rate_dropout(encoder_output)
        decoder_output = self.src_decoder(encoder_output)
        
        decoder_rates = decoder_output.permute(1, 0, 2)
        decoder_loss = self.classifier(decoder_rates, mask_labels)
        masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]
        masked_decoder_loss = masked_decoder_loss.mean()

        return masked_decoder_loss, decoder_rates


def binary_mask_to_attn_mask(x):
    return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391

相关推荐
warm3snow3 天前
AI 核心技能系列:12 篇文章带你系统掌握大模型岗位必备技能
ai·transformer·agent·skill·mcp·fine-tunning
西岸行者3 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
homelook4 天前
Transformer与电池管理系统(BMS)的结合是当前 智能电池管理 的前沿研究方向
人工智能·深度学习·transformer
悠哉悠哉愿意4 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码4 天前
嵌入式学习路线
学习
毛小茛4 天前
计算机系统概论——校验码
学习
babe小鑫4 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms4 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下4 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。4 天前
2026.2.25监控学习
学习