diff --git a/docs/deeplearning_operators/conv2d_bias_relu_fusion.md b/docs/deeplearning_operators/conv2d_bias_relu_fusion.md new file mode 100644 index 0000000000000000000000000000000000000000..710b797380045e164841dbd6e06a987454e221e4 --- /dev/null +++ b/docs/deeplearning_operators/conv2d_bias_relu_fusion.md @@ -0,0 +1,186 @@ +# Conv2d + Bias + ReLU Fusion +============================= + +
+ Contributor: @TileLang +
+ +:::{warning} + This document is still **experimental** and may be incomplete. + Suggestions and improvements are highly encouraged—please submit a PR! +::: + +:::{tip} +Example code can be found at [`examples/fusion_operator/example_conv2d_bias_relu_fusion.py`](../../examples/fusion_operator/example_conv2d_bias_relu_fusion.py). +::: + +Operator fusion is a critical optimization technique in deep learning compilers. This operator fuses **2D Convolution**, **Bias Addition**, and **ReLU Activation** into a single kernel. By performing these operations together, we reduce global memory round-trips (memory bandwidth pressure) and kernel launch overheads. + +## Operator Functionality + +### Application Scenarios +This fused pattern is extremely common in Convolutional Neural Networks (CNNs): +- **Standard CNN Layers**: Most convolutional layers are immediately followed by a bias addition and a non-linear activation function (like ReLU). +- **Inference Optimization**: Fusing these operations is a standard step in optimizing models for deployment (e.g., TensorRT, TVM). + +### Core Computation Logic +The fused operation computes: + +$$ +\text{output} = \text{ReLU}(\text{Conv2d}(\text{input}, \text{kernel}) + \text{bias}) +$$ + +Where: +1. **Conv2d**: Computes the sliding window dot product. +2. **Bias Add**: Adds a broadcasted bias vector to the convolution result. +3. **ReLU**: Applies $f(x) = \max(0, x)$ element-wise. + +All intermediate results (after convolution, before bias/ReLU) are kept in fast on-chip memory (registers or shared memory) rather than writing back to global memory. + +## Interface Parameters + +### Input Parameters +- **Input Tensor**: Shape `(batch_size, height, width, in_channels)`, data type typically `float16`. +- **Kernel (Weights)**: Shape `(kernel_h, kernel_w, in_channels, out_channels)`, data type matching input. +- **Bias**: Shape `(out_channels,)`, data type matching input. + +### Output Parameters +- **Output Tensor**: Shape `(batch_size, out_height, out_width, out_channels)`. + - `out_height` and `out_width` are calculated based on stride, padding, and dilation. +- Data type matches input. + +### Configuration Parameters +- **Dimensions**: `N, C, H, W` (Input), `F` (Output Channels), `K` (Kernel Size). +- **Convolution Params**: `S` (Stride), `D` (Dilation), `P` (Padding). +- **Tiling Params**: `block_M`, `block_N`, `block_K` for GEMM tiling. + +## Usage Example + +Below is an example of using the fused operator in `TileLang`. + +### Implementation Details + +```python +import torch +import tilelang +import tilelang.language as T +from tilelang.autotuner import * + +def ref_program(stride, padding, dilation): + def main(A, B, bias): + A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W + B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W + C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation) + C = C + bias.view(1, -1, 1, 1) + C = torch.relu(C) + C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C + return C + return main + +def conv_bias_relu_fusion(N, C, H, W, F, K, S, D, P, + block_M, block_N, block_K, num_stages, threads, + dtype="float16", accum_dtype="float"): + KH, KW = K, K + OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 + OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 + + @T.prim_func + def main( + data: T.Tensor((N, H, W, C), dtype), + kernel: T.Tensor((KH, KW, C, F), dtype), + bias: T.Tensor((F,), dtype), + out: T.Tensor((N, OH, OW, F), dtype), + ): + with T.Kernel( + T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), + threads=threads) as (bx, by): + + data_shared = T.alloc_shared((block_M, block_K), dtype) + kernel_shared = T.alloc_shared((block_K, block_N), dtype) + out_local = T.alloc_fragment((block_M, block_N), accum_dtype) + out_shared = T.alloc_shared((block_M, block_N), dtype) + bias_shared = T.alloc_shared((block_N,), dtype) + + kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data) + out_flat = T.Tensor((N * OH * OW, F), dtype, out.data) + + T.annotate_layout({ + out_shared: tilelang.layout.make_swizzled_layout(out_shared), + data_shared: tilelang.layout.make_swizzled_layout(data_shared), + kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared), + }) + + T.copy(bias[bx * block_N:(bx + 1) * block_N], bias_shared) + T.clear(out_local) + + for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages): + for i, j in T.Parallel(block_M, block_K): + k = k_iter * block_K + j + m = by * block_M + i + access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P + access_w = m % OW * S + k // C % KW * D - P + in_bound = ((access_h >= 0) and (access_w >= 0) and + (access_h < H) and (access_w < W)) + data_shared[i, j] = T.if_then_else( + in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0) + + T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared) + T.gemm(data_shared, kernel_shared, out_local) + + # Fused Bias + ReLU + for i, j in T.Parallel(block_M, block_N): + biased_value = out_local[i, j] + bias_shared[j] + out_shared[i, j] = T.max(biased_value, T.cast(0.0, accum_dtype)) + + T.copy(out_shared, out_flat[by * block_M, bx * block_N]) + + return main +``` + +### Data Construction + +```python + # Parameters + N, C, H, W, F, K = 16, 16, 64, 64, 16, 3 + S, D, P = 1, 1, 1 + + # Tiling Config + block_m, block_n, block_k = 64, 128, 32 + num_stages, threads = 3, 256 + + # Data + a = torch.randn(N, H, W, C).cuda().half().contiguous() + b = torch.randn(K, K, C, F).cuda().half().contiguous() + bias = torch.randn(F).cuda().half().contiguous() +``` + +### Operator Call (Compilation and Execution) + +```python + # Compile + fused_kernel = tilelang.compile( + conv_bias_relu_fusion(N, C, H, W, F, K, S, D, P, + block_m, block_n, block_k, num_stages, threads), + out_idx=3) + + # Run + out_fused = fused_kernel(a, b, bias) +``` + +### Result Verification + +```python + # Verify + ref_out = ref_program(S, P, D)(a, b, bias) + torch.testing.assert_close(out_fused, ref_out, rtol=1e-2, atol=1e-2) + print("Verification Passed!") +``` + +## Performance Notes + +### Optimization Suggestions +- **Fusion**: Always prefer fused kernels over separate calls to `conv2d`, `add`, and `relu` to save memory bandwidth. +- **Shared Memory**: Loading bias into shared memory allows efficient reuse across the block. +- **Accumulation**: Perform bias addition and ReLU on the accumulation registers (or fragments) before casting back to the output data type to maintain precision. + +For more details, please refer to the README located in the same directory as the example code: [`examples/fusion_operator/README.md`](../../examples/fusion_operator/README.md).