jax踩坑指南——人类早期驯服jax实录

文章目录

  • 前言
  • 一、jax和cuda,cudnn,nvidia-driver的对应关系
  • 二、第一坑-特定jax版本可能隐藏的不兼容
  • 三、第二坑-与pytorch的兼容
  • 四、第三坑-nvidia库版本低

前言

被jax折磨疯了,记录一下中间遇到的各种坑。jax这个新框架,比torch还娇贵,从nvidia-driver到cuda再到cudnn,你胆敢有一个不兼容我就撂挑子给你看!
cuda与nvidia-driver版本对应关系
pytorch和python版本对应关系
pytorch和cuda版本对应关系
cuda和cudnn版本对应关系
cuda和cudnn与jax版本对应关系

一、jax和cuda,cudnn,nvidia-driver的对应关系

一般我们 git clone的项目都会有一个要求的jax版本,比如:

bash 复制代码
pip install "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

很显然这个项目要求的是cuda12+jax0.4.19,但我的建议是不要直接使用这条指令去安装jax,而应该选择更具体的版本,打开jax库的链接:https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

里面会有细致的版本对应:

可以看到不仅有cuda的版本也有cudnn的版本,二者都要对应才可以。

二、第一坑-特定jax版本可能隐藏的不兼容

这里就出现第一个坑了,由于我的nvidia-driver是535版本,我去cuda与nvidia-driver版本对应关系看了一下只有cuda-12.1能装,所以兴奋的装了cuda12.1+cudnn8.9.7之后,安装jax

bash 复制代码
pip install --upgrade jax==0.4.19 jaxlib==0.4.19+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

在验证jax安装是否成功时

bash 复制代码
import jax
print(jax.devices())

得到报错:

bash 复制代码
CUDA backend failed to initialize: Found CUDA version 12010, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

我还特意查了jax官方文档,显示兼容的,但就是不work,估计是0.4.19版本必须要cuda>=12.2吧

到这里我只能选择升级显卡驱动,一不做二不休,直接升级到560版本,这次从cuda12.1-12.6都支持

三、第二坑-与pytorch的兼容

我大意的选择了cuda-12.3+cudnn8.9.7

然后发现我还需要安装pytorch,随便看一个pytorch版本:https://pytorch.org/get-started/previous-versions/

很好,cuda-12.1和cuda12.4都有,就是没有cuda-12.3,重装!

四、第三坑-nvidia库版本低

cuda-12.4和cudnn-8.9.7终于装完了,运行代码,继续报错:

bash 复制代码
CUDA backend failed to initialize: Found cuSOLVER version 11405, but JAX was built against version 11502, which is newer. The copy of cuSOLVER that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

我懵了,我跑去问gpt,gpt说是因为系统存在多个cuda版本,系统在加载旧的 CUDA 11.8 的动态库。我跟着它操作了一大通没有任何卵用。最后在pip list指令下发现是我虚拟环境里nvidia库版本太低

bash 复制代码
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.6.85                  pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.18.1                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.4.99                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi

含泪一个个升级,最后升级成这样

bash 复制代码
nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvcc-cu12    12.6.85
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.5.2.141
nvidia-cusparse-cu12     12.4.1.24
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.4.99
nvidia-nvtx-cu12         12.4.127

这次终于大功告成!

验证jax:

bash 复制代码
import jax
print(jax.devices())

得到

bash 复制代码
[cuda(id=0)]

代码也能正常执行。

总结:最后版本cuda12.4+cudnn897+jax0.4.19

相关推荐
柴 基1 分钟前
PyTorch 使用指南
人工智能·pytorch·python
前端Hardy8 分钟前
Python 打造 Excel 到 JSON 转换工具:从开发到打包全攻略
前端·后端·python
荔枝吻27 分钟前
【保姆级喂饭教程】Python依赖管理工具大全:Virtualenv、venv、Pipenv、Poetry、pdm、Rye、UV、Conda、Pixi等
python·uv·环境管理
mortimer38 分钟前
从 `__init__` 的重复劳动中解放出来:使用 dataclass 重构简化python
python
Blossom.1181 小时前
基于深度学习的图像分类:使用ShuffleNet实现高效分类
人工智能·python·深度学习·目标检测·机器学习·分类·数据挖掘
Lenyiin1 小时前
《LeetCode 热题 100》整整 100 题量大管饱题解套餐 中
java·c++·python·leetcode·面试·刷题·lenyiin
WJ.Polar1 小时前
Python与Mysql
开发语言·数据库·python·mysql
pk_xz1234561 小时前
社区资源媒体管理系统设计与实现
网络·python·深度学习·算法·数据挖掘·媒体
普郎特3 小时前
大白话帮你彻底理解 aiohttp 的 ClientSession 与 ClientResponse 对象
爬虫·python
空中湖3 小时前
PyTorch武侠演义 第一卷:初入江湖 第7章:矿洞中的计算禁制
人工智能·pytorch·python