使用 Docker 容器运行 Grok-1| 提供可用的镜像

概述

最近源神开源了 Grok-1 大模型,想着跑起来看看是什么样子。Grok 的 GitHub 里写的非常清楚了,首先 clone 代码,然后下载模型(大概 300 个 G),然后执行:

shell 复制代码
pip install -r requirements.txt
python run.py

听起来很简单,就像把大象塞进冰箱需要几步一样。但是实际上模型要依赖 jax、jaxlib,这俩对环境要求还是比较苛刻的,所以尝试在服务器上运行了一下,各种报错,无奈只能使用容器一个个环境的尝试,最后成功构建出一个可以运行的镜像(下面会展示宿主机和容器的环境)。这个镜像是适用于我们的环境的,在别的环境下不知道能否正常运行,所以欢迎你使用后给出一点反馈。

我做了什么

首先模型文件非常大,不适合每次都 docker cp 进基础环境的容器中,而且如果这个容器经过调试后可用,那么 commit 时也会把模型顺带着保存,那么这个镜像的体积可就太大了。所以模型文件,使用 -v 挂载进容器的 /root 下。

而代码比较小,大概 900MB,调试中免不了要修改一些代码,并且这些是希望调试好后直接保存进容器的,所以我将程序代码通过 docker cp 复制到了容器里,并且提交的镜像里也有,方便你直接使用。

然后就是安装各种环境,遇到一个报错解决一个。

GitHub

项目地址:github.com/mayooot/gro...

欢迎 ✨

快速启动

首先拉取镜像,大概 8 个 G。

shell 复制代码
docker pull mayooot/grok-docker:v1

然后要将下载的模型文件 ckpt-0 目录挂载进容器,下载教程可以参考这篇文章:Grok-1 本地部署过程

最后启动容器。

  • 注意要将 $your-dir/ckpt-0 替换成你的实际模型地址。
  • 共享内存设置为了 600g,应该是够用的,如果不够,请自行调整。
  • 要跑起来模型大概需要 8 张 A800/A100。所以这里使用 --gpus all 将所有 gpu 挂载进去。
shell 复制代码
docker run -d -it \
--network=host \
--shm-size 600g \
--name=grok-docker \
--gpus all \
-v $your-dir/ckpt-0:/root/ckpt-0 \
mayooot/grok-docker:v1

训练

程序代码已经存在于容器中,并且修改了模型的加载路径,所以只要你正确的把 ckpt-0 挂载进容器,那么直接执行下面代码,然后等待结果。

shell 复制代码
docker exec -it grok-docker bash
cd /root/grok-1/
python run.py

运行结果:

环境

宿主机环境

  • OS: Ubuntu 20.04.4
  • Physical Storage: 1TB
  • Physical Memory: 2TB
  • GPU: 8 * NVIDIA A100 80GB
  • Docker: 24.0.5
  • Nvidia Driver: 525.85.12

容器环境

yaml 复制代码
$ cat /etc/issue
Ubuntu 22.04.1 LTS \n \l

$ python --version
Python 3.10.8

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

$ pip show jax
Name: jax
Version: 0.4.26
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /root/miniconda3/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, flax, optax, orbax-checkpoint

$ pip show jaxlib
Name: jaxlib
Version: 0.4.26+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /root/miniconda3/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, optax, orbax-checkpoint
相关推荐
东胜物联几秒前
探寻5G工业网关市场,5G工业网关品牌解析
人工智能·嵌入式硬件·5g
皓74111 分钟前
服饰电商行业知识管理的创新实践与知识中台的重要性
大数据·人工智能·科技·数据分析·零售
川石课堂软件测试41 分钟前
性能测试|docker容器下搭建JMeter+Grafana+Influxdb监控可视化平台
运维·javascript·深度学习·jmeter·docker·容器·grafana
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读3】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer
AltmanChan1 小时前
大语言模型安全威胁
人工智能·安全·语言模型
985小水博一枚呀1 小时前
【深度学习滑坡制图|论文解读2】基于融合CNN-Transformer网络和深度迁移学习的遥感影像滑坡制图方法
人工智能·深度学习·神经网络·cnn·transformer·迁移学习
数据与后端架构提升之路1 小时前
从神经元到神经网络:深度学习的进化之旅
人工智能·神经网络·学习
爱技术的小伙子1 小时前
【ChatGPT】如何通过逐步提示提高ChatGPT的细节描写
人工智能·chatgpt
深度学习实战训练营2 小时前
基于CNN-RNN的影像报告生成
人工智能·深度学习
昨日之日20064 小时前
Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
人工智能·whisper·语音识别