搭建自己的AI模型应用网站:JavaScript + Flask-Python + ONNX

1. 前言

本文作者以一个前端新手视角,部署自己的神经网络模型作为后端,搭建自己的网站实现应用的实战经历。目前实现的网页应用有:

欢迎大家试用感受,本文将以博客基于GAN的序列号码预测中训练的pytorch模型为例,进行后端和前端搭建,并构建网站,以下是最终成果展示。

网址:http://www.funsound.cn:5002

2. 相关内容

相关知识点和工具语言如下,希望读者有一定的了解

  • 腾讯云服务器
  • Html + JavaScript 进行UI设计
  • pytorch 模型,onnx 模型导出
  • python flask 后端
  • 多进程服务实现并发访问

3. 后端工作

3.1 pytorch 模型转 onnx 模型

ONNX 模型是通用的NN格式,采用onnx格式将在服务器cpu推理上速度更快。

python 复制代码
# 实例化生成器模型
generator = Generator(input_dim, output_dim)

# 加载训练好的生成器模型权重
generator.load_state_dict(torch.load('models/generator_model.pth'))
generator.eval()  # 设置生成器为评估模式

# 导出模型为 ONNX 格式
generator.export_onnx('models/generator_model.onnx', (batch_size, input_dim))

加载onnx模型进行推理

python 复制代码
# 加载 ONNX 模型
ort_session = ort.InferenceSession('models/generator_model.onnx')
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
input_noise = np.random.randn(batch_size, input_dim).astype(np.float32)
generated_numbers = ort_session.run([output_name], {input_name: input_noise})[0]

基于onnx推理的CP号码生成算法封装成 【generator. LOTTO_GENERATOR】

3.2 多进程onnx服务

网站访问往往是一个多路并发访问场景,面对众多用户的请求,送入待处理,后端采用多进程进行调度。

python 复制代码
if __name__ == "__main__":
    from generator import LOTTO_GENERATOR # 我们的gan网络生成算法

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)

    # 获取并打印所有worker的状态
    res = get_workers_state(workers)
    print(res)

    # 提交100个任务
    worker_dir = "demo"
    for _ in range(100):
        task_id = generate_random_string(length=11)  # 生成长度为11的随机字符串作为task_id
        task_dir = f"{worker_dir}/{task_id}"  # 任务目录
        task_inp = generate_random_number_string(length=8)  # 生成长度为8的随机数字字符串作为任务输入
        task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
        task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
        
        os.system(f'mkdir -p {task_dir}')  # 创建任务目录
        params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
        submit_task(workers=workers, params=params)  # 提交任务
        time.sleep(0.01)  # 等待10毫秒后提交下一个任务

注意代码中多进程服务处理用户请求采用异步方式,用户提交任务后获取task_id, 主进程不会阻塞, 用户根据task_id来追踪自己的任务进度(task_prgs)和结果(task_rst)。
其中调度方式根据子进程的忙碌情况决定,选取最闲的子进程处理用户请求

