TensorFlow Serving 主要是用于部署和提供机器学习模型的预测服务,而不是一个完整的静态网页的 Web 服务器,不提供静态网页服务功能。而在生产应用中,都会需要网页服务,如何解决Tensorflow Serving的不足呢?
可以使用反向代理服务器来解决这个问题。反向代理服务器位于客户端和一或多个后端服务器之间,作为客户端请求的中介。反向代理服务器接受来自客户端的请求,然后将这些请求转发到一个或多个后端服务器进行处理。处理完成后,反向代理服务器将结果返回给客户端。常用的应用程序服务器,均可配置为Tensorflow Serving服务器的反向代理服务器,如Nginx、Apache、Flask、HAProxy、Caddy 或 Traefik等。
编者之前常用Apache2,尝试配置Apache2作为Tensoflow Serving的反向代理服务器,但花了大量时间,尝试了多种方法均没有解决CORS(Cross-Origin Resource Sharing,跨源资源共享)问题,Apache2始终因为CORS而不能访问Tensorflow Serving。最后尝试了Flask,它非常简单易用,大概5分钟,一次尝试及成功。
Flask 是一个用 Python 编写的轻量级 Web 框架,特别适用于开发中小型应用程序或服务,成为开发小型 Web 应用、RESTful API 和微服务的理想选择。
下面仍然以花卉识别为例来介绍Tensorflow Serving + Flask的编程实践。
文末附完整源代码。
一、Flask 安装及部署
1. 使用 pip 安装 Flask:
pip install Flask
2. Flask 目录结构
以花卉识别为例,它的目录结构规划如下所示:
your_project/
├── app_server.py # 主应用程序文件,包含 Flask 应用的代码
├── static/ # 存放静态文件(如 CSS、JavaScript、图像)
│ └──conf # 存放配置文件
│ └── label_map.json # 花卉标签
│ └──css # 存放 CSS 文件
│ └── styles.css # 示例 CSS 文件
│ └── js # 存放 JavaScript 文件
│ └── ai_flower_flask.js # 花卉识别 JavaScript 文件
├── templates/ # 存放模板文件(如 HTML 文件)
└── index.html # HTML 文件
说明
app_server.py
:主 Flask 应用程序文件,包含所有路由和视图函数。static/
:存放静态文件夹,Web 服务器将这些文件提供给客户端。templates/
:存放模板文件夹,Flask 使用这些模板来渲染 HTML 页面。
3. 创建 Flask 应用
编写服务端程序app_server.py,其中的关键代码解释如下:
(1)创建 Flask 应用实例:
python
app = Flask(__name__, static_folder='static', template_folder='templates')
static_folder
:指定静态文件夹的位置(如 CSS、JavaScript、图像)。template_folder
:指定模板文件夹的位置(如 HTML 文件)。
(2)配置 CORS:
python
CORS(app)
- 允许跨域请求,配置了所有路由都可以接受来自其他域的请求。
(3)路由和视图函数 :
1)首页路由:
```python
@app.route('/')
def index():
return render_template('index.html')
```
渲染 `templates` 文件夹中的 `index.html` 文件并返回给客户端。
2) 静态文件路由:
```python
@app.route('/static/<path:filename>')
def serve_static(filename):
return send_from_directory(app.static_folder, filename)
```
从 `static` 文件夹中提供静态文件(如图片、样式表等)。
3) 花卉识别路由:
```python
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.json
response = requests.post('http://localhost:8501/v1/models/ai_flower:predict', json=data)
predictions = response.json()
return jsonify(predictions)
except Exception as e:
return jsonify({'error': str(e)}), 500
```
- 从 POST 请求中获取 JSON 数据。
- 将数据转发到 TensorFlow Serving 的预测 API。
- 返回 TensorFlow Serving 的预测结果或错误信息。
4)运行应用:
python
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
- 启动 Flask 开发服务器,监听所有 IP 地址 (
0.0.0.0
) 的 5000 端口。 debug=True
:启用调试模式,便于开发和调试。
4. 运行 Flask 服务器
编写完成app_server.py程序后,执行该程序:
python app_server.py
然后通过浏览器访问 http://<your_server_ip>:5000,将看到 index.html 页面。
二、前端代码
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Classification</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!--link rel="stylesheet" href="css/style.css"-->
</head>
<body>
<h1>Image Classification</h1>
<input type="file" id="file-input" />
<button id="upload-button">Upload and Predict</button>
<div id="result"></div>
<script src="/static/js/ai_flower_flask.js" defer></script>
</body>
</html>
ai_flower_flask.js
这段代码用于处理用户上传的图像文件,将其通过 TensorFlow.js 处理后发送到 Flask 后端进行预测,然后将预测结果显示给用户。以下是代码的详细解释:
-
加载标签映射:
javascriptasync function loadLabelMap() { console.log('Label URL:', labelUrl); // 打印 labelUrl const response = await fetch(labelUrl); if (response.ok) { labelMap = await response.json(); } else { console.error('Failed to load label_map.json'); console.log(labelUrl); } } // 调用函数加载标签映射 loadLabelMap();
loadLabelMap
函数:异步加载标签映射文件label_map.json
。成功加载后,将其内容存储在labelMap
中。如果加载失败,则在控制台输出错误信息。loadLabelMap()
:调用该函数以初始化labelMap
。
-
处理图像上传和预测:
javascriptdocument.getElementById('upload-button').addEventListener('click', async () => { const fileInput = document.getElementById('file-input'); const file = fileInput.files[0]; if (!file) { alert("Please select an image file."); return; } // 读取图像文件作为数据 URL const reader = new FileReader(); reader.onload = async (event) => { const imageDataUrl = event.target.result; // 从数据 URL 创建一个 HTMLImageElement const image = new Image(); image.src = imageDataUrl; image.onload = async () => { // 预处理图像到所需的输入大小和格式 const tensorImg = tf.browser.fromPixels(image).toFloat(); const resizedImg = tf.image.resizeBilinear(tensorImg, [224, 224]); const normalizedImg = resizedImg.div(255.0); const batchedImg = normalizedImg.expandDims(0); // 将张量转换为数组 const tensorArray = await batchedImg.array(); // 使用 Fetch API 通过 REST API 将张量数据发送到 Flask 反向代理 try { const response = await fetch(modelUrl, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ instances: tensorArray }) }); const prediction = await response.json(); console.log(prediction); // 处理并显示预测结果 const outputTensor = prediction.predictions[0]; let results = ''; const top_k = 5; const topIndices = Array.from(outputTensor) .map((confidence, index) => ({ confidence, index })) .sort((a, b) => b.confidence - a.confidence) .slice(0, top_k); topIndices.forEach(({ confidence, index }) => { const labelId = index + 1; // 标签编号从 1 开始 const labelName = labelMap[labelId.toString()] || 'Unknown'; results += `<p>Label: ${labelName}, Confidence: ${confidence.toFixed(4)}</p>`; }); document.getElementById('result').innerHTML = results; } catch (error) { console.error('Error during fetch:', error); } }; }; reader.readAsDataURL(file); });
-
上传按钮点击事件:
- 获取用户选择的文件。如果没有文件,弹出提示。
- 使用
FileReader
读取文件,将其转换为数据 URL。 - 创建一个
HTMLImageElement
,将数据 URL 作为其src
属性,加载图像后处理。
-
图像预处理:
- 使用 TensorFlow.js 的
tf.browser.fromPixels
将图像数据转换为张量。 - 使用
tf.image.resizeBilinear
将图像调整为模型所需的输入大小[224, 224]
。 - 归一化图像像素值至
[0, 1]
范围。 - 扩展维度以匹配批量输入的格式。
- 使用 TensorFlow.js 的
-
发送预测请求:
- 将张量数据转换为数组,并通过
fetch
发送到 Flask 后端进行预测。 - 获取预测结果,并提取前
top_k
个预测结果。 - 将预测结果与标签名称映射到一起,并显示在页面上。
- 将张量数据转换为数组,并通过
-