diff --git a/test/_inductor/test_aoti_user_defined_op.py b/test/_inductor/test_aoti_user_defined_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4e1c7b18cfb5e55469598258645e54eaaf4db1 --- /dev/null +++ b/test/_inductor/test_aoti_user_defined_op.py @@ -0,0 +1,204 @@ +import os +import torch +from torch._inductor.pattern_matcher import register_graph_pattern, CallFunction, Arg, PatternMatcherPass, Match + +from torch.library import custom_op, triton_op, wrap_triton +from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests + +import triton +import triton.language as tl + +from testutils import TestUtils + +import torch_npu +import torch_npu._inductor + + +@triton.jit +def triton_cross_add(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tmp0 = tl.load(in_ptr0 + offsets, mask) + tmp1 = tl.load(in_ptr1 + offsets, mask) + tmp2 = tl.sin(tmp0) + tmp3 = tl.cos(tmp2) + tmp4 = tl.cos(tmp1) + tmp5 = tl.sin(tmp4) + tmp6 = tl.add(tmp3, tmp5) + tl.store(out_ptr + offsets, tmp6, mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 4}), + triton.Config({"BLOCK_SIZE": 8}), + triton.Config({"BLOCK_SIZE": 16}), + ], + key=["n_elements"], +) +@triton.jit +def triton_fused_add_one(in_ptr0, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tmp0 = tl.load(in_ptr0 + offsets, mask) + tmp1 = tmp0 + 1 + tl.store(out_ptr + offsets, tmp1, mask) + + +@triton.jit +def triton_fused_add_sin(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tmp0 = tl.load(in_ptr0 + offsets, mask) + tmp1 = tl.load(in_ptr1 + offsets, mask) + tmp2 = tl.sin(tmp1) + tmp3 = tl.add(tmp0, tmp2) + tl.store(out_ptr + offsets, tmp3, mask) + + +@register_graph_pattern( + CallFunction(torch.add, Arg(), CallFunction(torch.sin, Arg())), + pass_dict=PatternMatcherPass(pass_name="test"), +) +def add_sin_replacement(match: Match, x, y): + z = torch.zeros_like(x) + n_element = x.numel() + BLOCK_SIZE = 16 + grid = (triton.cdiv(n_element, BLOCK_SIZE),) + triton_fused_add_sin[grid](x, y, z, n_element, BLOCK_SIZE=BLOCK_SIZE) + return z + + +@custom_op("my_custom::cpu_add", mutates_args={}) +def cpu_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x_cpu = x.cpu() + y_cpu = y.cpu() + z_cpu = x_cpu + y_cpu + return z_cpu.to("npu") + + +@cpu_add.register_fake +def _(x, y): + return torch.zeros_like(x) + + +@triton.jit +def triton_fused_activation_min_max(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + block_start = tl.program_id(0) * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + tmp0 = tl.load(in_ptr0 + offsets, mask) + tmp1 = tl.load(in_ptr1 + offsets, mask) + tmp2 = tl.sigmoid(tmp0) + tmp3 = tl.softmax(tmp1) + tmp4 = tl.min(tmp2, axis=0) + tmp5 = tl.max(tmp3, axis=0) + tmp6 = tmp4 + tmp5 + tl.store(out_ptr + offsets, tmp6, mask) + + +@triton_op("my_triton::activation_min_max", mutates_args={}) +def activation_min_max(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + n_element = x.numel() + BLOCK_SIZE = 16 + grid = (triton.cdiv(n_element, BLOCK_SIZE),) + z = torch.zeros([1, x.shape[1]], dtype=x.dtype, device=x.device) + wrap_triton(triton_fused_activation_min_max)[grid](x, y, z, n_element, BLOCK_SIZE=BLOCK_SIZE) + return z + + +class Model(torch.nn.Module): + def __init__(self, dim=32): + super().__init__() + self.fc1 = torch.nn.Linear(dim, dim, dtype=torch.float16) + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, y, test_aoti=False): + # test fused kernel with weights + x = self.fc1(x + 1) + x = self.sigmoid(x) + y = torch.abs(y) + + # test user defined triton kernel + z1 = torch.zeros_like(x) + n_element = x.numel() + BLOCK_SIZE = 8 + grid = (triton.cdiv(n_element, BLOCK_SIZE),) + triton_cross_add[grid](x, y, z1, n_element, BLOCK_SIZE=BLOCK_SIZE) + + # test user defined triton kernel with autotune + z2 = torch.zeros_like(x) + triton_fused_add_one[grid](x, z2, n_element) + + # test register_graph_pattern + z3 = torch.cos(x) + torch.sin(y) + + # test user defined custom_op, AOTInductor not support custom_op + if test_aoti: + z4 = torch.ones_like(x) + else: + z4 = torch.ops.my_custom.cpu_add(x, y) + + # test user defined triton_op, output shape is [1, dim] + z5 = torch.ops.my_triton.activation_min_max(x, y) + + # test op_plugin ascendC kernel + z6, z7 = torch_npu.npu_rms_norm(x, y) + + # sum of all result, auto broadcast z4 + z8 = z1 + z2 + z3 + z4 + z5 + z6 + + return z8 + + +class TestAotiUserDefinedOp(TestUtils): + def generate_input_tensor(self, batch_size=8, dim=32, device="npu"): + x_input = torch.arange(0, batch_size * dim, 1, device=device).reshape([batch_size, dim]) + x_input = 1.0 / x_input.to(torch.float16) + y_input = torch.arange(batch_size * dim, 0, -1, device=device).reshape([batch_size, dim]) + y_input = 1.0 / y_input.to(torch.float16) + return x_input, y_input + + + @parametrize('shape_x', [8]) + @parametrize('shape_y', [32]) + def test_compile(self, shape_x, shape_y): + with torch.no_grad(): + model = Model().to("npu") + x_input, y_input = self.generate_input_tensor(shape_x, shape_y) + eager_res = model.forward(x_input, y_input) + + model_c = torch.compile(model, backend="inductor", dynamic=False) + compile_res = model_c(x_input, y_input, False) + self.assertEqual(eager_res, compile_res, atol=1e-3, rtol=1e-3) + + + @parametrize('shape_x', [8]) + @parametrize('shape_y', [32]) + @parametrize('autotune_at_compile', [True, False]) + @parametrize('static_mode', [True, False]) + def test_aoti_export(self, shape_x, shape_y, autotune_at_compile, static_mode): + with torch.no_grad(): + model = Model().to("npu") + torch._inductor.config.triton.autotune_at_compile_time = autotune_at_compile + torch_npu._inductor.config.inductor_static_mode = static_mode + x_input, y_input = self.generate_input_tensor(shape_x, shape_y) + + exported = torch.export.export(model, (x_input, y_input, True)) + model_name = f"model_{os.getpid()}_{shape_x}_{shape_y}_{int(autotune_at_compile)}_{int(static_mode)}.pt2" + output_path = torch._inductor.aoti_compile_and_package( + exported, + package_path=os.path.join(os.getcwd(), model_name), + ) + self.assertTrue( + os.path.exists(output_path), + f"could not find target {output_path} generated by test_aoti_export", + ) + +instantiate_parametrized_tests(TestAotiUserDefinedOp) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index d1c588ddd5f3a9e7959e405c3146f6109402e01b..c6b7db14295fdb3893e34e8a80c9bae2265c0711 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -18,11 +18,12 @@ from .codecache import patch_cache_base_get_system from .config import aggresive_autotune, num_vector_core, set_compile_threads from .config import log as npulog from .decomposition import _register_npu_inductor_decompositons +from .graph import patch_count_bytes from .lowering import make_reduction, npu_make_fallback from .npu_choices import should_use_persistent_reduction from .npu_device import NewNPUDeviceOpOverrides from .runtime import _load_cached_autotuning -from .utils import get_current_raw_stream +from .utils import get_current_raw_stream, patch_device_need_guard set_compile_threads() @@ -104,3 +105,5 @@ register_fa_pass() patch_cache_base_get_system() patch_triton_for_inductor() +patch_count_bytes() +patch_device_need_guard() diff --git a/torch_npu/_inductor/codegen/cpp_wrapper.py b/torch_npu/_inductor/codegen/cpp_wrapper.py index 4a8be794f975c694087efda1f854f9f9926e1075..ea97f7c5be87ae2d7f47d96bcb4e3a2e55d75db6 100644 --- a/torch_npu/_inductor/codegen/cpp_wrapper.py +++ b/torch_npu/_inductor/codegen/cpp_wrapper.py @@ -18,10 +18,11 @@ from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallAr from torch._inductor.ir import IRNode, TensorBox from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn -from torch._inductor.utils import DeferredLineBase +from torch._inductor.utils import DeferredLineBase, cache_on_self from torch._inductor.virtualized import V from torch._inductor.utils import _align, ALIGN_BYTES +from .triton_utils import write_npu_triton_header_once, define_user_defined_npu_triton_kernel from .. import config as npu_config from ..config import npu_block as NPU_ALIGN_BYTES @@ -218,6 +219,7 @@ class CppWrapperNpu(CppWrapperCpu): """ #include #include + #include """ ) with open( @@ -262,6 +264,7 @@ class CppWrapperNpu(CppWrapperCpu): #include #include + #include using namespace torch::aot_inductor; """ ) @@ -321,6 +324,10 @@ class CppWrapperNpu(CppWrapperCpu): if npu_config.aot_inductor.debug_kernel: self.header.splice("#include ") + @cache_on_self + def write_triton_header_once(self) -> None: + write_npu_triton_header_once(self) + def write_get_raw_stream(self, device_idx: int, graph=None) -> str: name = f"stream{device_idx}" self.writeline( @@ -396,6 +403,42 @@ class CppWrapperNpu(CppWrapperCpu): self.prefix.writeline("\n") return super().generate(is_inference) + def prepare_npu_triton_kernel_call(self, device_index, call_args, arg_types): + new_call_args = call_args + new_arg_types = arg_types + if npu_config.inductor_static_mode: + new_call_args = [] + new_arg_types = [] + # in inductor_static_mode, remove all Integer constant args from call_args + if len(call_args) != len(arg_types): + raise RuntimeError("call_args length and arg_types length should be same") + for zip_arg in zip(call_args, arg_types): + if not isinstance(zip_arg[0], sympy.Integer) and not zip_arg[1] is sympy.Integer: + new_call_args.append(zip_arg[0]) + new_arg_types.append(zip_arg[1]) + + device_index, new_call_args = super().prepare_triton_kernel_call(device_index, new_call_args) + + return device_index, new_call_args, new_arg_types + + def define_user_defined_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + ): + return define_user_defined_npu_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + ) + + def generate_user_defined_triton_kernel( self, kernel_name: str, @@ -763,8 +806,8 @@ class CppWrapperNpu(CppWrapperCpu): ) if triton: - device_index, call_args = self.prepare_triton_kernel_call( - device_index, call_args + device_index, call_args, arg_types = self.prepare_npu_triton_kernel_call( + device_index, call_args, arg_types ) _ = self.generate_load_kernel_once(kernel_name, device_index, V.graph) diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index a729a8560cfd589a63a92c013fd7446fae9145ca..7097ab14aec91d81187ee698cdd22214a4bad52d 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -28,7 +28,7 @@ from torch._inductor.codegen.triton import ( IndexingOptions, triton_reshape, TritonCSEVariable, - OpsHandler, + OpsHandler, triton_compute_type, ) from torch._inductor.codegen.triton import ( TritonKernel, @@ -86,6 +86,22 @@ def flatten(nums): return res +def gen_npu_triton_ext_imports(): + imports = IndentedBuffer() + imports.splice( + """ + from torch._inductor.runtime import triton_helpers + from torch_npu._inductor import npu_triton_heuristics + from torch_npu._inductor import npu_triton_helpers + from torch_npu._inductor.runtime import NPUDeviceProperties + from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math + import torch + import torch_npu + """ + ) + return imports.getvalue() + + class NPUTritonKernelOverrides(TritonKernelOverrides): @staticmethod @@ -439,21 +455,6 @@ class NPUIndexTritonKernel(TritonKernel): self.reduce_analysis = None self.load_store_indexing = None - def gen_triton_ext_imports(self): - imports = IndentedBuffer() - imports.splice( - """ - from torch._inductor.runtime import triton_helpers - from torch_npu._inductor import npu_triton_heuristics - from torch_npu._inductor import npu_triton_helpers - from torch_npu._inductor.runtime import NPUDeviceProperties - from torch_npu._inductor.npu_triton_helpers import libdevice, math as tl_math - import torch - import torch_npu - """ - ) - return imports.getvalue() - def patch_triton_hash(self): # remove this method once the original invocation is fixed import hashlib @@ -540,6 +541,17 @@ class NPUIndexTritonKernel(TritonKernel): break return dtype + @staticmethod + def inductor_meta_common(): + inductor_meta = { + "traced_graph_hash": "TRACED_GRAPH_HASH", + "traced_graph_dir": "TRACED_GRAPH_DIR", + "store_cubin": config.triton.store_cubin, + "force_disable_caches": config.force_disable_caches, + "profile_bandwidth_with_do_bench_using_profiling": config.profile_bandwidth_with_do_bench_using_profiling, + } + return inductor_meta + def create_inductor_meta(self): mutated_args = set() for mutation in self.mutations: @@ -572,11 +584,7 @@ class NPUIndexTritonKernel(TritonKernel): "numof_reduction_axis": self.numof_reduction_axis(), "split_axis_dtype": split_axis_dtype, "dual_reduction": self.numof_reduction_axis() > 1, - "traced_graph_hash": "TRACED_GRAPH_HASH", - "traced_graph_dir": "TRACED_GRAPH_DIR", - "store_cubin": config.triton.store_cubin, - "force_disable_caches": config.force_disable_caches, - "profile_bandwidth_with_do_bench_using_profiling": config.profile_bandwidth_with_do_bench_using_profiling, + **self.inductor_meta_common() } return inductor_meta @@ -671,7 +679,7 @@ class NPUIndexTritonKernel(TritonKernel): if name is None: code.splice(gen_common_triton_imports()) # Note: add extra imports for extensions - code.splice(self.gen_triton_ext_imports()) + code.splice(gen_npu_triton_ext_imports()) if config.benchmark_kernel: code.splice(self.imports_for_benchmark_kernel()) diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py index 1bbaef2a2fd2f6b0c3b2f6b17dbf57ddd569d750..6e72a7fea396f931bc8292536762a1fbda4cf631 100644 --- a/torch_npu/_inductor/codegen/triton_utils.py +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -1,4 +1,20 @@ +import inspect +from typing import List, Dict, Any + +import sympy import torch +from torch._inductor import ir, config +from torch._inductor.codegen.common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg +from torch._inductor.codegen.triton_utils import signature_to_meta +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, \ + user_defined_triton_kernel_transitive_closure_source_code +from torch._inductor.runtime import triton_heuristics +from torch._inductor.utils import IndentedBuffer +from torch._inductor.virtualized import V +from torch.utils._triton import patch_triton_dtype_repr + +from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel, gen_npu_triton_ext_imports +from torch_npu._inductor.runtime import NPUDeviceProperties # wrapper npu 32 bytes align, get and pass unalign info to triton meta # then autotune choose tiling param and send them to bishengIR @@ -24,3 +40,179 @@ def get_aligned_numel(dtype): return 32 // byte_per_numel[dtype] else: return 1 + + +def write_npu_triton_header_once(wrapper: PythonWrapperCodegen) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import ( + split_scan_grid, + grid_combo_kernels, + start_graph, + end_graph, + cooperative_reduction_grid, + ) + from torch_npu._inductor.npu_triton_heuristics import grid + import torch_npu + """ + if config.triton.autotune_at_compile_time: + wrapper.kernel_autotune_calls.splice(import_str) + wrapper.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + wrapper.imports.splice(import_str, strip=True) + wrapper.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + +def define_user_defined_npu_triton_kernel( + wrapper: PythonWrapperCodegen, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args +): + patch_triton_dtype_repr() + + original_name = kernel.__name__ + + signature: List[KernelArgType] = [] + constants: Dict[str, Any] = {} + non_constant_indices = [] + equal_to_1_args: List[str] = [] + for idx, key in enumerate(kernel.arg_names): + if key not in kwargs: + continue + arg = kwargs[key] + if idx in kernel.constexprs: + constants[key] = arg + elif kwargs[key] is None: + constants[key] = None + else: + non_constant_indices.append(idx) + if isinstance(arg, ir.TMADescriptor): + signature.append( + TMADescriptorArg( + name=key, + ) + ) + elif isinstance(arg, ir.Buffer): + signature.append( + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ) + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + signature.append( + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ) + ) + else: + signature.append(SizeArg(key, arg)) + if isinstance( + arg, (int, sympy.Integer) + ) and V.graph.sizevars.statically_known_equals( + arg, 1 # type: ignore[arg-type] + ): + equal_to_1_args.append(key) + triton_meta: Dict[str, Any] = { + "signature": signature_to_meta( + signature, + size_dtype=None, # try to infer based on symints + indices=non_constant_indices, + argdefs=kernel.arg_names, + ), + "device": NPUDeviceProperties.create( + V.graph.get_current_device_or_throw() + ), + "constants": { + **constants, + **dict.fromkeys(equal_to_1_args, 1), + }, + # special config for NPU, specify compile target + "mix_mode": "aiv", + } + + if restore_value_args: + triton_meta["restore_value"] = tuple(restore_value_args) + + if reset_to_zero_args: + triton_meta["reset_to_zero"] = tuple(reset_to_zero_args) + + # Distinguish between different functions using function id + cache_key: List[Any] = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key_tuple = tuple(cache_key) + + if cache_key_tuple in wrapper.user_defined_kernel_cache: + return wrapper.user_defined_kernel_cache[cache_key_tuple] + + name = f"{original_name}_{len(wrapper.user_defined_kernel_cache)}" + # Add to the cache for the next use + wrapper.user_defined_kernel_cache[cache_key_tuple] = (name, triton_meta) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + from .triton import gen_common_triton_imports, TritonKernel + + compile_wrapper.splice(gen_common_triton_imports()) + compile_wrapper.splice(gen_npu_triton_ext_imports()) + + inductor_meta = { + "kernel_name": name, + **NPUIndexTritonKernel.inductor_meta_common(), + } + + configs = [ + { + "kwargs": config.kwargs, + } + for config in configs + ] + + compile_wrapper.splice( + f""" + @npu_triton_heuristics.npu_user_autotune( + configs={configs!r}, + triton_meta={triton_meta!r}, + filename=__file__, + inductor_meta={inductor_meta!r}, + custom_kernel=True, + ) + @triton.jit + """ + ) + compile_wrapper.splice( + user_defined_triton_kernel_transitive_closure_source_code(kernel) + ) + + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + wrapper.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + return name, triton_meta diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py index 1b28c273bb25a85d01972c1c3eef726940e06f83..508bc85bd97d269bd43e6d582a5aeea223f97221 100644 --- a/torch_npu/_inductor/codegen/wrapper.py +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -1,11 +1,15 @@ import os import copy import hashlib +from typing import List, Any + import sympy import torch from torch._inductor import config -from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen, \ + user_defined_kernel_grid_fn_code +from torch._inductor.ir import IRNode from torch._inductor.runtime import triton_heuristics from torch._inductor.utils import ( cache_on_self, @@ -14,8 +18,9 @@ from torch._inductor.virtualized import V from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.utils._sympy.singleton_int import SingletonInt -from torch_npu._inductor import config as npu_config import torch_npu.npu.aclnn +from torch_npu._inductor import config as npu_config +from torch_npu._inductor.codegen.triton_utils import define_user_defined_npu_triton_kernel class NPUWrapperCodeGen(PythonWrapperCodegen): @@ -86,9 +91,61 @@ class NPUWrapperCodeGen(PythonWrapperCodegen): # it suffices as a type hint for the purposes of producing the correct code for this type. return SymbolicCallArg(expr, numel_expr) - # don't free anything - def make_buffer_free(self, buffer): - return "" + + def define_user_defined_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + ): + return define_user_defined_npu_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + ) + + + def generate_user_defined_triton_kernel( + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, + ): + grid_fn, code = user_defined_kernel_grid_fn_code( + kernel_name, configs, grid, wrapper=self + ) + if not (config.triton.autotune_at_compile_time and V.graph.cpp_wrapper): + # When codegen the autotune block only, do no insert Triton kernel + # code into the main block + # + # Must happen after free symbols are already codegened + # Emit the grid wrapper function right before the call + for line in code.split("\n"): + self.writeline(line) + + # Explicitly call the Python version of val_to_arg_str + args = [PythonWrapperCodegen.val_to_arg_str(self, v) for v in raw_args] + arg_types = [ + arg.get_dtype() if isinstance(arg, IRNode) else type(arg) + for arg in raw_args + ] + + # call self.generate_kernel_call here + self.generate_kernel_call( + kernel_name, + args, + grid_fn=grid_fn, + arg_types=arg_types, + raw_args=raw_args, + ) # don't assert def codegen_input_size_asserts(self) -> None: diff --git a/torch_npu/_inductor/graph.py b/torch_npu/_inductor/graph.py index caff8fbc60c1ba44c737658896e9c44095bec474..c39369fd0db361ce65e6bd01e5de0455a9e35744 100644 --- a/torch_npu/_inductor/graph.py +++ b/torch_npu/_inductor/graph.py @@ -111,4 +111,22 @@ def patch_codegen_with_cpp_wrapper(): # cpu return self.codegen() from torch._inductor.graph import GraphLowering - GraphLowering.codegen_with_cpp_wrapper = npu_codegen_with_cpp_wrapper \ No newline at end of file + GraphLowering.codegen_with_cpp_wrapper = npu_codegen_with_cpp_wrapper + + +def patch_count_bytes(): + def count_bytes(self): + total_bytes = 0 + node_counts = [] + node_runtimes = [] + for node in self.scheduler.nodes: + try: + num_bytes = node.get_read_write_buffers_sizes() + except AssertionError: + num_bytes = 0 + total_bytes += num_bytes + node_counts.append((node, num_bytes // 4)) + node_runtimes.append((node, node.get_estimated_runtime())) + + return total_bytes, node_counts, node_runtimes + torch._inductor.graph.GraphLowering.count_bytes = count_bytes \ No newline at end of file diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index db9c427e60d69e95e39e9a7b83396198831d6070..6e381b71e6aa97fbb1a46eb89b8e0bef19b2ff34 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -1,4 +1,6 @@ import torch +from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation + from torch_npu import npu_dtype_cast, _npu_dtype_cast aten = torch.ops.aten @@ -6,6 +8,7 @@ tr_c10d = torch.ops.tr_c10d prims = torch.ops.prims GENERATE_LIST = [ + triton_kernel_wrapper_mutation, prims.iota, aten.full, aten.mul, @@ -69,11 +72,14 @@ GENERATE_LIST = [ aten.logical_not, aten.pow, aten.gelu, + aten.sin, + aten.cos, aten.tanh, aten.isnan, aten.bitwise_and, aten.squeeze, aten.copy, + aten.copy_, aten.reciprocal ] diff --git a/torch_npu/_inductor/npu_device.py b/torch_npu/_inductor/npu_device.py index d8245460a78dec9602a2b514491b7ef064217d2f..73b21c497a7237038e2a29e20cde52ee0e4ec902 100644 --- a/torch_npu/_inductor/npu_device.py +++ b/torch_npu/_inductor/npu_device.py @@ -20,13 +20,13 @@ class NewNPUDeviceOpOverrides(NPUDeviceOpOverrides): """ def device_guard(self, device_idx): - return f"torch.npu._DeviceGuard({device_idx})" + return f"torch.npu.utils.device({device_idx})" def cpp_aoti_device_guard(self): - raise NotImplementedError + return "AOTINpuGuard" def cpp_aoti_stream_guard(self): - return "AOTICudaStreamGuard" + return "AOTINpuStreamGuard" def kernel_driver(self): source_code = """ diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index f954d4719c9f904d2921f9b50e4e0e0fc60d468e..98973b771e994ff750686636ec6390103dc6243e 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -42,7 +42,7 @@ from torch._inductor.runtime.triton_heuristics import ( get_first_attr, collected_calls, _dump_launch_params, - builtins + builtins, _pop_config_kwargs ) from triton.compiler import CompiledKernel @@ -1184,6 +1184,31 @@ def persistent_reduction_npu_index( ) +def npu_user_autotune( + configs, + triton_meta, + filename=None, + inductor_meta=None, + custom_kernel=False +): + if len(configs) == 0: + configs = [triton.Config({})] + else: + configs = [ + triton.Config(c.get("kwargs", {}), **_pop_config_kwargs({**c})) + for c in configs + ] + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel + ) + + def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): """ Compile a triton foreach kernel diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py index 79f7e165beb1fb2a9de83d03ef22ad7dcc547e68..85d8a7c2780f3d9c22237fab311b65f76135acc6 100644 --- a/torch_npu/_inductor/utils.py +++ b/torch_npu/_inductor/utils.py @@ -1,12 +1,32 @@ +from typing import Optional + import torch +from torch._inductor import utils, graph, scheduler + import torch_npu +NPU_TYPES = ["npu"] + # Not good implementation, but no other way def get_current_raw_stream(device): return torch.npu.current_stream(device).npu_stream +def is_npu(device: Optional[str]): + assert isinstance(device, str) or device is None, device + return device in NPU_TYPES + + +def patch_device_need_guard(): + def device_need_guard_npu(device: str): + assert isinstance(device, str) + return utils.is_gpu(device) or is_npu(device) + + utils.device_need_guard = device_need_guard_npu + scheduler.device_need_guard = device_need_guard_npu + + def patch_is_same_tensor(): from torch._subclasses.fake_tensor import FakeTensor @@ -23,7 +43,6 @@ def patch_is_same_tensor(): and data.storage_offset() == value.storage_offset() ) - from torch._inductor import utils, graph utils.is_same_tensor = is_same_tensor # We need to do extra-patch because of code like `from xxx import is_same_tensor` graph.is_same_tensor = is_same_tensor \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_runtime/utils_npu.h b/torch_npu/csrc/inductor/aoti_runtime/utils_npu.h index 06355dd0b25b4271cfb8deb94519b6d00d77ae9e..c62e101c3c0f3666cdf995f4ba72fc34fdbfe296 100644 --- a/torch_npu/csrc/inductor/aoti_runtime/utils_npu.h +++ b/torch_npu/csrc/inductor/aoti_runtime/utils_npu.h @@ -1,70 +1,13 @@ #pragma once -#ifdef USE_CUDA -// WARNING: Be careful when adding new includes here. This header will be used -// in model.so, and should not refer to any aten/c10 headers except the stable -// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule -// applies to other files under torch/csrc/inductor/aoti_runtime/. -#include - -#include -#include - -namespace torch::aot_inductor { - -inline void delete_cuda_guard(void* ptr) { - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_delete_cuda_guard(reinterpret_cast(ptr))); -} - -inline void delete_cuda_stream_guard(void* ptr) { - AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_delete_cuda_stream_guard( - reinterpret_cast(ptr))); -} - -class AOTICudaGuard { - public: - AOTICudaGuard(int32_t device_index) : guard_(nullptr, delete_cuda_guard) { - CUDAGuardHandle ptr = nullptr; - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_create_cuda_guard(device_index, &ptr)); - guard_.reset(ptr); - } - - void set_index(int32_t device_index) { - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_cuda_guard_set_index(guard_.get(), device_index)); - } - - private: - std::unique_ptr guard_; -}; - -class AOTICudaStreamGuard { - public: - AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) - : guard_(nullptr, delete_cuda_stream_guard) { - CUDAStreamGuardHandle ptr = nullptr; - AOTI_TORCH_ERROR_CODE_CHECK( - aoti_torch_create_cuda_stream_guard(stream, device_index, &ptr)); - guard_.reset(ptr); - } - - private: - std::unique_ptr guard_; -}; - -} // namespace torch::aot_inductor -#endif // USE_CUDA - #ifdef USE_NPU // WARNING: Be careful when adding new includes here. This header will be used // in model.so, and should not refer to any aten/c10 headers except the stable // C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule // applies to other files under torch/csrc/inductor/aoti_runtime/. #include - -#include +#include +#include namespace torch::aot_inductor { @@ -100,7 +43,7 @@ class AOTINpuStreamGuard { public: AOTINpuStreamGuard(aclrtStream stream, int32_t device_index) : guard_(nullptr, delete_npu_stream_guard) { - NpuStreamGuardHandle ptr = nullptr; + NPUStreamGuardHandle ptr = nullptr; AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_create_npu_stream_guard(stream, device_index, &ptr)); guard_.reset(ptr); diff --git a/torch_npu/csrc/inductor/aoti_torch/c/shim_npu.h b/torch_npu/csrc/inductor/aoti_torch/c/shim_npu.h new file mode 100644 index 0000000000000000000000000000000000000000..512213cb6f4bc107a984fbb64adf7e3c88b34344 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/c/shim_npu.h @@ -0,0 +1,52 @@ +#ifndef AOTI_TORCH_SHIM_NPU +#define AOTI_TORCH_SHIM_NPU + +#include + +#ifdef USE_NPU +#ifdef __cplusplus +extern "C" { +#endif + +struct NPUGuardOpaque; +using NPUGuardHandle = NPUGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_npu_guard( + int32_t device_index, + NPUGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_npu_guard(NPUGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_npu_guard_set_index(NPUGuardHandle guard, int32_t device_index); + +struct NPUStreamGuardOpaque; +using NPUStreamGuardHandle = NPUStreamGuardOpaque*; + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_npu_stream_guard( + void* stream, + int32_t device_index, + NPUStreamGuardHandle* ret_guard // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_npu_stream_guard(NPUStreamGuardHandle guard); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_current_npu_stream(int32_t device_index, void** ret_stream); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_current_npu_device(int32_t* device_index); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_set_current_npu_device(const int32_t& device_index); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_current_sycl_queue(void** ret); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // USE_NPU +#endif // AOTI_TORCH_SHIM_NPU \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp index 95f5a807ae435084da7d1e8d1652ec58058f1ce0..cc52595074138c9ebbcc0a92c9de9c1e16a15a7f 100644 --- a/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp +++ b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include @@ -32,6 +32,41 @@ namespace { } } // namespace +#ifdef USE_NPU +AOTITorchError aoti_torch_create_npu_guard(int32_t device_index, NPUGuardHandle* ret_guard) +{ + // todo: implement create npu guard logic + return AOTI_TORCH_SUCCESS; +} + +AOTITorchError aoti_torch_delete_npu_guard(NPUGuardHandle guard) +{ + // todo: implement delete npu guard logic + return AOTI_TORCH_SUCCESS; +} + +AOTITorchError aoti_torch_npu_guard_set_index(NPUGuardHandle guard, int32_t device_index) +{ + // todo: implement npu guard set index logic + return AOTI_TORCH_SUCCESS; +} + +AOTITorchError aoti_torch_create_npu_stream_guard( + void* stream, + int32_t device_index, + NPUStreamGuardHandle* ret_guard) +{ + // todo: implement create npu stream guard logic + return AOTI_TORCH_SUCCESS; +} + +AOTITorchError aoti_torch_delete_npu_stream_guard(NPUStreamGuardHandle guard) +{ + // todo: implement delete npu stream guard logic + return AOTI_TORCH_SUCCESS; +} +#endif // USE_NPU + AOTITorchError aoti_torch_create_tensor_from_blob_npu( void* data, int64_t ndim, diff --git a/torch_npu/utils/_triton.py b/torch_npu/utils/_triton.py index ebd07266fe67ec6dd8d5df6896814ba730459cda..45e175b79818c471ba6e0e414d928f61c309b06f 100644 --- a/torch_npu/utils/_triton.py +++ b/torch_npu/utils/_triton.py @@ -4,22 +4,78 @@ from torch.utils._triton import has_triton_package import torch_npu +@functools.lru_cache(None) def has_triton() -> bool: - # here has_triton only return False, - # when has_triton() is True, config.triton.autotune_at_compile_time is True, - # AOTI is not currently supported for autotune at compile stage - return False + if not has_triton_package(): + return False + + from torch._dynamo.device_interface import get_interface_for_device + + def cuda_extra_check(device_interface): + return True + + def cpu_extra_check(device_interface): + import triton.backends + + return "cpu" in triton.backends.backends + + def _return_true(device_interface): + return True + triton_supported_devices = { + "cuda": cuda_extra_check, + "xpu": _return_true, + "cpu": cpu_extra_check, + "npu": _return_true + } + def is_device_compatible_with_triton(): + for device, extra_check in triton_supported_devices.items(): + device_interface = get_interface_for_device(device) + if device_interface.is_available() and extra_check(device_interface): + return True + return False + + return is_device_compatible_with_triton() + + +@functools.lru_cache(None) def has_triton_tma(): - # here has_triton_tma only return False, - # keep pace with no transfer_to_npu, will be fully implemented in future + if has_triton_package(): + if ( + torch_npu.npu.is_available() + and not torch.version.hip + ): + try: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + return True + except ImportError: + pass + return False +@functools.lru_cache(None) def has_triton_tma_device(): - # here has_triton_tma_device only return False, - # keep pace with no transfer_to_npu, will be fully implemented in future + if has_triton_package(): + if ( + torch_npu.npu.is_available() + and not torch.version.hip + ): + try: + from triton.language.extra.ascend.libdevice import ( # noqa: F401 + reciprocal, + log1p, + ) + + return True + except ImportError: + pass + return False