Skip to the content.

Blackwell 上的 SVDQuant W4A4 算子实现 —— 基于 FA4 骨架的 warp 专用化、TMEM 与 2-CTA 持久化内核导览

如何让 Blackwell kernel 在复杂的流水线同步状态空间里不死锁 —— 借用 FlashAttention-4 的同步骨架(显式 per-warp 流水线状态 + warp 专用化 + 持久化 tile scheduler),而不是自己从头写一套 状态机。以本仓库 gemm_w4a4 为例:从 1-CTA 的 CUTLASS 例程直译版 重构成 FA4 衍生的 2-CTA 持久化内核,以及那一行藏在”看上去能跑” smoke 背后、价值 +198 % TF 的 SMEM 账目 bug。

代码:ultism/svdquant-kernels。 这篇文章的大量细节都活在仓库源码里 —— 行号、PTX、内核 docstring、 gotcha 文档 —— 建议配合仓库(以及一个能浏览仓库的 AI)一起读。

1. 前言

各 shape 的逐 shape MFU vs nunchaku —— 加粗青色 cell 是我们领先的位置

数字是 MFU(占该芯片 dense NVFP4 峰值的百分比)。两边不是同代芯片: 我们在 B200(SM_100,dense FP4 峰值 10 PFLOPS)上跑;nunchaku 的 NVFP4 在 __CUDA_ARCH__ >= 1200 上 gated,没有 SM_100 二进制,只能跑在 SM_120a/121a 的 RTX PRO 6000(4 PFLOPS 峰值)上 —— 两套 tensor core ISA、两条工具链、两代 Blackwell。MFU 已经把芯片峰值除掉了,但这张表 不是用来判断哪一边代码”写得好”的,它只是一个实现品质参考点: “成熟手写 inline PTX 在它自己的目标芯片上能跑多快”。同一台 B200、 剥掉 LoRA 和仿射、用 CUTLASS 的 dense_blockscaled_gemm_persistent.py 在 2-CTA 256×256 上跑出来的本地天花板是 45 %–63 % MFU;才是还真 值得追的剩余空间。

这个算子是 SVDQuant 里 compute-bound 的那一半:NVFP4 scaled-MMA + 小规模 的低秩 LoRA 残差 + 按列仿射。数学一行写得下;实现几乎用尽了 SM_100 / SM_103 比上一代多出来的全部原语。

仓库里有两版内核共存。v1cute_kernels/gemm_w4a4/kernel.py,1-CTA、 单体 @cute.kernel、stock cutlass.pipeline.PipelineState)在生产 shape 上卡在 ~27 % MFU;想把它升到 2-CTA 走 cta_group=TWO 的第一阶段尝试 得到的基本是零增益(28 % vs 27 %)。v2_fa4cute_kernels/gemm_w4a4/kernel_v2_fa4.py,FA4 衍生的 warp 专用化、 三 pipeline、2-CTA 持久化)是出货面,跑出上面那组数字的就是它。

整个项目里单行改动 ROI 最高的一处:把 2-CTA 模式下每个 CTA 对 LoRA-up 权重块的 SMEM 字节数估算减半。kernel 在 trace time 解一道 SMEM 预算的题 ——”每个 SM 给的共享内存这么多,主 K-loop 能同时 放几个 K-block 在飞?LoRA 那条预取又能跑多少 stage?”。LoRA-up 那一块手写了一行算术,少算了一件事:2-CTA 模式下硬件已经把这块 tile 在 cluster 内两个 CTA 之间分了,每个 CTA 实际拿到的片上分配 只是公式给的一半。预算求解器拿到这个虚高了一倍的数字,为了给 “其实不存在的”共享内存让位,悄悄把主 K-loop 的并发深度从 4 个 in-flight K-block 砍到了 2 个。症状是:没有症状 —— trace 通、 kernel 跑、数值对、看上去就是”有点慢”。改法:在那一行后面多除一次 cluster 的 CTA group 大小。生产 shape 上的墙钟:566 TF → 1685 TF (+198 %)4.2 % → 16.9 % MFU。同样 launch 配置下的 ncu A/B: Duration −31.2 %、SM Throughput +11.99 pp、SM Active Cycles −36.3 %。 提交 7296e90;完整数据在 §7。

这篇博文把这两个故事放一块儿写,因为它们本来就是同一个故事:LU SMEM 这条 bug 只有在 FA4 重构把 2-CTA 全链路打通之后才暴露得出来,而 LU SMEM 的修复之所以有价值,又是因为 FA4 重构把预算求解器放到了能真正 干活的位置上。

2. 为什么是这个算子,为什么写这篇博文

数学:

y = scaled_mma(act₄, wgt₄) · wcscale + bias + lora_act_in @ lora_up

输入是 NVFP4 打包格式(act, wgt: [M, K/2] uint8,每字节存两个 E2M1 nibble;ascales, wscales: [K/16, *] FP8-E4M3,按 16 个 K 一块的 scale)。 lora_act_in @ lora_up 是小秩 R 残差(生产里 R ≤ 128,最常见 R=32)。 wcscalebias 都是按输出列的。没有链式数据流,没有 softmax,没有 在线修正:一个主 MMA、一个 LoRA MMA、一个融合仿射。

下面两条设计约束决定了后面所有内容:

铺垫到这里。这篇博文要立的编辑主张:在 Blackwell 原语教学这件事上, 这个算子比 FA4 更适合做教材。FA4 的在线 softmax 和 S→P→O 链式数据流 都有真正的认知税 —— FA4 的复杂性大半其实不在 Blackwell,而在 注意力本身。SVDQuant W4A4 把那一层剥掉了:同样的 warp 专用化主循环、 同样的持久化 tile scheduler、同样的 tcgen05 累加器、同样的 TMA 捆绑、 同样的 2-CTA 切分 —— 但数学只占一屏。要想通过读一个真正的生产内核去 学 Blackwell 原语,这个算子是更干净的读物。

3. 这个内核用到的 Blackwell 原语

默认读者会 CUTLASS 2.x + CUDA。下面是 SM_100/SM_103 新增的部分,大致按 内核里实际碰到的顺序来。

3.1 tcgen05.mma scaled-MMA 与 NVFP4 atom

