diff --git a/docs/deeplearning_operators/elementwise.md b/docs/deeplearning_operators/elementwise/elementwise.md similarity index 100% rename from docs/deeplearning_operators/elementwise.md rename to docs/deeplearning_operators/elementwise/elementwise.md diff --git a/docs/deeplearning_operators/elementwise/elementwise_add_comparison.md b/docs/deeplearning_operators/elementwise/elementwise_add_comparison.md new file mode 100644 index 0000000000000000000000000000000000000000..c2bf5219afa233dfb83d11070730ae33329e8cbc --- /dev/null +++ b/docs/deeplearning_operators/elementwise/elementwise_add_comparison.md @@ -0,0 +1,235 @@ +# TileLang Elementwise Add 实现对比分析 + +## 概述 + +对比 `mcTileLang` 和 `tilelang` 两个版本的 elementwise add 实现差异。 + +## 代码对比 + +### mcTileLang 版本 + +```python +def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): + @T.prim_func + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + C[y, x] = A[y, x] + B[y, x] # 直接访问 global memory + return elem_add + +# 使用时需要手动编译 +kernel = tilelang.compile(elementwise_add(...), out_idx=-1) +``` + +### tilelang 版本 + +```python +@tilelang.jit(out_idx=[-1]) # JIT 装饰器 +def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): + @T.prim_func + def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), + C: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_N), in_dtype) + B_shared = T.alloc_shared((block_M, block_N), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) + + for (local_y, local_x) in T.Parallel(block_M, block_N): + C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] + + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + return elem_add + +# 使用时自动编译 +kernel = elementwise_add(...) +``` + +## 主要差异 + +| 特性 | mcTileLang | tilelang | +|------|------------|----------| +| 编译方式 | 手动 `tilelang.compile()` | `@tilelang.jit` 装饰器 | +| 内存访问 | 直接访问 global memory | 显式内存层次 (global → shared → fragment) | +| 数据写回 | 直接写 global | fragment → shared → global | +| 默认 block 大小 | 128×256 | 32×32 | + +## 关键问题:为什么 tilelang 版本使用两次 copy 写回? + +```python +T.copy(C_local, C_shared) # copy 1: fragment → shared +T.copy(C_shared, C[by * block_M, bx * block_N]) # copy 2: shared → global +``` + +### 设计意图:利用硬件加速指令 + +这里可能是为了利用 **stmatrix + TMA store** 硬件特性: + +| Copy 操作 | 期望生成的指令 | 架构要求 | +|-----------|---------------|----------| +| `C_local → C_shared` | **stmatrix** | SM90+ (Hopper) | +| `C_shared → C[...]` | **TMA store** | SM90+ (Hopper) | + +### TileLang 编译器的指令选择逻辑 + +```cpp +// tilelang/src/target/utils.cc + +bool TargetHasStmatrix(Target target) { + return arch >= 90; // SM90+ (Hopper) +} + +bool TargetHasBulkCopy(Target target) { // TMA + return arch >= 90; // SM90+ (Hopper) +} + +// tilelang/src/op/copy.cc + +// stmatrix 触发条件 +bool CopyNode::CheckSTSMCopy(Target target) const { + return TargetHasStmatrix(target) && // SM90+ + src.scope() == "local.fragment" && // 源是 fragment + (dst.scope() == "shared.dyn" || dst.scope() == "shared"); // 目标是 shared +} + +// TMA store 触发条件 +bool CopyNode::CheckBulkStore(Target target, ...) const { + return TargetHasBulkCopy(target) && // SM90+ + (src.scope() == "shared.dyn" || src.scope() == "shared") && // 源是 shared + dst.scope() == "global"; // 目标是 global +} +``` + +### 数据流 + +``` +┌─────────────────────────────────────────────────────────┐ +│ Register (C_local) │ +│ │ │ +│ │ stmatrix (SM90+) │ +│ ▼ │ +│ Shared Memory (C_shared) │ +│ │ │ +│ │ TMA store (SM90+) 或 普通 store (SM80) │ +│ ▼ │ +│ Global Memory (C) │ +└─────────────────────────────────────────────────────────┘ +``` + +### 为什么不能直接 fragment → global 使用 TMA? + +**硬件限制**:TMA 只支持 Shared Memory ↔ Global Memory,不支持从寄存器直接发起。 + +### 架构依赖性 + +| 架构 | stmatrix | TMA store | 两次 copy 效果 | +|------|----------|-----------|---------------| +| SM90+ (Hopper) | ✓ | ✓ | **最优**:stmatrix + TMA store | +| SM80 (Ampere) | ✗ | ✗ | 无优化,退化为普通 store | +| SM70 及以下 | ✗ | ✗ | 无优化,反而增加开销 | + +**重要**:stmatrix 和 TMA 都是 **Hopper (SM90+)** 专属特性,在老架构上两次 copy 会退化为普通操作。 + +--- + +## TMA 控制参数:disable_tma + +`T.copy` 支持 `disable_tma` 参数来显式禁用 TMA: + +```python +# tilelang/tilelang/language/copy.py + +def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + coalesced_width: int | None = None, + disable_tma: bool = False, + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): + +``` + +也可以通过编译配置全局禁用: + +```python +tilelang.compile( + program, + pass_configs={"tl.disable_tma_lower": True} +) +``` + +**注意**:只能**禁用** TMA,不能显式启用——编译器会自动判断是否使用。 + +--- + +## 即使没有 TMA,经过 Shared Memory 也有好处? + +### TensorCore 输出数据布局问题 + +TensorCore 的 MMA 指令输出数据分布在多个线程的寄存器中,布局是**非连续的**: + +``` +TensorCore MMA 输出布局(简化): +┌─────────────────────────────────────┐ +│ Thread 0: C[0,0], C[0,1], C[2,0]... │ +│ Thread 1: C[0,2], C[0,3], C[2,2]... │ +│ Thread 2: C[1,0], C[1,1], C[3,0]... │ +└─────────────────────────────────────┘ +``` + +直接写回 Global Memory 会导致非合并访问(non-coalesced access)。 + +### Shared Memory 重排的作用 + +``` +Fragment (分散布局) → Shared Memory (重排) → Global Memory (连续布局) +``` + +通过 Shared Memory 中转,可以将分散的数据重新排列成连续布局,提升写回效率。 + +### 重排是隐式的吗? + +**是的,完全隐式**。TileLang 的 `T.copy` 会: +1. 根据源和目标的内存类型推断需要的布局转换 +2. 自动生成必要的 swizzle/permute 操作 +3. 开发者无法显式控制重排方式 + +### 对 elementwise_add 的影响 + +| 场景 | 是否需要重排 | 原因 | +|------|-------------|------| +| **GEMM (TensorCore)** | 是 | MMA 输出布局分散 | +| **Elementwise** | 否 | 数据本身就是连续的 | + +**elementwise_add 示例的问题**:它没有使用 `T.gemm`,所以 `C_local` 的布局本身就是连续的,两次 copy 的收益主要来自 stmatrix + TMA store。 + +--- + +## 总结 + +| 问题 | 答案 | +|------|------| +| tilelang 版本为什么用两次 copy? | 利用 stmatrix + TMA store 硬件加速 | +| 这种设计在所有架构上都有效吗? | 否,**仅 SM90+ (Hopper)** 才能发挥优势 | +| 能显式控制 TMA 吗? | 只能通过 `disable_tma` 禁用,不能显式启用 | +| 经过 shared memory 有数据重排吗? | 对 TensorCore 输出有 | +| 重排是隐式的吗? | 开发者无法控制重排的底层实现算法,但可以通过 disable_tma 选择是否使用特定硬件指令来执行这个重排 | +| mcTileLang 版本 | 直接访问 global memory | +| 哪个版本更好? | **仅在 Hopper 上** tilelang 更优,老架构上两次 copy 反而是开销 | + +--- + +## 验证方法 + +查看生成的 CUDA 代码确认实际使用的指令: + +检查是否包含: +- stmatrix 指令 +- TMA store 指令 diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bba2a036c40f8454a6b6fa34fada7c4855524587..98c3fc8bd7a3b040b2a0c548e025ee5187c50ca4 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -42,7 +42,6 @@ def get_best_config(M, N): autotuner = AutoTuner.from_kernel( kernel=kernel, configs=get_configs(M, N)).set_compile_args( out_idx=[-1], - target="cuda", ).set_profile_args( supply_type=tilelang.TensorSupplyType.Auto, ref_prog=ref_program, diff --git a/examples/elementwise/example_elementwise_gelu.py b/examples/elementwise/example_elementwise_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..ec72e7b8ea673af8dc052a2b78e36b6588741bba --- /dev/null +++ b/examples/elementwise/example_elementwise_gelu.py @@ -0,0 +1,87 @@ +"""GELU 激活函数的 Elementwise 实现""" + +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x): + return torch.nn.functional.gelu(x) + + +def elementwise_gelu(M, N, block_M, block_N, in_dtype, out_dtype, threads): + """GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))""" + + @T.prim_func + def gelu_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + val = X[y, x] + # GELU approximation + val_cubed = val * val * val + inner = 0.7978845608 * (val + 0.044715 * val_cubed) # sqrt(2/pi) ≈ 0.7978845608 + Y[y, x] = 0.5 * val * (1.0 + T.tanh(inner)) + + return gelu_kernel + + +def get_configs(M, N): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N): + + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_gelu(M, N, block_M, block_N, "float32", "float32", threads) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=ref_program, + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N) + kernel = result.kernel + else: + config = {"block_M": 128, "block_N": 256, "threads": 128} + kernel = tilelang.compile( + elementwise_gelu(M, N, **config, in_dtype="float32", out_dtype="float32"), + out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x), rtol=1e-5, atol=1e-5) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/example_elementwise_leaky_relu.py b/examples/elementwise/example_elementwise_leaky_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce5518dd765af6148001ec7585150aa15ae22c6 --- /dev/null +++ b/examples/elementwise/example_elementwise_leaky_relu.py @@ -0,0 +1,87 @@ +"""LeakyReLU 激活函数的 Elementwise 实现""" + +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x, negative_slope=0.01): + return torch.nn.functional.leaky_relu(x, negative_slope=negative_slope) + + +def elementwise_leaky_relu(M, N, block_M, block_N, in_dtype, out_dtype, threads, negative_slope=0.01): + """LeakyReLU: max(0, x) + negative_slope * min(0, x)""" + + @T.prim_func + def leaky_relu_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + val = X[y, x] + Y[y, x] = T.if_then_else(val > 0.0, val, negative_slope * val) + + return leaky_relu_kernel + + +def get_configs(M, N): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N, negative_slope=0.01): + + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_leaky_relu(M, N, block_M, block_N, "float32", "float32", threads, negative_slope) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=lambda x: ref_program(x, negative_slope), + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--negative_slope", type=float, default=0.01) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + negative_slope = args.negative_slope + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N, negative_slope) + kernel = result.kernel + else: + config = {"block_M": 128, "block_N": 256, "threads": 128} + kernel = tilelang.compile( + elementwise_leaky_relu(M, N, **config, in_dtype="float32", out_dtype="float32", + negative_slope=negative_slope), + out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x, negative_slope), rtol=1e-5, atol=1e-5) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/example_elementwise_relu.py b/examples/elementwise/example_elementwise_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..a366dd503b0177be1f12fdef5835833cfd983843 --- /dev/null +++ b/examples/elementwise/example_elementwise_relu.py @@ -0,0 +1,80 @@ +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x): + return torch.relu(x) + + +def elementwise_relu(M, N, block_M, block_N, in_dtype, out_dtype, threads): + + @T.prim_func + def relu_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + val = X[y, x] + Y[y, x] = T.if_then_else(val > 0.0, val, 0.0) + + return relu_kernel + + +def get_configs(M, N): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N): + + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_relu(M, N, block_M, block_N, "float32", "float32", threads) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=ref_program, + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N) + kernel = result.kernel + else: + config = {"block_M": 128, "block_N": 256, "threads": 128} + kernel = tilelang.compile( + elementwise_relu(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x), rtol=1e-5, atol=1e-5) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/example_elementwise_sigmoid.py b/examples/elementwise/example_elementwise_sigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..467d4403775d54eb517b9514d7291808a9463a3a --- /dev/null +++ b/examples/elementwise/example_elementwise_sigmoid.py @@ -0,0 +1,80 @@ +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x): + return torch.sigmoid(x) + + +def elementwise_sigmoid(M, N, block_M, block_N, in_dtype, out_dtype, threads): + + @T.prim_func + def sigmoid_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + val = X[y, x] + Y[y, x] = 1.0 / (1.0 + T.exp(-val)) + + return sigmoid_kernel + + +def get_configs(M, N): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N): + + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_sigmoid(M, N, block_M, block_N, "float32", "float32", threads) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=ref_program, + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N) + kernel = result.kernel + else: + config = {"block_M": 128, "block_N": 256, "threads": 128} + kernel = tilelang.compile( + elementwise_sigmoid(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x), rtol=1e-5, atol=1e-5) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/example_elementwise_silu.py b/examples/elementwise/example_elementwise_silu.py new file mode 100644 index 0000000000000000000000000000000000000000..d56c7b0745ecbfd72452b396f6d55d385b2a13bd --- /dev/null +++ b/examples/elementwise/example_elementwise_silu.py @@ -0,0 +1,83 @@ +"""SiLU (Swish) 激活函数的 Elementwise 实现""" + +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x): + return torch.nn.functional.silu(x) + + +def elementwise_silu(M, N, block_M, block_N, in_dtype, out_dtype, threads): + """SiLU (Swish): x * sigmoid(x) = x / (1 + exp(-x))""" + + @T.prim_func + def silu_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + yi = start_y + local_y + xi = start_x + local_x + val = X[yi, xi] + Y[yi, xi] = val / (1.0 + T.exp(-val)) + + return silu_kernel + + +def get_configs(M, N): + block_M = [32, 64] + block_N = [32, 64] + threads = [64, 128] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N): + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_silu(M, N, block_M, block_N, "float32", "float32", threads) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=ref_program, + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N) + kernel = result.kernel + else: + config = {"block_M": 64, "block_N": 64, "threads": 128} + kernel = tilelang.compile( + elementwise_silu(M, N, **config, in_dtype="float32", out_dtype="float32"), + out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x), rtol=1e-5, atol=1e-5) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/example_elementwise_tanh.py b/examples/elementwise/example_elementwise_tanh.py new file mode 100644 index 0000000000000000000000000000000000000000..c79130babb4e5179f0f81b4c6b50aaa01a454a06 --- /dev/null +++ b/examples/elementwise/example_elementwise_tanh.py @@ -0,0 +1,79 @@ +import argparse +import itertools +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import AutoTuner + + +def ref_program(x): + return torch.tanh(x) + + +def elementwise_tanh(M, N, block_M, block_N, in_dtype, out_dtype, threads): + + @T.prim_func + def tanh_kernel(X: T.Tensor((M, N), in_dtype), Y: T.Tensor((M, N), out_dtype)): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + start_x = bx * block_N + start_y = by * block_M + for (local_y, local_x) in T.Parallel(block_M, block_N): + y = start_y + local_y + x = start_x + local_x + Y[y, x] = T.tanh(X[y, x]) + + return tanh_kernel + + +def get_configs(M, N): + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [64, 128, 256] + configs = list(itertools.product(block_M, block_N, threads)) + return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] + + +def get_best_config(M, N): + + def kernel(block_M=None, block_N=None, threads=None): + return elementwise_tanh(M, N, block_M, block_N, "float32", "float32", threads) + + autotuner = AutoTuner.from_kernel( + kernel=kernel, configs=get_configs(M, N)).set_compile_args( + out_idx=[-1], + ).set_profile_args( + supply_type=tilelang.TensorSupplyType.Auto, + ref_prog=ref_program, + skip_check=False, + ) + return autotuner.run(warmup=3, rep=20) + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=512) + parser.add_argument("--n", type=int, default=1024) + parser.add_argument("--use_autotune", action="store_true", default=False) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + + x = torch.randn(M, N, dtype=torch.float32, device="cuda") + + if args.use_autotune: + result = get_best_config(M, N) + kernel = result.kernel + else: + config = {"block_M": 128, "block_N": 256, "threads": 128} + kernel = tilelang.compile( + elementwise_tanh(M, N, **config, in_dtype="float32", out_dtype="float32"), out_idx=-1) + + out = kernel(x) + torch.testing.assert_close(out, ref_program(x), rtol=1e-2, atol=1e-2) + print("PASS") + except Exception as e: + print(f"NO PASS: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/elementwise/test_example_elementwise.py b/examples/elementwise/test_example_elementwise.py index f1668f4aa645bf1c8b530195f323675451555f1f..9106ed6cb26d33ccb7ebc63a91eecbb9db8bb7c6 100644 --- a/examples/elementwise/test_example_elementwise.py +++ b/examples/elementwise/test_example_elementwise.py @@ -1,10 +1,40 @@ import tilelang.testing import example_elementwise_add +import example_elementwise_relu +import example_elementwise_sigmoid +import example_elementwise_tanh +import example_elementwise_silu +import example_elementwise_leaky_relu +import example_elementwise_gelu def test_example_elementwise_add(): example_elementwise_add.main() +def test_example_elementwise_relu(): + example_elementwise_relu.main() + + +def test_example_elementwise_sigmoid(): + example_elementwise_sigmoid.main() + + +def test_example_elementwise_tanh(): + example_elementwise_tanh.main() + + +def test_example_elementwise_silu(): + example_elementwise_silu.main() + + +def test_example_elementwise_leaky_relu(): + example_elementwise_leaky_relu.main() + + +def test_example_elementwise_gelu(): + example_elementwise_gelu.main() + + if __name__ == "__main__": tilelang.testing.main()