本文基于CANN开源社区的shmem仓库进行技术解读
CANN组织地址:https://atomgit.com/cann
shmem仓库地址:https://atomgit.com/cann/shmem
前言
多机多卡训练时,不同设备之间需要频繁交换数据。传统的消息传递方式(如MPI)需要发送方和接收方都参与,效率不高。
SHMEM(Shared Memory)提供了一种更高效的方式------单边通信,发送方可以直接访问接收方的内存,无需接收方主动配合。
什么是SHMEM
SHMEM是基于OpenSHMEM标准的共享内存通信库:
传统双边通信(MPI):
进程A发送 → 网络 → 进程B接收
(需要双方配合)
SHMEM单边通信:
进程A直接写入进程B的内存
(只需发送方操作)
架构:
应用程序
↓
SHMEM API
↓
RDMA/RoCE(远程直接内存访问)
↓
网络硬件
↓
远程设备内存
核心概念
1. PE(Processing Element)
SHMEM中的处理单元,通常对应一个进程或设备:
python
import shmem
# 初始化SHMEM
shmem.init()
# 获取PE信息
my_pe = shmem.my_pe() # 我是第几号PE
n_pes = shmem.n_pes() # 总共多少个PE
print(f"我是PE {my_pe},共有{n_pes}个PE")
# 清理
shmem.finalize()
2. 对称内存
所有PE都能访问的内存区域:
python
# 分配对称内存
# 所有PE都会分配相同大小的内存
data = shmem.malloc(1024 * sizeof(float))
# 每个PE都可以访问其他PE的这块内存
3. 单边操作
不需要接收方参与的操作:
PE 0: shmem_put(data, target_pe=1) # 直接写入PE 1的内存
PE 1: (不需要任何操作,数据自动到达)
核心API
1. Put操作(写)
将本地数据写到远程PE:
c
// C语言示例
#include <shmem.h>
int main() {
shmem_init();
int me = shmem_my_pe();
int npes = shmem_n_pes();
// 分配对称内存
float *data = shmem_malloc(100 * sizeof(float));
// 初始化数据
for (int i = 0; i < 100; i++) {
data[i] = me * 100 + i;
}
// PE 0 将数据写入 PE 1
if (me == 0) {
float local_data[100];
for (int i = 0; i < 100; i++) {
local_data[i] = i * 2.0f;
}
shmem_float_put(data, local_data, 100, 1); // 写入PE 1
}
// 同步
shmem_barrier_all();
// PE 1 检查数据
if (me == 1) {
printf("PE 1 received: %f, %f, ...\n", data[0], data[1]);
}
shmem_free(data);
shmem_finalize();
return 0;
}
2. Get操作(读)
从远程PE读取数据:
c
// PE 0 从 PE 1 读取数据
if (me == 0) {
float remote_data[100];
shmem_float_get(remote_data, data, 100, 1); // 从PE 1读取
printf("Read from PE 1: %f, %f, ...\n", remote_data[0], remote_data[1]);
}
3. 原子操作
保证并发安全的操作:
c
// 原子加法
long *counter = shmem_malloc(sizeof(long));
*counter = 0;
shmem_barrier_all();
// 所有PE原子地增加counter
shmem_long_atomic_add(counter, 1, 0); // 在PE 0的counter上加1
shmem_barrier_all();
// PE 0 查看结果
if (me == 0) {
printf("Counter = %ld (should be %d)\n", *counter, npes);
}
常用原子操作:
c
// 原子加
shmem_int_atomic_add(&target, value, pe);
// 原子比较并交换
old = shmem_int_atomic_compare_swap(&target, cond, value, pe);
// 原子交换
old = shmem_int_atomic_swap(&target, value, pe);
// 原子取值并加
old = shmem_int_atomic_fetch_add(&target, value, pe);
4. 同步操作
c
// 全局屏障:所有PE都到达这个点才继续
shmem_barrier_all();
// 等待之前的操作完成
shmem_quiet();
// 等待特定PE的操作完成
shmem_fence();
5. 集合操作
c
// Broadcast:广播数据
long *data = shmem_malloc(sizeof(long));
if (me == 0) {
*data = 42;
}
shmem_broadcast64(data, data, 1, 0, 0, 0, npes, pSync);
// Reduction:规约操作
long *sum = shmem_malloc(sizeof(long));
*sum = me;
shmem_long_sum_to_all(sum, sum, 1, 0, 0, npes, pWrk, pSync);
// 结果:sum = 0 + 1 + 2 + ... + (npes-1)
// Collect:收集数据
long *collected = shmem_malloc(npes * sizeof(long));
long local_value = me * 10;
shmem_collect64(collected, &local_value, 1, 0, 0, npes, pSync);
使用场景
场景一:大模型训练
多机多卡训练时交换梯度和参数:
c
// 伪代码
void sync_gradients() {
int me = shmem_my_pe();
int npes = shmem_n_pes();
// 每个PE计算自己的梯度
float *local_grad = compute_gradients();
// 使用SHMEM进行AllReduce
float *global_grad = shmem_malloc(grad_size * sizeof(float));
// 方法1:使用集合操作
shmem_float_sum_to_all(global_grad, local_grad, grad_size,
0, 0, npes, pWrk, pSync);
// 方法2:手动实现Ring AllReduce
int next_pe = (me + 1) % npes;
int prev_pe = (me - 1 + npes) % npes;
for (int step = 0; step < npes - 1; step++) {
// 发送给下一个PE
shmem_float_put(recv_buffer, send_buffer, chunk_size, next_pe);
shmem_quiet();
// 累加接收到的数据
for (int i = 0; i < chunk_size; i++) {
send_buffer[i] += recv_buffer[i];
}
}
}
场景二:分布式推理
模型参数分布在多个设备上:
c
// 设备0有模型的前半部分
// 设备1有模型的后半部分
if (me == 0) {
// 前向传播前半部分
float *intermediate = forward_part1(input);
// 将中间结果发送给设备1
shmem_float_put(remote_input, intermediate, size, 1);
shmem_quiet();
}
if (me == 1) {
// 等待设备0的数据
shmem_barrier_all();
// 前向传播后半部分
float *output = forward_part2(remote_input);
}
场景三:参数服务器
实现分布式参数服务器:
c
// PE 0 作为参数服务器
if (me == 0) {
float *params = shmem_malloc(param_size * sizeof(float));
initialize_params(params);
// 等待worker更新
while (training) {
shmem_barrier_all();
// 参数已被worker更新
}
}
// 其他PE作为worker
if (me > 0) {
// 读取参数
float *local_params = malloc(param_size * sizeof(float));
shmem_float_get(local_params, params, param_size, 0);
// 计算梯度
float *grads = compute_gradients(local_params);
// 更新参数(原子操作)
for (int i = 0; i < param_size; i++) {
shmem_float_atomic_add(¶ms[i], -learning_rate * grads[i], 0);
}
shmem_barrier_all();
}
性能优化
1. 批量操作
c
// 不好:多次小传输
for (int i = 0; i < 1000; i++) {
shmem_float_put(&remote[i], &local[i], 1, target_pe);
}
// 好:一次大传输
shmem_float_put(remote, local, 1000, target_pe);
2. 非阻塞操作
c
// 使用非阻塞操作重叠通信和计算
shmem_float_put_nbi(remote, local, size, target_pe);
// 在通信进行时做其他计算
do_computation();
// 等待通信完成
shmem_quiet();
3. 内存对齐
c
// 对齐到缓存行,提升性能
#define CACHE_LINE_SIZE 64
float *data = shmem_align(CACHE_LINE_SIZE, size * sizeof(float));
与HCCL的对比
| 特性 | SHMEM | HCCL |
|---|---|---|
| 通信模式 | 单边通信 | 集合通信 |
| 编程模型 | 显式内存操作 | 隐式同步 |
| 灵活性 | 高 | 中 |
| 易用性 | 中 | 高 |
| 适用场景 | 细粒度控制 | 标准并行模式 |
使用建议:
HCCL:
- 数据并行训练
- 标准的AllReduce/Broadcast
- 快速开发
SHMEM:
- 需要细粒度控制
- 自定义通信模式
- 性能极致优化
调试技巧
1. 检查对称性
c
// 确保所有PE分配相同大小的对称内存
void *ptr = shmem_malloc(size);
if (ptr == NULL) {
fprintf(stderr, "PE %d: shmem_malloc failed\n", shmem_my_pe());
shmem_global_exit(1);
}
2. 同步检查
c
// 在关键点添加屏障,确保同步
shmem_barrier_all();
// 检查数据一致性
if (me == 0) {
for (int pe = 1; pe < npes; pe++) {
float remote_val;
shmem_float_get(&remote_val, &data[0], 1, pe);
printf("PE %d data[0] = %f\n", pe, remote_val);
}
}
3. 性能测试
c
#include <sys/time.h>
double get_time() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec + tv.tv_usec * 1e-6;
}
// 测试带宽
double start = get_time();
for (int i = 0; i < iterations; i++) {
shmem_float_put(remote, local, size, target_pe);
shmem_quiet();
}
double end = get_time();
double bandwidth = (size * sizeof(float) * iterations) / (end - start) / 1e9;
printf("Bandwidth: %.2f GB/s\n", bandwidth);
常见问题
问题1:内存不对称
c
// 错误:不同PE分配不同大小
if (me == 0) {
data = shmem_malloc(1000 * sizeof(float));
} else {
data = shmem_malloc(500 * sizeof(float)); // 错误!
}
// 正确:所有PE分配相同大小
data = shmem_malloc(1000 * sizeof(float));
问题2:缺少同步
c
// 错误:没有同步就访问
shmem_float_put(remote, local, size, target_pe);
// 立即访问remote(可能还没完成)
// 正确:等待完成
shmem_float_put(remote, local, size, target_pe);
shmem_quiet(); // 等待完成
问题3:死锁
c
// 错误:循环依赖
if (me == 0) {
shmem_float_get(data, remote, size, 1);
}
if (me == 1) {
shmem_float_get(data, remote, size, 0);
}
// 两个PE互相等待,死锁
// 正确:避免循环依赖
if (me == 0) {
shmem_float_put(remote, data, size, 1);
}
shmem_barrier_all();
应用场景总结
场景一:高性能计算
科学计算中的大规模并行。
场景二:深度学习
多机多卡训练的底层通信。
场景三:图计算
分布式图处理。
场景四:数据分析
大规模数据并行处理。
总结
SHMEM是基于OpenSHMEM的共享内存通信库:
- 单边通信模式,效率高
- 支持Put/Get/原子操作
- 提供同步和集合操作
- 适合需要细粒度控制的场景
- 是HCCL的底层支撑
对于需要深入理解分布式通信的开发者,SHMEM是重要的工具。
相关链接
shmem仓库地址:https://atomgit.com/cann/shmem
CANN组织地址:https://atomgit.com/cann
hccl仓库地址:https://atomgit.com/cann/hccl
OpenSHMEM官网:http://openshmem.org