§ 00 · The fly that draws the blueprintPrologue
§ 00 · 画蓝图的那只蝇序言
There is a particular vertigo that comes from reading a modern GPU kernel and realizing how much of it is layout, not arithmetic. A 4096-cube FP16 GEMM is, viewed line-by-line, mostly bookkeeping — which thread loads which 8 elements, where in shared memory they land, in what order the MFMA instructions consume them, when the next K-block's prefetch overlaps with the current MFMA tail. The actual multiplication is two lines. Everything else is layout.
读一个现代 GPU kernel 会有一种特殊的眩晕感: 你会意识到这段代码大部分不是算术, 是 layout。 一个 4096 立方的 FP16 GEMM, 逐行看下来, 绝大部分是簿记 —— 哪个 thread 加载哪 8 个元素、 共享内存里它们落在哪、 MFMA 指令以什么顺序消费它们、 下一个 K 块的 prefetch 在当前 MFMA 的尾巴上以多大比例重叠。 真正的乘法是两行。 其余全是 layout。
FlyDSL — the Fly dialect plus its Python front-end — is one answer to that observation. It is the Python DSL and MLIR-native compiler stack that AMD's ROCm group uses to author GPU kernels with explicit layouts and tiling. The name unpacks to Flexible Layout pYthon DSL. The design is unmistakably descended from NVIDIA CuTe: layout = (shape, stride), composition / divide / product, thread-value layouts, tiled copies, MMA atoms. What FlyDSL adds is the MLIR machinery — a real !fly.layout type, a pass pipeline that lowers layout algebra to address arithmetic, copy atoms to ROCDL buffer instructions, and MMA atoms to MFMA — plus the Python tracing layer that lets you write all of this without leaving the editor.
FlyDSL —— Fly dialect 加上它的 Python 前端 —— 是对这个观察的一种回应。 它是 AMD ROCm 组用来写 GPU kernel 的 Python DSL + MLIR-native 编译器栈, kernel 中 layout 和 tiling 都是显式的。 名字展开就是 Flexible Layout pYthon DSL。 设计明显继承自 NVIDIA CuTe: layout = (shape, stride), composition / divide / product, thread-value layout, tiled copy, MMA atom。 FlyDSL 多加的是 MLIR 那一层 —— 一个真实的 !fly.layout 类型、 一整套把 layout 代数 lower 到地址算术、 把 copy atom lower 到 ROCDL buffer 指令、 把 MMA atom lower 到 MFMA 的 pass 流水线 —— 再加上 Python tracing 层, 让你不用离开编辑器就能写完这些。
This entry reads FlyDSL through its examples/ directory — four files that form a strict pedagogical ladder. Example 01 is a vector add. Example 02 introduces tiled copy. Example 03 wires up a single-block MFMA GEMM. Example 04 is a production-grade preshuffle GEMM with double-buffered LDS, two-stage software pipelining, and hand-written instruction scheduling. The four together demonstrate every machinery that real FlyDSL kernels use; everything in kernels/ — paged attention, MoE GEMM, flash attention — is a recombination of the same pieces.
这一篇通过 examples/ 目录来读 FlyDSL —— 四个文件构成一条严格的教学阶梯。 例 01 是向量加。 例 02 引入 tiled copy。 例 03 接好一个单 block 的 MFMA GEMM。 例 04 是生产级的 preshuffle GEMM, 带双 buffer 的 LDS、 两阶段软件流水线、 手写指令调度。 四个加起来演示了真实 FlyDSL kernel 用到的全部机械; kernels/ 目录里所有东西 —— paged attention、 MoE GEMM、 flash attention —— 都是同一批零件的重新组合。
If you only have an hour, read § M1 (layout algebra) and § M3 (tiledCopy). Those two cover the FlyDSL mental model. Add § M5 (preshuffle GEMM) for the production patterns. Skip this prologue if you already know why you opened the file.
如果你只有一小时, 读 § M1 (layout 代数) 和 § M3 (tiledCopy) 这两节, 它们覆盖了 FlyDSL 的心智模型。 再加 § M5 (preshuffle GEMM) 就能看到生产 pattern。 如果你已经知道为什么要打开这个文件, 跳过这段序言就行。
What this writeup is not: it is not a build guide (the README at FlyDSL-lab/README.md handles that), and it is not a reference (docs/layout_system_guide.md has the full Quick Reference table). It is a reader's report on what each example teaches, where the abstractions earn their keep, and where the seams show.
这篇文章不是什么: 它不是编译指南 (这事 FlyDSL-lab/README.md 已经管了), 也不是 API 手册 (docs/layout_system_guide.md 有完整的 Quick Reference 表)。 它是一份读者报告 —— 每个例子教了什么, 抽象在哪里赚回了它的复杂度, 接缝在哪里能看出来。
M0 · The four-example ladderCompass of four examples
M0 · 四例阶梯四个例子的罗盘
The repo contains thousands of lines of production kernels under kernels/ and an exhaustive test suite under tests/, but its pedagogical core is just four files under examples/. Each one adds exactly one new concept to the previous, and reading them in order is the fastest way to internalize FlyDSL.
repo 里 kernels/ 下面有数千行 production kernel, tests/ 下面有详尽的测试套件, 但它的教学核心只有 examples/ 下面 4 个文件。 每一个都比上一个多引入恰好一个新概念, 按顺序读它们是吃透 FlyDSL 最快的路径。
| File | LOC | New concept | What it proves | ||
|---|---|---|---|---|---|
| 文件 | 行数 | 新概念 | 展示什么 | ||
| 01-vectorAdd.py | 137 | @flyc.kernel + logical_divide + copy_atom_call |
The minimum viable FlyDSL kernel: trace, compile, launch, capture in a CUDA Graph. | 最小可用的 FlyDSL kernel: trace、 编译、 launch、 被 CUDA Graph 捕获。 | |
| 02-tiledCopy.py | 69 | make_layout_tv + partition_S/D + BufferCopy128b |
Thread/value layouts that fan a tile across a wavefront, vectorized. | 用 thread / value layout 把一个 tile 向量化地分发给一个 wavefront。 | |
| 03-tiledMma.py | 95 | MFMA atom + make_tiled_copy_A/B/C + retile |
Copy layouts derived from the MMA atom; same register read two ways. | copy layout 从 MMA atom 反推; 同一段寄存器两种视角。 | |
| 04-preshuffle_gemm.py | 211 | preshuffle B + LDS swizzle + 2-stage pipeline + sched_* |
preshuffle B + LDS swizzle + 两阶段流水 + sched_* |
A real GEMM: 4096³ FP16 with hand-tuned instruction scheduling. | 一个真实的 GEMM: 4096³ FP16, 带手调指令调度。 |
The ladder is steep on the last step. Example 04 is roughly the union of every optimization that separates a 30% kernel from a 95% one on CDNA3 / CDNA4: preshuffled weight layout so MFMA needs no LDS transpose, XOR swizzle so LDS reads hit 32 banks evenly, two-stage ping-pong so the prefetch of K-block n+1 overlaps with the MFMA of K-block n, and a scheduler block that tells the LLVM AMDGPU backend exactly how to interleave v_mfma, ds_read, buffer_load, and ds_write instructions inside the hot loop. None of this is in 01–03. Reading them in order is the only way to keep the cognitive load manageable.
阶梯最后一级很陡。 例 04 大致是 CDNA3 / CDNA4 上把一个 30% kernel 拉到 95% 所有优化的并集: preshuffle 过的权重 layout 让 MFMA 不需要 LDS 转置、 XOR swizzle 让 LDS 读均匀打到 32 个 bank、 两阶段 ping-pong 让 K 块 n+1 的 prefetch 与 K 块 n 的 MFMA 重叠、 一段 scheduler 告诉 LLVM AMDGPU backend 在 hot loop 里如何精确交错 v_mfma、 ds_read、 buffer_load、 ds_write。 这些 01–03 都没有。 按顺序读才能保持认知负担在可控范围。
fly dialect with stock gpu/arith/scf/memref. A nine-stage pass pipeline progressively lowers layout algebra to address arithmetic and copy/MMA atoms to ROCDL intrinsics, then the standard MLIR ROCDL path produces a HSA fatbin. Everything past stage 9 is cached on disk and keyed on the type signature plus the Constexpr values plus a source hash — change a Constexpr[int] value and you get a new cache key.fly dialect 和官方的 gpu / arith / scf / memref 混在一起。 一个九阶段的 pass 流水线把 layout 代数渐进 lower 到地址算术、 把 copy / MMA atom 渐进 lower 到 ROCDL intrinsic, 之后标准的 MLIR ROCDL 路径产出 HSA fatbin。 第 9 阶段以后的东西全部落盘缓存, 缓存 key 由类型签名 + Constexpr 值 + 源码 hash 组成 —— 改一个 Constexpr[int] 的值就会生成新的 cache key。M1 · The grammar before the kernelsLayout algebra in ten minutes
M1 · 读 kernel 之前的语法十分钟 layout 代数
Every line of every FlyDSL kernel rests on the same five-word vocabulary: shape, stride, layout, divide, slice. The function names get longer (zipped_divide, logical_product, raked_product, composition) but they all decorate the same primitive idea — a layout is a function from a coordinate tuple to a linear index, and the algebra on layouts is the algebra on those functions. Internalize the primitive and the rest is a vocabulary exercise.
每一行 FlyDSL kernel 都建立在同样的五个词上: shape、 stride、 layout、 divide、 slice。 函数名会越长 (zipped_divide、 logical_product、 raked_product、 composition) 但它们都装饰同一个原始概念 —— layout 是一个从坐标 tuple 到线性 index 的函数, layout 上的代数就是这些函数上的代数。 把这个原始概念吃透, 剩下的就是词汇练习。
The atom · Shape, Stride, Layout
原子 · Shape, Stride, Layout
A !fly.layout is a pair of integer tuples: a shape and a stride. Given coordinates (c0, c1, ..., cn-1), the layout computes a linear index as
一个 !fly.layout 是两个整数 tuple 的对: shape 和 stride。 给定坐标 (c0, c1, ..., cn-1), layout 这样计算线性 index:
Strides are independent of shape, which is the whole point. A logical 8×16 matrix can be column-major (1, 8), row-major (16, 1), padded (1, 9), swizzled, transposed, or have its rows broadcast — all of these are different strides on the same shape. FlyDSL kernels never compute addresses by hand; they construct layouts, divide and compose them, then call copy or gemm on a fragment that holds the resulting (memory, layout) pair.
stride 和 shape 是独立的, 这正是关键。 一个逻辑上 8×16 的矩阵可以是列主 (1, 8)、 行主 (16, 1)、 padded 的 (1, 9)、 swizzled、 转置、 或行广播 —— 这些都是同一个 shape 上的不同 stride。 FlyDSL kernel 从来不手算地址; 它们构造 layout、 切分组合 layout, 然后对一个持有 (memory, layout) 对的 fragment 调用 copy 或 gemm。
Three things to remember about shapes:
关于 shape 有三点要记住:
- Shapes can be nested.
fx.make_shape(9, (4, 8))is a rank-2 shape whose second mode is itself a rank-2 tuple. Nested shapes are how multi-level tiling (block → warp → thread → instruction) gets expressed without an explosion of indices. - Static and dynamic values mix freely. A shape entry can be a Python integer (folded at compile time) or an SSA value (a runtime
fx.Int32). The MLIR type system tracks which is which. - "Layout" is also the name of the primitive type.
!fly.layout<(8,16):(1,8)>is a real MLIR type — operations consume and produce values of this type. This is what separates FlyDSL from a string-templated layout library: the layout algebra is a typed first-class IR, not a build-time string substitution.
- shape 可以嵌套。
fx.make_shape(9, (4, 8))是一个 rank-2 的 shape, 其第二维本身又是 rank-2 的 tuple。 嵌套 shape 是分层 tiling (block → warp → thread → instruction) 不用索引爆炸就能表达的关键。 - 静态值和动态值可以自由混用。 shape 的某一维可以是 Python int (编译期折掉) 或 SSA 值 (运行期
fx.Int32)。 MLIR 类型系统会追踪哪个是哪个。 - "Layout" 也是这个 primitive 类型的名字。
!fly.layout<(8,16):(1,8)>是一个真实的 MLIR 类型 —— 各 op 消费和产出这个类型的值。 这正是把 FlyDSL 和"模板字符串拼接 layout 库"区分开的关键: layout 代数是 typed 一等公民 IR, 不是 build-time 字符串替换。
The two operations you will actually use
你实际会用的两个操作
The Quick Reference lists 14 product / divide variants. In examples/, exactly three appear: logical_divide, zipped_divide, and flat_divide. All three split a layout into a "tile" and a "rest of the grid" — they differ only in how the result is shaped.
Quick Reference 表里列了 14 个 product / divide 变体。 在 examples/ 里实际出现的只有 3 个: logical_divide、 zipped_divide、 flat_divide。 三个都是把一个 layout 切成"tile"和"剩余网格"两部分 —— 只是结果的形状不同。
Once you've divided, you pick a tile with fx.slice — passing None for the modes you want to keep and an index for the modes you want to collapse:
divide 之后用 fx.slice 选 tile —— 想保留的维度传 None, 想折掉的维度传 index:
These four lines are the entire grammar of "give me this block's tile" in FlyDSL. Every kernel in the repo starts with some form of this pattern. The first half-page of any kernel is layout construction; the back half is where actual loads, MFMAs, and stores happen.
这四行就是 FlyDSL 里"把这个 block 的 tile 给我"的全部语法。 repo 里每个 kernel 都从这种 pattern 开始。 任何 kernel 的前半页都是 layout 构造; 后半页才是实际的 load、 MFMA、 store。
Thread × value · the layout on top of the layout
Thread × value · layout 之上的 layout
A tile is a chunk of data. To execute on a GPU, you need a second layout that maps each (thread, value-index) pair to a coordinate inside the tile. This is the thread/value (TV) layout, and it is FlyDSL's central abstraction for fanning data across a wavefront.
一个 tile 是一坨数据。 要在 GPU 上执行, 你需要第二个 layout, 把每对 (thread, value-index) 映射到 tile 内的一个坐标。 这就是 thread / value (TV) layout, 也是 FlyDSL 把数据分发给一个 wavefront 的核心抽象。
The TV layout is the magic. Once you have it, partition_S(tensor) on a thread slice gives you the fragment of tensor that this thread will touch — already shaped (V, VM, VN) with the right indexing math folded in. You never see the address arithmetic. The compiler does. The pass fly-layout-lowering at stage 3 of the pipeline materializes all of this into concrete index expressions.
TV layout 是关键。 有了它之后, 在一个 thread slice 上调 partition_S(tensor), 你拿到的就是这个 thread 要碰的那块 tensor —— 形状已经是 (V, VM, VN), 正确的索引算术已经折进去了。 你看不到地址算术, 编译器在背后做。 流水线第 3 阶段的 fly-layout-lowering pass 把这一切物化成具体的 index 表达式。
Layout is not a description of memory; it is a function. make_layout(shape, stride) defines a map coord ↦ index. composition, divide, and product are function composition on that map. Once you read FlyDSL through this lens, the gap between "what the code says" and "what the kernel does" closes by an order of magnitude.
layout 不是对内存的描述, 是一个函数。 make_layout(shape, stride) 定义了一个映射 coord ↦ index。 composition、 divide、 product 是这个映射上的函数组合。 一旦你用这个视角去读 FlyDSL, "代码说什么"和"kernel 在做什么"之间的距离会瞬间缩小一个数量级。
Everything else in the layout API — composition, complement, coalesce, right_inverse, raked_product, tiled_product — is constructed from these primitives. The categorical foundations paper for CuTe layouts works it out formally; for reading kernels, the four pictures above are enough.
layout API 里所有其他东西 —— composition、 complement、 coalesce、 right_inverse、 raked_product、 tiled_product —— 都是从这些 primitive 构造出来的。 CuTe 的范畴论基础那篇文章把这套形式化推出来了; 读 kernel 的话上面这四张图就够。
M2 · The two-line program01 · vectorAdd
M2 · 两行的程序01 · vectorAdd
The first example computes C = A + B for length-128 FP32 vectors. It is the minimum viable FlyDSL kernel — small enough that you can hold the whole pipeline in your head, large enough to exercise every required machinery. There are two functions:
第一个例子计算长度为 128 的 FP32 向量 C = A + B。 这是最小可用的 FlyDSL kernel —— 小到整条流水线能一次性放在脑子里, 又大到把必需的机械全部用一遍。 一共两个函数:
The structure that all FlyDSL kernels share is visible here in miniature. Three phases: layout construction, copy in, compute + copy out. Phase one builds the per-thread view of the data. Phase two moves it from global memory into register fragments. Phase three runs the arithmetic and writes the result back. Once you can see this skeleton, every kernel in kernels/ becomes a variation on it — a bigger compute phase, more sophisticated copies, an LDS staging buffer between phase two and three, but the same three-phase shape.
所有 FlyDSL kernel 都共用的结构在这里以最小规模呈现。 三个阶段: 构造 layout、 拷入、 计算 + 拷出。 阶段一建立数据的 per-thread 视图; 阶段二把数据从 global 搬到 register fragment; 阶段三跑算术然后写回。 看到这个骨架, kernels/ 里每个 kernel 都变成它的变体 —— 更大的计算阶段、 更复杂的 copy、 阶段二和三之间多一层 LDS staging buffer, 但是同一个三阶段结构。
logical_divide(A, layout(64,1)) produces a shape ((64,), (2,)) where the first mode is "what fits in one block" and the second mode is "how many blocks." slice(.., (None, bid)) then collapses the second mode for the current block. The second division splits the per-block tile into per-thread slots; with a trivial divisor of 1, each thread gets one element. The same cascade appears in every kernel — the difference is the divisor sizes and how many levels you stack.logical_divide(A, layout(64,1)) 产出 ((64,), (2,)), 第一维是"一个 block 装得下多少", 第二维是"一共多少 block"。 然后 slice(.., (None, bid)) 把第二维折成当前 block。 第二次切分把 per-block tile 切成 per-thread 槽位; divisor 取 1, 每个 thread 拿一个元素。 这种级联在每个 kernel 里都出现 —— 差别只是 divisor 多大、 叠几层。The two copy atoms · BufferCopy vs UniversalCopy
两种 copy atom · BufferCopy 与 UniversalCopy
Look at the two atoms FlyDSL constructs for this kernel. They are not the same:
看这个 kernel 里 FlyDSL 构造的两个 atom。 它们不一样:
UniversalCopy32b lowers to a generic flat_load_dword — pointer + offset, no extras. BufferCopy32b lowers to buffer_load_dword, which uses a four-SGPR buffer resource descriptor set up earlier by make_buffer_tensor. The buffer descriptor carries a num_records field, and the AMD GPU hardware checks every access against it: out-of-bounds loads return zero without raising a page fault, and out-of-bounds stores are silently dropped. This is the same OOB-by-SRD trick that gcnasm's vector_add kernel makes a centerpiece of.
UniversalCopy32b 会 lower 成普通的 flat_load_dword —— 指针 + offset, 没多余东西。 BufferCopy32b 会 lower 成 buffer_load_dword, 它用一个由四个 SGPR 组成的 buffer resource descriptor, 这个 descriptor 在前面 make_buffer_tensor 时已经建好。 descriptor 里有 num_records 字段, AMD GPU 硬件每次访问都对它检查: 越界 load 返回 0 而不触发 page fault, 越界 store 静默丢弃。 这正是 gcnasm vector_add kernel 中作为核心的"OOB-by-SRD"技巧。
The example deliberately uses both — A via buffer, B via universal — to make the contrast visible. In production code, you would use BufferCopy for both: it costs the same to issue and gets you branch-free boundary handling for free.
这个例子故意两种都用 —— A 走 buffer、 B 走 universal —— 把对比露出来。 生产代码里两者都会用 BufferCopy: 发射成本一样, 但能免费拿到无分支的边界处理。
The launch wrapper · @flyc.jit
启动包装器 · @flyc.jit
The annotations on vectorAdd tell the compiler what to do with each parameter: fx.Int32 is dynamic (it becomes a kernel argument), fx.Constexpr[int] is static (it gets baked in and contributes to the JIT cache key). Change a Constexpr value and you get a fresh compile; change a dynamic value and the cached binary is reused. This is the same Triton-style autotune model with type annotations doing the work that tl.constexpr does in Triton.
vectorAdd 上的类型注解告诉编译器每个参数怎么处理: fx.Int32 是动态的 (会变成 kernel 实参), fx.Constexpr[int] 是静态的 (会被烤进去, 同时参与 JIT cache key)。 改一个 Constexpr 值会触发重编, 改动态值则复用缓存。 这是 Triton 风格的 autotune 模型, 只是 tl.constexpr 的活由类型注解来干。
The example's second test captures the kernel into a torch.cuda.CUDAGraph and replays it. This works because FlyDSL's launch path goes through fly-gpu-stream-inject, a pass that threads the stream argument into the actual launch instead of consulting a TLS variable. For an inference engine that batches kernels into a captured graph for replay (vLLM, SGLang), this is the difference between FlyDSL kernels being usable and being a special case.
这个例子的第二个测试把 kernel capture 进 torch.cuda.CUDAGraph 然后 replay。 能跑通是因为 FlyDSL 的启动路径过的是 fly-gpu-stream-inject —— 一个把用户传入的 stream 直接编进 launch 的 pass, 不依赖 TLS 变量。 对把 kernel 批量塞进 captured graph 做回放的推理引擎 (vLLM、 SGLang), 这就是"FlyDSL kernel 能用"和"FlyDSL kernel 需要特殊处理"的区别。
Run it on a MI350X:
在 MI350X 上跑:
M3 · Fanning a tile across threads02 · tiledCopy
M3 · 把一个 tile 扇到一群 thread02 · tiledCopy
Example 02 is the smallest kernel in the repo (69 lines) and the most useful for understanding FlyDSL's central abstraction. It copies a 24×120 FP32 matrix from A to B using block tiles of (8, 24) and only four threads per block. The thread count is deliberately small — small enough that you can draw the full thread-value layout on one sheet of paper and verify by hand that it covers the tile correctly.
例 02 是 repo 里最小的 kernel (69 行), 也是理解 FlyDSL 核心抽象最有用的一个。 它用 (8, 24) 的 block tile 把 24×120 的 FP32 矩阵从 A 拷到 B, 每个 block 只有 4 个 thread。 thread 数故意做这么小 —— 小到完整的 thread-value layout 可以画在一张纸上, 用手验它确实覆盖了整个 tile。
The arithmetic of the TV layout
TV layout 的算账
An (8, 24) tile has 192 elements. Four threads issuing 128-bit (= 8 × FP32) loads can move 32 elements per round, so the tile takes six rounds to drain. The TV layout says how those six rounds tile the (8, 24) plane:
一个 (8, 24) tile 有 192 个元素。 4 个 thread 每次发 128-bit (= 8 × FP32) load, 一轮能搬 32 个, 所以这个 tile 需要 6 轮才搬完。 TV layout 说的就是这 6 轮如何在 (8, 24) 平面上铺开:
- thr_layout = ((4,1), (1,1)) — four threads, laid out 4 along M, 1 along N.
- val_layout = ((1,8), (1,1)) — each thread reads a 1×8 block of values.
- one execution covers
(4×1) × (1×8) = (4, 8)of the tile. - the tile is (8, 24), so the execution repeats
(8/4, 24/8) = (2, 3)times — that's the (VM, VN) shape thatpartition_Sreturns.
- thr_layout = ((4,1), (1,1)) —— 4 个 thread, 沿 M 排 4 个、 沿 N 排 1 个。
- val_layout = ((1,8), (1,1)) —— 每个 thread 一次读 1×8 个值。
- 一次执行覆盖 tile 里
(4×1) × (1×8) = (4, 8)。 - tile 是 (8, 24), 所以执行重复
(8/4, 24/8) = (2, 3)次 —— 这就是partition_S返回值里的 (VM, VN) 形状。
BufferCopy128b instructions — one per thread, each fetching its 8 consecutive FP32s. The (VM=2, VN=3) outer indexing means the execution repeats six times to cover the full (8, 24) tile, giving every thread a fragment of shape (V=8, VM=2, VN=3) = 48 values. partition_S hands you that exact fragment shape — no address arithmetic, just an indexable tensor.BufferCopy128b 指令 —— 每个 thread 一条, 各取自己连续的 8 个 FP32。 外层 (VM=2, VN=3) 的索引意味着执行要重复六次才能覆盖整个 (8, 24) tile, 让每个 thread 拿到形状为 (V=8, VM=2, VN=3) = 48 个值的 fragment。 partition_S 直接把这个 fragment 形状给你 —— 没有地址算术, 只有一个可索引的 tensor。partition_S, partition_D, fragment · the three calls that matter
partition_S, partition_D, fragment · 三个关键调用
Three method calls do all the work that would otherwise be a page of address math:
三个方法调用做完了原本需要一整页地址算术才能完成的事:
| Call | Returns | Used for | ||
|---|---|---|---|---|
| 调用 | 返回 | 用途 | ||
thr_copy.partition_S(bA) |
per-thread view of source, shape (V, VM, VN) |
source 的 per-thread 视图, 形状 (V, VM, VN) |
where the load reads | load 从哪里读 |
thr_copy.partition_D(bB) |
per-thread view of destination, same shape | destination 的 per-thread 视图, 同形状 | where the load writes | load 写到哪里 |
fx.make_fragment_like(part) |
register tensor with same logical shape | 同逻辑形状的寄存器 tensor | the in-register staging area | 寄存器里的暂存区 |
Then fx.copy(atom, src, dst) issues the actual instructions, walking the (VM, VN) outer modes and emitting one atom call per (VM, VN) coordinate. Six iterations, one BufferCopy128b each, four threads in parallel, 192 elements total — exactly the picture above.
然后 fx.copy(atom, src, dst) 发射真正的指令, 沿外层 (VM, VN) 维走一遍, 每个 (VM, VN) 坐标发一次 atom 调用。 6 次迭代, 每次一条 BufferCopy128b, 4 个 thread 并行, 共 192 个元素 —— 正好是上图。
In 01-vectorAdd, the partitioning is hand-rolled with two logical_divide + slice calls. Here it is delegated to tiled_copy + partition_S/D. The new abstraction earns its keep when you have more than a trivial number of threads and want vectorization: it would be tedious to compute a 128-bit-aligned load offset for each of 256 threads by hand, but make_layout_tv + partition_S compiles down to exactly that — without you ever writing the byte-offset arithmetic.
在 01-vectorAdd 里, partition 是用两次 logical_divide + slice 手写的。 这里则交给 tiled_copy + partition_S/D。 当 thread 数不再是 trivial 的几个、 又要做向量化时, 这个新抽象就赚回了它的复杂度: 256 个 thread 每个手算 128-bit 对齐的 load 偏移会很烦, make_layout_tv + partition_S 编出来正是这个 —— 而你从头到尾没写过一行字节偏移算术。
Run it:
跑一下:
M4 · First contact with MFMA03 · tiledMma & MFMA
M4 · 第一次接触 MFMA03 · tiledMma 与 MFMA
Example 03 is the first kernel that actually multiplies. It computes C(64, 64) = A(64, 8) · B(64, 8)ᵀ on FP32, in a single block with 256 threads (four wavefronts), entirely in registers — no LDS, no K-loop, no software pipeline. The point is to introduce MFMA atoms and the way copy layouts are derived from a tiled MMA, without yet bringing in the optimization machinery of example 04.
例 03 是第一个真正做乘法的 kernel。 它在 FP32 上计算 C(64, 64) = A(64, 8) · B(64, 8)ᵀ, 单 block 256 thread (4 个 wavefront), 全程在寄存器里 —— 没有 LDS、 没有 K-loop、 没有软件流水。 目的是介绍 MFMA atom 以及 copy layout 如何从一个 tiled MMA 反推, 还不引入例 04 的那些优化机械。
The MFMA atom · a single-instruction matrix multiply
MFMA atom · 一条指令一个矩阵乘
The CDNA matrix core executes Matrix Fused Multiply-Add instructions — v_mfma_f32_16x16x4f32 in this case. One instruction, 64 lanes, computes C(16, 16) += A(16, 4) · B(4, 16) in a fixed lane-distribution pattern dictated by the hardware. The constructor argument MFMA(16, 16, 4, fx.Float32) picks that exact variant; FlyDSL knows the lane layouts the instruction demands and exposes them as mma_atom.tv_layout_A/B/C.
CDNA matrix core 执行 Matrix Fused Multiply-Add 指令 —— 这里是 v_mfma_f32_16x16x4f32。 一条指令, 64 个 lane, 按硬件固定的 lane 分布 pattern 计算 C(16, 16) += A(16, 4) · B(4, 16)。 构造参数 MFMA(16, 16, 4, fx.Float32) 选出这个具体变体; FlyDSL 知道这条指令要求的 lane layout, 通过 mma_atom.tv_layout_A/B/C 暴露出来。
make_tiled_mma(atom, layout((2, 2, 1), (1, 2, 0))) says: tile this atom by laying four wavefronts out as 2 along M, 2 along N, 1 along K. The stride (1, 2, 0) maps the wave id to (M, N, K) indices — wave 0 → (M=0, N=0), wave 1 → (M=1, N=0), wave 2 → (M=0, N=1), wave 3 → (M=1, N=1). All four waves cooperate along the K dimension (stride 0 means K is broadcast across waves).
make_tiled_mma(atom, layout((2, 2, 1), (1, 2, 0))) 表示: 把这个 atom 用 4 个 wavefront 按 M 方向 2 个、 N 方向 2 个、 K 方向 1 个 平铺。 stride (1, 2, 0) 把 wave id 映射到 (M, N, K) 索引 —— wave 0 → (M=0, N=0), wave 1 → (M=1, N=0), wave 2 → (M=0, N=1), wave 3 → (M=1, N=1)。 4 个 wave 在 K 维上一起 reduce (stride 0 表示 K 在 wave 间广播)。
(1, 2, 0): wave 0 → (0,0), wave 1 → (1,0), wave 2 → (0,1), wave 3 → (1,1). The same wave-layout idiom scales: example 04 uses (1, 4, 1) with a much larger atom permutation.(1, 2, 0): wave 0 → (0,0), wave 1 → (1,0), wave 2 → (0,1), wave 3 → (1,1)。 这个 wave layout 习语可以扩展: 例 04 用 (1, 4, 1), 配合一个大得多的 atom permutation。make_tiled_copy_A · copy layouts inherited from the MMA
make_tiled_copy_A · 从 MMA 反推 copy layout
The MFMA instruction demands a specific lane-to-element layout for its A and B inputs and for its C output. Get this wrong and you load correct data into the wrong lane — the kernel runs without crashing and produces noise. The conventional fix is to read the MFMA documentation and hand-write copy layouts that match. FlyDSL inverts the direction:
MFMA 指令对 A 和 B 输入、 以及 C 输出, 都要求特定的 lane-to-element layout。 写错的话, 你会把正确的数据装到错误的 lane —— kernel 不崩溃, 但产出噪声。 常规修法是去读 MFMA 文档, 手写匹配的 copy layout。 FlyDSL 把这个方向反过来:
make_tiled_copy_A reads the MMA atom's tv_layout_A and constructs a copy whose per-thread fragment shape is exactly what the MFMA will consume. The compiler guarantees lane correctness; you never reason about which lane gets which element. The same trick gives you tiled_copy_B for the second operand and tiled_copy_C for the accumulator store.
make_tiled_copy_A 读 MMA atom 的 tv_layout_A, 构造一个 copy, 它的 per-thread fragment 形状正好是 MFMA 要消费的。 编译器保证 lane 正确, 你完全不必推哪个 lane 拿哪个元素。 同样的把戏给你 tiled_copy_B 用于第二个操作数, tiled_copy_C 用于累加器写回。
retile · one register, two views
retile · 一段寄存器, 两种视角
The same register fragment is read by two different consumers: the copy instructions (which want it in "vector load" shape) and the MFMA instruction (which wants it in "lane assignment" shape). The two shapes are isomorphic — same total bytes, same number of registers — but indexed differently. retile gives you the second view without reallocation:
同一段 register fragment 被两个消费者读: copy 指令 (想要它是"向量 load"的形状), 以及 MFMA 指令 (想要它是"lane 分配"的形状)。 两种形状同构 —— 总字节数相同、 寄存器数相同 —— 只是索引方式不同。 retile 给你第二个视角而不需要重新分配:
Without retile, you'd allocate one fragment for "after the load," then issue copies between it and a second "for the MMA" fragment — moving registers around unnecessarily. retile is FlyDSL's recognition that the data already in your registers can be addressed two ways at zero cost. The MLIR pass pipeline collapses both views into the same VGPR allocation; the only IR-level operation is a type reinterpret.
没有 retile 你要分配一个 "load 完之后"用的 fragment, 再发 copy 指令搬到第二个 "给 MMA 用"的 fragment —— 白白挪寄存器。 retile 是 FlyDSL 对一个事实的承认: 寄存器里的数据可以用两种方式寻址, 零成本。 MLIR pass 流水线把两个视角折叠成同一份 VGPR 分配; IR 层只有一个类型重新解释。
Run it:
跑一下:
Example 03 is the first example whose abstraction-to-mechanism ratio is dramatic. Less than 80 lines of layout code produce a working tiled MFMA — the equivalent hand-written ROCDL would be 300+ lines of register juggling and the lane-correctness check would be done with a tensor of fp32 ones loaded against a known-good reference. The abstractions are not just convenient; they make the kernel verifiable by inspection.
例 03 是第一个抽象与机制比例显著的例子。 不到 80 行 layout 代码产出一个可跑的 tiled MFMA —— 等价的手写 ROCDL 要 300+ 行寄存器周旋, 而 lane 正确性还得用一个 FP32 全 1 的 tensor 对照已知好的参考来验。 这些抽象不仅是方便, 它们让 kernel 看一眼就能验证正确性。
M5 · The production kernel04 · preshuffle GEMM
M5 · 生产 kernel04 · preshuffle GEMM
The fourth example is where it stops being a tutorial. 211 lines, a 4096³ FP16 GEMM, BLOCK = (128, 128, 64), two-stage LDS double buffer with XOR swizzle, double-buffered B in registers, a hand-written instruction scheduler in the hot loop, and a host-side weight preshuffle that flattens an MFMA-specific lane layout into a contiguous global memory pattern. Every optimization that production CDNA GEMMs use is here in compact form. The example earns its place at the top of the ladder.
第四个例子已经不是 tutorial 了。 211 行, 4096³ FP16 GEMM, BLOCK = (128, 128, 64), 两阶段 LDS 双缓冲 + XOR swizzle、 寄存器里双 buffer B、 hot loop 里手写的指令调度器, 还有一个 host 端权重预洗, 把 MFMA 特有的 lane 布局摊平到 global 内存中的连续 pattern。 production CDNA GEMM 用到的每一项优化都以紧凑形式出现在这里。 它当之无愧地坐在阶梯顶端。
The four pieces that make a fast GEMM
让 GEMM 跑快的四件套
| Optimization | What it costs | What it buys | |||
|---|---|---|---|---|---|
| 优化 | 代价 | 收益 | |||
| Preshuffle B (host-side) | Preshuffle B (host 端) | One memory reshape at load time | 加载时一次内存 reshape | MFMA gets lane-correct data from raw buffer_load_dwordx4 — no LDS transpose |
MFMA 直接从原始 buffer_load_dwordx4 拿到 lane-correct 数据 —— 不用 LDS 转置 |
| XOR swizzle on LDS A | LDS A 的 XOR swizzle | One fly.composed_layout wrapper |
一个 fly.composed_layout 包装 |
Bank conflicts on 64-lane ds_read drop from 8× to 1× |
64 lane ds_read 的 bank conflict 从 8× 降到 1× |
| 2-stage software pipeline | 两阶段软件流水线 | Double the LDS for A · double the B registers | A 的 LDS 翻倍, B 的寄存器翻倍 | Next K-block's buffer_load overlaps current MFMA |
下一个 K 块的 buffer_load 与当前 MFMA 重叠 |
| Explicit instruction schedule | 显式指令调度 | ~30 lines of fx.rocdl.sched_* |
大约 30 行的 fx.rocdl.sched_* |
~30% lift in peak FLOPs by filling issue slots | 填满 issue slot, 峰值 FLOPs 提升约 30% |
shuffle_weight in tests/utils.py) so that a plain buffer_load_dwordx4 off the global B already lands MFMA-lane-correct in VGPRs. The kernel then constructs a composed !fly.layout whose strides describe the chunked physical layout — same logical (N, K) shape, different stride math. For inference weights that never change, this trade saves the entire LDS round-trip on B.tests/utils.py 里的 shuffle_weight), 让一条普通的 buffer_load_dwordx4 从 global B 加载出来就已经是 MFMA lane-correct 的 VGPR 数据。 kernel 然后构造一个 composed !fly.layout, 它的 stride 描述这种 chunked 物理布局 —— 同样的逻辑 (N, K) shape, 不同的 stride 算术。 对推理时不变的权重, 这个 trade 省掉 B 的整个 LDS 来回。The LDS · XOR swizzle and double buffer
LDS · XOR swizzle 与双缓冲
The XOR swizzle (3, 3, 3) is the CuTe convention for "bank-conflict-free LDS access" — three bits of upper-order address XOR'd into three bits of lower-order address, in groups of three. With it, the 64-lane ds_read_b128 hits each of the 32 LDS banks exactly twice instead of all hitting one. Without the swizzle, the same access pattern would serialize over 8 cycles instead of completing in 1.
XOR swizzle (3, 3, 3) 是 CuTe 里"LDS 无 bank conflict 访问"的惯例 —— 把 3 位高位地址 XOR 到 3 位低位地址, 三位一组。 有了它, 64 lane 的 ds_read_b128 正好命中 32 个 LDS bank 各 2 次, 而不是全部撞在 1 个 bank 上。 没有 swizzle 的话, 同样的访问 pattern 会被串行化成 8 周期, 而不是 1 周期完成。
The trailing dimension STAGES_A = 2 is the double buffer — half the LDS holds the K-block currently being consumed by MFMA, half holds the K-block being filled by the next iteration's buffer_load → ds_write. Both halves are addressed by the same swizzle, which means a single composed layout serves the whole ping-pong without per-stage stride math.
尾维 STAGES_A = 2 就是双缓冲 —— LDS 一半装当前被 MFMA 消费的 K 块, 另一半装下一轮 buffer_load → ds_write 正在填的 K 块。 两半用同一个 swizzle 寻址, 这意味着一个 composed layout 就能服务整个 ping-pong, 不需要按 stage 算 stride。
The pipeline · read stage, write stage, ping-pong
流水线 · read stage, write stage, ping-pong
buffer_loads that prefetch K-block n+1 while the MFMA on track 5 is still grinding through K-block n. Track 3's ds_write stores the prefetched A into the other LDS stage (write = read ^ 1). The barriers separate iterations because ds_write from this iter must complete before next iter's ds_read begins. The B operand is fully register-resident with two stages — there's no ds_write B track because preshuffle eliminated the LDS round-trip on B entirely.buffer_load, 它们在轨道 5 上的 MFMA 还在嗑 K 块 n 的时候预取 K 块 n+1。 轨道 3 的 ds_write 把预取的 A 写到另一半 LDS stage (write = read ^ 1)。 barrier 把迭代分开, 因为本迭代的 ds_write 必须在下一迭代的 ds_read 开始前完成。 B 操作数完全驻留在寄存器中, 两个 stage —— 没有 ds_write B 轨道, 因为 preshuffle 已经把 B 的 LDS 来回完全消掉了。The soffset=next_k * gA_k_stride argument is worth pausing on. It is not byte arithmetic on a global pointer — it's the scalar offset field of the ROCDL buffer_load instruction. The buffer resource descriptor is built once at kernel entry; subsequent loads from different K-positions just vary the SGPR-held scalar offset. No v_add on a VGPR address per lane, no per-thread index computation. gcnasm's vector_add lives on the same mechanism.
soffset=next_k * gA_k_stride 这个参数值得停一下。 它不是全局指针的字节算术 —— 它是 ROCDL buffer_load 指令的 scalar offset 字段。 buffer resource descriptor 在 kernel 入口建一次; 后续从不同 K 位置 load 只是变 SGPR 里的 scalar offset。 没有 per-lane 在 VGPR 地址上做 v_add, 没有 per-thread 索引计算。 gcnasm 的 vector_add 就吃这个机制。
hot_loop_scheduler · telling the backend where to put each instruction
hot_loop_scheduler · 告诉 backend 每条指令放哪里
The last bit is the most foreign to anyone coming from Triton or CUDA C. After the algorithmic body of run_pipeline_stage, the kernel calls hot_loop_scheduler(), which emits a sequence of fx.rocdl.sched_* intrinsics:
最后这块对 Triton 或 CUDA C 出身的人最陌生。 在 run_pipeline_stage 的算法主体后面, kernel 调 hot_loop_scheduler(), 它发出一序列 fx.rocdl.sched_* intrinsic:
These are not real instructions. They are scheduling hints that the LLVM AMDGPU backend respects: sched_mfma(2) says "place exactly 2 MFMA instructions at this point in the schedule, in addition to whatever has already been scheduled here." The compiler picks which MFMAs (from the surrounding code) to put here, but the count and ordering are pinned.
这些不是真实指令。 它们是 调度提示, LLVM AMDGPU backend 会尊重它们: sched_mfma(2) 说"在调度的这个位置, 在已经放在这里的指令之外, 再放正好 2 条 MFMA 指令"。 编译器自己挑哪些MFMA (从周围代码里挑) 放在这里, 但数量和顺序是钉死的。
Why it matters: the CDNA matrix core has a long latency. A single MFMA can take 16–32 cycles to retire, and the wave can issue other instructions during the wait — ds_read, buffer_load, ds_write. Without hints, LLVM's scheduler does an okay job but tends to cluster same-class instructions. With hints, you can interleave them so that each issue slot during the MFMA latency holds something useful. The 8 leading iterations of sched_main_iter(with_vmem=True) spread buffer_loads evenly across the early hot-loop body; the trailing 7 with with_dswr=True spread ds_writes across the tail. This pattern is the difference between 60% and 90%+ of MFMA peak in real kernels.
为什么重要: CDNA matrix core 的延迟很长。 一条 MFMA 退役要 16–32 周期, wave 在等待期间可以发其他指令 —— ds_read、 buffer_load、 ds_write。 不给提示, LLVM scheduler 会干得还行, 但倾向于把同类指令聚在一起。 给提示之后, 你可以把它们交错排开, 让 MFMA 延迟期间每个 issue slot 都装着有用的东西。 开头 8 次 sched_main_iter(with_vmem=True) 把 buffer_load 均匀撒在 hot loop 前段; 末尾 7 次 with_dswr=True 把 ds_write 撒在 hot loop 后段。 真实 kernel 上, 这个 pattern 就是 60% 和 90%+ MFMA peak 的差距。
Schedulers are fragile. Changing BLOCK_K, swapping MFMA shape, or even adding a single instruction inside the hot loop can invalidate the count assumptions and tank performance. kernels/preshuffle_gemm.py in production uses a different scheduler for each (BM, BN, BK, MFMA-shape) tuple, often tuned with the ATT trace analyzer (capture-kernel-trace). For a first pass at a new GEMM, write it without a scheduler and add one only after you've profiled.
scheduler 很脆弱。 改 BLOCK_K、 换 MFMA 形状、 甚至在 hot loop 里多加一条指令, 都可能让数量假设失效、 性能暴跌。 production 的 kernels/preshuffle_gemm.py 给每个 (BM, BN, BK, MFMA-shape) 元组各用一份 scheduler, 通常用 ATT trace 分析器 (capture-kernel-trace) 调出来的。 写新 GEMM 的第一遍, 先不放 scheduler, 等 profile 完再加。
Run it:
跑一下:
The kernel runs and is correct. To measure how close to peak it gets, swap the example for the production kernels/preshuffle_gemm.py and use tests/kernels/test_preshuffle_gemm.py as the harness — that's the path the AMD team uses for tuning.
kernel 跑通了, 结果正确。 要看它离 peak 多近, 把这个例子换成 production 的 kernels/preshuffle_gemm.py, 用 tests/kernels/test_preshuffle_gemm.py 当 harness —— 这是 AMD 团队调优时走的路径。
M6 · Underneath the @decoratorsCompile pipeline · Python → fatbin
M6 · @ 装饰器之下编译流水线 · Python → fatbin
The Python you wrote was traced into MLIR. What happens next is worth knowing in outline, because most FlyDSL debugging happens at the IR level — when a kernel produces NaN or a wrong number, the first move is FLYDSL_DUMP_IR=1 and reading module-after-fly-layout-lowering.mlir.
你写的 Python 被 trace 成 MLIR。 接下来发生的事值得大概知道, 因为 FlyDSL 大部分调试都发生在 IR 层 —— kernel 出 NaN 或者数算错时, 第一步是 FLYDSL_DUMP_IR=1, 然后读 module-after-fly-layout-lowering.mlir。
Tracing · @flyc.kernel
Tracing · @flyc.kernel
The @flyc.kernel decorator wraps the Python function in a tracer. On first call with a given type signature, the tracer rewrites the function's AST to intercept each fx.* call, then executes the function in a context where each operation emits MLIR ops instead of computing values. By the end of the trace, you have a gpu.module with a single gpu.func kernel inside, expressed in a mix of the fly, gpu, arith, scf, memref, and vector dialects.
@flyc.kernel 装饰器把 Python 函数包进一个 tracer。 第一次以某个类型签名调用时, tracer 改写函数的 AST 来拦截每个 fx.* 调用, 然后在一个特殊上下文里执行函数 —— 每个操作发射 MLIR op 而不是计算具体值。 trace 结束时你得到一个 gpu.module, 里面有一个 gpu.func kernel, 用 fly、 gpu、 arith、 scf、 memref、 vector dialect 的混合体表达。
Two facts about tracing are load-bearing for kernel authors:
tracing 有两件事对 kernel 作者很关键:
- Python
ifon traced values is illegal. The trace runs once. If a branch depends on a kernel-time value (e.g.,tid), there's no way to materialize both branches. Usefx.const_expr(cond)when the condition is compile-time, orscf.if(via FlyDSL'sif_else) when it must be runtime. - Python
forover a list unrolls. Usefx.range_constexpr(n)for fully-unrolled compile-time loops, plainrange(start, stop, step, init=[acc])for anscf.forwith loop-carried values.
- 对 traced 值用 Python
if不合法。 trace 只跑一次。 如果分支依赖一个 kernel-time 值 (比如tid), 没办法把两条分支都物化。 编译期条件用fx.const_expr(cond); 必须运行期的话用scf.if(FlyDSL 的if_else)。 - Python
for遍历一个列表会展开。 完全展开的编译期循环用fx.range_constexpr(n); 带循环携带值的scf.for用range(start, stop, step, init=[acc])。
The pass pipeline · three stages, one fatbin
Pass 流水线 · 三段式, 一份 fatbin
The pipeline is built by RocmBackend._pipeline_parts() in python/flydsl/compiler/backends/rocm.py and split into three named stages. The split point matters: with FLYDSL_COMPILE_LLVM_DIR, FlyDSL can stop after Stage A, dump the IR, and hand it off to an external mlir-opt + llc toolchain.
Pass pipeline 在 python/flydsl/compiler/backends/rocm.py 里的 RocmBackend._pipeline_parts() 构建, 切成三段。 切点是有意义的: 配 FLYDSL_COMPILE_LLVM_DIR 时, FlyDSL 可以只跑到 Stage A, dump 出 IR, 交给外部 mlir-opt + llc 工具链接着跑。
| # | Pass | What it turns into what | |
|---|---|---|---|
| # | Pass | 把什么变成什么 | |
Stage A · pre_binary_fragments — Fly → ROCDLStage A · pre_binary_fragments —— Fly → ROCDL | |||
| 1 | fly-rewrite-func-signature |
lower DSL types (IntTuple, Layout, ComposedLayout, CoordTensor, MemRef) at function / SCF boundaries to packed LLVM structs |
把 DSL 类型 (IntTuple、 Layout、 ComposedLayout、 CoordTensor、 MemRef) 在函数和 SCF 边界处 lower 成打包好的 LLVM struct |
| 2 | fly-canonicalize |
folds !fly.layout algebra at compile time when shapes are static |
静态 shape 时, 在编译期折叠 !fly.layout 代数 |
| 3 | fly-layout-lowering |
turns fly.crd2idx, partitions, divides into concrete arith + vector ops |
把 fly.crd2idx、 partition、 divide 变成具体的 arith + vector op |
| 4 | fly-int-swizzle-simplify |
algebraically simplifies the swizzle-shaped arith sequences emitted by applySwizzle |
代数地化简 applySwizzle 发出的 swizzle 形态的 arith 序列 |
| 5 | canonicalize |
standard MLIR canonicalization (constant folding, etc.) | 标准 MLIR canonicalization (常量折叠等) |
| 6 | fly-convert-atom-call-to-ssa-form |
lifts copy_atom_call / mma_atom_call into their SSA counterparts; register tensors become vector SSA values |
把 copy_atom_call / mma_atom_call 提升成对应的 SSA 形式; 寄存器 tensor 变成 vector SSA 值 |
| 7 | fly-promote-regmem-to-vectorssa |
promotes fly.make_ptr(register) memory semantics to vector SSA (depends on pass 6) |
把 fly.make_ptr(register) 的 memory 语义升级成 vector SSA (依赖 pass 6) |
| 8 | convert-fly-to-rocdl |
copy atoms become rocdl.buffer_load/store; MMA atoms become rocdl.mfma.* |
copy atom 变成 rocdl.buffer_load / store; MMA atom 变成 rocdl.mfma.* |
| 9 | canonicalize |
second canonicalization round after ROCDL lowering | ROCDL lower 之后的第二轮 canonicalize |
| 10 | gpu.module(convert-scf-to-cf, cse, convert-gpu-to-rocdl{chipset=gfxNNN ...}, fly-rocdl-cluster-attr) |
inside gpu.module: SCF→CF, CSE, GPU intrinsics → ROCDL, then inject amdgpu-cluster-dims into the llvm.func passthrough |
在 gpu.module 里面: SCF → CF、 CSE、 GPU intrinsic → ROCDL, 然后把 amdgpu-cluster-dims 注入到 llvm.func 的 passthrough |
Stage B · binary_prep_fragments — → LLVMStage B · binary_prep_fragments —— → LLVM | |||
| 11 | rocdl-attach-target {chip=gfxNNN ...} |
tags the module with the target chip plus fast / unsafe-math / wave64 options |
给 module 打上目标芯片标签, 外加 fast / unsafe-math / wave64 等选项 |
| 12-13 | convert-scf-to-cf → convert-cf-to-llvm |
host-side SCF → CF → LLVM | host 侧 SCF → CF → LLVM |
| 14 | gpu-to-llvm{use-bare-pointers-...=true} |
GPU types and host launcher → LLVM dialect | GPU 类型和 host launcher → LLVM dialect |
| 15-17 | convert-vector/arith/func-to-llvm |
final lower of the remaining vector / arith / func ops to LLVM IR | 把剩下的 vector / arith / func op 终极 lower 到 LLVM IR |
| 18 | reconcile-unrealized-casts (+ optional ensure-debug-info-scope-on-llvm-func) |
final cast cleanup; debug-info scope inserted when FLYDSL_DEBUG_ENABLE_DEBUG_INFO=1 |
最后一轮 cast 清理; FLYDSL_DEBUG_ENABLE_DEBUG_INFO=1 时顺手插入 debug info scope |
Stage C · binary_fragmentStage C · binary_fragment | |||
| 19 | gpu-module-to-binary {format=fatbin opts="..."} |
invokes LLVM AMDGPU backend → HSA fatbin | 调 LLVM AMDGPU backend → HSA fatbin |
fly-layout-lowering is where most kernel-level bugs surface. If your partition_S produces a fragment of unexpected shape, it's because the layout algebra simplifies to something other than what you intended. Dump the IR after that pass to see what FlyDSL actually believes the indexing is. FLYDSL_DEBUG_PRINT_AFTER_ALL=1 dumps after every pass; FLYDSL_DUMP_DIR=./ir sends them to a directory you can grep through.
fly-layout-lowering 是大部分 kernel 级 bug 浮出水面的地方。 如果 partition_S 产出的 fragment 形状不对, 是因为 layout 代数简化成了你预期之外的东西。 dump 这个 pass 之后的 IR, 看 FlyDSL 实际认为索引是什么。 FLYDSL_DEBUG_PRINT_AFTER_ALL=1 每个 pass 后都 dump; FLYDSL_DUMP_DIR=./ir 把它们写到一个目录, 方便 grep。
An aside on what's not in the pipeline: gpu-kernel-outlining is gone. Outlining happens during Python tracing — @flyc.kernel emits gpu.func directly into a gpu.container_module, so by the time MlirCompiler.compile() takes the module, the kernel is already outlined.
顺带说一下流水线里没有的东西: gpu-kernel-outlining 不在了。 outlining 现在发生在 Python tracing 阶段 —— @flyc.kernel 直接把 gpu.func 发到 gpu.container_module 里, 所以等 MlirCompiler.compile() 拿到 module 时, kernel 已经 outline 过了。
The JIT cache
JIT 缓存
The compiled fatbin is keyed on the type signature, the Constexpr values, and a source hash of the kernel function (closure included). It is stored under ~/.flydsl/cache/. The cache invalidates on source changes to the kernel function itself, but not on changes to C++ passes or helper modules that aren't part of the traced closure. So:
编译出的 fatbin 用类型签名 + Constexpr 值 + kernel 函数 (含 closure) 的源码 hash 作为 key。 落在 ~/.flydsl/cache/ 下面。 修改 kernel 函数本身的源码会让缓存失效, 但修改 C++ pass 或不属于 traced closure 的 helper module不会。 所以:
§ R · Six traps from real debuggingReefs
§ R · 真实调试踩过的六个坑暗礁
The patterns FlyDSL gets wrong, the mistakes that survive review, the workarounds that take three days to discover.
FlyDSL 容易出问题的 pattern, 能逃过 review 的失误, 要花三天才能想到的解法。
-
1 · Branch-only values kill MLIR dominance.1 · 分支限定值破坏 MLIR dominance。Defining a value inside one branch of a Python
ifand using it after is the most common compile-time failure. The IR has no merge point for the two branches' definitions, so you get a dominance error. Hoist the value above the branch or return it as an explicit merged result.在 Pythonif的一个分支里定义一个值, 在 if 之后再用 —— 这是最常见的编译期失败。 IR 里两个分支的定义没有汇合点, 所以你拿到 dominance 错误。 把这个值提到分支之前定义, 或者显式作为合并后的返回值。 -
2 · SmemPtr._view_cache survives scf.for.2 · SmemPtr._view_cache 跨 scf.for 残留。If a shared-memory view is constructed inside
scf.forand consumed after, the cached view points to an SSA value that doesn't dominate the consumer. Clear the cache:SmemPtr._view_cache = Noneon loop exit.如果共享内存的 view 在scf.for里构造, 在循环外消费, 缓存里的 view 指向一个不 dominate 消费者的 SSA 值。 退出循环时清掉:SmemPtr._view_cache = None。 -
3 · Closure values don't trigger cache invalidation.3 · closure 值不触发缓存失效。JIT cache hashes the kernel function and its closure. If you change a Python constant defined outside the kernel and the kernel doesn't reference it through its closure, the cache won't invalidate. Either pass the value as a
Constexprargument, or clear the cache manually.JIT 缓存 hash 的是 kernel 函数和它的 closure。 如果你改一个 kernel 之外的 Python 常量, 而 kernel 没有通过 closure 引用它, 缓存不会失效。 要么把这个值作为Constexpr参数传进去, 要么手动清缓存。 -
4 · BufferCopy soffset is in elements, not bytes.4 · BufferCopy soffset 单位是 element, 不是 byte。The
soffsetargument tofx.copywith a BufferCopy atom is in elements of the atom's dtype, not bytes.soffset=next_k * gA_k_stridein example 04 multiplies a K-iter by the stride that already accounts for the element size. Triple-check this when chasing addressing-off-by-1.fx.copy配 BufferCopy atom 时,soffset参数的单位是 atom dtype 的元素数, 不是字节。 例 04 里soffset=next_k * gA_k_stride是 K-iter 乘上一个已经把元素大小算进去的 stride。 排查地址 off-by-1 时多查几遍这点。 -
5 · MFMA traversal order changes register pressure.5 · MFMA traversal order 影响寄存器压力。The
traversal_orderargument tofx.gemmchooses how the inner three nested loops over (M-atom, N-atom, K-atom) are ordered.KNMtypically minimizes A-fragment register pressure;MNKminimizes accumulator pressure. The wrong choice can spill VGPRs and crater performance.fx.gemm的traversal_order参数决定 (M-atom, N-atom, K-atom) 三层内嵌循环的顺序。KNM通常最小化 A 的 fragment 寄存器压力;MNK最小化累加器压力。 选错会 spill VGPR, 性能崩盘。 -
6 · Stale schedulers after a tile change.6 · 改 tile 之后 scheduler 悄悄失效。A
hot_loop_schedulertuned for (BM=128, BN=128, BK=64) breaks silently when you change the tile — performance drops, the kernel still gives correct numbers. Profile any tile-size change with the kernel trace analyzer before assuming the scheduler still fits.为 (BM=128, BN=128, BK=64) 调过的hot_loop_scheduler, 在你改了 tile 之后会悄悄失效 —— 性能掉, kernel 数还是对的。 任何 tile 改动之后, 都用 kernel trace analyzer profile 一下再假设 scheduler 还合身。
§ A · Where this lands for AMD workAMD-specific notes
§ A · 这件事对 AMD 工作意味着什么AMD 相关说明
FlyDSL is, in 2026, the cleanest abstraction layer for writing high-performance kernels on AMD CDNA. Triton-ROCm exists and produces correct kernels, but its scheduling model is opaque — you cannot pin "two MFMAs then one ds_read then two MFMAs" the way fx.rocdl.sched_* lets you. Composable Kernel (CK) gives you that control but in C++ templates that take 30 minutes to compile. FlyDSL is the middle: Python ergonomics, MLIR backbone, AMD-specific scheduling primitives.
2026 年, 在 AMD CDNA 上写高性能 kernel, FlyDSL 是最干净的抽象层。 Triton-ROCm 也存在, 也能产出正确 kernel, 但它的调度模型不透明 —— 你没法像 fx.rocdl.sched_* 那样钉死"两条 MFMA 然后一条 ds_read 再两条 MFMA"。 Composable Kernel (CK) 给你这个控制权, 但写在编译 30 分钟的 C++ 模板里。 FlyDSL 处在中间: Python 人体工学、 MLIR 骨架、 AMD 专属调度 primitive。
Three observations worth carrying:
值得带走的三条观察:
- The layout API is the AMD-specific contribution. CuTe gave the world layout algebra; FlyDSL gives it a typed MLIR home with passes that lower it. The fact that
!fly.layoutis a first-class type means an MLIR-aware tool (a verifier, an autotuner, an optimization pass) can reason about layouts in the IR, not just consume strings. - Preshuffle is a recurring idea, not a one-off. The
shuffle_weighttrick in example 04 — reshape on the host so the kernel needs no transpose — appears across the production kernels for every weight tensor whose layout is more constrained than its compute pattern. MoE GEMM, blockscale GEMM, INT4 GEMM all use it. If you're writing a new AMD GEMM and not preshuffling, ask yourself why. - Scheduling is where peak FLOPs hide. A FlyDSL kernel without
sched_*hints will hit 60–70% of peak. With a tuned scheduler it can hit 90%+. The gap is purely instruction-mix-in-the-hot-loop, and it cannot be recovered after the fact by tweaking LLVM flags. Production kernels in this repo treathot_loop_scheduleras a separate artifact of the kernel, tuned with ATT traces fromcapture-kernel-traceand reviewed independently of the algorithmic body.
- layout API 是 AMD 这边特有的贡献。 CuTe 把 layout 代数带给世界; FlyDSL 给它一个 typed MLIR 之家, 加一整套 lower 它的 pass。
!fly.layout是一等公民类型这件事意味着, 任何 MLIR-aware 工具 (verifier、 autotuner、 优化 pass) 都可以在 IR 里对 layout 进行推理, 而不是只能消费字符串。 - Preshuffle 是反复出现的思路, 不是一次性技巧。 例 04 的
shuffle_weight—— host 端 reshape, kernel 就不需要转置 —— 在 production kernel 里, 对每一个 layout 比计算 pattern 更受限的权重 tensor 都出现。 MoE GEMM、 blockscale GEMM、 INT4 GEMM 都用它。 写新 AMD GEMM 又不 preshuffle 的话, 问问自己为什么。 - 调度才是 peak FLOPs 藏的地方。 没有
sched_*提示的 FlyDSL kernel 大概能跑到 60–70% peak; 有调好的 scheduler 能跑到 90%+。 这段差距纯粹是 hot loop 内的指令编排, 事后调 LLVM flag 救不回来。 这个 repo 里 production kernel 把hot_loop_scheduler当作 kernel 的一件独立 artifact, 用capture-kernel-trace出来的 ATT trace 调, 独立于算法主体 review。
For a new kernel on CDNA3/CDNA4: (1) start from an example, (2) get correctness first, (3) profile with capture-kernel-trace + kernel-trace-analysis, (4) identify the top stall — usually lgkmcnt waits on ds_read, or vmcnt waits on buffer_load — (5) add the appropriate optimization (LDS swizzle for the former, prefetch + scheduler for the latter), (6) re-profile, repeat. The Skills in this repo (gemm-optimization, lds-optimization, prefetch-data-load) are tuned shortcuts for each of these steps.
写一个新的 CDNA3 / CDNA4 kernel: (1) 从一个 example 起步, (2) 先做正确, (3) 用 capture-kernel-trace + kernel-trace-analysis profile, (4) 找到 top stall —— 通常是 ds_read 上的 lgkmcnt 等待, 或 buffer_load 上的 vmcnt 等待 —— (5) 加对应优化 (前者上 LDS swizzle, 后者上 prefetch + scheduler), (6) 重新 profile, 重复。 这个 repo 里的 Skills (gemm-optimization、 lds-optimization、 prefetch-data-load) 是这些步骤各自的捷径。
§ ∞ · What four examples teachEpilogue
§ ∞ · 四个例子教了什么尾声
The shape of the four examples is the shape of every FlyDSL kernel. Construct a layout. Slice it to the current block. Partition the slice to the current thread. Allocate a register fragment. Issue a copy. Maybe move through LDS. Issue a compute. Maybe overlap with prefetch. Write back. The vocabulary that walks through this scales from a 16-line vector add to a 1500-line production paged-attention kernel without ever changing form.
这四个例子的形状就是每个 FlyDSL kernel 的形状。 构造 layout。 切到当前 block。 把切片再分给当前 thread。 分配 register fragment。 发 copy。 可能过一层 LDS。 发 compute。 可能与 prefetch 重叠。 写回。 这套词汇从一个 16 行的向量加到 1500 行的 production paged attention kernel, 形状从不变。
What is genuinely surprising — and what kept me reading the repo after the four examples — is how much of "GPU kernel optimization" turns out to be the same optimization repeated against different layouts. Preshuffle B for an MFMA-friendly load. Swizzle A in LDS for bank-free ds_read. Double-buffer everything so memory and compute overlap. Pin the instruction schedule so the matrix core doesn't idle. Every kernel I read in kernels/ applied some subset of these. The variations were in the geometry — tile sizes, MFMA shapes, atom permutations — but the optimization vocabulary was constant.
真正让我惊讶的事 —— 也是看完四个例子之后我继续读这个 repo 的原因 —— 是"GPU kernel 优化"绝大部分其实是同一套优化在不同 layout 上重复。 Preshuffle B 让 MFMA 友好加载。 LDS 上 swizzle A 拿到无 bank conflict 的 ds_read。 给所有东西做双缓冲, 让 memory 和 compute 重叠。 钉死指令调度, 让 matrix core 不空转。 我在 kernels/ 里读过的每个 kernel 都用了这套的某个子集。 变化在几何上 —— tile 大小、 MFMA 形状、 atom permutation —— 但优化的词汇是不变的。
That constancy is what FlyDSL captures. By making layout algebra a typed primitive and copy/MMA atoms a composable algebra, it turns "GPU kernel optimization" from a list of tricks into a calculus that can be reasoned about, transformed, and verified. Whether that calculus is the right abstraction for the next generation of AMD silicon (CDNA5, the unified RDNA/CDNA path) is an open question — but for what runs on MI300/MI350/MI450 today, it's the cleanest tool I've used.
FlyDSL 抓住的就是这种不变性。 把 layout 代数做成 typed primitive、 把 copy / MMA atom 做成可组合代数, 它把"GPU kernel 优化"从一份技巧清单变成一套可推理、 可变换、 可验证的演算。 这套演算是不是下一代 AMD 硅 (CDNA5、 统一 RDNA / CDNA 路线) 的正确抽象, 这是个开放问题 —— 但对今天跑在 MI300 / MI350 / MI450 上的东西, 这是我用过最干净的工具。
The next entry in this series is likely aiter — the production kernel library that the FlyDSL kernels target, and the test scaffolding that this repo borrows. Reading FlyDSL alone gives you the language; reading aiter shows you the dictionary.
系列下一篇大概率写 aiter —— FlyDSL kernel 瞄准的那个 production kernel 库, 也是 repo 这边借用的测试基础设施。 只读 FlyDSL 你拿到的是语言, 读 aiter 才看到字典。
§ References · the docs to read alongsideReferences
§ 参考资料 · 与这篇并读参考资料
-
FlyDSL · github.com/ROCm/FlyDSLThe source.
docs/layout_system_guide.mdhas the complete Quick Reference for layout algebra;docs/kernel_authoring_guide.mdhas the practical patterns;kernels/has the production reference implementations.源代码。docs/layout_system_guide.md有 layout 代数完整 Quick Reference;docs/kernel_authoring_guide.md是实战 pattern;kernels/是 production 参考实现。 -
CuTe · NVIDIA CUTLASS · cute/The intellectual parent. The layout algebra, the copy/MMA atom design, and the
partition_S/Didiom are all CuTe ideas. Reading the CuTe docs alongside FlyDSL clarifies which choices are universal and which are AMD-specific.思想上的母本。 layout 代数、 copy / MMA atom 设计、partition_S/Didiom 都是 CuTe 的概念。 CuTe 文档和 FlyDSL 对照着读, 可以看清哪些选择是通用的、 哪些是 AMD 特有的。 -
Categorical Foundations for CuTe Layouts · arxiv 2601.05972Colfax Research's formal treatment of layout algebra as a category. Sufficient to derive every algebraic identity FlyDSL relies on, and the right reading if you want to extend the algebra (e.g., custom
productvariants).Colfax Research 把 layout 代数当作一个 category 来形式化。 足够推出 FlyDSL 依赖的所有代数恒等式; 想扩展代数 (比如自定义product变体) 该读这本。 -
AMD CDNA3 ISA Reference · amd.com · CDNA3 ISAAuthoritative on every instruction FlyDSL's lowering pipeline emits. § 8 (MUBUF) for
buffer_load, § 10 (MFMA) for the matrix core, § 6 (s_waitcnt) for the vmcnt/lgkmcnt that scheduling controls.FlyDSL lowering 流水线发出的每条指令的权威手册。 § 8 (MUBUF) 对应buffer_load, § 10 (MFMA) 对应 matrix core, § 6 (s_waitcnt) 对应 scheduling 控制的 vmcnt / lgkmcnt。 -
MLIR documentation · mlir.llvm.orgFor reading
FLYDSL_DUMP_IRoutput and understanding thegpu,arith,scf,memref,vector, androcdldialects that FlyDSL composes with.读FLYDSL_DUMP_IR输出和理解 FlyDSL 组合用的gpu、arith、scf、memref、vector、rocdldialect 时用。 -
FlyDSL Tile Programming · the in-repo skillFlyDSL Tile Programming · repo 内 skillIf working in this repo with Claude Code, the
flydsl-tile-programmingandflydsl-kernel-authoringskills are tuned references for kernel construction patterns. Use them before re-deriving idioms from scratch.如果在这个 repo 里用 Claude Code 工作,flydsl-tile-programming和flydsl-kernel-authoring这两个 skill 是 kernel 构造 pattern 的精调参考。 从头推 idiom 之前先用它们。