人体骨架识别文献阅读——ST-TR:基于时空Transformer网络的骨架动作识别

目录

摘要

Abstract

[1 引言](#1 引言)

[2 时空Transformer网络](#2 时空Transformer网络)

[2.1 空间自注意力(SSA)](#2.1 空间自注意力(SSA))

[2.2 时间自注意力(TSA)](#2.2 时间自注意力(TSA))

[2.3 双流时空Transformer网络](#2.3 双流时空Transformer网络)

[2.4 SSA和TSA的实现](#2.4 SSA和TSA的实现)

[3 ST-TR网络结构](#3 ST-TR网络结构)

总结


摘要

本周阅读的论文题目是《Skeleton-based action recognition via spatial and temporal transformer networks》(《基于时空Transformer网络的骨架动作识别》)。在前几周中学习的ST-GCN以及基于ST-GCN做出改进的2s-AGCN和DGNN在骨骼图这类非欧几里数据上的空间和时间依赖性方面是有效的。但是依旧不能对3D骨骼中潜在信息的有效编码进行提取。由此,本文中提出了一种新的时空变换器网络(ST-TR),该网络使用Transformer自注意力算子来建模关节之间的依赖关系,使用空间自注意力模块(SSA)来理解不同身体部分之间的帧内交互,以及使用时间自注意力模块(TSA)来建模帧间相关性。然后将SSA和TSA在双流网络中结合,持续提升了骨骼图的识别结果和性能。

Abstract

The title of the paper read this week is "Skeleton-based action recognition via spatial and temporal transformer networks" . In the previous few weeks, the ST-GCN and the improved 2s-AGCN and DGNN based on ST-GCN have been effective in handling the spatial and temporal dependencies on non-Euclidean data such as skeletal graphs. However, they still cannot effectively extract the potential information encoding in 3D skeletons. Therefore, this paper proposes a new spatial-temporal transformer network (ST-TR), which uses transformer self-attention operators to model the dependency relationships between joints, uses spatial self-attention modules (SSA) to understand the intra-frame interactions between different body parts, and uses temporal self-attention modules (TSA) to model inter-frame correlations. Then, the SSA and TSA are combined in a dual-stream network, continuously improving the recognition results and performance of skeletal graphs.

文献链接🔗:Skeleton-based action recognition via spatial and temporal transformer networks

1 引言

如今,基于骨架的动作识别最广泛的方法是GNN,特别是GCN,因为它们是非欧几里得数据的有效表示,能够有效地捕获空间(帧内)和时间(帧间)信息。以ST-GCN为首的GCN模型被开始引入到基于骨架的动作识别中,ST-GCN通过在空间上操作骨骼骨连接来处理空间信息,并且通过考虑每个骨骼关节沿时间额外的时序连接来获取信息。尽管在骨骼数据上已被证明表现良好,ST-GCN模型存在一些结构限制,其中一些已被2s-AGCN、DGNN等算法解决,但是仍然存在一些问题:

  • 表示人体的图拓扑对所有层和所有动作都是固定的。这可能会阻止在时间上提取丰富的骨骼运动表示,尤其是在图连接是定向的,信息只能沿着预定义路径流动的情况下;
  • 空间和时序卷积都是从标准的2D卷积实现的。因此,它们被限制在局部邻域内操作,某种程度上受到卷积核大小的限制;
  • 在与"拍手"等动作相关的情况下,人体骨骼中未连接的关节之间的相关性会被低估。

由此,本文通过采用改进的Transformer自注意力算子来应对所有这些问题和限制,如下图所示,在骨骼关节上采用自注意力:

步骤如下:

  1. 对于每个身体关节,计算一个查询 、一个键 和一个值向量
  2. 执行关节查询与所有其他节点键的点积(),表示每对节点之间的连接强度;
  3. 每个节点通过其与当前节点的相关性进行缩放;
  4. 通过将加权节点相加获得其新特征。

人类骨骼序列的顺序性和层次结构,以及Transformer自注意力在建模长距离依赖关系中的灵活性,使Transformer成为解决ST-GCN弱点的完美解决方案。在本文中旨在将Transformer应用于时空骨骼架构,特别是应用于代表人类骨骼的关节,目标是通过对空间通过空间自注意力(SSA)模块和时间通过时间自注意力(TSA)模块建模,来模拟人类动作中的长距离交互,以取得优异的性能。

2 时空Transformer网络

本文中提出的空间-时间Transformer(ST-TR)网络是一种使用Transformer自注意力机制在空间和时间维度上同时操作的架构。通过使用两个模块来实现这一目标,即空间自注意力(SSA)模块和时序自注意力(TSA)模块,每个模块专注于提取两个维度之一的相关性。

原始Transformer自注意力背后的想法是允许编码句子中单词之间的短距离和长距离相关性。所以,同样的方法也可以应用于基于骨架的动作识别,因为节点在空间和时间维度上的相关性都至关重要。通过将构成骨架的关节视为一个词袋,并利用Transformer自注意力来提取节点嵌入,这些嵌入编码了周围关节之间的关系,就像NLP中短语中的单词一样。

与标准图卷积不同,在标准图卷积中只有相邻节点被比较,而ST-TR放弃了任何预定义的骨架结构,让Transformer自注意力自动发现对预测当前动作相关的关节关系。该操作的作用类似于图卷积,但其中的核值是基于发现的关节关系动态预测的。 同一思想也应用于序列层面,通过分析每个关节在动作中的变化并建立跨越不同帧的长期关系,类似于在自然语言处理中构建短语之间的关系。结果操作符能够获得在空间和时间维度上扩展的动态表示。

2.1 空间自注意力(SSA)

空间自注意力模块在每个帧内应用自注意力,以提取低级特征,嵌入身体部分之间的关系。这是通过独立计算每帧中每对关节之间的相关性来实现的,如下图所示:

其中,给定时间 的帧,对于骨骼中的每个节点

  1. 首先,通过应用可训练的线性变换到节点特征 (参数 ,所有节点共享)来计算查询向量 、键向量 和值向量
  2. 然后,对于每对身体节点 应用查询-键点积以获得权重 ,表示两个节点之间的相关性强度;
  3. 最后,使用得到的分数 来加权每个关节值 ,并计算加权求和以获得节点 的新嵌入 (输出通道数为)。

代码如下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
multi_matmul = False

class spatial_attention(nn.Module):
    def __init__(self, in_channels, kernel_size, dk, dv, Nh, complete, relative, layer, A, more_channels, drop_connect,
                 adjacency, num, num_point,
                 shape=25, stride=1,
                 last_graph=False, data_normalization=True, skip_conn=True, visualization=True):
        super(spatial_attention, self).__init__()
        self.in_channels = in_channels
        self.complete = complete
        self.kernel_size = 1
        self.dk = dk
        self.dv = dv
        self.num = num
        self.layer = layer
        self.more_channels = more_channels
        self.drop_connect = drop_connect
        self.visualization = visualization
        self.data_normalization = data_normalization
        self.skip_conn = skip_conn
        self.adjacency = adjacency
        self.Nh = Nh
        self.num_point=num_point
        self.A = A[0] + A[1] + A[2]
        if self.adjacency:
            self.mask = nn.Parameter(torch.ones(self.A.size()))
        self.shape = shape
        self.relative = relative
        self.last_graph = last_graph
        self.stride = stride
        self.padding = (self.kernel_size - 1) // 2

        assert self.Nh != 0, "整数除法或以零为模数,Nh>=1"
        assert self.dk % self.Nh == 0, "dk应除以Nh。(例如:out_channels:20,dk:40,Nh:4)"
        assert self.dv % self.Nh == 0, "dv应除以Nh. (例如: out_channels: 20, dv: 4, Nh: 4)"
        assert stride in [1, 2], str(stride) + " Up to 2 strides are allowed."


        if (self.more_channels):

            self.qkv_conv = nn.Conv2d(self.in_channels, (2 * self.dk + self.dv) * self.Nh // self.num,
                                      kernel_size=self.kernel_size,
                                      stride=stride,
                                      padding=self.padding)
        else:
            self.qkv_conv = nn.Conv2d(self.in_channels, 2 * self.dk + self.dv, kernel_size=self.kernel_size,
                                      stride=stride,
                                      padding=self.padding)
        if (self.more_channels):

            self.attn_out = nn.Conv2d(self.dv * self.Nh // self.num, self.dv, kernel_size=1, stride=1)
        else:
            self.attn_out = nn.Conv2d(self.dv, self.dv, kernel_size=1, stride=1)

        if self.relative:
            # 初始化两个参数以实现相对位置编码
            # 在对角线上重复一个权重
            # V^2-V+1 参数位于对角线之外的位置
            if self.more_channels:
                self.key_rel = nn.Parameter(torch.randn(((self.num_point ** 2) - self.num_point, self.dk // self.num), requires_grad=True))
            else:
                self.key_rel = nn.Parameter(torch.randn(((self.num_point ** 2) - self.num_point, self.dk // Nh), requires_grad=True))
            if self.more_channels:
                self.key_rel_diagonal = nn.Parameter(torch.randn((1, self.dk // self.num), requires_grad=True))
            else:
                self.key_rel_diagonal = nn.Parameter(torch.randn((1, self.dk // self.Nh), requires_grad=True))

    def forward(self, x, label, name):
        # 输入x
        # (batch_size, channels, 1, joints)
        B, _, T, V = x.size()

        # 计算查询、键和值
        # flat_q, flat_k, flat_v
        # (batch_size, Nh, dvh or dkh, joints)
        # dvh = dv / Nh, dkh = dk / Nh
        # q, k, v obtained by doing 2D convolution on the input (q=XWq, k=XWk, v=XWv)
        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)

        # 计算注意力分数,通过执行q*k
        # (batch_size, Nh, joints, dkh)*(batch_size, Nh, dkh, joints) =  (batch_size, Nh, joints,joints)
        # 如果空间问题,乘法也可以被除以 (multi_matmul)
        if (multi_matmul):
            for i in range(0, 5):
                flat_q_5 = flat_q[:, :, :, (5 * i):(5 * (i + 1))]
                product = torch.matmul(flat_q_5.transpose(2, 3), flat_k)
                if (i == 0):
                    logits = product
                else:
                    logits = torch.cat((logits, product), dim=2)
        else:
            logits = torch.matmul(flat_q.transpose(2, 3), flat_k)

        # 邻接矩阵被加权并添加到transformer的attention logits中去
        # 原始骨架结构的信息
        if (self.adjacency):
            self.A = self.A.cuda(device)
            logits = logits.reshape(-1, V, V)
            M, V, V = logits.shape
            A = self.A
            A *= self.mask
            A = A.unsqueeze(0).expand(M, V, V)
            logits = logits+A
            logits = logits.reshape(B, self.Nh, V, V)

        # 使用或不使用相对位置编码
        if self.relative:
            rel_logits = self.relative_logits(q)
            logits_sum = torch.add(logits, rel_logits)

        # 计算注意力权重
        if self.relative:
            weights = F.softmax(logits_sum, dim=-1)
        else:
            weights = F.softmax(logits, dim=-1)

        # Drop connect实现以避免过拟合
        if (self.drop_connect and self.training):
            mask = torch.bernoulli((0.5) * torch.ones(B * self.Nh * V, device=device))
            mask = mask.reshape(B, self.Nh, V).unsqueeze(2).expand(B, self.Nh, V, V)
            weights = weights * mask
            weights = weights / (weights.sum(3, keepdim=True) + 1e-8)



        # attn_out
        # (batch, Nh, joints, dvh)
        # weights*V
        # (batch, Nh, joints, joints)*(batch, Nh, joints, dvh)=(batch, Nh, joints, dvh)
        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))

        if not self.more_channels:
            attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.Nh))
        else:
            attn_out = torch.reshape(attn_out, (B, self.Nh, T, V, self.dv // self.num))

        attn_out = attn_out.permute(0, 1, 4, 2, 3)

        # combine_heads_2d,仅在每个 Z 分别计算后合并头部
        # (batch, Nh*dv, 1, joints)
        attn_out = self.combine_heads_2d(attn_out)

        # 乘以 W0(批次,输出通道,1,关节)其中输出通道=dv
        attn_out = self.attn_out(attn_out)
        return attn_out

    def compute_flat_qkv(self, x, dk, dv, Nh):
        qkv = self.qkv_conv(x)
        # T=1 在此情况下,因为正在分别考虑每一帧
        N, _, T, V = qkv.size()

        # 如果 self.more_channels=True,则每个头分配 dk*self.Nh//self.num 个通道
        if self.more_channels:
            q, k, v = torch.split(qkv, [dk * self.Nh // self.num, dk * self.Nh // self.num, dv * self.Nh // self.num],
                                  dim=1)
        else:
            q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)
        q = self.split_heads_2d(q, Nh)
        k = self.split_heads_2d(k, Nh)
        v = self.split_heads_2d(v, Nh)

        dkh = dk // Nh
        q = q*(dkh ** -0.5)
        if self.more_channels:
            flat_q = torch.reshape(q, (N, Nh, dk // self.num, T * V))
            flat_k = torch.reshape(k, (N, Nh, dk // self.num, T * V))
            flat_v = torch.reshape(v, (N, Nh, dv // self.num, T * V))
        else:
            flat_q = torch.reshape(q, (N, Nh, dkh, T * V))
            flat_k = torch.reshape(k, (N, Nh, dkh, T * V))
            flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, T * V))
        return flat_q, flat_k, flat_v, q, k, v

    def split_heads_2d(self, x, Nh):
        B, channels, T, V = x.size()
        ret_shape = (B, Nh, channels // Nh, T, V)
        split = torch.reshape(x, ret_shape)
        return split

    def combine_heads_2d(self, x):
        batch, Nh, dv, T, V = x.size()
        ret_shape = (batch, Nh * dv, T, V)
        return torch.reshape(x, ret_shape)

    def relative_logits(self, q):
        B, Nh, dk, T, V = q.size()
        q = torch.transpose(q, 2, 4).transpose(2, 3)
        q_first = q.unsqueeze(4).expand((B, Nh, T, V, V - 1, dk))
        q_first = torch.reshape(q_first, (B * Nh * T, -1, dk))

        # q 用于对角线参数的嵌入进行乘法
        q = torch.reshape(q, (B * Nh * T, V, dk))
        # key_rel_diagonal: (1, dk) -> (V, dk)
        param_diagonal = self.key_rel_diagonal.expand((V, dk))
        rel_logits = self.relative_logits_1d(q_first, q, self.key_rel, param_diagonal, T, V, Nh)
        return rel_logits

    def relative_logits_1d(self, q_first, q, rel_k, param_diagonal, T, V, Nh):
        # 计算一维上的相对对数
        # (B*Nh*1,V^2-V, self.dk // Nh)*(V^2 - V, self.dk // Nh)

        # (B*Nh*1, V^2-V)
        rel_logits = torch.einsum('bmd,md->bm', q_first, rel_k)
        # (B*Nh*1, V)
        rel_logits_diagonal = torch.einsum('bmd,md->bm', q, param_diagonal)

        # 重塑以获得 Srel
        rel_logits = self.rel_to_abs(rel_logits, rel_logits_diagonal)

        rel_logits = torch.reshape(rel_logits, (-1, Nh, V, V))
        return rel_logits

    def rel_to_abs(self, rel_logits, rel_logits_diagonal):
        B, L = rel_logits.size()
        B, V = rel_logits_diagonal.size()

        # (B, V-1, V) -> (B, V, V)
        rel_logits = torch.reshape(rel_logits, (B, V - 1, V))
        row_pad = torch.zeros(B, 1, V).to(rel_logits)
        rel_logits = torch.cat((rel_logits, row_pad), dim=1)

        # 连接左侧的其他嵌入
        # (B, V, V) -> (B, V, V+1) -> (B, V+1, V)
        rel_logits_diagonal = torch.reshape(rel_logits_diagonal, (B, V, 1))
        rel_logits = torch.cat((rel_logits_diagonal, rel_logits), dim=2)
        rel_logits = torch.reshape(rel_logits, (B, V + 1, V))

        # slice
        flat_sliced = rel_logits[:, :V, :]
        final_x = torch.reshape(flat_sliced, (B, V, V))
        return final_x

多头注意力通过重复执行此嵌入提取过程 次来实现,每次使用一组不同的可学习参数。因此获得的节点嵌入集 ,所有这些都指代相同的节点 ,然后与一个可学习的变换相结合,即 ,并构成SSA的输出特征。

需要注意的是:

  • SSA在完全连接图上类似于图卷积操作,但是核值(即节点之间的关系 分数)在SSA中是基于骨架姿态动态预测;
  • 骨架中的相关结构对于所有动作不是固定的,而是对每个样本自适应地变化。

2.2 时间自注意力(TSA)

时间自注意力(TSA)模块,分别研究每个关节在所有帧中的动态,即每个单独的关节被视为独立,通过比较同一身体关节在时间维度上的嵌入变化来计算帧之间的相关性,如下图所示:

TSA的公式与SSA公式对称:

其中:

  • 表示相同的关节点 在两个不同时刻
  • 是相关性得分;
  • 是与 相关的查询, 是与关节 相关的键和值(所有计算均使用可训练的线性变换,如 SSA 中所示);
  • 是结果节点嵌入。

可以看出,TSA使用的符号与SSA中使用的符号相反:下标表示时间,而上标表示关节。在TSA中同样也应用如 SSA 中所示的多头注意力。

TSA 模块通过提取时间节点之间的帧间关系,可以学习如何关联不同的帧(例如,第一帧中的节点与最后一帧中的节点),捕捉到标准ST-GCN卷积无法捕捉到的判别特征,因为这种卷积受内核大小的限制。

2.3 双流时空Transformer网络

为了结合SSA和TSA模块,所以采用了一种名为ST-TR的双流架构,在本文的公式中,双流区分所提出的自注意力机制的应用方式:SSA作用于空间流(称为 S-TR),而TSA作用于时间流(称为 T-TR),如下图所示:

在S-TR和T-TR这两个流中:

  1. 节点特征首先通过一个三层残差网络提取低级特征,其中每一层通过图卷积(GCN)在空间维度上处理输入,通过标准2D卷积(TCN)在时间维度上处理输入;
  2. 然后,SSA和TSA分别在STR和T-TR流中的后续层应用于替代GCN和TCN特征提取模块,S-TR流和T-TR流及其相应的特征提取层分别进行端到端训练;
  3. 最终,通过将它们的softmax输出分数相加,将子网络的输出融合在一起,以获得最终的预测。

在S-TR中,通过SSA模块在骨骼级别应用自注意力,该模块关注关节之间的空间关系,SSA模块的输出传递到具有核 的2D卷积模块(TCN),以提取时间相关的特征,如下所示:

2D卷积单元代码如下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math


class Unit2D(nn.Module):
    def __init__(self,
                 D_in,
                 D_out,
                 kernel_size,
                 stride=1,
                 dim=2,
                 dropout=0,
                 bias=True):
        super(Unit2D, self).__init__()
        pad = int((kernel_size - 1) / 2)
        print("Pad Temporal ", pad)

        if dim == 2:
            self.conv = nn.Conv2d(
                D_in,
                D_out,
                kernel_size=(kernel_size, 1),
                padding=(pad, 0),
                stride=(stride, 1),
                bias=bias)
        elif dim == 3:
            print("Pad Temporal ", pad)
            self.conv = nn.Conv2d(
                D_in,
                D_out,
                kernel_size=(1, kernel_size),
                padding=(0, pad),
                stride=(1, stride),
                bias=bias)
        else:
            raise ValueError()

        self.bn = nn.BatchNorm2d(D_out)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout, inplace=True)

        # initialize
        conv_init(self.conv)

    def forward(self, x):
        x = self.dropout(x)
        x = self.relu(self.bn(self.conv(x)))
        return x


def conv_init(module):
    # he_normal
    n = module.out_channels
    for k in module.kernel_size:
        n = n*k
    module.weight.data.normal_(0, math.sqrt(2. / n))


def import_class(name):
    components = name.split('.')
    mod = __import__(components[0])
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod

遵循原始Transformer,输入通过批量归一化层,并使用跳过连接将输入与 SSA 模块的输出相加。

相反,T-TR专注于发现帧间时间关系。类似于S-TR流,在每个T-TR层内部是一个标准的图卷积子模块,之后跟随所提出的时序自注意力模块:

TSA在所有时间维度上连接相同关节的图上运行。

2.4 SSA和TSA的实现

SSA的矩阵实现是基于对像素的Transformer实现的改进,如下图所示:

输入𝑓通过在批次维度中移动𝑇进行重塑,以便自注意力分别对每个时间帧进行操作。SSA 作为矩阵乘法实现,其中𝐐、𝐊和𝐕分别是查询、键和值矩阵,⊗表示矩阵乘法。

  • 首先,给定一个形状为 的输入张量,其中是输入特征的数量, 是帧数, 是节点数;
  • 再通过批归一化然后重新排列输入得到一个 的矩阵,在这里,将 维度移动到批处理维度内部,有效地在时间维度上实现了参数共享;
  • 然后,对每个帧分别应用变换:
  • 其中,根据权重 分别产生 是头部的数量, 是一个可学习的线性变换,它结合了头部的输出;
  • 最后,空间变换器的输出被重新排列,得到 的输出张量

代码如下:

python 复制代码
import torch
import torch.nn as nn
from torch.autograd import Variable
from net import conv_init
from spatial_transformer import spatial_attention
import numpy as np

scale_norm = False


class gcn_unit_attention(nn.Module):
    def __init__(self, in_channels, out_channels, incidence, num, dv_factor, dk_factor, Nh, complete, relative,
                 only_attention, layer, more_channels, drop_connect, data_normalization, skip_conn, adjacency, num_point, padding=0,
                 kernel_size=1,
                 stride=1, bn_flag=True,
                 t_dilation=1, last_graph=False, visualization=True):
        super().__init__()
        # 初始化邻接矩阵
        self.incidence = incidence
        self.incidence = incidence
        self.relu = nn.ReLU()
        self.visualization=visualization
        self.in_channels = in_channels
        self.more_channels = more_channels
        self.drop_connect = drop_connect
        self.data_normalization=data_normalization
        self.skip_conn=skip_conn
        self.num_point=num_point
        self.adjacency = adjacency
        print("Nh ", Nh)
        print("Dv ", dv_factor)
        print("Dk ", dk_factor)

        self.last_graph = last_graph
        # 是否仅使用注意力机制
        if (not only_attention):
            self.out_channels = out_channels - int((out_channels) * dv_factor)
        else:
            self.out_channels = out_channels
        # 数据归一化层
        self.data_bn = nn.BatchNorm1d(self.in_channels * self.num_point)
        self.bn = nn.BatchNorm2d(out_channels)
        self.only_attention = only_attention
        self.bn_flag = bn_flag
        self.layer = layer

        # 邻接矩阵转换
        self.incidence = Variable(self.incidence.clone(), requires_grad=False).view(-1, self.incidence.size()[-1],
                                                                                    self.incidence.size()[-1])

        # 每个 Conv2d 单元实现 2d 卷积以对每个单独的分区(滤波器大小 1x1)进行加权
        # 每个分区都有一个卷积单元
        # 这是仅在空间变换器和图卷积连接的情况下进行的
        # 图卷积层
        if (not self.only_attention):
            self.g_convolutions = nn.ModuleList(

                [nn.Conv2d(in_channels, self.out_channels, kernel_size=(kernel_size, 1), padding=(padding, 0),
                           stride=(stride, 1), dilation=(t_dilation, 1)) for i in
                 range(self.incidence.size()[0])]
            )
            for conv in self.g_convolutions:
                conv_init(conv)

            self.attention_conv = spatial_attention(in_channels=self.in_channels, kernel_size=1,
                                                dk=int(out_channels * dk_factor),
                                                dv=int(out_channels * dv_factor), Nh=Nh, complete=complete,
                                                relative=relative,
                                                stride=stride, layer=self.layer, A=self.incidence, num=num,
                                                more_channels=self.more_channels,
                                                drop_connect=self.drop_connect,
                                                data_normalization=self.data_normalization, skip_conn=self.skip_conn,
                                                adjacency=self.adjacency, visualization=self.visualization, num_point=self.num_point)
        else:
            self.attention_conv = spatial_attention(in_channels=self.in_channels, kernel_size=1,
                                                dk=int(out_channels * dk_factor),
                                                dv=int(out_channels), Nh=Nh, complete=complete,
                                                relative=relative,
                                                stride=stride, last_graph=self.last_graph, layer=self.layer,
                                                A=self.incidence, num=num, more_channels=self.more_channels,
                                                drop_connect=self.drop_connect,
                                                data_normalization=self.data_normalization, skip_conn=self.skip_conn,
                                                adjacency=self.adjacency, visualization=self.visualization, num_point=self.num_point)


    def forward(self, x, label, name):
        # N: number of samples, equal to the batch size
        # C: number of channels, in our case 3 (coordinates x, y, z)
        # T: number of frames
        # V: number of nodes
        N, C, T, V = x.size()
        x_sum = x
        # 数据归一化
        if (self.data_normalization):
            x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
            x = self.data_bn(x)
            x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)

        # 邻接矩阵转换
        self.incidence = self.incidence.cuda(x.get_device())

        # 可学习参数
        incidence = self.incidence

        # N, T, C, V > NT, C, 1, V
        xa = x.permute(0, 2, 1, 3).reshape(-1, C, 1, V)

        # 另一种在数据上尝试的归一化方法,称为"ScaleNorm"
        if scale_norm:
            self.scale = ScaleNorm(scale=C ** 0.5)
            xa = self.scale(xa)

        # S-TR
        attn_out = self.attention_conv(xa, label, name)
        # N, T, C, V > N, C, T, V
        attn_out = attn_out.reshape(N, T, -1, V).permute(0, 2, 1, 3)

        if (not self.only_attention):

            # 对于每个分区,将输入相乘,并将结果应用 1x1 卷积以加权每个分区
            for i, partition in enumerate(incidence):
                # print(partition)
                # NCTxV
                xp = x.reshape(-1, V)
                # (NCTxV)*(VxV)
                xp = xp.mm(partition.float())
                # NxCxTxV
                xp = xp.reshape(N, C, T, V)

                if i == 0:
                    y = self.g_convolutions[i](xp)
                else:
                    y = y + self.g_convolutions[i](xp)

            # 在通道维度上连接两个卷积
            y = torch.cat((y, attn_out), dim=1)
        else:
            if self.skip_conn and self.in_channels == self.out_channels:
                y = attn_out + x_sum
            else:
                y = attn_out
        if (self.bn_flag):
            y = self.bn(y)

        y = self.relu(y)

        return y


class ScaleNorm(nn.Module):
    """ScaleNorm"""

    def __init__(self, scale, eps=1e-5):
        super(ScaleNorm, self).__init__()
        self.scale = scale

        self.eps = eps

    def forward(self, x):
        norm = self.scale / torch.norm(x, dim=1, keepdim=True).clamp(min=self.eps)
        return x * norm

TSA的实现方式与SSA相同,唯一不同的是维度 对应于,反之亦然,即为了被每个TSA模块处理,输入被重塑为一个矩阵,以便沿着时间维度分别对每个关节进行操作:

同时矩阵形状也变为 ,如下图所示:

代码如下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F_func
from net import Unit2D
import math
import numpy as np
import time

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
dropout = False
scale_norm = False
save = False
multi_matmul = False

class tcn_unit_attention(nn.Module):

    def forward(self, x):
        # 输入 x
        # (batch_size, channels, time, joints)
        N, C, T, V = x.size()
        x_sum = x

        # 数据归一化
        if (self.data_normalization):
            x = x.permute(0, 1, 3, 2).reshape(N, C * V, T)
            x = self.data_bn(x)
            x = x.reshape(N, C, V, T).permute(0, 1, 3, 2)

        # 联合维度被放入批次中,以便分别沿时间处理每个关节
        x = x.permute(0, 3, 1, 2).reshape(-1, C, 1, T)

        if scale_norm:
            self.scale = ScaleNorm(scale=C ** 0.5)
            x = self.scale(x)

        flat_q, flat_k, flat_v, q, k, v = self.compute_flat_qkv(x, self.dk, self.dv, self.Nh)
        B, self.Nh, C, T = flat_q.size()

        # 计算通过 q*k 获得的分数
        # (batch_size, Nh, time, dkh)*(batch_size, Nh,dkh, time) =  (batch_size, Nh, time, time)

        if (multi_matmul):
            for i in range(0, 5):
                flat_q_5 = flat_q[:, :, :, (60 * i):(60 * (i + 1))]
                product = torch.matmul(flat_q_5.transpose(2, 3), flat_k)
                if (i == 0):
                    logits = product
                else:
                    logits = torch.cat((logits, product), dim=2)
        else:
            logits = torch.matmul(flat_q.transpose(2, 3), flat_k)

        if self.relative:
            rel_logits = self.relative_logits(q)
            logits_sum = torch.add(logits, rel_logits)

        # 计算权重
        if self.relative:

            weights = F_func.softmax(logits_sum, dim=-1)

        else:
            weights = F_func.softmax(logits, dim=-1)

        # 应用丢弃连接
        if (self.drop_connect and self.training):
            mask = torch.bernoulli((0.5) * torch.ones(B * self.Nh * T, device=device))
            mask = mask.reshape(B, self.Nh, T).unsqueeze(2).expand(B, self.Nh, T, T)
            weights = weights * mask
            weights = weights / (weights.sum(3, keepdim=True) + 1e-8)

        # attn_out
        # (batch, Nh, time, dvh)
        # weights*V
        # (batch, Nh, time, time)*(batch, Nh, time, dvh)=(batch, Nh, time, dvh)

        attn_out = torch.matmul(weights, flat_v.transpose(2, 3))
        if not self.more_channels:
            attn_out = torch.reshape(attn_out, (B, self.Nh, 1, T, self.dv // self.Nh))
        else:
            attn_out = torch.reshape(attn_out, (B, self.Nh, 1, T, self.dv // self.num))

        attn_out = attn_out.permute(0, 1, 4, 2, 3)

        # combine_heads_2d,仅在每个 Z 分别计算后合并头部
        # (batch, Nh*dv, time, 1)
        attn_out = self.combine_heads_2d(attn_out)

        # 乘以 W0(批次,输出通道,时间,关节),其中输出通道=dv
        attn_out = self.attn_out(attn_out)
        attn_out = attn_out.reshape(N, V, -1, T).permute(0, 2, 3, 1)

        # 残差连接
        if self.skip_conn:
            if dropout:
                attn_out = self.dropout(attn_out)

                if (not self.only_temporal_att):
                    x = self.tcn_conv(x_sum)
                    result = torch.cat((x, attn_out), dim=1)
                else:
                    result = attn_out

                result = result+(x_sum if (self.down is None) else self.down(x_sum))


            else:
                if (not self.only_temporal_att):
                    x = self.tcn_conv(x_sum)
                    result = torch.cat((x, attn_out), dim=1)
                else:
                    result = attn_out

                result = result+(x_sum if (self.down is None) else self.down(x_sum))


        else:
            result = attn_out

        if (self.bn_flag):
            result = self.bn(result)
        result = self.relu(result)
        return result

    def compute_flat_qkv(self, x, dk, dv, Nh):
        qkv = self.qkv_conv(x)

        # 在这种情况下,V=1,因为每个关节分别应用了时间变换器
        N, C, V1, T1 = qkv.size()
        if self.more_channels:
            q, k, v = torch.split(qkv, [dk * self.Nh // self.num, dk * self.Nh // self.num, dv * self.Nh // self.num],
                                  dim=1)
        else:
            q, k, v = torch.split(qkv, [dk, dk, dv], dim=1)

        q = self.split_heads_2d(q, Nh)
        k = self.split_heads_2d(k, Nh)
        v = self.split_heads_2d(v, Nh)

        dkh = dk // Nh
        q = q* (dkh ** -0.5)
        if self.more_channels:

            flat_q = torch.reshape(q, (N, Nh, dk // self.num, V1 * T1))
            flat_k = torch.reshape(k, (N, Nh, dk // self.num, V1 * T1))
            flat_v = torch.reshape(v, (N, Nh, dv // self.num, V1 * T1))
        else:
            flat_q = torch.reshape(q, (N, Nh, dkh, V1 * T1))
            flat_k = torch.reshape(k, (N, Nh, dkh, V1 * T1))
            flat_v = torch.reshape(v, (N, Nh, dv // self.Nh, V1 * T1))
        return flat_q, flat_k, flat_v, q, k, v

    def split_heads_2d(self, x, Nh):
        B, channels, F, V = x.size()
        ret_shape = (B, Nh, channels // Nh, F, V)
        split = torch.reshape(x, ret_shape)
        return split

    def combine_heads_2d(self, x):
        batch, Nh, dv, F, V = x.size()
        ret_shape = (batch, Nh * dv, F, V)
        return torch.reshape(x, ret_shape)

    def relative_logits(self, q):
        B, Nh, dk, _, T = q.size()
        # B, Nh, V, T, dk -> B, Nh, F, 1, dk
        q = q.permute(0, 1, 3, 4, 2)
        q = q.reshape(B, Nh, T, dk)
        rel_logits = self.relative_logits_1d(q, self.key_rel)
        return rel_logits

    def relative_logits_1d(self, q, rel_k):
        # compute relative logits along one dimension
        # (B, Nh,  1, V, channels // Nh)*(2 * K - 1, self.dk // Nh)
        # (B, Nh,  1, V, 2 * K - 1)
        rel_logits = torch.einsum('bhld,md->bhlm', q, rel_k)
        rel_logits = self.rel_to_abs(rel_logits)
        B, Nh, L, L = rel_logits.size()
        return rel_logits

    def rel_to_abs(self, x):
        B, Nh, L, _ = x.size()
        print(x.shape)
        col_pad = torch.zeros((B, Nh, L, 1)).to(x)
        x = torch.cat((x, col_pad), dim=3)
        flat_x = torch.reshape(x, (B, Nh, L * 2 * L))
        flat_pad = torch.zeros((B, Nh, L - 1)).to(x)
        flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)

        final_x = torch.reshape(flat_x_padded, (B, Nh, L + 1, 2 * L - 1))
        final_x = final_x[:, :, :L, L - 1:]
        return final_x


class ScaleNorm(nn.Module):
    """ScaleNorm"""

    def __init__(self, scale, eps=1e-5):
        super(ScaleNorm, self).__init__()
        self.scale = scale

        self.eps = eps

    def forward(self, x):
        norm = self.scale / torch.norm(x, dim=1, keepdim=True).clamp(min=self.eps)
        return x * norm

3 ST-TR网络结构

ST-TR的网络结构可以参考ST-GCN的网络结构,如下:也包括TCN-GCN单元,代码如下:

python 复制代码
class TCN_GCN_unit(nn.Module):
    def __init__(self,
                 in_channel,
                 out_channel,
                 A,
                 attention,
                 only_attention,
                 tcn_attention,
                 only_temporal_attention,
                 relative,
                 device,
                 attention_3,
                 dv,
                 dk,
                 Nh,
                 num,
                 dim_block1,
                 dim_block2,
                 dim_block3,
                 num_point,
                 weight_matrix,
                 more_channels,
                 drop_connect,
                 starting_ch,
                 all_layers,
                 adjacency,
                 data_normalization,
                 visualization,
                 skip_conn,
                 layer=0,
                 kernel_size=9,
                 stride=1,
                 dropout=0.5,
                 use_local_bn=False,
                 mask_learning=False,
                 last=False,
                 last_graph=False,
                 agcn = False
                 ):
        super(TCN_GCN_unit, self).__init__()
        half_out_channel = out_channel / 2
        self.A = A

        self.V = A.shape[-1]
        self.C = in_channel
        self.last = last
        self.data_normalization = data_normalization
        self.skip_conn = skip_conn
        self.num_point = num_point
        self.adjacency = adjacency
        self.last_graph = last_graph
        self.layer = layer
        self.stride = stride
        self.drop_connect = drop_connect
        self.visualization = visualization
        self.device = device
        self.all_layers = all_layers
        self.more_channels = more_channels

        if (out_channel >= starting_ch and attention or (self.all_layers and attention)):

            self.gcn1 = gcn_unit_attention(in_channel, out_channel, dv_factor=dv, dk_factor=dk, Nh=Nh,
                                           complete=True,
                                           relative=relative, only_attention=only_attention, layer=layer, incidence=A,
                                           bn_flag=True, last_graph=self.last_graph, more_channels=self.more_channels,
                                           drop_connect=self.drop_connect, adjacency=self.adjacency, num=num,
                                           data_normalization=self.data_normalization, skip_conn=self.skip_conn,
                                           visualization=self.visualization, num_point=self.num_point)
        else:

            if not agcn:
                self.gcn1 = unit_gcn(
                    in_channel,
                    out_channel,
                    A,
                    use_local_bn=use_local_bn,
                    mask_learning=mask_learning)
            else:
                self.gcn1 = unit_agcn(
                    in_channel,
                    out_channel,
                    A,
                    use_local_bn=use_local_bn,
                    mask_learning=mask_learning)

        if (out_channel >= starting_ch and tcn_attention or (self.all_layers and tcn_attention)):

            if out_channel <= starting_ch and self.all_layers:
                self.tcn1 = tcn_unit_attention_block(out_channel, out_channel, dv_factor=dv,
                                                     dk_factor=dk, Nh=Nh,
                                                     relative=relative, only_temporal_attention=only_temporal_attention,
                                                     dropout=dropout,
                                                     kernel_size_temporal=9, stride=stride,
                                                     weight_matrix=weight_matrix, bn_flag=True, last=self.last,
                                                     layer=layer,
                                                     device=self.device, more_channels=self.more_channels,
                                                     drop_connect=self.drop_connect, n=num,
                                                     data_normalization=self.data_normalization,
                                                     skip_conn=self.skip_conn,
                                                     visualization=self.visualization, dim_block1=dim_block1,
                                                     dim_block2=dim_block2, dim_block3=dim_block3, num_point=self.num_point)
            else:
                self.tcn1 = tcn_unit_attention(out_channel, out_channel, dv_factor=dv,
                                               dk_factor=dk, Nh=Nh,
                                               relative=relative, only_temporal_attention=only_temporal_attention,
                                               dropout=dropout,
                                               kernel_size_temporal=9, stride=stride,
                                               weight_matrix=weight_matrix, bn_flag=True, last=self.last,
                                               layer=layer,
                                               device=self.device, more_channels=self.more_channels,
                                               drop_connect=self.drop_connect, n=num,
                                               data_normalization=self.data_normalization, skip_conn=self.skip_conn,
                                               visualization=self.visualization, num_point=self.num_point)



        else:
            self.tcn1 = Unit2D(
                out_channel,
                out_channel,
                kernel_size=kernel_size,
                dropout=dropout,
                stride=stride)
        if ((in_channel != out_channel) or (stride != 1)):
            self.down1 = Unit2D(
                in_channel, out_channel, kernel_size=1, stride=stride)
        else:
            self.down1 = None

    def forward(self, x, label, name):
        # N, C, T, V = x.size()
        x = self.tcn1(self.gcn1(x, label, name)) + (x if
                                                    (self.down1 is None) else self.down1(x))

        return x


class TCN_GCN_unit_multiscale(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 A,
                 kernel_size=9,
                 stride=1,
                 **kwargs):
        super(TCN_GCN_unit_multiscale, self).__init__()
        self.unit_1 = TCN_GCN_unit(
            in_channels,
            out_channels / 2,
            A,
            kernel_size=kernel_size,
            stride=stride,
            **kwargs)
        self.unit_2 = TCN_GCN_unit(
            in_channels,
            out_channels - out_channels / 2,
            A,
            kernel_size=kernel_size * 2 - 1,
            stride=stride,
            **kwargs)

    def forward(self, x):
        return torch.cat((self.unit_1(x), self.unit_2(x)), dim=1)

ST-TR整体网络结构的代码如下:

python 复制代码
class Model(nn.Module):

    def forward(self, x, label, name):
        N, C, T, V, M = x.size()
        print(x.shape)
        if (self.concat_original):
            x_coord = x
            x_coord = x_coord.permute(0, 4, 1, 2, 3).reshape(N * M, C, T, V)

        # 数据归一化
        if self.use_data_bn:
            if self.M_dim_bn:
                x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
            else:
                x = x.permute(0, 4, 3, 1, 2).contiguous().view(N * M, V * C, T)
            x = self.data_bn(x)
            # to (N*M, C, T, V)
            x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(
                N * M, C, T, V)
        else:
            # from (N, C, T, V, M) to (N*M, C, T, V)
            x = x.permute(0, 4, 1, 2, 3).contiguous().view(N * M, C, T, V)

        # 主干网络
        if not self.all_layers:
            x = self.gcn0(x, label, name)
            x = self.tcn0(x)

        for i, m in enumerate(self.backbone):
            if i == 3 and self.concat_original:
                x = m(torch.cat((x, x_coord), dim=1), label, name)
            else:
                x = m(x, label, name)

        # 池化和分类
        # V pooling
        x = F.avg_pool2d(x, kernel_size=(1, V))

        # M pooling
        c = x.size(1)
        t = x.size(2)
        x = x.view(N, M, c, t).mean(dim=1).view(N, c, t)

        # T pooling
        x = F.avg_pool1d(x, kernel_size=x.size()[2])

        # C fcn
        x = self.fcn(x)
        x = F.avg_pool1d(x, x.size()[2:])
        x = x.view(N, self.num_class)
        return x

总结

在本文中,提出了一种新颖的方法,将Transformer自注意力引入骨骼活动识别,作为图卷积的替代方案,展现了空间自注意力模块(SSA)的更灵活和动态的表示。同样,时间自注意力模块(TSA)克服了标准卷积的严格局部性,能够提取动作中的长距离依赖。由SSA和TSA组成的双流ST-TR网络能够有效提取动作数据的空间和时间特征,从而在以关节坐标作为输入时在所有数据集和在添加骨骼信息时的数据集上都表现出优异的成绩。由于本文中仅涉及自注意力模块的配置已被证明是次优的,一个可能的未来工作是寻找能够替代各种任务中图卷积的统一Transformer架构来实现更好的识别性能。

相关推荐
数维学长9867 分钟前
【Manus资料合集】激活码内测渠道+《Manus Al:Agent应用的ChatGPT时刻》(附资源)
人工智能·chatgpt
施天助12 分钟前
开发ai模型最佳的系统是Ubuntu还是linux?
人工智能·ubuntu
邵奈一1 小时前
运行OpenManus项目(使用Conda)
人工智能·大模型·agent·agi
是理不是里_1 小时前
深度学习与普通神经网络有何区别?
人工智能·深度学习·神经网络
曲幽1 小时前
DeepSeek大语言模型下几个常用术语
人工智能·ai·语言模型·自然语言处理·ollama·deepseek
AORO_BEIDOU2 小时前
科普|卫星电话有哪些应用场景?
网络·人工智能·安全·智能手机·信息与通信
dreamczf2 小时前
基于Linux系统的边缘智能终端(RK3568+EtherCAT+PCIe+4G+5G)
linux·人工智能·物联网·5g
@Mr_LiuYang2 小时前
深度学习PyTorch之13种模型精度评估公式及调用方法
人工智能·pytorch·深度学习·模型评估·精度指标·模型精度
Herbig2 小时前
文心一言:中国大模型时代的破局者与探路者
人工智能