Tensorflow 2.X Debug中的Tensor.numpy问题 @tf.function

我在调试YOLOv3模型过程中想查看get_pred函数下面的get_anchors_and_decode函数里grid_shape的数值

python 复制代码
#---------------------------------------------------#
#   将预测值的每个特征层调成真实值
#---------------------------------------------------#
def get_anchors_and_decode(feats, anchors, num_classes, input_shape, calc_loss=False):
    num_anchors = len(anchors)
    #------------------------------------------#
    #   grid_shape指的是特征层的高和宽
    #------------------------------------------#
    grid_shape = K.shape(feats)[1:3]
    # print(grid_shape)
    #--------------------------------------------------------------------#
    #   获得各个特征点的坐标信息。生成的shape为(13, 13, num_anchors, 2)
    #--------------------------------------------------------------------#

    # with tf.compat.v1.Session() as sess:
    #     sess.run(tf.compat.v1.global_variables_initializer())
    #     grid_shape_value = sess.run(grid_shape)
    print("Grid Shape: ", grid_shape)


    grid_x  = K.tile(K.reshape(K.arange(0, stop=grid_shape[1]), [1, -1, 1, 1]), [grid_shape[0], 1, num_anchors, 1])
    grid_y  = K.tile(K.reshape(K.arange(0, stop=grid_shape[0]), [-1, 1, 1, 1]), [1, grid_shape[1], num_anchors, 1])
    grid    = K.cast(K.concatenate([grid_x, grid_y]), K.dtype(feats))
    #---------------------------------------------------------------#
    #   将先验框进行拓展,生成的shape为(13, 13, num_anchors, 2)
    #---------------------------------------------------------------#
    anchors_tensor = K.reshape(K.constant(anchors), [1, 1, num_anchors, 2])
    anchors_tensor = K.tile(anchors_tensor, [grid_shape[0], grid_shape[1], 1, 1])

    #---------------------------------------------------#
    #   将预测结果调整成(batch_size,13,13,3,85)
    #   85可拆分成4 + 1 + 80
    #   4代表的是中心宽高的调整参数
    #   1代表的是框的置信度
    #   80代表的是种类的置信度
    #---------------------------------------------------#
    feats           = K.reshape(feats, [-1, grid_shape[0], grid_shape[1], num_anchors, num_classes + 5])
    #------------------------------------------#
    #   对先验框进行解码,并进行归一化
    #------------------------------------------#
    box_xy          = (K.sigmoid(feats[..., :2]) + grid) / K.cast(grid_shape[::-1], K.dtype(feats))
    box_wh          = K.exp(feats[..., 2:4]) * anchors_tensor / K.cast(input_shape[::-1], K.dtype(feats))
    #------------------------------------------#
    #   获得预测框的置信度
    #------------------------------------------#
    box_confidence  = K.sigmoid(feats[..., 4:5])
    box_class_probs = K.sigmoid(feats[..., 5:])
    
    #---------------------------------------------------------------------#
    #   在计算loss的时候返回grid, feats, box_xy, box_wh
    #   在预测的时候返回box_xy, box_wh, box_confidence, box_class_probs
    #---------------------------------------------------------------------#
    if calc_loss == True:
        return grid, feats, box_xy, box_wh
    return box_xy, box_wh, box_confidence, box_class_probs
python 复制代码
@tf.function    
def get_pred(self, image_data, input_image_shape):
        out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False)
        return out_boxes, out_scores, out_classes

直接打印的话只有这个变量的属性,看不到具体数值:

python 复制代码
Tensor("yolo_eval/strided_slice:0", shape=(2,), dtype=int32)

如果试图用grid_shape.numpy()查看数值呢,又会报错:

python 复制代码
AttributeError: 'Tensor' object has no attribute 'numpy'

根本原因在于get_pred函数调用不是直接调用的,而是通过@tf.function这个特殊的tf装饰器调用的

@tf.function是一个加速程序运行,提升效率的装饰器

把@tf.function去掉不影响程序大致运行逻辑,会取消一些并行加速计算,但是这么做对调试就方便很多

去掉@tf.function之后的get_pred函数:

python 复制代码
def get_pred(self, image_data, input_image_shape):
        out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False)
        return out_boxes, out_scores, out_classes

此时,调试运行至get_pred下面调用的get_anchors_and_decode函数,在python控制台输入grid_shape.numpy(),回车:

可以直接查看Tensor的数值了!!!

特别需要注意的一点,在YOLO模型初始化的时候也会进入函数内部,这个时候grid_shape.numpy()依然报错:

需要在真正执行模型预测的时候,进入函数,grid_shape.numpy()才不会报错:

相关推荐
却道天凉_好个秋18 分钟前
OpenCV(二十一):HSV与HSL
人工智能·opencv·计算机视觉
从后端到QT20 分钟前
标量-向量-矩阵-基础知识
人工智能·机器学习·矩阵
新智元21 分钟前
65 岁图灵巨头离职创业!LeCun 愤然与小扎决裂,Meta 巨震
人工智能·openai
机器之心24 分钟前
全球第二、国内第一!钉钉发布DeepResearch多智能体框架,已在真实企业部署
人工智能·openai
新智元30 分钟前
翻译界的 ChatGPT 时刻!Meta 发布新模型,几段示例学会冷门新语言
人工智能·openai
沉默媛32 分钟前
什么是Hinge损失函数
人工智能·损失函数
北青网快讯44 分钟前
声网AI技术赋能,智能客服告别机械式应答
人工智能
机器之心1 小时前
TypeScript超越Python成GitHub上使用最广语言,AI是主要驱动力
人工智能·openai
nju_spy1 小时前
周志华《机器学习导论》第 15 章 规则学习(符号主义学习)
人工智能·机器学习·数理逻辑·序贯覆盖·规则学习·ripper·一阶规则学习
许泽宇的技术分享1 小时前
当 AI 工作流需要“人类智慧“:深度解析 Microsoft Agent Framework 的人工接入机制
人工智能·microsoft