From 0e79282ccad877282d6610364cce8a630f7dd71d Mon Sep 17 00:00:00 2001 From: wgb Date: Fri, 4 Jul 2025 14:15:14 +0800 Subject: [PATCH] Resize_ to support ncdhw to other dims --- test/custom_ops/test_resize_.py | 32 +++++++++++++++++++ .../csrc/framework/StorageDescHelper.cpp | 8 +++-- 2 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 test/custom_ops/test_resize_.py diff --git a/test/custom_ops/test_resize_.py b/test/custom_ops/test_resize_.py new file mode 100644 index 00000000000..93b30183249 --- /dev/null +++ b/test/custom_ops/test_resize_.py @@ -0,0 +1,32 @@ +import numpy as np +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestResize(TestCase): + + def test_masked_select_out(self): + + input_data = torch.tensor([[[[[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20], [21, 22, 23, 24, 25]]]]]], dtype=torch.float) + mask = torch.tensor([True, False, True, False, True]) + + input_data_npu = input_data.npu() + mask_npu = mask.npu() + + out_tensor = torch.empty((1, 1, 1, 1, 1), dtype=input_data.dtype) + out_tensor_npu = out_tensor.npu() + + out_tensor_npu = out_tensor_npu.view(-1) + out_tensor_npu = torch.masked_select(input_data_npu, mask_npu, out=out_tensor_npu) + out_tensor = torch.masked_select(input_data, mask, out=out_tensor) + self.assertRtolEqual(out_tensor_npu, out_tensor) + + def test_resize_ncdhw(self): + out_tensor = torch.empty((1, 1, 1, 1, 1), dtype=torch.float16).npu() + shape = [25] + out_tensor.resize_(shape) + self.assertEqual(shape, out_tensor.shape) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/framework/StorageDescHelper.cpp b/torch_npu/csrc/framework/StorageDescHelper.cpp index 24f0f83c4d2..a451a1a5405 100644 --- a/torch_npu/csrc/framework/StorageDescHelper.cpp +++ b/torch_npu/csrc/framework/StorageDescHelper.cpp @@ -61,9 +61,13 @@ void StorageDescHelper::UpdateDesc(torch_npu::NPUStorageDesc &npuDesc, const c10 } } npuDesc.base_strides_ = new_stride; - // 更新物理内存信息 - npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + int NCDHW_OR_NDHWC_DIM = 5; + if ((npuDesc.npu_format_ == ACL_FORMAT_NCDHW || npuDesc.npu_format_ == ACL_FORMAT_NDHWC) && new_size.size() < NCDHW_OR_NDHWC_DIM) { + npuDesc.storage_sizes_ = new_size; + } else { + npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + } if (new_data_numel > new_shape_numel) { // Refresh format to base format only when flattening storage data npuDesc.storage_sizes_ = new_size; -- Gitee