Java Spring Boot 使用DJL 部署python训练的PyTorch模型(MNIST)
Java 使用 DJL 训练模型:https://blog.csdn.net/xundh/category_11361043.html?spm=1001.2014.3001.5515
Python 训练Pytorch模型
本项目采用
PyTorch==1.10.0
版本训练。
pytorch 1.10.0 py3.9_cuda11.3_cudnn8_0 pytorch
pytorch-mutex 1.0 cuda pytorch
requests 2.28.1 pypi_0 pypi
scipy 1.9.3 pypi_0 pypi
setuptools 65.6.3 pyhd8ed1ab_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
six 1.16.0 pypi_0 pypi
tbb 2021.7.0 h91493d7_1 conda-forge
tk 8.6.12 h8ffe710_0 https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
torchaudio 0.10.0 py39_cu113 pytorch
torchsummary 1.5.1 pypi_0 pypi
torchvision 0.11.0 py39_cu113 pytorch
修改模型保存方法:
model.eval() # 模型验证模式
example = torch.rand(1, 1, 28, 28).to(device) # 模型输入层
traced_script_module = torch.jit.trace(model, example) # trace
traced_script_module.save('models/{}_model.pt'.format(val_ac)) # 保存模型
Java Spring Boot使用DJL调用模型
pom.xml
<!-- djl 依赖 -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.12.1-0.19.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.19.0</version>
</dependency>
<dependency>
<groupId>ai.djl.opencv</groupId>
<artifactId>opencv</artifactId>