NVFP4 是块缩放 FP4:两个 E2M1 nibble 打包到一个字节作为数值,再加每 16 个 K 元素一个 FP8-E4M3 的 scale。算上块 scale 后有效精度约 7 bit。 Blackwell 的 tcgen05.mma.kind::mxf4nvf4.block_scale.scale_vec::4X atom 两个 packed operand 加两个 scale tensor 同时进来,输出一个 FP32 累加器 落到 TMEM。

CuTe DSL 通过 make_blockscaled_trivial_tiled_mma(...) 把这条暴露 出来。值得知道的:它在 Blackwell 上只暴露 MXF4、NVFP4、MXF8 三种 scaled-MMA —— NVFP4 落地的同时 INT4 scaled-MMA 在 ISA 层就被砍了。 (Ascend 的 cube unit 仍有 INT4 MMA,这也是为什么仓库的 Ascend pod 保留 INT4 而 CUDA pod 走 NVFP4 —— 框架层数学一致,内核层按硬件特化 格式。)

atom 通过 tiled_mma.set(tcgen05.Field.SFA, …).SFB 两个运行时 入口接受 scale。scale 住在 TMEM(不是 SMEM):内核每 K-block 工作 量 cute.copy 一次 SMEM → TMEM,然后再发 gemm。用法在 kernel_v2_fa4.py:1339-1346

tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator)
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)

前三行是 Python trace 期的对 tiled_mma 对象的状态修改 —— 它对随后 那个 cute.gemm 在 MLIR 里捕获时生效。第四行才是真正在设备上发出去的 umma.commit

关于这里说的 “NVFP4” 与 cuBLAS NVFP4 linear 的差:完整 NVFP4 spec 是两级 scaling —— 一个 per-tensor FP32 scale,再加一个每 16 个 K 元素一个 FP8-E4M3 block scale。nunchaku 的设计选择是只用一级:block scale,任何 per-tensor scaling 都在离线 calibration 时折进 block scale (或折进 wcscale)。我们这个内核沿用 nunchaku 的同一套数学,因此也是 单级 NVFP4。cuBLAS 的 NVFP4 linear 则在运行时把两级都暴露出来。在 per-tensor scale 离线折好的前提下两者数学等价;差别在 spec 把什么带 到运行时 API,不在可达精度。我们跟 nunchaku 走是因为 LoRA + wcscale 那一套本来就自然吸收了 tensor 级 scale。

3.2 2-CTA dense MMA via cta_group=TWO

cluster_shape=(2, 1) cluster 里的两个 CTA 协同处理一个更大的 tile。 atom 用 CtaGroup.TWO 构造,会在 MMA 的线程布局里插一个大小为 2 的 V(volume)维。pair 里每个 CTA 各持有 cluster 级工作的一半,但 leader CTA 发出去的每一条 MMA 两个 CTA 都参与。

cluster layout 因式分解成 (V, M, N, K)

cluster_shape_mn = (2, 1),CtaGroup.TWO:
  cluster_layout_vmnk.shape = ((2,), 1, 1, 1)
  rank=0 → flat coord (0, 0, 0, 0)   ← leader CTA
  rank=1 → flat coord (1, 0, 0, 0)   ← follower CTA

(在 2-CTA 下从 cluster_layout_vmnk 里读哪个 index 才能恢复每个 CTA 的 M 位置,是那种属于”代码理解 gap”的坑、不该塞进讲原语的正文里;写在 docs/gotchas/cute_dsl.md:90-151,要看自取。)

SMEM 红利来了。在 CtaGroup.TWO 下,MMA atom 的 partition_shape_A 会沿 M halve A,partition_shape_B 沿 N halve B。每 CTA 只需要装 1-CTA atom 一半的 operand SMEM —— 这就是 Modular matmul-on-blackwell- part-3 那篇里说的 “2xSM MMA: Shared Memory Optimization”。CUTLASS 的 dense_blockscaled_gemm_persistent.py 用它,v2_fa4 主路径在 A、B 上也 用它(kernel_v2_fa4.py:465-468)。LoRA 路径里的 LU 算子本来也该 用它 —— 那是 §7 的事。

2-CTA cluster:A 按 M 切、B 按 N 切,每 CTA 只持有一半的 operand SMEM

3.3 TMEM —— 可寻址的累加器空间

Blackwell 之前,MMA 的累加器在寄存器里,靠 mma.sync PTX 或 cute::gemm 搬。Blackwell 上累加器住在 tensor memory(TMEM) —— SM 局部的一块 内存,自带分配器(utils.TmemAllocator)、自带释放 barrier、自带 512 列宽的布局。两个直接后果:

SM_100 上 TMEM 预算最多 512 列。NVFP4 block-scaled MMA 在 256×128 tile 上要 128 列累加器 + 16 列 SFA + 32 列 SFB ≈ 176 列;把累加器翻倍 (next-tile ping-pong 的 overlapping_accum)在 tile_n=128 还能塞下、 在 tile_n=256 就崩。所以两边(CUTLASS 参考和我们)都写 num_acc_stage = 1 if tile_n == 256 else 2(gotcha 在 docs/gotchas/cute_dsl.md:231-287)。

3.4 TMA、extra_tx_count 捆绑、is_leader_cta

TMA 拷贝是异步的,通过 mbarrier arrive 信号完成。每次 TMA 给 barrier 的 expected_transactionstx_count)加上它送达的字节数;累加到阈值, barrier 翻 full,consumer 可以放行。CuTe DSL 的 pipeline.PipelineTmaUmma.create 把这一套打包成 producer/consumer pattern。

包装里有两个 SM_100 专用旋钮:

3.5 StaticPersistentTileScheduler

套路:launch min(num_tiles, sm_count) 个 CTA,每个 CTA 沿 tile_idx += grid_dim() 走 tile,越界就退出。省 launch 开销、跨 tile 保 warp 热、TMA pipeline 跨 tile 带状态。

实现量很小 —— FA4 的 tile_scheduler.py::StaticPersistentTileScheduler ~30 行,我们的等价物 inline 在 kernel_v2_fa4.py:885-892。难的不是 scheduler,是内核里其它部分的状态能不能跨 tile 边界活下去 —— 那是 §5 的内容。

3.6 Warp 专用化(预告)

