这篇笔记是 mujoco 系列的番外第一篇,主要介绍如何在电脑上配置 JAX 和 MJX(Muojco + JAX),之所以介绍这方面内容是因为在实际使用中发现模型推理可能只消耗 0.05s,但 mujoco 仿真环境模拟一步有时就需要消耗 15s,如果你的问题本身不大的话造成这种情况的原因通常是因为没有开启加速或者没有选对积分器导致的。
从主线学习博客中有涉及到一部分 JAX 的内容,我看了下后面的官方教程中有一篇专门介绍如何使用 JAX 进行加速,因此开一篇番外来记录如何部署 JAX 并在 mujoco 中使用是相当有必要的。
1. 检查环境
首先要确保你的电脑上已经安装好了 Nvidia 显卡驱动,因为 JAX 和 cuda 以及 cudnn 版本关联很紧密,版本差异都可能导致 JAX 运行时无法识别 GPU。
关于显卡驱动安装和cudnn安装不是这篇博客的核心,网上有大量的资源可以借用,但一定要注意这两个版本是和你的显卡硬件配得上号的,并且 安装的 cudnn 版本是 runtime 版的:
如何安装 runtime 的 cudnn 可以参考这篇博客:【Linux安装cuda和cudnn实战教程】
1.1 检查 cuda 版本
【Note】:nvidia-smi
命令只能告诉你当前硬件与驱动最高支持的cuda版本是多少,并不是你当前正在使用的cuda版本,一定要用下面的命令查看真正在使用的版本号。
使用下面的命令:
bash
(mujoco) $ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
上面的 release 12.1
就是你的cuda版本号。
1.2 检查 cudnn 版本
在执行下面命令的时候需要注意你当前使用的 cuda 版本,特别是当你安装了多个版本的时候。
bash
(mujoco) $ cat /usr/local/cuda-12.1/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 9
#define CUDNN_MINOR 10
#define CUDNN_PATCHLEVEL 1
--
#define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
/* cannot use constexpr here since this is a C-only file */
上面的输出中 CUDNN_MAJOR
、CUDNN_MINOR
、CUDNN_PATCHLEVEL
拼接起来的到的版本号就是 9.10.1
1.3 检查 python 版本
这里需要检查的是你的 conda 环境对应的 python 版本,因此需要在conda 环境下执行命令:
python
(mujoco) $ python --version
python 3.10.17
综上所述,我的电脑上 cuda 版本号为 12.1
,cudnn 版本号为 9.10.1
,python 版本号为 3.10.17
2. 安装 brax 【可选】
如果你也是跟着我一起做 mujoco 系列博客的话建议安装执行这一章内容,因为后面会用到,如果你不需要 brax 库的话可以直接跳过这一章节。
brax 库在安装的时候会自动将 jax 更新到最新,但自动安装的版本只需要 jaxlib
和 jax>=0.4.2
即可,因此先安装 brax 库,然后在卸载自动安装的最新 jax 和 jaxlib
bash
(mujoco) $ pip install brax
(mujoco) $ pip uninstall jax
(mujoco) $ pip uninstall jaxlib
这样就可以在使用 brax 的前提下又能用指定版本的 jax 库了。
3. 下载并安装 JAX 包
3.1 常规情况
通常情况下使用 JAX 官方文档中的安装命令就可以直接适配到当前 cuda 和cudnn 最优版本,执行命令后会下载并安装很多其他依赖:
bash
(mujoco) $ conda install "jaxlib=*=*cuda*" jax -c conda-forge
比如我这里安装完成后可以看到以下内容:
bash
(mujoco) $ pip list | grep jax
jax 0.5.2
jax-cuda12-pjrt 0.5.2
jax-cuda12-plugin 0.5.2
jaxlib 0.5.2
jaxopt 0.8.5
3.2 旧版本安装
在 0.5.0
之前的版本需要指定cuda和cudnn版本号,因此如果你已经明确需要安装旧版本的 jax 库的话需要按照以下流程操作:
- 如果你的环境中已经有了 jax 库则先卸载
bash
(mujoco) $ pip list | grep jax
jax 0.6.1
jaxlib 0.6.1
(mujoco) $ pip uninstall jax
(mujoco) $ pip uninstall jaxlib
- 前往 Google JAX 库下载对应版本的whl包:
打开上面的链接找到自己对应的轮子,然后下载下来:

让后安装这个 jaxlib
:
bash
(mujoco) $ pip install jaxlib-0.4.29+cuda12.cudnn91-cp310-cp310-manylinux2014_x86_64.whl
校验安装的版本,我这里输出的是 0.4.29
:
bash
(mujoco) $ pip list | grep jax
jaxlib 0.4.29+cuda12.cudnn91
然后再安装匹配的 jax
库:
bash
(mujoco) $ pip install jax==0.4.29
此时你的conda环境中应该有以下两个库了:
bash
(mujoco) $ pip list | grep jax
jax 0.4.29
jaxlib 0.4.29+cuda12.cudnn91
【Note】旧版本很容易出现cuda不兼容的情况,除非你有十足的把握还是建议用 3.1 的方式进行安装。
4. 验证 JAX
使用下面的 python 代码验证 JAX 是否安装成功:
python
import jax
print(jax.devices())
如果输出结果类似如下则说明安装成功:
bash
[cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3)]
如果运行后的输出包含了以下内容则说明版本没有配对上:
python
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
在上面的代码输出正确内容后可以运行一个非常简单的 mujoco 示例看看 JAX 加速是否真正启动:
【Note】:这部分演示建议在 jupyter 中进行。
4.1 导入必要的包
python
import mediapy as media
import os
import jax
import mujoco
from mujoco import mjx
import time
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
4.2 准备模型和环境
python
xml = """
<mujoco>
<worldbody>
<light name="top" pos="0 0 1"/>
<body name="box_and_sphere" euler="0 0 -30">
<joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
<geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
<geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
</body>
</worldbody>
</mujoco>
"""
# CPU model, data
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)
# GPU model, data
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)
print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())
4.3 运行 CPU 示例
python
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True
duration = 3.8
framerate = 60
frames = []
mujoco.mj_resetData(mj_model, mj_data)
start_time = time.time()
while mj_data.time < duration:
mujoco.mj_step(mj_model, mj_data)
if len(frames) < mj_data.time * framerate:
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)
print(f"Total rendering cost time {time.time() - start_time:.2f}")
media.show_video(frames, fps=framerate)

4.4 运行 GPU 示例
python
jit_step = jax.jit(mjx.step)
frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
start_time = time.time()
while mjx_data.time < duration:
mjx_data = jit_step(mjx_model, mjx_data)
if len(frames) < mjx_data.time * framerate:
mj_data = mjx.get_data(mj_model, mjx_data)
renderer.update_scene(mj_data, scene_option=scene_option)
pixels = renderer.render()
frames.append(pixels)
print(f"Total rendering cost time {time.time() - start_time:.2f}")
media.show_video(frames, fps=framerate)

此时你可能会发现 GPU 耗时反而比 CPU 更长,这是因为对于单线程任务而言其效率提升不明显,MJX 的优势在与大规模并行运算。后续的主线博客会有更详细的介绍。