Relay算子注册(在pytorch.py端调用)

1. Relay算子注册 (C++层)

(a) 算子属性注册

路径 : src/relay/op/nn/nn.cc

cpp 复制代码
RELAY_REGISTER_OP("hardswish")
  .set_num_inputs(1)
  .add_argument("data", "Tensor", "Input tensor.")
  .set_support_level(3)
  .add_type_rel("Identity", Identity);
(b) 调用节点构造

路径 : src/relay/op/nn/activation.cc

cpp 复制代码
TVM_REGISTER_GLOBAL("relay.op._make.hardswish")
  .set_body_typed([](Expr data) {
    static const Op& op = Op::Get("hardswish");
    return Call(op, {data}, Attrs(), {});
  });

2. TOPI计算实现 (C++层)

© TOPI注册入口

路径 : src/topi/elemwise.cc

cpp 复制代码
TVM_REGISTER_GLOBAL("topi.hardswish")
  .set_body([](TVMArgs args, TVMRetValue* rv) {
    *rv = hardswish(args[0]);
  });
(d) 数学内核实现

路径 : include/tvm/topi/nn.h

cpp 复制代码
inline Tensor hardswish(const Tensor& x, std::string name = "T_hardswish") {
  auto three = make_const(x->dtype, 3);
  auto six = make_const(x->dtype, 6);
  return compute(
    x->shape,
    [&](const Array<Var>& i) {
      return x(i) * max(min(x(i) + three, six), 0) / six;
    },
    name, kElementWise
  );
}

3. Python接口层

(e) Relay Python API

路径 : python/tvm/relay/op/nn/_nn.py

python 复制代码
def hardswish(data):
    return _make.hardswish(data)
(f) TOPI通用接口

路径 : python/tvm/topi/nn.py

python 复制代码
@tvm.target.generic_func
def hardswish(x):
    return cpp.hardswish(x)

4. 计算调度注册

(g) Compute注册

路径 : python/tvm/relay/op/strategy/generic.py

python 复制代码
@register_compute("hardswish")
def hardswish_compute(attrs, inputs, out_type):
    return [topi.hardswish(inputs[0])]
(h) 调度策略

路径: `python/tvm/relay/op/op.py**

python 复制代码
register_broadcast_schedule("hardswish")
register_shape_func("hardswish", False, elemwise_shape_func)

5. 硬件专用实现

(i) NPU支持声明

路径: `src/relay/backend/contrib/npu/src/op_map.cc**

cpp 复制代码
const std::vector<std::string> _NPU_OP = {
  ...,
  "hardswish"  // 添加算子名
};
(j) NPU内核实现

路径: `python/tvm/relay/backend/contrib/npu/ops.py**

python 复制代码
def custom_hardswish(x):
    x1 = custom_add(x, te.extern_scalar_value(3.0))
    x2 = custom_relu(x1)
    return npu_hardwish(x2, ...)
(k) NPU策略注册

路径: `python/tvm/relay/op/strategy/npu.py**

python 复制代码
@hardswish.register("npu")
def hardswish_npu(x):
    return npu_api.custom_hardswish(x)

6. 前端框架对接

(l) PyTorch转换器

路径: `python/tvm/relay/frontend/pytorch.py**

python 复制代码
def _hardswish():
    def _impl(inputs, input_types):
        return _op.hardswish(inputs[0])
    return _impl

关键文件路径总结

功能模块 关键路径
Relay核心注册 src/relay/op/nn/{nn.cc, activation.cc}
TOPI计算 {include,src}/topi/{nn.h, elemwise.cc}
Python接口 python/tvm/{relay/op/nn/_nn.py, topi/nn.py}
策略注册 python/tvm/relay/op/strategy/{generic.py, npu.py}
硬件后端 src/relay/backend/contrib/npu/
前端对接 python/tvm/relay/frontend/pytorch.py

开发流程示意图

Relay注册 TOPI实现 Python接口 硬件后端 前端框架

通过这种清晰的路径划分,TVM实现了:

  1. 模块化开发:各层级代码物理隔离
  2. 可扩展性:新增硬件只需在对应目录添加实现
  3. 维护性:相关功能的代码集中存放
相关推荐
历程里程碑2 分钟前
Linux 18 进程控制
linux·运维·服务器·开发语言·数据结构·c++·笔记
AI科技3 分钟前
原创音乐人提升写歌数量,AI编曲软件实现创作周期大幅缩短
人工智能
亲爱的非洲野猪4 分钟前
从约束到互联:LLM生态中Rules、Tools、Skills与MCP的演进史
人工智能
jay神4 分钟前
基于MobileNet花卉识别系统
人工智能·深度学习·计算机视觉·毕业设计·花卉识别
云卓SKYDROID5 分钟前
无人机故障诊断技术模块要点!
人工智能·无人机·高科技·云卓科技·故障模块
m0_603888715 分钟前
VEQ Modality-Adaptive Quantization for MoE Vision-Language Models
人工智能·ai·语言模型·自然语言处理·论文速览
智驱力人工智能6 分钟前
无人机目标检测 低空安全治理的工程实践与价值闭环 无人机缺陷识别 农业无人机作物长势分析系统 森林防火无人机火点实时识别
人工智能·opencv·安全·yolo·目标检测·无人机·边缘计算
zhangfeng11337 分钟前
大语言模型llm 量化模型 跑在 边缘设备小显存显卡 GGUF GGML PyTorch (.pth, .bin, SafeTensors)
人工智能·pytorch·深度学习·语言模型
纤纡.7 分钟前
深度学习环境搭建:CUDA+PyTorch+TorchVision+Torchaudio 一站式安装教程
人工智能·pytorch·深度学习
方见华Richard8 分钟前
《认知几何学:思维如何弯曲意义空间》补充材料
人工智能·经验分享·交互·原型模式·空间计算