From 2f0188ed8831c1923188446147eb8a8991337d32 Mon Sep 17 00:00:00 2001 From: xiao-daipeng Date: Tue, 14 Jun 2022 15:56:24 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=97=E5=AD=90SignBitsPack=20fixbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_sign_bits_pack.py | 6 +++--- torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp | 9 ++++----- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_network_ops/test_sign_bits_pack.py b/test/test_network_ops/test_sign_bits_pack.py index b3acdbd849..087d3228f2 100644 --- a/test/test_network_ops/test_sign_bits_pack.py +++ b/test/test_network_ops/test_sign_bits_pack.py @@ -35,10 +35,10 @@ class TestLess(TestCase): def test_sign_bits_pack(self): shape_format = [ - [[np.float16, (16,)], 2], + [[np.float16, (17,)], 1], [[np.float32, (8,)], 1], [[np.float32, (32,)], 1], - [[np.float32, (32,)], 2], + [[np.float32, (33,)], 1], [[np.float32, (16,)], 2] ] @@ -48,7 +48,7 @@ class TestLess(TestCase): cpu_output = self.cpu_op_exec(input1, item[1]) npu_output = self.npu_op_exec(npu_input1, item[1]) - self.assertRtolEqual(cpu_output.astype(item[0][0]), npu_output.astype(item[0][0])) + self.assertRtolEqual(cpu_output, npu_output) if __name__ == "__main__": diff --git a/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp index a79b93797c..7fe4904e1a 100644 --- a/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp @@ -35,14 +35,13 @@ at::Tensor& sign_bits_pack_npu_nocheck( } at::Tensor NPUNativeFunctions::npu_sign_bits_pack(const at::Tensor& self, int64_t size) { - TORCH_CHECK(self.dim() == 1 && self.numel()%(size*8) == 0, - "input must be one-dimensional and size of input can be divisible by size*8"); + TORCH_CHECK(self.dim() == 1, "input must be one-dimensional"); TORCH_CHECK(self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::Float, "all only supports torch.float16 and torch.float32 dtypes"); + auto ysize = (self.numel() + 7) / 8; + TORCH_CHECK(size != 0 && ysize % size == 0, "all must be divisible by size"); + at::Tensor result = OpPreparation::ApplyTensor({size, ysize / size}, self.options().dtype(at::kByte), self); - // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensor({size, self.numel()/(size*8)}, self.options().dtype(at::kByte), self); - // calculate the output result of the NPU sign_bits_pack_npu_nocheck(result, self, size); -- Gitee