【代码解读】Opencood框架之PointPillarV2VNet

PointPillarV2VNet

PointPillarV2VNet

python 复制代码
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>, Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib


import torch.nn as nn

from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone import BaseBEVBackbone
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.fuse_modules.v2v_fuse import V2VNetFusion


class PointPillarV2VNet(nn.Module):
    def __init__(self, args):
        super(PointPillarV2VNet, self).__init__()

        self.max_cav = args['max_cav']
        # PIllar VFE
        self.pillar_vfe = PillarVFE(args['pillar_vfe'],
                                    num_point_features=4,
                                    voxel_size=args['voxel_size'],
                                    point_cloud_range=args['lidar_range'])
        self.scatter = PointPillarScatter(args['point_pillar_scatter'])
        self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64)
        
        # used to downsample the feature map for efficient computation
        self.shrink_flag = False
        if 'shrink_header' in args:
            self.shrink_flag = True
            self.shrink_conv = DownsampleConv(args['shrink_header'])
        self.compression = False

        if args['compression'] > 0:
            self.compression = True
            self.naive_compressor = NaiveCompressor(256, args['compression'])

        self.fusion_net = V2VNetFusion(args['v2vfusion'])

        self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'],
                                  kernel_size=1)
        self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'],
                                  kernel_size=1)
        if args['backbone_fix']:
            self.backbone_fix()

    def backbone_fix(self):
        """
        Fix the parameters of backbone during finetune on timedelay。
        """
        for p in self.pillar_vfe.parameters():
            p.requires_grad = False

        for p in self.scatter.parameters():
            p.requires_grad = False

        for p in self.backbone.parameters():
            p.requires_grad = False

        if self.compression:
            for p in self.naive_compressor.parameters():
                p.requires_grad = False
        if self.shrink_flag:
            for p in self.shrink_conv.parameters():
                p.requires_grad = False

        for p in self.cls_head.parameters():
            p.requires_grad = False
        for p in self.reg_head.parameters():
            p.requires_grad = False

    def forward(self, data_dict):
        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        record_len = data_dict['record_len']

        pairwise_t_matrix = data_dict['pairwise_t_matrix']

        batch_dict = {'voxel_features': voxel_features,
                      'voxel_coords': voxel_coords,
                      'voxel_num_points': voxel_num_points,
                      'record_len': record_len}
        # n, 4 -> n, c
        batch_dict = self.pillar_vfe(batch_dict)
        # n, c -> N, C, H, W
        batch_dict = self.scatter(batch_dict)
        batch_dict = self.backbone(batch_dict)

        spatial_features_2d = batch_dict['spatial_features_2d']
        # downsample feature to reduce memory
        if self.shrink_flag:
            spatial_features_2d = self.shrink_conv(spatial_features_2d)
        # compressor
        if self.compression:
            spatial_features_2d = self.naive_compressor(spatial_features_2d)
        fused_feature = self.fusion_net(spatial_features_2d,
                                        record_len,
                                        pairwise_t_matrix)

        psm = self.cls_head(fused_feature)
        rm = self.reg_head(fused_feature)

        output_dict = {'psm': psm,
                       'rm': rm}

        return output_dict

以下是对提供的代码逐行解析:

python 复制代码
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib

这些是文件的元信息注释,指定了文件的编码方式为 UTF-8,提供了作者信息和许可证信息。

python 复制代码
"""
Implementation of V2VNet Fusion
"""

这是一个文档字符串,说明了该文件的目的,即实现了 V2VNet Fusion。

python 复制代码
import torch
import torch.nn as nn

from opencood.models.sub_modules.torch_transformation_utils import \
    get_discretized_transformation_matrix, get_transformation_matrix, \
    warp_affine, get_rotated_roi
from opencood.models.sub_modules.convgru import ConvGRU
导入所需的模块和类,包括 PyTorch 库和自定义模块中的一些工具函数和 ConvGRU 类。
python 复制代码
class V2VNetFusion(nn.Module):
    def __init__(self, args):
        super(V2VNetFusion, self).__init__()
定义了一个名为 V2VNetFusion 的类,它继承自 nn.Module,表示它是一个 PyTorch 模型。
python 复制代码
        in_channels = args['in_channels']
        H, W = args['conv_gru']['H'], args['conv_gru']['W']
        kernel_size = args['conv_gru']['kernel_size']
        num_gru_layers = args['conv_gru']['num_layers']

        self.discrete_ratio = args['voxel_size'][0]
        self.downsample_rate = args['downsample_rate']
        self.num_iteration = args['num_iteration']
        self.gru_flag = args['gru_flag']
        self.agg_operator = args['agg_operator']

