搭建自己的AI模型应用网站:JavaScript + Flask-Python + ONNX

1. 前言

本文作者以一个前端新手视角,部署自己的神经网络模型作为后端,搭建自己的网站实现应用的实战经历。目前实现的网页应用有:

欢迎大家试用感受,本文将以博客基于GAN的序列号码预测中训练的pytorch模型为例,进行后端和前端搭建,并构建网站,以下是最终成果展示。

网址:http://www.funsound.cn:5002

2. 相关内容

相关知识点和工具语言如下,希望读者有一定的了解

  • 腾讯云服务器
  • Html + JavaScript 进行UI设计
  • pytorch 模型,onnx 模型导出
  • python flask 后端
  • 多进程服务实现并发访问

3. 后端工作

3.1 pytorch 模型转 onnx 模型

ONNX 模型是通用的NN格式,采用onnx格式将在服务器cpu推理上速度更快。

python 复制代码
# 实例化生成器模型
generator = Generator(input_dim, output_dim)

# 加载训练好的生成器模型权重
generator.load_state_dict(torch.load('models/generator_model.pth'))
generator.eval()  # 设置生成器为评估模式

# 导出模型为 ONNX 格式
generator.export_onnx('models/generator_model.onnx', (batch_size, input_dim))

加载onnx模型进行推理

python 复制代码
# 加载 ONNX 模型
ort_session = ort.InferenceSession('models/generator_model.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
input_noise = np.random.randn(batch_size, input_dim).astype(np.float32)
generated_numbers = ort_session.run([output_name], {input_name: input_noise})[0]

基于onnx推理的CP号码生成算法封装成 【generator. LOTTO_GENERATOR】

3.2 多进程onnx服务

网站访问往往是一个多路并发访问场景,面对众多用户的请求,送入待处理,后端采用多进程进行调度。

python 复制代码
if __name__ == "__main__":
    from generator import LOTTO_GENERATOR # 我们的gan网络生成算法

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)

    # 获取并打印所有worker的状态
    res = get_workers_state(workers)
    print(res)

    # 提交100个任务
    worker_dir = "demo"
    for _ in range(100):
        task_id = generate_random_string(length=11)  # 生成长度为11的随机字符串作为task_id
        task_dir = f"{worker_dir}/{task_id}"  # 任务目录
        task_inp = generate_random_number_string(length=8)  # 生成长度为8的随机数字字符串作为任务输入
        task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
        task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
        
        os.system(f'mkdir -p {task_dir}')  # 创建任务目录
        params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
        submit_task(workers=workers, params=params)  # 提交任务
        time.sleep(0.01)  # 等待10毫秒后提交下一个任务

注意代码中多进程服务处理用户请求采用异步方式,用户提交任务后获取task_id, 主进程不会阻塞, 用户根据task_id来追踪自己的任务进度(task_prgs)和结果(task_rst)。
其中调度方式根据子进程的忙碌情况决定,选取最闲的子进程处理用户请求