一个 warp 做 TMA load(load warp),一个 warp 发 MMA(mma warp), 四个 warp 跑 epilogue(epilogue warps)。一共 6 warp × 32 线程 = 每 CTA 192 线程。每个 warp 有自己的 pipeline 状态,独立推进 —— 没有”内核 全局状态”这个东西。

这是 FA4 给 tcgen05 + TMEM 这一族内核定下的结构性 pattern。完整图景 (3 pipeline × 3 warp group + 每 warp PipelineStateSimple)在 §6。

4. v1 —— FA4 之前的基线

cute_kernels/gemm_w4a4/kernel.py 是 v1:主 NVFP4 scaled-MMA + β- interleaved LoRA,只跑 1-CTA、单体 @cute.kernel、stock cutlass.pipeline.PipelineState。docstring 把出身写得很坦白:从 cutlass/examples/python/CuTeDSL/blackwell/ dense_blockscaled_gemm_persistent.py 移植过来,剥掉了持久 TileScheduler、 clusters > 1、TMA multicast、overlapping_accumtile_n ∈ {64, 192} 的 SFB-shift 黑魔法。两个 shape 自适应的 1-CTA tiler:小 M 用 (128, 128),其余用 (128, 256)

v1 是的第一步:找一份已知能跑的 CUTLASS 例程,留下里面明显 可泛化的部分,把 LoRA β-interleave 接到同一个 TMEM 累加器上 (kernel.py:30-33 引用了 TV-layout match 校验),先发一版数值干净的 内核,再谈优化。

4.1 v1 干得不错的地方

4.2 v1 撞墙的地方 —— vs CUTLASS 同硬件同 shape

同 B200、同 shape,v1 对 CUTLASS 自家 dense_blockscaled_gemm_persistent.py(只跑主 NVFP4 MMA,不带 LoRA; 严格对应我们 v0 干的事):

shape (M, K, N) CUTLASS 1-CTA 128×256 CUTLASS 2-CTA 256×128 CUTLASS 2-CTA 256×256 ours 1-CTA ours 2-CTA Phase 1
4352 × 3840 × 3072 3847 TF 38.5 % 4202 TF 42.0 % 4545 TF 45.4 % 1309 TF 13.1 % 1185 TF 11.8 %
4352 × 3840 × 15360 4167 TF 41.7 % 5181 TF 51.8 % 5836 TF 58.4 % 2735 TF 27.4 % 2599 TF 26.0 %
4352 × 15360 × 3840 4096 TF 41.0 % 5903 TF 59.0 % 6339 TF 63.4 % 2646 TF 26.5 % 2964 TF 29.6 %
4352 × 10240 × 3072 4174 TF 41.7 % 5375 TF 53.8 % 6074 TF 60.7 % 2299 TF 23.0 % 2350 TF 23.5 %

(出处:cute_kernels/gemm_w4a4/README.md:154-160。CUTLASS 列跑的是 v1 在 v0 模式下做的事 —— 只主 NVFP4,所以比较是 apples-to-apples。)

表里有两条不好辩的事实:

  1. 同 tile(128×256,1-CTA)下,CUTLASS 比我们快 ~14 pp。持久 scheduler、多 stage MMA/epilogue overlap、更细致的 pipeline 纪律。 这些在 v1 架构里不是做不出来 —— 就是没做。
  2. 2-CTA Phase 1 基本拿了零。我们的 2-CTA 列每个 shape 都落在 1-CTA 列 1–3 pp 范围内,有的还更慢。CUTLASS 的 2-CTA 256×128 列 —— FLOPs- per-atom 1-CTA 128×256 一样 —— 在同硬件上抬升 10–18 pp。 2-CTA 的机制对 CUTLASS 是 work 的;在 v1 架构里就是 work 不起来。

4.3 诊断

只要尝试把 v1 抬到 2-CTA 持久化,内核需要跟踪的状态空间就沿五个维度 同时扩张:

  1. Pipeline stage。N 个 A/B SMEM stage,每 stage 一个 mbarrier、 一个 phase bit、一个 index。
  2. 2-CTA pair barriercta_group=TWO 下每个 TMA barrier 都要知道 cluster 里两个 CTA,过 is_leader_cta 闸,tx_count 里要烘进 cta_group_size
  3. 持久 tile loop。tile 边界不排干 pipeline;状态从 tile N 活到 tile N+1。
  4. LoRA β 第二条 MMA。第二条 atom 插进主 K-loop,自己有 LoRA 前奏 SMEM 的 producer/consumer 周期。
  5. Epilogue 校正链。最终在 epilogue 里融合 × wcscale + bias, per-column 因子自己有 SMEM 排级。

