Spring Boot集成tensorflow实现图片检测服务

1.什么是tensorflow?

TensorFlow名字的由来就是张量(Tensor)在计算图(Computational Graph)里的流动(Flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,也就是计算图的节点),常见的损失函数,优化算法。

  • TensorFlow 是一个开放源代码软件库,用于进行高性能数值计算。借助其灵活的架构,用户可以轻松地将计算工作部署到多种平台(CPU、GPU、TPU)和设备(桌面设备、服务器集群、移动设备、边缘设备等)。www.tensorflow.org/tutorials/?...

  • TensorFlow 是一个用于研究和生产的开放源代码机器学习库。TensorFlow 提供了各种 API,可供初学者和专家在桌面、移动、网络和云端环境下进行开发。

  • TensorFlow是采用数据流图(data flow graphs)来计算,所以首先我们得创建一个数据流流图,然后再将我们的数据(数据以张量(tensor)的形式存在)放在数据流图中计算. 节点(Nodes)在图中表示数学操作,图中的边(edges)则表示在节点间相互联系的多维数据数组, 即张量(tensor)。训练模型时tensor会不断的从数据流图中的一个节点flow到另一节点, 这就是TensorFlow名字的由来。 张量(Tensor) :张量有多种. 零阶张量为 纯量或标量 (scalar) 也就是一个数值. 比如 [1],一阶张量为 向量 (vector), 比如 一维的 [1, 2, 3],二阶张量为 矩阵 (matrix), 比如 二维的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此类推, 还有 三阶 三维的 ... 张量从流图的一端流动到另一端的计算过程。它生动形象地描述了复杂数据结构在人工神经网中的流动、传输、分析和处理模式。

在机器学习中,数值通常由4种类型构成: (1)标量(scalar):即一个数值,它是计算的最小单元,如"1"或"3.2"等。 (2)向量(vector):由一些标量构成的一维数组,如[1, 3.2, 4.6]等。 (3)矩阵(matrix):是由标量构成的二维数组。 (4)张量(tensor):由多维(通常)数组构成的数据集合,可理解为高维矩阵。

tensorflow的基本概念

  • 图:描述了计算过程,Tensorflow用图来表示计算过程
  • 张量:Tensorflow 使用tensor表示数据,每一个tensor是一个多维化的数组
  • 操作:图中的节点为op,一个op获得/输入0个或者多个Tensor,执行并计算,产生0个或多个Tensor
  • 会话:session tensorflow的运行需要再绘话里面运行

tensorflow写代码流程

  • 定义变量占位符
  • 根据数学原理写方程
  • 定义损失函数cost
  • 定义优化梯度下降 GradientDescentOptimizer
  • session 进行训练,for循环
  • 保存saver

2.环境准备

整合步骤

  1. 模型构建:首先,我们需要在TensorFlow中定义并训练深度学习模型。这可能涉及选择合适的网络结构、优化器和损失函数等。
  2. 训练数据准备:接下来,我们需要准备用于训练和验证模型的数据。这可能包括数据清洗、标注和预处理等步骤。
  3. REST API设计:为了与TensorFlow模型进行交互,我们需要在SpringBoot中创建一个REST API。这可以使用SpringBoot的内置功能来实现,例如使用Spring MVC或Spring WebFlux。
  4. 模型部署:在模型训练完成后,我们需要将其部署到SpringBoot应用中。为此,我们可以使用TensorFlow的Java API将模型导出为ONNX或SavedModel格式,然后在SpringBoot应用中加载并使用。

在整合过程中,有几个关键点需要注意。首先,防火墙设置可能会影响TensorFlow训练过程中的网络通信。确保你的防火墙允许TensorFlow访问其所需的网络资源,以免出现训练中断或模型性能下降的问题。其次,要关注版本兼容性。SpringBoot和TensorFlow都有各自的版本更新周期,确保在整合时使用兼容的版本可以避免很多不必要的麻烦。

模型下载

模型构建和模型训练这块设计到python代码,这里跳过,感兴趣的可以下载源代码自己训练模型,咱们直接下载训练好的模型

下载好了,解压放在/resources/inception_v3目录下

3.代码工程

实验目的

实现图片检测

pom.xml

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>springboot-demo</artifactId>
        <groupId>com.et</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>Tensorflow</artifactId>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
    </properties>
    <dependencies>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.5.0</version>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>jmimemagic</groupId>
            <artifactId>jmimemagic</artifactId>
            <version>0.1.2</version>
        </dependency>
        <dependency>
            <groupId>jakarta.platform</groupId>
            <artifactId>jakarta.jakartaee-api</artifactId>
            <version>9.0.0</version>
        </dependency>
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.16.1</version>
        </dependency>
        <dependency>
            <groupId>org.springframework.restdocs</groupId>
            <artifactId>spring-restdocs-mockmvc</artifactId>
            <scope>test</scope>
        </dependency>

    </dependencies>
</project>

