diff --git a/torch_npu/_inductor/codegen/matmul_autotune.py b/torch_npu/_inductor/codegen/matmul_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..e8288bddccfc7e79d0bff99653c37a0889e58970 --- /dev/null +++ b/torch_npu/_inductor/codegen/matmul_autotune.py @@ -0,0 +1,176 @@ +import math +from torch._inductor.runtime.triton_heuristics import Config + + +class MatmulAutotune: + def __init__(self, M, N, K, dtype): + ''' + :param M: A矩阵的第一维 + :param N: B矩阵的第二维 + :param K: A矩阵的第二维 + :param dtype: torch tensor数据类型 + ''' + self.M = M + self.N = N + self.K = K + self.dtype = dtype + self.L1Size = 512 * 1024 / 2 # 预留一半给临时变量(否则测试时报错,需要编译器捕获异常) + self.L0Size = 64 * 1024 + self.L0CSize = 128 * 1024 + + def next_power_of_2(self, n): + if n <= 0: + return 0 + power = math.floor(math.log2(n)) + return int(2 ** power) + + def get_configs(self, configs): + ''' + 针对8*8分核算法的配置生成 + :param configs: list([BLOCK_M, BLOCK_N, BLOCK_K, SUB_BLOCK_M, SUB_BLOCK_N, SUB_BLOCK_K, num_stages, num_warps]) + :return: configs + ''' + + if self.M >= 64 and self.N >= 64: + BLOCK_M_start = min(self.M, 256) + BLOCK_M_start = self.next_power_of_2(BLOCK_M_start) + + for i in [1, 2, 4]: + BLOCK_M = BLOCK_M_start // i + if BLOCK_M < 8: + break + BLOCK_N_start = int(self.L0CSize // 4 // BLOCK_M_start) + for j in [1, 2, 4]: + BLOCK_N = BLOCK_N_start // j + if BLOCK_N < 8: + break + BLOCK_K_start = int(self.L1Size // ((BLOCK_M + BLOCK_N) * self.dtype.itemsize)) + BLOCK_K_start = min(BLOCK_K_start, self.K) + for k in [1, 2, 4]: + BLOCK_K = BLOCK_K_start // k + if BLOCK_K < 1: + break + SUB_BLOCK_K = int(self.L0Size // (max(BLOCK_M, BLOCK_N) * self.dtype.itemsize)) + SUB_BLOCK_K = min(SUB_BLOCK_K, BLOCK_K) + configs.append([BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_M, BLOCK_N, SUB_BLOCK_K, 1, 1]) + else: + min_mn = min(self.M, self.N) + max_mn = max(self.M, self.N) + + # 1) 最小轴不切分 + BLOCK_min_mn = min_mn + BLOCK_max_mn_start = int(self.L0CSize // 4 // BLOCK_min_mn) + BLOCK_max_mn_start = min(BLOCK_max_mn_start, max_mn) + BLOCK_max_mn_start = self.next_power_of_2(BLOCK_max_mn_start) + for i in [1, 2, 4]: + BLOCK_max_mn = BLOCK_max_mn_start // i + if BLOCK_max_mn < 8: + break + + BLOCK_K_start = int(self.L1Size // ((BLOCK_min_mn + BLOCK_max_mn) * self.dtype.itemsize)) + BLOCK_K_start = min(BLOCK_K_start, self.K) + for k in [1, 2, 4]: + BLOCK_K = BLOCK_K_start // k + if BLOCK_K < 1: + break + SUB_BLOCK_K = int(self.L0Size // (max(BLOCK_min_mn, BLOCK_max_mn) * self.dtype.itemsize)) + SUB_BLOCK_K = min(SUB_BLOCK_K, BLOCK_K) + if self.M < self.N: + configs.append([BLOCK_min_mn, BLOCK_max_mn, BLOCK_K, + BLOCK_min_mn, BLOCK_max_mn, SUB_BLOCK_K, 1, 1]) + else: + configs.append([BLOCK_max_mn, BLOCK_min_mn, BLOCK_K, + BLOCK_max_mn, BLOCK_min_mn, SUB_BLOCK_K, 1, 1]) + + # 2) 最小轴切分: 32 16 8 + BLOCK_min_mn_start = self.next_power_of_2(min_mn) + for i in [1, 2, 4]: + BLOCK_min_mn = BLOCK_min_mn_start // i + if BLOCK_min_mn < 8: + break + BLOCK_max_mn_start = int(self.L0CSize // 4 // BLOCK_min_mn) + for j in [1, 2, 4]: + BLOCK_max_mn = BLOCK_max_mn_start // j + BLOCK_max_mn = min(BLOCK_max_mn, max_mn) + BLOCK_max_mn = self.next_power_of_2(BLOCK_max_mn) + if BLOCK_max_mn < 8: + break + BLOCK_K_start = int(self.L1Size // ((BLOCK_min_mn + BLOCK_max_mn) * self.dtype.itemsize)) + BLOCK_K_start = min(BLOCK_K_start, self.K) + for k in [1, 2, 4]: + BLOCK_K = BLOCK_K_start // k + if BLOCK_K < 1: + break + SUB_BLOCK_K = int(self.L0Size // (max(BLOCK_min_mn, BLOCK_max_mn) * self.dtype.itemsize)) + SUB_BLOCK_K = min(SUB_BLOCK_K, BLOCK_K) + if self.M < self.N: + configs.append([BLOCK_min_mn, BLOCK_max_mn, BLOCK_K, + BLOCK_min_mn, BLOCK_max_mn, SUB_BLOCK_K, 1, 1]) + else: + configs.append([BLOCK_max_mn, BLOCK_min_mn, BLOCK_K, + BLOCK_max_mn, BLOCK_min_mn, SUB_BLOCK_K, 1, 1]) + + def get_configs_heuristics(self, configs): + ''' + 通用启发式tiling算法 + :param configs: list([BLOCK_M, BLOCK_N, BLOCK_K, SUB_BLOCK_M, SUB_BLOCK_N, SUB_BLOCK_K, num_stages, num_warps]) + :return: configs + ''' + min_mn = min(self.M, self.N) + max_mn = max(self.M, self.N) + + if min_mn < 128: + # 补充不切分的情况 + BLOCK_min_mn = min_mn + + BLOCK_max_mn_start = int(self.L0CSize // 4 // BLOCK_min_mn) + BLOCK_max_mn_start = min(max_mn, BLOCK_max_mn_start) + BLOCK_max_mn_start = self.next_power_of_2(BLOCK_max_mn_start) + for j in [1, 2, 4, 8]: + BLOCK_max_mn = BLOCK_max_mn_start // j + if BLOCK_max_mn < 8: + break + BLOCK_K_start = int(self.L1Size // ((BLOCK_min_mn + BLOCK_max_mn) * self.dtype.itemsize)) + BLOCK_K_start = min(BLOCK_K_start, self.K) + BLOCK_K_start = self.next_power_of_2(BLOCK_K_start) + for k in [1, 2, 4]: + BLOCK_K = BLOCK_K_start // k + if BLOCK_K < 1: + break + SUB_BLOCK_K = int(self.L0Size // (max(BLOCK_min_mn, BLOCK_max_mn) * self.dtype.itemsize)) + SUB_BLOCK_K = min(SUB_BLOCK_K, BLOCK_K) + if self.M < self.N: + configs.append([BLOCK_min_mn, BLOCK_max_mn, BLOCK_K, + BLOCK_min_mn, BLOCK_max_mn, SUB_BLOCK_K, 1, 1]) + else: + configs.append([BLOCK_max_mn, BLOCK_min_mn, BLOCK_K, + BLOCK_max_mn, BLOCK_min_mn, SUB_BLOCK_K, 1, 1]) + + BLOCK_min_mn_start = min(min_mn, 128) + BLOCK_min_mn_start = self.next_power_of_2(BLOCK_min_mn_start) + for i in [1, 2, 4, 8]: + BLOCK_min_mn = BLOCK_min_mn_start // i + if BLOCK_min_mn < 8: + break + BLOCK_max_mn_start = int(self.L0CSize // 4 // BLOCK_min_mn) + BLOCK_max_mn_start = min(max_mn, BLOCK_max_mn_start) + BLOCK_max_mn_start = self.next_power_of_2(BLOCK_max_mn_start) + for j in [1, 2, 4, 8]: + BLOCK_max_mn = BLOCK_max_mn_start // j + if BLOCK_max_mn < 8: + break + BLOCK_K_start = int(self.L1Size // ((BLOCK_min_mn + BLOCK_max_mn) * self.dtype.itemsize)) + BLOCK_K_start = min(BLOCK_K_start, self.K) + BLOCK_K_start = self.next_power_of_2(BLOCK_K_start) + for k in [1, 2, 4]: + BLOCK_K = BLOCK_K_start // k + if BLOCK_K < 1: + break + SUB_BLOCK_K = int(self.L0Size // (max(BLOCK_min_mn, BLOCK_max_mn) * self.dtype.itemsize)) + SUB_BLOCK_K = min(SUB_BLOCK_K, BLOCK_K) + if self.M < self.N: + configs.append([BLOCK_min_mn, BLOCK_max_mn, BLOCK_K, + BLOCK_min_mn, BLOCK_max_mn, SUB_BLOCK_K, 1, 1]) + else: + configs.append([BLOCK_max_mn, BLOCK_min_mn, BLOCK_K, + BLOCK_max_mn, BLOCK_min_mn, SUB_BLOCK_K, 1, 1])