stock cutlass.pipeline.PipelineState 是隐式的(状态藏在 PipelineTmaUmma/PipelineUmmaAsync 包装的 advance() 方法里随调用 演化)、分支的(每次调用根据当前 phase 走不同代码路径)、单维度的 (每个 pipeline 一个 PipelineState,每个 warp 角色一个 pipeline)。 它干净处理维度 1。它不和维度 2–5 组合 —— 经验证据很锋利:之前的一个 持久化移植(kernel.py-类,提交 61905df在 1-tile-per-CTA 下数值 正确,到每 CTA 处理 ~20 tile 时挂死 500×。这是 phase/state 跨 tile 边界漂移的经典签名:状态机对第一圈正确,从此累积误差。

修法不是”把 v1 调得更狠”,是把状态机整套换掉。下一节。

5. 为什么是 FA4 —— 我们采用的脚手架

FA4(Flash-Attention 4 在 Blackwell 上的前向,源码在 flash-attention/flash_attn/cute/)在另一个算子上已经解决了上面 那个 5 维状态空间问题:注意力。解法是把 pipeline 状态显式化、per- warp 化,把持久 tile loop 拎成内核最外层结构,把每个 Blackwell 专用 的脚滚雷因式分解成有名字的原语。我们没拿走 FA4 的数学 —— 我们的 算子里没有在线 softmax、没有 S→P→O 链式数据流、没有 Q/K/V 切分 —— 但我们整套地接过了脚手架

5.1 从 FA4 接过来的部分

5.2 从 FA4 拿的部分

这是本节的编辑取舍。FA4 是注意力内核;它的复杂性大半是注意力的复杂性。 具体:

结果是:把 FA4 脚手架适配到 SVDQuant W4A4 反而剥掉了 FA4 较硬的那 部分。如果你读过 FA4 并理解了 warp 专用化 pattern,这个内核是更干净的 第二个示例 —— 动件少,但原语一致。

6. v2_fa4 —— 重写

当前的文件是 cute_kernels/gemm_w4a4/kernel_v2_fa4.py。它是 FA4 衍生 重写的第三个真实迭代:v0_fa4(无 LoRA,仅脚手架)、v1_fa4(= v0_fa4 + 单 stage LoRA,作为参考数字保留在隐藏代码路径)、v2_fa4+C1(= v1_fa4 + 2-stage LoRA 前奏 + × wcscale + bias 融合 epilogue + §7 那条 LU SMEM fix)。出货是 LU 修复后的 v2_fa4+C1;下文描述的就是这个出货面。

v2_fa4 warp 分工:load → mma → epilogue,三条 pipeline 跨 tile 边界从不重置

6.1 v0_fa4 —— 不含 LoRA 的脚手架

FA4 衍生分支的第一次提交是 kernel_v0_fa4.py:FA4 骨架,无 LoRA, 无 wcscale/bias。目的:在重新插 LoRA 之前先把新的状态机器单独验通。

数字,生产 shape M=4352 K=3840 N=3072 fp16:

  1-CTA 2-CTA
v0_fa4 7.7 % 7.6 %

(出处:cute_kernels/gemm_w4a4/README.md:183-189。)比 v1 的 27 % 低 —— 但这是意料之中的。v0_fa4 是个 partial-feature 脚手架;多 stage pipeline 还没调,overlapping_accum 也还没接,墙钟里全套 FA4 pattern (多 pipeline 初始化、tile scheduler 开销、warp 专用化 barrier 集合)的 费用都还没被后续优化摊掉。我们把它冻起来当 v0/v1 reference(用 enable_lora flag 闸住,同一份文件能两用),就往前走了。

v0_fa4 在设备上的第一次跑出了整个项目最干净的一条 bring-up bump,单开 一节。

6.2 Bump(inline):首次 smoke 9 分钟挂死

症状:在 Modal 上 launch、nvidia-smi 显示 GPU 在忙、9 分钟没有 stdout、然后 Modal 容器超时。没有 abort、没有 assert、没有 PTX 错 —— 就是干净地卡住。

原因(根因分析在 docs/kernels/gemm_w4a4_fa4_v0_bringup.md:27-44):MMA warp 的 单 stage pipeline_acc 的 producer phase 初值是 Int32(0)pipeline_init_arrive 在内核启动时跑,把 empty mbarrier 预先 arrive 到 parity 1。MMA warp 调用 producer_acquire with phase 0 —— 这意思是 “等 barrier 翻到 parity 0”。但 consumer(epilogue warp)还没跑、barrier 还在 parity 1、MMA warp 死等。

修法:把 acc_producer_phase 初值改成 Int32(1)。这跟 stock cutlass.pipeline.make_pipeline_state(Producer, ...) 底下返回的一致, 也跟 FA4 在 load() 注释里写的 “single-stage producer starts at 1” 对得上。两字符 patch:

# kernel_v2_fa4.py:1247-1253
# Single-stage pipeline_acc — phase bit only (XOR toggle).
# Producer starts at phase=1: `pipeline_init_arrive` pre-arms
# the empty barrier to parity=1, so the first `producer_acquire`
# with phase=1 returns immediately. Starting at 0 blocks forever
# (consumer never flips full, kernel hangs — was the 9-min hang
# on first smoke). Mirrors stock `make_pipeline_state(Producer)`.
acc_producer_phase = Int32(1)

从 bring-up 里带走的教训:在显式 per-warp pipeline 状态下,初值 不变量是你自己的事。没有 cutlass.pipeline.make_pipeline_state 替你 搞这事;包装也不替你搞。初值 phase 错一个 bit,内核就静默挂死。 bring-up 文档里还列了它的兄弟问题(Bug 2/3,跨持久迭代的 ACCUMULATE 状态 trace 冻结,修法是把 K-tile 循环写成 Python range() 全展开而 不是 cutlass.range(unroll=1))。

6.3 重新加上 LoRA —— 共享 TMEM 上的 β-interleave

LoRA 校正项 lora_act_in @ lora_up 很小(R ≤ 128)。如果和主 MMA 串行 跑(”α” 变体),在最差生产 shape 上墙钟会膨胀 ~50 %,因为 tcgen05 异步发射队列深度 4–8,atom 数量极少的 LoRA pass 喂不满它(完整分析在 docs/kernels/gemm_w4a4.md:26-52)。修法是 β:把 LoRA atom 撒进 主 K-loop 的发射流里,让 pipe 永远不只见到 LoRA。

机制立在三条 Blackwell 事实上:

  1. tcgen05 在 SASS scoreboard 这一层尊重累加器的 RAW 依赖。两条 atom 写同一组 TMEM cell;ptxas 把这识别为数据依赖,在指令控制字 里发出 scoreboard write/wait 编码,LoRA atom 的 UTCHMMA 必然要等 前一条主 UTCOMMA retire 后才完成。不需要软件 fence。(这条值得 单独点出来,因为 PTX 手册 9.7.16.6.2 给的 “pipelined pair” 集合 要求 kind/shape/acc 都同 —— 我们这对 kind 不同,落在集合外;粗 读容易解读成”切 kind 要插 tcgen05.fence“,其实不需要。pair 规则约束的是两条 atom 能否在 tensor pipe 里重叠流水,不是它 们对共享 acc 的写是否保持顺序;顺序由 PTX 之下一层的 SASS scoreboard 保证。对着 cubin 验过:PTX 里 0 个 tcgen05.fence,SASS 里 UTCOMMA/UTCHMMA 之间 0 个 FENCE.*;CUTLASS 上游也是同套 模型。)
  2. 两条 atom 可以打到同一个 TMEM 地址。通过 gemm_ptx_partial(acc_tmem_addr: Int32),两条 atom 写到同一组 FP32 累加器 cell。不需要 cute.Tensor 别名套路(v1 走的别名路; 能跑,但更难维护)。
  3. TV-layout match。主 NVFP4 atom 和 LoRA fp16/bf16 atom 把 per-CTA cta_tile_shape_mnk 切到 per-thread 寄存器 fragment。β 要 work, 线程 t 在两条 atom 下的”第 i 个元素”必须落在同一个 TMEM cell。 匹配是 trace 期校验的(内核 docstring 在 kernel_v2_fa4.py:1261-1270 引用;原始校验通过 cute_kernels/gemm_w4a4/verify_tmem_layout.py 在 bring-up 期跑, 1SM 128×2562SM 256×256 都验过)。

真正付的代价是 kind 切换边界丢掉流水重叠 —— LoRA atom 和它两侧的主 atom 不能在 tensor pipe 里并发,只能串行。在生产密度下看不见(R=128 时大约每 stride = K_atoms / R_atoms ≈ 7–8 个主 atom 才撒一个 LoRA): 能抓到这件事的指标 stall_short_sb 在 v2 上测出来 0.42 cyc/inst, 和 CUTLASS 不带 LoRA 的 0.55 没区别(log/ncu_summary.md)。反过来 如果不共享 acc —— 主 MMA 和 LoRA 写两份独立累加器再合并 —— 每个 tile 要多走一次 TMEM-load + TMEM-store,代价远高于这几拍丢掉的重叠。

interleave pattern 本身就是 K-loop 主 atom 每次多一个分支。 stride = K_atoms // R_atoms 控制 LoRA atom 多久发一次;r_nextnext_lora_at 跟踪当前哪个 LoRA atom 上场、下次发在哪。源码在 kernel_v2_fa4.py:1309-1376

β-interleave:(A) 主 NVFP4 atom 和 LoRA fp16/bf16 atom 都通过 `acc_tmem_addr: Int32` 打到 (B) 同一块 FP32 TMEM 累加器(加粗框出的共享块);(C) tcgen05 发射顺序里每个 atom 都编号 —— 主在 1/2/3/5/6/7/9/10/11、LoRA 每 `stride = K_atoms/R_atoms` 撒一个在 4/8/12;(D) tensor pipe 时间线显示同 kind 的 atom 互相重叠(1↔2↔3),但每次 LoRA 注入要付**两个** retire 气泡 —— 每个 LoRA 两侧各一条粉色"无重叠"带(主→LoRA 抽干 pipe、LoRA→主 重新填满);幻影条标出如果能流水的话该 atom 本应在哪

值得理解的 MLIR trace 细节:tiled_mma.set(tcgen05.Field.ACCUMULATE, ...)Python trace 期对对象的修改。每个 cute.gemm 调用点在 trace 时捕获当时的字段值;运行时不会再执行 setter。所以 K-tile 循环 必须是 Python 全展开(for k_tile in range(k_tile_cnt):,而不是 cutlass.range(unroll=1)),因为后者 trace 一次循环体、然后会把第一个 kblock 位置上的 ACCUMULATE=False 捕获给每个 tile 复用 —— 这会在 每个 tile 边界把累加器擦掉。当前内核正是因为这条用的 Python range, 注释在 kernel_v2_fa4.py:1294-1306

6.4 Bump(inline):2-CTA LoRA 回归

把 LoRA 用单 stage 前奏接回 FA4 骨架之后(我们记作 v1_fa4 的配置), 2-CTA 路径反退

(M=4352 K=3840 N=3072 R=128 fp16) v0_fa4(无 LoRA) v1_fa4(1-stage LoRA)
1-CTA MFU 7.7 % (未测量)
2-CTA MFU 7.6 % 6.0 %

(出处:cute_kernels/gemm_w4a4/README.md:185-189。)从 LoRA 到 一个 stage 的 LoRA,2-CTA 路径反而变慢。这是病理性的 —— 即便 差的 LoRA,也该在 TFLOPS 上做加法,不该做减法。

诊断:LoRA SMEM(LA + LU)吃了预算。单 stage LoRA 前奏体积够大,让 预算求解器(_compute_stages)把主 K-loop 的 num_ab_stage 让出来换。 主循环 pipeline stage 少 → 在飞 tcgen05 atom 少 → SM% 下来 → 墙钟 上去。修法有两块;显而易见的那块在 §6.5,真正大头的那块在 §7。

6.5 C1 —— 2-stage LoRA 前奏

num_lora_stage 从 1 抬到 2(任务追踪里的 C1)。两个 LA/LU 缓冲 ping-pong。代价:LoRA SMEM 翻倍。收益:前奏成本被摊到更多主 MMA 迭代 上,预算求解器把主路径的 stage 还回来一些,回归解开。

数字,LU SMEM 修复(即 C1 单独的贡献):

shape (M=4352, K, N, R) v1_fa4(pre-C1)2-CTA v2_fa4+C1 2-CTA Δ
K=3840 N=3072 R=128 6.0 % 14.2 % +8.2 pp
K=3840 N=15360 R=128 15.2 % 18.6 % +3.4 pp
K=15360 N=3840 R=128 17.0 % 18.1 % +1.1 pp
K=10240 N=3072 R=32 11.6 % 26.1 % +14.5 pp

(出处:cute_kernels/gemm_w4a4/README.md:185-189。)C1 把”2-CTA LoRA 反而比 1-CTA 更费”这条反常消掉 —— 每个 shape 至少 +1 pp,最差 shape (小 N 或小 R)拿到两位数 pp。

ncu 机制,在 Verda B200(counter 不受限)的生产 shape 上抓的:

指标 v2 stage0(LoRA off) v2 stage1(pre-C1) v2 stage2(C1)
duration(µs) 42.0 77.1 69.6
SM throughput % 52.3 54.6 41.2
hmma 子管线 %(NVFP4 tcore) 60.5 31.8 34.9
warp cycles / issued inst 15.0 18.6 25.9
long_scoreboard cyc(L1TEX) 10.6 13.8 21.8

(出处:cute_kernels/gemm_w4a4/README.md:209-217。)三条读法:

顺手说一句:pre-C1 v1_fa4 → v2_fa4+C1 在最小 shape 上 ~2.4× 的加速并不 全是 C1 的功劳 —— 大头来自跟 C1 同一窗口落地的 LU SMEM 修复。C1 单独 的贡献就是这次 ncu A/B 给出的 −9.7 % / +3.1 pp。这是 §7 的有用背景。

6.6 融合 × wcscale + bias epilogue

v2_fa4 比 v1_fa4 多的最后一件事是把按输出列的仿射折进 epilogue warp。 数学:

y[m, n] = acc[m, n] * wcscale[n] + bias[n]

wcscalebias[N] 张量、c_dtype 形式进来。epilogue warp 把 TMEM → 寄存器 → mul-add → SMEM → GMEM 经 TMA store。SMEM 成本可忽略 (tile_n × c_dtype.width/8 = 256 或 512 B 每 buffer;v2 最多两 buffer), 账目算在 wcbias_smem_byteskernel_v2_fa4.py:449-453)里。 pipeline_acc 的 consumer 端增加了 store 前读广播因子的工作;producer 端不变。

