Mujoco 学习系列(番外一)MJX 部署与配置

这篇笔记是 mujoco 系列的番外第一篇,主要介绍如何在电脑上配置 JAX 和 MJX(Muojco + JAX),之所以介绍这方面内容是因为在实际使用中发现模型推理可能只消耗 0.05s,但 mujoco 仿真环境模拟一步有时就需要消耗 15s,如果你的问题本身不大的话造成这种情况的原因通常是因为没有开启加速或者没有选对积分器导致的。

从主线学习博客中有涉及到一部分 JAX 的内容,我看了下后面的官方教程中有一篇专门介绍如何使用 JAX 进行加速,因此开一篇番外来记录如何部署 JAX 并在 mujoco 中使用是相当有必要的。


1. 检查环境

首先要确保你的电脑上已经安装好了 Nvidia 显卡驱动,因为 JAX 和 cuda 以及 cudnn 版本关联很紧密,版本差异都可能导致 JAX 运行时无法识别 GPU

关于显卡驱动安装和cudnn安装不是这篇博客的核心,网上有大量的资源可以借用,但一定要注意这两个版本是和你的显卡硬件配得上号的,并且 安装的 cudnn 版本是 runtime 版的

如何安装 runtime 的 cudnn 可以参考这篇博客:【Linux安装cuda和cudnn实战教程】

1.1 检查 cuda 版本

【Note】:nvidia-smi 命令只能告诉你当前硬件与驱动最高支持的cuda版本是多少,并不是你当前正在使用的cuda版本,一定要用下面的命令查看真正在使用的版本号。

使用下面的命令:

bash 复制代码
(mujoco) $ nvcc -V

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

上面的 release 12.1 就是你的cuda版本号。

1.2 检查 cudnn 版本

在执行下面命令的时候需要注意你当前使用的 cuda 版本,特别是当你安装了多个版本的时候。

bash 复制代码
(mujoco) $ cat /usr/local/cuda-12.1/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 9
#define CUDNN_MINOR 10
#define CUDNN_PATCHLEVEL 1
--
#define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)

/* cannot use constexpr here since this is a C-only file */

上面的输出中 CUDNN_MAJORCUDNN_MINORCUDNN_PATCHLEVEL 拼接起来的到的版本号就是 9.10.1

1.3 检查 python 版本

这里需要检查的是你的 conda 环境对应的 python 版本,因此需要在conda 环境下执行命令:

python 复制代码
(mujoco) $ python --version
python 3.10.17

综上所述,我的电脑上 cuda 版本号为 12.1,cudnn 版本号为 9.10.1,python 版本号为 3.10.17


2. 安装 brax 【可选】

如果你也是跟着我一起做 mujoco 系列博客的话建议安装执行这一章内容,因为后面会用到,如果你不需要 brax 库的话可以直接跳过这一章节。

brax 库在安装的时候会自动将 jax 更新到最新,但自动安装的版本只需要 jaxlibjax>=0.4.2 即可,因此先安装 brax 库,然后在卸载自动安装的最新 jax 和 jaxlib

bash 复制代码
(mujoco) $ pip install brax
(mujoco) $ pip uninstall jax
(mujoco) $ pip uninstall jaxlib

这样就可以在使用 brax 的前提下又能用指定版本的 jax 库了。


3. 下载并安装 JAX 包

3.1 常规情况

通常情况下使用 JAX 官方文档中的安装命令就可以直接适配到当前 cuda 和cudnn 最优版本,执行命令后会下载并安装很多其他依赖:

bash 复制代码
(mujoco) $ conda install "jaxlib=*=*cuda*" jax -c conda-forge

比如我这里安装完成后可以看到以下内容:

bash 复制代码
(mujoco) $ pip list | grep jax
jax                                  0.5.2
jax-cuda12-pjrt                      0.5.2
jax-cuda12-plugin                    0.5.2
jaxlib                               0.5.2
jaxopt                               0.8.5

3.2 旧版本安装

0.5.0 之前的版本需要指定cuda和cudnn版本号,因此如果你已经明确需要安装旧版本的 jax 库的话需要按照以下流程操作:

  1. 如果你的环境中已经有了 jax 库则先卸载
bash 复制代码
(mujoco) $ pip list | grep jax
jax                            0.6.1
jaxlib                         0.6.1
(mujoco) $ pip uninstall jax
(mujoco) $ pip uninstall jaxlib
  1. 前往 Google JAX 库下载对应版本的whl包:

打开上面的链接找到自己对应的轮子,然后下载下来:

让后安装这个 jaxlib

bash 复制代码
(mujoco) $ pip install jaxlib-0.4.29+cuda12.cudnn91-cp310-cp310-manylinux2014_x86_64.whl

校验安装的版本,我这里输出的是 0.4.29

bash 复制代码
(mujoco) $ pip list | grep jax
jaxlib                   0.4.29+cuda12.cudnn91

