PointPillarV2VNet
PointPillarV2VNet
python
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>, Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib
import torch.nn as nn
from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone import BaseBEVBackbone
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.fuse_modules.v2v_fuse import V2VNetFusion
class PointPillarV2VNet(nn.Module):
def __init__(self, args):
super(PointPillarV2VNet, self).__init__()
self.max_cav = args['max_cav']
# PIllar VFE
self.pillar_vfe = PillarVFE(args['pillar_vfe'],
num_point_features=4,
voxel_size=args['voxel_size'],
point_cloud_range=args['lidar_range'])
self.scatter = PointPillarScatter(args['point_pillar_scatter'])
self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
# used to downsample the feature map for efficient computation
self.shrink_flag = False
if 'shrink_header' in args:
self.shrink_flag = True
self.shrink_conv = DownsampleConv(args['shrink_header'])
self.compression = False
if args['compression'] > 0:
self.compression = True
self.naive_compressor = NaiveCompressor(256, args['compression'])
self.fusion_net = V2VNetFusion(args['v2vfusion'])
self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
kernel_size=1)
self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
kernel_size=1)
if args['backbone_fix']:
self.backbone_fix()
def backbone_fix(self):
"""
Fix the parameters of backbone during finetune on timedelay。
"""
for p in self.pillar_vfe.parameters():
p.requires_grad = False
for p in self.scatter.parameters():
p.requires_grad = False
for p in self.backbone.parameters():
p.requires_grad = False
if self.compression:
for p in self.naive_compressor.parameters():
p.requires_grad = False
if self.shrink_flag:
for p in self.shrink_conv.parameters():
p.requires_grad = False
for p in self.cls_head.parameters():
p.requires_grad = False
for p in self.reg_head.parameters():
p.requires_grad = False
def forward(self, data_dict):
voxel_features = data_dict['processed_lidar']['voxel_features']
voxel_coords = data_dict['processed_lidar']['voxel_coords']
voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
record_len = data_dict['record_len']
pairwise_t_matrix = data_dict['pairwise_t_matrix']
batch_dict = {'voxel_features': voxel_features,
'voxel_coords': voxel_coords,
'voxel_num_points': voxel_num_points,
'record_len': record_len}
# n, 4 -> n, c
batch_dict = self.pillar_vfe(batch_dict)
# n, c -> N, C, H, W
batch_dict = self.scatter(batch_dict)
batch_dict = self.backbone(batch_dict)
spatial_features_2d = batch_dict['spatial_features_2d']
# downsample feature to reduce memory
if self.shrink_flag:
spatial_features_2d = self.shrink_conv(spatial_features_2d)
# compressor
if self.compression:
spatial_features_2d = self.naive_compressor(spatial_features_2d)
fused_feature = self.fusion_net(spatial_features_2d,
record_len,
pairwise_t_matrix)
psm = self.cls_head(fused_feature)
rm = self.reg_head(fused_feature)
output_dict = {'psm': psm,
'rm': rm}
return output_dict
以下是对提供的代码逐行解析:
python
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib
这些是文件的元信息注释,指定了文件的编码方式为 UTF-8,提供了作者信息和许可证信息。
python
"""
Implementation of V2VNet Fusion
"""
这是一个文档字符串,说明了该文件的目的,即实现了 V2VNet Fusion。
python
import torch
import torch.nn as nn
from opencood.models.sub_modules.torch_transformation_utils import \
get_discretized_transformation_matrix, get_transformation_matrix, \
warp_affine, get_rotated_roi
from opencood.models.sub_modules.convgru import ConvGRU
导入所需的模块和类,包括 PyTorch 库和自定义模块中的一些工具函数和 ConvGRU 类。
python
class V2VNetFusion(nn.Module):
def __init__(self, args):
super(V2VNetFusion, self).__init__()
定义了一个名为 V2VNetFusion 的类,它继承自 nn.Module,表示它是一个 PyTorch 模型。
python
in_channels = args['in_channels']
H, W = args['conv_gru']['H'], args['conv_gru']['W']
kernel_size = args['conv_gru']['kernel_size']
num_gru_layers = args['conv_gru']['num_layers']
self.discrete_ratio = args['voxel_size'][0]
self.downsample_rate = args['downsample_rate']
self.num_iteration = args['num_iteration']
self.gru_flag = args['gru_flag']
self.agg_operator = args['agg_operator']
初始化了一些参数,包括输入通道数、卷积 GRU 的尺寸和参数、离散比率、下采样率、迭代次数、GRU 标志和聚合操作符。
python
self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
stride=1, padding=1)
self.conv_gru = ConvGRU(input_size=(H, W),
input_dim=in_channels * 2,
hidden_dim=[in_channels],
kernel_size=kernel_size,
num_layers=num_gru_layers,
batch_first=True,
bias=True,
return_all_layers=False)
self.mlp = nn.Linear(in_channels, in_channels)
定义了消息卷积层 msg_cnn、卷积 GRU conv_gru 和全连接层 mlp。
python
def regroup(self, x, record_len):
cum_sum_len = torch.cumsum(record_len, dim=0)
split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
return split_x
定义了一个辅助函数 regroup,用于将张量 x 按照记录长度 record_len 进行划分。
python
def forward(self, x, record_len, pairwise_t_matrix):
定义了前向传播函数 forward,它接受输入数据 x、记录长度 record_len 和从每个 cav 到 ego 的配对变换矩阵 pairwise_t_matrix。
python
_, C, H, W = x.shape
B, L = pairwise_t_matrix.shape[:2]
获取输入数据 x 的形状,并获取配对变换矩阵的大小。
python
split_x = self.regroup(x, record_len)
将输入数据 x 按记录长度划分成多个子数据集。
python
pairwise_t_matrix = get_discretized_transformation_matrix(
pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
self.downsample_rate).reshape(B, L, L, 2, 3)
将配对变换矩阵进行离散化处理,得到变换后的矩阵。
【变换矩阵】通常是连续的
- 连续空间的性质:通常情况下,我们对于空间的描述是连续的,比如三维空间中的坐标轴是连续的实数轴。因此,变换矩阵描述的是从一个连续的空间到另一个连续的空间的变换,因此它们自然是连续的。
- 变换的流畅性:在实际应用中,我们通常希望变换过程是连续且平滑的,这样可以避免在变换过程中出现不连续或突变的情况,从而保持图像或对象的连续性和自然性。
- 数学表达的连续性:变换矩阵的数学表达通常是连续函数或者由连续函数组成,比如线性变换、仿射变换等。这些变换通常使用连续函数来描述,因此它们在整个定义域上是连续的。
以下是进行离散化处理的主要原因:
- 计算效率:在计算机中,离散化处理可以将连续的数据转换为离散的数据,这样可以更容易地进行计算和处理。离散化处理通常涉及到将连续空间转换为离散的像素或单元格,这样的处理更适合于计算机上的处理。
- 实现简易性:离散化处理可以使算法的实现更加简单,因为它减少了对连续变量的处理和计算。在图形学中,比如对三维物体的变换,将连续的空间转换为离散的像素可以更容易地应用在屏幕上。
- 数值稳定性:在数值计算中,对于某些变换或计算,连续的数据可能会导致数值不稳定性或数值误差的累积。离散化可以一定程度上减少这些问题的影响,因为它将连续的数据分割成离散的部分,降低了计算过程中的复杂性。
- 采样与量化:在信号处理中,离散化处理是常见的,因为信号通常以离散的形式进行采样和量化。这种处理使得信号可以更容易地存储、传输和处理。
python
roi_mask = get_rotated_roi((B * L, L, 1, H, W),
pairwise_t_matrix.reshape(B * L * L, 2, 3))
roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
batch_node_features = split_x
根据变换矩阵获取 ROI(感兴趣区域)掩码,并将输入数据划分成子数据集。
python
for l in range(self.num_iteration):
循环执行迭代次数。
python
batch_updated_node_features = []
初始化用于存储更新后节点特征的列表。
python
for b in range(B):
遍历每个批次。
python
N = record_len[b]
t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
updated_node_features = []
获取当前批次中记录的数量,以及对应的配对变换矩阵。
python
for i in range(N):
遍历每个节点。
python
mask = roi_mask[b, :N, i, ...]
current_t_matrix = t_matrix[:, i, :, :]
current_t_matrix = get_transformation_matrix(
current_t_matrix, (H, W))
获取当前节点的掩码和变换矩阵。
python
neighbor_feature = warp_affine(batch_node_features[b],
current_t_matrix,
(H, W))
ego_agent_feature = batch_node_features[b][i].unsqueeze(
0).repeat(N, 1, 1, 1)
neighbor_feature = torch.cat(
[neighbor_feature, ego_agent_feature], dim=1)
message = self.msg_cnn(neighbor_feature) * mask
获取邻居特征并进行拼接,并通过消息卷积层计算消息。
python
if self.agg_operator=="avg":
agg_feature = torch.mean(message, dim=0)
elif self.agg_operator=="max":
agg_feature = torch.max(message, dim=0)[0]
else:
raise ValueError("agg_operator has wrong value")
根据聚合操作符计算聚合特征。
python
cat_feature = torch.cat(
[batch_node_features[b][i, ...], agg_feature], dim=0)
将原始特征和聚合特征进行拼接。
python
if self.gru_flag:
gru_out = \
self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
0][
0].squeeze(0).squeeze(0)
else:
gru_out = batch_node_features[b][i, ...] + agg_feature
使用 GRU 网络更新特征。
python
updated_node_features.append(gru_out.unsqueeze(0))
batch_updated_node_features.append(
torch.cat(updated_node_features, dim=0))
将更新后的特征添加到列表中。
python
batch_node_features = batch_updated_node_features
更新批次中的节点特征。
python
out = torch.cat(
[itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)
将更新后的特征进行拼接。
python
out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
通过全连接层对特征进行处理。
python
return out
返回处理后的特征。
V2VNetFusion

python
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib
"""
Implementation of V2VNet Fusion
"""
import torch
import torch.nn as nn
from opencood.models.sub_modules.torch_transformation_utils import \
get_discretized_transformation_matrix, get_transformation_matrix, \
warp_affine, get_rotated_roi
from opencood.models.sub_modules.convgru import ConvGRU
class V2VNetFusion(nn.Module):
def __init__(self, args):
super(V2VNetFusion, self).__init__()
in_channels = args['in_channels']
H, W = args['conv_gru']['H'], args['conv_gru']['W']
kernel_size = args['conv_gru']['kernel_size']
num_gru_layers = args['conv_gru']['num_layers']
self.discrete_ratio = args['voxel_size'][0]
self.downsample_rate = args['downsample_rate']
self.num_iteration = args['num_iteration']
self.gru_flag = args['gru_flag']
self.agg_operator = args['agg_operator']
self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
stride=1, padding=1)
self.conv_gru = ConvGRU(input_size=(H, W),
input_dim=in_channels * 2,
hidden_dim=[in_channels],
kernel_size=kernel_size,
num_layers=num_gru_layers,
batch_first=True,
bias=True,
return_all_layers=False)
self.mlp = nn.Linear(in_channels, in_channels)
def regroup(self, x, record_len):
cum_sum_len = torch.cumsum(record_len, dim=0)
split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
return split_x
def forward(self, x, record_len, pairwise_t_matrix):
"""
Fusion forwarding.
Parameters
----------
x : torch.Tensor
input data, (B, C, H, W)
record_len : list
shape: (B)
pairwise_t_matrix : torch.Tensor
The transformation matrix from each cav to ego,
shape: (B, L, L, 4, 4)
Returns
-------
Fused feature.
"""
_, C, H, W = x.shape
B, L = pairwise_t_matrix.shape[:2]
# split x:[(L1, C, H, W), (L2, C, H, W)]
split_x = self.regroup(x, record_len)
# (B,L,L,2,3)
pairwise_t_matrix = get_discretized_transformation_matrix(
pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
self.downsample_rate).reshape(B, L, L, 2, 3)
# (B*L,L,1,H,W)
roi_mask = get_rotated_roi((B * L, L, 1, H, W),
pairwise_t_matrix.reshape(B * L * L, 2, 3))
roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
batch_node_features = split_x
# iteratively update the features for num_iteration times
for l in range(self.num_iteration):
batch_updated_node_features = []
# iterate each batch
for b in range(B):
# number of valid agent
N = record_len[b]
# (N,N,4,4)
# t_matrix[i, j]-> from i to j
t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
updated_node_features = []
# update each node i
for i in range(N):
# (N,1,H,W)
mask = roi_mask[b, :N, i, ...]
current_t_matrix = t_matrix[:, i, :, :]
current_t_matrix = get_transformation_matrix(
current_t_matrix, (H, W))
# (N,C,H,W)
neighbor_feature = warp_affine(batch_node_features[b],
current_t_matrix,
(H, W))
# (N,C,H,W)
ego_agent_feature = batch_node_features[b][i].unsqueeze(
0).repeat(N, 1, 1, 1)
#(N,2C,H,W)
neighbor_feature = torch.cat(
[neighbor_feature, ego_agent_feature], dim=1)
# (N,C,H,W)
message = self.msg_cnn(neighbor_feature) * mask
# (C,H,W)
if self.agg_operator=="avg":
agg_feature = torch.mean(message, dim=0)
elif self.agg_operator=="max":
agg_feature = torch.max(message, dim=0)[0]
else:
raise ValueError("agg_operator has wrong value")
# (2C, H, W)
cat_feature = torch.cat(
[batch_node_features[b][i, ...], agg_feature], dim=0)
# (C,H,W)
if self.gru_flag:
gru_out = \
self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
0][
0].squeeze(0).squeeze(0)
else:
gru_out = batch_node_features[b][i, ...] + agg_feature
updated_node_features.append(gru_out.unsqueeze(0))
# (N,C,H,W)
batch_updated_node_features.append(
torch.cat(updated_node_features, dim=0))
batch_node_features = batch_updated_node_features
# (B,C,H,W)
out = torch.cat(
[itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)
# (B,C,H,W)
out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return out
以下是解析
python
class V2VNetFusion(nn.Module):
def __init__(self, args):
super(V2VNetFusion, self).__init__()
定义了一个名为 V2VNetFusion 的类,它继承自 nn.Module,表示它是一个 PyTorch 模型
python
in_channels = args['in_channels']
H, W = args['conv_gru']['H'], args['conv_gru']['W']
kernel_size = args['conv_gru']['kernel_size']
num_gru_layers = args['conv_gru']['num_layers']
self.discrete_ratio = args['voxel_size'][0]
self.downsample_rate = args['downsample_rate']
self.num_iteration = args['num_iteration']
self.gru_flag = args['gru_flag']
self.agg_operator = args['agg_operator']
self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
stride=1, padding=1)
self.conv_gru = ConvGRU(input_size=(H, W),
input_dim=in_channels * 2,
hidden_dim=[in_channels],
kernel_size=kernel_size,
num_layers=num_gru_layers,
batch_first=True,
bias=True,
return_all_layers=False)
self.mlp = nn.Linear(in_channels, in_channels)
定义了模型的子模块,包括消息卷积层 msg_cnn、ConvGRU 层 conv_gru 和全连接层 mlp。
def regroup(self, x, record_len):
cum_sum_len = torch.cumsum(record_len, dim=0)
split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
return split_x
定义了一个辅助函数 regroup,用于将输入数据 x 按照记录长度 record_len 进行划分。
def forward(self, x, record_len, pairwise_t_matrix):
...
定义了模型的前向传播函数 forward,接受输入数据 x、记录长度 record_len 和从每个 cav 到 ego 的配对变换矩阵 pairwise_t_matrix。
_, C, H, W = x.shape
B, L = pairwise_t_matrix.shape[:2]
获取输入数据 x 的形状,并获取配对变换矩阵的大小。
python
Copy code
split_x = self.regroup(x, record_len)
将输入数据 x 按记录长度划分成多个子数据集。
pairwise_t_matrix = get_discretized_transformation_matrix(
pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
self.downsample_rate).reshape(B, L, L, 2, 3)
将配对变换矩阵进行离散化处理,得到变换后的矩阵。
roi_mask = get_rotated_roi((B * L, L, 1, H, W),
pairwise_t_matrix.reshape(B * L * L, 2, 3))
roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
batch_node_features = split_x
根据变换矩阵获取 ROI(感兴趣区域)掩码,并将输入数据划分成子数据集。
for l in range(self.num_iteration):
循环执行迭代次数。
batch_updated_node_features = []
# iterate each batch
for b in range(B):
# number of valid agent
N = record_len[b]
t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
updated_node_features = []
获取当前批次中记录的数量,以及对应的配对变换矩阵。
for i in range(N):
mask = roi_mask[b, :N, i, ...]
current_t_matrix = t_matrix[:, i, :, :]
current_t_matrix = get_transformation_matrix(
current_t_matrix, (H, W))
获取当前节点的掩码和变换矩阵。
neighbor_feature = warp_affine(batch_node_features[b],
current_t_matrix,
(H, W))
ego_agent_feature = batch_node_features[b][i].unsqueeze(
0).repeat(N, 1, 1, 1)
neighbor_feature = torch.cat(
[neighbor_feature, ego_agent_feature], dim=1)
message = self.msg_cnn(neighbor_feature) * mask
获取邻居特征并进行拼接,并通过消息卷积层计算消息。
if self.agg_operator=="avg":
agg_feature = torch.mean(message, dim=0)
elif self.agg_operator=="max":
agg_feature = torch.max(message, dim=0)[0]
else:
raise ValueError("agg_operator has wrong value")
根据聚合操作符计算聚合特征。
cat_feature = torch.cat(
[batch_node_features[b][i, ...], agg_feature], dim=0)
将原始特征和聚合特征进行拼接。
if self.gru_flag:
gru_out = \
self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
0][
0].squeeze(0).squeeze(0)
else:
gru_out = batch_node_features[b][i, ...] + agg_feature
updated_node_features.append(gru_out.unsqueeze(0))
使用 GRU 网络更新特征。
batch_updated_node_features.append(
torch.cat(updated_node_features, dim=0))
batch_node_features = batch_updated_node_features
更新批次中的节点特征。
out = torch.cat(
[itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)
将更新后的特征进行拼接。
out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
通过全连接层对特征进行处理。
return out
返回处理后的特征。
ConvGRU
ConvGRU(卷积门控循环单元)是一种结合了卷积神经网络(CNN)和循环神经网络(RNN)思想的模型,用于处理时序数据。它类似于传统的循环神经网络,但具有更好地捕捉空间信息的能力。
ConvGRU的核心是ConvGRUCell,它是ConvGRU的基本单元,类似于LSTM中的LSTMCell。与LSTM类似,GRU也有更新门(update gate)和重置门(reset gate)。这些门控制着隐藏状态中的信息流动,使得模型能够根据当前输入和前一个隐藏状态选择性地更新或保留信息。更新门决定了多少旧信息应该被保留,而重置门决定了多少旧信息应该被忽略。这样,GRU能够在处理序列数据时更好地捕捉长期依赖关系。卷积操作用于捕捉输入数据中的空间结构。在ConvGRUCell中,卷积操作能够更好地处理图像等具有空间关系的数据。
在每个时间步,ConvGRUCell接收当前输入和前一个隐藏状态作为输入。
- 首先,通过卷积操作处理当前输入,以捕获输入数据的空间结构。
- 然后,利用GRU的门控机制,根据当前输入和前一个隐藏状态生成更新门和重置门,从而控制信息的流动。
- 最后,根据门控制的信息流动,结合当前输入和前一个隐藏状态,生成新的隐藏状态。
python
class ConvGRU(nn.Module):
def __init__(self, input_size, input_dim, hidden_dim, kernel_size,
num_layers,
batch_first=False, bias=True, return_all_layers=False):
"""
:param input_size: (int, int)
Height and width of input tensor as (height, width).
:param input_dim: int e.g. 256
Number of channels of input tensor.
:param hidden_dim: int e.g. 1024
Number of channels of hidden state.
:param kernel_size: (int, int)
Size of the convolutional kernel.
:param num_layers: int
Number of ConvLSTM layers
:param dtype: torch.cuda.FloatTensor or torch.FloatTensor
Whether or not to use cuda.
:param alexnet_path: str
pretrained alexnet parameters
:param batch_first: bool
if the first position of array is batch or not
:param bias: bool
Whether or not to add the bias.
:param return_all_layers: bool
if return hidden and cell states for all layers
"""
super(ConvGRU, self).__init__()
# Make sure that both `kernel_size` and
# `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
self.height, self.width = input_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
# convert python list to pytorch module
self.cell_list = nn.ModuleList(cell_list)
def forward(self, input_tensor, hidden_state=None):
"""
:param input_tensor: (b, t, c, h, w) or (t,b,c,h,w)
depends on if batch first or not extracted features from alexnet
:param hidden_state:
:return: layer_output_list, last_state_list
"""
if not self.batch_first:
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
hidden_state = self._init_hidden(batch_size=input_tensor.size(0),
device=input_tensor.device,
dtype=input_tensor.dtype)
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
# input current hidden and cell state
# then compute the next hidden
# and cell state through ConvLSTMCell forward function
h = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :], # (b,t,c,h,w)
h_cur=h)
output_inner.append(h)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h])
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(self, batch_size, device=None, dtype=None):
init_states = []
for i in range(self.num_layers):
init_states.append(
self.cell_list[i].init_hidden(batch_size).to(device).to(dtype))
return init_states
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (isinstance(kernel_size, tuple) or
(isinstance(kernel_size, list) and all(
[isinstance(elem, tuple) for elem in kernel_size]))):
raise ValueError('`kernel_size` must be tuple or list of tuples')
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
上述代码定义了一个名为 ConvGRU 的类,用于实现卷积门控循环单元(Convolutional Gated Recurrent Unit,ConvGRU)的模型。以下是对代码的逐行解析:
class ConvGRU(nn.Module):
def __init__(self, input_size, input_dim, hidden_dim, kernel_size,
num_layers,
batch_first=False, bias=True, return_all_layers=False):
super(ConvGRU, self).__init__()
定义了 ConvGRU 类,继承自 nn.Module。构造函数中接收了输入尺寸 input_size、输入通道数 input_dim、隐藏状态通道数 hidden_dim、卷积核大小 kernel_size、层数 num_layers 等参数。
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
对 kernel_size 和 hidden_dim 进行扩展,确保它们的长度与层数一致。
self.height, self.width = input_size
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
初始化了类中的一些属性,包括输入的高度和宽度、输入通道数、隐藏状态通道数、卷积核大小、层数、是否以 batch 为第一维、是否使用偏置、是否返回所有层的状态。
cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
创建了一个列表 cell_list,用于存储每一层的 ConvGRUCell 实例。
self.cell_list = nn.ModuleList(cell_list)
将 cell_list 转换为 nn.ModuleList 类型,以确保每个 ConvGRUCell 实例都被正确注册为模型的子模块。
def forward(self, input_tensor, hidden_state=None):
if not self.batch_first:
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
如果不是以 batch 为第一维,需要对输入数据进行转置,将 batch 维度置于第一维。
if hidden_state is not None:
raise NotImplementedError()
else:
hidden_state = self._init_hidden(batch_size=input_tensor.size(0),
device=input_tensor.device,
dtype=input_tensor.dtype)
如果传入了隐藏状态,则抛出 NotImplementedError;否则,根据输入数据的大小和设备信息初始化隐藏状态。
layer_output_list = []
last_state_list = []
定义了用于存储每一层输出和最终状态的列表。
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
获取输入序列的长度,并初始化当前层的输入。
for layer_idx in range(self.num_layers):
h = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :],
h_cur=h)
output_inner.append(h)
遍历每一层,对于每一时刻,利用当前层的 ConvGRUCell 实例更新隐藏状态,并将更新后的隐藏状态存入列表中。
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append([h])
将每一层的输出存入列表 layer_output_list 中,并将最后一个时刻的隐藏状态存入 last_state_list 中。
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
如果设置了不返回所有层的状态,则只保留最后一层的状态。
return layer_output_list, last_state_list
返回所有层的输出和最终状态。
def _init_hidden(self, batch_size, device=None, dtype=None):
init_states = []
for i in range(self.num_layers):
init_states.append(
self.cell_list[i].init_hidden(batch_size).to(device).to(dtype))
return init_states
辅助方法,用于初始化隐藏状态。
@staticmethod
def _extend_for_multilayer(param, num_layers):
if not isinstance(param, list):
param = [param] * num_layers
return param
辅助方法,用于将单个参数扩展成多层的形式。