From 35f2b6369db9a1f63f7eb60108fcb4129aa06787 Mon Sep 17 00:00:00 2001 From: lxy Date: Thu, 18 Sep 2025 15:23:20 +0800 Subject: [PATCH] Add flagtree dot hint support files --- ...atrix-multiplication-optimized-flagtree.py | 214 ++++++++++++++++++ .../triton_patch/compiler/code_generator.py | 25 ++ .../python/triton_patch/runtime/jit.py | 21 ++ 3 files changed, 260 insertions(+) create mode 100644 ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py diff --git a/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py b/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py new file mode 100644 index 0000000..e1c0d22 --- /dev/null +++ b/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py @@ -0,0 +1,214 @@ +import torch +import torch_npu +import triton +import triton.language as tl +import triton.runtime.driver as driver + + +# get device properties of npu +def get_npu_properties(): + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device) + + +@triton.autotune( +configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), + ], + key=["M", "N", "K"] +) +@triton.jit +def matmul_kernel( + mat_a, mat_b, mat_c, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + num_cores: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_TRESHHOLD: tl.constexpr, +): + pid = tl.program_id(axis=0) + task_m_idx = 0 + task_n_idx = 0 + + ''' + 水平分核方式每个任务块编号如下 + [0, 1, 2, 3, 4, 5, 6, 7] + [8, 9, 10, 11, 12, 13, 14, 15] + [16, 17, 18, 19, 20, 21, 22, 23] + [24, 25, 26, 27, 28, 29, 30, 31] + [32, 33, 34, 35, 36, 37, 38, 39] + [40, 41, 42, 43, 44, 45, 46, 47] + [48, 49, 50, 51, 52, 53, 54, 55] + [56, 57, 58, 59, 60, 61, 62, 63] + 0核处理 0 20 40 60 4块任务 + 1核处理 1 21 41 61 4块任务 + 2核处理 2 22 42 62 4块任务 + ... + 19核处理 19 39 59 3块任务 + + 大shape下如果使用传统水平分核方式,会有如下问题 + 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 + 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, + 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 + 算子执行效率 + 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 + + 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 + 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 + [0, 8, 16, 24, 32, 40, 48, 56] + [57, 1, 9, 17, 25, 33, 41, 49] + [50, 58, 2, 10, 18, 26, 34, 42] + [43, 51, 59, 3, 11, 19, 27, 35] + [36, 44, 52, 60, 4, 12, 20, 28] + [29, 37, 45, 53, 61, 5, 13, 21] + [22, 30, 38, 46, 54, 62, 6, 14] + [15, 23, 31, 39, 47, 55, 63, 7] + + M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 + 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 + 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 + ''' + NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) + NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) + NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N + #当任务量较多时,可以使能对角线分核策略进行优化 + if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + #8 * 8 对角线分核代码实现 + curThresholdM = BLOCK_TRESHHOLD if block_idx < (NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD + curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD + curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < (curThresholdM * NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD + localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) + task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * NUM_BLOCKS_N) * BLOCK_TRESHHOLD + #求最小公倍数,方便求基本块的坐标 + x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM + while y != 0: + x, y = y, x % y + lcm = curThresholdM * curThresholdN // x + task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD + + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) # @hint: dot_pad_only_k + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) # @hint: dot_pad_only_k + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + else: + #传统顺序分核 + for block_idx in range ( + pid, NUM_BLOCKS, num_cores + ): + task_m_idx = block_idx // NUM_BLOCKS_N + task_n_idx = block_idx % NUM_BLOCKS_N + m_start = task_m_idx * BLOCK_M + n_start = task_n_idx * BLOCK_N + + mat_c_block = tl.zeros((BLOCK_M, BLOCK_N),dtype = tl.float32) + for k_start in range(0, K, BLOCK_K): + mat_a_offset = ((m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + ( + k_start + tl.arange(0, BLOCK_K) + )[None, :] + mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (k_start + tl.arange(0, BLOCK_K)) < K + )[None, :] + mat_a_block = tl.load(mat_a + mat_a_offset, mask = mat_a_mask, other = 0.0) # @hint: dot_pad_only_k + mat_b_offset = ((k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + mat_b_block = tl.load(mat_b + mat_b_offset, mask = mat_b_mask, other = 0.0) # @hint: dot_pad_only_k + mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) + mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + ( + n_start + tl.arange(0, BLOCK_N) + )[None, :] + mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( + (n_start + tl.arange(0, BLOCK_N)) < N + )[None, :] + tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask = mat_c_mask) + +def triton_matmul( + mat_a, + mat_b, +): + m = mat_a.shape[0] + k = mat_a.shape[1] + n = mat_b.shape[1] + mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) + + ''' + NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 256 + ''' + + num_cores = get_npu_properties()["num_aicore"] + + matmul_kernel[(num_cores,)] ( + mat_a, + mat_b, + mat_c, + m, + n, + k, + num_cores + ) + return mat_c + + +if __name__ == "__main__": + M = 2048 + K = 7168 + N = 16384 + + mat_a = torch.randn([M, K], dtype = torch.bfloat16, device = "npu") + mat_b = torch.randn([K, N], dtype = torch.bfloat16, device = "npu") + + result = triton_matmul(mat_a, mat_b) + golden = torch.matmul(mat_a, mat_b) + + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2 ** -6 + try: + torch.testing.assert_close(result[mask], golden[mask], atol = tmpatol, rtol = 0) + torch.testing.assert_close(result[~mask], golden[~mask], atol = 0, rtol = tmprtol) + print("run matmul success") + except: + print(f"[ERROR] M={M} ,K={K}, N={N}存在精度问题") \ No newline at end of file diff --git a/triton_patch/python/triton_patch/compiler/code_generator.py b/triton_patch/python/triton_patch/compiler/code_generator.py index a4b1233..d8ee4fd 100644 --- a/triton_patch/python/triton_patch/compiler/code_generator.py +++ b/triton_patch/python/triton_patch/compiler/code_generator.py @@ -486,6 +486,7 @@ class CodeGenerator(ast.NodeVisitor): return self.visit_Assign(node) def visit_Assign(self, node): + # flagtree: First, do normal assignment processing _names = [] if isinstance(node, ast.AnnAssign): _names += [self.visit(node.target)] @@ -509,6 +510,30 @@ class CodeGenerator(ast.NodeVisitor): not isinstance(value, native_nontensor_types): value = language.semantic.to_tensor(value, self.builder) self.set_value(name, value) + + # flagtree: After normal processing, check if we need to add hint annotation + if hasattr(node, 'lineno') and hasattr(self, 'jit_fn'): + line_num = node.lineno + # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later + function_def = self.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + + # Check if this is a tl.load call with dot_pad_only_k hint + if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and + isinstance(node.value, ast.Call) and + isinstance(node.value.func, ast.Attribute) and + isinstance(node.value.func.value, ast.Name) and + node.value.func.value.id == 'tl' and + node.value.func.attr == 'load'): + + # Add hint annotation to the loaded tensor(s) + for name, value in zip(names, values): + if _is_triton_value(value): + # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") + # Create hint annotation + hint_val = self.builder.get_unit_attr() + self.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) def visit_AugAssign(self, node): name = node.target.id diff --git a/triton_patch/python/triton_patch/runtime/jit.py b/triton_patch/python/triton_patch/runtime/jit.py index f6eaf52..281c5de 100644 --- a/triton_patch/python/triton_patch/runtime/jit.py +++ b/triton_patch/python/triton_patch/runtime/jit.py @@ -6,12 +6,14 @@ import itertools import os import re import textwrap +import tokenize from collections import defaultdict from functools import cached_property from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple from ..runtime.driver import driver from ..backends.ascend.compiler import AscendAttrsDescriptor from types import ModuleType +from io import StringIO TRITON_MODULE = __name__[:-len(".runtime.jit")] @@ -773,10 +775,29 @@ class JITFunction(KernelInterface[T]): # the user might want to monkey-patch self.src dynamically. # Our unit tests do this, for example. def parse(self): + # Maps line numbers to comment hints + line_flagtree_hints = {} + code_str = self.src + g = tokenize.generate_tokens(StringIO(code_str).readline) + for tok_type, tok_text, start, end, _ in g: + if tok_type == tokenize.COMMENT: + comment = tok_text.replace(" ", "").strip() + if comment.startswith('#@hint:'): + flagtree_hints = comment[len('#@hint:'):].strip() + # Record the line number of the comment + line_num = start[0] + line_flagtree_hints[line_num] = flagtree_hints + + # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + tree = ast.parse(self.src) assert isinstance(tree, ast.Module) assert len(tree.body) == 1 assert isinstance(tree.body[0], ast.FunctionDef) + + # Attach the line number to comment mapping to the function definition node + tree.body[0].line_flagtree_hints = line_flagtree_hints + return tree def __call__(self, *args, **kwargs): -- Gitee