初始化了一些参数,包括输入通道数、卷积 GRU 的尺寸和参数、离散比率、下采样率、迭代次数、GRU 标志和聚合操作符。

python 复制代码
        self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
                                 stride=1, padding=1)
        self.conv_gru = ConvGRU(input_size=(H, W),
                                input_dim=in_channels * 2,
                                hidden_dim=[in_channels],
                                kernel_size=kernel_size,
                                num_layers=num_gru_layers,
                                batch_first=True,
                                bias=True,
                                return_all_layers=False)
        self.mlp = nn.Linear(in_channels, in_channels)
定义了消息卷积层 msg_cnn、卷积 GRU conv_gru 和全连接层 mlp。
python 复制代码
    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return split_x

定义了一个辅助函数 regroup,用于将张量 x 按照记录长度 record_len 进行划分。

python 复制代码
    def forward(self, x, record_len, pairwise_t_matrix):

定义了前向传播函数 forward,它接受输入数据 x、记录长度 record_len 和从每个 cav 到 ego 的配对变换矩阵 pairwise_t_matrix。

python 复制代码
        _, C, H, W = x.shape
        B, L = pairwise_t_matrix.shape[:2]

获取输入数据 x 的形状,并获取配对变换矩阵的大小。

python 复制代码
        split_x = self.regroup(x, record_len)

将输入数据 x 按记录长度划分成多个子数据集。

python 复制代码
        pairwise_t_matrix = get_discretized_transformation_matrix(
            pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
            self.downsample_rate).reshape(B, L, L, 2, 3)

将配对变换矩阵进行离散化处理,得到变换后的矩阵。

【变换矩阵】通常是连续的

  • 连续空间的性质:通常情况下,我们对于空间的描述是连续的,比如三维空间中的坐标轴是连续的实数轴。因此,变换矩阵描述的是从一个连续的空间到另一个连续的空间的变换,因此它们自然是连续的。
  • 变换的流畅性:在实际应用中,我们通常希望变换过程是连续且平滑的,这样可以避免在变换过程中出现不连续或突变的情况,从而保持图像或对象的连续性和自然性。
  • 数学表达的连续性:变换矩阵的数学表达通常是连续函数或者由连续函数组成,比如线性变换、仿射变换等。这些变换通常使用连续函数来描述,因此它们在整个定义域上是连续的。

以下是进行离散化处理的主要原因:

  • 计算效率:在计算机中,离散化处理可以将连续的数据转换为离散的数据,这样可以更容易地进行计算和处理。离散化处理通常涉及到将连续空间转换为离散的像素或单元格,这样的处理更适合于计算机上的处理。
  • 实现简易性:离散化处理可以使算法的实现更加简单,因为它减少了对连续变量的处理和计算。在图形学中,比如对三维物体的变换,将连续的空间转换为离散的像素可以更容易地应用在屏幕上。
  • 数值稳定性:在数值计算中,对于某些变换或计算,连续的数据可能会导致数值不稳定性或数值误差的累积。离散化可以一定程度上减少这些问题的影响,因为它将连续的数据分割成离散的部分,降低了计算过程中的复杂性。
  • 采样与量化:在信号处理中,离散化处理是常见的,因为信号通常以离散的形式进行采样和量化。这种处理使得信号可以更容易地存储、传输和处理。
python 复制代码
        roi_mask = get_rotated_roi((B * L, L, 1, H, W),
                                   pairwise_t_matrix.reshape(B * L * L, 2, 3))
        roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
        batch_node_features = split_x

根据变换矩阵获取 ROI(感兴趣区域)掩码,并将输入数据划分成子数据集。

python 复制代码
        for l in range(self.num_iteration):

循环执行迭代次数。

python 复制代码
            batch_updated_node_features = []

初始化用于存储更新后节点特征的列表。

python 复制代码
            for b in range(B):

遍历每个批次。

python 复制代码
                N = record_len[b]
                t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
                updated_node_features = []

获取当前批次中记录的数量,以及对应的配对变换矩阵。

python 复制代码
                for i in range(N):

遍历每个节点。

