【django】模型部署过程

模型部署

  1. 准备工作
    模型保存格式
    确保你的模型已保存为可加载的格式:
    ● TensorFlow/Keras:.h5 或 SavedModel 格式
    ● PyTorch:.pt 或 .pth 文件
    ● Scikit-learn:使用 joblib 或 pickle 保存(推荐 joblib)

示例:保存 Scikit-learn 模型

from sklearn.ensemble import RandomForestClassifier

import joblib

model = RandomForestClassifier()

model.fit(X_train, y_train)

joblib.dump(model, 'my_model.joblib')

  1. 项目结构规划

    建议的 Django 项目结构:

    myproject/

    ├── myapp/

    │ ├── models/ # 存放模型文件

    │ │ └── my_model.joblib

    │ ├── views.py # 处理请求和模型调用

    │ ├── urls.py # 定义API路由

    │ └── ...

    ├── myproject/

    │ ├── settings.py

    │ └── urls.py # 主路由

    └── manage.py

  2. 模型加载与初始化

    在 Django 中全局加载模型

    在 myapp/apps.py 或 views.py 中初始化模型,避免每次请求重复加载。

myapp/views.py

from django.http import JsonResponse

from django.views.decorators.csrf import csrf_exempt

import joblib

import os

全局加载模型

model_path = os.path.join(os.path.dirname(file ), 'models/my_model.joblib')

model = joblib.load(model_path)

@csrf_exempt # 若需跨域访问可临时禁用CSRF(生产环境需谨慎)

def predict(request):

if request.method == 'POST':

try:

获取输入数据(假设发送JSON)

data = json.loads(request.body)

features = data['features']

复制代码
        # 调用模型预测
        prediction = model.predict([features])[0]
        
        return JsonResponse({'prediction': prediction})
    except Exception as e:
        return JsonResponse({'error': str(e)}, status=400)
return JsonResponse({'error': '仅支持POST请求'}, status=405)
  1. 配置路由
    在 myapp/urls.py 中添加API路由
    from django.urls import path
    from . import views

urlpatterns = [

path('predict/', views.predict, name='predict'),

]

在项目主路由 myproject/urls.py 中引入

from django.urls import include, path

urlpatterns = [

path('api/', include('myapp.urls')),

]

  1. 测试API

    使用 curl 或 Postman 发送POST请求测试:

    curl -X POST http://localhost:8000/api/predict/

    -H "Content-Type: application/json"

    -d '{"features": [1.2, 3.4, 5.6]}'

    预期响应:

    {"prediction": 0}

  2. 高级优化

    异步处理(Celery + Redis)

    如果模型推理耗时较长,可用 Celery 异步任务避免阻塞请求:

tasks.py(Celery任务)

from celery import shared_task

from myapp.views import model # 复用全局加载的模型

@shared_task

def async_predict(features):

return model.predict([features])[0]

views.py 修改为异步调用

@csrf_exempt

def predict(request):

if request.method == 'POST':

data = json.loads(request.body)

task = async_predict.delay(data['features'])

return JsonResponse({'task_id': task.id}, status=202)

缓存模型输出

使用 Django 缓存减少重复计算:

from django.core.cache import cache

def predict(request):

data = json.loads(request.body)

features = tuple(data['features']) # 转换为可哈希类型

复制代码
# 检查缓存
if cache.has_key(features):
    return JsonResponse({'prediction': cache.get(features)})

# 计算并缓存
prediction = model.predict([features])[0]
cache.set(features, prediction, timeout=3600)  # 缓存1小时
return JsonResponse({'prediction': prediction})
  1. 关键注意事项
  2. 线程安全:
    from threading import Lock
    model_lock = Lock()

def predict(request):

with model_lock:

prediction = model.predict(...)

○ 如果模型非线程安全(如某些 TensorFlow 旧版本),需加锁或使用单例模式。

  1. 性能优化:

○ 使用 gunicorn 或 uvicorn 替代 Django 自带的开发服务器。

○ 启用 GPU 加速(如 TensorFlow/PyTorch 的 GPU 版本)。

  1. 输入验证:

def validate_features(features):

if len(features) != 3:

raise ValueError("必须提供3个特征")

if not all(isinstance(x, (int, float)) for x in features):

raise ValueError("特征必须为数字")

○ 严格校验前端传入的数据格式和范围,防止恶意输入。

  1. 依赖管理:

tensorflow2.12.0
scikit-learn1.2.2

joblib==1.2.0

○ 在 requirements.txt 中明确指定模型库版本:

完整示例:图像分类模型集成

假设有一个图像分类模型(如 ResNet),可按以下方式处理文件上传:

views.py

from django.core.files.storage import default_storage

from tensorflow.keras.preprocessing import image

import numpy as np

def predict_image(request):

if request.method == 'POST':

file = request.FILES['image']

file_path = default_storage.save('tmp/' + file.name, file)

复制代码
    # 预处理图像
    img = image.load_img(file_path, target_size=(224, 224))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0
    
    # 预测
    prediction = model.predict(img_array)
    class_idx = np.argmax(prediction)
    
    return JsonResponse({'class': class_idx})

通过以上步骤,你可以将训练好的模型无缝集成到 Django 中,并通过 RESTful API 提供服务。根据实际需求调整代码结构和优化策略。

相关推荐
linweidong1 小时前
Go开发简历优化指南
分布式·后端·golang·高并发·简历优化·go面试·后端面经
敢敢变成了憨憨1 小时前
java操作服务器文件(把解析过的文件迁移到历史文件夹地下)
java·服务器·python
敲键盘的小夜猫2 小时前
Milvus向量Search查询综合案例实战(下)
数据库·python·milvus
咖啡啡不加糖2 小时前
雪花算法:分布式ID生成的优雅解决方案
java·分布式·后端
简简单单做算法2 小时前
基于mediapipe深度学习的虚拟画板系统python源码
人工智能·python·深度学习·mediapipe·虚拟画板
姑苏洛言2 小时前
基于微信公众号小程序的课表管理平台设计与实现
前端·后端
烛阴3 小时前
比UUID更快更小更强大!NanoID唯一ID生成神器全解析
前端·javascript·后端
why1513 小时前
字节golang后端二面
开发语言·后端·golang
还是鼠鼠3 小时前
单元测试-断言&常见注解
java·开发语言·后端·单元测试·maven
cainiao0806053 小时前
Spring Boot 4.0实战:构建高并发电商系统
java·spring boot·后端