controller

kotlin 复制代码
package com.et.tf.api;

import java.io.IOException;

import com.et.tf.service.ClassifyImageService;
import net.sf.jmimemagic.Magic;
import net.sf.jmimemagic.MagicMatch;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
@RequestMapping("/api")
public class AppController {
    @Autowired
    ClassifyImageService classifyImageService;


    @PostMapping(value = "/classify")
    @CrossOrigin(origins = "*")
    public ClassifyImageService.LabelWithProbability classifyImage(@RequestParam MultipartFile file) throws IOException {
        checkImageContents(file);
        return classifyImageService.classifyImage(file.getBytes());
    }

    @RequestMapping(value = "/")
    public String index() {
        return "index";
    }

    private void checkImageContents(MultipartFile file) {
        MagicMatch match;
        try {
            match = Magic.getMagicMatch(file.getBytes());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        String mimeType = match.getMimeType();
        if (!mimeType.startsWith("image")) {
            throw new IllegalArgumentException("Not an image type: " + mimeType);
        }
    }

}

service

java 复制代码
package com.et.tf.service;

import jakarta.annotation.PreDestroy;
import java.util.Arrays;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Scope;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TString;
import org.tensorflow.types.family.TType;

//Inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@Service
@Slf4j
public class ClassifyImageService {

    private final Session session;
    private final List<String> labels;
    private final String outputLayer;

    private final int W;
    private final int H;
    private final float mean;
    private final float scale;

    public ClassifyImageService(
        Graph inceptionGraph, List<String> labels, @Value("${tf.outputLayer}") String outputLayer,
        @Value("${tf.image.width}") int imageW, @Value("${tf.image.height}") int imageH,
        @Value("${tf.image.mean}") float mean, @Value("${tf.image.scale}") float scale
    ) {
        this.labels = labels;
        this.outputLayer = outputLayer;
        this.H = imageH;
        this.W = imageW;
        this.mean = mean;
        this.scale = scale;
        this.session = new Session(inceptionGraph);
    }

    public LabelWithProbability classifyImage(byte[] imageBytes) {
        long start = System.currentTimeMillis();
        try (Tensor image = normalizedImageToTensor(imageBytes)) {
            float[] labelProbabilities = classifyImageProbabilities(image);
            int bestLabelIdx = maxIndex(labelProbabilities);
            LabelWithProbability labelWithProbability =
                new LabelWithProbability(labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f, System.currentTimeMillis() - start);
            log.debug(String.format(
                    "Image classification [%s %.2f%%] took %d ms",
                    labelWithProbability.getLabel(),
                    labelWithProbability.getProbability(),
                    labelWithProbability.getElapsed()
                )
            );
            return labelWithProbability;
        }
    }

    private float[] classifyImageProbabilities(Tensor image) {
        try (Tensor result = session.runner().feed("input", image).fetch(outputLayer).run().get(0)) {
            final Shape resultShape = result.shape();
            final long[] rShape = resultShape.asArray();
            if (resultShape.numDimensions() != 2 || rShape[0] != 1) {
                throw new RuntimeException(
                    String.format(
                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                        Arrays.toString(rShape)
                    ));
            }
            int nlabels = (int) rShape[1];
            FloatDataBuffer resultFloatBuffer = result.asRawTensor().data().asFloats();
            float[] dst = new float[nlabels];
            resultFloatBuffer.read(dst);
            return dst;
        }
    }

    private int maxIndex(float[] probabilities) {
        int best = 0;
        for (int i = 1; i < probabilities.length; ++i) {
            if (probabilities[i] > probabilities[best]) {
                best = i;
            }
        }
        return best;
    }

