diff --git a/.gitmodules b/.gitmodules index a3e7dafc1abe00b5e6b7037cbda6e816ca39a84d..e480fc6e336b795a85b06c5274a9c5755c7c513e 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,3 +17,6 @@ [submodule "third_party/nlohmann"] path = third_party/nlohmann url = https://gitee.com/mirrors/nlohmann-json.git +[submodule "third_party/catlass"] + path = third_party/catlass + url = https://gitee.com/ascend/catlass.git diff --git a/test/_inductor/test_catlass_backend.py b/test/_inductor/test_catlass_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0d69d7a62f4ab168701e15bd980b436c1d4786 --- /dev/null +++ b/test/_inductor/test_catlass_backend.py @@ -0,0 +1,272 @@ +# Owner(s): ["module: inductor"] +import math +import os +from typing import Callable, List, Optional + +import torch +from testutils import TestUtils +from torch._inductor import config +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, parametrize, run_tests) + +import torch_npu + +_CATLASS_DIR = os.path.join(os.path.dirname(__file__), "../../third_party/catlass") + + +class NpuConfigPatch: + def __init__(self, kwargs): + self.kwargs = kwargs + self.old_cfgs = {} + self.old_envs = {} + + def __enter__(self): + from torch_npu._inductor import config as npu_config + + for key, val in self.kwargs.items(): + _obj = npu_config + keys = key.split(".") + for k in keys[:-1]: + _obj = getattr(_obj, k) + + env_name = f"TORCHINDUCTOR_NPU_{keys[-1].upper()}" + self.old_envs[env_name] = os.environ.get(env_name) + + old_val = getattr(_obj, keys[-1], None) + setattr(_obj, keys[-1], val) + if isinstance(val, bool): + os.environ[env_name] = "1" if val else "0" + elif val is None: + os.environ[env_name] = "" + else: + os.environ[env_name] = str(val) + + self.old_cfgs[keys[-1]] = (_obj, old_val) + + def __exit__(self, exc_type, exc_value, traceback): + # restore env + for env_name, old_val in self.old_envs.items(): + if old_val is None: + os.environ.pop(env_name, None) + else: + os.environ[env_name] = old_val + + # restore npu_config + for key, pair in self.old_cfgs.items(): + setattr(pair[0], key, pair[1]) + + @staticmethod + def patch_npu_config(kwargs): + return NpuConfigPatch(kwargs) + + +class TestCatlassBackend(TestUtils): + + @staticmethod + def mm_assert_close(actual, golden): + mask = golden.abs() < 1.0 + tmpatol = tmprtol = 2**-6 + + torch.testing.assert_close(actual[mask], golden[mask], atol=tmpatol, rtol=0) + torch.testing.assert_close(actual[~mask], golden[~mask], atol=0, rtol=tmprtol) + + def test_max_autotune_precompile(self): + """ + Make sure autotuning mm in sub processes work without crashes. + """ + + def mm(a, b): + return a @ b + + a = torch.randn(100, 10).npu().half() + b = torch.randn(10, 100).npu().half() + + patch_npu_cfgs = { + "npu.catlass_dir": _CATLASS_DIR, + "npu.catlass_max_profiling_configs": 2, + } + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": True, + "max_autotune_gemm_backends": "CATLASS,ATen", + "compile_threads": 4, + } + ), NpuConfigPatch.patch_npu_config(patch_npu_cfgs): + Y_compiled = torch.compile(mm, dynamic=False)(a, b) + Y = mm(a, b) + self.mm_assert_close(Y_compiled, Y) + + @parametrize("dynamic", (False,)) + @parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) + @parametrize("max_autotune_gemm_backends", ("CATLASS", "ATen,CATLASS")) + def test_max_autotune_catlass_backend_regular_mm( + self, + dtype: torch.dtype, + dynamic: bool, + max_autotune_gemm_backends: str, + ): + if max_autotune_gemm_backends == "CATLASS" and torch.version.hip: + return + + def mm(a, b): + return a @ b + + m, n, k = 128, 128, 16 + a = torch.randn((m, k), device="npu", dtype=dtype) + b = torch.randn((k, n), device="npu", dtype=dtype) + + patch_npu_cfgs = { + "npu.catlass_dir": _CATLASS_DIR, + "npu.catlass_max_profiling_configs": 4, + } + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + } + ), NpuConfigPatch.patch_npu_config(patch_npu_cfgs): + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) + Y = mm(a, b) + self.mm_assert_close(Y_compiled, Y) + + @parametrize("dynamic", (False,)) + @parametrize("dtype", (torch.float32, torch.float16, torch.bfloat16)) + @parametrize("max_autotune_gemm_backends", ("CATLASS", "ATen,CATLASS")) + def test_max_autotune_catlass_backend_simple_bmm( + self, + dtype: torch.dtype, + dynamic: bool, + max_autotune_gemm_backends: str, + ): + if max_autotune_gemm_backends == "CATLASS" and torch.version.hip: + return + + def bmm(a, b): + return torch.bmm(a, b) + + batch_size = 10 + m, n, k = 256, 256, 32 + a = torch.randn((batch_size, m, k), device="npu", dtype=dtype) + b = torch.randn((batch_size, k, n), device="npu", dtype=dtype) + + patch_npu_cfgs = { + "npu.catlass_dir": _CATLASS_DIR, + "npu.catlass_max_profiling_configs": 4, + } + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + } + ), NpuConfigPatch.patch_npu_config(patch_npu_cfgs): + Y_compiled = torch.compile(bmm, dynamic=dynamic)(a, b) + Y = bmm(a, b) + self.mm_assert_close(Y_compiled, Y) + + @parametrize("dynamic", (False,)) + @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float32)) + @parametrize("max_autotune_gemm_backends", ("CATLASS", "ATen,CATLASS")) + def test_max_autotune_catlass_backend_addmm( + self, + dtype: torch.dtype, + dynamic: bool, + max_autotune_gemm_backends: str, + ): + if max_autotune_gemm_backends == "CATLASS" and torch.version.hip: + return + + def addmm(x, a, b, alpha, beta): + return torch.addmm(x, a, b, alpha=alpha, beta=beta) + + def compare_results( + m: int, k: int, n: int, alpha: float, beta: float, x_shape: List[int] + ) -> None: + x = torch.randn(x_shape, device="npu", dtype=dtype) + a = torch.randn((m, k), device="npu", dtype=dtype) + b = torch.randn((k, n), device="npu", dtype=dtype) + Y = addmm(x, a, b, alpha, beta) + + compiled_fn = torch.compile(addmm, dynamic=dynamic) + Y_compiled = compiled_fn(x, a, b, alpha, beta) + self.mm_assert_close(Y_compiled, Y) + + patch_npu_cfgs = { + "npu.catlass_dir": _CATLASS_DIR, + "npu.catlass_max_profiling_configs": 4, + } + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + } + ), NpuConfigPatch.patch_npu_config(patch_npu_cfgs): + # 1. GEMM template does not support bfloat16 + # 2. MatmulBias template has precision issue on bfloat16 + if dtype != torch.bfloat16: + # No broadcast + compare_results(4096, 2578, 2048, 2.0, 0.4, [4096, 2048]) + # Boardcast first dim (only support standard alpha & beta) + compare_results(4096, 2578, 2048, 1.0, 1.0, [2048]) + + def _test_max_autotune_catlass_gemm_autotune( + self, + max_autotune_gemm_backends: str = "CATLASS", + fp16=True, + mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + batch_size: Optional[int] = None, + ): + if batch_size is None: + a = torch.randn(256, 32).npu() + b = torch.randn(32, 256).npu() + else: + a = torch.randn(batch_size, 256, 32).npu() + b = torch.randn(batch_size, 32, 256).npu() + if fp16: + a = a.half() + b = b.half() + + patch_npu_cfgs = { + "npu.catlass_dir": _CATLASS_DIR, + "npu.catlass_max_profiling_configs": 4, + "npu.catlass_use_gemm_autotune": True, + } + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + } + ), NpuConfigPatch.patch_npu_config(patch_npu_cfgs): + Y_compiled = torch.compile(mm)(a, b) + Y = mm(a, b) + self.mm_assert_close(Y_compiled, Y) + + @parametrize("fp16", (False, True)) + def test_max_autotune_catlass_regular_mm_autotune(self, fp16: bool): + def mm(a, b): + return a @ b + + self._test_max_autotune_catlass_gemm_autotune(fp16=fp16, mm=mm) + + @parametrize("fp16", (False, True)) + def test_max_autotune_catlass_simple_bmm_autotune(self, fp16: bool): + def bmm(a, b): + return torch.bmm(a, b) + + self._test_max_autotune_catlass_gemm_autotune(fp16=fp16, mm=bmm, batch_size=10) + + +instantiate_parametrized_tests(TestCatlassBackend) + + +if __name__ == "__main__": + run_tests() diff --git a/third_party/catlass b/third_party/catlass new file mode 160000 index 0000000000000000000000000000000000000000..c97676ee233cfcceb5ff5222f1d1ab289b7d506c --- /dev/null +++ b/third_party/catlass @@ -0,0 +1 @@ +Subproject commit c97676ee233cfcceb5ff5222f1d1ab289b7d506c diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index 31761bf5c3277ab03216eee06990442a8784b11f..47f5c5fefbbee41986433638c1f05ce43355ea0f 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -39,14 +39,23 @@ else: from .runtime import _load_cached_autotuning from .utils import get_current_raw_stream, patch_device_need_guard + from .async_compile import patch_async_compile + from .autotune_process import patch_tuning_process, patch_tuning_process_pool + from .select_algorithm import patch_algorithm_selector + from .kernel import ( + _register_npu_inductor_mm, + _register_npu_inductor_addmm, + _register_npu_inductor_bmm, + ) + set_compile_threads() def _inductor_register_backend_for_device(): - from .codegen.scheduling import NPUTritonScheduling + from .codegen.npu_combined_scheduling import NPUCombinedScheduling from .codegen.wrapper import NPUWrapperCodeGen from .codegen.cpp_wrapper import CppWrapperNpu - register_backend_for_device('npu', NPUTritonScheduling, NPUWrapperCodeGen, CppWrapperNpu) + register_backend_for_device('npu', NPUCombinedScheduling, NPUWrapperCodeGen, CppWrapperNpu) _inductor_register_backend_for_device() @@ -97,6 +106,14 @@ else: _register_npu_inductor_fallbacks() _register_npu_inductor_decompositons() + _register_npu_inductor_mm() + _register_npu_inductor_addmm() + _register_npu_inductor_bmm() + + patch_algorithm_selector() + patch_tuning_process() + patch_tuning_process_pool() + patch_async_compile() # register fx_pass should be put behind of _register_npu_inductor_decompositons diff --git a/torch_npu/_inductor/async_compile.py b/torch_npu/_inductor/async_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..c5b1c58fd729662552c87f134b7519a166d343b4 --- /dev/null +++ b/torch_npu/_inductor/async_compile.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import logging + +import torch + + +def patch_async_compile(): + from .codecache import NPUCodeCache + + log = logging.getLogger("torch._inductor") + + def npu(self, source_code, dst_file_ext, aot_compile=False): + log.info("NPU Kernel:\n%s", source_code) + + def task(): + if aot_compile: + # We rely on JITInductor to compile the CUDA code, + # so that we can load it into AOTInductor. + NPUCodeCache.compile(source_code, "o") + return NPUCodeCache.load(source_code, dst_file_ext)[0] + + return self.submit(task) + + torch._inductor.async_compile.AsyncCompile.npu = npu diff --git a/torch_npu/_inductor/autotune_process.py b/torch_npu/_inductor/autotune_process.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf7efd3b1acc17e78b300ee25339f0c2dcc60d8 --- /dev/null +++ b/torch_npu/_inductor/autotune_process.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import contextlib +import ctypes +import dataclasses +import functools +import logging +import os +import queue +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +from ctypes import CDLL, byref, c_size_t, c_void_p +from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List, + Optional, Sequence, Union) + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch import multiprocessing +from torch._dynamo.testing import rand_strided +from torch._inductor import config, ir +from torch._inductor.autotune_process import ( + BenchmarkRequest, NonzeroWorkspaceNotSupportedError, TensorMeta) +from torch._inductor.codecache import DLLWrapper +from torch._inductor.runtime.benchmarking import benchmarker + +from .codecache import NPUCodeCache + + +ASCEND_VISIBLE_DEVICES = "ASCEND_RT_VISIBLE_DEVICES" +EXIT_HANDLER_REGISTERED = False + +log = logging.getLogger("torch._inductor") + + +def patch_tuning_process(): + from torch._inductor import autotune_process + + autotune_process.CUDA_VISIBLE_DEVICES = ASCEND_VISIBLE_DEVICES + + +def patch_tuning_process_pool(): + from torch._inductor.autotune_process import TuningProcessPool + + def get_device_list(self) -> Sequence[Optional[int]]: + """ + Gather the list of devices to be used in the pool. + """ + if not config.autotune_multi_device: + # Don't use multiple devices + return [None] + + count = torch.npu.device_count() + + # If the user specified the visible devices in the env, use those. + if ASCEND_VISIBLE_DEVICES in os.environ: + devices = [int(d) for d in os.environ[ASCEND_VISIBLE_DEVICES].split(",")] + if len(devices) > count: + raise ValueError(f"Specified visible devices exceed the number of total devices: {devices}") + return devices + + return list(range(count)) + + TuningProcessPool.get_device_list = get_device_list + + +class NPUDeviceBenchmarkMixin: + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_npu + and tensor.device.index is not None + } + if len(device_idx_set) > 1: + raise ValueError(f"Can not mix devices: {device_idx_set}") + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.npu.current_device() + + with torch.npu.device(device_idx): + out = self._bench(fn) + torch.npu.synchronize() # shake out any NPU errors + + return out + + def _bench( + self, + fn, + warmup=25, + repeats=100, + ) -> float: + fn() + torch.npu.synchronize() + + # Estimate the runtime of the function + start_event = torch.npu.Event(enable_timing=True) + end_event = torch.npu.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.npu.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = min(max(int(warmup / estimate_ms), 1), 250) + n_repeat = min(max(int(repeats / estimate_ms), 1), 1000) + + # warm-up + for _ in range(n_warmup): + fn() + # benchmark + start_event.record() + for _ in range(n_repeat): + fn() + end_event.record() + torch.npu.synchronize() + + return start_event.elapsed_time(end_event) / n_repeat + + +class NPUBenchmarkRequest(NPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put NPU Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = NPUCodeCache.write(self.source_code, "so") + + def benchmark( + self, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + # create args and out tensor + if output_tensor is None: + input_tensors = tuple(x.to_tensor() for x in self.input_tensor_meta) + output_tensor = self.output_tensor_meta.to_tensor() + + try: + fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor) + except NonzeroWorkspaceNotSupportedError: + log.info("Skipping op due to nonzero workspace requirement") + return float("inf") + + out = self.do_bench(fn, *input_tensors, output_tensor) + return out + + def precompile(self): + # Prepopulate NPUCodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + NPUCodeCache.compile(self.source_code, "so") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [ + c_void_p(tensor.data_ptr()) + for tensor in list(input_tensors) + [output_tensor] + ] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.npu.current_stream().npu_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=output_tensor.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len({meta.name for meta in self.input_tensor_meta}) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.npu.current_stream().npu_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + run_method( + *args, # input ptrs and output ptrs + *self.extra_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.npu.synchronize() # shake out any NPU errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = NPUCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/torch_npu/_inductor/codecache.py b/torch_npu/_inductor/codecache.py index c9223e9ffd2c351097d2fbf76c89932ad74a1cf3..65b540c6b6165e771574c346ada05e369801326f 100644 --- a/torch_npu/_inductor/codecache.py +++ b/torch_npu/_inductor/codecache.py @@ -1,8 +1,14 @@ +import dataclasses import os import contextlib import functools import hashlib import json +import logging +import subprocess +import sys +import sysconfig +from time import time, time_ns from typing import ( Any, Callable, @@ -21,17 +27,36 @@ from typing import ( import torch from torch._inductor import config -from torch._inductor.codecache import CacheBase, get_lock_dir, LOCK_TIMEOUT +from torch._inductor.exc import CppCompileError +from torch._inductor.codecache import ( + CacheBase, + get_lock_dir, + write, + LOCK_TIMEOUT, + DLLWrapper, +) from torch._inductor.graph import GraphLowering +from torch._inductor.utils import ( + clear_on_fresh_inductor_cache, + is_linux, + is_windows, +) import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error +from .cpp_builder import library_paths +from . import config as npu_config +from .codegen.npu.catlass_utils import get_npu_arch, _normalize_npu_arch + empty_json = "{}" +log = logging.getLogger("torch._inductor") + @contextlib.contextmanager def lock_context(key): from filelock import FileLock + lock_dir = get_lock_dir() lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: @@ -88,6 +113,7 @@ def patch_aot_code_compiler_compile(): # which could not be skipped, so here we try to create a new npu op_json, # and clear the content of default op_json. from torch._inductor.codecache import AotCodeCompiler + AotCodeCompiler.src_compile = AotCodeCompiler.compile @classmethod @@ -100,17 +126,23 @@ def patch_aot_code_compiler_compile(): additional_files: List[str], ) -> Union[List[str], str]: result = cls.src_compile( - graph, source_code, serialized_extern_kernel_nodes, - device_type, additional_files + graph, + source_code, + serialized_extern_kernel_nodes, + device_type, + additional_files, ) generated_files = additional_files if not config.aot_inductor.package: return result - + output_so = [r for r in result if r.endswith(".so")] if len(output_so) > 1: - raise RuntimeError(f"Could not generate npu op json, because there are" - f"more than one so in generated files: {result}" + pta_error(ErrCode.INTERNAL)) + raise RuntimeError( + f"Could not generate npu op json, because there are" + f"more than one so in generated files: {result}" + + pta_error(ErrCode.INTERNAL) + ) output_so = output_so[0] key = os.path.basename(output_so)[0].replace(".", "_") dir_basename = os.path.splitext(output_so)[0] @@ -120,11 +152,225 @@ def patch_aot_code_compiler_compile(): with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) generated_files.append(extern_kernel_nodes_json) - + if serialized_extern_kernel_nodes: source_json_file = dir_basename + ".json" with open(source_json_file, "w") as f: f.write(empty_json) return generated_files + AotCodeCompiler.compile = compile_npu - \ No newline at end of file + + +def _catlass_include_paths() -> List[str]: + from .cpp_builder import get_ascend_home + + ASCEND_HOME = get_ascend_home() + catlass_path = npu_config.npu.catlass_dir + return [ + # Use realpath to get canonical absolute paths, in order not to mess up cache keys + os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp")), + os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw")), + os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw/impl")), + os.path.realpath(os.path.join(ASCEND_HOME, "compiler/tikcpp/tikcfw/interface")), + os.path.realpath(os.path.join(ASCEND_HOME, "include")), + os.path.realpath(os.path.join(ASCEND_HOME, "include/experiment/runtime")), + os.path.realpath(os.path.join(ASCEND_HOME, "include/experiment/msprof")), + os.path.realpath(os.path.join(catlass_path, "include")), + os.path.realpath(os.path.join(catlass_path, "tools/library/include")), + os.path.realpath(os.path.join(catlass_path, "tools/library/src")), + os.path.realpath(os.path.join(catlass_path, "tools/util/include")), + ] + + +def _ascend_lib_options() -> List[str]: + lpaths = library_paths(npu=True) + [sysconfig.get_config_var("LIBDIR")] + extra_ldflags: List[str] = [] + if is_linux(): + for path in lpaths: + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) + + extra_ldflags.append("-lruntime") + extra_ldflags.append("-lstdc++") + extra_ldflags.append("-lascendcl") + extra_ldflags.append("-lm") + extra_ldflags.append("-ltiling_api") + extra_ldflags.append("-lplatform") + extra_ldflags.append("-lc_sec") + extra_ldflags.append("-ldl") + extra_ldflags.append("-lnnopbase") + else: + raise NotImplementedError( + "Unsupported env, failed to find ascend libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _bisheng_host_compiler_options() -> List[str]: + return [ + "-fPIC", + "-fno-strict-aliasing", + "-fvisibility=hidden", + "-Wconversion", + ] + + +def _bisheng_compiler_options() -> List[str]: + npu_arch = _normalize_npu_arch(get_npu_arch()) + if npu_arch == "910B": + arch = "dav-c220" + else: + raise ValueError(f"Unrecognized NPU arch: {npu_arch}") + options = [ + f"--cce-aicore-arch={arch}", + "-O2", + "-std=c++17", + "-xcce", + "-mllvm -cce-aicore-stack-size=0x8000", + "-mllvm -cce-aicore-function-stack-size=0x8000", + "-mllvm -cce-aicore-record-overflow=true", + "-mllvm -cce-aicore-addr-transform", + "-mllvm -cce-aicore-dcci-insert-for-scalar=false", + "-DL2_CACHE_HINT", + ] + + if npu_config.npu.enable_debug_info: + options.extend(["--lineinfo", "-g"]) + + return options + + +def _bisheng_compiler() -> Optional[str]: + if os.path.exists(os.getenv("ASCEND_HOME_PATH")): + return os.path.realpath( + os.path.join( + os.getenv("ASCEND_HOME_PATH", ""), "compiler/ccec_compiler/bin/bisheng" + ) + ) + return "bisheng" + + +def npu_compile_command( + src_files: List[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[List[str]] = None, +) -> str: + if extra_args is None: + extra_args = [] + include_paths = _catlass_include_paths() + ascend_lib_options = _ascend_lib_options() + bisheng_host_compiler_options = _bisheng_host_compiler_options() + bisheng_compiler_options = _bisheng_compiler_options() + options = ( + bisheng_compiler_options + + extra_args + + [ + f"-Xcompiler {opt}" if "=" in opt else f"-Xcompiler={opt}" + for opt in bisheng_host_compiler_options + ] + + ["-I" + path for path in include_paths] + + ascend_lib_options + ) + src_file = " ".join(src_files) + res = "" + if dst_file_ext == "o": + res = f"{_bisheng_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}" + elif dst_file_ext == "so": + options.append("-shared") + res = f"{_bisheng_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + elif dst_file_ext == "exe": + res = f"{_bisheng_compiler()} {' '.join(options)} -o {dst_file} {src_file}" + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + log.debug("Bisheng command: %s", res) + return res + + +class NPUCompileError(CppCompileError): + pass + + +@clear_on_fresh_inductor_cache +class NPUCodeCache: + @dataclasses.dataclass + class CacheEntry: + input_path: str + output_path: str + + cache: Dict[str, CacheEntry] = {} + cache_clear = staticmethod(cache.clear) + _SOURCE_CODE_SUFFIX = "cpp" + + @classmethod + def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]: + """ + Writes source code into a file with dst_file_ext as the file extension. + Returns the hash key of source code, and the path to the file. + """ + + npu_command = repr( + npu_compile_command(["dummy_input"], "dummy_output", dst_file_ext) + ) + key, input_path = write(source_code, cls._SOURCE_CODE_SUFFIX, extra=npu_command) + return key, input_path + + @classmethod + def compile( + cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None + ) -> Tuple[str, str, str]: + """ + Compiles NPU source_code into a file with dst_file_ext extension. + Returns a tuple of dst_file_path, hash_key, source_code_path + """ + key, input_path = cls.write(source_code, dst_file_ext) + if key not in cls.cache: + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if not os.path.exists(output_path): + cmd = npu_compile_command( + [input_path], output_path, dst_file_ext, extra_args + ) + start_time = time() + log.debug("NPU Compilation: %s", cmd) + cmd_parts = cmd.split(" ") + try: + subprocess.check_output( + cmd_parts, stderr=subprocess.STDOUT, env=os.environ + ) + except subprocess.CalledProcessError as error: + raise NPUCompileError(cmd_parts, error.output) from error + end_time = time() + log_duration_msg = f"NPU Compilation took {end_time - start_time} seconds. Compile command: {cmd}" + log.info(log_duration_msg) + else: + log.debug( + "NPU Compilation skipped: %s since output already exists", + input_path, + ) + cls.cache[key] = NPUCodeCache.CacheEntry(input_path, output_path) + + return (cls.cache[key].output_path, key, input_path) + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + return (DLLWrapper(dst_file_path), hash_key, source_code_path) diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/__init__.py b/torch_npu/_inductor/codegen/npu/catlass_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/gemm_autotune.py b/torch_npu/_inductor/codegen/npu/catlass_library/gemm_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..adf8f2d70b3925f493dc922edd90f29d1e528942 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_library/gemm_autotune.py @@ -0,0 +1,213 @@ +import warnings + +from torch_npu._inductor import config +from .library import * + + +def _generate_tile_desc(Block_M, Block_N, Block_K_l1, Block_K_l0): + return TileDesription( + [Block_M, Block_N, Block_K_l1], [Block_M, Block_N, Block_K_l0] + ) + + +def _is_support_tile_autotune(op_kind): + return op_kind not in [ + GemmKind.MatmulBias, + GemmKind.Gemm, + ] + + +def _is_support_tuning_swizzle(op_kind): + return op_kind not in [ + GemmKind.Gemm, + ] + + +_default_tile_descs = { + "default": { + 2: [ + _generate_tile_desc(256, 128, 256, 64), + _generate_tile_desc(128, 256, 256, 64), + ], + 4: [ + _generate_tile_desc(128, 128, 256, 64), + _generate_tile_desc(128, 128, 128, 64), + _generate_tile_desc(256, 64, 64, 32), + ], + }, + # MatmulBias occupy more L1 cache + GemmKind.MatmulBias: { + 4: [ + _generate_tile_desc(112, 128, 256, 64), + ], + 2: [ + _generate_tile_desc(256, 128, 256, 64), + _generate_tile_desc(128, 256, 256, 64), + ], + }, + GemmKind.Gemm: { + 4: [ + _generate_tile_desc(128, 128, 128, 64), + ], + 2: [ + _generate_tile_desc(256, 128, 128, 64), + _generate_tile_desc(128, 256, 128, 64), + _generate_tile_desc(128, 128, 128, 64), + ], + }, +} + +_default_block_swizzles = { + "default": [ + BlockSwizzle.GemmIdentityBlockSwizzle_30, + BlockSwizzle.GemmIdentityBlockSwizzle_31, + ], + GemmKind.OptimizedMatmul: [ + BlockSwizzle.GemmIdentityBlockSwizzle_30, + ], +} + + +class Config: + __slots__ = ["tile_desc", "block_swizzle"] + + def __init__(self, tile_desc, blk_swizzle): + self.tile_desc = tile_desc + self.block_swizzle = blk_swizzle + + +class TileAutotune: + + stages: int = 2 + + def __init__(self, arch_type): + self.arch_type = arch_type + self.init_l1_l0_size(arch_type) + + def init_l1_l0_size(self, arch_type): + if arch_type != ArchType.A2: + warnings.warn( + f"Unknown arch type to get specific tile size: {arch_type}." + f"Will use the default tile size to generate tile configs." + ) + arch_type = ArchType.A2 + + if arch_type == ArchType.A2: + self.L1Size = 512 * 1024 // self.stages + self.L0CSize = 128 * 1024 // self.stages + self.L0ASize = 64 * 1024 // self.stages + self.L0BSize = 64 * 1024 // self.stages + self.L0Size = min(self.L0ASize, self.L0BSize) + + @staticmethod + def floor_power_of_2(n): + if n <= 1: + return n + return 1 << (n.bit_length() - 1) + + def gen_tile_configs(self, op_kind, dtype_size, shape_desc): + if not _is_support_tile_autotune(op_kind): + # currently autotune does not support MatmulBias & Gemm + return [] + + configs = [] + M, N, K = shape_desc + min_mn, max_mn = min(M, N), max(M, N) + + # helper method + def add_block_sizes_to_configs(BLOCK_min_mn): + BLOCK_max_mn_start = self.floor_power_of_2( + min(max_mn, self.L0CSize // BLOCK_min_mn // 4) + ) + for j in [1, 2]: + BLOCK_max_mn = max(BLOCK_max_mn_start // j, 16) + BLOCK_K_start = self.floor_power_of_2( + min(K, self.L1Size // (BLOCK_min_mn + BLOCK_max_mn) // dtype_size) + ) + + BLOCK_K = max(BLOCK_K_start, 16) + SUB_BLOCK_K = ( + self.L0Size // max(BLOCK_min_mn, BLOCK_max_mn) // dtype_size + ) + SUB_BLOCK_K = max(min(SUB_BLOCK_K, BLOCK_K), 16) + + if M < N: + BLOCK_M, BLOCK_N = BLOCK_min_mn, BLOCK_max_mn + else: + BLOCK_M, BLOCK_N = BLOCK_max_mn, BLOCK_min_mn + + configs.append( + _generate_tile_desc(BLOCK_M, BLOCK_N, BLOCK_K, SUB_BLOCK_K) + ) + + if min_mn < 128: + # non-split case + min_mn = max(min_mn, 16) + add_block_sizes_to_configs(min_mn) + + BLOCK_min_mn_start = min(self.floor_power_of_2(min_mn), 128) + for i in [1, 2]: + BLOCK_min_mn = max(BLOCK_min_mn_start // i, 16) + add_block_sizes_to_configs(BLOCK_min_mn) + + return configs + + +class GemmAutotune: + def __init__(self, arch_type): + self.arch_type = arch_type + self.tile_autotune = ( + TileAutotune(arch_type) if config.npu.catlass_use_gemm_autotune else None + ) + self.caches = {} + + def gen_configs(self, op_kind, dtype, shape_desc): + dtype_size = DataTypeSize[dtype] // 8 + + key = (op_kind, dtype_size, shape_desc) + if key in self.caches: + return self.caches[key] + + tile_cfgs = self._get_tile_configs(op_kind, dtype_size, shape_desc) + blk_swizzle_cfgs = self._get_default_block_swizzle_configs(op_kind) + + cfgs = [] + if blk_swizzle_cfgs is not None: + for tile_desc in tile_cfgs: + for blk_swizzle in blk_swizzle_cfgs: + cfgs.append(Config(tile_desc, blk_swizzle)) + else: + cfgs = [Config(tile_desc, None) for tile_desc in tile_cfgs] + + self.caches[key] = cfgs + return cfgs + + def _get_tile_configs(self, op_kind, dtype_size, shape_desc): + if self.tile_autotune is not None and shape_desc is not None: + res = self.tile_autotune.gen_tile_configs(op_kind, dtype_size, shape_desc) + if res: + return res + + return _default_tile_descs.get(op_kind, _default_tile_descs["default"])[ + dtype_size + ] + + def _get_default_block_swizzle_configs(self, op_kind): + if _is_support_tuning_swizzle(op_kind): + return _default_block_swizzles.get( + op_kind, _default_block_swizzles["default"] + ) + return None + + +def get_gemm_autotune(arch_type=None): + if not hasattr(get_gemm_autotune, "instance"): + if arch_type is None: + arch_type = ArchType.A2 + get_gemm_autotune.instance = GemmAutotune(arch_type) + return get_gemm_autotune.instance + + +def generate_configs(arch_type, op_kind, data_type, shape_desc): + autotune = get_gemm_autotune(arch_type) + return autotune.gen_configs(op_kind, data_type, shape_desc) diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/gemm_operation.py b/torch_npu/_inductor/codegen/npu/catlass_library/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..f31a3532842f651d74b618e627dace884a7a3df5 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_library/gemm_operation.py @@ -0,0 +1,111 @@ +import types + +from .library import * + + +class GemmOperation: + # + def __init__( + self, + gemm_kind, + arch, + dispatch_policy, + tile_description, + A, + B, + C, + element_epilogue, + block_swizzle=None, + block_epilogue=None, + D=None, + ): + self.gemm_kind = gemm_kind + self.arch = arch + self.dispatch_policy = dispatch_policy + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.D = D + + if self.D is None: + self.D = self.C + + self.element_epilogue = element_epilogue + self.block_swizzle = block_swizzle + self.block_epilogue = block_epilogue + + def accumulator_type(self): + return self.element_epilogue + + def arch_name(self): + return ArchTypeNames[self.arch] + + # Generates a short string representing the AB layout tags (e.g., nt or tn) + def layout_name(self): + return "%s%s" % ( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ) + + def dispatch_policy_name(self): + return f"{ShortDispatchPolicyNames[self.dispatch_policy]}" + + def block_swizzle_name(self): + return f"{ShortBlockSwizzleNames[self.block_swizzle]}" + + # Generates a string representing the element type. + def extended_name(self): + extended_name = ( + "{element_a}_{element_b}_{element_c}_{element_epi}_{element_d}".format( + element_a=DataTypeNames[self.A.element], + element_b=DataTypeNames[self.B.element], + element_c=DataTypeNames[self.C.element], + element_epi=DataTypeNames[self.element_epilogue], + element_d=DataTypeNames[self.D.element], + ) + ) + return extended_name + + # Generate the full kernel function name + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + tile_desc = self.tile_description.procedural_name() + swizzle_name = ( + "" if self.block_swizzle is None else f"_{self.block_swizzle_name()}" + ) + return "catlass_{p}_{op}_{dp}{sw}_{ex}_{td}_{l}".format( + p=self.arch_name(), + op=self.gemm_typename(), + dp=self.dispatch_policy_name(), + sw=swizzle_name, + ex=self.extended_name(), + td=tile_desc, + l=self.layout_name(), + ) + + def configuration_name(self): + return self.procedural_name() + + def arch_typename(self): + return ArchTypeTag[self.arch] + + def gemm_typename(self): + return GemmKindNames[self.gemm_kind] + + def swizzle_typename(self): + return "" if self.block_swizzle is None else BlockSwizzleTag[self.block_swizzle] + + def dispatch_policy_typename(self): + return DispatchPolicyTag[self.dispatch_policy] + + +def _make_layouttypname_func(attr_name: str): + def _func(self): + return LayoutTag[getattr(self, attr_name).layout] + + return _func + + +for name in ["A", "B", "C", "D"]: + setattr(GemmOperation, f"layout{name}_typename", _make_layouttypname_func(name)) diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/generator.py b/torch_npu/_inductor/codegen/npu/catlass_library/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d94a0a0d43133ce3cbc1d1cce640a2d1a8cad4 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_library/generator.py @@ -0,0 +1,169 @@ +from .library import * +from .gemm_operation import GemmOperation +from .gemm_autotune import generate_configs + + +def CreateGemmOperator( + manifest, + arch, + gemm_kind, + layouts, + dispatch_policies, + tile_description, + data_type, + block_swizzle=None, + shape_desc=None, +): + + element_a, element_b, element_c, element_epilogue = data_type + + for layout in layouts: + for dispatch_policy in dispatch_policies: + A = TensorDescription(element_a, layout[0]) + B = TensorDescription(element_b, layout[1]) + C = TensorDescription(element_c, layout[2]) + + new_operation = GemmOperation( + gemm_kind, + arch, + dispatch_policy, + tile_description, + A, + B, + C, + element_epilogue, + block_swizzle, + ) + manifest.append(new_operation, shape_desc) + + +def Generate910B(manifest, shape_desc=None): + if manifest.get_ops(shape_desc) is not None: + # use cached ops + return + + Generate910B_MM(manifest, shape_desc) + Generate910B_GEMM(manifest, shape_desc) + + +def Generate910B_MM(manifest, shape_desc): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + dispatch_policies = [ + DispatchPolicyType.MmadA2PingPongEUnitFlag, # only for BasicMatmul + ] + + # dtype of A, B, C and epilogue + data_types = [ + (DataType.f16, DataType.f16, DataType.f32, DataType.f32), + (DataType.f16, DataType.f16, DataType.f16, DataType.f32), + (DataType.bf16, DataType.bf16, DataType.f32, DataType.f32), + (DataType.bf16, DataType.bf16, DataType.bf16, DataType.f32), + (DataType.f32, DataType.f32, DataType.f32, DataType.f32), + ] + + gemm_kinds = [ + GemmKind.BasicMatmul, + GemmKind.BatchedMatmul, + ] + + arch = ArchType.A2 + for data_type in data_types: + for gemm_kind in gemm_kinds: + configs = generate_configs(arch, gemm_kind, data_type[0], shape_desc) + for cfg in configs: + CreateGemmOperator( + manifest, + arch, + gemm_kind, + layouts, + dispatch_policies, + cfg.tile_desc, + data_type, + cfg.block_swizzle, + shape_desc, + ) + + # MatmulBias op + gemm_kinds = [ + GemmKind.MatmulBias, + ] + dispatch_policies = [ + DispatchPolicyType.MmadA2PingPongBiasEUnitFlag, # only for MatmulBias + ] + for data_type in data_types: + for gemm_kind in gemm_kinds: + configs = generate_configs(arch, gemm_kind, data_type[0], shape_desc) + for cfg in configs: + CreateGemmOperator( + manifest, + arch, + gemm_kind, + layouts, + dispatch_policies, + cfg.tile_desc, + data_type, + cfg.block_swizzle, + ) + + # Optimized Matmul + gemm_kinds = [ + GemmKind.OptimizedMatmul, + ] + dispatch_policies = [ + DispatchPolicyType.MmadA2PreloadEUnitFlagESuffleK, # only for OptimizedMatmul + ] + for data_type in data_types: + for gemm_kind in gemm_kinds: + configs = generate_configs(arch, gemm_kind, data_type[0], shape_desc) + for cfg in configs: + CreateGemmOperator( + manifest, + arch, + gemm_kind, + layouts, + dispatch_policies, + cfg.tile_desc, + data_type, + cfg.block_swizzle, + ) + + +def Generate910B_GEMM(manifest, shape_desc): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + dispatch_policies = [ + DispatchPolicyType.GemmA2EUnitFlagEShuffleKEABBA, + ] + + # dtype of A, B, C and epilogue + # GEMM does not support bfloat16 + data_types = [ + (DataType.f16, DataType.f16, DataType.f32, DataType.f32), + (DataType.f16, DataType.f16, DataType.f16, DataType.f32), + (DataType.f32, DataType.f32, DataType.f32, DataType.f32), + ] + + arch = ArchType.A2 + for data_type in data_types: + configs = generate_configs(arch, GemmKind.Gemm, data_type[0], shape_desc) + for cfg in configs: + CreateGemmOperator( + manifest, + arch, + GemmKind.Gemm, + layouts, + dispatch_policies, + cfg.tile_desc, + data_type, + ) diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/library.py b/torch_npu/_inductor/codegen/npu/catlass_library/library.py new file mode 100644 index 0000000000000000000000000000000000000000..d490fc0a5f735e77098a9c3cc146174913b46499 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_library/library.py @@ -0,0 +1,252 @@ +import enum + +from enum import auto as enum_auto + + +class DataType(enum.Enum): + void = enum_auto() # primary used to disable C tensor for epilogues + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + f64 = enum_auto() + invalid = enum_auto() + + +# +DataTypeNames = { + DataType.void: "void", + DataType.u8: "u8", + DataType.u16: "u16", + DataType.u32: "u32", + DataType.u64: "u64", + DataType.s8: "s8", + DataType.s16: "s16", + DataType.s32: "s32", + DataType.s64: "s64", + DataType.f16: "f16", + DataType.bf16: "bf16", + DataType.f32: "f32", + DataType.f64: "f64", +} + + +DataTypeTag = { + DataType.void: "void", + DataType.u8: "uint8_t", + DataType.u16: "uint16_t", + DataType.u32: "uint32_t", + DataType.u64: "uint64_t", + DataType.s8: "int8_t", + DataType.s16: "int16_t", + DataType.s32: "int32_t", + DataType.s64: "int64_t", + DataType.f16: "half", + DataType.bf16: "bfloat16_t", + DataType.f32: "float", + DataType.f64: "double", +} + + +DataTypeSize = { + DataType.void: 0, + DataType.u8: 8, + DataType.u16: 16, + DataType.u32: 32, + DataType.u64: 64, + DataType.s8: 8, + DataType.s16: 16, + DataType.s32: 32, + DataType.s64: 64, + DataType.f16: 16, + DataType.bf16: 16, + DataType.f32: 32, + DataType.f64: 64, +} + + +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + VectorLayout = enum_auto() + + +# +LayoutTag = { + LayoutType.ColumnMajor: "Catlass::layout::ColumnMajor", + LayoutType.RowMajor: "Catlass::layout::RowMajor", + LayoutType.VectorLayout: "Catlass::layout::VectorLayout", +} + + +# +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: "n", + LayoutType.RowMajor: "t", + LayoutType.VectorLayout: "v", +} + + +class DispatchPolicyType(enum.Enum): + MmadA2 = enum_auto() + MmadA2Async = enum_auto() + + # Matmul + MmadA2PingPongDUnitFlag = enum_auto() + MmadA2PingPongEUnitFlag = enum_auto() + + # Optimized Matmul + MmadA2PreloadDUnitFlagDSuffleK = enum_auto() + MmadA2PreloadEUnitFlagESuffleK = enum_auto() + MmadA2PreloadEUnitFlagDSuffleK = enum_auto() + MmadA2PreloadDUnitFlagESuffleK = enum_auto() + + # MatmulBias + MmadA2PingPongBiasDUnitFlag = enum_auto() + MmadA2PingPongBiasEUnitFlag = enum_auto() + + # GEMM + GemmA2EUnitFlagDShuffleKDABBA = enum_auto() + GemmA2EUnitFlagEShuffleKDABBA = enum_auto() + GemmA2EUnitFlagEShuffleKEABBA = enum_auto() + + +DispatchPolicyTag = { + DispatchPolicyType.MmadA2: "Catlass::Gemm::MmadAtlasA2", + DispatchPolicyType.MmadA2Async: "Catlass::Gemm:MmadAtlasA2Async", + # Matmul + DispatchPolicyType.MmadA2PingPongDUnitFlag: "Catlass::Gemm::MmadAtlasA2Pingpong", + DispatchPolicyType.MmadA2PingPongEUnitFlag: "Catlass::Gemm::MmadAtlasA2Pingpong", + DispatchPolicyType.MmadA2PreloadDUnitFlagDSuffleK: "Catlass::Gemm::MmadAtlasA2Preload", + DispatchPolicyType.MmadA2PreloadEUnitFlagESuffleK: "Catlass::Gemm::MmadAtlasA2Preload", + DispatchPolicyType.MmadA2PreloadEUnitFlagDSuffleK: "Catlass::Gemm::MmadAtlasA2Preload", + DispatchPolicyType.MmadA2PreloadDUnitFlagESuffleK: "Catlass::Gemm::MmadAtlasA2Preload", + # MatmulBias + DispatchPolicyType.MmadA2PingPongBiasDUnitFlag: "Catlass::Gemm::MmadAtlasA2PingpongBias", + DispatchPolicyType.MmadA2PingPongBiasEUnitFlag: "Catlass::Gemm::MmadAtlasA2PingpongBias", + # GEMM + DispatchPolicyType.GemmA2EUnitFlagDShuffleKDABBA: "Catlass::Gemm::GemmAtlasA2", + DispatchPolicyType.GemmA2EUnitFlagEShuffleKDABBA: "Catlass::Gemm::GemmAtlasA2", + DispatchPolicyType.GemmA2EUnitFlagEShuffleKEABBA: "Catlass::Gemm::GemmAtlasA2", +} + +ShortDispatchPolicyNames = { + DispatchPolicyType.MmadA2: "dp-mmada2", + DispatchPolicyType.MmadA2Async: "dp-mmada2-async", + # Matmul + DispatchPolicyType.MmadA2PingPongDUnitFlag: "dp-mmada2-pp-f", + DispatchPolicyType.MmadA2PingPongEUnitFlag: "dp-mmada2-pp-t", + DispatchPolicyType.MmadA2PreloadDUnitFlagDSuffleK: "dp-mmada2-pl-ff", + DispatchPolicyType.MmadA2PreloadDUnitFlagESuffleK: "dp-mmada2-pl-ft", + DispatchPolicyType.MmadA2PreloadEUnitFlagESuffleK: "dp-mmada2-pl-tt", + DispatchPolicyType.MmadA2PreloadEUnitFlagDSuffleK: "dp-mmada2-pl-tf", + # MatmulBias + DispatchPolicyType.MmadA2PingPongBiasDUnitFlag: "dp-mmada2-ppb-f", + DispatchPolicyType.MmadA2PingPongBiasEUnitFlag: "dp-mmada2-ppb-t", + # GEMM + DispatchPolicyType.GemmA2EUnitFlagDShuffleKDABBA: "dp-gemm-tff", + DispatchPolicyType.GemmA2EUnitFlagEShuffleKDABBA: "dp-gemm-ttf", + DispatchPolicyType.GemmA2EUnitFlagEShuffleKEABBA: "dp-gemm-ttt", +} + + +class BlockSwizzle(enum.Enum): + GemmIdentityBlockSwizzle_30 = enum_auto() + GemmIdentityBlockSwizzle_31 = enum_auto() + GemmIdentityBlockSwizzle_40 = enum_auto() + GemmIdentityBlockSwizzle_41 = enum_auto() + + +BlockSwizzleTag = { + BlockSwizzle.GemmIdentityBlockSwizzle_30: "Catlass::Gemm::Block::GemmIdentityBlockSwizzle<3, 0>", + BlockSwizzle.GemmIdentityBlockSwizzle_31: "Catlass::Gemm::Block::GemmIdentityBlockSwizzle<3, 1>", + BlockSwizzle.GemmIdentityBlockSwizzle_40: "Catlass::Gemm::Block::GemmIdentityBlockSwizzle<4, 0>", + BlockSwizzle.GemmIdentityBlockSwizzle_41: "Catlass::Gemm::Block::GemmIdentityBlockSwizzle<4, 1>", +} + + +ShortBlockSwizzleNames = { + BlockSwizzle.GemmIdentityBlockSwizzle_30: "idbs30", + BlockSwizzle.GemmIdentityBlockSwizzle_31: "idbs31", + BlockSwizzle.GemmIdentityBlockSwizzle_40: "idbs40", + BlockSwizzle.GemmIdentityBlockSwizzle_41: "idbs41", +} + + +class GemmKind(enum.Enum): + # Standard Matmul + BasicMatmul = enum_auto() + OptimizedMatmul = enum_auto() + BatchedMatmul = enum_auto() + MatmulBias = enum_auto() + + # GEMM + Gemm = enum_auto() + Group = enum_auto() + + +GemmKindNames = { + GemmKind.BasicMatmul: "BasicMatmul", + GemmKind.OptimizedMatmul: "OptimizedMatmul", + GemmKind.BatchedMatmul: "BatchedMatmul", + GemmKind.MatmulBias: "MatmulBias", + GemmKind.Gemm: "KernelGemm", + GemmKind.Group: "KernelGroupGemm", +} + + +class ArchType(enum.Enum): + A2 = enum_auto() + + +ArchTypeTag = { + ArchType.A2: "Arch::AtlasA2", +} + + +ArchTypeNames = { + ArchType.A2: "A2", +} + + +class TensorDescription: + def __init__(self, element, layout): + self.element = element + self.layout = layout + + +class TileDesription: + def __init__(self, L1TileShape, L0TileShape): + self.l1_tile_shape = L1TileShape + self.l0_tile_shape = L0TileShape + + def procedural_name(self): + return "l1_{l1m}x{l1n}x{l1k}_l0_{l0m}x{l0n}x{l0k}".format( + l1m=self.l1_tile_shape[0], + l1n=self.l1_tile_shape[1], + l1k=self.l1_tile_shape[2], + l0m=self.l0_tile_shape[0], + l0n=self.l0_tile_shape[1], + l0k=self.l0_tile_shape[2], + ) + + def l1_tile_typename(self): + return "GemmShape<{l1m}, {l1n}, {l1k}>".format( + l1m=self.l1_tile_shape[0], + l1n=self.l1_tile_shape[1], + l1k=self.l1_tile_shape[2], + ) + + def l0_tile_typename(self): + return "GemmShape<{l0m}, {l0n}, {l0k}>".format( + l0m=self.l0_tile_shape[0], + l0n=self.l0_tile_shape[1], + l0k=self.l0_tile_shape[2], + ) diff --git a/torch_npu/_inductor/codegen/npu/catlass_library/manifest.py b/torch_npu/_inductor/codegen/npu/catlass_library/manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..9630e3440cfc04d82199435eb39f51165733ffb5 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_library/manifest.py @@ -0,0 +1,40 @@ +from typing import List + +from .gemm_operation import GemmOperation + + +class Manifest: + + # + def __init__(self, args=None): + self.operations = {} + self.args = args + self.operation_count = 0 + + @staticmethod + def _get_shape_key(shape_desc) -> str: + if shape_desc is None or not isinstance(shape_desc, tuple): + shape_desc = "default" + else: + shape_desc = "x".join(str(dim) for dim in shape_desc) + + def append(self, operation, shape_desc) -> None: + """ + Inserts the operation. + + shape_desc -> [] + """ + self.operations.setdefault(self._get_shape_key(shape_desc), []).append( + operation + ) + self.operation_count += 1 + + def get_ops(self, shape_desc) -> List[GemmOperation]: + key = self._get_shape_key(shape_desc) + if key in self.operations: + return self.operations[key] + else: + return None + + +manifest = Manifest() diff --git a/torch_npu/_inductor/codegen/npu/catlass_utils.py b/torch_npu/_inductor/codegen/npu/catlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..000be9b24f22a2aa309c053d81df799f034ba621 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/catlass_utils.py @@ -0,0 +1,122 @@ +import functools +import logging +import os +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import torch + +from torch._inductor.runtime.runtime_utils import cache_dir +from torch._inductor.virtualized import V + +from . import catlass_library + + +log = logging.getLogger("torch._inductor") + + +def get_npu_arch() -> Optional[str]: + try: + from ... import config as npu_config + + npu_arch = npu_config.target.arch + return npu_arch + except Exception as e: + log.error("Error getting npu arch: %s", e) + return None + + +def _normalize_npu_arch(arch: str) -> str: + if "910B" in arch or arch.startswith("Ascend910_93"): + return "910B" + else: + raise NotImplementedError(f"Unsupported npu arch: {arch}") + + +@functools.lru_cache(None) +def _gen_ops_cached(arch: str, shape_desc=None) -> List[Any]: + from .catlass_library import generator as catlass_generator + from .catlass_library.manifest import manifest + + if arch is None: + log.error( + "Cannot detect npu arch %s. " + "Will discard all catlass ops. " + "Please consider setting _inductor.npu.arch configs.", + arch, + ) + return [] + + arch = _normalize_npu_arch(arch) + + try: + func = getattr(catlass_generator, "Generate" + arch) + func(manifest, shape_desc) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current catlass lib." + ) from e + + return manifest.get_ops(shape_desc) + + +def gen_ops(shape_desc=None) -> List[Any]: + """ + Generates all supported CATLASS Gemm operations for M, N, K + """ + arch = get_npu_arch() + return _gen_ops_cached(arch, shape_desc) + + +def torch_dtype_to_catlass_type( + torch_dtype: torch.dtype, +) -> "catlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + if torch_dtype == torch.float: + return catlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return catlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return catlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +def dtype_match( + torch_dtype: Optional[torch.dtype], + catlass_dtype: "catlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + if torch_dtype == torch.float: + return catlass_dtype == catlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return catlass_dtype == catlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return catlass_dtype == catlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return catlass_dtype == catlass_library.library.DataType.s8 + elif torch.dtype == torch.uint8: + return catlass_dtype == catlass_library.library.DataType.u8 + elif torch.dtype == torch.int32: + return catlass_dtype == catlass_library.library.DataType.s32 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: List[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + + if torch_dtype in {torch.half, torch.bfloat16, torch.float}: + return torch.float + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") diff --git a/torch_npu/_inductor/codegen/npu/gemm_template.py b/torch_npu/_inductor/codegen/npu/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b97acf820994d023f50efa1a940783db6e2cfb --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/gemm_template.py @@ -0,0 +1,1377 @@ +import copy +import enum +import logging +import re +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +import sympy +from torch._inductor import ir +from torch._inductor.codegen.common import IndentedBuffer +from torch._inductor.ir import (Buffer, ChoiceCaller, FixedLayout, IRNode, + Layout, ReinterpretView) +from torch._inductor.utils import is_dynamic +from torch._inductor.virtualized import V + +from ...config import npu as inductor_npu_config +from . import catlass_utils +from .catlass_library import library as catlass_lib +from .catlass_library.gemm_operation import GemmOperation +from .npu_kernel import NPUTemplateBuffer, NPUTemplateKernel +from .npu_template import CATLASSTemplate + +log = logging.getLogger("torch._inductor") + + +# Optimized Matmul template +OPT_MM_TEMPLATE_CATLASS_1X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. + +template< + class ArchTag, + class AType, + class BType, + class CType, + class BiasType = void +> +struct TileCopyOpt : public Catlass::Gemm::Tile::TileCopy { + using Base = Catlass::Gemm::Tile::TileCopy; + using ElementA = typename Base::ElementA; + using ElementB = typename Base::ElementB; + using ElementAccumulator = typename Base::ElementAccumulator; + + // When matrix A is row-major, if the number of rows in matrix A is less than 16, + // using the CopyGmToL1IntervalDataCopy method can improve the transfer efficiency. + // The situation is similar for matrix B. If the above conditions are met, + // please uncomment the following and comment out the original matrix A transfer method + + // using CopyGmToL1A = Gemm::Tile::CopyGmToL1IntervalDataCopy; + + using CopyGmToL1A = typename Base::CopyGmToL1A; + using CopyGmToL1B = typename Base::CopyGmToL1B; + + using CopyL1ToL0A = typename Base::CopyL1ToL0A; + using CopyL1ToL0B = typename Base::CopyL1ToL0B; + + using CopyL0CToGm = typename Base::CopyL0CToGm; + using BiasTypeSelector = typename Base::BiasTypeSelector; + using CopyGmToL1Bias = typename Base::CopyGmToL1Bias; + using CopyL1ToBT = typename Base::CopyL1ToBT; +}; + +{{template.render_gemm_arguments(op_instance, argument_template, epilogue_template, + X, W, Bias, Y, alpha, beta, kernel)}} + +template +int LaunchGemmKernelImpl( + const GemmCoord& problemShape, + const LayoutA& layoutA, const LayoutB& layoutB, + uint8_t* deviceA, uint8_t* deviceB, uint8_t* deviceC, + size_t* workspace_size, uint8_t* workspace, aclrtStream stream) +{ + using TileCopy = TileCopyOpt, + std::conditional_t, + CType>; + + using BlockMmadOpt = Gemm::Block::BlockMmad< + DispatchPolicy, L1TileShape, L0TileShape, + std::conditional_t, + std::conditional_t, + CType, void, TileCopy>; + + using GemmKernel = Gemm::Kernel::{{op_instance.gemm_typename()}}< + std::conditional_t, + std::conditional_t, + BlockMmadOpt, BlockEpilogue, + std::conditional_t>; + + {{kernel_arguments}} + + using GemmAdapter = Gemm::Device::DeviceGemm; + GemmAdapter gemm_op; + + if (workspace_size) { + *workspace_size = gemm_op.GetWorkspaceSize(arguments); + return 0; + } + + {{ffts_addr_prepare}} + auto aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic(); + { + auto status = gemm_op.CanImplement(arguments); + CATLASS_CHECK(status); + } + { + auto status = gemm_op.Initialize(arguments, workspace); + CATLASS_CHECK(status); + } + { + auto status = {{kernel_call}} + CATLASS_CHECK(status); + } + + return 0; +} + +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + uint32_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + uint32_t m = {{kernel.size(X, -2)}}; + uint32_t k = {{kernel.size(X, -1)}}; + uint32_t n = {{kernel.size(W, -1)}}; + + GemmCoord problemShape{m, n, k}; + + // Define the layout of each matrix + LayoutA layoutA = LayoutA::template MakeLayout(m, k); + LayoutB layoutB = LayoutB::template MakeLayout(k, n); + LayoutC layoutC = LayoutC::template MakeLayout(m, n); + + uint8_t* deviceA = {{template.catlass_type_cast(X, kernel.ptr(X))}}; + uint8_t* deviceB = {{template.catlass_type_cast(W, kernel.ptr(W))}}; + uint8_t* deviceBias = {{template.catlass_type_cast(Bias, kernel.ptr(Bias))}}; + uint8_t* deviceC = {{template.catlass_type_cast(Y, kernel.ptr(Y))}}; + + bool isNeedPaddingA = IsNeedPadding(layoutA, alignByElement); + bool isNeedPaddingB = IsNeedPadding(layoutB, alignByElement); + + if (m > n) { + if (isNeedPaddingA && isNeedPaddingB) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else if (isNeedPaddingA) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else if (isNeedPaddingB) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } + } else { + if (isNeedPaddingA && isNeedPaddingB) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else if (isNeedPaddingA) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else if (isNeedPaddingB) { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } else { + LaunchGemmKernelImpl(problemShape, layoutA, layoutB, + deviceA, deviceB, deviceC, workspace_size, workspace, stream); + } + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + + +OPT_MM_ARGS_CATLASS_1X = r""" + // Initialize GemmUniversal1xInstance arguments. + + // Define ArchTag + using ArchTag = {{op_instance.arch_typename()}}; + + using ElementA = {{kernel.catlass_dtype(X)}}; + using ElementB = {{kernel.catlass_dtype(W)}}; + using ElementC = {{kernel.catlass_dtype(Y)}}; + + constexpr uint32_t alignByByte = 512; + constexpr uint32_t alignByElement = alignByByte / sizeof(ElementC); + + // Define the Layout + using LayoutA = {{op_instance.layoutA_typename()}}; + using LayoutB = {{op_instance.layoutB_typename()}}; + using LayoutC = {{op_instance.layoutC_typename()}}; + using LayoutBias = Catlass::layout::VectorLayout; + + using AType = Gemm::GemmType; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + // Define padding layout + static const uint32_t COMPUTE_LENGTH_A = 96 * 1024 / sizeof(ElementA); + static const uint32_t COMPUTE_LENGTH_B = 96 * 1024 / sizeof(ElementB); + using PaddingTag = Catlass::Gemm::Kernel::PaddingTag; + constexpr PaddingTag paddingTagA = (std::is_same_v || std::is_same_v ? + PaddingTag::NO_PADDING : PaddingTag::PADDING_BLOCK_ND); + constexpr PaddingTag paddingTagB = (std::is_same_v || std::is_same_v ? + PaddingTag::NO_PADDING : PaddingTag::PADDING_BLOCK_ND); + using PaddingBuilderA = Catlass::Gemm::Kernel::PaddingBuilder< + ArchTag, ElementA, LayoutA, COMPUTE_LENGTH_A, paddingTagA>; + using GlobalPaddingA = PaddingBuilderA::Padding; + using PaddingBuilderB = Catlass::Gemm::Kernel::PaddingBuilder< + ArchTag, ElementB, LayoutB, COMPUTE_LENGTH_B, paddingTagB>; + using GlobalPaddingB = PaddingBuilderA::Padding; + + using LayoutMmadA = typename PaddingBuilderA::LayoutAfterPadding; + using LayoutMmadB = typename PaddingBuilderB::LayoutAfterPadding; + using ATypePadding = Gemm::GemmType; + using BTypePadding = Gemm::GemmType; + + using DispatchPolicy = {{op_instance.dispatch_policy_typename()}}; + using L1TileShape = {{op_instance.tile_description.l1_tile_typename()}}; + using L0TileShape = {{op_instance.tile_description.l0_tile_typename()}}; + using BlockScheduler30 = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 0>; + using BlockScheduler31 = typename Gemm::Block::GemmIdentityBlockSwizzle<3, 1>; + using BiasType = {{template.catlass_elem_type(kernel.catlass_dtype(Bias), "LayoutBias")}}; + + {{epilogue_arguments}} +""" + + +# ------ Basic Matmul ------- + +MM_TEMPLATE_CATLASS_1X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. + +{{template.render_gemm_arguments(op_instance, argument_template, epilogue_template, + X, W, Bias, Y, alpha, beta, kernel)}} + +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + uint32_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + uint32_t m = {{kernel.size(X, -2)}}; + uint32_t k = {{kernel.size(X, -1)}}; + uint32_t n = {{kernel.size(W, -1)}}; + + GemmCoord problemShape{m, n, k}; + + // Define the layout of each matrix + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + + uint8_t* deviceA = {{template.catlass_type_cast(X, kernel.ptr(X))}}; + uint8_t* deviceB = {{template.catlass_type_cast(W, kernel.ptr(W))}}; + uint8_t* deviceBias = {{template.catlass_type_cast(Bias, kernel.ptr(Bias))}}; + uint8_t* deviceC = {{template.catlass_type_cast(Y, kernel.ptr(Y))}}; + + {{kernel_arguments}} + GemmAdapter gemm_op; + + if (workspace_size) { + *workspace_size = gemm_op.GetWorkspaceSize(arguments); + return 0; + } + + {{ffts_addr_prepare}} + auto aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic(); + { + auto status = gemm_op.CanImplement(arguments); + CATLASS_CHECK(status); + } + { + auto status = gemm_op.Initialize(arguments, workspace); + CATLASS_CHECK(status); + } + { + auto status = {{kernel_call}} + CATLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + + +MM_KERNEL_CALL_CATLASS_1X = r"""gemm_op(stream, aicCoreNum);""" + + +FFTS_ADDR_PREPARE = r""" + uint64_t fftsAddr{0}; + uint32_t fftsLen{0}; + RT_CHECK(rtGetC2cCtrlAddr(&fftsAddr, &fftsLen)); +""" + + +MM_FFTS_KERNEL_CALL_CATLASS_1X = r"""gemm_op(stream, aicCoreNum, fftsAddr);""" + + +MM_ARGS_CATLASS_1X = r""" + // Initialize GemmUniversal1xInstance arguments. + + // Define ArchTag + using ArchTag = {{op_instance.arch_typename()}}; + + // Define the Layout + using LayoutA = {{op_instance.layoutA_typename()}}; + using LayoutB = {{op_instance.layoutB_typename()}}; + using LayoutC = {{op_instance.layoutC_typename()}}; + using LayoutBias = Catlass::layout::VectorLayout; + + using DispatchPolicy = {{op_instance.dispatch_policy_typename()}}; + using L1TileShape = {{op_instance.tile_description.l1_tile_typename()}}; + using L0TileShape = {{op_instance.tile_description.l0_tile_typename()}}; + + using AType = {{template.catlass_elem_type(kernel.catlass_dtype(X), "LayoutA")}}; + using BType = {{template.catlass_elem_type(kernel.catlass_dtype(W), "LayoutB")}}; + using CType = {{template.catlass_elem_type(kernel.catlass_dtype(Y), "LayoutC")}}; + using BiasType = {{template.catlass_elem_type(kernel.catlass_dtype(Bias), "LayoutBias")}}; + + using GemmBlock = Gemm::Block::BlockMmad; + using BlockScheduler = {{op_instance.swizzle_typename()}}; + {{epilogue_arguments}} + + using GemmKernel = Gemm::Kernel::{{op_instance.gemm_typename()}}; + using GemmAdapter = Gemm::Device::DeviceGemm; +""" + + +MM_ARGS_CATLASS_1X_VOID_EPILOGUE = r""" + using BlockEpilogue = void; +""" + + +MM_ARGS_CATLASS_1X_EPILOGUE = r""" + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2ElemWiseOneSource; + using DType = BiasType; + using ComputeType = CType; + constexpr uint32_t computeLength = 16384; + using TileElemWiseEpilogue = Epilogue::Tile::TileElemWiseAdd; + using EpilogueTileCopy = Epilogue::Tile::TileCopy; + using BlockEpilogue = Epilogue::Block::BlockEpilogue; +""" + + +MM_KERNEL_ARGUMENTS_CATLASS_1X = r""" + typename GemmKernel::Arguments arguments{problemShape, deviceA, deviceB, deviceC}; +""" + + +BMM_KERNEL_ARGUMENTS_CATLASS_1X = r""" + typename GemmKernel::Arguments arguments{B, problemShape, deviceA, deviceB, deviceC}; +""" + + +MMBIAS_KERNEL_ARGUMENTS_CATLASS_1X = r""" + typename GemmKernel::Arguments arguments{problemShape, deviceA, deviceB, deviceC, deviceBias}; +""" + + +GEMM_KERNEL_ARGUMENTS_CATLASS_1X = r""" + typename EpilogueBlock::Params epilogueParams{alpha, beta, deviceBias, layoutBias, deviceC, layoutC}; + typename GemmKernel::Arguments arguments{problemShape, align, deviceA, deviceB, workspace, deviceWA, deviceWB, epilogueParams}; +""" + + +GEMM_TEMPLATE_CATLASS_1X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. + +{{template.render_gemm_arguments(op_instance, argument_template, epilogue_template, + X, W, Bias, Y, alpha, beta, kernel)}} + +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + uint32_t B = {{kernel.size(Y, 0, -3, default_value=1)}}; + uint32_t m = {{kernel.size(X, -2)}}; + uint32_t k = {{kernel.size(X, -1)}}; + uint32_t n = {{kernel.size(W, -1)}}; + + GemmCoord problemShape{m, n, k}; + + // define scalar + float alpha = {{alpha}}; + float beta = {{beta}}; + + // Define the layout of each matrix + LayoutA layoutA{m, k}; + LayoutB layoutB{k, n}; + LayoutC layoutC{m, n}; + LayoutBias layoutBias{m, n}; // TODO(px): check here + + // Define Workspace layout & size + const uint32_t align = 128; + LayoutA layoutWA = GetWorkspaceLayout(layoutA, align); + LayoutB layoutWB = GetWorkspaceLayout(layoutB, align); + size_t sizeWA = GetWorkspaceLen(layoutWA) * sizeof({{kernel.catlass_dtype(X)}}); + size_t sizeWB = GetWorkspaceLen(layoutWB) * sizeof({{kernel.catlass_dtype(W)}}); + size_t sizeWorkspace = static_cast(M) * N * sizeof({{kernel.catlass_dtype(Y)}}); + + uint8_t* deviceA = {{template.catlass_type_cast(X, kernel.ptr(X))}}; + uint8_t* deviceB = {{template.catlass_type_cast(W, kernel.ptr(W))}}; + uint8_t* deviceBias = {{template.catlass_type_cast(Bias, kernel.ptr(Bias))}}; + uint8_t* deviceC = {{template.catlass_type_cast(Y, kernel.ptr(Y))}}; + + if (workspace_size) { + // TODO(px): Gemm's GetWorkspaceSize has no use + // *workspace_size = gemm_op.GetWorkspaceSize(arguments); + if (!IsSameStride(layoutWA, layoutA)) { + sizeWorkspace += sizeWA; + } + if (!IsSameStride(layoutWB, layoutB)) { + sizeWorkspace += sizeWB; + } + *workspace_size = sizeWorkspace; + return 0; + } + + // split the workspace to three part: workspace, deviceWA (optional), deviceWB (optional) + uint8_t* deviceWA = deviceA; + uint8_t* deviceWB = deviceB; + uint8_t* offset = workspace + sizeWorkspace; + if (!IsSameStride(layoutWA, layoutA)) { + deviceWA = offset; + offset += sizeWA; + } + if (!IsSameStride(layoutWB, layoutB)) { + deviceWB = offset; + } + + {{kernel_arguments}} + GemmAdapter gemm_op; + + {{ffts_addr_prepare}} + auto aicCoreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAic(); + { + auto status = gemm_op.CanImplement(arguments); + CATLASS_CHECK(status); + } + { + auto status = gemm_op.Initialize(arguments, workspace); + CATLASS_CHECK(status); + } + { + auto status = {{kernel_call}} + CATLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + + +# Jinja template for Catlass 1.x GEMM Kernel arguments, used by the CATLASSGemmTemplate class below. +GEMM_ARGS_CATLASS_1X = r""" + // Initialize GemmUniversal1xInstance arguments. + + // Define ArchTag + using ArchTag = {{op_instance.arch_typename()}}; + + // Define the Layout + using LayoutA = {{op_instance.layoutA_typename()}}; + using LayoutB = {{op_instance.layoutB_typename()}}; + using LayoutC = {{op_instance.layoutC_typename()}}; + using LayoutBias = {{op_instance.layoutD_typename()}}; + + using DispatchPolicy = {{op_instance.dispatch_policy_typename()}}; + using L1TileShape = {{op_instance.tile_description.l1_tile_typename()}}; + using L0TileShape = {{op_instance.tile_description.l0_tile_typename()}}; + + using AType = {{template.catlass_elem_type(kernel.catlass_dtype(X), "LayoutA")}}; + using BType = {{template.catlass_elem_type(kernel.catlass_dtype(W), "LayoutB")}}; + using CType = {{template.catlass_elem_type(kernel.catlass_dtype(Y), "LayoutC")}}; + using BiasType = {{template.catlass_elem_type(kernel.catlass_dtype(Bias), "LayoutBias")}}; + + using GemmBlock = Gemm::Block::BlockGemm; + + {{epilogue_arguments}} + + using GemmKernel = Gemm::Kernel::{{op_instance.gemm_typename()}}; + using GemmAdapter = Gemm::Device::DeviceGemm; +""" + + +# Jinja template for Catlass 1.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CATLASSGemmTemplate class below. +GEMM_ARGS_CATLASS_1X_EPILOGUE = r""" + // TODO(px): define GemmOperation's epilogue dispatch policy + using EpilogueBlockDispatchPolicy = Catlass::Epilogue::EpilogueAtlasA2Gemm; + using DType = BiasType; + using ComputeType = CType; + using TileShapeCast = MatrixShape; + constexpr uint32_t computeLength = L1TileShape::MN / 2; + using TileElemWiseAddGemm = Epilogue::Tile::TileElemWiseAdd; + using TileElemWiseMulsGemm = Epilogue::Tile::TileElemWiseMuls; + using TileElemWiseCastD = Epilogue::Tile::TileCast; + using EpilogueTileCopy = Epilogue::Tile::TileCopy; + using EpilogueBlock = Epilogue::Block::BlockEpilogue; +""" + + +class CATLASSGemmTemplate(CATLASSTemplate, ABC): + """ + CATLASS GEMM Template, which is used to generate CATLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + ) -> None: + super().__init__("catlass_gemm", input_nodes, layout, input_reorder) + self.alpha = alpha + self.beta = beta + assert len(input_nodes) == 2 or len(input_nodes) == 3 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + self.is_batchmm = any(len(node.get_size()) == 3 for node in input_nodes) + self.has_bias = len(input_nodes) == 3 + self.use_gemm_addmm = False + self.shape_desc = self.get_shape_desc(input_nodes) + if self.has_bias: + bias = input_nodes[2] + bias_first_stride = bias.get_stride()[-2] + # For N = 1, cannot distinguish bias shape is (M, 1) or (1,) + # currently use matmulBias for this case + self.use_gemm_addmm = bias_first_stride != 0 and not ( + self.shape_desc[1] == 1 + ) + + @staticmethod + def get_shape_desc(input_nodes) -> Tuple[int, int, int]: + X, W = input_nodes[0], input_nodes[1] + M = X.get_size()[-2] + K = X.get_size()[-1] + N = W.get_size()[-1] + shape_desc = [M, N, K] + + for i, x in enumerate(shape_desc): + if isinstance(x, (int, sympy.Integer)): + shape_desc[i] = int(x) + elif isinstance(x, (sympy.Symbol, sympy.Expr)): + shape_desc[i] = x.subs(V.graph.sizevars.var_to_val) + else: + raise ValueError(f"Unknown shape dim type: {type(x)}, value: {x}") + return tuple(shape_desc) + + @staticmethod + @abstractmethod + def add_catlass_gemm_choices( + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @abstractmethod + def _is_op_kind_supported( + self, + op_kind: "catlass_lib.GemmKind", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self, op_kind: "catlass_lib.GemmKind") -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args(self, op_kind: "catlass_lib.GemmKind") -> Tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_kernel_arguments_and_call( + self, op_kind: "catlass_lib.GemmKind" + ) -> Tuple[str, str, str]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]: + raise NotImplementedError + + def _is_standard_matmul(self) -> bool: + return self.alpha == 1.0 and (self.beta == 0.0 or self.beta == 1.0) + + def _add_catlass_gemm_choices( + self, + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Catlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Catlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + ops = self.gen_ops(self.shape_desc) + for name, op in ops: + self.maybe_append_choice( + choices, + description=name, + op=op, + ) + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Catlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) + log.debug( + "Added %d Catlass gemm configs.", + len(ops), + ) + + def filter_op( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> Optional["GemmOperation"]: # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Catlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + accumulator_torch_dtype = catlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + catlass_utils.dtype_match(X.get_dtype(), op.A.element) + and catlass_utils.dtype_match(W.get_dtype(), op.B.element) + and catlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.C.element + ) + and catlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CATLASSGemmTemplate.catlass_layout(self.output_node.get_layout()) + + op.element_epilogue = op.accumulator_type() + + # Set bias layout and alignment. + if not self._set_bias_layout(op): + return None + + return op + + def gen_ops( + self, shape_desc: Tuple[int, int, int] + ) -> "List[Tuple[str, GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Catlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[Tuple[str, GemmOperation]]: A list of (catlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + ops = catlass_utils.gen_ops(shape_desc) + res: Dict[str, GemmOperation] = {} + + for op in ops: + assert isinstance(op, GemmOperation) + if not self._is_op_kind_supported(op.gemm_kind): + continue + + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + + log.debug("Got catlass configs: total number of ops: %d, ", len(res)) + return list(res.items())[: inductor_npu_config.catlass_max_profiling_configs] + + def header(self) -> IndentedBuffer: + """ + # Returns a buffer containing CUDA C++ code for the header section of the CATLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated C++ header code. + """ + res = super().header() + res.splice( + """ + #include "catlass/gemm/block/block_mmad.hpp" + #include "catlass/gemm/block/block_swizzle.hpp" + #include "catlass/gemm/dispatch_policy.hpp" + #include "catlass/gemm/gemm_type.hpp" + #include "catlass/gemm/device/device_gemm.hpp" + #include "catlass/gemm_coord.hpp" + #include "catlass/matrix_coord.hpp" + + // Epilogue + #include "catlass/epilogue/dispatch_policy.hpp" + #include "catlass/epilogue/tile/tile_copy.hpp" + #include "catlass/epilogue/tile/tile_elemwise_add.hpp" + #include "catlass/epilogue/tile/tile_elemwise_muls.hpp" + #include "catlass/epilogue/tile/tile_cast.hpp" + #include "catlass/epilogue/block/block_epilogue.hpp" + + // kernel headers + """ + ) + if not self._is_standard_matmul() or self.use_gemm_addmm: + res.splice( + """ + #include "catlass/gemm/kernel/gemm.hpp" + """ + ) + else: + res.splice( + """ + #include "catlass/gemm/kernel/basic_matmul.hpp" + #include "catlass/gemm/kernel/optimized_matmul.hpp" + #include "catlass/gemm/kernel/batched_matmul.hpp" + #include "catlass/gemm/kernel/matmul_bias.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + if not self._is_standard_matmul() or self.use_gemm_addmm: + res.splice( + """ + // Workspace util funcs + layout::RowMajor GetWorkspaceLayout(layout::RowMajor layout, uint32_t align) + { + if (align == 0) { + return layout; + } + return layout::RowMajor(layout.shape(0), layout.shape(1), RoundUp(layout.shape(1), align)); + } + + + layout::ColumnMajor GetWorkspaceLayout(layout::ColumnMajor layout, uint32_t align) + { + if (align == 0) { + return layout; + } + return layout::ColumnMajor(layout.shape(0), layout.shape(1), RoundUp(layout.shape(0), align)); + } + + + size_t GetWorkspaceLen(layout::RowMajor layout) + { + return layout.shape(0) * layout.stride(0); + } + + + size_t GetWorkspaceLen(layout::ColumnMajor layout) + { + return layout.shape(1) * layout.stride(1); + } + + bool IsSameStride(layout::RowMajor layout1, layout::RowMajor layout2) + { + return layout1.stride(0) == layout2.stride(0); + } + + bool IsSameStride(layout::ColumnMajor layout1, layout::ColumnMajor layout2) + { + return layout1.stride(1) == layout2.stride(1); + } + + """ + ) + else: + res.splice( + """ + bool IsNeedPadding(layout::RowMajor layout, uint32_t align) + { + // If the stride is greater than 65536, padding is required to reduce the stride. + if (layout.stride(0) < 65536) { + return layout.stride(0) % align != 0; + } else { + return true; + } + } + + bool IsNeedPadding(layout::ColumnMajor layout, uint32_t align) + { + // If the stride is greater than 65536, padding is required to reduce the stride. + if (layout.stride(1) < 65536) { + return layout.stride(1) % align != 0; + } else { + return true; + } + } + + bool IsNeedPadding(layout::zN layout, uint32_t align) + { + return false; + } + + bool IsNeedPadding(layout::nZ layout, uint32_t align) + { + return false; + } + """ + ) + return res + + @staticmethod + def catlass_layout(torch_layout: ir.Layout) -> "Optional[catlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding catlass layout str + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + str: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return catlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return catlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def catlass_elem_type( + catlass_dtype: str, + catlass_layout: str, + ) -> str: + if catlass_dtype == "void": + return "void" + else: + return f"Gemm::GemmType<{catlass_dtype}, {catlass_layout}>" + + @staticmethod + def layout_match( + torch_layout: ir.Layout, + catlass_layout: "catlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Catlass layout""" + return CATLASSGemmTemplate.catlass_layout(torch_layout) == catlass_layout + + def render( # type: ignore[override] + self, + kernel: NPUTemplateKernel, + op: "GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[NPUTemplateBuffer] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Catlass based C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (NPUTemplateKernel): The kernel to be rendered. + op (GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Catlass based C++ code fragment as a string, to be used by the current + NPUTemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + assert isinstance( + op, GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + if not isinstance(X.layout, FixedLayout): + raise NotImplementedError("X.layout is not fixed") + if not isinstance(W.layout, FixedLayout): + raise NotImplementedError("W.layout is not fixed") + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names() + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by CutlassEVTEpilogueArgumentFormatter and + # the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + kernel_call_signature = kernel.def_kernel( + inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] + ) + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + if Bias is not None: + assert Bias.get_layout().dtype == X.get_layout().dtype + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + argument_template, epilogue_template = self._get_template_args(op.gemm_kind) + kernel_arguments, ffts_addr_prepare, kernel_call = ( + self._get_kernel_arguments_and_call(op.gemm_kind) + ) + + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + argument_template=argument_template, + epilogue_template=epilogue_template, + kernel_arguments=kernel_arguments, + ffts_addr_prepare=ffts_addr_prepare, + kernel_call=kernel_call, + template=self, + kernel=kernel, + op_instance=op, + input_reorder=self.input_reorder, + test_call_statement=test_call_statement, + ) + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template(op.gemm_kind)).render( + **options + ) + + return res + + def fix_op_layout( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CATLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Catlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CATLASSGemmTemplate.catlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CATLASSGemmTemplate.catlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CATLASSGemmTemplate.catlass_layout(c_layout) + new_op.C.element = catlass_utils.torch_dtype_to_catlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CATLASSGemmTemplate.catlass_layout(d_layout) + return new_op + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Catlass C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs() + arg_names = [name.strip() for name in names_str.strip().split(",")] + if input_nodes[2] is None: + del arg_names[2] + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" + + +class CATLASS1xGemmTemplate(CATLASSGemmTemplate): + def __init__( + self, + input_nodes: List[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[List[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_catlass_gemm_choices( + choices: List[ChoiceCaller], + layout: ir.Layout, + input_nodes: List[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[List[int]] = None, + **extra_kwargs, + ) -> None: + template = CATLASS1xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_catlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + def _is_op_kind_supported( + self, + op_kind: "catlass_lib.GemmKind", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + if not self._is_standard_matmul() or self.use_gemm_addmm: + if self.is_batchmm: + # GEMM template limitation + return False + else: + return op_kind == catlass_lib.GemmKind.Gemm + + if self.is_batchmm and self.has_bias: + return False + + if self.is_batchmm: + return op_kind == catlass_lib.GemmKind.BatchedMatmul + + if self.has_bias: + supported_kinds = {catlass_lib.GemmKind.MatmulBias} + if not inductor_npu_config.catlass_ignore_gemm_in_standard_mm: + supported_kinds.add(catlass_lib.GemmKind.Gemm) + return op_kind in supported_kinds + + supported_kinds = { + catlass_lib.GemmKind.BasicMatmul, + catlass_lib.GemmKind.OptimizedMatmul, + } + if not inductor_npu_config.catlass_ignore_gemm_in_standard_mm: + supported_kinds.add(catlass_lib.GemmKind.Gemm) + return op_kind in supported_kinds + + def _get_extra_inputs_and_names( + self, + op: "GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> Tuple[Optional[Buffer], List[Optional[Buffer]], List[str]]: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + inputs: List[Optional[Buffer]] = [] + names: List[str] = [] + return (Bias, inputs, names) + + def _shape_match( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + X, W = self.input_nodes[0], self.input_nodes[1] + return X.get_size()[-1] == W.get_size()[-2] + + def _set_bias_layout( + self, + op: "GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CATLASSGemmTemplate.catlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # bias and output layout must match + return False + op.C.layout = bias_layout + else: + op.C.element = catlass_lib.DataType.void + op.C.layout = op.D.layout + return True + + def _get_template(self, op_kind: "catlass_lib.GemmKind") -> str: + if op_kind == catlass_lib.GemmKind.Gemm: + return GEMM_TEMPLATE_CATLASS_1X + elif op_kind == catlass_lib.GemmKind.OptimizedMatmul: + return OPT_MM_TEMPLATE_CATLASS_1X + else: + return MM_TEMPLATE_CATLASS_1X + + def _get_template_args(self, op_kind: "catlass_lib.GemmKind") -> Tuple[str, str]: + if op_kind == catlass_lib.GemmKind.Gemm: + return GEMM_ARGS_CATLASS_1X, GEMM_ARGS_CATLASS_1X_EPILOGUE + + # No kernel need epilogue currently + MM_EPILOGUE = MM_ARGS_CATLASS_1X_VOID_EPILOGUE + + if op_kind == catlass_lib.GemmKind.OptimizedMatmul: + return OPT_MM_ARGS_CATLASS_1X, MM_EPILOGUE + else: + return MM_ARGS_CATLASS_1X, MM_EPILOGUE + + def _get_kernel_arguments_and_call( + self, op_kind: "catlass_lib.GemmKind" + ) -> Tuple[str, str, str]: + KERNEL_ARGUMENTS = MM_KERNEL_ARGUMENTS_CATLASS_1X + if op_kind == catlass_lib.GemmKind.BatchedMatmul: + KERNEL_ARGUMENTS = BMM_KERNEL_ARGUMENTS_CATLASS_1X + if op_kind == catlass_lib.GemmKind.MatmulBias: + KERNEL_ARGUMENTS = MMBIAS_KERNEL_ARGUMENTS_CATLASS_1X + if op_kind == catlass_lib.GemmKind.Gemm: + KERNEL_ARGUMENTS = GEMM_KERNEL_ARGUMENTS_CATLASS_1X + + if ( + op_kind == catlass_lib.GemmKind.Gemm + or op_kind == catlass_lib.GemmKind.OptimizedMatmul + ): + return KERNEL_ARGUMENTS, FFTS_ADDR_PREPARE, MM_FFTS_KERNEL_CALL_CATLASS_1X + return KERNEL_ARGUMENTS, "", MM_KERNEL_CALL_CATLASS_1X + + def _are_inputs_layout_compatible(self, layouts: List[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [int(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def render_gemm_arguments( + self, + op: "GemmOperation", + argument_template: str, + epilogue_template: str, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: NPUTemplateKernel, + ) -> str: + """ + Render the Catlass C++ code required for rendering Gemm operation. + + Args: + op (GemmOperation): GemmOperation instance. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The Bias input tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs + beta (float): Scaling factor for the output tensor. + kernel (NPUTemplateKernel): NPU Template kernel for the operation. + + Returns: + str: A block of Catlass C++ code as a string, ready to be used as arguments for the GEMM operation. + """ + options = dict( + op_instance=op, + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + K="K", + ) + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, + **options, + ) + + return arguments diff --git a/torch_npu/_inductor/codegen/npu/npu_cpp_scheduling.py b/torch_npu/_inductor/codegen/npu/npu_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..8c008781f84784f2bc2292e9df222edd51b9a2d7 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/npu_cpp_scheduling.py @@ -0,0 +1,126 @@ +import logging +from typing import cast, Sequence + +from torch._dynamo.utils import counters +from torch._inductor import config +from torch._inductor.codecache import code_hash, get_path + +from torch._inductor.scheduler import ( + BaseSchedulerNode, + BaseScheduling, + Scheduler, + SchedulerNode, +) +from torch._inductor.utils import ( + get_fused_kernel_name, + get_kernel_metadata, + sympy_product, +) +from torch._inductor.virtualized import V +from torch._inductor.codegen.common import IndentedBuffer + +from .npu_kernel import NPUTemplateBuffer + + +log = logging.getLogger("torch._inductor") + + +class NPUCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for NPU C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by NPUCombinedScheduling. + + It handles fusion decisions and NPU C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self.scheduler = scheduler + + @classmethod + def get_backend_features(cls, device): + return {} + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_npu_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, NPUTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["npu", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.npu(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a NPU template, possibly with fused epilogues + """ + counters["inductor"]["npu_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_npu_cpp_template( + template_node + ), "Template node passed to NPUScheduler.codegen_template must be a SchedulerNode that wraps a NPUTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: NPUTemplateBuffer = cast(NPUTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/torch_npu/_inductor/codegen/npu/npu_kernel.py b/torch_npu/_inductor/codegen/npu/npu_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d359efdd20b262dd06f4a1cce1d05d492e6b1e --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/npu_kernel.py @@ -0,0 +1,555 @@ +import copy +import logging +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, + Optional, Tuple, Union) + +import torch +from sympy import Expr, symbols +from torch import dtype as torch_dtype +from torch._inductor.codegen.common import (IndentedBuffer, Kernel, KernelArgs, + OpOverrides, WorkspaceArg, + WorkspaceZeroMode) +from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP, CppPrinter +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.ir import (Buffer, ChoiceCaller, IRNode, Layout, + PrimitiveInfoType, TemplateBuffer, TensorBox) +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import V + +from ...autotune_process import NPUBenchmarkRequest + +if TYPE_CHECKING: + from .npu_template import ArgInfo, NPUTemplate + + +log = logging.getLogger("torch._inductor") + +cexpr = CppPrinter().doprint + +# NB: for catlass bf16 type +_DTYPE_TO_CPP = copy.deepcopy(DTYPE_TO_CPP) +_DTYPE_TO_CPP[torch.bfloat16] = "bfloat16_t" + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K"] +ValidLayoutAttrs = Literal["size"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class NPUTemplateBuffer(TemplateBuffer): + def __init__( # type: ignore[no-untyped-def] + self, + layout, + inputs, + make_kernel_render, + workspace_size: int, + template: "NPUTemplate", # + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self): # type: ignore[no-untyped-def] + return self.workspace_size if self.workspace_size is not None else 0 + + +class NPUKernelArgs(KernelArgs): + @staticmethod + def replace_bf16(s: str) -> str: + return s.replace("bfloat16", "bfloat16_t") + + # HACK for catlass's bfloat16 type + # torch.bfloat16 is "bfloat16_t" in catlass while "bfloat16" in other place + # so we can not modify the DTYPE_TO_CPP dict + def cpp_argdefs(self): + arg_defs, call_args, arg_types = super().cpp_argdefs() + + if any("bfloat16" in s for s in arg_types): + return ( + [self.replace_bf16(s) for s in arg_defs], + call_args, + [self.replace_bf16(s) for s in arg_types], + ) + else: + return arg_defs, call_args, arg_types + + +class NPUKernel(Kernel): + """ + Baseclass for NPU / Catlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + npu_args = NPUKernelArgs() + super().__init__(args=npu_args, *args, **kwargs) + self.layout_args: Dict[str, LayoutArg] = {} + # Mapping from arg name to IRNode. + self.named_nodes: Dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [arg for arg in self.layout_args.values() if arg.matches(node, attr, dim)] + if len(matches) >= 1: + # Verify all matches have the same node, attribute, and dimension + # And if they come from the same node, whichever symbol we use is fine. + # if in runtime the logic changes, this would trigger guard + first_match = matches[0] + if not all( + match.node == first_match.node + and match.attr == first_match.attr + and match.dim == first_match.dim + for match in matches + ): + raise AssertionError("All matching layout args should be identical") + return first_match + return None + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args.setdefault(symbol, arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + self.add_layout_arg("M", X, "size", mdim) + self.add_layout_arg("N", W, "size", ndim) + self.add_layout_arg("K", X, "size", kdim) + + def get_layout_args(self) -> Tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + return (M, N, K) + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class NPUTemplateKernel(NPUKernel): + """ + Template kernels defined by NPU / Catlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, aclrtStream stream" + + def __init__( + self, + kernel_name: str, + ) -> None: + """ + Initializes a new instance of the NPUTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + + def arg_name(self, node: IRNode) -> Optional[str]: + """ + Returns arg name of a given input or output node. + """ + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + + self.init_layout_args() + + size_args = [f"const int {s}" for s in ("M", "N", "K")] + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "NPUTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The NPUTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # We always originally initialize name with "KERNEL_NAME". So, we + # we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace("KERNEL_NAME", name) + _, call_args, arg_types = self.args.cpp_argdefs() + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + layout_args = self.get_layout_args() + call_args.extend(layout_args) # type: ignore[arg-type] + arg_types.extend("int" for a in layout_args) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = ( + call_args[i] + if V.graph.cpp_wrapper + else f"c_void_p({call_args[i]}.data_ptr())" + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append( + workspace + if V.graph.cpp_wrapper + else f"c_void_p({workspace}.data_ptr())" + ) + else: + ws = None + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + + wrapper.generate_kernel_call( + name, + call_args, + gpu=True, # generate non-cpu code + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return _DTYPE_TO_CPP.get(node.get_layout().dtype) + + def catlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CATLASSGemmTemplate + if node is None: + return default_dtype + from .npu_template import CATLASSTemplate + + return CATLASSTemplate._DTYPE_TO_CATLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CATLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) # type: ignore[union-attr] + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + +class NPUTemplateCaller(ChoiceCaller): + """ + NPUTemplateCaller + + This class represents a caller for NPU template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (NPUBenchmarkRequest): The benchmark request for the caller. + template_buffer (NPUTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[NPUTemplateBuffer, Optional[List[IRNode]]], str], + bmreq: NPUBenchmarkRequest, + template: "NPUTemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark( + *args, output_tensor=out + ) # @TODO: Hack for ensuring that Catlass Kernel is preferred + + def __str__(self) -> str: + return f"NPUTemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"npu_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "NPU", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch_typename()), + "tile_shape": str(op.tile_description.procedural_name()), + "dispatch_policy": str(op.dispatch_policy_typename()), + "swizzling": str(op.swizzle_typename()), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + } + else: + return {"backend": "NPU", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + NPUTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/torch_npu/_inductor/codegen/npu/npu_template.py b/torch_npu/_inductor/codegen/npu/npu_template.py new file mode 100644 index 0000000000000000000000000000000000000000..e11bbb30d796677e3f4ccb8618e998e3039c1bd7 --- /dev/null +++ b/torch_npu/_inductor/codegen/npu/npu_template.py @@ -0,0 +1,243 @@ +import functools +import itertools +import logging +from dataclasses import dataclass +from typing import List, Optional +from unittest.mock import patch + +import sympy + +import torch + +from torch._inductor.autotune_process import TensorMeta +from torch._inductor.ir import Buffer, IRNode, Layout +from torch._inductor.utils import IndentedBuffer, Placeholder, unique +from torch._inductor.virtualized import V +from torch._inductor.codegen.common import KernelTemplate + +from .npu_kernel import NPUTemplateCaller, NPUTemplateKernel, NPUTemplateBuffer +from ...autotune_process import NPUBenchmarkRequest + + +log = logging.getLogger("torch._inductor") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class NPUTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + input_reorder: Optional[List[int]] = None, + ): + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + description, + **kwargs, + ) -> NPUTemplateCaller: + """ + Generates the NPU template caller object for the given GEMM template and operation. This NPUTemplateCaller + may be used to call and benchmark the generated NPU kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A NPUTemplateCaller object representing the generated NPU template caller. + """ + kernel_name = f"npu_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), NPUTemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + size_args = V.graph.sizevars.size_hints(kernel.get_layout_args()) + + kernel_hash_name = f"npu_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = NPUBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=size_args, + source_code=code, + ) + + def make_kernel_render( + template_node: NPUTemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = NPUTemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CATLASSGemmTemplate + ) + return kernel, render + + return NPUTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + + #define ACL_CHECK(status) \\ + do { \\ + aclError error = status; \\ + if (error != ACL_ERROR_NONE) { \\ + std::cerr << __FILE__ << ":" << __LINE__ << " aclError:" << error << std::endl; \\ + } \\ + } while (0) + + + // Macro function for unwinding rt errors. + #define RT_CHECK(status) \\ + do { \\ + rtError_t error = status; \\ + if (error != RT_ERROR_NONE) { \\ + std::cerr << __FILE__ << ":" << __LINE__ << " rtError:" << error << std::endl; \\ + } \\ + } while (0) + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + +class CATLASSTemplate(NPUTemplate): + """ + CATLASSTemplate is a class that provides a template for generating CATLASS Templates. Used as a baseclass for the + CATLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CATLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include + #include + #include + + #include "catlass/catlass.hpp" + #include "catlass/arch/arch.hpp" + #include "catlass/layout/layout.hpp" + #include "catlass/status.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace Catlass; + + // Macro function for unwinding catlass errors. + #define CATLASS_CHECK(status) \\ + do { \\ + Catlass::Status error = status; \\ + if (error != Catlass::Status::kSuccess) { \\ + std::cerr << __FILE__ << ":" << __LINE__ << " raise catlassError" << std::endl; \\ + } \\ + } while (0) + + """ + ) + return res + + _DTYPE_TO_CATLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "bfloat16_t", + } + + def catlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"(uint8_t*)({ptr})" diff --git a/torch_npu/_inductor/codegen/npu_combined_scheduling.py b/torch_npu/_inductor/codegen/npu_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..6965482a96dcc27fa73b19d9ec7b0bccef5e1f6b --- /dev/null +++ b/torch_npu/_inductor/codegen/npu_combined_scheduling.py @@ -0,0 +1,91 @@ +from typing import Sequence, Union + +from torch._inductor.scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) + +from .npu.npu_cpp_scheduling import NPUCPPScheduling +from .scheduling import NPUTritonScheduling + + +class NPUCombinedScheduling(BaseScheduling): + """ + Scheduler for NPU Kernels, which delegates calls as appropriate + to the C++ and Triton Schedulers, which both work for NPU devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler) -> None: + super().__init__() + self._scheduler = scheduler + self._triton_scheduling = NPUTritonScheduling(scheduler) + self._npu_cpp_scheduling = NPUCPPScheduling(scheduler) + + def get_backend_features(self, device): # type:ignore[override] + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._npu_cpp_scheduling.is_npu_cpp_template(node): + return self._npu_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + if self._npu_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + for node in (node1, node2): + if self._npu_cpp_scheduling.is_npu_cpp_template(node): + return self._npu_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn(self, sizes): + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + if self._npu_cpp_scheduling.is_npu_cpp_template(template_node): + assert epilogue_nodes is None or len(epilogue_nodes) == 0 + return self._npu_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes + ) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]): + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self): + return self._triton_scheduling.codegen_sync() + + def flush(self): + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args, **kwargs): + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes(self, nodes): + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) + + def benchmark_combo_kernel(self, node_list): + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/torch_npu/_inductor/config.py b/torch_npu/_inductor/config.py index f9bf23ee3378200f366c568b1e3cd6ed49a451e8..31a653a8d48c64e18fbcb666e079b432f6538351 100644 --- a/torch_npu/_inductor/config.py +++ b/torch_npu/_inductor/config.py @@ -5,6 +5,8 @@ import torch from torch._inductor import config from triton.runtime.driver import driver +from .utils import classproperty + enable_npu_indexing = True config.triton.unique_kernel_names = True @@ -49,6 +51,39 @@ class aot_inductor: dump_path_py = os.environ.get("AOTI_DUMP_PATH_PY", "aoti_dump_py") +class npu: + # Whether to enable debug info, e.g., line number + enable_debug_info: bool = False + + @classproperty + def catlass_dir(self) -> str: + return os.environ.get( + "TORCHINDUCTOR_NPU_CATLASS_DIR", + os.path.abspath( + os.path.join(os.path.dirname(torch.__file__), "../third_party/catlass") + ), + ) + + # Configures the maximum number of CATLASS configs to profile in max_autotune. + # By default it's None, so that all CATLASS configs are tuned. + # This is mainly used to reduce test time in CI. + catlass_max_profiling_configs: Optional[int] = None + + catlass_backend_min_gemm_size: int = 1 + + # Wheter to ignore GEMM template for standard matmul + catlass_ignore_gemm_in_standard_mm: bool = True + + # Whether to use catlass matmul autotune to generate tile config + catlass_use_gemm_autotune: bool = True + + # enable generation of inline standalone runner in CATLASS CPP generated code + # which allows to compile the generated code into a standalone executable. + generate_test_runner: bool = ( + os.environ.get("INDUCTOR_NPU_BACKEND_GENERATE_TEST_RUNNER_CODE", "0") == "1" + ) + + traced_fx_graph_cache = os.environ.get("INDUCTOR_ASCEND_FX_GRAPH_CACHE", None) check_accuracy = os.environ.get("INDUCTOR_ASCEND_CHECK_ACCURACY", False) auto_fallback = os.environ.get("INDUCTOR_ASCEND_AUTO_FALLBACK", True) diff --git a/torch_npu/_inductor/kernel/__init__.py b/torch_npu/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e040885f621e148f16637594fbd9f2b132b84ab6 --- /dev/null +++ b/torch_npu/_inductor/kernel/__init__.py @@ -0,0 +1,2 @@ +from .mm import _register_npu_inductor_mm, _register_npu_inductor_addmm +from .bmm import _register_npu_inductor_bmm diff --git a/torch_npu/_inductor/kernel/bmm.py b/torch_npu/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..183436becb21bbf391570bea277fa44f2cacedec --- /dev/null +++ b/torch_npu/_inductor/kernel/bmm.py @@ -0,0 +1,117 @@ +import logging + +import torch +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate + +from torch._inductor import ir, lowering as L +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from torch._inductor.utils import ( + ceildiv as cdiv, + use_aten_gemm_kernels, + use_ck_template, + use_cpp_bmm_template, + use_triton_template, +) +from torch._inductor.virtualized import V +from torch._inductor.kernel.mm_common import ( + _is_static_problem, + addmm_epilogue, + mm_args, + mm_configs, + mm_options, +) + +from ..utils import use_catlass_template + + +log = logging.getLogger("torch._inductor") +aten = torch.ops.aten + +aten_bmm = torch._inductor.kernel.bmm.aten_bmm +aten_baddbmm = torch._inductor.kernel.bmm.aten_baddbmm +bmm_configs = torch._inductor.kernel.bmm.bmm_configs +bmm_template = torch._inductor.kernel.bmm.bmm_template + + +def _register_npu_inductor_bmm(): + @L.register_lowering(aten.bmm) + def tuned_bmm(mat1, mat2, *, layout=None): + if all(x.get_device().type == "cpu" for x in [mat1, mat2]): + # decompose to small ops when memory bound + if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1: + mat1 = L.unsqueeze(mat1, -1) + mat2 = L.unsqueeze(mat2, 1) + return L.sum_(L.mul(mat1, mat2), axis=2) + + def is_valid_to_require_contiguous(t): + if not ir.is_storage_and_layout(t): + return True + _, layout = ir.as_storage_and_layout(t, freeze=False) + return isinstance(layout, ir.FlexibleLayout) + + def is_preferred_layout_as_bmm_input(sizes, strides): + # contiguous on one of the last two dims + return ( + strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1]) + ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2])) + + # Make the input of bmm contiguous + # if it is not contiguous on either of the last two dims, + # because bmm cpu implementation would do contiguous() if not. + # This is to avoid additional copies in bmm. + def may_require_contiguous(t, meta_t): + sizes = meta_t.meta["val"].size() + strides = meta_t.meta["val"].stride() + if not is_preferred_layout_as_bmm_input(sizes, strides): + t = ir.ExternKernel.require_contiguous(t) + return t + + if is_valid_to_require_contiguous(mat1): + meta_mat1 = V.graph.current_node.args[0] + mat1 = may_require_contiguous(mat1, meta_mat1) + if is_valid_to_require_contiguous(mat2): + meta_mat2 = V.graph.current_node.args[1] + mat2 = may_require_contiguous(mat2, meta_mat2) + + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + + # options to tune from + choices = ( + [aten_bmm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + if use_triton_template(layout): + for config in bmm_configs(m, n, k, device_type=ir.get_device_type(mat1)): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + static_shape, is_nonzero = _is_static_problem(layout) + if static_shape and is_nonzero and use_catlass_template(layout, m, n, k): + from ..codegen.npu.gemm_template import CATLASS1xGemmTemplate + + CATLASS1xGemmTemplate.add_catlass_gemm_choices( + choices, layout, [mat1, mat2] + ) + + if use_cpp_bmm_template(layout, mat1, mat2): + from torch._inductor.codegen.cpp_bmm_template import CppBmmTemplate + + CppBmmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + if use_ck_template(layout): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + if len(choices) == 0: + log.warning("No choices for GEMM, using ATen backend as fallback") + choices.append(aten_bmm.bind((mat1, mat2), layout)) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) diff --git a/torch_npu/_inductor/kernel/mm.py b/torch_npu/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..a44bc2038086d039a1f049022b433d69af9cfead --- /dev/null +++ b/torch_npu/_inductor/kernel/mm.py @@ -0,0 +1,298 @@ +import functools +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm +from torch._inductor.autoheuristic.autoheuristic_utils import ( + mm_operations, +) +import torch._inductor.codegen +from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +import torch._inductor.kernel +from torch._inductor.virtualized import V + +from torch._inductor import config as inductor_config, ir +from torch._inductor.codegen.common import BackendFeature +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from torch._inductor.codegen.wrapper import PythonWrapperCodegen +from torch._inductor.ir import FixedLayout, FlexibleLayout, is_triton +from torch._inductor.lowering import register_lowering +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + NoValidChoicesError, + TritonTemplate, +) +from torch._inductor.utils import ( + use_aten_gemm_kernels, + use_ck_gemm_template, + use_cpp_gemm_template, + use_max_autotune, + use_triton_template, +) +from torch._inductor.kernel.mm_common import ( + _is_static_problem, + addmm_epilogue, + extra_mm_configs, + mm_args, + mm_configs, + mm_grid, + mm_options, + triton_config, +) + +from ..codegen.npu.gemm_template import CATLASS1xGemmTemplate +from ..utils import use_catlass_template + + +log = logging.getLogger("torch._inductor") +aten = torch.ops.aten + +lazy_register_extern_choice = torch._inductor.kernel.mm.lazy_register_extern_choice +aten_mm = torch._inductor.kernel.mm.aten_mm +aten_addmm = torch._inductor.kernel.mm.aten_addmm +mm_config_kwargs = torch._inductor.kernel.mm.mm_config_kwargs +mm_autoheuristic = torch._inductor.kernel.mm.mm_autoheuristic +mm_template = torch._inductor.kernel.mm.mm_template + + +def _register_npu_inductor_mm(): + @register_lowering(aten.mm, type_promotion_kind=None) + def tuned_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + name = "mm" + + aten_layout = layout + if not use_max_autotune(): + aten_layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + + # options to tune from + choices = ( + [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + ) + static_shape, is_nonzero = _is_static_problem(layout) + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if is_nonzero and use_catlass_template(layout, m, n, k): + CATLASS1xGemmTemplate.add_catlass_gemm_choices( + choices, layout, [mat1, mat2] + ) + # debug log + log.info(f"Choices number after catlass template: {len(choices)}") + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + input_nodes = [mat1, mat2] + if ( + is_nonzero + and use_triton_template(layout) + and torch._inductor.config.run_autoheuristic(name) + and is_triton(mat1) + ): + always_included = [] + if use_aten_gemm_kernels(): + always_included.append("extern_mm") + num_choices_before_extra_configs = len(choices) + for config in extra_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + # using AutoHeuristic for ranking + ah_choices = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mm_operations(), + None, + top_k=10, + always_included=always_included, + ) + if not torch._inductor.config.collect_autoheuristic(name): + # if we are collecting data, we do not want to modify choices + if ah_choices is not None and len(ah_choices) > 0: + # the order in which autoheuristic returns choices is not the same as + # as the order of choices, which affects things like epilogue fusion. + # once epilogue fusion benchmarks choices in sorted order, I think we can + # just use the order returned by autoheuristic + choices = [choice for choice in choices if choice in ah_choices] + else: + choices = choices[:num_choices_before_extra_configs] + + if ( + len(choices) == 0 + and not use_aten_gemm_kernels() + and inductor_config.autotune_fallback_to_aten + ): + log.warning("No choices for GEMM, using ATen backend as fallback") + return aten_mm.bind((mat1, mat2), aten_layout).output_node() + + for k in inductor_config.external_matmul: + choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout)) + + try: + return autotune_select_algorithm(name, choices, [mat1, mat2], layout) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning( + "All choices for GEMM were invalid, using ATen backend as fallback" + ) + return aten_mm.bind((mat1, mat2), aten_layout).output_node() + + +def _register_npu_inductor_addmm(): + @register_lowering(aten.addmm, type_promotion_kind=None) + def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + ordered_kwargs_for_cpp_kernel = ("beta", "alpha") + m, n, k, layout, mat1, mat2, inp_expanded = mm_args( + mat1, mat2, inp, layout=layout + ) + static_shape, is_nonzero = _is_static_problem(layout) + if (not is_nonzero) or (not use_max_autotune()): + if isinstance(layout, FixedLayout): + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm( + "addmm", choices, [inp, mat1, mat2], layout + ) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if static_shape and is_nonzero and use_catlass_template(layout, m, n, k): + # Filter out broadcasting on the last dim of the bias term + # since catlass does not support it yet. + if ( + PythonWrapperCodegen.statically_known_int_or_none( + inp_expanded.layout.stride[-1] + ) + != 0 + ): + CATLASS1xGemmTemplate.add_catlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + ) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + has_bias=True, + ) + + add_aten_fallback = False + if len(choices) == 0: + log.warning("No choices for GEMM, using ATen backend as fallback") + add_aten_fallback = True + + if add_aten_fallback: + choices.append( + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + ordered_kwargs_for_cpp_kernel, + alpha=alpha, + beta=beta, + ) + ) + + try: + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + except NoValidChoicesError: + if not inductor_config.autotune_fallback_to_aten: + raise + log.warning( + "All choices for GEMM were invalid, using ATen backend as fallback" + ) + fallback_choice = aten_addmm.bind( + (inp, mat1, mat2), + layout, + ordered_kwargs_for_cpp_kernel, + alpha=alpha, + beta=beta, + ) + return fallback_choice.output_node() diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index 30ff2092a66a4e56c56f7cb54c9f2c10333e5361..abe57f326a4754977fd92dc4a02c5ac078fd418c 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -78,7 +78,10 @@ GENERATE_LIST = [ aten.squeeze, aten.copy, aten.copy_, - aten.reciprocal + aten.reciprocal, + aten.mm, + aten.bmm, + aten.addmm, ] GENERATE_LIST2 = [ @@ -106,4 +109,7 @@ LOWERING_OVERLOAD_OP = [ aten.embedding, aten.cat, + aten.mm, + aten.bmm, + aten.addmm, ] diff --git a/torch_npu/_inductor/select_algorithm.py b/torch_npu/_inductor/select_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..388d745019bd35f5e9b45b669a4a1469023b8142 --- /dev/null +++ b/torch_npu/_inductor/select_algorithm.py @@ -0,0 +1,457 @@ +import builtins +import contextlib +import dataclasses +import functools +import inspect +import itertools +import json +import logging +import math +import operator +import os +import sys +import textwrap +import time +from collections import namedtuple +from concurrent.futures import as_completed, ThreadPoolExecutor +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union +from unittest.mock import patch + +import sympy +from filelock import FileLock + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.utils import counters, dynamo_timed, identity, preserve_rng_state + +from torch._inductor import config, ir +from torch._inductor.ir import ChoiceCaller +from torch._inductor.utils import restore_stdout_stderr +from torch._inductor.virtualized import V + +from torch._inductor.select_algorithm import ( + VERIFY, + PRINT_AUTOTUNE, + DEBUG, + get_mm_log_filename, + append_to_log, + get_env_num_workers, + NoValidChoicesError, + create_inputs_key, + create_precompile_key, + ExternKernelCaller, + TritonTemplateCaller, + AutotuneArgs, +) +from torch._inductor.exc import CppCompileError + + +log = logging.getLogger("torch._inductor") + + +class NPUCompileError(CppCompileError): + pass + + +def patch_algorithm_selector(): + + def __call__( + self, + name, + choices: List[ChoiceCaller], + input_nodes, + layout, + # optional dict mapping arg indices to the functions + # generating a torch.Tensor for that input from the + # corresponding ir.Buffer. if passed for a given + # arg, the function will be called instead of + # generating a random torch.Tensor for benchmarking. + input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + precompilation_timeout_seconds: int = 60 * 60, + return_multi_template=False, + ): + from .codegen.npu.npu_kernel import NPUTemplateCaller + + # Templates selected with input_gen_fns require specific input data to avoid IMA + # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection + # TODO(jgong5): support multi-template on CPU + if input_gen_fns is not None or layout.device.type == "cpu": + return_multi_template = False + + choices = [choice for choice in choices if choice is not None] + + if mm_file_name := get_mm_log_filename(): + M, K = input_nodes[-2].get_size()[:2] + N = input_nodes[-1].get_size()[-1] + append_to_log(mm_file_name, {"invoke": str((M, K, N))}) + + if len(choices) == 0: + backend_config = ( + "max_autotune_gemm_backends" + if name != "convolution" + else "max_autotune_conv_backends" + ) + raise NoValidChoicesError( + f"No choices to select, please consider adding ATEN into {backend_config} " + "config (defined in torch/_inductor/config.py) to allow at least one choice. " + ) + log.debug("Max autotune selects from %s choices.", str(len(choices))) + + if len(choices) == 1: + if not isinstance(choices[0], NPUTemplateCaller): + # NPUTemplateCaller still needs to go through autotuning process to retrieve workspace size. + return choices[0].output_node() + + @functools.lru_cache(None) + def make_benchmark_fn(): + return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + + inputs_key = create_inputs_key(input_nodes) + + def precompile(choices) -> Callable[[], None]: + def no_op(*args, **kwargs): + return + + if ( + precompilation_timeout_seconds is None + or precompilation_timeout_seconds <= 0 + ): + return no_op + + env_workers = get_env_num_workers() + num_workers = env_workers if env_workers is not None else (len(choices)) + + if num_workers <= 0: + return no_op + + if ( + sys.version_info.major == 3 + and sys.version_info.minor == 11 + and sys.version_info.micro <= 8 + ): + return no_op + + # check local and global cache before precompiling + timings = self.lookup( + choices, + name, + inputs_key, + benchmark=None, + ) + + if timings: + return no_op + + if config.search_autotune_cache and not ( + config.max_autotune or config.max_autotune_gemm + ): + return no_op + + precompile_key = create_precompile_key(name, inputs_key, choices) + if precompile_func := self.precompile_cache.get(precompile_key): + return precompile_func + + log.info( + "Multithreaded precompilation for %d choices using %d worker threads", + len(choices), + num_workers, + ) + + # In rare circumstances, because python threads inherit global state, + # thread pool executor can race and leave stdout/stderr in a state + # different than the original values. we explicitly restore the state + # here to avoid this issue. + + initial_stdout = sys.stdout + initial_stderr = sys.stderr + + def precompile_with_captured_stdout(choice): + with restore_stdout_stderr(initial_stdout, initial_stderr): + start_time = time.time() + choice.precompile() + return time.time() - start_time + + executor = ThreadPoolExecutor(max_workers=num_workers) + + futures = {} + for c in choices: + if hasattr(c, "precompile"): + future = executor.submit(precompile_with_captured_stdout, c) + futures[future] = c + + @functools.lru_cache(None) + @restore_stdout_stderr(initial_stdout, initial_stderr) + def wait_on_futures(): + counters["inductor"]["select_algorithm_precompile"] += 1 + for future in as_completed( + futures, + timeout=precompilation_timeout_seconds, + ): + if e := future.exception(): + log.error( + "Exception %s for benchmark choice %s", e, futures[future] + ) + else: + log.info( + "Precompiling benchmark choice %s took %.02fs", + futures[future], + future.result(), + ) + + executor.shutdown(wait=True) + + self.precompile_cache[precompile_key] = wait_on_futures + + return wait_on_futures + + def autotune(choices): + with dynamo_timed(f"{name}_template_autotuning"): + return make_benchmark_fn()(choices) + + if config.autotune_in_subproc: + from torch._inductor.autotune_process import tuning_pool + + # do the optional warmup + tuning_pool.initialize() + + def do_autotuning(precompile_fn): + precompile_start_ts = time.time() + with dynamo_timed(f"{name}_template_precompiling"): + precompile_fn() + precompile_elapse = time.time() - precompile_start_ts + + autotune_start_ts = time.time() + timings = self.lookup( + choices, + name, + inputs_key, + autotune, + ) + autotune_elapse = time.time() - autotune_start_ts + + if timings and all( + not math.isfinite(timing) for timing in timings.values() + ): + raise NoValidChoicesError + + if make_benchmark_fn.cache_info().currsize: + counters["inductor"]["select_algorithm_autotune"] += 1 + + if ( + make_benchmark_fn.cache_info().currsize + or log.getEffectiveLevel() == logging.DEBUG + or config.trace.log_autotuning_results + ): + self.log_results( + name, input_nodes, timings, autotune_elapse, precompile_elapse + ) + + for feedback_fn in self.feedback_saver_fns: + feedback_fn(timings, name, input_nodes, choices) + + return timings + + precompile_fn = precompile(choices) + + if return_multi_template and (config.max_autotune or config.max_autotune_gemm): + + def get_timings(): + timings = do_autotuning(precompile_fn) + min_extern_choice = float("inf") + for choice, timing in timings.items(): + if isinstance(choice, ExternKernelCaller): + min_extern_choice = min(min_extern_choice, timing) + + timings = { + choice: time + for choice, time in timings.items() + if ( + time <= min_extern_choice + or not isinstance(choice, ExternKernelCaller) + ) + } + + return timings + + return torch._inductor.ir.TensorBox.create( + torch._inductor.ir.MultiTemplateBuffer( + layout, + input_nodes, + get_timings, + choices, + ) + ) + + timings = do_autotuning(precompile_fn) + if timings == {} or choices[0] not in timings: + return choices[0].output_node() + + selected_key = builtins.min(timings, key=timings.__getitem__) + selected_time = timings[selected_key] + selected_choice = selected_key.output_node() + log.debug("selected choice: %s", str(selected_choice)) + return selected_choice + + @classmethod + def make_benchmark_fn( + cls, + choices, + input_nodes, + layout, + input_gen_fns=None, + ): + if input_gen_fns is None: + input_gen_fns = {} + + def get_inputs( + choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]], + ) -> AutotuneArgs: + # de-duplicate args + unique_example_inputs = { + x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) + for i, x in enumerate(input_nodes) + } + example_inputs = list(unique_example_inputs.values()) + example_inputs_extern = [ + ( + unique_example_inputs[input_node.get_name()] + if unique_example_inputs[input_node.get_name()].is_mkldnn + else torch.as_strided( + unique_example_inputs[input_node.get_name()], + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + ) + for input_node in input_nodes + ] + out = cls.benchmark_example_value(layout) + out_extern = torch.as_strided( + out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset) + ) + expected = None + if VERIFY: + choices[0].benchmark(*example_inputs_extern, out=out_extern) + expected = out_extern.clone() + + return AutotuneArgs.from_choice_args( + example_inputs, + example_inputs_extern, + out, + out_extern, + expected, + ) + + if DEBUG: + print(f"{len(choices)} tuning requests:") + + def benchmark_choice_in_current_process( + choice: ChoiceCaller, autotune_args: AutotuneArgs + ) -> float: + is_extern = isinstance(choice, ExternKernelCaller) + benchmark_tensors = autotune_args.get_benchmark_tensors(is_extern) + inpts, output = benchmark_tensors.unpack() + output.zero_() + result = choice.benchmark(*inpts, out=output) + if VERIFY and autotune_args.expected is not None: + autotune_args.verify(**VERIFY) + if torch.npu.is_available(): + torch.npu.synchronize() # shake out any NPU errors + return result + + def benchmark_in_current_process( + choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]], + ) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]: + inputs = get_inputs(choices) + timings = {} + for choice in choices: + try: + timing = benchmark_choice_in_current_process(choice, inputs) + except NPUCompileError as e: + log.error( + "NPU compilation error during autotuning: \n%s. \nIgnoring this choice.", + str(e), + ) + timing = float("inf") + except NotImplementedError as e: + log.warning("Not yet implemented: %s", e) + timing = float("inf") + except RuntimeError as e: + msg = str(e) + if "invalid argument" in msg: + msg += "\n\nThis may mean this NPU is too small for max_autotune mode.\n\n" + else: + if "illegal memory access" in msg: + msg += "\n\nEither error in template or triton bug.\n" + log.error( + "Runtime error during autotuning: \n%s. \nIgnoring this choice.", + msg, + ) + timing = float("inf") + except AssertionError as e: + raise AssertionError( # noqa: B904 + f"Incorrect result from choice {choice}\n" + ) from e + except Exception as e: + try: + from triton.runtime.autotuner import OutOfResources + + if isinstance(e, OutOfResources): + log.warning(e) + timing = float("inf") + else: + raise e + except ImportError: + raise e from None + + timings[choice] = timing + + # NB: We close the DLL after all catlass choices have been benchmarked + # to avoid runtime error. + for choice in choices: + if not isinstance(choice, ExternKernelCaller): + choice.bmreq.cleanup_run_fn() + + return timings + + def benchmark_in_sub_process( + choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]], + ): + from torch._inductor import autotune_process + from .codegen.npu.npu_kernel import NPUTemplateCaller + + # only benchmark triton kernel in sub process for now. + # ATen/Extern/Catlass kernel are still benchmarked in the current process. + extern = [ + c + for c in choices + if isinstance(c, ExternKernelCaller) or isinstance(c, NPUTemplateCaller) + ] + triton = [c for c in choices if c not in extern] + + timings = benchmark_in_current_process(extern) + timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type] + return timings + + benchmark = ( + benchmark_in_sub_process + if config.autotune_in_subproc + else benchmark_in_current_process + ) + + return benchmark + + from torch._inductor.select_algorithm import AlgorithmSelectorCache + + AlgorithmSelectorCache.__call__ = __call__ + AlgorithmSelectorCache.make_benchmark_fn = make_benchmark_fn diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py index 85d8a7c2780f3d9c22237fab311b65f76135acc6..2ac7e2abc1fa917ab2a97f467ed02c61df2072ac 100644 --- a/torch_npu/_inductor/utils.py +++ b/torch_npu/_inductor/utils.py @@ -1,7 +1,8 @@ -from typing import Optional +from typing import List, Optional import torch -from torch._inductor import utils, graph, scheduler +from torch._inductor import graph, scheduler, utils +from torch._inductor.utils import _use_autotune_backend, use_max_autotune import torch_npu @@ -45,4 +46,39 @@ def patch_is_same_tensor(): utils.is_same_tensor = is_same_tensor # We need to do extra-patch because of code like `from xxx import is_same_tensor` - graph.is_same_tensor = is_same_tensor \ No newline at end of file + graph.is_same_tensor = is_same_tensor + + +class classproperty: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner): + return self.func(owner) + + +def _use_template_for_npu(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: + return layout.device.type == "npu" and layout.dtype in allowed_layout_dtypes + + +def use_catlass_template(layout, m, n, k): + from torch._inductor.virtualized import V + + from . import config as npu_config + + gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) + if gemm_size <= 0 or gemm_size < npu_config.npu.catlass_backend_min_gemm_size: + return False + + # Do not use catlass template on ROCm + if torch.version.hip: + return False + + layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] + res = ( + _use_template_for_npu(layout, layout_dtypes) + and use_max_autotune() + and _use_autotune_backend("CATLASS") + ) + + return res