为什么不另起一个 epilogue pass 而是折进去:省一次 TMEM → SMEM → 寄存器 来回、省一次 TMA store、省一组 mbarrier。代价:~80 行 epilogue-warp 代码。

7. 静默的 SMEM 预算 bug —— LU ÷ cta_group_size

英雄发现。单行 patch,在生产 shape 上 +198 % TF,整个发现过程是”我写了 一个 cute.cosize probe,2 分钟跑完”。这是让这篇博文值得写的一节。

7.1 手写公式

Sm100GemmW4A4V2FA4._setup_attributes 估算 LoRA 前奏需要多少 SMEM 字节, 以便 _compute_stages 在决定塞几 stage 主 K-loop 之前先从每 SM SMEM 预算里扣掉。修复前的算式:

# kernel_v2_fa4.py —— 手写,修复前
la_bytes = mma_inst_shape_mn[0] * R * lora_ab_dtype.width // 8 // cta_group_size
lu_bytes = mma_inst_shape_mn[1] * R * lora_ab_dtype.width // 8     # ← bug
lora_smem_bytes = (la_bytes + lu_bytes) * num_lora_stage

LA 是 LoRA-down 激活,维度 [mma_tile_m, R]。LU 是 LoRA-up 权重, 维度 [mma_tile_n, R]。两者在 2-CTA 下都喂一个用 cta_group=TWO 构造 的 LoRA MMA atom。