    private Tensor normalizedImageToTensor(byte[] imageBytes) {
        try (Graph g = new Graph();
             TInt32 batchTensor = TInt32.scalarOf(0);
             TInt32 sizeTensor = TInt32.vectorOf(H, W);
             TFloat32 meanTensor = TFloat32.scalarOf(mean);
             TFloat32 scaleTensor = TFloat32.scalarOf(scale);
        ) {
            GraphBuilder b = new GraphBuilder(g);
            //Tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image
            // Some constants specific to the pre-trained model at:
            // https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
            //
            // - The model was trained with images scaled to 299x299 pixels.
            // - The colors, represented as R, G, B in 1-byte each were converted to
            //   float using (value - Mean)/Scale.

            // Since the graph is being constructed once per execution here, we can use a constant for the
            // input image. If the graph were to be re-used for multiple input images, a placeholder would
            // have been more appropriate.
            final Output input = b.constant("input", TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)));
            final Output output =
                b.div(
                    b.sub(
                        b.resizeBilinear(
                            b.expandDims(
                                b.cast(b.decodeJpeg(input, 3), DataType.DT_FLOAT),
                                b.constant("make_batch", batchTensor)
                            ),
                            b.constant("size", sizeTensor)
                        ),
                        b.constant("mean", meanTensor)
                    ),
                    b.constant("scale", scaleTensor)
                );
            try (Session s = new Session(g)) {
                return s.runner().fetch(output.op().name()).run().get(0);
            }
        }
    }

    static class GraphBuilder {
        final Scope scope;

        GraphBuilder(Graph g) {
            this.g = g;
            this.scope = new OpScope(g);
        }

        Output div(Output x, Output y) {
            return binaryOp("Div", x, y);
        }

        Output sub(Output x, Output y) {
            return binaryOp("Sub", x, y);
        }

        Output resizeBilinear(Output images, Output size) {
            return binaryOp("ResizeBilinear", images, size);
        }

        Output expandDims(Output input, Output dim) {
            return binaryOp("ExpandDims", input, dim);
        }

        Output cast(Output value, DataType dtype) {
            return g.opBuilder("Cast", "Cast", scope).addInput(value).setAttr("DstT", dtype).build().output(0);
        }

        Output decodeJpeg(Output contents, long channels) {
            return g.opBuilder("DecodeJpeg", "DecodeJpeg", scope)
                .addInput(contents)
                .setAttr("channels", channels)
                .build()
                .output(0);
        }

        Output<? extends TType> constant(String name, Tensor t) {
            return g.opBuilder("Const", name, scope)
                .setAttr("dtype", t.dataType())
                .setAttr("value", t)
                .build()
                .output(0);
        }

        private Output binaryOp(String type, Output in1, Output in2) {
            return g.opBuilder(type, type, scope).addInput(in1).addInput(in2).build().output(0);
        }

        private final Graph g;
    }

    @PreDestroy
    public void close() {
        session.close();
    }

    @Data
    @NoArgsConstructor
    @AllArgsConstructor
    public static class LabelWithProbability {
        private String label;
        private float probability;
        private long elapsed;
    }
}

application.yaml

yaml 复制代码
tf:
    frozenModelPath: inception-v3/inception_v3_2016_08_28_frozen.pb
    labelsPath: inception-v3/imagenet_slim_labels.txt
    outputLayer: InceptionV3/Predictions/Reshape_1
    image:
        width: 299
        height: 299
        mean: 0
        scale: 255

logging.level.net.sf.jmimemagic: WARN
spring:
  servlet:
    multipart:
      max-file-size: 5MB

Application.java

java 复制代码
package com.et.tf;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.tensorflow.Graph;
import org.tensorflow.proto.framework.GraphDef;

@SpringBootApplication
@Slf4j
public class Application {

    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }

    @Bean
    public Graph tfModelGraph(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) throws IOException {
        Resource graphResource = getResource(tfFrozenModelPath);

        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom(graphResource.getInputStream()));
        log.info("Loaded Tensorflow model");
        return graph;
    }

    private Resource getResource(@Value("${tf.frozenModelPath}") String tfFrozenModelPath) {
        Resource graphResource = new FileSystemResource(tfFrozenModelPath);
        if (!graphResource.exists()) {
            graphResource = new ClassPathResource(tfFrozenModelPath);
        }
        if (!graphResource.exists()) {
            throw new IllegalArgumentException(String.format("File %s does not exist", tfFrozenModelPath));
        }
        return graphResource;
    }

    @Bean
    public List<String> tfModelLabels(@Value("${tf.labelsPath}") String labelsPath) throws IOException {
        Resource labelsRes = getResource(labelsPath);
        log.info("Loaded model labels");
        return IOUtils.readLines(labelsRes.getInputStream(), StandardCharsets.UTF_8).stream()
            .map(label -> label.substring(label.contains(":") ? label.indexOf(":") + 1 : 0)).collect(Collectors.toList());
    }
}

以上只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

4.测试

启动 Spring Boot应用程序

测试图片分类

访问http://127.0.0.1:8080/,上传一张图片,点击分类

5.引用

相关推荐
why15118 分钟前
腾讯(QQ浏览器)后端开发
开发语言·后端·golang
浪裡遊22 分钟前
跨域问题(Cross-Origin Problem)
linux·前端·vue.js·后端·https·sprint
声声codeGrandMaster29 分钟前
django之优化分页功能(利用参数共存及封装来实现)
数据库·后端·python·django
呼Lu噜1 小时前
WPF-遵循MVVM框架创建图表的显示【保姆级】
前端·后端·wpf
bing_1581 小时前
为什么选择 Spring Boot? 它是如何简化单个微服务的创建、配置和部署的?
spring boot·后端·微服务
学c真好玩1 小时前
Django创建的应用目录详细解释以及如何操作数据库自动创建表
后端·python·django
Asthenia04121 小时前
GenericObjectPool——重用你的对象
后端
Piper蛋窝2 小时前
Go 1.18 相比 Go 1.17 有哪些值得注意的改动?
后端
excel2 小时前
招幕技术人员
前端·javascript·后端
盖世英雄酱581362 小时前
什么是MCP
后端·程序员