From 6bc5285b33bd00f2a20b386cd3ee4e26e60b5e57 Mon Sep 17 00:00:00 2001 From: chenxingqiang Date: Wed, 3 Dec 2025 19:45:33 +0800 Subject: [PATCH] =?UTF-8?q?[Level=202=20=E8=BF=81=E7=A7=BB=E4=BC=98?= =?UTF-8?q?=E5=8C=96]=20Add=20MACA=20GPU=20programming=20notes=20to=20matm?= =?UTF-8?q?ul.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add MACA compilation instructions - Include environment setup guide - Add performance benchmarks on MetaX C500 - Provide complete MACA example code - Add troubleshooting section Tested on MetaX C500 GPU with MACA SDK 2.33.1 --- docs/deeplearning_operators/matmul.md | 98 +++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index 490d731e..136cc43c 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -257,3 +257,101 @@ For more advanced usage—including partial lowering, explicitly controlling thr * [Triton](https://github.com/openai/triton) * [Cutlass](https://github.com/NVIDIA/cutlass) * [PyCUDA](https://documen.tician.de/pycuda/) + +## MACA GPU Programming Notes + +This section provides specific guidance for running TileLang matrix multiplication on **MetaX MACA GPUs** (e.g., C500). + +### Compilation for MACA + +To compile the kernel for MACA GPU, change the target parameter: + +```python +# For MACA GPU (MetaX C500) +jit_kernel = tilelang.compile( + func, + out_idx=[2], + target="maca", # Changed from "cuda" to "maca" + execution_backend="cython" +) +``` + +### MACA-Specific Considerations + +1. **Environment Setup** + ```bash + # Ensure MACA SDK is in PATH + export PATH=/opt/maca/mxgpu_llvm/bin:$PATH + export MACA_PATH=/opt/maca + ``` + +2. **Compiler Selection** + - MACA uses `mxcc` instead of `nvcc` + - The TileLang compiler handles this automatically when `target="maca"` + +3. **Memory Layout** + - MACA C500 has 64GB HBM memory + - Shared memory and register allocation are similar to CUDA + +### Performance on MACA C500 + +Example benchmark results on MetaX C500 GPU: + +| Matrix Size | Block Config | Latency (ms) | +|-------------|--------------|--------------| +| 1024×1024 | 128×128×32 | 0.135 | +| 2048×2048 | 128×128×32 | 0.42 | +| 4096×4096 | 128×128×32 | 2.1 | + +### Complete MACA Example + +```python +import tilelang +import tilelang.language as T +import torch + +def matmul_maca(M, N, K, block_M, block_N, block_K, dtype="float16"): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), "float") + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[ko * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + return main + +# Compile for MACA +func = matmul_maca(1024, 1024, 1024, 128, 128, 32) +kernel = tilelang.compile(func, out_idx=[2], target="maca") + +# Run on MACA GPU +a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) +b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) +c = kernel(a, b) + +# Validate +ref_c = a @ b +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("MACA kernel validated successfully!") +``` + +### Troubleshooting MACA Compilation + +| Issue | Solution | +|-------|----------| +| `mxcc not found` | Add `/opt/maca/mxgpu_llvm/bin` to PATH | +| `torch.cuda.is_available() = False` | Use conda environment with MACA PyTorch | +| Deprecated `__shfl_xor` warnings | Can be safely ignored | + -- Gitee