目录
- 如何在Triton中添加自定义转换Pass:以Linalg>EmitC为例
- MLIR/Triton中转换模式如何处理操作依赖关系
- MLIR转换过程中的中间态与IR演进
在Triton中添加自定义转换Pass:以Linalg>EmitC为例
1. 了解Triton的编译架构
Triton使用MLIR作为其编译基础设施,编译流程大致为:
- Triton IR (前端) → TritonGPU → SCF/Linalg/MemRef → LLVM → PTX
在添加转换pass前,需要明确你要在哪个阶段插入自定义转换。
2. 创建自定义转换Pass
2.1 创建必要的文件结构
首先在Triton代码库中创建适当的目录结构:
bash
# 在lib/Conversion下创建新目录
mkdir -p lib/Conversion/LinalgToEmitC
touch lib/Conversion/LinalgToEmitC/LinalgToEmitC.cpp
touch include/triton/Conversion/LinalgToEmitC/Passes.td
touch include/triton/Conversion/LinalgToEmitC/Passes.h
2.2 定义转换Pass (Passes.td)
在include/triton/Conversion/LinalgToEmitC/Passes.td中定义pass:
tablegen
#include "mlir/Dialect/Pass/PassBase.td"
def TritonLinalgToEmitC : Pass<"triton-linalg-to-emitc", "Operation"> {
let summary = "Convert Linalg operations to EmitC dialect";
let description = [{
This pass converts supported Linalg operations to EmitC operations
to facilitate C code generation.
}];
let constructor = "mlir::triton::createLinalgToEmitCPass()";
}
2.3 实现转换逻辑 (LinalgToEmitC.cpp)
cpp
#include "triton/Conversion/LinalgToEmitC/Passes.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
using namespace mlir;
namespace {
// 定义转换模式:Linalg.Generic -> EmitC.CallOp
class GenericOpConversion : public OpConversionPattern<linalg::GenericOp> {
public:
using OpConversionPattern<linalg::GenericOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// 1. 检查操作是否可转换
if (!op.hasTensorSemantics())
return failure();
// 2. 准备转换
Location loc = op.getLoc();
SmallVector<Value> operands = adaptor.getOperands();
// 3. 构建EmitC调用
// 这里需要具体实现如何将linalg.generic转换为emitc.call
// 通常需要解析iterator types, indexing maps等
// 4. 创建EmitC函数调用
auto funcType = FunctionType::get(
getContext(),
TypeRange(operands),
op.getResultTypes()
);
// 5. 创建调用操作
auto emitcCall = rewriter.create<emitc::CallOp>(
loc,
"custom_kernel_function", // C函数名
funcType.getResults(),
operands
);
rewriter.replaceOp(op, emitcCall.getResults());
return success();
}
};
// 主转换pass
struct LinalgToEmitCPass : public TritonLinalgToEmitCBase<LinalgToEmitCPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
RewritePatternSet patterns(context);
// 1. 配置转换目标
target.addLegalDialect<emitc::EmitCDialect, func::FuncDialect>();
target.addIllegalDialect<linalg::LinalgDialect>();
// 2. 添加转换模式
patterns.add<GenericOpConversion>(context);
// 3. 应用转换
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
// 创建pass实例
std::unique_ptr<Pass> mlir::triton::createLinalgToEmitCPass() {
return std::make_unique<LinalgToEmitCPass>();
}
2.4 暴露Pass API (Passes.h)
cpp
#ifndef TRITON_CONVERSION_LINALGTOEMITC_PASSES_H
#define TRITON_CONVERSION_LINALGTOEMITC_PASSES_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace triton {
std::unique_ptr<Pass> createLinalgToEmitCPass();
} // namespace triton
} // namespace mlir
#endif // TRITON_CONVERSION_LINALGTOEMITC_PASSES_H
3. 集成到构建系统
修改lib/Conversion/CMakeLists.txt,添加新的转换目录:
cmake
add_subdirectory(LinalgToEmitC)
在lib/Conversion/LinalgToEmitC/CMakeLists.txt中:
cmake
add_mlir_library(TritonLinalgToEmitC
LinalgToEmitC.cpp
DEPENDS
TritonConversionPassIncGen
LINK_LIBS PUBLIC
MLIREmitC
MLIRLinalg
MLIRPass
TritonIR
)
4. 在Triton编译流程中注册Pass
修改tools/triton-opt/triton-opt.cpp,添加你的pass:
cpp
#include "triton/Conversion/LinalgToEmitC/Passes.h"
void registerLinalgToEmitC() {
mlir::registerPass([]() -> std::unique_ptr<mlir::Pass> {
return mlir::triton::createLinalgToEmitCPass();
});
}
// 在main函数中调用注册
int main(int argc, char **argv) {
// ...
registerLinalgToEmitC();
// ...
}
5. 实现更复杂的转换模式
上面的示例只处理了linalg.generic,完整的转换需要处理更多操作,如:
linalg.matmullinalg.filllinalg.dotlinalg.conv
每种操作需要特定的转换逻辑,例如linalg.matmul可能转换为:
cpp
class MatmulOpConversion : public OpConversionPattern<linalg::MatmulOp> {
LogicalResult matchAndRewrite(
linalg::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// 获取输入和输出
Value lhs = adaptor.getInputs()[0];
Value rhs = adaptor.getInputs()[1];
Value out = adaptor.getOutputs()[0];
// 创建EmitC调用,调用自定义的matmul函数
auto callOp = rewriter.create<emitc::CallOp>(
op.getLoc(), "custom_matmul", out.getType(),
ValueRange{lhs, rhs, out});
rewriter.replaceOp(op, callOp.getResults());
return success();
}
};
6. 处理类型转换
在复杂的转换中,可能需要自定义类型转换:
cpp
struct LinalgToEmitCTypeConverter : public TypeConverter {
LinalgToEmitCTypeConverter() {
// 注册类型转换规则
addConversion([](TensorType type) {
// 转换tensor类型到适合emitc的类型
return type;
});
}
};
7. 测试你的转换
创建测试文件test/Conversion/LinalgToEmitC/basic.mlir:
mlir
// RUN: triton-opt %s -triton-linalg-to-emitc | FileCheck %s
func.func @test_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>)
outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// CHECK-LABEL: func.func @test_matmul
// CHECK: emitc.call "custom_matmul"
8. 高级技巧
- 处理内存布局:Linalg操作通常有特定的内存布局要求,需要在emitc调用中正确处理
- 性能优化:考虑添加tile、fuse等优化,生成高效的C代码
- 自定义代码生成 :结合EmitC的
emitc.constant和emitc.include等操作,生成完整的C函数
9. 调试技巧
- 使用
-debug和-debug-only=linalg-to-emitc选项查看pass执行细节 - 使用MLIR的打印方法
op->print(llvm::errs())在转换过程中打印IR - 使用断点和GDB进行交互式调试
MLIR/Triton中转换模式如何处理操作依赖关系
1. MLIR转换框架自动处理的依赖
1.1 SSA依赖自动维护
MLIR的转换框架会自动处理SSA(Static Single Assignment)形式的依赖关系:
cpp
// 当一个操作被替换时,所有使用其结果的操作会自动更新
// 例如: 将 %0 = linalg.matmul ... 转换为 %0 = emitc.call ...
// 任何使用%0的操作会自动指向新的值
rewriter.replaceOp(op, newResults);
1.2 转换过程中的拓扑排序
applyPartialConversion/applyFullConversion会自动按拓扑顺序处理操作:
cpp
// MLIR会确保按照依赖顺序应用转换模式
// 先转换不依赖于其他要转换操作的操作
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
2. 显式依赖分析技术
2.1 使用Operation的操作数和结果
最基础的方法是直接查询操作的使用链:
cpp
class MatmulOpConversion : public OpConversionPattern<linalg::MatmulOp> {
LogicalResult matchAndRewrite(
linalg::MatmulOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// 获取输入操作 (依赖的操作)
Value lhs = adaptor.getInputs()[0];
Value rhs = adaptor.getInputs()[1];
// 获取依赖于当前操作的操作
auto uses = op->getUses(); // 获取所有使用此操作结果的操作
// 通常不需要手动处理这些依赖,因为rewriter会自动处理
// 但在复杂转换中可能需要分析
for (auto &use : uses) {
Operation *user = use.getOwner();
// 分析依赖操作
}
// 创建新操作
auto emitcCall = rewriter.create<emitc::CallOp>(...);
rewriter.replaceOp(op, emitcCall);
return success();
}
};
2.2 依赖切片(Slices)
对于更复杂的依赖分析,可以使用切片API:
cpp
#include "mlir/Analysis/SliceAnalysis.h"
void analyzeDependencies(linalg::GenericOp op) {
// 获取前向切片 (依赖于此操作的操作)
SetVector<Operation *> forwardSlice;
getForwardSlice(op, &forwardSlice);
// 获取后向切片 (此操作依赖的操作)
SetVector<Operation *> backwardSlice;
getBackwardSlice(op, &backwardSlice);
// 处理依赖
for (Operation *depOp : backwardSlice) {
// 分析依赖操作
}
}
2.3 访问Linalg特定依赖
Linalg操作有特殊的依赖信息:
cpp
void analyzeLinalgDependencies(linalg::GenericOp genericOp) {
// 获取索引映射,这些定义了数据访问模式
auto indexingMaps = genericOp.getIndexingMaps();
// 分析操作中的计算块
Block &body = genericOp.getRegion().front();
// 分析块参数的依赖关系
for (auto arg : body.getArguments()) {
// 这些参数对应于输入/输出张量的元素
// 可以分析它们在计算块中的使用
}
// 获取迭代类型 (parallel, reduction等)
auto iteratorTypes = genericOp.getIteratorTypes();
}
3. 处理Linalg到EmitC转换中的依赖
3.1 保持数据流依赖
在Linalg到EmitC转换中,需要保持原始操作的数据流依赖:
cpp
class GenericOpConversion : public OpConversionPattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// 1. 收集所有输入/输出张量 (保持依赖)
SmallVector<Value> inputs = adaptor.getInputs();
SmallVector<Value> outputs = adaptor.getOutputs();
// 2. 为EmitC创建参数列表,保持顺序
SmallVector<Value> emitcArgs;
emitcArgs.append(inputs.begin(), inputs.end());
emitcArgs.append(outputs.begin(), outputs.end());
// 3. 创建EmitC调用,保持依赖关系
StringRef functionName = determineFunctionName(op); // 基于op的语义
auto emitcCall = rewriter.create<emitc::CallOp>(
op.getLoc(), functionName, op.getResultTypes(), emitcArgs);
// 4. 替换操作 - MLIR自动更新所有依赖
rewriter.replaceOp(op, emitcCall.getResults());
return success();
}
// 基于Linalg操作的语义确定函数名
StringRef determineFunctionName(linalg::GenericOp op) const {
// 分析计算块和迭代类型
// 例如:如果检测到是矩阵乘法,返回"matmul"
// 如果是元素级操作,返回"elementwise_add"等
if (isMatmulLike(op)) return "matmul";
if (isElementwise(op)) return "elementwise_op";
return "custom_kernel";
}
};
3.2 处理循环嵌套依赖
Linalg操作通常包含隐式循环,需要保持循环依赖:
cpp
void convertLoopDependencies(linalg::GenericOp op, PatternRewriter &rewriter) {
// 获取迭代空间维度
auto iteratorTypes = op.getIteratorTypes();
size_t numLoops = iteratorTypes.size();
// 收集循环边界
SmallVector<int64_t> loopBounds;
for (auto operand : op.getOperands()) {
if (auto tensorType = operand.getType().dyn_cast<RankedTensorType>()) {
// 从张量形状推断循环边界
for (int64_t dim : tensorType.getShape()) {
loopBounds.push_back(dim);
}
}
}
// 在EmitC中需要生成相应循环结构
// 可能需要创建多个emitc.call来表示循环
}
4. 高级依赖处理技术
4.1 分阶段转换保持依赖
复杂转换通常需要分阶段进行,每阶段保持特定依赖:
cpp
// 阶段1: 先转换数据准备操作
void firstPhaseConversion(ModuleOp module) {
RewritePatternSet phase1Patterns(context);
phase1Patterns.add<BufferAllocationConversion>(context);
(void)applyPatternsAndFoldGreedily(module, std::move(phase1Patterns));
}
// 阶段2: 再转换计算操作,此时依赖已经准备好
void secondPhaseConversion(ModuleOp module) {
RewritePatternSet phase2Patterns(context);
phase2Patterns.add<LinalgToEmitCConversion>(context);
(void)applyPatternsAndFoldGreedily(module, std::move(phase2Patterns));
}
4.2 自定义转换策略
对于特别复杂的依赖,可能需要自定义转换策略:
cpp
struct LinalgToEmitCStrategy : public ConversionPattern {
// 重写匹配和重写方法,手动控制转换顺序
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// 1. 首先分析依赖图
DependencyAnalysis analysis(op);
// 2. 按依赖顺序排序需要转换的操作
auto orderedOps = analysis.getTopologicalOrder();
// 3. 按顺序转换操作
for (Operation *depOp : orderedOps) {
if (depOp->hasTrait<OpTrait::IsTerminator>()) continue;
// 应用特定转换
if (succeeded(convertSingleOp(depOp, rewriter))) {
// 跟踪已转换的操作
}
}
return success();
}
};
5. 调试依赖关系
5.1 打印依赖图
在开发转换时,可以打印依赖关系帮助调试:
cpp
void debugDependencies(Operation *op) {
llvm::dbgs() << "Dependency analysis for: " << *op << "\n";
// 打印使用当前操作的操作
for (Operation *user : op->getUsers()) {
llvm::dbgs() << " Used by: " << *user << "\n";
}
// 打印当前操作使用的值
for (Value operand : op->getOperands()) {
if (Operation *defOp = operand.getDefiningOp()) {
llvm::dbgs() << " Depends on: " << *defOp << "\n";
}
}
}
5.2 可视化依赖
MLIR提供工具生成依赖图的DOT表示:
bash
# 生成依赖图
triton-opt input.mlir -mlir-print-op-graph | dot -Tpng -o graph.png
6. Linalg到EmitC转换中的实际依赖处理
在Linalg到EmitC的实际转换中,保持依赖的关键点:
- 输入/输出张量的顺序:保持原始Linalg操作中输入和输出的顺序,以维护数据依赖
- 迭代空间结构:EmitC调用需要反映Linalg操作的迭代类型(parallel/reduction)
- 内存访问模式:Linalg的索引映射定义了内存访问模式,这必须在生成的C代码中保持
- 副作用处理:Linalg操作可能有特定副作用语义,需要在EmitC中正确反映
一个更完整的Linalg Generic转换示例:
cpp
class GenericOpConversion : public OpConversionPattern<linalg::GenericOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
linalg::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// 1. 分析依赖和访问模式
auto iteratorTypes = llvm::to_vector<4>(op.getIteratorTypesArray());
auto indexingMaps = op.getIndexingMaps();
// 2. 确定需要生成的C函数签名
// 根据输入/输出数量和类型
SmallVector<Type> argTypes;
for (Value input : adaptor.getInputs()) {
argTypes.push_back(input.getType());
}
for (Value output : adaptor.getOutputs()) {
argTypes.push_back(output.getType());
}
// 3. 生成函数名,编码依赖信息
std::string funcName = "linalg_generic_";
funcName += std::to_string(iteratorTypes.size()) + "d_";
funcName += std::to_string(op.getNumInputs()) + "in_";
funcName += std::to_string(op.getNumOutputs()) + "out";
// 4. 创建EmitC调用,保持依赖
SmallVector<Value> args = adaptor.getOperands();
auto resultTypes = llvm::to_vector<4>(op.getResultTypes());
// 5. 生成额外的属性,如循环边界和访问模式
DictionaryAttr callAttrs = rewriter.getDictionaryAttr({
rewriter.getNamedAttr("iterator_types",
rewriter.getStrArrayAttr(iteratorTypes)),
rewriter.getNamedAttr("indexing_maps",
rewriter.getArrayAttr(indexingMaps))
});
auto emitcCall = rewriter.create<emitc::CallOp>(
loc, funcName, resultTypes, args, callAttrs, /*templateArgs=*/nullptr);
rewriter.replaceOp(op, emitcCall.getResults());
return success();
}
};
在实际应用中,Linalg到EmitC的转换通常不是一对一的操作替换,而是需要分析整个计算图的依赖关系,可能将多个Linalg操作融合成一个EmitC调用,以生成更高效的C代码。MLIR的转换框架提供了必要的基础设施来处理这些复杂的依赖关系。
MLIR转换过程中的中间态与IR演进
在MLIR/Triton的转换过程中完全允许且通常必然存在两种或多种dialect共存的中间状态。这是MLIR转换框架的核心设计原则之一,使复杂的转换可以通过一系列小步骤完成。
1. MLIR转换框架中的中间态
MLIR转换采用渐进式方法,一个操作一个操作地进行转换,中间必然会产生混合dialect的IR。这种设计有以下特点:
- 增量转换:每次转换一个操作,IR处于混合状态
- 转换目标(ConversionTarget):定义哪些dialect/operation是合法的
- 部分转换(Partial Conversion):允许中间状态存在,只要最终达到目标
- 依赖处理:MLIR自动处理SSA依赖,确保引用一致性
2. Linalg到EmitC转换中的IR演进示例
假设我们有以下初始IR(包含两个Linalg操作):
mlir
func.func @example(%in1: tensor<4x4xf32>, %in2: tensor<4x4xf32>, %out: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul {cast = #linalg.type_fn<cast_signed>}
ins(%in1, %in2 : tensor<4x4xf32>, tensor<4x4xf32>)
outs(%out : tensor<4x4xf32>) -> tensor<4x4xf32>
%1 = linalg.fill ins(%cst = arith.constant 0.0 : f32 : f32)
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
第一步:转换第一个linalg.matmul操作
假设模式匹配首先转换了linalg.matmul操作:
cpp
rewriter.replaceOp(matmulOp, emitcCallOp.getResults());
转换后的IR状态:
mlir
func.func @example(%in1: tensor<4x4xf32>, %in2: tensor<4x4xf32>, %out: tensor<4x4xf32>) -> tensor<4x4xf32> {
// Linalg.matmul 已转换为 EmitC.call
%0 = emitc.call @custom_matmul(%in1, %in2, %out) :
(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// linalg.fill 仍然存在,但已更新为使用新的%0
%1 = linalg.fill ins(%cst = arith.constant 0.0 : f32 : f32)
outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
关键变化:
linalg.matmul被替换为emitc.call- 所有使用原
%0的操作(此处是linalg.fill)自动更新为使用新操作的结果 - IR现在是混合状态:同时包含EmitC和Linalg dialect
第二步:转换linalg.fill操作
接着,转换框架处理linalg.fill:
cpp
rewriter.replaceOp(fillOp, emitcFillCall.getResults());
最终转换后的IR:
mlir
func.func @example(%in1: tensor<4x4xf32>, %in2: tensor<4x4xf32>, %out: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = emitc.call @custom_matmul(%in1, %in2, %out) :
(tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
// linalg.fill 也已转换为 EmitC
%1 = emitc.call @memset_zero(%0) :
(tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
}
3. 转换框架如何管理中间状态
3.1 ConversionTarget配置
cpp
ConversionTarget target(*context);
// 允许EmitC和必要的辅助dialect
target.addLegalDialect<emitc::EmitCDialect, func::FuncDialect, arith::ArithDialect>();
// 标记Linalg为非法,需要转换
target.addIllegalDialect<linalg::LinalgDialect>();
// 允许特定操作保持不变
target.addLegalOp<arith::ConstantOp>();
3.2 依赖处理机制
在中间状态中,MLIR通过以下方式维护依赖关系:
cpp
// 当replaceOp被调用时:
rewriter.replaceOp(oldOp, newValues);
// 内部自动执行:
// 1. 创建新操作,使用相同结果类型
// 2. 更新所有使用oldOp结果的操作,指向新值
// 3. 删除oldOp
3.3 中间状态的合法性检查
在部分转换中,框架会检查中间状态是否合法:
cpp
// 在每一步转换后,框架检查:
if (!target.isLegal(updatedOp)) {
// 如果操作仍不合法,继续应用模式
}
4. 调试中间状态
在开发转换pass时,可以观察中间状态:
cpp
void MyConversionPass::runOnOperation() {
// ...
if (failed(applyPartialConversion(...))) {
// 打印中间状态
getOperation()->dump();
signalPassFailure();
}
}
使用命令行选项查看详细转换过程:
bash
triton-opt input.mlir -triton-linalg-to-emitc -debug-only=dialect-conversion
输出会显示每一步转换前后的IR状态。
5. 处理复杂依赖的中间状态
当操作之间存在复杂依赖时,中间状态管理更为关键。例如,考虑依赖链:A→B→C,其中A和C需要转换,B不需要:
mlir
// 初始状态
%a = linalg.matmul ... // 需要转换
%b = tensor.extract_slice %a ... // 不需要转换
%c = linalg.fill %b ... // 需要转换
转换过程:
-
转换A:
%a = emitc.call @matmul...- B自动更新,继续使用新的
%a - C仍然依赖B
- B自动更新,继续使用新的
-
转换C:
%c = emitc.call @fill...- 没有其他操作依赖C,所以没有额外更新
中间状态:
mlir
%a = emitc.call @matmul... // 已转换
%b = tensor.extract_slice %a ... // 未转换,但正确引用%a
%c = emitc.call @fill %b ... // 已转换
6. 高级转换策略:分阶段转换
对于特别复杂的转换,可以采用分阶段策略,每个阶段产生不同的中间状态:
cpp
void runOnOperation() override {
MLIRContext *context = &getContext();
// 阶段1: 只转换特定Linalg操作
{
ConversionTarget target(*context);
target.addLegalDialect<emitc::EmitCDialect, linalg::LinalgDialect>();
target.addIllegalOp<linalg::MatmulOp>(); // 只转换matmul
RewritePatternSet patterns(context);
patterns.add<MatmulToEmitC>(context);
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
// 此时IR包含EmitC和剩余Linalg操作
}
// 阶段2: 转换其他Linalg操作
{
ConversionTarget target(*context);
target.addLegalDialect<emitc::EmitCDialect>();
target.addIllegalDialect<linalg::LinalgDialect>();
RewritePatternSet patterns(context);
patterns.add<GenericLinalgToEmitC>(context);
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
}
7. 实际注意事项
- 转换顺序重要性:MLIR不保证操作的处理顺序,应设计转换使其顺序无关
- 资源管理:在中间状态中,确保资源(如内存)正确管理
- 副作用处理:确保转换保持操作的副作用语义
- 验证中间状态:使用MLIR的验证机制确保中间IR的有效性
cpp
// 验证中间IR
getOperation()->walk([](Operation *op) {
if (failed(op->verify())) {
op->emitError("Invalid operation in intermediate state");
}
});
总结
在MLIR/Triton转换过程中,混合dialect的中间状态不仅是允许的,而且是转换框架的核心设计。当第一个操作转换完成后,IR会变成部分转换的状态,MLIR会自动处理依赖关系,确保所有引用正确更新。这种渐进式转换方法使复杂的dialect转换可以分解为可管理的小步骤,同时保持IR的一致性和正确性。