diff --git a/examples/fusion_operator/README.md b/examples/fusion_operator/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01fdbec03f77190bedf443aa8e32aea7850ebd90 --- /dev/null +++ b/examples/fusion_operator/README.md @@ -0,0 +1,134 @@ +# Operator Fusion Example (Conv+Bias+ReLU) + +This directory contains a complete example of implementing Operator Fusion using TileLang on MACA GPU. Specifically, it demonstrates fusing **Convolution**, **Bias Addition**, and **ReLU Activation** into a single efficient kernel. + +## Overview + +Operator fusion is a critical optimization technique in deep learning compilers. By combining multiple operations into a single kernel, we can: +- **Reduce Memory I/O**: Intermediate results (like the output of Convolution) are kept in registers or shared memory instead of being written to and read back from global memory. +- **Reduce Kernel Launch Overhead**: Only one kernel is launched instead of three. +- **Improve Cache Locality**: Data is processed immediately after generation. + +This example compares a fused implementation against a baseline of three separate kernels (Conv -> Bias -> ReLU). + +## Mathematical Formula + +The fused operation computes: + +``` +Output = ReLU(Conv2d(Input, Weight) + Bias) +``` + +Where: +- `Conv2d` is a standard 2D convolution. +- `Bias` is a channel-wise addition. +- `ReLU(x) = max(x, 0)` is the activation function. + +## Files + +| File | Description | +|------|-------------| +| `example_fusion.py` | Main implementation of fused and unfused kernels, with performance benchmarking. | +| `test_fusion.py` | Unit tests. | +| `README.md` | This documentation. | + +## Usage + +### Basic Example + +Here is the core code for defining and compiling the fused kernel: + +```python +import torch +import tilelang +import tilelang.language as T +from example_fusion import conv_bias_relu_fusion + +# Define dimensions +N, C, H, W, F = 16, 16, 64, 64, 16 +K, S, D, P = 3, 1, 1, 1 + +# Optimization parameters +block_m, block_n, block_k = 64, 128, 32 +num_stages, threads = 3, 256 + +# Create fused kernel program +program = conv_bias_relu_fusion( + N, C, H, W, F, K, S, D, P, + block_m, block_n, block_k, num_stages, threads +) + +# Compile +kernel = tilelang.compile(program, out_idx=3, target="maca") + +# Prepare data +data = torch.randn(N, H, W, C).cuda().half() +weight = torch.randn(K, K, C, F).cuda().half() +bias = torch.randn(F).cuda().half() + +# Run fused kernel +output = kernel(data, weight, bias) +``` + +### Running Benchmarks + +To compare the performance of the fused kernel against the unfused baseline: + +```bash +python examples/fusion_operator/example_conv2d_bias_relu_fusion.py +``` + +### Running Tests + + +```bash +python examples/fusion_operator/test_example_conv2d_bias_relu_fusion.py +``` +## Performance + +*These are untuning performance results.* + +Tested on MetaX C500 GPU (64GB): + +| Config (NxHxW) | Speedup | Status | +|----------------|---------|--------| +| 16x64x64 | 1.20 | ✅ Pass | +| 32x256x512 | 1.14 | ✅ Pass | +| 32x1024x1024 | 1.40 | ✅ Pass | + +## Fusion Strategy Implementation + +The fusion is implemented by extending the standard Convolution kernel. The key steps are: + +1. **Pipelined Convolution**: + The main computation loop uses `T.Pipelined` to overlap memory loads with Tensor Core GEMM operations. The partial results are accumulated in local registers (`out_local`). + +2. **Bias Loading**: + The Bias tensor is small and reused across spatial dimensions. It is pre-loaded into Shared Memory (`bias_shared`) or accessed directly if small enough. + +3. **In-Register Fusion**: + After the convolution loop completes, the results reside in `out_local` (registers). Instead of writing them to global memory immediately: + - **Bias Add**: We add the bias value from shared memory to the accumulator. + - **ReLU**: We apply `T.max(val, 0)` to the result. + + ```python + # Fusion Logic in TileLang + for i, j in T.Parallel(block_M, block_N): + # 1. Add Bias + biased_value = out_local[i, j] + bias_shared[j] + # 2. Apply ReLU + out_shared[i, j] = T.max(biased_value, T.cast(0.0, accum_dtype)) + ``` + +4. **Single Write-Back**: + Only the final activated result is written to Global Memory, saving two rounds of global memory access (writing Conv output and reading it back for Bias/ReLU). + +## Hardware Requirements + +- MACA GPU (MetaX C500 or compatible) +- MACA SDK 2.33.1 or later +- TileLang with MACA support + +## License + +Copyright 2025 MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved. diff --git a/examples/fusion_operator/example_conv2d_bias_relu_fusion.py b/examples/fusion_operator/example_conv2d_bias_relu_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..be3be687f8879cec281da0875f4b940a0a78eb53 --- /dev/null +++ b/examples/fusion_operator/example_conv2d_bias_relu_fusion.py @@ -0,0 +1,377 @@ +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import * + + +def ref_program(stride, padding, dilation): + """ + Reference PyTorch implementation of Conv2d + Bias + ReLU. + + Args: + stride: Convolution stride + padding: Convolution padding + dilation: Convolution dilation + + Returns: + Function that takes (A, B, bias) and returns the result + """ + def main(A, B, bias): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C + bias.view(1, -1, 1, 1) + C = torch.relu(C) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + return main + + +# 1. Convolution Kernel +def convolution(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, + dtype="float16", accum_dtype="float"): + """ + Standard Convolution Kernel using TileLang. + + Args: + N, C, H, W: Input dimensions (Batch, Channel, Height, Width) + F, K: Output channels, Kernel size + S, D, P: Stride, Dilation, Padding + block_M, block_N, block_K: Tiling parameters + num_stages: Pipeline stages + threads: Number of threads per block + dtype: Data type for computation + accum_dtype: Accumulation data type + + Returns: + TileLang primitive function + """ + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + T.clear(out_local) + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and + (access_w < W)) + data_shared[i, j] = T.if_then_else( + in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + T.copy(out_local, out_shared) + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +# 2. Bias Add Kernel +def bias_add(N, OH, OW, F, block_size=256, dtype="float16"): + """ + Bias Addition Kernel. + + Adds a bias vector to the output of convolution. + + Args: + N, OH, OW, F: Output dimensions + block_size: Thread block size + dtype: Data type + + Returns: + TileLang primitive function + """ + @T.prim_func + def main( + conv_out: T.Tensor((N, OH, OW, F), dtype), + bias: T.Tensor((F,), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + conv_flat = T.Tensor((N * OH * OW, F), dtype, conv_out.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + with T.Kernel(T.ceildiv(N * OH * OW, block_size), threads=block_size) as bx: + for i in T.Parallel(block_size): + row_idx = bx * block_size + i + if row_idx < N * OH * OW: + for f in T.Parallel(F): + out_flat[row_idx, f] = conv_flat[row_idx, f] + bias[f] + + return main + + +# 3. ReLU Kernel +def relu_activation(N, OH, OW, F, block_size=256, dtype="float16"): + """ + ReLU Activation Kernel. + + Applies ReLU(x) = max(x, 0) element-wise. + + Args: + N, OH, OW, F: Input dimensions + block_size: Thread block size + dtype: Data type + + Returns: + TileLang primitive function + """ + @T.prim_func + def main( + input_tensor: T.Tensor((N, OH, OW, F), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + input_flat = T.Tensor((N * OH * OW * F,), dtype, input_tensor.data) + out_flat = T.Tensor((N * OH * OW * F,), dtype, out.data) + + total_elements = N * OH * OW * F + with T.Kernel(T.ceildiv(total_elements, block_size), threads=block_size) as bx: + for i in T.Parallel(block_size): + idx = bx * block_size + i + if idx < total_elements: + input_val = input_flat[idx] + zero_val = T.Cast(dtype, 0.0) + out_flat[idx] = T.if_then_else(input_val > zero_val, input_val, zero_val) + + return main + + +# 4. Fused Kernel (Conv + Bias + ReLU) +def conv_bias_relu_fusion(N, C, H, W, F, K, S, D, P, + block_M, block_N, block_K, num_stages, threads, + dtype="float16", accum_dtype="float"): + """ + Fused Convolution + Bias + ReLU Kernel. + + Performs convolution, adds bias, and applies ReLU in a single kernel + to reduce memory I/O and kernel launch overhead. + + Args: + N, C, H, W: Input dimensions + F, K: Output channels, Kernel size + S, D, P: Stride, Dilation, Padding + block_M, block_N, block_K: Tiling parameters + num_stages: Pipeline stages + threads: Number of threads per block + dtype: Data type + accum_dtype: Accumulation data type + + Returns: + TileLang primitive function + """ + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + bias: T.Tensor((F,), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + bias_shared = T.alloc_shared((block_N,), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + T.copy(bias[bx * block_N:(bx + 1) * block_N], bias_shared) + T.clear(out_local) + + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = ((access_h >= 0) and (access_w >= 0) and + (access_h < H) and (access_w < W)) + data_shared[i, j] = T.if_then_else( + in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + # Fused Bias + ReLU + for i, j in T.Parallel(block_M, block_N): + biased_value = out_local[i, j] + bias_shared[j] + out_shared[i, j] = T.max(biased_value, T.cast(0.0, accum_dtype)) + + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main + + +def run_benchmark(config): + """ + Run benchmark for Fused vs Unfused kernels. + + Args: + config: Tuple of (N, C, H, W, F, K, S, D, P) + + Returns: + dict with latency results + """ + N, C, H, W, F, K, S, D, P = config + print(f"\nBenchmark: N={N}, H={H}, W={W}, F={F}, C={C}") + print("-" * 50) + + # Optimization params + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + + # Calculate Output Dims + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + # Compile Kernels + fused_kernel = tilelang.compile( + conv_bias_relu_fusion(N, C, H, W, F, K, S, D, P, + block_m, block_n, block_k, num_stages, threads), + out_idx=3) + + conv_kernel = tilelang.compile( + convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads), + out_idx=2) + + bias_kernel = tilelang.compile( + bias_add(N, OH, OW, F), + out_idx=2) + + relu_kernel = tilelang.compile( + relu_activation(N, OH, OW, F), + out_idx=1) + + # Prepare Data + a = torch.randn(N, H, W, C).cuda().half().contiguous() + b = torch.randn(K, K, C, F).cuda().half().contiguous() + bias = torch.randn(F).cuda().half().contiguous() + + # Correctness Check + ref_out = ref_program(S, P, D)(a, b, bias) + + # Fused Correctness + out_fused = fused_kernel(a, b, bias) + torch.testing.assert_close(out_fused, ref_out, rtol=1e-2, atol=1e-2) + + # Unfused Correctness + conv_out = conv_kernel(a, b) + bias_out = bias_kernel(conv_out, bias) + out_unfused = relu_kernel(bias_out) + torch.testing.assert_close(out_unfused, ref_out, rtol=1e-2, atol=1e-2) + + print(f"Output Tensor Shape: {ref_out.shape}") + + print("✅ Correctness: PASSED") + + # Benchmarking + profiler_fused = fused_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency_fused = profiler_fused.do_bench(fused_kernel, warmup=100, rep=100) + + profiler_conv = conv_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency_conv = profiler_conv.do_bench(conv_kernel, warmup=100, rep=100) + + profiler_bias = bias_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency_bias = profiler_bias.do_bench(bias_kernel, warmup=100, rep=100) + + profiler_relu = relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + latency_relu = profiler_relu.do_bench(relu_kernel, warmup=100, rep=100) + + latency_unfused = latency_conv + latency_bias + latency_relu + + print(f"Fused Latency: {latency_fused:.4f} ms") + print(f"Unfused Latency: {latency_unfused:.4f} ms") + print(f"Speedup: {latency_unfused/latency_fused:.2f}x") + + # Cleanup + del a, b, bias, ref_out, out_fused, out_unfused, conv_out, bias_out + del fused_kernel, conv_kernel, bias_kernel, relu_kernel + torch.cuda.empty_cache() + + return { + "config": f"{N}x{H}x{W}", + "fused_ms": latency_fused, + "unfused_ms": latency_unfused, + "speedup": latency_unfused / latency_fused + } + +def main(): + """ + Main entry point for running benchmarks. + """ + print("=" * 60) + print("TileLang Fusion Example (Conv+Bias+ReLU)") + print("=" * 60) + + # Configs: (N, C, H, W, F, K, S, D, P) + configs = [ + (16, 16, 64, 64, 16, 3, 1, 1, 1), + (32, 32, 256, 512, 32, 3, 1, 1, 1), + (32, 32, 1024, 1024, 64, 3, 1, 1, 1), + ] + + results = [] + for config in configs: + try: + result = run_benchmark(config) + results.append(result) + except Exception as e: + print(f"Failed for config {config}: {e}") + import traceback + traceback.print_exc() + + # Summary + print("\n" + "=" * 70) + print("Summary") + print("=" * 70) + print(f"{'Config (NxHxW)':<20} {'Speedup':<10}") + print("-" * 70) + for r in results: + print(f"{r['config']:<20} {r['speedup']:<10.2f}x") + + +if __name__ == "__main__": + main() diff --git a/examples/fusion_operator/test_example_conv2d_bias_relu_fusion.py b/examples/fusion_operator/test_example_conv2d_bias_relu_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f22668490159b3f5dea8c1cdc5742d96ec369c08 --- /dev/null +++ b/examples/fusion_operator/test_example_conv2d_bias_relu_fusion.py @@ -0,0 +1,9 @@ +import tilelang.testing + +import example_conv2d_bias_relu_fusion + +def test_example_convolution(): + example_conv2d_bias_relu_fusion.main() + +if __name__ == "__main__": + tilelang.testing.main()