PixelSNAIL论文代码学习(2)——门控残差网络的实现

文章目录

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他门门控控残差模块如何实现。

正文

门控残差网络介绍

  • 介绍

    • 通过门来控制每一个残差模块,门通常是由sigmoid函数组成
    • 作用:有效建模复杂函数,有助于缓解梯度消失和爆炸的问题
  • 基本步骤

    • 卷积操作:对输入矩阵执行卷积操作
    • 非线性激活:应用非线性激活函数,激活卷积操作的输出
    • 第二次卷积操作:对上一个层的输出进行二次卷积
    • 门控操作 :将二次卷积的输出分为a和b两个部分,并且通过sigmoid函数进行门控 a , b = S p l i t ( c 2 ) G a t e : g = a × s i g m o i d ( b ) a,b = Split(c_2) \\ Gate:g = a \times sigmoid(b) a,b=Split(c2)Gate:g=a×sigmoid(b)
      • 这里一般是沿着最后一个通道,将原来的矩阵拆解成a和b,然后在相乘,确保每一个矩阵有一个门控参数
    • 将门控输出 g g g和原始输入 x x x相加
  • 具体流程图如下

    • x: 输入
    • c1: 第一次卷积操作(Conv1)
    • a1: 非线性激活函数(例如 ReLU)
    • c2: 第二次卷积操作(Conv2),输出通道数是输入通道数的两倍
    • split: 将c2 分为两部分 a 和 b
    • a, b: 由 c2 分割得到的两部分
    • sigmoid: 对b 应用 sigmoid 函数
    • gated: 执行门控操作 a×sigmoid(b)
    • y: 输出,由原始输入 x 和门控输出相加得到
  • 这里参考一下论文中的图片,可以看到和基本的门控神经网络是近似的,只不过增加了一些辅助输入还有条件矩阵

门控残差网络具体实现代码

  • 具体和上面描述的差不多,这里增加了两个额外的参数,分别是辅助输入a和条件矩阵b

  • 注意,这里的二维卷积就是加上了简单的权重归一化的普通二维卷积。

  • 辅助输入a

    • 用途:提供额外的信息,帮助网络更好地执行任务,比如说在多模态场景或者多任务学习中,会通过a提供主输入x相关联的信息
    • 操作:如果提供了a,那么在第一次卷积之后,会经过全连接层与c1相加
  • 条件矩阵h

    • 用途:主要用于条件生成任务,因为条件生成任务的网络行为会受到某些条件和上下文影响。比如,在文本生成图像中,h会是一个文本描述的嵌入
    • 操作:如果提供了 h,那么 h 会被投影到一个与 c2 具有相同维度的空间中,并与 c2 相加。这是通过一个全连接层实现的,该层的权重是 hw。
python 复制代码
def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs):
    xs = int_shape(x)
    num_filters = xs[-1]

    # 执行第一次卷积
    c1 = conv(nonlinearity(x), num_filters)

    # 查看是否有辅助输入a
    if a is not None:  # add short-cut connection if auxiliary input 'a' is given
        c1 += nin(nonlinearity(a), num_filters)

    # 执行非线性单元
    c1 = nonlinearity(c1)
    if dropout_p > 0:
        c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p)

    # 执行第二次卷积
    c2 = conv(c1, num_filters * 2, init_scale=0.1)

    # add projection of h vector if included: conditional generation
    # 如果有辅助输入h,那么就将h投影到c2的维度上
    if h is not None:
        with tf.variable_scope(get_name('conditional_weights', counters)):
            hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32,
                                   initializer=tf.random_normal_initializer(0, 0.05), trainable=True)
        if init:
            hw = hw.initialized_value()
        c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters])

    # Is this 3,2 or 2,3 ?
    a, b = tf.split(c2, 2, 3)
    c3 = a * tf.nn.sigmoid(b)
    return x + c3

