mediapipe 训练自有图像数据分类

参考:

https://developers.google.com/mediapipe/solutions/customization/image_classifier

https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/image_classifier.ipynb#scrollTo=plvO-YmcQn5g

安装:

复制代码
pip install mediapipe-model-maker  -i http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com --use-pep517

版本错误情况

1)RuntimeError: File loading is not yet supported on Windows

其中mediapipe版本要大于等于0.10.0;下图中的要升级;不然后续用python 加载文件会报:

2)ImportError: cannot import name 'array_record_module' from 'array_record.python' ;参考:https://blog.csdn.net/LQ_001/article/details/130991571;原因:包依赖关系出现问题,原来版本 tensorflow-datasets==4.9.0

pip install tensorflow-datasets==4.8.3

1、训练代码

复制代码
import os
import tensorflow as tf
assert tf.__version__.startswith('2')

from mediapipe_model_maker import image_classifier

import matplotlib.pyplot as plt




image_path = os.path.join(os.path.dirname(r"C:\Users\loong\Downloads\mediapipe\flower_photos\flower_photos"), 'flower_photos')   ## down data  :https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz


#Review data

labels = []
for i in os.listdir(image_path):
  if os.path.isdir(os.path.join(image_path, i)):
    labels.append(i)
print(labels)

##plt 
NUM_EXAMPLES = 5

for label in labels:
  label_dir = os.path.join(image_path, label)
  example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
  fig, axs = plt.subplots(1, NUM_EXAMPLES, figsize=(10,2))
  for i in range(NUM_EXAMPLES):
    axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
    axs[i].get_xaxis().set_visible(False)
    axs[i].get_yaxis().set_visible(False)
  fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')

plt.show()
复制代码
#Create dataset;训练集、测试集

data = image_classifier.Dataset.from_folder(image_path)
train_data, remaining_data = data.split(0.8)
test_data, validation_data = remaining_data.split(0.5)


## retrain model 训练模型

spec = image_classifier.SupportedModels.MOBILENET_V2    ##有几个预训练模型,需要联网下载
hparams = image_classifier.HParams(export_dir="exported_model")  ##指定模型保存位置
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
model = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,
    options=options,
)

## 验证模型
loss, acc = model.evaluate(test_data)
print(f'Test loss:{loss}, Test accuracy:{acc}')

##保存模型
model.export_model()

默认训练是10epcos

自定义训练参数

dropout_rate=0.07放到model_options里的,官方代码不大一样可能版本问题

复制代码
## 训练参数自定义更改

hparams=image_classifier.HParams(epochs=10, export_dir="exported_model_2")
model_options = image_classifier.ModelOptions(dropout_rate=0.07)
options = image_classifier.ImageClassifierOptions(supported_model=spec,model_options=model_options, hparams=hparams)
model_2 = image_classifier.ImageClassifier.create(
    train_data = train_data,
    validation_data = validation_data,
    options=options,
)

相关的一些参数


复制代码
##模型压缩
from mediapipe_model_maker import quantization

quantization_config = quantization.QuantizationConfig.for_int8(train_data)
model.export_model(model_name="model_int8.tflite", quantization_config=quantization_config)

从8M缩小到3M左右

其他:

查看训练tebsorboard:

注意ValueError: Duplicate plugins for name projector错误,参考https://blog.csdn.net/weixin_44966641/article/details/123292034;我这里是换了个conda环境重新安装个新的tensorflow解决

复制代码
tensorboard --logdir=.

日志存放默认地址

2、加载推理

参考:https://blog.csdn.net/weixin_42357472/article/details/131322076

复制代码
import mediapipe as mp

BaseOptions = mp.tasks.BaseOptions
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningMode

options = ImageClassifierOptions(
    base_options=BaseOptions(model_asset_path=r"C:\User**ediapipe\model.tflite"),
    max_results=5,
    running_mode=VisionRunningMode.IMAGE)   ##加载模型

classifier = ImageClassifier.create_from_options(options)


# Load the input image from an image file.
mp_image = mp.Image.create_from_file(r"C:\Users\loong\Downloads\sun2.jpg")

# Perform image classification on the provided single image.
classification_result = classifier.classify(mp_image)
classification_result


相关推荐
化作星辰39 分钟前
深度学习_神经网络激活函数
人工智能·深度学习·神经网络
陈天伟教授42 分钟前
人工智能技术- 语音语言- 03 ChatGPT 对话、写诗、写小说
人工智能·chatgpt
llilian_161 小时前
智能数字式毫秒计在实际生活场景中的应用 数字式毫秒计 智能毫秒计
大数据·网络·人工智能
打码人的日常分享1 小时前
基于信创体系政务服务信息化建设方案(PPT)
大数据·服务器·人工智能·信息可视化·架构·政务
硬汉嵌入式2 小时前
专为 MATLAB 优化的 AI 助手MATLAB Copilot
人工智能·matlab·copilot
北京盛世宏博2 小时前
如何利用技术手段来甄选一套档案馆库房安全温湿度监控系统
服务器·网络·人工智能·选择·档案温湿度
搞科研的小刘选手2 小时前
【EI稳定】检索第六届大数据经济与信息化管理国际学术会议(BDEIM 2025)
大数据·人工智能·经济
半吊子全栈工匠2 小时前
软件产品的10个UI设计技巧及AI 辅助
人工智能·ui
机器之心3 小时前
真机RL!最强VLA模型π*0.6来了,机器人在办公室开起咖啡厅
人工智能·openai
机器之心3 小时前
马斯克Grok 4.1低调发布!通用能力碾压其他一切模型
人工智能·openai