《Relay IR的基石:expr.h 中的表达式类型系统剖析》

TVM Relay源码深度解读

文章目录

  • [TVM Relay源码深度解读](#TVM Relay源码深度解读)
  • [一 、从Constant看Relay表达式的设计哲学](#一 、从Constant看Relay表达式的设计哲学)
    • [1. 类定义概述](#1. 类定义概述)
    • [2. `ConstantNode` 详解](#2. ConstantNode 详解)
      • [1. 核心成员](#1. 核心成员)
      • [2. 关键方法](#2. 关键方法)
      • [3. 类型系统注册](#3. 类型系统注册)
    • [3. `Constant` 详解](#3. Constant 详解)
      • [1. 核心功能](#1. 核心功能)
  • [二. 核心内容概述](#二. 核心内容概述)
    • [(1) Relay表达式基类](#(1) Relay表达式基类)
    • [(2) 具体表达式类型](#(2) 具体表达式类型)
      • [1. 表达式类型 VarNode举例子](#1. 表达式类型 VarNode举例子)
        • [1. 核心设计理念](#1. 核心设计理念)
        • [2. 关键成员解析](#2. 关键成员解析)
          • [(1) 核心字段](#(1) 核心字段)
          • [(2) 特殊方法](#(2) 特殊方法)
        • [3. 变量标识系统](#3. 变量标识系统)
          • [(1) vid (Unique ID)](#(1) vid (Unique ID))
          • [(2) name_hint 与 vid 的关系](#(2) name_hint 与 vid 的关系)
        • [4. 类型系统整合](#4. 类型系统整合)
          • [(1) 类型注解流程](#(1) 类型注解流程)
          • [(2) 类型推导规则](#(2) 类型推导规则)
        • [5. 内存模型与跨语言交互](#5. 内存模型与跨语言交互)
          • [(1) C++ 层构造](#(1) C++ 层构造)
          • [(2) Python 绑定](#(2) Python 绑定)
          • [**(3) 对象生命周期**](#(3) 对象生命周期)
        • [6. 关键应用场景](#6. 关键应用场景)
          • [(1) 函数参数定义](#(1) 函数参数定义)
          • [(2) 优化 Pass 中的变量处理](#(2) 优化 Pass 中的变量处理)
          • [(3) 类型检查](#(3) 类型检查)
        • [7. 设计亮点总结](#7. 设计亮点总结)
        • [8. 典型问题分析](#8. 典型问题分析)
    • [(3) TVM_DECLARE_BASE_OBJECT_INFO 宏详解](#(3) TVM_DECLARE_BASE_OBJECT_INFO 宏详解)
      • [1. 宏的参数](#1. 宏的参数)
      • [2. 静态断言检查(防止非法继承)](#2. 静态断言检查(防止非法继承))
      • [2. 运行时类型索引(RuntimeTypeIndex)](#2. 运行时类型索引(RuntimeTypeIndex))
      • [3. 动态分配类型索引(_GetOrAllocRuntimeTypeIndex)](#3. 动态分配类型索引(_GetOrAllocRuntimeTypeIndex))
      • 通俗版解释:TVM的类型身份证系统
        • [1. 为什么要办身份证?](#1. 为什么要办身份证?)
        • [2. 办证过程(宏的作用)](#2. 办证过程(宏的作用))
        • [3. 特殊班级(FINAL版)](#3. 特殊班级(FINAL版))
        • [4. 实际有什么用?](#4. 实际有什么用?)
        • 举个栗子🌰
        • 一句话总结
      • [(4) 遍历接口](#(4) 遍历接口)
        • [1. C++ 场景示例](#1. C++ 场景示例)
          • [(1) 模型序列化(保存为JSON)](#(1) 模型序列化(保存为JSON))
          • [(2) 优化Pass中的常量修改](#(2) 优化Pass中的常量修改)
          • [(3) 调试打印](#(3) 调试打印)
        • [2. Python 场景示例](#2. Python 场景示例)
          • [(1) 直接属性访问](#(1) 直接属性访问)
          • [(2) 模型保存与加载](#(2) 模型保存与加载)
          • [(3) 自定义属性访问器](#(3) 自定义属性访问器)

一 、从Constant看Relay表达式的设计哲学

在TVM的Relay IR中,即使是看似简单的常量表达式relay.const(1),其背后也隐藏着整个类型系统的精妙设计。让我们从include/tvm/relay/expr.h中的Constant类入手,逐步拆解..."

1. 类定义概述

类名 继承关系 角色 关键特性
ConstantNode public ExprNode 常量表达式的实际数据存储 包含常量数据(NDArray)、类型信息,并实现属性访问、哈希和相等比较逻辑。
Constant public RelayExpr 常量表达式的智能指针封装 提供用户友好的构造函数和访问方法,隐藏内存管理细节。

2. ConstantNode 详解

cpp 复制代码
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
  bool is_scalar() const { return data->ndim == 0; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
    return equal(data, other->data);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};

1. 核心成员

  • data (runtime::NDArray)

    • 存储常量张量的实际数据(如权重、偏置等),TVM 使用 NDArray 统一表示多维数组。
    • 示例:卷积层的权重矩阵会被存储在这里。
  • tensor_type()

    • 根据 data 的维度(shape)和数据类型(dtype)自动生成对应的 TensorType
    • 用途:类型推断时确定常量的类型。
  • is_scalar()

    • 判断常量是否为标量(0维张量),如 data->ndim == 0

2. 关键方法

  • VisitAttrs

    • 实现属性的序列化/反序列化,支持以下字段:

      cpp 复制代码
      v->Visit("data", &data);          // 张量数据
      v->Visit("span", &span);         // 源码位置信息
      v->Visit("mdata", &mdata);       // 元数据(如调试信息)
      v->Visit("_checked_type_", &checked_type_);  // 类型检查后的类型
  • SEqualReduceSHashReduce

    • 结构化相等比较 :比较两个 ConstantNodedata 是否相同(用于优化中的常量折叠)。
    • 哈希计算 :基于 data 生成哈希值(用于快速查找重复常量)。

3. 类型系统注册

cpp 复制代码
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
  • _type_key = "relay.Constant":唯一标识常量节点类型。
  • FINAL:禁止继承,确保常量节点的行为不可被修改。

3. Constant 详解

cpp 复制代码
class Constant : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param data The data of the constant tensor.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};

1. 核心功能

  • 构造函数

    cpp 复制代码
    explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());
    • 接收 NDArray 数据,构造一个常量表达式。

    • 示例

      python 复制代码
      # Python 前端等价代码
      data = np.array([1, 2, 3], dtype="float32")
      const_expr = relay.Constant(tvm.nd.array(data))
  • 智能指针方法

    cpp 复制代码
    TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);

    展开后提供:

    • operator->():直接访问 ConstantNode 成员(如 const_expr->data)。
    • get():获取底层 ConstantNode 指针。
    • 自动内存管理(通过 ObjectRef 的引用计数)。

二. 核心内容概述

在TVM源码中,include/tvm/relay/expr.hRelay IR(中间表示)的核心头文件 ,定义了所有Relay表达式的基础数据结构和类型系统。它是实现TVM高层计算图表示的关键组成部分。以下是该文件的详细解析:
相关重要文件

文件路径 关联内容
include/tvm/relay/type.h 类型系统(TensorType等)
include/tvm/relay/op.h 运算符定义
include/tvm/relay/adt.h 代数数据类型支持
src/relay/ir/expr.cc 表达式方法的实现

include/tvm/relay/expr.h文件主要包含:

  • (1) Relay表达式基类RelayExpr/RelayExprNode
  • (2) 所有具体表达式类型的声明(如变量、常量、函数调用等)
  • (3) 表达式类型的遍历和转换接口
  • (4) 类型系统和属性访问的支持

(1) Relay表达式基类

cpp 复制代码
class RelayExprNode : public BaseExprNode { /*...*/ };
class RelayExpr : public BaseExpr { /*...*/ };
  • 角色:所有Relay表达式的公共基类
  • 功能
    • 提供类型系统支持(通过checked_type_字段)
    • 实现属性访问(VisitAttrs
    • 支持结构化相等比较(SEqualReduce

1. RelayExprNode 和 RelayExpr 的区别与用法

RelayExprNode 是 Relay 表达式的实际实现类,是一个 C++ 类,包含了表达式的所有数据和功能实现。它是所有 Relay 表达式类型的基类。

RelayExpr 是一个智能指针(relay::Expr),它指向 RelayExprNode 或其子类的实例。它提供了对 RelayExprNode 的安全访问和管理。

2. 主要区别

特性 RelayExprNode RelayExpr
类型 C++ 类 智能指针(std::shared_ptr 的封装)
生命周期管理 需要手动管理 自动管理
使用方式 通常不直接使用,作为实现细节 用户主要交互的接口
继承关系 作为基类定义表达式结构 作为访问接口

3. 使用模式

在 TVM 中,通常的模式是:

  1. 定义一个继承自 RelayExprNode 的具体表达式节点类
  2. 使用 RelayExpr 作为这些节点的引用
例子1:常量表达式
cpp 复制代码
// 创建一个常量表达式
auto const_node = relay::ConstantNode::make(tvm::runtime::NDArray::Zeros(...));
RelayExpr const_expr = const_node;

// 通常更简洁的写法
RelayExpr const_expr = relay::Constant(tvm::runtime::NDArray::Zeros(...));
例子2:变量表达式
cpp 复制代码
// 创建一个变量表达式
auto var_node = relay::VarNode::make("x", relay::Type());
RelayExpr var_expr = var_node;

// 或者更简洁地
RelayExpr var_expr = relay::Var("x", relay::Type());
例子3:函数应用
cpp 复制代码
// 创建函数应用表达式
RelayExpr func = ...; // 某个函数
RelayExpr arg = ...;  // 某个参数
auto call_node = relay::CallNode::make(func, {arg});
RelayExpr call_expr = call_node;

// 或者
RelayExpr call_expr = relay::Call(func, {arg});

4. 实际使用建议

  1. 用户代码 :在大多数情况下,你应该使用 RelayExpr 而不是直接操作 RelayExprNode

  2. 扩展 Relay :如果你想定义新的表达式类型,需要继承 RelayExprNode 并实现相应接口。

  3. 类型转换 :可以使用 as<T> 方法将 RelayExpr 向下转换为特定类型的节点指针:

cpp 复制代码
RelayExpr expr = ...;
if (const auto* call = expr.as<CallNode>()) {
  // 现在可以访问 CallNode 的特定成员
  call->op;
  call->args;
}
  1. 创建新表达式 :TVM 提供了辅助函数来创建表达式,通常以节点类型名去掉 "Node" 命名(如 relay::Var() 创建 VarNodeRelayExpr)。

这种分离设计使得 Relay IR 既灵活又安全,同时保持了良好的性能特性

(2) 具体表达式类型

表达式类型 说明 关键成员/方法
VarNode 变量(输入/中间结果) String name_hint, Type type_annotation, Id vid
ConstantNode 常量张量(如模型权重) runtime::NDArray data, tensor_type(), is_scalar()
CallNode 函数/运算符调用 Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args
LetNode Let绑定(实现变量作用域) Var var, Expr value, Expr body
TupleNode 元组结构(多返回值) Array<Expr> fields
TupleGetItemNode 从元组中获取元素 Expr tuple, int index
IfNode 条件表达式 Expr cond, Expr true_branch, Expr false_branch
OpNode 基本运算符(如add/concat) 通过Op::Get("op_name")获取
FunctionNode 函数定义(在function.h中声明,但属于表达式) Array<Var> params, Expr body, Type ret_type, Array<TypeVar> type_params
RefCreateNode 创建可变引用(用于状态更新) Expr value
RefReadNode 读取引用值 Expr ref
RefWriteNode 更新引用值 Expr ref, Expr value
ConstructorNode 代数数据类型(ADT)的构造器(在adt.h中声明) String tag, Array<Type> inputs
MatchNode 模式匹配(ADT处理) Expr data, Array<Clause> clauses
TempExprNode 临时表达式(用于优化过程中的中间表示) 通常作为优化Pass的中间载体
GlobalVarNode 全局函数引用(跨模块调用) String name_hint
SeqExprNode 顺序执行多个表达式(类似语句块) Array<Binding> bindings, Expr body

1. 表达式类型 VarNode举例子

include/tvm/relay/expr.h

cpp 复制代码
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
 public:
  /*!
   * \brief The unique identifier of the Var.
   *
   * vid will be preserved for the same Var during type inference
   * and other rewritings, while the VarNode might be recreated
   * to attach additional information.
   * This property can be used to keep track of parameter Var
   * information across passes.
   */
  Id vid;
  /*!
   * \brief type annotaion of the variable.
   * This field records user provided type annotation of the Var.
   * This field is optional and can be None.
   */
  Type type_annotation;

  /*! \return The name hint of the variable */
  const String& name_hint() const { return vid->name_hint; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("vid", &vid);
    v->Visit("type_annotation", &type_annotation);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
    return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other);
  }

  void SHashReduce(SHashReducer hash_reduce) const {
    hash_reduce(type_annotation);
    hash_reduce.FreeVarHashImpl(this);
  }

  static constexpr const char* _type_key = "relay.Var";
  TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode);
};

class Var : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param name_hint The name hint of a variable.
   * \param type_annotation The type annotation of a variable.
   * \param span The source span of the expression.
   */
  TVM_DLL Var(String name_hint, Type type_annotation, Span span = Span(), MetaData mdata = MetaData())
      : Var(Id(name_hint), type_annotation, span, mdata) {}

  /*!
   * \brief The constructor
   * \param vid The unique id of a variable.
   * \param type_annotation The type annotation of a variable.
   * \param span The source span of the expression.
   */
  TVM_DLL Var(Id vid, Type type_annotation, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
1. 核心设计理念

VarNodeVar 共同实现了 Relay IR 的变量系统 ,采用 TVM 标准的 Object-ObjectRef 设计模式

  • VarNode:存储实际数据的节点类 (继承自 ExprNode
  • Var:管理 VarNode智能指针包装类 (继承自 Expr

2. 关键成员解析
(1) 核心字段
成员 类型 作用
vid Id 唯一标识符,跨 Pass 保持不变(即使节点被重建)
type_annotation Type 用户显式指定的类型注解(可空)
name_hint() String 通过 vid->name_hint 获取的可读名称(非唯一)
span Span 源码位置信息(用于错误定位)
mdata MetaData 扩展元数据
(2) 特殊方法
方法 功能
SEqualReduce 结构化相等比较(用于优化 Pass 的重复检测)
SHashReduce 哈希计算(支持快速查找)
VisitAttrs 属性序列化/反序列化

3. 变量标识系统
(1) vid (Unique ID)
cpp 复制代码
class IdNode : public Object {
 public:
  String name_hint;
  // ... 其他元数据
};
  • 核心特性
    • 通过 Id(name_hint) 构造,但系统会保证其唯一性
    • 即使优化 Pass 重建变量节点,vid 保持不变
    • 用于跨 Pass 跟踪参数变量(如梯度更新时识别同一参数)
(2) name_hint 与 vid 的关系
python 复制代码
x = relay.var("input", shape=(1,3))  # 实际创建:
                                      # vid = Id("input_0x7f") (自动去重)
                                      # name_hint = "input" (用户友好)

4. 类型系统整合
(1) 类型注解流程
graph TD A[用户构造] -->|relay.var(..., dtype="float32")| B(type_annotation) B --> C[类型检查] C -->|更新| D(_checked_type_)
(2) 类型推导规则
  • type_annotation 存在:必须与实际使用类型兼容
  • 若为空:从上下文推断类型

5. 内存模型与跨语言交互
(1) C++ 层构造
cpp 复制代码
// 方式1:通过 name_hint
Var x("data", TensorType({1,3}, DataType::Float(32)));

// 方式2:直接指定 Id
Var x(Id("data_0x7f"), TensorType({1,3}, DataType::Float(32)));
(2) Python 绑定
python 复制代码
# Python 前端接口
x = relay.var(
    name="input",
    shape=(1,3),
    dtype="float32",
    span=SourceSpan(...)
)
(3) 对象生命周期
sequenceDiagram Python->>C++: relay.var() 创建请求 C++->>Heap: 分配 VarNode C++->>Python: 返回 Var(ObjectRef) Python->>C++: 析构时触发引用计数-1

6. 关键应用场景
(1) 函数参数定义
python 复制代码
def build_linear():
    x = relay.var("x", shape=(1,3))
    w = relay.var("w", shape=(3,2))
    b = relay.var("b", shape=(2,))
    y = relay.add(relay.matmul(x, w), b)
    return relay.Function([x, w, b], y)
(2) 优化 Pass 中的变量处理
cpp 复制代码
// 在 ConstantFolding 中识别变量引用
if (const VarNode* var = expr.as<VarNode>()) {
    if (var_map.count(var->vid)) {
        // 替换为已知常量
    }
}
(3) 类型检查
cpp 复制代码
// 检查变量类型是否匹配
bool CheckType(const VarNode* var, const Type& expected) {
    return var->checked_type().as<TensorType>()->dtype == expected;
}

7. 设计亮点总结
  1. 稳定性vid 保证变量在优化过程中的持久标识
  2. 灵活性type_annotation 支持显式/隐式类型指定
  3. 安全性TVM_DECLARE_FINAL_OBJECT_INFO 防止错误继承
  4. 可调试性spanname_hint 增强错误可读性
  5. 性能SEqualReduce/SHashReduce 优化图操作效率

8. 典型问题分析

Q: 为什么需要同时存在 vidname_hint

A: 分工不同:

  • name_hint:面向用户,提供可读性(允许重复)
  • vid:面向系统,保证唯一性和跨Pass一致性

Q: 何时会重建 VarNode

A: 典型场景:

  • 类型推断后附加 _checked_type_
  • 优化 Pass 中克隆表达式时保留原 vid 但新建节点

(3) TVM_DECLARE_BASE_OBJECT_INFO 宏详解

这个宏是 TVM 类型系统的核心 ,用于在 C++ 中动态注册和管理对象的类型信息。它的核心作用是: 为每个类自动生成类型注册代码,使其能被 TVM 运行时识别和操作


1. 宏的参数

cpp 复制代码
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
  • TypeName:当前类名(如 ConstantNode
  • ParentType:父类名(如 ExprNode

2. 静态断言检查(防止非法继承)

cpp 复制代码
static_assert(!ParentType::_type_final, "ParentObj marked as final");
  • 作用 :如果父类被标记为 final(通过 _type_final),则禁止子类继承。

2. 运行时类型索引(RuntimeTypeIndex)

cpp 复制代码
static uint32_t RuntimeTypeIndex() {
  // 检查子类槽位配置是否合法
  static_assert(TypeName::_type_child_slots == 0 || 
                ParentType::_type_child_slots == 0 ||
                TypeName::_type_child_slots < ParentType::_type_child_slots,
               "子类槽位数不能超过父类限制");

  // 如果已预分配类型ID,直接返回
  if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
    return TypeName::_type_index;
  }
  // 否则动态分配
  return _GetOrAllocRuntimeTypeIndex();
}
  • 功能 :返回类的唯一类型 ID(uint32_t)。
  • 优化 :优先使用预分配的 _type_index(性能更高),否则动态分配。

3. 动态分配类型索引(_GetOrAllocRuntimeTypeIndex)

cpp 复制代码
static uint32_t _GetOrAllocRuntimeTypeIndex() {
  static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex(
      TypeName::_type_key,         // 类型名称字符串(如 "relay.Constant")
      TypeName::_type_index,       // 预分配的类型ID
      ParentType::RuntimeTypeIndex(), // 父类类型ID
      TypeName::_type_child_slots, // 为子类预留的槽位数
      TypeName::_type_child_slots_can_overflow // 是否允许超额
  );
  return tidx;
}
  • 作用:向 TVM 运行时注册类型,并分配唯一 ID。
  • 关键参数
    • _type_child_slots:限制子类数量(防止类型爆炸)。
    • _type_child_slots_can_overflow:为 true 时允许突破限制。

通俗版解释:TVM的类型身份证系统

你可以把TVM的类型系统想象成一个学校的学生管理系统 ,而TVM_DECLARE_BASE_OBJECT_INFO就是给学生(类)办身份证的机器:


1. 为什么要办身份证?
  • 每个学生(类)需要唯一学号(类型ID)
  • 需要知道他的班主任是谁(父类)
  • 防止有人冒充转校生(非法继承)
2. 办证过程(宏的作用)
cpp 复制代码
// 给"小明同学"办证,班主任是"李老师"
TVM_DECLARE_BASE_OBJECT_INFO(小明, 李老师)

这个宏会自动做三件事:

  1. 检查家世清白

    cpp 复制代码
    static_assert(!李老师::是final班, "班主任明确不收新学生!");
    • 如果班主任声明"我们班不接收转学生",就报错
  2. 分配学号

    • 优先用预留的VIP学号(_type_index
    • 没有就现场摇号(_GetOrAllocRuntimeTypeIndex
  3. 登记亲属关系

    cpp 复制代码
    学号 = 教务处.登记(
     姓名:"小明",
     班主任:李老师.学号,
     可带小弟人数:3  // _type_child_slots
    );
3. 特殊班级(FINAL版)
cpp 复制代码
TVM_DECLARE_FINAL_OBJECT_INFO(学霸班, 实验班)
  • 相当于在班级门口挂**"禁止转入"**牌子
  • 其他班同学想转学过来会直接报错
4. 实际有什么用?
  • 查身份证快obj->IsInstance<小明>() 比查户口本快
  • 安全转班obj.as<小明>() 能安全转换类型
  • 防止冒名顶替:禁止随便认爹(错误继承)

举个栗子🌰
python 复制代码
# Python前端定义一个"汉堡店"类
@register_relay_node("food.HamburgerShop")
class HamburgerShopNode(ExprNode):
    _type_key = "food.HamburgerShop"
    _type_child_slots = 2  # 允许开2家分店

C++层通过这个宏:

  1. 给汉堡店分配类型ID(比如9527)
  2. 记录它的父类是ExprNode
  3. 允许最多2个子类(比如CheeseBurgerShopChickenBurgerShop

一句话总结

这个宏就是TVM给类发身份证+建家族档案的工具,让系统能:

  • ✅ 快速识别"你是谁"(类型检查)
  • ✅ 知道"你爸是谁"(继承关系)
  • ❌ 防止"乱认亲戚"(非法继承)

(4) 遍历接口

cpp 复制代码
  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

VisitAttrs 是 TVM 中用于统一序列化、反序列化和属性访问 的核心接口。以下是 ConstantNode 使用该函数的具体示例,涵盖 C++ 和 Python 场景:


1. C++ 场景示例
(1) 模型序列化(保存为JSON)
cpp 复制代码
// 创建常量节点
runtime::NDArray arr = runtime::NDArray::Empty({2, 2}, DLDataType{kDLFloat, 32, 1}, DLContext{kDLCPU, 0});
ConstantNode* const_node = new ConstantNode();
const_node->data = arr;

// 序列化为JSON
JSONAttrVisitor visitor;
const_node->VisitAttrs(&visitor);  // 触发以下调用:
                                   // visitor.Visit("data", &data)
                                   // visitor.Visit("span", &span)...
std::string json = visitor.GetJSON();

输出JSON片段

json 复制代码
{
  "type_key": "relay.Constant",
  "data": {"b64": "AABAA...", "dtype": "float32", "shape": [2, 2]},
  "span": null,
  "_checked_type_": "TensorType([2,2], float32)"
}
(2) 优化Pass中的常量修改
cpp 复制代码
class ConstantMutator : public AttrMutator {
 public:
  void VisitAttrs(AttrVisitor* v) override {
    if (v->IsMutator()) {  // 检查是否为修改模式
      runtime::NDArray new_data = ...; // 生成新数据
      v->Visit("data", &new_data);    // 修改data字段
    }
  }
};

// 调用示例:
ConstantMutator mutator;
const_node->VisitAttrs(&mutator);  // 修改常量数据
(3) 调试打印
cpp 复制代码
class DebugPrinter : public AttrVisitor {
 public:
  void Visit(const char* key, runtime::NDArray* data) override {
    std::cout << key << ": shape=" << data.Shape();
  }
};

DebugPrinter printer;
const_node->VisitAttrs(&printer);  // 输出:data: shape=[2,2]

2. Python 场景示例
(1) 直接属性访问
python 复制代码
import tvm
from tvm import relay

# 创建常量
data = tvm.nd.array(np.zeros((2,2), dtype="float32"))
const = relay.Constant(data)

# Python属性访问(背后调用VisitAttrs)
print(const.data)      # 访问NDArray → 触发Visit("data", &data)
print(const.span)      # 访问源码位置 → Visit("span", &span)

输出

复制代码
<tvm.nd.NDArray shape=(2, 2), dtype=float32>
None  # 未设置span时的默认值
(2) 模型保存与加载
python 复制代码
# 保存模型(触发序列化)
mod = tvm.IRModule.from_expr(const)
mod.save("const.json")  # 内部调用VisitAttrs

# 加载模型(触发反序列化)
loaded_mod = tvm.ir.load_json("const.json")
loaded_const = loaded_mod["main"].body
assert isinstance(loaded_const, relay.Constant)
(3) 自定义属性访问器
python 复制代码
class MyVisitor(tvm.ir.AttrVisitor):
    def visit(self, name, value):
        print(f"Attribute {name} has type {type(value)}")

visitor = MyVisitor()
const.visit_attrs(visitor)  # 显式调用VisitAttrs

输出

复制代码
Attribute data has type <class 'tvm.runtime.ndarray.NDArray'>
Attribute span has type <class 'tvm.ir.Span'>
...

cpp 复制代码
class Constant;
/*!
 * \brief Constant tensor type.
 */
class ConstantNode : public ExprNode {
 public:
  /*! \brief The data of the tensor */
  runtime::NDArray data;

  /*! \return The corresponding tensor type of the data */
  TensorType tensor_type() const;

  /*! \return Whether it is scalar(rank-0 tensor) */
  bool is_scalar() const { return data->ndim == 0; }

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("data", &data);
    v->Visit("span", &span);
    v->Visit("mdata", &mdata);
    v->Visit("_checked_type_", &checked_type_);
  }

  bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
    return equal(data, other->data);
  }

  void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }

  static constexpr const char* _type_key = "relay.Constant";
  TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};

class Constant : public Expr {
 public:
  /*!
   * \brief The constructor
   * \param data The data of the constant tensor.
   * \param span The source span of the expression.
   */
  TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span(), MetaData mdata = MetaData());

  TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};

以下是关于 ConstantNodeConstant 类的详细解释与概括,结合它们在 TVM Relay IR 中的作用和实现设计:



相关推荐
带娃的IT创业者9 分钟前
《AI大模型应知应会100篇》第22篇:系统提示词(System Prompt)设计与优化
人工智能·prompt
绝顶大聪明13 分钟前
【图像轮廓特征查找】图像处理(OpenCV) -part8
图像处理·人工智能·opencv
liruiqiang0514 分钟前
神经网络优化 - 小批量梯度下降之批量大小的选择
人工智能·深度学习·神经网络·机器学习·梯度下降
AI大模型顾潇14 分钟前
[特殊字符] Prompt如何驱动大模型对本地文件实现自主变更:Cline技术深度解析
前端·人工智能·llm·微调·prompt·编程·ai大模型
Blossom.11823 分钟前
量子计算与经典计算融合:开启计算新时代
人工智能·深度学习·opencv·物联网·生活·边缘计算·量子计算
AI技术学长37 分钟前
深度学习-python猫狗识别tensorflow2.0
人工智能·深度学习·计算机视觉·图像识别·计算机技术·tensorflow2·猫狗识别
6confim40 分钟前
掌握 Cursor:AI 编程助手的高效使用技巧
前端·人工智能·后端
offerwa40 分钟前
LLM多模态能力应用实战指南
人工智能
offerwa41 分钟前
知识图谱与大模型结合实践指南
人工智能
offerwa42 分钟前
大模型Agent系统设计与实现指南
人工智能