mediapipe 训练自有图像数据分类

参考:

https://developers.google.com/mediapipe/solutions/customization/image_classifier

https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/image_classifier.ipynb#scrollTo=plvO-YmcQn5g

安装:

pip install mediapipe-model-maker  -i http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com --use-pep517

版本错误情况

1)RuntimeError: File loading is not yet supported on Windows

其中mediapipe版本要大于等于0.10.0;下图中的要升级;不然后续用python 加载文件会报:

2)ImportError: cannot import name 'array_record_module' from 'array_record.python' ;参考:https://blog.csdn.net/LQ_001/article/details/130991571;原因:包依赖关系出现问题,原来版本 tensorflow-datasets==4.9.0

pip install tensorflow-datasets==4.8.3

1、训练代码

import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import image_classifier

import matplotlib.pyplot as plt




image_path = os.path.join(os.path.dirname(r"C:\Users\loong\Downloads\mediapipe\flower_photos\flower_photos"), 'flower_photos')   ## down data  :https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


#Review data

labels = []
for i in os.listdir(image_path):
  if os.path.isdir(os.path.join(image_path, i)):
    labels.append(i)
print(labels)

##plt 
NUM_EXAMPLES = 5

for label in labels:
  label_dir = os.path.join(image_path, label)
  example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
  fig, axs = plt.subplots(1, NUM_EXAMPLES, figsize=(10,2))
  for i in range(NUM_EXAMPLES):
    axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
    axs[i].get_xaxis().set_visible(False)
    axs[i].get_yaxis().set_visible(False)
  fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')

plt.show()
#Create dataset;训练集、测试集

data = image_classifier.Dataset.from_folder(image_path)
train_data, remaining_data = data.split(0.8)
test_data, validation_data = remaining_data.split(0.5)


## retrain model 训练模型

spec = image_classifier.SupportedModels.MOBILENET_V2    ##有几个预训练模型,需要联网下载
hparams = image_classifier.HParams(export_dir="exported_model")  ##指定模型保存位置
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
model = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,
    options=options,
)

## 验证模型
loss, acc = model.evaluate(test_data)
print(f'Test loss:{loss}, Test accuracy:{acc}')

##保存模型
model.export_model()

默认训练是10epcos

自定义训练参数

dropout_rate=0.07放到model_options里的,官方代码不大一样可能版本问题

## 训练参数自定义更改

hparams=image_classifier.HParams(epochs=10, export_dir="exported_model_2")
model_options = image_classifier.ModelOptions(dropout_rate=0.07)
options = image_classifier.ImageClassifierOptions(supported_model=spec,model_options=model_options, hparams=hparams)
model_2 = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,
    options=options,
)

相关的一些参数


##模型压缩
from mediapipe_model_maker import quantization

quantization_config = quantization.QuantizationConfig.for_int8(train_data)
model.export_model(model_name="model_int8.tflite", quantization_config=quantization_config)

从8M缩小到3M左右

其他:

查看训练tebsorboard:

注意ValueError: Duplicate plugins for name projector错误,参考https://blog.csdn.net/weixin_44966641/article/details/123292034;我这里是换了个conda环境重新安装个新的tensorflow解决

tensorboard --logdir=.

日志存放默认地址

2、加载推理

参考:https://blog.csdn.net/weixin_42357472/article/details/131322076

import mediapipe as mp

BaseOptions = mp.tasks.BaseOptions
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningMode

options = ImageClassifierOptions(
    base_options=BaseOptions(model_asset_path=r"C:\User**ediapipe\model.tflite"),
    max_results=5,
    running_mode=VisionRunningMode.IMAGE)   ##加载模型

classifier = ImageClassifier.create_from_options(options)


# Load the input image from an image file.
mp_image = mp.Image.create_from_file(r"C:\Users\loong\Downloads\sun2.jpg")

# Perform image classification on the provided single image.
classification_result = classifier.classify(mp_image)
classification_result


相关推荐
学习前端的小z几秒前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc
埃菲尔铁塔_CV算法28 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR28 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️35 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子1 小时前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python1 小时前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯1 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠1 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习