python 复制代码
def submit_task(workers, params: dict):
    # 找到任务最少的worker
    min_task_worker = min(workers, key=lambda worker: worker.queue.qsize() + worker.working.value)
    min_task_worker.queue.put(params)  # 将任务提交到最少任务的worker队列中
    print(f'assign the task to worker-{min_task_worker.wid}'

3.3 基于Flask搭建http访问接口

我们的后端代码如下,例如我们的ip 是 100.100.123,端口试用5002,则构建了以下http访问接口:

http一般格式: 【http://IP地址:端口/路由】

python 复制代码
from flask import Flask, jsonify,render_template,request
from generator import LOTTO_GENERATOR
from workers import *
import datetime
import json 

def get_now_time():
    current_time = datetime.datetime.now()
    return current_time.strftime('%Y-%m-%d %H:%M:%S')

def task_log(text,log_file="TASK.LOG"):
    with open(log_file,'a+') as f:
        print(text,file=f)


app = Flask(__name__)
USER_DIR = "user_data"
TASK_MAP = {}

"""
主页
"""
@app.route('/')
def index():
    return render_template('index.html')


@app.route('/lotto', methods=['POST'])
def predict():
    # 获取客户端信息
    ip = request.remote_addr
    data = request.get_json()

    task_id = ip + "_" + generate_random_string(20)
    user_id = ip
    task_inp = data['luck_num'] # 8位数字字符串 或者 空字符串
    task_dir = "%s/%s/%s" % (USER_DIR, user_id, task_id)
    task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
    task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
    task_log(f"TIME:{get_now_time()}")
    task_log(f"TASK_ID:{task_id}")
    task_log("")

    # 生成临时文件
    if not os.path.exists(task_dir): os.makedirs(task_dir)
    with open(task_prgs,'wt') as f:
        json.dump([0.0,'start'],f,indent=4)
    TASK_MAP[task_id] = {'task_dir': task_dir,
                         'task_prgs': task_prgs,
                         'task_rst': task_rst, }
    
    # 提交任务
    params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
    submit_task(workers=workers, params=params)  # 提交任务
    return task_id


"""
获得引擎状态
"""
@app.route('/get_worker_state', methods=['GET'])
def get_worker_state():
    ip = request.remote_addr
    res = {}
    for worker in workers:
        res[worker.wid] = worker.queue.qsize() + worker.working.value
    return res


"""
获得任务进度
"""
@app.route('/get_task_prgs', methods=['POST'])
def get_task_prgs():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return [-1, 'no such task id']
    else:
        task_prgs = TASK_MAP[task_id]['task_prgs']
        with open(task_prgs, 'rt') as f:
            content = json.load(f)
        return content

"""
获得任务结果
"""
@app.route('/get_task_rst', methods=['POST'])
def get_task_rst():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return {}
    else:
        task_rst = TASK_MAP[task_id]['task_rst']
        with open(task_rst, 'rt') as f:
            content = json.load(f)
        return content

if __name__ == '__main__':

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)
    
    app.run(host='0.0.0.0', port=5002)

这样后端就搭建起来啦,这里有4个onnx 模型在后台监听

3.4 python客户端测试

python 复制代码
import requests
import time
import json

# 定义服务端地址
server_url = 'http://110.110.123:5002' # 你的服务器和端口
headers = {'Content-Type': 'application/json'}

# 检查服务器 Worker 状态
def check_worker_status():
    response = requests.get(f'{server_url}/get_worker_state')
    if response.status_code == 200:
        worker_status = response.json()
        idle_workers = [wid for wid, status in worker_status.items() if status == 0]
        if idle_workers:
            print("Idle workers available:", idle_workers)
            return True
        else:
            print("No idle workers available.")
            return False
    else:
        print("Failed to get worker status.")
        return False

# 提交任务
def submit_task(json_data):
    if not check_worker_status():
        print("No idle workers available. Task submission failed.")
        return None

    response = requests.post(f'{server_url}/lotto', json=json_data)
    if response.status_code == 200:
        task_id = response.text
        print(f"Task submitted successfully. Task ID: {task_id}")
        return task_id
    else:
        print("Failed to submit task.")
        return None

# 轮询任务进度
def poll_task_progress(task_id):
    while True:
        json_data = {'task_id':task_id}
        response = requests.post(f'{server_url}/get_task_prgs', json=json_data)
        if response.status_code == 200:
            progress, text = response.json()
            print(f"Progress: {progress}, Status: {text}")
            if progress == 1:
                print("Task completed successfully.")
                return True
            elif progress == -1:
                print(f"Task failed: {text}")
                return False
        else:
            print("Failed to get task progress.")
            return False
        time.sleep(3)  # 每3秒查询一次

# 获取任务结果
def get_task_result(task_id):
    json_data = {'task_id':task_id}
    response = requests.post(f'{server_url}/get_task_rst', json=json_data)
    if response.status_code == 200:
        result = response.json()
        print("Task result:", result)
        return result
    else:
        print("Failed to get task result.")
        return None


# 主函数
def main():
    json_data = {'luck_num':""}
    # json_data = {'luck_num':"12345678"}

    # 提交TTS任务
    task_id = submit_task(json_data)
    if not task_id:
        return
        
    # 轮询任务进度
    if poll_task_progress(task_id):
        # 获取任务结果
        result = get_task_result(task_id)

