TPU 架构与 Pallas Kernel 编程入门:从内存层次结构到 FlashAttention做过 GPU kernel 优化的人对以下编程模型肯定不会陌生:写一个 CUDA kernel分发到流式多处理器(SM)上执行,缓存层次结构自行负责数据搬运。而TPU 则完全不同,除非明确告诉编译器要把哪些数据块搬到哪里,否则kernel 根本无法编译。实际操作确实和听起来一样繁琐,所以JAX 的Pallas 就是解决的这个问题:以 tile 为单位描述计算,无需手动指定输入张量各部分的搬运路径,编译器自动生成所需的数据移动操作。