Tensorflow2.0笔记 - 自定义Layer和Model

本笔记主要记录如何在tensorflow中实现自定的Layer和Model。详细内容请参考代码中的链接。

复制代码
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__
#关于自定义layer和自定义Model的相关介绍,参考下面的链接:
#https://tf.wiki/zh_hans/basic/models.html
#https://blog.csdn.net/lzs781/article/details/104741958


#自定义Dense层,继承自Layer
class MyDense(layers.Layer):
    #需要实现__init__和call方法
    def __init__(self, input_dim, output_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
        self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))

    def call(self, inputs, training=None):
        out = inputs@self.kernel + self.bias
        return out

#自定义Model,继承自Model
class MyModel(keras.Model):
    #需要实现__init__和call方法
    def __init__(self):
        super(MyModel, self).__init__()
        #自定义5层MyDense自定义Layer
        self.fc1 = MyDense(28*28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, trainning=None):
        x = self.fc1(inputs) #会调用MyDense的call方法
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        return x

customModel = MyModel()
customModel.build(input_shape=[None, 28*28])
customModel.summary()

运行结果:

相关推荐
飞Link4 分钟前
瑞萨联姻 Irida Labs:嵌入式开发者如何玩转“端侧视觉 AI”新范式?
人工智能
RSTJ_162516 分钟前
PYTHON+AI LLM DAY THREETY-SEVEN
开发语言·人工智能·python
郝学胜-神的一滴20 分钟前
深度学习优化核心:梯度下降与网络训练全解析
数据结构·人工智能·python·深度学习·算法·机器学习
Aision_29 分钟前
Agent 为什么需要 Checkpoint?
人工智能·python·gpt·langchain·prompt·aigc·agi
清水白石00833 分钟前
《Python性能深潜:从对象分配开销到“小对象风暴”的破解之道(含实战与最佳实践)》
开发语言·python
小贺儿开发36 分钟前
《唐朝诡事录之长安》——盛世马球
人工智能·unity·ai·shader·绘画·影视·互动
秋937 分钟前
ESP32 与 Air780E 4G 模块配合做 MQTT 数据传输
人工智能
DeepFlow 零侵扰全栈可观测1 小时前
运动战:AI 时代 IT 运维的决胜之道——DeepFlow 业务全链路可观测性的落地实践
运维·网络·人工智能·arcgis·云计算
链上日记1 小时前
AgentWin:AI Agent驱动的Web4智能金融新纪元
人工智能·金融
Brilliantwxx1 小时前
【C++】 vector(代码实现+坑点讲解)
开发语言·c++·笔记·算法