if __name__ == "__main__":
    main()

访问成功

4. 前端工作

4.1 JavaScript 访问 http 函数

JavaScript 调用 http端口如下:

html 复制代码
<script>

        /* 提交任务 */
        function submitTask() {
            var button = document.querySelector("button");
            button.disabled = true;
            button.innerText = "正在生成...";

            var useLuckyNumber = document.getElementById("use_lucky_number").checked;
            var luckInput = document.getElementById("luck_input");
            var luckNum = useLuckyNumber ? luckInput.value : "";
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/lotto", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var taskId = xhr.responseText;
                    checkProgress(taskId);
                } else if (xhr.readyState == 4) {
                    button.disabled = false;
                    button.innerText = "生成";
                    alert("任务提交失败,请重试。");
                }
            };
            xhr.send(JSON.stringify({luck_num: luckNum}));
        }

        /* 检查任务进度 */
        function checkProgress(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_prgs", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    var progress = response[0];
                    var status = response[1];
                    // document.getElementById("progress").innerText = "进度: " + progress + ", 状态: " + status;
                    if (progress == 1) {
                        getResult(taskId);
                    } else if (progress == -1) {
                        var button = document.querySelector("button");
                        button.disabled = false;
                        button.innerText = "生成";
                        alert("任务失败: " + status);
                    } else {
                        setTimeout(function() { checkProgress(taskId); }, 3000);
                    }
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 获取任务结果 */
        function getResult(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_rst", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    displayResult(response);
                    var button = document.querySelector("button");
                    button.disabled = false;
                    button.innerText = "生成";
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 显示任务结果 */
        function displayResult(response) {
            var frontNumbers = response.front_numbers;
            var backNumbers = response.back_numbers;
            var resultContainer = document.getElementById("result");
            resultContainer.innerHTML = ""; // 清空之前的结果

            for (var i = 0; i < frontNumbers.length; i++) {
                var lotterySet = document.createElement("div");
                lotterySet.className = "lottery-set";
                
                frontNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball front-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                backNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball back-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                resultContainer.appendChild(lotterySet);
            }
        }
    </script>

4.2 制作网页index.html

注意到Flask提供了网页渲染功能,这样我们可以设计我们的主页

python 复制代码
@app.route('/')
def index():
    return render_template('index.html')

把上述JS脚本放入index.html 就可以访问后端服务啦,具体html的UI显示,由于代码量很大这里不与展示了,感兴趣同学可以根据上述python客户端的访问逻辑试用GPT为你编写index.html,手机端访问效果如下:

5. 最后

上述是个人搭建自己网站部署AI应用的简单过程,完整源码后期整理上传,欢迎大家留言关注~

相关推荐
奋飛11 分钟前
TypeScript系列:第六篇 - 编写高质量的TS类型
javascript·typescript·ts·declare·.d.ts
_WndProc16 分钟前
【Python】Flask网页
开发语言·python·flask
笑衬人心。17 分钟前
初学Spring AI 笔记
人工智能·笔记·spring
互联网搬砖老肖18 分钟前
Python 中如何使用 Conda 管理版本和创建 Django 项目
python·django·conda
sunbyte20 分钟前
50天50个小项目 (Vue3 + Tailwindcss V4) ✨ | ThemeClock(主题时钟)
前端·javascript·css·vue.js·前端框架·tailwindcss
luofeiju27 分钟前
RGB下的色彩变换:用线性代数解构色彩世界
图像处理·人工智能·opencv·线性代数
小飞悟28 分钟前
🎯 什么是模块化?CommonJS 和 ES6 Modules 到底有什么区别?小白也能看懂
前端·javascript·设计
浏览器API调用工程师_Taylor29 分钟前
AOP魔法:一招实现登录弹窗的全局拦截与动态处理
前端·javascript·vue.js
测试者家园29 分钟前
基于DeepSeek和crewAI构建测试用例脚本生成器
人工智能·python·测试用例·智能体·智能化测试·crewai
FogLetter30 分钟前
初识图片懒加载:让网页像"懒人"一样聪明加载
前端·javascript