python 复制代码
                    mask = roi_mask[b, :N, i, ...]
                    current_t_matrix = t_matrix[:, i, :, :]
                    current_t_matrix = get_transformation_matrix(
                        current_t_matrix, (H, W))

获取当前节点的掩码和变换矩阵。

python 复制代码
                    neighbor_feature = warp_affine(batch_node_features[b],
                                                   current_t_matrix,
                                                   (H, W))
                    ego_agent_feature = batch_node_features[b][i].unsqueeze(
                        0).repeat(N, 1, 1, 1)
                    neighbor_feature = torch.cat(
                        [neighbor_feature, ego_agent_feature], dim=1)
                    message = self.msg_cnn(neighbor_feature) * mask

获取邻居特征并进行拼接,并通过消息卷积层计算消息。

python 复制代码
                    if self.agg_operator=="avg":
                        agg_feature = torch.mean(message, dim=0)
                    elif self.agg_operator=="max":
                        agg_feature = torch.max(message, dim=0)[0]
                    else:
                        raise ValueError("agg_operator has wrong value")

根据聚合操作符计算聚合特征。

python 复制代码
                    cat_feature = torch.cat(
                        [batch_node_features[b][i, ...], agg_feature], dim=0)

将原始特征和聚合特征进行拼接。

python 复制代码
                    if self.gru_flag:
                        gru_out = \
                            self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
                                0][
                                0].squeeze(0).squeeze(0)
                    else:
                        gru_out = batch_node_features[b][i, ...] + agg_feature

使用 GRU 网络更新特征。

python 复制代码
                    updated_node_features.append(gru_out.unsqueeze(0))
                batch_updated_node_features.append(
                    torch.cat(updated_node_features, dim=0))

将更新后的特征添加到列表中。

python 复制代码
            batch_node_features = batch_updated_node_features

更新批次中的节点特征。

