Tensorflow2.0笔记 - tensor排序操作

本笔记主要记录sort,argsort,以及top_k操作,加上一个求Top K准确度的例子。

import tensorflow as tf
import numpy as np

tf.__version__


#sort,argsort

#对1维的tensor进行排序
tensor = tf.random.shuffle(tf.range(10))
print(tensor)
#升序
print("======tf.sort(direction='ASCENDING'):", tf.sort(tensor, direction='ASCENDING'))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort,返回排序后元素对应原始数据元素的index
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))
args = tf.argsort(tensor, direction='DESCENDING')
print("======Max element:", tensor[args[0]])

#多维tensor排序
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)

#不带参数,默认升序
print("======tf.sort():", tf.sort(tensor))
#降序
print("======tf.sort(direction='DESCENDING'):", tf.sort(tensor, direction='DESCENDING'))
#argsort
print("======tf.argsort(direction='DESCENDING'):", tf.argsort(tensor, direction='DESCENDING'))

#top_k得到前最大/最小值
tensor = tf.random.uniform([3,3], maxval=10, dtype=tf.int32)
print(tensor)
#top_k返回值主要有indices和values
#indices返回top k个元素的下标数据
#values返回top k个元素的值
#得到最大的前两个元素
topN = tf.math.top_k(tensor, 2)
print("=====Top 2 indices:", topN.indices)
print("=====Top 2 values :", topN.values)

#top-k accuracy
#假设下面的tensor表示各个类别的预测概率信息,真实的标签类别是2(下标)
# tensor = tf.convert_to_tensor([0.1, 0.2, 0.3, 0.4])
# 那么top-1是0.4,对应标签是3,真实标签类别是2,预测错误,top-1预测准确率是0%
# top-2表示返回前两个最有可能的值[0.4,0.3],对应标签是[3,2],top-2预测准确率100%
# 同理,top-3预测准确率100%

#举例说明
#假设下面的tensor为两个样本的预测结果
prob = tf.constant([[0.1, 0.2, 0.7], [0.2, 0.65, 0.15]])
print("=====Probabilities:", prob)
#标签信息,第一个样本真实类别是2, 第二个样本真实类别是0
target = tf.constant([2, 0])

#使用top_k获得预测结果的indices,这个结果就是对应的类别信息
predictedClasses = tf.math.top_k(prob, 3).indices
predictedClasses = tf.transpose(predictedClasses, perm=[1, 0])
#转置后的矩阵,第一行表示两个个样本top 1的预测值(最有可能的类别),第二行表示top 2的预测值(第二可能的类别)
print(predictedClasses)
#将真实值broadcast_to一个3*2的矩阵(1x2 => 3x2)
target = tf.broadcast_to(target, [3,2])
print(target)

#接下来就可以对比preditecdClasses和target
#Predicted       Actual
#[2, 1]          [2,0]   => top1准确度: 1/2 = 50%
#[1, 0]          [2,0]   => top2准确度: 
#                           样本1(第一列前两个元素)和真实的target里有一个能对上,预测正确,计数1
#                           样本2(第2列前两个元素)和真实target的类别有一个能对上,预测正确,计数1
#                           最终结果是: 1+1 / 2(总样本数) = 100%
#[2, 1]          [2,0]   => top3准确度: 100%


#实例,返回topk的准确率函数
#output: 网络输出的预测概率结果,[b, N],batchsize个预测值
#target: 真实的类别,[b]
#topk: 表示要返回哪些topk结果,假设topk = [1, 2, 3],表示要返回top1, top2和top3三个准确度结果
def topKAccuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batchSize = target.shape[0]

    pred = tf.math.top_k(output, maxk).indices
    pred = tf.transpose(pred, perm=[1, 0])
    real = tf.broadcast_to(target, pred.shape)
    #转换为0(False),1(True)二值表示的结果
    correct = tf.equal(pred, real) #correct是一个[k, b]大小的tensor

    result = []
    for k in topk:
        #取出前k行求和除以样本数量
        #取出前k行用reshape进行flatten
        correct_k = tf.cast(tf.reshape(correct[:k], [-1]), dtype=tf.float32)
        #求和
        correct_k = tf.reduce_sum(correct_k)
        accuracy = float(correct_k / batchSize)
        result.append(accuracy)
    return result

#模拟一个10个样本,6个类别的预测结果
output = tf.random.normal([10, 6])
print("=====>Original Output:\n", output.numpy())
#softmax处理,让指定axis的数据转换成元素相加结果为1的数据(概率)
output = tf.math.softmax(output, axis=1)
print("=====>Probability(Softmax Output):\n", output.numpy())
print("=====>Argmax:\n", tf.argmax(output, axis=1).numpy())
#模拟一个真实类别信息,10个样本的真实标签
target = tf.random.uniform([10], maxval=6, dtype=tf.int32)
print("=====>Labels:\n", target.numpy())

accuracies = topKAccuracy(output, target, topk=(1,2,3,4,5,6))
print("Top1 - Top6 Accuracy:\n", accuracies)

运行结果:

相关推荐
YSGZJJ9 分钟前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞11 分钟前
COR 损失函数
人工智能·机器学习
幽兰的天空15 分钟前
Python 中的模式匹配:深入了解 match 语句
开发语言·python
HPC_fac130520678161 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
网易独家音乐人Mike Zhou4 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
安静读书4 小时前
Python解析视频FPS(帧率)、分辨率信息
python·opencv·音视频
小陈phd4 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
冰帝海岸4 小时前
01-spring security认证笔记
java·笔记·spring
Guofu_Liao5 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
小二·5 小时前
java基础面试题笔记(基础篇)
java·笔记·python