From 5d045b2f8b336503cd1fae262b7872f9484ffb08 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 21 Jun 2022 17:24:32 +0800 Subject: [PATCH] =?UTF-8?q?fixbug=20=E4=BF=AE=E6=94=B9check=5F5d=5F5d=5Fma?= =?UTF-8?q?tch=E6=A3=80=E6=B5=8Bsaclar=E6=8A=A5=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_add.py | 14 ++++++++++++++ torch_npu/csrc/framework/utils/NpuUtils.cpp | 11 +++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/test/test_network_ops/test_add.py b/test/test_network_ops/test_add.py index 8700e39e62d..c79786b7b41 100644 --- a/test/test_network_ops/test_add.py +++ b/test/test_network_ops/test_add.py @@ -390,6 +390,20 @@ class TestAdd(TestCase): npu_output = npu_output.to("cpu") self.assertRtolEqual(cpu_output, npu_output) + def test_add_scalar_check_5d_5d_match(self, device="npu"): + ca = torch.randn(4) + cb = ca.view(2, 2).transpose(1, 0) + na = ca.npu() + nb = cb.npu() + caout = torch.add(ca, 1) + cbout = torch.add(cb, 1) + naout = torch.add(na, 1) + nbout = torch.add(nb, 1) + naout = naout.to("cpu") + nbout = nbout.to("cpu") + self.assertRtolEqual(caout, naout) + self.assertRtolEqual(cbout, nbout) + if __name__ == "__main__": run_tests() diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index 8a6eea007f3..a786480b818 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -82,16 +82,11 @@ bool NpuUtils::check_5d_5d_match(const at::Tensor &tensor) { // (1) NC1HWC0 format in storage, NCHW format in des. // (2) 4d format situation, only uncontiguous in Channel size // (3) size and start point must be 16*, make sure the memory be contiguous - const c10::Storage storage = tensor.storage(); - const torch_npu::NPUStorageDesc npuDesc = - torch_npu::NPUBridge::GetNpuStorageImpl(storage.unsafeGetStorageImpl()) - ->get_npu_desc(); - if (tensor.is_contiguous()) { return false; } - if (npuDesc.npu_format_ != ACL_FORMAT_NC1HWC0) { + if (torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.npu_format_ != ACL_FORMAT_NC1HWC0) { return false; } @@ -117,8 +112,8 @@ bool NpuUtils::check_5d_5d_match(const at::Tensor &tensor) { int64_t contiguous_len = 16; int64_t c0_len = 16; - for (int i = 2; i < npuDesc.base_sizes_.size(); i++) { - contiguous_len *= npuDesc.base_sizes_[i]; + for (int i = 2; i < torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.base_sizes_.size(); i++) { + contiguous_len *= torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_.base_sizes_[i]; } bool is_offset_match = (tensor.storage_offset() % contiguous_len == 0); bool is_length_match = (tensor.size(1) % c0_len == 0); -- Gitee