diff --git a/docs/deeplearning_operators/convolution.md b/docs/deeplearning_operators/convolution.md
index 7477c569786fcb859cb5f16560adea94b881aa01..90b764e146bb8023bda3e0ab2ae2d3131c77d53a 100644
--- a/docs/deeplearning_operators/convolution.md
+++ b/docs/deeplearning_operators/convolution.md
@@ -1,2 +1,248 @@
-Convolution
-===========
+# Convolution
+=============
+
+
+
+:::{warning}
+ This document is still **experimental** and may be incomplete.
+ Suggestions and improvements are highly encouraged—please submit a PR!
+:::
+
+:::{tip}
+Example code can be found at [`examples/convolution/example_convolution.py`](../../examples/convolution/example_convolution.py).
+:::
+
+Convolution is a fundamental operation in deep learning, widely used for feature extraction in image processing, computer vision tasks, and neural networks like Convolutional Neural Networks (CNNs). It applies a kernel (filter) to an input tensor, computing dot products over sliding windows to produce an output feature map.
+
+## Operator Functionality
+
+### Application Scenarios
+Convolution operators are essential in deep learning for tasks such as:
+- **Image Feature Extraction**: Detecting edges, textures, and patterns in images.
+- **Computer Vision**: Object detection, image classification, and segmentation in models like ResNet, VGG, and YOLO.
+- **Signal Processing**: Applied to 1D signals (e.g., audio) or 3D data (e.g., video).
+
+### Core Computation Logic
+The convolution operation computes the output feature map by sliding a kernel over the input tensor and performing element-wise multiplication followed by summation. For a 2D convolution, the output at position (i, j) is calculated as:
+
+$$
+\text{output}[i][j] = \sum_{m=0}^{k_h - 1} \sum_{n=0}^{k_w - 1} \text{input}[i \cdot s_h + m][j \cdot s_w + n] \cdot \text{kernel}[m][n]
+$$
+
+Where:
+- \( k_h, k_w \): Kernel height and width.
+- \( s_h, s_w \): Stride in height and width.
+- Padding is applied to handle boundaries.
+
+This can be extended to multi-channel inputs and outputs with bias addition.
+
+## Interface Parameters
+
+### Input Parameters
+- **Input Tensor**: Shape `(batch_size, in_channels, height, width)`, data type typically `float16` or `float32`.
+- **Kernel (Weights)**: Shape `(out_channels, in_channels, kernel_height, kernel_width)`, data type matching input.
+- **Bias (Optional)**: Shape `(out_channels,)`, data type matching input.
+
+### Output Parameters
+- **Output Tensor**: Shape `(batch_size, out_channels, out_height, out_width)`, where:
+ - `out_height = (height + 2 * padding_h - kernel_height) // stride_h + 1`
+ - `out_width = (width + 2 * padding_w - kernel_width) // stride_w + 1`
+- Data type matches input.
+
+### Optional Parameters
+- **Kernel Size**: `(kernel_height, kernel_width)`, e.g., `(3, 3)`.
+- **Stride**: `(stride_h, stride_w)`, e.g., `(1, 1)`.
+- **Padding**: `(padding_h, padding_w)`, e.g., `(1, 1)`.
+- **Dilation**: `(dilation_h, dilation_w)`, default `(1, 1)`.
+- **Groups**: Integer for grouped convolution, default 1.
+
+## Usage Example
+
+Below is an example of using the convolution operator in `TileLang` within a MACA environment. The example is divided into sections for clarity: imports and parameter setup, data construction, operator call (compilation and execution), and result verification.
+
+### Imports and Parameter Setup
+
+```python
+import torch
+import tilelang
+from tilelang.autotuner import *
+import tilelang.language as T
+import argparse
+
+def check_hopper():
+ if not torch.cuda.is_available():
+ return None
+ props = torch.cuda.get_device_properties(0)
+ compute_capability = props.major, props.minor
+ return compute_capability == (9, 0)
+
+def ref_program(stride, padding, dilation):
+ def main(A, B):
+ A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
+ B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
+ C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
+ C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
+ return C
+ return main
+
+@tilelang.jit(out_idx=[2])
+def convolution(N,
+ C,
+ H,
+ W,
+ F,
+ K,
+ S,
+ D,
+ P,
+ block_M,
+ block_N,
+ block_K,
+ num_stages,
+ threads,
+ dtype="float16",
+ accum_dtype="float"):
+ # Define kernel dimensions
+ KH, KW = K, K
+ # Compute output dimensions
+ OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
+ OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
+ dtype = "float16"
+ accum_dtype = "float"
+ # Check if running on Hopper GPU for optimized im2col
+ is_hopper = check_hopper()
+
+ @T.prim_func
+ def main(
+ data: T.Tensor((N, H, W, C), dtype),
+ kernel: T.Tensor((KH, KW, C, F), dtype),
+ out: T.Tensor((N, OH, OW, F), dtype),
+ ):
+ # Define kernel with block dimensions
+ with T.Kernel(
+ T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
+ threads=threads) as (bx, by):
+ # Allocate shared memory for data, kernel, and output
+ data_shared = T.alloc_shared((block_M, block_K), dtype)
+ kernel_shared = T.alloc_shared((block_K, block_N), dtype)
+ out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+ out_shared = T.alloc_shared((block_M, block_N), dtype)
+
+ # Flatten kernel for efficient access
+ kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
+ out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
+
+ # Annotate layouts for swizzled memory access
+ T.annotate_layout({
+ out_shared: tilelang.layout.make_swizzled_layout(out_shared),
+ data_shared: tilelang.layout.make_swizzled_layout(data_shared),
+ kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
+ })
+
+ # Clear output accumulator
+ T.clear(out_local)
+ # Pipelined loop over K dimension
+ for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
+ # Load data using im2col for Hopper or manual indexing otherwise
+ if is_hopper:
+ T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
+ else:
+ for i, j in T.Parallel(block_M, block_K):
+ k = k_iter * block_K + j
+ m = by * block_M + i
+ # Compute input coordinates
+ access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
+ access_w = m % OW * S + k // C % KW * D - P
+ in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
+ (access_w < W))
+ data_shared[i, j] = T.if_then_else(
+ in_bound, data[m // (OH * OW), access_h, access_w, k % C], 0)
+ # Copy kernel to shared memory
+ T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
+ # Perform GEMM operation
+ T.gemm(data_shared, kernel_shared, out_local)
+
+ # Copy results back to global memory
+ T.copy(out_local, out_shared)
+ T.copy(out_shared, out_flat[by * block_M, bx * block_N])
+
+ return main
+```
+
+### Data Construction
+
+```python
+def main(argv=None):
+ # Parse command-line arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--n', type=int, default=128, help='n')
+ parser.add_argument('--c', type=int, default=128, help='c')
+ parser.add_argument('--h', type=int, default=64, help='h')
+ parser.add_argument('--w', type=int, default=64, help='w')
+ parser.add_argument('--f', type=int, default=128, help='f')
+ parser.add_argument('--k', type=int, default=3, help='k')
+ parser.add_argument('--s', type=int, default=1, help='s')
+ parser.add_argument('--d', type=int, default=1, help='d')
+ parser.add_argument('--p', type=int, default=1, help='p')
+
+ args = parser.parse_args(argv)
+ N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
+ # Generate random input and kernel tensors
+ a = torch.randn(N, H, W, C).cuda().half()
+ b = torch.randn(K, K, C, F).cuda().half()
+
+ a = a.contiguous()
+ b = b.contiguous()
+```
+
+### Operator Call (Compilation and Execution)
+
+```python
+ # Set block and thread parameters
+ block_m = 64
+ block_n = 128
+ block_k = 32
+ num_stages = 3
+ threads = 256
+ # Compile the convolution kernel
+ kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads)
+
+ # Execute the kernel
+ out_c = kernel(a, b)
+```
+
+### Result Verification
+
+```python
+ # Generate reference output using PyTorch
+ ref_c = ref_program(S, P, D)(a, b)
+ # Verify results match within tolerance
+ torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
+ print("All checks passed.✅")
+
+
+if __name__ == "__main__":
+ main()
+```
+
+This example demonstrates an optimized convolution kernel using im2col and GEMM. For production use, further optimizations and hardware-specific tuning refer to the example code at [`examples/convolution/example_convolution_autotune.py`](../../examples/convolution/example_convolution_autotune.py ).
+
+## Performance Notes
+
+### Recommended Configurations on Specific Devices
+
+- **GPU Model**: 曦云 C500-64GB
+- **Optimal Config**: {'block_M': 64, 'block_N': 128, 'block_K': 32, 'num_stages': 1, 'thread_num': 256, 'enable_rasteration': True}
+ - **Note:** This optimal configuration was obtained by running `python examples/convolution/example_convolution_autotune.py` with the example's default convolution parameters. It applies only to those default parameters. To obtain the best configuration for other convolution settings, re-run the autotuner and specify the desired convolution arguments via command-line options (for example: `--n`, `--c`, `--h`, `--w`, `--f`, `--k`, `--s`, `--d`, `--p`).
+
+
+### Performance Optimization Suggestions
+- **Tiling**: Divide input and output into tiles to fit in shared memory, reducing redundant loads.
+- **Vectorized Loads**: Use `tl.vectorized` for loading kernel and input data in chunks (e.g., float4 for float16).
+- **Autotuning**: Use `tilelang.autotune` to search optimal block sizes, thread counts, and tiling factors.
+- **Fusions**: Combine convolution with activation functions (e.g., ReLU) to reduce kernel launches.
+- **Hardware Awareness**: Leverage tensor cores on NVIDIA GPUs or Metax GPUs for mixed-precision computations.
+
+For advanced implementations, refer to the example code at [`examples/convolution/example_convolution.py`](../../examples/convolution/example_convolution.py ). Contributions to improve this documentation are welcome!
\ No newline at end of file
diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py
index e37dac280918ed1b1daac84701c206dcf7e4ca70..8a715ea5af2165320233b89320a7b1a27724eb40 100644
--- a/examples/convolution/example_convolution.py
+++ b/examples/convolution/example_convolution.py
@@ -111,6 +111,9 @@ def main(argv=None):
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
+ a = a.contiguous()
+ b = b.contiguous()
+
block_m = 64
block_n = 128
block_k = 32
@@ -123,6 +126,7 @@ def main(argv=None):
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
+ print("All checks passed.")
if __name__ == "__main__":
diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py
index bb3a147a5a0741d8156012b2bb0e53cf48f166e2..a0122757595f2e69cfb4144a9f5a29b3c1fba416 100644
--- a/examples/convolution/example_convolution_autotune.py
+++ b/examples/convolution/example_convolution_autotune.py
@@ -6,7 +6,7 @@ from tilelang.autotuner import *
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import ConvTemplate
-from tilelang.carver.arch import CUDA
+from tilelang.carver.arch import MACA
from tilelang.carver.roller.rasterization import NoRasterization
@@ -32,7 +32,7 @@ def ref_program(stride, padding, dilation):
def get_configs(N, C, H, W, F, K, S, D, P, with_roller=False, topk=15):
if with_roller:
- arch = CUDA("cuda")
+ arch = MACA("maca")
carve_template = ConvTemplate(
N=N,
C=C,
@@ -291,14 +291,15 @@ def main(n: int = 128,
with_roller: bool = True):
N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p
ref_prog = ref_program(S, P, D)
- use_autotune = True
+ # use_autotune = True
if use_autotune:
result = get_best_config(N, C, H, W, F, K, S, D, P, ref_prog, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
- kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_dix=[2])
+ kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, config["block_M"], config["block_N"], config["block_K"],
+ config["num_stages"], config["thread_num"]), out_idx=[2])
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
@@ -327,7 +328,7 @@ if __name__ == "__main__":
parser.add_argument(
"--with_roller",
action="store_true",
- default=True,
+ default=False,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p, args.use_autotune,