From e33763bce3bc4c95915b11e245398c409ee7a267 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Tue, 22 Mar 2022 14:01:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dscatter=E7=AE=97=E5=AD=90src?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E4=B8=8D=E5=8C=B9=E9=85=8D?= =?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_scatter.py | 16 +++++++++++++++- torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp | 17 ++++++----------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/test/test_network_ops/test_scatter.py b/test/test_network_ops/test_scatter.py index 1936d3761b2..f377284bc0d 100644 --- a/test/test_network_ops/test_scatter.py +++ b/test/test_network_ops/test_scatter.py @@ -51,7 +51,7 @@ class TestScatter(TestCase): input1 = input1.cpu() return input1.numpy() - def test_scatter_shape_format(self, device="npu"): + def test_scatter_shape_format(self): shape_format = [ [0, [3, 5], [np.float32, 0, [2, 5]]], [0, [3, 5], [np.float32, 3, [2, 5]]], @@ -86,6 +86,20 @@ class TestScatter(TestCase): npu_output = self.npu_op_exec_inplace(item[1], item[0], index, 1.23, False) self.assertRtolEqual(cpu_output, npu_output) + def test_scatter_debug(self): + a = np.random.uniform(-2,2,(31, 43, 41, 97)).astype(np.float16) + b = np.random.uniform(0,30,(31, 43, 41, 97)).astype(np.int32) + c = np.random.uniform(-2,2,(31, 43, 41, 97)).astype(np.float16) + ca = torch.from_numpy(a) + cb = torch.from_numpy(b).long() + cc = torch.from_numpy(c) + na = ca.npu() + nb = cb.npu() + nc = cc.npu() + dim = 0 + cpu_output = torch.scatter(ca, dim, cb, cc) + npu_output = torch.scatter(na, dim, nb, nc) + self.assertRtolEqual(cpu_output, npu_output.cpu()) if __name__ == "__main__": run_tests() diff --git a/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp index 0fad28bcd3e..8a32d4da902 100644 --- a/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ScatterKernelNpu.cpp @@ -41,7 +41,7 @@ at::Tensor& scatter_npu_src_impl( at::Tensor& self, int64_t dim, const at::Tensor& index_ex, - const at::Tensor& src) { + const at::Tensor& src_ex) { at::ScalarType selfType = self.scalar_type(); if (selfType == at::ScalarType::Half) { self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); @@ -52,6 +52,11 @@ at::Tensor& scatter_npu_src_impl( index = NPUNativeFunctions::npu_dtype_cast(index, at::ScalarType::Float); } + at::Tensor src(src_ex); + if (src.scalar_type() != self.scalar_type()) { + src = NPUNativeFunctions::npu_dtype_cast(src, self.scalar_type()); + } + if (!NpuUtils::check_match(&self)) { at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); @@ -64,7 +69,6 @@ at::Tensor& scatter_npu_src_impl( if(self.scalar_type() != selfType){ self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Half); } - return self; } @@ -74,10 +78,6 @@ at::Tensor& NPUNativeFunctions::scatter_( const at::Tensor& index_ex, const at::Tensor& src_ex) { at::Tensor src(src_ex); - if (src.scalar_type() != self.scalar_type()) { - src = NPUNativeFunctions::npu_dtype_cast(src, self.scalar_type()); - } - scatter_npu_src_impl(self, dim, index_ex, src); return self; } @@ -90,11 +90,6 @@ at::Tensor& NPUNativeFunctions::scatter_( at::Tensor srcTensor = scalar_to_tensor(src).to(at::ScalarType::Float); srcTensor = CalcuOpUtil::copy_tensor_host_to_device(srcTensor); at::Tensor srcTensor_broadcast = NPUNativeFunctions::npu_broadcast(srcTensor, array_to_small_vector(index_ex.sizes())); - - if (srcTensor_broadcast.scalar_type() != self.scalar_type()) { - srcTensor_broadcast = NPUNativeFunctions::npu_dtype_cast(srcTensor_broadcast, self.scalar_type()); - } - scatter_npu_src_impl(self, dim, index_ex, srcTensor_broadcast); return self; } -- Gitee