TensorFlow学习:在web前端如何使用Keras 模型

前言

在上篇文章 TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调中我们学习了如何使用官方模型,以及使用自己的数据微调模型。

但是吧,代码一直是跑在Python里,而我本身是做前端开发的。我是很想让它在前端进行浏览器里进行运行。

谷歌贴心的为我们准备了 TensorFlow.jsTensorFlow.js 是一个 JavaScript 库,用于在浏览器和 Node.js 训练和部署机器学习模型。

这篇文章我们来学习如何在前端运行模型,模型的话就用上一篇文章里训练的花朵分类模型。

官方文档:TensorFlow.js 官方文档

注: 下面是我的采坑心得,我这是第一次学习,第一次搞。你要是按照我的步骤遇到了其他问题,不要问我,我也不会。

建议按顺序观看,这是一个小系列,适合像我这样的初学者入门

配置环境:windows环境下tensorflow安装

图片分类案例学习:TensorFlow案例学习:对服装图像进行分类

使用官方模型,并进行微调:TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调

将模型转换,在前端使用:TensorFlow学习:在web前端如何使用Keras 模型

学习

处理模型

Keras 模型(通常通过 Python API 创建)可能保存成多种格式之一。完整模型格式可以转换成 TensorFlow.js Layers 格式,这种格式可以直接加载到 TensorFlow.js 中进行推断或进一步训练。

目标 TensorFlow.js Layers 格式是一个包含 model.json 文件和一组二进制格式的分片权重文件的目录。model.json 文件包含模型拓扑(又称"架构"或"计算图":它是对层及其连接方式的描述)和权重文件的清单。

我们上一篇文章训练出的模型就是Keras 模型,这里需要对其进行转换。

安装

javascript 复制代码
pip install tensorflowjs

就是这一步上来就被搞惨了。

最开始下载,还没下载多少就超时了,直接下载不了。后来查到可以使用国内镜像下载

javascript 复制代码
 pip install tensorflowjs  -i https://pypi.tuna.tsinghua.edu.cn/simple

下载速度是变快了很多,结果下载到最后又来一个依赖冲突,最后又下载失败了。最终解决完这个问题是因为在web端使用,也需要下载

javascript 复制代码
npm install @tensorflow/tfjs

当时我在想这两个是不是一个东西啊,问了一下gpt,npm下载的这个还真的可以用来进行模型转换。

这里我建议,即使你pip下载成功了,最好还是使用npm下载的这个进行模型转换。因为这样可以保证tensorflowjs版本一致,避免因为版本问题导致最后使用时又出问题

转换

这个也是个坑啊,文档是这样说的

但是在上一篇文章中,我最后保存的不是.h5格式啊,然后又回去跑模型,最后model.save('my_model.h5'),将模型保存为.h5格式。再然后转换模型

javascript 复制代码
tensorflowjs_converter --input_format=keras flower_model.h5 flower_js_model


看样子是成功了,结果还真没成功,在前端加载时又报错了。没办法,百度查、翻文档。然后看见了这个

还真需要用这个,不过上面的代码有点问题,不需要有\ 符号。正确代码应该是

javascript 复制代码
tensorflowjs_converter --input_format=tf_saved_model   flower_model web_model

这里要注意:

  • 我们还是使用的npm下载的依赖,不是pip下载的依赖
  • --input_format=tf_saved_model,指定输入格式
  • flower_model web_model这是两个路径,前面的是模型的逻辑,后面那个是转换完成后的输出路径
  • 这里加载的模型不是.h5,就是.pb文件所在的文件夹,记住是文件夹,不是目录

总之就是将flower_model下的模型进行转换,将转换后的模型输出到web_model目录下

在前端使用

这里要特别注意对输入图片的处理,一开始就是因为输入图片处理的不正确,导致模型在预测时结果不正确。后来各种查资料,才解决,使用代码如下:

javascript 复制代码
<template>
  <div class="page-container">
    <div class="first-title">
      官方文档:
      <a href="https://tensorflow.google.cn/js/models?hl=zh-cn"
        >https://tensorflow.google.cn/js/models?hl=zh-cn</a
      >
    </div>
    <div class="img-list">
      <img
        v-for="img in imageList"
        :key="img.name"
        :src="img.url"
        :id="img.name"
        :class="activeImg == img.name ? 'img-item img-item-active' : 'img-item'"
        @click="changeImg(img)"
      />
    </div>
    <div style="margin-top: 20px">结果是:{{ result }}</div>
  </div>
