Tensorflow2.0笔记 - 均方差MSE和交叉熵CROSS ENTROPHY作为损失函数

本笔记主要记录使用MSE和交叉熵作为loss function时的梯度计算方法。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__


#softmax函数使用
#参考资料:https://blog.csdn.net/u013230189/article/details/82835717
#简单例子:
#假设输出的LOGITS SCORE为:
#2.0
#1.0
#0.1
#使用softmax后,可以按照score得分高低来转换成概率大小,所有输出值相加为1:
#2.0     0.7
#1.0  -> 0.2
#0.1     0.1

#MSE均方差loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/35707643
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #使用softmax计算概率
    prob = tf.nn.softmax(x@w +b, axis=1)
    #使用MSE计算loss
    loss = tf.reduce_mean(tf.losses.MSE(tf.one_hot(y, depth=3), prob))
#求解损失函数的梯度
grads = tape.gradient(loss, [w,b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

#交叉熵loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/38241764
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #计算logits
    logits = x@w + b
    #使用交叉熵计算loss
    loss = tf.reduce_mean(tf.losses.categorical_crossentropy(tf.one_hot(y, depth=3), logits, from_logits=True))

#求解损失函数的梯度
grads = tape.gradient(loss, [w, b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

运行结果:

相关推荐
步步为营DotNet几秒前
Semantic Kernel 在.NET AI 开发中的深度探索与实践
人工智能·.net
AC赳赳老秦2 分钟前
技术文章素材收集自动化:用 OpenClaw 自动爬取行业资讯、技术热点、优质文章
运维·开发语言·python·自动化·wpf·deepseek·openclaw
安全指北针2 分钟前
AI检测 vs 传统SIEM:2026年安全运营效率实测对比
人工智能·安全
SilentSamsara3 分钟前
模型评估与超参调优:交叉验证、Optuna 与模型选择策略
人工智能·python·深度学习·机器学习·青少年编程
一次旅行3 分钟前
【AI工具】Odysseus:GitHub 6万星自托管AI工作空间,隐私优先的本地化AI体验
人工智能·github
网络研究院4 分钟前
利用人工智能破解中世纪密码
人工智能·研究·历史·语言·情报
东方佑4 分钟前
如何证明自然语言是条件随机、递归自指后的分形
人工智能
辰海Coding6 分钟前
MiniSpring框架学习笔记-JDBC 访问框架:如何抽取 JDBC 模板并隔离数据库?
java·数据库·笔记·学习·spring
叫我:松哥8 分钟前
基于LSTM与ARIMA的城市空气质量分析与预测系统
人工智能·python·rnn·算法·机器学习·flask·lstm
指尖在键盘上舞动9 分钟前
RKNN 模型部署:onnx转rknn后精度下降 —— 精度调优与问题排查
python·ubuntu·rk3588·rknn·onnx·npu