引言
GE(Graph Engine)是CANN的核心图编译器和执行器,负责将深度学习计算图转换为高效的硬件执行计划。本文将深入解析GE的架构设计、图优化技术以及模型编译流程,帮助开发者理解和应用GE进行模型性能优化。
相关链接:
- CANN组织链接:https://atomgit.com/cann
- ge仓库链接:https://atomgit.com/cann/ge
一、GE图引擎架构概述
1.1 核心架构
GE采用分层架构设计,自上而下包含以下核心组件:
┌─────────────────────────────────────────────────────┐
│ 前端接口层 (Frontend Interface) │
│ PyTorch Frontend | TensorFlow Frontend | ONNX │
├─────────────────────────────────────────────────────┤
│ 图构建层 (Graph Construction) │
│ IR Builder | Graph Parser | Validator │
├─────────────────────────────────────────────────────┤
│ 图优化层 (Graph Optimization) │
│ Pass Manager | Optimization Passes │
├──────────────────┬──────────────────┬───────────────┤
│ 算子融合 │ 内存优化 │ 调度优化 │
├──────────────────┴──────────────────┴───────────────┤
│ 代码生成层 (Code Generation) │
│ Kernel Selection | Binary Generation │
├─────────────────────────────────────────────────────┤
│ 执行引擎层 (Execution Engine) │
│ Stream Scheduler | Memory Manager | Executor │
└─────────────────────────────────────────────────────┘
1.2 主要功能模块
| 模块 | 功能描述 | 核心能力 |
|---|---|---|
| 图构建 | 解析前端模型,构建计算图IR | 多框架支持、类型推导 |
| 图优化 | 应用各种优化Pass提升性能 | 算子融合、常量折叠、死代码消除 |
| 内存管理 | 优化内存分配和复用 | 内存复用、inplace优化 |
| 调度优化 | 优化算子执行顺序 | 多流并行、异步执行 |
| 代码生成 | 生成最终的可执行代码 | Kernel选择、二进制生成 |
1.3 目录结构
ge/
├── include/
│ ├── ge/ge_api.h # 公共API头文件
│ ├── ge/graph.h # 计算图定义
│ ├── ge/ir_builder.h # IR构建接口
│ └── ge/pass_manager.h # Pass管理器
├── src/
│ ├── ir/ # 中间表示定义
│ │ ├── node.cc
│ │ ├── edge.cc
│ │ └── graph.cc
│ ├── builder/ # 图构建
│ │ ├── ir_builder.cc
│ │ └── graph_builder.cc
│ ├── passes/ # 优化Pass
│ │ ├── fusion/
│ │ │ ├── conv_bn_fusion.cc
│ │ │ └── matmul_add_fusion.cc
│ │ ├── memory/
│ │ │ ├── memory_optimize.cc
│ │ │ └── inplace_opt.cc
│ │ └── transform/
│ │ ├── const_fold.cc
│ │ └── dead_elim.cc
│ ├── scheduler/ # 调度器
│ │ ├── multi_stream_sched.cc
│ │ └── memory_scheduler.cc
│ └── executor/ # 执行引擎
│ ├── graph_executor.cc
│ └── stream_executor.cc
├── tests/
├── examples/
└── docs/
二、计算图构建
2.1 基础图构建
cpp
#include <memory>
#include "ge/ge_api.h"
#include "ge/ir_builder.h"
using namespace ge;
class GraphBuilder {
public:
// 创建简单的计算图
GraphPtr BuildSimpleGraph() {
// 创建计算图
auto graph = Graph::Make("SimpleGraph");
// 创建数据节点
auto data1 = OpData::Make("data1", DT_FLOAT, {1, 3, 224, 224});
auto data2 = OpData::Make("data2", DT_FLOAT, {64, 3, 7, 7});
auto data3 = OpData::Make("data3", DT_FLOAT, {64});
// 创建卷积算子
auto conv = Op::Make("Conv2D");
conv->SetAttr("strides", std::vector<int64_t>{1, 1, 1, 1});
conv->SetAttr("pads", std::vector<int64_t>{0, 0, 0, 0});
conv->SetAttr("dilations", std::vector<int64_t>{1, 1, 1, 1});
conv->SetAttr("groups", 1);
conv->SetAttr("data_format", "NCHW");
// 创建BiasAdd算子
auto bias_add = Op::Make("BiasAdd");
bias_add->SetAttr("data_format", "NCHW");
// 创建Relu算子
auto relu = Op::Make("Relu");
// 创建Pooling算子
auto pool = Op::Make("MaxPool");
pool->SetAttr("ksize", std::vector<int64_t>{1, 1, 2, 2});
pool->SetAttr("strides", std::vector<int64_t>{1, 1, 2, 2});
pool->SetAttr("pads", std::vector<int64_t>{0, 0, 0, 0});
pool->SetAttr("data_format", "NCHW");
// 建立连接关系
auto conv_out = graph->AddNode(conv);
auto bias_out = graph->AddNode(bias_add);
auto relu_out = graph->AddNode(relu);
auto pool_out = graph->AddNode(pool);
// 添加数据节点到图
graph->AddDataNode(data1);
graph->AddDataNode(data2);
graph->AddDataNode(data3);
// 建立边连接
graph->AddDataEdge(data1, 0, conv_out, 0); // data1 -> conv
graph->AddDataEdge(data2, 0, conv_out, 1); // data2 -> conv (weight)
graph->AddDataEdge(conv_out, 0, bias_out, 0); // conv -> bias_add
graph->AddDataEdge(data3, 0, bias_out, 1); // data3 -> bias_add (bias)
graph->AddDataEdge(bias_out, 0, relu_out, 0); // bias_add -> relu
graph->AddDataEdge(relu_out, 0, pool_out, 0); // relu -> pool
// 设置输出
graph->SetOutput({pool_out});
return graph;
}
// 使用IR Builder构建图
GraphPtr BuildGraphWithIRBuilder() {
// 创建IR Builder
auto builder = IRBuilder::Create("MyGraph");
// 添加输入
auto input = builder->AddInput("input", DT_FLOAT, {1, 3, 224, 224});
// 添加常量(权重)
auto weight = builder->AddConstant("weight", DT_FLOAT, {64, 3, 7, 7},
InitWeightData());
auto bias = builder->AddConstant("bias", DT_FLOAT, {64},
InitBiasData());
// 构建卷积层
auto conv = builder->AddOp("Conv2D", {input, weight});
builder->SetAttr(conv, "strides", std::vector<int64_t>{1, 1, 1, 1});
builder->SetAttr(conv, "pads", std::vector<int64_t>{0, 0, 0, 0});
// 构建BiasAdd
auto bias_add = builder->AddOp("BiasAdd", {conv, bias});
// 构建ReLU
auto relu = builder->AddOp("Relu", {bias_add});
// 构建MaxPool
auto pool = builder->AddOp("MaxPool", {relu});
builder->SetAttr(pool, "ksize", std::vector<int64_t>{1, 1, 2, 2});
builder->SetAttr(pool, "strides", std::vector<int64_t>{1, 1, 2, 2});
// 设置输出
builder->SetOutput({pool});
return builder->Build();
}
private:
std::vector<float> InitWeightData() {
return std::vector<float>(64 * 3 * 7 * 7, 0.1f);
}
std::vector<float> InitBiasData() {
return std::vector<float>(64, 0.0f);
}
};
2.2 从ONNX构建图
cpp
#include "ge/onnx_parser.h"
class ONNXGraphImporter {
public:
GraphPtr ImportFromONNX(const std::string& onnx_path) {
// 创建ONNX解析器
auto parser = ONNXParser::Create();
// 配置解析选项
ParseOptions options;
options.enable_fusion = true;
options.enable_shape_inference = true;
options.enable_const_fold = true;
// 解析ONNX模型
auto parse_result = parser->Parse(onnx_path, options);
if (!parse_result.success) {
std::cerr << "Failed to parse ONNX: " << parse_result.error_msg << std::endl;
return nullptr;
}
// 获取计算图
auto graph = parse_result.graph;
// 打印图信息
PrintGraphInfo(graph);
return graph;
}
// 使用形状推断优化图
GraphPtr ImportWithShapeInference(const std::string& onnx_path,
const std::map<std::string, Shape>& input_shapes) {
auto parser = ONNXParser::Create();
// 设置输入形状
parser->SetInputShapes(input_shapes);
// 启用形状推断
ParseOptions options;
options.enable_shape_inference = true;
options.infer_dynamic_shape = true;
auto result = parser->Parse(onnx_path, options);
return result.graph;
}
private:
void PrintGraphInfo(GraphPtr graph) {
std::cout << "=== Graph Information ===" << std::endl;
std::cout << "Graph name: " << graph->GetName() << std::endl;
std::cout << "Number of nodes: " << graph->GetNodeCount() << std::endl;
std::cout << "Number of inputs: " << graph->GetInputCount() << std::endl;
std::cout << "Number of outputs: " << graph->GetOutputCount() << std::endl;
// 打印输入输出信息
auto inputs = graph->GetInputs();
std::cout << "Inputs:" << std::endl;
for (const auto& input : inputs) {
auto shape = input->GetShape();
std::cout << " " << input->GetName() << ": [";
for (size_t i = 0; i < shape.size(); ++i) {
std::cout << shape[i];
if (i < shape.size() - 1) std::cout << ", ";
}
std::cout << "]" << std::endl;
}
}
};
三、图优化技术
3.1 算子融合优化
cpp
#include "ge/pass_manager.h"
class FusionOptimization {
public:
// Conv+BN融合
class ConvBNFusionPass : public GraphPass {
public:
bool Run(GraphPtr graph) override {
bool modified = false;
// 遍历图中的所有Conv节点
auto conv_nodes = graph->GetNodesByType("Conv2D");
for (auto conv_node : conv_nodes) {
// 查找Conv后的BatchNorm
auto bn_node = FindSuccessorBatchNorm(conv_node);
if (bn_node == nullptr) continue;
// 检查是否可以融合
if (!CanFuseConvBN(conv_node, bn_node)) continue;
// 执行融合
FuseConvBN(conv_node, bn_node);
modified = true;
std::cout << "Fused Conv+BN: " << conv_node->GetName()
<< " + " << bn_node->GetName() << std::endl;
}
return modified;
}
private:
NodePtr FindSuccessorBatchNorm(NodePtr conv_node) {
auto outputs = conv_node->GetOutNodes();
for (auto output : outputs) {
if (output->GetType() == "BatchNormalization") {
return output;
}
}
return nullptr;
}
bool CanFuseConvBN(NodePtr conv, NodePtr bn) {
// 检查是否只有一个使用者
if (conv->GetOutNodes().size() != 1) return false;
// 检查训练模式(推理模式才能融合)
std::string training_mode;
if (!bn->GetAttr("training", training_mode)) return false;
if (training_mode == "True") return false;
return true;
}
void FuseConvBN(NodePtr conv, NodePtr bn) {
// 获取Conv权重和偏置
auto conv_weight = GetNodeWeight(conv, "filter");
auto conv_bias = GetNodeBias(conv);
// 获取BN参数
auto bn_gamma = GetNodeWeight(bn, "scale");
auto bn_beta = GetNodeWeight(bn, "bias");
auto bn_mean = GetNodeWeight(bn, "mean");
auto bn_var = GetNodeWeight(bn, "variance");
float epsilon = 1e-5;
bn->GetAttr("epsilon", epsilon);
// 计算融合后的权重和偏置
auto fused_weight = ComputeFusedWeight(conv_weight, bn_gamma, bn_mean, bn_var, epsilon);
auto fused_bias = ComputeFusedBias(conv_bias, conv_weight, bn_gamma, bn_beta, bn_mean, bn_var, epsilon);
// 更新Conv节点
UpdateNodeWeight(conv, "filter", fused_weight);
if (conv_bias) {
UpdateNodeBias(conv, fused_bias);
} else {
SetNodeBias(conv, fused_bias);
}
// 替换BN的输出为Conv的输出
ReplaceNode(bn, conv);
}
std::vector<float> ComputeFusedWeight(const std::vector<float>& conv_w,
const std::vector<float>& gamma,
const std::vector<float>& mean,
const std::vector<float>& var,
float eps) {
// fused_weight = conv_w * gamma / sqrt(var + eps)
std::vector<float> fused(conv_w.size());
float sqrt_var = std::sqrt(var[0] + eps);
float scale = gamma[0] / sqrt_var;
for (size_t i = 0; i < conv_w.size(); ++i) {
fused[i] = conv_w[i] * scale;
}
return fused;
}
};
// MatMul+Add融合
class MatMulAddFusionPass : public GraphPass {
public:
bool Run(GraphPtr graph) override {
auto matmul_nodes = graph->GetNodesByType("MatMul");
bool modified = false;
for (auto matmul_node : matmul_nodes) {
auto add_node = FindSuccessorAdd(matmul_node);
if (add_node == nullptr) continue;
// 检查Add的另一个输入是否是常量
auto add_input = GetOtherInput(add_node, matmul_node);
if (!IsConstant(add_input)) continue;
// 执行融合:将bias添加到MatMul的输出
FuseMatMulAdd(matmul_node, add_node, add_input);
modified = true;
}
return modified;
}
private:
NodePtr FindSuccessorAdd(NodePtr matmul) {
auto outputs = matmul->GetOutNodes();
for (auto output : outputs) {
if (output->GetType() == "Add") {
return output;
}
}
return nullptr;
}
NodePtr GetOtherInput(NodePtr add_node, NodePtr matmul_node) {
auto inputs = add_node->GetInNodes();
for (auto input : inputs) {
if (input != matmul_node) {
return input;
}
}
return nullptr;
}
bool IsConstant(NodePtr node) {
return node->GetType() == "Constant";
}
void FuseMatMulAdd(NodePtr matmul, NodePtr add, NodePtr bias) {
// 获取bias数据
auto bias_data = GetNodeWeight(bias);
// 在MatMul节点上添加bias属性
matmul->SetAttr("has_bias", true);
matmul->SetAttr("bias", bias_data);
// 替换Add为MatMul
ReplaceNode(add, matmul);
}
};
};
3.2 内存优化
cpp
class MemoryOptimization {
public:
// 内存复用优化
class MemoryReusePass : public GraphPass {
public:
bool Run(GraphPtr graph) override {
// 构建内存活跃性分析
auto live_ranges = AnalyzeLiveRanges(graph);
// 执行内存复用
auto reuse_plan = ComputeMemoryReuse(live_ranges);
// 应用复用计划
ApplyMemoryReuse(graph, reuse_plan);
return true;
}
private:
struct LiveRange {
int start;
int end;
size_t size;
};
std::map<NodePtr, LiveRange> AnalyzeLiveRanges(GraphPtr graph) {
std::map<NodePtr, LiveRange> live_ranges;
// 拓扑排序
auto sorted_nodes = TopologicalSort(graph);
int schedule_id = 0;
for (auto node : sorted_nodes) {
// 获取输出tensor大小
size_t tensor_size = EstimateTensorSize(node);
LiveRange range;
range.start = schedule_id;
range.end = schedule_id; // 将在后续更新
range.size = tensor_size;
live_ranges[node] = range;
schedule_id++;
}
// 计算每个tensor的最后使用点
ComputeLastUses(live_ranges, graph);
return live_ranges;
}
void ComputeLastUses(std::map<NodePtr, LiveRange>& live_ranges, GraphPtr graph) {
// 反向遍历,更新每个tensor的end
auto sorted_nodes = TopologicalSort(graph);
std::reverse(sorted_nodes.begin(), sorted_nodes.end());
std::map<NodePtr, int> last_uses;
for (auto node : sorted_nodes) {
// 检查输入
auto inputs = node->GetInNodes();
for (auto input : inputs) {
if (live_ranges.find(input) != live_ranges.end()) {
if (last_uses.find(input) == last_uses.end()) {
last_uses[input] = node->GetId();
}
}
}
}
// 更新live_ranges的end
for (auto& [node, range] : live_ranges) {
if (last_uses.find(node) != last_uses.end()) {
range.end = last_uses[node];
}
}
}
std::map<NodePtr, void*> ComputeMemoryReuse(const std::map<NodePtr, LiveRange>& live_ranges) {
std::map<NodePtr, void*> reuse_plan;
// 简化策略:贪心算法
std::vector<void*> memory_blocks;
for (const auto& [node, range] : live_ranges) {
bool reused = false;
// 查找可复用的内存块
for (auto block : memory_blocks) {
if (CanReuseBlock(block, range, live_ranges)) {
reuse_plan[node] = block;
reused = true;
break;
}
}
// 如果没有可复用的,分配新的
if (!reused) {
void* new_block = AllocateMemory(range.size);
memory_blocks.push_back(new_block);
reuse_plan[node] = new_block;
}
}
return reuse_plan;
}
bool CanReuseBlock(void* block, const LiveRange& range,
const std::map<NodePtr, LiveRange>& live_ranges) {
// 检查该块在range区间内是否被使用
for (const auto& [node, other_range] : live_ranges) {
if (other_range.start >= range.start && other_range.start <= range.end) {
return false;
}
}
return true;
}
};
// Inplace优化
class InplaceOptimizationPass : public GraphPass {
public:
bool Run(GraphPtr graph) override {
// 查找可以进行inplace操作的节点
auto candidates = FindInplaceCandidates(graph);
// 应用inplace优化
for (auto candidate : candidates) {
ApplyInplace(graph, candidate);
}
return !candidates.empty();
}
private:
std::vector<NodePtr> FindInplaceCandidates(GraphPtr graph) {
std::vector<NodePtr> candidates;
// 查找ReLU等可inplace的算子
auto relu_nodes = graph->GetNodesByType("Relu");
for (auto relu : relu_nodes) {
if (CanInplace(relu)) {
candidates.push_back(relu);
}
}
return candidates;
}
bool CanInplace(NodePtr node) {
// 检查输出是否只有一个使用者
auto out_nodes = node->GetOutNodes();
if (out_nodes.size() != 1) return false;
// 检查输入输出形状是否一致
auto inputs = node->GetInNodes();
if (inputs.empty()) return false;
auto input_shape = inputs[0]->GetShape();
auto output_shape = node->GetShape();
return input_shape == output_shape;
}
void ApplyInplace(GraphPtr graph, NodePtr node) {
// 设置inplace属性
node->SetAttr("inplace", true);
// 修改内存复用关系
auto inputs = node->GetInNodes();
if (!inputs.empty()) {
// 输出复用输入的内存
node->SetAttr("reuse_input", inputs[0]->GetName());
}
}
};
};
3.3 常量折叠优化
cpp
class ConstantFoldingPass : public GraphPass {
public:
bool Run(GraphPtr graph) override {
bool modified = false;
// 获取所有常量节点
auto constants = graph->GetNodesByType("Constant");
// 构建常量值缓存
std::map<NodePtr, Tensor> constant_values;
for (auto const_node : constants) {
constant_values[const_node] = EvaluateConstant(const_node);
}
// 遍历所有节点,尝试常量折叠
auto nodes = graph->GetNodes();
for (auto node : nodes) {
if (TryFoldNode(node, constant_values)) {
modified = true;
}
}
// 清理无用的常量节点
if (modified) {
RemoveUnusedConstants(graph);
}
return modified;
}
private:
bool TryFoldNode(NodePtr node, std::map<NodePtr, Tensor>& constant_values) {
// 检查所有输入是否都是常量
auto inputs = node->GetInNodes();
for (auto input : inputs) {
if (constant_values.find(input) == constant_values.end()) {
return false;
}
}
// 所有输入都是常量,可以折叠
auto folded_result = EvaluateNode(node, constant_values);
// 创建新的常量节点
auto new_const = CreateConstantNode(node->GetName() + "_folded",
folded_result);
// 替换原节点
ReplaceNode(node, new_const);
// 更新常量值缓存
constant_values[new_const] = folded_result;
return true;
}
Tensor EvaluateNode(NodePtr node, const std::map<NodePtr, Tensor>& inputs) {
auto op_type = node->GetType();
if (op_type == "Add") {
return EvaluateAdd(node, inputs);
} else if (op_type == "Mul") {
return EvaluateMul(node, inputs);
} else if (op_type == "Sub") {
return EvaluateSub(node, inputs);
} else if (op_type == "Div") {
return EvaluateDiv(node, inputs);
}
// 不支持的算子
return Tensor();
}
Tensor EvaluateAdd(NodePtr node, const std::map<NodePtr, Tensor>& inputs) {
auto in_nodes = node->GetInNodes();
auto input1 = inputs.at(in_nodes[0]);
auto input2 = inputs.at(in_nodes[1]);
Tensor result = input1;
for (size_t i = 0; i < result.GetDataCount(); ++i) {
result.GetData<float>()[i] += input2.GetData<float>()[i];
}
return result;
}
Tensor EvaluateMul(NodePtr node, const std::map<NodePtr, Tensor>& inputs) {
auto in_nodes = node->GetInNodes();
auto input1 = inputs.at(in_nodes[0]);
auto input2 = inputs.at(in_nodes[1]);
Tensor result = input1;
for (size_t i = 0; i < result.GetDataCount(); ++i) {
result.GetData<float>()[i] *= input2.GetData<float>()[i];
}
return result;
}
void RemoveUnusedConstants(GraphPtr graph) {
auto constants = graph->GetNodesByType("Constant");
for (auto const_node : constants) {
if (const_node->GetOutNodes().empty()) {
graph->RemoveNode(const_node);
}
}
}
};
四、多流并行调度
4.1 流调度器
cpp
class MultiStreamScheduler {
public:
struct ScheduleConfig {
int num_streams = 4;
bool enable_pipeline = true;
int pipeline_depth = 3;
bool enable_memory_opt = true;
};
// 生成多流执行计划
ScheduleResult Schedule(GraphPtr graph, const ScheduleConfig& config) {
ScheduleResult result;
// 1. 图分区
auto partitions = PartitionGraph(graph, config.num_streams);
// 2. 依赖分析
auto dependencies = AnalyzeDependencies(partitions);
// 3. 生成执行计划
if (config.enable_pipeline) {
result = GeneratePipelineSchedule(partitions, dependencies, config);
} else {
result = GenerateParallelSchedule(partitions, dependencies);
}
// 4. 内存优化
if (config.enable_memory_opt) {
OptimizeMemoryUsage(result);
}
return result;
}
private:
// 图分区
std::vector<GraphPartition> PartitionGraph(GraphPtr graph, int num_partitions) {
std::vector<GraphPartition> partitions(num_partitions);
// 简化策略:按层级分配
auto levels = ComputeGraphLevels(graph);
for (size_t i = 0; i < levels.size(); ++i) {
int partition_id = i % num_partitions;
for (auto node : levels[i]) {
partitions[partition_id].AddNode(node);
}
}
return partitions;
}
// 计算图的层级
std::vector<std::vector<NodePtr>> ComputeGraphLevels(GraphPtr graph) {
std::map<NodePtr, int> node_levels;
std::vector<NodePtr> sorted_nodes = TopologicalSort(graph);
// 计算每个节点的层级
for (auto node : sorted_nodes) {
int max_input_level = -1;
for (auto input : node->GetInNodes()) {
if (node_levels.find(input) != node_levels.end()) {
max_input_level = std::max(max_input_level, node_levels[input]);
}
}
node_levels[node] = max_input_level + 1;
}
// 按层级分组
std::map<int, std::vector<NodePtr>> level_map;
for (const auto& [node, level] : node_levels) {
level_map[level].push_back(node);
}
std::vector<std::vector<NodePtr>> levels;
for (const auto& [level, nodes] : level_map) {
levels.push_back(nodes);
}
return levels;
}
// 生成流水线调度
ScheduleResult GeneratePipelineSchedule(const std::vector<GraphPartition>& partitions,
const DependencyMap& dependencies,
const ScheduleConfig& config) {
ScheduleResult result;
result.num_streams = config.num_streams;
// 为每个分区创建Stream
for (size_t i = 0; i < partitions.size(); ++i) {
StreamSchedule stream_sched;
stream_sched.stream_id = i;
stream_sched.nodes = partitions[i].GetNodes();
// 设置与其他Stream的依赖
for (size_t j = 0; j < partitions.size(); ++j) {
if (i != j && HasDependency(partitions[i], partitions[j], dependencies)) {
stream_sched.dependencies.push_back(j);
}
}
result.stream_schedules.push_back(stream_sched);
}
return result;
}
};
4.2 异步执行
cpp
class AsyncGraphExecutor {
public:
struct ExecutionConfig {
int num_streams = 4;
bool enable_profiling = false;
int queue_depth = 16;
};
void ExecuteAsync(GraphPtr graph, const ExecutionConfig& config) {
// 1. 创建执行计划
MultiStreamScheduler scheduler;
auto schedule = scheduler.Schedule(graph, {config.num_streams});
// 2. 初始化执行环境
InitializeExecution(schedule, config);
// 3. 启动异步执行
for (const auto& stream_sched : schedule.stream_schedules) {
ExecuteStreamAsync(stream_sched);
}
// 4. 等待完成
WaitForCompletion();
}
private:
void ExecuteStreamAsync(const StreamSchedule& stream_sched) {
auto stream = streams_[stream_sched.stream_id];
// 等待依赖Stream完成
for (int dep_stream_id : stream_sched.dependencies) {
auto dep_event = stream_events_[dep_stream_id];
aclrtStreamWaitEvent(stream, dep_event);
}
// 执行该Stream的所有节点
for (auto node : stream_sched.nodes) {
ExecuteNodeAsync(node, stream);
}
// 记录完成Event
aclrtEvent event;
aclrtCreateEvent(&event);
aclrtRecordEvent(event, stream);
stream_events_[stream_sched.stream_id] = event;
}
void ExecuteNodeAsync(NodePtr node, aclrtStream stream) {
// 根据节点类型执行不同的操作
auto op_type = node->GetType();
if (op_type == "Conv2D") {
ExecuteConv2DAsync(node, stream);
} else if (op_type == "MatMul") {
ExecuteMatMulAsync(node, stream);
} else if (op_type == "Add") {
ExecuteAddAsync(node, stream);
} else {
// 通用算子执行
ExecuteOpAsync(node, stream);
}
}
std::vector<aclrtStream> streams_;
std::vector<aclrtEvent> stream_events_;
};
五、图编译与代码生成
5.1 编译流程
cpp
class GraphCompiler {
public:
struct CompileOptions {
bool enable_optimization = true;
bool enable_fusion = true;
int opt_level = 3; // 0-3
std::string target_arch = "latest";
bool generate_debug_info = false;
};
CompileResult Compile(GraphPtr graph, const CompileOptions& options) {
CompileResult result;
// 1. 图验证
if (!ValidateGraph(graph)) {
result.success = false;
result.error_msg = "Graph validation failed";
return result;
}
// 2. 图优化
auto optimized_graph = graph;
if (options.enable_optimization) {
optimized_graph = OptimizeGraph(graph, options);
}
// 3. 算子选择
auto kernel_selection = SelectKernels(optimized_graph, options.target_arch);
// 4. 代码生成
auto executable = GenerateCode(optimized_graph, kernel_selection, options);
// 5. 二进制生成
result.binary = GenerateBinary(executable);
result.success = true;
return result;
}
private:
GraphPtr OptimizeGraph(GraphPtr graph, const CompileOptions& options) {
auto pass_manager = PassManager::Create();
// 注册优化Pass
if (options.enable_fusion) {
pass_manager->RegisterPass(std::make_shared<ConvBNFusionPass>());
pass_manager->RegisterPass(std::make_shared<MatMulAddFusionPass>());
}
pass_manager->RegisterPass(std::make_shared<ConstantFoldingPass>());
pass_manager->RegisterPass(std::make_shared<MemoryReusePass>());
pass_manager->RegisterPass(std::make_shared<DeadCodeEliminationPass>());
// 执行优化
auto result = pass_manager->Run(graph);
return result.optimized_graph;
}
KernelSelection SelectKernels(GraphPtr graph, const std::string& arch) {
KernelSelection selection;
auto nodes = graph->GetNodes();
for (auto node : nodes) {
auto kernel = SelectBestKernel(node, arch);
selection[node->GetId()] = kernel;
}
return selection;
}
KernelInfo SelectBestKernel(NodePtr node, const std::string& arch) {
// 查询所有可用Kernel
auto available_kernels = QueryAvailableKernels(node->GetType(), arch);
// 根据性能指标选择最佳Kernel
KernelInfo best_kernel;
float best_score = -1.0f;
for (const auto& kernel : available_kernels) {
float score = EvaluateKernel(kernel, node);
if (score > best_score) {
best_score = score;
best_kernel = kernel;
}
}
return best_kernel;
}
ExecutableGraph GenerateCode(GraphPtr graph,
const KernelSelection& selection,
const CompileOptions& options) {
ExecutableGraph exec_graph;
// 为每个节点生成执行代码
auto nodes = graph->GetNodes();
for (auto node : nodes) {
auto kernel = selection.at(node->GetId());
auto exec_node = GenerateExecutableNode(node, kernel);
exec_graph.AddNode(exec_node);
}
// 生成执行计划
auto exec_plan = GenerateExecutionPlan(graph, selection);
exec_graph.SetExecutionPlan(exec_plan);
return exec_graph;
}
};
5.2 可执行图
cpp
class ExecutableGraph {
public:
struct ExecutableNode {
std::string name;
std::string kernel_name;
std::vector<void*> input_buffers;
std::vector<void*> output_buffers;
std::map<std::string, std::any> attributes;
};
void Execute() {
// 按照执行计划执行
for (const auto& node : execution_plan_) {
ExecuteNode(node);
}
}
void ExecuteAsync(aclrtStream stream) {
for (const auto& node : execution_plan_) {
ExecuteNodeAsync(node, stream);
}
}
private:
void ExecuteNode(const ExecutableNode& node) {
// 查找Kernel函数
auto kernel_func = GetKernelFunction(node.kernel_name);
// 准备参数
std::vector<void*> args;
for (auto buf : node.input_buffers) args.push_back(buf);
for (auto buf : node.output_buffers) args.push_back(buf);
// 调用Kernel
kernel_func(args.data(), args.size());
}
void ExecuteNodeAsync(const ExecutableNode& node, aclrtStream stream) {
// 异步执行Kernel
auto kernel_func = GetKernelFunction(node.kernel_name);
std::vector<void*> args;
for (auto buf : node.input_buffers) args.push_back(buf);
for (auto buf : node.output_buffers) args.push_back(buf);
kernel_func(args.data(), args.size(), stream);
}
std::vector<ExecutableNode> nodes_;
std::vector<ExecutableNode> execution_plan_;
};
六、性能分析与调试
6.1 性能分析工具
cpp
class GraphProfiler {
public:
void ProfileGraph(GraphPtr graph) {
// 1. 收集原始性能数据
auto raw_data = CollectPerformanceData(graph);
// 2. 分析性能瓶颈
auto bottlenecks = IdentifyBottlenecks(raw_data);
// 3. 生成优化建议
auto suggestions = GenerateOptimizationSuggestions(bottlenecks);
// 4. 输出分析报告
PrintProfileReport(raw_data, bottlenecks, suggestions);
}
private:
PerformanceData CollectPerformanceData(GraphPtr graph) {
PerformanceData data;
// 执行图并收集每个节点的执行时间
auto nodes = graph->GetNodes();
for (auto node : nodes) {
auto node_stats = ProfileNode(node);
data.node_stats[node->GetId()] = node_stats;
}
// 收集整体性能指标
data.total_time = CalculateTotalTime(data);
data.memory_usage = CalculateMemoryUsage(graph);
return data;
}
NodePerformanceStats ProfileNode(NodePtr node) {
NodePerformanceStats stats;
// 多次执行取平均
const int iterations = 100;
std::vector<float> times;
for (int i = 0; i < iterations; ++i) {
auto start = std::chrono::high_resolution_clock::now();
ExecuteNode(node);
auto end = std::chrono::high_resolution_clock::now();
float time_us = std::chrono::duration<float, std::micro>(end - start).count();
times.push_back(time_us);
}
// 计算统计数据
stats.avg_time = std::accumulate(times.begin(), times.end(), 0.0f) / iterations;
stats.min_time = *std::min_element(times.begin(), times.end());
stats.max_time = *std::max_element(times.begin(), times.end());
float variance = 0.0f;
for (auto t : times) {
variance += (t - stats.avg_time) * (t - stats.avg_time);
}
stats.std_dev = std::sqrt(variance / iterations);
return stats;
}
std::vector<BottleneckInfo> IdentifyBottlenecks(const PerformanceData& data) {
std::vector<BottleneckInfo> bottlenecks;
float total_time = data.total_time;
for (const auto& [node_id, stats] : data.node_stats) {
float time_ratio = stats.avg_time / total_time;
if (time_ratio > 0.1f) { // 超过10%总时间
BottleneckInfo info;
info.node_id = node_id;
info.time_ratio = time_ratio;
info.suggestion = GenerateBottleneckSuggestion(stats);
bottlenecks.push_back(info);
}
}
// 按时间占比排序
std::sort(bottlenecks.begin(), bottlenecks.end(),
[](const BottleneckInfo& a, const BottleneckInfo& b) {
return a.time_ratio > b.time_ratio;
});
return bottlenecks;
}
void PrintProfileReport(const PerformanceData& data,
const std::vector<BottleneckInfo>& bottlenecks,
const std::vector<OptimizationSuggestion>& suggestions) {
std::cout << "=== Graph Performance Profile Report ===" << std::endl;
std::cout << "Total execution time: " << data.total_time << " us" << std::endl;
std::cout << "Memory usage: " << data.memory_usage / (1024 * 1024) << " MB" << std::endl;
std::cout << "\n=== Top Bottlenecks ===" << std::endl;
for (size_t i = 0; i < std::min(bottlenecks.size(), size_t(5)); ++i) {
const auto& b = bottlenecks[i];
std::cout << i + 1 << ". Node " << b.node_id
<< " (" << b.time_ratio * 100 << "% time)" << std::endl;
std::cout << " " << b.suggestion << std::endl;
}
std::cout << "\n=== Optimization Suggestions ===" << std::endl;
for (const auto& s : suggestions) {
std::cout << "- " << s.description << std::endl;
std::cout << " Expected speedup: " << s.expected_speedup << "x" << std::endl;
}
}
};
七、总结
本文全面介绍了CANN GE图引擎的架构设计和优化技术,涵盖了:
- 图构建:从零构建计算图、从ONNX导入模型
- 图优化:算子融合、内存优化、常量折叠
- 并行调度:多流并行、流水线执行
- 代码生成:Kernel选择、可执行图生成
- 性能分析:瓶颈识别、优化建议
通过合理应用GE的优化技术,开发者可以显著提升模型执行效率,实现生产级的高性能推理系统。
相关链接:
- CANN组织链接:https://atomgit.com/cann
- ge仓库链接:https://atomgit.com/cann/ge