CANN图引擎GE的编译优化与高效执行机制深度解析
前言
Graph Engine(GE)是CANN框架的核心图编译和执行引擎,负责将深度学习计算图转换为高效的NPU执行计划。GE提供了一套完整的图优化、编译和执行框架,通过图级别的优化实现性能提升。本文将深入剖析GE的架构设计、优化机制和执行流程。
相关链接:
- CANN组织链接:https://atomgit.com/cann
- GE仓库链接:https://atomgit.com/cann/ge
一、GE架构概述
1.1 设计理念
GE遵循以下核心设计理念:
- 计算图抽象:将模型表示为数据流图
- 多层IR体系:不同阶段的中间表示支持不同级别的优化
- Pass流水线:模块化的优化Pass可灵活组合
- 后端解耦:图优化与后端执行分离
- 延迟编译:支持运行时JIT编译
1.2 架构层次
GE/
├── 前端接口层
│ ├── Graph API - 图构建接口
│ ├── Tensor API - 张量定义接口
│ └── Op API - 算子定义接口
├── 中间表示层
│ ├── Graph IR - 计算图中间表示
│ ├── IR Builder - IR构建器
│ └── IR Printer - IR打印/调试
├── 图优化层
│ ├── 优化Pass框架
│ ├── 图变换Pass
│ │ ├── 算子融合 (Operator Fusion)
│ │ ├── 常量折叠 (Constant Folding)
│ │ ├── 死代码消除 (DCE)
│ │ ├── 公共子表达式消除 (CSE)
│ │ └── 循环展开 (Loop Unrolling)
│ ├── 内存优化Pass
│ │ ├── 内存复用 (Memory Reuse)
│ │ ├── 内存分配 (Memory Allocation)
│ │ └── 缓冲区融合 (Buffer Fusion)
│ └── 布局优化Pass
│ ├── 数据布局转换 (Layout Transform)
│ └── 格式优化 (Format Optimization)
├── 编译层
│ ├── 图分区 (Graph Partitioning)
│ ├── 流水线调度 (Pipeline Scheduling)
│ ├── 代码生成 (Code Generation)
│ └── 二进制生成 (Binary Generation)
└── 执行层
├── 图执行器 (Graph Executor)
├── 流管理 (Stream Management)
├── 事件同步 (Event Synchronization)
└── 内存管理 (Memory Management)
二、核心API详解
2.1 图构建API
cpp
#include "ge/ge_api.h"
/**
* @brief 计算图类
* 表示一个完整的计算图
*/
class Graph {
public:
// 构造函数
Graph(const std::string& name = "graph");
// 获取图名称
const std::string& GetName() const { return name_; }
// 添加算子节点
Status AddOp(OpPtr op);
// 添加数据边
Status AddDataEdge(const OpPtr& src_op, uint32_t src_index,
const OpPtr& dst_op, uint32_t dst_index);
// 添加控制边
Status AddControlEdge(const OpPtr& src_op, const OpPtr& dst_op);
// 获取所有算子
const std::vector<OpPtr>& GetAllOps() const { return ops_; }
// 获取输入节点
std::vector<OpPtr> GetInputNodes() const;
// 获取输出节点
std::vector<OpPtr> GetOutputNodes() const;
// 图验证
Status Validate() const;
// 拓扑排序
std::vector<OpPtr> TopologicalSort() const;
private:
std::string name_;
std::vector<OpPtr> ops_;
std::vector<DataEdge> data_edges_;
std::vector<ControlEdge> control_edges_;
};
/**
* @brief 算子节点类
*/
class Op {
public:
// 构造函数
Op(const std::string& name, const std::string& type);
// 获取算子名称
const std::string& GetName() const { return name_; }
// 获取算子类型
const std::string& GetType() const { return type_; }
// 添加输入张量描述
Status AddInputDesc(const TensorDesc& desc);
// 添加输出张量描述
Status AddOutputDesc(const TensorDesc& desc);
// 获取输入描述
const TensorDesc& GetInputDesc(uint32_t index) const;
// 获取输出描述
const TensorDesc& GetOutputDesc(uint32_t index) const;
// 设置属性
template<typename T>
Status SetAttr(const std::string& name, const T& value);
// 获取属性
template<typename T>
Status GetAttr(const std::string& name, T* value) const;
// 获取所有输入节点
std::vector<OpPtr> GetInputs() const;
// 获取所有输出节点
std::vector<OpPtr> GetOutputs() const;
private:
std::string name_;
std::string type_;
std::vector<TensorDesc> input_descs_;
std::vector<TensorDesc> output_descs_;
std::unordered_map<std::string, Attribute> attrs_;
};
/**
* @brief 张量描述类
*/
class TensorDesc {
public:
// 构造函数
TensorDesc(const Shape& shape, DataType dtype, Format format);
// 获取形状
const Shape& GetShape() const { return shape_; }
// 获取数据类型
DataType GetDataType() const { return dtype_; }
// 获取数据格式
Format GetFormat() const { return format_; }
// 设置形状
void SetShape(const Shape& shape) { shape_ = shape; }
// 设置存储格式
void SetFormat(Format format) { format_ = format; }
// 获取存储大小
size_t GetBytes() const;
// 获取形状维度数
int GetDimNum() const { return shape_.GetDimNum(); }
private:
Shape shape_;
DataType dtype_;
Format format_;
};
/**
* @brief 形状类
*/
class Shape {
public:
// 构造函数
Shape(const std::vector<int64_t>& dims);
// 获取维度
const std::vector<int64_t>& GetDims() const { return dims_; }
// 获取维度数
int GetDimNum() const { return dims_.size(); }
// 获取总元素数
int64_t GetTotalNum() const;
// 获取指定维度的值
int64_t GetDim(int index) const { return dims_[index]; }
// 设置维度
void SetDim(int index, int64_t value);
// 形状比较
bool operator==(const Shape& other) const;
// 未知维度标记
static constexpr int64_t kUnknownDim = -1;
private:
std::vector<int64_t> dims_;
};
2.2 图优化API
cpp
#include "ge/graph_optimizer.h"
/**
* @brief 图优化器
* 负责执行一系列优化Pass
*/
class GraphOptimizer {
public:
// 构造函数
GraphOptimizer();
// 添加优化Pass
Status AddPass(PassPtr pass);
// 执行所有Pass
Status Optimize(Graph* graph);
// 获取优化报告
const OptimizationReport& GetReport() const { return report_; }
private:
std::vector<PassPtr> passes_;
OptimizationReport report_;
};
/**
* @brief 优化Pass基类
* 所有优化Pass的基类
*/
class Pass {
public:
// 构造函数
Pass(const std::string& name) : name_(name) {}
// 析构函数
virtual ~Pass() = default;
// 执行优化
virtual Status Run(Graph* graph) = 0;
// 获取Pass名称
const std::string& GetName() const { return name_; }
// 获取Pass类型
virtual PassType GetType() const = 0;
protected:
std::string name_;
};
// Pass类型枚举
enum PassType {
PASS_GRAPH_TRANSFORM, // 图变换
PASS_MEMORY_OPT, // 内存优化
PASS_LAYOUT_OPT, // 布局优化
PASS_FUSION, // 算子融合
PASS_PARTITION // 图分区
};
/**
* @brief 算子融合Pass
* 将多个连续算子融合为单个融合算子
*/
class OperatorFusionPass : public Pass {
public:
struct FusionPattern {
std::string name; // 模式名称
std::vector<std::string> op_types; // 算子类型序列
std::function<bool(const std::vector<OpPtr>&)> matcher; // 匹配器
std::string fused_op_type; // 融合后的算子类型
};
OperatorFusionPass() : Pass("OperatorFusion") {}
Status Run(Graph* graph) override {
bool changed = true;
while (changed) {
changed = false;
for (const auto& pattern : patterns_) {
if (TryApplyPattern(graph, pattern)) {
changed = true;
break;
}
}
}
return Status::OK();
}
PassType GetType() const override { return PASS_FUSION; }
// 注册融合模式
void RegisterPattern(const FusionPattern& pattern) {
patterns_.push_back(pattern);
}
private:
bool TryApplyPattern(Graph* graph, const FusionPattern& pattern) {
// 查找所有匹配的子图
std::vector<std::vector<OpPtr>> matches;
FindPatternMatches(graph, pattern, &matches);
if (matches.empty()) {
return false;
}
// 应用融合(这里简化处理第一个匹配)
return ApplyFusion(graph, matches[0], pattern);
}
void FindPatternMatches(
Graph* graph,
const FusionPattern& pattern,
std::vector<std::vector<OpPtr>>* matches) {
// 遍历图中的所有算子序列
auto ops = graph->GetAllOps();
for (size_t i = 0; i + pattern.op_types.size() <= ops.size(); ++i) {
std::vector<OpPtr> candidate(
ops.begin() + i,
ops.begin() + i + pattern.op_types.size()
);
// 检查类型匹配
bool type_match = true;
for (size_t j = 0; j < candidate.size(); ++j) {
if (candidate[j]->GetType() != pattern.op_types[j]) {
type_match = false;
break;
}
}
if (type_match && pattern.matcher(candidate)) {
matches->push_back(candidate);
}
}
}
bool ApplyFusion(
Graph* graph,
const std::vector<OpPtr>& ops,
const FusionPattern& pattern) {
// 创建融合算子
std::string fused_name = "Fused_" + ops[0]->GetName();
for (size_t i = 1; i < ops.size(); ++i) {
fused_name += "_" + ops[i]->GetName();
}
auto fused_op = std::make_shared<Op>(fused_name, pattern.fused_op_type);
// 复制输入描述
fused_op->AddInputDesc(ops[0]->GetInputDesc(0));
// 复制输出描述
fused_op->AddOutputDesc(ops.back()->GetOutputDesc(0));
// 复制属性
for (auto* op : ops) {
// 简化:实际需要合并属性
}
// 替换原图中的算子
return graph->ReplaceSubgraph(ops, fused_op);
}
std::vector<FusionPattern> patterns_;
};
/**
* @brief 常量折叠Pass
* 在编译期计算常量表达式的值
*/
class ConstantFoldingPass : public Pass {
public:
ConstantFoldingPass() : Pass("ConstantFolding") {}
Status Run(Graph* graph) override {
auto ops = graph->GetAllOps();
for (auto& op : ops) {
if (IsConstantOp(op)) {
continue; // 已经是常量
}
// 检查所有输入是否都是常量
if (AllInputsAreConstants(op)) {
// 执行算子计算
Tensor result;
Status s = EvaluateConstantOp(op, &result);
if (s.ok()) {
// 创建常量算子替换原算子
auto const_op = CreateConstantOp(op->GetName(), result);
graph->ReplaceOp(op, const_op);
}
}
}
return Status::OK();
}
PassType GetType() const override { return PASS_GRAPH_TRANSFORM; }
private:
bool IsConstantOp(const OpPtr& op) {
return op->GetType() == "Const";
}
bool AllInputsAreConstants(const OpPtr& op) {
auto inputs = op->GetInputs();
for (auto& input : inputs) {
if (!IsConstantOp(input)) {
return false;
}
}
return true;
}
Status EvaluateConstantOp(const OpPtr& op, Tensor* result) {
// 简化:实际需要根据算子类型执行计算
return Status::OK();
}
OpPtr CreateConstantOp(const std::string& name, const Tensor& value) {
auto const_op = std::make_shared<Op>(name, "Const");
// 设置常量值
return const_op;
}
};
/**
* @brief 公共子表达式消除Pass
*/
class CommonSubexpressionEliminationPass : public Pass {
public:
CommonSubexpressionEliminationPass() : Pass("CSE") {}
Status Run(Graph* graph) override {
// 构建表达式哈希表
std::unordered_map<std::string, std::vector<OpPtr>> expr_map;
for (auto& op : graph->GetAllOps()) {
std::string hash = ComputeOpHash(op);
expr_map[hash].push_back(op);
}
// 对于每个哈希值,只保留第一个算子,其他替换为引用
for (auto& pair : expr_map) {
if (pair.second.size() > 1) {
auto& canonical = pair.second[0];
for (size_t i = 1; i < pair.second.size(); ++i) {
graph->ReplaceOp(pair.second[i], canonical);
}
}
}
return Status::OK();
}
PassType GetType() const override { return PASS_GRAPH_TRANSFORM; }
private:
std::string ComputeOpHash(const OpPtr& op) {
// 计算算子的签名哈希
std::string hash = op->GetType();
for (auto& input : op->GetInputs()) {
hash += "_" + input->GetName();
}
return hash;
}
};
/**
* @brief 内存复用Pass
* 优化内存分配,复用不再使用的缓冲区
*/
class MemoryReusePass : public Pass {
public:
MemoryReusePass() : Pass("MemoryReuse") {}
Status Run(Graph* graph) override {
// 1. 构建数据依赖图
DependencyGraph dep_graph;
BuildDependencyGraph(graph, &dep_graph);
// 2. 计算每个张量的生命周期
std::unordered_map<OpPtr, Lifetime> lifetimes;
ComputeLifetimes(graph, &dep_graph, &lifetimes);
// 3. 构建不相交集合用于内存复用
std::vector<BufferGroup> groups;
PartitionByLifetime(lifetimes, &groups);
// 4. 为每个组分配共享内存
AllocateSharedMemory(graph, groups);
return Status::OK();
}
PassType GetType() const override { return PASS_MEMORY_OPT; }
private:
struct Lifetime {
int birth;
int death;
};
struct BufferGroup {
std::vector<OpPtr> ops;
size_t max_size;
};
void BuildDependencyGraph(Graph* graph, DependencyGraph* dep_graph) {
// 构建数据依赖关系
}
void ComputeLifetimes(
Graph* graph,
const DependencyGraph* dep_graph,
std::unordered_map<OpPtr, Lifetime>* lifetimes) {
auto ops = graph->TopologicalSort();
for (size_t i = 0; i < ops.size(); ++i) {
Lifetime lt;
lt.birth = i;
// 计算最后一次使用的时间
lt.death = FindLastUse(ops[i], ops);
(*lifetimes)[ops[i]] = lt;
}
}
int FindLastUse(const OpPtr& op, const std::vector<OpPtr>& ops) {
int last_use = -1;
for (size_t i = 0; i < ops.size(); ++i) {
if (HasDependency(ops[i], op)) {
last_use = i;
}
}
return last_use;
}
bool HasDependency(const OpPtr& op, const OpPtr& target) {
auto inputs = op->GetInputs();
for (auto& input : inputs) {
if (input == target) {
return true;
}
}
return false;
}
void PartitionByLifetime(
const std::unordered_map<OpPtr, Lifetime>& lifetimes,
std::vector<BufferGroup>* groups) {
// 将不重叠生命周期的张量分组
}
void AllocateSharedMemory(
Graph* graph,
const std::vector<BufferGroup>& groups) {
// 为每组分配共享内存
}
};
2.3 图执行API
cpp
#include "ge/graph_executor.h"
/**
* @brief 图执行器
* 负责编译后的图的执行
*/
class GraphExecutor {
public:
// 构造函数
GraphExecutor();
// 析构函数
~GraphExecutor();
// 加载编译后的图
Status LoadGraph(const CompiledGraph& compiled_graph);
// 设置输入数据
Status SetInput(const std::string& input_name, const Tensor& data);
// 设置输入数据(按索引)
Status SetInput(uint32_t index, const Tensor& data);
// 获取输出数据
Status GetOutput(const std::string& output_name, Tensor* data);
// 获取输出数据(按索引)
Status GetOutput(uint32_t index, Tensor* data);
// 执行图
Status Execute();
// 异步执行
Status ExecuteAsync(aclrtStream stream);
// 获取执行统计信息
const ExecutionStats& GetStats() const { return stats_; }
private:
CompiledGraph compiled_graph_;
std::vector<Tensor> inputs_;
std::vector<Tensor> outputs_;
ExecutionStats stats_;
};
/**
* @brief 图编译器
* 将图编译为可执行格式
*/
class GraphCompiler {
public:
// 构造函数
GraphCompiler();
// 配置选项
struct CompileOptions {
bool enable_fusion = true; // 启用算子融合
bool enable_optimization = true; // 启用优化
bool enable_precision_mode = false; // 精度模式
std::string target_device = "NPU"; // 目标设备
int opt_level = 3; // 优化级别 (0-3)
};
// 编译图
Status Compile(
const Graph& graph,
const CompileOptions& options,
CompiledGraph* compiled_graph);
// 导出编译后的图
Status Save(const CompiledGraph& graph, const std::string& path);
// 加载编译后的图
Status Load(const std::string& path, CompiledGraph* graph);
private:
Status ApplyOptimizations(
const Graph& input,
const CompileOptions& options,
Graph* optimized);
Status GenerateExecutable(
const Graph& optimized,
const CompileOptions& options,
CompiledGraph* executable);
};
/**
* @brief 流管理器
* 管理执行流和事件同步
*/
class StreamManager {
public:
// 创建流
Status CreateStream(aclrtStream* stream);
// 销毁流
Status DestroyStream(aclrtStream stream);
// 创建事件
Status CreateEvent(aclrtEvent* event);
// 销毁事件
Status DestroyEvent(aclrtEvent event);
// 记录事件
Status RecordEvent(aclrtEvent event, aclrtStream stream);
// 等待事件
Status WaitEvent(aclrtEvent event, aclrtStream stream);
// 流同步
Status StreamSynchronize(aclrtStream stream);
// 设备同步
Status DeviceSynchronize();
private:
std::vector<aclrtStream> streams_;
std::vector<aclrtEvent> events_;
};
三、应用实践
3.1 构建和优化计算图
cpp
#include "ge/ge_api.h"
// 构建一个简单的卷积神经网络
Graph BuildConvNet() {
Graph graph("ConvNet");
// 创建数据输入节点
auto data = std::make_shared<Op>("data", "Data");
TensorDesc data_desc({1, 3, 224, 224}, DT_FLOAT, FORMAT_NCHW);
data->AddOutputDesc(data_desc);
graph.AddOp(data);
// 创建卷积层1
auto conv1 = std::make_shared<Op>("conv1", "Conv2D");
conv1->AddInputDesc(data_desc);
TensorDesc conv1_weight_desc({64, 3, 7, 7}, DT_FLOAT, FORMAT_NCHW);
conv1->AddInputDesc(conv1_weight_desc);
TensorDesc conv1_bias_desc({64}, DT_FLOAT, FORMAT_ND);
conv1->AddInputDesc(conv1_bias_desc);
TensorDesc conv1_output_desc({1, 64, 112, 112}, DT_FLOAT, FORMAT_NCHW);
conv1->AddOutputDesc(conv1_output_desc);
conv1->SetAttr("strides", std::vector<int>{2, 2});
conv1->SetAttr("pads", std::vector<int>{3, 3, 3, 3});
conv1->SetAttr("dilations", std::vector<int>{1, 1});
conv1->SetAttr("groups", 1);
graph.AddOp(conv1);
// 创建BN层1
auto bn1 = std::make_shared<Op>("bn1", "BatchNorm");
bn1->AddInputDesc(conv1_output_desc);
TensorDesc bn1_output_desc({1, 64, 112, 112}, DT_FLOAT, FORMAT_NCHW);
bn1->AddOutputDesc(bn1_output_desc);
bn1->SetAttr("epsilon", 1e-5f);
bn1->SetAttr("momentum", 0.9f);
graph.AddOp(bn1);
// 创建ReLU激活1
auto relu1 = std::make_shared<Op>("relu1", "ReLU");
relu1->AddInputDesc(bn1_output_desc);
relu1->AddOutputDesc(bn1_output_desc);
graph.AddOp(relu1);
// 创建最大池化层1
auto pool1 = std::make_shared<Op>("pool1", "MaxPool");
pool1->AddInputDesc(bn1_output_desc);
TensorDesc pool1_output_desc({1, 64, 56, 56}, DT_FLOAT, FORMAT_NCHW);
pool1->AddOutputDesc(pool1_output_desc);
pool1->SetAttr("kernels", std::vector<int>{3, 3});
pool1->SetAttr("strides", std::vector<int>{2, 2});
pool1->SetAttr("pads", std::vector<int>{1, 1, 1, 1});
graph.AddOp(pool1);
// 添加数据边
graph.AddDataEdge(data, 0, conv1, 0);
graph.AddDataEdge(conv1, 0, bn1, 0);
graph.AddDataEdge(bn1, 0, relu1, 0);
graph.AddDataEdge(relu1, 0, pool1, 0);
return graph;
}
// 应用优化Pass
void OptimizeGraph(Graph* graph) {
GraphOptimizer optimizer;
// 添加算子融合Pass
auto fusion_pass = std::make_shared<OperatorFusionPass>();
// 注册融合模式:Conv + BN + ReLU
OperatorFusionPass::FusionPattern conv_bn_relu;
conv_bn_relu.name = "ConvBNRelu";
conv_bn_relu.op_types = {"Conv2D", "BatchNorm", "ReLU"};
conv_bn_relu.matcher = [](const std::vector<OpPtr>& ops) {
// 检查是否可以融合(简化示例)
return true;
};
conv_bn_relu.fused_op_type = "FusedConvBNRelu";
fusion_pass->RegisterPattern(conv_bn_relu);
optimizer.AddPass(fusion_pass);
// 添加常量折叠Pass
optimizer.AddPass(std::make_shared<ConstantFoldingPass>());
// 添加公共子表达式消除Pass
optimizer.AddPass(std::make_shared<CommonSubexpressionEliminationPass>());
// 添加内存复用Pass
optimizer.AddPass(std::make_shared<MemoryReusePass>());
// 执行优化
Status s = optimizer.Optimize(graph);
if (!s.ok()) {
std::cerr << "Optimization failed: " << s.ErrorMessage() << std::endl;
} else {
std::cout << "Optimization completed successfully" << std::endl;
auto report = optimizer.GetReport();
std::cout << "Passes executed: " << report.passes_executed << std::endl;
std::cout << "Ops fused: " << report.ops_fused << std::endl;
std::cout << "Memory saved: " << report.memory_saved << " bytes" << std::endl;
}
}
// 编译和执行图
void CompileAndExecuteGraph(const Graph& graph) {
GraphCompiler compiler;
// 配置编译选项
GraphCompiler::CompileOptions options;
options.enable_fusion = true;
options.enable_optimization = true;
options.opt_level = 3;
options.target_device = "NPU";
// 编译图
CompiledGraph compiled_graph;
Status s = compiler.Compile(graph, options, &compiled_graph);
if (!s.ok()) {
std::cerr << "Compilation failed: " << s.ErrorMessage() << std::endl;
return;
}
// 保存编译后的图
compiler.Save(compiled_graph, "conv_net.om");
// 创建执行器
GraphExecutor executor;
executor.LoadGraph(compiled_graph);
// 设置输入
Tensor input_tensor({1, 3, 224, 224}, DT_FLOAT, FORMAT_NCHW);
// 填充输入数据...
executor.SetInput(0, input_tensor);
// 执行
s = executor.Execute();
if (!s.ok()) {
std::cerr << "Execution failed: " << s.ErrorMessage() << std::endl;
return;
}
// 获取输出
Tensor output_tensor;
executor.GetOutput(0, &output_tensor);
// 打印执行统计
auto stats = executor.GetStats();
std::cout << "Execution time: " << stats.execution_time_us << " us" << std::endl;
std::cout << "Memory used: " << stats.memory_used << " bytes" << std::endl;
}
3.2 自定义优化Pass
cpp
/**
* @brief 自定义布局转换Pass
* 优化数据布局以提高缓存命中率
*/
class LayoutTransformPass : public Pass {
public:
LayoutTransformPass() : Pass("LayoutTransform") {}
Status Run(Graph* graph) override {
// 分析每个算子的最优输入输出格式
std::unordered_map<OpPtr, FormatPreference> prefs;
AnalyzeFormatPreferences(graph, &prefs);
// 插入必要的格式转换
InsertFormatTransforms(graph, prefs);
// 尽可能消除冗余的转换
EliminateRedundantTransforms(graph);
return Status::OK();
}
PassType GetType() const override { return PASS_LAYOUT_OPT; }
private:
struct FormatPreference {
Format input_format;
Format output_format;
};
void AnalyzeFormatPreferences(
Graph* graph,
std::unordered_map<OpPtr, FormatPreference>* prefs) {
for (auto& op : graph->GetAllOps()) {
FormatPreference pref;
// 根据算子类型确定最优格式
if (op->GetType() == "Conv2D" || op->GetType() == "FusedConvBNRelu") {
// 卷积算子偏好NCHW
pref.input_format = FORMAT_NCHW;
pref.output_format = FORMAT_NCHW;
} else if (op->GetType() == "MatMul" || op->GetType() == "FullyConnection") {
// 矩阵乘法偏好NC1HWC0(NPU专用格式)
pref.input_format = FORMAT_NC1HWC0;
pref.output_format = FORMAT_NC1HWC0;
} else {
// 默认NCHW
pref.input_format = FORMAT_NCHW;
pref.output_format = FORMAT_NCHW;
}
(*prefs)[op] = pref;
}
}
void InsertFormatTransforms(
Graph* graph,
const std::unordered_map<OpPtr, FormatPreference>& prefs) {
for (auto& op : graph->GetAllOps()) {
auto pref = prefs.find(op);
if (pref == prefs.end()) continue;
// 检查输入格式
auto inputs = op->GetInputs();
for (size_t i = 0; i < inputs.size(); ++i) {
Format input_format = op->GetInputDesc(i).GetFormat();
if (input_format != pref->second.input_format) {
// 插入格式转换
auto trans_op = CreateFormatTransformOp(
inputs[i], input_format, pref->second.input_format);
graph->InsertOpBefore(op, trans_op, i);
}
}
// 检查输出格式
auto outputs = op->GetOutputs();
for (size_t i = 0; i < outputs.size(); ++i) {
Format output_format = op->GetOutputDesc(i).GetFormat();
if (output_format != pref->second.output_format) {
// 插入格式转换
auto trans_op = CreateFormatTransformOp(
op, output_format, pref->second.output_format);
graph->InsertOpAfter(op, trans_op, i);
}
}
}
}
void EliminateRedundantTransforms(Graph* graph) {
// 移除相邻且相互抵消的格式转换
bool changed = true;
while (changed) {
changed = false;
auto ops = graph->GetAllOps();
for (auto& op : ops) {
if (op->GetType() != "FormatTransform") continue;
auto consumers = op->GetOutputs();
for (auto& consumer : consumers) {
if (consumer->GetType() == "FormatTransform") {
// 检查是否可以抵消
Format src_fmt = op->GetInputDesc(0).GetFormat();
Format mid_fmt = op->GetOutputDesc(0).GetFormat();
Format dst_fmt = consumer->GetOutputDesc(0).GetFormat();
if (src_fmt == dst_fmt) {
// 移除这两个转换
graph->BypassOp(op);
graph->BypassOp(consumer);
changed = true;
break;
}
}
}
if (changed) break;
}
}
}
OpPtr CreateFormatTransformOp(
const OpPtr& input,
Format src_format,
Format dst_format) {
std::string name = "FormatTransform_" + src_format + "_to_" + dst_format;
auto trans_op = std::make_shared<Op>(name, "FormatTransform");
trans_op->AddInputDesc(input->GetOutputDesc(0));
TensorDesc output_desc = input->GetOutputDesc(0);
output_desc.SetFormat(dst_format);
trans_op->AddOutputDesc(output_desc);
trans_op->SetAttr("src_format", static_cast<int>(src_format));
trans_op->SetAttr("dst_format", static_cast<int>(dst_format));
return trans_op;
}
};
/**
* @brief 循环展开Pass
*/
class LoopUnrollingPass : public Pass {
public:
LoopUnrollingPass(int max_unroll_factor = 4)
: Pass("LoopUnrolling"), max_unroll_factor_(max_unroll_factor) {}
Status Run(Graph* graph) override {
// 查找循环结构
std::vector<LoopSubgraph> loops;
FindLoops(graph, &loops);
// 尝试展开小循环
for (const auto& loop : loops) {
if (ShouldUnroll(loop)) {
UnrollLoop(graph, loop);
}
}
return Status::OK();
}
PassType GetType() const override { return PASS_GRAPH_TRANSFORM; }
private:
struct LoopSubgraph {
OpPtr loop_var;
OpPtr loop_body;
int trip_count;
std::vector<OpPtr> body_ops;
};
void FindLoops(Graph* graph, std::vector<LoopSubgraph>* loops) {
// 识别循环结构(简化示例)
for (auto& op : graph->GetAllOps()) {
if (op->GetType() == "Loop" || op->GetType() == "While") {
LoopSubgraph loop;
loop.loop_var = op;
// 提取循环体...
loops->push_back(loop);
}
}
}
bool ShouldUnroll(const LoopSubgraph& loop) {
// 展开条件:
// 1. 循环次数较小且已知
// 2. 循环体较小
// 3. 展开后不会导致代码膨胀过多
return loop.trip_count > 0 &&
loop.trip_count <= max_unroll_factor_ &&
loop.body_ops.size() <= 10;
}
void UnrollLoop(Graph* graph, const LoopSubgraph& loop) {
// 复制循环体trip_count次
for (int i = 0; i < loop.trip_count; ++i) {
for (auto& op : loop.body_ops) {
auto cloned_op = CloneOp(op);
graph->AddOp(cloned_op);
// 重新连接数据边...
}
}
// 移除原始循环
graph->RemoveOp(loop.loop_var);
}
OpPtr CloneOp(const OpPtr& op) {
auto cloned = std::make_shared<Op>(
op->GetName() + "_clone",
op->GetType()
);
// 复制描述和属性
return cloned;
}
int max_unroll_factor_;
};
3.3 图分区和分布式编译
cpp
/**
* @brief 图分区Pass
* 将计算图分区以支持分布式训练
*/
class GraphPartitionPass : public Pass {
public:
GraphPartitionPass(int num_devices = 4)
: Pass("GraphPartition"), num_devices_(num_devices) {}
Status Run(Graph* graph) override {
// 分析计算代价
std::unordered_map<OpPtr, Cost> costs;
AnalyzeCosts(graph, &costs);
// 执行图分区
std::vector<GraphPartition> partitions;
PartitionGraph(graph, costs, &partitions);
// 插入通信算子
InsertCommunicationOps(graph, partitions);
// 为每个分区生成独立的子图
GeneratePartitionedGraphs(graph, partitions);
return Status::OK();
}
PassType GetType() const override { return PASS_PARTITION; }
private:
struct Cost {
float compute_cost; // 计算代价
float memory_cost; // 内存代价
float comm_cost; // 通信代价
};
struct GraphPartition {
int device_id;
std::vector<OpPtr> ops;
std::vector<OpPtr> send_ops;
std::vector<OpPtr> recv_ops;
};
void AnalyzeCosts(
Graph* graph,
std::unordered_map<OpPtr, Cost>* costs) {
for (auto& op : graph->GetAllOps()) {
Cost cost;
// 估算计算代价(根据FLOPs)
cost.compute_cost = EstimateFLOPs(op);
// 估算内存代价
cost.memory_cost = EstimateMemoryUsage(op);
// 通信代价(跨分区)
cost.comm_cost = 0.0f; // 初始为0,分区后计算
(*costs)[op] = cost;
}
}
float EstimateFLOPs(const OpPtr& op) {
// 根据算子类型和输入形状估算FLOPs
std::string type = op->GetType();
if (type == "Conv2D") {
auto output_shape = op->GetOutputDesc(0).GetShape();
auto kernel_shape = op->GetInputDesc(1).GetShape();
int64_t N = output_shape.GetDim(0);
int64_t C = output_shape.GetDim(1);
int64_t H = output_shape.GetDim(2);
int64_t W = output_shape.GetDim(3);
int64_t K = kernel_shape.GetDim(2);
int64_t K_w = kernel_shape.GetDim(3);
return N * C * H * W * K * K_w;
} else if (type == "MatMul") {
auto shape_a = op->GetInputDesc(0).GetShape();
auto shape_b = op->GetInputDesc(1).GetShape();
int64_t M = shape_a.GetDim(0);
int64_t K = shape_a.GetDim(1);
int64_t N = shape_b.GetDim(1);
return M * K * N;
}
return 1000.0f; // 默认代价
}
float EstimateMemoryUsage(const OpPtr& op) {
float total = 0.0f;
for (uint32_t i = 0; i < op->GetInputDescNum(); ++i) {
total += op->GetInputDesc(i).GetBytes();
}
for (uint32_t i = 0; i < op->GetOutputDescNum(); ++i) {
total += op->GetOutputDesc(i).GetBytes();
}
return total;
}
void PartitionGraph(
Graph* graph,
const std::unordered_map<OpPtr, Cost>& costs,
std::vector<GraphPartition>* partitions) {
// 使用简单的贪心算法分区
// 实际可以使用更复杂的算法如Kernighan-Lin
partitions->resize(num_devices_);
std::vector<float> device_loads(num_devices_, 0.0f);
// 按计算代价排序算子
auto ops = graph->GetAllOps();
std::sort(ops.begin(), ops.end(),
[&costs](const OpPtr& a, const OpPtr& b) {
return costs.at(a).compute_cost > costs.at(b).compute_cost;
});
// 贪心分配
for (auto& op : ops) {
// 找到当前负载最小的设备
int min_device = 0;
float min_load = device_loads[0];
for (int i = 1; i < num_devices_; ++i) {
if (device_loads[i] < min_load) {
min_load = device_loads[i];
min_device = i;
}
}
(*partitions)[min_device].ops.push_back(op);
device_loads[min_device] += costs.at(op).compute_cost;
}
}
void InsertCommunicationOps(
Graph* graph,
std::vector<GraphPartition>& partitions) {
// 分析跨分区边
for (size_t i = 0; i < partitions.size(); ++i) {
for (auto& op : partitions[i].ops) {
auto outputs = op->GetOutputs();
for (auto& output : outputs) {
// 检查输出是否在不同分区
int output_partition = FindPartition(output, partitions);
if (output_partition != i && output_partition >= 0) {
// 需要插入Send/Recv
auto send_op = CreateSendOp(op, output, i, output_partition);
auto recv_op = CreateRecvOp(op, output, i, output_partition);
partitions[i].send_ops.push_back(send_op);
partitions[output_partition].recv_ops.push_back(recv_op);
graph->InsertOpAfter(op, send_op, 0);
graph->InsertOpBefore(output, recv_op, 0);
}
}
}
}
}
int FindPartition(
const OpPtr& op,
const std::vector<GraphPartition>& partitions) {
for (size_t i = 0; i < partitions.size(); ++i) {
if (std::find(partitions[i].ops.begin(),
partitions[i].ops.end(), op)
!= partitions[i].ops.end()) {
return i;
}
}
return -1;
}
OpPtr CreateSendOp(
const OpPtr& src,
const OpPtr& dst,
int src_device,
int dst_device) {
std::string name = "Send_" + src->GetName() + "_to_" + dst->GetName();
auto send_op = std::make_shared<Op>(name, "Send");
send_op->AddInputDesc(src->GetOutputDesc(0));
send_op->AddOutputDesc(src->GetOutputDesc(0));
send_op->SetAttr("src_rank", src_device);
send_op->SetAttr("dst_rank", dst_device);
return send_op;
}
OpPtr CreateRecvOp(
const OpPtr& src,
const OpPtr& dst,
int src_device,
int dst_device) {
std::string name = "Recv_" + src->GetName() + "_from_" + dst->GetName();
auto recv_op = std::make_shared<Op>(name, "Recv");
recv_op->AddInputDesc(src->GetOutputDesc(0));
recv_op->AddOutputDesc(src->GetOutputDesc(0));
recv_op->SetAttr("src_rank", src_device);
recv_op->SetAttr("dst_rank", dst_device);
return recv_op;
}
void GeneratePartitionedGraphs(
Graph* graph,
const std::vector<GraphPartition>& partitions) {
for (size_t i = 0; i < partitions.size(); ++i) {
Graph subgraph("Partition_" + std::to_string(i));
subgraph.SetDeviceId(i);
for (auto& op : partitions[i].ops) {
subgraph.AddOp(op);
}
for (auto& op : partitions[i].send_ops) {
subgraph.AddOp(op);
}
for (auto& op : partitions[i].recv_ops) {
subgraph.AddOp(op);
}
partitioned_graphs_.push_back(subgraph);
}
}
int num_devices_;
std::vector<Graph> partitioned_graphs_;
};
四、性能优化技巧
4.1 编译时优化配置
cpp
// 配置高级优化选项
GraphCompiler::CompileOptions options;
// 启用所有优化
options.opt_level = 3; // 最高优化级别
// 启用特定优化
options.enable_fusion = true;
options.enable_precision_mode = true; // 允许FP16
options.enable_const_folding = true;
options.enable_dead_code_elimination = true;
// 目标设备配置
options.target_device = "NPU";
options.target_core_type = "AI_CORE";
// 内存优化
options.enable_memory_optimize = true;
options.memory_optimize_level = 2;
// 并行配置
options.enable_multi_stream = true;
options.stream_num = 4;
4.2 融合策略配置
cpp
// 自定义融合策略
FusionStrategy strategy;
// 启用的融合模式
strategy.enable_conv_bn_relu = true;
strategy.enable_conv_bias_add_relu = true;
strategy.enable_matmul_add_relu = true;
strategy.enable_batch_norm_act = true;
// 融合限制
strategy.max_fused_ops = 10;
strategy.max_fusion_depth = 5;
// 设置融合策略
optimizer.SetFusionStrategy(strategy);
五、总结
GE图引擎是CANN框架的核心组件,通过多层IR和模块化Pass架构实现了灵活高效的图优化。从算子融合、内存优化到图分区,GE提供了全方位的优化能力,使得深度学习模型能够在NPU上获得最佳性能。
相关链接:
- CANN组织链接:https://atomgit.com/cann
- GE仓库链接:https://atomgit.com/cann/ge