【Python/Pytorch - 网络模型】-- 手把手搭建E3D LSTM网络


文章目录

文章目录

  • [00 写在前面](#00 写在前面)
  • [01 基于Pytorch版本的E3D LSTM代码](#01 基于Pytorch版本的E3D LSTM代码)
  • [02 论文下载](#02 论文下载)

00 写在前面

测试代码,比较重要,它可以大概判断tensor维度在网络传播过程中,各个维度的变化情况,方便改成适合自己的数据集。

需要github上的数据集以及可运行的代码,可以私聊!

01 基于Pytorch版本的E3D LSTM代码

python 复制代码
# 库函数调用
from functools import reduce
from src.utils import nice_print, mem_report, cpu_stats
import copy
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F

# E3DLSTM模型代码
class E3DLSTM(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau):
        super().__init__()

        self._tau = tau
        self._cells = []

        input_shape = list(input_shape)
        for i in range(num_layers):
            cell = E3DLSTMCell(input_shape, hidden_size, kernel_size)
            # NOTE hidden state becomes input to the next cell
            input_shape[0] = hidden_size
            self._cells.append(cell)
            # Hook to register submodule
            setattr(self, "cell{}".format(i), cell)

    def forward(self, input):
        # NOTE (seq_len, batch, input_shape)

        batch_size = input.size(1)
        c_history_states = []
        h_states = []
        outputs = []

        for step, x in enumerate(input):
            for cell_idx, cell in enumerate(self._cells):
                if step == 0:
                    c_history, m, h = self._cells[cell_idx].init_hidden(
                        batch_size, self._tau, input.device
                    )
                    c_history_states.append(c_history)
                    h_states.append(h)

                # NOTE c_history and h are coming from the previous time stamp, but we iterate over cells
                c_history, m, h = cell(
                    x, c_history_states[cell_idx], m, h_states[cell_idx]
                )
                c_history_states[cell_idx] = c_history
                h_states[cell_idx] = h
                # NOTE hidden state of previous LSTM is passed as input to the next one
                x = h

            outputs.append(h)

        # NOTE Concat along the channels
        return torch.cat(outputs, dim=1)


class E3DLSTMCell(nn.Module):
    def __init__(self, input_shape, hidden_size, kernel_size):
        super().__init__()

        in_channels = input_shape[0]
        self._input_shape = input_shape
        self._hidden_size = hidden_size

        # memory gates: input, cell(input modulation), forget
        self.weight_xi = ConvDeconv3d(in_channels, hidden_size, kernel_size)
        self.weight_hi = ConvDeconv3d(hidden_size, hidden_size, kernel_size, bias=False)

        self.weight_xg = copy.deepcopy(self.weight_xi)
        self.weight_hg = copy.deepcopy(self.weight_hi)

        self.weight_xr = copy.deepcopy(self.weight_xi)
        self.weight_hr = copy.deepcopy(self.weight_hi)

        memory_shape = list(input_shape)
        memory_shape[0] = hidden_size

        # self.layer_norm = nn.LayerNorm(memory_shape)
        self.group_norm = nn.GroupNorm(1, hidden_size) # wzj

        # for spatiotemporal memory
        self.weight_xi_prime = copy.deepcopy(self.weight_xi)
        self.weight_mi_prime = copy.deepcopy(self.weight_hi)

        self.weight_xg_prime = copy.deepcopy(self.weight_xi)
        self.weight_mg_prime = copy.deepcopy(self.weight_hi)

        self.weight_xf_prime = copy.deepcopy(self.weight_xi)
        self.weight_mf_prime = copy.deepcopy(self.weight_hi)

        self.weight_xo = copy.deepcopy(self.weight_xi)
        self.weight_ho = copy.deepcopy(self.weight_hi)
        self.weight_co = copy.deepcopy(self.weight_hi)
        self.weight_mo = copy.deepcopy(self.weight_hi)

        self.weight_111 = nn.Conv3d(hidden_size + hidden_size, hidden_size, 1)

    def self_attention(self, r, c_history):
        batch_size = r.size(0)
        channels = r.size(1)
        r_flatten = r.view(batch_size, -1, channels)
        # BxtaoTHWxC
        c_history_flatten = c_history.view(batch_size, -1, channels)

        # Attention mechanism
        # BxTHWxC x BxtaoTHWxC' = B x THW x taoTHW
        scores = torch.einsum("bxc,byc->bxy", r_flatten, c_history_flatten)
        attention = F.softmax(scores, dim=2)

        return torch.einsum("bxy,byc->bxc", attention, c_history_flatten).view(*r.shape)

    def self_attention_fast(self, r, c_history):
        # Scaled Dot-Product but for tensors
        # instead of dot-product we do matrix contraction on twh dimensions
        scaling_factor = 1 / (reduce(operator.mul, r.shape[-3:], 1) ** 0.5)
        scores = torch.einsum("bctwh,lbctwh->bl", r, c_history) * scaling_factor

        attention = F.softmax(scores, dim=0)
        return torch.einsum("bl,lbctwh->bctwh", attention, c_history)

    def forward(self, x, c_history, m, h):
        # Normalized shape for LayerNorm is CxT×H×W
        normalized_shape = list(h.shape[-3:])

        def LR(input):
            # return F.layer_norm(input, normalized_shape)
            return self.group_norm(input, normalized_shape) # wzj

        # R is CxT×H×W
        r = torch.sigmoid(LR(self.weight_xr(x) + self.weight_hr(h)))
        i = torch.sigmoid(LR(self.weight_xi(x) + self.weight_hi(h)))
        g = torch.tanh(LR(self.weight_xg(x) + self.weight_hg(h)))

        recall = self.self_attention_fast(r, c_history)
        # nice_print(**locals())
        # mem_report()
        # cpu_stats()

        c = i * g + self.group_norm(c_history[-1] + recall) # wzj

        i_prime = torch.sigmoid(LR(self.weight_xi_prime(x) + self.weight_mi_prime(m)))
        g_prime = torch.tanh(LR(self.weight_xg_prime(x) + self.weight_mg_prime(m)))
        f_prime = torch.sigmoid(LR(self.weight_xf_prime(x) + self.weight_mf_prime(m)))

        m = i_prime * g_prime + f_prime * m
        o = torch.sigmoid(
            LR(
                self.weight_xo(x)
                + self.weight_ho(h)
                + self.weight_co(c)
                + self.weight_mo(m)
            )
        )
        h = o * torch.tanh(self.weight_111(torch.cat([c, m], dim=1)))

        # TODO is it correct FIFO?
        c_history = torch.cat([c_history[1:], c[None, :]], dim=0)
        # nice_print(**locals())

        return (c_history, m, h)

    def init_hidden(self, batch_size, tau, device=None):
        memory_shape = list(self._input_shape)
        memory_shape[0] = self._hidden_size
        c_history = torch.zeros(tau, batch_size, *memory_shape, device=device)
        m = torch.zeros(batch_size, *memory_shape, device=device)
        h = torch.zeros(batch_size, *memory_shape, device=device)

        return (c_history, m, h)


class ConvDeconv3d(nn.Module):
    def __init__(self, in_channels, out_channels, *vargs, **kwargs):
        super().__init__()

        self.conv3d = nn.Conv3d(in_channels, out_channels, *vargs, **kwargs)
        # self.conv_transpose3d = nn.ConvTranspose3d(out_channels, out_channels, *vargs, **kwargs)

    def forward(self, input):
        # print(self.conv3d(input).shape, input.shape)
        # return self.conv_transpose3d(self.conv3d(input))
        return F.interpolate(self.conv3d(input), size=input.shape[-3:], mode="nearest")

class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

class E3DLSTM_NET(nn.Module):
    def __init__(self, input_shape, hidden_size, num_layers, kernel_size, tau, time_steps, output_shape):
        super().__init__()

        self.input_shape = input_shape
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.kernel_size = kernel_size
        self.tau = tau
        self.time_steps = time_steps
        self.output_shape = output_shape
        self.dtype = torch.float32

        self.encoder = E3DLSTM(
            input_shape, hidden_size, num_layers, kernel_size, tau
        ).type(self.dtype)
        self.decoder = nn.Conv3d(
            hidden_size * time_steps, output_shape[0], kernel_size, padding=(0, 2, 2)
        ).type(self.dtype)
        self.out = Out(4, 1)

    def forward(self, input_seq):
        return self.out(self.decoder(self.encoder(input_seq)))

# 测试代码
if __name__ == '__main__':
    input_shape = (16, 4, 16, 16)
    output_shape = (16, 1, 16, 16)

    tau = 2
    hidden_size = 64
    kernel = (3, 5, 5)
    lstm_layers = 4
    time_steps = 29

    x = torch.ones([29, 2, 16, 4, 16, 16])
    model = E3DLSTM_NET(input_shape, hidden_size, lstm_layers, kernel, tau, time_steps, output_shape)
    print('finished!')
    f = model(x)
    print(f)

02 论文下载

Eidetic 3D LSTM: A Model for Video Prediction and Beyond
Eidetic 3D LSTM: A Model for Video Prediction and Beyond

Github链接:e3d_lstm

相关推荐
Jack_pirate8 分钟前
深度学习中的特征到底是什么?
人工智能·深度学习
微凉的衣柜22 分钟前
微软在AI时代的战略布局和挑战
人工智能·深度学习·microsoft
GocNeverGiveUp35 分钟前
机器学习1-简单神经网络
人工智能·机器学习
Schwertlilien38 分钟前
图像处理-Ch2-空间域的图像增强
人工智能
一只敲代码的猪1 小时前
Llama 3 模型系列解析(一)
大数据·python·llama
智慧化智能化数字化方案1 小时前
深入解读数据资产化实践指南(2024年)
大数据·人工智能·数据资产管理·数据资产入表·数据资产化实践指南
哦哦~9211 小时前
深度学习驱动的油气开发技术与应用
大数据·人工智能·深度学习·学习
Hello_WOAIAI1 小时前
批量将 Word 文件转换为 HTML:Python 实现指南
python·html·word
智慧化智能化数字化方案1 小时前
120页PPT讲解ChatGPT如何与财务数字化转型的业财融合
人工智能·chatgpt
winfredzhang1 小时前
使用Python开发PPT图片提取与九宫格合并工具
python·powerpoint·提取·九宫格·照片