【django】模型部署过程

模型部署

  1. 准备工作
    模型保存格式
    确保你的模型已保存为可加载的格式:
    ● TensorFlow/Keras:.h5 或 SavedModel 格式
    ● PyTorch:.pt 或 .pth 文件
    ● Scikit-learn:使用 joblib 或 pickle 保存(推荐 joblib)

示例:保存 Scikit-learn 模型

from sklearn.ensemble import RandomForestClassifier

import joblib

model = RandomForestClassifier()

model.fit(X_train, y_train)

joblib.dump(model, 'my_model.joblib')

  1. 项目结构规划

    建议的 Django 项目结构:

    myproject/

    ├── myapp/

    │ ├── models/ # 存放模型文件

    │ │ └── my_model.joblib

    │ ├── views.py # 处理请求和模型调用

    │ ├── urls.py # 定义API路由

    │ └── ...

    ├── myproject/

    │ ├── settings.py

    │ └── urls.py # 主路由

    └── manage.py

  2. 模型加载与初始化

    在 Django 中全局加载模型

    在 myapp/apps.py 或 views.py 中初始化模型,避免每次请求重复加载。

myapp/views.py

from django.http import JsonResponse

from django.views.decorators.csrf import csrf_exempt

import joblib

import os

全局加载模型

model_path = os.path.join(os.path.dirname(file ), 'models/my_model.joblib')

model = joblib.load(model_path)

@csrf_exempt # 若需跨域访问可临时禁用CSRF(生产环境需谨慎)

def predict(request):

if request.method == 'POST':

try:

获取输入数据(假设发送JSON)

data = json.loads(request.body)

features = data['features']

        # 调用模型预测
        prediction = model.predict([features])[0]
        
        return JsonResponse({'prediction': prediction})
    except Exception as e:
        return JsonResponse({'error': str(e)}, status=400)
return JsonResponse({'error': '仅支持POST请求'}, status=405)
  1. 配置路由
    在 myapp/urls.py 中添加API路由
    from django.urls import path
    from . import views

urlpatterns = [

path('predict/', views.predict, name='predict'),

]

在项目主路由 myproject/urls.py 中引入

from django.urls import include, path

urlpatterns = [

path('api/', include('myapp.urls')),

]

  1. 测试API

    使用 curl 或 Postman 发送POST请求测试:

    curl -X POST http://localhost:8000/api/predict/

    -H "Content-Type: application/json"

    -d '{"features": [1.2, 3.4, 5.6]}'

    预期响应:

    {"prediction": 0}

  2. 高级优化

    异步处理(Celery + Redis)

    如果模型推理耗时较长,可用 Celery 异步任务避免阻塞请求:

tasks.py(Celery任务)

from celery import shared_task

from myapp.views import model # 复用全局加载的模型

@shared_task

def async_predict(features):

return model.predict([features])[0]

views.py 修改为异步调用

@csrf_exempt

def predict(request):

if request.method == 'POST':

data = json.loads(request.body)

task = async_predict.delay(data['features'])

return JsonResponse({'task_id': task.id}, status=202)

缓存模型输出

使用 Django 缓存减少重复计算:

from django.core.cache import cache

def predict(request):

data = json.loads(request.body)

features = tuple(data['features']) # 转换为可哈希类型

# 检查缓存
if cache.has_key(features):
    return JsonResponse({'prediction': cache.get(features)})

# 计算并缓存
prediction = model.predict([features])[0]
cache.set(features, prediction, timeout=3600)  # 缓存1小时
return JsonResponse({'prediction': prediction})
  1. 关键注意事项
  2. 线程安全:
    from threading import Lock
    model_lock = Lock()

def predict(request):

with model_lock:

prediction = model.predict(...)

○ 如果模型非线程安全(如某些 TensorFlow 旧版本),需加锁或使用单例模式。

  1. 性能优化:

○ 使用 gunicorn 或 uvicorn 替代 Django 自带的开发服务器。

○ 启用 GPU 加速(如 TensorFlow/PyTorch 的 GPU 版本)。

  1. 输入验证:

def validate_features(features):

if len(features) != 3:

raise ValueError("必须提供3个特征")

if not all(isinstance(x, (int, float)) for x in features):

raise ValueError("特征必须为数字")

○ 严格校验前端传入的数据格式和范围,防止恶意输入。

  1. 依赖管理:

tensorflow2.12.0
scikit-learn1.2.2

joblib==1.2.0

○ 在 requirements.txt 中明确指定模型库版本:

完整示例:图像分类模型集成

假设有一个图像分类模型(如 ResNet),可按以下方式处理文件上传:

views.py

from django.core.files.storage import default_storage

from tensorflow.keras.preprocessing import image

import numpy as np

def predict_image(request):

if request.method == 'POST':

file = request.FILES['image']

file_path = default_storage.save('tmp/' + file.name, file)

    # 预处理图像
    img = image.load_img(file_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    # 预测
    prediction = model.predict(img_array)
    class_idx = np.argmax(prediction)
    
    return JsonResponse({'class': class_idx})

通过以上步骤,你可以将训练好的模型无缝集成到 Django 中,并通过 RESTful API 提供服务。根据实际需求调整代码结构和优化策略。

相关推荐
დ旧言~20 分钟前
【Python】使用库
python
小杨40435 分钟前
springboot框架启动流程二(源码分析)
spring boot·后端·架构
星尘库42 分钟前
基于SpringBoot的失物招领平台的设计与实现
vue.js·spring boot·后端·小程序
哥是黑大帅1 小时前
Milvus向量数据库部署
数据库·python·milvus
云卷️1 小时前
微服务面试题及原理
java·后端·微服务·云原生·架构
补三补四1 小时前
Django与数据库
数据库·python·django
上海研博数据1 小时前
codewave初识
后端
lczdyx2 小时前
Transformer 代码剖析6 - 位置编码 (pytorch实现)
人工智能·pytorch·python·深度学习·transformer
云天徽上2 小时前
【目标检测】目标检测中的数据增强终极指南:从原理到实战,用Python解锁模型性能提升密码(附YOLOv5实战代码)
人工智能·python·yolo·目标检测·机器学习·计算机视觉