From 91ce0537eb4120dbb491963a0fa1f565222b48b8 Mon Sep 17 00:00:00 2001 From: Zymonody7 <2196905367@qq.com> Date: Wed, 10 Jun 2026 17:30:42 +0800 Subject: [PATCH] [MACA] codegen Min/Max: explicit integer branches and fp16/bf16 vector lowering - Integer scalars now emit min()/max() explicitly (8/16-bit promote through int and cast back) instead of relying on CodeGenC defaults. - Vectorized fp16/bf16 min/max lower per-lane via a float compare; the previous default lowering emitted min(__half, __half), which fails to compile on MACA (call to 'min' is ambiguous). - Add codegen and runtime tests under testing/python/maca/, verified on MetaX C500 (MACA 2.33.1.13): 8/8 passed. Fixes: https://gitee.com/metax-maca/mcTileLang/issues/IFUFVD --- src/target/codegen_maca.cc | 76 ++++++++++++- src/target/codegen_maca.h | 2 + .../python/maca/test_maca_codegen_minmax.py | 106 ++++++++++++++++++ 3 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 testing/python/maca/test_maca_codegen_minmax.py diff --git a/src/target/codegen_maca.cc b/src/target/codegen_maca.cc index 6e8ae676..c606f658 100644 --- a/src/target/codegen_maca.cc +++ b/src/target/codegen_maca.cc @@ -1133,10 +1133,44 @@ void CodeGenTileLangMACA::VisitExpr_(const CastNode *op, std::ostream &os) { os << sret; } +// Per-lane min/max for fp16/bf16 vectors. PrintVecBinaryOp would emit +// min(__half, __half) which is ambiguous on MACA, so compare through +// float instead (exact for min/max). +void CodeGenTileLangMACA::PrintVecMinMaxHalf(const PrimExpr &lhs, + const PrimExpr &rhs, DataType t, + bool is_min, std::ostream &os) { + std::string sret = name_supply_->FreshName("_"); + this->PrintIndent(); + this->PrintType(t, stream); + stream << ' ' << sret << ";\n"; + int ssa_scope = BeginScope(); + { + std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); + std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); + const char *cmp = is_min ? " < " : " > "; + for (int i = 0, lanes = t.lanes(); i < lanes; ++i) { + std::ostringstream lhs_elem, rhs_elem, value_temp; + PrintVecElemLoad(vlhs, lhs.dtype(), i, lhs_elem); + PrintVecElemLoad(vrhs, rhs.dtype(), i, rhs_elem); + value_temp << "((float)" << lhs_elem.str() << cmp << "(float)" + << rhs_elem.str() << " ? " << lhs_elem.str() << " : " + << rhs_elem.str() << ")"; + PrintVecElemStore(sret, t, i, value_temp.str()); + } + } + EndScope(ssa_scope); + os << sret; +} + void CodeGenTileLangMACA::VisitExpr_(const MinNode *op, std::ostream &os) { - // TODO(wt): Consider vectorized reduction and impl for other dtypes DataType t = op->dtype; + // fp16/bf16 vectors need per-lane lowering + if ((t.is_bfloat16() || t.is_float16()) && t.lanes() > 1) { + PrintVecMinMaxHalf(op->a, op->b, t, /*is_min=*/true, os); + return; + } + // Standard min/max functions don't support bfloat16 or float16 if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) { os << "mctlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) @@ -1152,14 +1186,34 @@ void CodeGenTileLangMACA::VisitExpr_(const MinNode *op, std::ostream &os) { } } - // For all other scalar types (int, uint), use default implementation + // Integer scalars: emit min() explicitly so all widths behave the same. + // 8/16-bit values promote through int, cast the result back. + if ((t.is_int() || t.is_uint()) && t.is_scalar()) { + if (t.bits() == 32 || t.bits() == 64) { + os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } else { + os << "("; + PrintType(t, os); + os << ")min((int)" << PrintExpr(op->a) << ", (int)" << PrintExpr(op->b) + << ")"; + } + return; + } + + // float/int vectors: default implementation handles these per-lane + // via PrintVecBinaryOp CodeGenC::VisitExpr_(op, os); } void CodeGenTileLangMACA::VisitExpr_(const MaxNode *op, std::ostream &os) { - // TODO(wt): Consider vectorized reduction and impl for other dtypes DataType t = op->dtype; + // fp16/bf16 vectors need per-lane lowering + if ((t.is_bfloat16() || t.is_float16()) && t.lanes() > 1) { + PrintVecMinMaxHalf(op->a, op->b, t, /*is_min=*/false, os); + return; + } + // Standard min/max functions don't support bfloat16 or float16 if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) { os << "mctlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) @@ -1175,7 +1229,21 @@ void CodeGenTileLangMACA::VisitExpr_(const MaxNode *op, std::ostream &os) { } } - // For all other scalar types (int, uint), use default implementation + // Integer scalars: same handling as MinNode above. + if ((t.is_int() || t.is_uint()) && t.is_scalar()) { + if (t.bits() == 32 || t.bits() == 64) { + os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } else { + os << "("; + PrintType(t, os); + os << ")max((int)" << PrintExpr(op->a) << ", (int)" << PrintExpr(op->b) + << ")"; + } + return; + } + + // float/int vectors: default implementation handles these per-lane + // via PrintVecBinaryOp CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/codegen_maca.h b/src/target/codegen_maca.h index 7a88a374..e68cb0d5 100644 --- a/src/target/codegen_maca.h +++ b/src/target/codegen_maca.h @@ -36,6 +36,8 @@ public: void PrintVecBinaryOp(const std::string &op, DataType t, PrimExpr lhs, PrimExpr rhs, std::ostream &os) final; // NOLINT(*) + void PrintVecMinMaxHalf(const PrimExpr &lhs, const PrimExpr &rhs, DataType t, + bool is_min, std::ostream &os); // NOLINT(*) void PrintType(DataType t, std::ostream &os) final; // NOLINT(*) void PrintVecElemLoad(const std::string &vec, DataType t, int i, std::ostream &os) final; // NOLINT(*) diff --git a/testing/python/maca/test_maca_codegen_minmax.py b/testing/python/maca/test_maca_codegen_minmax.py new file mode 100644 index 00000000..f2086f86 --- /dev/null +++ b/testing/python/maca/test_maca_codegen_minmax.py @@ -0,0 +1,106 @@ +import pytest + +import tilelang +from tilelang import tvm as tvm +import tilelang.language as T +import tilelang.testing +from tilelang.utils.target import check_maca_availability + +MACA_TARGET = "maca" + +requires_maca = pytest.mark.skipif( + not check_maca_availability(), reason="MACA toolchain is not available") + + +def _minmax_kernel(dtype: str, fn, vec: int = 1, M: int = 128): + + @T.prim_func + def main( + A: T.Tensor((M, vec), dtype=dtype), + B: T.Tensor((M, vec), dtype=dtype), + C: T.Tensor((M, vec), dtype=dtype), + ): + with T.Kernel(1, 1, threads=M) as (bx, by): + tid = T.get_thread_binding() + for v in T.vectorized(vec): + C[tid, v] = fn(A[tid, v], B[tid, v]) + + return main + + +def _lower_to_maca_source(func) -> str: + with tvm.transform.PassContext(), tvm.target.Target(MACA_TARGET): + artifact = tilelang.lower(func, target=MACA_TARGET) + assert artifact.kernel_source is not None + return artifact.kernel_source + + +@requires_maca +def test_maca_codegen_min_fp16_scalar(): + src = _lower_to_maca_source(_minmax_kernel("float16", T.min)) + assert "mctlass::fast_min" in src + + +@requires_maca +def test_maca_codegen_max_bf16_scalar(): + src = _lower_to_maca_source(_minmax_kernel("bfloat16", T.max)) + assert "mctlass::fast_max" in src + + +@requires_maca +def test_maca_codegen_min_int32_scalar(): + src = _lower_to_maca_source(_minmax_kernel("int32", T.min)) + assert "min(" in src + + +@requires_maca +def test_maca_codegen_min_int8_scalar(): + # Narrow integers promote through int and cast back. + src = _lower_to_maca_source(_minmax_kernel("int8", T.min)) + assert "min((int)" in src + + +@requires_maca +def test_maca_codegen_min_fp16_vectorized(): + # fp16x2 min must not emit min(__half, __half), which is + # ambiguous on MACA; it lowers to a per-lane float-compare selection. + src = _lower_to_maca_source(_minmax_kernel("float16", T.min, vec=2)) + assert "min(__half" not in src + + +@requires_maca +def test_maca_codegen_max_bf16_vectorized(): + src = _lower_to_maca_source(_minmax_kernel("bfloat16", T.max, vec=2)) + assert "max(__maca_bfloat16" not in src + + +@requires_maca +def test_maca_min_int32_runtime(): + import torch + + M = 128 + kernel = tilelang.compile( + _minmax_kernel("int32", T.min), target=MACA_TARGET) + a = torch.randint(-1000, 1000, (M, 1), dtype=torch.int32, device="cuda") + b = torch.randint(-1000, 1000, (M, 1), dtype=torch.int32, device="cuda") + c = torch.empty_like(a) + kernel(a, b, c) + torch.testing.assert_close(c, torch.minimum(a, b)) + + +@requires_maca +def test_maca_min_fp16_vectorized_runtime(): + import torch + + M = 128 + kernel = tilelang.compile( + _minmax_kernel("float16", T.min, vec=2), target=MACA_TARGET) + a = torch.randn(M, 2, dtype=torch.float16, device="cuda") + b = torch.randn(M, 2, dtype=torch.float16, device="cuda") + c = torch.empty_like(a) + kernel(a, b, c) + torch.testing.assert_close(c, torch.minimum(a, b)) + + +if __name__ == "__main__": + tilelang.testing.main() -- Gitee