在 Apache TVM 的 Relay IR 中,基础节点(Var
、Const
、Call
、Function
和 Expr
)是构建计算图的核心数据结构。以下是对它们的详细解析,包括定义、作用、内部组成及相互关系:
1. Expr
(表达式基类)
作用
- 所有 Relay IR 节点的基类,提供统一的类型系统和遍历接口。
- 支持递归访问、变换和类型检查。
关键组成
字段/方法 |
说明 |
checked_type_ |
表达式的推断类型(如 TensorType 、TupleType )。 |
span |
源代码位置信息(用于调试和错误报告)。 |
VisitAttrs(visitor) |
递归访问所有属性和子节点(用于序列化、优化等)。 |
Mutate() |
生成表达式的副本(用于变换和优化)。 |
子类关系
Expr Var Const Call Function Tuple Let If
2. Var
(变量)
作用
- 表示计算图中的 输入变量 或 中间变量(符号化张量)。
- 类似于深度学习模型中的输入占位符或中间激活值。
关键组成
字段 |
说明 |
name_hint |
变量名称(字符串标识符,如 "x" )。 |
type_annotation |
变量的显式类型注解(可选,如 TensorType({1,3}, float32) )。 |
vid |
内部唯一 ID(用于优化和去重)。 |
示例
cpp
复制代码
// 定义一个浮点型张量变量
Var x("x", TensorType({1, 3}, DataType::Float(32)));
3. Const
(常量)
作用
- 表示 不可变的数据(如模型权重、超参数)。
- 在计算图中作为叶子节点存在。
关键组成
字段 |
说明 |
data |
存储的常量值(runtime::NDArray 或 relay::ConstantNode )。 |
checked_type_ |
常量的类型(通常从 data 自动推断)。 |
示例
cpp
复制代码
// 定义一个常量张量
NDArray weight = NDArray::Empty({3, 3}, DataType::Float(32), {kDLCPU, 0});
Const weight_const(weight);
4. Call
(函数调用)
作用
- 表示对 算子(Operator) 或 函数(Function) 的调用。
- 是构建计算图的核心节点(如
add
、conv2d
)。
关键组成
字段 |
说明 |
op |
调用的目标(Op 、Function 或 GlobalVar )。 |
args |
参数列表(Array<Expr> ,可以是 Var 、Const 或其他 Call )。 |
attrs |
调用的属性(如卷积的 strides 、padding )。 |
示例
cpp
复制代码
// 调用加法算子
Expr a = Var("a", TensorType({1}, DataType::Float(32)));
Expr b = Var("b", TensorType({1}, DataType::Float(32)));
Expr add_call = Call(Op::Get("add"), {a, b});
5. Function
(函数定义)
作用
- 封装可复用的计算单元(类似 Lambda 表达式)。
- 用于表示模型中的子图或复合算子(如
conv2d + relu
融合)。
关键组成
字段 |
说明 |
params |
输入参数列表(Array<Var> )。 |
body |
函数体的表达式(Expr )。 |
ret_type |
返回值的类型(如 TensorType )。 |
type_params |
泛型类型参数(支持多态,类似 C++ 模板)。 |
示例
cpp
复制代码
// 定义一个简单的加法函数
Var x("x", TensorType({1}, DataType::Float(32)));
Var y("y", TensorType({1}, DataType::Float(32)));
Expr body = Call(Op::Get("add"), {x, y});
Function add_func({x, y}, body, TensorType({1}, DataType::Float(32)));
6. 节点间的协作关系
计算图示例
复制代码
z = (x + y) * 2
对应的 Relay IR 结构:
- 变量 :
x
、y
(Var
节点)。
- 常量 :
2
(Const
节点)。
- 调用 :
add(x, y)
和 multiply(add_result, 2)
(Call
节点)。
- 函数 :封装整个计算(
Function
节点)。
代码实现
cpp
复制代码
Var x("x", TensorType({1}, DataType::Float(32)));
Var y("y", TensorType({1}, DataType::Float(32)));
Expr add = Call(Op::Get("add"), {x, y});
Expr two = Const(NDArray::FromVector({2.0f}));
Expr mul = Call(Op::Get("multiply"), {add, two});
Function func({x, y}, mul, TensorType({1}, DataType::Float(32)));
7. 类型系统支持
所有 Expr
节点都关联类型信息:
Var
/Const
:通过 type_annotation
或 checked_type_
指定张量类型。
Call
:根据算子的类型规则推断返回类型(如 add(Tensor, Tensor) -> Tensor
)。
Function
:通过 ret_type
声明返回值类型。
总结
节点 |
角色 |
关键特性 |
Expr |
所有节点的基类 |
提供类型检查和遍历接口。 |
Var |
输入/中间变量 |
符号化表示,支持类型注解。 |
Const |
常量数据 |
存储不可变值(如权重)。 |
Call |
算子或函数调用 |
构建计算图的核心节点,依赖 op 和 args 。 |
Function |
可复用的计算单元 |
封装参数、计算体和返回类型,支持多态。 |
这些基础节点共同构成了 Relay IR 的 静态计算图,通过组合它们可以表示复杂的深度学习模型,并为后续优化和代码生成提供基础。