python 复制代码
def submit_task(workers, params: dict):
    # 找到任务最少的worker
    min_task_worker = min(workers, key=lambda worker: worker.queue.qsize() + worker.working.value)
    min_task_worker.queue.put(params)  # 将任务提交到最少任务的worker队列中
    print(f'assign the task to worker-{min_task_worker.wid}'

3.3 基于Flask搭建http访问接口

我们的后端代码如下,例如我们的ip 是 100.100.123,端口试用5002,则构建了以下http访问接口:

http一般格式: 【http://IP地址:端口/路由】

python 复制代码
from flask import Flask, jsonify,render_template,request
from generator import LOTTO_GENERATOR
from workers import *
import datetime
import json 

def get_now_time():
    current_time = datetime.datetime.now()
    return current_time.strftime('%Y-%m-%d %H:%M:%S')

def task_log(text,log_file="TASK.LOG"):
    with open(log_file,'a+') as f:
        print(text,file=f)


app = Flask(__name__)
USER_DIR = "user_data"
TASK_MAP = {}

"""
主页
"""
@app.route('/')
def index():
    return render_template('index.html')


@app.route('/lotto', methods=['POST'])
def predict():
    # 获取客户端信息
    ip = request.remote_addr
    data = request.get_json()

    task_id = ip + "_" + generate_random_string(20)
    user_id = ip
    task_inp = data['luck_num'] # 8位数字字符串 或者 空字符串
    task_dir = "%s/%s/%s" % (USER_DIR, user_id, task_id)
    task_prgs = f'{task_dir}/progress.txt'  # 任务进度文件路径
    task_rst = f'{task_dir}/result.txt'  # 任务结果文件路径
    task_log(f"TIME:{get_now_time()}")
    task_log(f"TASK_ID:{task_id}")
    task_log("")

    # 生成临时文件
    if not os.path.exists(task_dir): os.makedirs(task_dir)
    with open(task_prgs,'wt') as f:
        json.dump([0.0,'start'],f,indent=4)
    TASK_MAP[task_id] = {'task_dir': task_dir,
                         'task_prgs': task_prgs,
                         'task_rst': task_rst, }
    
    # 提交任务
    params = {
            'task_id': task_id,
            'task_inp': task_inp,
            'task_prgs': task_prgs,
            'task_rst': task_rst
        }
    submit_task(workers=workers, params=params)  # 提交任务
    return task_id


"""
获得引擎状态
"""
@app.route('/get_worker_state', methods=['GET'])
def get_worker_state():
    ip = request.remote_addr
    res = {}
    for worker in workers:
        res[worker.wid] = worker.queue.qsize() + worker.working.value
    return res


"""
获得任务进度
"""
@app.route('/get_task_prgs', methods=['POST'])
def get_task_prgs():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return [-1, 'no such task id']
    else:
        task_prgs = TASK_MAP[task_id]['task_prgs']
        with open(task_prgs, 'rt') as f:
            content = json.load(f)
        return content

"""
获得任务结果
"""
@app.route('/get_task_rst', methods=['POST'])
def get_task_rst():
    ip = request.remote_addr
    data = request.get_json()
    task_id = data['task_id']
    if task_id not in TASK_MAP:
        return {}
    else:
        task_rst = TASK_MAP[task_id]['task_rst']
        with open(task_rst, 'rt') as f:
            content = json.load(f)
        return content

if __name__ == '__main__':

    # 初始化worker数量
    nj = 4
    backends = [LOTTO_GENERATOR() for _ in range(nj)]
    workers = init_workers(nj=nj, backends=backends)
    
    app.run(host='0.0.0.0', port=5002)

这样后端就搭建起来啦,这里有4个onnx 模型在后台监听

3.4 python客户端测试

python 复制代码
import requests
import time
import json

# 定义服务端地址
server_url = 'http://110.110.123:5002' # 你的服务器和端口
headers = {'Content-Type': 'application/json'}

# 检查服务器 Worker 状态
def check_worker_status():
    response = requests.get(f'{server_url}/get_worker_state')
    if response.status_code == 200:
        worker_status = response.json()
        idle_workers = [wid for wid, status in worker_status.items() if status == 0]
        if idle_workers:
            print("Idle workers available:", idle_workers)
            return True
        else:
            print("No idle workers available.")
            return False
    else:
        print("Failed to get worker status.")
        return False

# 提交任务
def submit_task(json_data):
    if not check_worker_status():
        print("No idle workers available. Task submission failed.")
        return None

    response = requests.post(f'{server_url}/lotto', json=json_data)
    if response.status_code == 200:
        task_id = response.text
        print(f"Task submitted successfully. Task ID: {task_id}")
        return task_id
    else:
        print("Failed to submit task.")
        return None

# 轮询任务进度
def poll_task_progress(task_id):
    while True:
        json_data = {'task_id':task_id}
        response = requests.post(f'{server_url}/get_task_prgs', json=json_data)
        if response.status_code == 200:
            progress, text = response.json()
            print(f"Progress: {progress}, Status: {text}")
            if progress == 1:
                print("Task completed successfully.")
                return True
            elif progress == -1:
                print(f"Task failed: {text}")
                return False
        else:
            print("Failed to get task progress.")
            return False
        time.sleep(3)  # 每3秒查询一次

# 获取任务结果
def get_task_result(task_id):
    json_data = {'task_id':task_id}
    response = requests.post(f'{server_url}/get_task_rst', json=json_data)
    if response.status_code == 200:
        result = response.json()
        print("Task result:", result)
        return result
    else:
        print("Failed to get task result.")
        return None


# 主函数
def main():
    json_data = {'luck_num':""}
    # json_data = {'luck_num':"12345678"}

    # 提交TTS任务
    task_id = submit_task(json_data)
    if not task_id:
        return
        
    # 轮询任务进度
    if poll_task_progress(task_id):
        # 获取任务结果
        result = get_task_result(task_id)

if __name__ == "__main__":
    main()

访问成功

4. 前端工作

4.1 JavaScript 访问 http 函数

JavaScript 调用 http端口如下:

html 复制代码
<script>

        /* 提交任务 */
        function submitTask() {
            var button = document.querySelector("button");
            button.disabled = true;
            button.innerText = "正在生成...";

            var useLuckyNumber = document.getElementById("use_lucky_number").checked;
            var luckInput = document.getElementById("luck_input");
            var luckNum = useLuckyNumber ? luckInput.value : "";
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/lotto", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var taskId = xhr.responseText;
                    checkProgress(taskId);
                } else if (xhr.readyState == 4) {
                    button.disabled = false;
                    button.innerText = "生成";
                    alert("任务提交失败,请重试。");
                }
            };
            xhr.send(JSON.stringify({luck_num: luckNum}));
        }

        /* 检查任务进度 */
        function checkProgress(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_prgs", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    var progress = response[0];
                    var status = response[1];
                    // document.getElementById("progress").innerText = "进度: " + progress + ", 状态: " + status;
                    if (progress == 1) {
                        getResult(taskId);
                    } else if (progress == -1) {
                        var button = document.querySelector("button");
                        button.disabled = false;
                        button.innerText = "生成";
                        alert("任务失败: " + status);
                    } else {
                        setTimeout(function() { checkProgress(taskId); }, 3000);
                    }
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 获取任务结果 */
        function getResult(taskId) {
            var xhr = new XMLHttpRequest();
            xhr.open("POST", "/get_task_rst", true);
            xhr.setRequestHeader("Content-Type", "application/json;charset=UTF-8");
            xhr.onreadystatechange = function () {
                if (xhr.readyState == 4 && xhr.status == 200) {
                    var response = JSON.parse(xhr.responseText);
                    displayResult(response);
                    var button = document.querySelector("button");
                    button.disabled = false;
                    button.innerText = "生成";
                }
            };
            xhr.send(JSON.stringify({task_id: taskId}));
        }

        /* 显示任务结果 */
        function displayResult(response) {
            var frontNumbers = response.front_numbers;
            var backNumbers = response.back_numbers;
            var resultContainer = document.getElementById("result");
            resultContainer.innerHTML = ""; // 清空之前的结果

            for (var i = 0; i < frontNumbers.length; i++) {
                var lotterySet = document.createElement("div");
                lotterySet.className = "lottery-set";
                
                frontNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball front-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                backNumbers[i].forEach(function(number) {
                    var numberBall = document.createElement("div");
                    numberBall.className = "number-ball back-ball";
                    numberBall.innerText = number;
                    lotterySet.appendChild(numberBall);
                });

                resultContainer.appendChild(lotterySet);
            }
        }
    </script>

4.2 制作网页index.html

注意到Flask提供了网页渲染功能,这样我们可以设计我们的主页

python 复制代码
@app.route('/')
def index():
    return render_template('index.html')

把上述JS脚本放入index.html 就可以访问后端服务啦,具体html的UI显示,由于代码量很大这里不与展示了,感兴趣同学可以根据上述python客户端的访问逻辑试用GPT为你编写index.html,手机端访问效果如下:

5. 最后

上述是个人搭建自己网站部署AI应用的简单过程,完整源码后期整理上传,欢迎大家留言关注~

相关推荐
sp_fyf_20243 分钟前
【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
CoderIsArt6 分钟前
基于 BP 神经网络整定的 PID 控制
人工智能·深度学习·神经网络
编程修仙14 分钟前
Collections工具类
linux·windows·python
开源社19 分钟前
一场开源视角的AI会议即将在南京举办
人工智能·开源
FreeIPCC20 分钟前
谈一下开源生态对 AI人工智能大模型的促进作用
大数据·人工智能·机器人·开源
芝麻团坚果31 分钟前
对subprocess启动的子进程使用VSCode python debugger
linux·ide·python·subprocess·vscode debugger
机器之心38 分钟前
全球十亿级轨迹点驱动,首个轨迹基础大模型来了
人工智能·后端
z千鑫39 分钟前
【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南
人工智能·pytorch·深度学习·aigc·tensorflow·keras·codemoss
EterNity_TiMe_40 分钟前
【论文复现】神经网络的公式推导与代码实现
人工智能·python·深度学习·神经网络·数据分析·特征分析
Stara05111 小时前
Git推送+拉去+uwsgi+Nginx服务器部署项目
git·python·mysql·nginx·gitee·github·uwsgi