python 复制代码
        out = torch.cat(
            [itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)

将更新后的特征进行拼接。

python 复制代码
        out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

通过全连接层对特征进行处理。

python 复制代码
        return out

返回处理后的特征。

V2VNetFusion

python 复制代码
# -*- coding: utf-8 -*-
# Author: Hao Xiang <haxiang@g.ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib


"""
Implementation of V2VNet Fusion
"""

import torch
import torch.nn as nn

from opencood.models.sub_modules.torch_transformation_utils import \
    get_discretized_transformation_matrix, get_transformation_matrix, \
    warp_affine, get_rotated_roi
from opencood.models.sub_modules.convgru import ConvGRU


class V2VNetFusion(nn.Module):
    def __init__(self, args):
        super(V2VNetFusion, self).__init__()
        
        in_channels = args['in_channels']
        H, W = args['conv_gru']['H'], args['conv_gru']['W']
        kernel_size = args['conv_gru']['kernel_size']
        num_gru_layers = args['conv_gru']['num_layers']

        self.discrete_ratio = args['voxel_size'][0]
        self.downsample_rate = args['downsample_rate']
        self.num_iteration = args['num_iteration']
        self.gru_flag = args['gru_flag']
        self.agg_operator = args['agg_operator']

        self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
                                 stride=1, padding=1)
        self.conv_gru = ConvGRU(input_size=(H, W),
                                input_dim=in_channels * 2,
                                hidden_dim=[in_channels],
                                kernel_size=kernel_size,
                                num_layers=num_gru_layers,
                                batch_first=True,
                                bias=True,
                                return_all_layers=False)
        self.mlp = nn.Linear(in_channels, in_channels)

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return split_x

    def forward(self, x, record_len, pairwise_t_matrix):
        """
        Fusion forwarding.
        
        Parameters
        ----------
        x : torch.Tensor
            input data, (B, C, H, W)
            
        record_len : list
            shape: (B)
            
        pairwise_t_matrix : torch.Tensor
            The transformation matrix from each cav to ego, 
            shape: (B, L, L, 4, 4) 
            
        Returns
        -------
        Fused feature.
        """
        _, C, H, W = x.shape
        B, L = pairwise_t_matrix.shape[:2]

        # split x:[(L1, C, H, W), (L2, C, H, W)]
        split_x = self.regroup(x, record_len)
        # (B,L,L,2,3)
        pairwise_t_matrix = get_discretized_transformation_matrix(
            pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
            self.downsample_rate).reshape(B, L, L, 2, 3)
        # (B*L,L,1,H,W)
        roi_mask = get_rotated_roi((B * L, L, 1, H, W),
                                   pairwise_t_matrix.reshape(B * L * L, 2, 3))
        roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
        batch_node_features = split_x
        
        # iteratively update the features for num_iteration times
        for l in range(self.num_iteration):

            batch_updated_node_features = []
            # iterate each batch
            for b in range(B):

                # number of valid agent
                N = record_len[b]
                # (N,N,4,4)
                # t_matrix[i, j]-> from i to j
                t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
                updated_node_features = []
                # update each node i
                for i in range(N):
                    # (N,1,H,W)
                    mask = roi_mask[b, :N, i, ...]

                    current_t_matrix = t_matrix[:, i, :, :]
                    current_t_matrix = get_transformation_matrix(
                        current_t_matrix, (H, W))

                    # (N,C,H,W)
                    neighbor_feature = warp_affine(batch_node_features[b],
                                                   current_t_matrix,
                                                   (H, W))
                    # (N,C,H,W)
                    ego_agent_feature = batch_node_features[b][i].unsqueeze(
                        0).repeat(N, 1, 1, 1)
                    #(N,2C,H,W)
                    neighbor_feature = torch.cat(
                        [neighbor_feature, ego_agent_feature], dim=1)
                    # (N,C,H,W)
                    message = self.msg_cnn(neighbor_feature) * mask

                    # (C,H,W)
                    if self.agg_operator=="avg":
                        agg_feature = torch.mean(message, dim=0)
                    elif self.agg_operator=="max":
                        agg_feature = torch.max(message, dim=0)[0]
                    else:
                        raise ValueError("agg_operator has wrong value")
                    # (2C, H, W)
                    cat_feature = torch.cat(
                        [batch_node_features[b][i, ...], agg_feature], dim=0)
                    # (C,H,W)
                    if self.gru_flag:
                        gru_out = \
                            self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
                                0][
                                0].squeeze(0).squeeze(0)
                    else:
                        gru_out = batch_node_features[b][i, ...] + agg_feature
                    updated_node_features.append(gru_out.unsqueeze(0))
                # (N,C,H,W)
                batch_updated_node_features.append(
                    torch.cat(updated_node_features, dim=0))
            batch_node_features = batch_updated_node_features
        # (B,C,H,W)
        out = torch.cat(
            [itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)
        # (B,C,H,W)
        out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        return out

以下是解析

python 复制代码
class V2VNetFusion(nn.Module):
    def __init__(self, args):
        super(V2VNetFusion, self).__init__()

定义了一个名为 V2VNetFusion 的类,它继承自 nn.Module,表示它是一个 PyTorch 模型

python 复制代码
    in_channels = args['in_channels']
    H, W = args['conv_gru']['H'], args['conv_gru']['W']
    kernel_size = args['conv_gru']['kernel_size']
    num_gru_layers = args['conv_gru']['num_layers']

    self.discrete_ratio = args['voxel_size'][0]
    self.downsample_rate = args['downsample_rate']
    self.num_iteration = args['num_iteration']
    self.gru_flag = args['gru_flag']
    self.agg_operator = args['agg_operator']

    self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3,
                             stride=1, padding=1)
    self.conv_gru = ConvGRU(input_size=(H, W),
                            input_dim=in_channels * 2,
                            hidden_dim=[in_channels],
                            kernel_size=kernel_size,
                            num_layers=num_gru_layers,
                            batch_first=True,
                            bias=True,
                            return_all_layers=False)
    self.mlp = nn.Linear(in_channels, in_channels)

定义了模型的子模块,包括消息卷积层 msg_cnn、ConvGRU 层 conv_gru 和全连接层 mlp。

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return split_x
定义了一个辅助函数 regroup,用于将输入数据 x 按照记录长度 record_len 进行划分。

def forward(self, x, record_len, pairwise_t_matrix):
    ...

定义了模型的前向传播函数 forward,接受输入数据 x、记录长度 record_len 和从每个 cav 到 ego 的配对变换矩阵 pairwise_t_matrix。

        _, C, H, W = x.shape
        B, L = pairwise_t_matrix.shape[:2]
获取输入数据 x 的形状,并获取配对变换矩阵的大小。
python
Copy code
        split_x = self.regroup(x, record_len)
将输入数据 x 按记录长度划分成多个子数据集。

    pairwise_t_matrix = get_discretized_transformation_matrix(
        pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio,
        self.downsample_rate).reshape(B, L, L, 2, 3)

