From 5757fed5db3b27c2d2e90ae09e372feef92e5534 Mon Sep 17 00:00:00 2001 From: chow_chow <57070221@qq.com> Date: Wed, 16 Dec 2020 22:12:23 +0800 Subject: [PATCH] MinCon Bug Fix Beijing --- MindCon/1214/Custom Operator GPU.md | 262 ++++++++++++++++++++++++++++ 1 file changed, 262 insertions(+) create mode 100644 MindCon/1214/Custom Operator GPU.md diff --git a/MindCon/1214/Custom Operator GPU.md b/MindCon/1214/Custom Operator GPU.md new file mode 100644 index 0000000000..a0d45d2d7a --- /dev/null +++ b/MindCon/1214/Custom Operator GPU.md @@ -0,0 +1,262 @@ +# Custom Operator (GPU) + +## Overview +Operators are the basic elements for building neural networks. When the built-in operators cannot meet the requirements of network under development, a GPU operator can be implemented with MindSpore easily. + +- Primitive registration: primitives operators are the basic unit for building network models. Users can directly or indirectly call primitive operators to build a neural network model. + +- GPU Kernel implementation: GPU Kernel is used to call GPU to implement accelerated computation. + +- GPU Kernel registration: Operator registration is used to register GPU Kernel and necessary information to the framework, and the framework completes the call to the GPU Kernel. + +In this tutorial, a TensorAddV2 operator will be develop with C++ and CUDA in the MindSpore framework. TensorAddV2 is used to add two Tensors of the same dimension element by element. + +## Primitive Operator Registration + +A primitives operator usually includes: + +- Operator name: the operator name is used to uniquely identify an operator + +- Comment: describes the algorithm and usage constraints of the operator. The comment will be exported as the MindSpore API documentation for developers to consult. + +- Input: input Tensor of the operator. + +- Attribute: general description of algorithm parameters, for example, `data_format` in Conv2d describes the input data in `NCHW` or `NHWC` format. + +- Input data compliance check: checks input data and attributes. It is for the convenience of developers to locate problems in the network model as early as possible. + +- Output data type and dimension derivation: derives the output data type and dimension. + +An operator named TensorAddV2 is defined in the following code: + +- `TensorAddV2` inherits from `PrimitiveWithInfer`. + +- The `__init__` constructor is used to initialize the operator. TensorAddV2 has no attributes, so there is no additional input in `__init__`. + +- In the `infer_shape` method, the dimension of two inputs must be the same, and the of dimension of output is the same as as that of x1 + +- The `infer_dtype` method restricts that the two input data must be of float32 type, and the type of output data is the same as that of input data . + +``` +# mindspore/ops/operations/math_ops.py +class TensorAddV2(PrimitiveWithInfer): + """ + Adds two input tensors element-wise. + """ + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) + + def infer_shape(self, x1_shape, x2_shape): + validator.check_integer('input dims', len(x1_shape), len(x2_shape), Rel.EQ, self.name) + for i in range(len(x1_shape)): + validator.check_integer('input_shape', x1_shape[i], x2_shape[i], Rel.EQ, self.name) + return x1_shape + + def infer_dtype(self, x1_dtype, x2_type): + validator.check_tensor_type_same({'x1_dtype': x1_dtype}, [mstype.float32], self.name) + validator.check_tensor_type_same({'x2_dtype': x2_dtype}, [mstype.float32], self.name) + return x1_dtype +``` + +Then export the TensorAddV2 type in `__init__.py` for developers to import and use in the network. + +``` +# mindspore/ops/operations/__init__.py +from .math_ops import (Abs, ACos, ..., TensorAddV2) +... +... +__all__ = [ + 'ReverseSequence', + 'CropAndResize', + ..., + 'TensorAddV2' +] +``` + +## Define Back Propagationion Function of the Operator + +It is a necessity to define a back propagation function (bprop) in the primitive operator If it support automatic differentiation. the bprop must define how the reverse calculation logic gets its input gradient from the forward input, forward output, and output gradient. The reverse calculation logic can be constructed by using built-in operators or customized reverse operators. + +Pay attention to the following points when defining back propagation function of the operator: + +- The input parameter order of the bprop function is defined as forward input, forward output, and output gradient. If the operator has multiple outputs, the forward output and output gradient will be provided in the form of tuples. + +- The return values of the bprop function is defined to be a tuple of input gradients, and the order of the elements in the tuple is consistent with the order of the forward input parameters. Even if there is only one input gradient, the return value must be in the form of a tuple. + +For example, the reverse primitive of TensorAddV2 is: + +``` +import mindspore.ops as ops +@bprop_getters.register(ops.TensorAddV2) +def get_bprop_tensoraddv2(self): + """Generate bprop for TensorAddV2""" + + def bprop(x1, x2, out, dout): + return dout, dout + + return bprop +``` + +## GPU Operator Development + +GPU custom operators are inherited from `GPUKernel`: + +- `Init()`: initializes the GPU Kernel, records the input/output dimensions of the operator, and prepares for the launch. + +- `GetInputSizeList()`: reports the input Tensor's requirement of GPU memory (in bytes) to MindSpore. + +- `GetOutputSizeList()`: reports the output Tensor's requirement of GPU memory (in bytes) to MindSpore. + +- `GetWorkspaceSizeList()`: reports the bytes number of `Workspace` to MindSpore. `Workspace` is the space used to store temporary data during the calculation. + +- `Launch()`: usually calls CUDA kernel (CUDA kernel is the kernel function based on the parallel computing architecture of Nvidia GPU), or cuDNN interface, etc., to implement the acceleration of operator on the GPUs. + +The following code shows the implementation of TensorAddV2: In order to support the generalization of data types, we use class templates to define `TensorAddV2GpuKernel`: + +- `Init()`: records the number of elements of Tensor + +`GetInputSizeList()`: returns the number of bytes that the input Tensor needs to occupy. TensorAddV2 has two Inputs, and each Input occupies the space of element_num * sizeof(T). + +- `GetOutputSizeList()` : returns the number of bytes that the output Tensor needs to occupy. TensorAddV2 has an output that occupies element_num * sizeof(T) bytes. + +- Since TensorAddV2 does not require `Workspace`, `GetWorkspaceSizeList()` returns an empty `std::vector`. + - `Launch()`: receives the address of input and output in the GPU memory, and then calls `TensorAddV2` to accelerate. + +``` +// mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h + +template +class TensorAddV2GpuKernel : public GpuKernel { + public: + TensorAddV2GpuKernel() : element_num_(1) {} + ~TensorAddV2GpuKernel() override = default; + + bool Init(const CNodePtr &kernel_node) override { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < shape.size(); i++) { + element_num_ *= shape[i]; + } + InitSizeLists(); + return true; + } + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, void *stream_ptr) override { + T *x1 = GetDeviceAddress(inputs, 0); + T *x2 = GetDeviceAddress(inputs, 1); + T *y = GetDeviceAddress(outputs, 0); + + TensorAddV2(element_num_, x1, x2, y, reinterpret_cast(stream_ptr)); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(element_num_ * sizeof(T)); + input_size_list_.push_back(element_num_ * sizeof(T)); + output_size_list_.push_back(element_num_ * sizeof(T)); + } + + private: + size_t element_num_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +``` + +The CUDA kernel `TensorAddV2Kernel` is called in `TensorAddV2` to realize the parallel addition of `element_num` elements: + +``` +// mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.h + + __global__ void TensorAddV2Kernel(const size_t element_num, const T* x1, const T* x2, T* y) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < element_num; i += blockDim.x * gridDim.x) { + y[i] = x1[i] + x2[i]; + } + } + + template + void TensorAddV2(const size_t &element_num, const T* x1, const T* x2, T* y, cudaStream_t stream){ + size_t thread_per_block = 256; + size_t block_per_grid = (element_num + thread_per_block - 1 ) / thread_per_block; + TensorAddV2Kernel<<>>(element_num, x1, x2, y); + return; + } + + template void TensorAddV2(const size_t &element_num, const float* x1, const float* x2, float* y, cudaStream_t stream); +``` + +## GPU Operator Registration + +Operator information includes: + +- Primitive + +- Input dtype, output dtype + +- GPU Kernel class + +- CUDA build-in dtype + +MindSpore will call the `CUDA build-in dtype` to instantiate the `GPU Kernel class` template class according to `Primive`, `Input dtype`, and `output dtype`. + +The following code shows the TensorAddV2 operators that support float and int data types. + +``` +// mindspore/ccsrc/backend/kernel_compiler/gpu/math/tensor_add_v2_gpu_kernel.cc + +MS_REG_GPU_KERNEL_ONE(TensorAddV2, KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + TensorAddV2GpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(TensorAddV2, KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + TensorAddV2GpuKernel, int) +``` + +## Compile MindSpore + +MindSpore needs re-compilation and installation for the custom GPU operator. For details, please refer to the [installation document](https://gitee.com/mindspore/docs/blob/master/install/mindspore_gpu_install_source.md#). + +## Operator Verification + +the following codes verifies the TensorAddV2 operator with a single operator network: + +``` +# tests/st/ops/gpu/test_tensoraddv2_op.py + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(device_target='GPU') + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_TensroAdd(): + x1 = Tensor(np.ones((3, 4), np.float32)) + x2 = Tensor(np.ones((3, 4), np.float32)) + y = P.TensorAddV2()(x1, x2) + print('result: ', y) +``` + +There shoould be result is as by the `pytest -s tests/st/ops/gpu/test_tensoraddv2_op.py` command: + +``` +result: [[2. 2. 2. 2.] + [2. 2. 2. 2.] + [2. 2. 2. 2.]] +``` + + -- Gitee