Tensorflow2.0笔记 - 均方差MSE和交叉熵CROSS ENTROPHY作为损失函数

本笔记主要记录使用MSE和交叉熵作为loss function时的梯度计算方法。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__


#softmax函数使用
#参考资料:https://blog.csdn.net/u013230189/article/details/82835717
#简单例子:
#假设输出的LOGITS SCORE为:
#2.0
#1.0
#0.1
#使用softmax后,可以按照score得分高低来转换成概率大小,所有输出值相加为1:
#2.0     0.7
#1.0  -> 0.2
#0.1     0.1

#MSE均方差loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/35707643
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #使用softmax计算概率
    prob = tf.nn.softmax(x@w +b, axis=1)
    #使用MSE计算loss
    loss = tf.reduce_mean(tf.losses.MSE(tf.one_hot(y, depth=3), prob))
#求解损失函数的梯度
grads = tape.gradient(loss, [w,b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

#交叉熵loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/38241764
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #计算logits
    logits = x@w + b
    #使用交叉熵计算loss
    loss = tf.reduce_mean(tf.losses.categorical_crossentropy(tf.one_hot(y, depth=3), logits, from_logits=True))

#求解损失函数的梯度
grads = tape.gradient(loss, [w, b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

运行结果:

相关推荐
Dxy123931021610 小时前
Python在图片上画圆形:从入门到实战
开发语言·python
小江的记录本10 小时前
【系统设计】《2026高频经典系统设计题》(秒杀系统、短链接系统、订单系统、支付系统、IM系统、RAG系统设计)(完整版)
java·后端·python·安全·设计模式·架构·系统架构
物联网软硬件开发-轨物科技10 小时前
【轨物方案】光伏清洁-检测一体化机器人系统
数据库·人工智能·机器人
m0_3776182310 小时前
HTML怎么显示速率限制重置时间_HTML X-RateLimit-Reset解析【说明】
jvm·数据库·python
果汁华10 小时前
Chrome DevTools MCP:让 AI 编码助手拥有浏览器调试超能力
前端·人工智能·chrome devtools
u01091476010 小时前
C#怎么实现OAuth2.0授权_C#如何对接第三方快捷登录【核心】
jvm·数据库·python
2301_7775993710 小时前
如何显著提升 Google Sheets 数据库批量更新脚本的执行效率
jvm·数据库·python
杰梵10 小时前
聚酯切片DSC热分析应用报告
人工智能·算法
2201_7610405910 小时前
bootstrap怎么给div添加自定义的边框样式
jvm·数据库·python
Java后端的Ai之路10 小时前
当大模型开始“水土不服“:从通才到专才的进化论——Fine-tuning 企业级实战全攻略
人工智能·python·langchain·rag·lcel