将配对变换矩阵进行离散化处理,得到变换后的矩阵。

        roi_mask = get_rotated_roi((B * L, L, 1, H, W),
                                   pairwise_t_matrix.reshape(B * L * L, 2, 3))
        roi_mask = roi_mask.reshape(B, L, L, 1, H, W)
        batch_node_features = split_x
根据变换矩阵获取 ROI(感兴趣区域)掩码,并将输入数据划分成子数据集。

    for l in range(self.num_iteration):

循环执行迭代次数。

            batch_updated_node_features = []
            # iterate each batch
            for b in range(B):

                # number of valid agent
                N = record_len[b]
                t_matrix = pairwise_t_matrix[b][:N, :N, :, :]
                updated_node_features = []
获取当前批次中记录的数量,以及对应的配对变换矩阵。

            for i in range(N):
                mask = roi_mask[b, :N, i, ...]
                current_t_matrix = t_matrix[:, i, :, :]
                current_t_matrix = get_transformation_matrix(
                    current_t_matrix, (H, W))

获取当前节点的掩码和变换矩阵。

                    neighbor_feature = warp_affine(batch_node_features[b],
                                                   current_t_matrix,
                                                   (H, W))
                    ego_agent_feature = batch_node_features[b][i].unsqueeze(
                        0).repeat(N, 1, 1, 1)
                    neighbor_feature = torch.cat(
                        [neighbor_feature, ego_agent_feature], dim=1)
                    message = self.msg_cnn(neighbor_feature) * mask
获取邻居特征并进行拼接,并通过消息卷积层计算消息。

                if self.agg_operator=="avg":
                    agg_feature = torch.mean(message, dim=0)
                elif self.agg_operator=="max":
                    agg_feature = torch.max(message, dim=0)[0]
                else:
                    raise ValueError("agg_operator has wrong value")

根据聚合操作符计算聚合特征。

                    cat_feature = torch.cat(
                        [batch_node_features[b][i, ...], agg_feature], dim=0)
将原始特征和聚合特征进行拼接。

                if self.gru_flag:
                    gru_out = \
                        self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[
                            0][
                            0].squeeze(0).squeeze(0)
                else:
                    gru_out = batch_node_features[b][i, ...] + agg_feature
                updated_node_features.append(gru_out.unsqueeze(0))

使用 GRU 网络更新特征。

                batch_updated_node_features.append(
                    torch.cat(updated_node_features, dim=0))
            batch_node_features = batch_updated_node_features
