import math
import pdb
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
def import_class(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
def conv_branch_init(conv, branches):
weight = conv.weight
n = weight.size(0)
k1 = weight.size(1)
k2 = weight.size(2)
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
nn.init.constant_(conv.bias, 0)
def conv_init(conv):
if conv.weight is not None:
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def bn_init(bn, scale):
nn.init.constant_(bn.weight, scale)
nn.init.constant_(bn.bias, 0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
nn.init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
if hasattr(m, 'weight') and m.weight is not None:
m.weight.data.normal_(1.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0)
class TemporalConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, dilation=1):
super(TemporalConv, self).__init__()
pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, 1),
padding=(pad, 0),
stride=(stride, 1),
dilation=(dilation, 1))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelfAttention(nn.Module):
def __init__(self, in_channels, out_channels, num_heads=4):
super(SelfAttention, self).__init__()
assert out_channels % num_heads == 0, "out_channels 必须是 num_heads 的倍数"
self.num_heads = num_heads
self.head_dim = out_channels // num_heads
self.conv_q = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv_k = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv_v = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.fc_out = nn.Conv2d(out_channels, out_channels, kernel_size=1)
def forward(self, x):
N, C, T, V = x.shape
q = self.conv_q(x)
k = self.conv_k(x)
v = self.conv_v(x)
q = q.view(N, self.num_heads, self.head_dim, T, V)
k = k.view(N, self.num_heads, self.head_dim, T, V)
v = v.view(N, self.num_heads, self.head_dim, T, V)
q = q.permute(0, 1, 3, 4, 2).contiguous()
k = k.permute(0, 1, 3, 4, 2).contiguous()
attention_scores = torch.matmul(q, k.permute(0, 1, 4, 3, 2))
attention_scores = attention_scores / math.sqrt(self.head_dim)
attention_weights = F.softmax(attention_scores, dim=-1)
v = v.permute(0, 1, 3, 4, 2).contiguous()
out = torch.matmul(attention_weights, v)
out = out.view(N, self.num_heads * self.head_dim, T, V)
out = self.fc_out(out)
return out
class MultiScale_TemporalConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilations=[1, 2, 3, 4],
residual=True,
residual_kernel_size=1):
super().__init__()
assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'
# Multiple branches of temporal convolution
self.num_branches = len(dilations) + 2
branch_channels = out_channels // self.num_branches
if type(kernel_size) == list:
assert len(kernel_size) == len(dilations)
else:
kernel_size = [kernel_size] * len(dilations)
# Temporal Convolution branches
self.branches = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
in_channels,
branch_channels,
kernel_size=1,
padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
TemporalConv(
branch_channels,
branch_channels,
kernel_size=ks,
stride=stride,
dilation=dilation),
)
for ks, dilation in zip(kernel_size, dilations)
])
# Additional Max & 1x1 branch
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
nn.BatchNorm2d(branch_channels) # 为什么还要加bn
))
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)),
nn.BatchNorm2d(branch_channels)
))
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
# self.selfattention = SelfAttention(out_channels, out_channels)
# initialize
self.apply(weights_init)
def forward(self, x):
# Input dim: (N,C,T,V)
res = self.residual(x)
branch_outs = []
for tempconv in self.branches:
out = tempconv(x)
branch_outs.append(out)
# 这里的是所有的结果concat,dim=1
out = torch.cat(branch_outs, dim=1)
# 这里尝试在多尺度时间卷积上加入自注意力机制效果
# out = self.selfattention(out) + out
out += res
return out
class CTRGC(nn.Module):
def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1):
super(CTRGC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if in_channels == 3 or in_channels == 9:
self.rel_channels = 8
self.mid_channels = 16
else:
self.rel_channels = in_channels // rel_reduction
self.mid_channels = in_channels // mid_reduction
self.conv1 = nn.Conv2d(6, self.rel_channels, kernel_size=1)
self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1)
self.tanh = nn.Tanh()
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
def forward(self, x, A=None, alpha=1):
x1, x2, x3 = self.conv1(x), self.conv2(x), self.conv3(x)
graph = self.tanh(x1.mean(-2).unsqueeze(-1) - x2.mean(-2).unsqueeze(-2))
graph = self.conv4(graph)
graph_c = graph * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0) # N,C,V,V
y = torch.einsum('ncuv,nctv->nctu', graph_c, x3)
return y, graph
class unit_tcn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
super(unit_tcn, self).__init__()
pad = int((kernel_size - 1) / 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
stride=(stride, 1))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
conv_init(self.conv)
bn_init(self.bn, 1)
def forward(self, x):
x = self.bn(self.conv(x))
return x
class unit_gcn(nn.Module):
def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True):
super(unit_gcn, self).__init__()
inter_channels = out_channels // coff_embedding
self.inter_c = inter_channels
self.out_c = out_channels
self.in_c = in_channels
self.adaptive = adaptive
self.num_subset = A.shape[0]
self.convs = nn.ModuleList()
for i in range(self.num_subset):
self.convs.append(CTRGC(in_channels, out_channels))
if residual:
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
else:
self.down = lambda x: x
else:
self.down = lambda x: 0
if self.adaptive:
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
else:
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
self.alpha = nn.Parameter(torch.zeros(1))
self.bn = nn.BatchNorm2d(out_channels)
self.soft = nn.Softmax(-2)
self.relu = nn.ReLU(inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
bn_init(self.bn, 1e-6)
def forward(self, x):
y = None
graph_list = []
if self.adaptive:
A = self.PA
else:
A = self.A.cuda(x.get_device())
for i in range(self.num_subset):
z, graph = self.convs[i](x, A[i], self.alpha)
graph_list.append(graph)
y = z + y if y is not None else z
y = self.bn(y)
y += self.down(x)
y = self.relu(y)
return y, torch.stack(graph_list, 1)
class TCN_GCN_unit(nn.Module):
def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5,
dilations=[1, 2]):
super(TCN_GCN_unit, self).__init__()
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive)
# self.tcn1 = TemporalConv(out_channels, out_channels, stride=stride)
self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride,
dilations=dilations,
residual=True)
self.relu = nn.ReLU(inplace=True)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
def forward(self, x):
z, graph = self.gcn1(x)
y = self.relu(self.tcn1(z) + self.residual(x))
return y, graph
class Model(nn.Module):
def __init__(self, num_class=155, num_point=17, num_person=2, graph=None, graph_args=dict(), in_channels=3,
drop_out=0, adaptive=True):
super(Model, self).__init__()
if graph is None:
raise ValueError()
else:
Graph = import_class(graph)
self.graph = Graph(**graph_args)
A = self.graph.A # 3,25,25
self.num_class = num_class
self.num_point = num_point
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
base_channel = 64
# self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive)
# self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
# self.l3 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
# self.l4 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
# self.l5 = TCN_GCN_unit(base_channel, base_channel * 2, A, stride=2, adaptive=adaptive)
# self.l6 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive)
# self.l7 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive)
# self.l8 = TCN_GCN_unit(base_channel * 2, base_channel * 4, A, stride=2, adaptive=adaptive)
# self.l9 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive)
# self.l10 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive)
# self.fc = nn.Linear(base_channel * 4, num_class)
self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive)
self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l3 = TCN_GCN_unit(base_channel, base_channel * 2, A, stride=2, adaptive=adaptive)
self.l4 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive)
self.l5 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive)
self.l6 = TCN_GCN_unit(base_channel * 2, base_channel * 4, A, stride=2, adaptive=adaptive)
self.l7 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive)
self.fc = nn.Linear(base_channel * 4, num_class)
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
bn_init(self.data_bn, 1)
if drop_out:
self.drop_out = nn.Dropout(drop_out)
else:
self.drop_out = lambda x: x
def partDivison(self, graph):
# _, num_joints, _ = graph.size()
_, k, u, v = graph.size() # n k u v
head = [0, 1, 2, 3, 4, 5, 6] # nose, eyes, and ears
left_arm = [5, 7, 9] # arms connections
right_arm = [6, 8, 10] # arms connections
# arm = [5, 6, 7, 8, 9, 10]
torso = [5, 6, 11, 12] # torso connections
left_leg = [11, 13, 15] # legs connections
right_leg = [12, 14, 16]
graph_list = []
part_list = [[head, left_arm, right_arm, torso, left_leg, right_leg]]
for part in part_list:
part_grah = graph[:, :, :, part].mean(dim=-1, keepdim=True)
graph_list.append(part_grah)
return torch.cat(graph_list, -1)
def forward(self, x):
if torch.isnan(x).any() or torch.isinf(x).any():
print("Input data contains NaN or Inf.")
if len(x.shape) == 3:
N, T, VC = x.shape
x = x.view(N, T, self.num_point, -1).permute(0, 3, 1, 2).contiguous().unsqueeze(-1)
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
# x, _ = self.l1(x)
# x, _ = self.l2(x)
# x, _ = self.l3(x)
# x, _ = self.l4(x)
# x, _ = self.l5(x)
# x, _ = self.l6(x)
# x, _ = self.l7(x)
# x, _ = self.l8(x)
# x, _ = self.l9(x)
# x, graph = self.l10(x)
x, _ = self.l1(x)
x, _ = self.l2(x)
x, _ = self.l3(x)
x, _ = self.l4(x)
x, _ = self.l5(x)
x, _ = self.l6(x)
x, graph = self.l7(x)
# N*M,C,T,V
c_new = x.size(1)
x = x.view(N, M, c_new, -1)
x = x.mean(3).mean(1)
x = self.drop_out(x)
graph2 = graph.view(N, M, -1, c_new, V, V)
# graph4 = torch.einsum('n m k c u v, n m k c v l -> n m k c u l', graph2, graph2)
graph2 = graph2.view(N, M, -1, c_new, V, V).mean(1).mean(2).view(N, -1)
# graph4 = graph4.view(N, M, -1, c_new, V, V).mean(1).mean(2).view(N, -1)
# graph = torch.cat([graph2, graph4], -1)
return self.fc(x), graph2
atttention1111
yyfhq2024-11-16 16:25
相关推荐
Wishell201511 分钟前
为什么深度学习和神经网络要使用 GPU?起名字什么的好难27 分钟前
conda虚拟环境安装pytorch gpu版18号房客34 分钟前
计算机视觉-人工智能(AI)入门教程一百家方案36 分钟前
「下载」智慧产业园区-数字孪生建设解决方案:重构产业全景图,打造虚实结合的园区数字化底座云起无垠42 分钟前
“AI+Security”系列第4期(一)之“洞” 见未来:AI 驱动的漏洞挖掘新范式Auc241 小时前
使用scrapy框架爬取微博热搜榜QQ_7781329741 小时前
基于深度学习的图像超分辨率重建梦想画家1 小时前
Python Polars快速入门指南:LazyFrames清 晨1 小时前
Web3 生态全景:创新与发展之路程序猿000001号1 小时前
使用Python的Seaborn库进行数据可视化