然后再安装匹配的 jax 库:

bash 复制代码
(mujoco) $ pip install jax==0.4.29

此时你的conda环境中应该有以下两个库了:

bash 复制代码
(mujoco) $ pip list | grep jax
jax                      0.4.29
jaxlib                   0.4.29+cuda12.cudnn91

【Note】旧版本很容易出现cuda不兼容的情况,除非你有十足的把握还是建议用 3.1 的方式进行安装。


4. 验证 JAX

使用下面的 python 代码验证 JAX 是否安装成功:

python 复制代码
import jax
print(jax.devices())

如果输出结果类似如下则说明安装成功:

bash 复制代码
[cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3)]

如果运行后的输出包含了以下内容则说明版本没有配对上:

python 复制代码
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

在上面的代码输出正确内容后可以运行一个非常简单的 mujoco 示例看看 JAX 加速是否真正启动:

【Note】:这部分演示建议在 jupyter 中进行。

4.1 导入必要的包

python 复制代码
import mediapy as media
import os
import jax
import mujoco
from mujoco import mjx
import time

xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags

4.2 准备模型和环境

python 复制代码
xml = """
<mujoco>
  <worldbody>
    <light name="top" pos="0 0 1"/>
    <body name="box_and_sphere" euler="0 0 -30">
      <joint name="swing" type="hinge" axis="1 -1 0" pos="-.2 -.2 -.2"/>
      <geom name="red_box" type="box" size=".2 .2 .2" rgba="1 0 0 1"/>
      <geom name="green_sphere" pos=".2 .2 .2" size=".1" rgba="0 1 0 1"/>
    </body>
  </worldbody>
</mujoco>
"""

# CPU model, data
mj_model = mujoco.MjModel.from_xml_string(xml)
mj_data = mujoco.MjData(mj_model)
renderer = mujoco.Renderer(mj_model)

# GPU model, data
mjx_model = mjx.put_model(mj_model)
mjx_data = mjx.put_data(mj_model, mj_data)

print(mj_data.qpos, type(mj_data.qpos))
print(mjx_data.qpos, type(mjx_data.qpos), mjx_data.qpos.devices())

4.3 运行 CPU 示例

python 复制代码
scene_option = mujoco.MjvOption()
scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

duration = 3.8
framerate = 60

frames = []
mujoco.mj_resetData(mj_model, mj_data)

start_time = time.time()
while mj_data.time < duration:
  mujoco.mj_step(mj_model, mj_data)
  if len(frames) < mj_data.time * framerate:
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)
print(f"Total rendering cost time {time.time() - start_time:.2f}")

media.show_video(frames, fps=framerate)

4.4 运行 GPU 示例

python 复制代码
jit_step = jax.jit(mjx.step)

frames = []
mujoco.mj_resetData(mj_model, mj_data)
mjx_data = mjx.put_data(mj_model, mj_data)
start_time = time.time()
while mjx_data.time < duration:
  mjx_data = jit_step(mjx_model, mjx_data)
  if len(frames) < mjx_data.time * framerate:
    mj_data = mjx.get_data(mj_model, mjx_data)
    renderer.update_scene(mj_data, scene_option=scene_option)
    pixels = renderer.render()
    frames.append(pixels)
print(f"Total rendering cost time {time.time() - start_time:.2f}")

media.show_video(frames, fps=framerate)

此时你可能会发现 GPU 耗时反而比 CPU 更长,这是因为对于单线程任务而言其效率提升不明显,MJX 的优势在与大规模并行运算。后续的主线博客会有更详细的介绍。

相关推荐
一洽客服系统7 分钟前
技术为器,服务为本:AI时代的客服价值重构
人工智能
jz_ddk1 小时前
[学习] C语言多维指针探讨(代码示例)
linux·c语言·开发语言·学习·算法
moonsims2 小时前
无人机桥梁3D建模的拍摄频率
人工智能
LaughingZhu4 小时前
PH热榜 | 2025-05-29
前端·人工智能·经验分享·搜索引擎·产品运营
视觉语言导航5 小时前
俄罗斯无人机自主任务规划!UAV-CodeAgents:基于多智能体ReAct和视觉语言推理的可扩展无人机任务规划
人工智能·深度学习·无人机·具身智能
世润5 小时前
深度学习-梯度消失和梯度爆炸
人工智能·深度学习
小彭律师5 小时前
LSTM+Transformer混合模型架构文档
人工智能·lstm·transformer
-曾牛5 小时前
使用Spring AI集成Perplexity AI实现智能对话(详细配置指南)
java·人工智能·后端·spring·llm·大模型应用·springai
艾莉丝努力练剑6 小时前
深入详解编译与链接:翻译环境和运行环境,翻译环境:预编译+编译+汇编+链接,运行环境
c语言·开发语言·汇编·学习
归去_来兮7 小时前
长短期记忆(LSTM)网络模型
人工智能·深度学习·lstm·时序模型