前言
在上篇文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调中我们学习了如何使用官方模型,以及使用自己的数据微调模型。
但是吧,代码一直是跑在Python里,而我本身是做前端开发的。我是很想让它在前端进行浏览器里进行运行。
谷歌贴心的为我们准备了 TensorFlow.js 。TensorFlow.js 是一个 JavaScript 库,用于在浏览器和 Node.js 训练和部署机器学习模型。
这篇文章我们来学习如何在前端运行模型,模型的话就用上一篇文章里训练的花朵分类模型。
官方文档:TensorFlow.js 官方文档
注: 下面是我的采坑心得,我这是第一次学习,第一次搞。你要是按照我的步骤遇到了其他问题,不要问我,我也不会。
建议按顺序观看,这是一个小系列,适合像我这样的初学者入门
图片分类案例学习:TensorFlow案例学习:对服装图像进行分类
使用官方模型,并进行微调:TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调
将模型转换,在前端使用:TensorFlow学习:在web前端如何使用Keras 模型
学习
处理模型
Keras 模型(通常通过 Python API 创建)可能保存成多种格式之一。完整模型格式可以转换成 TensorFlow.js Layers
格式,这种格式可以直接加载到 TensorFlow.js
中进行推断或进一步训练。
目标 TensorFlow.js Layers
格式是一个包含 model.json
文件和一组二进制格式的分片权重文件的目录。model.json
文件包含模型拓扑(又称"架构"或"计算图":它是对层及其连接方式的描述)和权重文件的清单。
我们上一篇文章训练出的模型就是Keras 模型,这里需要对其进行转换。
安装
javascript
pip install tensorflowjs
就是这一步上来就被搞惨了。
最开始下载,还没下载多少就超时了,直接下载不了。后来查到可以使用国内镜像下载
javascript
pip install tensorflowjs -i https://pypi.tuna.tsinghua.edu.cn/simple
下载速度是变快了很多,结果下载到最后又来一个依赖冲突,最后又下载失败了。最终解决完这个问题是因为在web端使用,也需要下载
javascript
npm install @tensorflow/tfjs
当时我在想这两个是不是一个东西啊,问了一下gpt,npm下载的这个还真的可以用来进行模型转换。
这里我建议,即使你pip下载成功了,最好还是使用npm下载的这个进行模型转换。因为这样可以保证tensorflowjs版本一致,避免因为版本问题导致最后使用时又出问题
转换
这个也是个坑啊,文档是这样说的
但是在上一篇文章中,我最后保存的不是.h5
格式啊,然后又回去跑模型,最后model.save('my_model.h5')
,将模型保存为.h5
格式。再然后转换模型
javascript
tensorflowjs_converter --input_format=keras flower_model.h5 flower_js_model
看样子是成功了,结果还真没成功,在前端加载时又报错了。没办法,百度查、翻文档。然后看见了这个
还真需要用这个,不过上面的代码有点问题,不需要有\
符号。正确代码应该是
javascript
tensorflowjs_converter --input_format=tf_saved_model flower_model web_model
这里要注意:
- 我们还是使用的npm下载的依赖,不是pip下载的依赖
--input_format=tf_saved_model
,指定输入格式flower_model web_model
这是两个路径,前面的是模型的逻辑,后面那个是转换完成后的输出路径- 这里加载的模型不是
.h5
,就是.pb
文件所在的文件夹,记住是文件夹,不是目录
总之就是将flower_model
下的模型进行转换,将转换后的模型输出到web_model
目录下
在前端使用
这里要特别注意对输入图片的处理,一开始就是因为输入图片处理的不正确,导致模型在预测时结果不正确。后来各种查资料,才解决,使用代码如下:
javascript
<template>
<div class="page-container">
<div class="first-title">
官方文档:
<a href="https://tensorflow.google.cn/js/models?hl=zh-cn"
>https://tensorflow.google.cn/js/models?hl=zh-cn</a
>
</div>
<div class="img-list">
<img
v-for="img in imageList"
:key="img.name"
:src="img.url"
:id="img.name"
:class="activeImg == img.name ? 'img-item img-item-active' : 'img-item'"
@click="changeImg(img)"
/>
</div>
<div style="margin-top: 20px">结果是:{{ result }}</div>
</div>
</template>
<script setup>
import { ref, onMounted } from "vue";
import * as tf from "@tensorflow/tfjs";
// 图片
const imageList = ref([]);
// 当前选中的图片
const activeImg = ref("f1");
// 结果
const result = ref("");
// 图片列表
const IMAGES = [
{
name: "f1",
url: "../assets/f1.jpg",
},
{
name: "f2",
url: "../assets/f2.jpg",
},
{
name: "f3",
url: "../assets/f3.jpg",
},
{
name: "f4",
url: "../assets/f4.jpg",
},
];
const IMAGENET_CLASSES = ["雏菊", "蒲公英", "玫瑰", "向日葵", "郁金香"];
onMounted(() => {
imageList.value = [];
IMAGES.forEach((item) => {
import(item.url).then((img) => {
imageList.value.push({
name: item.name,
url: img.default,
});
});
});
});
// 切换图片
const changeImg = async (img) => {
activeImg.value = img.name;
// 识别图片
await identify(img.name);
};
// 识别图片
const identify = async (id) => {
const imageElement = await document.getElementById(id);
console.log("图片", imageElement);
// 载入模型
const model = await tf.loadGraphModel("../../public/web_model/model.json");
console.log("模型:", model);
// 图像预处理
const imageTensor = preprocessImage(imageElement);
// 对图片进行预测
const predictions = await model.predict(imageTensor);
console.log("predictions:", predictions);
// 获取预测结果
const predictedIndex = tf.argMax(predictions, 1).dataSync()[0];
const predictedLabel = IMAGENET_CLASSES[predictedIndex];
result.value = predictedLabel;
console.log("结果:", predictedLabel, predictedIndex);
};
// 图像预处理
const preprocessImage = (img) => {
// 将图像转换为张量对象并将像素值转换为浮点数类型
const tensor = tf.browser.fromPixels(img).toFloat();
// 张量的轴上添加一个维度,以适应模型的输入要求
const expandedDims = tensor.expandDims();
// 调整图像的尺寸为224x224,尺寸是模型的要求
const resizedImg = tf.image.resizeBilinear(expandedDims, [224, 224]);
// 将像素值归一化到范围[0, 1]之间
const normalizedImg = tf.div(resizedImg, 255.0);
// 返回归一化后的图像张量
return normalizedImg;
};
</script>
<style lang="scss" scoped>
.img-list {
display: flex;
.img-item {
width: 240px;
height: 180px;
border-radius: 5px;
cursor: pointer;
padding: 10px;
}
.img-item-active {
border: 2px solid red;
}
}
</style>
最终效果