基于Tensorflow Serving + Flask的花卉识别编程实践

TensorFlow Serving 主要是用于部署和提供机器学习模型的预测服务,而不是一个完整的静态网页的 Web 服务器,不提供静态网页服务功能。而在生产应用中,都会需要网页服务,如何解决Tensorflow Serving的不足呢?

可以使用反向代理服务器来解决这个问题。反向代理服务器位于客户端和一或多个后端服务器之间,作为客户端请求的中介。反向代理服务器接受来自客户端的请求,然后将这些请求转发到一个或多个后端服务器进行处理。处理完成后,反向代理服务器将结果返回给客户端。常用的应用程序服务器,均可配置为Tensorflow Serving服务器的反向代理服务器,如Nginx、Apache、Flask、HAProxy、Caddy 或 Traefik等。

编者之前常用Apache2,尝试配置Apache2作为Tensoflow Serving的反向代理服务器,但花了大量时间,尝试了多种方法均没有解决CORS(Cross-Origin Resource Sharing,跨源资源共享)问题,Apache2始终因为CORS而不能访问Tensorflow Serving。最后尝试了Flask,它非常简单易用,大概5分钟,一次尝试及成功。

Flask 是一个用 Python 编写的轻量级 Web 框架,特别适用于开发中小型应用程序或服务,成为开发小型 Web 应用、RESTful API 和微服务的理想选择。

下面仍然以花卉识别为例来介绍Tensorflow Serving + Flask的编程实践。

文末附完整源代码。

一、Flask 安装及部署

1. 使用 pip 安装 Flask:

pip install Flask

2. Flask 目录结构

以花卉识别为例,它的目录结构规划如下所示:

your_project/
├── app_server.py						# 主应用程序文件,包含 Flask 应用的代码
	├── static/							# 存放静态文件(如 CSS、JavaScript、图像)
│     └──conf        					# 存放配置文件
│          └── label_map.json       	# 花卉标签
│     └──css        					# 存放 CSS 文件
│          └── styles.css       		# 示例 CSS 文件
│     └── js         					# 存放 JavaScript 文件
│          └── ai_flower_flask.js	    # 花卉识别 JavaScript 文件
├── templates/            				# 存放模板文件(如 HTML 文件)
      └── index.html        			# HTML 文件
说明
  • app_server.py:主 Flask 应用程序文件,包含所有路由和视图函数。
  • static/:存放静态文件夹,Web 服务器将这些文件提供给客户端。
  • templates/:存放模板文件夹,Flask 使用这些模板来渲染 HTML 页面。

3. 创建 Flask 应用

编写服务端程序app_server.py,其中的关键代码解释如下:

(1)创建 Flask 应用实例

python 复制代码
app = Flask(__name__, static_folder='static', template_folder='templates')
  • static_folder:指定静态文件夹的位置(如 CSS、JavaScript、图像)。
  • template_folder:指定模板文件夹的位置(如 HTML 文件)。

(2)配置 CORS

python 复制代码
CORS(app)
  • 允许跨域请求,配置了所有路由都可以接受来自其他域的请求。

(3)路由和视图函数
1)首页路由

 ```python
 @app.route('/')
 def index():
     return render_template('index.html')
 ```
 渲染 `templates` 文件夹中的 `index.html` 文件并返回给客户端。

2) 静态文件路由

 ```python
 @app.route('/static/<path:filename>')
 def serve_static(filename):
     return send_from_directory(app.static_folder, filename)
 ```
 从 `static` 文件夹中提供静态文件(如图片、样式表等)。

3) 花卉识别路由

 ```python
 @app.route('/predict', methods=['POST'])
 def predict():
     try:
         data = request.json
         response = requests.post('http://localhost:8501/v1/models/ai_flower:predict', json=data)
         predictions = response.json()
         return jsonify(predictions)
     except Exception as e:
         return jsonify({'error': str(e)}), 500
 ```
 - 从 POST 请求中获取 JSON 数据。
 - 将数据转发到 TensorFlow Serving 的预测 API。
 - 返回 TensorFlow Serving 的预测结果或错误信息。

4)运行应用

python 复制代码
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)
  • 启动 Flask 开发服务器,监听所有 IP 地址 (0.0.0.0) 的 5000 端口。
  • debug=True:启用调试模式,便于开发和调试。

4. 运行 Flask 服务器

编写完成app_server.py程序后,执行该程序:

python app_server.py

然后通过浏览器访问 http://<your_server_ip>:5000,将看到 index.html 页面。

二、前端代码

index.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Classification</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <!--link rel="stylesheet" href="css/style.css"-->
</head>
<body>
    <h1>Image Classification</h1>
    <input type="file" id="file-input" />
    <button id="upload-button">Upload and Predict</button>
    <div id="result"></div>
    <script src="/static/js/ai_flower_flask.js" defer></script>
</body>
</html>

ai_flower_flask.js

