From 21973b295ee075a5ae5c6106f4c1ad4f389fec48 Mon Sep 17 00:00:00 2001 From: zhanghao0689 Date: Mon, 25 Aug 2025 11:10:56 +0800 Subject: [PATCH] move sync event to else branch --- impl/index/arithprogression/arithprogression_common_impl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/impl/index/arithprogression/arithprogression_common_impl.h b/impl/index/arithprogression/arithprogression_common_impl.h index 27ca2c84..1c0bc44a 100644 --- a/impl/index/arithprogression/arithprogression_common_impl.h +++ b/impl/index/arithprogression/arithprogression_common_impl.h @@ -60,9 +60,6 @@ __aicore__ inline void ArithProgressionImpl(const LocalTensor &dstLocal, cons constexpr int32_t BLOCK_NUM = (ONE_BLK_SIZE / sizeof(T)); constexpr int32_t REPEAT_NUM = (ONE_REPEAT_BYTE_SIZE / sizeof(T)); - auto eventIdVToS = GetTPipePtr()->FetchEventID(HardEvent::V_S); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); if (count > BLOCK_NUM) { // Generates a basic arithmetic sequence of the BLOCK_NUM length for filling in subsequent arithmetic sequences. GetBaseArithProgression(dstLocal, firstValue, diffValue, BLOCK_NUM); @@ -114,6 +111,9 @@ __aicore__ inline void ArithProgressionImpl(const LocalTensor &dstLocal, cons } } else { // When the length is less than BLOCK_NUM, the arithmetic sequence is generated directly by using a scalar. + auto eventIdVToS = GetTPipePtr()->FetchEventID(HardEvent::V_S); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); GetBaseArithProgression(dstLocal, firstValue, diffValue, count); auto eventIdSToV = GetTPipePtr()->FetchEventID(HardEvent::S_V); SetFlag(eventIdSToV); -- Gitee