diff --git a/test/_inductor/test_add_sum.py b/test/_inductor/test_add_sum.py index bafa69ddb4b9eedc2f96b23fc8c008798ac2f6d2..e173e0fe84e40bfec41930e647d577740c4801d8 100644 --- a/test/_inductor/test_add_sum.py +++ b/test/_inductor/test_add_sum.py @@ -14,20 +14,22 @@ class TestSumAdd(TestUtils): @parametrize('shape', [(9, 9, 31, 64)]) @parametrize('dim', [3]) @parametrize('dtype', ['float32']) - def test_reduction_cases_shapes(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + @parametrize('dynamic', [False, True]) + def test_reduction_cases_shapes(self, shape, dim, dtype, dynamic): + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor", dynamic=dynamic) r = func(a, b, dim) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) @parametrize('shape', [(9, 10, 31, 63)]) @parametrize('dim', [0, 1]) @parametrize('dtype', ['float32']) - def test_reduction_cases_shapes1(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + @parametrize('dynamic', [False, True]) + def test_reduction_cases_shapes1(self, shape, dim, dtype, dynamic): + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor", dynamic=dynamic) r = func(a, b, dim) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_arrange.py b/test/_inductor/test_arange.py similarity index 97% rename from test/_inductor/test_arrange.py rename to test/_inductor/test_arange.py index f80b2fb92f4fe64eea01eaba1c4016e043ff981f..b9f99713a7b7881c0cff28728b0f2e643fa4f912 100644 --- a/test/_inductor/test_arrange.py +++ b/test/_inductor/test_arange.py @@ -20,7 +20,7 @@ class TestArrange(TestUtils): std_arrange = self.op_calc(start, end, step) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") inductor_arrange = compiled_op_calc(start, end, step) self.assertEqual(std_arrange, inductor_arrange) diff --git a/test/_inductor/test_broadcast.py b/test/_inductor/test_broadcast.py index 93e78f03519246f0570b79002990e1f23f68ab9c..1f899597493ba88d4698ad86f47405f039c9eeba 100644 --- a/test/_inductor/test_broadcast.py +++ b/test/_inductor/test_broadcast.py @@ -23,7 +23,7 @@ class TestBroadcast(TestUtils): a = self._generate_tensor(shape, dtype) b = self._generate_tensor(shape, dtype) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") for dim in [3, 2, 1, 0]: new_shape = list(copy.deepcopy(shape)) new_shape.insert(dim, self.broadcast_size) diff --git a/test/_inductor/test_cat.py b/test/_inductor/test_cat.py index 26d89caaa8dd679975d020621fdeafbb79f78c0a..cf09712b5641e450c5bd71ec1cf32e6b892ce69a 100644 --- a/test/_inductor/test_cat.py +++ b/test/_inductor/test_cat.py @@ -10,7 +10,7 @@ class TestCat(TestUtils): return torch.cat([input_element, input_element], dim) # case:change shapes - @parametrize('shape', [(8, 16, 32, 64)]) + @parametrize('shape', [(8, 16, 32, 64), (16, 32, 64, 128)]) @parametrize('dim', [-1]) @parametrize('dtype', ['bfloat16']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_empty.py b/test/_inductor/test_empty.py index 06f7323a8ebdb6685868bf4e44e37bd2ed656464..0f8b3aa12ae0d10d950b822c3a13ac9471301a41 100644 --- a/test/_inductor/test_empty.py +++ b/test/_inductor/test_empty.py @@ -18,7 +18,7 @@ class TestEmpty(TestUtils): return x # case: change shapes - @parametrize('shape', [(8, 64, 128)]) + @parametrize('shape', [(8, 64, 128), (16, 128, 256)]) @parametrize('dim', [0]) @parametrize('dtype', ['float32']) def test_cases_empty(self, shape, dim, dtype): @@ -29,7 +29,7 @@ class TestEmpty(TestUtils): self.assertTrue(inductor_ret.numel() > 0) - @parametrize('shape', [(8, 64, 128)]) + @parametrize('shape', [(8, 64, 128), (16, 128, 256)]) @parametrize('dim', [0]) @parametrize('dtype', ['float32']) def test_cases_empty_permuted(self, shape, dim, dtype): diff --git a/test/_inductor/test_high_order_sum.py b/test/_inductor/test_high_order_sum.py index a0253c261f26c868b6e69b9ebb6ec77b631253eb..21242c07a36929315594ef667524d4ff8d775df3 100644 --- a/test/_inductor/test_high_order_sum.py +++ b/test/_inductor/test_high_order_sum.py @@ -15,7 +15,7 @@ class TestSum(TestUtils): def test_high_order_sum(self): npu_dropout_backward_9 = torch.randn((32768, 256), device='npu', dtype=torch.float32) ref = self.op_sum(npu_dropout_backward_9) - func = torch.compile(self.op_sum, backend="inductor", dynamic=False) + func = torch.compile(self.op_sum, backend="inductor") calc = func(npu_dropout_backward_9) self.assertEqual(ref, calc, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_reduction_brocast_add.py b/test/_inductor/test_reduction_brocast_add.py index fb29fa1516e2b9620a518e86d746b519a01985f1..915eb8ac0d06c042713ef95ef794bed07dcd423c 100644 --- a/test/_inductor/test_reduction_brocast_add.py +++ b/test/_inductor/test_reduction_brocast_add.py @@ -13,13 +13,13 @@ class TestSumAdd(TestUtils): return y # case:change shapes - @parametrize('shape', [(9, 9, 31, 63)]) + @parametrize('shape', [(9, 9, 31, 63), (11, 11, 63, 127)]) @parametrize('dim', [0, 1, 2]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes1(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim, shape) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor") r = func(a, b, dim, shape) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_renorm.py b/test/_inductor/test_renorm.py index 0d49727221a41ca79bea7841ac99698ded23e19f..3412c50c3f4ebaf6bdfbcece4ab489a9aaecb3fc 100644 --- a/test/_inductor/test_renorm.py +++ b/test/_inductor/test_renorm.py @@ -9,7 +9,7 @@ class TestRenorm(TestUtils): return torch.renorm(input_element, p=2, dim=dim, maxnorm=5) # case:change shapes - @parametrize('shape', [(32, 64)]) + @parametrize('shape', [(32, 64), (64, 128)]) @parametrize('dim', [-1]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_repeat.py b/test/_inductor/test_repeat.py index 9df53202ac2d676e7963275148db287d77f69aaf..fca6604b90c84f9c2298557b2196a5fc0dab6cf2 100644 --- a/test/_inductor/test_repeat.py +++ b/test/_inductor/test_repeat.py @@ -9,7 +9,7 @@ class TestRepeat(TestUtils): return input_element.repeat(dim) # case:change shapes - @parametrize('shape', [(16, 128, 64)]) + @parametrize('shape', [(16, 128, 64), (32, 256, 128)]) @parametrize('dim', [(1, 1, 2), (1, 2, 1), (2, 1, 1)]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_split_loop.py b/test/_inductor/test_split_loop.py index 840de0a95d2b2e9f88881088938f8af004f60f82..5276245b38f40e77a04a80fc5d5270c82643c156 100644 --- a/test/_inductor/test_split_loop.py +++ b/test/_inductor/test_split_loop.py @@ -17,7 +17,7 @@ class TestSplitLoop(TestUtils): std_ = self.op_calc(a, b) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") inductor_ = compiled_op_calc(a, b) self.assertEqual(std_, inductor_, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_trans_to_npu.py b/test/_inductor/test_trans_to_npu.py index 8c14f7dcd0635d1d00c0384a78ca664a8eea2043..07c16665d75531ddbb0ae93569b8f14a5ab3b794 100644 --- a/test/_inductor/test_trans_to_npu.py +++ b/test/_inductor/test_trans_to_npu.py @@ -19,7 +19,7 @@ class TestTransToNpu(TestUtils): std_result = self.op_add(input_element1, input_element2) - compiled_op_add = torch.compile(self.op_add, backend="inductor", dynamic=False) + compiled_op_add = torch.compile(self.op_add, backend="inductor") inductor_result1 = compiled_op_add(input_element1, input_element2) torch.testing.assert_close(std_result, inductor_result1, atol=1e-3, rtol=1e-3) diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 782cc9f7455cd6d9f4eea63075768fdeb1af0690..fc319c7bc5e24ccc445e1f3863d84d0fd04233ee 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -1,6 +1,6 @@ from functools import reduce import sympy as sympy -from torch._inductor.codegen.simd import (EnableReduction, DisableReduction) +from torch._inductor.codegen.simd import EnableReduction, DisableReduction from torch._inductor.codegen.triton import TritonKernel from torch._inductor.loop_body import MemoryUsageType from torch._inductor.runtime.runtime_utils import next_power_of_2 @@ -16,7 +16,7 @@ from ..config import num_vector_core, log class SplitTiling: def __init__(self, kernel: TritonKernel): self.kernel = kernel - self.indexing = [] # load and store indexing among all scheduler nodes + self.indexing = [] # load and store indexing among all scheduler nodes kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] kernel.sorted_axis.sort(reverse=True, key=self.key) for i, dim in enumerate(kernel.sorted_axis): @@ -61,9 +61,17 @@ class SplitTiling: else: return x.name + @staticmethod + def get_length_val(x): + length_expr = x.length + if not isinstance(length_expr, sympy.Integer): + return length_expr.subs(V.graph.sizevars.var_to_val) + else: + return length_expr + @classmethod def total_split_numels(cls, axis_list): - numels = [x.length for x in axis_list] + numels = [cls.get_length_val(x) for x in axis_list] return reduce(lambda x, y: x * y, numels) if numels else 1 # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低split和tiling轴选择策略的复杂性 。 @@ -114,7 +122,7 @@ class SplitTiling: for i, x in enumerate(self.kernel.split_axis): x.split_order = i - # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. + # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. # Tiling 原则2:对于规约算子,规约轴要成为tiling轴。 # Tiling 原则3: 多维规约, 只有规约轴可以被选择为tiling轴 # Tiling 原则4: tiling轴 要覆盖 total numel 的 80% @@ -125,26 +133,51 @@ class SplitTiling: # cover the biggest axis and not exceed 3 axis def meet_stop_condition(): - total_numel = reduce(lambda x, y: x + y, - map(lambda x: x.length, self.kernel.sorted_axis)) if self.kernel.sorted_axis else 1 - tiling_numel = reduce(lambda x, y: x + y, - map(lambda x: x.length, self.kernel.tiling_axis)) if self.kernel.tiling_axis else 1 + total_numel = ( + reduce( + lambda x, y: x + y, + map(lambda x: self.get_length_val(x), self.kernel.sorted_axis), + ) + if self.kernel.sorted_axis + else 1 + ) + tiling_numel = ( + reduce( + lambda x, y: x + y, + map(lambda x: self.get_length_val(x), self.kernel.tiling_axis), + ) + if self.kernel.tiling_axis + else 1 + ) if self.kernel.numof_reduction_axis() > 1 and all( - self.kernel.range_tree_nodes[var].is_tiling_axis for var in self.kernel.reduction_axis_list()): + self.kernel.range_tree_nodes[var].is_tiling_axis + for var in self.kernel.reduction_axis_list() + ): return True # currently, the maximum dim that triton-ascend support is 2 max_transpose_dims = 2 - if (self.possible_need_permute or tiling_numel / total_numel >= 0.8) and \ - len(self.kernel.tiling_axis) >= min(max_transpose_dims, len(self.kernel.sorted_axis)): + if ( + self.possible_need_permute or tiling_numel / total_numel >= 0.8 + ) and len(self.kernel.tiling_axis) >= min( + max_transpose_dims, len(self.kernel.sorted_axis) + ): return True return False def select_tiling(low_dim=True, reduction=True): for axis in reversed(self.kernel.sorted_axis): - if low_dim and axis.sorted_order in self.kernel.low_dims and axis not in self.kernel.tiling_axis: + if ( + low_dim + and axis.sorted_order in self.kernel.low_dims + and axis not in self.kernel.tiling_axis + ): axis.is_tiling_axis = True self.kernel.tiling_axis.append(axis) - if reduction and axis.prefix == 'r' and axis not in self.kernel.tiling_axis: + if ( + reduction + and axis.prefix == "r" + and axis not in self.kernel.tiling_axis + ): axis.is_tiling_axis = True self.kernel.tiling_axis.append(axis) if low_dim or reduction: @@ -174,8 +207,11 @@ class SplitTiling: # the below logic doesn't work when there're two reduction axis, but only one need outer reduction def should_outer_reduce_me(self, x): - should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, - 32768) and x.is_loop + should_outer = ( + self.kernel.is_higher_order_reduction(True) + and SplitTiling.great_than(x.length, 32768) + and x.is_loop + ) if should_outer: self.should_outer_reduce = True self.kernel.split_axis = x @@ -185,8 +221,8 @@ class SplitTiling: def find_longest_dimension(self, check_in_tiling=False): longest = None for axis in self.kernel.sorted_axis: - if (longest is None or axis.length > longest.length) and \ - (not check_in_tiling or axis not in self.kernel.tiling_axis): + not_tiling = not check_in_tiling or axis not in self.kernel.tiling_axis + if (longest is None or axis.length > longest.length) and not_tiling: longest = axis return longest @@ -253,10 +289,14 @@ class SplitTiling: def convert(x, y): xnumel = x ynumel = y - if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance( + xnumel, sympy.Integer + ): xnumel = xnumel.subs(V.graph.sizevars.var_to_val) - if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance( + ynumel, sympy.Integer + ): ynumel = ynumel.subs(V.graph.sizevars.var_to_val) if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 7097ab14aec91d81187ee698cdd22214a4bad52d..201136d3f6f072ac01597b46b31c50f7df24dbc8 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -472,14 +472,6 @@ class NPUIndexTritonKernel(TritonKernel): pass def initialize_range_tree(self, pid_cache): - self.total_numels = 0 - for k, x in self.numels.items(): - if not isinstance(x, sympy.Integer): - x = x.subs(V.graph.sizevars.var_to_val) - self.numels[k] = x - if x > 1: - self.total_numels += 1 - no_r_dim = not self.inside_reduction or self.numels["r"] == 1 prefixes = "wvtzyxr" active_prefixes = prefixes[-len(self.numels):] @@ -850,7 +842,7 @@ class NPUIndexTritonKernel(TritonKernel): def codegen_range(index): def is_1d_reduction(): - return self.numels["r"] > 1 and len(self.numels) == 1 + return V.graph.sizevars.statically_known_gt(self.numels["r"], 1) and len(self.numels) == 1 def loop_body(index, indexing_code, is_last_axis, do_indent=True): if do_indent: @@ -1573,7 +1565,7 @@ class NPUIndexTritonKernel(TritonKernel): def add_range(i, expr): expr = sv.simplify(expr) if not sv.statically_known_multiple_of(remaining[i], expr): - raise CantSplit() + raise CantSplit # guard on the last item out remaining[i] = FloorDiv(remaining[i], expr) new_ranges[i].append(expr) @@ -1602,17 +1594,21 @@ class NPUIndexTritonKernel(TritonKernel): # Two checks: # 1. remaining sizes to be merged # 2. remained_size is already divided to 1 - while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + while ( + group < len(remaining) + and sv.statically_known_gt(remaining[group], 1) + and sv.statically_known_gt(remained_size, 1) + ): group_size = remaining[group] # size should be divisible by group_size if not sv.statically_known_multiple_of(remained_size, group_size): - raise CantSplit() + raise CantSplit index_list.append(add_range(group, group_size)) remained_size = FloorDiv(remained_size, group_size) stride_list.append(remained_size) group = group + 1 if remained_size != 1: - raise CantSplit() + raise CantSplit return_getters.append(make_combined(stride_list, index_list)) return_getters_groups = [] @@ -1622,22 +1618,21 @@ class NPUIndexTritonKernel(TritonKernel): return_getters = [] for size in length_group: if sv.statically_known_equals(size, 1): # type: ignore[arg-type] - return_getters.append(lambda _: sympy.Integer(0)) + return_getters.append(lambda _: sympy.S.Zero) continue while ( - current_group < len(remaining) - and size_hints(remaining[current_group]) == 1 + current_group < len(remaining) + and sv.statically_known_equals(remaining[current_group], 1) ): # scroll to next group with remaining elements current_group += 1 - size_hint = sv.size_hint(size) - if size_hint > size_hints(remaining[current_group]): + if sv.statically_known_gt(size, remaining[current_group]): # add multiple ranges (two or more) to the list, as well as the getter funcs - add_multiple_range(size_hint, return_getters) + add_multiple_range(size, return_getters) else: return_getters.append( - operator.itemgetter(add_range(current_group, size_hint)) + operator.itemgetter(add_range(current_group, size)) ) return_getters_groups.append(return_getters) diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py index 508bc85bd97d269bd43e6d582a5aeea223f97221..86a63ee301554075cc4eee22e0a85b718e5fc065 100644 --- a/torch_npu/_inductor/codegen/wrapper.py +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -7,8 +7,13 @@ import sympy import torch from torch._inductor import config -from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen, \ - user_defined_kernel_grid_fn_code +from torch._inductor.codegen.wrapper import ( + PythonWrapperCodegen, + SymbolicCallArg, + SubgraphPythonWrapperCodegen, + user_defined_kernel_grid_fn_code, + pexpr, +) from torch._inductor.ir import IRNode from torch._inductor.runtime import triton_heuristics from torch._inductor.utils import ( @@ -74,14 +79,7 @@ class NPUWrapperCodeGen(PythonWrapperCodegen): # generate numel expr for range_tree_node def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): expr = f"{kernel_name}_{node.name}_numel" - if (expr, V.graph) not in self.kernel_numel_expr: - # declare expr once in each graph (scope) - self.kernel_numel_expr.add((expr, V.graph)) - self.writeline( - f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" - ) - else: - self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + self.writeline(f"{expr} = {pexpr(numel_expr)}") # We can get symbolic expressions here, like s0*64 # It is fine to have them here, but we need to handle them correctly as their own type # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy*