自定义神经网络时的注意事项

问题描述

`

通过继承tf.keras.Model自定义神经网络模型时遇到的一系列问题。
代码如下,

c在这里插入代码片 复制代码
class STFT_ConV2D(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_layer = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(768, activation='relu')
        ])

        self.add = tf.keras.layers.Add()
        self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        x, y = inputs
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)(x)
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)(x)
        x = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)(x)
        x = self.pre_layer(x)

        y = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)(y)
        y = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)(y)
        y = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)(y)
        y = self.pre_layer(y)
        output = self.add([x, y])
        output = self.output_dense(output)
        return output

产生的bug为,

markup 复制代码
  ValueError: Exception encountered when calling layer 'sequential' (type Sequential).
        
  Input 0 of layer "dense" is incompatible with the layer: expected axis -1 of input shape to have value 11368, but received input with shape (None, 210680)

x输入和y输入都使用了成员变量pre_layer,共享了pre_layer层,也就共享了pre_layer层的参数矩阵和结构。
由于x先经过三层卷积层后shape由原来的shape=(360, 256, 109, 1)变成了shape=(360, 203, 56, 1)
再经过pre_layer层里的Flatten时,除" batchsize "轴(axis=0)外,其余轴被铺平,输出shape=(360,11368)。接着处理y输入,经过三层卷积层后,shape由原来的shape=(360, 511, 513, 1)变成了shape=(360,458, 460, 1),之后执行到y = self.pre_layer(y)时,如果执行成功,则输出shape=(360,21068),此时与x的shape=(360,11368)维度冲突,从而产生异常。

要点归纳:

  1. 通过继承tf.keras.Model写神经网络模型时,每一个神经网络层只能被同一个输入占有。
  2. 所有tf.keras.layers下的层对象不能直接出现在call()方法中,必须以成员变量的形式在构造器中定义,然后在call()方法中通过self.成员变量的方式调用
  3. 卷积层tf.keras.layers.Conv2D()当神经网络第一层时,必须通过参数input_shape指定输入shape,该shape中不能包含" batchsize "轴,例如输入x的shape为(a, b, c, d),其中a代表样本数,b代表行像素,c代表列像素,d代表通道数。则应该指定input_shape=x.shape[1:],去除a所在轴,以免卷积层对该轴造成影响。

解决方案:

python 复制代码
class STFT_ConV2D(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conV2d_x1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)
        self.conV2d_x2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)
        self.conV2d_x3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)

        self.conV2d_y1 = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)
        self.conV2d_y2 = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)
        self.conV2d_y3 = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)

        self.flatten_x = tf.keras.layers.Flatten()
        self.flatten_y = tf.keras.layers.Flatten()

        self.dense_x = tf.keras.layers.Dense(768, activation='relu')
        self.dense_y = tf.keras.layers.Dense(768, activation='relu')

        self.add = tf.keras.layers.Add()
        self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        # x.shape = (360, 256, 109, 1) , y.shape = (360, 511, 513, 1)
        # inputs = (x, y)
        x, y = inputs  
        x = self.conV2d_x1(x) # (360, 249, 102, 3)
        x = self.conV2d_x2(x) # (360, 234, 87, 3)
        x = self.conV2d_x3(x) # (360, 203, 56, 1)
        x = self.flatten_x(x) # (360, 11368)
        x = self.dense_x(x)  # (360, 768)

        y = self.conV2d_y1(y)
        y = self.conV2d_y2(y)
        y = self.conV2d_y3(y)
        y = self.flatten_y(y)
        y = self.dense_y(y)

        output = self.add([x, y]) # (360, 768)
        output = self.output_dense(output)
        return output
相关推荐
聆风吟º40 分钟前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子1 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
智驱力人工智能1 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5772 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
h64648564h2 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切2 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
学电子她就能回来吗4 小时前
深度学习速成:损失函数与反向传播
人工智能·深度学习·学习·计算机视觉·github
爱吃泡芙的小白白4 小时前
突破传统:CNN卷积层(普通/空洞)核心技术演进与实战指南
人工智能·神经网络·cnn·卷积层·空洞卷积·普通卷积
Coder_Boy_4 小时前
TensorFlow小白科普
人工智能·深度学习·tensorflow·neo4j
大模型玩家七七4 小时前
梯度累积真的省显存吗?它换走的是什么成本
java·javascript·数据库·人工智能·深度学习