diff --git a/test/custom_ops/test_resize_.py b/test/custom_ops/test_resize_.py new file mode 100644 index 0000000000000000000000000000000000000000..93b3018324930e4011458b277e8630c11aeb290a --- /dev/null +++ b/test/custom_ops/test_resize_.py @@ -0,0 +1,32 @@ +import numpy as np +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestResize(TestCase): + + def test_masked_select_out(self): + + input_data = torch.tensor([[[[[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20], [21, 22, 23, 24, 25]]]]]], dtype=torch.float) + mask = torch.tensor([True, False, True, False, True]) + + input_data_npu = input_data.npu() + mask_npu = mask.npu() + + out_tensor = torch.empty((1, 1, 1, 1, 1), dtype=input_data.dtype) + out_tensor_npu = out_tensor.npu() + + out_tensor_npu = out_tensor_npu.view(-1) + out_tensor_npu = torch.masked_select(input_data_npu, mask_npu, out=out_tensor_npu) + out_tensor = torch.masked_select(input_data, mask, out=out_tensor) + self.assertRtolEqual(out_tensor_npu, out_tensor) + + def test_resize_ncdhw(self): + out_tensor = torch.empty((1, 1, 1, 1, 1), dtype=torch.float16).npu() + shape = [25] + out_tensor.resize_(shape) + self.assertEqual(shape, out_tensor.shape) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/framework/StorageDescHelper.cpp b/torch_npu/csrc/framework/StorageDescHelper.cpp index 24f0f83c4d20d830b62dc1c65f0b67c89e78d116..a451a1a54053875d1ebcfd9c43bb7907def92677 100644 --- a/torch_npu/csrc/framework/StorageDescHelper.cpp +++ b/torch_npu/csrc/framework/StorageDescHelper.cpp @@ -61,9 +61,13 @@ void StorageDescHelper::UpdateDesc(torch_npu::NPUStorageDesc &npuDesc, const c10 } } npuDesc.base_strides_ = new_stride; - // 更新物理内存信息 - npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + int NCDHW_OR_NDHWC_DIM = 5; + if ((npuDesc.npu_format_ == ACL_FORMAT_NCDHW || npuDesc.npu_format_ == ACL_FORMAT_NDHWC) && new_size.size() < NCDHW_OR_NDHWC_DIM) { + npuDesc.storage_sizes_ = new_size; + } else { + npuDesc.storage_sizes_ = FormatHelper::GetStorageSizes(npuDesc); + } if (new_data_numel > new_shape_numel) { // Refresh format to base format only when flattening storage data npuDesc.storage_sizes_ = new_size;