Tensorflow2.0笔记 - 自定义Layer和Model

本笔记主要记录如何在tensorflow中实现自定的Layer和Model。详细内容请参考代码中的链接。

复制代码
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics

tf.__version__
#关于自定义layer和自定义Model的相关介绍,参考下面的链接:
#https://tf.wiki/zh_hans/basic/models.html
#https://blog.csdn.net/lzs781/article/details/104741958


#自定义Dense层,继承自Layer
class MyDense(layers.Layer):
    #需要实现__init__和call方法
    def __init__(self, input_dim, output_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))
        self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))

    def call(self, inputs, training=None):
        out = inputs@self.kernel + self.bias
        return out

#自定义Model,继承自Model
class MyModel(keras.Model):
    #需要实现__init__和call方法
    def __init__(self):
        super(MyModel, self).__init__()
        #自定义5层MyDense自定义Layer
        self.fc1 = MyDense(28*28, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, trainning=None):
        x = self.fc1(inputs) #会调用MyDense的call方法
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        return x

customModel = MyModel()
customModel.build(input_shape=[None, 28*28])
customModel.summary()

运行结果:

相关推荐
BlackWolfSky1 分钟前
鸿蒙高级课程笔记1—应用DFX能力介绍
笔记·华为·harmonyos
Faker66363aaa2 分钟前
YOLO11改进蚊虫目标检测模型,AttheHead注意力机制提升检测精度
人工智能·目标检测·计算机视觉
郝学胜-神的一滴2 分钟前
基于30年教学沉淀的清华大学AI通识经典:《人工智能的底层逻辑》
人工智能·程序人生·机器学习·scikit-learn·sklearn
OPEN-Source3 分钟前
大模型实战:把 LangChain / LlamaIndex 工作流接入监控与告警体系
人工智能·langchain·企业微信·rag
青春不朽5124 分钟前
Scikit-learn 入门指南
python·机器学习·scikit-learn
得物技术5 分钟前
大模型网关:大模型时代的智能交通枢纽|得物技术
人工智能·ai
共享家95277 分钟前
嵌入模型(Embedding)的全方位指南
人工智能·机器学习
进击的小头13 分钟前
FIR滤波器实战:音频信号降噪
c语言·python·算法·音视频
ViiTor_AI16 分钟前
AI 有声书旁白来了:AI 配音如何重塑有声书制作模式
人工智能
2501_9416527716 分钟前
验证码识别与分类任务_gfl_x101-32x4d_fpn_ms-2x_coco模型训练与优化
人工智能·数据挖掘