更新批次中的节点特征。

    out = torch.cat(
        [itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0)

将更新后的特征进行拼接。

        out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
通过全连接层对特征进行处理。

    return out

返回处理后的特征。

ConvGRU

ConvGRU(卷积门控循环单元)是一种结合了卷积神经网络(CNN)和循环神经网络(RNN)思想的模型,用于处理时序数据。它类似于传统的循环神经网络,但具有更好地捕捉空间信息的能力。

ConvGRU的核心是ConvGRUCell,它是ConvGRU的基本单元,类似于LSTM中的LSTMCell。与LSTM类似,GRU也有更新门(update gate)和重置门(reset gate)。这些门控制着隐藏状态中的信息流动,使得模型能够根据当前输入和前一个隐藏状态选择性地更新或保留信息。更新门决定了多少旧信息应该被保留,而重置门决定了多少旧信息应该被忽略。这样,GRU能够在处理序列数据时更好地捕捉长期依赖关系。卷积操作用于捕捉输入数据中的空间结构。在ConvGRUCell中,卷积操作能够更好地处理图像等具有空间关系的数据。

在每个时间步,ConvGRUCell接收当前输入和前一个隐藏状态作为输入。

  • 首先,通过卷积操作处理当前输入,以捕获输入数据的空间结构。
  • 然后,利用GRU的门控机制,根据当前输入和前一个隐藏状态生成更新门和重置门,从而控制信息的流动。
  • 最后,根据门控制的信息流动,结合当前输入和前一个隐藏状态,生成新的隐藏状态。
python 复制代码
class ConvGRU(nn.Module):
    def __init__(self, input_size, input_dim, hidden_dim, kernel_size,
                 num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        """
        :param input_size: (int, int)
            Height and width of input tensor as (height, width).
        :param input_dim: int e.g. 256
            Number of channels of input tensor.
        :param hidden_dim: int e.g. 1024
            Number of channels of hidden state.
        :param kernel_size: (int, int)
            Size of the convolutional kernel.
        :param num_layers: int
            Number of ConvLSTM layers
        :param dtype: torch.cuda.FloatTensor or torch.FloatTensor
            Whether or not to use cuda.
        :param alexnet_path: str
            pretrained alexnet parameters
        :param batch_first: bool
            if the first position of array is batch or not
        :param bias: bool
            Whether or not to add the bias.
        :param return_all_layers: bool
            if return hidden and cell states for all layers
        """
        super(ConvGRU, self).__init__()

        # Make sure that both `kernel_size` and
        # `hidden_dim` are lists having len == num_layers
        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')

        self.height, self.width = input_size
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.num_layers = num_layers
        self.batch_first = batch_first
        self.bias = bias
        self.return_all_layers = return_all_layers

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
            cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
                                         input_dim=cur_input_dim,
                                         hidden_dim=self.hidden_dim[i],
                                         kernel_size=self.kernel_size[i],
                                         bias=self.bias))

        # convert python list to pytorch module
        self.cell_list = nn.ModuleList(cell_list)

    def forward(self, input_tensor, hidden_state=None):
        """
        :param input_tensor: (b, t, c, h, w) or (t,b,c,h,w)
            depends on if batch first or not extracted features from alexnet
        :param hidden_state:
        :return: layer_output_list, last_state_list
        """
        if not self.batch_first:
            # (t, b, c, h, w) -> (b, t, c, h, w)
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

        # Implement stateful ConvLSTM
        if hidden_state is not None:
            raise NotImplementedError()
        else:
            hidden_state = self._init_hidden(batch_size=input_tensor.size(0),
                                             device=input_tensor.device,
                                             dtype=input_tensor.dtype)

        layer_output_list = []
        last_state_list = []

        seq_len = input_tensor.size(1)
        cur_layer_input = input_tensor

        for layer_idx in range(self.num_layers):
            h = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                # input current hidden and cell state
                # then compute the next hidden
                # and cell state through ConvLSTMCell forward function
                h = self.cell_list[layer_idx](
                    input_tensor=cur_layer_input[:, t, :, :, :],  # (b,t,c,h,w)
                    h_cur=h)
                output_inner.append(h)

            layer_output = torch.stack(output_inner, dim=1)
            cur_layer_input = layer_output

            layer_output_list.append(layer_output)
            last_state_list.append([h])

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]

        return layer_output_list, last_state_list

    def _init_hidden(self, batch_size, device=None, dtype=None):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(
                self.cell_list[i].init_hidden(batch_size).to(device).to(dtype))
        return init_states

    @staticmethod
    def _check_kernel_size_consistency(kernel_size):
        if not (isinstance(kernel_size, tuple) or
                (isinstance(kernel_size, list) and all(
                    [isinstance(elem, tuple) for elem in kernel_size]))):
            raise ValueError('`kernel_size` must be tuple or list of tuples')

    @staticmethod
    def _extend_for_multilayer(param, num_layers):
        if not isinstance(param, list):
            param = [param] * num_layers
        return param

上述代码定义了一个名为 ConvGRU 的类,用于实现卷积门控循环单元(Convolutional Gated Recurrent Unit,ConvGRU)的模型。以下是对代码的逐行解析:

class ConvGRU(nn.Module):
    def __init__(self, input_size, input_dim, hidden_dim, kernel_size,
                 num_layers,
                 batch_first=False, bias=True, return_all_layers=False):
        super(ConvGRU, self).__init__()

定义了 ConvGRU 类,继承自 nn.Module。构造函数中接收了输入尺寸 input_size、输入通道数 input_dim、隐藏状态通道数 hidden_dim、卷积核大小 kernel_size、层数 num_layers 等参数。

        kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
        hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
        if not len(kernel_size) == len(hidden_dim) == num_layers:
            raise ValueError('Inconsistent list length.')
对 kernel_size 和 hidden_dim 进行扩展,确保它们的长度与层数一致。

    self.height, self.width = input_size
    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.kernel_size = kernel_size
    self.num_layers = num_layers
    self.batch_first = batch_first
    self.bias = bias
    self.return_all_layers = return_all_layers