这段代码用于处理用户上传的图像文件,将其通过 TensorFlow.js 处理后发送到 Flask 后端进行预测,然后将预测结果显示给用户。以下是代码的详细解释:

  1. 加载标签映射

    javascript 复制代码
    async function loadLabelMap() {
        console.log('Label URL:', labelUrl); // 打印 labelUrl
        const response = await fetch(labelUrl);
        if (response.ok) {
            labelMap = await response.json();
        } else {
            console.error('Failed to load label_map.json');
            console.log(labelUrl);
        }
    }
    
    // 调用函数加载标签映射
    loadLabelMap();
    • loadLabelMap 函数:异步加载标签映射文件 label_map.json。成功加载后,将其内容存储在 labelMap 中。如果加载失败,则在控制台输出错误信息。
    • loadLabelMap():调用该函数以初始化 labelMap
  2. 处理图像上传和预测

    javascript 复制代码
    document.getElementById('upload-button').addEventListener('click', async () => {
        const fileInput = document.getElementById('file-input');
        const file = fileInput.files[0];
    
        if (!file) {
            alert("Please select an image file.");
            return;
        }
    
        // 读取图像文件作为数据 URL
        const reader = new FileReader();
        reader.onload = async (event) => {
            const imageDataUrl = event.target.result;
            
            // 从数据 URL 创建一个 HTMLImageElement
            const image = new Image();
            image.src = imageDataUrl;
            image.onload = async () => {
                // 预处理图像到所需的输入大小和格式
                const tensorImg = tf.browser.fromPixels(image).toFloat();
                const resizedImg = tf.image.resizeBilinear(tensorImg, [224, 224]); 
                const normalizedImg = resizedImg.div(255.0);
                const batchedImg = normalizedImg.expandDims(0);
    
                // 将张量转换为数组
                const tensorArray = await batchedImg.array();
    
                // 使用 Fetch API 通过 REST API 将张量数据发送到 Flask 反向代理
                try {
                    const response = await fetch(modelUrl, {
                        method: 'POST',
                        headers: { 'Content-Type': 'application/json' },
                        body: JSON.stringify({ instances: tensorArray })
                    });
                    const prediction = await response.json();
                    console.log(prediction);
                    
                    // 处理并显示预测结果
                    const outputTensor = prediction.predictions[0];
                    
                    let results = '';
                    const top_k = 5;
                    const topIndices = Array.from(outputTensor)
                        .map((confidence, index) => ({ confidence, index }))
                        .sort((a, b) => b.confidence - a.confidence)
                        .slice(0, top_k);
    
                    topIndices.forEach(({ confidence, index }) => {
                        const labelId = index + 1; // 标签编号从 1 开始
                        const labelName = labelMap[labelId.toString()] || 'Unknown';
                        results += `<p>Label: ${labelName}, Confidence: ${confidence.toFixed(4)}</p>`;
                    });
    
                    document.getElementById('result').innerHTML = results;
                } catch (error) {
                    console.error('Error during fetch:', error);
                }
            };
        };
        reader.readAsDataURL(file);
    });
    • 上传按钮点击事件

      • 获取用户选择的文件。如果没有文件,弹出提示。
      • 使用 FileReader 读取文件,将其转换为数据 URL。
      • 创建一个 HTMLImageElement,将数据 URL 作为其 src 属性,加载图像后处理。
    • 图像预处理

      • 使用 TensorFlow.js 的 tf.browser.fromPixels 将图像数据转换为张量。
      • 使用 tf.image.resizeBilinear 将图像调整为模型所需的输入大小 [224, 224]
      • 归一化图像像素值至 [0, 1] 范围。
      • 扩展维度以匹配批量输入的格式。
    • 发送预测请求

      • 将张量数据转换为数组,并通过 fetch 发送到 Flask 后端进行预测。
      • 获取预测结果,并提取前 top_k 个预测结果。
      • 将预测结果与标签名称映射到一起,并显示在页面上。

完整源代码

相关推荐
林的快手几秒前
209.长度最小的子数组
java·数据结构·数据库·python·算法·leetcode
从以前15 分钟前
准备考试:解决大学入学考试问题
数据结构·python·算法
Ven%31 分钟前
如何修改pip全局缓存位置和全局安装包存放路径
人工智能·python·深度学习·缓存·自然语言处理·pip
枫欢32 分钟前
将现有环境192.168.1.100中的svn迁移至新服务器192.168.1.4;
服务器·python·svn
测试杂货铺1 小时前
UI自动化测试实战实例
自动化测试·软件测试·python·selenium·测试工具·测试用例·pytest
余~~185381628001 小时前
NFC 碰一碰发视频源码搭建技术详解,支持OEM
开发语言·人工智能·python·音视频
苏三有春2 小时前
PyQt实战——使用python提取JSON数据(十)
python·json·pyqt
allnlei2 小时前
自定义 Celery的logging模块
python·celery
帅逼码农2 小时前
python爬虫代码
开发语言·爬虫·python·安全架构
跟德姆(dom)一起学AI2 小时前
0基础跟德姆(dom)一起学AI 自然语言处理05-文本特征处理
人工智能·python·深度学习·自然语言处理