【Python/Pytorch - 网络模型】-- 手把手搭建E3D LSTM网络


文章目录

文章目录

  • [00 写在前面](#00 写在前面)
  • [01 基于Pytorch版本的E3D LSTM代码](#01 基于Pytorch版本的E3D LSTM代码)
  • [02 论文下载](#02 论文下载)

00 写在前面

测试代码,比较重要,它可以大概判断tensor维度在网络传播过程中,各个维度的变化情况,方便改成适合自己的数据集。

需要github上的数据集以及可运行的代码,可以私聊!

01 基于Pytorch版本的E3D LSTM代码

python 复制代码
# 库函数调用
from functools import reduce
from src.utils import nice_print, mem_report, cpu_stats
import copy
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F

# E3DLSTM模型代码
class E3DLSTM(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau):
        super().__init__()

        self._tau = tau
        self._cells = []

        input_shape = list(input_shape)
        for i in range(num_layers):
            cell = E3DLSTMCell(input_shape, hidden_size, kernel_size)
            # NOTE hidden state becomes input to the next cell
            input_shape[0] = hidden_size
            self._cells.append(cell)
            # Hook to register submodule
            setattr(self, "cell{}".format(i), cell)

    def forward(self, input):
        # NOTE (seq_len, batch, input_shape)

        batch_size = input.size(1)
        c_history_states = []
        h_states = []
        outputs = []

        for step, x in enumerate(input):
            for cell_idx, cell in enumerate(self._cells):
                if step == 0:
                    c_history, m, h = self._cells[cell_idx].init_hidden(
                        batch_size, self._tau, input.device
                    )
                    c_history_states.append(c_history)
                    h_states.append(h)

                # NOTE c_history and h are coming from the previous time stamp, but we iterate over cells
                c_history, m, h = cell(
                    x, c_history_states[cell_idx], m, h_states[cell_idx]
                )
                c_history_states[cell_idx] = c_history
                h_states[cell_idx] = h
                # NOTE hidden state of previous LSTM is passed as input to the next one
                x = h

            outputs.append(h)

        # NOTE Concat along the channels
        return torch.cat(outputs, dim=1)


class E3DLSTMCell(nn.Module):
    def __init__(self, input_shape, hidden_size, kernel_size):
        super().__init__()

        in_channels = input_shape[0]
        self._input_shape = input_shape
        self._hidden_size = hidden_size

        # memory gates: input, cell(input modulation), forget
        self.weight_xi = ConvDeconv3d(in_channels, hidden_size, kernel_size)
        self.weight_hi = ConvDeconv3d(hidden_size, hidden_size, kernel_size, bias=False)

        self.weight_xg = copy.deepcopy(self.weight_xi)
        self.weight_hg = copy.deepcopy(self.weight_hi)

        self.weight_xr = copy.deepcopy(self.weight_xi)
        self.weight_hr = copy.deepcopy(self.weight_hi)

        memory_shape = list(input_shape)
        memory_shape[0] = hidden_size

        # self.layer_norm = nn.LayerNorm(memory_shape)
        self.group_norm = nn.GroupNorm(1, hidden_size) # wzj

        # for spatiotemporal memory
        self.weight_xi_prime = copy.deepcopy(self.weight_xi)
        self.weight_mi_prime = copy.deepcopy(self.weight_hi)

        self.weight_xg_prime = copy.deepcopy(self.weight_xi)
        self.weight_mg_prime = copy.deepcopy(self.weight_hi)

        self.weight_xf_prime = copy.deepcopy(self.weight_xi)
        self.weight_mf_prime = copy.deepcopy(self.weight_hi)

        self.weight_xo = copy.deepcopy(self.weight_xi)
        self.weight_ho = copy.deepcopy(self.weight_hi)
        self.weight_co = copy.deepcopy(self.weight_hi)
        self.weight_mo = copy.deepcopy(self.weight_hi)

        self.weight_111 = nn.Conv3d(hidden_size + hidden_size, hidden_size, 1)

    def self_attention(self, r, c_history):
        batch_size = r.size(0)
        channels = r.size(1)
        r_flatten = r.view(batch_size, -1, channels)
        # BxtaoTHWxC
        c_history_flatten = c_history.view(batch_size, -1, channels)

        # Attention mechanism
        # BxTHWxC x BxtaoTHWxC' = B x THW x taoTHW
        scores = torch.einsum("bxc,byc->bxy", r_flatten, c_history_flatten)
        attention = F.softmax(scores, dim=2)

        return torch.einsum("bxy,byc->bxc", attention, c_history_flatten).view(*r.shape)

    def self_attention_fast(self, r, c_history):
        # Scaled Dot-Product but for tensors
        # instead of dot-product we do matrix contraction on twh dimensions
        scaling_factor = 1 / (reduce(operator.mul, r.shape[-3:], 1) ** 0.5)
        scores = torch.einsum("bctwh,lbctwh->bl", r, c_history) * scaling_factor

        attention = F.softmax(scores, dim=0)
        return torch.einsum("bl,lbctwh->bctwh", attention, c_history)

    def forward(self, x, c_history, m, h):
        # Normalized shape for LayerNorm is CxT×H×W
        normalized_shape = list(h.shape[-3:])

        def LR(input):
            # return F.layer_norm(input, normalized_shape)
            return self.group_norm(input, normalized_shape) # wzj

        # R is CxT×H×W
        r = torch.sigmoid(LR(self.weight_xr(x) + self.weight_hr(h)))
        i = torch.sigmoid(LR(self.weight_xi(x) + self.weight_hi(h)))
        g = torch.tanh(LR(self.weight_xg(x) + self.weight_hg(h)))

        recall = self.self_attention_fast(r, c_history)
        # nice_print(**locals())
        # mem_report()
        # cpu_stats()

        c = i * g + self.group_norm(c_history[-1] + recall) # wzj

        i_prime = torch.sigmoid(LR(self.weight_xi_prime(x) + self.weight_mi_prime(m)))
        g_prime = torch.tanh(LR(self.weight_xg_prime(x) + self.weight_mg_prime(m)))
        f_prime = torch.sigmoid(LR(self.weight_xf_prime(x) + self.weight_mf_prime(m)))

        m = i_prime * g_prime + f_prime * m
        o = torch.sigmoid(
            LR(
                self.weight_xo(x)
                + self.weight_ho(h)
                + self.weight_co(c)
                + self.weight_mo(m)
            )
        )
        h = o * torch.tanh(self.weight_111(torch.cat([c, m], dim=1)))

        # TODO is it correct FIFO?
        c_history = torch.cat([c_history[1:], c[None, :]], dim=0)
        # nice_print(**locals())

        return (c_history, m, h)

    def init_hidden(self, batch_size, tau, device=None):
        memory_shape = list(self._input_shape)
        memory_shape[0] = self._hidden_size
        c_history = torch.zeros(tau, batch_size, *memory_shape, device=device)
        m = torch.zeros(batch_size, *memory_shape, device=device)
        h = torch.zeros(batch_size, *memory_shape, device=device)

        return (c_history, m, h)


class ConvDeconv3d(nn.Module):
    def __init__(self, in_channels, out_channels, *vargs, **kwargs):
        super().__init__()

        self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)
        # self.conv_transpose3d = nn.ConvTranspose3d(out_channels, out_channels, *vargs, **kwargs)

    def forward(self, input):
        # print(self.conv3d(input).shape, input.shape)
        # return self.conv_transpose3d(self.conv3d(input))
        return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")

class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

class E3DLSTM_NET(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau, time_steps, output_shape):
        super().__init__()

        self.input_shape = input_shape
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.kernel_size = kernel_size
        self.tau = tau
        self.time_steps = time_steps
        self.output_shape = output_shape
        self.dtype = torch.float32

        self.encoder = E3DLSTM(
            input_shape, hidden_size, num_layers, kernel_size, tau
        ).type(self.dtype)
        self.decoder = nn.Conv3d(
            hidden_size * time_steps, output_shape[0], kernel_size, padding=(0, 2, 2)
        ).type(self.dtype)
        self.out = Out(4, 1)

    def forward(self, input_seq):
        return self.out(self.decoder(self.encoder(input_seq)))

# 测试代码
if __name__ == '__main__':
    input_shape = (16, 4, 16, 16)
    output_shape = (16, 1, 16, 16)

    tau = 2
    hidden_size = 64
    kernel = (3, 5, 5)
    lstm_layers = 4
    time_steps = 29

    x = torch.ones([29, 2, 16, 4, 16, 16])
    model = E3DLSTM_NET(input_shape, hidden_size, lstm_layers, kernel, tau, time_steps, output_shape)
    print('finished!')
    f = model(x)
    print(f)

02 论文下载

Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Eidetic 3D LSTM: A Model for Video Prediction and Beyond

Github链接:e3d_lstm

相关推荐
肥猪猪爸5 分钟前
使用卡尔曼滤波器估计pybullet中的机器人位置
数据结构·人工智能·python·算法·机器人·卡尔曼滤波·pybullet
LZXCyrus34 分钟前
【杂记】vLLM如何指定GPU单卡/多卡离线推理
人工智能·经验分享·python·深度学习·语言模型·llm·vllm
Enougme37 分钟前
Appium常用的使用方法(一)
python·appium
懷淰メ43 分钟前
PyQt飞机大战游戏(附下载地址)
开发语言·python·qt·游戏·pyqt·游戏开发·pyqt5
我感觉。1 小时前
【机器学习chp4】特征工程
人工智能·机器学习·主成分分析·特征工程
hummhumm1 小时前
第 22 章 - Go语言 测试与基准测试
java·大数据·开发语言·前端·python·golang·log4j
YRr YRr1 小时前
深度学习神经网络中的优化器的使用
人工智能·深度学习·神经网络
DieYoung_Alive1 小时前
一篇文章了解机器学习(下)
人工智能·机器学习
夏沫的梦1 小时前
生成式AI对产业的影响与冲击
人工智能·aigc
hummhumm1 小时前
第 28 章 - Go语言 Web 开发入门
java·开发语言·前端·python·sql·golang·前端框架