diff --git a/examples/scalenorm/README.md b/examples/scalenorm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e441eecbe49211b5d703c41ca972124be491a634 --- /dev/null +++ b/examples/scalenorm/README.md @@ -0,0 +1,132 @@ +# ScaleNorm Operator Implementation with TileLang + +## Operator Description + +### **Application** Scenario + +ScaleNorm (Scaled L2 Normalization) is a normalization technique often used in Transformer architectures as a computationally efficient alternative to LayerNorm. It normalizes activations based on their L2-norm and applies a learnable affine transformation. +The mathematical definition is: +$$ y = \frac{x}{\|x\|_2 + \epsilon} \cdot w + b $$ + +Where: + +* $\|x\|_2$ is the L2-norm across the feature dimension. +* $w$ (Weight) and $b$ (Bias) are learnable parameters. + +### Core Calculation Logic + +This operator combines Reduction and Element-wise operations. It is a Memory-Bound kernel. The execution flow is: + +* Row Loading: Loads a row of data from Global Memory to Shared Memory to minimize repeated global access during reduction. +* Square & Reduce: Computes $x^2$ locally and performs a parallel reduction sum across the feature dimension ($N$) to calculate $\|x\|_2$. +* Fused Affine Transform: Applies the normalization factor, multiplies by weight $w$, adds bias $b$, and writes the result back to Global Memory in a single fused pass. + + +## Interface Parameters + +### Input Parameters + +- Input (Tensor): + - Shape: $(*, N)$, where $*$ means any number of additional dimensions (e.g., Batch Size, Sequence Length) and $N$ represents the feature dimension (Hidden Size) to be normalized. + - Type: `float32`. +- Weight (Tensor): + - Shape: $(N,)$, `float32`. Learnable scale parameter corresponding to the feature dimension. +- Bias (Tensor): + - Shape: $(N,)$, `float32`. Learnable shift parameter corresponding to the feature dimension. +- Output (Tensor): + - Shape: $(*, N)$, `float32`. The output tensor has the same shape as the input. + +### Optional Parameters + +- blk_m: Tile size for the M dimension. + + +## Usage + +### Basic Example + +```python +import torch +import tilelang +from example_scale_norm import scalenorm, ref_program + +# 1. Configuration +M, N = 4096, 4096 +blk_m = 1 + +# Create and Compile Kernel +# The @tilelang.jit decorator compiles the kernel for the MACA target +kernel = scalenorm(M, N, blk_m) + +# Prepare Data +x = torch.randn(M, N, device="cuda", dtype=torch.float32) +weight = torch.ones(N, device="cuda", dtype=torch.float32) +bias = torch.zeros(N, device="cuda", dtype=torch.float32) + +# Run Kernel +y = kernel(x, weight, bias) + +# Validate +y_ref = ref_program(x, weight, bias) +torch.testing.assert_close(y, y_ref, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# Profile latency +profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) +latency = profiler.do_bench() + +print(f"Latency: {latency:.3f} ms") +``` + +### Running Benchmarks + +```bash +cd mcTileLang/examples/scalenorm +python example_scale_norm.py +``` + +### Running Tests + +using `tilelang.testing` for test + +```bash +python test_scale_norm.py +``` + + +## Performance Report + +### Test Environment + +- **Hardware**: XiCloud C500 +- **Software**: TileLang 0.15, Python 3.10, MACA SDK 2.33.1.12 + +### Benchmark Results + +The TileLang implementation achieves significant speedups compared to PyTorch, especially on small to medium workloads where overhead reduction is critical. + +| Scale (M, N) | Ref Time (ms) | Tile Time (ms) | Speedup | +| :--------------- | :------------ | :------------- | :-------- | +| **(1024, 1024)** | 0.092 | 0.024 | **3.89x** | +| **(4096, 4096)** | 0.486 | 0.124 | **3.90x** | +| **(8192, 8192)** | 1.505 | 0.594 | **2.53x** | + +## Implementation Details + +* **`T.Kernel` & `T.ceildiv`**: + * Used to define the GPU Grid and Block dimensions. + +* **`T.alloc_shared` (Shared Memory)**: + * `input_shared = T.alloc_shared((blk_m, N), dtype)` + +* **`T.alloc_fragment` (Register Memory)**: + * Allocates data directly in GPU registers for the fastest possible access during mathematical computations (square, add, multiply). + +* **`T.copy` (Data Movement)**: + * Provides explicit control over memory transfers (Global $\leftrightarrow$ Shared $\leftrightarrow$ Register). + +* **`T.Parallel`**: + * `for i, j in T.Parallel(blk_m, N):`This construct maps the loops to GPU threads. + +* **`T.reduce_sum` **: + * This is a highly optimized primitive that performs a tree-based reduction across threads using Warp Shuffle or Shared Memory primitives. It is the core component for calculating the $L_2$ norm ($\sum x^2$) efficiently. \ No newline at end of file diff --git a/examples/scalenorm/example_scale_norm.py b/examples/scalenorm/example_scale_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..3256d5a7b93c80a2407271faa26de2bd4345297e --- /dev/null +++ b/examples/scalenorm/example_scale_norm.py @@ -0,0 +1,206 @@ +import torch +import tilelang +import tilelang.language as T +import torch.nn.init as init + +@tilelang.jit( + out_idx=[3], + target="maca", + execution_backend="cython", + pass_configs={"tl.disable_tma_lower": True} +) +def scalenorm(M, N, blk_m): + """ + ScaleNorm Kernel Implementation using TileLang. + Formula: y = w * (x / (norm(x, p=2) + eps)) + b + + Args: + M (int): Number of rows (Batch size * Sequence length). + N (int): Number of columns (Hidden dimension). + blk_m (int): Block size for the M dimension. + + Returns: + The JIT-compiled TileLang primitive function. + """ + dtype = "float" + eps = 1e-5 + + @T.prim_func + def main( + Input: T.Tensor((M, N), dtype), + Weight: T.Tensor((N,), dtype), + Bias: T.Tensor((N,), dtype), + Output: T.Tensor((M, N), dtype) + ): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + # Allocate Shared Memory and Register Fragments + input_shared = T.alloc_shared((blk_m, N), dtype) + + input_local = T.alloc_fragment((blk_m, N), dtype) + sq_local = T.alloc_fragment((blk_m, N), dtype) + + sum_sq = T.alloc_fragment((blk_m,), dtype) + + # Data Loading + T.copy(Input[bx * blk_m : (bx + 1) * blk_m, :], input_shared) + + T.copy(input_shared, input_local) + + # Compute + for i, j in T.Parallel(blk_m, N): + val = input_local[i, j] + sq_local[i, j] = val * val + + T.reduce_sum(sq_local, sum_sq, dim=1) + + for i in T.Parallel(blk_m): + norm_val = T.sqrt(sum_sq[i]) + sum_sq[i] = 1.0 / (norm_val + eps) + + # Apply Affine Transformation + for i, j in T.Parallel(blk_m, N): + # Load Weight and Bias + w_val = Weight[j] + b_val = Bias[j] + x_val = input_local[i, j] + factor = sum_sq[i] + + # Formula: y = w * x * factor + b + input_local[i, j] = w_val * (x_val * factor) + b_val + + # Write Back Results + T.copy(input_local, Output[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + +def ref_program(input, weight, bias, eps=1e-5): + """ + PyTorch Reference Implementation of ScaleNorm. + + Args: + input (torch.Tensor): Input tensor. + weight (torch.Tensor): Weight parameter. + bias (torch.Tensor): Bias parameter. + eps (float): Epsilon for numerical stability. + + Returns: + torch.Tensor: Normalized and scaled output. + """ + # math: weight * input / (norm(input) + eps) + bias + norm = torch.norm(input, p=2, dim=-1, keepdim=True) + scalenorm = weight * input / (norm + eps) + if bias is not None: + scalenorm = scalenorm + bias + return scalenorm + +def run_benchmark(configs, block_config, target="maca"): + """ + Main loop for executing benchmarks across different configurations. + + Args: + configs (list): List of dictionaries containing dataset shapes (M, N). + block_config (dict): Dictionary containing block parameters (blk_m). + target (str): Compilation target (e.g., "maca", "cuda"). + """ + blk_m = block_config["blk_m"] + + # Print Benchmark Header + print(f"\n{'='*80}") + print(f"{'ScaleNorm Operator Benchmark (TileLang vs PyTorch)':^80}") + print(f"{'='*80}") + print(f"Target: {target} | Block Config: blk_m={blk_m}") + print(f"{'-'*80}") + + summary_results = [] + + for cfg in configs: + M, N = cfg["M"], cfg["N"] + scale_str = f"({M}, {N})" + + compile_success = False + error_msg = "" + kernel = None + + # Acquire/Compile JIT Kernel + try: + # The decorator handles JIT compilation returning a callable wrapper + kernel = scalenorm(M, N, blk_m) + compile_success = True + except Exception as e: + error_msg = str(e) + + if not compile_success: + print(f"| {scale_str:<18} | Compilation Failed: {error_msg} |") + summary_results.append((scale_str, "-", "-", "-", "Compile Err")) + continue + + # Automatically manages Tensor generation and memory allocation + profiler = kernel.get_profiler() + + # Correctness Verification + status = "PASSED" + try: + profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) + except Exception as e: + status = "FAILED" + print(f"Validation Error: {e}") + + # Performance Benchmarking + if status == "PASSED": + ref_latency = profiler.do_bench(ref_program, warmup=500) + tl_latency = profiler.do_bench(warmup=500) + speedup = ref_latency / tl_latency if tl_latency > 0 else 0.0 + + print(f"Running Scale {scale_str:<15} ... Done.") + print(f" > Status: {status}(Correctness verified)") + print(f" > Ref Time: {ref_latency:.3f} ms") + print(f" > Tile Time: {tl_latency:.3f} ms") + print(f" > Speedup: {speedup:.2f}x") + print(f"{'-'*80}") + + summary_results.append((scale_str, ref_latency, tl_latency, speedup, status)) + else: + print(f"Running Scale {scale_str:<15} ... Failed Validation.") + print(f"{'-'*80}") + summary_results.append((scale_str, "-", "-", "-", status)) + + # Summary Result Display + print(f"{'='*80}") + # Table Header + header = f"| {'Scale (M, N)':<20} | {'Ref (ms)':>12} | {'Tile (ms)':>12} | {'Speedup':>10} | {'Status':^10} |" + print(header) + print(f"|{'-'*22}|{'-'*14}|{'-'*14}|{'-'*12}|{'-'*12}|") + + for res in summary_results: + scale, ref, tile, spd, st = res + + # Format data + ref_str = f"{ref:.3f}" if isinstance(ref, float) else str(ref) + tile_str = f"{tile:.3f}" if isinstance(tile, float) else str(tile) + spd_str = f"{spd:.2f}x" if isinstance(spd, float) else str(spd) + + row = f"| {scale:<20} | {ref_str:>12} | {tile_str:>12} | {spd_str:>10} | {st:^10} |" + print(row) + print(f"{'='*80}\n") + + +def main(): + """ + Entry point for the benchmark script. + """ + BLOCK_CONFIG = { + "blk_m": 1, + } + + DATA_CONFIGS = [ + {"M": 1024, "N": 1024, "desc": "Small"}, + {"M": 4096, "N": 4096, "desc": "Medium"}, + {"M": 8192, "N": 8192, "desc": "Large"} + ] + + TARGET = "maca" + + run_benchmark(DATA_CONFIGS, BLOCK_CONFIG, TARGET) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/scalenorm/test_scale_norm.py b/examples/scalenorm/test_scale_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d23c4301422c2c8e7f9f6ef5fde5dec4817433 --- /dev/null +++ b/examples/scalenorm/test_scale_norm.py @@ -0,0 +1,13 @@ +import tilelang.testing + +import example_scale_norm +from unittest import mock +import sys + +def test_example_scale_norm(): + with mock.patch.object(sys, 'argv', ["example_scale_norm.py"]): + example_scale_norm.main() + + +if __name__ == "__main__": + tilelang.testing.main() \ No newline at end of file