Tensorflow2.0笔记 - tensor排序操作

本笔记主要记录sort,argsort,以及top_k操作,加上一个求Top K准确度的例子。

复制代码
import tensorflow as tf
import numpy as np

tf.__version__


#sort,argsort

#对1维的tensor进行排序
tensor = tf.random.shuffle(tf.range(10))
print(tensor)
#升序
print("======tf.sort(direction='ASCENDING'):", tf.sort(tensor, direction='ASCENDING'))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort,返回排序后元素对应原始数据元素的index
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))
args = tf.argsort(tensor, direction='DESCENDING')
print("======Max element:", tensor[args[0]])

#多维tensor排序
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)

#不带参数,默认升序
print("======tf.sort():", tf.sort(tensor))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))

#top_k得到前最大/最小值
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)
#top_k返回值主要有indices和values
#indices返回top k个元素的下标数据
#values返回top k个元素的值
#得到最大的前两个元素
topN = tf.math.top_k(tensor, 2)
print("=====Top 2 indices:", topN.indices)
print("=====Top 2 values :", topN.values)

#top-k accuracy
#假设下面的tensor表示各个类别的预测概率信息,真实的标签类别是2(下标)
# tensor = tf.convert_to_tensor([0.1, 0.2, 0.3, 0.4])
# 那么top-1是0.4,对应标签是3,真实标签类别是2,预测错误,top-1预测准确率是0%
# top-2表示返回前两个最有可能的值[0.4,0.3],对应标签是[3,2],top-2预测准确率100%
# 同理,top-3预测准确率100%

#举例说明
#假设下面的tensor为两个样本的预测结果
prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.65, 0.15]])
print("=====Probabilities:", prob)
#标签信息,第一个样本真实类别是2, 第二个样本真实类别是0
target = tf.constant([2, 0])

#使用top_k获得预测结果的indices,这个结果就是对应的类别信息
predictedClasses = tf.math.top_k(prob, 3).indices
predictedClasses = tf.transpose(predictedClasses, perm=[1, 0])
#转置后的矩阵,第一行表示两个个样本top 1的预测值(最有可能的类别),第二行表示top 2的预测值(第二可能的类别)
print(predictedClasses)
#将真实值broadcast_to一个3*2的矩阵(1x2 => 3x2)
target = tf.broadcast_to(target, [3,2])
print(target)

#接下来就可以对比preditecdClasses和target
#Predicted       Actual
#[2, 1]          [2,0]   => top1准确度: 1/2 = 50%
#[1, 0]          [2,0]   => top2准确度: 
#                           样本1(第一列前两个元素)和真实的target里有一个能对上,预测正确,计数1
#                           样本2(第2列前两个元素)和真实target的类别有一个能对上,预测正确,计数1
#                           最终结果是: 1+1 / 2(总样本数) = 100%
#[2, 1]          [2,0]   => top3准确度: 100%


#实例,返回topk的准确率函数
#output: 网络输出的预测概率结果,[b, N],batchsize个预测值
#target: 真实的类别,[b]
#topk: 表示要返回哪些topk结果,假设topk = [1, 2, 3],表示要返回top1, top2和top3三个准确度结果
def topKAccuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batchSize = target.shape[0]

    pred = tf.math.top_k(output, maxk).indices
    pred = tf.transpose(pred, perm=[1, 0])
    real = tf.broadcast_to(target, pred.shape)
    #转换为0(False),1(True)二值表示的结果
    correct = tf.equal(pred, real) #correct是一个[k, b]大小的tensor

    result = []
    for k in topk:
        #取出前k行求和除以样本数量
        #取出前k行用reshape进行flatten
        correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
        #求和
        correct_k = tf.reduce_sum(correct_k)
        accuracy = float(correct_k / batchSize)
        result.append(accuracy)
    return result

#模拟一个10个样本,6个类别的预测结果
output = tf.random.normal([10, 6])
print("=====>Original Output:\n", output.numpy())
#softmax处理,让指定axis的数据转换成元素相加结果为1的数据(概率)
output = tf.math.softmax(output, axis=1)
print("=====>Probability(Softmax Output):\n", output.numpy())
print("=====>Argmax:\n", tf.argmax(output, axis=1).numpy())
#模拟一个真实类别信息,10个样本的真实标签
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print("=====>Labels:\n", target.numpy())

accuracies = topKAccuracy(output, target, topk=(1,2,3,4,5,6))
print("Top1 - Top6 Accuracy:\n", accuracies)

运行结果:

相关推荐
微爱帮监所写信寄信8 小时前
微爱帮监狱寄信写信工具用户头像安全审核体系
人工智能
熬夜敲代码的小N8 小时前
AI文本分类实战:从数据预处理到模型部署全流程解析
人工智能·分类·数据挖掘
沛沛老爹8 小时前
Web开发者快速上手AI Agent:Dify本地化部署与提示词优化实战
前端·人工智能·rag·faq·文档细粒度
国科安芯8 小时前
低轨卫星边缘计算节点的抗辐照MCU选型分析
人工智能·单片机·嵌入式硬件·架构·边缘计算·安全威胁分析·安全性测试
美团技术团队8 小时前
2025 美团技术团队热门技术文章汇总
人工智能
GEO AI搜索优化助手8 小时前
生成式AI搜索的跨行业革命与商业模式重构
大数据·人工智能·搜索引擎·重构·生成式引擎优化·ai优化·geo搜索优化
张拭心8 小时前
"氛围编程"程序员被解雇了
android·前端·人工智能
我是人机不吃鸭梨8 小时前
Flutter AI 集成革命(2025版):从 Gemini 模型到智能表单验证器的终极方案
开发语言·javascript·人工智能·flutter·microsoft·架构
晨光32118 小时前
Day42 图像数据与显存
深度学习
编码小哥8 小时前
OpenCV阈值分割技术:全局阈值与自适应阈值
人工智能·opencv·计算机视觉