diff --git a/test/test_network_ops/test_ne.py b/test/test_network_ops/test_ne.py index 395f7bca4f46dcddae26cdf8b921a82be91000f4..74913faa8122843692b34066a1a96273952b1337 100644 --- a/test/test_network_ops/test_ne.py +++ b/test/test_network_ops/test_ne.py @@ -49,6 +49,20 @@ class TestNe(TestCase): output = input3.to("cpu") output = output.numpy() return output + + def test_ne_shape_format_int32(self): + dtype_list = [np.int32] + format_list = [0, 3] + shape_list = [[1024], [8, 128], [2, 8, 128], [2, 8, 128, 512]] + shape_format = [ + [d, i, j] for d in dtype_list for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) def test_ne_shape_format_fp32(self): dtype_list = [np.float32] diff --git a/torch_npu/csrc/aten/ops/NeKernelNpu.cpp b/torch_npu/csrc/aten/ops/NeKernelNpu.cpp index b95c86a50da19535dcfb20563efed18dbecc572c..adbb7587b37e6167d5a974033d212179d42029a2 100644 --- a/torch_npu/csrc/aten/ops/NeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/NeKernelNpu.cpp @@ -21,13 +21,7 @@ namespace at_npu { namespace native { at::Tensor& ne_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) { - at::Tensor selfCast = self; - at::Tensor otherCast = other; - if(self.dtype() == at::ScalarType::Int || other.dtype() == at::ScalarType::Int){ - selfCast = self.to(at::ScalarType::Float); - otherCast = other.to(at::ScalarType::Float); - } - auto unified_result = OpPreparation::comparison_op_check(result, selfCast, otherCast, true); + auto unified_result = OpPreparation::comparison_op_check(result, self, other, true); if(self.scalar_type() == at::kLong) { TORCH_WARN_ONCE("The oprator of ne is executed, Currently High Accuracy but Low Performance OP with 64-bit has been used," "Please Do Some Cast at Python Functions with 32-bit for Better Performance!"); @@ -35,8 +29,8 @@ at::Tensor& ne_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, const OpCommand cmd; cmd.Name("NotEqual") .Expect(unified_result) - .Input(selfCast) - .Input(otherCast) + .Input(self) + .Input(other) .Output(result) .Run(); @@ -44,18 +38,14 @@ at::Tensor& ne_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, const } at::Tensor& ne_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, at::Scalar other) { - at::Tensor selfCast = self; - if(self.dtype() == at::ScalarType::Int){ - selfCast = self.to(at::ScalarType::Float); - } if(self.scalar_type() == at::kLong) { TORCH_WARN_ONCE("The oprator of ne is executed, Currently High Accuracy but Low Performance OP with 64-bit has been used," "Please Do Some Cast at Python Functions with 32-bit for Better Performance!"); } OpCommand cmd; cmd.Name("NotEqual") - .Input(selfCast) - .Input(other, selfCast.scalar_type()) + .Input(self) + .Input(other, self.scalar_type()) .Output(result) .Run();