当大模型参数量突破千亿,推理成本成为落地的最大障碍。MX 格式通过"共享指数"这一简单而优雅的设计,在存储和计算两端同时实现了数倍的效率提升。本文从浮点数的基本原理出发,逐步拆解 MX 的设计思路、量化过程、硬件优势与精度代价。
1. 浮点数:一切的起点
计算机表示小数的方式和科学计数法如出一辙:
value=(−1)sign×2exponent−bias×(1+mantissa) value = (-1)^{sign} \times 2^{exponent - bias} \times (1 + mantissa) value=(−1)sign×2exponent−bias×(1+mantissa)
三个字段各司其职:
| 字段 | 作用 | 类比 |
|---|---|---|
| Sign (符号) | 正/负 | 正负号 |
| Exponent (指数) | 数量级 | 10^n 中的 n |
| Mantissa (尾数) | 有效数字 | 1.234 中的 234 |
指数位越多,能表示的范围越大;尾数位越多,精度越高。 这是一个零和博弈------总位宽固定时,范围和精度不可兼得。
2. FP16 与 BF16:同样 16 位,不同的取舍
FP16 (IEEE 754 Half Precision)
[1 sign] [5 exponent] [10 mantissa] = 16 bits
- 范围:±65504
- 精度:约 3.3 位十进制
- 特点:精度好,但范围太小。大模型训练中梯度动辄超过 65504,导致 overflow
BF16 (Brain Floating Point)
[1 sign] [8 exponent] [7 mantissa] = 16 bits
- 范围:与 FP32 相同(±3.4×10³⁸)
- 精度:约 2.1 位十进制
- 特点:Google Brain 团队的实用主义产物。8 位指数 = FP32 同款,天然不溢出;与 FP32 互转只需截断/补零低 16 位尾数
为什么 BF16 成为主流?
答案在于容错性。深度学习对个别数值的微小误差不敏感(网络本身就在做近似),但对溢出零容忍------一个 inf 就能毁掉整个 batch。BF16 用 3 位尾数精度换来了 FP32 级别的动态范围,这笔交易在 AI 场景下极其划算。
3. 再进一步:共享指数的思想
观察:相邻数据的指数高度冗余
在实际的权重矩阵和激活张量中,相邻元素的数量级往往接近。如果一组 32 个 BF16 数的指数分别是:
[5, 5, 6, 5, 5, 6, 5, 5, 6, 6, 5, 5, ...]
每个数都花 8 bits 存指数,32 个数就是 256 bits 的指数。但它们之间的差异很小------能不能只存一个代表性的指数?
MX 格式的核心设计
一组 32 个元素共享一个 8-bit 指数,每个元素只保留符号 + 低精度尾数。
传统 BF16 (32 个元素):
32 × [1S + 8E + 7M] = 32 × 16 = 512 bits
MXFP4 (32 个元素):
1 × [8-bit 共享指数] + 32 × [4-bit 尾数] = 8 + 128 = 136 bits
压缩比:512 / 136 = 3.76×
4. 量化过程详解
以一组简单的数为例:[0.75, -3.5, 12.0, -0.25]
Step 1:确定共享指数
将每个数写成 1.xxx × 2^E 的形式:
| 值 | 浮点表示 | 指数 |
|---|---|---|
| 0.75 | 1.5 × 2⁻¹ | -1 |
| -3.5 | -1.75 × 2¹ | 1 |
| 12.0 | 1.5 × 2³ | 3 (最大) |
| -0.25 | -1.0 × 2⁻² | -2 |
E_shared = 3,取最大指数以保证最大值不溢出。
Step 2:归一化
所有值除以 2^E_shared = 8:
| 原始值 | 归一化 | MXFP4 量化 |
|---|---|---|
| 0.75 | 0.094 | → 0 |
| -3.5 | -0.438 | → -0.5 |
| 12.0 | 1.5 | → 1.5 (精确) |
| -0.25 | -0.031 | → 0 |
Step 3:反量化验证
| 还原值 | 原始值 | 相对误差 |
|---|---|---|
| 0 | 0.75 | 100% |
| -4.0 | -3.5 | 14% |
| 12.0 | 12.0 | 0% |
| 0 | -0.25 | 100% |
结论明确:最大值完美保留,与最大值相差越远的小值,精度损失越严重。
这不是 bug,而是 feature------或者说,是一个有意为之的 trade-off。
5. MX 格式家族
MX 不是一种格式,而是 OCP (Open Compute Project) 标准化的一族格式:
| 格式 | 元素位宽 | 元素结构 | 有效位宽 | 典型场景 |
|---|---|---|---|---|
| MXFP8 | 8 bits | E4M3 / E5M2 | 8.25 | 训练 |
| MXFP6 | 6 bits | E3M2 / E2M3 | 6.25 | 微调 |
| MXFP4 | 4 bits | E2M1 | 4.25 | 推理 |
| MXINT8 | 8 bits | 有符号整数 | 8.25 | 推理部署 |
有效位宽 = 元素位宽 + 8/32 = 元素位宽 + 0.25(共享指数的均摊成本)
MXINT8:特殊但重要
MXINT8 中的每个元素是纯 INT8 整数,没有自己的指数。共享指数充当"缩放因子"(scale factor)。
value_i = 2^(E_shared - 127) × int8_i
这本质上就是业界广泛使用的 per-group dynamic quantization 的硬件标准化。PyTorch 中的 torch.quantize_per_channel 做的事和这几乎一样,只是 MX 把它固化到了硬件指令集中。
MXFP vs MXINT:均匀分布 vs 对数分布
MXINT8 数轴 (均匀分布):
──┼──┼──┼──┼──┼──┼──┼──┼──→
-128 -64 0 64 127
间距相同,大值小值精度一样
MXFP8 数轴 (对数分布):
──┼┼┼┼──┼─┼─┼───┼──┼──┼────→
0 0.1 0.5 1 4 448
小值处密集,大值处稀疏
如果数据分布均匀,MXINT8 更优;如果数据集中在零附近(大多数权重如此),MXFP8 更优。
6. 共享指数矩阵乘法:硬件的真正收益
大模型 90% 以上的计算量是矩阵乘法 (GEMM)。共享指数的最大价值不在存储压缩,而在于从根本上简化了乘法器硬件。
数学推导
矩阵乘法的核心是点积。对于一组 32 个元素的点积:
∑k=031Ak×Bk \sum_{k=0}^{31} A_k \times B_k k=0∑31Ak×Bk
传统浮点下,每次乘法都需要:指数相加 → 尾数相乘 → 对齐移位 → 累加。
共享指数下:
∑k=031(2EA⋅mAk)×(2EB⋅mBk)=2EA+EB×∑k=031mAk×mBk \sum_{k=0}^{31} (2^{E_A} \cdot m_{A_k}) \times (2^{E_B} \cdot m_{B_k}) = 2^{E_A + E_B} \times \sum_{k=0}^{31} m_{A_k} \times m_{B_k} k=0∑31(2EA⋅mAk)×(2EB⋅mBk)=2EA+EB×k=0∑31mAk×mBk
指数加法提到求和号外面,只执行 1 次。里面的 32 次乘加全部变成低精度整数运算。
硬件成本对比
| 传统 FP16 GEMM | MXFP4 GEMM | |
|---|---|---|
| 指数运算 | 每次乘法 1 次 (32 次) | 整组 1 次 |
| 乘法器 | 10-bit × 10-bit | 1-bit × 1-bit |
| 乘法器面积 | 基准 | ~1/100 |
| 累加器 | FP32 | INT 累加 + 最终转浮点 |
| 整体能效 | 基准 | 4-8× 提升 |
MXFP4 的 1-bit 尾数乘法本质上就是一个 AND 门------几乎不占面积。这意味着同样的芯片面积可以放下多得多的 MAC 单元。
7. 精度分析:BF16 转 MX 会丢什么?
两个损失来源
来源 1:尾数截断
BF16 有 7 位尾数,MXFP4 只有 1 位。即使指数完美匹配,7→1 的量化也不可避免地丢失信息。
来源 2:指数对齐导致的有效位右移
组内元素被迫使用统一的共享指数。对于指数较小的元素,其尾数需要右移 (E_shared - E_i) 位来补偿,右移出去的位就丢失了。
组内最大指数 E_shared = 5
元素 x (指数=3): 右移 2 位, 低 2 位丢失
元素 y (指数=-2): 右移 7 位, 几乎完全丢失
什么时候精度够用?
经验法则:组内动态范围(最大值/最小值)不超过尾数能表示的范围时,精度损失可控。
| 格式 | 尾数位 | 可表达动态范围 | 适用条件 |
|---|---|---|---|
| MXFP8 (E4M3) | 3 bits | ~16× | 大多数层 |
| MXFP6 (E2M3) | 3 bits | ~16× | 大多数层 |
| MXFP4 (E2M1) | 1 bit | ~3× | 分布均匀的层 |
Outlier 是最大的敌人。 一组中如果有一个异常大的值,共享指数被它绑架,其余 31 个元素的精度全部受损。这也是为什么 LLM 量化研究中,处理 activation outlier(如 SmoothQuant、GPTQ 等方法)是核心课题。
8. 工程实践中的注意事项
组大小的选择
组越小 → 共享指数越多 → 组内动态范围越小 → 精度越好,但指数存储开销越大。
| 组大小 | 指数均摊开销 | 精度 | 硬件复杂度 |
|---|---|---|---|
| 32 (标准) | 0.25 bits/元素 | 基准 | 基准 |
| 16 | 0.5 bits/元素 | 更好 | 更复杂 |
| 8 | 1 bit/元素 | 最好 | 最复杂 |
MX 标准选定 32 作为默认组大小,是精度和效率的平衡点。
混合精度策略
实践中几乎不会全网络使用同一种 MX 格式:
- 权重: MXFP4 (静态分布,量化友好)
- 激活: MXFP8 (动态分布,需要更多精度)
- 累加器: FP32 (避免误差累积)
- 关键层 (如 attention 的 softmax): 保持 BF16/FP16
与 CIM (Compute-In-Memory) 架构的天然适配
CIM 架构将权重固化在存储单元中,使用模拟或数字计算完成乘加。MX 格式在这个场景下优势更为突出:
- MXFP4 权重只需要 4 bits/cell,存储密度翻倍
- 共享指数可以在读出电路统一处理,不增加单元复杂度
- 低精度乘法天然匹配 CIM 的模拟计算精度限制
9. 总结
| 格式 | 核心理念 | 适用场景 |
|---|---|---|
| FP16 | 高精度,小范围 | 推理(非 LLM) |
| BF16 | 大范围,够用的精度 | 训练主流 |
| MXFP8 | 共享指数 + 8bit 元素 | 低精度训练/高精度推理 |
| MXFP4 | 共享指数 + 4bit 元素 | 极致推理压缩 |
| MXINT8 | 共享指数 + 整数元素 | 推理部署 |
从 FP32 → BF16 → MX,每一步都在用精度换效率。 关键在于找到"够用的精度"------神经网络的容错性让这条路走得比想象中更远。MX 格式的贡献在于,它把"共享指数"这个朴素的想法标准化了,让芯片设计者可以放心地围绕它构建硬件。
参考:OCP Microscaling Formats (MX) Specification v1.0, 2023