Tensorflow2.0笔记 - 均方差MSE和交叉熵CROSS ENTROPHY作为损失函数

本笔记主要记录使用MSE和交叉熵作为loss function时的梯度计算方法。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__


#softmax函数使用
#参考资料:https://blog.csdn.net/u013230189/article/details/82835717
#简单例子:
#假设输出的LOGITS SCORE为:
#2.0
#1.0
#0.1
#使用softmax后,可以按照score得分高低来转换成概率大小,所有输出值相加为1:
#2.0     0.7
#1.0  -> 0.2
#0.1     0.1

#MSE均方差loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/35707643
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #使用softmax计算概率
    prob = tf.nn.softmax(x@w +b, axis=1)
    #使用MSE计算loss
    loss = tf.reduce_mean(tf.losses.MSE(tf.one_hot(y, depth=3), prob))
#求解损失函数的梯度
grads = tape.gradient(loss, [w,b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

#交叉熵loss函数及其梯度计算
#参考资料:https://zhuanlan.zhihu.com/p/38241764
#下面例子中:x表示两个样本数据,每个数据是一个长度为4的tensor:[2,4]
x = tf.random.normal([2,4])
#输入数据的维度是4,输出节点我们定义为3维,表示3分类结果
w = tf.random.normal([4,3])
#bias初始化为0
b = tf.zeros([3])
#输出的label值,表示两个样本的真实label的class是2和0
y = tf.constant([2,0])

with tf.GradientTape() as tape:
    tape.watch([w,b])
    #计算logits
    logits = x@w + b
    #使用交叉熵计算loss
    loss = tf.reduce_mean(tf.losses.categorical_crossentropy(tf.one_hot(y, depth=3), logits, from_logits=True))

#求解损失函数的梯度
grads = tape.gradient(loss, [w, b])
print("Gradients of w:\n", grads[0].numpy())
print("Gradients of b:\n", grads[1].numpy())

运行结果:

相关推荐
workflower11 小时前
需求-技术需求
python·测试用例·需求分析·软件需求
HAREWORK_FFF11 小时前
非技术岗位与AI岗位的能力映射与转型成功概率评估
人工智能
tq108611 小时前
agent 记忆 = markdown + json + git
人工智能·git
石臻臻的杂货铺11 小时前
Codex + Claude Code + 一个编排器:独立开发者的「一人军团」实战手册
人工智能
壹通GEO11 小时前
GEO数据分析不再难:1键生成归因热力图+预警报告
人工智能·数据挖掘·数据分析
ding_zhikai11 小时前
【Web应用开发笔记】Django笔记3-2:部署我的简陋网页
笔记·后端·python·django
肾透侧视攻城狮11 小时前
《TensorFlow生态全景图:核心组件、扩展工具与工业级应用深度解读》
人工智能·深度学习·tensorflow生态系统·tfcore/.js/lite·tf extended/hub·tf serving·生态系统优势对比
山岚的运维笔记11 小时前
SQL Server笔记 -- 第86章:查询存储
笔记·python·sql·microsoft·sqlserver·flask
两万五千个小时11 小时前
构建mini Claude Code:11 - 从「被动等待」到「主动找活」
人工智能·python·架构
朴实赋能11 小时前
当情绪可以被看见:AI手环如何成为青少年心理的“预警哨”?
人工智能·发疯可耻但有用·男生女生一起愁愁愁·不上称的倔强·沉默的忧郁#情绪消费·爱你老己·新型校园攀