文章目录
-
- [1. 位置编码](#1. 位置编码)
- [1.2 EncoderLayer](#1.2 EncoderLayer)
- [1.3 Encoder](#1.3 Encoder)
- [1.4 STNDT](#1.4 STNDT)
1. 位置编码
model.py
py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100
class PositionalEncoding(nn.Module):
def __init__(self, trial_length, d_model, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(trial_length, d_model)
position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
if d_model % 2 == 0:
pe[:, 1::2] = torch.cos(position * div_term)
else:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
1.2 EncoderLayer
model.py
核心编码层,加入了将空间注意力编码
py
class STNTransformerEncoderLayer(TransformerEncoderLayer):
def __init__(self, d_model, d_model_s, num_heads=2, dim_feedforward=128, dropout=0.1,
activation='relu'):
super().__init__(
d_model,
nhead=num_heads,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation
)
self.num_heads = num_heads
self.num_input = d_model
self.d_model_s = d_model_s # d_model_s: 时间步数(例如 160), 用于空间自注意力
self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)
self.spatial_norm1 = nn.LayerNorm(d_model_s)
self.ts_norm1 = nn.LayerNorm(d_model)
self.ts_norm2 = nn.LayerNorm(d_model)
self.ts_linear1 = nn.Linear(d_model, dim_feedforward)
self.ts_linear2 = nn.Linear(dim_feedforward, d_model)
self.ts_dropout1 = nn.Dropout(dropout)
self.ts_dropout2 = nn.Dropout(dropout)
self.ts_dropout3 = nn.Dropout(dropout)
def attend(self, src, context_mask=None, **kwargs):
attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)
return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
def spatial_attend(self, src, context_mask=None, **kwargs):
r"""
Attends over spatial dimension
Args:
src: spatiotemporal neural population input
context_mask: spatial context mask
Returns:
spatiotemporal neural population activity transformed by spatial attention
"""
attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)
return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))
def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):
# temporal
residual = src
src = self.norm1(src)
t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)
src = residual + self.dropout1(t_out)
residual = src
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
# spatial
spatial_src = self.spatial_norm1(spatial_src)
spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)
# spatio-temporal feature mixture
ts_residual = src
src = self.ts_norm1(src)
ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)
ts_out = ts_residual + self.ts_dropout1(ts_out)
ts_residual = ts_out
ts_out = self.ts_norm2(ts_out)
ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))
ts_out = ts_residual + self.ts_dropout3(ts_out)
return ts_out
1.3 Encoder
model.py
py
class STNTransformerEncoder(TransformerEncoder):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__(encoder_layer, num_layers, norm)
def forward(self, src, spatial_src, mask=None, spatial_mask=None):
for i, mod in enumerate(self.layers):
if i == 0:
src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)
else:
src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)
if self.norm is not None:
src = self.norm(src)
return src
1.4 STNDT
model.py
py
class SpatioTemporalNDT(nn.Module):
def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3,
dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,
enc_heads=2, enc_dff=128, enc_drop=0.1
) -> None:
super().__init__()
self.src_mask = None
self.num_input = num_neurons
self.num_spatial_input = trial_length
self.embedder = nn.Identity()
self.spatial_embedder = nn.Identity()
self.scale = math.sqrt(num_neurons)
self.spatial_scale = math.sqrt(trial_length)
self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)
self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)
self.projector = nn.Identity()
self.spatial_projector = nn.Identity()
self.n_views = 2
self.temperature = temperature
self.contrast_lambda = c_lambda
self.cel = nn.CrossEntropyLoss(reduction='none')
self.mse = nn.MSELoss(reduction='mean')
encoder_layer =STNTransformerEncoderLayer(
d_model=self.num_input,
d_model_s=self.num_spatial_input,
num_heads=enc_heads,
dim_feedforward=enc_dff,
dropout=enc_drop
)
self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))
self.rate_dropout = nn.Dropout(dropout)
self.src_decoder = nn.Linear(num_neurons, self.num_input)
self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)
def _get_mask(self, src, do_convert=True):
if self.src_mask is not None:
return self.src_mask
size = src.size(0)
context_forward = 13
context_backward = 79
mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)
back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)
mask = mask & back_mask
mask = mask.float()
mask = binary_mask_to_attn_mask(mask)
self.src_mask = mask
return self.src_mask
def forward(self, src: torch.Tensor, mask_labels: torch.Tensor):
src = src.float()
spatial_src = src.permute(2,0,1)
spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scale
spatial_src = self.spatial_pos_encoder(spatial_src)
src = src.permute(1,0,2)
src = self.embedder(src) * self.scale
src = self.src_pos_encoder(src)
src_mask = self._get_mask(src)
spatial_src_mask = None
encoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)
encoder_output = self.rate_dropout(encoder_output)
decoder_output = self.src_decoder(encoder_output)
decoder_rates = decoder_output.permute(1, 0, 2)
decoder_loss = self.classifier(decoder_rates, mask_labels)
masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]
masked_decoder_loss = masked_decoder_loss.mean()
return masked_decoder_loss, decoder_rates
def binary_mask_to_attn_mask(x):
return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))
下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391