Tensorflow2.0笔记 - 均方差MSE和交叉熵CROSS ENTROPHY作为损失函数

本笔记主要记录使用MSE和交叉熵作为loss function时的梯度计算方法。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__


#softmax函数使用
#参考资料:https://blog.csdn.net/u013230189/article/details/82835717
#简单例子:
#假设输出的LOGITS SCORE为:
#2.0
#1.0
#0.1
#使用softmax后,可以按照score得分高低来转换成概率大小,所有输出值相加为1:
#2.0     0.7
#1.0  -> 0.2
#0.1     0.1

#MSE均方差loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/35707643
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #使用softmax计算概率
    prob = tf.nn.softmax(x@w +b, axis=1)
    #使用MSE计算loss
    loss = tf.reduce_mean(tf.losses.MSE(tf.one_hot(y, depth=3), prob))
#求解损失函数的梯度
grads = tape.gradient(loss, [w,b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

#交叉熵loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/38241764
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #计算logits
    logits = x@w + b
    #使用交叉熵计算loss
    loss = tf.reduce_mean(tf.losses.categorical_crossentropy(tf.one_hot(y, depth=3), logits, from_logits=True))

#求解损失函数的梯度
grads = tape.gradient(loss, [w, b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

运行结果:

相关推荐
听风吟丶40 分钟前
Java 8 Stream API 高级实战:从数据处理到性能优化的深度解析
开发语言·python
AA陈超2 小时前
ASC学习笔记0014:手动添加一个新的属性集
c++·笔记·学习·ue5
澳鹏Appen2 小时前
数据集月度精选 | 高质量具身智能数据集:打开机器人“感知-决策-动作”闭环的钥匙
人工智能·机器人·具身智能
文人sec2 小时前
pytest1-接口自动化测试场景
软件测试·python·单元测试·pytest
Chunyyyen2 小时前
【第二十二周】自然语言处理的学习笔记06
笔记·学习·自然语言处理
q***71014 小时前
开源模型应用落地-工具使用篇-Spring AI-Function Call(八)
人工智能·spring·开源
极限实验室4 小时前
Coco AI 参选 Gitee 2025 最受欢迎开源软件!您的每一票,都是对中国开源的硬核支持
人工智能·开源
secondyoung4 小时前
Mermaid流程图高效转换为图片方案
c语言·人工智能·windows·vscode·python·docker·流程图
iFlow_AI4 小时前
iFlow CLI Hooks 「从入门到实战」应用指南
开发语言·前端·javascript·人工智能·ai·iflow·iflow cli
Shang180989357264 小时前
THC63LVD1027D一款10位双链路LVDS信号中继器芯片,支持WUXGA分辨率视频数据传输THC63LVD1027支持30位数据通道方案
人工智能·考研·信息与通信·信号处理·thc63lvd1027d·thc63lvd1027