From b680a4ea8356360c1cc8aef7fa28aa1fe0d04cf9 Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Wed, 16 Feb 2022 18:25:48 +0800 Subject: [PATCH 1/6] add memorypool --- test/test_tensor.py | 245 ++++++++++++++++++ torch_npu/csrc/InitNpuBindings.cpp | 4 +- torch_npu/csrc/aten/common/SetNpu.cpp | 4 +- .../csrc/aten/common/TensorFactories.cpp | 17 +- torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp | 6 +- .../BatchNormBackwardKernelNpu.cpp | 2 +- .../ops/normalization/BatchNormKernelNpu.cpp | 2 +- .../csrc/distributed/ProcessGroupHCCL.cpp | 12 +- torch_npu/csrc/framework/OpParamMaker.cpp | 2 +- .../csrc/framework/utils/CalcuOpUtil.cpp | 2 +- torch_npu/csrc/framework/utils/NpuUtils.cpp | 2 +- torch_npu/csrc/framework/utils/NpuUtils.h | 2 +- torch_npu/csrc/npu/Module.cpp | 36 +-- 13 files changed, 289 insertions(+), 47 deletions(-) create mode 100644 test/test_tensor.py diff --git a/test/test_tensor.py b/test/test_tensor.py new file mode 100644 index 0000000000..71ee4254f4 --- /dev/null +++ b/test/test_tensor.py @@ -0,0 +1,245 @@ +import torch +import torch_npu +import tempfile + +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_device_type import device_type_test_bases, DeviceTypeTestBase, onlyOn, dtypes, instantiate_device_type_tests +from itertools import product, combinations, combinations_with_replacement, permutations + + +def onlyNPU(fn): + return onlyOn('npu')(fn) + + +class NPUTestBase(DeviceTypeTestBase): + device_type = 'npu' + +device_type_test_bases.append(NPUTestBase) + +class TestTensor(TestCase): + @onlyNPU + def test_narrow_empty(self, device): + x = torch.randn(2, 3, 4).to(device=device) + for d in range(x.dim()): + y = x.narrow(d, x.size(d), 0) + sz = list(x.size()) + sz[d] = 0 + self.assertEqual(sz, y.size()) + + @onlyNPU + def test_tensor_set(self, device): + t1 = torch.Tensor() + t2 = torch.Tensor(3, 4, 9, 10).uniform_() + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + size = torch.Size([9, 3, 4, 10]) + t1.set_(t2.storage(), 0, size) + self.assertEqual(t1.size(), size) + t1.set_(t2.storage(), 0, tuple(size)) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), (120, 40, 10, 1)) + stride = (10, 360, 90, 1) + t1.set_(t2.storage(), 0, size, stride) + self.assertEqual(t1.stride(), stride) + t1.set_(t2.storage(), 0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) + + # test argument names + t1 = torch.Tensor() + # 1. case when source is tensor + t1.set_(source=t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 2. case when source is storage + t1.set_(source=t2.storage()) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 3. case when source is storage, and other args also specified + t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) + + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([False, False], dtype=torch.bool) + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + @onlyNPU + @dtypes(torch.half, torch.float) + def test_cat_all_dtypes_and_devices(self, device, dtype): + x = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device) + + expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dtype, device=device) + self.assertEqual(torch.cat((x, x), 0).to('cpu'), expected1.to('cpu')) + + expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dtype, device=device) + self.assertEqual(torch.cat((x, x), 1).to('cpu'), expected2.to('cpu')) + + @onlyNPU + def test_cat_mem_overlap(self, device): + x = torch.rand((1, 3)).to(device).expand((6, 3)) + y = torch.rand((3, 3)).to(device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + torch.cat([y, y], out=x) + + @onlyNPU + def test_cat(self, device): + SIZE = 10 + for dim in range(-3, 3): + pos_dim = dim if dim >= 0 else 3 + dim + x = torch.rand(13, SIZE, SIZE).to(device).transpose(0, pos_dim) + y = torch.rand(17, SIZE, SIZE).to(device).transpose(0, pos_dim) + z = torch.rand(19, SIZE, SIZE).to(device).transpose(0, pos_dim) + + res1 = torch.cat((x, y, z), dim) + self.assertEqual(res1.narrow(pos_dim, 0, 13).to('cpu'), x.to('cpu'), atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 13, 17).to('cpu'), y.to('cpu'), atol=0, rtol=0) + self.assertEqual(res1.narrow(pos_dim, 30, 19).to('cpu'), z.to('cpu'), atol=0, rtol=0) + + x = torch.randn(20, SIZE, SIZE).to(device) + self.assertEqual(torch.cat(torch.split(x, 7)).to('cpu'), x.to('cpu')) + self.assertEqual(torch.cat(torch.chunk(x, 7)).to('cpu'), x.to('cpu')) + + y = torch.randn(1, SIZE, SIZE).to(device) + z = torch.cat([x, y]) + self.assertEqual(z.size(), (21, SIZE, SIZE)) + + + # TODO: this test should be updated + @onlyNPU + def test_zeros(self, device): + res1 = torch.zeros(100, 100, device=device) + res2 = torch.tensor((), device=device) + torch.zeros(100, 100, device=device, out=res2) + + self.assertEqual(res1.to('cpu'), res2.to('cpu')) + + boolTensor = torch.zeros(2, 2, device=device, dtype=torch.bool) + expected = torch.tensor([[False, False], [False, False]], + device=device, dtype=torch.bool) + self.assertEqual(boolTensor.to('cpu'), expected.to('cpu')) + + halfTensor = torch.zeros(1, 1, device=device, dtype=torch.half) + expected = torch.tensor([[0.]], device=device, dtype=torch.float16) + self.assertEqual(halfTensor.to('cpu'), expected.to('cpu')) + + bfloat16Tensor = torch.zeros(1, 1, device=device, dtype=torch.half) + expected = torch.tensor([[0.]], device=device, dtype=torch.half) + self.assertEqual(bfloat16Tensor.to('cpu'), expected.to('cpu')) + + # TODO: this test should be updated + @onlyNPU + def test_zeros_out(self, device): + shape = (3, 4) + out = torch.zeros(shape, device=device) + torch.zeros(shape, device=device, out=out) + + # change the dtype, layout, device + with self.assertRaises(RuntimeError): + torch.zeros(shape, device=device, dtype=torch.int64, out=out) + with self.assertRaises(RuntimeError): + torch.zeros(shape, device=device, layout=torch.sparse_coo, out=out) + + # leave them the same + self.assertEqual(torch.zeros(shape, device=device).to('cpu'), + torch.zeros(shape, device=device, dtype=out.dtype, out=out).to('cpu')) + self.assertEqual(torch.zeros(shape, device=device).to('cpu'), + torch.zeros(shape, device=device, layout=torch.strided, out=out).to('cpu')) + self.assertEqual(torch.zeros(shape, device=device).to('cpu'), + torch.zeros(shape, device=device, out=out).to('cpu')) + + # TODO: this test should be updated + @onlyNPU + def test_ones(self, device): + res1 = torch.ones(100, 100, device=device) + res2 = torch.tensor((), device=device) + torch.ones(100, 100, device=device, out=res2) + self.assertEqual(res1.to('cpu'), res2.to('cpu')) + + # test boolean tensor + res1 = torch.ones(1, 2, device=device, dtype=torch.bool) + expected = torch.tensor([[True, True]], device=device, dtype=torch.bool) + self.assertEqual(res1.to('cpu'), expected.to('cpu')) + + @onlyNPU + def test_empty_strided(self, device): + for shape in [(2, 3, 4), (0, 2, 0)]: + # some of these cases are pretty strange, just verifying that if as_strided + # allows them then empty_strided can as well. + for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]: + empty_strided = torch.empty_strided(shape, strides, device=device) + # as_strided checks the storage size is big enough to support such a strided tensor; + # instead of repeating this calculation, we just use empty_strided which does the same + # calculation when setting the storage size. + as_strided = torch.empty(empty_strided.storage().size(), + device=device).as_strided(shape, strides) + self.assertEqual(empty_strided.shape, as_strided.shape) + self.assertEqual(empty_strided.stride(), as_strided.stride()) + @onlyNPU + def test_empty_tensor_props(self, device): + sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] + for size in sizes: + x = torch.empty(tuple(size), device=device) + self.assertEqual(size, x.shape) + self.assertTrue(x.is_contiguous()) + size_ones_instead_of_zeros = (x if x != 0 else 1 for x in size) + y = torch.empty(tuple(size_ones_instead_of_zeros), device=device) + self.assertEqual(x.stride(), y.stride()) + + @onlyNPU + @dtypes(torch.half, torch.float) + def test_full_inference(self, device, dtype): + size = (2, 2) + + prev_default = torch.get_default_dtype() + torch.set_default_dtype(dtype) + + # Tests bool fill value inference + t = torch.full(size, True) + self.assertEqual(t.dtype, torch.bool) + + # Tests integer fill value inference + t = torch.full(size, 1) + self.assertEqual(t.dtype, torch.long) + + # Tests float fill value inference + t = torch.full(size, 1.) + self.assertEqual(t.dtype, dtype) + + torch.set_default_dtype(prev_default) + + @onlyNPU + def test_full_out(self, device): + size = (5,) + o = torch.empty(size, device=device, dtype=torch.long) + + # verifies dtype/out conflict throws a RuntimeError + with self.assertRaises(RuntimeError): + torch.full(o.shape, 1., dtype=torch.float, out=o) + + # verifies out dtype overrides inference + self.assertEqual(torch.full(o.shape, 1., out=o).dtype, o.dtype) + self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype) + + # TODO: this test should be updated + @onlyNPU + def test_ones_like(self, device): + expected = torch.ones(100, 100, device=device) + + res1 = torch.ones_like(expected) + self.assertEqual(res1.to('cpu'), expected.to('cpu')) + + # test boolean tensor + expected = torch.tensor([True, True], device=device, dtype=torch.bool) + res1 = torch.ones_like(expected) + self.assertEqual(res1.to('cpu'), expected.to('cpu')) + + @onlyNPU + def test_zeros_like(self, device): + expected = torch.zeros((100, 100,), device=device) + + res1 = torch.zeros_like(expected) + self.assertEqual(res1.to('cpu'), expected.to('cpu')) + +instantiate_device_type_tests(TestTensor, globals(), only_for='npu') + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 9dd12dd179..0774cc72c1 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include @@ -53,7 +53,7 @@ PyObject * THPModule_npu_shutdown(PyObject * /* unused */) if (c10::npu::NpuSysCtrl::GetInstance().GetInitFlag()) { c10::npu::npuSynchronizeDevice(); THNPUCachingHostAllocator_emptyCache(); - c10::npu::NPUCachingAllocator::emptyCache(); + c10_npu::emptyCache(); c10::npu::NpuSysCtrl::SysStatus status = c10::npu::NpuSysCtrl::GetInstance().Finalize(); if (status != c10::npu::NpuSysCtrl::SysStatus::FINALIZE_SUCC) { fprintf(stdout, "THPModule_npu_shutdown failed.\n"); diff --git a/torch_npu/csrc/aten/common/SetNpu.cpp b/torch_npu/csrc/aten/common/SetNpu.cpp index 5b6dc25c49..f69d25aae5 100644 --- a/torch_npu/csrc/aten/common/SetNpu.cpp +++ b/torch_npu/csrc/aten/common/SetNpu.cpp @@ -18,8 +18,8 @@ #include #include #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/aten/common/ResizeNpu.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" @@ -32,7 +32,7 @@ c10::StorageImpl* storage_new_npu(caffe2::TypeMeta data_type) { c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), 0, - at::npu::NPUCachingAllocator::get(), + c10_npu::get(), true) .release(); return storage; diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index c2a9ae6eab..d7b8d73c20 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -29,9 +29,9 @@ #include #include #include -#include #include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/aten/common/ResizeNpu.h" #include "torch_npu/csrc/framework/StorageDescHelper.h" #include "torch_npu/csrc/framework/InferFormat.h" @@ -84,7 +84,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = at::npu::NPUCachingAllocator::get(); + c10::Allocator *allocator = c10_npu::get(); int64_t nelements = at::prod_intlist(size); auto dtype = c10::scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); int64_t size_bytes = nelements * dtype.itemsize(); @@ -272,7 +272,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = at::npu::NPUCachingAllocator::get(); + c10::Allocator *allocator = c10_npu::get(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); @@ -306,7 +306,7 @@ namespace at_npu AT_ASSERT(options.backend() == at::Backend::NPU); TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = at::npu::NPUCachingAllocator::get(); + c10::Allocator *allocator = c10_npu::get(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); @@ -347,8 +347,7 @@ namespace at_npu options.device(device); options.layout(layout_opt); options.pinned_memory(pin_memory_opt); - at::Tensor result = - OpPreparation::ApplyTensorWithFormat(size, options, dst_format); + at::Tensor result = OpPreparation::ApplyTensorWithFormat(size, options, dst_format); if (names.has_value()) { internal_set_names_inplace(result, names); @@ -361,8 +360,7 @@ namespace at_npu const c10::TensorOptions &options, int64_t dst_format) { - at::Tensor result = - OpPreparation::ApplyTensorWithFormat(size, options, dst_format); + at::Tensor result = OpPreparation::ApplyTensorWithFormat(size, options, dst_format); if (names.has_value()) { internal_set_names_inplace(result, names); @@ -376,8 +374,7 @@ namespace at_npu const c10::TensorOptions &options, int64_t dst_format) { - at::Tensor result = - OpPreparation::ApplyTensorWithFormat(size, options, dst_format); + at::Tensor result = OpPreparation::ApplyTensorWithFormat(size, options, dst_format); if (names.has_value()) { internal_set_names_inplace(result, names); diff --git a/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp b/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp index d56497a1a4..1335693ee7 100644 --- a/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp @@ -23,9 +23,9 @@ at::Tensor& NPUNativeFunctions::addmv_out( const at::Tensor& vec, at::Scalar beta, at::Scalar alpha, - at::Tensor& result) { + at::Tensor& result) { NpuUtils::check_1d(vec, "vec", "addmv"); - + at::Tensor mat1 = vec.unsqueeze(1); // matmul mat*alpha @@ -33,7 +33,7 @@ at::Tensor& NPUNativeFunctions::addmv_out( // matmul*alpha at::Tensor mmMulResult = at::mm(mat_alpha, mat1); - + at::Tensor mmMulResult1 = mmMulResult.squeeze(); // calculate the output size diff --git a/torch_npu/csrc/aten/ops/normalization/BatchNormBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/normalization/BatchNormBackwardKernelNpu.cpp index 80c0391d8f..135397f761 100644 --- a/torch_npu/csrc/aten/ops/normalization/BatchNormBackwardKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/normalization/BatchNormBackwardKernelNpu.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" diff --git a/torch_npu/csrc/aten/ops/normalization/BatchNormKernelNpu.cpp b/torch_npu/csrc/aten/ops/normalization/BatchNormKernelNpu.cpp index df44becf83..0d0f82bea9 100644 --- a/torch_npu/csrc/aten/ops/normalization/BatchNormKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/normalization/BatchNormKernelNpu.cpp @@ -16,7 +16,7 @@ #include #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index 77c68f31ff..cc7f30d15d 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include #include @@ -460,7 +460,7 @@ c10::intrusive_ptr ProcessGroupHCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - c10::npu::NPUCachingAllocator::recordStream( + c10_npu::recordStream( inputs[i].storage().data_ptr(), hcclStream); } { @@ -578,7 +578,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclAllgather", std::vector({input})); - c10::npu::NPUCachingAllocator::recordStream( + c10_npu::recordStream( output.storage().data_ptr(), stream); return HcclAllGather( input.data_ptr(), @@ -595,7 +595,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < outputTensors[0].size(); ++j) { // See [Sync Streams]. - c10::npu::NPUCachingAllocator::recordStream( + c10_npu::recordStream( outputTensors[i][j].storage().data_ptr(), hcclStreams[i]); outputTensors[i][j].copy_(outputFlattened[i][j], true); @@ -629,7 +629,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclReduceScatter", std::vector({input})); - c10::npu::NPUCachingAllocator::recordStream( + c10_npu::recordStream( output.storage().data_ptr(), stream); return HcclReduceScatter( input.data_ptr(), @@ -646,7 +646,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < inputTensors[0].size(); ++j) { // See [Sync Streams]. - c10::npu::NPUCachingAllocator::recordStream( + c10_npu::recordStream( inputTensors[i][j].storage().data_ptr(), hcclStreams[i]); inputFlattened[i][j].copy_(inputTensors[i][j], true); diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 3983af3294..9d4d9f3afa 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -15,7 +15,7 @@ #include "torch_npu/csrc/register/OptionsManager.h" #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include "torch_npu/csrc/framework/aoe/AoeUtils.h" diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp index 71bd52d346..899fe33d88 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp @@ -16,7 +16,7 @@ #include #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/register/OptionsManager.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index 07bda7c6a6..01194ff2af 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -357,7 +357,7 @@ namespace at_npu if (index == 1) { C10_NPU_CHECK(aclrtGetDevice(&deviceId)); - c10::npu::NPUCachingAllocator::FreeDeviceCachedMemory(deviceId); + c10_npu::FreeDeviceCachedMemory(deviceId); return true; } AT_ERROR("NPU out of memory. device id: ", deviceId); diff --git a/torch_npu/csrc/framework/utils/NpuUtils.h b/torch_npu/csrc/framework/utils/NpuUtils.h index 1abaeb2851..4c291abd5f 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.h +++ b/torch_npu/csrc/framework/utils/NpuUtils.h @@ -18,7 +18,7 @@ #define __PULGIN_NATIVE_NPU_UTILS_NUP_UTILS__ #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include #include diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index c4c36a1726..dee9b8ac2c 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -1,5 +1,5 @@ // Copyright (c) 2020 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. +// Copyright (c) 2019, Facebook CORPORATION. // All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); @@ -21,7 +21,7 @@ #include #include #include -#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include #include @@ -171,7 +171,7 @@ PyObject * THNPModule_setStream_wrap(PyObject *self, PyObject *obj) PyObject * THNPModule_emptyCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - c10::npu::NPUCachingAllocator::emptyCache(); + c10_npu::emptyCache(); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -182,10 +182,10 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const int device = (int) THPUtils_unpackLong(arg); - using c10::npu::NPUCachingAllocator::StatType; - using c10::npu::NPUCachingAllocator::Stat; - using c10::npu::NPUCachingAllocator::StatArray; - using c10::npu::NPUCachingAllocator::DeviceStats_; + using c10_npu::StatType; + using c10_npu::Stat; + using c10_npu::StatArray; + using c10_npu::DeviceStats_; const auto statToDict = [](const Stat& stat) { py::dict dict; @@ -208,7 +208,7 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) return dict; }; - const DeviceStats_ stats = c10::npu::NPUCachingAllocator::getDeviceStats(device); + const DeviceStats_ stats = c10_npu::getDeviceStats(device); py::dict result; result["num_alloc_retries"] = stats.num_alloc_retries; @@ -231,7 +231,7 @@ PyObject * THNPModule_resetAccumulatedMemoryStats(PyObject *_unused, PyObject *a HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_accumulated_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10::npu::NPUCachingAllocator::resetAccumulatedStats(device); + c10_npu::resetAccumulatedStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -241,7 +241,7 @@ PyObject * THNPModule_resetPeakMemoryStats(PyObject *_unused, PyObject *arg) HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10::npu::NPUCachingAllocator::resetPeakStats(device); + c10_npu::resetPeakStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -250,8 +250,8 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - using c10::npu::NPUCachingAllocator::SegmentInfo; - using c10::npu::NPUCachingAllocator::BlockInfo; + using c10_npu::SegmentInfo; + using c10_npu::BlockInfo; const auto segmentInfoToDict = [](const SegmentInfo& segmentInfo) { py::dict segmentDict; @@ -274,7 +274,7 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) return segmentDict; }; - const std::vector& snapshot = c10::npu::NPUCachingAllocator::snapshot(); + const std::vector& snapshot = c10_npu::snapshot(); py::list result; for (const auto& segmentInfo : snapshot) { @@ -300,7 +300,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject } ssize_t size = PyLong_AsSsize_t(size_o); aclrtStream stream = static_cast(PyLong_AsVoidPtr(stream_o)); - void* mem = c10::npu::NPUCachingAllocator::raw_alloc_with_stream(size, stream); + void* mem = c10_npu::raw_alloc_with_stream(size, stream); return PyLong_FromVoidPtr(mem); END_HANDLE_TH_ERRORS } @@ -308,7 +308,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject PyObject * THNPModule_npuCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){ HANDLE_TH_ERRORS void* mem_ptr = PyLong_AsVoidPtr(obj); - c10::npu::NPUCachingAllocator::raw_delete(mem_ptr); + c10_npu::raw_delete(mem_ptr); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -322,7 +322,7 @@ static PyGILState_STATE npuMutexGILState; PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10::npu::NPUCachingAllocator::getFreeMutex(); + auto mutex = c10_npu::getFreeMutex(); // This has to be a busy loop because we **absolutely need to** hold the GIL // or it's a recipe for a deadlock otherwise (if we let other Python threads // run while we have the cudaMutex, but not the GIL, they might try to e.g. @@ -343,7 +343,7 @@ PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) PyObject * THNPModule_npuUnlockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10::npu::NPUCachingAllocator::getFreeMutex(); + auto mutex = c10_npu::getFreeMutex(); PyGILState_Release(npuMutexGILState); mutex->unlock(); Py_RETURN_NONE; @@ -468,7 +468,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_getDeviceCount", (PyCFunction)THNPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, {"_npu_getCurrentStream", (PyCFunction)THNPModule_getCurrentStream_wrap, METH_O, nullptr}, {"_npu_getDefaultStream", (PyCFunction)THNPModule_getDefaultStream_wrap, METH_O, nullptr}, - {"_npu_setStream", (PyCFunction)THNPModule_setStream_wrap, METH_O, nullptr}, + {"_npu_setStream", (PyCFunction)THNPModule_setStream_wrap, METH_O, nullptr}, {"_npu_setStream", (PyCFunction)THNPModule_setStream_wrap, METH_O, nullptr}, {"_npu_emptyCache", (PyCFunction) THNPModule_emptyCache, METH_NOARGS, nullptr}, {"_npu_memoryStats", (PyCFunction) THNPModule_memoryStats, METH_O, nullptr}, -- Gitee From 1e574e90468272dceae6d3450014b1d1779de882 Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Wed, 16 Feb 2022 18:44:47 +0800 Subject: [PATCH 2/6] add test/test_tensor.py --- test/test_tensor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 71ee4254f4..7a861a979b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,9 +1,10 @@ +import tempfile import torch import torch_npu -import tempfile from torch.testing._internal.common_utils import TestCase, run_tests -from torch.testing._internal.common_device_type import device_type_test_bases, DeviceTypeTestBase, onlyOn, dtypes, instantiate_device_type_tests +from torch.testing._internal.common_device_type import device_type_test_bases, \ + DeviceTypeTestBase, onlyOn, dtypes, instantiate_device_type_tests from itertools import product, combinations, combinations_with_replacement, permutations @@ -17,6 +18,7 @@ class NPUTestBase(DeviceTypeTestBase): device_type_test_bases.append(NPUTestBase) class TestTensor(TestCase): + @onlyNPU def test_narrow_empty(self, device): x = torch.randn(2, 3, 4).to(device=device) @@ -62,6 +64,7 @@ class TestTensor(TestCase): t2 = torch.tensor([False, False], dtype=torch.bool) t1.set_(t2) self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + @onlyNPU @dtypes(torch.half, torch.float) def test_cat_all_dtypes_and_devices(self, device, dtype): @@ -102,8 +105,8 @@ class TestTensor(TestCase): z = torch.cat([x, y]) self.assertEqual(z.size(), (21, SIZE, SIZE)) - # TODO: this test should be updated + @onlyNPU def test_zeros(self, device): res1 = torch.zeros(100, 100, device=device) @@ -126,6 +129,7 @@ class TestTensor(TestCase): self.assertEqual(bfloat16Tensor.to('cpu'), expected.to('cpu')) # TODO: this test should be updated + @onlyNPU def test_zeros_out(self, device): shape = (3, 4) @@ -147,6 +151,7 @@ class TestTensor(TestCase): torch.zeros(shape, device=device, out=out).to('cpu')) # TODO: this test should be updated + @onlyNPU def test_ones(self, device): res1 = torch.ones(100, 100, device=device) @@ -220,6 +225,7 @@ class TestTensor(TestCase): self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype) # TODO: this test should be updated + @onlyNPU def test_ones_like(self, device): expected = torch.ones(100, 100, device=device) -- Gitee From a57d794c50f2aefcfd575158cad6adce0b850fcd Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Wed, 16 Feb 2022 19:33:53 +0800 Subject: [PATCH 3/6] add torch_npu/csrc/core/npu/NPUCachingAllocator.h --- .../csrc/core/npu/NPUCachingAllocator.cpp | 1177 +++++++++++++++++ torch_npu/csrc/core/npu/NPUCachingAllocator.h | 142 ++ 2 files changed, 1319 insertions(+) create mode 100644 torch_npu/csrc/core/npu/NPUCachingAllocator.cpp create mode 100644 torch_npu/csrc/core/npu/NPUCachingAllocator.h diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp new file mode 100644 index 0000000000..dddf79cdab --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -0,0 +1,1177 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" + +namespace c10_npu { + +C10_DEFINE_REGISTRY(FreeNPUMemoryCallbacksRegistry, FreeMemoryCallback); + +// +// Yet another caching allocator for NPU device allocations. +// +// - Allocations are associated with a stream. Once freed, blocks can be +// re-allocated on the same stream, but not on any other stream. +// - The allocator attempts to find the smallest cached block that will fit the +// requested size. If the block is larger than the requested size, it may be +// split. If no block is found, the allocator will delegate to npuMalloc. +// - If the npuMalloc fails, the allocator will free all cached blocks that +// are not split and retry the allocation. +// - Large (>1MB) and small allocations are stored in separate pools. +// Small requests are packed into 2MB buffers. Large requests will use the +// smallest available free block or allocate a new block using npuMalloc. +// To reduce fragmentation, requests between 1MB and 10MB will allocate and +// split a 20MB block, if no free block of sufficient size is available. +// +// With this allocator, allocations and frees should logically be considered +// "usages" of the memory segment associated with streams, just like kernel +// launches. The programmer must insert the proper synchronization if memory +// segments are used from multiple streams. +// +// The library provides a recordStream() function to help insert the correct +// synchronization when allocations are used on multiple streams. This will +// ensure that the block is not reused before each recorded stream completes +// work. +// +namespace { +using stream_set = std::unordered_set; + +constexpr size_t kMinBlockSize = + 512; // all sizes are rounded to at least 512 bytes +constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB +constexpr size_t kSmallBuffer = + 2097152; // "small" allocations are packed in 2 MiB blocks +constexpr size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks +constexpr size_t kMinLargeAlloc = + 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer +constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB + +typedef std::bitset(StatType::NUM_TYPES)> StatTypes; + +void update_stat(Stat& stat, int64_t amount) { + stat.current += amount; + stat.peak = std::max(stat.current, stat.peak); + if (amount > 0) { + stat.allocated += amount; + } + if (amount < 0) { + stat.freed += -amount; + } +} + +void reset_accumulated_stat(Stat& stat) { + stat.allocated = 0; + stat.freed = 0; +} + +void reset_peak_stat(Stat& stat) { + stat.peak = stat.current; +} + +void update_stat_array( + StatArray& stat_array, + int64_t amount, + const StatTypes& stat_types) { + for (size_t stat_type = 0; stat_type < stat_types.size(); ++stat_type) { + if (stat_types[stat_type]) { + update_stat(stat_array[stat_type], amount); + } + } +} + +struct DeviceStats { + uint64_t amount_allocated; // total amount allocated in bytes + uint64_t max_amount_allocated; // max total amount allocated in bytes + uint64_t amount_cached; // total amount in cache in bytes + uint64_t max_amount_cached; // max total amount in cache in bytes + + DeviceStats() + : amount_allocated(0), + max_amount_allocated(0), + amount_cached(0), + max_amount_cached(0) {} + + void increaseAllocated(size_t delta) { + amount_allocated += delta; + max_amount_allocated = std::max(max_amount_allocated, amount_allocated); + } + + void decreaseAllocated(size_t delta) { + amount_allocated -= delta; + } + + void increaseCached(size_t delta) { + amount_cached += delta; + max_amount_cached = std::max(max_amount_cached, amount_cached); + } + + void decreaseCached(size_t delta) { + amount_cached -= delta; + } +}; + +struct Block; +using Comparison = bool (*)(const Block*, const Block*); +using BlockPool = std::set; + +struct Block { + int device; // npu + aclrtStream stream; // allocation stream + stream_set stream_uses; // streams on which the block was used + size_t size; // block size in bytes + BlockPool* pool; // owning memory pool + void* ptr; // memory address + bool allocated; // in-use flag + Block* prev; // prev block if split from a larger allocation + Block* next; // next block if split from a larger allocation + int event_count; // number of outstanding NPU events + + Block(int device, aclrtStream stream, size_t size, BlockPool* pool, void* ptr) + : device(device), + stream(stream), + stream_uses(), + size(size), + pool(pool), + ptr(ptr), + allocated(0), + prev(nullptr), + next(nullptr), + event_count(0) {} + + // constructor for search key + Block(int device, aclrtStream stream, size_t size) + : device(device), + stream(stream), + stream_uses(), + size(size), + pool(nullptr), + ptr(nullptr), + allocated(0), + prev(nullptr), + next(nullptr), + event_count(0) {} + + bool is_split() const { + return (prev != nullptr) || (next != nullptr); + } +}; + +static bool BlockComparator(const Block* a, const Block* b) { + if (a->device != b->device) { + return a->device < b->device; + } + if (a->stream != b->stream) { + return reinterpret_cast(a->stream) < + reinterpret_cast(b->stream); + } + if (a->size != b->size) { + return a->size < b->size; + } + return reinterpret_cast(a->ptr) < + reinterpret_cast(b->ptr); +} + +static std::string format_size(uint64_t size) { + std::ostringstream os; + os.precision(2); + os << std::fixed; + if (size <= 1024) { + os << size << " bytes"; + } else if (size <= 1048576) { + os << (size / 1024.0); + os << " KiB"; + } else if (size <= 1073741824ULL) { + os << (size / 1048576.0); + os << " MiB"; + } else { + os << (size / 1073741824.0); + os << " GiB"; + } + return os.str(); +} +} // namespace + +struct THNCachingAllocator { + // device statistics + std::vector device_stats; + std::vector device_stats_; + + // lock around all operations + mutable std::recursive_mutex mutex; + + // lock around calls to aclFree (to prevent deadlocks with NCCL) + mutable std::mutex npu_free_mutex; + + // cached blocks larger than 1 MB + BlockPool large_blocks; + + // cached blocks 1 MB or smaller + BlockPool small_blocks; + + // allocated blocks by device pointer + std::unordered_map allocated_blocks; + + // outstanding acl events + std::deque> npu_events; + + THNCachingAllocator() + : large_blocks(BlockComparator), small_blocks(BlockComparator) {} + + DeviceStats& get_stats_for_device(int device) { + AT_ASSERT(device >= 0); + if ((size_t)device >= device_stats.size()) { + device_stats.resize(device + 1); + } + return device_stats.at(device); + } + + DeviceStats_& get_stats_for_device_(int device) { + AT_ASSERT(device >= 0); + if ((size_t)device >= device_stats_.size()) { + device_stats_.resize(device + 1); + } + return device_stats_.at(device); + } + + /** allocates a block which is safe to use from the provided stream */ + void malloc(void** devPtr, size_t size, aclrtStream stream) { + std::lock_guard lock(mutex); + int device = 0; + C10_NPU_CHECK(aclrtGetDevice(&device)); + // process outstanding npuEvents + process_events(); + size = round_size(size); + DeviceStats& stats = get_stats_for_device(device); + + Block search_key(device, stream, size); + auto& pool = get_pool(size); + + DeviceStats_& stats_ = get_stats_for_device_(device); + StatTypes stat_types; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast(get_stat_type_for_pool(pool))] = true; + + auto find_free_block = [&]() -> Block* { + auto it = pool.lower_bound(&search_key); + if (it != pool.end() && (*it)->device == device && + (*it)->stream == stream) { + Block* block = *it; + pool.erase(it); + return block; + } + return nullptr; + }; + + Block* block = find_free_block(); + if (block == nullptr) { + bool freed_memory = false; + for (const auto& name : FreeNPUMemoryCallbacksRegistry()->Keys()) { + freed_memory |= + FreeNPUMemoryCallbacksRegistry()->Create(name)->Execute(); + } + if (freed_memory) { + block = find_free_block(); + } + } + if (block == nullptr) { + void* ptr = nullptr; + size_t alloc_size = get_allocation_size(size); + aclError err = npu_malloc_retry(device, &ptr, alloc_size); + + if (err != ACL_ERROR_NONE) { + if (err == ACL_ERROR_RT_MEMORY_ALLOCATION) { + size_t device_free; + size_t device_total; + C10_NPU_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total)); + + const auto& stats = get_stats_for_device(device); + + stats_.num_ooms += 1; + // "total capacity": total global memory on NPU + // "already allocated": memory allocated by the program using the + // caching allocator + // "free": free memory as reported by the NPU API + // "cached": memory held by the allocator but not used by the program + // + // The "allocated" amount does not include memory allocated outside + // of the caching allocator, such as memory allocated by other + // programs or memory held by the driver. + // + // The sum of "allocated" + "free" + "cached" may be less than the + // total capacity due to memory held by the driver and usage by other + // programs. + // + // Note that at this point npu_malloc_retry has already returned all + // possible "cached" memory to the driver. The only remaining "cached" + // memory is split from a larger block that is partially in-use. + AT_ERROR( + "NPU out of memory. Tried to allocate ", + format_size(alloc_size), + " (NPU ", + device, + "; ", + format_size(device_total), + " total capacity; ", + format_size(stats.amount_allocated), + " already allocated; ", + format_size(device_free), + " free; ", + format_size(stats.amount_cached - stats.amount_allocated), + " cached)"); + } else { + C10_NPU_CHECK(err); + } + } + stats.increaseCached(alloc_size); + block = new Block(device, stream, alloc_size, &pool, ptr); + + update_stat_array(stats_.segment, 1, stat_types); + update_stat_array(stats_.reserved_bytes, alloc_size, stat_types); + } + + Block* remaining = nullptr; + AT_ASSERT(block); + + const bool already_split = block->is_split(); + if (should_split(block, size)) { + remaining = block; + block = new Block(device, stream, size, &pool, block->ptr); + block->prev = remaining->prev; + if (block->prev) { + block->prev->next = block; + } + block->next = remaining; + + remaining->prev = block; + remaining->ptr = static_cast(remaining->ptr) + size; + remaining->size -= size; + pool.insert(remaining); + + if (already_split) { + // An already-split inactive block is being shrunk by size bytes. + update_stat_array( + stats_.inactive_split_bytes, -block->size, stat_types); + } else { + // A new split inactive block is being created from a previously unsplit + // block, size remaining->size bytes. + update_stat_array( + stats_.inactive_split_bytes, remaining->size, stat_types); + update_stat_array(stats_.inactive_split, 1, stat_types); + } + } else if (already_split) { + // An already-split block is becoming active + update_stat_array(stats_.inactive_split_bytes, -block->size, stat_types); + update_stat_array(stats_.inactive_split, -1, stat_types); + } + + block->allocated = true; + allocated_blocks[block->ptr] = block; + *devPtr = block->ptr; + stats.increaseAllocated(block->size); + + c10::reportMemoryUsageToProfiler( + block, block->size, c10::Device(c10::DeviceType::NPU, device)); + + update_stat_array(stats_.allocation, 1, stat_types); + update_stat_array(stats_.allocated_bytes, block->size, stat_types); + update_stat_array(stats_.active, 1, stat_types); + update_stat_array(stats_.active_bytes, block->size, stat_types); + } + + void free(void* ptr) { + std::lock_guard lock(mutex); + if (!ptr) { + return; + } + + auto it = allocated_blocks.find(ptr); + if (it == allocated_blocks.end()) { + AT_ERROR("invalid device pointer: ", ptr); + } + + Block* block = it->second; + allocated_blocks.erase(it); + block->allocated = false; + + c10::reportMemoryUsageToProfiler( + block, -block->size, c10::Device(c10::DeviceType::NPU, block->device)); + + DeviceStats_& stats_ = get_stats_for_device_(block->device); + StatTypes stat_types; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast(get_stat_type_for_pool(*(block->pool)))] = + true; + update_stat_array(stats_.allocation, -1, {stat_types}); + update_stat_array(stats_.allocated_bytes, -block->size, {stat_types}); + get_stats_for_device(block->device).decreaseAllocated(block->size); + + if (!block->stream_uses.empty()) { + insert_events(block); + } else { + free_block(block); + } + } + + /** returns cached blocks to the system allocator */ + void emptyCache() { + std::lock_guard lock(mutex); + synchronize_and_free_events(c10::nullopt); + c10::npu::npuSynchronizeDevice(); + free_blocks(large_blocks, large_blocks.begin(), large_blocks.end()); + free_blocks(small_blocks, small_blocks.begin(), small_blocks.end()); + } + + void* getBaseAllocation(void* ptr, size_t* outSize) { + std::lock_guard lock(mutex); + Block* block = find_allocated_block(ptr); + if (!block) { + AT_ERROR("invalid device pointer: ", ptr); + } + while (block->prev) { + block = block->prev; + } + void* basePtr = block->ptr; + if (outSize) { + size_t size = 0; + while (block) { + size += block->size; + block = block->next; + } + *outSize = size; + } + return basePtr; + } + + // Accumulates sizes of all memory blocks for given device in given pool + void cacheInfoAux( + const BlockPool& blocks, + int dev_id, + size_t* total, + size_t* largest) { + Block search_key(dev_id, 0, 0); + auto it = blocks.lower_bound(&search_key); + for (; it != blocks.end() && *it && (*it)->device == dev_id; ++it) { + size_t blocksize = (*it)->size; + *total += blocksize; + if (blocksize > *largest) { + *largest = blocksize; + } + } + } + + void cacheInfo(int dev_id, size_t* total, size_t* largest) { + std::lock_guard lock(mutex); + cacheInfoAux(large_blocks, dev_id, total, largest); + cacheInfoAux(small_blocks, dev_id, total, largest); + } + + /** Returns a copy of the memory allocator stats for the device **/ + DeviceStats_ getStatsForDevice(int dev_id) { + std::lock_guard lock(mutex); + return get_stats_for_device_(dev_id); + } + + /** Resets the historical accumulation stats for the device **/ + void resetAccumulatedStats(int dev_id) { + std::lock_guard lock(mutex); + DeviceStats_& stats = get_stats_for_device_(dev_id); + + for (size_t statType = 0; + statType < static_cast(StatType::NUM_TYPES); + ++statType) { + reset_accumulated_stat(stats.allocation[statType]); + reset_accumulated_stat(stats.segment[statType]); + reset_accumulated_stat(stats.active[statType]); + reset_accumulated_stat(stats.inactive_split[statType]); + reset_accumulated_stat(stats.allocated_bytes[statType]); + reset_accumulated_stat(stats.reserved_bytes[statType]); + reset_accumulated_stat(stats.active_bytes[statType]); + reset_accumulated_stat(stats.inactive_split_bytes[statType]); + } + + stats.num_alloc_retries = 0; + stats.num_ooms = 0; + } + + /** Resets the historical peak stats for the device **/ + void resetPeakStats(int dev_id) { + std::lock_guard lock(mutex); + DeviceStats_& stats = get_stats_for_device_(dev_id); + + for (size_t statType = 0; + statType < static_cast(StatType::NUM_TYPES); + ++statType) { + reset_peak_stat(stats.allocation[statType]); + reset_peak_stat(stats.segment[statType]); + reset_peak_stat(stats.active[statType]); + reset_peak_stat(stats.inactive_split[statType]); + reset_peak_stat(stats.allocated_bytes[statType]); + reset_peak_stat(stats.reserved_bytes[statType]); + reset_peak_stat(stats.active_bytes[statType]); + reset_peak_stat(stats.inactive_split_bytes[statType]); + } + } + + /** Dump a complete snapshot of the memory held by the allocator. Potentially + * VERY expensive. **/ + std::vector snapshot() const { + std::lock_guard lock(mutex); + + std::vector result; + const auto all_blocks = get_all_blocks(); + + for (const Block* const head_block : all_blocks) { + if (head_block->prev != nullptr) { + continue; + } + result.emplace_back(); + SegmentInfo& segment_info = result.back(); + segment_info.device = head_block->device; + segment_info.address = reinterpret_cast(head_block->ptr); + segment_info.is_large = (head_block->pool == &large_blocks); + + const Block* block = head_block; + while (block != nullptr) { + segment_info.blocks.emplace_back(); + BlockInfo& block_info = segment_info.blocks.back(); + + block_info.size = block->size; + block_info.allocated = block->allocated; + block_info.active = block->allocated || (block->event_count > 0); + + segment_info.total_size += block_info.size; + if (block_info.allocated) { + segment_info.allocated_size += block_info.size; + } + if (block_info.active) { + segment_info.active_size += block_info.size; + } + + block = block->next; + } + } + + std::sort( + result.begin(), + result.end(), + [](const SegmentInfo& a, const SegmentInfo& b) { + if (a.device != b.device) { + return a.device < b.device; + } + return a.address < b.address; + }); + + return result; + } + + std::vector get_all_blocks() const { + std::vector blocks; + blocks.insert(blocks.end(), small_blocks.begin(), small_blocks.end()); + blocks.insert(blocks.end(), large_blocks.begin(), large_blocks.end()); + for (const auto& item : allocated_blocks) { + blocks.push_back(item.second); + } + return blocks; + } + + void recordStream(const c10::DataPtr& ptr, c10::npu::NPUStream stream) { + // Empty tensor's storage().data() might be a null ptr. As there is no + // blocks associated with those tensors, it is fine to do nothing here. + if (!ptr.get()) { + return; + } + // If a tensor is not allocated by this instance, simply skip + // This usually happens when NPU tensors are shared across processes, + // we have implemented reference counting based sharing mechanism to + // guarantee tensors won't be accidentally freed by one process while + // they are still being used in another + if (ptr.get_deleter() != &raw_delete) { + return; + } + std::lock_guard lock(mutex); + Block* block = find_allocated_block(ptr.get()); + // block could be nullptr in some cases, e.g., tensor loaded from blob, or + // shared from another process, or not pointing to a NPU tensor. + if (block) { + if (stream.stream() == block->stream) { + // ignore uses on the allocation stream, since those don't require any + // special synchronization + return; + } + block->stream_uses.insert(stream); + } + } + + /** moves a block into a pool of cached free blocks */ + void free_block(Block* block) { + AT_ASSERT(!block->allocated && block->event_count == 0); + size_t original_block_size = block->size; + + auto& pool = *block->pool; + int64_t net_change_inactive_split_blocks = 0; + int64_t net_change_inactive_split_size = 0; + + const int64_t subsumed_size_prev = + try_merge_blocks(block, block->prev, pool); + if (subsumed_size_prev > 0) { + net_change_inactive_split_blocks -= 1; + net_change_inactive_split_size -= subsumed_size_prev; + } + const int64_t subsumed_size_next = + try_merge_blocks(block, block->next, pool); + if (subsumed_size_next > 0) { + net_change_inactive_split_blocks -= 1; + net_change_inactive_split_size -= subsumed_size_next; + } + pool.insert(block); + + if (block->is_split()) { + net_change_inactive_split_blocks += 1; + net_change_inactive_split_size += block->size; + } + + DeviceStats_& stats_ = get_stats_for_device_(block->device); + StatTypes stat_types; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast(get_stat_type_for_pool(*(block->pool)))] = + true; + + update_stat_array( + stats_.inactive_split, net_change_inactive_split_blocks, stat_types); + update_stat_array( + stats_.inactive_split_bytes, + net_change_inactive_split_size, + stat_types); + update_stat_array(stats_.active, -1, stat_types); + update_stat_array(stats_.active_bytes, -original_block_size, stat_types); + } + + /** combine previously split blocks */ + size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) { + if (!src || src->allocated || src->event_count > 0) { + return 0; + } + if (dst->prev == src) { + dst->ptr = src->ptr; + dst->prev = src->prev; + if (dst->prev) { + dst->prev->next = dst; + } + } else { + dst->next = src->next; + if (dst->next) { + dst->next->prev = dst; + } + } + + const size_t subsumed_size = src->size; + dst->size += src->size; + pool.erase(src); + delete src; + + return subsumed_size; + } + + BlockPool& get_pool(size_t size) { + if (size <= kSmallSize) { + return small_blocks; + } else { + return large_blocks; + } + } + + StatType get_stat_type_for_pool(const BlockPool& pool) { + if (&pool == &small_blocks) { + return StatType::SMALL_POOL; + } else if (&pool == &large_blocks) { + return StatType::LARGE_POOL; + } else { + AT_ERROR("get_stat_type_for_pool: invalid pool"); + } + } + + bool should_split(const Block* block, size_t size) { + size_t remaining = block->size - size; + if (block->pool == &small_blocks) { + return remaining >= kMinBlockSize; + } else if (block->pool == &large_blocks) { + return remaining > kSmallSize; + } else { + AT_ERROR("should_split: invalid pool"); + } + } + + size_t round_size(size_t size) { + // be consistent with ACL memory alloc rules + size = size + 32; + if (size < kMinBlockSize) { + return kMinBlockSize; + } else { + return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); + } + } + + size_t get_allocation_size(size_t size) { + if (size <= kSmallSize) { + return kSmallBuffer; + } else if (size < kMinLargeAlloc) { + return kLargeBuffer; + } else { + return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); + } + } + + aclError npu_malloc_retry(int device, void** devPtr, size_t size) { + // Try npuMalloc. If npuMalloc fails, frees all non-split cached blocks + // and retries. + aclError err = aclrtMalloc( + devPtr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST); + if (err != ACL_ERROR_NONE) { + DeviceStats_& stats_ = get_stats_for_device_(device); + stats_.num_alloc_retries += 1; + + // npuGetLastError(); // reset the last NPU error + free_cached_blocks(device); + err = aclrtMalloc( + devPtr, size, aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST); + if (err != ACL_ERROR_NONE) { + return err; + } + } + return ACL_ERROR_NONE; + } + + void free_cached_blocks(int device) { + // First ensure that all blocks that can't currently be allocated due to + // outstanding events are returned to the pool. + synchronize_and_free_events(device); + + // Free all non-split cached blocks on device + Block lower_bound(device, nullptr, 0); + Block upper_bound(device + 1, nullptr, 0); + + c10::npu::npuSynchronizeDevice(); + free_blocks( + large_blocks, + large_blocks.lower_bound(&lower_bound), + large_blocks.lower_bound(&upper_bound)); + free_blocks( + small_blocks, + small_blocks.lower_bound(&lower_bound), + small_blocks.lower_bound(&upper_bound)); + } + + void free_blocks( + BlockPool& blocks, + BlockPool::iterator it, + BlockPool::iterator end) { + // Frees all non-split blocks between `it` and `end` + std::lock_guard lock(npu_free_mutex); + while (it != end) { + Block* block = *it; + if (!block->prev && !block->next) { + aclrtFree((void*)block->ptr); + + get_stats_for_device(block->device).decreaseCached(block->size); + DeviceStats_& stats_ = get_stats_for_device_(block->device); + StatTypes stat_types; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast( + get_stat_type_for_pool(*(block->pool)))] = true; + + update_stat_array(stats_.segment, -1, stat_types); + update_stat_array(stats_.reserved_bytes, -block->size, stat_types); + + auto cur = it; + ++it; + blocks.erase(cur); + delete block; + } else { + ++it; + } + } + } + + void synchronize_and_free_events(c10::optional device) { + // Synchronize on outstanding events and then free associated blocks. + // Limited to blocks on the given device if specified. + auto remaining_events = decltype(npu_events)(); + + for (auto& e : npu_events) { + aclrtEvent event = e.first; + Block* block = e.second; + if (device.has_value() && block->device != *device) { + remaining_events.push_back(e); + continue; + } + + C10_NPU_CHECK(aclrtSynchronizeEvent(event)); + C10_NPU_CHECK(aclrtDestroyEvent(event)); + + block->event_count--; + if (block->event_count == 0) { + free_block(block); + } + } + + std::swap(npu_events, remaining_events); + } + + Block* find_allocated_block(void* ptr) { + auto it = allocated_blocks.find(ptr); + if (it == allocated_blocks.end()) { + return nullptr; + } + return it->second; + } + + void insert_events(Block* block) { + int prev_device = 0; + C10_NPU_CHECK(aclrtGetDevice(&prev_device)); + + stream_set streams(std::move(block->stream_uses)); + AT_ASSERT(block->stream_uses.empty()); + for (auto it = streams.begin(); it != streams.end(); ++it) { + int pre_device = 0; + aclError ret = aclrtGetDevice(&pre_device); + if (ret != ACL_ERROR_NONE) { + C10_NPU_CHECK(aclrtSetDevice(it->device_index())); + } else if (pre_device != it->device_index()) { + C10_NPU_CHECK(aclrtSetDevice(it->device_index())); + } + + aclrtEvent event; + aclrtCreateEvent(&event); + aclrtRecordEvent(event, it->stream()); + + block->event_count++; + npu_events.emplace_back(event, block); + } + + int cur_device = 0; + aclError ret = aclrtGetDevice(&cur_device); + if (ret != ACL_ERROR_NONE) { + C10_NPU_CHECK(aclrtSetDevice(prev_device)); + } else if (cur_device != prev_device) { + C10_NPU_CHECK(aclrtSetDevice(prev_device)); + } + } + + void process_events() { + // Process outstanding npuEvents. Events that are completed are removed + // from the queue, and the 'event_count' for the corresponding allocation + // is decremented. Stops at the first event which has not been completed. + // Since events on different devices or streams may occur out of order, + // the processing of some events may be delayed. + while (!npu_events.empty()) { + auto& e = npu_events.front(); + aclrtEvent event = e.first; + Block* block = e.second; + + aclrtEventStatus status = ACL_EVENT_STATUS_RESERVED; + aclError err = aclrtQueryEvent(event, &status); + if (err != ACL_ERROR_NONE) { + C10_NPU_CHECK(err); + } + if (status != ACL_EVENT_STATUS_COMPLETE) { + break; + } + + aclrtDestroyEvent(event); + + block->event_count--; + if (block->event_count == 0) { + free_block(block); + } + npu_events.pop_front(); + } + } + + void allocate_adjacent_ptr( + size_t size1, + size_t size2, + void** ptr_pre, + void** ptr_next, + aclrtStream stream) { + size_t round_size_pre = (size1 + 32 + 511) / 512 * 512; + size_t round_size = round_size_pre + size2; + malloc(ptr_pre, round_size, stream); + + Block* temp_block = allocated_blocks.find(*ptr_pre)->second; + DeviceStats_& stats_ = get_stats_for_device_(temp_block->device); + StatTypes stat_types; + stat_types[static_cast(StatType::AGGREGATE)] = true; + stat_types[static_cast( + get_stat_type_for_pool(*(temp_block->pool)))] = true; + update_stat_array(stats_.allocation, -1, {stat_types}); + update_stat_array(stats_.allocated_bytes, -temp_block->size, {stat_types}); + update_stat_array(stats_.active, -1, {stat_types}); + update_stat_array(stats_.active_bytes, -temp_block->size, {stat_types}); + + Block* next_block = nullptr; + Block* pre_block = allocated_blocks.find(*ptr_pre)->second; + next_block = pre_block; + auto& pool = get_pool(round_size); + pre_block = new Block( + next_block->device, + next_block->stream, + round_size_pre, + &pool, + pre_block->ptr); + + pre_block->prev = next_block->prev; + if (pre_block->prev) { + pre_block->prev->next = pre_block; + } + pre_block->next = next_block; + next_block->prev = pre_block; + next_block->ptr = static_cast(next_block->ptr) + round_size_pre; + pre_block->size = round_size_pre; + next_block->size -= round_size_pre; + + pre_block->allocated = true; + next_block->allocated = true; + allocated_blocks[pre_block->ptr] = pre_block; + allocated_blocks[next_block->ptr] = next_block; + + *ptr_next = next_block->ptr; + + DeviceStats_& stats_pre = get_stats_for_device_(pre_block->device); + StatTypes stat_types_pre; + stat_types_pre[static_cast(StatType::AGGREGATE)] = true; + stat_types_pre[static_cast( + get_stat_type_for_pool(*(pre_block->pool)))] = true; + update_stat_array(stats_pre.allocation, 1, stat_types_pre); + update_stat_array( + stats_pre.allocated_bytes, pre_block->size, stat_types_pre); + update_stat_array(stats_pre.active, 1, stat_types_pre); + update_stat_array(stats_pre.active_bytes, pre_block->size, stat_types_pre); + + DeviceStats_& stats_next = get_stats_for_device_(next_block->device); + StatTypes stat_types_next; + stat_types_next[static_cast(StatType::AGGREGATE)] = true; + stat_types_next[static_cast( + get_stat_type_for_pool(*(next_block->pool)))] = true; + update_stat_array(stats_next.allocation, 1, stat_types_next); + update_stat_array( + stats_next.allocated_bytes, next_block->size, stat_types_next); + update_stat_array(stats_next.active, 1, stat_types_next); + update_stat_array( + stats_next.active_bytes, next_block->size, stat_types_next); + } +}; + +THNCachingAllocator caching_allocator; + +static void NPUCachingDeleter(void* ptr) { + caching_allocator.free(ptr); +} + +// NB: I decided not to fold this into THNCachingAllocator, because the latter +// has a lot more methods and it wasn't altogether clear that they should +// actually be publically exposed +struct NPUCachingAllocator : public c10::Allocator { + c10::DataPtr allocate(size_t size) const override { + int device = 0; + C10_NPU_CHECK(aclrtGetDevice(&device)); + void* r = nullptr; + if (size != 0) { + caching_allocator.malloc( + &r, size, c10::npu::getCurrentNPUStreamNoWait(device)); + } + return {r, r, &NPUCachingDeleter, c10::Device(c10::DeviceType::NPU, device)}; + } + c10::DeleterFnPtr raw_deleter() const override { + return &NPUCachingDeleter; + } +}; + +std::tuple allocate_adjacent(size_t size1, size_t size2) { + int device = 0; + C10_NPU_CHECK(aclrtGetDevice(&device)); + void* ptr_pre = nullptr; + void* ptr_next = nullptr; + caching_allocator.allocate_adjacent_ptr( + size1, + size2, + &ptr_pre, + &ptr_next, + c10::npu::getCurrentNPUStreamNoWait(device)); + + c10::DataPtr data_pre = { + ptr_pre, ptr_pre, &NPUCachingDeleter, c10::Device(c10::DeviceType::NPU, device)}; + c10::DataPtr data_next = { + ptr_next, ptr_next, &NPUCachingDeleter, c10::Device(c10::DeviceType::NPU, device)}; + std::tuple adjacent_dataptr = + std::make_tuple(std::move(data_pre), std::move(data_next)); + + return adjacent_dataptr; +} + +NPUCachingAllocator device_allocator; + +c10::Allocator* get(void) { + return &device_allocator; +} + +void emptyCache(void) { + caching_allocator.emptyCache(); +} + +void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock) { + caching_allocator.cacheInfo(dev_id, cachedAndFree, largestBlock); +} + +void* getBaseAllocation(void* ptr, size_t* size) { + return caching_allocator.getBaseAllocation(ptr, size); +} + +void recordStream(const c10::DataPtr& ptr, c10::npu::NPUStream stream) { + caching_allocator.recordStream(ptr, stream); +} + +std::mutex* getFreeMutex() { + return &caching_allocator.npu_free_mutex; +} + +static inline void assertValidDevice(int device) { + int device_num = c10::npu::device_count(); + AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument."); +} + +DeviceStats_ getDeviceStats(int device) { + assertValidDevice(device); + return caching_allocator.getStatsForDevice(device); +} + +void resetAccumulatedStats(int device) { + assertValidDevice(device); + caching_allocator.resetAccumulatedStats(device); +} + +void resetPeakStats(int device) { + assertValidDevice(device); + caching_allocator.resetPeakStats(device); +} + +std::vector snapshot() { + return caching_allocator.snapshot(); +} + +uint64_t currentMemoryAllocated(int device) { + assertValidDevice(device); + return caching_allocator.get_stats_for_device(device).amount_allocated; +} + +uint64_t maxMemoryAllocated(int device) { + assertValidDevice(device); + return caching_allocator.get_stats_for_device(device).max_amount_allocated; +} + +void resetMaxMemoryAllocated(int device) { + assertValidDevice(device); + DeviceStats& stats = caching_allocator.get_stats_for_device(device); + stats.max_amount_allocated = stats.amount_allocated; +} + +uint64_t currentMemoryCached(int device) { + assertValidDevice(device); + return caching_allocator.get_stats_for_device(device).amount_cached; +} + +uint64_t maxMemoryCached(int device) { + assertValidDevice(device); + return caching_allocator.get_stats_for_device(device).max_amount_cached; +} + +void resetMaxMemoryCached(int device) { + assertValidDevice(device); + DeviceStats& stats = caching_allocator.get_stats_for_device(device); + stats.max_amount_cached = stats.amount_cached; +} + +// +// In NPU IPC, sender sends a tensor to receiver, getIpcDevPtr +// is called by the receiving process to map the NPU memory from the sending +// process into its own address space. +// +// NPU IPC only allows sharing a big memory block associated with a +// npuIpcMemHandle_t and it can be opened only **once** per context per +// process. There can be multiple types of storage in the same IPC mem block, so +// we must cache the device ptr to construct typed storage as it comes. +// +// ipcMemHandle_to_devptr maps a npuIpcMemHandle_t to a device pointer in the +// process that can be used to access the memory block in the sender process. It +// only saves a weak_ptr of the device pointer in the map, the shared_ptr will +// be used to reconstruct all storages in this NPUMalloc allocation. And it +// will deleted in npuIpcCloseMemHandle when its reference count is 0. +// +namespace { +std::mutex IpcMutex; +std::unordered_map> ipcMemHandle_to_devptr; +} // namespace + + +void* raw_alloc(size_t nbytes) { + if (nbytes == 0) { + return nullptr; + } + int device = 0; + C10_NPU_CHECK(aclrtGetDevice(&device)); + void* r = nullptr; + caching_allocator.malloc(&r, nbytes, c10::npu::getCurrentNPUStreamNoWait(device)); + return r; +} + +void* raw_alloc_with_stream(size_t nbytes, aclrtStream stream) { + if (nbytes == 0) { + return nullptr; + } + void* r = nullptr; + caching_allocator.malloc(&r, nbytes, stream); + return r; +} + +void raw_delete(void* ptr) { + caching_allocator.free(ptr); +} + +void FreeDeviceCachedMemory(int device) +{ + caching_allocator.free_cached_blocks(device); +} + +} // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h new file mode 100644 index 0000000000..a63106e0fe --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -0,0 +1,142 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace c10_npu { + +// Caching allocator will execute every registered callback if it unable to find +// block inside of already allocated area. +class C10_NPU_API FreeMemoryCallback { + public: + virtual ~FreeMemoryCallback(){}; + virtual bool Execute() = 0; +}; + +C10_DECLARE_REGISTRY(FreeNPUMemoryCallbacksRegistry, FreeMemoryCallback); +#define REGISTER_FREE_MEMORY_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(FreeNPUMemoryCallbacksRegistry, name, __VA_ARGS__); + + +// TODO: Turn this into an honest to goodness class. I briefly attempted to do +// this, but it was a bit irritating to figure out how to also correctly +// apply pimpl pattern so I didn't have to leak any internal implementation +// details in the header (NPUCachingAllocator could be made a pimpl, but +// you also need to appropriately define a class which is a subclass +// of Allocator. Not impossible, but required a bit more surgery than +// I wanted to do at the time.) +// +// Why is this using a namespace rather than old-style THNCachingAllocator_ +// prefix? Mostly because it made the HIPify rules easier to write; _ is +// not counted as a word boundary, so you would otherwise have to list each +// of these functions. +struct Stat { + int64_t current = 0; + int64_t peak = 0; + int64_t allocated = 0; + int64_t freed = 0; +}; + +enum struct StatType : uint64_t { + AGGREGATE = 0, + SMALL_POOL = 1, + LARGE_POOL = 2, + NUM_TYPES = 3 // remember to update this whenever a new stat type is added +}; + +typedef std::array(StatType::NUM_TYPES)> StatArray; +// Struct containing memory allocator summary statistics for a device. +struct DeviceStats_ { + // COUNT: allocations requested by client code + StatArray allocation; + // COUNT: number of allocated segments from npuMalloc(). + StatArray segment; + // COUNT: number of active memory blocks (allocated or used by stream) + StatArray active; + // COUNT: number of inactive, split memory blocks (unallocated but can't be released via npuFree) + StatArray inactive_split; + + // SUM: bytes requested by client code + StatArray allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + StatArray reserved_bytes; + // SUM: bytes within active memory blocks + StatArray active_bytes; + // SUM: bytes within inactive, split memory blocks + StatArray inactive_split_bytes; + + // COUNT: total number of failed calls to NPU malloc necessitating cache flushes. + int64_t num_alloc_retries = 0; + + // COUNT: total number of OOMs (i.e. failed calls to NPU after cache flush) + int64_t num_ooms = 0; +}; + +// Struct containing info of an allocation block (i.e. a fractional part of a cudaMalloc).. +struct BlockInfo { + int64_t size = 0; + bool allocated = false; + bool active = false; +}; + +// Struct containing info of a memory segment (i.e. one contiguous cudaMalloc). +struct SegmentInfo { + int64_t device = 0; + uintptr_t address = 0; + int64_t total_size = 0; + int64_t allocated_size = 0; + int64_t active_size = 0; + bool is_large = false; + std::vector blocks; +}; + + +C10_NPU_API void* raw_alloc(size_t nbytes); +C10_NPU_API void* raw_alloc_with_stream(size_t nbytes, aclrtStream stream); +C10_NPU_API void raw_delete(void* ptr); + +C10_NPU_API std::tuple allocate_adjacent(size_t size1, size_t size2); + +C10_NPU_API c10::Allocator* get(); +C10_NPU_API void emptyCache(); +C10_NPU_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock); +C10_NPU_API void* getBaseAllocation(void* ptr, size_t* size); +C10_NPU_API void recordStream(const c10::DataPtr& ptr, c10::npu::NPUStream stream); +C10_NPU_API DeviceStats_ getDeviceStats(int device); +C10_NPU_API void resetAccumulatedStats(int device); +C10_NPU_API void resetPeakStats(int device); +C10_NPU_API std::vector snapshot(); + +C10_NPU_API uint64_t currentMemoryAllocated(int device); +C10_NPU_API uint64_t maxMemoryAllocated(int device); +C10_NPU_API void resetMaxMemoryAllocated(int device); +C10_NPU_API uint64_t currentMemoryCached(int device); +C10_NPU_API uint64_t maxMemoryCached(int device); +C10_NPU_API void resetMaxMemoryCached(int device); + +C10_NPU_API std::mutex* getFreeMutex(); + +C10_NPU_API void FreeDeviceCachedMemory(int device); + +} // namespace c10 -- Gitee From 6babefead09182046a3d0af5586fabd258e8b573 Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Thu, 17 Feb 2022 09:38:04 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=A3=80=E6=9F=A5?= =?UTF-8?q?=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_tensor.py | 90 +++++++++---------- torch_npu/csrc/InitNpuBindings.cpp | 2 +- torch_npu/csrc/aten/common/SetNpu.cpp | 2 +- .../csrc/aten/common/TensorFactories.cpp | 6 +- .../csrc/core/npu/NPUCachingAllocator.cpp | 12 +-- torch_npu/csrc/core/npu/NPUCachingAllocator.h | 44 ++++----- .../csrc/distributed/ProcessGroupHCCL.cpp | 10 +-- torch_npu/csrc/framework/utils/NpuUtils.cpp | 2 +- torch_npu/csrc/npu/Module.cpp | 30 +++---- 9 files changed, 99 insertions(+), 99 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 7a861a979b..e71b687685 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,11 +1,12 @@ import tempfile +from itertools import product, combinations, combinations_with_replacement, permutations + import torch import torch_npu from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_device_type import device_type_test_bases, \ DeviceTypeTestBase, onlyOn, dtypes, instantiate_device_type_tests -from itertools import product, combinations, combinations_with_replacement, permutations def onlyNPU(fn): @@ -30,51 +31,51 @@ class TestTensor(TestCase): @onlyNPU def test_tensor_set(self, device): - t1 = torch.Tensor() - t2 = torch.Tensor(3, 4, 9, 10).uniform_() - t1.set_(t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - size = torch.Size([9, 3, 4, 10]) - t1.set_(t2.storage(), 0, size) - self.assertEqual(t1.size(), size) - t1.set_(t2.storage(), 0, tuple(size)) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), (120, 40, 10, 1)) - stride = (10, 360, 90, 1) - t1.set_(t2.storage(), 0, size, stride) - self.assertEqual(t1.stride(), stride) - t1.set_(t2.storage(), 0, size=size, stride=stride) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), stride) - - # test argument names - t1 = torch.Tensor() - # 1. case when source is tensor - t1.set_(source=t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - # 2. case when source is storage - t1.set_(source=t2.storage()) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - # 3. case when source is storage, and other args also specified - t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), stride) - - t1 = torch.tensor([True, True], dtype=torch.bool) - t2 = torch.tensor([False, False], dtype=torch.bool) - t1.set_(t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + t1 = torch.Tensor() + t2 = torch.Tensor(3, 4, 9, 10).uniform_() + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + size = torch.Size([9, 3, 4, 10]) + t1.set_(t2.storage(), 0, size) + self.assertEqual(t1.size(), size) + t1.set_(t2.storage(), 0, tuple(size)) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), (120, 40, 10, 1)) + stride = (10, 360, 90, 1) + t1.set_(t2.storage(), 0, size, stride) + self.assertEqual(t1.stride(), stride) + t1.set_(t2.storage(), 0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) + + # test argument names + t1 = torch.Tensor() + # 1. case when source is tensor + t1.set_(source=t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 2. case when source is storage + t1.set_(source=t2.storage()) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 3. case when source is storage, and other args also specified + t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) + + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([False, False], dtype=torch.bool) + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) @onlyNPU @dtypes(torch.half, torch.float) def test_cat_all_dtypes_and_devices(self, device, dtype): - x = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device) + x = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device) - expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dtype, device=device) - self.assertEqual(torch.cat((x, x), 0).to('cpu'), expected1.to('cpu')) + expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dtype, device=device) + self.assertEqual(torch.cat((x, x), 0).to('cpu'), expected1.to('cpu')) - expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dtype, device=device) - self.assertEqual(torch.cat((x, x), 1).to('cpu'), expected2.to('cpu')) + expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dtype, device=device) + self.assertEqual(torch.cat((x, x), 1).to('cpu'), expected2.to('cpu')) @onlyNPU def test_cat_mem_overlap(self, device): @@ -105,8 +106,6 @@ class TestTensor(TestCase): z = torch.cat([x, y]) self.assertEqual(z.size(), (21, SIZE, SIZE)) - # TODO: this test should be updated - @onlyNPU def test_zeros(self, device): res1 = torch.zeros(100, 100, device=device) @@ -128,8 +127,6 @@ class TestTensor(TestCase): expected = torch.tensor([[0.]], device=device, dtype=torch.half) self.assertEqual(bfloat16Tensor.to('cpu'), expected.to('cpu')) - # TODO: this test should be updated - @onlyNPU def test_zeros_out(self, device): shape = (3, 4) @@ -150,8 +147,6 @@ class TestTensor(TestCase): self.assertEqual(torch.zeros(shape, device=device).to('cpu'), torch.zeros(shape, device=device, out=out).to('cpu')) - # TODO: this test should be updated - @onlyNPU def test_ones(self, device): res1 = torch.ones(100, 100, device=device) @@ -178,6 +173,7 @@ class TestTensor(TestCase): device=device).as_strided(shape, strides) self.assertEqual(empty_strided.shape, as_strided.shape) self.assertEqual(empty_strided.stride(), as_strided.stride()) + @onlyNPU def test_empty_tensor_props(self, device): sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] @@ -225,7 +221,7 @@ class TestTensor(TestCase): self.assertEqual(torch.full(size, 1, out=o).dtype, o.dtype) # TODO: this test should be updated - + @onlyNPU def test_ones_like(self, device): expected = torch.ones(100, 100, device=device) diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 0774cc72c1..56ddd1f2d2 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -53,7 +53,7 @@ PyObject * THPModule_npu_shutdown(PyObject * /* unused */) if (c10::npu::NpuSysCtrl::GetInstance().GetInitFlag()) { c10::npu::npuSynchronizeDevice(); THNPUCachingHostAllocator_emptyCache(); - c10_npu::emptyCache(); + c10_npu::NPUCachingAllocatoremptyCache(); c10::npu::NpuSysCtrl::SysStatus status = c10::npu::NpuSysCtrl::GetInstance().Finalize(); if (status != c10::npu::NpuSysCtrl::SysStatus::FINALIZE_SUCC) { fprintf(stdout, "THPModule_npu_shutdown failed.\n"); diff --git a/torch_npu/csrc/aten/common/SetNpu.cpp b/torch_npu/csrc/aten/common/SetNpu.cpp index f69d25aae5..08f743c733 100644 --- a/torch_npu/csrc/aten/common/SetNpu.cpp +++ b/torch_npu/csrc/aten/common/SetNpu.cpp @@ -32,7 +32,7 @@ c10::StorageImpl* storage_new_npu(caffe2::TypeMeta data_type) { c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), 0, - c10_npu::get(), + c10_npu::NPUCachingAllocatorget(), true) .release(); return storage; diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index d7b8d73c20..43c4ea7419 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -84,7 +84,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::get(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); int64_t nelements = at::prod_intlist(size); auto dtype = c10::scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); int64_t size_bytes = nelements * dtype.itemsize(); @@ -272,7 +272,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::get(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); @@ -306,7 +306,7 @@ namespace at_npu AT_ASSERT(options.backend() == at::Backend::NPU); TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::get(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp index dddf79cdab..cca744252f 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.cpp @@ -14,11 +14,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - -#include -#include -#include -#include #include #include #include @@ -29,13 +24,19 @@ #include #include #include + +#include +#include #include #include #include +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" namespace c10_npu { +namespace NPUCachingAllocator { C10_DEFINE_REGISTRY(FreeNPUMemoryCallbacksRegistry, FreeMemoryCallback); @@ -1174,4 +1175,5 @@ void FreeDeviceCachedMemory(int device) caching_allocator.free_cached_blocks(device); } +} // namespace NPUCachingAllocator } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/NPUCachingAllocator.h b/torch_npu/csrc/core/npu/NPUCachingAllocator.h index a63106e0fe..8b7eed8e82 100644 --- a/torch_npu/csrc/core/npu/NPUCachingAllocator.h +++ b/torch_npu/csrc/core/npu/NPUCachingAllocator.h @@ -25,6 +25,7 @@ #include namespace c10_npu { +namespace NPUCachingAllocator { // Caching allocator will execute every registered callback if it unable to find // block inside of already allocated area. @@ -112,31 +113,32 @@ struct SegmentInfo { }; -C10_NPU_API void* raw_alloc(size_t nbytes); -C10_NPU_API void* raw_alloc_with_stream(size_t nbytes, aclrtStream stream); -C10_NPU_API void raw_delete(void* ptr); +void* raw_alloc(size_t nbytes); +void* raw_alloc_with_stream(size_t nbytes, aclrtStream stream); +void raw_delete(void* ptr); -C10_NPU_API std::tuple allocate_adjacent(size_t size1, size_t size2); +std::tuple allocate_adjacent(size_t size1, size_t size2); -C10_NPU_API c10::Allocator* get(); -C10_NPU_API void emptyCache(); -C10_NPU_API void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock); -C10_NPU_API void* getBaseAllocation(void* ptr, size_t* size); -C10_NPU_API void recordStream(const c10::DataPtr& ptr, c10::npu::NPUStream stream); -C10_NPU_API DeviceStats_ getDeviceStats(int device); -C10_NPU_API void resetAccumulatedStats(int device); -C10_NPU_API void resetPeakStats(int device); -C10_NPU_API std::vector snapshot(); +c10::Allocator* get(); +void emptyCache(); +void cacheInfo(int dev_id, size_t* cachedAndFree, size_t* largestBlock); +void* getBaseAllocation(void* ptr, size_t* size); +void recordStream(const c10::DataPtr& ptr, c10::npu::NPUStream stream); +DeviceStats_ getDeviceStats(int device); +void resetAccumulatedStats(int device); +void resetPeakStats(int device); +std::vector snapshot(); -C10_NPU_API uint64_t currentMemoryAllocated(int device); -C10_NPU_API uint64_t maxMemoryAllocated(int device); -C10_NPU_API void resetMaxMemoryAllocated(int device); -C10_NPU_API uint64_t currentMemoryCached(int device); -C10_NPU_API uint64_t maxMemoryCached(int device); -C10_NPU_API void resetMaxMemoryCached(int device); +uint64_t currentMemoryAllocated(int device); +uint64_t maxMemoryAllocated(int device); +void resetMaxMemoryAllocated(int device); +uint64_t currentMemoryCached(int device); +uint64_t maxMemoryCached(int device); +void resetMaxMemoryCached(int device); -C10_NPU_API std::mutex* getFreeMutex(); +std::mutex* getFreeMutex(); -C10_NPU_API void FreeDeviceCachedMemory(int device); +void FreeDeviceCachedMemory(int device); +} // namespace NPUCachingAllocator } // namespace c10 diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index cc7f30d15d..b9239a9645 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -460,7 +460,7 @@ c10::intrusive_ptr ProcessGroupHCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - c10_npu::recordStream( + c10_npu::NPUCachingAllocatorrecordStream( inputs[i].storage().data_ptr(), hcclStream); } { @@ -578,7 +578,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclAllgather", std::vector({input})); - c10_npu::recordStream( + c10_npu::NPUCachingAllocatorrecordStream( output.storage().data_ptr(), stream); return HcclAllGather( input.data_ptr(), @@ -595,7 +595,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < outputTensors[0].size(); ++j) { // See [Sync Streams]. - c10_npu::recordStream( + c10_npu::NPUCachingAllocatorrecordStream( outputTensors[i][j].storage().data_ptr(), hcclStreams[i]); outputTensors[i][j].copy_(outputFlattened[i][j], true); @@ -629,7 +629,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclReduceScatter", std::vector({input})); - c10_npu::recordStream( + c10_npu::NPUCachingAllocatorrecordStream( output.storage().data_ptr(), stream); return HcclReduceScatter( input.data_ptr(), @@ -646,7 +646,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < inputTensors[0].size(); ++j) { // See [Sync Streams]. - c10_npu::recordStream( + c10_npu::NPUCachingAllocatorrecordStream( inputTensors[i][j].storage().data_ptr(), hcclStreams[i]); inputFlattened[i][j].copy_(inputTensors[i][j], true); diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index 01194ff2af..8a8d3031df 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -357,7 +357,7 @@ namespace at_npu if (index == 1) { C10_NPU_CHECK(aclrtGetDevice(&deviceId)); - c10_npu::FreeDeviceCachedMemory(deviceId); + c10_npu::NPUCachingAllocatorFreeDeviceCachedMemory(deviceId); return true; } AT_ERROR("NPU out of memory. device id: ", deviceId); diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index dee9b8ac2c..e622f556a2 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -171,7 +171,7 @@ PyObject * THNPModule_setStream_wrap(PyObject *self, PyObject *obj) PyObject * THNPModule_emptyCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - c10_npu::emptyCache(); + c10_npu::NPUCachingAllocatoremptyCache(); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -182,10 +182,10 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const int device = (int) THPUtils_unpackLong(arg); - using c10_npu::StatType; - using c10_npu::Stat; - using c10_npu::StatArray; - using c10_npu::DeviceStats_; + using c10_npu::NPUCachingAllocatorStatType; + using c10_npu::NPUCachingAllocatorStat; + using c10_npu::NPUCachingAllocatorStatArray; + using c10_npu::NPUCachingAllocatorDeviceStats_; const auto statToDict = [](const Stat& stat) { py::dict dict; @@ -208,7 +208,7 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) return dict; }; - const DeviceStats_ stats = c10_npu::getDeviceStats(device); + const DeviceStats_ stats = c10_npu::NPUCachingAllocatorgetDeviceStats(device); py::dict result; result["num_alloc_retries"] = stats.num_alloc_retries; @@ -231,7 +231,7 @@ PyObject * THNPModule_resetAccumulatedMemoryStats(PyObject *_unused, PyObject *a HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_accumulated_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10_npu::resetAccumulatedStats(device); + c10_npu::NPUCachingAllocatorresetAccumulatedStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -241,7 +241,7 @@ PyObject * THNPModule_resetPeakMemoryStats(PyObject *_unused, PyObject *arg) HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10_npu::resetPeakStats(device); + c10_npu::NPUCachingAllocatorresetPeakStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -250,8 +250,8 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - using c10_npu::SegmentInfo; - using c10_npu::BlockInfo; + using c10_npu::NPUCachingAllocatorSegmentInfo; + using c10_npu::NPUCachingAllocatorBlockInfo; const auto segmentInfoToDict = [](const SegmentInfo& segmentInfo) { py::dict segmentDict; @@ -274,7 +274,7 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) return segmentDict; }; - const std::vector& snapshot = c10_npu::snapshot(); + const std::vector& snapshot = c10_npu::NPUCachingAllocatorsnapshot(); py::list result; for (const auto& segmentInfo : snapshot) { @@ -300,7 +300,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject } ssize_t size = PyLong_AsSsize_t(size_o); aclrtStream stream = static_cast(PyLong_AsVoidPtr(stream_o)); - void* mem = c10_npu::raw_alloc_with_stream(size, stream); + void* mem = c10_npu::NPUCachingAllocatorraw_alloc_with_stream(size, stream); return PyLong_FromVoidPtr(mem); END_HANDLE_TH_ERRORS } @@ -308,7 +308,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject PyObject * THNPModule_npuCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){ HANDLE_TH_ERRORS void* mem_ptr = PyLong_AsVoidPtr(obj); - c10_npu::raw_delete(mem_ptr); + c10_npu::NPUCachingAllocatorraw_delete(mem_ptr); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -322,7 +322,7 @@ static PyGILState_STATE npuMutexGILState; PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10_npu::getFreeMutex(); + auto mutex = c10_npu::NPUCachingAllocatorgetFreeMutex(); // This has to be a busy loop because we **absolutely need to** hold the GIL // or it's a recipe for a deadlock otherwise (if we let other Python threads // run while we have the cudaMutex, but not the GIL, they might try to e.g. @@ -343,7 +343,7 @@ PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) PyObject * THNPModule_npuUnlockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10_npu::getFreeMutex(); + auto mutex = c10_npu::NPUCachingAllocatorgetFreeMutex(); PyGILState_Release(npuMutexGILState); mutex->unlock(); Py_RETURN_NONE; -- Gitee From ebf804eb4de1ca31c5b69897eef7f1b3e2495669 Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Thu, 17 Feb 2022 09:51:31 +0800 Subject: [PATCH 5/6] add torch_npu/csrc/InitNpuBindings.cpp --- torch_npu/csrc/InitNpuBindings.cpp | 2 +- torch_npu/csrc/aten/common/SetNpu.cpp | 2 +- .../csrc/aten/common/TensorFactories.cpp | 6 ++-- .../csrc/distributed/ProcessGroupHCCL.cpp | 10 +++---- torch_npu/csrc/framework/utils/NpuUtils.cpp | 2 +- torch_npu/csrc/npu/Module.cpp | 30 +++++++++---------- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 56ddd1f2d2..c34eac8c07 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -53,7 +53,7 @@ PyObject * THPModule_npu_shutdown(PyObject * /* unused */) if (c10::npu::NpuSysCtrl::GetInstance().GetInitFlag()) { c10::npu::npuSynchronizeDevice(); THNPUCachingHostAllocator_emptyCache(); - c10_npu::NPUCachingAllocatoremptyCache(); + c10_npu::NPUCachingAllocator::emptyCache(); c10::npu::NpuSysCtrl::SysStatus status = c10::npu::NpuSysCtrl::GetInstance().Finalize(); if (status != c10::npu::NpuSysCtrl::SysStatus::FINALIZE_SUCC) { fprintf(stdout, "THPModule_npu_shutdown failed.\n"); diff --git a/torch_npu/csrc/aten/common/SetNpu.cpp b/torch_npu/csrc/aten/common/SetNpu.cpp index 08f743c733..42e5feb44d 100644 --- a/torch_npu/csrc/aten/common/SetNpu.cpp +++ b/torch_npu/csrc/aten/common/SetNpu.cpp @@ -32,7 +32,7 @@ c10::StorageImpl* storage_new_npu(caffe2::TypeMeta data_type) { c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), 0, - c10_npu::NPUCachingAllocatorget(), + c10_npu::NPUCachingAllocator::get(), true) .release(); return storage; diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index 43c4ea7419..f0b97cf836 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -84,7 +84,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocator::get(); int64_t nelements = at::prod_intlist(size); auto dtype = c10::scalarTypeToTypeMeta(dtype_or_default(dtype_opt)); int64_t size_bytes = nelements * dtype.itemsize(); @@ -272,7 +272,7 @@ namespace at_npu AT_ASSERT(c10::device_or_default(device_opt).type() == at::DeviceType::NPU); TORCH_CHECK(!pinned_memory_or_default(pin_memory_opt), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocator::get(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); @@ -306,7 +306,7 @@ namespace at_npu AT_ASSERT(options.backend() == at::Backend::NPU); TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned"); check_size_nonnegative(size); - c10::Allocator *allocator = c10_npu::NPUCachingAllocatorget(); + c10::Allocator *allocator = c10_npu::NPUCachingAllocator::get(); // when the shape and format are not match, fix format here. aclFormat format = InferFormat::GuessStorageFormat(size, (aclFormat)dst_format); int64_t nelements = StorageDescHelper::GetMemorySize(size, format); diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index b9239a9645..31d394afcc 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -460,7 +460,7 @@ c10::intrusive_ptr ProcessGroupHCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - c10_npu::NPUCachingAllocatorrecordStream( + c10_npu::NPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), hcclStream); } { @@ -578,7 +578,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclAllgather", std::vector({input})); - c10_npu::NPUCachingAllocatorrecordStream( + c10_npu::NPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); return HcclAllGather( input.data_ptr(), @@ -595,7 +595,7 @@ c10::intrusive_ptr ProcessGroupHCCL::allgather( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < outputTensors[0].size(); ++j) { // See [Sync Streams]. - c10_npu::NPUCachingAllocatorrecordStream( + c10_npu::NPUCachingAllocator::recordStream( outputTensors[i][j].storage().data_ptr(), hcclStreams[i]); outputTensors[i][j].copy_(outputFlattened[i][j], true); @@ -629,7 +629,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( HcclComm comm, c10::npu::NPUStream& stream) { RECORD_FUNCTION("HcclReduceScatter", std::vector({input})); - c10_npu::NPUCachingAllocatorrecordStream( + c10_npu::NPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); return HcclReduceScatter( input.data_ptr(), @@ -646,7 +646,7 @@ c10::intrusive_ptr ProcessGroupHCCL::reduce_scatter( c10::npu::NPUStreamGuard guard(hcclStreams[i]); for (size_t j = 0; j < inputTensors[0].size(); ++j) { // See [Sync Streams]. - c10_npu::NPUCachingAllocatorrecordStream( + c10_npu::NPUCachingAllocator::recordStream( inputTensors[i][j].storage().data_ptr(), hcclStreams[i]); inputFlattened[i][j].copy_(inputTensors[i][j], true); diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index 8a8d3031df..a573f25d66 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -357,7 +357,7 @@ namespace at_npu if (index == 1) { C10_NPU_CHECK(aclrtGetDevice(&deviceId)); - c10_npu::NPUCachingAllocatorFreeDeviceCachedMemory(deviceId); + c10_npu::NPUCachingAllocator::FreeDeviceCachedMemory(deviceId); return true; } AT_ERROR("NPU out of memory. device id: ", deviceId); diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index e622f556a2..079984ac6a 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -171,7 +171,7 @@ PyObject * THNPModule_setStream_wrap(PyObject *self, PyObject *obj) PyObject * THNPModule_emptyCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - c10_npu::NPUCachingAllocatoremptyCache(); + c10_npu::NPUCachingAllocator::emptyCache(); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -182,10 +182,10 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to memory_allocated"); const int device = (int) THPUtils_unpackLong(arg); - using c10_npu::NPUCachingAllocatorStatType; - using c10_npu::NPUCachingAllocatorStat; - using c10_npu::NPUCachingAllocatorStatArray; - using c10_npu::NPUCachingAllocatorDeviceStats_; + using c10_npu::NPUCachingAllocator::StatType; + using c10_npu::NPUCachingAllocator::Stat; + using c10_npu::NPUCachingAllocator::StatArray; + using c10_npu::NPUCachingAllocator::DeviceStats_; const auto statToDict = [](const Stat& stat) { py::dict dict; @@ -208,7 +208,7 @@ PyObject * THNPModule_memoryStats(PyObject *_unused, PyObject *arg) return dict; }; - const DeviceStats_ stats = c10_npu::NPUCachingAllocatorgetDeviceStats(device); + const DeviceStats_ stats = c10_npu::NPUCachingAllocator::getDeviceStats(device); py::dict result; result["num_alloc_retries"] = stats.num_alloc_retries; @@ -231,7 +231,7 @@ PyObject * THNPModule_resetAccumulatedMemoryStats(PyObject *_unused, PyObject *a HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_accumulated_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10_npu::NPUCachingAllocatorresetAccumulatedStats(device); + c10_npu::NPUCachingAllocator::resetAccumulatedStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -241,7 +241,7 @@ PyObject * THNPModule_resetPeakMemoryStats(PyObject *_unused, PyObject *arg) HANDLE_TH_ERRORS THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to reset_peak_memory_stats"); const int device = (int) THPUtils_unpackLong(arg); - c10_npu::NPUCachingAllocatorresetPeakStats(device); + c10_npu::NPUCachingAllocator::resetPeakStats(device); END_HANDLE_TH_ERRORS Py_RETURN_NONE; } @@ -250,8 +250,8 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS - using c10_npu::NPUCachingAllocatorSegmentInfo; - using c10_npu::NPUCachingAllocatorBlockInfo; + using c10_npu::NPUCachingAllocator::SegmentInfo; + using c10_npu::NPUCachingAllocator::BlockInfo; const auto segmentInfoToDict = [](const SegmentInfo& segmentInfo) { py::dict segmentDict; @@ -274,7 +274,7 @@ PyObject * THNPModule_memorySnapshot(PyObject *_unused, PyObject *noargs) return segmentDict; }; - const std::vector& snapshot = c10_npu::NPUCachingAllocatorsnapshot(); + const std::vector& snapshot = c10_npu::NPUCachingAllocator::snapshot(); py::list result; for (const auto& segmentInfo : snapshot) { @@ -300,7 +300,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject } ssize_t size = PyLong_AsSsize_t(size_o); aclrtStream stream = static_cast(PyLong_AsVoidPtr(stream_o)); - void* mem = c10_npu::NPUCachingAllocatorraw_alloc_with_stream(size, stream); + void* mem = c10_npu::NPUCachingAllocator::raw_alloc_with_stream(size, stream); return PyLong_FromVoidPtr(mem); END_HANDLE_TH_ERRORS } @@ -308,7 +308,7 @@ PyObject * THNPModule_npuCachingAllocator_raw_alloc(PyObject *_unused, PyObject PyObject * THNPModule_npuCachingAllocator_raw_delete(PyObject *_unused, PyObject *obj){ HANDLE_TH_ERRORS void* mem_ptr = PyLong_AsVoidPtr(obj); - c10_npu::NPUCachingAllocatorraw_delete(mem_ptr); + c10_npu::NPUCachingAllocator::raw_delete(mem_ptr); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -322,7 +322,7 @@ static PyGILState_STATE npuMutexGILState; PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10_npu::NPUCachingAllocatorgetFreeMutex(); + auto mutex = c10_npu::NPUCachingAllocator::getFreeMutex(); // This has to be a busy loop because we **absolutely need to** hold the GIL // or it's a recipe for a deadlock otherwise (if we let other Python threads // run while we have the cudaMutex, but not the GIL, they might try to e.g. @@ -343,7 +343,7 @@ PyObject * THNPModule_npuLockMutex(PyObject *module, PyObject *noargs) PyObject * THNPModule_npuUnlockMutex(PyObject *module, PyObject *noargs) { - auto mutex = c10_npu::NPUCachingAllocatorgetFreeMutex(); + auto mutex = c10_npu::NPUCachingAllocator::getFreeMutex(); PyGILState_Release(npuMutexGILState); mutex->unlock(); Py_RETURN_NONE; -- Gitee From edb14aa885cc9f2f1bf687c39f0c2a8943c4833e Mon Sep 17 00:00:00 2001 From: "zhousinan@huawei.com" Date: Thu, 17 Feb 2022 10:06:23 +0800 Subject: [PATCH 6/6] add #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" --- torch_npu/csrc/framework/utils/NpuUtils.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/framework/utils/NpuUtils.h b/torch_npu/csrc/framework/utils/NpuUtils.h index 4c291abd5f..3f00946e47 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.h +++ b/torch_npu/csrc/framework/utils/NpuUtils.h @@ -18,7 +18,6 @@ #define __PULGIN_NATIVE_NPU_UTILS_NUP_UTILS__ #include -#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include #include #include @@ -26,10 +25,12 @@ #include "third_party/acl/inc/acl/acl.h" #include "third_party/acl/inc/acl/acl_base.h" -#include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" #include "third_party/acl/inc/acl/acl_op.h" #include "third_party/acl/inc/ge/ge_error_codes.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" + using std::string; using std::vector; -- Gitee