modelscope实时手机检测-通用 模型导出Onnx、修改模型输出维度

目录

说明

代码

其他

参考


说明

本模型为高性能热门应用系列检测模型中的 实时手机检测模型,基于面向工业落地的高性能检测框架DAMOYOLO,其精度和速度超越当前经典的YOLO系列方法。用户使用的时候,仅需要输入一张图像,便可以获得图像中所有手机的坐标信息,并可用于打电话检测等后续应用场景。

本模型为实时手机检测模型,基于检测框架DAMOYOLO-S模型,DAMO-YOLO是一个面向工业落地的目标检测框架,兼顾模型速度与精度,其训练的模型效果超越了目前的一众YOLO系列方法,并且仍然保持极高的推理速度。

代码

复制代码
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

model_id = 'damo/cv_tinynas_object-detection_damoyolo_phone'
input_location = 'image_phone.jpg'

phone_detection = pipeline(Tasks.domain_specific_object_detection, model=model_id)
result = phone_detection(input_location)
print("result is : ", result)

#================模型导出============================
from modelscope.models import Model
import torch

model_id = 'damo/cv_tinynas_object-detection_damoyolo_phone'
model = Model.from_pretrained(model_id)
model.eval()
input = torch.randn((1,3,640,640)).float()
type(model).__call__ = type(model).forward
torch.onnx.export(model, input,'damoyolo_phone.onnx',input_names=["input"],output_names=["output"],opset_version=13)


#================修改模型输出维度=======================
import onnx
from onnx import helper, TensorProto,shape_inference

# 加载原始模型
model = onnx.load("damoyolo_phone.onnx")

# 获取原始输出节点信息
original_output = model.graph.output[0]
print(f"Original output shape: {original_output.type.tensor_type.shape}")

# 创建Transpose节点
transpose_node = helper.make_node(
        "Transpose",
        inputs=[original_output.name],  # 使用原始输出作为输入
        outputs=["transposed_output"],
        perm=[0, 2, 1]  # 交换第1和第2维度
)

# 创建新的输出Tensor
new_output = helper.make_tensor_value_info(
        "transposed_output",
        TensorProto.FLOAT,
        [1, 5, 8400]  # 新的形状
)

# 添加新节点到计算图
model.graph.node.append(transpose_node)

# 替换输出
model.graph.output.remove(original_output)
model.graph.output.insert(0, new_output)

# 检查并修复模型结构
model = shape_inference.infer_shapes(model)

# 保存修改后的模型
onnx.save(model, "model_modified.onnx")
print("Model modified and saved successfully!")
复制代码
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

model_id = 'damo/cv_tinynas_object-detection_damoyolo_phone'
input_location = 'image_phone.jpg'

phone_detection = pipeline(Tasks.domain_specific_object_detection, model=model_id)
result = phone_detection(input_location)
print("result is : ", result)

#================模型导出============================
from modelscope.models import Model
import torch

model_id = 'damo/cv_tinynas_object-detection_damoyolo_phone'
model = Model.from_pretrained(model_id)
model.eval()
input = torch.randn((1,3,640,640)).float()
type(model).__call__ = type(model).forward
torch.onnx.export(model, input,'damoyolo_phone.onnx',input_names=["input"],output_names=["output"],opset_version=13)


#================修改模型输出维度=======================
import onnx
from onnx import helper, TensorProto,shape_inference

# 加载原始模型
model = onnx.load("damoyolo_phone.onnx")

# 获取原始输出节点信息
original_output = model.graph.output[0]
print(f"Original output shape: {original_output.type.tensor_type.shape}")

# 创建Transpose节点
transpose_node = helper.make_node(
        "Transpose",
        inputs=[original_output.name],  # 使用原始输出作为输入
        outputs=["transposed_output"],
        perm=[0, 2, 1]  # 交换第1和第2维度
)

# 创建新的输出Tensor
new_output = helper.make_tensor_value_info(
        "transposed_output",
        TensorProto.FLOAT,
        [1, 5, 8400]  # 新的形状
)

# 添加新节点到计算图
model.graph.node.append(transpose_node)

# 替换输出
model.graph.output.remove(original_output)
model.graph.output.insert(0, new_output)

# 检查并修复模型结构
model = shape_inference.infer_shapes(model)

# 保存修改后的模型
onnx.save(model, "model_modified.onnx")
print("Model modified and saved successfully!")

其他

模型修改后推理实现C# OnnxRuntime DAMO-YOLO 手机检测

参考

https://modelscope.cn/models/iic/cv_tinynas_object-detection_damoyolo_phone/summary

相关推荐
李小白661 分钟前
python 函数
开发语言·python
没书读了35 分钟前
考研复习-线性代数强化-向量组和方程组特征值
python·线性代数·机器学习
做运维的阿瑞1 小时前
Python核心架构深度解析:从解释器原理到GIL机制全面剖析
开发语言·python·架构·系统架构
软件算法开发1 小时前
基于黑翅鸢优化的LSTM深度学习网络模型(BKA-LSTM)的一维时间序列预测算法matlab仿真
深度学习·算法·lstm·时间序列预测·黑翅鸢优化·bka-lstm
AI数据皮皮侠3 小时前
中国上市公司数据(2000-2023年)
大数据·人工智能·python·深度学习·机器学习
ggaofeng5 小时前
深度学习基本函数
人工智能·深度学习
Dxy12393102167 小时前
python如何通过链接下载保存视频
python·spring·音视频
Terio_my7 小时前
Java bean 数据校验
java·开发语言·python
无咎.lsy8 小时前
裸K初级篇 - (一)蜡烛突破信号
python
可触的未来,发芽的智生10 小时前
新奇特:神经网络的集团作战思维,权重共享层的智慧
人工智能·python·神经网络·算法·架构