LA 正确地除了 cta_group_size,因为 LoRA atom 用 partition_shape_A 沿 M 切(M-shard,跟主 A 同机制)。cluster shape 是 (2, 1),所以每 CTA 持有一半的 M。

LU cta_group_size。手写公式假设每 CTA 持有完整的 mma_tile_n × R 的 LU SMEM。

7.2 为什么这是 bug

CtaGroup.TWO 下,2-CTA dense MMA atom 同样沿 NB 切到 V 伙伴上(N-shard,via partition_shape_B)。这正是 Modular matmul-on- blackwell-part-3 那篇 “Shared Memory Optimization” 一节里说的 “2xSM MMA halves the B tile” 优化;CUTLASS 在 dense_blockscaled_gemm_ persistent.py 里默认就这么做、连注释都没写,因为这是 partition_shape_B 的默认行为。

sm100_utils.make_smem_layout_b(tiled_mma_2cta, ...) 返回的 per-CTA SMEM layout 已经是 tile_n × tile_k 的一半。所以当 LoRA 的 make_smem_layout_b(...) 构造 LU layout 时,LU layout 已经是 per CTA 一半大。手写估算重复算了一次。

7.3 为什么症状是”什么都没有”

这是危险点。LoRA SMEM 估算偏大不会 crash —— 它会让预算求解器悲观。 求解器以为 LoRA SMEM 比实际多吃 16 KB,于是从主路径里拿走 16 KB,把 num_ab_stage 从 4 钳到 2。内核 trace 通、跑得动、数值对,只是主 K-loop pipeline 深度被砍掉一半。

没有 assert、没有 shape mismatch、没有分配失败。_compute_stages 的 打印(如果你打开)说”stage=2 fits” —— 因为按悲观预算就只有 stage=2 能 塞。内核行为里任何指向这条 bug 的迹象。墙钟是”慢但能跑”;ncu 说 “低 SM%、高 long_scoreboard“;你花一周去调 num_lora_stage 和 tile 几何,全都不动。

7.4 两分钟的 probe

发现 bug 的故事是值得带走的部分。cute.cosize 在 trace 期工作,返回 Int32,给出 layout 的实际 SMEM cosize —— 正是手写公式想估的量。在 _setup_attributes 里丢一个 print:

print("la_one =", cute.cosize(slice_(self.la_smem_layout_staged,
                                     (None, None, None, 0))))
print("lu_one =", cute.cosize(slice_(self.lu_smem_layout_staged,
                                     (None, None, None, 0))))

捕获到的输出(生产 shape,R=128、fp16、2-CTA):

[PROBE96] num_lora_stage=2 cta_group_size=2
[PROBE96] la_one cosize=16384 -> 32768 B (handwritten 32768 B, factor 1.000)
[PROBE96] lu_one cosize=8192  -> 16384 B (handwritten 32768 B, factor 0.500)

LA 跟手写值对得上(factor 1.000)。LU 是恰好一半(factor 0.500)。 bug 找到,120 秒。

7.5 修复

多写一个 // self.cta_group_size

