From 286195f57ad3fdbe3c2bcc7638d0c5c4ee85b28f Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Thu, 14 Aug 2025 17:09:02 +0800 Subject: [PATCH] =?UTF-8?q?torch=5Fnpu=E6=94=AF=E6=8C=81dlpack?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/common/DLConvertor.cpp | 304 +++++++++++++++++++++ torch_npu/csrc/aten/common/DLConvertor.h | 20 ++ torch_npu/csrc/aten/common/dlpack.h | 236 ++++++++++++++++ torch_npu/csrc/npu/Module.cpp | 66 +++++ torch_npu/utils/__init__.py | 1 + torch_npu/utils/dlpack.py | 118 ++++++++ torch_npu/utils/tensor_methods.py | 4 + 7 files changed, 749 insertions(+) create mode 100644 torch_npu/csrc/aten/common/DLConvertor.cpp create mode 100644 torch_npu/csrc/aten/common/DLConvertor.h create mode 100644 torch_npu/csrc/aten/common/dlpack.h create mode 100644 torch_npu/utils/dlpack.py diff --git a/torch_npu/csrc/aten/common/DLConvertor.cpp b/torch_npu/csrc/aten/common/DLConvertor.cpp new file mode 100644 index 00000000000..5583a0bb1fd --- /dev/null +++ b/torch_npu/csrc/aten/common/DLConvertor.cpp @@ -0,0 +1,304 @@ +#include +#include "torch_npu/csrc/aten/common/from_blob.h" +#include "torch_npu/csrc/aten/common/DLConvertor.h" + +using namespace std; +namespace at { + +DLDataType getDLDataType(const Tensor& t) +{ + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = t.element_size() * 8; + switch (t.scalar_type()) { + case ScalarType::UInt1: + case ScalarType::UInt2: + case ScalarType::UInt3: + case ScalarType::UInt4: + case ScalarType::UInt5: + case ScalarType::UInt6: + case ScalarType::UInt7: + case ScalarType::Byte: + case ScalarType::UInt16: + case ScalarType::UInt32: + case ScalarType::UInt64: + dtype.code = DLDataTypeCode::kDLUInt; + break; + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + case ScalarType::Char: + dtype.code = DLDataTypeCode::kDLInt; + break; + // NOLINTNEXTLINE(bugprone-branch-clone) + case ScalarType::Double: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Float: + dtype.code = DLDataTypeCode::kDLFloat; + break; + // NOLINTNEXTLINE(bugprone-branch-clone) + case ScalarType::Int: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Long: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Short: + dtype.code = DLDataTypeCode::kDLInt; + break; + case ScalarType::Half: + dtype.code = DLDataTypeCode::kDLFloat; + break; + case ScalarType::Bool: + dtype.code = DLDataTypeCode::kDLBool; + break; + case ScalarType::ComplexHalf: + case ScalarType::ComplexFloat: + case ScalarType::ComplexDouble: + dtype.code = DLDataTypeCode::kDLComplex; + break; + case ScalarType::BFloat16: + dtype.code = DLDataTypeCode::kDLBfloat; + break; + case ScalarType::Float8_e5m2: + case ScalarType::Float8_e5m2fnuz: + case ScalarType::Float8_e4m3fn: + case ScalarType::Float8_e4m3fnuz: + case ScalarType::Float8_e8m0fnu: + TORCH_CHECK(false, "float8 types are not supported by dlpack"); + break; + case ScalarType::QInt8: + case ScalarType::QUInt8: + case ScalarType::QInt32: + case ScalarType::QUInt4x2: + case ScalarType::QUInt2x4: + TORCH_CHECK(false, "QUInt/QInt types are not supported by dlpack"); + break; + case ScalarType::Bits1x8: + case ScalarType::Bits2x4: + case ScalarType::Bits4x2: + case ScalarType::Bits8: + case ScalarType::Bits16: + TORCH_CHECK(false, "Bit types are not supported by dlpack"); + break; + case ScalarType::Undefined: + TORCH_CHECK(false, "Undefined is not a valid ScalarType"); + case ScalarType::NumOptions: + TORCH_CHECK(false, "NumOptions is not a valid ScalarType"); + } + return dtype; +} + +static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) +{ + DLDevice ctx; + ctx.device_id = static_cast(static_cast(device_id)); + switch (tensor.device().type()) { + case DeviceType::CPU: + ctx.device_type = DLDeviceType::kDLCPU; + break; + case DeviceType::PrivateUse1: + ctx.device_type = DLDeviceType::kDLExtDev; + break; + default: + TORCH_CHECK(false, "Cannot pack tensors on " + tensor.device().str()); + } + return ctx; +} + +static Device getATenDevice(const DLDevice& ctx, void* data) +{ + switch (ctx.device_type) { + case DLDeviceType::kDLCPU: + return at::Device(DeviceType::CPU); + case DLDeviceType::kDLExtDev: + return at::Device(DeviceType::PrivateUse1, static_cast(ctx.device_id)); + default: + TORCH_CHECK( + false, "Unsupported device_type: ", std::to_string(ctx.device_type)); + } +} + +ScalarType toScalarType(const DLDataType& dtype) +{ + ScalarType stype = ScalarType::Undefined; + TORCH_CHECK(dtype.lanes == 1, "ATen does not support lanes != 1"); + switch (dtype.code) { + case DLDataTypeCode::kDLUInt: + switch (dtype.bits) { + case 8: + stype = ScalarType::Byte; + break; + case 16: + stype = ScalarType::UInt16; + break; + case 32: + stype = ScalarType::UInt32; + break; + case 64: + stype = ScalarType::UInt64; + break; + default: + TORCH_CHECK( + false, "Unsupported kUInt bits ", std::to_string(dtype.bits)); + } + break; + case DLDataTypeCode::kDLInt: + switch (dtype.bits) { + case 8: + stype = ScalarType::Char; + break; + case 16: + stype = ScalarType::Short; + break; + case 32: + stype = ScalarType::Int; + break; + case 64: + stype = ScalarType::Long; + break; + default: + TORCH_CHECK( + false, "Unsupported kInt bits ", std::to_string(dtype.bits)); + } + break; + case DLDataTypeCode::kDLFloat: + switch (dtype.bits) { + case 16: + stype = ScalarType::Half; + break; + case 32: + stype = ScalarType::Float; + break; + case 64: + stype = ScalarType::Double; + break; + default: + TORCH_CHECK( + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); + } + break; + case DLDataTypeCode::kDLBfloat: + switch (dtype.bits) { + case 16: + stype = ScalarType::BFloat16; + break; + default: + TORCH_CHECK( + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); + } + break; + case DLDataTypeCode::kDLComplex: + switch (dtype.bits) { + case 32: + stype = ScalarType::ComplexHalf; + break; + case 64: + stype = ScalarType::ComplexFloat; + break; + case 128: + stype = ScalarType::ComplexDouble; + break; + default: + TORCH_CHECK( + false, "Unsupported kFloat bits ", std::to_string(dtype.bits)); + } + break; + case DLDataTypeCode::kDLBool: + switch (dtype.bits) { + case 8: + stype = ScalarType::Bool; + break; + default: + TORCH_CHECK( + false, "Unsupported kDLBool bits ", std::to_string(dtype.bits)); + } + break; + default: + TORCH_CHECK(false, "Unsupported code ", std::to_string(dtype.code)); + } + return stype; +} + +namespace { +struct ATenDLMTensor +{ + Tensor handle; + DLManagedTensor tensor{}; +}; +} // namespace + +static void deleter(DLManagedTensor* arg) +{ + delete static_cast(arg->manager_ctx); +} + +// This function returns a shared_ptr to memory managed DLpack tensor +// constructed out of ATen tensor +DLManagedTensor* toDLPack(const Tensor& src) +{ + // create a new tensor with possibly normalized strides + // gh-83069 + auto shape = src.sizes(); + auto strides = src.strides().vec(); + for (int i = 0; i < src.dim(); i++) { + if (shape[i] < 2) { + strides[i] = 1; + } + } + + auto view = src.as_strided(shape, strides, src.storage_offset()); + ATenDLMTensor* atDLMTensor(new ATenDLMTensor); + atDLMTensor->handle = view; + atDLMTensor->tensor.manager_ctx = atDLMTensor; + atDLMTensor->tensor.deleter = &deleter; + atDLMTensor->tensor.dl_tensor.data = view.data_ptr(); + c10::DeviceIndex device_id = 0; + if (src.is_cuda() || src.is_privateuseone()) { + device_id = src.get_device(); + } + atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id); + atDLMTensor->tensor.dl_tensor.ndim = static_cast(src.dim()); + atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src); + atDLMTensor->tensor.dl_tensor.shape = view.sizes().data(); + atDLMTensor->tensor.dl_tensor.strides = view.strides().data(); + atDLMTensor->tensor.dl_tensor.byte_offset = 0; + return &(atDLMTensor->tensor); +} + +Tensor fromDLPack(DLManagedTensor* src) +{ + auto deleter = [src](void* self [[maybe_unused]]) { + if (src->deleter) { + src->deleter(src); + } + }; + return fromDLPack(src, std::move(deleter)); +} + +Tensor fromDLPack(DLManagedTensor* src, std::function deleter) +{ + Device device = getATenDevice(src->dl_tensor.device, src->dl_tensor.data); + ScalarType stype = toScalarType(src->dl_tensor.dtype); + if (!src->dl_tensor.strides) { + return at_npu::native::from_blob( + src->dl_tensor.data, + IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + std::move(deleter), + at::device(device).dtype(stype), + {device}); + } + return at_npu::native::from_blob( + src->dl_tensor.data, + IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim), + IntArrayRef(src->dl_tensor.strides, src->dl_tensor.ndim), + deleter, + at::device(device).dtype(stype), + {device}); +} +} // namespace at diff --git a/torch_npu/csrc/aten/common/DLConvertor.h b/torch_npu/csrc/aten/common/DLConvertor.h new file mode 100644 index 00000000000..1690820b3ab --- /dev/null +++ b/torch_npu/csrc/aten/common/DLConvertor.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include "torch_npu/csrc/aten/common/dlpack.h" + +// this convertor will: +// 1) take a Tensor object and wrap it in the DLPack tensor +// 2) take a dlpack tensor and convert it to the ATen Tensor + +namespace at { + +TORCH_API ScalarType toScalarType(const DLDataType& dtype); +TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +TORCH_API Tensor fromDLPack(DLManagedTensor* src); +TORCH_API Tensor fromDLPack(DLManagedTensor* src, std::function deleter); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); + +} // namespace at diff --git a/torch_npu/csrc/aten/common/dlpack.h b/torch_npu/csrc/aten/common/dlpack.h new file mode 100644 index 00000000000..6f8e03dd570 --- /dev/null +++ b/torch_npu/csrc/aten/common/dlpack.h @@ -0,0 +1,236 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 80 + +/*! \brief The current ABI version of dlpack */ +#define DLPACK_ABI_VERSION 1 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +// NOLINTNEXTLINE(modernize-deprecated-headers) +#include +// NOLINTNEXTLINE(modernize-deprecated-headers) +#include + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft AI Accelerator */ + kDLMAIA = 17, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + const int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + const int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! \brief Destructor signature void (*)(void*) - this should be called + * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + * if there is no way for the caller to provide a reasonable destructor. + * The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index dd0f006a174..05a5bc03d77 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -17,9 +17,11 @@ #include #include #include +#include #include #include "torch_npu/csrc/aten/NPUGeneratorImpl.h" +#include "torch_npu/csrc/aten/common/DLConvertor.h" #include "torch_npu/csrc/aten/common/SetNpu.h" #include "torch_npu/csrc/core/npu/NPUException.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" @@ -1866,6 +1868,68 @@ PyObject* THNPModule_aclop_stop_dump(PyObject* self, PyObject* noargs) Py_RETURN_NONE; END_HANDLE_TH_ERRORS } +static void DLPack_Capsule_Destructor(PyObject* data) +{ + if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) { + // early out, see DLPack spec: if a consuming library sets the capsule + // name to something else, they own it and we don't need to do anything + return; + } + HANDLE_TH_ERRORS + // Causes overheads for validity checks again, but this case is rare + // since consuming libraries should rename the capsule according to spec. + // Note that this cannot set a python error (we checked validity above), + // so we don't need to handle python error state here. + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + // the dlMTensor has not been consumed, call deleter ourselves. + // DLPack spec mentions that deleter may be NULL, but deleter from + // `at::toDLPack` is never NULL, so no need for an additional check here. + dlMTensor->deleter(dlMTensor); + END_HANDLE_TH_ERRORS_RET() +} + +static PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) +{ + HANDLE_TH_ERRORS + TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor"); + DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data)); + return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor); + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) +{ + using namespace torch::autograd; + HANDLE_TH_ERRORS + DLManagedTensor* dlMTensor = + (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor"); + TORCH_CHECK( + dlMTensor, + "from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."); + + auto deleter_with_gil = [dlMTensor](void*) { + if (dlMTensor->deleter) { + pybind11::gil_scoped_acquire gil; + dlMTensor->deleter(dlMTensor); + } + }; + + // atensor steals the ownership of the underlying storage. It also passes a + // destructor function that will be called when the underlying storage goes + // out of scope. When the destructor is called, the dlMTensor is destructed + // too. + // HACK: Ensure that we hold the GIL here just in case the + // managed tensor originating from a buggy NumPy build. + auto atensor = at::fromDLPack(dlMTensor); + + // Make sure this capsule will never be used again. + PyCapsule_SetName(data, "used_dltensor"); + return THPVariable_Wrap(atensor); + END_HANDLE_TH_ERRORS +} static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, @@ -1937,6 +2001,8 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_reset_device_res_limit", (PyCFunction)THNPModule_reset_device_res_limit, METH_O, nullptr}, {"_aclop_start_dump", (PyCFunction)THNPModule_aclop_start_dump, METH_O, nullptr}, {"_aclop_stop_dump", (PyCFunction)THNPModule_aclop_stop_dump, METH_NOARGS, nullptr}, + {"_npu_to_dlpack", (PyCFunction)THPModule_toDLPack, METH_O, nullptr}, + {"_npu_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, nullptr}, {nullptr}}; TORCH_NPU_API PyMethodDef* THNPModule_get_methods() diff --git a/torch_npu/utils/__init__.py b/torch_npu/utils/__init__.py index 41484e05708..057dc933d17 100644 --- a/torch_npu/utils/__init__.py +++ b/torch_npu/utils/__init__.py @@ -20,6 +20,7 @@ from .flops_count import _FlopsCounter as FlopsCounter from .affinity import _set_thread_affinity as set_thread_affinity from .affinity import _reset_thread_affinity as reset_thread_affinity from ._graph_tree import _apply_npugraph_tree_methods +from .dlpack import to_dlpack, from_dlpack # init flopcount diff --git a/torch_npu/utils/dlpack.py b/torch_npu/utils/dlpack.py new file mode 100644 index 00000000000..18ba225c65d --- /dev/null +++ b/torch_npu/utils/dlpack.py @@ -0,0 +1,118 @@ +import enum +import torch + +from typing import Any +from torch_npu._C import _npu_from_dlpack +from torch_npu._C import _npu_to_dlpack as to_dlpack + + +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = 1, + kDLGPU = 2, + kDLCPUPinned = 3, + kDLOpenCL = 4, + kDLVulkan = 7, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, + kDLExtDev = 12, + kDLOneAPI = 14, + + +torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule + +Returns an opaque object (a "DLPack capsule") representing the tensor. + +.. note:: + ``to_dlpack`` is a legacy DLPack interface. The capsule it returns + cannot be used for anything in Python other than use it as input to + ``from_dlpack``. The more idiomatic use of DLPack is to call + ``from_dlpack`` directly on the tensor object - this works when that + object has a ``__dlpack__`` method, which PyTorch and most other + libraries indeed have now. + +.. warning:: + Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``. + Behavior when a capsule is consumed multiple times is undefined. + +Args: + tensor: a tensor to be exported + +The DLPack capsule shares the tensor's memory. +""") + + +def from_dlpack(ext_tensor: Any) -> 'torch.Tensor': + """from_dlpack(ext_tensor) -> Tensor + + Converts a tensor from an external library into a ``torch.Tensor``. + + The returned PyTorch tensor will share the memory with the input tensor + (which may have come from another library). Note that in-place operations + will therefore also affect the data of the input tensor. This may lead to + unexpected issues (e.g., other libraries may have read-only flags or + immutable data structures), so the user should only do this if they know + for sure that this is fine. + + Args: + ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule): + The tensor or DLPack capsule to convert. + + If ``ext_tensor`` is a tensor (or ndarray) object, it must support + the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__`` + method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is + an opaque ``PyCapsule`` instance, typically produced by a + ``to_dlpack`` function or method. + + Examples:: + + >>> import torch.utils.dlpack + >>> t = torch.arange(4) + + # Convert a tensor directly (supported in PyTorch >= 1.10) + >>> t2 = torch.from_dlpack(t) + >>> t2[:2] = -1 # show that memory is shared + >>> t2 + tensor([-1, -1, 2, 3]) + >>> t + tensor([-1, -1, 2, 3]) + + # The old-style DLPack usage, with an intermediate capsule object + >>> capsule = torch.utils.dlpack.to_dlpack(t) + >>> capsule + + >>> t3 = torch.from_dlpack(capsule) + >>> t3 + tensor([-1, -1, 2, 3]) + >>> t3[0] = -9 # now we're sharing memory between 3 tensors + >>> t3 + tensor([-9, -1, 2, 3]) + >>> t2 + tensor([-9, -1, 2, 3]) + >>> t + tensor([-9, -1, 2, 3]) + + """ + if hasattr(ext_tensor, '__dlpack__'): + device = ext_tensor.__dlpack_device__() + # device is either CUDA or ROCm, we need to pass the current + # stream + if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream(f'cuda:{device[1]}') + # cuda_stream is the pointer to the stream and it is a public + # attribute, but it is not documented + # The array API specify that the default legacy stream must be passed + # with a value of 1 for CUDA + # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none + is_cuda = device[0] == DLDeviceType.kDLGPU + # Since pytorch is not using PTDS by default, lets directly pass + # the legacy stream + stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream + dlpack = ext_tensor.__dlpack__(stream=stream_ptr) + else: + dlpack = ext_tensor.__dlpack__() + else: + # Old versions just call the converter + dlpack = ext_tensor + return _npu_from_dlpack(dlpack) diff --git a/torch_npu/utils/tensor_methods.py b/torch_npu/utils/tensor_methods.py index f978dfc8790..a9ea27309c8 100644 --- a/torch_npu/utils/tensor_methods.py +++ b/torch_npu/utils/tensor_methods.py @@ -86,3 +86,7 @@ def _add_tensor_methods(): torch.Tensor.type_raw = torch.Tensor.type torch.Tensor.type = _npu_type torch.Tensor.__reduce_ex__ = _reduce_ex + torch.utils.to_dlpack = torch_npu.utils.to_dlpack + torch.utils.from_dlpack = torch_npu.utils.from_dlpack + torch.utils.dlpack.to_dlpack = torch_npu.utils.to_dlpack + torch.utils.dlpack.from_dlpack = torch_npu.utils.from_dlpack -- Gitee