diff --git a/impl/index/arithprogression/arithprogression_common_impl.h b/impl/index/arithprogression/arithprogression_common_impl.h index 27ca2c84ce1d57d93b58abb9c6f0eb8ec188af79..1c0bc44a7ce2ab16098544f203d31ceaeb4ec09d 100644 --- a/impl/index/arithprogression/arithprogression_common_impl.h +++ b/impl/index/arithprogression/arithprogression_common_impl.h @@ -60,9 +60,6 @@ __aicore__ inline void ArithProgressionImpl(const LocalTensor &dstLocal, cons constexpr int32_t BLOCK_NUM = (ONE_BLK_SIZE / sizeof(T)); constexpr int32_t REPEAT_NUM = (ONE_REPEAT_BYTE_SIZE / sizeof(T)); - auto eventIdVToS = GetTPipePtr()->FetchEventID(HardEvent::V_S); - SetFlag(eventIdVToS); - WaitFlag(eventIdVToS); if (count > BLOCK_NUM) { // Generates a basic arithmetic sequence of the BLOCK_NUM length for filling in subsequent arithmetic sequences. GetBaseArithProgression(dstLocal, firstValue, diffValue, BLOCK_NUM); @@ -114,6 +111,9 @@ __aicore__ inline void ArithProgressionImpl(const LocalTensor &dstLocal, cons } } else { // When the length is less than BLOCK_NUM, the arithmetic sequence is generated directly by using a scalar. + auto eventIdVToS = GetTPipePtr()->FetchEventID(HardEvent::V_S); + SetFlag(eventIdVToS); + WaitFlag(eventIdVToS); GetBaseArithProgression(dstLocal, firstValue, diffValue, count); auto eventIdSToV = GetTPipePtr()->FetchEventID(HardEvent::S_V); SetFlag(eventIdSToV);