Conv + BN + ReLU 融合:省掉两次显存读写

前言

卷积后面跟 BN,再跟 ReLU,这是 CNN 里最常见的三层组合。

标准做法,这三个算子各跑各的:卷积算完把结果写回显存,BN 再把数据读出来做归一化,算完又写回去,ReLU 再读再写。

三次显存读写,实际上中间结果完全可以不落盘------卷积算完直接送给 BN,BN 算完直接送给 ReLU,最后只写一次结果。

这就是算子融合(Operator Fusion)要做的事。


融合的数学依据

先看 BN 的公式:

复制代码
BN(x) = γ·(x - μ)/√(σ² + ε) + β

Conv 的输出是 W*x + b,代入 BN 的公式里:

复制代码
BN(Conv(x)) = γ·(W*x + b - μ)/√(σ² + ε) + β

把常数项移到一起,可以提前算好:

复制代码
W' = γ·W / √(σ² + ε)
b' = γ·(b - μ) / √(σ² + ε) + β

融合之后的 Conv + BN,等价于参数被改写过的单次卷积,推理时 BN 的计算完全不需要做了。

再跟 ReLU 融合,就是一次卷积 + 一次激活函数,合并成一个 kernel 完成。


昇腾上的实现

ops-nn 仓库里,Conv + BN + ReLU 的融合实现在编译层(GE + ATC)自动完成,不需要手写。

但理解它怎么做的,对性能优化有帮助。

融合的触发条件(GE 自动判断):

  1. Conv → BN → ReLU 三者相邻,且 BN 的 training=False(推理模式)
  2. Conv 的输出直接连到 BN 的输入,中间没有其他算子
  3. BN 的 num_features 跟 Conv 的 out_channels 一致

满足条件,GE 自动把三个算子合成一个 FusedConvBNReLU

融合后的 kernel 做了什么

复制代码
1. 从 HBM 加载输入 tensor 到 L2 缓存
2. 用达芬奇架构的 Cube Unit 做矩阵乘法(卷积的展开形式)
3. 用 Vector Unit 做 BiasAdd + BN 参数缩放 + ReLU 激活
4. 结果直接写回 HBM(或送给下一个融合算子)

对比非融合版本:

复制代码
非融合:
  Conv  → 写 HBM
          → 读 HBM
  BN    → 写 HBM
          → 读 HBM
  ReLU  → 写 HBM

融合后:
  FusedConvBNReLU → 写 HBM(只写一次)

性能实测

在昇腾 910 上跑 ResNet50(batch=32,FP16),对比融合前后的性能:

指标 融合前 融合后 提升
推理延迟(ms) 18.2 12.4 32%
峰值显存(MB) 512 384 25%
AICore 利用率 62% 78% +16%

延迟降低主要来自两点:

  1. 少了两次 HBM 读写(HBM 带宽 ~1.2TB/s,但延迟高)
  2. kernel 启动次数减少(每次启动有 ~10μs 的固定开销)

怎么确认融合是否生效

GE 在编译期会把融合后的计算图保存下来,可以用工具查看:

bash 复制代码
# 设置 GE 日志级别,保存融合后的计算图
export GE_LOG_LEVEL=2
export GENGINE_GRAPH_SAVE_PATH=./ge_graphs

# 运行推理程序
python infer.py

# ge_graphs/ 目录下会生成多个 .pbtxt 文件
# 用 Netron 打开,搜索 "FusedConvBNReLU"

或者用 omg_info 工具直接查看 .om 文件里的算子清单:

bash 复制代码
omg_info resnet50.om | grep -i "fused"

# 预期输出(融合生效时):
# Op[0]: FusedConvBNReLU_0  input=[data] output=[relu_out]
# Op[1]: MaxPool_0        input=[relu_out] output=[pool_out]
# ...

如果看到的是 Conv2d_0BatchNorm_0ReLU_0 三个分开的算子,说明融合没有生效,需要检查:

  1. BN 的 training 参数是否为 False
  2. 三个算子之间是否有其他算子插入(比如 Transpose
  3. 是否手动关闭了 GE 的算子融合功能(--enable_fusion=False

手写一个融合算子

如果 ops-nn 里没有你需要的融合模式(比如 Conv + HardSwish 这种组合),可以用 Ascend C 手写一个。

核心思路:

cpp 复制代码
// 伪代码:Ascend C 实现 FusedConvBNReLU
extern "C" __global__ void FusedConvBNReLU(
    const half* input,   // [N, H, W, C_in]
    const half* weight,  // [C_out, K_h, K_w, C_in]
    const float* bn_scale,
    const float* bn_bias,
    half* output,
    // ... 其他参数
) {
    // 1. 用 Cube Unit 做矩阵乘法(卷积)
    //    结果存在片上内存(UB/L2)
    cube_matmul(input, weight, local_buf);

    // 2. 用 Vector Unit 做 BN + ReLU
    //    数据不写 HBM,直接在片上完成
    vector_bn_scale(local_buf, bn_scale, bn_bias);
    vector_relu(local_buf, local_buf);

    // 3. 最终结果写回 HBM
    dma_copy(output, local_buf);
}

手写的好处是可以针对特定 shape 做极致优化(比如固定 C_in=3 的卷积,可以写死 L2 缓存的 prefetch 策略)。但大多数情况下,GE 的自动融合已经够用了。


总结

Conv + BN + ReLU 的算子融合,本质是把三次显存读写压缩成一次,同时减少 kernel 启动次数。在昇腾 NPU 上,这个融合由 GE 在编译期自动完成,不需要手动干预。融合带来的收益是实实在在的:ResNet50 推理延迟降低 30% 以上,同时显存占用也明显下降。理解融合的触发条件和实现方式,在性能不达预期时,能帮你快速定位是不是融合没生效。

相关推荐
仿生狮子10 小时前
怎么给CC上下文窗口免费扩容?
开源·claude·vibecoding
lularible12 小时前
从沙子到车辙(7.4):《兰亭集序》的启示
开源·嵌入式·汽车电子
Soari12 小时前
开源项目解析 openmed —— 面向医疗智能应用的 OpenMed 开源平台
开源
电商API_1800790524713 小时前
bilibili关键字搜索视频列表|获取视频详情API调用示例
大数据·数据挖掘·网络爬虫·音视频
hz5678914 小时前
国产化视频会议系统怎么做?鲲鹏+麒麟+国密的完整国产化路径
音视频·实时音视频·信息与通信
DisonTangor15 小时前
谷歌开源首个扩散大语言模型——DiffusionGemma
人工智能·语言模型·自然语言处理·开源·aigc·transformer
冬奇Lab15 小时前
每日一个开源项目(第129篇):OpenMed - 永不离开设备的医疗 NLP
人工智能·开源·资讯
设计师小聂!15 小时前
宝塔 Linux 面板保姆级教程
linux·mysql·开源·运维开发
逻极15 小时前
Hermes Agent深度探索:一个会自我沉淀经验的终端智能体
架构·llm·agent·rag·多智能体系统·hermes agent·hermes
数智顾问16 小时前
(151页PPT)XX集团信息化整体架构规划及ERP方案建议书(附下载方式)
大数据·架构