diff --git a/test/test_network_ops/test_sign_bits_pack.py b/test/test_network_ops/test_sign_bits_pack.py index b3acdbd8490b3c9f19d54ebfcea79188929d841d..087d3228f23694f9c90fb6d15e14a2b9cb284082 100644 --- a/test/test_network_ops/test_sign_bits_pack.py +++ b/test/test_network_ops/test_sign_bits_pack.py @@ -35,10 +35,10 @@ class TestLess(TestCase): def test_sign_bits_pack(self): shape_format = [ - [[np.float16, (16,)], 2], + [[np.float16, (17,)], 1], [[np.float32, (8,)], 1], [[np.float32, (32,)], 1], - [[np.float32, (32,)], 2], + [[np.float32, (33,)], 1], [[np.float32, (16,)], 2] ] @@ -48,7 +48,7 @@ class TestLess(TestCase): cpu_output = self.cpu_op_exec(input1, item[1]) npu_output = self.npu_op_exec(npu_input1, item[1]) - self.assertRtolEqual(cpu_output.astype(item[0][0]), npu_output.astype(item[0][0])) + self.assertRtolEqual(cpu_output, npu_output) if __name__ == "__main__": diff --git a/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp index a79b93797cf04d80d705ad4d85145592d42a1aa0..7fe4904e1a99632b3bb1a4fdd854b27b259e7e11 100644 --- a/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp @@ -35,14 +35,13 @@ at::Tensor& sign_bits_pack_npu_nocheck( } at::Tensor NPUNativeFunctions::npu_sign_bits_pack(const at::Tensor& self, int64_t size) { - TORCH_CHECK(self.dim() == 1 && self.numel()%(size*8) == 0, - "input must be one-dimensional and size of input can be divisible by size*8"); + TORCH_CHECK(self.dim() == 1, "input must be one-dimensional"); TORCH_CHECK(self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::Float, "all only supports torch.float16 and torch.float32 dtypes"); + auto ysize = (self.numel() + 7) / 8; + TORCH_CHECK(size != 0 && ysize % size == 0, "all must be divisible by size"); + at::Tensor result = OpPreparation::ApplyTensor({size, ysize / size}, self.options().dtype(at::kByte), self); - // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensor({size, self.numel()/(size*8)}, self.options().dtype(at::kByte), self); - // calculate the output result of the NPU sign_bits_pack_npu_nocheck(result, self, size);