diff --git a/test/test_network_ops/test_quantize_per_channel.py b/test/test_network_ops/test_quantize_per_channel.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc0a808e0fed650b78fff056509ffacd74b0276 --- /dev/null +++ b/test/test_network_ops/test_quantize_per_channel.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestQuantizePerChannel(TestCase): + def generate_data_per_channel(self, min_d, max_d, shape_x, shape_scale, shape_zp, dtype_x, dtype_scale, dtype_zp): + input_x = np.random.uniform(min_d, max_d, shape_x).astype(dtype_x) + scales = np.random.uniform(min_d, max_d, shape_scale).astype(dtype_scale) + zero_points = np.random.uniform(min_d, max_d, shape_zp).astype(dtype_zp) + npu_input_x = torch.from_numpy(input_x) + npu_input_scales = torch.from_numpy(scales) + npu_input_zero_points = torch.from_numpy(zero_points) + return npu_input_x, npu_input_scales, npu_input_zero_points + + def cpu_op_exec_per_channel(self, input_x, input_scales, input_zero_points, axis, dtype): + output = torch.quantize_per_channel(input_x, input_scales, input_zero_points, axis, dtype).int_repr() + output = output.numpy() + return output + + def npu_op_exec_per_channel(self, input_x, input_scales, input_zero_points, axis, dtype): + input_x = input_x.to("npu") + input_scales = input_scales.to("npu") + input_zero_points = input_zero_points.to("npu") + output = torch.quantize_per_channel(input_x, input_scales, input_zero_points, axis, dtype) + output = output.to("cpu") + output = output.numpy() + return output + + def test_per_channel_3_3_0_int32(self, device): + input_x1, scales, zero_points = self.generate_data_per_channel(-1, 1, (3, 3), (3,), (3,), np.float32, + np.float32, np.int32) + cpu_output1 = self.cpu_op_exec_per_channel(input_x1, scales, zero_points, 0, torch.qint32) + npu_output1 = self.npu_op_exec_per_channel(input_x1, scales, zero_points, 0, torch.qint32) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_channel_3_3_3_3_1_int8(self, device): + input_x1, scales, zero_points = self.generate_data_per_channel(-1, 1, (3, 3), (3,), (3,), np.float32, + np.float32, np.int8) + cpu_output1 = self.cpu_op_exec_per_channel(input_x1, scales, zero_points, 1, torch.qint8).astype(np.int32) + npu_output1 = self.npu_op_exec_per_channel(input_x1, scales, zero_points, 1, torch.qint8).astype(np.int32) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_channel_3_3_3_3_3_3_3_3_4_uint8(self, device): + input_x1, scales, zero_points = self.generate_data_per_channel(-1, 1, (3, 3, 3, 3, 3, 3, 3, 3), (3,), (3,), + np.float32, np.float32, np.int32) + cpu_output1 = self.cpu_op_exec_per_channel(input_x1, scales, zero_points, 4, torch.quint8) + npu_output1 = self.npu_op_exec_per_channel(input_x1, scales, zero_points, 4, torch.quint8) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_channel_30_30_30_30_30_2_uint8(self, device): + input_x1, scales, zero_points = self.generate_data_per_channel(-1, 1, (30, 30, 30, 30), (30,), (30,), + np.float16, np.float32, np.uint8) + input_x1_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec_per_channel(input_x1_cpu, scales, zero_points, 2, torch.quint8) + npu_output1 = self.npu_op_exec_per_channel(input_x1, scales, zero_points, 2, torch.quint8) + self.assertRtolEqual(cpu_output1, npu_output1) + +instantiate_device_type_tests(TestQuantizePerChannel, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_quantize_per_tensor.py b/test/test_network_ops/test_quantize_per_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..446e9f8f48e759c66362016c17d61ee3282d3e67 --- /dev/null +++ b/test/test_network_ops/test_quantize_per_tensor.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestQuantizePerTensor(TestCase): + + def generate_data_per_tensor(self, min_d, max_d, shape_x, dtype_x): + input_x = np.random.uniform(min_d, max_d, shape_x).astype(dtype_x) + npu_input_x = torch.from_numpy(input_x) + return npu_input_x + + def cpu_op_exec_per_tensor(self, input_x, input_scale, input_zero_point, dtype): + output = torch.quantize_per_tensor(input_x, input_scale, input_zero_point, dtype).int_repr() + output = output.numpy() + return output + + def npu_op_exec_per_tensor(self, input_x, input_scale, input_zero_point, dtype): + input_x = input_x.to("npu") + output = torch.quantize_per_tensor(input_x, input_scale, input_zero_point, dtype) + output = output.to("cpu") + output = output.numpy() + return output + + def test_per_tensor_3_3_0p1_10_int32(self, device): + input_x1 = self.generate_data_per_tensor(-1, 1, (3, 3), np.float32) + cpu_output1 = self.cpu_op_exec_per_tensor(input_x1, 0.1, 10, torch.qint32) + npu_output1 = self.npu_op_exec_per_tensor(input_x1, 0.1, 10, torch.qint32) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_tensor_3_3_0p1_10_int8(self, device): + input_x1 = self.generate_data_per_tensor(-1, 1, (3, 3), np.float16) + input_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec_per_tensor(input_cpu, 0.1, 10, torch.qint8) + npu_output1 = self.npu_op_exec_per_tensor(input_x1, 0.1, 10, torch.qint8) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_tensor_3_3_3_3_3_3_0p1_10_uint8(self, device): + input_x1 = self.generate_data_per_tensor(-1, 1, (3, 3, 3, 3, 3, 3), np.float32) + cpu_output1 = self.cpu_op_exec_per_tensor(input_x1, 0.1, 10, torch.quint8) + npu_output1 = self.npu_op_exec_per_tensor(input_x1, 0.1, 10, torch.quint8) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_per_tensor_30_30_30_30_30_30_0p01_5_uint8(self, device): + input_x1 = self.generate_data_per_tensor(-1, 1, (30, 30, 30, 30, 30, 30), np.float32) + cpu_output1 = self.cpu_op_exec_per_tensor(input_x1, 0.01, 5, torch.quint8) + npu_output1 = self.npu_op_exec_per_tensor(input_x1, 0.01, 5, torch.quint8) + self.assertRtolEqual(cpu_output1, npu_output1) + +instantiate_device_type_tests(TestQuantizePerTensor, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1aaf07af260fab4bdf1d5589f6c9a30f3af69f2c..fc911b50b3e816585e3d2dc926c16bc0149a4c34 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1001,7 +1001,6 @@ supported: - mkldnn_reorder_conv3d_weight - to_mkldnn_backward - quantize_per_tensor - - quantize_per_tensor.tensors - quantize_per_channel - dequantize.self - dequantize.tensors diff --git a/torch_npu/csrc/aten/ops/QuantizePerChannelKernelNpu.cpp b/torch_npu/csrc/aten/ops/QuantizePerChannelKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7248a57b5beace3a103c563e4c5c7f1c2b783b7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/QuantizePerChannelKernelNpu.cpp @@ -0,0 +1,94 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +c10::SmallVector quantize_reshape_size( + const at::Tensor& self, + int64_t axis) { + c10::SmallVector outSize; + for(int64_t i=0; i < self.dim(); i++) { + if(i != axis) { + outSize.emplace_back(1); + } else { + outSize.emplace_back(self.sizes()[i]); + } + } + return outSize; +} + +at::Tensor& quantize_per_channel_out_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& scales, + const at::Tensor& zero_points, + int64_t axis, + at::ScalarType dtype) { + auto reshapeSize = quantize_reshape_size(self, axis); + at::Tensor scales_reshape = scales.reshape(reshapeSize); + at::Tensor zp_reshape = zero_points.reshape(reshapeSize); + at::Tensor scales_broadcast = NPUNativeFunctions::npu_broadcast(scales_reshape, self.sizes()); + at::Tensor zp_broadcast = NPUNativeFunctions::npu_broadcast(zp_reshape, self.sizes()); + string dtypeStr = "torch.qint8"; + if (dtype == at::ScalarType::QUInt8) { + dtypeStr = "torch.quint8"; + } else if (dtype == at::ScalarType::QInt32) { + dtypeStr = "torch.qint32"; + } + OpCommand cmd; + cmd.Name("Quantize") + .Input(self) + .Input(scales_broadcast) + .Input(zp_broadcast) + .Output(result) + .Attr("axis", axis) + .Attr("dtype", dtypeStr) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::quantize_per_channel( + const at::Tensor& self, + const at::Tensor& scales, + const at::Tensor& zero_points, + int64_t axis, + at::ScalarType dtype) { + axis = CalcuOpUtil::make_wrap_dim(axis, self.dim()); + TORCH_CHECK(scales.dim() == 1, "Scales' dim should be equal to 1."); + TORCH_CHECK(zero_points.dim() == 1, "Zero points' dim should be equal to 1."); + TORCH_CHECK(scales.sizes()[0] == zero_points.sizes()[0], "Scales' size should be equal to zero points' size."); + TORCH_CHECK(scales.sizes()[0] == self.sizes()[axis], "length of scales must equal to the specified dimension."); + auto outputSize = input_same_output_size(self); + auto outputDtype = at::kInt; + if (dtype == at::ScalarType::QInt8) { + outputDtype = at::kChar; + } else if (dtype == at::ScalarType::QUInt8) { + outputDtype = at::kByte; + } else if (dtype == at::ScalarType::QInt32) { + outputDtype = at::kInt; + } + at::Tensor result = OpPreparation::ApplyTensorWithSizes( + outputSize, + self.options().dtype(outputDtype)); + quantize_per_channel_out_nocheck(result, self, scales, zero_points, axis, dtype); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/QuantizePerTensorKernelNpu.cpp b/torch_npu/csrc/aten/ops/QuantizePerTensorKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ce83ad715fb29a4bf1686281e9dede99b0ff4b14 --- /dev/null +++ b/torch_npu/csrc/aten/ops/QuantizePerTensorKernelNpu.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& quantize_per_tensor_out_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& scales, + const at::Tensor& zero_points, + at::ScalarType dtype) { + string dtypeStr = "torch.qint8"; + if (dtype == at::ScalarType::QUInt8) { + dtypeStr = "torch.quint8"; + } else if (dtype == at::ScalarType::QInt32) { + dtypeStr = "torch.qint32"; + } + OpCommand cmd; + cmd.Name("Quantize") + .Input(self) + .Input(scales) + .Input(zero_points) + .Output(result) + .Attr("axis", (int64_t)1) + .Attr("dtype", dtypeStr) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + float scaleFloat = static_cast(scale); + auto outputSize = input_same_output_size(self); + auto outputDtype = at::kInt; + if (dtype == at::ScalarType::QInt8) { + outputDtype = at::kChar; + } else if (dtype == at::ScalarType::QUInt8) { + outputDtype = at::kByte; + } else if (dtype == at::ScalarType::QInt32) { + outputDtype = at::kInt; + } + at::Tensor scaleTensor = OpPreparation::ApplyTensorWithSizes( + {1}, + self.options().dtype(at::kFloat)); + scaleTensor[0] = scaleFloat; + at::Tensor zpTensor = OpPreparation::ApplyTensorWithSizes( + {1}, + self.options().dtype(at::kInt)); + zpTensor[0] = zero_point; + at::Tensor result = OpPreparation::ApplyTensorWithSizes( + outputSize, + self.options().dtype(outputDtype)); + quantize_per_tensor_out_nocheck(result, self, scaleTensor, zpTensor, dtype); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file