Pytorch Android 对象识别

背景

过去一段时间,一直在摸索PyTorch这个深度机器学习框架,摸索的方向主要为"对象是识别",以用于智能家居生态中的人文关怀。例如:视力不佳,在家找东西眼神吃力。基于此种原因,通过请教从事AI算法的大师,推荐我使用pytorch,易于上手,方便终端部署运行。

本文章从客户端研发角度,介绍PyTorch在Android系统中集成使用方法。

PyTorch介绍

官网 :github.com/pytorch/pyt...

摘自:pytorch.org/docs/stable...

原文:PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.PyTorch是一个优化过的张量库,用于GPU和CPU进行深度学习

入门

环境准备

核心开发环境, IntelliJ IDEA工具是为了使用YOLOV5工程训练自己的模型

  1. 安装Python 3.9
  2. Android Studio Giraffe | 2022.3.1 Patch 2
  3. IntelliJ IDEA 2021.2.4 (Community Edition) 【用于编辑python】

模型准备

下载源码

git clone github.com/ultralytics...

IntelliJ IDEA导入源码工程

注意阅读 README.md

  1. 安装依赖 : pip install -r requirements.txt

  2. 导出模型

  • 进入根目录
  • 执行导出命令:python3 export.py --weights yolov5s.pt --optimize --include torchscript
  • 最终结果会输出在根目录下,名为"yolov5s.torchscript"

注意点:

  1. 安装依赖,可能无法完全安装,需设置镜像
  2. 执行export.py时,会先下载yolov5s.pt模型保存至根目录,但是大概率无法下载成功,可自行手动下载,然后讲文件放在根目录下即可

Android客户端安装

源码下载

git clone github.com/pytorch/and...

Android Studio 导入源码工程

  1. 导入 "android-demo-app/ObjectDetection" 工程
  2. 复制"yolov5s.torchscript"文件至"android-demo-app/ObjectDetection/app/src/main/assets" 文件夹下
  3. 文件修改

根目录/app/src/main/java/org/pytorch/demo/objectdetection

MainActivity.java 复制代码
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "yolov5s.torchscript"));

根目录/app

build.gradle 复制代码
......
android {
    namespace 'org.pytorch.demo.objectdetection'
......
}

dependencies {
......
implementation 'org.pytorch:pytorch_android_lite:2.1.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:2.1.0'
......
}

根目录

build.gradle 复制代码
......
dependencies {
    classpath "com.android.tools.build:gradle:8.1.2"
......

根目录/gradle/wrapper/

gradle-wrapper.properties 复制代码
distributionUrl=https://services.gradle.org/distributions/gradle-8.0-bin.zip

效果

自定义模型训练

目标

仅识别儿童洞洞书和玩具熊

数据准备

使用相机设置1:1的分辨率,对洞洞书和玩具熊进行多角度拍照

数据标注

labelstud 官网 :labelstud.io/guide/get_s...

部分效果

具体使用说明,可自行搜索

标注导出

在1023工程中,选择"Export"菜单,对话框中选择YOLO格式 标注目录结构

yolov5工程准备

工程目录结构

导入标注数据

  1. "根目录/data" 路径下,创建自己的标注数据文件夹,本篇定义为"harvey1022"
  2. 将刚刚导出的标注文件夹"images"和"labels"复制到"harvey1022"文件夹中

配置标注数据

  1. "根目录/data" 路径下,创建自己的"xxx.yaml"配置文件,本篇定义为"harvey1022.yaml"
  2. 编写配置内容

注意:names 中的序号一定要和打标时的序号一致

yaml 复制代码
path: data/harvey1022 
train: images
val: images
test:

names:
  0: dongdongshu
  1: wanjuxiong

配置模型参数

  1. "根目录/models" 路径下,创建自己的模型参数文件"xxxx.yaml",本篇定义为"yolov5s-harvey1022.yaml", 来自yolov5s.yaml文件
  2. 定义类别数量为2

定义2缘于模型只会识别2种物体(玩具熊,洞洞书)

yaml 复制代码
# Parameters
nc: 2  # number of classes
depth_multiple: 0.33  # model depth multiple
......

开始训练

命令: python3 train.py --img 640 --epochs 150 --device cpu --data data/harvey1022.yaml --weights yolov5s.pt --cfg models/yolov5s-harvey1022.yaml

导出手机模型

注意:runs目录存在于根目录,开始训练时,会自动创建。每次训练完成,训练好的模型所在文件位置可能会发生变化,例如,runs/train/exp/weights/best.pt ,runs/train/exp2/weights/best.pt

命令: python3 export.py --weights runs/train/exp/weights/best.pt --optimize --include torchscript

结果: 最终会在根目录下生成一个"best.torchscript"文件,即训练好的模型

集成best.torchscript模型

  1. 将best.torchscript文件复制到ObjectDetection工程中的assets文件下
  2. 更改MainActivity.java文件中的模型名称为"best.torchscript"
  3. 修改 PrePostProcessor.java 中的模型解析参数

修改前

arduino 复制代码
private static int mOutputColumn = 85; // left, top, right, bottom, score and 80 class probability

修改后,因为模型只会识别2种物体,所以left+top+right+bottom+score+nc(2)=7

arduino 复制代码
private static int mOutputColumn = 7; // left, top, right, bottom, score and 80 class probability
  1. 编译,安装

QA

手机端运行模型,使用到两个库是否可以自行编译?

可以,git clone github.com/pytorch/pyt... , 通过Android Studio 导入pytorch/android源码即可编译。

小插曲:官方的test_app使用的是非lite版,因此如果要使用test_app来验证编译好的库,需要编译非lite版本。

相关推荐
wL魔法师3 小时前
【LLM】大模型训练中的稳定性问题
人工智能·pytorch·深度学习·llm
技术小黑9 小时前
Transformer系列 | Pytorch复现Transformer
pytorch·深度学习·transformer
DogDaoDao10 小时前
神经网络稀疏化设计构架方法和原理深度解析
人工智能·pytorch·深度学习·神经网络·大模型·剪枝·网络稀疏
西猫雷婶11 小时前
pytorch基本运算-Python控制流梯度运算
人工智能·pytorch·python·深度学习·神经网络·机器学习
ACEEE12221 天前
Stanford CS336 | Assignment 2 - FlashAttention-v2 Pytorch & Triotn实现
人工智能·pytorch·python·深度学习·机器学习·nlp·transformer
深耕AI2 天前
【PyTorch训练】准确率计算(代码片段拆解)
人工智能·pytorch·python
nuczzz2 天前
pytorch非线性回归
人工智能·pytorch·机器学习·ai
~-~%%2 天前
Moe机制与pytorch实现
人工智能·pytorch·python
Garfield20052 天前
绕过 FlashAttention-2 限制:在 Turing 架构上使用 PyTorch 实现 FlashAttention
pytorch·flashattention·turing·图灵架构·t4·2080ti
深耕AI2 天前
【PyTorch训练】为什么要有 loss.backward() 和 optimizer.step()?
人工智能·pytorch·python