Tensorflow2.0笔记 - BatchNormalization

本笔记记录BN层相关的代码。关于BatchNormalization,可以自行百度,或参考这里:

一文读懂Batch Normalization - 知乎神经网络基础系列: 《深度学习中常见激活函数的原理和特点》《过拟合: dropout原理和在模型中的多种应用》深度模型的基础结构是MLP,模型训练和调参的复杂性随着模型深度的增加而加大,这使得算法工程师在业务中...https://zhuanlan.zhihu.com/p/594944859#:~:text=Batch,normalization%E7%9A%84%E6%80%9D%E8%B7%AF%E5%BE%88%E7%AE%80%E5%8D%95%EF%BC%8C%E5%AF%B9%E8%BE%93%E5%85%A5%E7%9A%84%E6%95%B0%E6%8D%AE%E5%9C%A8%E6%AF%8F%E4%B8%AA%E7%BB%B4%E5%BA%A6%E4%B8%8A%E8%BF%9B%E8%A1%8C%E6%A0%87%E5%87%86%E5%8C%96%E5%A4%84%E7%90%86%EF%BC%8C%E5%86%8D%E8%BF%9B%E8%A1%8C%E7%BA%BF%E6%80%A7%E5%8F%98%E6%8D%A2%EF%BC%8C%E4%BB%A5%E7%BC%93%E8%A7%A3%E5%9B%A0%E6%A0%87%E5%87%86%E5%8C%96%E5%AF%BC%E8%87%B4%E7%9A%84%E6%95%B0%E6%8D%AE%E8%A1%A8%E5%BE%81%E8%83%BD%E5%8A%9B%E7%9A%84%E4%B8%8B%E9%99%8D%E3%80%82Batch Normalization(BN)超详细解析_batchnorm在预测阶段需要计算吗-CSDN博客文章浏览阅读3.7w次,点赞109次,收藏458次。单层视角神经网络可以看成是上图形式,对于中间的某一层,其前面的层可以看成是对输入的处理,后面的层可以看成是损失函数。一次反向传播过程会同时更新所有层的权重W1,W2,...,WL,前面层权重的更新会改变当前层输入的分布,而跟据反向传播的计算方式,我们知道,对Wk的更新是在假定其输入不变的情况下进行的。如果假定第k层的输入节点只有2个,对第k层的某个输出节点而言,相当于一个线性模型y=w1x1+w2x..._batchnorm在预测阶段需要计算吗https://blog.csdn.net/weixin_44023658/article/details/105844861

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Input

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__

#下面的x中的数据按照均值是1,标准差为0.5的正态分布排布,通过BatchNormalization进行修正
x = tf.random.normal([2, 4, 4, 3], mean=1., stddev=0.5)
net = layers.BatchNormalization(axis=3)
out = net(x, training=True)

print("Variables:\n", net.variables)

#设置优化器
optimizer = optimizers.Adam(learning_rate=1e-4)

for i in range(100):
    with tf.GradientTape() as tape:
        #进行100次BN层前向传播,moving_mean和moving_variance会变化
        out = net(x, training=True)
        #自定义损失函数是均值和1的距离
        loss = tf.reduce_mean(tf.pow(out, 2)) - 1
    #进行梯度更新, gamma和beta会变化
    grads = tape.gradient(loss, net.trainable_variables)
    optimizer.apply_gradients(zip(grads, net.trainable_variables))

print("Trained 100 times:\n", net.variables)

运行结果:

相关推荐
bryant_meng39 分钟前
【python】OpenCV—Image Moments
开发语言·python·opencv·moments·图片矩
车载诊断技术1 小时前
电子电气架构 --- 什么是EPS?
网络·人工智能·安全·架构·汽车·需求分析
KevinRay_1 小时前
Python超能力:高级技巧让你的代码飞起来
网络·人工智能·python·lambda表达式·列表推导式·python高级技巧
跃跃欲试-迪之1 小时前
animatediff 模型网盘分享
人工智能·stable diffusion
Captain823Jack2 小时前
nlp新词发现——浅析 TF·IDF
人工智能·python·深度学习·神经网络·算法·自然语言处理
被制作时长两年半的个人练习生2 小时前
【AscendC】ReduceSum中指定workLocal大小时如何计算
人工智能·算子开发·ascendc
资源补给站2 小时前
大恒相机开发(2)—Python软触发调用采集图像
开发语言·python·数码相机
Captain823Jack2 小时前
w04_nlp大模型训练·中文分词
人工智能·python·深度学习·神经网络·算法·自然语言处理·中文分词
Black_mario2 小时前
链原生 Web3 AI 网络 Chainbase 推出 AVS 主网, 拓展 EigenLayer AVS 应用场景
网络·人工智能·web3
PieroPc3 小时前
Python 自动化 打开网站 填表登陆 例子
运维·python·自动化