初始化了类中的一些属性,包括输入的高度和宽度、输入通道数、隐藏状态通道数、卷积核大小、层数、是否以 batch 为第一维、是否使用偏置、是否返回所有层的状态。

        cell_list = []
        for i in range(0, self.num_layers):
            cur_input_dim = input_dim if i == 0 else hidden_dim[i - 1]
            cell_list.append(ConvGRUCell(input_size=(self.height, self.width),
                                         input_dim=cur_input_dim,
                                         hidden_dim=self.hidden_dim[i],
                                         kernel_size=self.kernel_size[i],
                                         bias=self.bias))
创建了一个列表 cell_list,用于存储每一层的 ConvGRUCell 实例。

    self.cell_list = nn.ModuleList(cell_list)

将 cell_list 转换为 nn.ModuleList 类型,以确保每个 ConvGRUCell 实例都被正确注册为模型的子模块。

    def forward(self, input_tensor, hidden_state=None):
        if not self.batch_first:
            input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
如果不是以 batch 为第一维,需要对输入数据进行转置,将 batch 维度置于第一维。

    if hidden_state is not None:
        raise NotImplementedError()
    else:
        hidden_state = self._init_hidden(batch_size=input_tensor.size(0),
                                         device=input_tensor.device,
                                         dtype=input_tensor.dtype)

如果传入了隐藏状态,则抛出 NotImplementedError;否则,根据输入数据的大小和设备信息初始化隐藏状态。

        layer_output_list = []
        last_state_list = []
定义了用于存储每一层输出和最终状态的列表。

    seq_len = input_tensor.size(1)
    cur_layer_input = input_tensor

获取输入序列的长度,并初始化当前层的输入。

        for layer_idx in range(self.num_layers):
            h = hidden_state[layer_idx]
            output_inner = []
            for t in range(seq_len):
                h = self.cell_list[layer_idx](
                    input_tensor=cur_layer_input[:, t, :, :, :],
                    h_cur=h)
                output_inner.append(h)
遍历每一层,对于每一时刻,利用当前层的 ConvGRUCell 实例更新隐藏状态,并将更新后的隐藏状态存入列表中。

        layer_output = torch.stack(output_inner, dim=1)
        cur_layer_input = layer_output

        layer_output_list.append(layer_output)
        last_state_list.append([h])

将每一层的输出存入列表 layer_output_list 中,并将最后一个时刻的隐藏状态存入 last_state_list 中。

        if not self.return_all_layers:
            layer_output_list = layer_output_list[-1:]
            last_state_list = last_state_list[-1:]
如果设置了不返回所有层的状态,则只保留最后一层的状态。

    return layer_output_list, last_state_list

返回所有层的输出和最终状态。

    def _init_hidden(self, batch_size, device=None, dtype=None):
        init_states = []
        for i in range(self.num_layers):
            init_states.append(
                self.cell_list[i].init_hidden(batch_size).to(device).to(dtype))
        return init_states
辅助方法,用于初始化隐藏状态。

@staticmethod
def _extend_for_multilayer(param, num_layers):
    if not isinstance(param, list):
        param = [param] * num_layers
    return param

辅助方法,用于将单个参数扩展成多层的形式。

相关推荐
YoseZang4 分钟前
【机器学习和深度学习】分类问题通用评价指标:精确率、召回率、准确率和混淆矩阵
深度学习·机器学习·分类算法
微臣愚钝4 分钟前
《Generative Adversarial Nets》-GAN:生成对抗网络,一场伪造者与鉴定师的终极博弈
人工智能·深度学习
木卯9 分钟前
5种创建型设计模式笔记(Python实现)
python·设计模式
IT古董9 分钟前
【漫话机器学习系列】128.预处理之训练集与测试集(Preprocessing Traning And Test Sets)
深度学习·机器学习·自然语言处理
JokerSZ.21 分钟前
复现:latent diffusion(LDM)stable diffusion
人工智能·深度学习·stable diffusion·生成模型
T0uken24 分钟前
【深度学习】Pytorch:更换激活函数
人工智能·pytorch·深度学习
张琪杭25 分钟前
pytorch tensor创建tensor
人工智能·pytorch·python
山西茄子31 分钟前
DeepStream推理dewarped所有surfaces
人工智能·深度学习·计算机视觉·deepstream
星星点点洲35 分钟前
【RAG】RAG 系统的基本搭建流程(ES关键词检索示例)
python·elasticsearch
带娃的IT创业者1 小时前
《Python实战进阶》No18: 使用 Apache Spark 进行分布式计算
python·spark·apache