diff --git a/examples/tanhshrink/README.md b/examples/tanhshrink/README.md new file mode 100644 index 0000000000000000000000000000000000000000..65bc95706c61e5027f550ed91828c6ae329be3d7 --- /dev/null +++ b/examples/tanhshrink/README.md @@ -0,0 +1,105 @@ +# TanhShrink Operator Implementation with TileLang + +## Operator Description + +### Application Scenario +`TanhShrink` is an activation function defined mathematically as: +$$ f(x) = x - \tanh(x) $$ + +It is commonly used in deep neural networks, particularly in scenarios involving residual learning or where sparse activation is beneficial. As an Element-wise operation, it is typically applied in: +- Nonlinear activation for hidden layers. +- Feature processing in Generative Adversarial Networks (GANs) or specific regression tasks. + +### Core Calculation Logic +This operator is classified as a **Memory-Bound** kernel. The core implementation logic is as follows: + +* Data Loading: Utilizes TileLang's automatic vectorization to efficiently load data from Global Memory directly into Register Fragments. + +* Parallel Computation: The input tensor `(M, N)` is divided into tiles, each processed by a Thread Block. Inside the block, threads perform the $x - \tanh(x)$ calculation in parallel using SIMT execution. +* Write Back: The computed results are written back directly to Global Memory. +* Optimization: Since no cross-thread reduction is required, the implementation skips Shared Memory and computes directly in registers to minimize latency and maximize bandwidth utilization. + +## Interface Parameters + +### Input Parameters +- A (Input Tensor): + - Shape: `(M, N)`, 2D Tensor. + - Data Type: `float` (float32). +- B (Output Tensor): + - Shape: `(M, N)`, 2D Tensor (Buffer for results). + - Data Type: `float` (float32). + +### Optional Parameters +- block_M: Tile size for the M dimension. +- block_N: Tile size for the N dimension. + +## Usage + +### Basic Example + +```python +import torch +import tilelang +from example_tanh_shrink import tanh_shrink, ref_program + +# Define dimensions +M, N = 4096, 4096 +block_M, block_N = 64, 128 + +# Create kernel +program = tanh_shrink(M, N, block_M, block_N) +kernel = tilelang.compile( + program, + target="maca", + execution_backend="cython", + pass_configs={"tl.disable_tma_lower": True} +) + +# Run kernel +x = torch.randn(M, N, device="cuda", dtype=torch.float32) +y = torch.empty_like(x) +kernel(x, y) + +# Validate +y_ref = ref_program(x) +torch.testing.assert_close(y, y_ref, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# Profile latency +profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) +latency = profiler.do_bench() + +print(f"Latency: {latency:.3f} ms") +``` + +### Running Benchmarks + +The provided script includes a built-in benchmark suite that tests the operator across multiple scales and compares performance against PyTorch. + +```bash +cd mcTileLang/examples/tanhshrink +python example_tanh_shrink.py +``` + +### Running Tests + +using `tilelang.testing` for test + +```bash +python test_tanh_shrink.py +``` + +--- + +## Performance Report + +### Test Environment +- **Hardware**: XiCloud C500 +- **Software**: TileLang 0.15, Python 3.10, MACA SDK 2.33.1.12 + +### Benchmark Results +| Scale (M, N) | Ref Time (ms) | Tile Time (ms) | Speedup | +| :--------------- | :------------ | :------------- | :-------- | +| **(1024, 1024)** | 0.032 | 0.026 | **1.21x** | +| **(4096, 4096)** | 0.289 | 0.122 | **2.36x** | +| **(8192, 8192)** | 0.956 | 0.436 | **2.19x** | diff --git a/examples/tanhshrink/example_tanh_shrink.py b/examples/tanhshrink/example_tanh_shrink.py new file mode 100644 index 0000000000000000000000000000000000000000..ff88271c8c098775690c3aa2a851f6cd70c800a8 --- /dev/null +++ b/examples/tanhshrink/example_tanh_shrink.py @@ -0,0 +1,170 @@ +import torch +import tilelang +import tilelang.language as T +import torch.nn.functional as F + +def tanh_shrink(M, N, block_M, block_N): + """ + TanhShrink Operator Implementation using TileLang. + Formula: y = x - tanh(x) + + Args: + M (int): Number of rows. + N (int): Number of columns. + block_M (int): Block size for M dimension. + block_N (int): Block size for N dimension. + + Returns: + The compiled TileLang primitive function. + """ + dtype = "float" + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + # Allocate Register Fragments + val_local = T.alloc_fragment((block_M, block_N), dtype) + + # Load Data (Global -> Register) + T.copy(A[bx * block_M, by * block_N], val_local) + + # Parallel Computation + for i, j in T.Parallel(block_M, block_N): + x = val_local[i, j] + # Formula: x - tanh(x) + val_local[i, j] = x - T.tanh(x) + + # Store Data (Register -> Global) + T.copy(val_local, B[bx * block_M, by * block_N]) + + return main + +def ref_program(x): + """ + PyTorch reference implementation for TanhShrink. + """ + return F.tanhshrink(x) + +def run_benchmark(configs, block_config, target="maca"): + """ + Main loop for executing benchmarks across different configurations. + + Args: + configs (list): List of dictionaries containing dataset shapes (M, N). + block_config (dict): Dictionary containing block sizes (block_M, block_N). + target (str): Compilation target (default: "maca"). + """ + block_M = block_config["block_M"] + block_N = block_config["block_N"] + + # Print table header + print(f"\n{'='*80}") + print(f"{'TanhShrink Operator Benchmark (TileLang vs PyTorch)':^80}") + print(f"{'='*80}") + print(f"Target: {target} | Block Config: M={block_M}, N={block_N}") + print(f"{'-'*80}") + + # Store summary results + summary_results = [] + + for cfg in configs: + M, N = cfg["M"], cfg["N"] + scale_str = f"({M}, {N})" + + # Build and Compile TileLang Kernel + program = tanh_shrink(M, N, block_M, block_N) + + compile_success = False + error_msg = "" + + try: + kernel = tilelang.compile( + program, + out_idx=1, + target=target, + execution_backend="cython", + pass_configs={"tl.disable_tma_lower": True} + ) + compile_success = True + except Exception as e: + error_msg = str(e) + + if not compile_success: + print(f"| {scale_str:<18} | Compilation Failed: {error_msg} |") + summary_results.append((scale_str, "-", "-", "-", "Compile Err")) + continue + + # Automatically manages Tensor generation and memory allocation + profiler = kernel.get_profiler() + + # Correctness Verification + status = "PASSED" + try: + profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) + except Exception as e: + status = "FAILED" + print(f"Validation Error: {e}") + + # Performance Benchmarking + if status == "PASSED": + ref_latency = profiler.do_bench(ref_program, warmup=500) + tl_latency = profiler.do_bench(warmup=500) + speedup = ref_latency / tl_latency if tl_latency > 0 else 0.0 + + print(f"Running Scale {scale_str:<15} ... Done.") + print(f" > Status: {status}(Correctness verified)") + print(f" > Ref Time: {ref_latency:.3f} ms") + print(f" > Tile Time: {tl_latency:.3f} ms") + print(f" > Speedup: {speedup:.2f}x") + print(f"{'-'*80}") + + summary_results.append((scale_str, ref_latency, tl_latency, speedup, status)) + else: + print(f"Running Scale {scale_str:<15} ... Failed Validation.") + print(f"{'-'*80}") + summary_results.append((scale_str, "-", "-", "-", status)) + + # Summary Result Display + print(f"{'='*80}") + # Table Header + header = f"| {'Scale (M, N)':<20} | {'Ref (ms)':>12} | {'Tile (ms)':>12} | {'Speedup':>10} | {'Status':^10} |" + print(header) + print(f"|{'-'*22}|{'-'*14}|{'-'*14}|{'-'*12}|{'-'*12}|") + + for res in summary_results: + scale, ref, tile, spd, st = res + + # Format data + ref_str = f"{ref:.3f}" if isinstance(ref, float) else str(ref) + tile_str = f"{tile:.3f}" if isinstance(tile, float) else str(tile) + spd_str = f"{spd:.2f}x" if isinstance(spd, float) else str(spd) + + row = f"| {scale:<20} | {ref_str:>12} | {tile_str:>12} | {spd_str:>10} | {st:^10} |" + print(row) + print(f"{'='*80}\n") + +def main(): + """ + Entry point for the benchmark script. + Configures parameters and runs the benchmark suite. + """ + # Block Configuration + BLOCK_CONFIG = { + "block_M": 64, + "block_N": 128 + } + + # Dataset Scale Configuration + DATA_CONFIGS = [ + {"M": 1024, "N": 1024, "desc": "Small"}, + {"M": 4096, "N": 4096, "desc": "Medium"}, + {"M": 8192, "N": 8192, "desc": "Large"} + ] + + # Testing target + TARGET = "maca" + + run_benchmark(DATA_CONFIGS, BLOCK_CONFIG, TARGET) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/tanhshrink/test_tanh_shrink.py b/examples/tanhshrink/test_tanh_shrink.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9b6f8585857f6e5fa6a562f539738f0ee058fa --- /dev/null +++ b/examples/tanhshrink/test_tanh_shrink.py @@ -0,0 +1,13 @@ +import tilelang.testing + +import example_tanh_shrink +from unittest import mock +import sys + +def test_example_tanh_shrink(): + with mock.patch.object(sys, 'argv', ["example_tanh_shrink.py"]): + example_tanh_shrink.main() + + +if __name__ == "__main__": + tilelang.testing.main() \ No newline at end of file