使用pytorch实现

  • tensorflow的模型定义过程和pytorch的定义过程就是不一样,tensorflow中的conv2d只需要给出输出的channel,直接输入需要卷积的部分即可。但是使用pytorch,需要进行给定输入的 channel,然后在给出输出的filter_size,很麻烦。
  • 除此之外,在定义模型的层的过程中,我们不能在forward中定义层,只能在init函数中定义层。
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm

class GatedResNet(nn.Module):
    def __init__(self, num_filters, nonlinearity=F.elu, dropout_p=0.0):
        super(GatedResNet, self).__init__()
        self.num_filters = num_filters
        self.nonlinearity = nonlinearity
        self.dropout_p = dropout_p

        # 第一卷积层
        self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1)
#         self.conv1 = weight_norm(self.conv1)

        # 第二卷积层,输出通道是 2 * num_filters,用于门控机制
        self.conv2 = nn.Conv2d(num_filters, 2 * num_filters, kernel_size=3, padding=1)
#         self.conv2 = weight_norm(self.conv2)

        
        # 条件权重用于 h,初始化在前向传播过程中
        self.hw = None

    def forward(self, x, a=None, h=None):
        c1 = self.conv1(self.nonlinearity(x))

        # 检查是否有辅助输入 'a'
        if a is not None:
            c1 += a  # 或使用 NIN 使维度兼容

        c1 = self.nonlinearity(c1)
        if self.dropout_p > 0:
            c1 = F.dropout(c1, p=self.dropout_p, training=self.training)

        c2 = self.conv2(c1)
        print('the shape of c2',c2.shape)

        # 如果有辅助输入 h,则加入 h 的投影
        if h is not None:
            if self.hw is None:
                self.hw = nn.Parameter(torch.randn(h.size(1),  self.num_filters) * 0.05)
            print(self.hw.shape)
            c2 +=  (h @ self.hw).view(h.size(0), 1, 1, self.num_filters)
            

        # 将通道分为两组:'a' 和 'b'
        a, b = c2.chunk(2, dim=1)
        c3 = a * torch.sigmoid(b)

        return x + c3

# 测试
x = torch.randn(16, 32, 32, 32)  # [批次大小,通道数,高度,宽度]
a = torch.randn(16, 32, 32, 32)  # 和 x 维度相同的辅助输入
h = torch.randn(16, 64)  # 可选的条件变量
model = GatedResNet(32)
out = model(x, a , h)

总结

  • 遇到了很多问题,是因为经验不够,而且很多东西都不了解,然后改的很痛苦,而且现在完全还没有跑起来,完整的组件都没有搭建完成,这里还需要继续努力。
  • 关于门控残差网络这里,这里学到了很多,知道了具体的运作流程,也知道他是专门针对序列数据,防止出现梯度爆炸的。以后可以多用用看。
相关推荐
zquwei16 分钟前
SpringCloudGateway+Nacos注册与转发Netty+WebSocket
java·网络·分布式·后端·websocket·网络协议·spring
爱吃西瓜的小菜鸡16 分钟前
【C语言】判断回文
c语言·学习·算法
Aimin202230 分钟前
路由器做WPAD、VPN、透明代理中之间一个
网络
小A15939 分钟前
STM32完全学习——SPI接口的FLASH(DMA模式)
stm32·嵌入式硬件·学习
群联云防护小杜1 小时前
如何给负载均衡平台做好安全防御
运维·服务器·网络·网络协议·安全·负载均衡
岁岁岁平安1 小时前
spring学习(spring-DI(字符串或对象引用注入、集合注入)(XML配置))
java·学习·spring·依赖注入·集合注入·基本数据类型注入·引用数据类型注入
武昌库里写JAVA1 小时前
Java成长之路(一)--SpringBoot基础学习--SpringBoot代码测试
java·开发语言·spring boot·学习·课程设计
qq_589568101 小时前
数据可视化echarts学习笔记
学习·信息可视化·echarts
爱码小白1 小时前
网络编程(王铭东老师)笔记
服务器·网络·笔记
蜜獾云1 小时前
linux firewalld 命令详解
linux·运维·服务器·网络·windows·网络安全·firewalld