</template>

<script setup>
import { ref, onMounted } from "vue";
import * as tf from "@tensorflow/tfjs";

// 图片
const imageList = ref([]);
// 当前选中的图片
const activeImg = ref("f1");
// 结果
const result = ref("");
// 图片列表
const IMAGES = [
  {
    name: "f1",
    url: "../assets/f1.jpg",
  },
  {
    name: "f2",
    url: "../assets/f2.jpg",
  },
  {
    name: "f3",
    url: "../assets/f3.jpg",
  },
  {
    name: "f4",
    url: "../assets/f4.jpg",
  },
];
const IMAGENET_CLASSES = ["雏菊", "蒲公英", "玫瑰", "向日葵", "郁金香"];

onMounted(() => {
  imageList.value = [];

  IMAGES.forEach((item) => {
    import(item.url).then((img) => {
      imageList.value.push({
        name: item.name,
        url: img.default,
      });
    });
  });
});

// 切换图片
const changeImg = async (img) => {
  activeImg.value = img.name;
  // 识别图片
  await identify(img.name);
};

// 识别图片
const identify = async (id) => {
  const imageElement = await document.getElementById(id);
  console.log("图片", imageElement);
  // 载入模型
  const model = await tf.loadGraphModel("../../public/web_model/model.json");
  console.log("模型:", model);

  // 图像预处理
  const imageTensor = preprocessImage(imageElement);

  // 对图片进行预测
  const predictions = await model.predict(imageTensor);

  console.log("predictions:", predictions);

  // 获取预测结果
  const predictedIndex = tf.argMax(predictions, 1).dataSync()[0];
  const predictedLabel = IMAGENET_CLASSES[predictedIndex];
  result.value = predictedLabel;
  console.log("结果:", predictedLabel, predictedIndex);
};

// 图像预处理
const preprocessImage = (img) => {
  // 将图像转换为张量对象并将像素值转换为浮点数类型
  const tensor = tf.browser.fromPixels(img).toFloat();
  // 张量的轴上添加一个维度,以适应模型的输入要求
  const expandedDims = tensor.expandDims();
  // 调整图像的尺寸为224x224,尺寸是模型的要求
  const resizedImg = tf.image.resizeBilinear(expandedDims, [224, 224]);
  // 将像素值归一化到范围[0, 1]之间
  const normalizedImg = tf.div(resizedImg, 255.0);
  // 返回归一化后的图像张量
  return normalizedImg;
};
</script>

<style lang="scss" scoped>
.img-list {
  display: flex;

  .img-item {
    width: 240px;
    height: 180px;
    border-radius: 5px;
    cursor: pointer;
    padding: 10px;
  }

  .img-item-active {
    border: 2px solid red;
  }
}
</style>

最终效果

相关推荐
小鸡吃米…4 天前
基于 TensorFlow 的图像识别
人工智能·python·tensorflow
小鸡吃米…4 天前
TensorFlow - 构建计算图
人工智能·python·tensorflow
A懿轩A4 天前
【2026 最新】TensorFlow 安装配置详细指南 同时讲解安装CPU和GPU版本 小白也能轻松上手!逐步带图超详细展示(Windows 版)
人工智能·windows·python·深度学习·tensorflow
小鸡吃米…5 天前
TensorFlow 实现异或(XOR)运算
人工智能·python·tensorflow·neo4j
小鸡吃米…5 天前
TensorFlow 实现梯度下降优化
人工智能·python·tensorflow·neo4j
甄心爱学习5 天前
【LR逻辑回归】原理以及tensorflow实现
算法·tensorflow·逻辑回归
小鸡吃米…6 天前
TensorFlow 实现多层感知机学习
人工智能·python·tensorflow
小鸡吃米…6 天前
TensorFlow 优化器
人工智能·python·tensorflow
小鸡吃米…7 天前
TensorFlow 模型导出
python·tensorflow·neo4j
Jonathan Star8 天前
Ant Design (antd) Form 组件中必填项的星号(*)从标签左侧移到右侧
人工智能·python·tensorflow