From ecbacfede2494c7c404407563fbc99905234db82 Mon Sep 17 00:00:00 2001 From: zhuceHW Date: Tue, 2 Sep 2025 16:02:13 +0800 Subject: [PATCH] 1. AOTInductor support dynamic triton kernel 2. fix testcase --- test/_inductor/test_aoti_user_defined_op.py | 29 ++++++++----- test/_inductor/test_embedding.py | 7 +++- test/_inductor/test_lazy_register.py | 5 ++- .../.pytorch-disabled-tests.json | 1 - torch_npu/_inductor/codegen/cpp_wrapper.py | 41 ++++++++++++++----- torch_npu/_inductor/codegen/split_tiling.py | 3 -- torch_npu/_inductor/codegen/tile_generator.py | 21 +++++++++- torch_npu/_inductor/codegen/triton.py | 10 ++++- torch_npu/_inductor/codegen/triton_utils.py | 25 ----------- torch_npu/_inductor/npu_triton_heuristics.py | 4 -- 10 files changed, 87 insertions(+), 59 deletions(-) diff --git a/test/_inductor/test_aoti_user_defined_op.py b/test/_inductor/test_aoti_user_defined_op.py index 6b4e1c7b18..0a97bbc742 100644 --- a/test/_inductor/test_aoti_user_defined_op.py +++ b/test/_inductor/test_aoti_user_defined_op.py @@ -116,7 +116,7 @@ class Model(torch.nn.Module): self.fc1 = torch.nn.Linear(dim, dim, dtype=torch.float16) self.sigmoid = torch.nn.Sigmoid() - def forward(self, x, y, test_aoti=False): + def forward(self, x, y): # test fused kernel with weights x = self.fc1(x + 1) x = self.sigmoid(x) @@ -136,11 +136,8 @@ class Model(torch.nn.Module): # test register_graph_pattern z3 = torch.cos(x) + torch.sin(y) - # test user defined custom_op, AOTInductor not support custom_op - if test_aoti: - z4 = torch.ones_like(x) - else: - z4 = torch.ops.my_custom.cpu_add(x, y) + # test aten + z4 = torch.ones_like(x) # test user defined triton_op, output shape is [1, dim] z5 = torch.ops.my_triton.activation_min_max(x, y) @@ -172,7 +169,7 @@ class TestAotiUserDefinedOp(TestUtils): eager_res = model.forward(x_input, y_input) model_c = torch.compile(model, backend="inductor", dynamic=False) - compile_res = model_c(x_input, y_input, False) + compile_res = model_c(x_input, y_input) self.assertEqual(eager_res, compile_res, atol=1e-3, rtol=1e-3) @@ -180,15 +177,27 @@ class TestAotiUserDefinedOp(TestUtils): @parametrize('shape_y', [32]) @parametrize('autotune_at_compile', [True, False]) @parametrize('static_mode', [True, False]) - def test_aoti_export(self, shape_x, shape_y, autotune_at_compile, static_mode): + @parametrize('dynamic', [True, False]) + def test_aoti_export(self, shape_x, shape_y, autotune_at_compile, static_mode, dynamic): + if static_mode and dynamic: + # static mode and dynamic is mutual exclusion + return + with torch.no_grad(): model = Model().to("npu") torch._inductor.config.triton.autotune_at_compile_time = autotune_at_compile torch_npu._inductor.config.inductor_static_mode = static_mode x_input, y_input = self.generate_input_tensor(shape_x, shape_y) - exported = torch.export.export(model, (x_input, y_input, True)) - model_name = f"model_{os.getpid()}_{shape_x}_{shape_y}_{int(autotune_at_compile)}_{int(static_mode)}.pt2" + model_name = f"model_{os.getpid()}_{shape_x}_{shape_y}_{int(autotune_at_compile)}_{int(static_mode)}_{int(dynamic)}.pt2" + + if dynamic: + batch_dim = torch.export.Dim("batch", min=1, max=128) + exported = torch.export.export(model, (x_input, y_input), + dynamic_shapes={"x": {0: batch_dim}, "y": {0: batch_dim}}) + else: + exported = torch.export.export(model, (x_input, y_input)) + output_path = torch._inductor.aoti_compile_and_package( exported, package_path=os.path.join(os.getcwd(), model_name), diff --git a/test/_inductor/test_embedding.py b/test/_inductor/test_embedding.py index 80b92e69b5..9ca243b7e2 100644 --- a/test/_inductor/test_embedding.py +++ b/test/_inductor/test_embedding.py @@ -6,9 +6,12 @@ import torch_npu class TestEmbeddingDense(TestUtils): + def __init__(self, methodName='runTest'): + super().__init__(methodName) + self.embedding = nn.Embedding(16, 128).npu() + def op_calc(self, input): - embedding = nn.Embedding(16, 128).npu() - output = embedding(input) + output = self.embedding(input) return output # UT skip, reason: precision fail diff --git a/test/_inductor/test_lazy_register.py b/test/_inductor/test_lazy_register.py index 350c8f97cc..07526ae046 100644 --- a/test/_inductor/test_lazy_register.py +++ b/test/_inductor/test_lazy_register.py @@ -1,9 +1,12 @@ +from unittest import skipIf + import torch from torch.testing._internal.common_utils import run_tests from testutils import TestUtils import torch_npu +@skipIf(torch_npu.utils._dynamo.is_inductor_npu_initialized(), reason="Inductor npu has initialized") class TestLazyRegister(TestUtils): def test_compile_but_not_invoked(self): @@ -13,7 +16,7 @@ class TestLazyRegister(TestUtils): run = torch.compile(run) self.assertFalse(torch_npu.utils._dynamo.is_inductor_npu_initialized()) - def test_disale_register_inductor_npu(self): + def test_disable_register_inductor_npu(self): torch_npu.utils._dynamo.disable_register_inductor_npu() def run(x, y): diff --git a/test/unsupported_test_cases/.pytorch-disabled-tests.json b/test/unsupported_test_cases/.pytorch-disabled-tests.json index b62207999e..999a9b090b 100644 --- a/test/unsupported_test_cases/.pytorch-disabled-tests.json +++ b/test/unsupported_test_cases/.pytorch-disabled-tests.json @@ -31314,7 +31314,6 @@ "test_reduction_cases_shapes_shape0_dim_0_dtype_int32 (__main__.TestGeometric)": ["", [""]], "test_pointwise_cases_shape0_dtype_int32 (__main__.TestForeachAdd)": ["", [""]], "test_pointwise_cases_shape1_dtype_int32 (__main__.TestForeachAdd)": ["", [""]], - "test_pointwise_cases (__main__.TestEmbeddingDense)": ["", [""]], "test_fake_autocast_mT_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], "test_fake_autocast_scatter_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], "test_fake_autocast_special_i0e_npu_float32 (__main__.TestFakeTensorPRIVATEUSE1)": ["", [""]], diff --git a/torch_npu/_inductor/codegen/cpp_wrapper.py b/torch_npu/_inductor/codegen/cpp_wrapper.py index ea97f7c5be..fe5d688d40 100644 --- a/torch_npu/_inductor/codegen/cpp_wrapper.py +++ b/torch_npu/_inductor/codegen/cpp_wrapper.py @@ -14,7 +14,7 @@ from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.codegen.cpp_utils import cexpr, DTYPE_TO_CPP, DEVICE_TO_ATEN from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.codegen.multi_kernel import MultiKernelCall -from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg +from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, pexpr from torch._inductor.ir import IRNode, TensorBox from torch._inductor.runtime.runtime_utils import dynamo_timed from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn @@ -167,15 +167,15 @@ class DeferredNpuGridLine(DeferredLineBase): if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): grid = self.grid[i] break - checkIfTrue(grid is not None, "grid can not be None") - grid_args_str = ", ".join( - [cexpr(V.graph.sizevars.simplify(item)) for item in grid] - ) + elif isinstance(self.grid, DeferredNpuDefaultGrid): + grid = self.grid() else: - launch_grid = (params['grid_x'], params['grid_y'], params['grid_z']) - grid_args_str = ", ".join( - [cexpr(item) for item in launch_grid] - ) + grid = self.grid + + checkIfTrue(grid is not None, "grid can not be None") + grid_args_str = ", ".join( + [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + ) return f"\n Grid {self.grid_var} = Grid({grid_args_str});\n" @@ -403,6 +403,25 @@ class CppWrapperNpu(CppWrapperCpu): self.prefix.writeline("\n") return super().generate(is_inference) + # generate numel expr for range_tree_node + def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): + expr = f"{kernel_name}_{node.name}_numel" + + if (expr, V.graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((expr, V.graph)) + self.writeline(f"int64_t {expr} = {pexpr(numel_expr)};") + else: + self.writeline(f"{expr} = {pexpr(numel_expr)};") + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + return SymbolicCallArg(expr, numel_expr) + def prepare_npu_triton_kernel_call(self, device_index, call_args, arg_types): new_call_args = call_args new_arg_types = arg_types @@ -682,7 +701,9 @@ class CppWrapperNpu(CppWrapperCpu): struct_data = f'{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));' arg_data = f'static_cast<{signature2dtype[arg_signature]}>({var_name})' else: - raise TypeError("Infer arg_type to cpp failed!") + self.writeline(f"int32_t {var_name} = {cexpr(arg)};") + struct_data = f'int32_t {var_name} __attribute__((aligned(4)));' + arg_data = f'static_cast({var_name})' nonlocal struct_def_body nonlocal struct_arg_body diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 782cc9f745..dbc3158f8b 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -3,12 +3,9 @@ import sympy as sympy from torch._inductor.codegen.simd import (EnableReduction, DisableReduction) from torch._inductor.codegen.triton import TritonKernel from torch._inductor.loop_body import MemoryUsageType -from torch._inductor.runtime.runtime_utils import next_power_of_2 from torch._inductor.utils import ModularIndexing, sympy_subs from torch._inductor.virtualized import V -from .kernel_analysis import IndexAnalysis -from .triton_utils import get_aligned_numel from ..config import num_vector_core, log diff --git a/torch_npu/_inductor/codegen/tile_generator.py b/torch_npu/_inductor/codegen/tile_generator.py index 6195bcfa5a..f32000ce4b 100644 --- a/torch_npu/_inductor/codegen/tile_generator.py +++ b/torch_npu/_inductor/codegen/tile_generator.py @@ -2,13 +2,32 @@ import copy import functools import math import sys +import torch from torch._inductor.runtime.runtime_utils import next_power_of_2 from torch._inductor.runtime.triton_heuristics import Config -from .triton_utils import byte_per_numel from .. import config +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16 # torch.complex128 +} + + # generate tiling configs class TileGenerator: diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 7097ab14ae..c623e90189 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -47,8 +47,8 @@ from torch._inductor.codegen.triton import ( ) from torch._inductor.codegen.triton_utils import config_of, signature_of, signature_to_meta from torch._inductor.dtype_propagation import DtypePropagationOpsHandler -from torch._inductor.runtime.hints import ReductionHint from torch._inductor.runtime.runtime_utils import next_power_of_2 +from torch._inductor.runtime.triton_heuristics import cooperative_reduction_grid from torch._inductor.scheduler import SchedulerNode from torch._inductor.utils import ( Placeholder, @@ -74,6 +74,7 @@ import torch_npu._inductor.config as inductor_npu_config from .kernel_analysis import IndexAnalysis, ReductionAnalysis from .npu_kernel_features import NumelList from ..runtime import NPUDeviceProperties +from ..npu_triton_heuristics import grid as npu_grid_fn def flatten(nums): @@ -605,7 +606,11 @@ class NPUIndexTritonKernel(TritonKernel): size_hints.append(numel_expr) return size_hints - # torch251 done + def _get_grid_fn(self): + if self.cooperative_reduction: + return cooperative_reduction_grid + return npu_grid_fn + def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid): for node in self.sorted_axis: if isinstance(node.expr, ModularIndexing): @@ -715,6 +720,7 @@ class NPUIndexTritonKernel(TritonKernel): self.triton_meta = triton_meta self.gen_numel_args(signature, triton_meta_signature, argdefs) + self.triton_meta["configs"] = [config_of(signature)] # add in tiling args self.add_autotune_args(argdefs) diff --git a/torch_npu/_inductor/codegen/triton_utils.py b/torch_npu/_inductor/codegen/triton_utils.py index 6e72a7fea3..06de4fd4f9 100644 --- a/torch_npu/_inductor/codegen/triton_utils.py +++ b/torch_npu/_inductor/codegen/triton_utils.py @@ -16,31 +16,6 @@ from torch.utils._triton import patch_triton_dtype_repr from torch_npu._inductor.codegen.triton import NPUIndexTritonKernel, gen_npu_triton_ext_imports from torch_npu._inductor.runtime import NPUDeviceProperties -# wrapper npu 32 bytes align, get and pass unalign info to triton meta -# then autotune choose tiling param and send them to bishengIR -byte_per_numel = { - torch.float32: 4, # torch.float32 or torch.float - torch.float64: 8, # torch.float64 or torch.double - torch.float16: 2, # torch.float16 or torch.half - torch.bfloat16: 2, # torch.bfloat16 - torch.int32: 4, # torch.int32 or torch.int - torch.int64: 8, # torch.int64 or torch.long - torch.int16: 2, # torch.int16 or torch.short - torch.int8: 1, # torch.int8 - torch.uint8: 1, # torch.uint8 - torch.bool: 1, # torch.bool - torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) - torch.complex64: 8, # torch.complex64 - torch.complex128: 16 # torch.complex128 -} - - -def get_aligned_numel(dtype): - if dtype in byte_per_numel: - return 32 // byte_per_numel[dtype] - else: - return 1 - def write_npu_triton_header_once(wrapper: PythonWrapperCodegen) -> None: import_str = f""" diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index 98973b771e..3dc3d482ed 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -58,11 +58,7 @@ except ImportError: import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error -from .codegen.split_tiling import SplitTiling -from .utils import get_current_raw_stream from .codegen.tile_generator import TileGenerator -from .codegen.triton_utils import get_aligned_numel -from .config import aggresive_autotune from .config import log from . import config as npu_config -- Gitee