From 1525f79597e9350a72db8ff63b374a049358b983 Mon Sep 17 00:00:00 2001 From: PengC Date: Tue, 15 Oct 2024 20:27:15 +0800 Subject: [PATCH] fix antiquant bug --- .../antiquant/ascend_antiquant_impl.h | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/impl/quantization/antiquant/ascend_antiquant_impl.h b/impl/quantization/antiquant/ascend_antiquant_impl.h index 89ccb066..fa5b1b4d 100644 --- a/impl/quantization/antiquant/ascend_antiquant_impl.h +++ b/impl/quantization/antiquant/ascend_antiquant_impl.h @@ -429,8 +429,8 @@ __aicore__ inline void AntiQuantFp16TransposeTailImpl(const LocalTensor &d } } -template -__aicore__ inline void AscendAntiQuantFp16Transpose(const LocalTensor &dst, const LocalTensor &src, +template +__aicore__ inline void AscendAntiQuantFp16Transpose(const LocalTensor &dst, const LocalTensor &src, LocalTensor offset, const LocalTensor &scale, const LocalTensor &sharedTmpBuffer, const uint32_t calCount, const uint32_t K, const AntiQuantShapeInfo& shapeInfo) { @@ -446,7 +446,15 @@ __aicore__ inline void AscendAntiQuantFp16Transpose(const LocalTensor &dst SetMaskCount(); SetVectorMask(0, src.GetSize()); UnaryRepeatParams unaryParams(1, 1, DEFAULT_REPEAT_STRIDE, HALF_DEFAULT_REPEAT_STRIDE); - Cast(dst, src, RoundMode::CAST_NONE, MASK_PLACEHOLDER, 1, unaryParams); + + if constexpr (IsSameType::value) { + UnaryRepeatParams s42f16unaryParams; + s42f16unaryParams.srcRepStride = ONE_FOURTH_DEFAULT_REPEAT_STRIDE; + Cast(dst, src, RoundMode::CAST_NONE, MASK_PLACEHOLDER, 1, s42f16unaryParams); + } else { + Cast(dst, src, RoundMode::CAST_NONE, MASK_PLACEHOLDER, 1, unaryParams); + } + PipeBarrier(); LocalTensor stackBuffer = sharedTmpBuffer.ReinterpretCast(); @@ -481,8 +489,10 @@ __aicore__ inline void AscendAntiQuantImpl(const LocalTensor &ds uint32_t calCount = src.GetSize(); if constexpr (!isTranspose) { return AscendAntiQuantNoTranspose(dst, src, offset, scale, sharedTmpBuffer, calCount, K, shapeInfo); - } else if constexpr (IsSameType::value && IsSameType::value) { - return AscendAntiQuantFp16Transpose(dst, src, offset, scale, sharedTmpBuffer, calCount, K, shapeInfo); + } else if constexpr (IsSameType::value && + (IsSameType::value || IsSameType::value)) { + return AscendAntiQuantFp16Transpose(dst, src, offset, scale, sharedTmpBuffer, calCount, + K, shapeInfo); } AscendAntiQuantTranspose(dst, src, offset, scale, sharedTmpBuffer, K, shapeInfo); } @@ -495,8 +505,10 @@ __aicore__ inline void AscendAntiQuantImpl(const LocalTensor &ds uint32_t calCount = src.GetSize(); if constexpr (!isTranspose) { return AscendAntiQuantNoTranspose(dst, src, scale, sharedTmpBuffer, calCount, K, shapeInfo); - } else if constexpr (IsSameType::value && IsSameType::value) { - return AscendAntiQuantFp16Transpose(dst, src, scale, scale, sharedTmpBuffer, calCount, K, shapeInfo); + } else if constexpr (IsSameType::value && + (IsSameType::value || IsSameType::value)) { + return AscendAntiQuantFp16Transpose(dst, src, scale, scale, sharedTmpBuffer, calCount, + K, shapeInfo); } AscendAntiQuantTranspose(dst, src, scale, sharedTmpBuffer, K, shapeInfo); } -- Gitee