Mirage-LLM编译成大Kernel

Ref

  1. Mirage github
  2. Mirage 博客
  3. 关于uGraph
  4. deepwiki

源码

mirage/src/kernel/customized.cc

实现自定义内核算子(KNCustomizedOp)的核心文件,主要负责将内核级图和线程块级图连接起来

Graph::customized() - 自己定义算子

cpp 复制代码
std::vector<DTensor> Graph::customized(std::vector<DTensor> const &inputs,
                                       threadblock::Graph const &bgraph) {
  KNOperator *op = create_customized_op(inputs, bgraph);
  assert(op != nullptr);
  operators.push_back(op);
  return op->output_tensors;
}

int Graph::customized(std::vector<DTensor const *> _inputs,
                      DTensor **outputs,
                      mirage::threadblock::Graph const *bgraph) {
  std::vector<DTensor> inputs;
  for (auto const &t : _inputs) {
    inputs.push_back(t == nullptr ? DTensor::EMPTY_TENSOR : *t);
  }
  KNOperator *op = create_customized_op(inputs, *bgraph);
  assert(op != nullptr);
  operators.push_back(op);
  for (size_t i = 0; i < op->output_tensors.size(); i++) {
    outputs[i] = &op->output_tensors[i];
  }
  return op->output_tensors.size();
}

mirage/src/transpiler/transpile.cc

mirage/src/kernel/chunk.cc

chunk算子用于对于给定张量在指定维度上进行切分,注意这里貌似只涉及了描述chunk的行为,具体的可能是通过Transpiler代码生成器进行翻译成CUDA代码

算子创建

cpp 复制代码
KNOperator *
    Graph::create_chunk_op(DTensor const &input, int chunk_size, int dim) {
  if (dim < 0 || dim >= input.num_dims || chunk_size <= 0) {
    return nullptr;
  }
  if (input.dim[dim] % chunk_size != 0) {
    return nullptr;
  }
  if (!this->can_allocate(input)) {
    return nullptr;
  }

  KNChunkOp *op = new KNChunkOp(this, input, chunk_size, dim);
  return op;
}

运行chunk算子

cpp 复制代码
std::vector<DTensor>
    Graph::chunk(DTensor const &input, int chunk_size, int dim) {
  KNOperator *op = create_chunk_op(input, chunk_size, dim);
  assert(op != nullptr);
  operators.push_back(op);
  assert(op->output_tensors.size() > 0);
  return op->output_tensors;
}

创建chunk算子并加入算子库operators之后,并调用该算子运行得到结果

相关推荐
西岸行者5 天前
学习笔记:SKILLS 能帮助更好的vibe coding
笔记·学习
悠哉悠哉愿意5 天前
【单片机学习笔记】串口、超声波、NE555的同时使用
笔记·单片机·学习
别催小唐敲代码5 天前
嵌入式学习路线
学习
毛小茛5 天前
计算机系统概论——校验码
学习
babe小鑫5 天前
大专经济信息管理专业学习数据分析的必要性
学习·数据挖掘·数据分析
winfreedoms5 天前
ROS2知识大白话
笔记·学习·ros2
在这habit之下5 天前
Linux Virtual Server(LVS)学习总结
linux·学习·lvs
我想我不够好。5 天前
2026.2.25监控学习
学习
im_AMBER5 天前
Leetcode 127 删除有序数组中的重复项 | 删除有序数组中的重复项 II
数据结构·学习·算法·leetcode
CodeJourney_J5 天前
从“Hello World“ 开始 C++
c语言·c++·学习