Tensorflow2.0笔记 - 自定义Layer和Model

本笔记主要记录如何在tensorflow中实现自定的Layer和Model。详细内容请参考代码中的链接。

复制代码
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__
#关于自定义layer和自定义Model的相关介绍,参考下面的链接:
#https://tf.wiki/zh_hans/basic/models.html
#https://blog.csdn.net/lzs781/article/details/104741958


#自定义Dense层,继承自Layer
class MyDense(layers.Layer):
    #需要实现__init__和call方法
    def __init__(self, input_dim, output_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
        self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))

    def call(self, inputs, training=None):
        out = inputs@self.kernel + self.bias
        return out

#自定义Model,继承自Model
class MyModel(keras.Model):
    #需要实现__init__和call方法
    def __init__(self):
        super(MyModel, self).__init__()
        #自定义5层MyDense自定义Layer
        self.fc1 = MyDense(28*28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, trainning=None):
        x = self.fc1(inputs) #会调用MyDense的call方法
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        return x

customModel = MyModel()
customModel.build(input_shape=[None, 28*28])
customModel.summary()

运行结果:

相关推荐
Mininglamp_27183 小时前
会中 AI Skill 架构设计解析:3 种人设 × 7 种能力的技术实现
人工智能·语音识别·硬件·ai agent·skill
墨神谕4 小时前
人工智能(三)— 神经网络的训练
人工智能·神经网络·机器学习
APIshop4 小时前
Python 获取 1688 商品采集 API 接口 | 工厂货源自动化对接商品信息 | 无需选品
运维·python·自动化
deepin_sir4 小时前
10 - 函数
开发语言·python
RyFit4 小时前
Java + AI 实战:Spring AI 从入门到企业级落地
java·人工智能·spring
Raink老师4 小时前
【AI面试临阵磨枪-69】如何设计一个支持百万级工具的 Agent 系统?如何快速路由与选择工具?
人工智能·面试·职场和发展
oort1234 小时前
My Name:开发者部署平台OORT.sh—— AI时代的开发者部署平台,是Vibe Coding闺蜜
人工智能
charlee444 小时前
《GIS基础原理与技术实践》配套案例(Python版)
python·conda·numpy·gis·环境配置
枫叶林FYL4 小时前
项目十:事件溯源仓储管理系统(WMS)仿真实现
开发语言·python
君为先-bey4 小时前
CogVideoX——Transformer从文本到视频的扩散模型
深度学习·音视频·transformer·扩散模型