【django】模型部署过程

模型部署

  1. 准备工作
    模型保存格式
    确保你的模型已保存为可加载的格式:
    ● TensorFlow/Keras:.h5 或 SavedModel 格式
    ● PyTorch:.pt 或 .pth 文件
    ● Scikit-learn:使用 joblib 或 pickle 保存(推荐 joblib)

示例:保存 Scikit-learn 模型

from sklearn.ensemble import RandomForestClassifier

import joblib

model = RandomForestClassifier()

model.fit(X_train, y_train)

joblib.dump(model, 'my_model.joblib')

  1. 项目结构规划

    建议的 Django 项目结构:

    myproject/

    ├── myapp/

    │ ├── models/ # 存放模型文件

    │ │ └── my_model.joblib

    │ ├── views.py # 处理请求和模型调用

    │ ├── urls.py # 定义API路由

    │ └── ...

    ├── myproject/

    │ ├── settings.py

    │ └── urls.py # 主路由

    └── manage.py

  2. 模型加载与初始化

    在 Django 中全局加载模型

    在 myapp/apps.py 或 views.py 中初始化模型,避免每次请求重复加载。

myapp/views.py

from django.http import JsonResponse

from django.views.decorators.csrf import csrf_exempt

import joblib

import os

全局加载模型

model_path = os.path.join(os.path.dirname(file ), 'models/my_model.joblib')

model = joblib.load(model_path)

@csrf_exempt # 若需跨域访问可临时禁用CSRF(生产环境需谨慎)

def predict(request):

if request.method == 'POST':

try:

获取输入数据(假设发送JSON)

data = json.loads(request.body)

features = data['features']

复制代码
        # 调用模型预测
        prediction = model.predict([features])[0]
        
        return JsonResponse({'prediction': prediction})
    except Exception as e:
        return JsonResponse({'error': str(e)}, status=400)
return JsonResponse({'error': '仅支持POST请求'}, status=405)
  1. 配置路由
    在 myapp/urls.py 中添加API路由
    from django.urls import path
    from . import views

urlpatterns = [

path('predict/', views.predict, name='predict'),

]

在项目主路由 myproject/urls.py 中引入

from django.urls import include, path

urlpatterns = [

path('api/', include('myapp.urls')),

]

  1. 测试API

    使用 curl 或 Postman 发送POST请求测试:

    curl -X POST http://localhost:8000/api/predict/

    -H "Content-Type: application/json"

    -d '{"features": [1.2, 3.4, 5.6]}'

    预期响应:

    {"prediction": 0}

  2. 高级优化

    异步处理(Celery + Redis)

    如果模型推理耗时较长,可用 Celery 异步任务避免阻塞请求:

tasks.py(Celery任务)

from celery import shared_task

from myapp.views import model # 复用全局加载的模型

@shared_task

def async_predict(features):

return model.predict([features])[0]

views.py 修改为异步调用

@csrf_exempt

def predict(request):

if request.method == 'POST':

data = json.loads(request.body)

task = async_predict.delay(data['features'])

return JsonResponse({'task_id': task.id}, status=202)

缓存模型输出

使用 Django 缓存减少重复计算:

from django.core.cache import cache

def predict(request):

data = json.loads(request.body)

features = tuple(data['features']) # 转换为可哈希类型

复制代码
# 检查缓存
if cache.has_key(features):
    return JsonResponse({'prediction': cache.get(features)})

# 计算并缓存
prediction = model.predict([features])[0]
cache.set(features, prediction, timeout=3600)  # 缓存1小时
return JsonResponse({'prediction': prediction})
  1. 关键注意事项
  2. 线程安全:
    from threading import Lock
    model_lock = Lock()

def predict(request):

with model_lock:

prediction = model.predict(...)

○ 如果模型非线程安全(如某些 TensorFlow 旧版本),需加锁或使用单例模式。

  1. 性能优化:

○ 使用 gunicorn 或 uvicorn 替代 Django 自带的开发服务器。

○ 启用 GPU 加速(如 TensorFlow/PyTorch 的 GPU 版本)。

  1. 输入验证:

def validate_features(features):

if len(features) != 3:

raise ValueError("必须提供3个特征")

if not all(isinstance(x, (int, float)) for x in features):

raise ValueError("特征必须为数字")

○ 严格校验前端传入的数据格式和范围,防止恶意输入。

  1. 依赖管理:

tensorflow2.12.0
scikit-learn1.2.2

joblib==1.2.0

○ 在 requirements.txt 中明确指定模型库版本:

完整示例:图像分类模型集成

假设有一个图像分类模型(如 ResNet),可按以下方式处理文件上传:

views.py

from django.core.files.storage import default_storage

from tensorflow.keras.preprocessing import image

import numpy as np

def predict_image(request):

if request.method == 'POST':

file = request.FILES['image']

file_path = default_storage.save('tmp/' + file.name, file)

复制代码
    # 预处理图像
    img = image.load_img(file_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    # 预测
    prediction = model.predict(img_array)
    class_idx = np.argmax(prediction)
    
    return JsonResponse({'class': class_idx})

通过以上步骤,你可以将训练好的模型无缝集成到 Django 中,并通过 RESTful API 提供服务。根据实际需求调整代码结构和优化策略。

相关推荐
小白—人工智能14 分钟前
数据可视化 —— 多边图应用(大全)
python·信息可视化·数据可视化
noravinsc21 分钟前
使用django实现windows任务调度管理
python·django·sqlite
hvinsion22 分钟前
【Python 开源】你的 Windows 关机助手——PyQt5 版定时关机工具
windows·python·开源·定时关机
只因在人海中多看了你一眼22 分钟前
Django从零搭建卖家中心注册页面实战
python·django
亿牛云爬虫专家28 分钟前
Pyppeteer实战:基于Python的无头浏览器控制新选择
python·数据采集·爬虫代理·代理ip·无头浏览器·小红书·pyppeteer
小森776733 分钟前
(四)机器学习---逻辑回归及其Python实现
人工智能·python·算法·机器学习·逻辑回归·线性回归
生信碱移36 分钟前
入门级宏基因组数据分析教程,从实验到分析与应用
人工智能·经验分享·python·神经网络·数据挖掘·数据分析·数据可视化
码农不惑1 小时前
Django的定制以及admin
数据库·python·django·sqlite
风象南1 小时前
SpringBoot项目如何用ServiceLocatorFactoryBean优雅切换支付渠道?
java·spring boot·后端
杂学者1 小时前
python 办公自动化------ excel文件的操作,读取、写入
python·excel