diff --git a/test/test_network_ops/test_min.py b/test/test_network_ops/test_min.py new file mode 100644 index 0000000000000000000000000000000000000000..c70c365bc3cd150f839433123400e3490611b6da --- /dev/null +++ b/test/test_network_ops/test_min.py @@ -0,0 +1,519 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestMin(TestCase): + def cpu_op_exec(self, input1): + output = torch.min(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.min(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_other_exec(self, input1, input2): + output = torch.min(input1, input2) + output = output.numpy() + return output + + def npu_op_other_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.min(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_other_exec_out(self, input1, input2, out): + torch.min(input1, input2, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_exec(self, input1, dim, keepdim): + output1, output2 = torch.min(input1, dim, keepdim) + output1 = output1.numpy() + output2 = output2.int().numpy() + return output1, output2 + + def npu_op_dim_exec(self, input1, dim, keepdim): + input1 = input1.to("npu") + output1, output2 = torch.min(input1, dim, keepdim) + output1 = output1.to("cpu") + output2 = output2.to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_dim_exec_out(self, input1, dim, keepdim): + out = torch.tensor(0).to(input1.dtype) + indices = torch.tensor(0).to(torch.long) + torch.min(input1, dim=dim, keepdim=keepdim, out=(out,indices)) + out = out.numpy() + indices = indices.numpy() + return out,indices + + def npu_op_dim_exec_out(self, input1, dim, keepdim): + out = torch.tensor(0).to(input1.dtype).npu() + indices = torch.tensor(0).to(torch.long).npu() + torch.min(input1, dim=dim, keepdim=keepdim, out=(out,indices)) + out = out.to("cpu").numpy() + indices = indices.to("cpu").numpy() + return out,indices + + def cpu_op_amin_exec(self, input1, dim, keepdim): + output = torch.amin(input1, dim, keepdim) + output = output.numpy() + return output + + def npu_op_amin_exec(self, input1, dim, keepdim): + output = torch.amin(input1, dim, keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_amin_exec_out(self, input1, dim, keepdim, out): + torch.amin(input1, dim, keepdim, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def min_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + + def min_result_dim(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + cpu_output_dim = cpu_output_dim.astype(npu_output_dim.dtype) + self.assertRtolEqual(cpu_output_dim, npu_output_dim) + + def min_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 10) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output_other = self.cpu_op_other_exec(cpu_input1, cpu_input2) + npu_output_other = self.npu_op_other_exec(npu_input1, npu_input2) + cpu_output_other = cpu_output_other.astype(npu_output_other.dtype) + + self.assertRtolEqual(cpu_output_other, npu_output_other) + + def min_out_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], -100, 100) + cpu_input3, npu_input3 = create_common_tensor(item[0], -100, 100) + cpu_input4, npu_input4 = create_common_tensor(item[1], -100, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_other_exec(cpu_input1, cpu_input2) + npu_output_out1 = self.npu_op_other_exec(npu_input1, npu_input2) + npu_output_out2 = self.npu_op_other_exec_out(npu_input1, npu_input2, npu_input4) + cpu_output = cpu_output.astype(npu_output_out1.dtype) + self.assertRtolEqual(cpu_output, npu_output_out1) + self.assertRtolEqual(cpu_output, npu_output_out2) + cpu_out_dim, cpu_out_indices = self.cpu_op_dim_exec_out(cpu_input1, dim=0, keepdim=True) + npu_out_dim, npu_out_indices = self.npu_op_dim_exec_out(npu_input1, dim=0, keepdim=True) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec(npu_input1, dim=0, keepdim=True) + cpu_out_dim = cpu_out_dim.astype(npu_out_dim.dtype) + if cpu_out_dim.dtype != np.float16: + self.assertRtolEqual(npu_out_dim, cpu_out_dim) + self.assertRtolEqual(npu_out_indices, cpu_out_indices) + else: + self.assertRtolEqual(npu_out_dim, npu_output_dim) + self.assertRtolEqual(npu_out_indices, npu_output_indices) + + def min_name_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input1.names = item[0][3] + npu_input1.names = item[0][3] + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec(cpu_input1, item[1], item[2]) + + if npu_output_dim.dtype != np.float16: + self.assertRtolEqual(npu_output_dim, cpu_output_dim) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + else: + self.assertRtolEqual( npu_output_dim, cpu_output_dim.astype(np.float16)) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + + def min_name_out_result_other(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input1.names = item[0][3] + npu_input1.names = item[0][3] + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_dim, cpu_output_indices = self.cpu_op_dim_exec_out(cpu_input1, item[1], item[2]) + npu_output_dim, npu_output_indices = self.npu_op_dim_exec_out(npu_input1, item[1], item[2]) + + if npu_output_dim.dtype != np.float16: + self.assertRtolEqual(npu_output_dim, cpu_output_dim) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + else: + self.assertRtolEqual( npu_output_dim, cpu_output_dim.astype(np.float16)) + self.assertRtolEqual(npu_output_indices.astype(np.int32), cpu_output_indices.astype(np.int32)) + + def amin_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_amin = self.cpu_op_amin_exec(cpu_input1, item[1], item[2]) + npu_output_amin = self.npu_op_amin_exec(npu_input1, item[1], item[2]) + npu_output_amin_out = self.npu_op_amin_exec_out(npu_input1, item[1], item[2], npu_input2) + cpu_output_amin = cpu_output_amin.astype(npu_output_amin.dtype) + self.assertRtolEqual(cpu_output_amin, npu_output_amin) + self.assertRtolEqual(cpu_output_amin, npu_output_amin_out) + + def test_min_out_result(self, device): + shape_format = [ + [[np.float16, 0, [19, 16, 14, 14]], [np.float16, 0, [25, 16, 1, 1]]], + [[np.float16, 0, [19, 17, 28, 28]], [np.float16, 0, [17, 17, 1, 1]]], + [[np.float16, 0, [19, 3, 22, 22]], [np.float16, 0, [3, 3, 3]]], + [[np.float16, 0, [19, 16, 14, 14]], [np.float16, 0, [19, 16, 14, 14]]], + [[np.float32, 0, [25, 19, 7, 7]], [np.float32, 0, [19, 25, 3, 3]]], + [[np.float32, 0, [25, 3, 22, 22]], [np.float32, 0, [3, 3, 7, 7]]], + [[np.float32, 0, [2, 3, 3, 3]], [np.float32, 0, [3, 1, 3]]], + [[np.float32, 0, [19, 23, 7, 7]], [np.float32, 0, [19, 23, 7, 7]]], + ] + self.min_out_result_other(shape_format) + + def test_min_shape_format_fp16_1d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_1d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_2d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_2d(self, device): + format_list = [0, 3] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp16_4d(self, device): + format_list = [0, 4, 3, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result(shape_format) + + def test_min_dim_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp16_4d(self, device): + format_list = [0, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_dim_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_dim(shape_format) + + def test_min_other_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_other_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.min_result_other(shape_format) + + def test_min_dimname_shape_format(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_result_other(shape_format) + + def test_min_dimname_shape_format_fp16(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_result_other(shape_format) + + def test_min_dimname_out_shape_format(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_out_result_other(shape_format) + + def test_min_dimname_out_shape_format_fp16(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16], ('N', 'C', 'H', 'W')], + np.random.choice(['N', 'C', 'H', 'W']), j] for i in format_list for j + in + keepdim_list + ] + self.min_name_out_result_other(shape_format) + + def test_amin_shape_format_fp16_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_1d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_2d(self, device): + format_list = [0, 3, 4] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25]], np.random.randint(0, 2), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_3d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15]], np.random.randint(0, 3), j] for i in format_list for j in + keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float16, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.amin_result(shape_format) + + def test_amin_shape_format_fp32_4d(self, device): + format_list = [0, 3, 4, 29] + keepdim_list = [True, False] + shape_format = [[[np.float32, i, [18, 25, 15, 16]], np.random.randint(0, 4), j] for i in format_list for j + in keepdim_list + ] + self.amin_result(shape_format) + + +instantiate_device_type_tests(TestMin, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/MinKernelNpu.cpp b/torch_npu/csrc/aten/ops/MinKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1940e03beaa7f60c793f0644ff6d33e4453632ab --- /dev/null +++ b/torch_npu/csrc/aten/ops/MinKernelNpu.cpp @@ -0,0 +1,185 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" + +namespace at_npu { +namespace native { + +tuple min_out_npu_nocheck( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + OpCommand cmd; + cmd.Name("ArgMinWithValue") + .Input(self) + .Output(indices) + .Output(output) + .Attr("dimension", dim) + .Attr("keep_dims", keepdim) + .Run(); + return std::tie(output, indices); +} + +tuple NPUNativeFunctions::min_out( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + c10::SmallVector dims = {dim}; + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + c10::SmallVector indicesSize = outputSize; + auto func = [&self, dim, keepdim](at::Tensor& output, at::Tensor& indices) { + min_out_npu_nocheck(self, dim, keepdim, output, indices); + }; + + at::Tensor indices_tmp; + OpPipeWithMultiOut pipe(output, indices_tmp); + return pipe.FixOutputSizeAndFormat<0>({self}, self, ACL_FORMAT_ND, outputSize) + .ApplyOutputWithSpecailParams<1>(indicesSize, self.options().dtype(at::ScalarType::Int), ACL_FORMAT_ND) + .Call(func) + .ReflushOutputDtype<1>(at::ScalarType::Long) + .FixOutputExceptDtype<1>({self}, ACL_FORMAT_ND, at::ScalarType::Long, indicesSize) + .FixOutputWithReplace<1>(indices) + .ReturnRef(); +} + +tuple NPUNativeFunctions::min(const at::Tensor& self, int64_t dim, bool keepdim) { + at::Tensor selfCast = self; + if(self.dtype() == at::ScalarType::Bool){ + selfCast = self.to(at::ScalarType::Float); + } + c10::SmallVector dims = {dim}; + auto outputSize = reduce_ops_npu_output_size(selfCast, dims, keepdim); + c10::SmallVector indicesSize = outputSize; + auto func = [&selfCast, dim, keepdim](at::Tensor outputs, at::Tensor indices) { + min_out_npu_nocheck(selfCast, dim, keepdim, outputs, indices); + }; + + at::Tensor outputs, indices; + OpPipeWithDefinedMultiOut pipe(outputs, indices); + std::tie(outputs, indices) = pipe.ApplyOutputWithSpecailParams<0>(outputSize, selfCast.options(), ACL_FORMAT_ND) + .ApplyOutputWithSpecailParams<1>(indicesSize, selfCast.options().dtype(at::ScalarType::Int), ACL_FORMAT_NCHW) + .Call(func) + .ReflushOutputDtype<1>(at::ScalarType::Long) + .Return(); + + if(self.dtype() == at::ScalarType::Bool){ + outputs = outputs.to(at::ScalarType::Bool); + } + return std::tie(outputs, indices); +} + +tuple NPUNativeFunctions::min_out( + const at::Tensor& self, + at::Dimname dim, + bool keepdim, + at::Tensor& output, + at::Tensor& indices) { + return min_out(self, dimname_to_position(self, dim), keepdim, output, indices); +} + +tuple NPUNativeFunctions::min(const at::Tensor& self, at::Dimname dim, bool keepdim) { + return min(self, dimname_to_position(self, dim), keepdim); +} + +at::Tensor& min_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("Minimum") + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::min_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + self.sizes()); + min_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor NPUNativeFunctions::min(const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + min_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor& min_out_npu_nocheck( + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("ReduceMin") + .Input(self) + .Input(dims) + .Output(result) + .Attr("keep_dims", keepdim) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::amin(const at::Tensor& self, at::IntArrayRef dims, bool keepdim) { + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + int64_t npu_format = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npu_format = ACL_FORMAT_NCHW; + } + at::Tensor result = OpPreparation::ApplyTensorWithFormat(self, outputSize, npu_format); + min_out_npu_nocheck(self, dims, keepdim, result); + return result; +} + +at::Tensor NPUNativeFunctions::min(const at::Tensor& self) { + c10::SmallVector dims = CalcuOpUtil::get_dimlist_for_tensor(self); + return amin(self, dims, false); +} + +at::Tensor& NPUNativeFunctions::amin_out( + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim, + at::Tensor& result) { + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + outputSize); + min_out_npu_nocheck(self, dims, keepdim, result); + return result; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 758c2abd34e4ff9e7678df8e943087d3734bd542..4b73e36c763bbb46c0a1c93475d1ac37e2effb82 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -26,7 +26,7 @@ __all__ = [ "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", - "Stream", "Event", "profiler", "set_option" + "Stream", "Event", "profiler" ] @@ -44,4 +44,3 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event from . import profiler -from .npu_frontend_enhance import set_option \ No newline at end of file diff --git a/torch_npu/npu/profiler.py b/torch_npu/npu/profiler.py index b7c1474fa6b68a7e9540f4bce3ab067000598113..0dbf9de39a5f6f3e3e5375fd786bf94ccfe76fab 100644 --- a/torch_npu/npu/profiler.py +++ b/torch_npu/npu/profiler.py @@ -1165,7 +1165,7 @@ def parse_kineto_results(result): is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() fe = FunctionEvent( - id_event=kineto_event.correlation_id(), + id=kineto_event.correlation_id(), name=rewrite_name(name=kineto_event.name(), with_wildcard=True), trace_name=rewrite_name(name=kineto_event.name(), with_wildcard=False), thread=kineto_event.start_thread_id(), @@ -1212,7 +1212,7 @@ def parse_kineto_results(result): for mem_record in mem_records: if not mem_record[1]: fe = FunctionEvent( - id_event=mem_record[0].handle(), + id=mem_record[0].handle(), name="[memory]", trace_name=None, # not outputting in the trace thread=mem_record[0].thread_id(), @@ -1340,7 +1340,7 @@ def parse_legacy_records(thread_records): start_flops = start.flops() fe = FunctionEvent( - id_event=record.handle(), + id=record.handle(), node_id=record.node_id(), name=rewrite_name(name=start.name(), with_wildcard=True), trace_name=rewrite_name(name=start.name(), with_wildcard=False), @@ -1398,7 +1398,7 @@ def parse_legacy_records(thread_records): if num_open_handles_cpu == 0: # output event as a top-level memory event fe = FunctionEvent( - id_event=0, + id=0, name="[memory]", trace_name=None, thread=0,