diff --git a/comm/lcal/src/ascendc_kernels/allreduce_big_data.h b/comm/lcal/src/ascendc_kernels/allreduce_big_data.h index f8ce0276dfeecea54c8c36b130768685827c3405..f8d7c9d534ebac259da99f4834847371cee3adaf 100644 --- a/comm/lcal/src/ascendc_kernels/allreduce_big_data.h +++ b/comm/lcal/src/ascendc_kernels/allreduce_big_data.h @@ -30,6 +30,8 @@ public: DumpLcclLogInfo(LogId::INIT, static_cast(op)); if constexpr(!std::is_same_v) { BuildScaleOffset(scale, scaleCount, offset); + this->input = input; + this->output = output; } if (blockIdx >= PING_PONG_SIZE * rankSize) { @@ -124,6 +126,22 @@ public: } DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); } + + FORCE_INLINE_AICORE void SupportBigScale() + { + if constexpr(!std::is_same_v) { + constexpr int32_t bigScaleFlagOffset = 2; + if (blockIdx == 0) { + inputGt.SetGlobalBuffer((__gm__ U*)input); + outputGt.SetGlobalBuffer((__gm__ T*)output); + CpGM2GMWithScale(len, inputGt, outputGt, COPYONLY); + sync.SetSyncFlag(magic, 0, blockNum * bigScaleFlagOffset, rank); + } else { + sync.WaitSyncFlag(magic, 0, blockNum * bigScaleFlagOffset, rank); + } + } + return; + } private: FORCE_INLINE_AICORE void Producer() { @@ -251,6 +269,8 @@ private: T offset = 0; bool isEnableScale = false; bool isVectorScale = false; + GM_ADDR input = nullptr; + GM_ADDR output = nullptr; }; #endif // LCCL_ALLREDUCE_BIG_DATA_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/lccl_op.h b/comm/lcal/src/ascendc_kernels/lccl_op.h index bf54ce2b1bf05482a6b4e133b52a5d5e2db231a0..115c2690960279d4c967ff93e9fea2e6708c630f 100644 --- a/comm/lcal/src/ascendc_kernels/lccl_op.h +++ b/comm/lcal/src/ascendc_kernels/lccl_op.h @@ -129,6 +129,7 @@ extern "C" __global__ __aicore__ void LcalAllReduce_##type##suffix(KERNELS_ARGS_ constexpr int32_t cceSmallDataSize = 2 * 1024 * 1024; \ constexpr int32_t smallDataSize910a3 = 32 * 1024 * 1024; \ constexpr int32_t rankSize910a3 = 16; \ + constexpr int32_t scaleCountMax = 12 * 1024 * 1024; \ __gm__ type * shareAddrs[LCAL_MAX_RANK_SIZE]; \ GET_IPC_MEM_ARGS(type); \ if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { \ @@ -142,8 +143,14 @@ extern "C" __global__ __aicore__ void LcalAllReduce_##type##suffix(KERNELS_ARGS_ CLASS_OP_QUANT_LAUNCH(AllReduceOneShot, half, int8_t); \ } else if (len * sizeof(type) <= quantSmallDataSize) { \ CLASS_OP_QUANT_LAUNCH(AllReduceTwoShot, half, int8_t); \ - } else { \ + } else if (scaleCount * rankSize <= scaleCountMax) { \ CLASS_OP_QUANT_LAUNCH(AllReduceBigData, half, int8_t); \ + } else { \ + AllReduceBigData opTmp(localRank, localRankSize, extraFlag); \ + opTmp.Init(KERNELS_ARGS_CALL()); \ + opTmp.SupportBigScale(); \ + input = output; \ + CLASS_OP_LAUNCH(AllReduceBigData, half); \ } \ } else if ((extraFlag & ExtraFlag::TOPO_910B2C) != 0 && rankSize > smallRankSize) { \ if (len * sizeof(type) < cceSmallDataSize) { \