diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index fc221965afd929f65350150c4df4c9b1cff9736d..6e8ae676c5e1139777486d8966b4901f01725b3c 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -1961,8 +1961,23 @@ void CodeGenTileLangMACA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_bitor())) { os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::atomic_add_elem_op())) { + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAdd(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_add_ret_elem_op())) { + os << "AtomicAddRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; } else if (op->op.same_as(tl::atomic_addx2_elem_op())) { - // atomic_addx2_elem_op(dst_ptr, src_ptr[, memory_order]) std::string dst_ptr = PrintExpr(op->args[0]); std::string src_ptr = PrintExpr(op->args[1]); this->PrintIndent(); diff --git a/tilelang/intrinsics/maca_mma_macro_generator.py b/tilelang/intrinsics/maca_mma_macro_generator.py index f551a0c5f78921dbb4c7012de8f0783386023759..cb59955348591c36e1c4cbec7fb6f3d70f9de037 100644 --- a/tilelang/intrinsics/maca_mma_macro_generator.py +++ b/tilelang/intrinsics/maca_mma_macro_generator.py @@ -49,6 +49,8 @@ class TensorCoreIntrinEmitter: "int32": "int32", "float8_e4m3": "e4m3", "float8_e5m2": "e5m2", + "float8_e4m3fn": "e4m3", + "float8_e5m2fn": "e5m2", "float8_e4m3fnuz": "e4m3fnuz", "float8_e5m2fnuz": "e5m2fnuz", } @@ -124,17 +126,28 @@ class TensorCoreIntrinEmitter: self.local_size_b = (n_dim * k_dim) // warp_size self.local_size_out = (m_dim * n_dim) // warp_size + def _dtype_abbrv_lookup(self, dtype): + s = str(dtype) + if s.startswith("dtype('") and s.endswith("')"): + s = s[7:-2] + if s not in self.dtype_abbrv: + raise KeyError(f"Unsupported dtype for MACA MMA: {dtype!r}") + return self.dtype_abbrv[s] + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): - self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] - self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] - self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + self.a_dtype_abbrv = self._dtype_abbrv_lookup(a_dtype) + self.b_dtype_abbrv = self._dtype_abbrv_lookup(b_dtype) + self.accum_dtype_abbrv = self._dtype_abbrv_lookup(accum_dtype) def _initialize_mma_prefix(self, k_dim=16): in_dtype, out_dtype = self.a_dtype, self.accum_dtype M_DIM, N_DIM = self.M_DIM, self.N_DIM out_dtype_abbrv = {T.float16: "f16", T.float32: "f32", T.int8: "i8", T.int32: "i32"}[out_dtype] - in_dtype_abbrv = { + in_dtype_key = str(in_dtype) + if in_dtype_key.startswith("dtype('") and in_dtype_key.endswith("')"): + in_dtype_key = in_dtype_key[7:-2] + in_dtype_map = { "bfloat16": "bf16", "float16": "f16", "float32": "f32", @@ -142,7 +155,10 @@ class TensorCoreIntrinEmitter: "int32": "i32", "float8_e4m3fnuz": "fp8", "float8_e5m2fnuz": "fp8", - }[in_dtype] + "float8_e4m3fn": "fp8", + "float8_e5m2fn": "fp8", + } + in_dtype_abbrv = in_dtype_map[in_dtype_key] if in_dtype_abbrv == "fp8": self.mma_suffix = f"{M_DIM}x{N_DIM}x{k_dim}fp8"