From ad2e0e723176ec7c02f3d2c0184e2e27a55d862b Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Sat, 28 Feb 2026 11:56:50 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[MACA]=20codegen:=20=E6=94=AF=E6=8C=81=20at?= =?UTF-8?q?omic=5Fadd=5Felem=5Fop=20/=20atomic=5Fadd=5Fret=5Felem=5Fop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 在 VisitExpr_(CallNode) 中为 tl::atomic_add_elem_op 生成 AtomicAdd(...); - 为 tl::atomic_add_ret_elem_op 生成 AtomicAddRet(...) 表达式 - 修复 gemv、attention_sink 等用例的 MACA codegen --- src/target/codegen_maca.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index fc221965..6e8ae676 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -1961,8 +1961,23 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_bitor())) { os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::atomic_add_elem_op())) { + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAdd(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_add_ret_elem_op())) { + os << "AtomicAddRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; } else if (op->op.same_as(tl::atomic_addx2_elem_op())) { - // atomic_addx2_elem_op(dst_ptr, src_ptr[, memory_order]) std::string dst_ptr = PrintExpr(op->args[0]); std::string src_ptr = PrintExpr(op->args[1]); this->PrintIndent(); -- Gitee From 16dc6013ccf77c653be6314af7b94d062931ae3c Mon Sep 17 00:00:00 2001 From: Dayuxiaoshui <792179245@qq.com> Date: Wed, 4 Mar 2026 13:21:32 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[MACA]=20=E6=94=AF=E6=8C=81=20float8=5Fe4m3?= =?UTF-8?q?fn/float8=5Fe5m2fn=20=E7=9A=84=20MMA=20Layout=20=E6=8E=A8?= =?UTF-8?q?=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - dtype_abbrv 增加 float8_e4m3fn、float8_e5m2fn - 新增 _dtype_abbrv_lookup 处理 dtype('...') 形式 - in_dtype_map 增加上述 FP8 映射为 fp8,解决 Layout 阶段 KeyError --- .../intrinsics/maca_mma_macro_generator.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tilelang/intrinsics/maca_mma_macro_generator.py b/tilelang/intrinsics/maca_mma_macro_generator.py index f551a0c5..cb599553 100644 --- a/tilelang/intrinsics/maca_mma_macro_generator.py +++ b/tilelang/intrinsics/maca_mma_macro_generator.py @@ -49,6 +49,8 @@ class TensorCoreIntrinEmitter: "int32": "int32", "float8_e4m3": "e4m3", "float8_e5m2": "e5m2", + "float8_e4m3fn": "e4m3", + "float8_e5m2fn": "e5m2", "float8_e4m3fnuz": "e4m3fnuz", "float8_e5m2fnuz": "e5m2fnuz", } @@ -124,17 +126,28 @@ class TensorCoreIntrinEmitter: self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size + def _dtype_abbrv_lookup(self, dtype): + s = str(dtype) + if s.startswith("dtype('") and s.endswith("')"): + s = s[7:-2] + if s not in self.dtype_abbrv: + raise KeyError(f"Unsupported dtype for MACA MMA: {dtype!r}") + return self.dtype_abbrv[s] + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): - self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] - self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] - self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + self.a_dtype_abbrv = self._dtype_abbrv_lookup(a_dtype) + self.b_dtype_abbrv = self._dtype_abbrv_lookup(b_dtype) + self.accum_dtype_abbrv = self._dtype_abbrv_lookup(accum_dtype) def _initialize_mma_prefix(self, k_dim=16): in_dtype, out_dtype = self.a_dtype, self.accum_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype] - in_dtype_abbrv = { + in_dtype_key = str(in_dtype) + if in_dtype_key.startswith("dtype('") and in_dtype_key.endswith("')"): + in_dtype_key = in_dtype_key[7:-2] + in_dtype_map = { "bfloat16": "bf16", "float16": "f16", "float32": "f32", @@ -142,7 +155,10 @@ class TensorCoreIntrinEmitter: "int32": "i32", "float8_e4m3fnuz": "fp8", "float8_e5m2fnuz": "fp8", - }[in_dtype] + "float8_e4m3fn": "fp8", + "float8_e5m2fn": "fp8", + } + in_dtype_abbrv = in_dtype_map[in_dtype_key] if in_dtype_abbrv == "fp8": self.mma_suffix = f"{M_DIM}x{N_DIM}x{k_dim}fp8" -- Gitee