From ef046b62e90330dd6f3a82c8f0afd9e6fc4a9a83 Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Sun, 17 Aug 2025 16:10:18 +0800 Subject: [PATCH 1/6] fix triton codegen for dynamic shape --- torch_npu/_inductor/codegen/split_tiling.py | 79 ++++++++++++++++----- torch_npu/_inductor/codegen/triton.py | 35 ++++----- torch_npu/_inductor/codegen/wrapper.py | 16 ++--- 3 files changed, 82 insertions(+), 48 deletions(-) diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 782cc9f745..4b86845a28 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -1,6 +1,6 @@ from functools import reduce import sympy as sympy -from torch._inductor.codegen.simd import (EnableReduction, DisableReduction) +from torch._inductor.codegen.simd import EnableReduction, DisableReduction from torch._inductor.codegen.triton import TritonKernel from torch._inductor.loop_body import MemoryUsageType from torch._inductor.runtime.runtime_utils import next_power_of_2 @@ -16,7 +16,7 @@ from ..config import num_vector_core, log class SplitTiling: def __init__(self, kernel: TritonKernel): self.kernel = kernel - self.indexing = [] # load and store indexing among all scheduler nodes + self.indexing = [] # load and store indexing among all scheduler nodes kernel.sorted_axis = [x for x in kernel.range_tree_nodes.values()] kernel.sorted_axis.sort(reverse=True, key=self.key) for i, dim in enumerate(kernel.sorted_axis): @@ -61,9 +61,17 @@ class SplitTiling: else: return x.name + @staticmethod + def get_length_val(x): + length_expr = x.length + if not isinstance(length_expr, sympy.Integer): + return length_expr.subs(V.graph.sizevars.var_to_val) + else: + return length_expr + @classmethod def total_split_numels(cls, axis_list): - numels = [x.length for x in axis_list] + numels = [cls.get_length_val(x) for x in axis_list] return reduce(lambda x, y: x * y, numels) if numels else 1 # Split 原则1 :先做维度合并,再切分 。通过维度合并降维降低split和tiling轴选择策略的复杂性 。 @@ -114,7 +122,7 @@ class SplitTiling: for i, x in enumerate(self.kernel.split_axis): x.split_order = i - # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. + # Tiling 原则1:load / store 中索引表达式的中的低维轴都要成为tiling 轴. # Tiling 原则2:对于规约算子,规约轴要成为tiling轴。 # Tiling 原则3: 多维规约, 只有规约轴可以被选择为tiling轴 # Tiling 原则4: tiling轴 要覆盖 total numel 的 80% @@ -125,26 +133,51 @@ class SplitTiling: # cover the biggest axis and not exceed 3 axis def meet_stop_condition(): - total_numel = reduce(lambda x, y: x + y, - map(lambda x: x.length, self.kernel.sorted_axis)) if self.kernel.sorted_axis else 1 - tiling_numel = reduce(lambda x, y: x + y, - map(lambda x: x.length, self.kernel.tiling_axis)) if self.kernel.tiling_axis else 1 + total_numel = ( + reduce( + lambda x, y: x + y, + map(lambda x: self.get_length_val(x), self.kernel.sorted_axis), + ) + if self.kernel.sorted_axis + else 1 + ) + tiling_numel = ( + reduce( + lambda x, y: x + y, + map(lambda x: self.get_length_val(x), self.kernel.tiling_axis), + ) + if self.kernel.tiling_axis + else 1 + ) if self.kernel.numof_reduction_axis() > 1 and all( - self.kernel.range_tree_nodes[var].is_tiling_axis for var in self.kernel.reduction_axis_list()): + self.kernel.range_tree_nodes[var].is_tiling_axis + for var in self.kernel.reduction_axis_list() + ): return True # currently, the maximum dim that triton-ascend support is 2 max_transpose_dims = 2 - if (self.possible_need_permute or tiling_numel / total_numel >= 0.8) and \ - len(self.kernel.tiling_axis) >= min(max_transpose_dims, len(self.kernel.sorted_axis)): + if ( + self.possible_need_permute or tiling_numel / total_numel >= 0.8 + ) and len(self.kernel.tiling_axis) >= min( + max_transpose_dims, len(self.kernel.sorted_axis) + ): return True return False def select_tiling(low_dim=True, reduction=True): for axis in reversed(self.kernel.sorted_axis): - if low_dim and axis.sorted_order in self.kernel.low_dims and axis not in self.kernel.tiling_axis: + if ( + low_dim + and axis.sorted_order in self.kernel.low_dims + and axis not in self.kernel.tiling_axis + ): axis.is_tiling_axis = True self.kernel.tiling_axis.append(axis) - if reduction and axis.prefix == 'r' and axis not in self.kernel.tiling_axis: + if ( + reduction + and axis.prefix == "r" + and axis not in self.kernel.tiling_axis + ): axis.is_tiling_axis = True self.kernel.tiling_axis.append(axis) if low_dim or reduction: @@ -174,8 +207,11 @@ class SplitTiling: # the below logic doesn't work when there're two reduction axis, but only one need outer reduction def should_outer_reduce_me(self, x): - should_outer = self.kernel.is_higher_order_reduction(True) and SplitTiling.great_than(x.length, - 32768) and x.is_loop + should_outer = ( + self.kernel.is_higher_order_reduction(True) + and SplitTiling.great_than(x.length, 32768) + and x.is_loop + ) if should_outer: self.should_outer_reduce = True self.kernel.split_axis = x @@ -185,8 +221,9 @@ class SplitTiling: def find_longest_dimension(self, check_in_tiling=False): longest = None for axis in self.kernel.sorted_axis: - if (longest is None or axis.length > longest.length) and \ - (not check_in_tiling or axis not in self.kernel.tiling_axis): + if (longest is None or axis.length > longest.length) and ( + not check_in_tiling or axis not in self.kernel.tiling_axis + ): longest = axis return longest @@ -253,10 +290,14 @@ class SplitTiling: def convert(x, y): xnumel = x ynumel = y - if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance(xnumel, sympy.Integer): + if isinstance(xnumel, (sympy.Symbol, sympy.Expr)) and not isinstance( + xnumel, sympy.Integer + ): xnumel = xnumel.subs(V.graph.sizevars.var_to_val) - if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance(ynumel, sympy.Integer): + if isinstance(ynumel, (sympy.Symbol, sympy.Expr)) and not isinstance( + ynumel, sympy.Integer + ): ynumel = ynumel.subs(V.graph.sizevars.var_to_val) if isinstance(xnumel, sympy.Integer) and isinstance(ynumel, int): diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index a729a8560c..63da1d2849 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -471,14 +471,6 @@ class NPUIndexTritonKernel(TritonKernel): pass def initialize_range_tree(self, pid_cache): - self.total_numels = 0 - for k, x in self.numels.items(): - if not isinstance(x, sympy.Integer): - x = x.subs(V.graph.sizevars.var_to_val) - self.numels[k] = x - if x > 1: - self.total_numels += 1 - no_r_dim = not self.inside_reduction or self.numels["r"] == 1 prefixes = "wvtzyxr" active_prefixes = prefixes[-len(self.numels):] @@ -842,7 +834,7 @@ class NPUIndexTritonKernel(TritonKernel): def codegen_range(index): def is_1d_reduction(): - return self.numels["r"] > 1 and len(self.numels) == 1 + return V.graph.sizevars.statically_known_gt(self.numels["r"], 1) and len(self.numels) == 1 def loop_body(index, indexing_code, is_last_axis, do_indent=True): if do_indent: @@ -1565,7 +1557,7 @@ class NPUIndexTritonKernel(TritonKernel): def add_range(i, expr): expr = sv.simplify(expr) if not sv.statically_known_multiple_of(remaining[i], expr): - raise CantSplit() + raise CantSplit # guard on the last item out remaining[i] = FloorDiv(remaining[i], expr) new_ranges[i].append(expr) @@ -1594,17 +1586,21 @@ class NPUIndexTritonKernel(TritonKernel): # Two checks: # 1. remaining sizes to be merged # 2. remained_size is already divided to 1 - while (group < len(remaining) and remaining[group] > 1) and (remained_size > 1): + while ( + group < len(remaining) + and sv.statically_known_gt(remaining[group], 1) + and sv.statically_known_gt(remained_size, 1) + ): group_size = remaining[group] # size should be divisible by group_size if not sv.statically_known_multiple_of(remained_size, group_size): - raise CantSplit() + raise CantSplit index_list.append(add_range(group, group_size)) remained_size = FloorDiv(remained_size, group_size) stride_list.append(remained_size) group = group + 1 if remained_size != 1: - raise CantSplit() + raise CantSplit return_getters.append(make_combined(stride_list, index_list)) return_getters_groups = [] @@ -1614,22 +1610,21 @@ class NPUIndexTritonKernel(TritonKernel): return_getters = [] for size in length_group: if sv.statically_known_equals(size, 1): # type: ignore[arg-type] - return_getters.append(lambda _: sympy.Integer(0)) + return_getters.append(lambda _: sympy.S.Zero) continue while ( - current_group < len(remaining) - and size_hints(remaining[current_group]) == 1 + current_group < len(remaining) + and sv.statically_known_equals(remaining[current_group], 1) ): # scroll to next group with remaining elements current_group += 1 - size_hint = sv.size_hint(size) - if size_hint > size_hints(remaining[current_group]): + if sv.statically_known_gt(size, remaining[current_group]): # add multiple ranges (two or more) to the list, as well as the getter funcs - add_multiple_range(size_hint, return_getters) + add_multiple_range(size, return_getters) else: return_getters.append( - operator.itemgetter(add_range(current_group, size_hint)) + operator.itemgetter(add_range(current_group, size)) ) return_getters_groups.append(return_getters) diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py index 1b28c273bb..b1f80cac0a 100644 --- a/torch_npu/_inductor/codegen/wrapper.py +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -5,7 +5,12 @@ import sympy import torch from torch._inductor import config -from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg, SubgraphPythonWrapperCodegen +from torch._inductor.codegen.wrapper import ( + PythonWrapperCodegen, + SymbolicCallArg, + SubgraphPythonWrapperCodegen, + pexpr, +) from torch._inductor.runtime import triton_heuristics from torch._inductor.utils import ( cache_on_self, @@ -69,14 +74,7 @@ class NPUWrapperCodeGen(PythonWrapperCodegen): # generate numel expr for range_tree_node def generate_node_numel_expr(self, kernel_name: str, node, numel_expr): expr = f"{kernel_name}_{node.name}_numel" - if (expr, V.graph) not in self.kernel_numel_expr: - # declare expr once in each graph (scope) - self.kernel_numel_expr.add((expr, V.graph)) - self.writeline( - f"{self.declare}{expr} = {self.expr_printer(numel_expr)}{self.ending}" - ) - else: - self.writeline(f"{expr} = {self.expr_printer(numel_expr)}{self.ending}") + self.writeline(f"{expr} = {pexpr(numel_expr)}") # We can get symbolic expressions here, like s0*64 # It is fine to have them here, but we need to handle them correctly as their own type # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* -- Gitee From f83b231e1f42027ed6302374a592e95a79d53dfb Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Wed, 27 Aug 2025 22:21:22 +0800 Subject: [PATCH 2/6] modify test for dynamic shape --- test/_inductor/test_abs.py | 2 +- test/_inductor/test_add_sum.py | 14 ++++++++------ test/_inductor/{test_arrange.py => test_arange.py} | 2 +- test/_inductor/test_broadcast.py | 2 +- test/_inductor/test_cat.py | 2 +- test/_inductor/test_empty.py | 4 ++-- test/_inductor/test_high_order_sum.py | 2 +- test/_inductor/test_reduction_brocast_add.py | 6 +++--- test/_inductor/test_renorm.py | 2 +- test/_inductor/test_repeat.py | 2 +- test/_inductor/test_split_loop.py | 2 +- test/_inductor/test_trans_to_npu.py | 2 +- 12 files changed, 22 insertions(+), 20 deletions(-) rename test/_inductor/{test_arrange.py => test_arange.py} (97%) diff --git a/test/_inductor/test_abs.py b/test/_inductor/test_abs.py index 62482afd7a..ed34ffb2c7 100644 --- a/test/_inductor/test_abs.py +++ b/test/_inductor/test_abs.py @@ -9,7 +9,7 @@ class TestAbs(TestUtils): result = torch.abs(first_element) return result - @parametrize('shape', [(1024, 32), (256, 8)]) + @parametrize('shape', [(1024, 32), (256, 8), (512, 64)]) @parametrize('dtype', ['float16', 'float32', 'bfloat16']) def test_pointwise_cases(self, shape, dtype): first_element = self._generate_tensor(shape, dtype) diff --git a/test/_inductor/test_add_sum.py b/test/_inductor/test_add_sum.py index bafa69ddb4..e173e0fe84 100644 --- a/test/_inductor/test_add_sum.py +++ b/test/_inductor/test_add_sum.py @@ -14,20 +14,22 @@ class TestSumAdd(TestUtils): @parametrize('shape', [(9, 9, 31, 64)]) @parametrize('dim', [3]) @parametrize('dtype', ['float32']) - def test_reduction_cases_shapes(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + @parametrize('dynamic', [False, True]) + def test_reduction_cases_shapes(self, shape, dim, dtype, dynamic): + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor", dynamic=dynamic) r = func(a, b, dim) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) @parametrize('shape', [(9, 10, 31, 63)]) @parametrize('dim', [0, 1]) @parametrize('dtype', ['float32']) - def test_reduction_cases_shapes1(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + @parametrize('dynamic', [False, True]) + def test_reduction_cases_shapes1(self, shape, dim, dtype, dynamic): + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor", dynamic=dynamic) r = func(a, b, dim) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_arrange.py b/test/_inductor/test_arange.py similarity index 97% rename from test/_inductor/test_arrange.py rename to test/_inductor/test_arange.py index f80b2fb92f..b9f99713a7 100644 --- a/test/_inductor/test_arrange.py +++ b/test/_inductor/test_arange.py @@ -20,7 +20,7 @@ class TestArrange(TestUtils): std_arrange = self.op_calc(start, end, step) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") inductor_arrange = compiled_op_calc(start, end, step) self.assertEqual(std_arrange, inductor_arrange) diff --git a/test/_inductor/test_broadcast.py b/test/_inductor/test_broadcast.py index 93e78f0351..1f89959749 100644 --- a/test/_inductor/test_broadcast.py +++ b/test/_inductor/test_broadcast.py @@ -23,7 +23,7 @@ class TestBroadcast(TestUtils): a = self._generate_tensor(shape, dtype) b = self._generate_tensor(shape, dtype) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") for dim in [3, 2, 1, 0]: new_shape = list(copy.deepcopy(shape)) new_shape.insert(dim, self.broadcast_size) diff --git a/test/_inductor/test_cat.py b/test/_inductor/test_cat.py index 26d89caaa8..cf09712b56 100644 --- a/test/_inductor/test_cat.py +++ b/test/_inductor/test_cat.py @@ -10,7 +10,7 @@ class TestCat(TestUtils): return torch.cat([input_element, input_element], dim) # case:change shapes - @parametrize('shape', [(8, 16, 32, 64)]) + @parametrize('shape', [(8, 16, 32, 64), (16, 32, 64, 128)]) @parametrize('dim', [-1]) @parametrize('dtype', ['bfloat16']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_empty.py b/test/_inductor/test_empty.py index 06f7323a8e..0f8b3aa12a 100644 --- a/test/_inductor/test_empty.py +++ b/test/_inductor/test_empty.py @@ -18,7 +18,7 @@ class TestEmpty(TestUtils): return x # case: change shapes - @parametrize('shape', [(8, 64, 128)]) + @parametrize('shape', [(8, 64, 128), (16, 128, 256)]) @parametrize('dim', [0]) @parametrize('dtype', ['float32']) def test_cases_empty(self, shape, dim, dtype): @@ -29,7 +29,7 @@ class TestEmpty(TestUtils): self.assertTrue(inductor_ret.numel() > 0) - @parametrize('shape', [(8, 64, 128)]) + @parametrize('shape', [(8, 64, 128), (16, 128, 256)]) @parametrize('dim', [0]) @parametrize('dtype', ['float32']) def test_cases_empty_permuted(self, shape, dim, dtype): diff --git a/test/_inductor/test_high_order_sum.py b/test/_inductor/test_high_order_sum.py index a0253c261f..21242c07a3 100644 --- a/test/_inductor/test_high_order_sum.py +++ b/test/_inductor/test_high_order_sum.py @@ -15,7 +15,7 @@ class TestSum(TestUtils): def test_high_order_sum(self): npu_dropout_backward_9 = torch.randn((32768, 256), device='npu', dtype=torch.float32) ref = self.op_sum(npu_dropout_backward_9) - func = torch.compile(self.op_sum, backend="inductor", dynamic=False) + func = torch.compile(self.op_sum, backend="inductor") calc = func(npu_dropout_backward_9) self.assertEqual(ref, calc, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_reduction_brocast_add.py b/test/_inductor/test_reduction_brocast_add.py index fb29fa1516..63e5b67b18 100644 --- a/test/_inductor/test_reduction_brocast_add.py +++ b/test/_inductor/test_reduction_brocast_add.py @@ -13,13 +13,13 @@ class TestSumAdd(TestUtils): return y # case:change shapes - @parametrize('shape', [(9, 9, 31, 63)]) + @parametrize('shape', [(9, 9, 31, 63), (11, 11, 63, 127)]) @parametrize('dim', [0, 1, 2]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes1(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=torch.float32, device="npu") for _ in range(2)] + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim, shape) - func = torch.compile(self.foo, backend="inductor", dynamic=False) + func = torch.compile(self.foo, backend="inductor") r = func(a, b, dim, shape) self.assertEqual(r, r1, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_renorm.py b/test/_inductor/test_renorm.py index 0d49727221..3412c50c3f 100644 --- a/test/_inductor/test_renorm.py +++ b/test/_inductor/test_renorm.py @@ -9,7 +9,7 @@ class TestRenorm(TestUtils): return torch.renorm(input_element, p=2, dim=dim, maxnorm=5) # case:change shapes - @parametrize('shape', [(32, 64)]) + @parametrize('shape', [(32, 64), (64, 128)]) @parametrize('dim', [-1]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_repeat.py b/test/_inductor/test_repeat.py index 9df53202ac..fca6604b90 100644 --- a/test/_inductor/test_repeat.py +++ b/test/_inductor/test_repeat.py @@ -9,7 +9,7 @@ class TestRepeat(TestUtils): return input_element.repeat(dim) # case:change shapes - @parametrize('shape', [(16, 128, 64)]) + @parametrize('shape', [(16, 128, 64), (32, 256, 128)]) @parametrize('dim', [(1, 1, 2), (1, 2, 1), (2, 1, 1)]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes(self, shape, dim, dtype): diff --git a/test/_inductor/test_split_loop.py b/test/_inductor/test_split_loop.py index 840de0a95d..5276245b38 100644 --- a/test/_inductor/test_split_loop.py +++ b/test/_inductor/test_split_loop.py @@ -17,7 +17,7 @@ class TestSplitLoop(TestUtils): std_ = self.op_calc(a, b) - compiled_op_calc = torch.compile(self.op_calc, backend="inductor", dynamic=False) + compiled_op_calc = torch.compile(self.op_calc, backend="inductor") inductor_ = compiled_op_calc(a, b) self.assertEqual(std_, inductor_, atol=1e-3, rtol=1e-3) diff --git a/test/_inductor/test_trans_to_npu.py b/test/_inductor/test_trans_to_npu.py index 8c14f7dcd0..07c16665d7 100644 --- a/test/_inductor/test_trans_to_npu.py +++ b/test/_inductor/test_trans_to_npu.py @@ -19,7 +19,7 @@ class TestTransToNpu(TestUtils): std_result = self.op_add(input_element1, input_element2) - compiled_op_add = torch.compile(self.op_add, backend="inductor", dynamic=False) + compiled_op_add = torch.compile(self.op_add, backend="inductor") inductor_result1 = compiled_op_add(input_element1, input_element2) torch.testing.assert_close(std_result, inductor_result1, atol=1e-3, rtol=1e-3) -- Gitee From d72bbd15b0f3d39b00dc8817c4b1fd462d20892b Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Thu, 28 Aug 2025 09:40:43 +0800 Subject: [PATCH 3/6] clean code --- torch_npu/_inductor/codegen/split_tiling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch_npu/_inductor/codegen/split_tiling.py b/torch_npu/_inductor/codegen/split_tiling.py index 4b86845a28..fc319c7bc5 100644 --- a/torch_npu/_inductor/codegen/split_tiling.py +++ b/torch_npu/_inductor/codegen/split_tiling.py @@ -221,9 +221,8 @@ class SplitTiling: def find_longest_dimension(self, check_in_tiling=False): longest = None for axis in self.kernel.sorted_axis: - if (longest is None or axis.length > longest.length) and ( - not check_in_tiling or axis not in self.kernel.tiling_axis - ): + not_tiling = not check_in_tiling or axis not in self.kernel.tiling_axis + if (longest is None or axis.length > longest.length) and not_tiling: longest = axis return longest -- Gitee From 0bd61f0fa2c93a1018ec8ce9db65811619f1c928 Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Thu, 28 Aug 2025 11:45:49 +0800 Subject: [PATCH 4/6] nit --- test/_inductor/test_reduction_brocast_add.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/_inductor/test_reduction_brocast_add.py b/test/_inductor/test_reduction_brocast_add.py index 63e5b67b18..915eb8ac0d 100644 --- a/test/_inductor/test_reduction_brocast_add.py +++ b/test/_inductor/test_reduction_brocast_add.py @@ -17,7 +17,7 @@ class TestSumAdd(TestUtils): @parametrize('dim', [0, 1, 2]) @parametrize('dtype', ['float32']) def test_reduction_cases_shapes1(self, shape, dim, dtype): - a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch' + dtype), device="npu") for _ in range(2)] + a, b = [torch.randn(shape, requires_grad=False, dtype=eval('torch.' + dtype), device="npu") for _ in range(2)] r1 = self.foo(a, b, dim, shape) func = torch.compile(self.foo, backend="inductor") r = func(a, b, dim, shape) -- Gitee From ed07abfd788abd405838c568b213b8c3bb7b37bc Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Thu, 28 Aug 2025 17:10:33 +0800 Subject: [PATCH 5/6] fix ut --- test/_inductor/test_abs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/_inductor/test_abs.py b/test/_inductor/test_abs.py index ed34ffb2c7..62482afd7a 100644 --- a/test/_inductor/test_abs.py +++ b/test/_inductor/test_abs.py @@ -9,7 +9,7 @@ class TestAbs(TestUtils): result = torch.abs(first_element) return result - @parametrize('shape', [(1024, 32), (256, 8), (512, 64)]) + @parametrize('shape', [(1024, 32), (256, 8)]) @parametrize('dtype', ['float16', 'float32', 'bfloat16']) def test_pointwise_cases(self, shape, dtype): first_element = self._generate_tensor(shape, dtype) -- Gitee From 8418e9d615b4589cdbd490f832ddf95cda63b58c Mon Sep 17 00:00:00 2001 From: Xuan Peng Date: Fri, 29 Aug 2025 11:20:23 +0800 Subject: [PATCH 6/6] nit --- torch_npu/_inductor/codegen/wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/_inductor/codegen/wrapper.py b/torch_npu/_inductor/codegen/wrapper.py index 7afba0d288..86a63ee301 100644 --- a/torch_npu/_inductor/codegen/wrapper.py +++ b/torch_npu/_inductor/codegen/wrapper.py @@ -14,6 +14,7 @@ from torch._inductor.codegen.wrapper import ( user_defined_kernel_grid_fn_code, pexpr, ) +from torch._inductor.ir import IRNode from torch._inductor.runtime import triton_heuristics from torch._inductor.utils import ( cache_on_self, -- Gitee