# kernel_v2_fa4.py:429-444 —— 修复后
lora_smem_bytes = 0
if cutlass.const_expr(self.enable_lora):
    la_bytes = (self.mma_inst_shape_mn[0] * self.R
                * self.lora_ab_dtype.width // 8) // self.cta_group_size
    lu_bytes = (self.mma_inst_shape_mn[1] * self.R
                * self.lora_ab_dtype.width // 8) // self.cta_group_size
    lora_smem_bytes = (la_bytes + lu_bytes) * self.num_lora_stage

提交 7296e90。429-444 行的注释引用了 probe 工件和 docs/gotchas/cute_dsl.md:289-347 的 gotcha 条目。

7.6 生产 shape 上的 before/after

同一个 bench_gemm_v2_fa4_c1.py benchmark、fp16、2-CTA、 M=4352 K=3840 N=3072 R=128,pre-fix 在 B300 上、post-fix 在 B200 上(绝对 TFLOPS 跨卡可比;B200 的 NVFP4 峰值更低,所以”同 TF”也意味着我们在更 弱的卡上更快了):

指标 pre-fix post-fix Δ
TFLOPS 566 1685 +198 %
MFU(B200 10 PFLOPS NVFP4) 4.2 % 16.9 % +12.7 pp

(出处:docs/gpu.md:286-296。)同一台 Verda B200 上的 ncu A/B (HEAD^ vs HEAD = 提交 7296e90,在两次跑之间用磁盘上换 kernel, num_lora_stage=2,单次 launch):

指标 pre-LU-fix post-LU-fix Δ
Duration 46.69 µs 32.13 µs −14.56 µs / −31.2 %
Compute (SM) % 41.63 53.62 +11.99 pp
Memory % 25.58 38.91 +13.33 pp
L1/TEX Cache % 28.50 44.75 +16.25 pp
L2 Cache % 24.57 36.18 +11.61 pp
DRAM % 5.04 7.31 +2.27 pp
SM Active Cycles 72 433 46 126 −36.3 %
Memory Throughput 386 GB/s 561 GB/s +45 %
Grid / Block 148 / 192 148 / 192 完全一致

(出处:docs/gpu.md:393-403。)读起来很干净:同样的 launch shape、 同样的 occupancy,但 num_ab_stage 抬到 4 后 SM 侧 pipeline 喂饱 → SM% +12 pp、SM Active Cycles 砍 36 %。L1/TEX 和 L2 吞吐成比例上升, 因为 TMA producer 现在有更多在飞 buffer 要填 —— 不是”省带宽”,是 “带宽在整个内核墙钟上更均匀”。DRAM 仍低(compute-bound 仍成立)。

7.7 为什么这能泛化 —— 教学内容

bug 是具体的(lu_bytes 多算一倍)。pattern 是普适的:任何手写 SMEM 预算算式,只要它喂的是 stage 求解器、且对应的 operand SMEM 是 make_smem_layout_{a, b}(tiled_mma_2cta, ...) 出来的,就必须沿那个被 切分的轴除一遍 cta_group_size。A 是 M-split(partition_shape_A 沿 M halve),B 是 N-split(partition_shape_B 沿 N halve)。2-CTA 下 两者都 per CTA halve,只是沿不同轴。

为什么会有人手写预算:_compute_stages 需要在 operand SMEM 真正分配 之前给出字节估算(layout 取决于 stage 数,stage 数取决于预算 —— 循环依赖)。手写公式是用来打破循环的,但在非主 operand 上很容易把 cta_group 的切分搞错。

更稳健的替代:先把 layout 建出来,回读 cute.cosize,用读回来的值做 预算输入。多写一点代码,但跟硬件真相一致。两条路都行;要规避的失败 模式是”手写公式 + 没人对照 cute.cosize 做交叉验证”。

docs/gotchas/cute_dsl.md:289-347 把这条记作未来自己看的一条 pattern, 带 probe 模板、症状描述(”no assert fires; numerics are still correct; perf is just lower than it should be”)、apply 指南(”任何地方你手写的 SMEM 预算估算,只要 operand 来自 make_smem_layout_{a, b}(tiled_mma_2cta, ...),就除一遍 cta_group_size。A 和 B 在 2-CTA 下都 halve,只是沿 不同轴。”)。

8. 用 Blackwell 内核作者的眼光读 ncu

LU SMEM 这个发现如果没有 ncu 就只是墙钟噪声 —— 47 µs 里有 14 µs 是 真的;但在 Modal(ncu 被封,见下)上你只会看到一个”慢”的墙钟和一份 “正常”的 torch.profiler activity,结论是”内核需要 tuning”,没有任何 具体方向。C1 的机制故事(前奏摊销 vs 延迟修复)就更难没 ncu 读出来 —— duration 下来了你就出货了,永远不知道每 warp 的 long_scoreboard 其实涨了。所以这一节是方法论小结。

8.1 计数器访问 —— Modal 封、Verda 开

CLAUDE.md 执行环境一节写到的分工对任何要复现或扩展这项工作的人都 关键:

8.2 最常被复制粘贴的一条 —— hmma 是 NVFP4 tensor pipe

tcgen05 UTCQMMA 在 ncu metric tree 上跑在 hmma 子管线上。没有 独立的 qmma_* counter。grep qmma 拿不到任何东西,一下午就没了。 你要的 metric 是:

sm__pipe_tensor_subpipe_hmma_cycles_active.avg.pct_of_peak_sustained_active

覆盖 HMMA + UTCHMMA + UTCQMMA + UTCOMMA 一起。按累加器 dtype 拆 FLOPs:

sm__ops_path_tensor_op_utcqmma_src_fp4_fp6_fp8_dst_fp32
sm__ops_path_tensor_op_utcqmma_src_fp4_fp6_fp8_dst_fp16
sm__ops_path_tensor_op_utcomma_src_fp4_dst_fp32     # 单独的 FP4-only 路径

--section ComputeWorkloadAnalysis 自动拉出子管线分解。UTCQMMA 工作 在 SOL “Compute (SM) Pipe Utilization” 面板里挂在 “HMMA Pipe” 下。完整 清单见 docs/gpu.md:105-127

8.3 2-CTA UMMA 内核要读的 SOL 分解

§6.5 的 C1 ncu 表是模板。读法:

要学的 pattern:hmma % 比 CUTLASS reference 低,跟 long_scoreboard 周期高,不是同一个问题。前者说”tensor pipe 闲”,后者说”warp 没活 可发”。两条可以同时成立,修法不同。

8.4 trace 期 cute.cosize probe 模式

解锁 LU SMEM 发现的唯一工具。两块机制:

  1. CuTe layout 在 trace 期就知道 cosize。表达式 cute.cosize(layout) (或切过的 layout,比如 slice_(self.lu_smem_layout_staged, (None, None, None, 0)) 砍掉 stage 维)。
  2. _setup_attributes(trace 期执行)里 print 它,值会在内核运行 之前打到控制台。所以不需要设备侧仪表;诊断在 trace 期就 surface 出来,一行 print。

模板,从 docs/gotchas/cute_dsl.md:319-324

print("la_one =", cute.cosize(slice_(self.la_smem_layout_staged,
                                     (None, None, None, 0))))
print("lu_one =", cute.cosize(slice_(self.lu_smem_layout_staged,
                                     (None, None, None, 0))))

逐一对比手写值。比 1.0 → 一致。比 0.5 或 2.0 → operand 在某个你没算到 的轴上被切了(或没被切)。2 分钟,可发现整类”手写预算偏差”bug。

9. 校准 —— 这个内核实际站在哪里

两个参考点;它们告诉你不同的事。

9.1 诚实的上限 —— 同 B200 上的 CUTLASS NVFP4

CUTLASS 自家 dense_blockscaled_gemm_persistent.py 是纯主 NVFP4 MMA (无 LoRA、无 wcscale、无 bias)。同样的 atom,同样的硬件。在生产 shape 行(M=4352 K=15360 N=3840,K-heavy 那一条):

变体 MFU
CUTLASS 1-CTA 128×256 41.0 %
CUTLASS 2-CTA 256×128 59.0 %
CUTLASS 2-CTA 256×256 63.4 %
v2_fa4+C1+LU-fix, fp16 2-CTA 27.3 %
v2_fa4+C1+LU-fix, bf16 2-CTA 27.3 %

(出处:CUTLASS 列 cute_kernels/gemm_w4a4/README.md:156-160; v2_fa4 fp16/bf16 列 docs/gpu.md:316-318。)两条要带走的:

这不是”我们慢”的对比,是”上面还有多少空间、接下来从空间里拿什么”的 对比。

9.2 实现品质参考 —— RTX PRO 6000 上的 nunchaku

nunchaku 的 NVFP4 在 __CUDA_ARCH__ >= 1200(SM_120a/121a,见 nunchaku/setup.py:41-64)这条上 gated,所以我们在 B200 上没法 跑 —— 没有 SM_100 的 nunchaku 二进制。我们在 RTX PRO 6000 Blackwell Server Edition(SM_120a)上跑它作为实现品质参考,不是上限。两边 硬件峰值差 2.5×(B200 10 PFLOPS NVFP4 vs PRO 6000 4 PFLOPS),所以 MFU 对比要待在同一侧的列里才算 apples-to-apples。

Shape (M, K, N, R) ours fp16 (B200) nunchaku fp16 (PRO 6000) Δ pp ours bf16 nunchaku bf16 Δ pp
4352 × 3840 × 3072 × R=128 16.9 16.2 +0.7 17.3 17.7 −0.4
4352 × 3840 × 15360 × R=128 26.5 19.5 +7.0 26.7 24.7 +2.0
4352 × 15360 × 3840 × R=128 27.3 25.0 +2.3 27.3 30.5 −3.2
4352 × 10240 × 3072 × R=32 26.4 21.4 +5.0 26.2 25.2 +1.0

(出处:docs/gpu.md:314-319。)fp16:4/4 shape 全数领先。bf16: 2/4 领先,1/4 是 −0.4 pp 在噪声内,1/4 还落后 3.2 pp(M=4352 K=15360 N=3840)。那一格 −3.2 pp 就是 docs/gpu.md:79-103 说的”bf16 hand-PTX vs DSL MLIR lowering”不对称:nunchaku 的 MMA 是 inline PTX (mma_earlycuda.cuh),fp16 和 bf16 是两套分别手调的 PTX,register packing / 累加器精度不同。我们走一个 tcgen05 atom + ab_dtype 替换 —— 两种 dtype 走同一份 MLIR。要补上 bf16 上最后那 3 pp 大概率得下到 inline PTX,超出范围。

同 shape 的绝对吞吐(B200 vs PRO 6000,峰值比 ~2.5×):

Shape ours TF (B200) nunchaku TF (PRO 6000) 比例
4352 × 3840 × 3072 × R=128 1685 ~648 2.60×
4352 × 3840 × 15360 × R=128 2648 ~780 3.40×
4352 × 15360 × 3840 × R=128 2735 ~1000 2.74×
4352 × 10240 × 3072 × R=32 2645 ~856 3.09×

(出处:docs/gpu.md:330-335。)跨卡数字仅作绝对参考;apples-to-apples 那一条还是看上面那张同列 MFU 表。

关于 nunchaku 的 fp16 列,简短一句: 他们的 hand-PTX fp16 路径打到 255 寄存器/线程 + ~2.28 M LMEM spill + 101 % spill overhead;bf16 路径 248 寄存器、零 spill。那 7 个寄存器的差就是”装得下/装不下”的差,也就是 他们 bf16-over-fp16 那 ~5 pp 跳变的来源。我们不复现这种不对称 —— 单条 tcgen05 atom + ab_dtype 替换走同一份 MLIR 下降,两个 dtype 都一样, 所以我们 fp16 ≈ bf16(四个 shape 里有三个 ±0.1 pp 之内)。这是他们参考 实现的 codegen 性质,不是我们内核的性质;它解释了他们那一列的形状, 不解释我们这一列的位置。

10. 还在桌上的杠杆

按当前认知的 ROI 排:

11. 代码在哪儿,以及致谢

代码:

仓库里相关的文档:

本文引用的关键提交:

交叉链接:同算子的 Ascend(Atlas A3)侧在 csrc/kernels/gemm_w4a4/, 用 INT4 + AscendC;数学一致。格式分立的架构理由在 docs/architecture.mdCLAUDE.md

致谢。感谢 Verda 提供 counter 不受限的 B200 镜像 —— LU SMEM 修复 在受限计数器的主机上只会读作墙钟噪声,C1 的机制分析也 literally 需要 计数器。感谢 Tri Dao 的 Flash-Attention 4,warp-spec 脚手架 pattern 让整个 FA4 衍生重写成立。感谢 NVIDIA 的 CUTLASS 团队 既贡献了 dense_blockscaled_gemm_persistent.py 参考,又在 Modular matmul-on-blackwell-part-3 那篇里把 “2xSM MMA halves the B tile” 机制用大白话讲清楚。

发现了 bug、对不上的数字,或者觉得某个原语解释得不到位?欢迎给本 仓库开 issue,或者直接 PR 进 docs/gotchas/cute_dsl.md —— 这类发现 最终都会归档到那里。