diff --git a/CMakeLists.txt b/CMakeLists.txt index f1606b572de7f1f0338db54d0e89cbd87af32f66..f71a4d3fa42407cfc5738c565f7f5fd34f73eb24 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,9 +83,9 @@ include_directories( ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src/kernels/include - ${PROJECT_SOURCE_DIR}/src/kernels/include/lcal - ${PROJECT_SOURCE_DIR}/src/kernels/include/lcal/lcoc - ${PROJECT_SOURCE_DIR}/src/kernels/include/lcal/tiling + ${PROJECT_SOURCE_DIR}/comm/lcal/include + ${PROJECT_SOURCE_DIR}/comm/lcal/include/lcoc + ${PROJECT_SOURCE_DIR}/comm/lcal/include/lcoc/tiling ${PROJECT_SOURCE_DIR}/3rdparty/mki/include ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include $ENV{ASCEND_HOME_PATH}/include @@ -117,14 +117,12 @@ if (BUILD_CUSTOMIZE_OPS) endif() set(CMAKE_INSTALL_PREFIX "${CMAKE_SOURCE_DIR}/output/atb/cxx_abi_${cxx_abi}") - +add_subdirectory(comm/lcal) message(STATUS "CMAKE_INSTALL_PREFIX:${CMAKE_INSTALL_PREFIX}") install(FILES ${PROJECT_SOURCE_DIR}/scripts/set_env.sh DESTINATION ./..) install(DIRECTORY ${PROJECT_SOURCE_DIR}/ops_configs DESTINATION ./configs) install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/mki/lib/libmki.so DESTINATION lib) -install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/liblcal.so DESTINATION lib) -install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/liblcal_static.a DESTINATION lib) install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/libasdops_aicpu_kernels.so DESTINATION lib OPTIONAL) install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/libtbe_adapter.so DESTINATION lib OPTIONAL) install(FILES ${PROJECT_SOURCE_DIR}/3rdparty/asdops/lib/libcann_ops_adapter.so DESTINATION lib OPTIONAL) diff --git a/comm/lcal/CMakeLists.txt b/comm/lcal/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5e63434a947b96c2bdc5f67c284b9996cf6dc96 --- /dev/null +++ b/comm/lcal/CMakeLists.txt @@ -0,0 +1,60 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +cmake_minimum_required(VERSION 3.12) +project(Lcal LANGUAGES CXX) +set(CMAKE_CXX_STANDARD 14) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +option(USE_CXX11_ABI "USE_CXX11_ABI" 0) + +IF (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "aarch64") + set(ARCH aarch64) +ELSEIF (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "x86_64") + set(ARCH x86_64) +ENDIF() +if(USE_CXX11_ABI) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") +endif() + +message("== CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE}") +message("== CMAKE_SOURCE_DIR ${CMAKE_SOURCE_DIR}") + +# 获取环境变量 +if (DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) +else() + message("ASCEND_HOME_PATH not set! using default path!") + set(ASCEND_HOME_PATH /usr/local/Ascend/ascend-toolkit/latest) +endif() + +message("== ASCEND_HOME_PATH: ${ASCEND_HOME_PATH}") + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations") +add_link_options(-Wl,-z,relro,-z,now) +add_link_options(-s) + +include_directories( + ${CMAKE_CURRENT_LIST_DIR}/include + ${CMAKE_CURRENT_LIST_DIR}/include/lcoc/ + ${CMAKE_CURRENT_LIST_DIR}/include/lcoc/tiling/ + ${ASCEND_HOME_PATH}/${ARCH}-linux/include/ + ${ASCEND_HOME_PATH}/${ARCH}-linux/include/hccl/ + ${ASCEND_HOME_PATH}/${ARCH}-linux/include/experiment + ${ASCEND_HOME_PATH}/${ARCH}-linux/include/experiment/runtime/ + ${ASCEND_HOME_PATH}/${ARCH}-linux/include/experiment/msprof/ + ) +link_directories(${ASCEND_HOME_PATH}/${ARCH}-linux/lib64/) + +set(AIV_ARCH dav-c220-vec) +set(AIC_ARCH dav-c220-cube) + +add_subdirectory(src) diff --git a/comm/lcal/cmake/CMakeCCECompiler.cmake.in b/comm/lcal/cmake/CMakeCCECompiler.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..073b728b1a357502ed1d07eea4311f5b96f49dc6 --- /dev/null +++ b/comm/lcal/cmake/CMakeCCECompiler.cmake.in @@ -0,0 +1,14 @@ +# + # Copyright (c) 2024 Huawei Technologies Co., Ltd. + # This file is a part of the CANN Open Software. + # Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + # Please refer to the License for details. You may not use this file except in compliance with the License. + # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + # See LICENSE in the root of the software repository for the full text of the License. + # + set(CMAKE_CCE_COMPILER "@CMAKE_CCE_COMPILER@") + set(CMAKE_CCE_COMPILER_LOADED 1) + set(CMAKE_CCE_OUTPUT_EXTENSION @CMAKE_CCE_OUTPUT_EXTENSION@) + set(CMAKE_CCE_COMPILER_ENV_VAR "@CMAKE_CCE_COMPILER_ENV_VAR@") + set(CMAKE_CCE_SOURCE_FILE_EXTENSIONS @CMAKE_CCE_SOURCE_FILE_EXTENSIONS@) diff --git a/comm/lcal/cmake/CMakeCCEInformation.cmake b/comm/lcal/cmake/CMakeCCEInformation.cmake new file mode 100644 index 0000000000000000000000000000000000000000..e85910f9436127cbbe34b413fe87a60a5c80a0e0 --- /dev/null +++ b/comm/lcal/cmake/CMakeCCEInformation.cmake @@ -0,0 +1,28 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +include(CMakeCommonLanguageInclude) +include(Compiler/CMakeCommonCompilerMacros) +set(CMAKE_INCLUDE_FLAG_CCE "-I") +if(UNIX) + set(CMAKE_CCE_OUTPUT_EXTENSION .o) +else() + set(CMAKE_CCE_OUTPUT_EXTENSION .obj) +endif() +set(CMAKE_DEPFILE_FLAGS_CCE "-MD -MT -MF ") +set(CMAKE_CCE_DEPFILE_FORMAT gcc) +set(CMAKE_CCE_DEPENDS_USE_COMPILER TRUE) +if(NOT CMAKE_CCE_COMPILE_OBJECT) + set(CMAKE_CCE_COMPILE_OBJECT + "${CMAKE_CCE_COMPILER} -xcce \ + ${__IMPLICIT_INCLUDES} ${_CMAKE_CCE_BUILTIN_INCLUDE_PATH}\ + ${_CMAKE_COMPILE_AS_CCE_FLAG} ${_CMAKE_CCE_COMPILE_OPTIONS}\ + ${_CMAKE_CCE_COMMON_COMPILE_OPTIONS} -o -c ") +endif() + diff --git a/comm/lcal/cmake/CMakeDetermineCCECompiler.cmake b/comm/lcal/cmake/CMakeDetermineCCECompiler.cmake new file mode 100644 index 0000000000000000000000000000000000000000..955cf6af2c08ad7504cd5634e5d569bef4153e3c --- /dev/null +++ b/comm/lcal/cmake/CMakeDetermineCCECompiler.cmake @@ -0,0 +1,31 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +set(PRIVATE_CCEC_PATH ${CMAKE_SOURCE_DIR}/3rdparty/compiler) +find_program(CMAKE_CCE_COMPILER + NAMES "ccec" + HINTS "${PRIVATE_CCEC_PATH}/ccec_compiler/bin" + HINTS "${ASCEND_HOME_PATH}/${ARCH}-linux/ccec_compiler/bin" + DOC "CCE Compiler" +) +find_program(CMAKE_CCE_LINKER + NAMES "ld.lld" + HINTS "${PRIVATE_CCEC_PATH}/ccec_compiler/bin" + HINTS "${ASCEND_HOME_PATH}/${ARCH}-linux/ccec_compiler/bin" + DOC "CCE Linker" +) +message(STATUS "CMAKE_CCE_COMPILER: " ${CMAKE_CCE_COMPILER}) +message(STATUS "CMAKE_PLATFORM_INFO_DIR: "${CMAKE_PLATFORM_INFO_DIR}) +configure_file(${CMAKE_CURRENT_LIST_DIR}/CMakeCCECompiler.cmake.in + ${CMAKE_PLATFORM_INFO_DIR}/CMakeCCECompiler.cmake + @ONLY +) +set(CMAKE_CCE_SOURCE_FILE_EXTENSIONS cce;cpp) +set(CMAKE_CCE_COMPILER_ENV_VAR "CCEC") + diff --git a/comm/lcal/cmake/CMakeTestCCECompiler.cmake b/comm/lcal/cmake/CMakeTestCCECompiler.cmake new file mode 100644 index 0000000000000000000000000000000000000000..1136cb4bcc74e6957deb4e1661dd71bdf4c850cc --- /dev/null +++ b/comm/lcal/cmake/CMakeTestCCECompiler.cmake @@ -0,0 +1,10 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +set(CMAKE_CCE_COMPILER_WORKS 1 CACHE INTERNAL "") \ No newline at end of file diff --git a/comm/lcal/include/comm_args.h b/comm/lcal/include/comm_args.h new file mode 100644 index 0000000000000000000000000000000000000000..ff4f329757e05c03c9767822b088b30e91e9e937 --- /dev/null +++ b/comm/lcal/include/comm_args.h @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_COMM_ARGS_H +#define LCCL_COMM_ARGS_H +#include + +#if !defined(__DAV_C220_VEC__) && !defined(__DAV_C310__) && !defined(__DAV_C220_CUBE__) +using GM_ADDR = uint8_t*; +#else +#define FORCE_INLINE_AICORE __attribute__((always_inline)) inline __aicore__ +#include "kernel_operator.h" +#endif +namespace Lcal { + +constexpr int LCAL_MAX_RANK_SIZE = 128; // lcal通信库最大支持的npu卡数 +constexpr int RANK_SIZE_TWO = 2; // 可用SIO的规模,以及是否需要跨卡搬运数据核的分界规模 +constexpr int64_t IPC_BUFF_MAX_SIZE = 100 * 1024 * 1024; +constexpr int64_t IPC_DATA_OFFSET = 2 * 1024 * 1024; // 前2MB作为flag标志位,之后100MB作为数据存储 +constexpr int64_t SYNC_FLAG_BIT_NUM = 10; // cce 算子在用 +constexpr int64_t MEM_DMA_UNIT_INT_NUM = 4; +constexpr int64_t EVENT_ID_MASK = 0xFFFFFFFF; +constexpr int64_t PING_PONG_SIZE = 2; +constexpr int64_t UB_SINGLE_DMA_SIZE_MAX = 190 * 1024; +constexpr int64_t SMALL_DATA_SIZE = 1 * 1024 * 1024; +constexpr int64_t UB_SINGLE_PING_PONG_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX / 2; +constexpr int UB_ALIGN_SIZE = 32; + +// 2step算法中,2个aiv真正用作数据预处理 +constexpr int64_t PRE_CORE_REAL_NUM = 2; + +constexpr int64_t AIV_PER_AICORE = 2; + +constexpr int DFX_COUNT = 50; + +constexpr int64_t HALF_NUM = 2; + +constexpr int64_t THREE_NUM = 3; + +constexpr int64_t FOUR_NUM = 4; + +constexpr int64_t VADD_MAX_REPEAT = 255; +constexpr int64_t VADD_UNIT_BYTE = 256; + +// vadd单位粒度是256B,vadd最大repeat次数为255,两个相乘的结果 +constexpr int64_t MAX_VADD_SIZE = VADD_MAX_REPEAT * VADD_UNIT_BYTE; +constexpr int64_t BLOCK_UNIT_BYTE = 32; +constexpr int64_t VADD_UNIT_TO_BLOCK_UNIT_RATIO = VADD_UNIT_BYTE / BLOCK_UNIT_BYTE; // 8 + +constexpr bool ATOMIC_ENABLE = false; + +constexpr int32_t LCCL_DUMP_UINT_SIZE = 1 * 1024 * 1024; +enum Op : int { + COPYONLY = -1, + ADD = 0, + MUL = 1, + MAX = 2, + MIN = 3 +}; + +struct ExtraFlag { + static constexpr uint32_t RDMA = 1; + static constexpr uint32_t TOPO_910B2C = 1 << 1; + static constexpr uint32_t TOPO_910_93 = 1 << 2; + static constexpr uint32_t DETERMINISTIC = 1 << 3; + static constexpr uint32_t QUANT_FP16 = 1 << 4; + static constexpr uint32_t QUANT_FP32 = 1 << 5; + static constexpr uint32_t TOPO_910A5 = 1 << 6; + static constexpr uint32_t QUANT_DELAY = 1 << 7; + static constexpr uint32_t QUANT_CURRENT = 1 << 8; + static constexpr uint32_t TOPO_PCIE = 1 << 9; + static constexpr uint32_t IS_GREATER_THAN_40_AIV = 1 << 16; +}; + +struct CommArgs { + int rank = 0; // attr rank_id, global rank + int localRank = -1; + int rankSize = 0; // global rank size + int localRankSize = -1; // 此参数是指fullmesh互联的卡数 + uint32_t extraFlag = 0; // 32 bit map,具体每一位的含义就在此文件正上方 + GM_ADDR peerMems[LCAL_MAX_RANK_SIZE] = {}; // 传入初始化获得的buff,所有allreduce都是同一个参数 + /** + * @param sendCountMatrix 大小是rankSize*rankSize的一维数组 + * eg: sendCountMatrix[1] 的数值,对应二维数组的[0][1],表示 卡0 要给 卡1 发送的数据个数 + */ + int64_t sendCountMatrix[LCAL_MAX_RANK_SIZE * LCAL_MAX_RANK_SIZE] = {}; // for all2allv + int64_t dfx[DFX_COUNT] = {}; + GM_ADDR dumpAddr = nullptr; + int32_t magics[LCAL_MAX_RANK_SIZE] = {0}; + uint64_t fftsVal = 0; +}; + +struct LcclDumpBlockInfo { + uint32_t len = 0; + uint32_t core = 0; + uint32_t blockNum = 0; + uint32_t dumpOffset = 0; + uint32_t magic = 0; + uint32_t rsv = 0; + uint64_t dumpAddr = 0; +}; + +struct LcclDumpLogInfo { + uint32_t logId = 0; + uint32_t blockId = 0; + uint64_t syscyc = 0; + uint64_t curPc = 0; + uint32_t operationType = 0; + uint32_t rsv = 0; +}; + +union LcclDumpUnion { + LcclDumpBlockInfo blockInfo; + LcclDumpLogInfo logInfo; +}; + +enum LogId : int { + OVERALL = 0, + INIT, + PROCESS +}; + +} // namespace Lcal +#endif // LCCL_COMM_ARGS_H \ No newline at end of file diff --git a/comm/lcal/include/lcal.h b/comm/lcal/include/lcal.h new file mode 100644 index 0000000000000000000000000000000000000000..bc13ba18df3fc5a24de3dda836a74482f20e30af --- /dev/null +++ b/comm/lcal/include/lcal.h @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_H +#define LCAL_H +#include "lcal_types.h" +#include "lcal_comm.h" +#include "lccl.h" +#include "lcoc.h" +#endif // LCAL_H \ No newline at end of file diff --git a/comm/lcal/include/lcal_api.h b/comm/lcal/include/lcal_api.h new file mode 100644 index 0000000000000000000000000000000000000000..582c951479f6158fd74b3305fbea428cd39ffa6b --- /dev/null +++ b/comm/lcal/include/lcal_api.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_API_H +#define LCAL_API_H + +#include +#include +#include +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef void *LcalCommPtr; +#define LCAL_UNIQUE_ID_BYTES 128 +typedef struct { char internal[LCAL_UNIQUE_ID_BYTES]; } LcalUniqueId; + +int LcalGetUniqueId(LcalUniqueId *uniqueId, int commDomain); + +int LcalCommInitRankLocal(int rankSize, int rank, LcalCommPtr *comm); + +int LcalCommInitRank(LcalUniqueId commId, int rankSize, int rank, LcalCommPtr *comm); + +int LcalCommInitRankWithCustDomainSize(int commDomain, int bufferSize, int rankSize, int rank, LcalCommPtr *comm, + const bool isEnableAutoMagicNum = false); + +int LcalCommInitRankWithDomain(int commDomain, int rankSize, int rank, LcalCommPtr *comm); + +int LcalGetCommArgsDev(LcalCommPtr comm, GM_ADDR &commArgsPtr); + +int LcalGetCommArgsHost(LcalCommPtr comm, Lcal::CommArgs *&commArgsPtr); + +void LcalPrintDFX2Log(LcalCommPtr comm); + +int LcalCommInit(int rank, int rankSize, LcalCommPtr *comms); + +int LcalCommInitAll(uint32_t ndev, int32_t* devices, LcalCommPtr *comms); + +int LcalCommInitThread(int rank, int rankSize, const char *uid, LcalCommPtr *comms); + +int LcclAllReduce(void *sendBuf, void *recvBuf, int64_t count, HcclDataType dataType, HcclReduceOp op, + LcalCommPtr comm, aclrtStream stream); + +int LcclAllGather(void *sendBuf, void *recvBuf, int64_t sendCount, HcclDataType dataType, LcalCommPtr comm, + aclrtStream stream); + +int LcclReduceScatter(void *sendBuf, void *recvBuf, int64_t recvCount, HcclDataType dataType, HcclReduceOp op, + LcalCommPtr comm, aclrtStream stream); + +int LcclBroadcast(void *buf, int64_t count, HcclDataType dataType, int root, LcalCommPtr comm, + aclrtStream stream); + +int LcclCommDestroy(LcalCommPtr comm); + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // LCAL_API_H \ No newline at end of file diff --git a/comm/lcal/include/lcal_comm.h b/comm/lcal/include/lcal_comm.h new file mode 100644 index 0000000000000000000000000000000000000000..6ec0fbd7e8d9ba061ed9278378fee49fdf797438 --- /dev/null +++ b/comm/lcal/include/lcal_comm.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_COMM_H +#define LCAL_COMM_H + +#include +#include + +#include +#include "lcal_types.h" +#include "lcal_api.h" +#include "comm_args.h" + +namespace Lcal { +constexpr int IPC_NAME_SIZE = 65; +constexpr int SINGLE_MACHINE_910B2C_RANK_SIZE = 16; + +class LcalSockExchange; +class LcalComm { +public: + LcalComm(int rank, int rankSize); + LcalComm(int rank, int rankSize, int bufferSize); + LcalComm(int rank, int rankSize, int commDomain, int bufferSize, int isEnableMagic); + LcalComm(int rank, int rankSize, LcalUniqueId commId); + ~LcalComm(); + LcalComm(const LcalComm &) = delete; + LcalComm &operator=(const LcalComm &) = delete; + int Init(); + int InitThread(const std::string &uid = "default"); + int GetRank() const; + int GetRankSize() const; + int GetCommSize() const; + int GetBufferSize() const; + const PhysicalInfo &GetPhysicalInfo() const; + GM_ADDR GetCommArgsPtr() const; + CommArgs* GetCommArgs(); + std::string PrintDFX(); + friend class Lccl; + friend class Lcoc; + friend class LcclTest; + +private: + int SetMemoryName(std::string &name); + int SetIpcPidSdid(std::string &name, const uint32_t *pids, const int64_t *sdids) const; + int OpenIpcMem(const char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]); + int GetDev(); + int GetDevThread(const std::string &uid = ""); + int EnablePeerAccess(); + int InitCommMem(); + int InitCommon(); + void CloseIpcMem(); + void FreePeerMem(GM_ADDR &mem) const; + int InitMem(); + int GetSidId(int64_t sdids[LCAL_MAX_RANK_SIZE], int rankSize); + int GetPid(uint32_t *pids); + int GetName(std::string &name, char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]) const; + int SyncCommArgs(); + int InitDumpAddr(); + +private: + int rank_ = 0; // global rank id + int rankSize_ = 0; // global rank size + int commSize_ = 0; // local LcalComm size + int localRank_ = -1; + int localRankSize_ = -1; + int devId_ = 0; + int64_t magic_ = 1; + bool inited_ = false; + bool ipcMemInited_ = false; + std::string uid_ = {}; + std::vector devList_ = {}; + std::vector rankList_ = {}; + int commDomain_ = {}; + int bufferSize_ = LCAL_COMM_BUFFER_SIZE; + + // shared ping pong buff,这个地址就是一开始申请在HBM上的,所以host上可以取到,但不能直接修改。 + GM_ADDR peerMem_[LCAL_MAX_RANK_SIZE] = {}; + PhysicalInfo physicalInfo_ = {}; + CommArgs commArgs_ = {}; // host侧 + GM_ADDR commArgsPtr_ = nullptr; // device侧 + LcalUniqueId commId_ = {}; + LcalSockExchange *socketExchange_ = nullptr; + bool deterministic_ = false; + bool isEnableMsprofOp_ = false; + bool isEnableMix_ = false; +}; +} // Lcal + +#endif // LCAL_COMM_H \ No newline at end of file diff --git a/comm/lcal/include/lcal_types.h b/comm/lcal/include/lcal_types.h new file mode 100644 index 0000000000000000000000000000000000000000..104b3622363b8d822a87547e307cc6a3fe4a5270 --- /dev/null +++ b/comm/lcal/include/lcal_types.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_TYPES_H +#define LCAL_TYPES_H + +#include +#include +#include + +namespace Lcal { +constexpr int LCAL_SUCCESS = 0; +constexpr int LCAL_ERROR_NOT_INITIALIZED = -1; +constexpr int LCAL_ERROR_MKIRT = -2; +constexpr int LCAL_ERROR_PARA_CHECK_FAIL = -3; +constexpr int LCAL_ERROR_INTERNAL = -4; +constexpr int LCAL_ERROR_TIMEOUT = -5; +constexpr int LCAL_ERROR_NOT_FOUND = -7; +constexpr int OUT_OF_DEVICE_MEMORY = -8; +constexpr int64_t LCAL_INVALID_VALUE = -1; + +// shared buffer size,这里要和collectives.cce文件中的常量联动修改!!! +constexpr int LCAL_BUFF_BYTES = 204 * 1024 * 1024; +constexpr int LCAL_FLAG_BUFF_BYTES = 4 * 1024 * 1024; +constexpr int LCAL_COMM_BUFFER_SIZE = 200; // 单位MB + +enum class ChipName { + CHIP_310P3 = 0, + CHIP_910B1, + CHIP_910B2, + CHIP_910B3, + CHIP_910B4, + CHIP_910B41, + CHIP_910B2C, + CHIP_910_9391, + CHIP_910_9381, + CHIP_910_9392, + CHIP_910_9382, + CHIP_910_9372, + CHIP_910_9361, + CHIP_910_9362, + CHIP_910A5, + RESERVED, +}; + +enum class PhysicalLink { + HCCS = 0, + PCIE = 1, + RESERVED, +}; + +// 包含 物理链路、芯片名称 信息。 +struct PhysicalInfo { + ChipName chipName = ChipName::RESERVED; + PhysicalLink physicalLink = PhysicalLink::RESERVED; + uint32_t coreNum = 0; +}; + +enum class LcalType { + ALL_REDUCE = 1, + REDUCE_SCATTER = 2, + ALL_GATHER = 3, + BROADCAST = 4, + ALL2ALL = 5, + ALL_REDUCE_910B2C = 6, + ALL_GATHER_910B2C = 7, + LOCAL_REDUCE = 8, + SEND = 9, + RECV = 10, + ALL2ALL_V_C = 11, + GATHER = 12, + PURE_MATMUL = 101, + MATMUL_ALL_REDUCE = 102, + MATMUL_REDUCE_SCATTER = 103, + ALL_GATHER_MATMUL = 104, + ALL_GATHER_MATMUL_V2 = 105, + ALL2ALL_MATMUL = 106, + MATMUL_ALL2ALL = 107, + ALL_GATHER_MATMUL_REDUCE_SCATTER = 111, + BANDWIDTH = 201, + ALLTOALLV_ALLGATHER_MATMUL = 305, + ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN = 309, + MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN = 310, + LCAL_TYPE_MAX = 311, +}; + +const std::map LCAL_TYPE2NAME = { + { LcalType::ALL_REDUCE, "LcalAllReduce" }, + { LcalType::REDUCE_SCATTER, "LcalReduceScatter" }, + { LcalType::ALL_GATHER, "LcalAllGather" }, + { LcalType::BROADCAST, "LcalBroadcast" }, + { LcalType::PURE_MATMUL, "LcalPureMatmul" }, + { LcalType::MATMUL_ALL_REDUCE, "LcalMatmulAllReduce" }, + { LcalType::MATMUL_REDUCE_SCATTER, "LcalMatmulReduceScatter" }, + { LcalType::ALL_GATHER_MATMUL, "LcalAllGatherMatmul" }, + { LcalType::ALL_GATHER_MATMUL_V2, "LcalAllGatherMatmulV2" }, + { LcalType::ALL2ALL_MATMUL, "LcalAll2AllMatmul" }, + { LcalType::MATMUL_ALL2ALL, "LcalMatmulAll2All" }, + { LcalType::ALL2ALL, "LcalAll2All" }, + { LcalType::ALL2ALL_V_C, "LcalAll2AllVC" }, + { LcalType::ALL_GATHER_MATMUL_REDUCE_SCATTER, "LcalAllGatherMatmulReduceScatter" }, + { LcalType::BANDWIDTH, "LcalBandwidthTest" }, + { LcalType::ALL_REDUCE_910B2C, "LcalAllReduce910B2C" }, + { LcalType::ALL_GATHER_910B2C, "LcalAllGather910B2C" }, + { LcalType::ALLTOALLV_ALLGATHER_MATMUL, "LcalAllToAllVAllGatherMatmul" }, + { LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN, "LcalAllToAllVAllGatherMatmulHidden" }, + { LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN, "LcalMatmulReduceScatterAllToAllVHidden" } +}; + + +} // namespace Lcal +#endif // LCAL_TYPES_H diff --git a/comm/lcal/include/lccl.h b/comm/lcal/include/lccl.h new file mode 100644 index 0000000000000000000000000000000000000000..c9cae2a2d0e71609471bdfae414b6ace64da543b --- /dev/null +++ b/comm/lcal/include/lccl.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_LCCL_H +#define LCAL_LCCL_H + +#include + + +namespace Lcal { +class Lccl { +public: + Lccl() = delete; + explicit Lccl(LcalComm *comm); + explicit Lccl(LcalComm &comm); + ~Lccl(); + uint32_t GetBlockNum(LcalType cclType, uint32_t rankSize, int64_t dataSize, int localRankSize, uint32_t extraFlag) + const; + int AllReduce(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, + HcclReduceOp op = HCCL_REDUCE_SUM, aclrtStream stream = nullptr, + HcclDataType outputDataType = HCCL_DATA_TYPE_RESERVED, const void *scale = nullptr, int64_t scaleCount = 0, + const void *offset = nullptr) const; + int ReduceScatter(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, + HcclReduceOp op = HCCL_REDUCE_SUM, aclrtStream stream = nullptr) const; + int AllGather(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const; + int All2All(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const; + int All2All(void *sendBuff, void *recvBuff, int64_t count, int burstLen, + int stride, HcclDataType dataType, aclrtStream stream) const; + int All2AllVC(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const; + + int Broadcast(void *buff, int64_t count, HcclDataType dataType, int32_t root, aclrtStream stream) const; + int BandwidthTest(const void *buff, void *recvBuff, int64_t count, HcclDataType dataType, + int32_t root, aclrtStream stream) const; + friend class LcclTest; + +private: + bool CheckDataType(const HcclDataType &dataType) const; + bool CheckBuff(const void *sendBuff, const void *recvBuff) const; + int LoopBack(const void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const; + +private: + LcalComm *comm_ = nullptr; + int rank_ = 0; + int rankSize_ = 0; +}; +} +#endif // LCAL_LCCL_H \ No newline at end of file diff --git a/comm/lcal/include/lcoc/lcoc.h b/comm/lcal/include/lcoc/lcoc.h new file mode 100644 index 0000000000000000000000000000000000000000..ea8b413e6f00b126af0bd5f32d7e078d2a56b300 --- /dev/null +++ b/comm/lcal/include/lcoc/lcoc.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_LCOC_H +#define LCAL_LCOC_H + +#include +#include +#include "lcoc_args.h" +#include "tiling_args.h" + +namespace Lcal { +class Lcoc { +public: + Lcoc() = delete; + explicit Lcoc(LcalComm &comm); + explicit Lcoc(LcalComm *comm); + ~Lcoc(); + int SetParam(LcalType lcalType, const CoCTiling &tiling, const CoCParamDesc ¶mDesc); + int AllGatherMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream = nullptr); + int AllGatherMatmulV2(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream = nullptr); + int MatmulReduceScatter(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream = nullptr); + int MatmulAllReduce(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream = nullptr); + int PureMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream = nullptr); + int AllGatherMatmulReduceScatter(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, + void *workspace, aclrtStream stream = nullptr); + int AllToAllVAllGatherMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream = nullptr); + int AllToAllVAllGatherMatmulHidden(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream = nullptr); + int MatmulReduceScatterAllToAllVHidden(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream = nullptr); + int64_t GetWorkspaceSize(); + LcalComm *GetComm(); + MatMulInfo &GetMatMulInfo(); + void GetTiling(CoCTiling &tiling); + +private: + int LaunchOperator(CoCInputPkg &inputPkg, CoCOutputPkg &outputPkg, void *workspace, aclrtStream stream); + bool CheckBasic(const CoCInputPkg &inputPkg, const CoCOutputPkg &outputPkg, LcalType lcalType) const; + bool CheckInputParam(LcalType lcalType, const CoCTiling &tiling, const CoCParamDesc ¶mDesc) const; + void SetLcocParam(LcalType lcalType, const CoCParamDesc ¶mDesc); + void SetTaskParam(LcalType lcalType, const CoCParamDesc ¶mDesc, const LcalComm &comm); + +private: + LcalComm *comm_ = nullptr; + CoCTilingData tiling_ = {}; + TaskParam taskParam_ = {}; + bool tilingSuccess_ = false; +}; +} +#endif // LCAL_LCOC_H diff --git a/comm/lcal/include/lcoc/lcoc_args.h b/comm/lcal/include/lcoc/lcoc_args.h new file mode 100644 index 0000000000000000000000000000000000000000..c62b169e3f06fe7ce1a9558a8a74440f4b66f638 --- /dev/null +++ b/comm/lcal/include/lcoc/lcoc_args.h @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_LCOC_ARGS_H +#define LCAL_LCOC_ARGS_H + +#include +#include +#include + +constexpr int64_t WORKSPACE_REDUCE_SIZE = 4000000; +#pragma once +namespace Lcal { + const constexpr int32_t INT8_ELE_SIZE = 1; + const constexpr int32_t FP_BF_16_ELE_SIZE = 2; + constexpr uint32_t ALIGN_BYTES = 512; + constexpr int32_t PARAM_CHECK_MAX_VALUE = -1; + constexpr int32_t PARAM_CHECK_MIN_VALUE_ZERO = 0; + constexpr int32_t PARAM_CHECK_MIN_VALUE_ONE = 1; + constexpr int32_t INPUT_PARAM_DEFAULT_VALUE = -1; + constexpr int32_t MAX_M_VALUE = 10000000; + constexpr int32_t MAX_K_VALUE = 100000; + constexpr int32_t MAX_N_VALUE = 100000; + + enum CoCDataTypeDesc : int { + COC_DATA_TYPE_UNDEFINED = -1, + FP16FP16_FP32_FP16 = 0, // 无量化,无反量化 + BF16BF16_FP32_BF16 = 1, // 无量化,无反量化 + INT8INT8_INT32_FP16 = 2, // W8A8,未融合量化,随路反量化 + INT8INT8_INT32_BF16 = 3, // W8A8,未融合量化,aiv反量化 + FP16INT8_INT32_FP16 = 4, // W8A8,融合量化,随路反量化 + BF16INT8_INT32_BF16 = 5, // W8A8,融合量化,aiv反量化 + FP16INT8_FP32_FP16 = 6, // W8A16,融合伪量化,无反量化 + BF16INT8_FP32_BF16 = 7, // W8A16,融合伪量化,无反量化 + FP16INT4_FP32_FP16 = 8, // W4A16,融合伪量化,无反量化 + BF16INT4_FP32_BF16 = 9, // W4A16,融合伪量化,无反量化 + COC_DATA_TYPE_DESC_MAX = 10, + }; + + const std::map COC_TYPE2ELE_SIZE = { + { FP16FP16_FP32_FP16, FP_BF_16_ELE_SIZE }, { BF16BF16_FP32_BF16, FP_BF_16_ELE_SIZE }, + { INT8INT8_INT32_FP16, INT8_ELE_SIZE }, { INT8INT8_INT32_BF16, INT8_ELE_SIZE }, + { FP16INT8_INT32_FP16, INT8_ELE_SIZE }, { BF16INT8_INT32_BF16, INT8_ELE_SIZE }, + { FP16INT8_FP32_FP16, FP_BF_16_ELE_SIZE }, { BF16INT8_FP32_BF16, FP_BF_16_ELE_SIZE }, + { FP16INT4_FP32_FP16, FP_BF_16_ELE_SIZE }, { BF16INT4_FP32_BF16, FP_BF_16_ELE_SIZE } + }; + + const std::map COC_TYPE2HCCL_TYPE = { + { FP16FP16_FP32_FP16, HCCL_DATA_TYPE_FP16 }, { BF16BF16_FP32_BF16, HCCL_DATA_TYPE_BFP16 }, + { INT8INT8_INT32_FP16, HCCL_DATA_TYPE_FP16 }, { INT8INT8_INT32_BF16, HCCL_DATA_TYPE_BFP16 }, + { FP16INT8_INT32_FP16, HCCL_DATA_TYPE_FP16 }, { BF16INT8_INT32_BF16, HCCL_DATA_TYPE_BFP16 }, + { FP16INT8_FP32_FP16, HCCL_DATA_TYPE_FP16 }, { BF16INT8_FP32_BF16, HCCL_DATA_TYPE_BFP16 }, + { FP16INT4_FP32_FP16, HCCL_DATA_TYPE_FP16 }, { BF16INT4_FP32_BF16, HCCL_DATA_TYPE_BFP16 } + }; + + struct CoCParamDesc { + CoCDataTypeDesc dataTypeDesc = FP16FP16_FP32_FP16; + MatMulInfo mmInfo = {}; + QuantInfo quantInfo = {}; + PostInfo postInfo = {}; + HcclReduceOp op = HCCL_REDUCE_SUM; // 当前不支持其他值 + TwoDimTPInfo twoDimTPInfo = {}; + MoeInfo moeInfo = {}; + }; + + struct CoCInputPkg { + void *matrixA = nullptr; + void *matrixB = nullptr; + void *bias = nullptr; + void *gamma = nullptr; + void *dequantScale = nullptr; // 反量化参数,当融合了Matmul前置伪量化或后置反量化操作时需要传入 + void *dequantOffset = nullptr; // 可选,若无offset(如对称量化场景),传入空指针即可 + + void *quantScale = nullptr; // 量化参数,当融合了量化操作时需要传入 + void *quantOffset = nullptr; // 可选,若无offset(如对称量化场景),传入空指针即可 + void *num_local_tokens_per_expert = nullptr; + void *num_global_tokens_per_local_expert = nullptr; + void *global_tokens_per_expert_matrix = nullptr; + }; + + struct CoCOutputPkg { + void *output = nullptr; + void *midOutput = nullptr; // 先通信后计算情况下,通信的中间结果 + }; + + struct TaskParam { + // hardware info + int32_t rank = -1; + int32_t rankSize = -1; + int32_t blockDim = -1; + int32_t bufferSize = -1; + ChipName chipName = ChipName::CHIP_910B3; + // param info + CoCParamDesc cocParamDesc = {}; + + // type + LcalType lcalType = LcalType::ALL_REDUCE; + }; +} +#endif // LCAL_LCOC_ARGS_H diff --git a/comm/lcal/include/lcoc/lcoc_base.h b/comm/lcal/include/lcoc/lcoc_base.h new file mode 100644 index 0000000000000000000000000000000000000000..7bf4321ce53ced51f5cd840e3c2cd29a9a35378c --- /dev/null +++ b/comm/lcal/include/lcoc/lcoc_base.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCAL_LCOC_BASE_H +#define LCAL_LCOC_BASE_H + +#include + +#pragma once +namespace Lcal { +enum QuantGranularity : int { + QUANT_GRANULARITY_UNDEFINED = -1, + PER_TENSOR = 0, + PER_CHANNEL = 1, + PER_GROUP = 2, + PER_TOKEN = 3, + FLOAT32_SCALE_PER_CHANNEL = 4, + QUANT_GRANULARITY_MAX = 5, +}; + +struct MatMulInfo { + int64_t batchSize = 1; + int64_t m = -1; + int64_t k = -1; + int64_t n = -1; + bool transA = false; + bool transB = false; + bool withBias = false; + bool isInt8 = false; + bool weightNz = false; +}; + +struct TwoDimTPInfo { // 2D-TP,含x轴的通信和y轴通信 + int32_t agDim = -1; // 表示ag轴卡数,规定x轴方向是非连续卡号 + int32_t rsDim = -1; // 表示rs轴卡数,规定y轴方向是连续卡号 + bool innerDimIsAg = true; // 是否沿着内轴进行allgather通信 +}; + +struct QuantInfo { + // 反量化(包括Matmul前置伪量化和后置反量化)粒度 + QuantGranularity dequantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + int32_t dequantGroupSize = -1; + + QuantGranularity quantGranularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED; // 量化粒度 + int32_t quantGroupSize = -1; +}; + +struct PostInfo { + int32_t withRmsNorm = 0; +}; + +struct MoeInfo { + int16_t local_expert_nums = 0; + int8_t EP = 0; + int8_t TP = 0; + int32_t maxOutputSize = -1; + int8_t isMoe = 0; +}; +} +#endif // LCAL_LCOC_BASE_H diff --git a/comm/lcal/include/lcoc/lcoc_func.h b/comm/lcal/include/lcoc/lcoc_func.h new file mode 100644 index 0000000000000000000000000000000000000000..35a95c97dfa275e2941fa324cd6f9b9260cc1f2c --- /dev/null +++ b/comm/lcal/include/lcoc/lcoc_func.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCAL_LCOC_FUNC_H +#define LCAL_LCOC_FUNC_H + +#include +#include +#include +#include +#include + +#pragma once +namespace Lcal { + // 校验参数取值范围在[min, max]内,当max=-1时,表示参数取值范围在[min, +∞) + bool CheckParamScope(const std::string &name, const int &value, const int &min, const int &max); + bool CheckParamScopeList(std::vector> paramCheckList); + bool CheckParamAlign(const std::string &name, const int &value, const int &align); + void PrintErrorLog(LcalType lcalType, const std::string &log); + bool CheckParamPowerOfTwo(const std::string &name, int value); +} + +#endif // LCAL_LCOC_FUNC_H \ No newline at end of file diff --git a/comm/lcal/include/lcoc/lcoc_workspace.h b/comm/lcal/include/lcoc/lcoc_workspace.h new file mode 100644 index 0000000000000000000000000000000000000000..0b9e40fe124f37c8f62b14d31fa484a711bff5a3 --- /dev/null +++ b/comm/lcal/include/lcoc/lcoc_workspace.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_LCOC_WORKSPACE_H +#define LCAL_LCOC_WORKSPACE_H + +#if !defined(__DAV_C220_VEC__) && !defined(__DAV_M200_VEC__) && !defined(__DAV_C220_CUBE__) && !defined(__DAV__C310__) +#define __aicore__ +#define GM_ADDR int64_t +#endif + +struct LcalWorkspaceInfo { // host侧起点为0,device起点为gm_workspace 记录Offset + GM_ADDR gm_reducebuf{ 0 }; + GM_ADDR gm_a_align{ 0 }; + GM_ADDR gm_b_align{ 0 }; + GM_ADDR gm_accum{ 0 }; + GM_ADDR gm_formate_dequant_scale{ 0 }; + GM_ADDR gm_dequant_param{ 0 }; + + // moe + GM_ADDR gm_out_loop_per_expert{ 0 }; + GM_ADDR gm_in_loop_per_expert{ 0 }; + GM_ADDR gm_out_loop_per_EP{ 0 }; + GM_ADDR gm_in_loop_per_EP{ 0 }; + GM_ADDR gm_sum_num_local_tokens_per_expert{ 0 }; + GM_ADDR gm_sum_num_global_tokens_per_local_expert{ 0 }; + GM_ADDR gm_in_expert_comm_count_accum{ 0 }; + GM_ADDR gm_out_expert_comm_count_accum{ 0 }; + + GM_ADDR gm_num_local_tokens_per_expert{ 0 }; + GM_ADDR gm_num_global_tokens_per_local_expert{ 0 }; + GM_ADDR comm_matrix_trunc{ 0 }; + + GM_ADDR workspaceSize {0}; // total size +}; + +inline __aicore__ int32_t AlignUp(int32_t len, int32_t size) +{ + return (len + size - 1) & ~(size - 1); +} + +#if !defined(__DAV_C220_VEC__) && !defined(__DAV_M200_VEC__) && !defined(__DAV_C220_CUBE__) && !defined(__DAV__C310__) +inline uint64_t GetDequantWorkSpaceSize(Lcal::LcalType lcalType, int32_t withSerialMode, int32_t m, int32_t n, + int32_t m0, int32_t n0, int32_t pValue, int32_t nLoop, int32_t rankSize, int32_t blockDim, + int32_t maxOutputSize = -1) +{ + (void) nLoop; + constexpr int32_t TWO = 2; + uint64_t dequantWorkSpaceSize = 0; + if (withSerialMode > 0) { + dequantWorkSpaceSize = (maxOutputSize == -1 ? m : maxOutputSize) * n * sizeof(int32_t); + if (lcalType == Lcal::LcalType::ALL_GATHER_MATMUL) { + dequantWorkSpaceSize *= rankSize; + } + } else { + if (lcalType == Lcal::LcalType::MATMUL_ALL_REDUCE || lcalType == Lcal::LcalType::MATMUL_REDUCE_SCATTER) { + dequantWorkSpaceSize = pValue * blockDim * m0 * n0 * TWO * sizeof(int32_t); + } else { + dequantWorkSpaceSize = (maxOutputSize == -1 ? m : maxOutputSize) * n * sizeof(int32_t); + if (lcalType == Lcal::LcalType::ALL_GATHER_MATMUL) { + dequantWorkSpaceSize *= rankSize; + } + } + } + return dequantWorkSpaceSize; +} +#endif + +inline __aicore__ void GetLcalMoeWorkspaceInfo(LcalWorkspaceInfo& lcalWorkspaceInfo, GM_ADDR& workspaceOffset, + int32_t m, bool hasDequantParam = false, int32_t is_alltoallvc = false, + int32_t EP = 1, int32_t expertPerRank = 1, int32_t outputSize = -1) +{ + (void) is_alltoallvc; + (void) outputSize; + constexpr int32_t ALIGN8 = 8; + if (hasDequantParam) { + lcalWorkspaceInfo.gm_dequant_param = workspaceOffset; + workspaceOffset += sizeof(float) * AlignUp(m * EP, ALIGN8); + } + lcalWorkspaceInfo.comm_matrix_trunc = workspaceOffset; + workspaceOffset += sizeof(int32_t) * EP * EP * expertPerRank; +} + +inline __aicore__ LcalWorkspaceInfo GetLcalWorkspaceInfo(GM_ADDR gmWorkSpace, int32_t batchSize, int32_t m, + int32_t k, int32_t n, int32_t mAlign, int32_t kAlign, int32_t nAlign, bool transa, bool transb, + int32_t mmadSize, bool hasAAlign, bool hasBAlign, int32_t accumRankSize, bool hasAccum = false, + uint64_t dequantWorkSpaceSize = 0, bool hasDequantParam = false, bool hasFormatDequantScale = false, + bool isDeterministic = false, + int32_t isMoe = false, int32_t is_alltoallvc = false, + int32_t EP = 1, int32_t expertPerRank = 1, int32_t outputSize = -1 +) +{ + (void) accumRankSize; + if (outputSize == -1) { + outputSize = m; + } + constexpr int32_t ALIGN8 = 8; + LcalWorkspaceInfo lcalWorkspaceInfo; + lcalWorkspaceInfo.gm_reducebuf = gmWorkSpace; + GM_ADDR workspaceOffset = gmWorkSpace; + if (isDeterministic) { + workspaceOffset += WORKSPACE_REDUCE_SIZE; + } + + if (hasAAlign) { + lcalWorkspaceInfo.gm_a_align = workspaceOffset; + workspaceOffset += static_cast(batchSize) * (transa ? k * mAlign : m * kAlign) * mmadSize; + } + + if (hasBAlign) { + lcalWorkspaceInfo.gm_b_align = workspaceOffset; + workspaceOffset += static_cast(batchSize) * (transb ? n * kAlign : k * nAlign) * mmadSize * + (expertPerRank <= 0 ? 1 : expertPerRank); + } + + if (isMoe) { + GetLcalMoeWorkspaceInfo(lcalWorkspaceInfo, workspaceOffset, m, hasDequantParam, is_alltoallvc, EP, + expertPerRank, outputSize); + } + + if (!isMoe && hasDequantParam) { + lcalWorkspaceInfo.gm_dequant_param = workspaceOffset; + workspaceOffset += sizeof(int32_t) * AlignUp(n, ALIGN8); + } + + if (hasFormatDequantScale) { + lcalWorkspaceInfo.gm_formate_dequant_scale = workspaceOffset; + workspaceOffset += sizeof(float) * AlignUp(n, ALIGN8); + } + + if (hasAccum) { + lcalWorkspaceInfo.gm_accum = workspaceOffset; + workspaceOffset += dequantWorkSpaceSize; + } + lcalWorkspaceInfo.workspaceSize = workspaceOffset; + return lcalWorkspaceInfo; +} + + +#endif \ No newline at end of file diff --git a/comm/lcal/include/lcoc/tiling/tiling.h b/comm/lcal/include/lcoc/tiling/tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..d6aa177b75b620f7d3c9bc755a47507eba31f861 --- /dev/null +++ b/comm/lcal/include/lcoc/tiling/tiling.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_TILING_H +#define LCAL_TILING_H + +#include +#include +#include "tiling_args.h" +#include "lcal_types.h" +#include "lcal_comm.h" +#include "lcoc.h" + +namespace Lcal { +class CoCTilingFunc { +public: + CoCTilingFunc(const CoCTilingFunc &) = delete; + CoCTilingFunc &operator = (const CoCTilingFunc &) = delete; + CoCTilingFunc() {} + virtual ~CoCTilingFunc() {} + CoCTilingData GenerateTiling(const TaskParam &taskParam, const CoCTiling &tiling); + + virtual bool CheckTiling(const TaskParam &taskParam); + virtual void GetDefaultTiling(const TaskParam &taskParam); + +protected: + CoCTilingData cocTilingData = {}; +}; + +class CoCMatmulAllReduceTilingFunc : public CoCTilingFunc { +public: + CoCMatmulAllReduceTilingFunc(const CoCMatmulAllReduceTilingFunc &) = delete; + CoCMatmulAllReduceTilingFunc &operator = (const CoCMatmulAllReduceTilingFunc &) = delete; + CoCMatmulAllReduceTilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; + +class CoCMatmulAllReduceDeterTilingFunc : public CoCMatmulAllReduceTilingFunc { +public: + CoCMatmulAllReduceDeterTilingFunc(const CoCMatmulAllReduceDeterTilingFunc &) = delete; + CoCMatmulAllReduceDeterTilingFunc &operator = (const CoCMatmulAllReduceDeterTilingFunc &) = delete; + CoCMatmulAllReduceDeterTilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; + +class CoCMatmulReduceScatterTilingFunc : public CoCMatmulAllReduceTilingFunc { +public: + CoCMatmulReduceScatterTilingFunc(const CoCMatmulReduceScatterTilingFunc &) = delete; + CoCMatmulReduceScatterTilingFunc &operator = (const CoCMatmulReduceScatterTilingFunc &) = delete; + CoCMatmulReduceScatterTilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; + +class CoCAllGatherMatmulTilingFunc : public CoCTilingFunc { +public: + CoCAllGatherMatmulTilingFunc(const CoCAllGatherMatmulTilingFunc &) = delete; + CoCAllGatherMatmulTilingFunc &operator = (const CoCAllGatherMatmulTilingFunc &) = delete; + CoCAllGatherMatmulTilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; + +class CoCAllGatherMatmulV2TilingFunc : public CoCTilingFunc { +public: + CoCAllGatherMatmulV2TilingFunc(const CoCAllGatherMatmulV2TilingFunc &) = delete; + CoCAllGatherMatmulV2TilingFunc &operator = (const CoCAllGatherMatmulV2TilingFunc &) = delete; + CoCAllGatherMatmulV2TilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; + +class CoCAllgatherMatmulReduceScatterTilingFunc : public CoCTilingFunc { +public: + CoCAllgatherMatmulReduceScatterTilingFunc(const CoCAllgatherMatmulReduceScatterTilingFunc &) = delete; + CoCAllgatherMatmulReduceScatterTilingFunc &operator = (const CoCAllgatherMatmulReduceScatterTilingFunc &) = delete; + CoCAllgatherMatmulReduceScatterTilingFunc() {} + bool CheckTiling(const TaskParam &taskParam) override; + void GetDefaultTiling(const TaskParam &taskParam) override; +}; +class CoCAllToAllAllGatherMatmulTilingFunc : public CoCAllGatherMatmulTilingFunc { +public: + CoCAllToAllAllGatherMatmulTilingFunc(const CoCAllToAllAllGatherMatmulTilingFunc &) = delete; + CoCAllToAllAllGatherMatmulTilingFunc &operator = (const CoCAllToAllAllGatherMatmulTilingFunc &) = delete; + CoCAllToAllAllGatherMatmulTilingFunc() {} + bool CheckTiling(const TaskParam &tilingInfo) override; + void GetDefaultTiling(const TaskParam &tilingInfo) override; +}; +class CoCAllToAllAllGatherMatmulHiddenTilingFunc : public CoCAllGatherMatmulTilingFunc { +public: + CoCAllToAllAllGatherMatmulHiddenTilingFunc(const CoCAllToAllAllGatherMatmulHiddenTilingFunc &) = delete; + CoCAllToAllAllGatherMatmulHiddenTilingFunc &operator = ( + const CoCAllToAllAllGatherMatmulHiddenTilingFunc &) = delete; + CoCAllToAllAllGatherMatmulHiddenTilingFunc() {} + bool CheckTiling(const TaskParam &tilingInfo) override; + void GetDefaultTiling(const TaskParam &tilingInfo) override; +}; + +class CoCMatmulReduceScatterAllToAllHiddenTilingFunc : public CoCMatmulReduceScatterTilingFunc { +public: + CoCMatmulReduceScatterAllToAllHiddenTilingFunc(const CoCMatmulReduceScatterAllToAllHiddenTilingFunc &) = delete; + CoCMatmulReduceScatterAllToAllHiddenTilingFunc &operator = ( + const CoCMatmulReduceScatterAllToAllHiddenTilingFunc &) = delete; + CoCMatmulReduceScatterAllToAllHiddenTilingFunc() {} + bool CheckTiling(const TaskParam &tilingInfo) override; + void GetDefaultTiling(const TaskParam &tilingInfo) override; +}; + +} +#endif // LCAL_TILING_H diff --git a/comm/lcal/include/lcoc/tiling/tiling_91093.h b/comm/lcal/include/lcoc/tiling/tiling_91093.h new file mode 100644 index 0000000000000000000000000000000000000000..d977d2531a180b92fe9795f986ee2d59abaf846c --- /dev/null +++ b/comm/lcal/include/lcoc/tiling/tiling_91093.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCAL_TILING_91093_H +#define LCAL_TILING_91093_H + +#include "tiling_args.h" +namespace Lcal { + void AllGatherNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData); + void AllGatherNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData); + void AllGatherNPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData); + void AllGatherNPU91093TwoRankINT8Tiling(CoCTilingData &cocTilingData); + + void AllGatherV2NPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData); + void AllGatherV2NPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData); + void AllGatherV2NPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData); + + void AllReduceNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData); + void AllReduceNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData); + + void ReduceScatterNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData); + void ReduceScatterNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData); + void ReduceScatterNPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData); + void ReduceScatterNPU91093TwoRankINT8Tiling(CoCTilingData &cocTilingData); + void ReduceScatterNPU91093FourRankFP16Tiling(CoCTilingData &cocTilingData); + + void CoCAllgatherMatmulReduceScatterAgEightRsTwoTiling(CoCTilingData &cocTilingData); + void CoCAllgatherMatmulReduceScatterDefaultTiling(CoCTilingData &cocTilingData, int32_t rsDim); +} +#endif // LCAL_TILING_91093_H diff --git a/comm/lcal/include/lcoc/tiling/tiling_910B.h b/comm/lcal/include/lcoc/tiling/tiling_910B.h new file mode 100644 index 0000000000000000000000000000000000000000..ca6efab23c51919f28902790aff165da5422bc3c --- /dev/null +++ b/comm/lcal/include/lcoc/tiling/tiling_910B.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCAL_TILING_910B_H +#define LCAL_TILING_910B_H + +#include "tiling_args.h" +namespace Lcal { + void AllGatherGetDefaultTiling(CoCTilingData &cocTilingData); + void AllGatherEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData); + void AllGatherFourRankINT8Tiling(CoCTilingData &cocTilingData); + + void AllGatherV2EightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData); + void AllGatherV2EightRankFP16Core16GetDefaultTiling(CoCTilingData &cocTilingData); + + void AllReduceGetDefaultTiling(CoCTilingData &cocTilingData); + void AllReduceFourRankInt8GetDefaultTiling(CoCTilingData &cocTilingData); + void AllReduceFourRankFP16GetDefaultTiling(CoCTilingData &cocTilingData); + void AllReduceEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData); + void AllReduceEightRankINT8GetDefaultTiling(CoCTilingData &cocTilingData); + void AllReduceTwoRankFP16Tiling(CoCTilingData &cocTilingData); + + void ReduceScatterEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData); + void ReduceScatterFourRankINT8Tiling(CoCTilingData &cocTilingData); +} +#endif // LCAL_TILING_910B_H \ No newline at end of file diff --git a/comm/lcal/include/lcoc/tiling/tiling_args.h b/comm/lcal/include/lcoc/tiling/tiling_args.h new file mode 100644 index 0000000000000000000000000000000000000000..f46e047702dbbfff03f03f75cd5d47add9f95613 --- /dev/null +++ b/comm/lcal/include/lcoc/tiling/tiling_args.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_TILING_ARGS_H +#define LCAL_TILING_ARGS_H + +#include "lcoc_base.h" + +#pragma once +namespace Lcal { + constexpr int32_t MAX_CORE_NUM = 20; + constexpr int32_t MAX_L2_SIZE = 192 * 1024 * 1024; + constexpr int32_t MAX_L0CSIZE = 128 * 1024; + constexpr int32_t HBM_BM = 1; + constexpr int32_t L2_BW = 5; + constexpr int32_t BYTE_512 = 512; + constexpr int32_t MAX_UB_NUM = 97280; // 190 * 1024 / 2 + constexpr int32_t MIN_UB_NUM = 256; + constexpr int32_t A3_DIE_NUM = 2; // 一张卡有两个die + constexpr int32_t DEFAULT_P_VALUE = 1; + constexpr int32_t MIN_P_VALUE = 1; + constexpr int32_t MAX_P_VALUE = 15; + constexpr int32_t TAG_MOD = 10000; + constexpr int32_t SWIZZLE_COUNT_FOUR = 4; + constexpr int32_t DEFAULT_SWIZZLE_COUNT = 7; + constexpr int32_t SWIZZLE_DIRECT_ZERO = 0; + constexpr int32_t SWIZZLE_DIRECT_ONE = 1; + constexpr int32_t COMM_DATA_DIRECT = 0; + constexpr int32_t COMM_NPU_DIRECT = 1; + constexpr int32_t COMMNPUSPLIT_ONE = 1; + constexpr int32_t COMMNPUSPLIT_TWO = 2; + constexpr int32_t COMMNPUSPLIT_THREE = 3; + constexpr int32_t COMMNPUSPLIT_EIGHT = 8; + constexpr int32_t COMMNPUSPLIT_FOUR = 4; + constexpr int32_t COMMDATASPLIT_ONE = 1; + constexpr int32_t COMMDATASPLIT_TWO = 2; + constexpr int32_t COMMDATASPLIT_FOUR = 4; + constexpr int32_t COMMDATASPLIT_EIGHT = 8; + constexpr int32_t COMMDATASPLIT_SIXTEEN = 16; + constexpr int32_t FLAG_BUFF_BYTES = 5 * 512 * 1024; // 2.5MB + constexpr int32_t AXES_ALIGN_SIZE_INT8 = 128; + constexpr int32_t DEFAULT_ROW = 128; + constexpr int32_t DEFAULT_COL = 256; + constexpr int32_t AXES_ALIGN_SIZE = 512; + constexpr int32_t BASE_BLOCK_STEP = 2; + constexpr int32_t INPUT_DTYPE = 2; + constexpr int32_t MAX_BLOCK_COUNT = 2; + constexpr int32_t BLOCK_COUNT_3 = 3; + constexpr int32_t FP16_SIZE = 2; + constexpr int32_t FP32_SIZE = 4; + constexpr int32_t BLOCK_SIZE = 16; + constexpr int32_t BLOCK_SIZE_K = 32; + constexpr int32_t ND_SHAPE_SIZE = 2; + constexpr int32_t NZ_SHAPE_SIZE = 4; + constexpr int32_t CUBE_BLOCK_SIZE_INT8 = 512; + constexpr int32_t CUBE_BLOCK_SIZE = 256; + constexpr int32_t MIN_UB_MOVE_NUM = 5120; + constexpr int32_t VALID_UB_MOVE_NUM = 20480; + constexpr int32_t L1AB_PINGPONG_BUFFER_LEN_FP16 = 131072; // 128 KB + constexpr int32_t HALF_KBYTE = 512; + constexpr int32_t SECOND_TO_MS = 1e3; + constexpr int64_t MATMUL_BASE_100US = static_cast(1024) * 8192 * 1024; + constexpr int64_t ALLREDUCE_BASE_100US = 4096 * 1024; + constexpr double ONE_K = 1024.0; + constexpr double B1_FLOP_PER_MS = (364 * 0.8) * 1e9; + constexpr double DOUBLE = 2.0; + constexpr double HALF_PROB = 0.5; + constexpr int32_t CONDITION_M_ST = 0; + constexpr int32_t CONDITION_M_END = 1; + constexpr int32_t CONDITION_K_ST = 2; + constexpr int32_t CONDITION_K_END = 3; + constexpr int32_t CONDITION_N_ST = 4; + constexpr int32_t CONDITION_N_END = 5; + constexpr int32_t RANKSIZE_TWO = 2; + constexpr int32_t RANKSIZE_FOUR = 4; + constexpr int32_t RANKSIZE_EIGHT = 8; + constexpr int32_t RANKSIZE_SIXTEEN = 16; + constexpr int32_t DIV_TWO = 2; + constexpr int32_t LENPERLOOP_DEFAULT = 5120; + constexpr int32_t ALLGATHERV2_CORENUM_SIXTEEN = 16; + constexpr int32_t ALLREDUCE_LENPERLOOP_DEFAULT = 5120; // 使用的core数为16时的取值 + constexpr int32_t TREE_LEN_PER_LOOP = 20480; + constexpr int32_t DIM_EIGHT = 8; + constexpr int32_t DIM_TWO = 2; + constexpr int32_t DEFAULT_SPLIT_K = 0; + constexpr int32_t NUM_TWO = 2; + + // Todo: tmp hard code, need tiling func for moe + constexpr int32_t AllTOAll_HIDDEN_UBMOVENUM = 28672; + + + // 默认值均为-1 + struct CoCTiling { + // Tiling参数,用来控制融合算子执行策略 + // 可外部传入,也可内部计算得到 + int32_t m0 = -1; + int32_t k0 = -1; + int32_t n0 = -1; + int32_t swizzlCount = -1; + int32_t swizzlDirect = -1; + int32_t pValue = -1; + int32_t ubMoveNum = -1; + int32_t commNpuSplit = -1; + int32_t commDataSplit = -1; + int32_t commDirect = -1; + int32_t lenPerLoop = -1; + int32_t extraUbMoveNum = -1; + int32_t extraCommNpuSplit = -1; // 2dtp使用 + int32_t extraCommDataSplit = -1; // 2dtp使用 + int32_t extraCommDirect = -1; // 2dtp使用 + int32_t extraLenPerLoop = -1; // 2dtp使用 + int32_t splitK = -1; + int32_t write2OtherRank = -1; + int32_t withSerialMode = -1; + // 控制融合算子实现的参数 + int32_t is91093 = -1; + int32_t bufferSize = -1; + }; + + struct CoCTilingData : CoCTiling { + // 外部传入的参数 + int64_t m = -1; + int64_t k = -1; + int64_t n = -1; + int64_t batchSize = -1; + + // NPU相关的参数 + int32_t blockDim = -1; + int32_t rank = -1; + int32_t rankSize = -1; + int32_t tag = -1; // 默认值为0 + + // 内部计算得到的参数 + int32_t mLoop = -1; + int32_t kLoop = -1; + int32_t nLoop = -1; + int32_t coreLoop = -1; + uint32_t tilingKey = -1; + + // Tiling Func + const char* ToString() const; + void SetDefaultValue(); // 设置默认值 + }; + + struct CoCKernelParam { + CoCTilingData cocTilingData = {}; + QuantInfo quantInfo = {}; // device侧对应23-26 + TwoDimTPInfo twoDimTPInfo = {}; // device侧对应27-29 + PostInfo postInfo = {}; // device侧对应30 + MoeInfo moeInfo = {}; // device侧对应31 + bool weightNz = false; + }; +} +#endif // LCAL_TILING_ARGS_H diff --git a/comm/lcal/include/lcoc/tiling/tiling_func.h b/comm/lcal/include/lcoc/tiling/tiling_func.h new file mode 100644 index 0000000000000000000000000000000000000000..111e2b02510857b1161f7c5ff7f6c607d9a8488c --- /dev/null +++ b/comm/lcal/include/lcoc/tiling/tiling_func.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_TILING_FUNC_H +#define LCAL_TILING_FUNC_H + +#include +#include +#include +#include "lcoc_args.h" +#include "lcal_types.h" +#include "tiling_args.h" + +#pragma once +namespace Lcal { + struct TilingValue { + int32_t value = -1; + std::map>> conditionMap = {}; + }; + + int32_t CeilDev(int32_t num, int32_t div); + int32_t RoundNum(int32_t num, int32_t rnd); + void UpdateTilingValue(const int32_t &tilingParam, int32_t &tilingDataParam); + double GetMTETime(double mknGB, int32_t m0, int32_t n0, double aBindWidth = 3.0, double bBindWidth = 3.0); + int32_t GetValueFromMKNConditionMap(int32_t m, int32_t k, int32_t n, + int32_t defaultValue, + std::map>> conditionMap); + bool Is910B(const ChipName &chipName); + bool Is91093(const ChipName &chipName); + uint32_t GetTilingKey(const MatMulInfo &mmInfo, CoCTilingData &tilingData); + void DealTilingParamByBuffSize(CoCTilingData &cocTilingData); + int ClampValue(int32_t value, int32_t min, int32_t max); + void SetTilingParam(CoCTilingData &cocTilingData, const std::map& tilingParamMap); + void SetSecondCoreSplitTling(CoCTilingData &cocTilingData); + void SetTilingParam2D(CoCTilingData &cocTilingData, const std::map& tilingParamMap); + bool CheckCoCTiling(const CoCTiling &tiling); + bool CheckCoCTilingData(const CoCTilingData &tilingData); + void TransformCoCTiling(const CoCTiling &tiling, CoCTilingData &tilingData); + void CalTilingParam(const MatMulInfo &mmInfo, CoCTilingData &tilingData); + void SetTilingInputParam(const TaskParam &taskParam, CoCTilingData &tilingData); + void SetTilingData(const TaskParam &taskParam, const CoCTiling &tiling, CoCTilingData &tilingData); +} + +#endif // LCAL_TILING_FUNC_H diff --git a/comm/lcal/src/CMakeLists.txt b/comm/lcal/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..400edecda5750b7d12f293d0cd4ba42b3cbb0965 --- /dev/null +++ b/comm/lcal/src/CMakeLists.txt @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +set_source_files_properties(lcal_internal.cpp PROPERTIES COMPILE_FLAGS "-O3") + +set(LCAL_SOURCE_FILE lcal_comm.cpp lccl.cpp + lcal_internal.cpp lcal_internal.h lcal_wrap.cpp + tools/socket/lcal_sock_exchange.h + tools/socket/lcal_sock_exchange.cpp + coc_kernel_args.h coc_kernel_args.cpp lcoc.cpp lcoc_func.cpp +) +file(GLOB TILING_SOURCE_FILE tiling/*.cpp) +list(APPEND LCAL_SOURCE_FILE ${TILING_SOURCE_FILE}) + +add_library(lcal SHARED ${LCAL_SOURCE_FILE}) +add_library(lcal_static STATIC ${LCAL_SOURCE_FILE}) +set_target_properties(lcal_static PROPERTIES POSITION_INDEPENDENT_CODE ON) + +target_link_libraries(lcal ascendcl runtime profapi c_sec mki) +target_link_libraries(lcal_static ascendcl runtime profapi c_sec mki) + +message(STATUS "LCAL USE_MSSANITIZER = ${USE_MSSANITIZER}") +set(LCAL_CCE_PATH "/tmp/lcal_cce.o") +if(USE_MSSANITIZER) +math(EXPR LCAL_1OP_BIN_SIZE "128 * 1024 * 1024") +add_definitions(-DUSE_MSSANITIZER) +else() +math(EXPR LCAL_1OP_BIN_SIZE "5 * 1024 * 1024") +endif() + +add_definitions(-DLCAL_1OP_BIN_SIZE=${LCAL_1OP_BIN_SIZE}) + +add_subdirectory(kernels) +add_subdirectory(ascendc_kernels) + +add_custom_command( + OUTPUT ${LCAL_CCE_PATH} + COMMAND cat ascendc_kernels/lccl_op.o kernels/lcoc_op.o > ${LCAL_CCE_PATH} + COMMAND echo "concat op..." + DEPENDS lccl_op lcoc_op +) + +set_source_files_properties( + lcal_internal.cpp + PROPERTIES + OBJECT_DEPENDS ${LCAL_CCE_PATH} +) +install(TARGETS lcal LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) +install(TARGETS lcal_static DESTINATION ${CMAKE_INSTALL_PREFIX}/lib) + + + diff --git a/comm/lcal/src/ascendc.cmake b/comm/lcal/src/ascendc.cmake new file mode 100644 index 0000000000000000000000000000000000000000..27f3d366cf19f6485c62058e4678596c72f7e2c5 --- /dev/null +++ b/comm/lcal/src/ascendc.cmake @@ -0,0 +1,52 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +enable_language(CCE) +# 设置编译选项 +# 定义 sanitizer 相关的编译选项,如果启用则添加,否则为空 +if(USE_MSSANITIZER) + set(SANITIZER_FLAGS + -g --cce-enable-sanitizer + ) + set(SANITIZER_DEPEND_LIBS + --dependent-libraries ${ASCEND_HOME_PATH}/tools/mssanitizer/lib64/libsanitizer_stub_dav-c220-vec.a + --dependent-libraries ${ASCEND_HOME_PATH}/tools/mssanitizer/lib64/libsanitizer_stub_dav-c220-cube.a + ) +else() + set(SANITIZER_FLAGS) # 空 + set(SANITIZER_DEPEND_LIBS) +endif() +set(CCE_COMPILE_OPTION + -O2 -std=gnu++17 + --cce-aicore-only + -Wno-deprecated-declarations + ${SANITIZER_FLAGS} + "SHELL:-mllvm -cce-aicore-long-call" + "SHELL:-mllvm -cce-aicore-function-stack-size=16000" + "SHELL:-mllvm -cce-aicore-record-overflow=false" + "SHELL:-mllvm -cce-aicore-addr-transform" + "SHELL:-mllvm --cce-aicore-jump-expand=true" +) +set(PRIVATE_CCEC_PATH ${CMAKE_SOURCE_DIR}/3rdparty/compiler) +# 设置包含路径 +if (EXISTS ${PRIVATE_CCEC_PATH}) + message(STATUS "Using custom ccec include directories") + set(CCE_INCLUDE_BASE ${PRIVATE_CCEC_PATH}) +else() + set(CCE_INCLUDE_BASE ${ASCEND_HOME_PATH}/${ARCH}-linux) +endif() + +message(STATUS "Using tikcpp include directories") +include_directories( + ${ASCEND_HOME_PATH}/toolkit/toolchain/hcc/aarch64-target-linux-gnu/include/c++/7.3.0 + ${ASCEND_HOME_PATH}/toolkit/toolchain/hcc/aarch64-target-linux-gnu/include/c++/7.3.0/aarch64-target-linux-gnu/ + ${CCE_INCLUDE_BASE}/tikcpp/tikcfw/ + ${CCE_INCLUDE_BASE}/tikcpp/tikcfw/interface/ + ${CCE_INCLUDE_BASE}/tikcpp/tikcfw/impl/ +) \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy.h b/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy.h new file mode 100644 index 0000000000000000000000000000000000000000..c1da7cc98bab6d2b1c23685cc5601f645dabe323 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy.h @@ -0,0 +1,240 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALL2ALL_HIERARCHY_H +#define LCCL_ALL2ALL_HIERARCHY_H + +#include "collectives.h" +#include "sync_collectives.h" +#include "ipc_queue.h" + +using namespace AscendC; + +template +class All2AllHierarchy : protected Collectives { + constexpr static int QUEUE_DEPTH = 2; + constexpr static int32_t STEP_TIMES = 2; + constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; + constexpr static int INVALID_RANK = 0xFFFFFFFF; + constexpr static const int64_t SIO = 2; + constexpr static int64_t CORE_NUM_PER_STAGE = 16; + constexpr static int64_t MULTI_RANK_SIZE = CORE_NUM_PER_STAGE; + constexpr static int64_t PRODUCER_CORE = 1; + constexpr static int64_t CONSUMER_CORE = 2; + static const int64_t DIE_CHANGE = 1; + +public: + FORCE_INLINE_AICORE All2AllHierarchy(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + this->input = (__gm__ T *) input; + this->output = (__gm__ T *) output; + + perRankDataNum = GetDataCount(len, rankSize); + curRankDataNum = perRankDataNum; + InitShare(); + InitCoreGroup(); + InitDataSlice(); + } + FORCE_INLINE_AICORE void Process() + { + if (coreGroup == PRODUCER_CORE) { + Producer(); + } else { + Consumer(); + } + } +private: + FORCE_INLINE_AICORE void InitShare() + { + int64_t queNum = blockNum / STEP_TIMES; + if (rankSize <= CORE_NUM_PER_STAGE) { + queNum = rankSize; + } + if (len < perQueElemLen) { + coreNumPerRank = 1; + } + perQueElemLen = IPC_BUFF_MAX_SIZE / queNum / QUEUE_DEPTH / sizeof(T); + queLen = perQueElemLen * QUEUE_DEPTH; + queSize = queLen * sizeof(T); + } + + FORCE_INLINE_AICORE void InitCoreGroup() + { + coreNumPerRank = 1; + if (len < perQueElemLen) { + coreNumPerRank = 1; + } + coreNumPerStage = coreNumPerRank * rankSize < CORE_NUM_PER_STAGE ? + coreNumPerRank * rankSize : CORE_NUM_PER_STAGE; + rankNumPerCore = CeilDiv(rankSize, coreNumPerStage); + flagNumPerStage = rankSize; + groupCore = (rank / coreNumPerStage) * coreNumPerStage; + if (blockIdx < coreNumPerStage) { + coreGroup = PRODUCER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + groupCoreIdx[i] = (groupCore + i * coreNumPerStage) % rankSize + blockIdx; + } + } else if (blockIdx < coreNumPerStage + coreNumPerStage) { + coreGroup = CONSUMER_CORE; + for (auto i = 0; i < rankNumPerCore; ++i) { + int64_t prefix = (groupCore - i * coreNumPerStage) >= 0 ? + (groupCore - i * coreNumPerStage) : groupCore + ((rankNumPerCore - i) * coreNumPerStage); + groupCoreIdx[i] = prefix + blockIdx - coreNumPerStage; + } + } + } + + FORCE_INLINE_AICORE void InitDataSlice() + { + ipcDataNumPreBlock = curRankDataNum; + if (coreGroup == PRODUCER_CORE) { + for (auto i = 0; i < rankNumPerCore; ++i) { + if (groupCoreIdx[i] % SIO == rank % SIO) { + srcInnerQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + (groupCoreIdx[i] % coreNumPerStage) * queSize, queLen, perQueElemLen); + } else { + SrcSioQue[i].Init(&sync, magic, shareAddrs[sioRank] + IPC_DATA_OFFSET + + ((groupCoreIdx[i] + (rank - sioRank)) % coreNumPerStage) * queSize, + queLen, perQueElemLen); + } + sliceNum = CeilDiv(ipcDataNumPreBlock, perQueElemLen); + } + } else if (coreGroup == CONSUMER_CORE) { + for (auto i = 0; i < rankNumPerCore; ++i) { + computePullRank(groupCoreIdx[i], rank); + if (rank % SIO == 0) { + pullOffset = DIE_CHANGE * groupCoreIdx[i] % SIO; + } else { + pullOffset = groupCoreIdx[i] % SIO - DIE_CHANGE; + } + + pullQue[i].Init(&sync, magic, shareAddrs[pullRank] + IPC_DATA_OFFSET + + (rank % coreNumPerStage) * queSize + pullOffset * queSize, queLen, perQueElemLen); + sliceNum = CeilDiv(ipcDataNumPreBlock, perQueElemLen); + } + } + } + + FORCE_INLINE_AICORE void Producer() + { + for (auto i = 0; i < rankNumPerCore; ++i) { + for (auto sliceIdx = 0; sliceIdx < sliceNum; ++sliceIdx) { + Input2IpcSlice(i, sliceIdx); + } + } + } + + FORCE_INLINE_AICORE void Input2IpcSlice(int64_t idx, int64_t sliceIdx) + { + inputGt.SetGlobalBuffer((__gm__ T*)input + groupCoreIdx[idx] * ipcDataNumPreBlock, ipcDataNumPreBlock); + copyLen = ipcDataNumPreBlock - perQueElemLen * sliceIdx; + if (copyLen > perQueElemLen) { + copyLen = perQueElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + if (groupCoreIdx[idx] % SIO == rank % SIO) { + if (idx > 0) { + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * (idx - 1), + groupCoreIdx[idx - 1] + flagNumPerStage, rank); + } + srcInnerQue[idx].DeQue(rank, groupCoreIdx[idx] + flagNumPerStage); + writeGt = srcInnerQue[idx].EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), inputGt[sliceIdx * perQueElemLen], writeGt, Op::COPYONLY); + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, groupCoreIdx[idx], rank); + } + } else { + if (idx > 0) { + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * (idx - 1), + groupCoreIdx[idx - 1] + flagNumPerStage + (rank - sioRank), sioRank); + } + SrcSioQue[idx].DeQue(sioRank, groupCoreIdx[idx] + (rank - sioRank) + flagNumPerStage); + writeGt = SrcSioQue[idx].EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), inputGt[sliceIdx * perQueElemLen], writeGt, Op::COPYONLY); + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, groupCoreIdx[idx] + (rank - sioRank), sioRank); + } + } + } + FORCE_INLINE_AICORE void Consumer() + { + for (auto i = 0; i < rankNumPerCore; ++i) { + computePullRank(groupCoreIdx[i], rank); + for (auto sliceIdx = 0; sliceIdx < sliceNum; ++sliceIdx) { + Ipc2Output(i, sliceIdx); + } + } + } + + FORCE_INLINE_AICORE void computePullRank(int64_t& target, int64_t rank) + { + if (rank % SIO == 0) { + pullRank = (target / SIO) * SIO; + } else { + pullRank = (target / SIO) * SIO + DIE_CHANGE; + } + } + + FORCE_INLINE_AICORE void Ipc2Output(int64_t idx, int64_t sliceIdx) + { + outputGt.SetGlobalBuffer((__gm__ T*)output + groupCoreIdx[idx] * ipcDataNumPreBlock, + ipcDataNumPreBlock); + cpOffset = rank % SIO == 0 ? rank + groupCoreIdx[idx] % SIO : + (rank - DIE_CHANGE) + groupCoreIdx[idx] % SIO; + copyLen = ipcDataNumPreBlock - perQueElemLen * sliceIdx; + if (copyLen > perQueElemLen) { + copyLen = perQueElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + readGt = pullQue[idx].ReadFront(); + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * idx, cpOffset, pullRank); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), readGt, outputGt[sliceIdx * perQueElemLen], Op::COPYONLY); + } + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, cpOffset + flagNumPerStage, pullRank); + } + GlobalTensor inputGt; + GlobalTensor outputGt; + GlobalTensor readGt; + GlobalTensor writeGt; + __gm__ T *input; + __gm__ T *output; + + int atomOp; + IpcQueue srcInnerQue[MULTI_RANK_SIZE]; + IpcQueue SrcSioQue[MULTI_RANK_SIZE]; + IpcQueue pullQue[MULTI_RANK_SIZE]; + int64_t perRankDataNum; + int64_t curRankDataNum; + int64_t ipcDataNumPreBlock; + int64_t pullRank; + int64_t pullOffset; + int64_t sioRank = (rank % 2 == 0) ? rank + 1:rank - 1; + int64_t cpOffset; + int64_t perQueElemLen; + int64_t queLen; + int64_t queSize; + int64_t coreNumPerStage; + int64_t flagNumPerStage; + int64_t coreNumPerRank; + int64_t rankNumPerCore; + int64_t coreGroup; + int64_t groupCoreIdx[MULTI_RANK_SIZE]; + int64_t sliceNum; + int64_t copyLen; + int64_t groupCore; +}; + +#endif // LCCL_ALL2ALL_HIERARCHY_H diff --git a/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy_small.h b/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy_small.h new file mode 100644 index 0000000000000000000000000000000000000000..3ea12ebf58e3a5e39cf35f5574f2548156eadebb --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/all2all_hierarchy_small.h @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALL2ALL_HIERARCHY_SMALL_H +#define LCCL_ALL2ALL_HIERARCHY_SMALL_H + +#include "collectives.h" +#include "sync_collectives.h" +#include "ipc_queue.h" + +using namespace AscendC; + +template +class All2AllHierarchySmall : protected Collectives { + constexpr static int QUEUE_DEPTH = 2; + constexpr static int32_t STEP_TIMES = 2; + constexpr static int INVALID_RANK_NUM = 0xFFFFFFFF; + constexpr static int INVALID_RANK = 0xFFFFFFFF; + constexpr static int64_t CORE_NUM_PER_STAGE = 16; + constexpr static int64_t PRODUCER_CORE = 1; + constexpr static int64_t CONSUMER_CORE = 2; + constexpr static int64_t SIO = 2; + +public: + FORCE_INLINE_AICORE All2AllHierarchySmall(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + this->input = (__gm__ T *) input; + this->output = (__gm__ T *) output; + + curRankDataNum = GetDataCount(len, rankSize); + InitShare(); + InitCoreGroup(); + InitDataSlice(); + } + FORCE_INLINE_AICORE void Process() + { + if (coreGroup == PRODUCER_CORE) { + Producer(); + } else { + Consumer(); + } + } +private: + FORCE_INLINE_AICORE void InitShare() + { + coreNumPerStage = CORE_NUM_PER_STAGE; + singleStage = coreNumPerStage / SIO; + perQueElemLen = IPC_BUFF_MAX_SIZE / SIO / singleStage / QUEUE_DEPTH / sizeof(T); + queLen = perQueElemLen * QUEUE_DEPTH; + queSize = queLen * sizeof(T); + queBlockSize = IPC_BUFF_MAX_SIZE / SIO; + } + + FORCE_INLINE_AICORE void InitCoreGroup() + { + if (len < perQueElemLen) { + coreNumPerRank = 1; + } + loopCount = rankSize / SIO; + flagNumPerStage = coreNumPerStage; + if (blockIdx < coreNumPerStage) { + coreGroup = PRODUCER_CORE; + groupCoreIdx = blockIdx; + } else if (blockIdx < coreNumPerStage + coreNumPerStage) { + coreGroup = CONSUMER_CORE; + groupCoreIdx = blockIdx - coreNumPerStage; + } + } + + FORCE_INLINE_AICORE void InitDataSlice() + { + ipcDataNumPreBlock = GetDataCount(curRankDataNum, singleStage); + int64_t ifOffSet = queBlockSize * (rank % SIO); + if (coreGroup == PRODUCER_CORE) { + for (auto i = 0; i < loopCount; ++i) { + if (groupCoreIdx < singleStage) { + srcLocalQue1.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + ifOffSet + + groupCoreIdx * queSize, queLen, perQueElemLen); + } else { + srcSioQue1.Init(&sync, magic, shareAddrs[sioRank] + IPC_DATA_OFFSET + ifOffSet + + (groupCoreIdx - singleStage) * queSize, queLen, perQueElemLen); + } + sliceNum = CeilDiv(ipcDataNumPreBlock, perQueElemLen); + } + } + sliceNum = CeilDiv(ipcDataNumPreBlock, perQueElemLen); + } + + FORCE_INLINE_AICORE void Producer() + { + for (auto i = 0; i < loopCount; ++i) { + srcRank = (rank + i * SIO) % rankSize; + sioSrcRank = (srcRank % SIO == 0) ? srcRank + 1 : srcRank - 1; + srcLocalQue = srcLocalQue1; + srcSioQue = srcSioQue1; + for (auto sliceIdx = 0; sliceIdx < sliceNum; ++sliceIdx) { + Input2IpcSlice(i, sliceIdx); + } + } + } + + FORCE_INLINE_AICORE void Input2IpcSlice(int64_t idx, int64_t sliceIdx) + { + copyLen = ipcDataNumPreBlock - perQueElemLen * sliceIdx; + if (copyLen > perQueElemLen) { + copyLen = perQueElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + int64_t flagIdx = groupCoreIdx + (rank % SIO) * singleStage; + if (groupCoreIdx < singleStage) { + if (idx > 0) { + int64_t waitRank = (srcRank - SIO) >= 0 ? (srcRank - SIO) : srcRank + ((loopCount - 1) * SIO); + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * (idx - 1), flagIdx + (waitRank / SIO) * coreNumPerStage + + flagNumPerStage, rank); + } + inputGt.SetGlobalBuffer((__gm__ T*)input + srcRank * curRankDataNum + groupCoreIdx * ipcDataNumPreBlock, + ipcDataNumPreBlock); + srcLocalQue.DeQue(rank, flagIdx + (srcRank / SIO) * coreNumPerStage + flagNumPerStage); + writeGt = srcLocalQue.EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), inputGt[sliceIdx * perQueElemLen], writeGt, Op::COPYONLY); + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, flagIdx, rank); + } + } else { + flagIdx = flagIdx - singleStage; + if (idx > 0) { + int64_t waitRank = (sioSrcRank - SIO) >= 0 ? (sioSrcRank - SIO) : sioSrcRank + ((loopCount - 1) * SIO); + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * (idx - 1), flagIdx + (waitRank / SIO) * coreNumPerStage + + flagNumPerStage, sioRank); + } + inputGt.SetGlobalBuffer((__gm__ T*)input + sioSrcRank * curRankDataNum + + (groupCoreIdx - singleStage) * ipcDataNumPreBlock, ipcDataNumPreBlock); + srcSioQue.DeQue(sioRank, flagIdx + (sioSrcRank / SIO) * coreNumPerStage + flagNumPerStage); + writeGt = srcSioQue.EnQue(); + if (copyLen > 0) { + CpGM2GMPingPong(copyLen * sizeof(T), inputGt[sliceIdx * perQueElemLen], writeGt, Op::COPYONLY); + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, flagIdx, sioRank); + } + } + } + FORCE_INLINE_AICORE void Consumer() + { + for (auto i = 0; i < loopCount; ++i) { + destRank = (rank - i * SIO) >= 0 ? (rank - i * SIO) : rank + ((loopCount - i) * SIO); + if (groupCoreIdx < singleStage) { + detHccsQue.Init(&sync, magic, shareAddrs[destRank] + IPC_DATA_OFFSET + + groupCoreIdx * queSize, queLen, perQueElemLen); + } else { + detHccsSioQue.Init(&sync, magic, shareAddrs[destRank] + IPC_DATA_OFFSET + queBlockSize + + (groupCoreIdx - singleStage) * queSize, queLen, perQueElemLen); + } + for (auto sliceIdx = 0; sliceIdx < sliceNum; ++sliceIdx) { + Ipc2Output(i, sliceIdx); + } + } + } + + FORCE_INLINE_AICORE void Ipc2Output(int64_t idx, int64_t sliceIdx) + { + outputGt.SetGlobalBuffer((__gm__ T*)output + (destRank / SIO) * SIO * curRankDataNum + + groupCoreIdx * ipcDataNumPreBlock, ipcDataNumPreBlock); + copyLen = ipcDataNumPreBlock - perQueElemLen * sliceIdx; + if (copyLen > perQueElemLen) { + copyLen = perQueElemLen; + } else if (copyLen < 0) { + copyLen = 0; + } + sync.WaitSyncFlag(magic, sliceIdx + sliceNum * idx, groupCoreIdx, destRank); + if (groupCoreIdx < singleStage) { + readGt = detHccsQue.ReadFront(); + } else { + readGt = detHccsSioQue.ReadFront(); + } + CpGM2GMPingPong(copyLen * sizeof(T), readGt, outputGt[sliceIdx * perQueElemLen], Op::COPYONLY); + sync.SetSyncFlag(magic, sliceIdx + sliceNum * idx, groupCoreIdx + flagNumPerStage + + (rank / SIO) * coreNumPerStage, destRank); + } + GlobalTensor inputGt; + GlobalTensor readGt; + GlobalTensor writeGt; + GlobalTensor outputGt; + __gm__ T *input; + __gm__ T *output; + + int atomOp; + IpcQueue srcLocalQue; + IpcQueue srcSioQue; + IpcQueue detHccsQue; + IpcQueue detHccsSioQue; + IpcQueue srcLocalQue1; + IpcQueue srcSioQue1; + + int64_t loopCount; + int64_t queBlockSize; + int64_t srcRank; + int64_t sioSrcRank; + int64_t destRank; + int64_t singleStage; + int64_t curRankDataNum; + int64_t ipcDataNumPreBlock; + int64_t sioRank = (rank % 2 == 0) ? rank + 1:rank - 1; + int64_t perQueElemLen; + int64_t queLen; + int64_t queSize; + int64_t coreNumPerStage; + int64_t flagNumPerStage; + int64_t coreNumPerRank; + int64_t coreGroup; + int64_t groupCoreIdx; + int64_t sliceNum; + int64_t copyLen; +}; + +#endif // LCCL_ALL2ALL_HIERARCHY_SMALL_H diff --git a/comm/lcal/src/ascendc_kernels/91093/allgather_hierarchy_double_ring.h b/comm/lcal/src/ascendc_kernels/91093/allgather_hierarchy_double_ring.h new file mode 100644 index 0000000000000000000000000000000000000000..babfaa525c5b8fecacca750d85dc6726e9eceffc --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/allgather_hierarchy_double_ring.h @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLGATHER_HIERARCHY_DOUBLE_RING_H +#define LCCL_ALLGATHER_HIERARCHY_DOUBLE_RING_H + +#include "collectives.h" +#include "ipc_queue.h" +using namespace AscendC; + +constexpr int STAGE_NUM = 4; +constexpr int QUE_DEPTH = 8; +constexpr int QUE_NUM_LOCAL = 2; +constexpr int RING_NUM = 2; +constexpr int STAGE_EVENT = 0; +constexpr int RING_EVENT = 1; + +enum STAGE { + HCCS_RING = 0, + HCCS_TO_OUT, + HCCS_TO_SIO, + SIO_TO_OUT +}; + +template +class AllGatherHierarchyDoubleRing : public Collectives { +public: + FORCE_INLINE_AICORE AllGatherHierarchyDoubleRing(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, Op::COPYONLY); + int64_t dataTotalSize = len * sizeof(T); + const int coreNumPerStep = blockNum / STAGE_NUM; + stage = blockIdx / coreNumPerStep; + int stageCoreIdx = blockIdx % coreNumPerStep; + dataSizePerCore = dataTotalSize / coreNumPerStep; + const int64_t inputOffset = stageCoreIdx * dataSizePerCore; + if (stageCoreIdx == coreNumPerStep - 1) { + dataSizePerCore = dataTotalSize - (coreNumPerStep - 1) * dataSizePerCore; + } + + inputGm.SetGlobalBuffer(input + inputOffset, dataSizePerCore); + if (stage == STAGE::HCCS_TO_OUT) { + for (int i = rank % RING_NUM; i < rankSize; i += RING_NUM) { + outputGm[i / RING_NUM].SetGlobalBuffer(output + dataTotalSize * i + inputOffset, dataSizePerCore); + } + } else if (stage == STAGE::SIO_TO_OUT) { + for (int i = (rank + 1) % RING_NUM; i < rankSize; i += RING_NUM) { + outputGm[i / RING_NUM].SetGlobalBuffer(output + dataTotalSize * i + inputOffset, dataSizePerCore); + } + } + + int64_t queTotalSize = IPC_BUFF_MAX_SIZE / coreNumPerStep; + int64_t queSize = queTotalSize / QUE_NUM_LOCAL; + int64_t queHccsOffset = stageCoreIdx * queTotalSize; + blockSize = queSize / QUE_DEPTH; + + queHccsLocal.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + queHccsOffset, queSize, blockSize); + queSioLocal.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + queHccsOffset + queSize, + queSize, blockSize); + rankRingForward = (rank + RING_NUM) % rankSize; + queHccsForward.Init(&sync, magic, shareAddrs[rankRingForward] + IPC_DATA_OFFSET + queHccsOffset, + queSize, blockSize); + rankSioAdjoint = rank ^ 1; + queSioAdjoint.Init(&sync, magic, + shareAddrs[rankSioAdjoint] + IPC_DATA_OFFSET + queHccsOffset + queSize, queSize, blockSize); + + for (int i = 0; i < STAGE_NUM; ++i) { + stageEvents[i] = sync.CalEventIdByMulBlockNum(STAGE_EVENT, stageCoreIdx + coreNumPerStep * i); + } + + DumpLcclLogInfo(LogId::INIT, Op::COPYONLY); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, Op::COPYONLY); + int count = rankSize / RING_NUM * CeilDiv(dataSizePerCore, blockSize); + if (stage == STAGE::HCCS_RING) { + ProcessHccsRing(count); + } else if (stage == STAGE::HCCS_TO_OUT) { + ProcessHccsToOut(count); + } else if (stage == STAGE::HCCS_TO_SIO) { + ProcessHccsToSio(count); + } else if (stage == STAGE::SIO_TO_OUT) { + ProcessSioToOut(count); + } + DumpLcclLogInfo(LogId::PROCESS, Op::COPYONLY); + } +private: + FORCE_INLINE_AICORE void ProcessHccsRing(const int count) + { + constexpr int dependencyNum = 3; + int deQueWaitRanks[dependencyNum] = {(rank + rankSize - RING_NUM) % rankSize, rank, rank}; + int deQueWaitEvents[dependencyNum] = { + sync.CalEventIdByMulBlockNum(RING_EVENT, blockIdx), + stageEvents[static_cast(STAGE::HCCS_TO_OUT)], + stageEvents[static_cast(STAGE::HCCS_TO_SIO)]}; + int64_t remainSize = dataSizePerCore; + int64_t dataSize = 0; + GlobalTensor input; + GlobalTensor output; + int64_t waitFlag = 0; + int i = 0; + while (i < count) { + int countRankId = (rank + i * RING_NUM) % rankSize; + if (countRankId == rank) { + dataSize = (remainSize >= blockSize) ? blockSize : remainSize; + input = inputGm[dataSizePerCore - remainSize]; + remainSize -= blockSize; + } else { + if (i == 1) { + sync.WaitSyncFlag(magic, 0, stageEvents[static_cast(STAGE::HCCS_RING)], rankRingForward); + waitFlag = sync.GetInnerFlag(rankRingForward, + stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + } + if (waitFlag < i - 1) { + waitFlag = sync.GetInnerFlag(rankRingForward, + stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + continue; + } + input = queHccsForward.ReadFront(); + } + queHccsLocal.DeQue(deQueWaitRanks, deQueWaitEvents, dependencyNum); + output = queHccsLocal.EnQue(); + CpGM2GMPingPong(dataSize, input, output, -1); + + sync.SetSyncFlag(magic, i, stageEvents[static_cast(STAGE::HCCS_RING)], rank); + if (countRankId != rank) { + if ((rank + (i + 1) * RING_NUM) % rankSize == rank) { + queHccsForward.ReadFront(); + sync.SetSyncFlag(magic, i, sync.CalEventIdByMulBlockNum(RING_EVENT, blockIdx), rank); + } else { + sync.SetSyncFlag(magic, i - 1, sync.CalEventIdByMulBlockNum(RING_EVENT, blockIdx), rank); + } + } + ++i; + } + } + + FORCE_INLINE_AICORE void ProcessHccsToOut(const int count) + { + GlobalTensor input; + GlobalTensor output; + int64_t remainSize = dataSizePerCore; + int64_t dataSize = 0; + sync.WaitSyncFlag(magic, 0, stageEvents[static_cast(STAGE::HCCS_RING)], rank); + int64_t waitFlag = sync.GetInnerFlag(rank, stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + int i = 0; + while (i < count) { + if (waitFlag < i) { + waitFlag = sync.GetInnerFlag(rank, stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + continue; + } + int countRankId = (rank + i * RING_NUM) % rankSize; + if (countRankId == rank) { + dataSize = (remainSize >= blockSize) ? blockSize : remainSize; + } + input = queHccsLocal.ReadFront(); + output = outputGm[countRankId / RING_NUM][dataSizePerCore - remainSize]; + CpGM2GMPingPong(dataSize, input, output, -1); + constexpr int32_t halfQueDepth = 2; + if (i % (QUE_DEPTH / halfQueDepth) == 0) { + sync.SetSyncFlag(magic, i, stageEvents[static_cast(STAGE::HCCS_TO_OUT)], rank); + } + if ((countRankId + RING_NUM) % rankSize == rank) { + remainSize -= blockSize; + } + ++i; + } + } + FORCE_INLINE_AICORE void ProcessHccsToSio(const int count) + { + GlobalTensor input; + GlobalTensor output; + int64_t remainSize = dataSizePerCore; + int64_t dataSize = 0; + sync.WaitSyncFlag(magic, 0, stageEvents[static_cast(STAGE::HCCS_RING)], rank); + int64_t waitFlag = sync.GetInnerFlag(rank, stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + int i = 0; + while (i < count) { + if (waitFlag < i) { + waitFlag = sync.GetInnerFlag(rank, stageEvents[static_cast(STAGE::HCCS_RING)]) & EVENT_ID_MASK; + continue; + } + int countRankId = (rank + i * RING_NUM) % rankSize; + if (countRankId == rank) { + dataSize = (remainSize >= blockSize) ? blockSize : remainSize; + remainSize -= blockSize; + } + input = queHccsLocal.ReadFront(); + queSioAdjoint.DeQue(rankSioAdjoint, stageEvents[static_cast(STAGE::SIO_TO_OUT)]); + output = queSioAdjoint.EnQue(); + CpGM2GMPingPong(dataSize, input, output, -1); + sync.SetSyncFlag(magic, i, stageEvents[static_cast(STAGE::HCCS_TO_SIO)], rank); + ++i; + } + } + FORCE_INLINE_AICORE void ProcessSioToOut(const int count) + { + GlobalTensor input; + GlobalTensor output; + int64_t remainSize = dataSizePerCore; + int64_t dataSize = 0; + sync.WaitSyncFlag(magic, 0, stageEvents[static_cast(STAGE::HCCS_TO_SIO)], rankSioAdjoint); + int64_t waitFlag = sync.GetInnerFlag(rankSioAdjoint, + stageEvents[static_cast(STAGE::HCCS_TO_SIO)]) & EVENT_ID_MASK; + int i = 0; + while (i < count) { + if (waitFlag < i) { + waitFlag = sync.GetInnerFlag(rankSioAdjoint, + stageEvents[static_cast(STAGE::HCCS_TO_SIO)]) & EVENT_ID_MASK; + continue; + } + int countRankId = (rankSioAdjoint + i * RING_NUM) % rankSize; + if (countRankId == rankSioAdjoint) { + dataSize = (remainSize >= blockSize) ? blockSize : remainSize; + } + input = queSioLocal.ReadFront(); + output = outputGm[countRankId / RING_NUM][dataSizePerCore - remainSize]; + CpGM2GMPingPong(dataSize, input, output, -1); + constexpr int32_t halfQueDepth = 2; + if (i % (QUE_DEPTH / halfQueDepth) == 0) { + sync.SetSyncFlag(magic, i, stageEvents[static_cast(STAGE::SIO_TO_OUT)], rank); + } + if ((countRankId + RING_NUM) % rankSize == rankSioAdjoint) { + remainSize -= blockSize; + } + ++i; + } + } +private: + int stageEvents[STAGE_NUM]; + GlobalTensor inputGm; + GlobalTensor outputGm[LCAL_MAX_RANK_SIZE / RING_NUM]; + IpcQueue queHccsLocal; + IpcQueue queHccsForward; + IpcQueue queSioLocal; + IpcQueue queSioAdjoint; + int64_t dataSizePerCore; + int stage; + int rankRingForward; + int rankSioAdjoint; + int64_t blockSize; +}; + +#endif // LCCL_ALLGATHER_HIERARCHY_DOUBLE_RING_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/91093/allreduce_big_data_sio.h b/comm/lcal/src/ascendc_kernels/91093/allreduce_big_data_sio.h new file mode 100644 index 0000000000000000000000000000000000000000..1c62f2e58625b530cc1309071ede4ad5781a683c --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/allreduce_big_data_sio.h @@ -0,0 +1,245 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLREDUCE_BIG_DATA_SIO_H +#define LCCL_ALLREDUCE_BIG_DATA_SIO_H + +#include "collectives.h" +#include "sync_collectives.h" +#include "ipc_queue.h" +using namespace AscendC; + +template +class AllReduceBigDataSio : protected Collectives { + constexpr static int QUEUE_DEPTH = 4; + +public: + FORCE_INLINE_AICORE AllReduceBigDataSio(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + + perStepBlockNum = rankSize; + + ipcBuffMaxSizeAligned = IPC_BUFF_MAX_SIZE / rankSize / QUEUE_DEPTH * rankSize * QUEUE_DEPTH; + perQueSize = ipcBuffMaxSizeAligned / rankSize; + perQueNum = perQueSize / sizeof(T); + curBlockSize = perQueSize / QUEUE_DEPTH; + curBlockNum = curBlockSize / sizeof(T); + atomOp = op; + for (int i = 0; i < rankSize; ++i) { + rankList[i] = i; + coreIdxList[i] = PING_PONG_SIZE * rankSize + blockIdx % perStepBlockNum; + } + + peerRank = blockIdx % perStepBlockNum; + perRankDataNum = len / rankSize; + + if (rank % RANK_SIZE_TWO == 0) { + adjRank = rank + 1; + } else { + adjRank = rank - 1; + } + + curRankDataNum = perRankDataNum; + if (blockIdx % perStepBlockNum == rankSize - 1) { + curRankDataNum = len - (rankSize - 1) * perRankDataNum; + } + pullRankDataNum = perRankDataNum; + if (rank == rankSize - 1) { + pullRankDataNum = len - rank * perRankDataNum; + } + inputBuffOffsetNum = blockIdx % rankSize * perRankDataNum; + + inputGt.SetGlobalBuffer((__gm__ T*)input + inputBuffOffsetNum, curRankDataNum); + + outputBuffOffsetNum = peerRank * perRankDataNum; + + outputGt.SetGlobalBuffer((__gm__ T*)output + outputBuffOffsetNum, curRankDataNum); + inputIpcGtOffsetNum = perQueSize * (blockIdx % perStepBlockNum); + + if (blockIdx / perStepBlockNum == 0) { + ProducerInit(); + } else if (blockIdx / perStepBlockNum == 1) { + ConsumerInit(); + } else { + PullerInit(); + } + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx / perStepBlockNum == 0) { + Producer(); + } else if (blockIdx / perStepBlockNum == 1) { + Consumer(); + } else { + Puller(); + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } +private: + FORCE_INLINE_AICORE void Producer() + { + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t remain = curRankDataNum; + int count = 0; + while (count < loopCount) { + inputQue.DeQue(rankList, coreIdxList, rankSize); + GlobalTensor outputGm = inputQue.EnQue(); + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + CpGM2GMPingPong(copyNum * sizeof(T), inputGt[count * curBlockNum], outputGm, COPYONLY); + sync.SetOuterFlag(magic, count); + + if (blockIdx % RANK_SIZE_TWO == rank % RANK_SIZE_TWO) { + sync.WaitOuterFlag(magic, count, rank, blockIdx); + sync.WaitOuterFlag(magic, count, adjRank, blockIdx); + GlobalTensor inputGm = sioAtomSrcQue.ReadFront(); + GlobalTensor outputGm = sioAtomDstQue.EnQue(); + CpGM2GMPingPong(copyNum * sizeof(T), inputGm, outputGm, atomOp); + } + sync.SetInnerFlag(magic, count); + remain = remain - curBlockNum; + count = count + 1; + } + } + + FORCE_INLINE_AICORE void Consumer() + { + int64_t atomLoopCount = CeilDiv(pullRankDataNum, curBlockNum); + int64_t atomRemain = pullRankDataNum; + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t remain = curRankDataNum; + int count = 0; + int64_t maxLoopCount = (loopCount < atomLoopCount) ? loopCount : atomLoopCount; + while (count < maxLoopCount) { + if (peerRank != rank && rank % RANK_SIZE_TWO == peerRank % RANK_SIZE_TWO && count != atomLoopCount) { + sync.WaitInnerFlag(magic, count, rank, rank); + sync.WaitInnerFlag(magic, count, peerRank, rank); + + GlobalTensor inputGm = srcQue.ReadFront(); + GlobalTensor outputGm = dstQue.EnQue(); + + int64_t atomCopyNum = (atomRemain < curBlockNum) ? atomRemain : curBlockNum; + CpGM2GMPingPong(atomCopyNum * sizeof(T), inputGm, outputGm, atomOp); + atomRemain = atomRemain - curBlockNum; + } + sync.SetOuterFlag(magic, count); + if (count == loopCount) { + break; + } + if (rank % RANK_SIZE_TWO == peerRank % RANK_SIZE_TWO) { + sync.WaitOneRankPartOuterFlag(magic, count, peerRank, perStepBlockNum, perStepBlockNum); + if (peerRank != rank) { + GlobalTensor inputGm = pullSrcQue.ReadFront(); + GlobalTensor outputGm = pullDstQue.EnQue(); + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + CpGM2GMPingPong(copyNum * sizeof(T), inputGm, outputGm, COPYONLY); + } + sync.SetInnerFlag(magic, count); + } + remain = remain - curBlockNum; + count = count + 1; + } + } + FORCE_INLINE_AICORE void Puller() + { + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t remain = curRankDataNum; + int count = 0; + while (count < loopCount) { + if (rank % RANK_SIZE_TWO == peerRank % RANK_SIZE_TWO) { + sync.WaitInnerFlag(magic, count, rank, blockIdx - perStepBlockNum); + } else { + sync.WaitInnerFlag(magic, count, adjRank, blockIdx - perStepBlockNum); + } + GlobalTensor inputGm = pullQue.ReadFront(); + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + CpGM2GMPingPong(copyNum * sizeof(T), inputGm, outputGt[count * curBlockNum], COPYONLY); + sync.SetInnerFlag(magic, count); + remain = remain - curBlockNum; + count = count + 1; + } + } + + FORCE_INLINE_AICORE void ProducerInit() + { + inputQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + if (blockIdx % RANK_SIZE_TWO == rank % RANK_SIZE_TWO) { + sioAtomSrcQue.Init(&sync, magic, shareAddrs[adjRank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + sioAtomDstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + } + } + FORCE_INLINE_AICORE void ConsumerInit() + { + srcQue.Init(&sync, magic, shareAddrs[peerRank] + IPC_DATA_OFFSET + rank * perQueSize, + perQueNum, curBlockNum); + dstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + rank * perQueSize, + perQueNum, curBlockNum); + if (peerRank != rank && rank % RANK_SIZE_TWO == peerRank % RANK_SIZE_TWO) { + pullSrcQue.Init(&sync, magic, shareAddrs[peerRank] + IPC_DATA_OFFSET + + peerRank * perQueSize, perQueNum, curBlockNum); + pullDstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + peerRank * perQueSize, perQueNum, curBlockNum); + } + } + + FORCE_INLINE_AICORE void PullerInit() + { + if (rank % RANK_SIZE_TWO == peerRank % RANK_SIZE_TWO) { + pullQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + } else { + pullQue.Init(&sync, magic, shareAddrs[adjRank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + } + } +private: + GlobalTensor inputGt; + GlobalTensor outputGt; + + int atomOp; + int64_t ipcBuffMaxSizeAligned; + + int64_t perRankDataNum; + int64_t curRankDataNum; + int64_t peerRank; + int64_t adjRank; + int64_t pullRankDataNum; + int64_t inputBuffOffsetNum; + int64_t outputBuffOffsetNum; + int64_t inputIpcGtOffsetNum; + int64_t curBlockSize; + int64_t perStepBlockNum; + int64_t curBlockNum; + int64_t perQueSize; + int64_t perQueNum; + + IpcQueue inputQue; + IpcQueue srcQue; + IpcQueue dstQue; + IpcQueue pullQue; + IpcQueue sioAtomSrcQue; + IpcQueue sioAtomDstQue; + IpcQueue pullSrcQue; + IpcQueue pullDstQue; + + int rankList[LCAL_MAX_RANK_SIZE]; + int coreIdxList[LCAL_MAX_RANK_SIZE]; +}; + +#endif // LCCL_ALLREDUCE_BIG_DATA_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/91093/allreduce_hierarchy_double_ring.h b/comm/lcal/src/ascendc_kernels/91093/allreduce_hierarchy_double_ring.h new file mode 100644 index 0000000000000000000000000000000000000000..8fc881494d75ca1ec34c9c4c19ec679a536753c3 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/allreduce_hierarchy_double_ring.h @@ -0,0 +1,420 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_ALLREDUCE_HIERARCHY_DOUBLE_RING_H +#define LCCL_ALLREDUCE_HIERARCHY_DOUBLE_RING_H + +#include "sync_collectives.h" +#include "collectives.h" +#include "ipc_queue.h" +using namespace AscendC; + +template +class AllReduceHierarchyDoubleRing : protected Collectives { + constexpr static int32_t RING_LAYER_NUM = 2; + constexpr static int32_t INPUT_CORE_NUM = 4; + constexpr static int32_t SIO_CORE_NUM = 12; + constexpr static int32_t RING_CORE_NUM = 12; + constexpr static int32_t OUTPUT_CORE_NUM = 6; + constexpr static int32_t IPC_QUE_DEPTH = 32; + constexpr static int32_t RING_GATHER_QUE_DEPTH = 3; + constexpr static int32_t SIO_GATHER_QUE_DEPTH = 2; + constexpr static int32_t INPUT_FLAG = 0 * RING_CORE_NUM; + constexpr static int32_t SIO_REDUCE_FLAG = 1 * RING_CORE_NUM; + constexpr static int32_t RING_REDUCE_FLAG = 2 * RING_CORE_NUM; + constexpr static int32_t RING_REDUCE_PEER_FLAG = 3 * RING_CORE_NUM; + constexpr static int32_t RING_GATHER_FLAG = 4 * RING_CORE_NUM; + constexpr static int32_t RING_GATHER_PEER_FLAG = 5 * RING_CORE_NUM; + constexpr static int32_t SIO_GATHER_PEER_FLAG = 6 * RING_CORE_NUM; + constexpr static int32_t SIO_GATHER_FLAG = 7 * RING_CORE_NUM; + constexpr static int32_t SIO_GATHER_OUTPUT_FLAG = 8 * RING_CORE_NUM; + constexpr static int32_t OUTPUT_FLAG = 9 * RING_CORE_NUM; + constexpr static int32_t INPUT_CORE_SCALE = RING_CORE_NUM / INPUT_CORE_NUM; + constexpr static int32_t SIO_CORE_SCALE = RING_CORE_NUM / SIO_CORE_NUM; + constexpr static int32_t OUTPUT_CORE_SCALE = RING_CORE_NUM / OUTPUT_CORE_NUM; + constexpr static int64_t BLOCK_NUM_ALIGN = BLOCK_SIZE / sizeof(T); + +public: + FORCE_INLINE_AICORE AllReduceHierarchyDoubleRing(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + atomOp = op; + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + blockNum = INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM + OUTPUT_CORE_NUM; + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + return; + } + sioLayerId = rank / RING_LAYER_NUM; + ringLayerId = rank % RING_LAYER_NUM; + ringRankSize = rankSize / RING_LAYER_NUM; + ringNextRankId = (sioLayerId + 1) % ringRankSize * RING_LAYER_NUM + ringLayerId; + ringPrevRankId = (sioLayerId + (ringRankSize - 1)) % ringRankSize * RING_LAYER_NUM + ringLayerId; + sioPeerRankId = sioLayerId * RING_LAYER_NUM + (ringLayerId + 1) % RING_LAYER_NUM; + ipcBlockNum = IPC_BUFF_MAX_SIZE / (IPC_QUE_DEPTH + RING_GATHER_QUE_DEPTH + SIO_GATHER_QUE_DEPTH) / sizeof(T); + dmaPerLoop = ipcBlockNum - rankSize; + loopCount = CeilDiv(len, rankSize * dmaPerLoop); + const int64_t sumDataLastLoop = len - (loopCount - 1) * rankSize * dmaPerLoop; + dmaLastLoop = sumDataLastLoop / rankSize; + dmaLastRankLoop = sumDataLastLoop - (rankSize - 1) * dmaLastLoop; + totalBlockDataNum = (loopCount - 1) * dmaPerLoop + dmaLastLoop; + + InitQue(); + inputTensor.SetGlobalBuffer((__gm__ T*) input); + outputTensor.SetGlobalBuffer((__gm__ T*) output); + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + for (curLoopCnt = 0; curLoopCnt < loopCount; ++curLoopCnt) { + for (sioLayerLoop = 0; sioLayerLoop < ringRankSize; ++sioLayerLoop) { + if (blockIdx < INPUT_CORE_NUM) { + Input2Ipc(); + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM) { + SioReduce(); + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM) { + RingReduce(); + } else { + PrepareOutput(); + } + ++ipcQueIdx; + } + + for (sioLayerLoop = 0; sioLayerLoop < ringRankSize; ++sioLayerLoop) { + if (blockIdx < INPUT_CORE_NUM) { + ; + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM) { + SioGather(); + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM) { + RingGather(); + } else { + Ipc2Output(); + } + ++gatherQueIdx; + } + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } +private: + IpcQueue inputQueList[INPUT_CORE_SCALE]; + IpcQueue sioQueList[SIO_CORE_SCALE]; + IpcQueue sioGatherSrc1QueList[SIO_CORE_SCALE]; + IpcQueue sioGatherSrc2QueList[SIO_CORE_SCALE]; + IpcQueue sioGatherDstQueList[SIO_CORE_SCALE]; + IpcQueue ringSrcQue; + IpcQueue ringDstQue; + IpcQueue ringGatherSrcQue; + IpcQueue ringGatherDstQue; + IpcQueue outputSrc1QueList[OUTPUT_CORE_SCALE]; + IpcQueue outputSrc2QueList[OUTPUT_CORE_SCALE]; + IpcQueue outputSrc3QueList[OUTPUT_CORE_SCALE]; + + IpcQueue *inputQue = nullptr; + IpcQueue *sioQue = nullptr; + IpcQueue *sioGatherSrc1Que = nullptr; + IpcQueue *sioGatherSrc2Que = nullptr; + IpcQueue *sioGatherDstQue = nullptr; + IpcQueue *outputSrc1Que = nullptr; + IpcQueue *outputSrc2Que = nullptr; + IpcQueue *outputSrc3Que = nullptr; + GlobalTensor srcIpcTensor; + GlobalTensor dstIpcTensor; + GlobalTensor inputTensor; + GlobalTensor outputTensor; + int atomOp = COPYONLY; + int32_t sioLayerId = 0; + int32_t ringLayerId = 0; + int32_t ringRankSize = 0; + int32_t ringNextRankId = 0; + int32_t ringPrevRankId = 0; + int32_t sioPeerRankId = 0; + int32_t localBlockIdx = 0; + int64_t ipcBlockNum = 0; + int64_t totalBlockDataNum = 0; + int64_t dmaPerLoop = 0; + int64_t dmaLastLoop = 0; + int64_t dmaLastRankLoop = 0; + int32_t ipcQueIdx = 0; + int32_t gatherQueIdx = 0; + int32_t loopCount = 0; + int32_t curLoopCnt = 0; + int32_t sioLayerLoop = 0; + int64_t coreDataNum = 0; + int64_t lastCoreDataNum = 0; + int64_t curCoreDataNum = 0; + + FORCE_INLINE_AICORE void InitQue() + { + const int64_t dmaSizePerCore = ipcBlockNum / RING_CORE_NUM * sizeof(T); + const int64_t ipcBlockSize = ipcBlockNum * sizeof(T); + if (blockIdx < INPUT_CORE_NUM) { + for (int32_t blockLoop = 0; blockLoop < INPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = blockIdx * INPUT_CORE_SCALE + blockLoop; + inputQueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + } + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM) { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - INPUT_CORE_NUM) * SIO_CORE_SCALE + blockLoop; + sioQueList[blockLoop].Init(&sync, magic, shareAddrs[sioPeerRankId] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + sioGatherSrc1QueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + sioGatherSrc2QueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + IPC_QUE_DEPTH * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * RING_GATHER_QUE_DEPTH, ipcBlockNum); + sioGatherDstQueList[blockLoop].Init(&sync, magic, shareAddrs[sioPeerRankId] + IPC_DATA_OFFSET + + (IPC_QUE_DEPTH + RING_GATHER_QUE_DEPTH) * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * SIO_GATHER_QUE_DEPTH, ipcBlockNum); + } + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM) { + localBlockIdx = (blockIdx - (INPUT_CORE_NUM + SIO_CORE_NUM)); + ringSrcQue.Init(&sync, magic, shareAddrs[ringPrevRankId] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + ringDstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + ringGatherSrcQue.Init(&sync, magic, shareAddrs[ringPrevRankId] + IPC_DATA_OFFSET + + IPC_QUE_DEPTH * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * RING_GATHER_QUE_DEPTH, ipcBlockNum); + ringGatherDstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + IPC_QUE_DEPTH * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * RING_GATHER_QUE_DEPTH, ipcBlockNum); + } else { + for (int32_t blockLoop = 0; blockLoop < OUTPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - (INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM)) * OUTPUT_CORE_SCALE + + blockLoop; + outputSrc1QueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + (IPC_QUE_DEPTH + RING_GATHER_QUE_DEPTH) * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * SIO_GATHER_QUE_DEPTH, ipcBlockNum); + outputSrc2QueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + IPC_QUE_DEPTH * ipcBlockSize + dmaSizePerCore * localBlockIdx, + ipcBlockNum * RING_GATHER_QUE_DEPTH, ipcBlockNum); + outputSrc3QueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + } + } + } + + FORCE_INLINE_AICORE void Input2Ipc() + { + for (int32_t blockLoop = 0; blockLoop < INPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = blockIdx * INPUT_CORE_SCALE + blockLoop; + inputQue = &(inputQueList[blockLoop]); + Input2IpcByCore(); + } + } + + FORCE_INLINE_AICORE void Input2IpcByCore() + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - 1 - sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + + (*inputQue).DeQue(rank, RING_REDUCE_PEER_FLAG + localBlockIdx); + const int32_t consumedQueIdx = ipcQueIdx - (IPC_QUE_DEPTH + ringRankSize - 1); + if (consumedQueIdx >= 0 && consumedQueIdx % ringRankSize == 0) { + sync.WaitSyncFlag(magic, consumedQueIdx, OUTPUT_FLAG + localBlockIdx, rank); + sync.WaitSyncFlag(magic, consumedQueIdx, RING_GATHER_PEER_FLAG + localBlockIdx, rank); + } + + BuildCoreDataNum(curLoopCnt, targetRankOffset); + srcIpcTensor = inputTensor[targetRankOffset * totalBlockDataNum + curLoopCnt * dmaPerLoop + + localBlockIdx * coreDataNum]; + dstIpcTensor = (*inputQue).EnQue(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, COPYONLY); + sync.SetSyncFlag(magic, ipcQueIdx, INPUT_FLAG + localBlockIdx, sioPeerRankId); + } + + FORCE_INLINE_AICORE void SioReduce() + { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + if (sioLayerLoop < ringRankSize - 1) { + sioGatherSrc1QueList[blockLoop].ReadFront(); + } + } + if (curLoopCnt > 0 && sioLayerLoop == 0) { + return; + } + const int32_t endIdx = (curLoopCnt < loopCount - 1) && (sioLayerLoop == ringRankSize - 1) ? 1 : 0; + for (int32_t i = 0; i <= endIdx; ++i) { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - INPUT_CORE_NUM) * SIO_CORE_SCALE + blockLoop; + sioQue = &(sioQueList[blockLoop]); + SioReduceByCore(curLoopCnt + i, (sioLayerLoop + i) % ringRankSize, ipcQueIdx + i); + } + } + } + + FORCE_INLINE_AICORE void SioReduceByCore(int32_t newLoopCnt, int32_t newLayerLoop, int32_t newQueIdx) + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - 1 - newLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + (ringLayerId + 1) % RING_LAYER_NUM; + + sync.WaitSyncFlag(magic, newQueIdx, INPUT_FLAG + localBlockIdx, rank); + BuildCoreDataNum(newLoopCnt, targetRankOffset); + srcIpcTensor = inputTensor[targetRankOffset * totalBlockDataNum + newLoopCnt * dmaPerLoop + + localBlockIdx * coreDataNum]; + dstIpcTensor = (*sioQue).EnQue(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, atomOp); + sync.SetSyncFlag(magic, newQueIdx, SIO_REDUCE_FLAG + localBlockIdx, sioPeerRankId); + } + + FORCE_INLINE_AICORE void BuildCoreDataNum(int32_t processLoopIdx, int32_t targetRankOffset) + { + const int64_t damCurLoop = (processLoopIdx == loopCount - 1) ? + (targetRankOffset == rankSize - 1 ? dmaLastRankLoop : dmaLastLoop) : dmaPerLoop; + coreDataNum = ipcBlockNum / RING_CORE_NUM; + const int32_t maxIdx = damCurLoop / coreDataNum; + const int32_t lastIdx = maxIdx >= RING_CORE_NUM ? (RING_CORE_NUM - 1) : maxIdx; + + lastCoreDataNum = damCurLoop - lastIdx * coreDataNum; + curCoreDataNum = localBlockIdx < lastIdx ? coreDataNum : (localBlockIdx == lastIdx ? lastCoreDataNum : 0); + } + + FORCE_INLINE_AICORE void SioGather() + { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - INPUT_CORE_NUM) * SIO_CORE_SCALE + blockLoop; + sioGatherSrc1Que = &(sioGatherSrc1QueList[blockLoop]); + sioGatherSrc2Que = &(sioGatherSrc2QueList[blockLoop]); + sioGatherDstQue = &(sioGatherDstQueList[blockLoop]); + SioGatherByCore(); + } + } + + FORCE_INLINE_AICORE void SioGatherByCore() + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + + sync.WaitSyncFlag(magic, gatherQueIdx, RING_GATHER_FLAG + localBlockIdx, rank); + if (gatherQueIdx >= SIO_GATHER_QUE_DEPTH) { + sync.WaitSyncFlag(magic, gatherQueIdx - SIO_GATHER_QUE_DEPTH, SIO_GATHER_OUTPUT_FLAG + localBlockIdx, rank); + } + BuildCoreDataNum(curLoopCnt, targetRankOffset); + srcIpcTensor = (sioLayerLoop == 0 ? (*sioGatherSrc1Que).ReadFront() : (*sioGatherSrc2Que).ReadFront()); + dstIpcTensor = (*sioGatherDstQue).ReadFront(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, COPYONLY); + sync.SetSyncFlag(magic, gatherQueIdx, SIO_GATHER_PEER_FLAG + localBlockIdx, sioPeerRankId); + sync.SetSyncFlag(magic, gatherQueIdx, SIO_GATHER_FLAG + localBlockIdx, rank); + } + + FORCE_INLINE_AICORE void RingReduce() + { + if (sioLayerLoop == 0) { + ringDstQue.ReadFront(); + return; + } + + const int32_t consumedQueIdx = ipcQueIdx - 1; + sync.WaitSyncFlag(magic, consumedQueIdx + 1, SIO_REDUCE_FLAG + localBlockIdx, rank); + if (sioLayerLoop == 1) { + sync.WaitSyncFlag(magic, consumedQueIdx, SIO_REDUCE_FLAG + localBlockIdx, ringPrevRankId); + } else { + sync.WaitSyncFlag(magic, consumedQueIdx - 1, RING_REDUCE_FLAG + localBlockIdx, + ringPrevRankId); + } + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - 1 -sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + BuildCoreDataNum(curLoopCnt, targetRankOffset); + srcIpcTensor = ringSrcQue.ReadFront(); + dstIpcTensor = ringDstQue.ReadFront(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, atomOp); + sync.SetSyncFlag(magic, consumedQueIdx, RING_REDUCE_FLAG + localBlockIdx, rank); + sync.SetSyncFlag(magic, consumedQueIdx, RING_REDUCE_PEER_FLAG + localBlockIdx, ringPrevRankId); + } + + FORCE_INLINE_AICORE void RingGather() + { + if (sioLayerLoop == 0) { + sync.SetSyncFlag(magic, gatherQueIdx, RING_GATHER_FLAG + localBlockIdx, rank); + sync.SetSyncFlag(magic, gatherQueIdx, RING_GATHER_PEER_FLAG + localBlockIdx, ringPrevRankId); + return; + } + + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + sync.WaitSyncFlag(magic, gatherQueIdx - 1, RING_GATHER_FLAG + localBlockIdx, ringPrevRankId); + if (gatherQueIdx > RING_GATHER_QUE_DEPTH) { + sync.WaitSyncFlag(magic, gatherQueIdx - RING_GATHER_QUE_DEPTH, OUTPUT_FLAG + localBlockIdx, rank); + if (targetRankOffset != ringPrevRankId) { + sync.WaitSyncFlag(magic, gatherQueIdx - RING_GATHER_QUE_DEPTH, RING_GATHER_PEER_FLAG + localBlockIdx, + rank); + } + } + + BuildCoreDataNum(curLoopCnt, targetRankOffset); + if (sioLayerLoop == 1) { + srcIpcTensor = ringSrcQue.ReadFront(); + } else { + srcIpcTensor = ringGatherSrcQue.ReadFront(); + } + dstIpcTensor = ringGatherDstQue.ReadFront(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, COPYONLY); + sync.SetSyncFlag(magic, gatherQueIdx, RING_GATHER_FLAG + localBlockIdx, rank); + if (gatherQueIdx > 0) { + sync.SetSyncFlag(magic, gatherQueIdx - 1, RING_GATHER_PEER_FLAG + localBlockIdx, ringPrevRankId); + } + if (sioLayerLoop == ringRankSize - 1) { + ringGatherSrcQue.ReadFront(); + } + } + + FORCE_INLINE_AICORE void PrepareOutput() + { + for (int32_t blockLoop = 0; blockLoop < OUTPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - (INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM)) * OUTPUT_CORE_SCALE + + blockLoop; + if (sioLayerLoop < ringRankSize - 1) { + outputSrc3QueList[blockLoop].ReadFront(); + } + } + } + + FORCE_INLINE_AICORE void Ipc2Output() + { + for (int32_t blockLoop = 0; blockLoop < OUTPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - (INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM)) * OUTPUT_CORE_SCALE + + blockLoop; + outputSrc1Que = &(outputSrc1QueList[blockLoop]); + outputSrc2Que = &(outputSrc2QueList[blockLoop]); + outputSrc3Que = &(outputSrc3QueList[blockLoop]); + Ipc2OutputByCore(); + } + } + + FORCE_INLINE_AICORE void Ipc2OutputByCore() + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - sioLayerLoop)) % ringRankSize; + const int32_t targetSioRankOffset = targetSioLayerId * RING_LAYER_NUM + (ringLayerId + 1) % RING_LAYER_NUM; + const int32_t targetRingRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + BuildCoreDataNum(curLoopCnt, targetSioRankOffset); + sync.WaitSyncFlag(magic, gatherQueIdx, SIO_GATHER_PEER_FLAG + localBlockIdx, rank); + srcIpcTensor = (*outputSrc1Que).ReadFront(); + dstIpcTensor = outputTensor[targetSioRankOffset * totalBlockDataNum + curLoopCnt * dmaPerLoop + + localBlockIdx * coreDataNum]; + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, COPYONLY); + sync.SetSyncFlag(magic, gatherQueIdx, SIO_GATHER_OUTPUT_FLAG + localBlockIdx, sioPeerRankId); + BuildCoreDataNum(curLoopCnt, targetRingRankOffset); + sync.WaitSyncFlag(magic, gatherQueIdx, SIO_GATHER_FLAG + localBlockIdx, rank); + srcIpcTensor = sioLayerLoop == 0 ? (*outputSrc3Que).ReadFront() : (*outputSrc2Que).ReadFront(); + dstIpcTensor = outputTensor[targetRingRankOffset * totalBlockDataNum + + curLoopCnt * dmaPerLoop + localBlockIdx * coreDataNum]; + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcIpcTensor, dstIpcTensor, COPYONLY); + sync.SetSyncFlag(magic, gatherQueIdx, OUTPUT_FLAG + localBlockIdx, rank); + } +}; +#endif // LCCL_ALLREDUCE_HIERARCHY_DOUBLE_RING_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_big_data_91093_4step.h b/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_big_data_91093_4step.h new file mode 100644 index 0000000000000000000000000000000000000000..d97e24a794ab3d334f80f5b2dab8ba28fc274b3a --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_big_data_91093_4step.h @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_REDUCE_SCATTER_BIG_DATA_91093_H +#define LCCL_REDUCE_SCATTER_BIG_DATA_91093_H + +#include "sync_collectives.h" +#include "collectives.h" +#include "ipc_queue.h" + +constexpr int PER_STEP_BLOCKNUM = 8; +constexpr int ARRAY_MAX_SIZE = 10; +constexpr int NUM_OF_TWO = 2; +constexpr int NUM_OF_THREE = 3; +constexpr int NUM_OF_FOUR = 4; + +template +class ReduceScatterBigData91093 : protected Collectives { +public: + __aicore__ inline ReduceScatterBigData91093(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + __aicore__ inline void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + constexpr int IPC_QUEUE_DEPTH_91093 = NUM_OF_FOUR; + atomOp = op; + relaBlockIdx = blockIdx % PER_STEP_BLOCKNUM; + ipcSizeOfBlock = IPC_BUFF_MAX_SIZE / rankSize; + ipcNumOfBlock = ipcSizeOfBlock / sizeof(T); + ipcBlockNum = ipcNumOfBlock / IPC_QUEUE_DEPTH_91093; + totalBlockDataNum = len; + loopCount = CeilDiv(totalBlockDataNum, ipcBlockNum); + dstOutputGlobal.SetGlobalBuffer((__gm__ T*)output, totalBlockDataNum); + if ((rank % NUM_OF_TWO) == 0) { + adjPeerRank = rank + 1; + } else { + adjPeerRank = rank - 1; + } + StepRankPerCoreInit(); + IpcQueueInit(); + if ((blockIdx / PER_STEP_BLOCKNUM) == 0) { + for (int i = 0; i < stepOneRankPerCore; i++) { + srcInputGlobal[i].SetGlobalBuffer((__gm__ T*)input + (blockIdx * stepOneOriginRankPerCore + i) * + totalBlockDataNum, totalBlockDataNum); + } + } + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + + __aicore__ inline void StepRankPerCoreInit() + { + int halfRankSize = rankSize / NUM_OF_TWO; + stepOneOriginRankPerCore = CeilDiv(rankSize, PER_STEP_BLOCKNUM); + stepTwoOriginRankPerCore = CeilDiv(halfRankSize, PER_STEP_BLOCKNUM); + stepThreeOriginRankPerCore = CeilDiv(halfRankSize, PER_STEP_BLOCKNUM); + stepOneInUseBlockNum = CeilDiv(rankSize, stepOneOriginRankPerCore); + stepTwoInUseBlockNum = CeilDiv(halfRankSize, stepTwoOriginRankPerCore); + stepThreeInUseBlockNum = CeilDiv(halfRankSize, stepThreeOriginRankPerCore); + if ((blockIdx / PER_STEP_BLOCKNUM) == 0) { + if ((blockIdx % PER_STEP_BLOCKNUM) == (stepOneInUseBlockNum - 1)) { + stepOneRankPerCore = rankSize - (blockIdx % PER_STEP_BLOCKNUM) * stepOneOriginRankPerCore; + } else { + stepOneRankPerCore = stepOneOriginRankPerCore; + } + } else if ((blockIdx / PER_STEP_BLOCKNUM) == 1) { + if ((blockIdx % PER_STEP_BLOCKNUM) == (stepTwoInUseBlockNum - 1)) { + stepTwoRankPerCore = halfRankSize - (blockIdx % PER_STEP_BLOCKNUM) * stepTwoOriginRankPerCore; + } else { + stepTwoRankPerCore = stepTwoOriginRankPerCore; + } + } else if ((blockIdx / PER_STEP_BLOCKNUM) == NUM_OF_TWO || (blockIdx / PER_STEP_BLOCKNUM) == NUM_OF_THREE) { + if (((blockIdx - PER_STEP_BLOCKNUM * NUM_OF_TWO) / NUM_OF_TWO) == (stepThreeInUseBlockNum - 1)) { + stepThreeRankPerCore = halfRankSize - ((blockIdx - PER_STEP_BLOCKNUM * NUM_OF_TWO) / + NUM_OF_TWO) * stepThreeOriginRankPerCore; + } else { + stepThreeRankPerCore = stepThreeOriginRankPerCore; + } + } + } + + __aicore__ inline void IpcQueueInit() + { + int ipcRank; + if ((blockIdx / PER_STEP_BLOCKNUM) == 0) { + for (int i = 0; i < stepOneRankPerCore; i++) { + ipcRank = blockIdx * stepOneOriginRankPerCore + i; + writeIpcQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + ipcRank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + } + } else if ((blockIdx / PER_STEP_BLOCKNUM) == 1) { + for (int i = 0; i < stepTwoRankPerCore; i++) { + ipcRank = ((blockIdx % PER_STEP_BLOCKNUM) * stepTwoOriginRankPerCore + i) * + NUM_OF_TWO + (rank % NUM_OF_TWO); + readIpcQue[i].Init(&sync, magic, shareAddrs[adjPeerRank] + IPC_DATA_OFFSET + ipcRank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + writeIpcQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + ipcRank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + } + } else if ((blockIdx / PER_STEP_BLOCKNUM) == NUM_OF_TWO || (blockIdx / PER_STEP_BLOCKNUM) == NUM_OF_THREE) { + for (int i = 0; i < stepThreeRankPerCore; i++) { + stepThreeRank = (((blockIdx - PER_STEP_BLOCKNUM * NUM_OF_TWO) / NUM_OF_TWO) * + stepThreeOriginRankPerCore + i) * NUM_OF_TWO + (rank % NUM_OF_TWO); + writeIpcQue[i].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + rank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + readIpcQue[i].Init(&sync, magic, shareAddrs[stepThreeRank] + IPC_DATA_OFFSET + rank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + } + } else if (blockIdx == (NUM_OF_FOUR * PER_STEP_BLOCKNUM)) { + readIpcQue[0].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + rank * + ipcNumOfBlock * sizeof(T), ipcNumOfBlock, ipcBlockNum); + } + } + + __aicore__ inline void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + int stepIndex = blockIdx / PER_STEP_BLOCKNUM; + if (stepIndex == 0 && ((relaBlockIdx * stepOneOriginRankPerCore) >= rankSize)) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + if (stepIndex == 1 && ((relaBlockIdx * stepTwoOriginRankPerCore) >= (rankSize / NUM_OF_TWO))) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + if ((stepIndex == NUM_OF_TWO || stepIndex == NUM_OF_THREE) && ((blockIdx - PER_STEP_BLOCKNUM * + NUM_OF_TWO) / NUM_OF_TWO * stepThreeOriginRankPerCore) >= (rankSize / NUM_OF_TWO)) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + if (stepIndex == 0) { + StepOneProcess(); + } else if (stepIndex == 1) { + StepTwoProcess(); + } else if ((stepIndex == NUM_OF_TWO || stepIndex == NUM_OF_THREE) && ((blockIdx % NUM_OF_TWO) == 0)) { + StepThreeProcess(); + } else if (blockIdx == (NUM_OF_FOUR * PER_STEP_BLOCKNUM)) { + StepFourProcess(); + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } + + __aicore__ inline void StepOneProcess() + { + for (int i = 0; i < stepOneRankPerCore; i++) { + if ((blockIdx * stepOneOriginRankPerCore + i) % NUM_OF_TWO == rank % NUM_OF_TWO) { + if ((blockIdx * stepOneOriginRankPerCore + i) == rank) { + waitWriteRankArr[i] = rank; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_FOUR; + } else { + waitWriteRankArr[i] = blockIdx * stepOneOriginRankPerCore + i; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_TWO + ((rank / NUM_OF_TWO) / + stepThreeOriginRankPerCore) * NUM_OF_TWO; + } + } else { + waitWriteRankArr[i] = adjPeerRank; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM + ((blockIdx * stepOneOriginRankPerCore + i) / + NUM_OF_TWO) / stepTwoOriginRankPerCore; + } + } + InputToIpcProcess(waitWriteRankArr, waitWriteBlockArr, stepOneRankPerCore); + } + __aicore__ inline void StepTwoProcess() + { + int waitReadRank; + int processRank; + waitReadRank = adjPeerRank; + for (int i = 0; i < stepTwoRankPerCore; i++) { + processRank = (relaBlockIdx * stepTwoOriginRankPerCore + i) * NUM_OF_TWO + (rank % NUM_OF_TWO); + waitReadRankArr[i] = waitReadRank; + waitReadBlockArr[i] = processRank / stepOneOriginRankPerCore; + if (processRank == rank) { + waitWriteRankArr[i] = rank; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_FOUR; + } else { + waitWriteRankArr[i] = processRank; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_TWO + ((rank / NUM_OF_TWO) / + stepThreeOriginRankPerCore) * NUM_OF_TWO; + } + } + SioAtomicToIpcProcess(waitReadRankArr, waitReadBlockArr, waitWriteRankArr, + waitWriteBlockArr, stepTwoRankPerCore); + } + + __aicore__ inline void StepThreeProcess() + { + for (int i = 0; i < stepThreeRankPerCore; i++) { + waitReadRankArr[i] = (((blockIdx - PER_STEP_BLOCKNUM * NUM_OF_TWO) / NUM_OF_TWO) * + stepThreeOriginRankPerCore + i) * NUM_OF_TWO + (rank % NUM_OF_TWO); + waitReadBlockArr[i] = PER_STEP_BLOCKNUM + (rank / NUM_OF_TWO) / stepTwoOriginRankPerCore; + waitWriteRankArr[i] = rank; + waitWriteBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_FOUR; + } + HccsAtomicToIpcProcess(waitReadRankArr, waitReadBlockArr, waitWriteRankArr, + waitWriteBlockArr, stepThreeRankPerCore); + } + + __aicore__ inline void StepFourProcess() + { + for (int i = 0; i < stepThreeInUseBlockNum; i++) { + waitReadRankArr[i] = rank; + waitReadBlockArr[i] = PER_STEP_BLOCKNUM * NUM_OF_TWO + i * NUM_OF_TWO; + } + IpcToOutputProcess(waitReadRankArr, waitReadBlockArr, stepThreeInUseBlockNum); + } + + __aicore__ inline void InputToIpcProcess(int *waitWriteRank, int *waitWriteBlock, int waitCount) + { + int processBlockNum = ipcBlockNum; + for (int count = 0; count < loopCount; count++) { + if (count == (loopCount - 1)) { + processBlockNum = totalBlockDataNum - ipcBlockNum * count; + } + for (int i = 0; i < waitCount; i++) { + writeIpcQue[i].DeQue(waitWriteRank[i], waitWriteBlock[i]); + dstIpcGlobal = writeIpcQue[i].EnQue(); + CpInputToIpc(count, processBlockNum, srcInputGlobal[i]); + } + sync.SetInnerFlag(magic, count); + } + } + + __aicore__ inline void SioAtomicToIpcProcess(int *waitReadRank, int *waitReadBlock, int *waitWriteRank, + int *waitWriteBlock, int waitCount) + { + int processBlockNum = ipcBlockNum; + for (int count = 0; count < loopCount; count++) { + if (count == (loopCount - 1)) { + processBlockNum = totalBlockDataNum - ipcBlockNum * count; + } + for (int i = 0; i < waitCount; i++) { + srcIpcGlobal = readIpcQue[i].ReadFront(); + sync.WaitInnerFlag(magic, count, waitReadRank[i], waitReadBlock[i]); + sync.WaitInnerFlag(magic, count, rank, waitReadBlock[i]); + writeIpcQue[i].DeQue(waitWriteRank[i], waitWriteBlock[i]); + dstIpcGlobal = writeIpcQue[i].EnQue(); + SioAtomicAddToIpc(count, processBlockNum, waitWriteRankArr[i], i); + } + sync.SetInnerFlag(magic, count); + } + } + + __aicore__ inline void HccsAtomicToIpcProcess(int *waitReadRank, int *waitReadBlock, int *waitWriteRank, + int *waitWriteBlock, int waitCount) + { + int processBlockNum = ipcBlockNum; + for (int count = 0; count < loopCount; count++) { + if (count == (loopCount - 1)) { + processBlockNum = totalBlockDataNum - ipcBlockNum * count; + } + for (int i = 0; i < waitCount; i++) { + sync.WaitInnerFlag(magic, count, waitReadRank[i], waitReadBlock[i]); + sync.WaitInnerFlag(magic, count, rank, waitReadBlock[i]); + srcIpcGlobal = readIpcQue[i].ReadFront(); + writeIpcQue[i].DeQue(waitWriteRank[i], waitWriteBlock[i]); + dstIpcGlobal = writeIpcQue[i].EnQue(); + HccsAtomicAddToIpc(count, processBlockNum, waitReadRank[i], i); + } + sync.SetInnerFlag(magic, count); + } + } + + __aicore__ inline void IpcToOutputProcess(int *waitReadRank, int *waitReadBlock, int waitCount) + { + int processBlockNum = ipcBlockNum; + for (int count = 0; count < loopCount; count++) { + if (count == (loopCount - 1)) { + processBlockNum = totalBlockDataNum - ipcBlockNum * count; + } + for (int i = 0; i < waitCount; i++) { + sync.WaitInnerFlag(magic, count, waitReadRank[i], waitReadBlock[i]); + } + srcIpcGlobal = readIpcQue[0].ReadFront(); + CpIpcToOutput(count, processBlockNum); + sync.SetInnerFlag(magic, count); + } + } + +protected: + GlobalTensor srcInputGlobal[ARRAY_MAX_SIZE]; + GlobalTensor srcIpcGlobal; + GlobalTensor dstIpcGlobal; + GlobalTensor dstOutputGlobal; + + int totalBlockDataNum; + int atomOp; + int relaBlockIdx; + int ipcBlockNum; + int loopCount; + int ipcNumOfBlock; + int ipcSizeOfBlock; + IpcQueue writeIpcQue[ARRAY_MAX_SIZE]; + IpcQueue readIpcQue[ARRAY_MAX_SIZE]; + int adjPeerRank; + int stepThreeRank; + int stepOneRankPerCore; + int stepTwoRankPerCore; + int stepThreeRankPerCore; + int stepOneOriginRankPerCore; + int stepTwoOriginRankPerCore; + int stepThreeOriginRankPerCore; + int stepOneInUseBlockNum; + int stepTwoInUseBlockNum; + int stepThreeInUseBlockNum; + int waitWriteRankArr[ARRAY_MAX_SIZE]; + int waitWriteBlockArr[ARRAY_MAX_SIZE]; + int waitReadRankArr[ARRAY_MAX_SIZE]; + int waitReadBlockArr[ARRAY_MAX_SIZE]; + +private: + __aicore__ inline void HccsAtomicAddToIpc(int num, int processBlockNum, int waitRank, int i) + { + if (waitRank != rank) { + CpGM2GMPingPong(processBlockNum * sizeof(T), srcIpcGlobal, dstIpcGlobal, atomOp); + } + } + + __aicore__ inline void CpInputToIpc(int num, int processBlockNum, GlobalTensor inputTensor) + { + CpGM2GMPingPong(processBlockNum * sizeof(T), inputTensor[num * ipcBlockNum], dstIpcGlobal, -1); + } + + __aicore__ inline void SioAtomicAddToIpc(int num, int processBlockNum, int processRank, int i) + { + CpGM2GMPingPong(processBlockNum * sizeof(T), srcIpcGlobal, dstIpcGlobal, atomOp); + } + + __aicore__ inline void CpIpcToOutput(int num, int processBlockNum) + { + CpGM2GMPingPong(processBlockNum * sizeof(T), srcIpcGlobal, dstOutputGlobal[num * ipcBlockNum], -1); + } +}; +#endif // LCCL_REDUCE_SCATTER_BIG_DATA_91093_H diff --git a/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_hierarchy_double_ring.h b/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_hierarchy_double_ring.h new file mode 100644 index 0000000000000000000000000000000000000000..1c7ac15695f3146ded4ff42151b5b721fae5ae91 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/91093/reduce_scatter_hierarchy_double_ring.h @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_REDUCE_SCATTER_HIERARCHY_DOUBLE_RING_H +#define LCCL_REDUCE_SCATTER_HIERARCHY_DOUBLE_RING_H + +#include "sync_collectives.h" +#include "collectives.h" +#include "ipc_queue.h" +using namespace AscendC; + +template +class ReduceScatterHierarchyDoubleRing : protected Collectives { + constexpr static int32_t RING_LAYER_NUM = 2; + constexpr static int32_t INPUT_CORE_NUM = 12; + constexpr static int32_t SIO_CORE_NUM = 12; + constexpr static int32_t RING_CORE_NUM = 12; + constexpr static int32_t IPC_QUE_DEPTH = 32; + constexpr static int32_t INPUT_SIO_PEER_FLAG = 0 * RING_CORE_NUM; + constexpr static int32_t SIO_REDUCE_FLAG = 1 * RING_CORE_NUM; + constexpr static int32_t RING_REDUCE_FLAG = 2 * RING_CORE_NUM; + constexpr static int32_t RING_REDUCE_PEER_FLAG = 3 * RING_CORE_NUM; + constexpr static int32_t OUTPUT_FLAG = 4 * RING_CORE_NUM; + constexpr static int32_t INPUT_FLAG = 5 * RING_CORE_NUM; + + constexpr static int32_t INPUT_CORE_SCALE = RING_CORE_NUM / INPUT_CORE_NUM; + constexpr static int32_t SIO_CORE_SCALE = RING_CORE_NUM / SIO_CORE_NUM; + constexpr static int64_t BLOCK_NUM_ALIGN = BLOCK_SIZE / sizeof(T); + constexpr static int32_t BREAK_CYCLE = 10; + +public: + FORCE_INLINE_AICORE ReduceScatterHierarchyDoubleRing(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + atomOp = op; + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + blockNum = INPUT_CORE_NUM + SIO_CORE_NUM + RING_CORE_NUM; + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + return; + } + sioLayerId = rank / RING_LAYER_NUM; + ringLayerId = rank % RING_LAYER_NUM; + ringRankSize = rankSize / RING_LAYER_NUM; + ringNextRankId = (sioLayerId + 1) % ringRankSize * RING_LAYER_NUM + ringLayerId; + ringPrevRankId = (sioLayerId + (ringRankSize - 1)) % ringRankSize * RING_LAYER_NUM + ringLayerId; + sioPeerRankId = sioLayerId * RING_LAYER_NUM + (ringLayerId + 1) % RING_LAYER_NUM; + ipcBlockNum = IPC_BUFF_MAX_SIZE / IPC_QUE_DEPTH / sizeof(T); + totalBlockDataNum = len; + loopCount = CeilDiv(totalBlockDataNum, ipcBlockNum); + dmaPerLoop = ipcBlockNum; + dmaLastLoop = totalBlockDataNum - (loopCount - 1) * ipcBlockNum; + const int64_t dmaSizePerCore = ipcBlockNum / RING_CORE_NUM * sizeof(T); + if (blockIdx < INPUT_CORE_NUM) { + for (int32_t blockLoop = 0; blockLoop < INPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = blockIdx * INPUT_CORE_SCALE + blockLoop; + inputQueList[blockLoop].Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + } + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM) { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - INPUT_CORE_NUM) * SIO_CORE_SCALE + blockLoop; + sioQueList[blockLoop].Init(&sync, magic, shareAddrs[sioPeerRankId] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + } + } else { + localBlockIdx = (blockIdx - (INPUT_CORE_NUM + SIO_CORE_NUM)); + ringSrcQue.Init(&sync, magic, shareAddrs[ringPrevRankId] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + ringDstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + + dmaSizePerCore * localBlockIdx, ipcBlockNum * IPC_QUE_DEPTH, ipcBlockNum); + } + inputTensor.SetGlobalBuffer((__gm__ T*) input); + outputTensor.SetGlobalBuffer((__gm__ T*) output); + DumpLcclLogInfo(LogId::INIT, static_cast(atomOp)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + + for (curLoopCnt = 0; curLoopCnt < loopCount; ++curLoopCnt) { + const int64_t damCurLoop = (curLoopCnt == loopCount - 1) ? dmaLastLoop : dmaPerLoop; + coreDataNum = damCurLoop / RING_CORE_NUM; + lastCoreDataNum = damCurLoop - (RING_CORE_NUM - 1) * coreDataNum; + for (sioLayerLoop = 0; sioLayerLoop < ringRankSize; ++sioLayerLoop) { + if (blockIdx < INPUT_CORE_NUM) { + Input2Ipc(); + } else if (blockIdx < INPUT_CORE_NUM + SIO_CORE_NUM) { + SioReduce(); + } else { + RingReduceOutput(); + } + ++ipcQueIdx; + } + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } + +private: + IpcQueue inputQueList[INPUT_CORE_SCALE]; + IpcQueue sioQueList[SIO_CORE_SCALE]; + IpcQueue ringSrcQue; + IpcQueue ringDstQue; + IpcQueue *inputQue = nullptr; + IpcQueue *sioQue = nullptr; + GlobalTensor inputTensor; + GlobalTensor outputTensor; + GlobalTensor srcTensor; + GlobalTensor dstTensor; + int atomOp = COPYONLY; + int32_t sioLayerId = 0; + int32_t ringLayerId = 0; + int32_t ringRankSize = 0; + int32_t ringNextRankId = 0; + int32_t ringPrevRankId = 0; + int32_t sioPeerRankId = 0; + int32_t localBlockIdx = 0; + int64_t ipcBlockNum = 0; + int64_t totalBlockDataNum = 0; + int64_t dmaPerLoop = 0; + int64_t dmaLastLoop = 0; + int32_t ipcQueIdx = 0; + int32_t loopCount = 0; + int32_t curLoopCnt = 0; + int32_t sioLayerLoop = 0; + int64_t coreDataNum = 0; + int64_t lastCoreDataNum = 0; + int64_t curCoreDataNum = 0; + + FORCE_INLINE_AICORE void Input2Ipc() + { + for (int32_t blockLoop = 0; blockLoop < INPUT_CORE_SCALE; ++blockLoop) { + localBlockIdx = blockIdx * INPUT_CORE_SCALE + blockLoop; + inputQue = &(inputQueList[blockLoop]); + Input2IpcByCore(); + } + } + + FORCE_INLINE_AICORE void Input2IpcByCore() + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - 1 - sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + ringLayerId; + curCoreDataNum = (localBlockIdx == RING_CORE_NUM - 1) ? lastCoreDataNum : coreDataNum; + srcTensor = inputTensor[targetRankOffset * totalBlockDataNum + curLoopCnt * ipcBlockNum + + localBlockIdx * coreDataNum]; + dstTensor = (*inputQue).EnQue(); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcTensor, dstTensor, COPYONLY); + sync.SetSyncFlag(magic, ipcQueIdx, INPUT_FLAG + localBlockIdx, rank); + } + + FORCE_INLINE_AICORE void SioReduce() + { + for (int32_t blockLoop = 0; blockLoop < SIO_CORE_SCALE; ++blockLoop) { + localBlockIdx = (blockIdx - INPUT_CORE_NUM) * SIO_CORE_SCALE + blockLoop; + sioQue = &(sioQueList[blockLoop]); + SioReduceByCore(); + } + } + + FORCE_INLINE_AICORE void SioReduceByCore() + { + const int32_t targetSioLayerId = (sioLayerId + (ringRankSize - 1 - sioLayerLoop)) % ringRankSize; + const int32_t targetRankOffset = targetSioLayerId * RING_LAYER_NUM + (ringLayerId + 1) % RING_LAYER_NUM; + + curCoreDataNum = (localBlockIdx == RING_CORE_NUM - 1) ? lastCoreDataNum : coreDataNum; + srcTensor = inputTensor[targetRankOffset * totalBlockDataNum + curLoopCnt * ipcBlockNum + + localBlockIdx * coreDataNum]; + dstTensor = (*sioQue).EnQue(); + if (ipcQueIdx == 0) { + sync.WaitSyncFlag(magic, ipcQueIdx, INPUT_FLAG + localBlockIdx, sioPeerRankId, BREAK_CYCLE); + } else { + sync.WaitSyncFlag(magic, ipcQueIdx, INPUT_FLAG + localBlockIdx, sioPeerRankId); + } + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcTensor, dstTensor, atomOp); + sync.SetSyncFlag(magic, ipcQueIdx, SIO_REDUCE_FLAG + localBlockIdx, sioPeerRankId); + } + + FORCE_INLINE_AICORE void RingReduceOutput() + { + if (sioLayerLoop == 0) { + ringDstQue.ReadFront(); + return; + } + curCoreDataNum = (localBlockIdx == RING_CORE_NUM - 1) ? lastCoreDataNum : coreDataNum; + srcTensor = ringSrcQue.ReadFront(); + dstTensor = ringDstQue.ReadFront(); + GlobalTensor srcOutTensor; + GlobalTensor dstOutTensor; + if (sioLayerLoop == ringRankSize - 1) { + ringSrcQue.ReadFront(); + srcOutTensor = dstTensor; + dstOutTensor = outputTensor[curLoopCnt * ipcBlockNum + localBlockIdx * coreDataNum]; + } + const int32_t consumedQueIdx = ipcQueIdx - 1; + if (consumedQueIdx == 0) { + sync.WaitSyncFlag(magic, consumedQueIdx, SIO_REDUCE_FLAG + localBlockIdx, ringPrevRankId, BREAK_CYCLE); + } else { + sync.WaitSyncFlag(magic, consumedQueIdx, SIO_REDUCE_FLAG + localBlockIdx, ringPrevRankId); + } + if (sioLayerLoop > 1) { + sync.WaitSyncFlag(magic, consumedQueIdx - 1, RING_REDUCE_FLAG + localBlockIdx, ringPrevRankId); + } + sync.WaitSyncFlag(magic, ipcQueIdx, INPUT_FLAG + localBlockIdx, rank); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcTensor, dstTensor, atomOp); + if (sioLayerLoop != ringRankSize - 1) { + sync.SetSyncFlag(magic, consumedQueIdx, RING_REDUCE_FLAG + localBlockIdx, rank); + } else { + sync.WaitSyncFlag(magic, ipcQueIdx, SIO_REDUCE_FLAG + localBlockIdx, rank); + CpGM2GMPingPong(curCoreDataNum * sizeof(T), srcOutTensor, dstOutTensor, COPYONLY); + } + } +}; + +#endif // LCCL_REDUCE_SCATTER_HIERARCHY_DOUBLE_RING_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/CMakeLists.txt b/comm/lcal/src/ascendc_kernels/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..873cb6a663e5025c6e771ad139f057b4113a7a81 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/CMakeLists.txt @@ -0,0 +1,193 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +include(../ascendc.cmake) +include_directories(.) +option(ENABLE_LCCL_910A5_OP "ENABLE lccl_910A5_op library and compile options" OFF) + +file(GLOB_RECURSE KERNEL_FILES *.cpp) +set_source_files_properties(${KERNEL_FILES} PROPERTIES LANGUAGE CCE) + +# 常规算子 +add_library(normal_lccl_op1_tmp OBJECT + lccl_op1.cpp + sync_collectives.h + collectives.h +) + +add_library(normal_lccl_op2_tmp OBJECT + lccl_op2.cpp + sync_collectives.h + collectives.h +) +# 设置编译选项 +target_compile_options(normal_lccl_op1_tmp PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} +) +# 设置编译选项 +target_compile_options(normal_lccl_op2_tmp PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} +) + +add_custom_target(normal_lccl_op1 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/normal_lccl_op1_tmp.dir/lccl_op1.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/normal_lccl_op1.dir/lccl_op1.cpp.o" --allow-multiple-definition +) +add_dependencies(normal_lccl_op1 normal_lccl_op1_tmp) +add_custom_target(normal_lccl_op2 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/normal_lccl_op2_tmp.dir/lccl_op2.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/normal_lccl_op2.dir/lccl_op2.cpp.o" --allow-multiple-definition +) +add_dependencies(normal_lccl_op2 normal_lccl_op2_tmp) + +# 带dump的mix算子 +add_library(dump_lccl_op1_tmp_aic OBJECT + lccl_op1.cpp + sync_collectives.h + collectives.h +) +add_library(dump_lccl_op1_tmp_aiv OBJECT + lccl_op1.cpp + sync_collectives.h + collectives.h +) +target_compile_options(dump_lccl_op1_tmp_aic PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIC_ARCH} + -DENABLE_LCCL_DUMP + -DENABLE_LCCL_MIX +) +target_compile_options(dump_lccl_op1_tmp_aiv PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} + --cce-long-call=true + -DENABLE_LCCL_DUMP + -DENABLE_LCCL_MIX +) + +add_custom_target(dump_lccl_op1 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/dump_lccl_op1_tmp_aic.dir/lccl_op1.cpp.o" + "CMakeFiles/dump_lccl_op1_tmp_aiv.dir/lccl_op1.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/dump_lccl_op1.dir/lccl_op1.cpp.o" --allow-multiple-definition +) +add_dependencies(dump_lccl_op1 dump_lccl_op1_tmp_aic dump_lccl_op1_tmp_aiv) + +add_library(dump_lccl_op2_tmp_aic OBJECT + lccl_op2.cpp + sync_collectives.h + collectives.h +) +add_library(dump_lccl_op2_tmp_aiv OBJECT + lccl_op2.cpp + sync_collectives.h + collectives.h +) +target_compile_options(dump_lccl_op2_tmp_aic PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIC_ARCH} + -DENABLE_LCCL_DUMP + -DENABLE_LCCL_MIX +) +target_compile_options(dump_lccl_op2_tmp_aiv PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} + --cce-long-call=true + -DENABLE_LCCL_DUMP + -DENABLE_LCCL_MIX +) + +add_custom_target(dump_lccl_op2 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/dump_lccl_op2_tmp_aic.dir/lccl_op2.cpp.o" + "CMakeFiles/dump_lccl_op2_tmp_aiv.dir/lccl_op2.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/dump_lccl_op2.dir/lccl_op2.cpp.o" --allow-multiple-definition +) +add_dependencies(dump_lccl_op2 dump_lccl_op2_tmp_aic dump_lccl_op2_tmp_aiv) + +# 不带dump的mix算子 + +add_library(mix_lccl_op1_tmp_aic OBJECT + lccl_op1.cpp + sync_collectives.h + collectives.h +) +add_library(mix_lccl_op1_tmp_aiv OBJECT + lccl_op1.cpp + sync_collectives.h + collectives.h +) +target_compile_options(mix_lccl_op1_tmp_aic PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIC_ARCH} + -DENABLE_LCCL_MIX +) +target_compile_options(mix_lccl_op1_tmp_aiv PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} + --cce-long-call=true + -DENABLE_LCCL_MIX +) + +add_custom_target(mix_lccl_op1 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/mix_lccl_op1_tmp_aic.dir/lccl_op1.cpp.o" + "CMakeFiles/mix_lccl_op1_tmp_aiv.dir/lccl_op1.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/mix_lccl_op1.dir/lccl_op1.cpp.o" --allow-multiple-definition +) +add_dependencies(mix_lccl_op1 mix_lccl_op1_tmp_aic mix_lccl_op1_tmp_aiv) + +add_library(mix_lccl_op2_tmp_aic OBJECT + lccl_op2.cpp + sync_collectives.h + collectives.h +) +add_library(mix_lccl_op2_tmp_aiv OBJECT + lccl_op2.cpp + sync_collectives.h + collectives.h +) +target_compile_options(mix_lccl_op2_tmp_aic PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIC_ARCH} + -DENABLE_LCCL_MIX +) +target_compile_options(mix_lccl_op2_tmp_aiv PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} + --cce-long-call=true + -DENABLE_LCCL_MIX +) + +add_custom_target(mix_lccl_op2 + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/mix_lccl_op2_tmp_aic.dir/lccl_op2.cpp.o" + "CMakeFiles/mix_lccl_op2_tmp_aiv.dir/lccl_op2.cpp.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "CMakeFiles/mix_lccl_op2.dir/lccl_op2.cpp.o" --allow-multiple-definition +) +add_dependencies(mix_lccl_op2 mix_lccl_op2_tmp_aic mix_lccl_op2_tmp_aiv) + +add_custom_target(lccl_op + COMMAND echo "generating lccl op ... ENABLE_LCCL_910A5_OP=${ENABLE_LCCL_910A5_OP}" + COMMAND rm -f lccl_op.o + COMMAND find CMakeFiles -name "*.o" ! -path "*tmp*" | sort | xargs -I {} sed '1s/^/DDDD/' {} >> lccl_op.o + COMMAND truncate -c -s ${LCAL_1OP_BIN_SIZE} lccl_op.o + COMMAND rm -f ${LCAL_CCE_PATH} +) +add_dependencies(lccl_op dump_lccl_op1 dump_lccl_op2 mix_lccl_op1 mix_lccl_op2 normal_lccl_op1 normal_lccl_op2) diff --git a/comm/lcal/src/ascendc_kernels/allgather.h b/comm/lcal/src/ascendc_kernels/allgather.h new file mode 100644 index 0000000000000000000000000000000000000000..d015de722d8c6f1c15308f6e7fb2799f065b35ce --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/allgather.h @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLGATHER_H +#define LCCL_ALLGATHER_H + +#include "collectives.h" + +using namespace AscendC; + +constexpr int64_t MEM_DMA_UNIT_SIZE = MEM_DMA_UNIT_INT_NUM * sizeof(int64_t); + +constexpr int64_t STEP1 = 1; + +template + +class AllGather : public Collectives { +public: + FORCE_INLINE_AICORE AllGather(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + globalRank = (reinterpret_cast<__gm__ CommArgs *>(commArgs))->rank; + globalRankSize = (reinterpret_cast<__gm__ CommArgs *>(commArgs))->rankSize; + localRankSize = (reinterpret_cast<__gm__ CommArgs *>(commArgs))->localRankSize; + baseOffsetSize = IPC_DATA_OFFSET; + GetBlockDataCount(len, blockNum, offsetFromInput, countToShare); + offsetToShare = offsetFromInput; + + inputGm.SetGlobalBuffer((__gm__ T*)input + offsetFromInput, countToShare); + if (extraFlag & ExtraFlag::RDMA) { + blockNumPerRank = blockNum / localRankSize; + useCoreNumToOutput = blockNumPerRank * localRankSize; + } else { + blockNumPerRank = blockNum / rankSize; + useCoreNumToOutput = blockNumPerRank * rankSize; + } + if (blockIdx >= useCoreNumToOutput) { + return; + } + GetBlockDataCount(len, blockNumPerRank, offsetFromShare, countToOutput); + blockRank = blockIdx / blockNumPerRank; + offsetToOutput = blockRank * len + offsetFromShare; + + if ((extraFlag & ExtraFlag::RDMA) == 0) { + outputGm.SetGlobalBuffer((__gm__ T*)output + offsetToOutput, countToOutput); + } + } + FORCE_INLINE_AICORE void Process() + { + if (extraFlag & ExtraFlag::RDMA) { + shareGm.SetGlobalBuffer((__gm__ T*)(shareAddrs[rank % localRankSize] + baseOffsetSize) + + len * globalRank + offsetToShare, countToShare); + if (countToShare > 0) { + CpGM2GMPingPong(countToShare * sizeof(T), inputGm, shareGm, COPYONLY); + } + sync.SetInnerFlag(magic, STEP1); + sync.WaitRankInnerFlag(magic, STEP1, blockRank); + if (blockIdx >= useCoreNumToOutput) { + return; + } + outputGm.SetGlobalBuffer((__gm__ T*)(shareAddrs[globalRank % localRankSize] + baseOffsetSize) + + len * (globalRank / localRankSize) * localRankSize + offsetToOutput, countToOutput); + shareGm.SetGlobalBuffer((__gm__ T*)(shareAddrs[blockRank] + baseOffsetSize) + + len * (globalRank / localRankSize) * localRankSize + offsetToOutput, countToOutput); + if (countToOutput > 0 && blockRank != rank) { + CpGM2GMPingPong(countToOutput * sizeof(T), shareGm, outputGm, COPYONLY); + } + } else { + shareGm.SetGlobalBuffer((__gm__ T*)(shareAddrs[rank] + baseOffsetSize) + offsetToShare, countToShare); + if (countToShare > 0) { + CpGM2GM(shareGm, inputGm, countToShare, COPYONLY); + } + sync.SetInnerFlag(magic, STEP1); + sync.WaitRankInnerFlag(magic, STEP1, blockRank); + if (blockIdx >= useCoreNumToOutput) { + return; + } + shareGm.SetGlobalBuffer((__gm__ T*)(shareAddrs[blockRank] + baseOffsetSize) + offsetFromShare, + countToOutput); + if (countToOutput > 0) { + CpGM2GM(outputGm, shareGm, countToOutput, COPYONLY); + } + } + } + +private: + + FORCE_INLINE_AICORE void GetBlockDataCount( + const int64_t dataLen, const int64_t useBlockNum, int64_t& blockDataOffset, int64_t& blockDataCount) + { + blockDataCount = CeilDiv(dataLen, useBlockNum); + blockDataCount = blockDataCount > MEM_DMA_UNIT_SIZE / sizeof(T) ? + blockDataCount : MEM_DMA_UNIT_SIZE / sizeof(T); + blockDataOffset = blockIdx % useBlockNum * blockDataCount; + if (blockDataOffset >= dataLen) { + blockDataOffset = dataLen; + blockDataCount = 0; + return; + } + if (blockDataOffset + blockDataCount > dataLen) { + blockDataCount = dataLen - blockDataOffset; + } + } +private: + GlobalTensor inputGm; + GlobalTensor outputGm; + GlobalTensor shareGm; + + int64_t baseOffsetSize; + int64_t offsetFromInput; + int64_t offsetToShare; + int64_t countToShare; + int64_t useCoreNumToOutput; + int64_t blockNumPerRank; + int64_t blockRank; + int64_t offsetFromShare;; + int64_t offsetToOutput; + int64_t countToOutput; + int globalRank; + int globalRankSize; + int localRankSize; +}; + +#endif // LCCL_ALLREDUCE_TWO_SHOT_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/allreduce_big_data.h b/comm/lcal/src/ascendc_kernels/allreduce_big_data.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ce0276dfeecea54c8c36b130768685827c3405 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/allreduce_big_data.h @@ -0,0 +1,256 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLREDUCE_BIG_DATA_H +#define LCCL_ALLREDUCE_BIG_DATA_H + +#include "allreduce_quant.h" +#include "sync_collectives.h" +#include "ipc_queue.h" +using namespace AscendC; + +template +class AllReduceBigData : protected AllReduceQuant { + constexpr static int QUEUE_DEPTH = 4; + constexpr static T oneCast = (T) 1; + +public: + FORCE_INLINE_AICORE AllReduceBigData(int rank, int rankSize, uint32_t extraFlag) + : AllReduceQuant(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + if constexpr(!std::is_same_v) { + BuildScaleOffset(scale, scaleCount, offset); + } + + if (blockIdx >= PING_PONG_SIZE * rankSize) { + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + return; + } + + perStepBlockNum = rankSize; + + __gm__ CommArgs *localArgs = reinterpret_cast<__gm__ CommArgs *>(commArgs); + int globalRankSize = localArgs->rankSize <= 0 ? rankSize : localArgs->rankSize; + int localRankSize = localArgs->localRankSize <= 0 ? rankSize : localArgs->localRankSize; + int serverNum = globalRankSize / localRankSize; + int64_t ipcBuffMaxSizeAligned = IPC_BUFF_MAX_SIZE / (globalRankSize + serverNum - 1) / + QUEUE_DEPTH / sizeof(T) /scaleNum * scaleNum * QUEUE_DEPTH * sizeof(T) * globalRankSize; + curBlockSize = ipcBuffMaxSizeAligned / localRankSize / QUEUE_DEPTH; + curBlockNum = curBlockSize / sizeof(T); + atomOp = op; + int64_t perQueSize = ipcBuffMaxSizeAligned / localRankSize; + int64_t perQueNum = perQueSize / sizeof(T); + + for (int i = 0; i < rankSize; ++i) { + rankList[i] = i; + coreIdxList[i] = rankSize + blockIdx % perStepBlockNum; + } + + peerRank = blockIdx % perStepBlockNum; + perRankDataNum = GetDataCount(len, rankSize) / scaleNum * scaleNum; + + curRankDataNum = perRankDataNum; + if (blockIdx % perStepBlockNum == rankSize - 1) { + curRankDataNum = len - (rankSize - 1) * perRankDataNum; + } + + pullRankDataNum = (rank == rankSize - 1) ? (len - rank * perRankDataNum) : perRankDataNum; + + inputBuffOffsetNum = blockIdx % rankSize * perRankDataNum; + + inputGt.SetGlobalBuffer((__gm__ U*)input + inputBuffOffsetNum, curRankDataNum); + + outputBuffOffsetNum = peerRank * perRankDataNum; + + outputGt.SetGlobalBuffer((__gm__ T*)output + outputBuffOffsetNum, curRankDataNum); + + inputIpcGtOffsetNum = perQueSize * (blockIdx % perStepBlockNum); + + if (blockIdx / perStepBlockNum == 0) { + inputQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + inputIpcGtOffsetNum, + perQueNum, curBlockNum); + } else { + srcQue.Init(&sync, magic, shareAddrs[peerRank] + IPC_DATA_OFFSET + rank * perQueSize, + perQueNum, curBlockNum); + dstQue.Init(&sync, magic, shareAddrs[rank] + IPC_DATA_OFFSET + rank * perQueSize, + perQueNum, curBlockNum); + pullQue.Init(&sync, magic, shareAddrs[peerRank] + IPC_DATA_OFFSET + peerRank * perQueSize, + perQueNum, curBlockNum); + } + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx >= PING_PONG_SIZE * rankSize) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + + if constexpr (!std::is_same_v) { + if (rankSize == 1 && blockIdx == 0) { + int64_t remain = curRankDataNum; + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t count = 0; + while (count < loopCount) { + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + Collectives::CpGM2GMPingPong(copyNum * sizeof(T), inputGt[count * curBlockNum], + outputGt[count * curBlockNum], COPYONLY); + remain -= curBlockNum; + ++count; + } + } + if (rankSize == 1) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + } + + if (blockIdx / perStepBlockNum == 0) { + Producer(); + } else { + Consumer(); + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } +private: + FORCE_INLINE_AICORE void Producer() + { + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t remain = curRankDataNum; + int count = 0; + while (count < loopCount) { + inputQue.DeQue(rankList, coreIdxList, rankSize); + GlobalTensor outputGm = inputQue.EnQue(); + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + if constexpr (std::is_same_v) { + Collectives::CpGM2GMPingPong(copyNum * sizeof(T), inputGt[count * curBlockNum], outputGm, COPYONLY); + } else { + if (blockIdx != rank) { + GlobalTensor outputGmTmp; + outputGmTmp.SetGlobalBuffer((__gm__ U*)outputGm.GetPhyAddr()); + Collectives::CpGM2GMPingPong(copyNum * sizeof(U), inputGt[count * curBlockNum], outputGmTmp, + COPYONLY); + } else { + CpGM2GMWithScale(copyNum, inputGt[count * curBlockNum], outputGm, COPYONLY); + } + } + sync.SetInnerFlag(magic, count); + + remain = remain - curBlockNum; + count = count + 1; + } + } + + FORCE_INLINE_AICORE void Consumer() + { + int64_t atomLoopCount = CeilDiv(pullRankDataNum, curBlockNum); + int64_t atomRemain = pullRankDataNum; + int64_t loopCount = CeilDiv(curRankDataNum, curBlockNum); + int64_t remain = curRankDataNum; + int count = 0; + while (count < loopCount || count < atomLoopCount) { + if (peerRank != rank && count != atomLoopCount) { + sync.WaitInnerFlag(magic, count, rank, rank); + sync.WaitInnerFlag(magic, count, peerRank, rank); + + GlobalTensor inputGm = srcQue.ReadFront(); + GlobalTensor outputGm = dstQue.EnQue(); + + int64_t atomCopyNum = (atomRemain < curBlockNum) ? atomRemain : curBlockNum; + if constexpr (std::is_same_v) { + Collectives::CpGM2GMPingPong(atomCopyNum * sizeof(T), inputGm, outputGm, atomOp); + } else { + GlobalTensor inputGmTmp; + inputGmTmp.SetGlobalBuffer((__gm__ U*)inputGm.GetPhyAddr()); + CpGM2GMWithScale(atomCopyNum, inputGmTmp, outputGm, atomOp); + } + atomRemain = atomRemain - curBlockNum; + } + sync.SetOuterFlag(magic, count); + if (count == loopCount) { + break; + } + sync.WaitOneRankPartOuterFlag(magic, count, peerRank, rankSize, rankSize); + if (!(extraFlag & ExtraFlag::RDMA)) { + GlobalTensor pullGm = pullQue.ReadFront(); + int64_t copyNum = (remain < curBlockNum) ? remain : curBlockNum; + Collectives::CpGM2GMPingPong(copyNum * sizeof(T), pullGm, outputGt[count * curBlockNum], COPYONLY); + } + + sync.SetInnerFlag(magic, count); + remain = remain - curBlockNum; + count = count + 1; + } + } + + FORCE_INLINE_AICORE void BuildScaleOffset(GM_ADDR scale, int64_t scaleCount, GM_ADDR offset) + { + if (scale != nullptr && offset != nullptr) { + scaleGt.SetGlobalBuffer((__gm__ T*)scale); + this->firstScale = scaleGt.GetValue(0); + this->offset =* reinterpret_cast<__gm__ T*>(offset); + this->scaleNum = scaleCount < 1 ? 1 : scaleCount; + isVectorScale = scaleCount > 1; + isEnableScale = scaleCount > 0 && !(*(uint16_t *)(&(this->offset)) == 0 && + scaleCount == 1 && *(uint16_t *)(&firstScale) == *(uint16_t *)(&oneCast)); + } + } + + FORCE_INLINE_AICORE void CpGM2GMWithScale(int64_t atomCopyNum, GlobalTensor inputGm, GlobalTensor outputGm, + int64_t atomOp) + { + if (!isEnableScale) { + Collectives::CpGM2GMPingPong(atomCopyNum * sizeof(T), inputGm, outputGm, atomOp); + } else if (!isVectorScale) { + CpGM2GMPingPong(atomCopyNum * sizeof(T), inputGm, outputGm, atomOp, firstScale, offset); + } else { + CpGM2GMPingPong(atomCopyNum * sizeof(T), inputGm, outputGm, atomOp, scaleGt, scaleNum, + offset); + } + } +private: + GlobalTensor inputGt; + GlobalTensor outputGt; + + int atomOp; + + int64_t perRankDataNum; + int64_t curRankDataNum; + int64_t peerRank; + int64_t pullRankDataNum; + int64_t inputBuffOffsetNum; + int64_t outputBuffOffsetNum; + int64_t inputIpcGtOffsetNum; + int64_t curBlockSize; + int64_t perStepBlockNum; + int64_t curBlockNum; + + IpcQueue inputQue; + IpcQueue srcQue; + IpcQueue dstQue; + IpcQueue pullQue; + + int rankList[LCAL_MAX_RANK_SIZE]; + int coreIdxList[LCAL_MAX_RANK_SIZE]; + + GlobalTensor scaleGt; + int64_t scaleNum = 1; + T firstScale = 1; + T offset = 0; + bool isEnableScale = false; + bool isVectorScale = false; +}; + +#endif // LCCL_ALLREDUCE_BIG_DATA_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/allreduce_one_shot.h b/comm/lcal/src/ascendc_kernels/allreduce_one_shot.h new file mode 100644 index 0000000000000000000000000000000000000000..9532a04f62ba3dc17bdf1d023e780fa748cac9fe --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/allreduce_one_shot.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLREDUCE_ONE_SHOT_H +#define LCCL_ALLREDUCE_ONE_SHOT_H + +#include "sync_collectives.h" +#include "allreduce_quant.h" + +using namespace AscendC; +template +class AllReduceOneShot : protected AllReduceQuant { + constexpr static T oneCast = (T) 1; + +public: + FORCE_INLINE_AICORE AllReduceOneShot(int rank, int rankSize, uint32_t extraFlag) + : AllReduceQuant(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + if constexpr(!std::is_same_v) { + BuildScaleOffset(scale, scaleCount, offset); + } + atomOp = op; + blockNum = blockNum / rankSize * rankSize; + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + return; + } + + corePerRank = blockNum / rankSize; + rankIDOfBlock = blockIdx / corePerRank; + + dataDMAPerCore = len / rankSize / corePerRank / scaleNum * scaleNum; + dataReducePerCore = len / corePerRank / scaleNum * scaleNum; + + blockDataNum = dataDMAPerCore; + if (blockIdx == rankSize * corePerRank - 1) { + blockDataNum = len - blockIdx * dataDMAPerCore; + } + + blockReduceNum = dataReducePerCore; + if (blockIdx % corePerRank == corePerRank - 1) { + blockReduceNum = len - blockIdx % corePerRank * dataReducePerCore; + } + + __gm__ U* curRankGm = (__gm__ U*)shareAddrs[rank] + IPC_DATA_OFFSET / sizeof(U); + __gm__ U* peerRankGm = (__gm__ U*)shareAddrs[rankIDOfBlock] + IPC_DATA_OFFSET / sizeof(U); + __gm__ U* intputGm = (__gm__ U*)input; + __gm__ T* outputGm = (__gm__ T*)output; + + srcInputGlobal.SetGlobalBuffer(intputGm + blockIdx * dataDMAPerCore); + dstIPCGlobal.SetGlobalBuffer(curRankGm + blockIdx * dataDMAPerCore); + copyOutputGlobal.SetGlobalBuffer(outputGm + blockIdx * dataDMAPerCore); + srcIPCGlobal.SetGlobalBuffer(peerRankGm + blockIdx % corePerRank * dataReducePerCore); + dstOutputGlobal.SetGlobalBuffer(outputGm + blockIdx % corePerRank * dataReducePerCore); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx >= blockNum) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + CpInputToBuffAndOutput(); + sync.SetInnerFlag(magic, 1); + + sync.WaitRankInnerFlag(magic, 1, rank); + sync.WaitRankInnerFlag(magic, 1, rankIDOfBlock); + if (rankIDOfBlock != rank) { + if constexpr (!std::is_same_v) { + if (!isEnableScale) { + Collectives::CpGM2GM(dstOutputGlobal, srcIPCGlobal, blockReduceNum, atomOp); + } else if (!isVectorScale) { + CpGM2GM(dstOutputGlobal, srcIPCGlobal, blockReduceNum, atomOp, firstScale, offset); + } else { + CpGM2GM(dstOutputGlobal, srcIPCGlobal, blockReduceNum, atomOp, scaleGt, scaleNum, offset); + } + } else { + Collectives::CpGM2GM(dstOutputGlobal, srcIPCGlobal, blockReduceNum, atomOp); + } + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } + + FORCE_INLINE_AICORE void CpInputToBuffAndOutput() + { + Collectives::CpGM2GM(dstIPCGlobal, srcInputGlobal, blockDataNum, COPYONLY); + if constexpr (!std::is_same_v) { + if (!isEnableScale) { + Collectives::CpGM2GM(copyOutputGlobal, srcInputGlobal, blockDataNum, COPYONLY); + } else if (!isVectorScale) { + CpGM2GM(copyOutputGlobal, srcInputGlobal, blockDataNum, COPYONLY, firstScale, offset); + } else { + CpGM2GM(copyOutputGlobal, srcInputGlobal, blockDataNum, COPYONLY, scaleGt, scaleNum, offset); + } + } else { + Collectives::CpGM2GM(copyOutputGlobal, srcInputGlobal, blockDataNum, -1); + } + } + +protected: + GlobalTensor srcInputGlobal; + GlobalTensor srcIPCGlobal; + GlobalTensor dstIPCGlobal; + GlobalTensor dstOutputGlobal; + GlobalTensor copyOutputGlobal; + + int rankIDOfBlock; + int corePerRank; + int dataDMAPerCore; + int dataReducePerCore; + int blockDataNum; + int blockReduceNum; + int atomOp; + GlobalTensor scaleGt; + int64_t scaleNum = 1; + T firstScale = 1; + T offset = 0; + bool isEnableScale = false; + bool isVectorScale = false; + +private: + FORCE_INLINE_AICORE void BuildScaleOffset(GM_ADDR scale, int64_t scaleCount, GM_ADDR offset) + { + if (scale != nullptr && offset != nullptr) { + this->offset =* reinterpret_cast<__gm__ T*>(offset); + scaleGt.SetGlobalBuffer((__gm__ T*)scale); + this->firstScale = scaleGt.GetValue(0); + this->scaleNum = scaleCount < 1 ? 1 : scaleCount; + isVectorScale = scaleCount > 1; + isEnableScale = scaleCount > 0 && !(*(uint16_t *)(&(this->offset)) == 0 && + scaleCount == 1 && *(uint16_t *)(&firstScale) == *(uint16_t *)(&oneCast)); + } + } +}; + +#endif // LCCL_ALLREDUCE_ONE_SHOT_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/allreduce_quant.h b/comm/lcal/src/ascendc_kernels/allreduce_quant.h new file mode 100644 index 0000000000000000000000000000000000000000..aa749abc81bfda31f9e6048c05b6f259c5663d75 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/allreduce_quant.h @@ -0,0 +1,217 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_ALLREDUCE_QUANT_H +#define LCCL_ALLREDUCE_QUANT_H +#include "collectives.h" +using namespace AscendC; + +class AllReduceQuant : protected Collectives { + constexpr static int32_t UB_HEAD_OFFSET = 96; + constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + ALIGN_SIZE; +public: + FORCE_INLINE_AICORE AllReduceQuant(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + + template + FORCE_INLINE_AICORE void CpGM2GM(const GlobalTensor& outputGT, const GlobalTensor& inputGT, + const uint32_t calCount, int op, T scale, T offset) + { + DataCopyGM2GM cpKernel; + cpKernel.Init(outputGT, inputGT, calCount, op); + cpKernel.Process(scale, offset); + } + + template + FORCE_INLINE_AICORE void CpGM2GM(const GlobalTensor& outputGT, const GlobalTensor& inputGT, + const uint32_t calCount, int op, const GlobalTensor& scaleGT, int64_t scaleCount, T offset) + { + DataCopyGM2GM cpKernel; + cpKernel.Init(outputGT, inputGT, calCount, op); + cpKernel.Process(scaleGT, scaleCount, offset); + } + + template + FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor& inputGT, + const GlobalTensor& outputGT, int op, T scale, T offset) + { + constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; + constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(T) + sizeof(U)) / ALIGN_SIZE * ALIGN_SIZE; + constexpr int32_t inputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(U); + constexpr int32_t outputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(T); + __gm__ U *input = const_cast<__gm__ U *>(inputGT.GetPhyAddr()); + __gm__ T *output = const_cast<__gm__ T *>(outputGT.GetPhyAddr()); + __ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET), (__ubuf__ U*)(UB_MID_OFFSET)}; + __ubuf__ T* outputUB[2] = {(__ubuf__ T*)(inputUB[0] + inputUbBlockSize / sizeof(U)), + (__ubuf__ T*)(inputUB[1] + inputUbBlockSize / sizeof(U))}; + __ubuf__ T* targetOutputUB = nullptr; + int inputOffsetNum = 0; + int outputOffsetNum = 0; + + SetAtomic(op); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + targetOutputUB = (i & 1) ? outputUB[0] : outputUB[1]; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(T) * sizeof(U)); + SetWaitEvent(eventId); + CastImpl(targetOutputUB, (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, size / sizeof(T)); + PipeBarrier(); + AddsImpl(targetOutputUB, targetOutputUB, offset, size / sizeof(T)); + PipeBarrier(); + MulsImpl(targetOutputUB, targetOutputUB, scale, size / sizeof(T)); + SetWaitEvent(eventId); + SetWaitEvent(eventId); + CpUB2GM(output + outputOffsetNum, targetOutputUB, size); + AscendC::SetFlag(eventId); + + dataSizeRemain -= size; + inputOffsetNum += (size / sizeof(T)); + outputOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + SetWaitEvent(EVENT_ID3); + UnsetAtomic(op); + return; + } + + template + FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor& inputGT, + const GlobalTensor& outputGT, int op, const GlobalTensor& scaleGT, int64_t scaleCount, T offset) + { + constexpr int32_t ubSplitSize = sizeof(T) + sizeof(U) + sizeof(T) + sizeof(U) + sizeof(T); + constexpr int64_t ubAlignNum = UB_SINGLE_DMA_SIZE_MAX / ubSplitSize / ALIGN_SIZE * ALIGN_SIZE; + __gm__ T *scale = const_cast<__gm__ T *>(scaleGT.GetPhyAddr()); + __gm__ U *input = const_cast<__gm__ U *>(inputGT.GetPhyAddr()); + __gm__ T *output = const_cast<__gm__ T *>(outputGT.GetPhyAddr()); + if (scaleCount > ubAlignNum) { + CpGM2GMPingPongForBigScale(dataSizeRemain, input, output, op, scale, scaleCount, offset); + } else { + CpGM2GMPingPongForSmallScale(dataSizeRemain, input, output, op, scale, scaleCount, offset); + } + return; + } + +protected: + template + FORCE_INLINE_AICORE void CpGM2GMPingPongForBigScale(int64_t dataSizeRemain, __gm__ U *input, + __gm__ T *output, int op, __gm__ T *scale, int64_t scaleCount, T offset) + { + constexpr int64_t mulVal = 2; + constexpr int64_t ubSplitSize = (sizeof(T) + sizeof(U) + sizeof(T)) * mulVal; + constexpr int64_t ubAlignNum = UB_SINGLE_DMA_SIZE_MAX / ubSplitSize / ALIGN_SIZE * ALIGN_SIZE; + const int64_t batchDataNum = (scaleCount + ubAlignNum - 1) / ubAlignNum; + + __ubuf__ T* scaleUB[2] = {(__ubuf__ T*)(UB_HEAD_OFFSET), (__ubuf__ T*)(UB_MID_OFFSET)}; + __ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET + ubAlignNum * sizeof(T)), + (__ubuf__ U*)(UB_MID_OFFSET + ubAlignNum * sizeof(T))}; + __ubuf__ T* outputUB[2] = {(__ubuf__ T*)(UB_HEAD_OFFSET + ubAlignNum * (sizeof(T) + sizeof(U))), + (__ubuf__ T*)(UB_MID_OFFSET + ubAlignNum * (sizeof(T) + sizeof(U)))}; + __ubuf__ T* targetOutputUB = nullptr; + int64_t i = 0; + int32_t curDataNum = 0; + int32_t processedNum = 0; + + SetAtomic(op); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + while (dataSizeRemain > 0) { + if (i % batchDataNum == batchDataNum - 1) { + curDataNum = scaleCount - i % batchDataNum * ubAlignNum; + } else { + curDataNum = ubAlignNum; + } + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + targetOutputUB = (i & 1) ? outputUB[0] : outputUB[1]; + + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + processedNum, curDataNum * sizeof(U)); + SetWaitEvent(eventId); + CpGM2UB((i & 1) ? scaleUB[0] : scaleUB[1], scale + i % batchDataNum * ubAlignNum, curDataNum * sizeof(T)); + CastImpl(targetOutputUB, (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, curDataNum); + SetWaitEvent(eventId); + AddsImpl(targetOutputUB, targetOutputUB, offset, curDataNum); + PipeBarrier(); + MulImpl(targetOutputUB, targetOutputUB, (i & 1) ? scaleUB[0] : scaleUB[1], curDataNum); + SetWaitEvent(eventId); + CpUB2GM(output + processedNum, targetOutputUB, curDataNum * sizeof(T)); + AscendC::SetFlag(eventId); + + dataSizeRemain -= curDataNum * sizeof(T); + processedNum += curDataNum; + ++i; + } + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + SetWaitEvent(EVENT_ID3); + UnsetAtomic(op); + return; + } + + template + FORCE_INLINE_AICORE void CpGM2GMPingPongForSmallScale(int64_t dataSizeRemain, __gm__ U *input, + __gm__ T *output, int op, __gm__ T *scale, int64_t scaleCount, T offset) + { + constexpr int32_t ubSplitSize = sizeof(T) + sizeof(U) + sizeof(T) + sizeof(U) + sizeof(T); + constexpr int64_t ubAlignNum = UB_SINGLE_DMA_SIZE_MAX / ubSplitSize / ALIGN_SIZE * ALIGN_SIZE; + const int64_t batchDataNum = ubAlignNum / scaleCount * scaleCount; + const int64_t ubMidOffset = ubAlignNum * (sizeof(T) + sizeof(U) + sizeof(T)) + UB_HEAD_OFFSET + ALIGN_SIZE; + + __ubuf__ T* scaleUB = (__ubuf__ T*)(UB_HEAD_OFFSET); + __ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET + ubAlignNum * sizeof(T)), (__ubuf__ U*)(ubMidOffset)}; + __ubuf__ T* outputUB[2] = {(__ubuf__ T*)(UB_HEAD_OFFSET + ubAlignNum * (sizeof(T) + sizeof(U))), + (__ubuf__ T*)(ubMidOffset + ubAlignNum * sizeof(U))}; + __ubuf__ T* targetOutputUB = nullptr; + int64_t processedNum = 0; + SetAtomic(op); + CpGM2UB(scaleUB, scale, scaleCount * sizeof(T)); + SetWaitEvent(EVENT_ID1); + int64_t repeatTimes = batchDataNum / scaleCount; + int64_t mulVal = 2; + for (int64_t i = 1; i < repeatTimes; i *= mulVal) { + PipeBarrier(); + CopyUB2UB(scaleUB + i * scaleCount, scaleUB, (repeatTimes > i * mulVal ? i : repeatTimes - i) * scaleCount); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > batchDataNum * sizeof(T) ? batchDataNum * sizeof(T) : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + targetOutputUB = (i & 1) ? outputUB[0] : outputUB[1]; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + processedNum, size / sizeof(T) * sizeof(U)); + SetWaitEvent(eventId); + CastImpl(targetOutputUB, (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, size / sizeof(T)); + PipeBarrier(); + AddsImpl(targetOutputUB, targetOutputUB, offset, size / sizeof(T)); + PipeBarrier(); + MulImpl(targetOutputUB, targetOutputUB, scaleUB, size / sizeof(T)); + SetWaitEvent(eventId); + CpUB2GM(output + processedNum, targetOutputUB, size); + AscendC::SetFlag(eventId); + dataSizeRemain -= size; + processedNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + SetWaitEvent(EVENT_ID3); + UnsetAtomic(op); + return; + } +}; + +#endif // LCCL_ALLREDUCE_QUANT_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/allreduce_two_shot.h b/comm/lcal/src/ascendc_kernels/allreduce_two_shot.h new file mode 100644 index 0000000000000000000000000000000000000000..3ef039603cc40ae7680ca59027aec2aa3d2c969f --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/allreduce_two_shot.h @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_ALLREDUCE_TWO_SHOT_H +#define LCCL_ALLREDUCE_TWO_SHOT_H + +#include "allreduce_quant.h" +#include "sync_collectives.h" +using namespace AscendC; +template +class AllReduceTwoShot : protected AllReduceQuant { + constexpr static int QUEUE_DEPTH = 4; + constexpr static T oneCast = (T) 1; + +public: + FORCE_INLINE_AICORE AllReduceTwoShot(int rank, int rankSize, uint32_t extraFlag) + : AllReduceQuant(rank, rankSize, extraFlag) {} + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + if constexpr(!std::is_same_v) { + BuildScaleOffset(scale, scaleCount, offset); + } + + if (blockIdx >= rankSize) { + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + return; + } + + blockNum = rankSize; + + __gm__ CommArgs *localArgs = reinterpret_cast<__gm__ CommArgs *>(commArgs); + + int localRankSize = localArgs->localRankSize <= 0 ? rankSize : localArgs->localRankSize; + int globalRankSize = localArgs->rankSize <= 0 ? rankSize : localArgs->rankSize; + int serverNum = globalRankSize / localRankSize; + int64_t ipcBuffMaxSizeAligned = IPC_BUFF_MAX_SIZE / (globalRankSize + serverNum - 1) / + QUEUE_DEPTH / sizeof(T) /scaleNum * scaleNum * QUEUE_DEPTH * sizeof(T) * globalRankSize; + ipcDataPerParagraphSize = ipcBuffMaxSizeAligned / localRankSize; + int64_t ipcDataPerParagraphNum = ipcDataPerParagraphSize / sizeof(T); + atomOp = op; + corePerRank = blockNum / rankSize; + coreSegmentedIdx = blockIdx % corePerRank; + peerRank = blockIdx / corePerRank; + perRankDataNum = GetDataCount(len, rankSize) / scaleNum * scaleNum; + curRankDataNum = (rank == rankSize - 1) ? (len - rank * perRankDataNum) : perRankDataNum; + pullRankDataNum = perRankDataNum; + if (peerRank == rankSize - 1) { + pullRankDataNum = len - peerRank * perRankDataNum; + } + pullBlockDataNum = GetDataCount(pullRankDataNum, corePerRank); + dataNumPreBlock = pullBlockDataNum; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumPreBlock = pullRankDataNum - coreSegmentedIdx * pullBlockDataNum; + } + buffOffsetNum = peerRank * perRankDataNum + coreSegmentedIdx * pullBlockDataNum + + ipcDataPerParagraphNum * peerRank; + + curBlockDataNum = GetDataCount(curRankDataNum, corePerRank); + ipcDataNumPreBlock = curBlockDataNum; + ipcbuffOffsetNum = rank * perRankDataNum + coreSegmentedIdx * curBlockDataNum + ipcDataPerParagraphNum * rank; + + inputGt.SetGlobalBuffer((__gm__ U*)input + buffOffsetNum - ipcDataPerParagraphNum * peerRank, dataNumPreBlock); + inputIpcGt.SetGlobalBuffer((__gm__ T*)(shareAddrs[rank] + IPC_DATA_OFFSET) + buffOffsetNum, dataNumPreBlock); + srcIpcGt.SetGlobalBuffer((__gm__ T*)(shareAddrs[peerRank] + IPC_DATA_OFFSET) + ipcbuffOffsetNum, + ipcDataNumPreBlock); + processIpcGt.SetGlobalBuffer((__gm__ T*)(shareAddrs[rank] + IPC_DATA_OFFSET) + ipcbuffOffsetNum, + ipcDataNumPreBlock); + + processedIpcGt.SetGlobalBuffer((__gm__ T*)(shareAddrs[peerRank] + IPC_DATA_OFFSET) + buffOffsetNum, + dataNumPreBlock); + outputGt.SetGlobalBuffer((__gm__ T*)output + buffOffsetNum - ipcDataPerParagraphNum * peerRank, + dataNumPreBlock); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + if (blockIdx >= rankSize) { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + return; + } + if constexpr (std::is_same_v) { + Collectives::CpGM2GM(inputIpcGt, inputGt, dataNumPreBlock, COPYONLY); + } else { + if (peerRank == rank) { + if (!isEnableScale) { + Collectives::CpGM2GM(inputIpcGt, inputGt, dataNumPreBlock, COPYONLY); + } else if (!isVectorScale) { + CpGM2GM(inputIpcGt, inputGt, dataNumPreBlock, COPYONLY, firstScale, offset); + } else { + CpGM2GM(inputIpcGt, inputGt, dataNumPreBlock, COPYONLY, scaleGt, scaleNum, offset); + } + } else { + GlobalTensor inputIpcGtTmp; + inputIpcGtTmp.SetGlobalBuffer((__gm__ U*)inputIpcGt.GetPhyAddr()); + Collectives::CpGM2GM(inputIpcGtTmp, inputGt, dataNumPreBlock, COPYONLY); + } + } + sync.SetInnerFlag(magic, 1); + + sync.WaitInnerFlag(magic, 1, rank, coreSegmentedIdx + rank * corePerRank); + sync.WaitInnerFlag(magic, 1, peerRank, coreSegmentedIdx + rank * corePerRank); + if (peerRank != rank) { + if constexpr (std::is_same_v) { + Collectives::CpGM2GM(processIpcGt, srcIpcGt, ipcDataNumPreBlock, atomOp); + } else { + GlobalTensor srcIpcGtTmp; + srcIpcGtTmp.SetGlobalBuffer((__gm__ U*)srcIpcGt.GetPhyAddr()); + if (!isEnableScale) { + Collectives::CpGM2GM(processIpcGt, srcIpcGtTmp, ipcDataNumPreBlock, atomOp); + } else if (!isVectorScale) { + CpGM2GM(processIpcGt, srcIpcGtTmp, ipcDataNumPreBlock, atomOp, firstScale, offset); + } else { + CpGM2GM(processIpcGt, srcIpcGtTmp, ipcDataNumPreBlock, atomOp, scaleGt, scaleNum, offset); + } + } + } + + if (!(extraFlag & ExtraFlag::RDMA)) { + sync.SetOuterFlag(magic, 1); + sync.WaitOneRankOuterFlag(magic, 1, peerRank); + Collectives::CpGM2GM(outputGt, processedIpcGt, dataNumPreBlock, COPYONLY); + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } + +private: + GlobalTensor inputGt; + GlobalTensor outputGt; + GlobalTensor inputIpcGt; + GlobalTensor srcIpcGt; + GlobalTensor processedIpcGt; + GlobalTensor processIpcGt; + + int atomOp; + + int64_t corePerRank; + int64_t coreSegmentedIdx; + int64_t ipcDataPerParagraphSize; + int64_t perRankDataNum; + int64_t curRankDataNum; + int64_t pullBlockDataNum; + int64_t curBlockDataNum; + int64_t peerRank; + int64_t pullRankDataNum; + int64_t dataNumPreBlock; + int64_t buffOffsetNum; + int64_t ipcDataNumPreBlock; + int64_t ipcbuffOffsetNum; + + GlobalTensor scaleGt; + int64_t scaleNum = 1; + T firstScale = 1; + T offset = 0; + bool isEnableScale = false; + bool isVectorScale = false; + FORCE_INLINE_AICORE void BuildScaleOffset(GM_ADDR scale, int64_t scaleCount, GM_ADDR offset) + { + if (scale != nullptr && offset != nullptr) { + scaleGt.SetGlobalBuffer((__gm__ T*)scale); + this->firstScale = scaleGt.GetValue(0); + this->scaleNum = scaleCount < 1 ? 1 : scaleCount; + this->offset =* reinterpret_cast<__gm__ T*>(offset); + isVectorScale = scaleCount > 1; + isEnableScale = scaleCount > 0 && !(*(uint16_t *)(&(this->offset)) == 0 && + scaleCount == 1 && *(uint16_t *)(&firstScale) == *(uint16_t *)(&oneCast)); + } + } +}; + +#endif // LCCL_ALLREDUCE_TWO_SHOT_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/collectives.h b/comm/lcal/src/ascendc_kernels/collectives.h new file mode 100644 index 0000000000000000000000000000000000000000..9743a2dcd7273cfd74893a327ea0567399aba6cb --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/collectives.h @@ -0,0 +1,502 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_COLLECTIVES_H +#define LCCL_COLLECTIVES_H + +#include + +#include "datacopy_gm2gm.h" +#include "datacopy_gm2gm_delay.h" +#include "sync_collectives.h" +using namespace AscendC; +using namespace Lcal; + +#define KERNELS_ARGS_FUN() \ +GM_ADDR input, GM_ADDR output, GM_ADDR commArgs, int64_t len, int64_t magic, int op, int root, int cycleCount, \ +GM_ADDR scale, int64_t scaleCount, GM_ADDR offset + +#define KERNELS_ARGS_CALL() \ +input, output, commArgs, len, magic, op, root, cycleCount, scale, scaleCount, offset + +#define KERNELS_GATHER_TABLE_ARGS_FUN() \ +GM_ADDR embTable, GM_ADDR lookup, GM_ADDR revData, int64_t lookupLen, int64_t embTableLen, int64_t embTableDim + +#define KERNELS_GATHER_TABLE_ARGS_CALL() \ +embTable, lookup, revData, lookupLen, embTableLen, embTableDim + +enum DfxPos : int { + MAGIC, + LEN, + RUN_STATUS +}; + +class Collectives { + constexpr static int32_t UB_HEAD_OFFSET = 96; + constexpr static int32_t UB_MID_OFFSET = UB_HEAD_OFFSET + UB_SINGLE_PING_PONG_ADD_SIZE_MAX + ALIGN_SIZE; +public: + FORCE_INLINE_AICORE Collectives(int rank, int rankSize, uint32_t extraFlag) : rank(rank), rankSize(rankSize), + extraFlag(extraFlag) {} + + FORCE_INLINE_AICORE ~Collectives() + { + const int64_t notRunning = 0xdead; + dfx.SetValue(RUN_STATUS, notRunning); + } + + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + dumpAddr_ = (reinterpret_cast<__gm__ CommArgs *>(commArgs))->dumpAddr; + GlobalTensor peerMemsAddrGm; + peerMemsAddrGm.SetGlobalBuffer(&(reinterpret_cast<__gm__ CommArgs *>(commArgs))->peerMems[0], + LCAL_MAX_RANK_SIZE); + for (int i = 0; i < rankSize; ++i) { + shareAddrs[i] = peerMemsAddrGm.GetValue(i) + + (magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET); + } + dfx.SetGlobalBuffer((reinterpret_cast<__gm__ CommArgs *>(commArgs))->dfx, + DFX_COUNT); + this->root = root; + this->len = len; + this->magic = magic; + this->localRank = reinterpret_cast<__gm__ CommArgs *>(commArgs)->localRank; + this->localRankSize = reinterpret_cast<__gm__ CommArgs *>(commArgs)->localRankSize; + this->xRankSize = localRankSize; + this->yRankSize = rankSize / localRankSize; + this->xRankIdx = rank % localRankSize; + this->yRankIdx = rank / localRankSize; + + blockIdx = GetBlockIdx(); + blockNum = GetBlockNum() * LCAL_BLOCK_NUM_MULTI; + + sync.Init(rank, rankSize, shareAddrs); + dfx.SetValue(MAGIC, magic); + dfx.SetValue(LEN, len); + const int64_t running = 0xbeef; + dfx.SetValue(RUN_STATUS, running); + } + + template + FORCE_INLINE_AICORE void DataCopyWrapPingPong(const GlobalTensor& inputGT, const GlobalTensor& outputGT, + int64_t dataSizeRemain, int op, TBuf tbuf) + { + if (dataSizeRemain <= 0) { + return; + } + LocalTensor localUB[2]; + localUB[0] = tbuf.GetWithOffset(UB_SINGLE_PING_PONG_ADD_SIZE_MAX, 0); + localUB[1] = tbuf.GetWithOffset(UB_SINGLE_PING_PONG_ADD_SIZE_MAX, UB_SINGLE_PING_PONG_ADD_SIZE_MAX); + + int inputOffsetNum = 0; + int outputOffsetNum = 0; + + PipeBarrier(); + if (op != COPYONLY) { + SetAscendCAtomic(op); + } + PipeBarrier(); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > UB_SINGLE_PING_PONG_ADD_SIZE_MAX ? + UB_SINGLE_PING_PONG_ADD_SIZE_MAX : dataSizeRemain; + TEventID eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + DataCopyWrap(localUB[(i & 1) ? 0 : 1], inputGT[inputOffsetNum], size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + DataCopyWrap(outputGT[outputOffsetNum], localUB[(i & 1) ? 0 : 1], size); + AscendC::SetFlag(eventId); + dataSizeRemain -= size; + inputOffsetNum += (size / sizeof(T)); + outputOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + + if (op != COPYONLY) { + SetAtomicNone(); + } + PipeBarrier(); + } + + template + FORCE_INLINE_AICORE void CpGM2GM(const GlobalTensor& outputGT, const GlobalTensor& inputGT, + const uint32_t calCount, int op) + { + DataCopyGM2GM cpKernel; + cpKernel.Init(outputGT, inputGT, calCount, op); + cpKernel.Process(); + } + + template + FORCE_INLINE_AICORE void CpGM2GMDelay(GlobalTensor& outputGT, GlobalTensor (&inputGT)[8], + GlobalTensor (&inputScaleGT)[8], const uint32_t calCount, int rankCount, GlobalTensor& outScaleGT, + TBuf tbuf) + { + DataCopyGM2GMDelay cpKernel; + cpKernel.Init(outputGT, inputGT, inputScaleGT, calCount, rankCount, outScaleGT, tbuf); + cpKernel.Process(); + } + + template + FORCE_INLINE_AICORE T1 CeilDiv(T1 a, T2 b) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b; + } + + template + FORCE_INLINE_AICORE void VecAddCce(int64_t curDealSize, __ubuf__ T *ubuf0, __ubuf__ T *ubuf1) + { + if (curDealSize > MAX_VADD_SIZE) { + vadd(ubuf0, ubuf1, ubuf0, VADD_MAX_REPEAT, 1, 1, 1, + VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO); + vadd((__ubuf__ T*)((__ubuf__ int8_t*)ubuf0 + VADD_MAX_REPEAT * VADD_UNIT_BYTE), + (__ubuf__ T*)((__ubuf__ int8_t*)ubuf1 + VADD_MAX_REPEAT * VADD_UNIT_BYTE), + (__ubuf__ T*)((__ubuf__ int8_t*)ubuf0 + VADD_MAX_REPEAT * VADD_UNIT_BYTE), + CeilDiv((curDealSize - MAX_VADD_SIZE), VADD_UNIT_BYTE), 1, 1, 1, + VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO); + } else { + vadd(ubuf0, ubuf1, ubuf0, CeilDiv(curDealSize, VADD_UNIT_BYTE), 1, 1, 1, + VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO); + } + } + + template + FORCE_INLINE_AICORE void LoopVaddCceProcess(__ubuf__ T* localUB[2], const int64_t remainSize, + int64_t (&targetRankArr)[8], const int64_t targetRankArrValidSize, const int64_t srcIpcOffsetNum, + __gm__ T *srcGmMem, __gm__ T *dstIpcMem, int64_t alreadyDealNum) + { + for (int64_t alreadyDealSize = 0; alreadyDealSize < remainSize; + alreadyDealSize += UB_SINGLE_PING_PONG_ADD_SIZE_MAX) { + int64_t curDealSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; + if (remainSize - alreadyDealSize < UB_SINGLE_PING_PONG_ADD_SIZE_MAX) { + curDealSize = remainSize - alreadyDealSize; + } + if (alreadyDealSize != 0) { + AscendC::WaitFlag(EVENT_ID0); + } + CpGM2UB(localUB[0], srcGmMem + alreadyDealNum, curDealSize); + + for (int64_t i = 0; i < targetRankArrValidSize; i++) { + int64_t targetRank = targetRankArr[i]; + if (targetRank == rank) { + continue; + } + if (i > 0 && !((targetRankArr[0] == rank) && i == 1)) { + AscendC::WaitFlag(EVENT_ID1); + } + CpGM2UB(localUB[1], + (__gm__ T*)(shareAddrs[targetRank] + IPC_DATA_OFFSET) + srcIpcOffsetNum + alreadyDealNum, + curDealSize); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + VecAddCce(curDealSize, localUB[0], localUB[1]); + if (((i + 1) == targetRankArrValidSize)) { + continue; + } + if ((i + 1 == targetRankArrValidSize - 1) && (targetRankArr[i + 1] == rank)) { + continue; + } + AscendC::SetFlag(EVENT_ID1); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM((__gm__ T*)dstIpcMem + alreadyDealNum, localUB[0], curDealSize); + if (alreadyDealSize + UB_SINGLE_PING_PONG_ADD_SIZE_MAX < remainSize) { + AscendC::SetFlag(EVENT_ID0); + } + alreadyDealNum += curDealSize / sizeof(T); + } + } + + template + FORCE_INLINE_AICORE void LoopVaddCce(__ubuf__ T* localUB[2], const int64_t remainNum, int64_t (&targetRankArr)[8], + int64_t targetRankArrValidSize, int64_t srcIpcOffsetNum, __gm__ T *srcGmMem, __gm__ T *dstIpcMem) + { + AscendC::PipeBarrier(); + LoopVaddCceProcess(localUB, remainNum * (int64_t)sizeof(T), targetRankArr, targetRankArrValidSize, + srcIpcOffsetNum, srcGmMem, dstIpcMem, 0); + AscendC::PipeBarrier(); + } + + template + FORCE_INLINE_AICORE void CpGM2GMPingPong(int64_t dataSizeRemain, const GlobalTensor& inputGT, + const GlobalTensor& outputGT, int op) + { + constexpr int32_t ubBlockSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; + constexpr int32_t ubAlignNum = ubBlockSize / (sizeof(T) + sizeof(U)) / ALIGN_SIZE * ALIGN_SIZE; + constexpr int32_t inputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(U); + constexpr int32_t outputUbBlockSize = std::is_same_v ? ubBlockSize : ubAlignNum * sizeof(T); + + __gm__ U *input = const_cast<__gm__ U *>(inputGT.GetPhyAddr()); + __gm__ T *output = const_cast<__gm__ T *>(outputGT.GetPhyAddr()); + __ubuf__ U* inputUB[2] = {(__ubuf__ U*)(UB_HEAD_OFFSET), (__ubuf__ U*)(UB_MID_OFFSET)}; + __ubuf__ T* outputUB[2] = {(__ubuf__ T*)inputUB[0], (__ubuf__ T*)inputUB[1]}; + if constexpr (!std::is_same_v) { + outputUB[0] = (__ubuf__ T*)(inputUB[0] + inputUbBlockSize / sizeof(U)); + outputUB[1] = (__ubuf__ T*)(inputUB[1] + inputUbBlockSize / sizeof(U)); + } + int inputOffsetNum = 0; + int outputOffsetNum = 0; + if (dataSizeRemain <= 0) { + return; + } + + SetAtomic(op); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > outputUbBlockSize ? outputUbBlockSize : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], input + inputOffsetNum, size / sizeof(T) * sizeof(U)); + if constexpr (!std::is_same_v) { + SetWaitEvent(eventId); + CastImpl((i & 1) ? outputUB[0] : outputUB[1], (i & 1) ? inputUB[0] : inputUB[1], RoundMode::CAST_NONE, + size / sizeof(T)); + SetWaitEvent(eventId); + } + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM(output + outputOffsetNum, (i & 1) ? outputUB[0] : outputUB[1], size); + AscendC::SetFlag(eventId); + dataSizeRemain -= size; + inputOffsetNum += (size / sizeof(T)); + outputOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + UnsetAtomic(op); + return; + } + + template + FORCE_INLINE_AICORE void VecAdd(int64_t curDealSize, LocalTensor &ubuf0, LocalTensor &ubuf1) + { + if (curDealSize > MAX_VADD_SIZE) { + Add(ubuf0, ubuf1, ubuf0, MASK_PLACEHOLDER, VADD_MAX_REPEAT, + {1, 1, 1, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO}); + + Add(ubuf0[MAX_VADD_SIZE / sizeof(T)], ubuf1[MAX_VADD_SIZE / sizeof(T)], + ubuf0[MAX_VADD_SIZE / sizeof(T)], MASK_PLACEHOLDER, + CeilDiv((curDealSize - MAX_VADD_SIZE), VADD_UNIT_BYTE), + {1, 1, 1, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO}); + } else { + Add(ubuf0, ubuf1, ubuf0, MASK_PLACEHOLDER, CeilDiv(curDealSize, VADD_UNIT_BYTE), + {1, 1, 1, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO, VADD_UNIT_TO_BLOCK_UNIT_RATIO}); + } + } + + template + FORCE_INLINE_AICORE void LoopVadd(TBuf tbuf, int64_t &remainNum, int64_t (&targetRankArr)[8], + int64_t targetRankArrValidSize, int64_t srcIpcOffsetNum, const GlobalTensor &srcGt, + const GlobalTensor &dstGt) + { + if (remainNum <= 0) { + return; + } + LocalTensor localUB[2]; + localUB[0] = tbuf.GetWithOffset(95 * 1024, 0); + localUB[1] = tbuf.GetWithOffset(95 * 1024, 95 * 1024); + + AscendC::PipeBarrier(); + LoopVaddProcess(localUB, remainNum * sizeof(T), targetRankArr, targetRankArrValidSize, + srcIpcOffsetNum, srcGt, dstGt, 0); + AscendC::PipeBarrier(); + } + template + FORCE_INLINE_AICORE void LoopVaddProcess(LocalTensor (&localUB)[2], const int64_t remainSize, + int64_t (&targetRankArr)[8], const int64_t targetRankArrValidSize, const int64_t srcIpcOffsetNum, + const GlobalTensor &srcGt, const GlobalTensor &dstGt, int64_t alreadyDealNum) + { + for (int64_t alreadyDealSize = 0; alreadyDealSize < remainSize; + alreadyDealSize += UB_SINGLE_PING_PONG_ADD_SIZE_MAX) { + int64_t curDealSize = UB_SINGLE_PING_PONG_ADD_SIZE_MAX; + if (remainSize - alreadyDealSize < UB_SINGLE_PING_PONG_ADD_SIZE_MAX) { + curDealSize = remainSize - alreadyDealSize; + } + if (alreadyDealSize != 0) { + AscendC::WaitFlag(EVENT_ID0); + } + DataCopyWrap(localUB[0], srcGt[alreadyDealNum], curDealSize); + + for (int64_t i = 0; i < targetRankArrValidSize; i++) { + int64_t targetRank = targetRankArr[i]; + if (targetRank == rank) { + continue; + } + if (i > 0 && !((targetRankArr[0] == rank) && i == 1)) { + AscendC::WaitFlag(EVENT_ID1); + } + GlobalTensor srcGtTmp; + srcGtTmp.SetGlobalBuffer( + (__gm__ T*)(shareAddrs[targetRank] + IPC_DATA_OFFSET) + srcIpcOffsetNum + alreadyDealNum); + DataCopyWrap(localUB[1], srcGtTmp, curDealSize); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID2); + AscendC::WaitFlag(EVENT_ID2); + VecAdd(curDealSize, localUB[0], localUB[1]); + if (((i + 1) == targetRankArrValidSize)) { + continue; + } + if ((i + 1 == targetRankArrValidSize - 1) && (targetRankArr[i + 1] == rank)) { + continue; + } + AscendC::SetFlag(EVENT_ID1); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopyWrap(dstGt[alreadyDealNum], localUB[0], curDealSize); + if (alreadyDealSize + UB_SINGLE_PING_PONG_ADD_SIZE_MAX < remainSize) { + AscendC::SetFlag(EVENT_ID0); + } + alreadyDealNum += curDealSize / sizeof(T); + } + } + + template + FORCE_INLINE_AICORE void SetSingleValue2Gm(GM_ADDR gm, T value) + { + AscendC::PipeBarrier(); + __ubuf__ T *inputUB = (__ubuf__ T *)(96); + *inputUB = value; + AscendC::PipeBarrier(); + CpUB2GM((__gm__ T *)gm, inputUB, sizeof(T)); + AscendC::PipeBarrier(); + } + +protected: + int rank; + int rankSize; + int localRank = 0; + int localRankSize = 0; + int xRankSize = 0; + int yRankSize = 0; + int xRankIdx = 0; + int yRankIdx = 0; + uint32_t extraFlag; + int root; + int64_t len; + int64_t magic; + int64_t blockIdx; + int64_t blockNum; + GM_ADDR shareAddrs[LCAL_MAX_RANK_SIZE]; + GlobalTensor dfx; + SyncCollectives sync; + GM_ADDR dumpAddr_ = nullptr; + + template + FORCE_INLINE_AICORE void SetAscendCAtomic(int op) + { + SetAtomicType(); + switch (op) { + case ADD: + SetAtomicAdd(); + return; + case MUL: + return; + case MAX: + SetAtomicMax(); + return; + case MIN: + SetAtomicMin(); + return; + default: + ; + } + } + + template + FORCE_INLINE_AICORE void SetAtomic(int op) + { + PipeBarrier(); + if (op != -1) { +#ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); +#endif + } + PipeBarrier(); + } + + FORCE_INLINE_AICORE void UnsetAtomic(int op) + { + if (op != -1) { + AscendC::SetAtomicNone(); + } + PipeBarrier(); + } + + template + FORCE_INLINE_AICORE void SetWaitEvent(event_t eventId) + { + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + } + + FORCE_INLINE_AICORE void DumpLcclLogInfo(LogId logId, Op operationType) + { +#ifdef ENABLE_LCCL_DUMP + constexpr int32_t UB_HEAD_OFFSET = 96; + + AscendC::PipeBarrier(); + GM_ADDR blockGm = (GM_ADDR)(dumpAddr_ + LCCL_DUMP_UINT_SIZE * GetBlockIdx()); + __ubuf__ LcclDumpBlockInfo *blockUb = (__ubuf__ LcclDumpBlockInfo*)(UB_HEAD_OFFSET); + __ubuf__ LcclDumpLogInfo *logUb = (__ubuf__ LcclDumpLogInfo*)(UB_HEAD_OFFSET + sizeof(LcclDumpBlockInfo)); + + CpGM2UB((__ubuf__ uint8_t*)blockUb, blockGm, sizeof(LcclDumpBlockInfo)); + AscendC::PipeBarrier(); + + if (blockUb->dumpOffset < sizeof(LcclDumpLogInfo)) { + return; + } + + logUb->logId = logId; + logUb->blockId = GetBlockIdx(); + logUb->syscyc = static_cast(GetSystemCycle()); + logUb->curPc = static_cast(get_pc()); + logUb->operationType = operationType; + logUb->rsv = 0; + CpUB2GM((GM_ADDR)blockUb->dumpAddr, (__ubuf__ uint8_t*)logUb, sizeof(LcclDumpLogInfo)); + + blockUb->dumpAddr += sizeof(LcclDumpBlockInfo); + blockUb->dumpOffset -= sizeof(LcclDumpLogInfo); + CpUB2GM(blockGm, (__ubuf__ uint8_t*)blockUb, sizeof(LcclDumpBlockInfo)); + AscendC::PipeBarrier(); +#endif + } +}; + +FORCE_INLINE_AICORE int64_t GetDataCount(const int64_t dataLen, const int64_t useBlockNum) +{ + return dataLen / useBlockNum; +} +#endif // LCCL_COLLECTIVES_H diff --git a/comm/lcal/src/ascendc_kernels/datacopy_gm2gm.h b/comm/lcal/src/ascendc_kernels/datacopy_gm2gm.h new file mode 100644 index 0000000000000000000000000000000000000000..318f679d7b0d9c2bc188b0dd17a993d2845f1a19 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/datacopy_gm2gm.h @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_DATACOPY_GM2GM_H +#define LCCL_DATACOPY_GM2GM_H +#include +#include "comm_args.h" + +using namespace AscendC; +using namespace Lcal; + +constexpr int32_t BUFFER_NUM = 1; +constexpr int32_t TILE_NUM = 2; +constexpr int32_t BLOCK_SIZE = UB_SINGLE_DMA_SIZE_MAX / TILE_NUM / BUFFER_NUM; + +template +FORCE_INLINE_AICORE void SetAtomicOpType(int op) +{ + switch (op) { + case ADD: + AscendC::SetAtomicAdd(); + break; + + case MUL: + break; + case MAX: + AscendC::SetAtomicMax(); + break; + case MIN: + AscendC::SetAtomicMin(); + break; + default: + AscendC::SetAtomicNone(); + ; + } +} + +template +FORCE_INLINE_AICORE void CpUB2GM(__gm__ T *gmAddr, __ubuf__ T *ubAddr, uint32_t size) +{ + LocalTensor ubTensor; + GlobalTensor gmTensor; + DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); + DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +template +FORCE_INLINE_AICORE void CpGM2UB(__ubuf__ T *ubAddr, __gm__ T *gmAddr, uint32_t size) +{ + LocalTensor ubTensor; + GlobalTensor gmTensor; + DataCopyExtParams dataCopyParams(1, size, 0, 0, 0); + ubTensor.address_.logicPos = static_cast(TPosition::VECIN); + ubTensor.address_.bufferAddr = reinterpret_cast(ubAddr); + gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint8_t *>(gmAddr)); + DataCopyPadExtParams padParams; + DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + + +template +FORCE_INLINE_AICORE void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, const uint32_t calCount) +{ + LocalTensor srcTensor; + LocalTensor dstTensor; + TBuffAddr srcAddr, dstAddr; + srcAddr.bufferAddr = reinterpret_cast(src); + dstAddr.bufferAddr = reinterpret_cast(dst); + srcTensor.SetAddr(srcAddr); + dstTensor.SetAddr(dstAddr); + DataCopy(dstTensor, srcTensor, calCount); +} +template +__aicore__ inline void DataCopyWrap(const GlobalTensor &dstGlobal, const LocalTensor &srcLocal, + const uint32_t size) +{ + if (size % UB_ALIGN_SIZE == 0) { + DataCopy(dstGlobal, srcLocal, size / sizeof(T)); + } else { + DataCopyExtParams copyParams{1, size, 0, 0, 0}; + DataCopyPad(dstGlobal, srcLocal, copyParams); + } +} + +template +__aicore__ inline void DataCopyWrap(const LocalTensor &dstLocal, const GlobalTensor &srcGlobal, + const uint32_t size) +{ + if (size % UB_ALIGN_SIZE == 0) { + DataCopy(dstLocal, srcGlobal, size / sizeof(T)); + } else { + DataCopyExtParams copyParams{1, size, 0, 0, 0}; + DataCopyPadExtParams padParams{true, 0, 1, 0}; + DataCopyPad(dstLocal, srcGlobal, copyParams, padParams); + } +} + +template +class DataCopyGM2GM { + constexpr static int32_t UB_HEAD_OFFSET = 64; + constexpr static int32_t BLOCK_SIZE_PIECE = BLOCK_SIZE / (sizeof(T) + sizeof(U)) / ALIGN_SIZE * ALIGN_SIZE; + constexpr static int32_t INPUT_BLOCK_SIZE = std::is_same_v ? BLOCK_SIZE : BLOCK_SIZE_PIECE * sizeof(U); + constexpr static int32_t OUTPUT_BLOCK_SIZE = std::is_same_v ? BLOCK_SIZE : BLOCK_SIZE_PIECE * sizeof(T); +public: + FORCE_INLINE_AICORE DataCopyGM2GM() {} + FORCE_INLINE_AICORE void Init(const GlobalTensor& outputGt, const GlobalTensor& inputGt, + const uint32_t calCount, int op) + { + inputGm = inputGt.GetPhyAddr(); + outputGm = outputGt.GetPhyAddr(); + inputUB = (__ubuf__ U*)(UB_HEAD_OFFSET); + if constexpr (std::is_same_v) { + outputUB = (__ubuf__ T*)inputUB; + } else { + outputUB = (__ubuf__ T*)(UB_HEAD_OFFSET + INPUT_BLOCK_SIZE); + } + this->op = op; + dataSizeRemain = calCount * sizeof(T); + } + + FORCE_INLINE_AICORE void Process() + { + SetAtomic(op); + int64_t i = 0; + while (dataSizeRemain >= OUTPUT_BLOCK_SIZE) { + CpGM2UB(inputUB, (__gm__ U*)inputGm + i * INPUT_BLOCK_SIZE / sizeof(U), INPUT_BLOCK_SIZE); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if constexpr (!std::is_same_v) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CastImpl(outputUB, inputUB, RoundMode::CAST_NONE, INPUT_BLOCK_SIZE / sizeof(U)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + CpUB2GM((__gm__ T*)outputGm + i * OUTPUT_BLOCK_SIZE / sizeof(T), (__ubuf__ T*)outputUB, + OUTPUT_BLOCK_SIZE); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + i += 1; + dataSizeRemain -= OUTPUT_BLOCK_SIZE; + } + if (dataSizeRemain > 0) { + CpGM2UB(inputUB, (__gm__ U*)inputGm + i * INPUT_BLOCK_SIZE / sizeof(U), + dataSizeRemain / sizeof(T) * sizeof(U)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if constexpr (!std::is_same_v) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CastImpl(outputUB, inputUB, RoundMode::CAST_NONE, dataSizeRemain / sizeof(T)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + CpUB2GM((__gm__ T*)outputGm + i * OUTPUT_BLOCK_SIZE / sizeof(T), (__ubuf__ T*)outputUB, + dataSizeRemain); + PipeBarrier(); + } + UnsetAtomic(op); + } + + FORCE_INLINE_AICORE void Process(T scale, T offset) + { + SetAtomic(op); + int64_t i = 0; + int64_t batchDataNum = OUTPUT_BLOCK_SIZE / sizeof(T); + while (dataSizeRemain > 0) { + int64_t curProcessNum = (dataSizeRemain > OUTPUT_BLOCK_SIZE ? OUTPUT_BLOCK_SIZE : dataSizeRemain) / + sizeof(T); + CpGM2UB(inputUB, (__gm__ U*)inputGm + i * batchDataNum, curProcessNum * sizeof(U)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if constexpr (!std::is_same_v) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CastImpl(outputUB, inputUB, RoundMode::CAST_NONE, curProcessNum); + PipeBarrier(); + AddsImpl(outputUB, outputUB, offset, curProcessNum); + PipeBarrier(); + MulsImpl(outputUB, outputUB, scale, curProcessNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + CpUB2GM((__gm__ T*)outputGm + i * batchDataNum, (__ubuf__ T*)outputUB, curProcessNum * sizeof(T)); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + i += 1; + dataSizeRemain -= OUTPUT_BLOCK_SIZE; + } + UnsetAtomic(op); + } + + FORCE_INLINE_AICORE void Process(const GlobalTensor& scaleGT, int64_t scaleCount, T offset) + { + if (scaleCount > UB_SINGLE_DMA_SIZE_MAX / (sizeof(T) + sizeof(U) + sizeof(T)) / ALIGN_SIZE * ALIGN_SIZE) { + ProcessForBigScale(scaleGT, scaleCount, offset); + } else { + ProcessForSmallScale(scaleGT, scaleCount, offset); + } + } +private: + FORCE_INLINE_AICORE void UnsetAtomic(int op) + { + if (op != -1) { + AscendC::SetAtomicNone(); + } + PipeBarrier(); + } + + FORCE_INLINE_AICORE void SetAtomic(int op) + { + PipeBarrier(); + if (op != -1) { +#ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); +#endif + } + PipeBarrier(); + } + + FORCE_INLINE_AICORE void ProcessForSmallScale(const GlobalTensor& scaleGT, int64_t scaleCount, T offset) + { + SetAtomic(op); + constexpr int32_t blockPieceNum = UB_SINGLE_DMA_SIZE_MAX / (sizeof(T) + sizeof(T) + sizeof(U)) / ALIGN_SIZE * + ALIGN_SIZE; + const int32_t batchDataNum = blockPieceNum / scaleCount * scaleCount; + const int32_t inputBlockSize = batchDataNum * sizeof(U); + const int32_t outputBlockSize = batchDataNum * sizeof(T); + scaleUB = (__ubuf__ T*)(UB_HEAD_OFFSET); + outputUB = (__ubuf__ T*)(scaleUB + blockPieceNum); + inputUB = (__ubuf__ U*)(outputUB + blockPieceNum); + __gm__ T *scale = const_cast<__gm__ T*>(scaleGT.GetPhyAddr()); + + CpGM2UB((__ubuf__ T*)scaleUB, scale, scaleCount * sizeof(T)); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + + int64_t repeatTimes = (dataSizeRemain > outputBlockSize ? outputBlockSize : dataSizeRemain) / sizeof(T) / + scaleCount; + int64_t mulVal = 2; + for (int64_t i = 1; i < repeatTimes; i *= mulVal) { + PipeBarrier(); + CopyUB2UB(scaleUB + i * scaleCount, scaleUB, (repeatTimes > i * mulVal ? i : repeatTimes - i) * scaleCount); + } + int64_t i = 0; + while (dataSizeRemain > 0) { + int64_t curProcessNum = (dataSizeRemain > outputBlockSize ? outputBlockSize : dataSizeRemain) / sizeof(T); + CpGM2UB(inputUB, (__gm__ U*)inputGm + i * batchDataNum, curProcessNum * sizeof(U)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CastImpl(outputUB, inputUB, RoundMode::CAST_NONE, curProcessNum); + PipeBarrier(); + AddsImpl(outputUB, outputUB, offset, curProcessNum); + PipeBarrier(); + MulImpl(outputUB, outputUB, scaleUB, curProcessNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM((__gm__ T*)outputGm + i * batchDataNum, (__ubuf__ T*)outputUB, curProcessNum * sizeof(T)); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + i += 1; + dataSizeRemain -= outputBlockSize; + } + UnsetAtomic(op); + } + + FORCE_INLINE_AICORE void ProcessForBigScale(const GlobalTensor& scaleGT, int64_t scaleCount, T offset) + { + SetAtomic(op); + const int32_t blockPieceNum = UB_SINGLE_DMA_SIZE_MAX / (sizeof(T) + sizeof(U) + sizeof(T)) / ALIGN_SIZE * + ALIGN_SIZE; + const int32_t inputBlockSize = blockPieceNum * sizeof(U); + const int32_t outputBlockSize = blockPieceNum * sizeof(T); + const int32_t dataNumPerBatch = outputBlockSize / sizeof(T); + const int32_t scaleBatchNum = (scaleCount + dataNumPerBatch - 1) / dataNumPerBatch; + + scaleUB = (__ubuf__ T*)(UB_HEAD_OFFSET); + outputUB = (__ubuf__ T*)(scaleUB + outputBlockSize / sizeof(T)); + inputUB = (__ubuf__ U*)(outputUB + outputBlockSize / sizeof(T)); + __gm__ T *scale = const_cast<__gm__ T*>(scaleGT.GetPhyAddr()); + + int64_t i = 0; + int32_t curDataNum = 0; + int32_t processedNum = 0; + while (dataSizeRemain > 0) { + if (i % scaleBatchNum == scaleBatchNum - 1) { + curDataNum = scaleCount - i % scaleBatchNum * dataNumPerBatch; + } else { + curDataNum = dataNumPerBatch; + } + CpGM2UB(inputUB, (__gm__ U*)inputGm + processedNum, curDataNum * sizeof(U)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpGM2UB(scaleUB, scale + i % scaleBatchNum * dataNumPerBatch, curDataNum * sizeof(T)); + CastImpl(outputUB, inputUB, RoundMode::CAST_NONE, curDataNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AddsImpl(outputUB, outputUB, offset, curDataNum); + PipeBarrier(); + MulImpl(outputUB, outputUB, scaleUB, curDataNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM((__gm__ T*)outputGm + processedNum, (__ubuf__ T*)outputUB, curDataNum * sizeof(T)); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + i += 1; + dataSizeRemain -= curDataNum * sizeof(T); + processedNum += curDataNum; + } + UnsetAtomic(op); + } +private: + int64_t dataSizeRemain = 0; + __ubuf__ U* inputUB = nullptr; + __ubuf__ T* outputUB = nullptr; + __ubuf__ T* scaleUB = nullptr; + const __gm__ U* inputGm = nullptr; + const __gm__ T* outputGm = nullptr; + int op; +}; +#endif // LCCL_DATACOPY_GM2GM_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/datacopy_gm2gm_delay.h b/comm/lcal/src/ascendc_kernels/datacopy_gm2gm_delay.h new file mode 100644 index 0000000000000000000000000000000000000000..46a94bd3020ec3cb8bb18746e895c29828078dbb --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/datacopy_gm2gm_delay.h @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_DATACOPY_GM2GM_DELAY_H +#define LCCL_DATACOPY_GM2GM_DELAY_H +#include "datacopy_gm2gm.h" + +using namespace AscendC; +using namespace Lcal; + +template +class DataCopyGM2GMDelay { + constexpr static int64_t THREE_NUM = 3; + constexpr static int64_t FOUR_NUM = 4; + constexpr static int32_t WORK_OFFSET = 8192; + constexpr static int32_t WORK_BLOCK_NUM = WORK_OFFSET / sizeof(T); + constexpr static int32_t UB_HEAD_OFFSET = WORK_OFFSET * 2; + constexpr static int32_t SCALE_SIZE = 32; + constexpr static int32_t SCALE_NUM = SCALE_SIZE / sizeof(T); + constexpr static int32_t SINGLE_SCALE_SIZE = 2; + constexpr static int32_t BLOCK_NUM = (UB_SINGLE_DMA_SIZE_MAX - WORK_OFFSET * 2 - SCALE_SIZE * 4) / 2 / + (sizeof(U) + sizeof(T)) / ALIGN_SIZE * ALIGN_SIZE; + constexpr static int32_t IN_BLOCKSIZE = BLOCK_NUM * sizeof(U); + +public: + FORCE_INLINE_AICORE DataCopyGM2GMDelay() {} + + FORCE_INLINE_AICORE void Init(GlobalTensor& outputGt, GlobalTensor (&inputGt)[8], + GlobalTensor (&inputScaleGt)[8], const uint32_t calNum, int rankCount, GlobalTensor& outScaleGt, + TBuf tbuf) + { + for (int index = 0; index < rankCount; index++) { + this->inputGt[index] = inputGt[index]; + this->inputScaleGt[index] = inputScaleGt[index]; + } + this->outputGt = outputGt; + this->outScaleGt = outScaleGt; + inTensor[0] = tbuf.GetWithOffset(BLOCK_NUM, 0); + inTensor[1] = tbuf.GetWithOffset(BLOCK_NUM, WORK_OFFSET + SCALE_SIZE * HALF_NUM + IN_BLOCKSIZE * THREE_NUM); + singleScaleUBTensor[0] = tbuf.GetWithOffset(SCALE_NUM, IN_BLOCKSIZE); + singleScaleUBTensor[1] = tbuf.GetWithOffset(SCALE_NUM, WORK_OFFSET + SCALE_SIZE * HALF_NUM + + IN_BLOCKSIZE * FOUR_NUM); + singleScaleUUBTensor[0] = tbuf.GetWithOffset(SCALE_NUM, IN_BLOCKSIZE); + singleScaleUUBTensor[1] = tbuf.GetWithOffset(SCALE_NUM, WORK_OFFSET + SCALE_SIZE * HALF_NUM + + IN_BLOCKSIZE * FOUR_NUM); + scaleUBTensor[0] = tbuf.GetWithOffset(SCALE_NUM, IN_BLOCKSIZE + SCALE_SIZE); + scaleUBTensor[1] = tbuf.GetWithOffset(SCALE_NUM, WORK_OFFSET + SCALE_SIZE * THREE_NUM + + IN_BLOCKSIZE * FOUR_NUM); + scaleUUBTensor[0] = tbuf.GetWithOffset(SCALE_NUM, IN_BLOCKSIZE + SCALE_SIZE); + scaleUUBTensor[1] = tbuf.GetWithOffset(SCALE_NUM, WORK_OFFSET + SCALE_SIZE * THREE_NUM + + IN_BLOCKSIZE * FOUR_NUM); + workUBTensor[0] = tbuf.GetWithOffset(WORK_BLOCK_NUM, IN_BLOCKSIZE + SCALE_SIZE * HALF_NUM); + workUBTensor[1] = tbuf.GetWithOffset(WORK_BLOCK_NUM, WORK_OFFSET + SCALE_SIZE * FOUR_NUM + + IN_BLOCKSIZE * FOUR_NUM); + outputUBTensor[0] = tbuf.GetWithOffset(BLOCK_NUM, IN_BLOCKSIZE + SCALE_SIZE * HALF_NUM + WORK_OFFSET); + outputUBTensor[1] = tbuf.GetWithOffset(BLOCK_NUM, WORK_OFFSET * HALF_NUM + SCALE_SIZE * FOUR_NUM + + IN_BLOCKSIZE * FOUR_NUM); + this->rankCount = rankCount; + totalDataSize = calNum * sizeof(U); + this->calNum = calNum; + this->rankId = rankId; + } + + FORCE_INLINE_AICORE void PreProcess() + { + for (int index = 0; index < rankCount; index++) { + DataCopyWrap(scaleUUBTensor[0][index * SCALE_SIZE / sizeof(U)], inputScaleGt[index], SCALE_SIZE); + pipe_barrier(PIPE_ALL); + DataCopyWrap(scaleUUBTensor[1][index * SCALE_SIZE / sizeof(U)], inputScaleGt[index], SCALE_SIZE); + pipe_barrier(PIPE_ALL); + } + for (int index = 0; index < rankCount; index++) { + scaleUBTensor[0][index].SetValue(0, scaleUBTensor[0].GetValue(index * SCALE_SIZE / sizeof(T))); + pipe_barrier(PIPE_ALL); + scaleUBTensor[1][index].SetValue(0, scaleUBTensor[1].GetValue(index * SCALE_SIZE / sizeof(T))); + pipe_barrier(PIPE_ALL); + outputUBTensor[0][index].SetValue(0, 1); + AscendC::PipeBarrier(); + } + Div(scaleUBTensor[1], outputUBTensor[0], scaleUBTensor[1], rankCount); + AscendC::PipeBarrier(); + ReduceMin(singleScaleUBTensor[0], scaleUBTensor[0], + workUBTensor[1][WORK_BLOCK_NUM / HALF_NUM], rankCount, false); + pipe_barrier(PIPE_ALL); + DataCopyWrap(outScaleGt, singleScaleUUBTensor[0], sizeof(T)); + AscendC::PipeBarrier(); + } + + FORCE_INLINE_AICORE void LoopUncastAndMul(int idx, int index, event_t eventId) + { + PipeBarrier(); + T scalarValue = scaleUBTensor[1].GetValue(index); + PipeBarrier(); + int32_t perRankNum; + PipeBarrier(); + for (int j = 0; perRankNumRemain > 0; j++) { + PipeBarrier(); + perRankNum = perRankNumRemain >= WORK_BLOCK_NUM ? WORK_BLOCK_NUM : perRankNumRemain; + PipeBarrier(); + + perRankNumRemain -= perRankNum; + PipeBarrier(); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + PipeBarrier(); + Cast((idx & 1) ? workUBTensor[0] : workUBTensor[1], (idx & 1) ? inTensor[0][j * + WORK_BLOCK_NUM] : inTensor[1][j * WORK_BLOCK_NUM], RoundMode::CAST_NONE, perRankNum); + PipeBarrier(); + if (index == 0) { + Muls((idx & 1) ? outputUBTensor[0][j * WORK_BLOCK_NUM] : outputUBTensor[1][j * + WORK_BLOCK_NUM], (idx & 1) ? workUBTensor[0] : workUBTensor[1], scalarValue, perRankNum); + } else { + Axpy((idx & 1) ? outputUBTensor[0][j * WORK_BLOCK_NUM] : outputUBTensor[1][j * + WORK_BLOCK_NUM], (idx & 1) ? workUBTensor[0] : workUBTensor[1], scalarValue, perRankNum); + } + PipeBarrier(); + } + } + + FORCE_INLINE_AICORE void Mte3Process(int idx, int index, int calCount, event_t eventId) + { + if (index == (rankCount - 1)) { + if constexpr (std::is_same_v) { + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + DataCopyWrap(outputGt[idx * BLOCK_NUM], (idx & 1) ? + outputUBTensor[0] : outputUBTensor[1], calCount * sizeof(V)); + } + if constexpr (std::is_same_v) { + PipeBarrier(); + T scaleValue = singleScaleUBTensor[0].GetValue(0); + PipeBarrier(); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + PipeBarrier(); + Muls((idx & 1) ? outputUBTensor[0] : outputUBTensor[1], (idx & 1) ? + outputUBTensor[0] : outputUBTensor[1], scaleValue, calCount); + PipeBarrier(); + Cast((idx & 1) ? inTensor[0] : inTensor[1], (idx & 1) ? + outputUBTensor[0] : outputUBTensor[1], RoundMode::CAST_NONE, calCount); + PipeBarrier(); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + DataCopyWrap(outputGt[idx * BLOCK_NUM], (idx & 1) ? + inTensor[0] : inTensor[1], calCount * sizeof(V)); + } + } + } + + FORCE_INLINE_AICORE int GetSize(int idx, int numOfPiece) + { + int size; + if (idx < (numOfPiece - 1)) { + size = IN_BLOCKSIZE; + } else if (idx == (numOfPiece - 1)) { + size = totalDataSize - (numOfPiece - 1) * IN_BLOCKSIZE; + } else { + size = 0; + } + return size; + } + + FORCE_INLINE_AICORE void Process() + { + PreProcess(); + int numOfPiece = CeilDiv(calNum, BLOCK_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; i < numOfPiece; i += HALF_NUM) { + for (int index = 0; index < rankCount; index++) { + for (int k = 0; k < HALF_NUM; k++) { + int idx = i + k; + int size = GetSize(idx, numOfPiece); + int32_t calCount = size / sizeof(U); + perRankNumRemain = calCount; + event_t eventId = (idx & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::WaitFlag(eventId); + DataCopyWrap((idx & 1) ? inTensor[0] : inTensor[1], inputGt[index][BLOCK_NUM * idx], size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + LoopUncastAndMul(idx, index, eventId); + Mte3Process(idx, index, calCount, eventId); + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + AscendC::SetFlag(eventId); + } + } + } + + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + } +private: + template + FORCE_INLINE_AICORE T1 CeilDiv(T1 a, T2 b) + { + if (b == 0) { + return 0; + } + return (a + b - 1) / b; + } + +private: + int64_t totalDataSize = 0; + int rankCount; + int perRankNumRemain; + int calNum; + int rankId; + int numLayer; + + LocalTensor inTensor[2]; + LocalTensor singleScaleUUBTensor[2]; + LocalTensor singleScaleUBTensor[2]; + LocalTensor scaleUUBTensor[2]; + LocalTensor scaleUBTensor[2]; + LocalTensor workUBTensor[2]; + LocalTensor outputUBTensor[2]; + + GlobalTensor outputGt; + GlobalTensor inputGt[8]; + GlobalTensor inputScaleGt[8]; + GlobalTensor outScaleGt; +}; + +#endif // LCCL_DATACOPY_GM2GM_DELAYH + diff --git a/comm/lcal/src/ascendc_kernels/ipc_queue.h b/comm/lcal/src/ascendc_kernels/ipc_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..89967c2f55017d46ef2b8d63a90d8077bb016024 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/ipc_queue.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_IPC_QUEUE_H +#define LCCL_IPC_QUEUE_H +#include "sync_collectives.h" +using namespace AscendC; + +template +class IpcQueue { +public: + FORCE_INLINE_AICORE IpcQueue() {} + + FORCE_INLINE_AICORE void Init(SyncCollectives *sync, int64_t magic, GM_ADDR workSpace, uint64_t bufferNum, + uint64_t blockNum) + { + this->sync = sync; + this->magic = magic; + depth = bufferNum / blockNum; + front = 0; + rear = 0; + count = 0; + this->blockNum = blockNum; + buff.SetGlobalBuffer((__gm__ T*)workSpace, bufferNum); + blockIdx = GetBlockIdx(); + } + + FORCE_INLINE_AICORE bool Full() + { + if ((rear + 1) % depth == front) { + return true; + } + return false; + } + + FORCE_INLINE_AICORE GlobalTensor EnQue() + { + uint64_t rearOld = rear; + rear = (rear + 1) % depth; + return buff[rearOld * blockNum]; + } + + FORCE_INLINE_AICORE void DeQue(int checkRank, int checkBlock = -1) + { + if (!Full()) { + return; + } + if (checkBlock == -1) { + checkBlock = blockIdx; + } + sync->WaitInnerFlag(magic, count, checkRank, checkBlock); + PipeBarrier(); + int64_t val = sync->GetInnerFlag(checkRank, checkBlock) & EVENT_ID_MASK; + count = val + 1; + front = (val + 1) % depth; + } + + FORCE_INLINE_AICORE void DeQue(int *rankList, int checkCount, int checkBlock = -1) + { + if (!Full()) { + return; + } + if (checkBlock == -1) { + checkBlock = blockIdx; + } + int64_t minIndex = LLONG_MAX; + for (int i = 0; i < checkCount; i++) { + sync->WaitInnerFlag(magic, count, rankList[i], checkBlock); + PipeBarrier(); + + int64_t val = sync->GetInnerFlag(rankList[i], checkBlock) & EVENT_ID_MASK; + if (minIndex > val) { + minIndex = val; + } + } + count = minIndex + 1; + front = (minIndex + 1) % depth; + } + FORCE_INLINE_AICORE void DeQue(int *rankList, int *blockIdxList, int checkCount) + { + if (!Full()) { + return; + } + + int64_t minIndex = LLONG_MAX; + for (int i = 0; i < checkCount; i++) { + sync->WaitInnerFlag(magic, count, rankList[i], blockIdxList[i]); + PipeBarrier(); + + int64_t val = sync->GetInnerFlag(rankList[i], blockIdxList[i]) & EVENT_ID_MASK; + if (minIndex > val) { + minIndex = val; + } + } + count = minIndex + 1; + front = (minIndex + 1) % depth; + } + + FORCE_INLINE_AICORE GlobalTensor ReadFront() + { + uint64_t frontOld = front; + front = (front + 1) % depth; + return buff[frontOld * blockNum]; + } + +private: + int64_t magic; + uint64_t depth; + uint64_t front; + uint64_t rear; + uint64_t count; + uint64_t blockNum; + GlobalTensor buff; + SyncCollectives *sync; + int blockIdx; +}; +#endif // LCCL_IPC_QUEUE_H diff --git a/comm/lcal/src/ascendc_kernels/lccl_op.h b/comm/lcal/src/ascendc_kernels/lccl_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bf54ce2b1bf05482a6b4e133b52a5d5e2db231a0 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/lccl_op.h @@ -0,0 +1,250 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_OP_H +#define LCCL_OP_H + +#if defined(__DAV_C220_VEC__) || defined(__DAV_C220_CUBE__) + +#include "op_def.h" +#include "allgather.h" +#include "91093/allgather_hierarchy_double_ring.h" +#include "allreduce_one_shot.h" +#include "allreduce_two_shot.h" +#include "allreduce_big_data.h" +#include "91093/allreduce_big_data_sio.h" +#include "91093/allreduce_hierarchy_double_ring.h" +#include "reduce_scatter.h" +#include "91093/reduce_scatter_big_data_91093_4step.h" +#include "91093/reduce_scatter_hierarchy_double_ring.h" +#include "91093/all2all_hierarchy.h" +#include "91093/all2all_hierarchy_small.h" + +#include "../kernels/lcal_allreduce_2npu_read.cce" +#include "../kernels/lcal_allreduce_2npu_write.cce" +#include "../kernels/lcal_allreduce_2npu_big_write.cce" +#include "../kernels/lcal_allreduce_two_shot.cce" +#include "../kernels/lcal_allreduce_big_data.cce" +#include "../kernels/lcal_allreduce_two_shot_910B2C.cce" +#include "../kernels/lcal_allreduce_big_data_910B2C.cce" +#include "../kernels/lcal_allreduce_deterministic.cce" +#include "../kernels/lcal_allreduce_deterministic_big_data.cce" +#include "../kernels/lcal_reduce_scatter_big_data_write.cce" +#include "../kernels/lcal_reduce_scatter_write.cce" +#include "../kernels/lcal_reduce_scatter.cce" +#include "../kernels/lcal_reduce_scatter_big_data.cce" +#include "../kernels/lcal_allgather_910B2C.cce" +#include "../kernels/lcal_allgather_big_data_910B2C.cce" +#include "../kernels/lcal_allgather_2npu.cce" +#include "../kernels/lcal_allgather_2npu_big_data_write.cce" +#include "../kernels/lcal_allgather.cce" +#include "../kernels/lcal_allgather_big_data.cce" +#include "../kernels/lcal_broadcast_write.cce" +#include "../kernels/lcal_broadcast_big_data.cce" +#include "../kernels/lcal_all2all_transpose.cce" + +#define CLASS_OP_910B_RDMA_LAUNCH(name, type) \ +do { \ +name opKernel(localRank, localRankSize, extraFlag); \ +opKernel.Init(KERNELS_ARGS_CALL()); \ +opKernel.Process(); \ +} while (0) + +#define CLASS_OP_QUANT_LAUNCH(name, outputType, inputType) \ +do { \ +name opKernel(localRank, localRankSize, extraFlag); \ +opKernel.Init(KERNELS_ARGS_CALL()); \ +opKernel.Process(); \ +} while (0) + +extern "C" __global__ __aicore__ __attribute__((section("Attr_Section_Lcal"))) void LcalDescriptor() {} + +#define LCCL_BROADCAST_FUNC_AUTO_DEF(suffix) \ +extern "C" __global__ __aicore__ void LcalBroadcast##suffix(KERNELS_ARGS_FUN()) \ +{ \ + if ASCEND_IS_AIV { \ + GET_COMM_ARGS; \ + __gm__ char * shareAddrs[LCAL_MAX_RANK_SIZE]; \ + GET_IPC_MEM_ARGS(char); \ + if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { \ + LcalBroadcast2npuBigDataWrite(ALLREDUCE_ARGS_CALL(char)); \ + } else { \ + LcalBroadcastBigData(ALLREDUCE_ARGS_CALL(char)); \ + } \ + } \ +} + +#define LCCL_ALLGATHER_FUNC_AUTO_DEF(type, suffix) \ +extern "C" __global__ __aicore__ void LcalAllGather_##type##suffix(KERNELS_ARGS_FUN()) { \ + if ASCEND_IS_AIV { \ + GET_COMM_ARGS; \ + constexpr int32_t quickOneshotRankSize = 2; \ + constexpr int32_t cceSmallDataSize = 2 * 1024 * 1024; \ + constexpr int32_t smallRankSize = 8; \ + constexpr int32_t smallDataSize910a3 = 32 * 1024 * 1024; \ + __gm__ type * shareAddrs[LCAL_MAX_RANK_SIZE]; \ + GET_IPC_MEM_ARGS(type); \ + if ((extraFlag & ExtraFlag::TOPO_910B2C) != 0 && rankSize > smallRankSize) { \ + if (len * sizeof(type) < cceSmallDataSize) { \ + LcalAllGather910B2C(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + LcalAllGatherBigData910B2C(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } else if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { \ + LcalAllGather2npuBigDataWrite(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && lcalBlockNum != rankSize && \ + (len > smallDataSize910a3 / sizeof(type) || rankSize > smallRankSize) && \ + rankSize > quickOneshotRankSize && rankSize % quickOneshotRankSize == 0) { \ + CLASS_OP_LAUNCH(AllGatherHierarchyDoubleRing, type); \ + } else { \ + if (rankSize == quickOneshotRankSize && len * sizeof(type) < SIZE_OF_8M && lcalBlockNum != rankSize) { \ + LcalAllGather2npu(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else if (rankSize == quickOneshotRankSize && lcalBlockNum != rankSize) { \ + LcalAllGather2npuBigDataWrite(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else if (rankSize > quickOneshotRankSize && len * sizeof(type) < cceSmallDataSize || \ + lcalBlockNum == rankSize) { \ + LcalAllGather(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + LcalAllGatherBigData(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } \ + } \ +} + +#define LCCL_ALL_REDUCE_FUNC_AUTO_DEF(type, suffix) \ +extern "C" __global__ __aicore__ void LcalAllReduce_##type##suffix(KERNELS_ARGS_FUN()) { \ + if ASCEND_IS_AIV { \ + GET_COMM_ARGS; \ + constexpr int32_t quickOneshotRankSize = 2; \ + constexpr int32_t threeStepNum = 3; \ + constexpr int32_t smallRankSize = 8; \ + constexpr int32_t oneshotDataSize = 16 * 1024; \ + constexpr int64_t quantSmallDataSize = 512 * 1024; \ + constexpr int32_t cceSmallDataSize = 2 * 1024 * 1024; \ + constexpr int32_t smallDataSize910a3 = 32 * 1024 * 1024; \ + constexpr int32_t rankSize910a3 = 16; \ + __gm__ type * shareAddrs[LCAL_MAX_RANK_SIZE]; \ + GET_IPC_MEM_ARGS(type); \ + if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { \ + if (len * sizeof(type) < SIZE_OF_8M) { \ + LcalAllReduce2npuWrite(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + LcalAllReduce2npuBigDataWrite(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } else if ((extraFlag & ExtraFlag::QUANT_FP16) != 0 && std::is_same_v) { \ + if (len * sizeof(type) <= oneshotDataSize) { \ + CLASS_OP_QUANT_LAUNCH(AllReduceOneShot, half, int8_t); \ + } else if (len * sizeof(type) <= quantSmallDataSize) { \ + CLASS_OP_QUANT_LAUNCH(AllReduceTwoShot, half, int8_t); \ + } else { \ + CLASS_OP_QUANT_LAUNCH(AllReduceBigData, half, int8_t); \ + } \ + } else if ((extraFlag & ExtraFlag::TOPO_910B2C) != 0 && rankSize > smallRankSize) { \ + if (len * sizeof(type) < cceSmallDataSize) { \ + LcalAllReduceTwoShot910B2C(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + LcalAllReduceBigData910B2C(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } else if ((extraFlag & ExtraFlag::DETERMINISTIC) != 0) { \ + constexpr uint32_t maxAivNum = 40; \ + const bool isAivNumSupport = ((extraFlag & ExtraFlag::IS_GREATER_THAN_40_AIV) != 0 || \ + rankSize * threeStepNum <= maxAivNum); \ + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0) { \ + if (rankSize % quickOneshotRankSize == 1 || rankSize == quickOneshotRankSize || \ + (rankSize <= rankSize910a3 && len * sizeof(type) <= smallDataSize910a3 && isAivNumSupport)) { \ + LcalAllReduceDeterministicBigData(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + CLASS_OP_LAUNCH(AllReduceHierarchyDoubleRing, type); \ + } \ + } else if (len * sizeof(type) < SMALL_DATA_SIZE) { \ + LcalAllReduceDeterministic(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else { \ + LcalAllReduceDeterministicBigData(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } else if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && lcalBlockNum != rankSize && \ + (rankSize == quickOneshotRankSize || len * sizeof(type) > smallDataSize910a3)) { \ + if (rankSize == quickOneshotRankSize) { \ + LcalAllReduce2npuBigDataWrite(ALLREDUCE_ARGS_CALL_16P(type)); \ + } else if (rankSize % quickOneshotRankSize == 0) { \ + CLASS_OP_LAUNCH(AllReduceHierarchyDoubleRing, type); \ + } else { \ + CLASS_OP_LAUNCH(AllReduceBigDataSio, type); \ + } \ + } else { \ + if (len * sizeof(type) < cceSmallDataSize or lcalBlockNum == rankSize) { \ + if (rankSize == quickOneshotRankSize && lcalBlockNum != rankSize) { \ + LcalAllReduce2npuRead(ALLREDUCE_ARGS_CALL(type)); \ + } else { \ + LcalAllReduceTwoShot(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } else { \ + LcalAllReduceBigData(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + } \ + } \ +} + +#define LCCL_ALL2ALL_FUNC_AUTO_DEF(type, suffix) \ +extern "C" __global__ __aicore__ void LcalAll2All_##type##suffix(KERNELS_ARGS_FUN()) { \ + if ASCEND_IS_AIV { \ + GET_COMM_ARGS; \ + __gm__ type * shareAddrs[LCAL_MAX_RANK_SIZE]; \ + GET_IPC_MEM_ARGS(type); \ + constexpr int32_t smallRankSize = 8; \ + if (op != 0 && root != 0) { \ + LcalAll2AllTranspose(ALLREDUCE_ARGS_CALL_16P(type)); \ + } \ + else if ((extraFlag & ExtraFlag::TOPO_910_93) != 0) { \ + if (rankSize <= smallRankSize && len * sizeof(type) > SMALL_DATA_SIZE && \ + (len * sizeof(type)) % (smallRankSize * smallRankSize * rankSize) == 0) { \ + CLASS_OP_LAUNCH(All2AllHierarchySmall, type); \ + } else { \ + CLASS_OP_LAUNCH(All2AllHierarchy, type); \ + } \ + } \ + } \ +} + +#define LCCL_REDUCE_SCATTER_FUNC_AUTO_DEF(type, suffix) \ +extern "C" __global__ __aicore__ void LcalReduceScatter_##type##suffix(KERNELS_ARGS_FUN()) { \ + if ASCEND_IS_AIV { \ + GET_COMM_ARGS; \ + constexpr int32_t quickOneshotRankSize = 2; \ + constexpr int32_t cceSmallDataSize = 2 * 1024 * 1024; \ + constexpr int32_t a3BigDataSize = 32 * 1024 * 1024; \ + constexpr int32_t a3SupportRankSize = 4; \ + constexpr int32_t smallRankSize = 8; \ + const bool isDbRing = (rankSize == a3SupportRankSize || rankSize == smallRankSize) && \ + (len * sizeof(type) * smallRankSize > cceSmallDataSize && \ + len * sizeof(type) * smallRankSize <= a3BigDataSize); \ + __gm__ type * shareAddrs[LCAL_MAX_RANK_SIZE]; \ + GET_IPC_MEM_ARGS(type); \ + if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { \ + LcalReduceScatterBigDataWrite(ALLREDUCE_ARGS_CALL(type)); \ + } else if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && (rankSize > smallRankSize || isDbRing)) { \ + if (isDbRing) { \ + CLASS_OP_LAUNCH(ReduceScatterHierarchyDoubleRing, type); \ + } else if (len * sizeof(type) <= SMALL_DATA_SIZE) { \ + CLASS_OP_LAUNCH(ReduceScatter, type); \ + } else { \ + CLASS_OP_LAUNCH(ReduceScatterBigData91093, type); \ + } \ + } else { \ + if (rankSize == quickOneshotRankSize && len * sizeof(type) < SIZE_OF_8M) { \ + LcalReduceScatterWrite(ALLREDUCE_ARGS_CALL(type)); \ + } else if (rankSize > quickOneshotRankSize && len * sizeof(type) < cceSmallDataSize){ \ + LcalReduceScatter(ALLREDUCE_ARGS_CALL(type)); \ + } else { \ + LcalReduceScatterBigData(ALLREDUCE_ARGS_CALL(type)); \ + } \ + } \ + } \ +} +#endif +#endif \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/lccl_op1.cpp b/comm/lcal/src/ascendc_kernels/lccl_op1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..62539f3d7ad76990eeaf89475ec8355686079346 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/lccl_op1.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ + +#include "lccl_op.h" + +LCCL_TYPE_AIV_FUNC(LCCL_ALL_REDUCE_FUNC_AUTO_DEF); + +#endif + +#ifdef __DAV_C220_CUBE__ + +#include "lccl_op.h" + +LCCL_TYPE_AIC_FUNC(LCCL_ALL_REDUCE_FUNC_AUTO_DEF); + +#endif + diff --git a/comm/lcal/src/ascendc_kernels/lccl_op2.cpp b/comm/lcal/src/ascendc_kernels/lccl_op2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76b52f7370ec843f913c7cd3630d30c03720c75f --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/lccl_op2.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ + +#include "lccl_op.h" + +LCCL_TYPE_AIV_FUNC(LCCL_ALLGATHER_FUNC_AUTO_DEF); +LCCL_TYPE_AIV_FUNC(LCCL_REDUCE_SCATTER_FUNC_AUTO_DEF); +LCCL_TYPE_AIV_FUNC(LCCL_ALL2ALL_FUNC_AUTO_DEF); + +#ifdef ENABLE_LCCL_MIX +LCCL_BROADCAST_FUNC_AUTO_DEF(_mix_aiv) +#else +LCCL_BROADCAST_FUNC_AUTO_DEF() +#endif +#endif +#ifdef __DAV_C220_CUBE__ + +#include "lccl_op.h" +LCCL_TYPE_AIC_FUNC(LCCL_ALLGATHER_FUNC_AUTO_DEF); +LCCL_TYPE_AIC_FUNC(LCCL_REDUCE_SCATTER_FUNC_AUTO_DEF); +LCCL_TYPE_AIC_FUNC(LCCL_ALL2ALL_FUNC_AUTO_DEF); +#ifdef ENABLE_LCCL_MIX +LCCL_BROADCAST_FUNC_AUTO_DEF(_mix_aic) +#else +LCCL_BROADCAST_FUNC_AUTO_DEF() +#endif +#endif \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/op_def.h b/comm/lcal/src/ascendc_kernels/op_def.h new file mode 100644 index 0000000000000000000000000000000000000000..45086dae027690760d1d2c1f6b940d58a2c0c139 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/op_def.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_OP_DEF_H +#define LCCL_OP_DEF_H +#define GET_COMM_ARGS \ + GlobalTensor commArgsGm; \ + commArgsGm.SetGlobalBuffer(reinterpret_cast<__gm__ int *>(commArgs), 5); \ + int rank = commArgsGm.GetValue(0); \ + int localRank = commArgsGm.GetValue(1); \ + int rankSize = commArgsGm.GetValue(2); \ + int localRankSize = commArgsGm.GetValue(3); \ + uint32_t extraFlag = commArgsGm.GetValue(4); \ + GM_ADDR dumpAddr = (reinterpret_cast<__gm__ CommArgs *>(commArgs))->dumpAddr; \ + int32_t lcalBlockNum = GetBlockNum() + +#ifdef ENABLE_LCCL_MIX +#define SET_MAGIC \ +do { \ + __gm__ CommArgs * commArgsTmp = reinterpret_cast<__gm__ CommArgs *>(commArgs); \ + PipeBarrier(); \ + SetAtomicNone(); \ + SetMaskNormImpl(); \ + SetSyncBaseAddr(commArgsTmp->fftsVal); \ + SetVectorMask((uint64_t)-1, (uint64_t)-1); \ + PipeBarrier(); \ + LocalTensor localSet; \ + localSet.address_.logicPos = static_cast(TPosition::VECIN); \ + localSet.address_.bufferAddr = reinterpret_cast((__ubuf__ int32_t *)96); \ + GlobalTensor magicGt; \ + magicGt.SetGlobalBuffer((__gm__ int32_t *)commArgsTmp->magics); \ + if (GetBlockIdx() == 0) { \ + SetAtomicOpType(Op::ADD); \ + localSet.SetValue(0, 1); \ + AscendC::SetFlag(EVENT_ID0); \ + AscendC::WaitFlag(EVENT_ID0); \ + DataCopyExtParams dataCopyParams(1, sizeof(int32_t), 0, 0, 0); \ + DataCopyPad(magicGt[rankSize - 1], localSet, dataCopyParams); \ + AscendC::SetAtomicNone(); \ + PipeBarrier(); \ + } \ + SyncAll(); \ + DataCopyExtParams dataCopyParams(1, sizeof(int32_t), 0, 0, 0); \ + DataCopyPadExtParams padParams; \ + DataCopyPad(localSet, magicGt[rankSize - 1], dataCopyParams, padParams); \ + AscendC::SetFlag(EVENT_ID0); \ + AscendC::WaitFlag(EVENT_ID0); \ + magic = static_cast(localSet.GetValue(0)); \ + PipeBarrier(); \ + constexpr int32_t aivNumPerAic = 2; \ + lcalBlockNum = GetBlockNum() * aivNumPerAic; \ +} while (0) +#else +#define SET_MAGIC \ +do {} while (0) +#endif + +#define GET_IPC_MEM_ARGS(type) \ +do { \ + SET_MAGIC; \ + GlobalTensor peerMemsAddrGm; \ + peerMemsAddrGm.SetGlobalBuffer(&(reinterpret_cast<__gm__ CommArgs *>(commArgs))->peerMems[0], \ + LCAL_MAX_RANK_SIZE); \ + for (int i = 0; i < rankSize; ++i) { \ + shareAddrs[i] = (__gm__ type *) (peerMemsAddrGm.GetValue(i) + \ + (magic % PING_PONG_SIZE) * (IPC_BUFF_MAX_SIZE + IPC_DATA_OFFSET)); \ + } \ + AscendC::PipeBarrier(); \ +} while (0) \ + +#define CLASS_OP_LAUNCH(name, type) \ +do { \ + name opKernel(rank, rankSize, extraFlag); \ + opKernel.Init(KERNELS_ARGS_CALL()); \ + opKernel.Process(); \ +} while (0) + +#define CLASS_OP_QUANT_910A5_LAUNCH(name, outputType, addType, inputType) \ +do { \ + name opKernel(rank, rankSize, extraFlag); \ + opKernel.Init(KERNELS_ARGS_CALL()); \ + opKernel.Process(); \ +} while (0) + +#define LCCL_TYPE_FUNC(fun) \ + fun(int,);fun(int8_t,);fun(int16_t,);fun(int64_t,); \ + fun(float,);fun(float16_t,);fun(bfloat16_t,) + +#ifdef ENABLE_LCCL_MIX +#define LCCL_TYPE_AIC_FUNC(fun) \ + fun(int, _mix_aic); fun(int8_t, _mix_aic); fun(int16_t, _mix_aic); fun(int64_t, _mix_aic); \ + fun(float, _mix_aic); fun(float16_t, _mix_aic); fun(bfloat16_t, _mix_aic) + +#define LCCL_TYPE_AIV_FUNC(fun) \ + fun(int, _mix_aiv); fun(int8_t, _mix_aiv); fun(int16_t, _mix_aiv); fun(int64_t, _mix_aiv); \ + fun(float, _mix_aiv); fun(float16_t, _mix_aiv); fun(bfloat16_t, _mix_aiv) +#else +#define LCCL_TYPE_AIC_FUNC(fun) \ + (void)0 + +#define LCCL_TYPE_AIV_FUNC(fun) \ + fun(int,); fun(int8_t,); fun(int16_t,); fun(int64_t,); \ + fun(float,); fun(float16_t,); fun(bfloat16_t,) +#endif + +#define LCCL_VADD_910B_TYPE_FUNC(fun) \ + fun(int);fun(int16_t); \ + fun(float);fun(float16_t) + +#define LCCL_QUANT_LOW_TYPE_FUNC(fun) \ + fun(int8_t) + +#endif \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/reduce_scatter.h b/comm/lcal/src/ascendc_kernels/reduce_scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..82405c41654667b5c03dc8e8c171f53ae85866a7 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/reduce_scatter.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_REDUCE_SCATTER_H +#define LCCL_REDUCE_SCATTER_H + +#include "sync_collectives.h" +#include "collectives.h" +using namespace AscendC; + +template +class ReduceScatter : protected Collectives { +public: + FORCE_INLINE_AICORE ReduceScatter(int rank, int rankSize, uint32_t extraFlag) + : Collectives(rank, rankSize, extraFlag) {} + + FORCE_INLINE_AICORE void Init(KERNELS_ARGS_FUN()) + { + Collectives::Init(KERNELS_ARGS_CALL()); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + atomOp = op; + DMANumMax = BLOCK_SIZE / sizeof(T); + corePerRank = blockNum / rankSize; + rankIDOfBlock = blockIdx / corePerRank; + dataDMAPerCore = CeilDiv(len, corePerRank); + blockIdxOfLen = blockIdx % corePerRank; + if (blockIdxOfLen == corePerRank - 1) { + blockDataNum = len - blockIdxOfLen * dataDMAPerCore; + } else { + blockDataNum = dataDMAPerCore; + } + inputOffset = rankIDOfBlock * len + (blockIdx % corePerRank) * dataDMAPerCore; + outputOffset = dataDMAPerCore * (blockIdx % corePerRank); + dstIpcDataOffset = IPC_DATA_OFFSET / sizeof(T) + rankIDOfBlock * len + outputOffset; + srcIpcDataOffset = IPC_DATA_OFFSET / sizeof(T) + rank * len + outputOffset; + srcInputGlobal.SetGlobalBuffer((__gm__ T*)input + inputOffset, blockDataNum); + if ((extraFlag & ExtraFlag::RDMA) == ExtraFlag::RDMA) { + dstOutputGlobal.SetGlobalBuffer((__gm__ T*)shareAddrs[rank] + srcIpcDataOffset, blockDataNum); + } else { + dstOutputGlobal.SetGlobalBuffer((__gm__ T*)output + outputOffset, blockDataNum); + } + dstIPCGlobal.SetGlobalBuffer((__gm__ T*)shareAddrs[rank] + dstIpcDataOffset, blockDataNum); + srcIPCGlobal.SetGlobalBuffer((__gm__ T*)shareAddrs[rankIDOfBlock] + srcIpcDataOffset, blockDataNum); + DumpLcclLogInfo(LogId::INIT, static_cast(op)); + } + FORCE_INLINE_AICORE void Process() + { + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + CpInputToBuffAndOutput(); + sync.SetInnerFlag(magic, 1); + sync.WaitRankInnerFlag(magic, 1, rank); + sync.WaitInnerFlag(magic, 1, rankIDOfBlock, rank * corePerRank + blockIdx % corePerRank); + if (rankIDOfBlock != rank) { + CpGM2GM(dstOutputGlobal, srcIPCGlobal, blockDataNum, atomOp); + } + DumpLcclLogInfo(LogId::PROCESS, static_cast(atomOp)); + } + + FORCE_INLINE_AICORE void CpInputToBuffAndOutput() + { + CpGM2GM(dstIPCGlobal, srcInputGlobal, blockDataNum, -1); + if ((extraFlag & ExtraFlag::RDMA) != ExtraFlag::RDMA) { + if ((blockIdx >= rank * corePerRank) && (blockIdx < (rank * corePerRank + corePerRank))) { + CpGM2GM(dstOutputGlobal, srcInputGlobal, blockDataNum, -1); + } + } + } + +protected: + GlobalTensor srcInputGlobal; + GlobalTensor srcIPCGlobal; + GlobalTensor dstIPCGlobal; + GlobalTensor dstOutputGlobal; + int blockIdxOfLen; + int DMANumMax; + int rankIDOfBlock; + int corePerRank; + int inputOffset; + int outputOffset; + int srcIpcDataOffset; + int dstIpcDataOffset; + int dataDMAPerCore; + int blockDataNum; + int atomOp; +}; +#endif // LCCL_REDUCE_SCATTER_H \ No newline at end of file diff --git a/comm/lcal/src/ascendc_kernels/sync_collectives.h b/comm/lcal/src/ascendc_kernels/sync_collectives.h new file mode 100644 index 0000000000000000000000000000000000000000..89f5ac66115a8bb7c1362c00f34ce2fad76b3ad5 --- /dev/null +++ b/comm/lcal/src/ascendc_kernels/sync_collectives.h @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCCL_SYNC_H +#define LCCL_SYNC_H + +#include "comm_args.h" + +using namespace AscendC; +using namespace Lcal; + +constexpr int64_t FLAG_UNIT_INT_NUM = 4; +constexpr int64_t SYNC_UNIT_SIZE = FLAG_UNIT_INT_NUM * sizeof(int64_t); +constexpr int64_t MAGIC_OFFSET = 32; +constexpr int64_t MAGIC_MASK = ~((1LL << MAGIC_OFFSET) - 1); +#ifdef ENABLE_LCCL_MIX +constexpr int32_t LCAL_BLOCK_NUM_MULTI = 2; +#else +constexpr int32_t LCAL_BLOCK_NUM_MULTI = 1; +#endif + +class SyncCollectives { +public: + __aicore__ inline SyncCollectives() {} + + __aicore__ inline void Init(int rank, int rankSize, GM_ADDR *shareAddrs) + { + this->rank = rank; + this->rankSize = rankSize; + this->shareAddrs = shareAddrs; + this->blockIdx = GetBlockIdx(); + this->blockNum = GetBlockNum() * LCAL_BLOCK_NUM_MULTI; + segmentCount = GetBlockNum() * LCAL_BLOCK_NUM_MULTI * FLAG_UNIT_INT_NUM; + localSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]); + basicSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + GetBlockIdx() * FLAG_UNIT_INT_NUM; + blockOuterSyncAddr = (__gm__ int64_t*)(shareAddrs[rank]) + segmentCount + GetBlockIdx() * FLAG_UNIT_INT_NUM; + TPipe pipe; + pipe.InitBuffer(tBuf, GetBlockNum() * SYNC_UNIT_SIZE); + } + + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag(localSyncAddr + eventID * FLAG_UNIT_INT_NUM, v); + } + + __aicore__ inline void SetSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank) + { + int64_t v = MergeMagicWithValue(magic, value); + SetFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, v); + } + + __aicore__ inline int32_t CalEventIdByMulBlockNum(int32_t blockMultiplier, int32_t targetCoreId) + { + return (blockMultiplier * blockNum) + targetCoreId; + } + + __aicore__ inline void WaitSyncFlag(int32_t magic, int32_t value, int32_t eventID, int32_t rank, + int32_t breakCycle = 0) + { + int64_t v = MergeMagicWithValue(magic, value); + WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[rank]) + eventID * FLAG_UNIT_INT_NUM, 1, v, breakCycle); + } + + __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(basicSyncAddr, value); + } + __aicore__ inline void SetInnerFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag((__gm__ int64_t*)(shareAddrs[setRank]) + setBlock * FLAG_UNIT_INT_NUM, value); + } + + __aicore__ inline void WaitInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankPartFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM, 1, value); + } + + __aicore__ inline void WaitRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + WaitOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); + } + + __aicore__ inline bool CheckRankInnerFlag(int32_t magic, int32_t eventID, int64_t waitRank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + return CheckOneRankAllFlag((__gm__ int64_t*)(shareAddrs[waitRank]), value); + } + + __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID) + { + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(blockOuterSyncAddr, value); + } + + __aicore__ inline void SetOuterFlag(int32_t magic, int32_t eventID, int64_t setRank, int64_t setBlock) + { + __gm__ int64_t *flagAddr = GetOuterFlagAddr(setRank, setBlock); + int64_t value = MergeMagicWithValue(magic, eventID); + SetFlag(flagAddr, value); + } + + __aicore__ inline void WaitOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, int64_t waitBlock) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr = GetOuterFlagAddr(waitRank, waitBlock); + WaitOneRankPartFlag(flagAddr, 1, value); + } + + __aicore__ inline void WaitOneRankOuterFlag(int32_t magic, int32_t eventID, int64_t rank) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + flagAddr = GetOuterFlagAddr(rank, 0); + WaitOneRankPartFlag(flagAddr, blockNum, value); + } + __aicore__ inline void WaitAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + int waitRank; + for (auto r = 0; r < rankSize; ++r) { + waitRank = (rank + r) % rankSize; + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + WaitOneRankPartFlag(flagAddr, flagNum, value); + } + } + + __aicore__ inline bool CheckAllRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t startBlock, + int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + int waitRank; + for (auto r = 0; r < rankSize; ++r) { + waitRank = (rank + r) % rankSize; + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + if (!CheckOneRankPartFlag(flagAddr, flagNum, value)) { + return false; + } + } + return true; + } + + __aicore__ inline void WaitAllRankOuterFlag(int32_t magic, int32_t eventID) + { + WaitAllRankPartOuterFlag(magic, eventID, 0, blockNum); + } + + __aicore__ inline bool CheckAllRankOuterFlag(int32_t magic, int32_t eventID) + { + return CheckAllRankPartOuterFlag(magic, eventID, 0, blockNum); + } + + __aicore__ inline void SetFlag(__gm__ int64_t* setAddr, int64_t setValue) + { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + GlobalTensor globalSet; + globalSet.SetGlobalBuffer(setAddr, FLAG_UNIT_INT_NUM); + LocalTensor localSet = tBuf.GetWithOffset(1, 0); + localSet.SetValue(0, setValue); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + DataCopy(globalSet, localSet, FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + tBuf.FreeTensor(localSet); + } + + __aicore__ inline void WaitFlag(__gm__ int64_t* waitAddr, int64_t waitValue) + { + WaitOneRankPartFlag(waitAddr, 1, waitValue); + } + + __aicore__ inline int64_t GetFlag(__gm__ int64_t* waitAddr) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(1, 0); + DataCopy(localWait, globalWait, FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + int64_t res = localWait.GetValue(0); + tBuf.FreeTensor(localWait); + return res; + } + + __aicore__ inline void WaitOneRankPartOuterFlag(int32_t magic, int32_t eventID, int64_t waitRank, + int64_t startBlock, int64_t flagNum) + { + int64_t value = MergeMagicWithValue(magic, eventID); + __gm__ int64_t *flagAddr; + flagAddr = GetOuterFlagAddr(waitRank, startBlock); + WaitOneRankPartFlag(flagAddr, flagNum, value); + } + + __aicore__ inline int64_t GetInnerFlag(int64_t waitRank, int64_t waitBlock) + { + return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + waitBlock * FLAG_UNIT_INT_NUM); + } + + __aicore__ inline int64_t GetOuterFlag(int64_t waitRank, int64_t waitBlock) + { + return GetFlag((__gm__ int64_t*)(shareAddrs[waitRank]) + segmentCount + waitBlock * FLAG_UNIT_INT_NUM); + } + +private: + __aicore__ inline int64_t MergeMagicWithValue(int32_t magic, int32_t value) + { + return (static_cast(magic) << MAGIC_OFFSET) | static_cast(value); + } + + __aicore__ inline __gm__ int64_t* GetInnerFlagAddr(int64_t flagRank, int64_t flagBlock) + { + return (__gm__ int64_t*)(shareAddrs[flagRank]) + flagBlock * FLAG_UNIT_INT_NUM; + } + + __aicore__ inline __gm__ int64_t* GetOuterFlagAddr(int64_t flagRank, int64_t flagBlock) + { + return (__gm__ int64_t*)(shareAddrs[flagRank]) + segmentCount + flagBlock * FLAG_UNIT_INT_NUM; + } + + __aicore__ inline void WaitOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue, + int32_t breakCycle = 0) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); + bool isSync = true; + do { + if (breakCycle > 0) { + int64_t systemCycleBefore = AscendC::GetSystemCycle(); + int64_t systemCycleAfter = AscendC::GetSystemCycle(); + while (systemCycleAfter - systemCycleBefore < breakCycle) { + systemCycleAfter = AscendC::GetSystemCycle(); + }; + } + DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + isSync = true; + for (auto i = 0; i < flagNum; ++i) { + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { + isSync = false; + break; + } + } + } while (!isSync); + tBuf.FreeTensor(localWait); + } + + __aicore__ inline void WaitOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) + { + WaitOneRankPartFlag(waitAddr, blockNum, checkValue); + } + + __aicore__ inline bool CheckOneRankPartFlag(__gm__ int64_t* waitAddr, int64_t flagNum, int64_t checkValue) + { + GlobalTensor globalWait; + globalWait.SetGlobalBuffer(waitAddr, flagNum * FLAG_UNIT_INT_NUM); + LocalTensor localWait = tBuf.GetWithOffset(flagNum * FLAG_UNIT_INT_NUM, 0); + DataCopy(localWait, globalWait, flagNum * FLAG_UNIT_INT_NUM); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + bool isSync = true; + for (auto i = 0; i < flagNum; ++i) { + int64_t v = localWait.GetValue(i * FLAG_UNIT_INT_NUM); + if ((v & MAGIC_MASK) != (checkValue & MAGIC_MASK) || v < checkValue) { + isSync = false; + break; + } + } + tBuf.FreeTensor(localWait); + return isSync; + } + + __aicore__ inline bool CheckOneRankAllFlag(__gm__ int64_t* waitAddr, int64_t checkValue) + { + return CheckOneRankPartFlag(waitAddr, blockNum, checkValue); + } + + int rank; + int rankSize; + int blockIdx; + int blockNum; + GM_ADDR *shareAddrs; + int64_t segmentCount; + __gm__ int64_t* localSyncAddr; + __gm__ int64_t* basicSyncAddr; + __gm__ int64_t* blockOuterSyncAddr; + TBuf tBuf; +}; + +#endif // LCCL_SYNC _H \ No newline at end of file diff --git a/comm/lcal/src/ccl_kernel_args.h b/comm/lcal/src/ccl_kernel_args.h new file mode 100644 index 0000000000000000000000000000000000000000..18fdffe7905322e8b34621182ce835869dbd6f33 --- /dev/null +++ b/comm/lcal/src/ccl_kernel_args.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_CCL_KERNEL_ARGS_H +#define LCAL_CCL_KERNEL_ARGS_H + +#include "lcal_types.h" +#include "comm_args.h" + +namespace Lcal { +struct AscendCCLKernelArgs { + const void *input = nullptr; + const void *output = nullptr; + const void *commArgsPtr = nullptr; + int64_t count = 0; + int64_t magic = 0; + int op = 0; + int root = 0; + int cycleCount = 0; + const void *scale = nullptr; + int64_t scaleCount = 0; + const void *offset = nullptr; +}; + +struct CCLGatherArgs { + const void *embTable = nullptr; + const void *lookup = nullptr; + const void *revData = nullptr; + int64_t lookupLen = 0; + int64_t embTableLen = 0; + int64_t embTableDim = 0; +}; +} +#endif // LCAL_CCL_KERNEL_ARGS_H diff --git a/comm/lcal/src/coc_kernel_args.cpp b/comm/lcal/src/coc_kernel_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89a30187fccb0861f83b463472278e1066bf5807 --- /dev/null +++ b/comm/lcal/src/coc_kernel_args.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coc_kernel_args.h" +#include +#include +#include +#include +#include "tiling.h" +#include "lcal_internal.h" +using namespace Mki; + +namespace Lcal { +int CoCKernelArgs::SetFFTSAddr() +{ + uint32_t fftsLen; + int error = MkiRtGetC2cCtrlAddr(&fftsAddr, &fftsLen); + if (error != MKIRT_SUCCESS) { + MKI_LOG(ERROR) << "MkiRtGetC2cCtrlAddr err"; + return LCAL_ERROR_MKIRT; + } + return LCAL_SUCCESS; +} + +void CoCKernelArgs::SetInputPkgArgs(CoCInputPkg &inputPkg) +{ + matrixA = inputPkg.matrixA; + matrixB = inputPkg.matrixB; + bias = inputPkg.bias; + gamma = inputPkg.gamma; + dequantScale = inputPkg.dequantScale; + dequantOffset = inputPkg.dequantOffset; + quantScale = inputPkg.quantScale; + quantOffset = inputPkg.quantOffset; + numLocalTokensPerExpertPtr = inputPkg.num_local_tokens_per_expert; + numGlobalTokensPerLocalExpertPtr = inputPkg.num_global_tokens_per_local_expert; + globalTokensPerLocalExpertMatrixPtr = inputPkg.global_tokens_per_expert_matrix; +} + +void CoCKernelArgs::SetOutputPkgArgs(CoCOutputPkg &outputPkg) +{ + output = outputPkg.output; + midOutput = outputPkg.midOutput; +} + +void CoCKernelArgs::SetWorkspacePtrArg(void *workspacePtr) +{ + workspace = workspacePtr; +} + +void CoCKernelArgs::SetParamDescArgs(const CoCParamDesc ¶mDesc) +{ + cocKernelParam.quantInfo = paramDesc.quantInfo; + cocKernelParam.twoDimTPInfo = paramDesc.twoDimTPInfo; + cocKernelParam.postInfo = paramDesc.postInfo; + cocKernelParam.weightNz = paramDesc.mmInfo.weightNz; + cocKernelParam.moeInfo = paramDesc.moeInfo; +} + +void CoCKernelArgs::SetCommArgs(const LcalComm &comm) +{ + commArgsPtr = comm.GetCommArgsPtr(); +} + +void CoCKernelArgs::SetCoCTilingDataArgs(const CoCTilingData &tilingData) +{ + pCocTiling = &(cocKernelParam.cocTilingData); + cocKernelParam.cocTilingData = tilingData; +} + +std::string CoCKernelArgs::ParamToString() +{ + std::string quantInfoString = "[QuantInfo]: dequantGranularity=" + + std::to_string(cocKernelParam.quantInfo.dequantGranularity) + "\n"; + auto moeInfo = cocKernelParam.moeInfo; + std::string moeInfoString = + std::string("[MoeInfo]: local_expert_nums=") + std::to_string(moeInfo.local_expert_nums) + + ", EP=" + std::to_string(static_cast(moeInfo.EP)) + + ", TP=" + std::to_string(static_cast(moeInfo.TP)) + + ", maxOutputSize=" + std::to_string(moeInfo.maxOutputSize) + + ", isMoe=" + std::to_string(static_cast(moeInfo.isMoe)) + "\n"; + std::string weightNzInfoString = "[weightNz]: weightNz=" + + std::to_string(cocKernelParam.weightNz) + "\n"; + std::string tilingInfoString = cocKernelParam.cocTilingData.ToString(); + return quantInfoString + moeInfoString + weightNzInfoString + tilingInfoString; +} +} diff --git a/comm/lcal/src/coc_kernel_args.h b/comm/lcal/src/coc_kernel_args.h new file mode 100644 index 0000000000000000000000000000000000000000..91ce2cd449531da530ba7aeaf55dc81372edb71a --- /dev/null +++ b/comm/lcal/src/coc_kernel_args.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef LCAL_COC_KERNEL_ARGS_H +#define LCAL_COC_KERNEL_ARGS_H + +#include +#include "tiling_args.h" +#include "lcal_comm.h" +#include "lcoc_args.h" + +namespace Lcal { +struct CoCKernelArgs { + void *matrixA = nullptr; + void *matrixB = nullptr; + void *bias = nullptr; + void *gamma = nullptr; + void *output = nullptr; + void *midOutput = nullptr; + void *workspace = nullptr; + void *dequantScale = nullptr; + void *dequantOffset = nullptr; + void *quantScale = nullptr; + void *quantOffset = nullptr; + void *commArgsPtr = nullptr; + uint64_t fftsAddr = 0; + + void *numLocalTokensPerExpertPtr = nullptr; + void *numGlobalTokensPerLocalExpertPtr = nullptr; + void *globalTokensPerLocalExpertMatrixPtr = nullptr; + CoCTilingData *pCocTiling = nullptr; + CoCKernelParam cocKernelParam = {}; + int SetFFTSAddr(); + void SetInputPkgArgs(CoCInputPkg &inputPkg); + void SetOutputPkgArgs(CoCOutputPkg &outputPkg); + void SetWorkspacePtrArg(void *workspacePtr); + void SetParamDescArgs(const CoCParamDesc ¶mDesc); + void SetCommArgs(const LcalComm &comm); + void SetCoCTilingDataArgs(const CoCTilingData &tilingData); + std::string ParamToString(); +}; + +} + +#endif // LCAL_COC_KERNEL_ARGS_H diff --git a/comm/lcal/src/kernels/CMakeLists.txt b/comm/lcal/src/kernels/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..598c35b372003936a28deedc245f79270ebbe64c --- /dev/null +++ b/comm/lcal/src/kernels/CMakeLists.txt @@ -0,0 +1,50 @@ +# +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# +include(../ascendc.cmake) +set(OP_NAMES pure_matmul matmul_allreduce matmul_reduce_scatter allgather_matmul allgather_matmul_reduce_scatter alltoallv_allgather_matmul matmul_reduce_scatter_alltoallv) + +file(GLOB KERNEL_FILES *.cpp) +set_source_files_properties(${KERNEL_FILES} PROPERTIES LANGUAGE CCE) +file(GLOB KERNEL_FILES2 *.cce) +set_source_files_properties(${KERNEL_FILES2} PROPERTIES LANGUAGE CCE) + +foreach(OP_NAME IN LISTS OP_NAMES) + add_library(lcoc_${OP_NAME}_aic_obj OBJECT coc_${OP_NAME}.cce) + target_compile_options(lcoc_${OP_NAME}_aic_obj PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIC_ARCH} + ) + add_library(lcoc_${OP_NAME}_aiv_obj OBJECT coc_${OP_NAME}.cce) + target_compile_options(lcoc_${OP_NAME}_aiv_obj PRIVATE + ${CCE_COMPILE_OPTION} + --cce-aicore-arch=${AIV_ARCH} + ) + add_custom_target(${OP_NAME}_o + COMMAND ${CMAKE_CCE_LINKER} -m aicorelinux -Ttext=0 + "CMakeFiles/lcoc_${OP_NAME}_aic_obj.dir/coc_${OP_NAME}*.o" + "CMakeFiles/lcoc_${OP_NAME}_aiv_obj.dir/coc_${OP_NAME}*.o" + ${SANITIZER_DEPEND_LIBS} + --static -o "lcal_coc_${OP_NAME}.o" --allow-multiple-definition + COMMAND truncate -c -s ${LCAL_1OP_BIN_SIZE} "lcal_coc_${OP_NAME}.o" + ) +endforeach() +# 生成文件名列表,每个都带有 .o 后缀 +set(OUTPUT_FILES "") +foreach(OP_NAME IN LISTS OP_NAMES) + list(APPEND OUTPUT_FILES "lcal_coc_${OP_NAME}.o") +endforeach() + +add_custom_target(lcoc_op + COMMAND cat ${OUTPUT_FILES} > lcoc_op.o +) +foreach(OP_NAME IN LISTS OP_NAMES) + add_dependencies(${OP_NAME}_o lcoc_${OP_NAME}_aic_obj lcoc_${OP_NAME}_aiv_obj) + add_dependencies(lcoc_op ${OP_NAME}_o) +endforeach() \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_add_bias_runner.cce b/comm/lcal/src/kernels/coc_add_bias_runner.cce new file mode 100644 index 0000000000000000000000000000000000000000..f79db13de6931f41c15bff7d1f892a8431c58777 --- /dev/null +++ b/comm/lcal/src/kernels/coc_add_bias_runner.cce @@ -0,0 +1,343 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef __COC_ADD_BIAS_RUNNER__ +#define __COC_ADD_BIAS_RUNNER__ + +#ifdef __DAV_C220_VEC__ + +#include +#include "coc_internal.cce" + +enum class BiasMode { ADD = 0, MOVE, ATOMIC_ADD }; + +template +class BaseSerialBiasAdder { +public: + __aicore__ explicit BaseSerialBiasAdder() = default; + + inline __aicore__ void SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN()) + { + this->gm_out = reinterpret_cast<__gm__ OutputDtype *>(gm_out); + this->gm_bias = reinterpret_cast<__gm__ OutputDtype *>(gm_bias); + + this->batch_size = batch_size; + this->m = m; + this->n = n; + + int32_t align_core_num = get_block_num() * get_subblockdim(); + int32_t align_core_idx = get_block_idx() * get_subblockdim() + get_subblockid(); + + if constexpr (MODE == BiasMode::MOVE || MODE == BiasMode::ATOMIC_ADD) { + max_len = Block32B::AlignDown(MAX_UB_BUFF / sizeof(OutputDtype)); + } else if constexpr (MODE == BiasMode::ADD) { + max_len = Block32B::AlignDown(MAX_UB_BUFF / (sizeof(OutputDtype) * 3)); + } + + int32_t n_round = Block32B::AlignUp(n); + m_per_loop = (n_round <= max_len) ? (max_len / n_round) : 1; + n_per_loop = (n_round <= max_len) ? n : max_len; + + int32_t m_per_core_base = m / align_core_num; + int32_t m_remainder = m % align_core_num; + int32_t m_offset_base = align_core_idx * m_per_core_base; + if (align_core_idx < m_remainder) { + m_this_core = m_per_core_base + 1; + m_offset_this_core = m_offset_base + align_core_idx; + } else { + m_this_core = m_per_core_base; + m_offset_this_core = m_offset_base + m_remainder; + } + } + + inline __aicore__ void Run() + { + if constexpr (MODE == BiasMode::ADD) { + AddBias(); + } else if constexpr (MODE == BiasMode::MOVE) { + MoveBias(); + } else if constexpr (MODE == BiasMode::ATOMIC_ADD) { + SetAtomicAdd(); + PipeBarrier(); + MoveBias(); + SetAtomicNone(); + PipeBarrier(); + } + } + + inline __aicore__ void Barrier() + { + FFTSCrossCoreSync(0, AIV_FINISH_ADD_BIAS_FLAG_ID); + WaitEvent(AIV_FINISH_ADD_BIAS_FLAG_ID); + } + +private: + inline __aicore__ void AddBias() + { + if constexpr (MODE != BiasMode::ADD) { + return; + } + + auto ub_bias = reinterpret_cast<__ubuf__ OutputDtype *>((uintptr_t)0); + auto ub_out1 = reinterpret_cast<__ubuf__ OutputDtype *>((uintptr_t)(max_len * sizeof(OutputDtype))); + auto ub_out2 = reinterpret_cast<__ubuf__ OutputDtype *>((uintptr_t)(max_len * sizeof(OutputDtype) * 2)); + bool ping = true; + + for (int32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + for (int32_t n_complete = 0, n_this_loop = n_per_loop; n_complete < n; n_complete += n_this_loop) { + n_this_loop = (n_complete + n_this_loop > n) ? (n - n_complete) : n_this_loop; + + // MTE2: ub_bias <- gm_bias + CopyGmToUbufAlign(ub_bias, gm_bias + n_complete, 1, n_this_loop, 0); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + BroadcastBiasforbias(ub_bias, n_this_loop); + + PipeBarrier(); + + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + + ProcessMLoop(n_complete, n_this_loop, ub_out1, ub_out2, ping, ub_bias); + + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + ping = !ping; + } + } + } + + inline __aicore__ void BroadcastBiasforbias(__ubuf__ OutputDtype *ub_bias, int32_t n_this_loop) + { + for (int32_t row_idx = 1; row_idx < m_per_loop; ++row_idx) { + CopyUB2UB(ub_bias + row_idx * Block32B::AlignUp(n), ub_bias, 0, 1, + Block32B::Count(n_this_loop), 0, 0); + } + } + + inline __aicore__ void ProcessMLoop(int32_t n_complete, int32_t n_this_loop, __ubuf__ OutputDtype *ub_out1, + __ubuf__ OutputDtype *ub_out2, bool ping, __ubuf__ OutputDtype *ub_bias) + { + for (int32_t m_complete = 0, m_this_loop = m_per_loop; m_complete < m_this_core; m_complete += m_this_loop) { + m_this_loop = (m_complete + m_this_loop > m_this_core) ? (m_this_core - m_complete) : m_this_loop; + + auto ub_out = ping ? ub_out1 : ub_out2; + auto event_id = ping ? EVENT_ID1 : EVENT_ID2; + int32_t out_offset = (m_offset_this_core + m_complete) * n + n_complete; + + WaitFlag(event_id); + + // MTE2: ub_out <- gm_out + CopyGmToUbufAlign(ub_out, gm_out + out_offset, m_this_loop, n_this_loop, n - n_this_loop); + + SetFlag(event_id); + WaitFlag(event_id); + + // V: ub_out <- ub_out + ub_bias + AddBiasToOutput(ub_out, ub_bias, m_this_loop, n_this_loop); + + SetFlag(event_id); + WaitFlag(event_id); + + // MTE3: gm_out <- ub_out + CopyUbufToGmAlign(gm_out + out_offset, ub_out, m_this_loop, n_this_loop, n - n_this_loop); + + SetFlag(event_id); + } + } + + inline __aicore__ void AddBiasToOutput(__ubuf__ OutputDtype *ub_out, __ubuf__ OutputDtype *ub_bias, + int32_t m_this_loop, int32_t n_this_loop) + { + int32_t n_blocks = m_this_loop * Block32B::Count(n_this_loop); + int32_t repeat_times = DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT); + uint8_t repeat = UINT8_MAX; + for (int32_t repeat_complete = 0; repeat_complete < repeat_times; repeat_complete += repeat) { + repeat = (repeat_complete + repeat > repeat_times) ? (repeat_times - repeat_complete) : repeat; + + int32_t vadd_offset = repeat_complete * Block256B::size; + Vadd(ub_out + vadd_offset, ub_out + vadd_offset, ub_bias + vadd_offset, repeat, 1, 1, 1, 8, 8, 8); + } + } + + inline __aicore__ void MoveBias() + { + if constexpr (MODE != BiasMode::MOVE && MODE != BiasMode::ATOMIC_ADD) { + return; + } + + auto ub_base = reinterpret_cast<__ubuf__ OutputDtype *>((uintptr_t)0); + + for (int32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + ProcessBias(ub_base); + } + } + + inline __aicore__ void ProcessBias(__ubuf__ OutputDtype *ub_base) + { + int32_t n_this_loop = n_per_loop; + for (int32_t n_complete = 0; n_complete < n; n_complete += n_this_loop) { + if (n_complete + n_this_loop > n) { + n_this_loop = n - n_complete; + } + + // MTE2: ub_base <- gm_bias + CopyGmToUbufAlign(ub_base, gm_bias + n_complete, 1, n_this_loop, 0); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + BroadcastBias(ub_base, n_this_loop); + + // MTE3: gm_out <- ub_base + CopyBiasToOutput(n_complete, n_this_loop, ub_base); + } + } + + inline __aicore__ void BroadcastBias(__ubuf__ OutputDtype *ub_base, int32_t n_this_loop) + { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + for (int32_t row_idx = 1; row_idx < m_per_loop; ++row_idx) { + CopyUB2UB(ub_base + row_idx * Block32B::AlignUp(n), ub_base, 0, 1, + Block32B::Count(n_this_loop), 0, 0); + } + } + + inline __aicore__ void CopyBiasToOutput(int32_t n_complete, int32_t n_this_loop, __ubuf__ OutputDtype *ub_base) + { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int32_t m_this_loop = m_per_loop; + for (int32_t m_complete = 0; m_complete < m_this_core; m_complete += m_this_loop) { + if (m_complete + m_this_loop > m_this_core) { + m_this_loop = m_this_core - m_complete; + } + + CopyUbufToGmAlign(gm_out + (m_offset_this_core + m_complete) * n + n_complete, ub_base, m_this_loop, + n_this_loop, n - n_this_loop); + } + } + + __gm__ OutputDtype *gm_out; + __gm__ OutputDtype *gm_bias; + + int32_t batch_size; + int32_t m; + int32_t n; + + int32_t m_this_core; + int32_t m_offset_this_core; + + int32_t m_per_loop; + int32_t n_per_loop; + + int32_t max_len; + int32_t repeat_per_loop; +}; + +template +class PureMatmulBiasAdder { + static constexpr auto MODE = std::is_same::value ? BiasMode::ADD : BiasMode::ATOMIC_ADD; + +public: + __aicore__ explicit PureMatmulBiasAdder() = default; + + inline void __aicore__ SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN()) + { + base_adder.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + inline void __aicore__ Run() + { + base_adder.Run(); + base_adder.Barrier(); + } + +private: + BaseSerialBiasAdder base_adder; +}; + +template +class MatmulAllReduceBiasAdder { +public: + __aicore__ explicit MatmulAllReduceBiasAdder() = default; + + inline void __aicore__ SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN()) + { + base_adder.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + inline void __aicore__ Run() + { + base_adder.Run(); + base_adder.Barrier(); + } + +private: + BaseSerialBiasAdder base_adder; +}; + +template +class MatmulReduceScatterBiasAdder { + static constexpr auto MODE = std::is_same::value ? BiasMode::ADD : BiasMode::ATOMIC_ADD; + +public: + __aicore__ explicit MatmulReduceScatterBiasAdder() = default; + + inline void __aicore__ SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN()) + { + m = m / rank_size; + base_adder.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + inline void __aicore__ Run() + { + base_adder.Run(); + base_adder.Barrier(); + } + +private: + BaseSerialBiasAdder base_adder; +}; + +template +class AllGatherMatmulBiasAdder { + static constexpr auto MODE = std::is_same::value ? BiasMode::ADD : BiasMode::ATOMIC_ADD; + +public: + __aicore__ explicit AllGatherMatmulBiasAdder() = default; + + inline void __aicore__ SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN()) + { + m = m * rank_size; + base_adder.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + inline void __aicore__ Run() + { + base_adder.Run(); + base_adder.Barrier(); + } + +private: + BaseSerialBiasAdder base_adder; +}; + +#endif + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allgather.cce b/comm/lcal/src/kernels/coc_allgather.cce new file mode 100644 index 0000000000000000000000000000000000000000..e63f160d6b180a580c454f10726e18ad032e7364 --- /dev/null +++ b/comm/lcal/src/kernels/coc_allgather.cce @@ -0,0 +1,435 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template // T: allgather type; MatType: matmul type +class AllGather : public CocCommBase { +public: + __aicore__ explicit AllGather() {}; + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) + { + CocCommBase::SetArgs(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + if constexpr (HAVE_BIAS) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + need_dequant = workspace_info.gm_accum; + if (need_dequant) { + serial_dequant_runner.SetArgs(reinterpret_cast<__gm__ bfloat16_t *>(gm_out), workspace_info, + reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale), + reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset), dequant_granularity, 1, m * rank_size, n); + } + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + this->gm_out = aligned_a ? reinterpret_cast<__gm__ T *>(workspace_info.gm_a_align) : gm_a; + gm_a_pingpong_size = m0 * k_align * p_value * rank_size; + cal_count = DivCeil(m_loop, p_value); + } + + + + FORCE_INLINE_AICORE void EndFlagsAndBias() + { + ResetIpcFlags(2); + + if (aiv_idx == 1 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + } + + FORCE_INLINE_AICORE void MoveResultFromSrcToDst(__gm__ T *gm_src, __gm__ T *gm_dst, + int32_t len) + { + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultToDst(gm_src, gm_dst, len); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE void MoveResultToDst(__gm__ T *gm_src, __gm__ T *gm_dst, + int32_t len) + { + int32_t ping_pong_move_count = (len + max_ub_ping_pong_size - 1) / max_ub_ping_pong_size; + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = len - move_idx * max_ub_ping_pong_size; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, gm_src, 1, actual_move_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + CopyUbufToGm(gm_dst, ub_buff_st, 1, actual_move_size * sizeof(T) / 32, 0, 0); + gm_src += max_ub_ping_pong_size; + gm_dst += max_ub_ping_pong_size; + SetFlag(event_id); + } + } + + + FORCE_INLINE_AICORE + void MoveToOtherRankWithSkip(__gm__ T *gm_src, int32_t rank_offset, int32_t len, + int32_t rank_st, int32_t skip_num, int32_t group_num, int32_t rank_scope) + { + int32_t ping_pong_move_count = (len + max_ub_ping_pong_size - 1) / max_ub_ping_pong_size; + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = len - move_idx * max_ub_ping_pong_size; + } + int32_t block_len = actual_move_size * sizeof(T) / 32; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, gm_src, 1, block_len, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + int32_t dst_rank = rank_st % rank_scope; + for (int32_t cycle_idx = 0; cycle_idx < group_num; ++cycle_idx) { + if (dst_rank != rank && dst_rank < rank_size) { + CopyUbufToGm(buff[dst_rank] + rank_offset, ub_buff_st, 1, block_len, 0, 0); + } + dst_rank = (dst_rank + skip_num) % rank_scope; + } + gm_src += max_ub_ping_pong_size; + rank_offset += max_ub_ping_pong_size; + SetFlag(event_id); + } + } + + FORCE_INLINE_AICORE + void MoveWithSplit(__gm__ T *gm_src, int32_t rank_offset, int32_t len) + { + int32_t data_split = DivCeil(len, len_per_loop); + int32_t data_block = len_per_loop; // 每份数据量 + int32_t rank_st = core_idx; + int32_t skip_num = comm_npu_split; + int32_t group_num = DivCeil(rank_size, comm_npu_split); + int32_t scope = comm_npu_split * group_num; + int32_t data_offset = -data_block; // 当前份数据的起始位置 + + if (is_91093) { // 卡间通信:91093只copy奇偶相同的卡 + rank_st = rank_st * A3_DIE_NUM + (rank % A3_DIE_NUM); + group_num = DivCeil(group_num, A3_DIE_NUM); + skip_num = skip_num * A3_DIE_NUM; + } + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + for (int32_t data_block_idx = 0; data_block_idx < data_split; ++data_block_idx) { + data_offset += data_block; // 当前份数据的起始位置 + data_block = data_block_idx == data_split - 1 ? len - data_offset : data_block; // 当前份数据量 + int32_t num_per_core = DivCeil(data_block, comm_data_split); + + int32_t data_src = data_offset + (core_idx / comm_npu_split) * num_per_core; + int32_t data_len = data_block + data_offset - data_src; + data_len = data_len >= num_per_core ? num_per_core : data_len; + // npu 方向:一份数据先发送到所有目标卡,再发送下一份数据,以此类推 + if (comm_direct) { + MoveToOtherRankWithSkip(gm_src + data_src, rank_offset + data_src, data_len, + rank_st, comm_npu_split, group_num, scope); + continue; + } + // data len 方向:所有的数据先发送到目标卡0,再发送到目标卡1,以此类推 + int32_t dst_rank = rank_st % scope; + for (int32_t rank_group_idx = 0; rank_group_idx < group_num; ++rank_group_idx) { + if (dst_rank != rank && dst_rank < rank_size) { + MoveResultToDst(gm_src + data_src, buff[dst_rank] + rank_offset + data_src, data_len); + } + dst_rank = (dst_rank + comm_npu_split) % scope; + } + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE void RunWithSplit() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + + int64_t data_len = static_cast(m) * k_align; // 数据量 + int32_t num_per_rank_move = m0 * k_align * p_value; // 每轮搬运到其他卡的数据量 + int64_t src_offset = 0; // 当前份数据的起始位置 + int64_t rank_offset = rank * num_per_rank_move; + for (int32_t cal_idx = 0; cal_idx < cal_count + MAX_BLOCK_COUNT; ++cal_idx) { + uint64_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + + if (cal_idx == cal_count - 1) { + num_per_rank_move = data_len - src_offset; + } + + // wait aic + if (cal_idx >= MAX_BLOCK_COUNT) { + WaitEvent(flag_idx); + } + // Step 1: AIV sync + SetAndWaitAivSync(flag_idx); + + if (cal_idx < cal_count) { + // Step 2: Rank sync + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + // Step 3: AIV sync + SetAndWaitAivSync(flag_idx); + // Step 4: Move + if (aiv_idx == 0 && core_idx < core_count) { + int64_t gm_rank_offset = flag_idx * gm_a_pingpong_size + rank_offset; + MoveWithSplit(gm_out + src_offset, gm_rank_offset, num_per_rank_move); + src_offset += num_per_rank_move; + } + // Step 5: AIV Sync + SetAndWaitAivSync(flag_idx); + // Step 6: Rank Sync + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + } + // aiv之间同步 + SetAndWaitAivSync(flag_idx); // 通信后aiv同步 + // 发送aic同步 + SetAicSync(flag_idx); + } + if (need_dequant) { + serial_dequant_runner.Run(); + } + EndFlagsAndBias(); + } + + FORCE_INLINE_AICORE void DataCopySio(int32_t cal_idx_sio, int32_t copy_len_sio) + { + if (cal_idx_sio < 0 || cal_idx_sio >= cal_count) { + return; + } + int32_t flag_idx_sio = cal_idx_sio % BLOCK_COUNT_3; + int32_t len_per_core = copy_len_sio / SIO_TOTAL_CORE_NUM; + int32_t sio_core_idx = aiv_idx * core_num + core_idx - core_count; + int32_t core_offset = sio_core_idx * len_per_core; + int64_t src_offset_sio = cal_idx_sio * p_value * m0 * k_align; + + if (sio_core_idx >= 0 && sio_core_idx < SIO_TOTAL_CORE_NUM) { + for (int32_t src_rank = rank % 2; src_rank < rank_size; src_rank += 2) { + int32_t sio_rank_offset = flag_idx_sio * gm_a_pingpong_size + src_rank * p_value * m0 * k_align; + __gm__ T *src_addr = buff[rank] + sio_rank_offset + core_offset; + if (src_rank == rank) { + src_addr = gm_out + src_offset_sio + core_offset; + } + MoveResultFromSrcToDst(src_addr, buff[rank ^ 1] + sio_rank_offset + core_offset, len_per_core); + } + } + } + + FORCE_INLINE_AICORE void RunWithSio() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + int32_t copy_len_hccs = p_value * m0 * k_align; + int32_t copy_len_sio = p_value * m0 * k_align; + + for (int32_t cal_idx = 0; cal_idx < cal_count + BLOCK_COUNT_3; ++cal_idx) { + int32_t cal_idx_sio = cal_idx - 1; + uint64_t flag_idx = cal_idx % BLOCK_COUNT_3; + uint64_t flag_idx_sio = cal_idx_sio % BLOCK_COUNT_3; + int64_t src_offset = cal_idx * p_value * m0 * k_align; + int32_t rank_offset = flag_idx * gm_a_pingpong_size + rank * p_value * m0 * k_align; + + // 一次copy p_value * m0 行 + if (cal_idx == cal_count - 1) { + copy_len_hccs = (m - cal_idx * p_value * m0) * k_align; + } + + if (cal_idx_sio == cal_count - 1) { + copy_len_sio = (m - cal_idx_sio * p_value * m0) * k_align; + } + + // wait aic + if (cal_idx >= BLOCK_COUNT_3) { + WaitEvent(flag_idx); + } + // Step 1: AIV sync + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + + if (cal_idx < cal_count + 1) { + // Step 2: Rank sync + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + } + + // HCCS part + if (cal_idx < cal_count && core_idx < core_count) { + // Step 4: Move Hccs + MoveWithSplit(gm_out + src_offset, rank_offset, copy_len_hccs); + } + // SIO part + DataCopySio(cal_idx_sio, copy_len_sio); + + if (cal_idx < cal_count + 1) { + // Step 5: AIV Sync + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + // Step 6: Rank Sync + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + } + // aiv之间同步 + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + + // 发送aic同步 + if (cal_idx >= 1) { + SetAicSync(flag_idx_sio); + } + } + EndFlagsAndBias(); + } + + FORCE_INLINE_AICORE void Run() + { + if (is_91093) { + RunWithSio(); + } else { + RunWithSplit(); + } + } + +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_idx; + using CocCommBase::core_num; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_direct; + using CocCommBase::swizzl_count; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::is_91093; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::comm_npu_split; + using CocCommBase::comm_data_split; + using CocCommBase::comm_direct; + using CocCommBase::len_per_loop; + using CocCommBase::core_count; + using CocCommBase::weight_nz; + using CocCommBase::local_expert_nums; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::EP; + using CocCommBase::TP; + using CocCommBase::flag_offset; + int32_t m_align; + int32_t k_align; + int32_t n_align; + int32_t aligned_a; + int32_t aligned_b; + int32_t cal_count; + int32_t gm_a_pingpong_size; + bool need_dequant; + Preprocessor preprocessor; + AllGatherMatmulBiasAdder add_bias_runner; + SerialDequantRunner serial_dequant_runner; +}; + +constexpr int32_t NO_BIAS_MASK4 = 0b000000 | 0b100000 | 0b010000 | 0b110000 | 0b001000 | 0b101000 | 0b011000 | 0b111000; +constexpr int32_t BIAS_MASK4 = 0b000010 | 0b100010 | 0b010010 | 0b110010 | 0b001010 | 0b101010 | 0b011010 | 0b111010; + +template +inline __aicore__ void CocAllGatherMatmulAiv(COC_ARGS_FUN(T)) +{ + // write + + + AllGather allgather_write_without_bias; + AllGather allgather_write_with_bias; + AllGather allgather_int8_write_without_bias; + AllGather allgather_int8_write_with_bias; + SetAtomicNone(); + SetMaskNorm(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t write_to_other_rank = cocTilingData->write2OtherRank; + // swizzl = 0 transa = 0 transb = 0 splitk = 0 bias = 0 int8 = 0 + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + allgather_write_without_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + allgather_write_with_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_with_bias.Run(); + break; + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + allgather_int8_write_without_bias.SetArgs(COC_ARGS_CALL_INT8()); + allgather_int8_write_without_bias.Run(); + break; + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + allgather_int8_write_with_bias.SetArgs(COC_ARGS_CALL_INT8()); + allgather_int8_write_with_bias.Run(); + break; + default : + break; + } + PipeBarrier(); +} + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allgather_matmul.cce b/comm/lcal/src/kernels/coc_allgather_matmul.cce new file mode 100644 index 0000000000000000000000000000000000000000..8e7812818e13467130298c4647bd78c5e7be6228 --- /dev/null +++ b/comm/lcal/src/kernels/coc_allgather_matmul.cce @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#include "coc_ppmatmul_switch.cce" +#include "coc_allgather.cce" +#include "coc_allgather_v2.cce" + +#ifdef __DAV_C220_CUBE__ +// Matmul in LcalAllGatherMatmul +#define COC_ALL_GATHER_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmul_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +// Matmul in LcalAllGatherMatmulV2 +#define COC_ALL_GATHER_MATMUL_V2_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmulV2_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +#elif __DAV_C220_VEC__ +// AllGather in LcalAllGatherMatmul +#define COC_ALL_GATHER_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmul_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + CocAllGatherMatmulAiv(COC_ARGS_CALL()); \ +} + +// AllGather in LcalAllGatherMatmul +#define COC_ALL_GATHER_MATMUL_V2_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmulV2_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + CocAllGatherMatmulV2Aiv(COC_ARGS_CALL()); \ +} + +#endif + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // 910B support bf16 +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_ALL_GATHER_MATMUL_FUNC_AUTO_DEF); +COC_TYPE_FUNC(COC_ALL_GATHER_MATMUL_V2_FUNC_AUTO_DEF); + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allgather_matmul_reduce_scatter.cce b/comm/lcal/src/kernels/coc_allgather_matmul_reduce_scatter.cce new file mode 100644 index 0000000000000000000000000000000000000000..a630e6bdb444991d7a9be6234717aa559f2d1c89 --- /dev/null +++ b/comm/lcal/src/kernels/coc_allgather_matmul_reduce_scatter.cce @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#include "coc_ppmatmul_switch.cce" +#include "coc_allgather_reducescatter.cce" +#ifdef __DAV_C220_CUBE__ + +// Matmul in LcalAllGatherMatmulReduceScatter +#define COC_ALL_GATHER_MATMUL_REDUCESCATTER_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmulReduceScatter_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + return CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +#elif __DAV_C220_VEC__ +// Vector in AllGatherMatmulReduceScatter +#define COC_ALL_GATHER_MATMUL_REDUCESCATTER_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllGatherMatmulReduceScatter_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + return CocAllGatherMatmulReduceScatterAiv(COC_ARGS_CALL()); \ +} +#endif + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // 910B support bf16 +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_ALL_GATHER_MATMUL_REDUCESCATTER_FUNC_AUTO_DEF); +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allgather_reducescatter.cce b/comm/lcal/src/kernels/coc_allgather_reducescatter.cce new file mode 100644 index 0000000000000000000000000000000000000000..84c1f80332913ecea9c283df37b0665ee3d50645 --- /dev/null +++ b/comm/lcal/src/kernels/coc_allgather_reducescatter.cce @@ -0,0 +1,501 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; +template +class AllGatherReduceScatter : public CocCommBase { +public: + __aicore__ explicit AllGatherReduceScatter() {}; + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) { + CocCommBase::SetArgsForReduce(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + if constexpr (HAVE_BIAS) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + m_align = (m + CUBE_MATRIX_SIZE - 1) / CUBE_MATRIX_SIZE * CUBE_MATRIX_SIZE; + k_align = (k + CUBE_MATRIX_SIZE - 1) / CUBE_MATRIX_SIZE * CUBE_MATRIX_SIZE; + n_align = (n + CUBE_MATRIX_SIZE - 1) / CUBE_MATRIX_SIZE * CUBE_MATRIX_SIZE; + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + this->gm_a = aligned_a ? reinterpret_cast<__gm__ T *>(workspace_info.gm_a_align) : gm_a; + // 确定本卡的ag和rs分别的idx + if (inner_dim_is_Ag) { + this->rank_ag_idx = rank % ag_dim; + this->rank_rs_idx = rank / ag_dim; + this->other_rank_ag_idx = other_rank % ag_dim; + this->other_rank_rs_idx = other_rank / ag_dim; + }else { + this->rank_ag_idx = rank / rs_dim; + this->rank_rs_idx = rank % rs_dim; + this->other_rank_ag_idx = other_rank / rs_dim; + this->other_rank_rs_idx = other_rank % rs_dim; + } + + twod_big_dim = ag_dim > rs_dim ? ag_dim: rs_dim; + gm_a_pingpong_size = m0 * k_align * p_value * twod_big_dim; + gm_c_pingpong_size = p_value * twod_big_dim * n_loop * m0 * n0; + m_loop_per_bigdim = DivCeil(m_loop * ag_dim, twod_big_dim); + m_per_bigdim = m * ag_dim / twod_big_dim; + comm_count = DivCeil(batch_size * m_loop_per_bigdim, p_value); + ag_part_dim = twod_big_dim / ag_dim; + rs_part_dim = twod_big_dim / rs_dim; + + ag_comm_npu_split = comm_npu_split; + ag_comm_data_split = comm_data_split; + ag_len_per_loop = len_per_loop; + ag_comm_direct = comm_direct; + + rs_comm_npu_split = extra_comm_npu_split; + rs_comm_data_split = extra_comm_data_split; + rs_len_per_loop = extra_len_per_loop; + + ag_core_count = ag_comm_npu_split * ag_comm_data_split; + rs_core_count = rs_comm_npu_split * rs_comm_data_split; + + ag_max_ub_ping_pong_size = (max_ub_single_dma_size / 2) / n0 * n0; + rs_max_ub_ping_pong_size = (extra_ub_move_num / 2) / n0 * n0; + } + + FORCE_INLINE_AICORE void CopyGMToGM(__gm__ T* gm_src, __gm__ T* gm_dst, int32_t copy_size) { + auto ub0 = output_UB_T[0]; + auto ub1 = output_UB_T[1]; + int32_t interm_offset = 0; + for (int32_t move_idx = 0; interm_offset < copy_size; ++move_idx){ + uint32_t data_size = interm_offset + ag_max_ub_ping_pong_size < copy_size ? ag_max_ub_ping_pong_size : copy_size - interm_offset; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub = (move_idx & 1) ? ub0 : ub1; + WaitFlag(event_id); + CopyGmToUbuf(ub, gm_src + interm_offset, 1, data_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + CopyUbufToGm(gm_dst + interm_offset, ub, 1, data_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + interm_offset += data_size; + } + } + + FORCE_INLINE_AICORE + void MoveToOtherRankWithSkip(__gm__ T *gm_src, int32_t rank_offset, int32_t len, + int32_t rank_st, int32_t skip_num, int32_t group_num) + { + int32_t ping_pong_move_count = (len + ag_max_ub_ping_pong_size - 1) / ag_max_ub_ping_pong_size; + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = ag_max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = len - move_idx * ag_max_ub_ping_pong_size; + } + int32_t block_len = actual_move_size * sizeof(T) / 32; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, gm_src, 1, block_len, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + int32_t dst_rank = rank_st % ag_dim; + for (int32_t cycle_idx = 0; cycle_idx < group_num; ++cycle_idx) { + int32_t real_rank; + if (inner_dim_is_Ag) { + real_rank = dst_rank + rank / ag_dim * ag_dim; + } else { + real_rank = dst_rank * rs_dim + rank % rs_dim; + } + if (real_rank != rank && dst_rank < ag_dim) { + CopyUbufToGm(buff[real_rank] + rank_offset, ub_buff_st, 1, block_len, 0, 0); + } + dst_rank = (dst_rank + skip_num) % ag_dim; + } + gm_src += ag_max_ub_ping_pong_size; + rank_offset += ag_max_ub_ping_pong_size; + SetFlag(event_id); + } + } + + FORCE_INLINE_AICORE + void MoveWithSplit(__gm__ T *gm_src, int32_t rank_offset, int32_t len) + { + int32_t data_split = DivCeil(len, ag_len_per_loop); + int32_t data_block = ag_len_per_loop; // 每份数据量 + int32_t group_num = ag_dim / ag_comm_npu_split; + int32_t data_offset = -data_block; // 当前份数据的起始位置 + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + for (int32_t data_block_idx = 0; data_block_idx < data_split; ++data_block_idx) { + data_offset += data_block; // 当前份数据的起始位置 + data_block = data_block_idx == data_split - 1 ? len - data_offset : data_block; // 当前份数据量 + int32_t num_per_core = DivCeil(data_block, ag_comm_data_split); + + int32_t data_src = data_offset + (core_idx / ag_comm_npu_split) * num_per_core; + int32_t data_len = data_block + data_offset - data_src; + data_len = data_len >= num_per_core ? num_per_core : data_len; + if (ag_comm_direct) { + MoveToOtherRankWithSkip(gm_src + data_src, rank_offset + data_src, data_len, + core_idx, ag_comm_npu_split, group_num); + continue; + } + // data len 方向:所有的数据先发送到目标卡0,再发送到目标卡1,以此类推 + int32_t dst_rank = core_idx % ag_dim; + for (int32_t rank_group_idx = 0; rank_group_idx < group_num; ++rank_group_idx) { + int32_t real_rank; + if (inner_dim_is_Ag) { + real_rank = dst_rank + rank / ag_dim * ag_dim; + } else { + real_rank = dst_rank * rs_dim + rank % rs_dim; + } + if (real_rank != rank && dst_rank < ag_dim) { + CopyGMToGM(gm_src + data_src, buff[real_rank] + rank_offset + data_src, data_len); + } + dst_rank = (dst_rank + ag_comm_npu_split) % ag_dim; + } + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE int32_t GetRealCoreIdx(int32_t index, int32_t rank_per_core) + { + int32_t core_index = core_idx - ag_core_count; + int32_t core_rank_offset = (core_index / rs_comm_data_split) * rank_per_core; + int32_t rank_idx_rot = (index + core_index) % rank_per_core; + int32_t real_core_idx = core_rank_offset + rank_idx_rot; + + return real_core_idx; + } + + FORCE_INLINE_AICORE void GetLenPerCore(int32_t rank_total, int32_t loop_index, int32_t &m_in_core, int32_t &buff_offset) + { + int32_t core_index = core_idx - ag_core_count; + int32_t before_core_offset = rs_len_per_loop * rs_comm_data_split * loop_index; + int32_t loop_total = rank_total - before_core_offset; + int32_t real_core_offset = core_index % rs_comm_data_split * rs_len_per_loop; + + buff_offset = before_core_offset + real_core_offset; + + m_in_core = (real_core_offset >= loop_total) ? 0 : + ((real_core_offset + rs_len_per_loop) > loop_total ? + loop_total - real_core_offset : rs_len_per_loop); + } + + FORCE_INLINE_AICORE void FirstStepInOutWithSplit(int32_t rank_total, int32_t rank_buff_offset, int32_t comm_idx, int32_t flag_idx, int64_t out_part_offset) + { + SetAtomicAdd(); + PipeBarrier(); + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + + int32_t rank_per_core = rs_dim / rs_comm_npu_split; + int32_t m_per_core = rank_total / rs_comm_data_split; + int32_t data_split_num = DivCeil(m_per_core, rs_len_per_loop); + for (int32_t loop_idx = 0; loop_idx < data_split_num; loop_idx++) { + int32_t m_in_core; + int32_t offset; + GetLenPerCore(rank_total, loop_idx, m_in_core, offset); + + for (int32_t rank_idx = 0; rank_idx < rank_per_core; rank_idx++) { + int32_t real_rank_idx_tmp = GetRealCoreIdx(rank_idx, rank_per_core); + int32_t real_rank_idx; + if (inner_dim_is_Ag){ + real_rank_idx = real_rank_idx_tmp * ag_dim + rank % ag_dim; + } else { + real_rank_idx = real_rank_idx_tmp + rank / rs_dim * rs_dim; + } + + if (real_rank_idx == rank) + continue; + + FirstStepInOut(m_in_core, buff[real_rank_idx], rank_buff_offset, offset, comm_idx, flag_idx, out_part_offset); + } + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + + + FORCE_INLINE_AICORE void FirstStepInOut(int32_t mat_blocks_size, __gm__ T *input, int32_t gm_offset, int32_t offset, int32_t comm_idx, int32_t flag_idx, int64_t out_part_offset) { + int32_t ping_pong_move_count = DivCeil(mat_blocks_size, rs_max_ub_ping_pong_size); // max_ub_ping_pong_size一定是N0的倍数,但不一定是M0*N0的倍数 + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = rs_max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = mat_blocks_size - move_idx * rs_max_ub_ping_pong_size; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + // 从其他卡读的matrix是多个小的m0*n0块顺序排布,写的时候需要重排 + CopyGmToUbuf(ub_buff_st, input + gm_offset + offset + move_idx * rs_max_ub_ping_pong_size, 1, actual_move_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + int32_t move_num_offset = offset + move_idx * rs_max_ub_ping_pong_size; + auto ub_buff = ub_buff_st; + int32_t left_m = actual_move_size / n0; + while (left_m > 0) { + // 获取写到本卡的m和n的idx + int32_t loop_idx = (move_num_offset / (m0 * n0)); + int32_t n_idx = loop_idx % n_loop; + int64_t m_idx = comm_idx * p_value + loop_idx / n_loop; + int32_t actual_m = (m_idx == (m_loop_per_bigdim - 1)) ? (m_per_bigdim - m_idx * m0) : m0; + int32_t actual_n = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + int32_t m_offset = (move_num_offset % (m0 * n0)) / n0; // 当前一块起点对应的m,在当前块的位置 + int32_t actual_move_m; + if (m_offset >= actual_m) { // m0=128,最后一个小块m=120, + actual_move_m = m0 < m_offset + left_m ? m0 - m_offset : left_m; + // m0 - m_offset表示当前块剩下的一小段,跳过; + } else { + actual_move_m = actual_m < m_offset + left_m ? actual_m - m_offset : left_m; + // left_m较大,则该块copy完,下次再copy下一块; + // left_m较小,则只copy left_m的部分 + int64_t out_buff_offset = (m_idx * m0 + m_offset) * n + n_idx * n0; + CopyUbufToGmUnknown(n % BLOCK_SIZE_16 == 0, gm_out + out_part_offset + out_buff_offset, + ub_buff, actual_move_m, actual_n * sizeof(T), (n0 - actual_n) * sizeof(T) / 32, (n - actual_n) * sizeof(T)); + } + left_m -= actual_move_m; + move_num_offset += actual_move_m * n0; + ub_buff += actual_move_m * n0; + } + SetFlag(event_id); + } + } + + + FORCE_INLINE_AICORE void EndFlagsAndBias() + { + ResetIpcFlags(2); + + if (aiv_idx == 1 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + } + + // p_value的含义在RS和AG不一样:在RS中,每个core计算p_value次后通信一次;在AG中,每从其他rank各gather p_value行后计算一次 + // 在2DTP中,p_value含义和AG一致。 + FORCE_INLINE_AICORE void Run() { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + PipeBarrier(); + // twod_big_dim:2D情况下每次总共搬运块数,取AG和RS较大的dim + int32_t twod_big_dim = ag_dim > rs_dim ? ag_dim: rs_dim; + int64_t gm_a_pingpong_size = m0 * k_align * p_value * twod_big_dim; + // 2 * 4 * 8 * 128*256 + int64_t gm_c_pingpong_size = p_value * twod_big_dim * n_loop * m0 * n0; + int32_t m_loop_per_bigdim = DivCeil(m_loop * ag_dim, twod_big_dim); + int64_t m_per_bigdim = m * ag_dim / twod_big_dim; + int32_t comm_count = DivCeil(m_loop_per_bigdim, p_value); + int32_t ag_m = p_value * m0; + int32_t rs_p_value = p_value; + + for (int32_t comm_idx = 0; comm_idx < comm_count + MAX_BLOCK_COUNT; ++comm_idx) { + uint64_t flag_idx = comm_idx % MAX_BLOCK_COUNT; + int32_t commrs_idx = comm_idx - MAX_BLOCK_COUNT; + if (comm_idx == comm_count - 1) { // last allgather + ag_m = m_per_bigdim - (comm_count - 1) * p_value * m0; + } + if (commrs_idx == comm_count - 1) { // last reducescatter + rs_p_value = m_loop_per_bigdim - (comm_count - 1) * p_value; + } + // wait aic + if (commrs_idx >= 0) { + WaitEvent(flag_idx); + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + // AGRS条件下,AG只会与部分core通信 + // 本卡与other卡的rs的rank相同,表示在一个ag通信域 + // 前两个循环没有rs,只有ag;最后两个循环没有ag,只有rs + + // 首先写自己rank的flag + + CrossRankSyncV1(FLAG_ZERO_IDX, comm_idx + 1); + SetAndWaitAivSync(flag_idx); + // AG部分 + if (comm_idx < comm_count && aiv_idx == 0 && core_idx < ag_comm_npu_split * ag_comm_data_split) { + // check目标rank数据是否准备好 + // AG每张卡 copy ag_part_dim次 + for (int32_t ag_part_idx = 0; ag_part_idx < ag_part_dim; ag_part_idx++) { + int64_t src_offset = comm_idx * p_value * m0 * k_align + ag_part_idx * m_per_bigdim * k_align; + int32_t bigdim_idx = rank_ag_idx * ag_part_dim + ag_part_idx; + int32_t rank_offset = flag_idx * gm_a_pingpong_size + bigdim_idx * p_value * m0 * k_align; + MoveWithSplit(gm_a + src_offset, rank_offset, ag_m * k_align); + } + } + // RS部分 + if (comm_idx >= MAX_BLOCK_COUNT && aiv_idx == 0 && core_idx >= ag_core_count && core_idx < ag_core_count + rs_core_count) { + for (int32_t rs_part_idx = 0; rs_part_idx < rs_part_dim; rs_part_idx++) { + int32_t bigdim_idx = rank_rs_idx * rs_part_dim + rs_part_idx; + int32_t rank_buff_offset = flag_idx * gm_c_pingpong_size + bigdim_idx * rs_p_value * m0 * n_loop * n0; + FirstStepInOutWithSplit(rs_p_value * m0 * n_loop * n0, LCAL_2DTP_C_OFFSET + rank_buff_offset, commrs_idx, flag_idx, m_per_bigdim * rs_part_idx * n); + } + } + + SetAndWaitAivSync(flag_idx); + CrossRankSyncV2(FLAG_ONE_IDX, comm_idx + 1); + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 发送aic同步 + SetAicSync(flag_idx); + } + + EndFlagsAndBias(); + } +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + using CocCommBase::ResetIpcFlags; + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::core_idx; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_direct; + using CocCommBase::swizzl_count; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::ag_dim; + using CocCommBase::rs_dim; + using CocCommBase::inner_dim_is_Ag; + using CocCommBase::comm_npu_split; + using CocCommBase::comm_data_split; + using CocCommBase::comm_direct; + using CocCommBase::len_per_loop; + using CocCommBase::extra_comm_npu_split; + using CocCommBase::extra_comm_data_split; + using CocCommBase::extra_comm_direct; + using CocCommBase::extra_len_per_loop; + using CocCommBase::extra_ub_move_num; + using CocCommBase::weight_nz; + using CocCommBase::local_expert_nums; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::EP; + using CocCommBase::TP; + using CocCommBase::flag_offset; + int32_t m_align; + int64_t k_align; + int32_t n_align; + int32_t aligned_a; + int32_t aligned_b; + int32_t comm_count; + + int32_t ag_comm_npu_split; + int32_t ag_comm_data_split; + int32_t ag_len_per_loop; + int32_t ag_comm_direct; + + int32_t rs_comm_npu_split; + int32_t rs_comm_data_split; + int32_t rs_len_per_loop; + int32_t rs_comm_direct; + + int32_t ag_core_count; + int32_t rs_core_count; + + int32_t ag_max_ub_ping_pong_size; + int32_t rs_max_ub_ping_pong_size; + __gm__ T *gm_a; + + // 本卡的ag、rs分别的rank idx + // 前8个core,每个core负责一张卡的通信 + int32_t rank_ag_idx; + int32_t rank_rs_idx; + // 本core负责的其他卡通信,对应的ag和rs的rank idx + int32_t other_rank_ag_idx; + int32_t other_rank_rs_idx; + Preprocessor preprocessor; + AllGatherMatmulBiasAdder add_bias_runner; + + int32_t twod_big_dim; + int64_t gm_a_pingpong_size; + int64_t gm_c_pingpong_size; + int32_t m_loop_per_bigdim; + int32_t m_per_bigdim; + int32_t ag_part_dim; + int32_t rs_part_dim; + +}; + +template +inline __aicore__ void CocAllGatherMatmulReduceScatterAiv(COC_ARGS_FUN(T)) { + // write + AllGatherReduceScatter allgatherreducescatter_write_without_bias; + AllGatherReduceScatter allgatherreducescatter_write_with_bias; + + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t tiling_key = cocTilingData->tilingKey; + // swizzl = 0 transa = 0 transb = 0 splitk = 0 bias = 0 int8 = 0 + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + allgatherreducescatter_write_without_bias.SetArgs(COC_ARGS_CALL()); + allgatherreducescatter_write_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + allgatherreducescatter_write_with_bias.SetArgs(COC_ARGS_CALL()); + allgatherreducescatter_write_with_bias.Run(); + break; + default : + break; + } + + PipeBarrier(); +} + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allgather_v2.cce b/comm/lcal/src/kernels/coc_allgather_v2.cce new file mode 100644 index 0000000000000000000000000000000000000000..76e8f42411044679d2d861885e9fb96952155095 --- /dev/null +++ b/comm/lcal/src/kernels/coc_allgather_v2.cce @@ -0,0 +1,367 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class AllGatherV2 : public AllGather { +public: + __aicore__ explicit AllGatherV2(){}; + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) + { + AllGather::SetArgs(COC_ARGS_CALL()); + max_move_m = max_ub_ping_pong_size > max_move_k ? max_ub_ping_pong_size / max_move_k : 1; + gm_allgather = gm_allgather_out; + } + + FORCE_INLINE_AICORE void MoveResultFromPeerMemToOut(__gm__ T *gm_src, __gm__ T *gm_dst, int32_t actual_m) + { + int32_t ping_pong_move_count = (actual_m + max_move_m - 1) / max_move_m; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_m = max_move_m; + if (move_idx == ping_pong_move_count - 1) { + actual_move_m = actual_m - move_idx * max_move_m; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + int32_t k_move_count = (k_align + max_move_k - 1) / max_move_k; + for (int32_t k_move_idx = 0; k_move_idx < k_move_count; ++k_move_idx) { + int32_t actual_k_move_num_in_peer_mem = max_move_k; + int32_t actual_k_move_num_in_out = max_move_k; + if (k_move_idx == k_move_count - 1) { + actual_k_move_num_in_peer_mem = k_align - k_move_idx * max_move_k; + actual_k_move_num_in_out = k - k_move_idx * max_move_k; + } + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, gm_src + move_idx * max_move_m * k_align + k_move_idx * max_move_k, + actual_move_m, actual_k_move_num_in_peer_mem * sizeof(T) / 32, + (k_align - actual_k_move_num_in_peer_mem) * sizeof(T) / 32, 0); + SetFlag(event_id); + WaitFlag(event_id); + if (ALIGN) { + CopyUbufToGm(gm_dst + move_idx * max_move_m * k + k_move_idx * max_move_k, ub_buff_st, + actual_move_m, actual_k_move_num_in_out * sizeof(T) / 32, + (actual_k_move_num_in_peer_mem - actual_k_move_num_in_out) * sizeof(T) / 32, + (k - actual_k_move_num_in_out) * sizeof(T) / 32); + } else { + CopyUbufToGmAlignB16(gm_dst + move_idx * max_move_m * k + k_move_idx * max_move_k, ub_buff_st, + actual_move_m, actual_k_move_num_in_out * sizeof(T), + (actual_k_move_num_in_peer_mem - actual_k_move_num_in_out) * sizeof(T) / 32, + (k - actual_k_move_num_in_out) * sizeof(T)); + } + SetFlag(event_id); + } + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE void Run() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + PipeBarrier(); + + for (int32_t cal_idx = 0; cal_idx < cal_count + MAX_BLOCK_COUNT; ++cal_idx) { + uint64_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + int32_t actual_m = p_value * m0; + if (cal_idx == cal_count - 1) { + actual_m = m - cal_idx * p_value * m0; + } + // wait aic + int32_t cal_done_idx = cal_idx - MAX_BLOCK_COUNT; + if (cal_done_idx >= 0) { + WaitEvent(flag_idx); + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + if (cal_idx < cal_count && aiv_idx == 0 && core_idx < rank_size) { + int64_t src_offset = cal_idx * p_value * m0 * k_align; + int32_t rank_offset = flag_idx * gm_a_pingpong_size + rank * p_value * m0 * k_align; + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + + if (other_rank != rank) { + MoveResultFromSrcToDst(gm_out + src_offset, buff[other_rank] + rank_offset, actual_m * k_align); + } + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + } else if (cal_idx > 0 && cal_idx < cal_count + 1 && aiv_idx == 1 && core_idx >= rank_size && + core_idx < rank_size * 2) { // peermem to out + uint64_t s2_flag_idx = (cal_idx - 1) % MAX_BLOCK_COUNT; + int32_t s2_other_rank = core_idx - rank_size; + int64_t src_offset = (cal_idx - 1) * p_value * m0 * k_align; + int32_t other_rank_offset = s2_flag_idx * gm_a_pingpong_size + s2_other_rank * p_value * m0 * k_align; + int64_t dst_offset = s2_other_rank * static_cast(m) * k + (cal_idx - 1) * p_value * m0 * k; + int32_t s2_actual_m = p_value * m0; + if (cal_idx == cal_count) { + s2_actual_m = m - (cal_idx - 1) * p_value * m0; + } + if (s2_other_rank != rank) { + MoveResultFromPeerMemToOut(buff[rank] + other_rank_offset, gm_allgather + dst_offset, s2_actual_m); + } else { + MoveResultFromPeerMemToOut(gm_out + src_offset, gm_allgather + dst_offset, s2_actual_m); + } + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 发送aic同步 + SetAicSync(flag_idx); + } + + EndFlagsAndBias(); + } + + + FORCE_INLINE_AICORE void RunWithSplit() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + PipeBarrier(); + + int64_t data_len = static_cast(m) * k_align; // 数据量 + int32_t num_per_rank_move = m0 * k0 * p_value * k_loop; // 每轮搬运到其他卡的数据量 + int32_t core_count = comm_npu_split * comm_data_split; // 每张卡上使用的核数 + int64_t src_offset = 0; // 当前份数据的起始位置 + int64_t rank_offset = rank * num_per_rank_move; + for (int32_t cal_idx = 0; cal_idx < cal_count + MAX_BLOCK_COUNT; ++cal_idx) { + uint64_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + if (cal_idx == cal_count - 1) { + num_per_rank_move = data_len - src_offset; + } + + // wait aic + if (cal_idx >= MAX_BLOCK_COUNT) { + WaitEvent(flag_idx); + } + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + if (cal_idx < cal_count) { + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + } + SetAndWaitAivSync(flag_idx); + if (cal_idx < cal_count && aiv_idx == 0 && core_idx < core_count) { + int64_t gm_rank_offset = flag_idx * gm_a_pingpong_size + rank_offset; + MoveWithSplit(gm_out + src_offset, gm_rank_offset, num_per_rank_move); + src_offset += num_per_rank_move; + } else if (cal_idx > 0 && cal_idx < cal_count + 1 && aiv_idx == 1 && + core_idx >= core_count && core_idx < rank_size + core_count) { // peermem to out + // 如果剩余的core数不够,则循环搬运 + int32_t other_core_num = get_block_num() - core_count; // 剩余的core数 + int32_t cycle_num = (other_core_num + rank_size - 1) / other_core_num; // 循环次数 + uint64_t s2_flag_idx = (cal_idx - 1) % MAX_BLOCK_COUNT; + int64_t src_offset = (cal_idx - 1) * p_value * m0 * k_align; + int32_t s2_actual_m = cal_idx == cal_count ? m - (cal_idx - 1) * p_value * m0 : p_value * m0; + for (int32_t cycle_idx = 0; cycle_idx < cycle_num; ++cycle_idx) { + int32_t s2_other_rank = core_idx - core_count + cycle_idx * other_core_num; + int32_t other_rank_offset = s2_flag_idx * gm_a_pingpong_size + s2_other_rank * p_value * m0 * k_align; + int64_t dst_offset = s2_other_rank * static_cast(m) * k + (cal_idx - 1) * p_value * m0 * k; + if (s2_other_rank >= rank_size) { + break; + } + if (s2_other_rank != rank) { + MoveResultFromPeerMemToOut(buff[rank] + other_rank_offset, gm_allgather + dst_offset, s2_actual_m); + } else { + MoveResultFromPeerMemToOut(gm_out + src_offset, gm_allgather + dst_offset, s2_actual_m); + } + } + } + SetAndWaitAivSync(flag_idx); + if (cal_idx < cal_count) { + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 发送aic同步 + SetAicSync(flag_idx); + } + + EndFlagsAndBias(); + } + +public: + using AllGather::SetAicSync; + using AllGather::SetAndWaitAivSync; + using AllGather::SetBuffFlag; + using AllGather::SetBuffFlagByAdd; + using AllGather::CheckBuffFlag; + using AllGather::ResetIpcFlags; + using AllGather::EndFlagsAndBias; + using AllGather::CrossRankSyncV1; + using AllGather::CrossRankSyncV2; + using AllGather::buff; + using AllGather::gm_out; + using AllGather::ctrl_flags_UB; + using AllGather::output_UB_T; + using AllGather::batch_size; + using AllGather::m; + using AllGather::k; + using AllGather::n; + using AllGather::m0; + using AllGather::k0; + using AllGather::n0; + using AllGather::m_loop; + using AllGather::n_loop; + using AllGather::k_loop; + using AllGather::core_loop; + using AllGather::core_idx; + using AllGather::rank; + using AllGather::rank_size; + using AllGather::tiling_key; + using AllGather::swizzl_count; + using AllGather::p_value; + using AllGather::aiv_idx; + using AllGather::other_rank; + using AllGather::max_ub_single_dma_size; + using AllGather::max_ub_ping_pong_size; + using AllGather::m_align; + using AllGather::k_align; + using AllGather::n_align; + using AllGather::aligned_a; + using AllGather::aligned_b; + using AllGather::cal_count; + using AllGather::gm_a_pingpong_size; + using AllGather::preprocessor; + using AllGather::add_bias_runner; + using AllGather::MoveResultFromSrcToDst; + using AllGather::comm_npu_split; + using AllGather::comm_data_split; + using AllGather::comm_direct; + using AllGather::len_per_loop; + using AllGather::MoveWithSplit; + using AllGather::local_expert_nums; + using AllGather::is_moe; + using AllGather::is_moe_averaged; + using AllGather::is_alltoallvc; + using AllGather::EP; + using AllGather::TP; + int32_t max_move_m; + int32_t max_move_k = 20480; + int32_t copy_core_num; + int32_t m_k_num; + int32_t num_per_rank_move; + int32_t core_count; + int32_t first_step_core_num; + int32_t num_per_move; + __gm__ T *gm_allgather; +}; + +constexpr int32_t NO_BIAS_MASK5 = 0b000000 | 0b100000 | 0b010000 | 0b110000 | 0b001000 | 0b101000 | 0b011000 | 0b111000; +constexpr int32_t BIAS_MASK5 = 0b000010 | 0b100010 | 0b010010 | 0b110010 | 0b001010 | 0b101010 | 0b011010 | 0b111010; + +template +FORCE_INLINE_AICORE void RunAllGatherV2Align16(int32_t tiling_key, COC_ARGS_FUN(T)) +{ + // 16 align + AllGatherV2 allgather_write_align_16_without_bias; + AllGatherV2 allgather_write_align_16_with_bias; + switch (tiling_key) { + case 0b000000: + case 0b100000: + case 0b010000: + case 0b110000: + case 0b001000: + case 0b101000: + case 0b011000: + case 0b111000: + allgather_write_align_16_without_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_align_16_without_bias.RunWithSplit(); + break; + case 0b000010: + case 0b100010: + case 0b010010: + case 0b110010: + case 0b001010: + case 0b101010: + case 0b011010: + case 0b111010: + allgather_write_align_16_with_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_align_16_with_bias.RunWithSplit(); + break; + default: + break; + } +} + +template +FORCE_INLINE_AICORE void RunAllGatherV2UnAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) +{ + // 16 unalign + AllGatherV2 allgather_write_unalign_16_without_bias; + AllGatherV2 allgather_write_unalign_16_with_bias; + switch (tiling_key) { + case 0b000000: + case 0b100000: + case 0b010000: + case 0b110000: + case 0b001000: + case 0b101000: + case 0b011000: + case 0b111000: + allgather_write_unalign_16_without_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_unalign_16_without_bias.RunWithSplit(); + break; + case 0b000010: + case 0b100010: + case 0b010010: + case 0b110010: + case 0b001010: + case 0b101010: + case 0b011010: + case 0b111010: + allgather_write_unalign_16_with_bias.SetArgs(COC_ARGS_CALL()); + allgather_write_unalign_16_with_bias.RunWithSplit(); + break; + default : + break; + } +} + +template +inline __aicore__ void CocAllGatherMatmulV2Aiv(COC_ARGS_FUN(T)) +{ + // write + AllGatherV2 allgather_write_align_16_without_bias; + AllGatherV2 allgather_write_align_16_with_bias; + AllGatherV2 allgather_write_unalign_16_without_bias; + AllGatherV2 allgather_write_unalign_16_with_bias; + + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t k = cocTilingData->k; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t write_to_other_rank = cocTilingData->write2OtherRank; + // swizzl = 0 transa = 0 transb = 0 splitk = 0 bias = 0 int8 = 0 + if (k % BLOCK_SIZE_16 == 0) { + RunAllGatherV2Align16(tiling_key, COC_ARGS_CALL()); + } else { + RunAllGatherV2UnAlign16(tiling_key, COC_ARGS_CALL()); + } + PipeBarrier(); +} + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_allreduce.cce b/comm/lcal/src/kernels/coc_allreduce.cce new file mode 100644 index 0000000000000000000000000000000000000000..c1f95df044d6cc68ac5848b168d658087d948ddd --- /dev/null +++ b/comm/lcal/src/kernels/coc_allreduce.cce @@ -0,0 +1,729 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class AllReduce : public CocCommBase { +public: + __aicore__ explicit AllReduce(){}; + + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) + { + CocCommBase::SetArgsForReduce(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + postprocessor.SetArgs(PP_MATMUL_AIV_POST_ARGS_CALL()); + if constexpr (HAVE_BIAS) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + need_dequant = workspace_info.gm_accum; + if (need_dequant) { + if (withSerialMode) { + serial_dequant_runner.SetArgs(reinterpret_cast<__gm__ bfloat16_t *>(buff[rank]), workspace_info, + reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale), + reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset), + dequant_granularity, batch_size, m, n); + } else { + fused_dequant_runner.SetArgs(reinterpret_cast<__gm__ bfloat16_t *>(buff[rank]), workspace_info, + reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale), + reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset), dequant_granularity, + batch_size, m, n, m0, n0, m_loop, n_loop, core_loop, swizzl_direct, + swizzl_count, p_value, rank_size); + } + } + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + fused_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ T *>(buff[rank]), + reinterpret_cast<__gm__ float32_t *>(gm_quant_scale), m, n, + m0, n0, m_loop, n_loop, core_loop, swizzl_direct, swizzl_count, p_value, rank_size); + serial_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ T *>(gm_out), reinterpret_cast<__gm__ float32_t*>(gm_quant_scale), m, n, m0, n0); + } + total_core_idx = aiv_idx * core_num + core_idx; + cal_count = DivCeil(core_loop, loop_num_per_comm); + } + + FORCE_INLINE_AICORE int32_t GetCoreGroup() { + if (total_core_idx < core_count) { + return 0; + } + if (total_core_idx < core_count + SIO_TOTAL_CORE_NUM) { + return 1; + } + return -1; + } + + FORCE_INLINE_AICORE void InitFlags() { + if constexpr (HAVE_BIAS) { + SetAtomicAdd(); + PipeBarrier(); + } + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE void EndFlagsAndBias() { + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + if constexpr (HAVE_BIAS) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + + FORCE_INLINE_AICORE void StartBeforeFisrtStep(uint64_t flag_idx) { + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + SetAtomicAdd(); + PipeBarrier(); + } + + FORCE_INLINE_AICORE void EndFirstStep(uint64_t flag_idx) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + SetAndWaitAivSync(flag_idx, is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT); + } + + // input是peermem属于每个rank的部分的首地址 + FORCE_INLINE_AICORE void SecondStepParallel(int32_t data_size_remain, __gm__ T* input, int32_t gm_out_offset) { + if (data_size_remain <= 0) { + return; + } + InitFlags(); + int32_t ping_pong_move_count = DivCeil(data_size_remain, max_ub_ping_pong_size); + + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx){ + int32_t actual_move_size = (move_idx == ping_pong_move_count - 1) ? + data_size_remain - move_idx * max_ub_ping_pong_size : max_ub_ping_pong_size; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, input + move_idx * max_ub_ping_pong_size, 1, + actual_move_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + int32_t move_num_offset = gm_out_offset + move_idx * max_ub_ping_pong_size; + CopyUbToGmTransLayout(ub_buff_st, actual_move_size, move_num_offset); + SetFlag(event_id); + } + EndFlagsAndBias(); + } + + FORCE_INLINE_AICORE void SecondStepParallelWithSplit(int32_t data_size_remain, int32_t cal_idx, + int32_t flag_idx, int32_t data_loop_idx) { + if (data_size_remain <= 0) { + return; + } + InitFlags(); + int32_t rank_per_core = rank_size / comm_npu_split; + int32_t core_rank_offset = (core_idx / comm_data_split) * rank_per_core; + for (int32_t index = 0; index < rank_per_core; index++) { + int32_t rank_idx_rot = (index + core_idx) % rank_per_core; + int32_t real_core_idx = core_rank_offset + rank_idx_rot; + int32_t before_other_rank_offset = data_loop_idx * comm_data_split * len_per_loop; + int32_t other_rank_offset = before_other_rank_offset + real_core_idx * m_per_rank * n0 + core_idx % comm_data_split * len_per_loop; + int32_t other_rank_buff_offset = flag_idx * gm_c_pingpong_size + other_rank_offset; + int32_t ping_pong_move_count = DivCeil(data_size_remain, max_ub_ping_pong_size); + + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx){ + int32_t actual_move_size = (move_idx == ping_pong_move_count - 1) ? + data_size_remain - move_idx * max_ub_ping_pong_size : max_ub_ping_pong_size; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub_buff_st, buff[real_core_idx] + other_rank_buff_offset + move_idx * max_ub_ping_pong_size, 1, + actual_move_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + int64_t move_num_offset = other_rank_offset + move_idx * max_ub_ping_pong_size; + CopyUbToGmTransLayout(ub_buff_st, actual_move_size, move_num_offset + cal_idx * gm_c_pingpong_size); + SetFlag(event_id); + } + } + EndFlagsAndBias(); + } + + FORCE_INLINE_AICORE void FirstStepDivCore(int32_t data_len, int32_t offset) { + // 4或8卡确定性,用tree + if (is_deterministic && rank_size >= 4 && rank_size <= 8) { + return FirstStepInPeerMemTree(data_len, offset); + } + // 否则,用线性累加 + return FirstStepInPeerMemSeq(data_len, offset); + } + + FORCE_INLINE_AICORE void SecondStepSerial(int32_t data_size_remain, __gm__ T *input, + __gm__ T *output) + { + if (data_size_remain <= 0) { + return; + } + InitFlags(); + + int32_t offset = 0; + for (int32_t move_idx = 0; data_size_remain >= max_ub_ping_pong_size; ++move_idx) { + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub, input + offset, 1, max_ub_ping_pong_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); // MTE3等MTE2 + WaitFlag(event_id); + CopyUbufToGm(output + offset, ub, 1, max_ub_ping_pong_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); // MTE2等MTE3 + data_size_remain -= max_ub_ping_pong_size; + offset += max_ub_ping_pong_size; + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + if (data_size_remain >= 0) { + CopyGmToUbuf(output_UB_T[0], input + offset, 1, (data_size_remain * sizeof(T) + 31) / 32, 0, 0); + SetFlag(EVENT_ID0); // MTE3等MTE2 + WaitFlag(EVENT_ID0); + if (ALIGN) { + CopyUbufToGm(output + offset, output_UB_T[0], 1, data_size_remain * sizeof(T) / 32, 0, 0); + } else { + CopyUbufToGmAlignB16(output + offset, output_UB_T[0], 1, data_size_remain * sizeof(T), 0, 0); + } + } + + if constexpr (HAVE_BIAS) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + + FORCE_INLINE_AICORE void ParallelWithSplit() { + ResetIpcFlags(3); + PipeBarrier(); + + for (int32_t cal_idx = 0; cal_idx < cal_count; ++cal_idx) { + uint64_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + int32_t actual_loop_num = (cal_idx == cal_count - 1) ? core_loop - cal_idx * loop_num_per_comm : + loop_num_per_comm; + int32_t m_total = actual_loop_num * m0; + m_per_rank = DivCeil(m_total, rank_size); // pvalue * corenum / ranksize * m0 + m_in_rank = (rank * m_per_rank >= m_total) ? 0 : + ((rank + 1) * m_per_rank > m_total ? m_total - rank * m_per_rank : m_per_rank); + // wait aic + WaitEvent(flag_idx); + + if (need_dequant) { + SetAndWaitAivSync(flag_idx); + fused_dequant_runner.RunDequantAllReduce(cal_idx); + } + + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + SetAndWaitAivSync(flag_idx); + //fused_pertoken_dequant_runner.Run(cal_idx); + fused_pertoken_dequant_runner.RunDequantAllReduce(cal_idx); + } + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 卡内matmul结果准备就绪 + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + + StartBeforeFisrtStep(flag_idx); + + int32_t rank_total = m_in_rank * n0; + int32_t rank_offset = rank * m_per_rank * n0; + + int32_t rank_buff_offset = flag_idx * m0 * n0 * loop_num_per_comm + rank_offset; + + int32_t len_per_core = rank_total / comm_data_split; + int32_t data_split_num = DivCeil(len_per_core, len_per_loop); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + for (int loop_index = 0; loop_index < data_split_num; loop_index++) { + if (aiv_idx == 0 && core_idx < comm_data_split * comm_npu_split) { + int32_t before_core_offset = len_per_loop * comm_data_split * loop_index; + int32_t loop_total = rank_total - before_core_offset; + int32_t real_core_offset = core_idx % comm_data_split * len_per_loop; + + int32_t m_in_core = (real_core_offset >= loop_total) ? 0 : + ((real_core_offset + len_per_loop) > loop_total ? + loop_total - real_core_offset : len_per_loop); + + FirstStepDivCore(m_in_core, rank_buff_offset + before_core_offset + real_core_offset); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + EndFirstStep(flag_idx); + + CrossRankSyncV1(FLAG_ONE_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx); + + for (int loop_index = 0; loop_index < data_split_num; loop_index++) { + if (aiv_idx == 0 && core_idx < comm_data_split * comm_npu_split) { + int32_t before_core_offset = len_per_loop * comm_data_split * loop_index; + int32_t loop_total = rank_total - before_core_offset; + int32_t real_core_offset = core_idx % comm_data_split * len_per_loop; + + int32_t m_in_core = (real_core_offset >= loop_total) ? 0 : + ((real_core_offset + len_per_loop) > loop_total ? + loop_total - real_core_offset : len_per_loop); + + SecondStepParallelWithSplit(m_in_core, cal_idx, flag_idx, loop_index); + } + } + SetAndWaitAivSync(flag_idx); + + CrossRankSyncV2(FLAG_TWO_IDX, cal_idx + 1); + + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 发送aic同步 + SetAicSync(flag_idx); + } + ResetIpcFlags(3); + + if (aiv_idx == 0 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + } + + + FORCE_INLINE_AICORE void DataCopySioRs(int32_t cal_idx_sio, int32_t len_per_rank) { + int32_t flag_idx_sio = cal_idx_sio % BLOCK_COUNT_4; + int32_t len_per_core = len_per_rank / SIO_TOTAL_CORE_NUM; + int32_t sio_core_idx = total_core_idx - core_count; + int32_t core_offset = sio_core_idx * len_per_core; + int32_t sio_peer_rank = rank ^ 1; + // sio-0搬0 2 4 6,sio-1搬1 3 5 7 + for(int32_t src_rank = rank % 2; src_rank < rank_size; src_rank += 2) { + int32_t peer_offset = flag_idx_sio * gm_c_pingpong_size + src_rank * len_per_rank + core_offset; + FirstStepInPeerMem(len_per_core, buff[sio_peer_rank] + peer_offset, buff[rank] + peer_offset); + } + } + + FORCE_INLINE_AICORE void DataCopySioAg(int32_t cal_idx_sio, int32_t len_per_rank) { + int32_t flag_idx_sio = cal_idx_sio % BLOCK_COUNT_4; + int32_t len_per_core = len_per_rank / SIO_TOTAL_CORE_NUM; + int32_t sio_core_idx = total_core_idx - core_count; + int32_t core_offset = sio_core_idx * len_per_core; + int32_t sio_peer_rank = rank ^ 1; + // 1从0卡搬0 2 4 6, 0从1卡搬1 3 5 7 + for(int32_t src_rank = sio_peer_rank % 2; src_rank < rank_size; src_rank += 2) { + int32_t peer_offset = flag_idx_sio * gm_c_pingpong_size + src_rank * len_per_rank; + int32_t dst_offset = cal_idx_sio * gm_c_pingpong_size + src_rank * len_per_rank + core_offset; + SecondStepParallel(len_per_core, buff[sio_peer_rank] + peer_offset + core_offset, dst_offset); + } + // copy自己卡 + int32_t local_offset = flag_idx_sio * gm_c_pingpong_size + rank * len_per_rank + core_offset; + int32_t dst_offset = cal_idx_sio * gm_c_pingpong_size + rank * len_per_rank + core_offset; + SecondStepParallel(len_per_core, buff[rank] + local_offset, dst_offset); + } + + FORCE_INLINE_AICORE void ParallelSio() { + ResetIpcFlags(3); + PipeBarrier(); + int32_t last_loop_num = core_loop - (cal_count - 1) * loop_num_per_comm; + int32_t core_group = GetCoreGroup(); + for (int32_t cal_idx = 0; cal_idx < cal_count + 2; ++cal_idx) { + int32_t hccs_idx = cal_idx - 1; // sio-rs -> hccs -> sio-ag + int32_t sio2_idx = cal_idx - 2; // sio-ag + int32_t flag_idx_sio1 = cal_idx % BLOCK_COUNT_4; + int32_t flag_idx_hccs = hccs_idx % BLOCK_COUNT_4; + int32_t flag_idx_sio2 = sio2_idx % BLOCK_COUNT_4; + int32_t loop_num_hccs = hccs_idx == cal_count - 1 ? last_loop_num : loop_num_per_comm; + // wait aic + if (cal_idx < cal_count){ + WaitEvent(flag_idx_sio1); + } + + if (need_dequant) { + fused_dequant_runner.RunDequantAllReduce(cal_idx); + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx_sio1, BLOCK_COUNT_4); + + // 卡内matmul结果准备就绪 + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + + StartBeforeFisrtStep(flag_idx_sio1); + + if (core_group == 0 && cal_idx >= 1 && cal_idx < cal_count + 1) { // step 2-1 hccs rs + int32_t size_per_rank = loop_num_hccs * m0 * n0 / rank_size; + int32_t rank_offset = rank * size_per_rank; + int32_t rank_buff_offset = flag_idx_hccs * gm_c_pingpong_size + rank_offset; + int32_t size_per_core = size_per_rank / (comm_data_split); + + int32_t data_split_num = DivCeil(size_per_core, len_per_loop); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + for (int loop_index = 0; loop_index < data_split_num; loop_index++) { + int32_t before_core_offset = len_per_loop * comm_data_split * loop_index; + int32_t loop_total = size_per_rank - before_core_offset; + int32_t real_core_offset = core_idx % comm_data_split * len_per_loop; + + int32_t m_in_core = (real_core_offset >= loop_total) ? 0 : + ((real_core_offset + len_per_loop) > loop_total ? + loop_total - real_core_offset : len_per_loop); + + FirstStepDivCore(m_in_core, rank_buff_offset + before_core_offset + real_core_offset); + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + } + if (core_group == 1 && cal_idx < cal_count) { // step 1 sio reducescatter + int32_t loop_num_sio1 = cal_idx == cal_count - 1 ? last_loop_num : loop_num_per_comm; + int32_t size_per_rank = loop_num_sio1 * m0 * n0 / rank_size; + DataCopySioRs(cal_idx, size_per_rank); + } + + EndFirstStep(flag_idx_sio1); + + CrossRankSyncV1(FLAG_ONE_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx_sio1, BLOCK_COUNT_4); + if (core_group == 0 && cal_idx >= 1 && cal_idx < cal_count + 1) { // step2-2 hccs ag + int32_t size_per_rank = loop_num_hccs * m0 * n0 / rank_size; + int32_t pipe_offset = flag_idx_hccs * gm_c_pingpong_size + other_rank * size_per_rank; + int32_t dst_offset = hccs_idx * gm_c_pingpong_size + other_rank * size_per_rank; + if ((other_rank % 2) == (rank % 2) && other_rank != rank) { + FirstStepInPeerMemTransLayout(size_per_rank, buff[other_rank] + pipe_offset, buff[rank] + pipe_offset, dst_offset); + } + } + if (core_group == 1 && cal_idx >= 2) { // step3: sio-ag + int32_t loop_num_sio2 = sio2_idx == cal_count - 1 ? last_loop_num : loop_num_per_comm; + int32_t size_per_rank = loop_num_sio2 * m0 * n0 / rank_size; + DataCopySioAg(sio2_idx, size_per_rank); + } + SetAndWaitAivSync(flag_idx_sio1, BLOCK_COUNT_4); + CrossRankSyncV2(FLAG_TWO_IDX, cal_idx + 1); + + // aiv之间同步 + SetAndWaitAivSync(flag_idx_sio1, BLOCK_COUNT_4); + + // 发送aic同步 + if (cal_idx >= 2) + SetAicSync(flag_idx_sio2); + } + ResetIpcFlags(3); + + if (aiv_idx == 0 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + } + + + FORCE_INLINE_AICORE void Serial() { + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ONE_IDX, tag); + // aic/aiv之间同步 + WaitEvent(AIV_WAIT_AIC_FINISH_MATMUL_FLAG_ID); + + FFTSCrossCoreSync(0, AIV_FINISH_ALIGN_FLAG_ID); + WaitEvent(AIV_FINISH_ALIGN_FLAG_ID); + + if (need_dequant) { + serial_dequant_runner.Run(); + } + if (aiv_idx == 1 && core_idx < rank_size) { + int32_t data_size = batch_size * m * n; + int32_t data_size_per_rank = (data_size + BLOCK_SIZE_16 * rank_size - 1) / (BLOCK_SIZE_16 * rank_size) * BLOCK_SIZE_16; + if (other_rank == rank){ + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ZERO_IDX, tag); + }else { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ZERO_IDX, tag); + PipeBarrier(); + int32_t rank_buff_offset = rank * data_size_per_rank; + FirstStepInPeerMem(data_size_per_rank, buff[other_rank] + rank_buff_offset, buff[rank] + rank_buff_offset, true); + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ONE_IDX, tag); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ONE_IDX, tag * rank_size); + PipeBarrier(); + int32_t data_size_in_other_rank = data_size_per_rank; + if (other_rank * data_size_in_other_rank >= data_size){ + data_size_in_other_rank = 0; + } else if ((other_rank + 1) * data_size_in_other_rank > data_size){ + data_size_in_other_rank = data_size - other_rank * data_size_per_rank; + } + int32_t other_rank_buff_offset = other_rank * data_size_per_rank; + SecondStepSerial(data_size_in_other_rank, buff[other_rank] + other_rank_buff_offset, gm_out + other_rank_buff_offset); + } + } + + FORCE_INLINE_AICORE void SerialWithSplit() { + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + MAX_FLAG_COUNT + FLAG_ONE_IDX, tag); + // aic/aiv之间同步 + WaitEvent(AIV_WAIT_AIC_FINISH_MATMUL_FLAG_ID); + + FFTSCrossCoreSync(0, AIV_FINISH_ALIGN_FLAG_ID); + WaitEvent(AIV_FINISH_ALIGN_FLAG_ID); + + if (need_dequant) { + serial_dequant_runner.Run(); + } + + int32_t data_size = batch_size * m * n; + int32_t data_size_per_rank = (data_size + BLOCK_SIZE_16 * rank_size - 1) / (BLOCK_SIZE_16 * rank_size) * BLOCK_SIZE_16; + + int32_t use_core_count = comm_npu_split * comm_data_split; + int32_t rank_buff_offset = rank * data_size_per_rank; + + int32_t len_per_core = data_size_per_rank / comm_data_split; + int32_t data_split_num = DivCeil(len_per_core, len_per_loop); + + SetAndWaitAivSync(0); + CrossRankSyncV3(MAX_FLAG_COUNT + FLAG_ZERO_IDX, tag); + StartBeforeFisrtStep(0); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + for (int loop_index = 0; loop_index < data_split_num; loop_index++) { + if (aiv_idx == 0 && core_idx < comm_data_split * comm_npu_split) { + int32_t before_core_offset = len_per_loop * comm_data_split * loop_index; + int32_t loop_total = data_size_per_rank - before_core_offset; + int32_t real_core_offset = core_idx % comm_data_split * len_per_loop; + + int32_t m_in_core = (real_core_offset >= loop_total) ? 0 : + ((real_core_offset + len_per_loop) > loop_total ? + loop_total - real_core_offset : len_per_loop); + + FirstStepDivCore(m_in_core, rank_buff_offset + before_core_offset + real_core_offset); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + + EndFirstStep(0); + + CrossRankSyncV4(MAX_FLAG_COUNT + FLAG_ONE_IDX, tag); + SetAndWaitAivSync(0); + + if (aiv_idx == 0 && core_idx < rank_size) { + PipeBarrier(); + int32_t data_size_in_other_rank = data_size_per_rank; + if (other_rank * data_size_in_other_rank >= data_size){ + data_size_in_other_rank = 0; + } else if ((other_rank + 1) * data_size_in_other_rank > data_size){ + data_size_in_other_rank = data_size - other_rank * data_size_per_rank; + } + int32_t other_rank_buff_offset = other_rank * data_size_per_rank; + SecondStepSerial(data_size_in_other_rank, buff[other_rank] + other_rank_buff_offset, gm_out + other_rank_buff_offset); + } + } + + FORCE_INLINE_AICORE void Run() + { + // Padding + preprocessor.Run(); + + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + + if (withSerialMode) { + if (is_deterministic) { + SerialWithSplit(); + } else { + Serial(); + } + } else { + ParallelWithSplit(); + } + + + + PipeBarrier(); + postprocessor.Run(); + PipeBarrier(); + + if (withSerialMode && dequant_granularity == QuantGranularity::PER_TOKEN) { + serial_pertoken_dequant_runner.Run(); + } + + } + +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::FillZero; + using CocCommBase::FirstStepInPeerMem; + using CocCommBase::FirstStepInPeerMemSeq; + using CocCommBase::FirstStepInPeerMemTree; + using CocCommBase::FirstStepInPeerMemTransLayout; + using CocCommBase::CopyUbToGmTransLayout; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + using CocCommBase::CrossRankSyncV3; + using CocCommBase::CrossRankSyncV4; + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::core_idx; + using CocCommBase::core_num; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_count; + using CocCommBase::swizzl_direct; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::is_91093; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::withSerialMode; + using CocCommBase::tag; + using CocCommBase::loop_num_per_comm; // p_value * core_num + using CocCommBase::gm_c_pingpong_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::comm_npu_split; + using CocCommBase::comm_data_split; + using CocCommBase::len_per_loop; + using CocCommBase::core_count; + using CocCommBase::weight_nz; + using CocCommBase::is_deterministic; + using CocCommBase::local_expert_nums; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::EP; + using CocCommBase::TP; + using CocCommBase::flag_offset; + int32_t cal_count; + int32_t m_per_rank; + int32_t m_in_rank; + int32_t total_core_idx; + Preprocessor preprocessor; + Postprocessor postprocessor; + MatmulAllReduceBiasAdder add_bias_runner; + SerialDequantRunner serial_dequant_runner; + FusedDequantRunner fused_dequant_runner; + //AllReduceFusedPerTokenDequantRunner fused_pertoken_dequant_runner; + FusedPerTokenDequantRunner fused_pertoken_dequant_runner; + SerialPerTokenDequantRunner serial_pertoken_dequant_runner; + bool need_dequant; +}; + +constexpr int32_t NO_BIAS_MASK1 = 0b000000 | 0b100000 | 0b010000 | 0b110000 | 0b001000 | 0b101000 | 0b011000 | + 0b111000 | 0b000100 | 0b100100 | 0b010100 | 0b110100 | 0b001100 | 0b101100 | + 0b011100 | 0b111100; +constexpr int32_t BIAS_MASK1 = 0b000010 | 0b100010 | 0b010010 | 0b110010 | 0b001010 | 0b101010 | 0b011010 | 0b111010 | + 0b000110 | 0b100110 | 0b010110 | 0b110110 | 0b001110 | 0b101110 | 0b011110 | 0b111110; + +template +FORCE_INLINE_AICORE void RunAllReduceAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) { + // 16 align + AllReduce allreduce_align_16_without_bias; + AllReduce allreduce_align_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + allreduce_align_16_without_bias.SetArgs(COC_ARGS_CALL()); + allreduce_align_16_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + allreduce_align_16_with_bias.SetArgs(COC_ARGS_CALL()); + allreduce_align_16_with_bias.Run(); + break; + default : + break; + } +} + +template +FORCE_INLINE_AICORE void RunAllReduceUnAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) { + // 16 unalign + AllReduce allreduce_unalign_16_without_bias; + AllReduce allreduce_unalign_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + allreduce_unalign_16_without_bias.SetArgs(COC_ARGS_CALL()); + allreduce_unalign_16_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + allreduce_unalign_16_with_bias.SetArgs(COC_ARGS_CALL()); + allreduce_unalign_16_with_bias.Run(); + break; + default : + break; + } +} + +template +inline __aicore__ void CocMatmulAllReduceAiv(COC_ARGS_FUN(T)) +{ + // 16 align + AllReduce allreduce_align_16_without_bias; + AllReduce allreduce_align_16_with_bias; + // 16 unalign + AllReduce allreduce_unalign_16_without_bias; + AllReduce allreduce_unalign_16_with_bias; + + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int64_t batch_size = cocTilingData->batchSize; + int32_t m = cocTilingData->m; + int32_t n = cocTilingData->n; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t rank_size = cocTilingData->rankSize; + int32_t withSerialMode = cocTilingData->withSerialMode; + // swizzl = 0 transa = 0 transb = 0 splitk = 0 bias = 0 int8 = 0 + if ((withSerialMode == 0 && n % BLOCK_SIZE_16 == 0) || (withSerialMode && (batch_size * m * n) % (rank_size * BLOCK_SIZE_16) == 0)){ + RunAllReduceAlign16(tiling_key, COC_ARGS_CALL()); + } else { + RunAllReduceUnAlign16(tiling_key, COC_ARGS_CALL()); + } + PipeBarrier(); +} +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_alltoall_allgather_hidden.cce b/comm/lcal/src/kernels/coc_alltoall_allgather_hidden.cce new file mode 100644 index 0000000000000000000000000000000000000000..dcfdeeb15823235a599361d772f10631be430437 --- /dev/null +++ b/comm/lcal/src/kernels/coc_alltoall_allgather_hidden.cce @@ -0,0 +1,600 @@ +#ifdef __DAV_C220_VEC__ +#include + +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class AllToAllvAllGatherHiddenSplit: public CocCommBase{ +public: + FORCE_INLINE_AICORE AllToAllvAllGatherHiddenSplit(){}; + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)){ + CocCommBase::SetArgs(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + this->gm_out = aligned_a ? reinterpret_cast<__gm__ T *>(workspace_info.gm_a_align) : gm_a; + this -> expert_nums = local_expert_nums * EP; + comm_k = p_value * k0; + comm_count = DivCeil(k, comm_k);//28 + this->gm_quant_scale = reinterpret_cast<__gm__ float32_t *>(gm_quant_scale); + int32_t output_num = m; + if (!is_moe_averaged) { + output_num = 0; + this -> global_tokens_per_expert_matrix = global_tokens_per_expert_matrix; + for (int32_t i = 0 ; i < EP; i++) { + for (int32_t j = 0; j < local_expert_nums; j++) { + output_num += this->global_tokens_per_expert_matrix[i * expert_nums + j + rank * local_expert_nums]; + } + } + } + if (maxOutputSize > 0 && output_num >= maxOutputSize) { + output_num = maxOutputSize; + } + if(dequant_granularity == QuantGranularity::PER_TOKEN) { + serial_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ MatType *>(gm_out), reinterpret_cast<__gm__ float32_t*>(workspace_info.gm_dequant_param), output_num, n, m0, n0); + } + } + + + inline __attribute__((always_inline)) __aicore__ void ScaleAllToAll(){ + int32_t usable_buff = 200 * 1024 * 1024 / 4 / 2; + int32_t max_move_num = usable_buff / rank_size; + int32_t scale_pingpang_size = usable_buff; + + int32_t cal_count = 0; + if(is_moe_averaged) { + cal_count = DivCeil(m / EP, max_move_num); + } else { + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t in_num = 0; + int32_t out_num = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + out_num += global_tokens_per_expert_matrix[rank * expert_nums + j + ep_idx * local_expert_nums]; + } + for(int32_t j = 0; j < local_expert_nums; j++) { + in_num += global_tokens_per_expert_matrix[ep_idx * expert_nums + j + rank * local_expert_nums];// + } + cal_count = max(cal_count, max(in_num, out_num)); + } + cal_count = DivCeil(cal_count, max_move_num); + } + + PipeBarrier(); + + int64_t sum_out = 0, sum_in = 0; + int32_t received_loop_number = 0; + int32_t ep_idx = real_core_idx; + + int32_t out_num, in_num; + if(is_moe_averaged) { + out_num = m / EP; + in_num = m / EP; + } else { + out_num = 0; + in_num = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + out_num += global_tokens_per_expert_matrix[rank * expert_nums + j + real_core_idx * local_expert_nums]; + } + for(int32_t j = 0; j < local_expert_nums; j ++) { + in_num += global_tokens_per_expert_matrix[real_core_idx * expert_nums + rank * local_expert_nums + j]; + } + } + + max_ub_ping_pong_size = max_ub_ping_pong_size / 2; // + int32_t receive_expert_id = 0; + int32_t receive_expert_token_nums; + int32_t last_ep_local = 0; + if (is_moe_averaged) { + receive_expert_token_nums = m / EP / local_expert_nums; + last_ep_local = (m / EP) * real_core_idx; + } else { + receive_expert_token_nums = global_tokens_per_expert_matrix[real_core_idx * expert_nums + rank * local_expert_nums]; + for(int32_t i = 0; i < real_core_idx * local_expert_nums; i++) { + last_ep_local += global_tokens_per_expert_matrix[rank * expert_nums + i]; + } + } + + + + for(int32_t cal_idx = 0; cal_idx < cal_count; cal_idx ++) { + int32_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + + SetAndWaitAivSync(flag_idx, gm_a_pingpong_num); + int32_t received_rank_num = 0; + if (is_moe_averaged){ + received_rank_num = rank_size; + } else { + for(int32_t i = 0; i < EP; i++) { + int32_t in_num_tmp = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + in_num_tmp += global_tokens_per_expert_matrix[i * expert_nums + rank * local_expert_nums + j];// + } + if(cal_idx * max_move_num < in_num_tmp) { + received_rank_num += 1; + } + } + } + + received_loop_number += received_rank_num; + + if(real_core_idx < rank_size){ + if(real_core_idx == rank) { + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE); + } + if(is_moe_averaged || cal_idx * max_move_num < out_num) { + int32_t data_len = ((cal_idx + 1) * max_move_num >= out_num) ? (out_num - cal_idx * max_move_num) : max_move_num; + __gm__ float32_t *src_address; + __gm__ float32_t *dst_address = (__gm__ float32_t *)buff[real_core_idx] + flag_idx * scale_pingpang_size + max_move_num * rank;; + if (is_moe_averaged) { + src_address = gm_quant_scale + 1LL * last_ep_local + sum_out; + } else { + src_address = gm_quant_scale + 1LL * last_ep_local + sum_out; + } + + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE * (cal_idx + 1)); + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDstv2(src_address, dst_address, data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + sum_out += data_len; + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ADD_IDX, FLAG_VALUE); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ADD_IDX, FLAG_VALUE * received_loop_number); + + if(is_moe_averaged || cal_idx * max_move_num < in_num) { + int32_t data_len = ((cal_idx + 1) * max_move_num >= in_num) ? (in_num - cal_idx * max_move_num) : max_move_num; + __gm__ float32_t *src_address; + __gm__ float32_t *dst_address; + src_address = (__gm__ float32_t *)buff[rank] + flag_idx * scale_pingpang_size + max_move_num * real_core_idx; + + while(receive_expert_id < local_expert_nums && data_len > 0) { + int32_t move_data_len; + if (data_len >= receive_expert_token_nums){ + move_data_len = receive_expert_token_nums; + } else { + move_data_len = data_len; + } + + if (is_moe_averaged) { + dst_address = reinterpret_cast<__gm__ float32_t *>(workspace_info.gm_dequant_param) + + 1LL * (m / local_expert_nums) * receive_expert_id + 1LL * (m / expert_nums) * real_core_idx + sum_in; + } else { + int32_t before_expert_sum = 0; + for(int32_t i = 0; i < receive_expert_id; i++){ + for(int32_t j = 0; j < EP; j ++) { + before_expert_sum += global_tokens_per_expert_matrix[j * expert_nums + i + rank * local_expert_nums]; + } + } + int32_t before_rank_in_expert_sum = 0; + for(int32_t i = 0; i < real_core_idx; i++){ + before_rank_in_expert_sum += global_tokens_per_expert_matrix[i * expert_nums + rank * local_expert_nums + receive_expert_id]; + } + dst_address = reinterpret_cast<__gm__ float32_t *>(workspace_info.gm_dequant_param) + + 1LL * before_expert_sum + 1LL * before_rank_in_expert_sum + sum_in; + } + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDstv2(src_address, dst_address, move_data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + + if (data_len >= receive_expert_token_nums){ + receive_expert_id += 1; + data_len -= receive_expert_token_nums; + if (receive_expert_id > local_expert_nums) { + break; + } + if (is_moe_averaged) { + receive_expert_token_nums = m / EP / local_expert_nums; + } else { + receive_expert_token_nums = global_tokens_per_expert_matrix[real_core_idx * expert_nums + receive_expert_id + rank * local_expert_nums]; + } + sum_in = 0; + } else{ + sum_in += data_len; + receive_expert_token_nums -= data_len; + data_len = 0; + } + src_address += move_data_len; + } + } + } + } + + + max_ub_ping_pong_size = max_ub_ping_pong_size * 2; + if (real_core_idx < rank_size) { + if(real_core_idx == rank) { + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + FLAG_TWO_IDX, 0); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_TWO_IDX, 0); + } + PipeBarrier(); + } + + + + inline __attribute__((always_inline)) __aicore__ void AllGatherGlobalTokensMatrix(){ + int32_t usable_buff = 100 * 1024 * 1024 / 4; + //先把num_local_tokens_per_expert copy到共享内存 + PipeBarrier(); + SetAndWaitAivSync(0); + if(real_core_idx < rank_size) { + int32_t data_len = expert_nums; + __gm__ int32_t *src_address = num_local_tokens_per_expert; + __gm__ int32_t *dst_address = (__gm__ int32_t *)buff[rank]; + if(real_core_idx == rank) { + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, 1, data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE); + } + + + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE); + + src_address = (__gm__ int32_t *)buff[real_core_idx]; + dst_address = global_tokens_per_expert_matrix + real_core_idx * data_len; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, 1, data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + } + + + + template + inline __attribute__((always_inline)) __aicore__ void MoveResultFromSrcToDst(__gm__ CommType *gm_src, __gm__ CommType *gm_dst, + int32_t m_actual, int32_t n_actual, bool is_align) + { + __ubuf__ CommType *output_UB_T[2] = {(__ubuf__ CommType *)(32), (__ubuf__ CommType *)(97440)}; + int32_t max_move_m = (max_ub_ping_pong_size / Block32B::AlignUp(n_actual)); + if (max_move_m > 4095) + max_move_m = 4095; + int32_t ping_pong_move_count = DivCeil(m_actual, max_move_m); + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_m = max_move_m; + int32_t actual_move_n = n_actual; + if(move_idx == ping_pong_move_count - 1) { + actual_move_m = m_actual - move_idx * max_move_m; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + if(is_align) { + CopyGmToUbuf(ub_buff_st, gm_src, actual_move_m, actual_move_n * sizeof(CommType) / 32, (k_align - actual_move_n) * sizeof(CommType) / 32, 0); + } else { + CopyGmToUbufAlignB16(ub_buff_st, gm_src, actual_move_m ,actual_move_n * sizeof(CommType), (k_align - actual_move_n) * sizeof(CommType), 0); + } + SetFlag(event_id); + WaitFlag(event_id); + if(is_align) { + CopyUbufToGm(gm_dst, ub_buff_st, actual_move_m, actual_move_n * sizeof(CommType) / 32, 0, 0); + } else { + CopyUbufToGmAlignB16(gm_dst, ub_buff_st, actual_move_m , actual_move_n * sizeof(CommType), 0, 0); + } + gm_src += actual_move_m * k_align; + gm_dst += actual_move_m * actual_move_n; + SetFlag(event_id); + } + } + + template + inline __attribute__((always_inline)) __aicore__ void MoveResultFromSrcToDstv2(__gm__ CommType *gm_src, __gm__ CommType *gm_dst, + int32_t len, bool is_align) + { + __ubuf__ CommType *output_UB_T[2] = {(__ubuf__ CommType *)(32), (__ubuf__ CommType *)(97440)}; + int32_t ping_pong_move_count = (len + max_ub_ping_pong_size - 1) / max_ub_ping_pong_size; + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = len - move_idx * max_ub_ping_pong_size; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + if(is_align) { + CopyGmToUbuf(ub_buff_st, gm_src, 1, actual_move_size * sizeof(CommType) / 32, 0, 0); + } else { + CopyGmToUbufAlignB16(ub_buff_st, gm_src, 1, actual_move_size * sizeof(CommType), 0, 0); + } + SetFlag(event_id); + WaitFlag(event_id); + if(is_align) { + CopyUbufToGm(gm_dst, ub_buff_st, 1, actual_move_size * sizeof(CommType) / 32, 0, 0); + } else { + CopyUbufToGmAlignB16(gm_dst, ub_buff_st, 1, actual_move_size * sizeof(CommType), 0, 0); + } + gm_src += max_ub_ping_pong_size; + gm_dst += max_ub_ping_pong_size; + SetFlag(event_id); + } + } + + + + inline __attribute__((always_inline)) __aicore__ void EndFlagsAndBias() + { + ResetIpcFlags(4); + if (real_core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + // if constexpr (HAVE_BIAS) { + // add_bias_runner.Run(); + // } + } + +inline __attribute__((always_inline)) __aicore__ void Run(){ + preprocessor.Run(local_expert_nums); + if (is_moe_averaged) { + max_m = m; + } else { + if (maxOutputSize == -1) { + max_m = 0; + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t sum_m_ep = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = local_expert_id + ep_idx * local_expert_nums; + for(int32_t i = 0; i < EP; i++) { + sum_m_ep += global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + } + max_m = max(max_m, sum_m_ep); + } + } else { + max_m = maxOutputSize; + } + } + gm_a_pingpong_size = comm_k * max_m; //8192 + gm_a_pingpong_num = buffer_size * 1024 * 1024 / sizeof(T) / gm_a_pingpong_size; + if (gm_a_pingpong_num > 8) { + gm_a_pingpong_num = 8; + } + + if(dequant_granularity == QuantGranularity::PER_TOKEN){ + ScaleAllToAll(); + } + + withSerialMode = 1; + int64_t dst_before_expert_sum[16] = {0}; // 当前expert搬运起点,src的位置 + int32_t sum_num_local_tokens_per_expert[16] = {0}; // 当前expert搬运dst的位置 + int32_t gmm_ep_idx = real_core_idx < rank_size ? real_core_idx : rank_size - 1; + if (!is_moe_averaged) { + int32_t hcumsum = 0; + for (int32_t j = 0; j <= gmm_ep_idx; j++) { + for(int32_t i = 0; i < local_expert_nums; i++) { + if (j == gmm_ep_idx) + sum_num_local_tokens_per_expert[i] = hcumsum; + hcumsum += global_tokens_per_expert_matrix[rank * expert_nums + i + j * local_expert_nums]; + } + } + + int32_t cumsum = 0; + for (int32_t i = 0; i < local_expert_nums; i++) { + for(int32_t j = 0; j < rank_size; j++) { + if (j == rank) { + dst_before_expert_sum[i] = cumsum; + } + cumsum += global_tokens_per_expert_matrix[j * expert_nums + i + gmm_ep_idx * local_expert_nums]; + } + } + } else { + for (int32_t i = 0; i < local_expert_nums; i++) { + sum_num_local_tokens_per_expert[i] = (m / expert_nums) * (gmm_ep_idx * local_expert_nums + i); + dst_before_expert_sum[i] = (m / expert_nums) * (EP * i + rank); + } + // dst_before_expert_sum = token_per_expert * rank_size * local_expert_id; + // dst_in_expert_sum = token_per_expert * rank; + } + + for(int32_t comm_idx = 0; comm_idx < comm_count + gm_a_pingpong_num ; comm_idx++){ + uint64_t flag_idx = comm_idx % gm_a_pingpong_num; + + if(comm_idx > gm_a_pingpong_num - 1) { + WaitEvent(flag_idx); + } + SetAndWaitAivSync(flag_idx, gm_a_pingpong_num); + if (real_core_idx < rank_size && comm_idx < comm_count) { + if(real_core_idx == rank){ + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ZERO_IDX, FLAG_VALUE); + } + int32_t k_len; + if(comm_idx == comm_count - 1){ + k_len = k - comm_idx * comm_k; + } else { + k_len = comm_k; + } + + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ZERO_IDX, FLAG_VALUE * (comm_idx + 1)); + + int32_t m_len = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = real_core_idx * local_expert_nums + local_expert_id; + if(is_moe_averaged) { + m_len = m / EP / local_expert_nums; + } else { + m_len = global_tokens_per_expert_matrix[rank * expert_nums + expert_id]; + } + if (maxOutputSize > 0 && m_len > maxOutputSize - dst_before_expert_sum[local_expert_id]) { + m_len = maxOutputSize - dst_before_expert_sum[local_expert_id]; + } + if (m_len <= 0) { + continue; + } + __gm__ T *src_address, *dst_address; + src_address = gm_out + 1LL * k_align * sum_num_local_tokens_per_expert[local_expert_id] + comm_idx * comm_k; + dst_address = buff[real_core_idx] + 1LL * flag_idx * gm_a_pingpong_size + 1LL * k_len * dst_before_expert_sum[local_expert_id]; + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, m_len, k_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ONE_IDX, FLAG_VALUE); + if(real_core_idx == rank){ + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ONE_IDX, FLAG_VALUE * (comm_idx + 1) * rank_size); + } + } + SetAndWaitAivSync(flag_idx, gm_a_pingpong_num); + if (comm_idx < comm_count) { + SetAicSync(flag_idx); + } + } + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + serial_pertoken_dequant_runner.Run(); + } + EndFlagsAndBias(); + } + +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::real_core_idx; + using CocCommBase::core_num; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::buffer_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_direct; + using CocCommBase::swizzl_count; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::withSerialMode; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::weight_nz; + + using CocCommBase::global_tokens_per_expert_matrix; + using CocCommBase::num_local_tokens_per_expert; + + + using CocCommBase::local_expert_nums; + using CocCommBase::TP; + using CocCommBase::EP; + using CocCommBase::maxOutputSize; + using CocCommBase::flag_offset; + int32_t max_m; + int32_t comm_k; + int32_t comm_count; + int32_t gm_a_pingpong_size; + int32_t expert_nums; + int32_t gm_a_pingpong_num; + int32_t m_align; + int32_t k_align; + int32_t n_align; + int32_t aligned_a; + int32_t aligned_b; + Preprocessor preprocessor; + + //AllGatherMatmulBiasAdder add_bias_runner; + FusedPerTokenDequantRunner fused_pertoken_dequant_runner; + SerialPerTokenDequantRunner serial_pertoken_dequant_runner; + __gm__ float32_t *gm_quant_scale; +}; + + + +template +inline __aicore__ void CocAllToAllVAllGatherHiddenAiv(COC_ARGS_FUN(T)){ + AllToAllvAllGatherHiddenSplit alltoall_allgather_without_bias; + AllToAllvAllGatherHiddenSplit alltoall_allgather_with_bias; + AllToAllvAllGatherHiddenSplit alltoall_allgather_int8_without_bias; + AllToAllvAllGatherHiddenSplit alltoall_allgather_int8_with_bias; + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t write_to_other_rank = cocTilingData->write2OtherRank; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + alltoall_allgather_without_bias.SetArgs(COC_ARGS_CALL()); + alltoall_allgather_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + alltoall_allgather_with_bias.SetArgs(COC_ARGS_CALL()); + alltoall_allgather_with_bias.Run(); + break; + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + alltoall_allgather_int8_without_bias.SetArgs(COC_ARGS_CALL_INT8()); + alltoall_allgather_int8_without_bias.Run(); + break; + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + alltoall_allgather_int8_with_bias.SetArgs(COC_ARGS_CALL_INT8()); + alltoall_allgather_int8_with_bias.Run(); + break; + default : + break; + } + PipeBarrier(); +} + +#endif diff --git a/comm/lcal/src/kernels/coc_alltoall_reduce_scatter_hidden.cce b/comm/lcal/src/kernels/coc_alltoall_reduce_scatter_hidden.cce new file mode 100644 index 0000000000000000000000000000000000000000..7bea5dc40f00a775fb0a6c1dc481580b6a4873b3 --- /dev/null +++ b/comm/lcal/src/kernels/coc_alltoall_reduce_scatter_hidden.cce @@ -0,0 +1,350 @@ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class AllToAllVReduceScatterHiddenSplit : public CocCommBase { +public: + __aicore__ explicit AllToAllVReduceScatterHiddenSplit(){}; + + inline __attribute__((always_inline)) __aicore__ void SetArgs(COC_ARGS_FUN(T)) + { + CocCommBase::SetArgs(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + this->gm_out = gm_out; + expert_nums = local_expert_nums * EP; + if (!is_moe_averaged) { + this->global_tokens_per_expert_matrix = global_tokens_per_expert_matrix; + } + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + if(dequant_granularity == QuantGranularity::PER_TOKEN) { + fused_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ T *>(buff[rank]), workspace_info, reinterpret_cast<__gm__ float32_t*>(gm_quant_scale), + batch_size, m, k, n, m0,k0, n0, m_loop, n_loop, core_loop, rank, swizzl_direct, + swizzl_count, p_value, EP, TP, local_expert_nums, is_moe_averaged, 1, maxOutputSize, buffer_size, global_tokens_per_expert_matrix); + } + } + + inline __attribute__((always_inline)) __aicore__ void EndFlagsAndBias() + { + ResetIpcFlags(2); + if (real_core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + } + + + + + template + inline __attribute__((always_inline)) __aicore__ void MoveResultFromSrcToDst(__gm__ CommType *gm_src, __gm__ CommType *gm_dst, + int32_t m_actual,int32_t n_actual) + { + __ubuf__ CommType *output_UB_T[2] = {(__ubuf__ CommType *)(32), (__ubuf__ CommType *)(97440)}; + int32_t max_move_m = (max_ub_ping_pong_size / Block32B::AlignUp(n_actual)); + if (max_move_m > 4095) + max_move_m = 4095; + int32_t ping_pong_move_count = DivCeil(m_actual, max_move_m); + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_m = max_move_m; // 4 + int32_t actual_move_n = n_actual;//3584 + if(move_idx == ping_pong_move_count - 1) { + actual_move_m = m_actual - move_idx * max_move_m; + } + + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbufAlignB16(ub_buff_st, gm_src, actual_move_m ,actual_move_n * sizeof(CommType), 0, 0); + + SetFlag(event_id); + WaitFlag(event_id); + CopyUbufToGmAlignB16(gm_dst, ub_buff_st, actual_move_m , actual_move_n * sizeof(CommType), 0, (n - actual_move_n) * sizeof(CommType)); + gm_src += actual_move_m * actual_move_n; + gm_dst += actual_move_m * n; + SetFlag(event_id); + } + } + + + + inline __attribute__((always_inline)) __aicore__ void Run() + { + preprocessor.Run(local_expert_nums); + if (is_moe_averaged) { + max_m = m; + } else { + if (maxOutputSize == -1) { + max_m = 0; + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t sum_m_ep = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = local_expert_id + ep_idx * local_expert_nums; + for(int32_t i = 0; i < EP; i++) { + sum_m_ep += global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + } + max_m = max(max_m, sum_m_ep); + } + } else { + max_m = maxOutputSize; + } + } + + comm_n = p_value * n0; + gm_a_pingpong_size = max_m * comm_n; + gm_a_pingpong_num = buffer_size * 1024 * 1024 / 2 / gm_a_pingpong_size; + if (gm_a_pingpong_num > 8) { + gm_a_pingpong_num = 8; + } + + cal_count = DivCeil(n, comm_n); + int32_t max_flag_id = cal_count < gm_a_pingpong_num ? cal_count : gm_a_pingpong_num; + for (int64_t cal_idx = 0; cal_idx < max_flag_id; ++cal_idx) { + SetAicSync(cal_idx); + } + + int64_t dst_before_expert_sum[16] = {0}; // 当前expert搬运起点,src的位置 + int32_t sum_num_local_tokens_per_expert[16] = {0}; // 当前expert搬运dst的位置 + int32_t gmm_ep_idx = real_core_idx < rank_size ? real_core_idx : rank_size - 1; + if (!is_moe_averaged) { + int32_t hcumsum = 0; + for (int32_t j = 0; j <= gmm_ep_idx; j++) { + for(int32_t i = 0; i < local_expert_nums; i++) { + if (j == gmm_ep_idx) + sum_num_local_tokens_per_expert[i] = hcumsum; + hcumsum += global_tokens_per_expert_matrix[rank * expert_nums + i + j * local_expert_nums]; + } + } + + int32_t cumsum = 0; + for (int32_t i = 0; i < local_expert_nums; i++) { + for(int32_t j = 0; j < rank_size; j++) { + if (j == rank) { + dst_before_expert_sum[i] = cumsum; + } + cumsum += global_tokens_per_expert_matrix[j * expert_nums + i + gmm_ep_idx * local_expert_nums]; + } + } + } else { + for (int32_t i = 0; i < local_expert_nums; i++) { + sum_num_local_tokens_per_expert[i] = (m / expert_nums) * (gmm_ep_idx * local_expert_nums + i); + dst_before_expert_sum[i] = (m / expert_nums) * (EP * i + rank); + } + } + + for (int32_t cal_idx = 0; cal_idx < cal_count; ++cal_idx) { + uint64_t flag_idx = cal_idx % gm_a_pingpong_num; + WaitEvent(flag_idx); + + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + SetAndWaitAivSync(flag_idx,gm_a_pingpong_num); + fused_pertoken_dequant_runner.DequantPerTokenMatmulAllToAllHidden(cal_idx); + } + SetAndWaitAivSync(flag_idx, gm_a_pingpong_num); + int64_t n_len, m_len; + if(cal_idx == cal_count - 1){ + n_len = n - cal_idx * comm_n; + } else { + n_len = comm_n; + } + int32_t n_loop_cal = DivCeil(n_len, n0); + + if (real_core_idx < rank_size) { + if (real_core_idx == rank) { + SetBuffFlagByAdd( + ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + FLAG_ZERO_IDX, FLAG_VALUE); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_ZERO_IDX, FLAG_VALUE * (cal_idx + 1)); + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = real_core_idx * local_expert_nums + local_expert_id; + if(is_moe_averaged) { + m_len = m / EP / local_expert_nums; + } else { + m_len = global_tokens_per_expert_matrix[rank * expert_nums + expert_id]; + } + + if (maxOutputSize > 0 && m_len > maxOutputSize - dst_before_expert_sum[local_expert_id]) { + m_len = maxOutputSize - dst_before_expert_sum[local_expert_id]; + } + + if (m_len <= 0) { + continue; + } + int64_t buff_offset = flag_idx * gm_a_pingpong_size + 1LL * dst_before_expert_sum[local_expert_id] * n_len; + + int64_t gm_offset = 1LL * sum_num_local_tokens_per_expert[local_expert_id] * n + 1LL * cal_idx * comm_n; + __gm__ T *src_address, *dst_address; + src_address = buff[real_core_idx] + buff_offset; + dst_address = gm_out + gm_offset; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, m_len, n_len); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_ONE_IDX, FLAG_VALUE); + if (real_core_idx == rank) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + FLAG_ONE_IDX, + FLAG_VALUE * (cal_idx + 1) * EP); + } + } + SetAndWaitAivSync(flag_idx, gm_a_pingpong_num); + SetAicSync(flag_idx); + } + EndFlagsAndBias(); + } + + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::FillZero; + using CocCommBase::FirstStepInPeerMem; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::real_core_idx; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::buffer_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_count; + using CocCommBase::swizzl_direct; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::loop_num_per_comm; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::maxOutputSize; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::flag_offset; + using CocCommBase::weight_nz; + + int32_t gm_a_pingpong_size; + int32_t gm_a_pingpong_num; + int32_t cal_count; + int32_t comm_n; + int32_t max_m; + + int32_t m_align; + int32_t k_align; + int32_t n_align; + + + using CocCommBase::global_tokens_per_expert_matrix; + using CocCommBase::expert_nums; + using CocCommBase::local_expert_nums; + using CocCommBase::TP; + using CocCommBase::EP; + Preprocessor preprocessor; + FusedPerTokenDequantRunner fused_pertoken_dequant_runner; +}; + + +template +inline __attribute__((always_inline)) __aicore__ void RunAllToAllVReduceScatterHiddenAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) +{ + // 16 align + AllToAllVReduceScatterHiddenSplit all_to_allv_reduce_scatter_align_16_without_bias; + //AllToAllVReduceScatterHiddenSplit all_to_allv_reduce_scatter_align_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + all_to_allv_reduce_scatter_align_16_without_bias.SetArgs(COC_ARGS_CALL()); + all_to_allv_reduce_scatter_align_16_without_bias.Run(); + break; + // case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + // case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + // case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + // case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + // all_to_allv_reduce_scatter_align_16_with_bias.SetArgs(COC_ARGS_CALL()); + // all_to_allv_reduce_scatter_align_16_with_bias.Run(); + // break; + default: + break; + } +} + +template +inline __attribute__((always_inline)) __aicore__ void RunAllToAllVReduceScatterHiddenUnAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) +{ + // 16 unalign + AllToAllVReduceScatterHiddenSplit all_to_allv_reduce_scatter_unalign_16_without_bias; + AllToAllVReduceScatterHiddenSplit all_to_allv_reduce_scatter_unalign_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + all_to_allv_reduce_scatter_unalign_16_without_bias.SetArgs(COC_ARGS_CALL()); + all_to_allv_reduce_scatter_unalign_16_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + all_to_allv_reduce_scatter_unalign_16_with_bias.SetArgs(COC_ARGS_CALL()); + all_to_allv_reduce_scatter_unalign_16_with_bias.Run(); + break; + default: + break; + } +} + +template +inline __attribute__((always_inline)) __aicore__ void CocMatmulAllToAllVReduceScatterHiddenAiv(COC_ARGS_FUN(T)) +{ + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t n = cocTilingData->n; + int32_t tiling_key = cocTilingData->tilingKey; + if (n % BLOCK_SIZE_16 == 0) { + RunAllToAllVReduceScatterHiddenAlign16(tiling_key, COC_ARGS_CALL()); + } else { + RunAllToAllVReduceScatterHiddenUnAlign16(tiling_key, COC_ARGS_CALL()); + } + + PipeBarrier(); +} + +#endif diff --git a/comm/lcal/src/kernels/coc_alltoallv_allgather.cce b/comm/lcal/src/kernels/coc_alltoallv_allgather.cce new file mode 100644 index 0000000000000000000000000000000000000000..0b70df2aa9ed0e22f527a57cae48ec3c28fc4eef --- /dev/null +++ b/comm/lcal/src/kernels/coc_alltoallv_allgather.cce @@ -0,0 +1,642 @@ +#ifdef __DAV_C220_VEC__ +#include + +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class AllToAllvAllGather: public CocCommBase{ +public: + __aicore__ explicit AllToAllvAllGather(){}; + inline __attribute__((always_inline)) __aicore__ void SetArgs(COC_ARGS_FUN(T)){ + CocCommBase::SetArgs(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + if constexpr (HAVE_BIAS) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + this->gm_out = aligned_a ? reinterpret_cast<__gm__ T *>(workspace_info.gm_a_align) : gm_a; + this->gm_quant_scale = reinterpret_cast<__gm__ float32_t *>(gm_quant_scale); + this -> expert_nums = local_expert_nums * EP; + is_moe_averaged = 0; + if(global_tokens_per_expert_matrix == nullptr) { + is_moe_averaged = 1; + } + this -> global_tokens_per_expert_matrix = reinterpret_cast<__gm__ int32_t *>(global_tokens_per_expert_matrix); + gm_a_pingpong_size = m0 * k_align * p_value * rank_size; + if(dequant_granularity == QuantGranularity::PER_TOKEN) { + int32_t output_num = m; + if (!is_moe_averaged) { + output_num = 0; + for (int32_t i = 0 ; i < EP; i++) { + for (int32_t j = 0; j < local_expert_nums; j++) { + output_num += global_tokens_per_expert_matrix[i * expert_nums + j + rank * local_expert_nums]; + } + } + if (maxOutputSize > 0 && output_num >= maxOutputSize) { + output_num = maxOutputSize; + } + } + serial_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ MatType *>(gm_out), reinterpret_cast<__gm__ float32_t*>(workspace_info.gm_dequant_param), output_num, n, m0, n0); + } + } + + inline __attribute__((always_inline)) __aicore__ void ScaleAllToAll(){ + + int32_t usable_buff = 200 * 1024 * 1024 / 4 / 2; + int32_t max_move_num = usable_buff / rank_size; + int32_t scale_pingpang_size = usable_buff; + + int32_t cal_count = 0; + if(is_moe_averaged) { + cal_count = DivCeil(m / EP, max_move_num); + } else { + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t in_num = 0; + int32_t out_num = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + out_num += global_tokens_per_expert_matrix[rank * expert_nums + j + ep_idx * local_expert_nums]; + } + for(int32_t j = 0; j < local_expert_nums; j++) { + in_num += global_tokens_per_expert_matrix[ep_idx * expert_nums + j + rank * local_expert_nums];// + } + cal_count = max(cal_count, max(in_num, out_num)); + } + cal_count = DivCeil(cal_count, max_move_num); + } + + PipeBarrier(); + + int64_t sum_out = 0, sum_in = 0; + int32_t received_loop_number = 0; + int32_t ep_idx = real_core_idx; + + int32_t out_num, in_num; + if(is_moe_averaged) { + out_num = m / EP; + in_num = m / EP; + } else if(real_core_idx < rank_size){ + out_num = 0; + in_num = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + out_num += global_tokens_per_expert_matrix[rank * expert_nums + j + real_core_idx * local_expert_nums]; + } + for(int32_t j = 0; j < local_expert_nums; j ++) { + in_num += global_tokens_per_expert_matrix[real_core_idx * expert_nums + rank * local_expert_nums + j]; + } + } + + max_ub_ping_pong_size = max_ub_ping_pong_size / 2; // + int32_t receive_expert_id = 0; + int32_t receive_expert_token_nums; + int32_t last_ep_local = 0; + if (is_moe_averaged) { + receive_expert_token_nums = m / EP / local_expert_nums; + last_ep_local = (m / EP) * real_core_idx; + } else if(real_core_idx < rank_size){ + receive_expert_token_nums = global_tokens_per_expert_matrix[real_core_idx * expert_nums + rank * local_expert_nums]; + for(int32_t i = 0; i < real_core_idx * local_expert_nums; i++) { + last_ep_local += global_tokens_per_expert_matrix[rank * expert_nums + i]; + } + } + + for(int32_t cal_idx = 0; cal_idx < cal_count; cal_idx ++) { + int32_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + + SetAndWaitAivSync(flag_idx); + int32_t received_rank_num = 0; + if (is_moe_averaged){ + received_rank_num = rank_size; + } else { + for(int32_t i = 0; i < EP; i++) { + int32_t in_num_tmp = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + in_num_tmp += global_tokens_per_expert_matrix[i * expert_nums + rank * local_expert_nums + j];// + } + if(cal_idx * max_move_num < in_num_tmp) { + received_rank_num += 1; + } + } + } + + received_loop_number += received_rank_num; + + if(real_core_idx < rank_size){ + if(real_core_idx == rank) { + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE); + } + if(is_moe_averaged || cal_idx * max_move_num < out_num) { + int32_t data_len = ((cal_idx + 1) * max_move_num >= out_num) ? (out_num - cal_idx * max_move_num) : max_move_num; + __gm__ float32_t *src_address; + __gm__ float32_t *dst_address = (__gm__ float32_t *)buff[real_core_idx] + flag_idx * scale_pingpang_size + max_move_num * rank;; + if (is_moe_averaged) { + src_address = gm_quant_scale + 1LL * last_ep_local + sum_out; + } else { + src_address = gm_quant_scale + 1LL * last_ep_local + sum_out; + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_TWO_IDX, FLAG_VALUE * (cal_idx + 1)); + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + sum_out += data_len; + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ADD_IDX, FLAG_VALUE); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ADD_IDX, FLAG_VALUE * received_loop_number); + + if(is_moe_averaged || cal_idx * max_move_num < in_num) { + int32_t data_len = ((cal_idx + 1) * max_move_num >= in_num) ? (in_num - cal_idx * max_move_num) : max_move_num; + __gm__ float32_t *src_address; + __gm__ float32_t *dst_address; + src_address = (__gm__ float32_t *)buff[rank] + flag_idx * scale_pingpang_size + max_move_num * real_core_idx; + + while(receive_expert_id < local_expert_nums && data_len > 0) { + int32_t move_data_len; + if (data_len >= receive_expert_token_nums){ + move_data_len = receive_expert_token_nums; + } else { + move_data_len = data_len; + } + + if (is_moe_averaged) { + dst_address = reinterpret_cast<__gm__ float32_t *>(workspace_info.gm_dequant_param) + + 1LL * (m / local_expert_nums) * receive_expert_id + 1LL * (m / expert_nums) * real_core_idx + sum_in; + } else { + int32_t before_expert_sum = 0; + for(int32_t i = 0; i < receive_expert_id; i++){ + for(int32_t j = 0; j < EP; j ++) { + before_expert_sum += global_tokens_per_expert_matrix[j * expert_nums + i + rank * local_expert_nums]; + } + } + int32_t before_rank_in_expert_sum = 0; + for(int32_t i = 0; i < real_core_idx; i++){ + before_rank_in_expert_sum += global_tokens_per_expert_matrix[i * expert_nums + rank * local_expert_nums + receive_expert_id]; + } + dst_address = reinterpret_cast<__gm__ float32_t *>(workspace_info.gm_dequant_param) + + 1LL * before_expert_sum + 1LL * before_rank_in_expert_sum + sum_in; + } + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, move_data_len, 0); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + + if (data_len >= receive_expert_token_nums){ + receive_expert_id += 1; + data_len -= receive_expert_token_nums; + if (receive_expert_id > local_expert_nums) { + break; + } + if (is_moe_averaged) { + receive_expert_token_nums = m / EP / local_expert_nums; + } else { + receive_expert_token_nums = global_tokens_per_expert_matrix[real_core_idx * expert_nums + receive_expert_id + rank * local_expert_nums]; + } + sum_in = 0; + } else{ + sum_in += data_len; + receive_expert_token_nums -= data_len; + data_len = 0; + } + src_address += move_data_len; + } + } + } + } + + + max_ub_ping_pong_size = max_ub_ping_pong_size * 2; + + if (real_core_idx < rank_size) { + if(real_core_idx == rank) { + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + FLAG_TWO_IDX, 0); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_TWO_IDX, 0); + } + PipeBarrier(); + } + + + + template + inline __attribute__((always_inline)) __aicore__ void MoveResultFromSrcToDst(__gm__ CommType *gm_src, __gm__ CommType *gm_dst, + int32_t len, bool is_align = true) + { + __ubuf__ CommType *output_UB_T[2] = {(__ubuf__ CommType *)(32), (__ubuf__ CommType *)(97440)}; + int32_t ping_pong_move_count = (len + max_ub_ping_pong_size - 1) / max_ub_ping_pong_size; + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = len - move_idx * max_ub_ping_pong_size; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + if(is_align) { + CopyGmToUbuf(ub_buff_st, gm_src, 1, actual_move_size * sizeof(CommType) / 32, 0, 0); + } else { + CopyGmToUbufAlignB16(ub_buff_st, gm_src, 1, actual_move_size * sizeof(CommType), 0, 0); + } + SetFlag(event_id); + WaitFlag(event_id); + if(is_align) { + CopyUbufToGm(gm_dst, ub_buff_st, 1, actual_move_size * sizeof(CommType) / 32, 0, 0); + } else { + CopyUbufToGmAlignB16(gm_dst, ub_buff_st, 1, actual_move_size * sizeof(CommType), 0, 0); + } + gm_src += max_ub_ping_pong_size; + gm_dst += max_ub_ping_pong_size; + SetFlag(event_id); + } + } + + inline __attribute__((always_inline)) __aicore__ void EndFlagsAndBias() + { + ResetIpcFlags(4); + if (real_core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + } + +inline __attribute__((always_inline)) __aicore__ void Run(){ + preprocessor.Run(local_expert_nums); + int32_t comm_m = p_value * m0; + int32_t comm_count; + if (is_moe_averaged) { + comm_count = DivCeil(m / EP , comm_m); + } else { + int32_t max_comm_count = 0; + int32_t max_input_per_ep = 0; + int32_t max_output_per_ep = 0; + for (int32_t ep_idx = 0; ep_idx < EP; ep_idx++) { + int32_t tmp_sum = 0; + for(int32_t i = 0; i < local_expert_nums; i++) { + tmp_sum += global_tokens_per_expert_matrix[rank * expert_nums + ep_idx * local_expert_nums + i]; + } + max_output_per_ep = max(max_output_per_ep, tmp_sum); + tmp_sum = 0; + for(int32_t i = 0; i < local_expert_nums; i++) { + tmp_sum += global_tokens_per_expert_matrix[ep_idx * expert_nums + rank * local_expert_nums + i]; + } + max_input_per_ep = max(max_input_per_ep, tmp_sum); + max_comm_count = max(max_comm_count, max(max_output_per_ep, max_input_per_ep)); + } + comm_count = DivCeil(max_comm_count, comm_m); + } + + + int32_t out_num = 0;//发往 core_idx 卡的token数; + int32_t before_rank_offset_src = 0;//发往core_idx卡的token的地址offset; + int32_t cur_local_expert_id = 0;//当前正在发送的local expert id; + int32_t cur_expert_len = 0;//当前发送的local expert 的token的长度; + int32_t expert_remain_data_len; + if (real_core_idx < rank_size) { + if (is_moe_averaged) { + before_rank_offset_src = (m / rank_size) * real_core_idx; + out_num = (m / rank_size); + cur_expert_len = m / rank_size / local_expert_nums; + } else { + for(int32_t i = 0; i < real_core_idx; i++){ + for (int32_t j = 0; j < local_expert_nums; j++) { + before_rank_offset_src += global_tokens_per_expert_matrix[rank * expert_nums + j + i * local_expert_nums]; + } + } + for(int32_t i = 0; i < local_expert_nums; i++) { + out_num += global_tokens_per_expert_matrix[rank * expert_nums + i + real_core_idx * local_expert_nums]; + } + cur_expert_len = global_tokens_per_expert_matrix[rank * expert_nums + real_core_idx * local_expert_nums]; + } + } + expert_remain_data_len = cur_expert_len; + + + + if(dequant_granularity == QuantGranularity::PER_TOKEN){ + ScaleAllToAll(); + } + + + int32_t cur_expert = real_core_idx * local_expert_nums; + int32_t received_loop_number = 0; + int32_t sum_out_this_core = 0; //已经发往 core_idx 卡的token数 + int32_t sum_in_expert = 0; //当前expert已经发送的token数 + + for(int32_t comm_idx = 0; comm_idx < comm_count + MAX_BLOCK_COUNT; comm_idx++){ + uint64_t flag_idx = comm_idx % MAX_BLOCK_COUNT; + int32_t received_rank_num = 0; + if (is_moe_averaged){ + received_rank_num = rank_size; + } else { + for(int32_t i = 0; i < EP; i++){ + int32_t in_loop_per_ep = 0; + for(int32_t j = 0; j < local_expert_nums; j++) { + in_loop_per_ep += global_tokens_per_expert_matrix[i * expert_nums + j + rank * local_expert_nums]; + } + if (comm_idx * comm_m < in_loop_per_ep) { + received_rank_num += 1; + } + } + } + received_loop_number += received_rank_num; + + if (comm_idx > 1) { + WaitEvent(flag_idx); + } + SetAndWaitAivSync(flag_idx); + + + if (real_core_idx < rank_size && comm_idx < comm_count) { + if(real_core_idx == rank){ + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ZERO_IDX, FLAG_VALUE); + } + if(is_moe_averaged || comm_idx * comm_m < out_num){ + int32_t data_len; + if ((comm_idx + 1) * comm_m >= out_num){ + data_len = out_num - comm_idx * comm_m; + } else { + data_len = comm_m; + } + + __gm__ T *src_address, *dst_address; + src_address = gm_out + 1LL * before_rank_offset_src * k_align + 1LL * comm_idx * comm_m * k_align; + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ZERO_IDX, FLAG_VALUE * (comm_idx + 1)); + + //因为data_len的token可能跨expert,所以需要循环 + int32_t remain_data_len = data_len; + while(cur_local_expert_id < local_expert_nums && remain_data_len > 0) { + int32_t move_data_len; + if (remain_data_len >= cur_expert_len - sum_in_expert) { + move_data_len = cur_expert_len - sum_in_expert; + } else { + move_data_len = remain_data_len; + } + + if (move_data_len > 0) { + move_data_len = 1LL * move_data_len * k_align; + //关键点:计算本次通信在目标卡共享内存内的地址 + int32_t before_expert_offset = 0; //在目标卡的共享内存中这次通信expert的offset + int32_t before_rank_offset = 0; //在目标卡的共享内存中这次通信当前expert中当前rank的offset + for(int32_t i = 0; i < rank_size; i++) { + //第i张卡发往core_idx卡的token数。 + int32_t out_this_rank = 0; + for (int32_t j = 0; j < local_expert_nums; j ++) { + int32_t expert_token_num; + if (is_moe_averaged) { + expert_token_num = m / expert_nums; + } else { + expert_token_num = global_tokens_per_expert_matrix[i * expert_nums + real_core_idx * local_expert_nums + j]; + } + out_this_rank += expert_token_num; + } + + int32_t data_len_this_rank; + if ((comm_idx + 1) * comm_m >= out_this_rank) { + data_len_this_rank = out_this_rank - comm_idx * comm_m; + } else { + data_len_this_rank = comm_m; + } + + //expert token数的前缀和 + int32_t sum = 0; + for(int32_t j = 0; j < cur_local_expert_id; j++) { + int32_t expert_id = real_core_idx * local_expert_nums + j; + //本次通信,第i张卡发往expert_id的token数。 + //i卡发往expert的总的token数: + int32_t expert_token_num; + if (is_moe_averaged) { + expert_token_num = m / expert_nums; + } else { + expert_token_num = global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + if (comm_idx * comm_m < sum + expert_token_num && comm_idx * comm_m + data_len_this_rank > sum) + { + int32_t tmp_len = min(comm_idx * comm_m + data_len_this_rank, sum + expert_token_num) - + max(comm_idx * comm_m, sum); + before_expert_offset += tmp_len; + } + sum += expert_token_num; + } + if (i < rank) { + int32_t expert_id = real_core_idx * local_expert_nums + cur_local_expert_id; + int32_t expert_token_num; + if (is_moe_averaged) { + expert_token_num = m / expert_nums; + } else { + expert_token_num = global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + if ((comm_idx * comm_m < sum + expert_token_num) && (comm_idx * comm_m + data_len_this_rank > sum)) { + int32_t tmp_len = min(comm_idx * comm_m + data_len_this_rank, sum + expert_token_num) - + max(comm_idx * comm_m, sum); + before_rank_offset += tmp_len; + } + } + } + + dst_address = buff[real_core_idx] + 1LL * flag_idx * gm_a_pingpong_size + + 1LL * before_expert_offset * k_align + 1LL * before_rank_offset * k_align; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + MoveResultFromSrcToDst(src_address, dst_address, move_data_len); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + } + + if (remain_data_len >= cur_expert_len - sum_in_expert) { + cur_local_expert_id ++; + remain_data_len -= (cur_expert_len - sum_in_expert); + if (is_moe_averaged) { + cur_expert_len = m / expert_nums; + } else if(cur_local_expert_id < local_expert_nums){ + cur_expert_len = global_tokens_per_expert_matrix[rank * expert_nums + real_core_idx * local_expert_nums + cur_local_expert_id]; + } + sum_in_expert = 0; + } else { + sum_in_expert += remain_data_len; + remain_data_len = 0; + } + src_address += move_data_len; + } + + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[real_core_idx] + flag_offset + + FLAG_ONE_IDX, FLAG_VALUE); + } + if(real_core_idx == rank){ + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + + FLAG_ONE_IDX, FLAG_VALUE * received_loop_number); + } + } + + SetAndWaitAivSync(flag_idx); + SetAicSync(flag_idx); + } + + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + serial_pertoken_dequant_runner.Run(); + } + EndFlagsAndBias(); + } + + +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::real_core_idx; + using CocCommBase::core_num; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_direct; + using CocCommBase::swizzl_count; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::withSerialMode; + + + using CocCommBase::num_local_tokens_per_expert; + using CocCommBase::num_global_tokens_per_local_expert; + using CocCommBase::global_tokens_per_expert_matrix; + + using CocCommBase::local_expert_nums; + using CocCommBase::TP; + using CocCommBase::EP; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::maxOutputSize; + using CocCommBase::weight_nz; + + using CocCommBase::comm_npu_split; + using CocCommBase::comm_data_split; + using CocCommBase::comm_direct; + using CocCommBase::len_per_loop; + using CocCommBase::core_count; + using CocCommBase::flag_offset; + + + __gm__ int32_t *out_loop_per_ep; + __gm__ int32_t *in_loop_per_ep; + __gm__ int32_t *sum_num_local_tokens_per_expert; + __gm__ int32_t *sum_num_global_tokens_per_local_expert; + __gm__ int32_t *expert_comm_count_accum; + + __gm__ float32_t *gm_quant_scale; + + + + int32_t gm_a_pingpong_size; + int32_t m_align; + int32_t k_align; + int32_t n_align; + int32_t aligned_a; + int32_t aligned_b; + + int32_t expert_nums; + + Preprocessor preprocessor; + AllGatherMatmulBiasAdder add_bias_runner; + SerialPerTokenDequantRunner serial_pertoken_dequant_runner; + + bool need_dequant; +}; + + + +template +inline __aicore__ void CocAllToAllVAllGatherAiv(COC_ARGS_FUN(T)){ + AllToAllvAllGather alltoall_allgather_without_bias; + AllToAllvAllGather alltoall_allgather_with_bias; + AllToAllvAllGather alltoall_allgather_int8_without_bias; + AllToAllvAllGather alltoall_allgather_int8_with_bias; + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t write_to_other_rank = cocTilingData->write2OtherRank; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + alltoall_allgather_without_bias.SetArgs(COC_ARGS_CALL()); + alltoall_allgather_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + alltoall_allgather_with_bias.SetArgs(COC_ARGS_CALL()); + alltoall_allgather_with_bias.Run(); + break; + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + alltoall_allgather_int8_without_bias.SetArgs(COC_ARGS_CALL_INT8()); + alltoall_allgather_int8_without_bias.Run(); + break; + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + alltoall_allgather_int8_with_bias.SetArgs(COC_ARGS_CALL_INT8()); + alltoall_allgather_int8_with_bias.Run(); + break; + default : + break; + } + PipeBarrier(); +} + +#endif diff --git a/comm/lcal/src/kernels/coc_alltoallv_allgather_matmul.cce b/comm/lcal/src/kernels/coc_alltoallv_allgather_matmul.cce new file mode 100644 index 0000000000000000000000000000000000000000..ab8f4469c396ce9572afabd3ce5645932fa61a70 --- /dev/null +++ b/comm/lcal/src/kernels/coc_alltoallv_allgather_matmul.cce @@ -0,0 +1,44 @@ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + + + +#include "coc_ppmatmul_switch.cce" +#include "coc_alltoallv_allgather.cce" +#include "coc_alltoall_allgather_hidden.cce" +#ifdef __DAV_C220_CUBE__ + + +#define COC_ALL_TO_ALL_ALL_GATHER_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllToAllVAllGatherMatmul_##type##_mix_aic(COC_ARGS_FUN(type)){ \ + return CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} +#define COC_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllToAllVAllGatherMatmulHidden_##type##_mix_aic(COC_ARGS_FUN(type)){ \ + return CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + + +#elif __DAV_C220_VEC__ +#define COC_ALL_TO_ALL_ALL_GATHER_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllToAllVAllGatherMatmul_##type##_mix_aiv(COC_ARGS_FUN(type)){ \ + return CocAllToAllVAllGatherAiv(COC_ARGS_CALL()); \ +} +#define COC_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalAllToAllVAllGatherMatmulHidden_##type##_mix_aiv(COC_ARGS_FUN(type)){ \ + return CocAllToAllVAllGatherHiddenAiv(COC_ARGS_CALL()); \ +} + +#endif + + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_ALL_TO_ALL_ALL_GATHER_MATMUL_FUNC_AUTO_DEF); +COC_TYPE_FUNC(COC_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN_FUNC_AUTO_DEF); + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_comm_base.cce b/comm/lcal/src/kernels/coc_comm_base.cce new file mode 100644 index 0000000000000000000000000000000000000000..5b3b1209efe96ef08166e55bb566e3492036e77d --- /dev/null +++ b/comm/lcal/src/kernels/coc_comm_base.cce @@ -0,0 +1,545 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_COC_COMM_BASE_H +#define LCAL_COC_COMM_BASE_H + +#ifdef __DAV_C220_VEC__ + +#include "coc_internal.cce" +#include "coc_add_bias_runner.cce" +#include "coc_preprocessor.cce" +#include "coc_postprocessor.cce" +#include "tiling_args.h" +#include "lcoc_workspace.h" +template +class CocCommBase { +public: + __aicore__ explicit CocCommBase(){}; + + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) + { + CoCBuffAddrAndArgs coc_buff_and_args(COC_ARGS_CALL()); + for (int i=0; igm_out = gm_out; + max_ub_ping_pong_size = max_ub_ping_pong_size / n0 * n0; + loop_num_per_comm = p_value * get_block_num(); + gm_c_pingpong_size = m0 * n0 * loop_num_per_comm; + } + + FORCE_INLINE_AICORE void SetFromParam(__gm__ uint8_t *para_gm) + { + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + auto quantInfo = ¶->quantInfo; + auto twoDimTPInfo = ¶->twoDimTPInfo; + auto moeInfo = ¶->moeInfo; + batch_size = cocTilingData->batchSize; + m = cocTilingData->m; + k = cocTilingData->k; + n = cocTilingData->n; + + m0 = cocTilingData->m0; + k0 = cocTilingData->k0; + n0 = cocTilingData->n0; + + m_loop = cocTilingData->mLoop; + k_loop = cocTilingData->kLoop; + n_loop = cocTilingData->nLoop; + + core_loop = cocTilingData->coreLoop; + swizzl_count = cocTilingData->swizzlCount; + tiling_key = cocTilingData->tilingKey; + rank = cocTilingData->rank; + rank_size = cocTilingData->rankSize; + buffer_size = cocTilingData->bufferSize; + flag_offset = buffer_size * 1024 * 1024 / sizeof(int32_t);; + p_value = cocTilingData->pValue; + max_ub_single_dma_size = cocTilingData->ubMoveNum; + withSerialMode = cocTilingData->withSerialMode; + tag = cocTilingData->tag; + comm_npu_split = cocTilingData->commNpuSplit; + comm_data_split = cocTilingData->commDataSplit; + comm_direct = cocTilingData->commDirect; + len_per_loop = cocTilingData->lenPerLoop; + extra_ub_move_num = cocTilingData->extraUbMoveNum; + extra_comm_npu_split = cocTilingData->extraCommNpuSplit; + extra_comm_data_split = cocTilingData->extraCommDataSplit; + extra_comm_direct = cocTilingData->extraCommDirect; + extra_len_per_loop = cocTilingData->extraLenPerLoop; + is_91093 = cocTilingData->is91093; + core_count = comm_npu_split * comm_data_split; + dequant_granularity = static_cast(quantInfo->dequantGranularity); + dequant_group_size = quantInfo->dequantGroupSize; + quant_granularity = static_cast(quantInfo->quantGranularity); + quant_group_size = quantInfo->quantGroupSize; + swizzl_direct = (tiling_key & SWIZZL_MASK) ? true : false; + trans_a = (tiling_key & TRANS_A_MASK) ? true : false; + trans_b = (tiling_key & TRANS_B_MASK) ? true : false; + is_int8 = (tiling_key & INT8_MASK) ? true : false; + + ag_dim = twoDimTPInfo->agDim; + rs_dim = twoDimTPInfo->rsDim; + inner_dim_is_Ag = twoDimTPInfo->innerDimIsAg; + weight_nz = para->weightNz; + + local_expert_nums = moeInfo->local_expert_nums; + TP = moeInfo->TP; + EP = moeInfo->EP; + maxOutputSize = moeInfo->maxOutputSize; + is_moe = moeInfo->isMoe; + } + + FORCE_INLINE_AICORE void SetWorkspace(__gm__ uint8_t *gm_workspace) + { + int32_t m_align, k_align, n_align; + if (is_int8) { + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + } else { + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + } + int32_t aligned_a, aligned_b; + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + + bool has_a_align = IsQuant(quant_granularity) || aligned_a; + bool has_b_align = IsQuant(dequant_granularity) && !is_int8 || aligned_b; + bool has_accum = IsQuant(dequant_granularity) && is_int8 && (std::is_same::value || std::is_same::value); + bool has_dequant_param = (dequant_granularity == QuantGranularity::PER_TOKEN || dequant_granularity == QuantGranularity::PER_TENSOR); + bool hasFormatDequantScale = (dequant_granularity == QuantGranularity::PER_CHANNEL); + + if (weight_nz) { + aligned_b = 0; + has_b_align = false; + } + workspace_info = GetLcalWorkspaceInfo(gm_workspace, batch_size, m, k, n, m_align, k_align, n_align, + trans_a, trans_b, is_int8 ? 1 : 2, has_a_align, has_b_align, 0, has_accum, 0, has_dequant_param, + hasFormatDequantScale, is_deterministic, is_moe, is_alltoallvc, EP, local_expert_nums, maxOutputSize); + + } + + FORCE_INLINE_AICORE void SetAicSync(uint64_t flag_idx) + { + FFTSCrossCoreSync(2, flag_idx); + } + + FORCE_INLINE_AICORE void SetAndWaitAivSync(uint64_t flag_idx, int32_t pipe_depth = 2) + { + FFTSCrossCoreSync(0, flag_idx + pipe_depth); + WaitEvent(flag_idx + pipe_depth); + } + + + + FORCE_INLINE_AICORE void SetBuffFlag(__ubuf__ int32_t *ctrl_flags_UB, \ + __gm__ int32_t *buff, int32_t flag) + { + *ctrl_flags_UB = flag; + SetFlag(EVENT_ID2); + WaitFlag(EVENT_ID2); + CopyUbufToGmAlignB16(buff, ctrl_flags_UB, 1, sizeof(int32_t), 0, 0); + } + + FORCE_INLINE_AICORE void SetBuffFlagByAdd(__ubuf__ int32_t *ctrl_flags_UB, \ + __gm__ int32_t *buff, int32_t flag) + { + PipeBarrier(); + *ctrl_flags_UB = flag; + PipeBarrier(); + SetAtomicAdd(); + PipeBarrier(); + CopyUbufToGmAlignB16(buff, ctrl_flags_UB, 1, sizeof(int32_t), 0, 0); + PipeBarrier(); + SetAtomicNone(); + PipeBarrier(); + } + + inline __aicore__ void call_dcci(__gm__ void *__restrict__ gm_ptr) + { + __asm__ __volatile__(""); + dcci(gm_ptr, SINGLE_CACHE_LINE); + __asm__ __volatile__(""); + } + + FORCE_INLINE_AICORE void CheckBuffFlag(__ubuf__ int32_t *ctrl_flags_UB, \ + __gm__ int32_t *buff, int32_t flag) + { + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + while (true) { + CopyGmToUbufAlignB16(ctrl_flags_UB, buff, 1, sizeof(int32_t), 0, 0); + SetFlag(EVENT_ID3); + WaitFlag(EVENT_ID3); // Scalar等MTE2 + if (*ctrl_flags_UB == flag) { + break; + } + } + } + + FORCE_INLINE_AICORE void CrossRankSyncV1(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx == rank) { + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + flag_idx, FLAG_VALUE); + } else if (aiv_idx == 0 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[core_idx] + flag_offset + flag_idx, + FLAG_VALUE * flag_data); + } + } + + FORCE_INLINE_AICORE void CrossRankSyncV2(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx < rank_size) { + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[core_idx] + flag_offset + flag_idx, FLAG_VALUE); + } + if (aiv_idx == 0 && core_idx == rank) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + flag_idx, + FLAG_VALUE * rank_size * flag_data); + } + } + + FORCE_INLINE_AICORE void CrossRankSyncV3(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx == rank) { + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + flag_idx, flag_data); + } else if (aiv_idx == 0 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[core_idx] + flag_offset + flag_idx, + flag_data); + } + } + + FORCE_INLINE_AICORE void CrossRankSyncV4(int32_t flag_idx, int32_t flag_data) + { + if (aiv_idx == 0 && core_idx < rank_size){ + if (core_idx != rank) { + SetBuffFlagByAdd(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + flag_idx, flag_data); + } + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[core_idx] + flag_offset + flag_idx, flag_data * rank_size); + } + } + + + FORCE_INLINE_AICORE void ResetIpcFlags(int32_t num_flags) + { + for (int32_t idx = 0; idx < num_flags; ++idx) { + if (core_idx == 0 && aiv_idx == 0){ + SetBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[rank] + flag_offset + idx, 0); + } + } + } + + FORCE_INLINE_AICORE void FillZero(int32_t data_size_remain, __gm__ T *output, \ + int32_t total_aiv, int32_t aiv_idx_in_clean){ + int32_t repeat_time = 128; + int32_t num_per_call = repeat_time * 128; + // 检查T是否为float16_t + if constexpr (std::is_same::value) { + VectorDup(output_UB_T[0], static_cast(0), repeat_time, 1, 8); + } + // 检查T是否为bfloat16_t + else if constexpr (std::is_same::value) { + VectorDup(output_UB_T[0], static_cast(0), repeat_time, 1, 8); + } + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + data_size_remain = DivCeil(data_size_remain, total_aiv); + data_size_remain = (data_size_remain + 15) / 16 * 16; + int32_t offset = aiv_idx_in_clean * data_size_remain; + while (data_size_remain > 0){ + int32_t data_size = data_size_remain < num_per_call ? data_size_remain : num_per_call; + CopyUbufToGm(output + offset, output_UB_T[0], 1, data_size * sizeof(T) / 32, 0, 0); + data_size_remain -= data_size; + offset += data_size; + } + } + + FORCE_INLINE_AICORE void CopyUbToGmTransLayout(__ubuf__ T* ub_buff_st, int32_t actual_move_size, int64_t move_num_offset) { + auto ub_buff = ub_buff_st; + int32_t left_m = actual_move_size / n0; + while (left_m > 0){ + int32_t loop_idx = move_num_offset / (m0 * n0); + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t actual_m = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t actual_n = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + int32_t m_offset = (move_num_offset % (m0 * n0)) / n0; + int32_t actual_move_m = m0 < m_offset + left_m ? m0 - m_offset : left_m; + if (m_offset < actual_m) { + actual_move_m = actual_m < m_offset + left_m ? actual_m - m_offset : left_m; + int64_t out_buff_offset = (m_idx * m0 + m_offset) * n + n_idx * n0; + CopyUbufToGmUnknown(n % BLOCK_SIZE_16 == 0, gm_out + out_buff_offset, ub_buff, actual_move_m, actual_n * sizeof(T), + (n0 - actual_n) * sizeof(T) / 32, (n - actual_n) * sizeof(T)); + } + left_m -= actual_move_m; + move_num_offset += actual_move_m * n0; + ub_buff += actual_move_m * n0; + } + } + + + FORCE_INLINE_AICORE void CopyGMToGM(__gm__ T* gm_src, __gm__ T* gm_dst, int32_t copy_size) { + auto ub0 = output_UB_T[0]; + auto ub1 = output_UB_T[1]; + int32_t interm_offset = 0; + for (int32_t move_idx = 0; interm_offset < copy_size; ++move_idx){ + uint32_t data_size = interm_offset + max_ub_ping_pong_size < copy_size ? max_ub_ping_pong_size : copy_size - interm_offset; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub = (move_idx & 1) ? ub0 : ub1; + WaitFlag(event_id); + CopyGmToUbuf(ub, gm_src + interm_offset, 1, data_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + CopyUbufToGm(gm_dst + interm_offset, ub, 1, data_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + interm_offset += data_size; + } + } + + // 分核线性策略:支持任意划分方案; + FORCE_INLINE_AICORE void FirstStepInPeerMemSeq(int32_t data_size_remain, int32_t core_buff_offset) { + if (data_size_remain <= 0) { + return; + } + auto ub0 = output_UB_T[0]; + auto ub1 = output_UB_T[1]; + int32_t rank_per_core = (rank_size) / comm_npu_split; + int32_t core_rank_offset = (core_idx / comm_data_split) * rank_per_core; // 每个core搬运不同的卡 + + for (int32_t rank_idx = 0; rank_idx < rank_per_core; ++rank_idx){ + int32_t rank_idx_rot = (rank_idx + core_idx) % rank_per_core; + int32_t m_rank_idx = core_rank_offset + rank_idx_rot; + if (m_rank_idx == rank) { + continue; + } + if (is_91093 && (m_rank_idx % 2) != (rank % 2)) { // 91093只搬奇偶相同的卡 + continue; + } + CopyGMToGM(buff[m_rank_idx] + core_buff_offset, buff[rank] + core_buff_offset, data_size_remain); + } + } + + // 分核树形策略,仅支持comm_npu_split=1,当前仅支持4或8卡 + FORCE_INLINE_AICORE void FirstStepInPeerMemTree(int32_t data_size_remain, int32_t core_buff_offset) { + if (data_size_remain <= 0) { + return; + } + int32_t rank_per_core = (rank_size) / comm_npu_split; + int32_t core_rank_offset = (core_idx / comm_data_split) * rank_per_core; // 每个core搬运不同的卡 + // 额外buffer: core_num * len_per_loop[20480] * (ranksize/2) * sizeof(fp16) = 3932160 (4MB) + __gm__ T* gm_reducebuf = reinterpret_cast<__gm__ T *>(workspace_info.gm_reducebuf) + core_idx * len_per_loop * rank_size / 2; // 每个core使用:rank_size / 2 * len_per_loop长度 + + // 共搬运7次,前3次搬运先用普通move,不atomic加 + SetAtomicNone(); + int32_t rank_idx = 0; // 初始化的NPU ID + int32_t turn_atomic_step = rank_size / 2 - 1; // rank_size==8, step=3; rank_size==4, step=1; + for (int32_t visited = 0; visited < rank_size - 1; visited++){ // rank 8 + if (visited == turn_atomic_step) { // 前(rank/2-1)次搬完后,开始atomic加 + SetAtomicAdd(); + } + int32_t rank_idx_rot = (rank_idx + core_idx) % rank_per_core; // 实际NPU ID + if (rank_idx_rot == rank) { + rank_idx++; + rank_idx_rot = (rank_idx + core_idx) % rank_per_core; + } + if (is_91093 && (rank_idx_rot % 2) != (rank % 2)) { // 91093只搬奇偶相同的卡 + continue; + } + // 搬运地址:0 1 2,3 4 5分别搬到同样的地方,后3次atomicadd + auto gm_interm = gm_reducebuf + (visited % turn_atomic_step) * len_per_loop; + if (visited == rank_size - 2) { // last, atomic add to self peermem + gm_interm = buff[rank] + core_buff_offset; + } + auto gm_peer = buff[rank_idx_rot] + core_buff_offset; + CopyGMToGM(gm_peer, gm_interm, data_size_remain); + rank_idx++; + } + if (rank_size == 8) { // rank8树形累加 + // interm[1] -> self peermem + CopyGMToGM(gm_reducebuf + 1 * len_per_loop, buff[rank] + core_buff_offset, data_size_remain); + // interm[2] -> interm[0] + CopyGMToGM(gm_reducebuf + 2 * len_per_loop, gm_reducebuf, data_size_remain); + } + if (rank_size >= 4) { + // interm[0] -> self peermem + CopyGMToGM(gm_reducebuf, buff[rank] + core_buff_offset, data_size_remain); + } + + } + + // 原始策略:每个core负责一个NPU的搬运 + FORCE_INLINE_AICORE void FirstStepInPeerMem(int32_t data_size_remain, __gm__ T *input, __gm__ T *output, bool atomic_add = false) { + if (data_size_remain <= 0) { + return; + } + if (atomic_add) { + SetAtomicAdd(); + PipeBarrier(); + } + int32_t offset = 0; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + CopyGMToGM(input, output, data_size_remain); + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + if (atomic_add) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + + // Firststepinpeermem+转格式输出 + FORCE_INLINE_AICORE void FirstStepInPeerMemTransLayout(int32_t data_size_remain, __gm__ T *input, __gm__ T *output, int32_t out_offset = -1, bool atomic_add = false) { + if (data_size_remain <= 0) { + return; + } + if (atomic_add) { + SetAtomicAdd(); + PipeBarrier(); + } + int32_t offset = 0; + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + for (int32_t move_idx = 0; data_size_remain > 0; ++move_idx){ + uint32_t data_size = data_size_remain > max_ub_ping_pong_size ? max_ub_ping_pong_size : data_size_remain; + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + CopyGmToUbuf(ub, input + offset, 1, data_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); // MTE3等MTE2 + WaitFlag(event_id); + CopyUbufToGm(output + offset, ub, 1, data_size * sizeof(T) / 32, 0, 0); + CopyUbToGmTransLayout(ub, data_size, out_offset + offset); + SetFlag(event_id); // MTE2等MTE3 + data_size_remain -= data_size; + offset += data_size; + } + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + if (atomic_add) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + +public: + __gm__ T *buff[LCAL_MAX_RANK_SIZE]; + __gm__ T *gm_out; + __ubuf__ int32_t *ctrl_flags_UB = (__ubuf__ int32_t *)(0); + __ubuf__ T *output_UB_T[2] = {(__ubuf__ T *)(32), (__ubuf__ T *)(97440)}; + + __gm__ int32_t *num_local_tokens_per_expert; + __gm__ int32_t *num_global_tokens_per_local_expert; + __gm__ int32_t *global_tokens_per_expert_matrix; + int32_t expert_nums,local_expert_nums, TP, EP, maxOutputSize; + int32_t is_moe, is_moe_averaged, is_alltoallvc; + + int32_t batch_size; + int32_t m; + int32_t k; + int32_t n; + int32_t m0; + int32_t k0; + int32_t n0; + + int32_t m_loop; + int32_t n_loop; + int32_t k_loop; + int32_t core_loop; + int32_t core_idx; + int32_t real_core_idx; + + int32_t rank; + int32_t rank_size; + int32_t buffer_size; + int32_t flag_offset; + + int32_t tiling_key; + int32_t swizzl_count; + bool swizzl_direct; + bool trans_a; + bool trans_b; + bool is_int8; + bool is_91093; + int32_t p_value; + + int32_t aiv_idx; + int32_t other_rank; + int32_t core_num; + int32_t max_ub_single_dma_size; + int32_t max_ub_ping_pong_size; + int32_t loop_num_per_comm; + int32_t gm_c_pingpong_size; + int32_t withSerialMode; + int32_t tag; + int32_t comm_npu_split; + int32_t comm_data_split; + int32_t comm_direct; + int32_t len_per_loop; + int32_t core_count; + + int32_t extra_ub_move_num; + int32_t extra_comm_npu_split; // 2dtp allReduce使用 + int32_t extra_comm_data_split; // 2dtp allreduce使用 + int32_t extra_comm_direct; // 2dtp allreduce使用 + int32_t extra_len_per_loop; // 2dtp allreduce使用 + bool is_deterministic; + + QuantGranularity dequant_granularity; + int32_t dequant_group_size; + QuantGranularity quant_granularity; + int32_t quant_group_size; + + LcalWorkspaceInfo workspace_info; + + int32_t ag_dim; + int32_t rs_dim; + bool inner_dim_is_Ag; + bool weight_nz{false}; +}; + +#endif +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_const_args.cce b/comm/lcal/src/kernels/coc_const_args.cce new file mode 100644 index 0000000000000000000000000000000000000000..9832d2943253aa9c6a2cd49292dcb56714513edd --- /dev/null +++ b/comm/lcal/src/kernels/coc_const_args.cce @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_COC_CONST_ARGS_H +#define LCAL_COC_CONST_ARGS_H +#include +#include "kernel_operator.h" +using namespace AscendC; +#ifndef FORCE_INLINE_AICORE +#define FORCE_INLINE_AICORE inline __attribute__((always_inline)) __aicore__ +constexpr int32_t BLOCK_SIZE = 16; +constexpr int LCAL_MAX_RANK_SIZE = 32; + +struct ExtraFlag { + static constexpr uint32_t RDMA = 1; + static constexpr uint32_t TOPO_910B2C = 1 << 1; + static constexpr uint32_t TOPO_910_93 = 1 << 2; + static constexpr uint32_t DETERMINISTIC = 1 << 3; + static constexpr uint32_t QUANT_FP16 = 1 << 4; + static constexpr uint32_t QUANT_FP32 = 1 << 5; +}; +#endif +constexpr int32_t AIV_FINISH_ALIGN_FLAG_ID = 8; +constexpr int32_t AIC_FINISH_MATMUL_FLAG_ID = 9; +constexpr int32_t AIV_FINISH_ADD_BIAS_FLAG_ID = 10; +constexpr int32_t AIV_FINISH_DEQUANT_FLAG_ID = 11; +constexpr int32_t AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID = 12; +constexpr int32_t AIV_WAIT_AIC_FINISH_MATMUL_FLAG_ID = 13; + +constexpr int32_t A3_DIE_NUM = 2; // 一张卡有两个die +constexpr int32_t BLOCK_SIZE_16 = 16; +constexpr int32_t BLOCK_SIZE_32 = 32; +constexpr int32_t SWIZZL_MASK = 0b100000; +constexpr int32_t TRANS_A_MASK = 0b010000; +constexpr int32_t TRANS_B_MASK = 0b001000; +constexpr int32_t INT8_MASK = 0b000100; +constexpr int32_t BIAS_MASK = 0b000010; +constexpr int32_t QUANT_MASK = 0x00FF0000; +constexpr int32_t QUANT_SHIFT = 16; +constexpr int32_t MAX_BLOCK_COUNT = 2; +constexpr int32_t BLOCK_COUNT_3 = 3; +constexpr int32_t BLOCK_COUNT_4 = 4; +constexpr int32_t L0AB_PINGPONG_BUFFER_LEN = 16384; // 32 KB +constexpr int32_t CUBE_MATRIX_SIZE = 256; // 16 * 16 +constexpr int64_t L1_PINGPONG_BUFFER_LEN = 131072; // 256 KB +constexpr int32_t MAX_CORE_NUM = 25; +constexpr int64_t MAX_UB_BUFF = 196608; // 192 * 1024 个 Byte; +constexpr int32_t ADD_REPEAT_TIME = 4; +constexpr int32_t FLAG_ZERO_IDX = 0; +constexpr int32_t FLAG_ONE_IDX = 1; +constexpr int32_t FLAG_TWO_IDX = 2; +constexpr int32_t FLAG_ADD_IDX = 3; +constexpr int32_t MAX_FLAG_COUNT = 3 + ADD_REPEAT_TIME * 2; +constexpr int32_t FLAG_VALUE = 1; + +constexpr int32_t VEC_BLOCK_PER_REPEAT = 8; +constexpr uint8_t REPEAT_PER_LOOP = 255; +constexpr uint32_t PPMATMUL_RUN_PURE_MATMUL = 1; +constexpr uint32_t PPMATMUL_RUN_MATMUL_ALLREDUCE = 2; +constexpr uint32_t PPMATMUL_RUN_MATMUL_REDUCE_SCATTER = 3; +constexpr uint32_t PPMATMUL_RUN_ALL_GATHER_MATMUL = 4; +constexpr uint32_t PPMATMUL_RUN_ALL_GATHER_MATMUL_V2 = 5; +constexpr int32_t LCAL_2DTP_C_OFFSET = 100 * 1024 * 1024 / sizeof(half); +constexpr uint32_t PPMATMUL_RUN_ALL_GATHER_MATMUL_REDUCE_SCATTER = 6; +constexpr uint32_t PPMATMUL_RUN_ALL_GATHER_MATMUL_SIO = 7; +constexpr int32_t HCCS_TOTAL_CORE_NUM = 8; +constexpr int32_t SIO_TOTAL_CORE_NUM = 8; +constexpr uint64_t WORKSPACE_REDUCE_SIZE = 4000000; +constexpr int32_t TWOD_DATA_SPLIT_DEFAULT = 2; +constexpr int32_t TWOD_LEN_PER_LOOP_DEFAULT = 5120; + + +constexpr uint32_t PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL = 13; +constexpr uint32_t PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN = 15; +constexpr uint32_t PPMATMUL_RUN_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN = 16; +constexpr int LCAL_BUFF_BYTES = 204 * 1024 * 1024; +constexpr int32_t FLAG_BUFF_BYTES = 5 * 512 * 1024; // 2.5MB +constexpr int32_t FLAG_OFFSET = (LCAL_BUFF_BYTES - FLAG_BUFF_BYTES) / sizeof(int32_t); // 201.5 * 1024 * 1024 + +enum QuantGranularity : int { + QUANT_GRANULARITY_UNDEFINED = -1, + PER_TENSOR = 0, + PER_CHANNEL = 1, + PER_GROUP = 2, + PER_TOKEN = 3, + FLOAT32_SCALE_PER_CHANNEL = 4, + QUANT_GRANULARITY_MAX = 5, +}; + + +template +struct BaseBlock { + static_assert((SIZE & (SIZE - 1)) == 0, "Invalid block size"); + static constexpr size_t size = SIZE / sizeof(T); + + static FORCE_INLINE_AICORE size_t Count(size_t len) + { + return (len + size - 1) / size; + } + + static FORCE_INLINE_AICORE bool IsAligned(size_t len) + { + return len % size == 0; + } + + static FORCE_INLINE_AICORE size_t AlignUp(size_t len) + { + return (len + size - 1) & ~(size - 1); + } + + static FORCE_INLINE_AICORE size_t AlignDown(size_t len) + { + return len & ~(size - 1); + } +}; + +template +using Block32B = BaseBlock; + +template +using Block256B = BaseBlock; + +template +using Block512B = BaseBlock; + +template +struct CoCCommArgs { + int rank; // attr rank_id, global rank + int localRank; + int rankSize; // global rank size + int localRankSize; + uint32_t extraFlag; + __gm__ T *peerMems[LCAL_MAX_RANK_SIZE]; + int64_t sendCountMatrix[LCAL_MAX_RANK_SIZE * LCAL_MAX_RANK_SIZE]; +}; + + + + + +#endif // LCAL_COC_CONST_ARGS_H \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_dequant_runner.cce b/comm/lcal/src/kernels/coc_dequant_runner.cce new file mode 100644 index 0000000000000000000000000000000000000000..b2f4a400b0ee0960fb9037aab6b6745d85540baa --- /dev/null +++ b/comm/lcal/src/kernels/coc_dequant_runner.cce @@ -0,0 +1,1381 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef __COC_DEQUANTER__ +#define __COC_DEQUANTER__ + +#ifdef __DAV_C220_VEC__ + +#include +#include "coc_internal.cce" + +template +class LoopDequanter { +}; + +template <> +class LoopDequanter { +public: + static constexpr int32_t max_len = 9792; + + inline __aicore__ LoopDequanter() = default; + + inline __aicore__ void SetForLoop() + { + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + + inline __aicore__ void WaitForLoop() + { + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + } + + inline __aicore__ void Loop(__gm__ bfloat16_t *dst, __gm__ int32_t *src, float32_t scale, int32_t offset, + int32_t n_rows_this_loop, int32_t n_cols_this_loop, int32_t src_stride, int32_t dst_stride) + { + is_ping = !is_ping; + auto ub_in = is_ping ? ub_in0 : ub_in1; + auto ub_out = is_ping ? ub_out0 : ub_out1; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + + int32_t n_blocks = Block32B::Count(n_cols_this_loop) * (sizeof(int32_t) / sizeof(bfloat16_t)); + int32_t ubuf_gap = n_blocks - Block32B::Count(n_cols_this_loop); + + WaitFlag(event_id); + CopyGmToUbufAlign(ub_in, src, n_rows_this_loop, n_cols_this_loop, src_stride - n_cols_this_loop, ubuf_gap); + SetFlag(event_id); + + WaitFlag(event_id); + Vadds(ub_adds, ub_in, offset, repeat, 1, 1, 8, 8); + SetFlag(event_id); + + PipeBarrier(); + + Vconv(ub_adds_f32, ub_adds, repeat, 1, 1, 8, 8); + + PipeBarrier(); + + Vmuls(ub_muls, ub_adds_f32, scale, repeat, 1, 1, 8, 8); + + PipeBarrier(); + + WaitFlag(event_id); + Vconv(ub_out, ub_muls, repeat, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(event_id); + + WaitFlag(event_id); + CopyUbufToGmAlign(dst, ub_out, n_rows_this_loop, n_cols_this_loop, dst_stride - n_cols_this_loop); + SetFlag(event_id); + } + +private: + static constexpr uint8_t repeat = 153; + __ubuf__ bfloat16_t *ub_out0 = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)0); + __ubuf__ bfloat16_t *ub_out1 = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)19584); + __ubuf__ float32_t *ub_adds_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)39936); + __ubuf__ int32_t *ub_in0 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)79104); + __ubuf__ int32_t *ub_in1 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)118272); + __ubuf__ int32_t *ub_adds = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)157440); + __ubuf__ float32_t *ub_muls = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)157440); + + bool is_ping = false; +}; + +template <> +class LoopDequanter { +public: + static constexpr int32_t max_len = 8192; + + inline __aicore__ LoopDequanter() = default; + + inline __aicore__ void SetForLoop() + { + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID0); + } + + inline __aicore__ void WaitForLoop() + { + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID0); + } + + inline __aicore__ void Loop(__gm__ bfloat16_t *dst, __gm__ int32_t *src, __gm__ float32_t *scale, + int32_t n_rows_this_loop, int32_t n_cols_this_loop, int32_t src_stride, int32_t dst_stride) + { + is_ping = !is_ping; + auto ub_in = is_ping ? ub_in0 : ub_in1; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + + int32_t n_blocks = Block32B::Count(n_cols_this_loop) * (sizeof(int32_t) / sizeof(bfloat16_t)); + int32_t ubuf_gap = n_blocks - Block32B::Count(n_cols_this_loop); + + WaitFlag(event_id); + CopyGmToUbufAlign(ub_in, src, n_rows_this_loop, n_cols_this_loop, src_stride - n_cols_this_loop, ubuf_gap); + SetFlag(event_id); + + WaitFlag(event_id); + Vconv(ub_in_f32, ub_in, repeat, 1, 1, 8, 8); + SetFlag(event_id); + + WaitFlag(EVENT_ID2); + if (scale_rows == 0 || scale_source != scale) { + scale_rows = 1; + scale_source = scale; + + CopyGmToUbufAlign(ub_scale, scale, 1, n_cols_this_loop, 0); + } + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + for (; scale_rows < n_rows_this_loop; ++scale_rows) { + CopyUB2UB(ub_scale + scale_rows * n_blocks * Block32B::size, ub_scale, + 0, 1, n_blocks, 0, 0); + } + PipeBarrier(); + + Vmul(ub_mul, ub_in_f32, ub_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID0); + Vconv(ub_out, ub_mul, repeat, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + CopyUbufToGmAlign(dst, ub_out, n_rows_this_loop, n_cols_this_loop, dst_stride - n_cols_this_loop); + SetFlag(EVENT_ID0); + } + +private: + static constexpr uint8_t repeat = 128; + __ubuf__ int32_t *ub_in0 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)0); + __ubuf__ float32_t *ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)32768); + __ubuf__ float32_t *ub_in_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)65536); + __ubuf__ float32_t *ub_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)98560); + __ubuf__ bfloat16_t *ub_out = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)131328); + __ubuf__ int32_t *ub_in1 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)163840); + + __gm__ float32_t *scale_source = nullptr; + int32_t scale_rows = 0; + bool is_ping = false; +}; + +template +class LoopPerTokenDequanter { +public: + static constexpr int32_t max_len = 8 * 32 / 4 * 128; + + inline __aicore__ LoopPerTokenDequanter(int32_t n0) + { + n_round = (n0 + 127) / 128 * 128; // n_this_loop + 127 / 128是需要的repeat数,每个repeat占用8个blocks + ub_in0 = reinterpret_cast<__ubuf__ T *>((uintptr_t)0); + ub_in1 = reinterpret_cast<__ubuf__ T *>(ub_in0 + max_len); + ub_out = reinterpret_cast<__ubuf__ T *>(ub_in1 + max_len); + ub_scales = reinterpret_cast<__ubuf__ float32_t *>(ub_out + max_len); + ub_in_f32 = reinterpret_cast<__ubuf__ float32_t *>(ub_scales + max_len); + ub_out_f32 = reinterpret_cast<__ubuf__ float32_t *>(ub_in_f32 + max_len); + } + + inline __aicore__ void SetForLoop() + { + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + + + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID2); + } + + inline __aicore__ void WaitForLoop() + { + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + + + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void Loop(__gm__ T *buff, __gm__ float32_t *scale, + int32_t n_rows_this_loop, int32_t n_cols_this_loop, int32_t stride) + { + + is_ping = !is_ping; + auto ub_in = is_ping ? ub_in0 : ub_in1; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + int32_t ubufGap = Block32B::Count(n_round) - Block32B::Count(n_cols_this_loop); + WaitFlag(event_id); + CopyGmToUbufAlign(ub_in, buff, n_rows_this_loop, n_cols_this_loop, stride - n_cols_this_loop, ubufGap); + SetFlag(event_id); + + WaitFlag(event_id); + Vconv(ub_in_f32, ub_in, repeat, 1, 1, 8, 4); + SetFlag(event_id); + + + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID2); + if (scale_source != scale) { + scale_source = scale; + CopyGmToUbufAlign(ub_scales, scale, 1, n_rows_this_loop, 0); + } + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID2); + + + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID2); // 注意必须是MTE2_S,不能是MTE2_V,否则会读到0,造成乱码 + WaitFlag(EVENT_ID2); + PipeBarrier(); + for (int32_t row = 0; row < n_rows_this_loop; ++row) { + float32_t scale = ub_scales[row]; + Vmuls(ub_out_f32 + n_round * row, ub_in_f32 + n_round * row, scale, (n_cols_this_loop + 127) / 128 * 2, 1, 1, 8, 8); + } + PipeBarrier(); + Vconv(ub_out, ub_out_f32, repeat, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID2); + + + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(buff, ub_out, n_rows_this_loop, n_cols_this_loop, stride - n_cols_this_loop, ubufGap); + SetFlag(EVENT_ID2); + } + +private: + static constexpr uint8_t repeat = 128; + __ubuf__ T *ub_in0 = nullptr; + __ubuf__ T *ub_in1 = nullptr; + __ubuf__ T *ub_out = nullptr; + __ubuf__ float32_t *ub_scales = nullptr; + __gm__ float32_t *scale_source = nullptr; + __ubuf__ float32_t *ub_in_f32 = nullptr; + __ubuf__ float32_t *ub_out_f32 = nullptr; + int32_t n_round; + bool is_ping = false; +}; + +class LoopScaleFormater { +public: + static constexpr int32_t max_len = 8160; + + inline __aicore__ LoopScaleFormater() = default; + + inline __aicore__ void SetForLoop() + { + set_ctrl(sbitset1(get_ctrl(), 59)); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + + inline __aicore__ void WaitForLoop() + { + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + set_ctrl(sbitset0(get_ctrl(), 59)); + } + + inline __aicore__ void Loop(__gm__ float32_t *dst, __gm__ int64_t *src, int32_t len) + { + is_ping = !is_ping; + auto ub_in = is_ping ? ub_in0 : ub_in1; + auto ub_vconv = is_ping ? ub_vconv0 : ub_vconv1; + auto ub_out = is_ping ? ub_out0 : ub_out1; + auto event_id = is_ping ? EVENT_ID0 : EVENT_ID1; + + WaitFlag(event_id); + CopyGmToUbufAlign(ub_in, src, 1, len, 0); + SetFlag(event_id); + + WaitFlag(event_id); + WaitFlag(event_id); + Vconv(ub_vconv, ub_in, repeat, 1, 1, 4, 8); + SetFlag(event_id); + SetFlag(event_id); + + WaitFlag(event_id); + CopyUbufToGmAlign(dst, ub_out, 1, len, 0); + SetFlag(event_id); + } + +private: + static constexpr uint8_t repeat = 255; + __ubuf__ int64_t *ub_in0 = reinterpret_cast<__ubuf__ int64_t *>((uintptr_t)0); + __ubuf__ int64_t *ub_in1 = reinterpret_cast<__ubuf__ int64_t *>((uintptr_t)131072); + __ubuf__ int32_t *ub_vconv0 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)65536); + __ubuf__ int32_t *ub_vconv1 = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)98304); + __ubuf__ float32_t *ub_out0 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)65536); + __ubuf__ float32_t *ub_out1 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)98304); + + bool is_ping = false; +}; + +class BaseDequantRunner { +public: + + class TileLoopIter { + public: + inline __aicore__ TileLoopIter(int32_t m_this_tile, int32_t n_this_tile) + { + m_this_subcore = m_this_tile >> 1; + n_this_subcore = n_this_tile; + if (get_subblockid() == 1) { + m_offset_this_subcore = m_this_subcore; + m_this_subcore += m_this_tile & 1; + } else { + m_offset_this_subcore = 0; + } + } + + inline __aicore__ void Init(int32_t max_len) + { + int32_t max_m_per_loop = max_len / Block32B::AlignUp(n_this_subcore); + m_complete = 0; + m_this_loop = max_m_per_loop > m_this_subcore ? m_this_subcore : max_m_per_loop; + n_this_loop = n_this_subcore; + } + inline __aicore__ void Init(int32_t max_len, int32_t n0) // max_len = 8192或者9792 + { + // Block32B::AlignUp:扩展到32/sizeof(half)的倍数,也就是扩展到16的倍数 + // m_this_subcore最大值:max_len / n_this_subcore, 16384/256=64 + int32_t max_m_per_loop = max_len / ((n0 + 127) / 128 * 128); + m_complete = 0; + m_this_loop = max_m_per_loop > m_this_subcore ? m_this_subcore : max_m_per_loop; // 本次loop所处理的m,最大为max_m_per_loop + n_this_loop = n_this_subcore; // 本次loop所处理的n + } + + inline __aicore__ bool End() + { + return m_complete >= m_this_subcore; + } + + inline __aicore__ void Next() + { + m_complete += m_this_loop; + if (End()) { + return; + } + if (m_complete + m_this_loop > m_this_subcore) { + m_this_loop = m_this_subcore - m_complete; + } + } + + inline __aicore__ int32_t m_offset_in_tile() const + { + return m_offset_this_subcore + m_complete; + } + + int32_t m_this_subcore; + int32_t n_this_subcore; + + int32_t m_this_loop; + int32_t n_this_loop; + + int32_t m_offset_this_subcore; + int32_t m_complete; + }; + __aicore__ explicit BaseDequantRunner() = default; + + inline __aicore__ void SetArgs(__gm__ bfloat16_t *gm_out, const LcalWorkspaceInfo &workspace_info, + __gm__ int64_t *gm_dequant_scale, __gm__ int32_t *gm_dequant_offset, QuantGranularity dequant_granularity, + int32_t batch_size, int32_t m, int32_t n) + { + this->gm_accum = reinterpret_cast<__gm__ int32_t *>(workspace_info.gm_accum); + this->gm_format_dequant_scale = reinterpret_cast<__gm__ float32_t *>(workspace_info.gm_formate_dequant_scale); + this->gm_out = gm_out; + + this->gm_dequant_scale = gm_dequant_scale; + this->gm_dequant_offset = gm_dequant_offset; + this->dequant_granularity = dequant_granularity; + + this->batch_size = batch_size; + this->m = m; + this->n = n; + + if (dequant_granularity == QuantGranularity::PER_TENSOR) { + gm_format_dequant_scale = reinterpret_cast<__gm__ float32_t *>(gm_dequant_scale); + } else if (dequant_granularity == QuantGranularity::PER_CHANNEL){ + FormatScale(); + } else { + gm_format_dequant_scale = reinterpret_cast<__gm__ float32_t *>(gm_dequant_scale); + } + } + + inline __aicore__ void FormatScale() + { + // if (dequant_granularity != QuantGranularity::PER_CHANNEL) { + // return; + // } + + int32_t align_core_idx = get_block_idx() * get_subblockdim() + get_subblockid(); + int32_t align_core_num = get_block_num() * get_subblockdim(); + + int32_t len = LoopScaleFormater::max_len; + int32_t loop_num = DivCeil(n, len); + LoopScaleFormater loop_scale_formater; + loop_scale_formater.SetForLoop(); + for (int32_t i = align_core_idx; i < loop_num; i += align_core_num) { + int32_t offset = i * len; + if (offset + len > n) { + len = n - offset; + } + loop_scale_formater.Loop(gm_format_dequant_scale + offset, gm_dequant_scale + offset, len); + } + loop_scale_formater.WaitForLoop(); + + Barrier(); + } + +protected: + inline __aicore__ void Barrier() + { + FFTSCrossCoreSync(0, AIV_FINISH_DEQUANT_FLAG_ID); + WaitEvent(AIV_FINISH_DEQUANT_FLAG_ID); + } + + __gm__ int32_t *gm_accum; + __gm__ bfloat16_t *gm_out; + + __gm__ int64_t *gm_dequant_scale; + __gm__ int32_t *gm_dequant_offset; + QuantGranularity dequant_granularity; + + __gm__ float32_t *gm_format_dequant_scale; + + int32_t batch_size; + int32_t m; + int32_t k; + int32_t n; +}; + +class SerialDequantRunner : public BaseDequantRunner { +public: + class LoopIter { + public: + inline __aicore__ LoopIter(int32_t batch_size, int32_t n_rows, int32_t n_cols) : + batch_size(batch_size), n_rows(n_rows), n_cols(n_cols) + { + int32_t align_core_num = get_block_num() * get_subblockdim(); + int32_t align_core_idx = get_block_idx() * get_subblockdim() + get_subblockid(); + int32_t n_rows_per_core_base = n_rows / align_core_num; + int32_t n_rows_remainder = n_rows % align_core_num; + int32_t row_offset_base = align_core_idx * n_rows_per_core_base; + if (align_core_idx < n_rows_remainder) { + n_rows_this_core = n_rows_per_core_base + 1; + row_offset_this_core = row_offset_base + align_core_idx; + } else { + n_rows_this_core = n_rows_per_core_base; + row_offset_this_core = row_offset_base + n_rows_remainder; + } + n_cols_this_core = n_cols; + col_offset_this_core = 0; + + core_offset = row_offset_this_core * n_cols; + } + + inline __aicore__ void InitBatchLoop() + { + batch_idx = 0; + batch_offset = 0; + } + + inline __aicore__ bool EndBatchLoop() const + { + return batch_idx == batch_size; + } + + inline __aicore__ void NextBatchLoop() + { + ++batch_idx; + if (EndBatchLoop()) { + return; + } + batch_offset = static_cast(batch_idx) * n_rows * n_cols; + } + + inline __aicore__ void InitRowLoop(int32_t max_rows_per_loop) + { + n_rows_complete = 0; + n_rows_this_loop = (n_rows_this_core < max_rows_per_loop) ? n_rows_this_core : max_rows_per_loop; + row_offset = 0; + } + + inline __aicore__ bool EndRowLoop() const + { + return n_rows_complete == n_rows_this_core; + } + + inline __aicore__ void NextRowLoop() + { + n_rows_complete += n_rows_this_loop; + if (EndRowLoop()) { + return; + } + if (n_rows_complete + n_rows_this_loop > n_rows_this_core) { + n_rows_this_loop = n_rows_this_core - n_rows_complete; + } + row_offset = n_rows_complete; + } + + inline __aicore__ void InitColLoop(int32_t max_cols_per_loop) + { + n_cols_complete = 0; + n_cols_this_loop = (n_cols < max_cols_per_loop) ? n_cols : max_cols_per_loop; + col_offset = 0; + } + + inline __aicore__ bool EndColLoop() const + { + return n_cols_complete == n_cols_this_core; + } + + inline __aicore__ void NextColLoop() + { + n_cols_complete += n_cols_this_loop; + if (EndColLoop()) { + return; + } + if (n_cols_complete + n_cols_this_loop > n_cols_this_core) { + n_cols_this_loop = n_cols_this_core - n_cols_complete; + } + col_offset = n_cols_complete; + } + + inline __aicore__ int64_t offset() const + { + return core_offset + row_offset * n_cols + col_offset; + } + + int32_t batch_size; + int32_t n_rows; + int32_t n_cols; + + int32_t n_rows_this_core; + int32_t n_cols_this_core; + int64_t row_offset_this_core; + int64_t col_offset_this_core; + + int32_t batch_idx; + int32_t n_rows_complete; + int32_t n_cols_complete; + + int32_t n_rows_this_loop; + int32_t n_cols_this_loop; + + int64_t core_offset; + int64_t batch_offset; + int64_t row_offset; + int64_t col_offset; + }; + + __aicore__ explicit SerialDequantRunner() = default; + + inline __aicore__ void Run() + { + switch (dequant_granularity) { + case QuantGranularity::PER_TENSOR: + DequantPerTensor(); + break; + case QuantGranularity::PER_CHANNEL: + DequantPerChannel(); + break; + case QuantGranularity::PER_TOKEN: + DequantPerChannel(); + break; + case QuantGranularity::FLOAT32_SCALE_PER_CHANNEL: + DequantPerChannel(); + break; + default: + break; + } + + Barrier(); + } + +private: + inline __aicore__ void DequantPerTensor() + { + float32_t scale = gm_format_dequant_scale[0]; + + const auto max_len = LoopDequanter::max_len; + int32_t n_round = Block32B::AlignUp(n); + int32_t max_m_per_loop = (n_round <= max_len) ? (max_len / n_round) : 1; + int32_t max_n_per_loop = (n_round <= max_len) ? n : max_len; + + LoopIter it(batch_size, m, n); + LoopDequanter loop_dequanter; + loop_dequanter.SetForLoop(); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_n_per_loop); !it.EndColLoop(); it.NextColLoop()) { + for (it.InitRowLoop(max_m_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto dst = gm_out + it.offset(); + auto src = gm_accum + it.offset(); + loop_dequanter.Loop(dst, src, scale, 0, it.n_rows_this_loop, it.n_cols_this_loop, n, n); + } + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void DequantPerChannel() + { + const auto max_len = LoopDequanter::max_len; + int32_t n_round = Block32B::AlignUp(n); + int32_t max_m_per_loop = (n_round <= max_len) ? (max_len / n_round) : 1; + int32_t max_n_per_loop = (n_round <= max_len) ? n : max_len; + + LoopIter it(batch_size, m, n); + LoopDequanter loop_dequanter; + loop_dequanter.SetForLoop(); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_n_per_loop); !it.EndColLoop(); it.NextColLoop()) { + for (it.InitRowLoop(max_m_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto dst = gm_out + it.offset(); + auto src = gm_accum + it.offset(); + //auto src = gm_accum; + auto scale = gm_format_dequant_scale + it.col_offset; + loop_dequanter.Loop(dst, src, scale, it.n_rows_this_loop, it.n_cols_this_loop, n, n); + } + } + } + loop_dequanter.WaitForLoop(); + } + +}; + + + +template +class SerialPerTokenDequantRunner : public SerialDequantRunner{ +public: + __aicore__ explicit SerialPerTokenDequantRunner() = default; + inline __aicore__ void SetArgs(__gm__ T *gm_out, + __gm__ float32_t *gm_dequant_scale_pertoken, int32_t m, int32_t n, int32_t m0, int32_t n0) + { + this->gm_out = reinterpret_cast<__gm__ T *>(gm_out); + this->gm_dequant_scale_pertoken = reinterpret_cast<__gm__ float32_t *>(gm_dequant_scale_pertoken); + this->m = m; + this->n = n; + this->m0 = m0; + this->n0 = n0; + } + + inline __aicore__ void Run() { + const auto max_len = LoopPerTokenDequanter::max_len; + int32_t max_m_per_loop = max_len / ((n0 + 127) / 128 * 128); + LoopIter it(1, m, n); + LoopPerTokenDequanter loop_dequanter(n0); + loop_dequanter.SetForLoop(); + for (it.InitRowLoop(max_m_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + for (it.InitColLoop(n0); !it.EndColLoop(); it.NextColLoop()) { + __gm__ T * dst_add = gm_out + it.offset(); + __gm__ float32_t * scale = gm_dequant_scale_pertoken + it.row_offset + it.row_offset_this_core; + loop_dequanter.Loop(dst_add, scale, it.n_rows_this_loop, it.n_cols_this_loop, n); + } + } + loop_dequanter.WaitForLoop(); + } + + +private: + __gm__ T *gm_out; + __gm__ float32_t *gm_dequant_scale_pertoken; + int32_t m; + int32_t n; + int32_t m0; + int32_t n0; +}; + + + +class FusedDequantRunner : public BaseDequantRunner { +public: + __aicore__ explicit FusedDequantRunner() = default; + inline __aicore__ void SetArgs(__gm__ bfloat16_t *gm_out, const LcalWorkspaceInfo &workspace_info, + __gm__ int64_t *gm_dequant_scale, __gm__ int32_t *gm_dequant_offset, QuantGranularity dequant_granularity, + int32_t batch_size, int32_t m, int32_t n, int32_t m0, int32_t n0, int32_t m_loop, int32_t n_loop, + int32_t core_loop, int32_t swizzl_direct, int32_t swizzl_count, int32_t p_value, int32_t rank_size) + { + BaseDequantRunner::SetArgs(gm_out, workspace_info, gm_dequant_scale, gm_dequant_offset, dequant_granularity, + batch_size, m, n); + + //cit.SetArgs(m, n, m0, n0, m_loop, n_loop, core_loop, swizzle_direct, swizzle_count, p_value); + core_num = get_block_num(); + core_idx = get_block_idx(); + this -> m0 = m0; + this -> n0 = n0; + this -> m_loop = m_loop; + this -> n_loop = n_loop; + this -> core_loop = core_loop; + this->swizzl_direct = swizzl_direct; + this->swizzl_count = swizzl_count; + + this->loop_num_per_comm = p_value * core_num; + this -> p_value = p_value; + this -> rank_size = rank_size; + + } + + inline __aicore__ void RunDequantAllReduce(int32_t cal_idx) + { + switch (dequant_granularity) { + case QuantGranularity::PER_TENSOR: + DequantAllReducePerTensor(cal_idx); + return; + case QuantGranularity::PER_CHANNEL: + DequantAllReducePerChannel(cal_idx); + return; + case QuantGranularity::PER_TOKEN: + DequantAllReducePerChannel(cal_idx); + return; + case QuantGranularity::FLOAT32_SCALE_PER_CHANNEL: + DequantAllReducePerChannel(cal_idx); + return; + default: + return; + } + } + + + + + inline __aicore__ void DequantAllReducePerChannel(int32_t cal_idx) + { + LoopDequanter loop_dequanter; + loop_dequanter.SetForLoop(); + //int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t loop_idx = cal_idx * core_num + core_idx; + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + TileLoopIter tit(m_actual, n_actual); + int64_t offset_this_tile = flag_idx * loop_num_per_comm * m0 * n0 + + (loop_idx % loop_num_per_comm) * m0 * n0; + for (tit.Init(LoopDequanter::max_len); !tit.End(); tit.Next()) { + int64_t src_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + int64_t dst_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + auto accum = gm_accum + src_offset; + auto out = gm_out + dst_offset; + auto scale = gm_format_dequant_scale + n_idx * n0; + loop_dequanter.Loop(out, accum, scale, tit.m_this_loop, tit.n_this_loop, n0, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void DequantAllReducePerTensor(int32_t cal_idx) + { + LoopDequanter loop_dequanter; + float32_t scale = gm_format_dequant_scale[0]; + loop_dequanter.SetForLoop(); + //int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t loop_idx = cal_idx * core_num + core_idx; + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + TileLoopIter tit(m_actual, n_actual); + int64_t offset_this_tile = flag_idx * loop_num_per_comm * m0 * n0 + + (loop_idx % loop_num_per_comm) * m0 * n0; + for (tit.Init(LoopDequanter::max_len); !tit.End(); tit.Next()) { + int64_t src_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + int64_t dst_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + auto accum = gm_accum + src_offset; + auto out = gm_out + dst_offset; + loop_dequanter.Loop(out, accum, scale, 0, tit.m_this_loop, tit.n_this_loop, n0, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void RunDequantReduceScatter(int32_t cal_idx) + { + switch (dequant_granularity) { + case QuantGranularity::PER_TENSOR: + DequantReduceScatterPerTensor(cal_idx); + return; + case QuantGranularity::PER_CHANNEL: + DequantReduceScatterPerChannel(cal_idx); + return; + case QuantGranularity::PER_TOKEN: + DequantReduceScatterPerChannel(cal_idx); + return; + case QuantGranularity::FLOAT32_SCALE_PER_CHANNEL: + DequantReduceScatterPerChannel(cal_idx); + return; + default: + return; + } + } + + inline __aicore__ void DequantReduceScatterPerChannel(int32_t cal_idx) + { + LoopDequanter loop_dequanter; + loop_dequanter.SetForLoop(); + int32_t m_loop_per_rank = m_loop / rank_size; + //int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t comm_num = DivCeil(core_loop, loop_num_per_comm); + int32_t actual_loop_num = loop_num_per_comm; + if (cal_idx == comm_num - 1) { + actual_loop_num = core_loop - cal_idx * loop_num_per_comm; + } + + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + + int32_t in_batch_idx = loop_idx % (m_loop * n_loop); + int64_t rank_idx = in_batch_idx % rank_size; + int32_t in_rank_idx = in_batch_idx / rank_size; + + int64_t m_idx, n_idx; + GetBlockIdx(in_rank_idx, m_loop_per_rank, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop_per_rank - 1)) ? (m / rank_size - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + + TileLoopIter tit(m_actual, n_actual); + int64_t rank_offset_c = (loop_idx % rank_size) * (actual_loop_num / rank_size) * m0 * n0; + int64_t offset_this_tile = flag_idx * m0 * loop_num_per_comm * n0 + rank_offset_c + + + ((loop_idx % loop_num_per_comm) / rank_size) * m0 * n0; + + for (tit.Init(LoopDequanter::max_len); !tit.End(); tit.Next()) { + int64_t src_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + int64_t dst_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + auto accum = gm_accum + src_offset; + auto out = gm_out + dst_offset; + auto scale = gm_format_dequant_scale + n_idx * n0; + loop_dequanter.Loop(out, accum, scale, tit.m_this_loop, tit.n_this_loop, n0, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void DequantReduceScatterPerTensor(int32_t cal_idx) + { + LoopDequanter loop_dequanter; + loop_dequanter.SetForLoop(); + float32_t scale = gm_format_dequant_scale[0]; + int32_t m_loop_per_rank = m_loop / rank_size; + //int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t comm_num = DivCeil(core_loop, loop_num_per_comm); + int32_t actual_loop_num = loop_num_per_comm; + if (cal_idx == comm_num - 1) { + actual_loop_num = core_loop - cal_idx * loop_num_per_comm; + } + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + + int32_t in_batch_idx = loop_idx % (m_loop * n_loop); + int64_t rank_idx = in_batch_idx % rank_size; + int32_t in_rank_idx = in_batch_idx / rank_size; + + int64_t m_idx, n_idx; + GetBlockIdx(in_rank_idx, m_loop_per_rank, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop_per_rank - 1)) ? (m / rank_size - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + + TileLoopIter tit(m_actual, n_actual); + int64_t rank_offset_c = (loop_idx % rank_size) * (actual_loop_num / rank_size) * m0 * n0; + int64_t offset_this_tile = flag_idx * m0 * loop_num_per_comm * n0 + rank_offset_c + + + ((loop_idx % loop_num_per_comm) / rank_size) * m0 * n0; + + for (tit.Init(LoopDequanter::max_len); !tit.End(); tit.Next()) { + int64_t src_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + int64_t dst_offset = offset_this_tile + tit.m_offset_in_tile() * n0; + auto accum = gm_accum + src_offset; + auto out = gm_out + dst_offset; + loop_dequanter.Loop(out, accum, scale, 0, tit.m_this_loop, tit.n_this_loop, n0, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + + + inline __aicore__ void SetArgs(__gm__ bfloat16_t *gm_out, const LcalWorkspaceInfo &workspace_info, + __gm__ int64_t *gm_dequant_scale, __gm__ int32_t *gm_dequant_offset, QuantGranularity dequant_granularity, + int32_t batch_size, int32_t m, int32_t n, int32_t m0, int32_t n0, int32_t m_loop, int32_t n_loop, + int32_t core_loop,int32_t rank, int32_t swizzle_direct, int32_t swizzle_count, int32_t p_value, int32_t EP, int32_t TP, + int32_t local_expert_nums, int32_t is_moe_averaged, int32_t is_alltoallvc, + __gm__ int32_t* num_local_tokens_per_expert, __gm__ int32_t* num_global_tokens_per_local_expert) + { + BaseDequantRunner::SetArgs(gm_out, workspace_info, gm_dequant_scale, gm_dequant_offset, dequant_granularity, + batch_size, m, n); + + core_num = get_block_num(); + core_idx = get_block_idx(); + + loop_per_EP = p_value * core_num / (EP * TP); + + out_loop_per_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_out_loop_per_expert); + out_loop_per_ep = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_out_loop_per_EP); + sum_num_local_tokens_per_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_sum_num_local_tokens_per_expert); + sum_num_global_tokens_per_local_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_sum_num_global_tokens_per_local_expert); + + in_expert_comm_count_accum = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_in_expert_comm_count_accum); + + this->n_loop = n_loop; + this->m_loop = m_loop; + this->m0 = m0; + this->n0 = n0; + this->swizzl_direct = swizzle_direct; + this->swizzl_count = swizzle_count; + this->p_value = p_value; + this->rank_size = EP * TP; + this->rank = rank; + + + this->EP = EP; + this->TP = TP; + this->local_expert_nums = local_expert_nums; + + this->is_moe_averaged = is_moe_averaged; + this->is_alltoallvc = is_alltoallvc; + this->num_local_tokens_per_expert = reinterpret_cast<__gm__ int32_t *>(num_local_tokens_per_expert); + this->num_global_tokens_per_local_expert = + reinterpret_cast<__gm__ int32_t *>(num_global_tokens_per_local_expert); + } + +private: + int32_t core_num; + int32_t core_idx; + + int32_t m0; + int32_t n0; + int32_t m_loop; + int32_t n_loop; + int32_t core_loop; + int32_t loop_num_per_comm; + int32_t swizzl_direct; + int32_t swizzl_count; + + int32_t p_value; + int32_t rank_size; + + int32_t loop_per_EP; + int32_t rank; + int32_t EP; + int32_t TP; + int32_t local_expert_nums; + int32_t is_moe_averaged; + int32_t is_alltoallvc; + __gm__ int32_t *out_loop_per_expert; + __gm__ int32_t *out_loop_per_ep; + __gm__ int32_t *sum_num_local_tokens_per_expert; + __gm__ int32_t *sum_num_global_tokens_per_local_expert; + __gm__ int32_t *in_expert_comm_count_accum; + __gm__ int32_t* num_local_tokens_per_expert; + __gm__ int32_t* num_global_tokens_per_local_expert; + + int32_t sum_loop; +}; + +template +class FusedPerTokenDequantRunner : public BaseDequantRunner { +public: + __aicore__ explicit FusedPerTokenDequantRunner() = default; + + inline __aicore__ void SetArgs(__gm__ T *gm_buff, + __gm__ float32_t *gm_dequant_scale_pertoken, int32_t m, int32_t n, int32_t m0, int32_t n0, + int32_t m_loop, int32_t n_loop, int32_t core_loop, int32_t swizzl_direct, int32_t swizzl_count, + int32_t p_value, int32_t rank_size) + { + this->gm_buff = gm_buff; + this->gm_dequant_scale_pertoken = gm_dequant_scale_pertoken; + core_num = get_block_num(); + core_idx = get_block_idx(); + this-> m = m; + this -> n = n; + this -> m0 = m0; + this -> n0 = n0; + this -> m_loop = m_loop; + this -> n_loop = n_loop; + this -> core_loop = core_loop; + this->swizzl_direct = swizzl_direct; + this->swizzl_count = swizzl_count; + + this->loop_num_per_comm = p_value * core_num; + this -> p_value = p_value; + this -> rank_size = rank_size; + } + + + inline __aicore__ void SetArgs(__gm__ T *gm_buff, const LcalWorkspaceInfo &workspace_info, + __gm__ float32_t *gm_dequant_scale_pertoken, + int32_t batch_size, int32_t m, int32_t n, int32_t m0, int32_t n0, int32_t m_loop, int32_t n_loop, + int32_t core_loop,int32_t rank, int32_t swizzle_direct, int32_t swizzle_count, int32_t p_value, int32_t EP, int32_t TP, + int32_t local_expert_nums, int32_t is_moe_averaged, int32_t is_alltoallvc, + __gm__ int32_t* num_local_tokens_per_expert, __gm__ int32_t* num_global_tokens_per_local_expert) + { + this->gm_buff = gm_buff; + this->gm_dequant_scale_pertoken = gm_dequant_scale_pertoken; + this->m = m; + this->n = n; + + + core_num = get_block_num(); + core_idx = get_block_idx(); + + loop_per_EP = p_value * core_num / (EP * TP); + + out_loop_per_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_out_loop_per_expert); + out_loop_per_ep = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_out_loop_per_EP); + sum_num_local_tokens_per_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_sum_num_local_tokens_per_expert); + sum_num_global_tokens_per_local_expert = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_sum_num_global_tokens_per_local_expert); + in_expert_comm_count_accum = reinterpret_cast<__gm__ int32_t *> (workspace_info.gm_in_expert_comm_count_accum); + + this->n_loop = n_loop; + this->m_loop = m_loop; + this->m0 = m0; + this->n0 = n0; + this->swizzl_direct = swizzle_direct; + this->swizzl_count = swizzle_count; + this->p_value = p_value; + this->rank_size = EP * TP; + this->rank = rank; + + + this->EP = EP; + this->TP = TP; + this->local_expert_nums = local_expert_nums; + + this->is_moe_averaged = is_moe_averaged; + this->is_alltoallvc = is_alltoallvc; + + this->num_local_tokens_per_expert = reinterpret_cast<__gm__ int32_t *>(num_local_tokens_per_expert); + this->num_global_tokens_per_local_expert = + reinterpret_cast<__gm__ int32_t *>(num_global_tokens_per_local_expert); + } +inline __aicore__ void SetArgs(__gm__ T *gm_buff, const LcalWorkspaceInfo &workspace_info, + __gm__ float32_t *gm_dequant_scale_pertoken, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m0, int32_t k0, int32_t n0, int32_t m_loop, int32_t n_loop, + int32_t core_loop,int32_t rank, int32_t swizzle_direct, int32_t swizzle_count, int32_t p_value, int32_t EP, int32_t TP, + int32_t local_expert_nums, int32_t is_moe_averaged, int32_t is_alltoallvc, int32_t max_output_size, int32_t buffer_size, + __gm__ int32_t* global_tokens_per_expert_matrix) + { + this->gm_buff = gm_buff; + this->gm_dequant_scale_pertoken = gm_dequant_scale_pertoken; + this->m = m; + this->k = k; + this->n = n; + + + core_num = get_block_num(); + core_idx = get_block_idx(); + + this->n_loop = n_loop; + this->m_loop = m_loop; + this->m0 = m0; + this->k0 = k0; + this->n0 = n0; + this->swizzl_direct = swizzle_direct; + this->swizzl_count = swizzle_count; + this->p_value = p_value; + this->rank_size = EP * TP; + this->rank = rank; + this->buffer_size = buffer_size; + + + this->EP = EP; + this->TP = TP; + this->local_expert_nums = local_expert_nums; + + this->is_moe_averaged = is_moe_averaged; + this->is_alltoallvc = is_alltoallvc; + + //hidden + this->comm_n = p_value * n0; + this->global_tokens_per_expert_matrix = reinterpret_cast<__gm__ int32_t *>(global_tokens_per_expert_matrix); + this->expert_nums = EP * local_expert_nums; + this->maxOutputSize = max_output_size; + if(is_moe_averaged) { + sum_m_loop = DivCeil((m / expert_nums) * EP, m0) * local_expert_nums; + max_m = m; + } else { + if (maxOutputSize == -1) { + max_m = 0; + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t sum_m_ep = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = local_expert_id + ep_idx * local_expert_nums; + for(int32_t i = 0; i < EP; i++) { + sum_m_ep += global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + } + max_m = max(max_m, sum_m_ep); + } + } else { + max_m = maxOutputSize; + } + + + for(int32_t i = 0; i < local_expert_nums; i++){ + int32_t last_sum_m = (i == 0 ? 0 : sum_m[i - 1]); + for(int j = 0; j < EP; j++) { + sum_m[i] += global_tokens_per_expert_matrix[j * expert_nums + rank * local_expert_nums + i]; + //global_tokens_per_expert_matrix[j][rank * local_expert_nums + i] + } + if (maxOutputSize > 0 && sum_m[i] + last_sum_m > maxOutputSize) { + sum_m[i] = maxOutputSize - last_sum_m; + } + sum_m_loop += DivCeil(sum_m[i], m0); + sum_m[i] += (i == 0 ? 0 : sum_m[i - 1]); + } + + } + sum_loop = 0; + //hidden end. + } + + + inline __aicore__ void RunDequantAllReduce(int32_t cal_idx) + { + LoopPerTokenDequanter loop_dequanter(n0); + loop_dequanter.SetForLoop(); + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t loop_idx = cal_idx * core_num + core_idx; + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + TileLoopIter tit(m_actual, n_actual); + int64_t offset_this_tile = flag_idx * loop_num_per_comm * m0 * n0 + + (loop_idx % loop_num_per_comm) * m0 * n0; + for (tit.Init(LoopPerTokenDequanter::max_len, n0); !tit.End(); tit.Next()) { + int64_t offset = offset_this_tile + tit.m_offset_in_tile() * n0; // 子核当前需处理的字节偏移 + auto buff = gm_buff + offset; // 通信缓冲内的地址 + auto scale = gm_dequant_scale_pertoken + m_idx * m0 + tit.m_offset_in_tile(); // 注意要加上m_offset_in_tile + loop_dequanter.Loop(buff, scale, tit.m_this_loop, tit.n_this_loop, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void RunDequantReduceScatter(int32_t cal_idx) + { + LoopPerTokenDequanter loop_dequanter(n0); + loop_dequanter.SetForLoop(); + int32_t m_loop_per_rank = m_loop / rank_size; + //int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + int32_t pipe_depth = MAX_BLOCK_COUNT; + int32_t flag_idx = cal_idx % pipe_depth; + int32_t comm_num = DivCeil(core_loop, loop_num_per_comm); + int32_t actual_loop_num = loop_num_per_comm; + if (cal_idx == comm_num - 1) { + actual_loop_num = core_loop - cal_idx * loop_num_per_comm; + } + + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + + int32_t in_batch_idx = loop_idx % (m_loop * n_loop); + int64_t rank_idx = in_batch_idx % rank_size; + int32_t in_rank_idx = in_batch_idx / rank_size; + + int64_t m_idx, n_idx; + GetBlockIdx(in_rank_idx, m_loop_per_rank, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop_per_rank - 1)) ? (m / rank_size - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + + TileLoopIter tit(m_actual, n_actual); + int64_t rank_offset_c = (loop_idx % rank_size) * (actual_loop_num / rank_size) * m0 * n0; + int64_t offset_this_tile = flag_idx * m0 * loop_num_per_comm * n0 + rank_offset_c + + + ((loop_idx % loop_num_per_comm) / rank_size) * m0 * n0; + for (tit.Init(LoopPerTokenDequanter::max_len, n0); !tit.End(); tit.Next()) { + int64_t offset = offset_this_tile + tit.m_offset_in_tile() * n0; // 子核当前需处理的字节偏移 + auto buff = gm_buff + offset; // 通信缓冲内的地址 + auto scale = gm_dequant_scale_pertoken + m_idx * m0 + tit.m_offset_in_tile(); // 注意要加上m_offset_in_tile + loop_dequanter.Loop(buff, scale, tit.m_this_loop, tit.n_this_loop, n0); + } + } + loop_dequanter.WaitForLoop(); + } + + inline __aicore__ void DequantPerTokenMatmulAllToAllHidden(int32_t cal_idx) { + cal_count = DivCeil(n, comm_n); + gm_a_pingpong_size = comm_n * max_m; + gm_a_pingpong_num = buffer_size * 1024 * 1024 / 2 / gm_a_pingpong_size; + if (gm_a_pingpong_num > 8) { + gm_a_pingpong_num = 8; + } + LoopPerTokenDequanter loop_dequanter(n0); + loop_dequanter.SetForLoop(); + int32_t n_len; + if(cal_idx == cal_count - 1) { + n_len = n - cal_idx * comm_n; + } else { + n_len = comm_n; + } + n_loop = DivCeil(n_len,n0); + int32_t sum_loop_num = sum_m_loop * n_loop; + //int32_t flag_id = cal_idx % MAX_BLOCK_COUNT; + int32_t flag_id = cal_idx % gm_a_pingpong_num; + + for(int32_t loop_idx = 0; loop_idx < sum_loop_num; loop_idx ++) { + if((loop_idx + sum_loop) % core_num != core_idx) { + continue; + } + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, sum_m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t sum_loop_before = 0; + int32_t local_expert_idx = -1; + int32_t m_in_expert; + for(int32_t i = 0; i < local_expert_nums; i++) { + if(is_moe_averaged) { + m_in_expert = m / local_expert_nums; + } else { + m_in_expert = sum_m[i] - (i == 0 ? 0 : sum_m[i - 1]); + } + sum_loop_before += DivCeil(m_in_expert, m0); + if(sum_loop_before > m_idx) { + local_expert_idx = i; + break; + } + } + int32_t m_loop_in_expert = DivCeil(m_in_expert, m0); + sum_loop_before -= m_loop_in_expert; + int32_t m_idx_in_expert = m_idx - sum_loop_before; + int32_t m_actual = ((m_idx_in_expert == m_loop_in_expert - 1) ? (m_in_expert - m_idx_in_expert * m0) : m0); + int32_t n_actual = ((n_idx == n_loop - 1) ? (n_len - n_idx * n0) : n0); + + int32_t sum_m_before = 0; + if(is_moe_averaged) { + sum_m_before = local_expert_idx * (m / local_expert_nums); + } else { + sum_m_before = sum_m[local_expert_idx] - m_in_expert; + } + + int64_t m_offset_this_tile = sum_m_before + m_idx_in_expert * m0; + + int64_t offset_this_tile = flag_id * gm_a_pingpong_size + + 1LL * (sum_m_before + m_idx_in_expert * m0) * n_len + 1LL * n_idx * n0; + // int64_t offset_this_tile = + // 1LL * (sum_m_before + m_idx_in_expert * m0) * n_len + 1LL * n_idx * n0; + + + TileLoopIter tit(m_actual, n_actual); + + for (tit.Init(LoopPerTokenDequanter::max_len, n0); !tit.End(); tit.Next()){ + int64_t buff_offset = offset_this_tile + tit.m_offset_in_tile() * n_len; // 子核当前需处理的字节偏移 + //int64_t buff_offset = offset_this_tile; + auto buff = gm_buff + buff_offset; + //auto buff = gm_buff; + auto scale = gm_dequant_scale_pertoken + m_offset_this_tile + tit.m_offset_in_tile(); + //scale = gm_dequant_scale; + loop_dequanter.Loop(buff, scale, tit.m_this_loop, tit.n_this_loop, n_len); + } + + } + sum_loop += sum_loop_num; + loop_dequanter.WaitForLoop(); + } + + +private: + int32_t core_num; + int32_t core_idx; + int32_t m0; + int32_t k0; + int32_t n0; + int32_t m_loop; + int32_t n_loop; + int32_t core_loop; + int32_t loop_num_per_comm; + int32_t swizzl_direct; + int32_t swizzl_count; + + int32_t p_value; + int32_t rank_size; + __gm__ T *gm_buff; + __gm__ float32_t *gm_dequant_scale_pertoken; + + + int32_t loop_per_EP; + int32_t rank; + int32_t EP; + int32_t TP; + int32_t local_expert_nums; + int32_t is_moe_averaged; + int32_t is_alltoallvc; + int32_t buffer_size; + + __gm__ int32_t *out_loop_per_expert; + __gm__ int32_t *out_loop_per_ep; + __gm__ int32_t *sum_num_local_tokens_per_expert; + __gm__ int32_t *sum_num_global_tokens_per_local_expert; + __gm__ int32_t *in_expert_comm_count_accum; + + + __gm__ int32_t* num_local_tokens_per_expert; + __gm__ int32_t* num_global_tokens_per_local_expert; + + int32_t sum_loop; + + __gm__ int32_t* global_tokens_per_expert_matrix; + int32_t max_m; + int32_t sum_m[32] = {0}; + int32_t sum_m_loop = 0; + int32_t comm_n; + int32_t comm_k; + int64_t gm_a_pingpong_size; + int64_t gm_a_pingpong_num; + int32_t expert_nums; + int32_t cal_count; + int32_t maxOutputSize; + +}; +#endif + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_internal.cce b/comm/lcal/src/kernels/coc_internal.cce new file mode 100644 index 0000000000000000000000000000000000000000..548575ad0178786a8b12dc33a534b8f9d8b07fa4 --- /dev/null +++ b/comm/lcal/src/kernels/coc_internal.cce @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_COC_INTERNAL_H +#define LCAL_COC_INTERNAL_H + +#include +#include "kernel_operator.h" +#include "coc_const_args.cce" +using namespace AscendC; + + + +template +FORCE_INLINE_AICORE LocalTensor CreateLocalTensor(__ubuf__ T *addr) +{ + LocalTensor tensor; + TBuffAddr taddr; + taddr.bufferAddr = reinterpret_cast(addr); + tensor.SetAddr(taddr); + return tensor; +} + +template +FORCE_INLINE_AICORE LocalTensor CreateLocalTensor(uint32_t buffer_offset) +{ + LocalTensor tensor; + tensor.address_.bufferAddr = buffer_offset; + return tensor; +} + +template +FORCE_INLINE_AICORE LocalTensor CreateLocalTensor(uint32_t buffer_offset, uint8_t logic_pos) +{ + LocalTensor tensor; + tensor.address_.logicPos = logic_pos; + tensor.address_.bufferAddr = buffer_offset; + return tensor; +} + +template +FORCE_INLINE_AICORE GlobalTensor CreateGlobalTensor(__gm__ T *addr) +{ + GlobalTensor tensor; + tensor.SetGlobalBuffer(addr); + return tensor; +} + +template +inline __aicore__ void FFTSCrossCoreSync(uint64_t mode, uint64_t flag_id) +{ + uint64_t config = 1 | (mode << 4) | (flag_id << 8); + ffts_cross_core_sync(pipe, config); +} + +template +inline __aicore__ void CopyUB2UB(__ubuf__ T *dst, __ubuf__ T *src, uint8_t sid, uint16_t nBurst, uint16_t lenBurst, + uint16_t srcStride, uint16_t dstStride) +{ + LocalTensor srcTensor = CreateLocalTensor(src); + LocalTensor dstTensor = CreateLocalTensor(dst); + DataCopyParams repeatParams(nBurst, lenBurst, srcStride, dstStride); + DataCopy(dstTensor, srcTensor, repeatParams); +} + +template +inline __aicore__ void Vconv(__ubuf__ Tdst *dst, __ubuf__ Tsrc *src, uint8_t repeat, uint16_t dstBlockStride, + uint16_t srcBlockStride, uint8_t dstRepeatStride, uint8_t srcRepeatStride, + const RoundMode &roundMode = RoundMode::CAST_NONE) +{ + LocalTensor srcTensor = CreateLocalTensor(src); + LocalTensor dstTensor = CreateLocalTensor(dst); + UnaryRepeatParams repeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride); + Cast(dstTensor, srcTensor, roundMode, -1, repeat, repeatParams); +} + +template +inline __aicore__ void Vadd(__ubuf__ T *dst, __ubuf__ T *src0, __ubuf__ T *src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + LocalTensor srcTensor0 = CreateLocalTensor(src0); + LocalTensor srcTensor1 = CreateLocalTensor(src1); + LocalTensor dstTensor = CreateLocalTensor(dst); + BinaryRepeatParams repeatParams(dstBlockStride, src0BlockStride, src1BlockStride, dstRepeatStride, src0RepeatStride, + src1RepeatStride); + Add(dstTensor, srcTensor0, srcTensor1, -1, repeat, repeatParams); +} + +template +inline __aicore__ void Vadds(__ubuf__ T *dst, __ubuf__ T *src, const T &scalarValue, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint8_t dstRepeatStride, + uint8_t srcRepeatStride) +{ + LocalTensor srcTensor = CreateLocalTensor(src); + LocalTensor dstTensor = CreateLocalTensor(dst); + UnaryRepeatParams repeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride); + Adds(dstTensor, srcTensor, scalarValue, -1, repeat, repeatParams); +} + +template +inline __aicore__ void Vmul(__ubuf__ T *dst, __ubuf__ T *src0, __ubuf__ T *src1, uint8_t repeat, uint8_t dstBlockStride, + uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride, + uint8_t src0RepeatStride, uint8_t src1RepeatStride) +{ + LocalTensor srcTensor0 = CreateLocalTensor(src0); + LocalTensor srcTensor1 = CreateLocalTensor(src1); + LocalTensor dstTensor = CreateLocalTensor(dst); + BinaryRepeatParams repeatParams(dstBlockStride, src0BlockStride, src1BlockStride, dstRepeatStride, src0RepeatStride, + src1RepeatStride); + Mul(dstTensor, srcTensor0, srcTensor1, -1, repeat, repeatParams); +} + +template +inline __aicore__ void Vmuls(__ubuf__ T *dst, __ubuf__ T *src, const T &scalarValue, uint8_t repeat, + uint16_t dstBlockStride, uint16_t srcBlockStride, uint8_t dstRepeatStride, + uint8_t srcRepeatStride) +{ + LocalTensor srcTensor = CreateLocalTensor(src); + LocalTensor dstTensor = CreateLocalTensor(dst); + UnaryRepeatParams repeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride); + Muls(dstTensor, srcTensor, scalarValue, -1, repeat, repeatParams); +} + +inline __aicore__ bool IsQuant(const QuantGranularity &granularity) +{ + return (granularity > QuantGranularity::QUANT_GRANULARITY_UNDEFINED) && + (granularity < QuantGranularity::QUANT_GRANULARITY_MAX); +} + +#define COC_ARGS_FUN_IIO(T_INPUT1, T_INPUT2, T_OUTPUT) \ + __gm__ T_INPUT1 *gm_a, __gm__ T_INPUT2 *gm_b, __gm__ T_OUTPUT *gm_bias, __gm__ T_OUTPUT *gm_gamma, \ + __gm__ T_OUTPUT *gm_out, __gm__ T_OUTPUT *gm_allgather_out, GM_ADDR gm_workspace, \ + GM_ADDR gm_dequant_scale, GM_ADDR gm_dequant_offset, GM_ADDR gm_quant_scale, \ + GM_ADDR gm_quant_offset, GM_ADDR coc_comm_args, GM_ADDR ffts_addr, \ + __gm__ int32_t* num_local_tokens_per_expert, __gm__ int32_t *num_global_tokens_per_local_expert, \ + __gm__ int32_t *global_tokens_per_expert_matrix, GM_ADDR para_gm + +#define COC_ARGS_FUN_IO(T_INPUT, T_OUTPUT) COC_ARGS_FUN_IIO(T_INPUT, T_INPUT, T_OUTPUT) + +#define COC_ARGS_FUN(T) COC_ARGS_FUN_IO(T, T) + +#define COC_ARGS_CALL() \ + gm_a, gm_b, gm_bias, gm_gamma, gm_out, gm_allgather_out, gm_workspace, gm_dequant_scale, gm_dequant_offset, \ + gm_quant_scale, gm_quant_offset, coc_comm_args, ffts_addr,\ + num_local_tokens_per_expert, num_global_tokens_per_local_expert, \ + global_tokens_per_expert_matrix, para_gm + +#define COC_ARGS_CALL_INT8() \ + reinterpret_cast(gm_a), reinterpret_cast(gm_b), reinterpret_cast(gm_bias), \ + reinterpret_cast(gm_gamma), reinterpret_cast(gm_out), \ + reinterpret_cast(gm_allgather_out), gm_workspace, gm_dequant_scale, gm_dequant_offset, \ + gm_quant_scale, gm_quant_offset, coc_comm_args, ffts_addr, \ + num_local_tokens_per_expert, num_global_tokens_per_local_expert,\ + global_tokens_per_expert_matrix, para_gm + +#define PP_MATMUL_AIC_ARGS_FUN(T_INPUT, T_OUTPUT) \ + GM_ADDR gm_a, GM_ADDR gm_b, __gm__ T_OUTPUT *gm_bias, __gm__ T_OUTPUT *gm_c, \ + __gm__ T_OUTPUT *gm_peer_mem, GM_ADDR gm_workspace, GM_ADDR gm_dequant_scale, \ + GM_ADDR gm_dequant_offset, int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m0, \ + int32_t k0, int32_t n0, int32_t m_loop, int32_t k_loop, int32_t n_loop, int32_t core_loop, \ + int32_t swizzl_count, int32_t swizzl_direct, int32_t rank, int32_t rank_size, int32_t p_value, \ + int32_t withSerialMode, QuantGranularity quant_granularity, QuantGranularity dequant_granularity, \ + int32_t ag_dim, int32_t rs_dim, bool inner_dim_is_Ag, bool weight_nz, bool is_91093,\ + __gm__ int32_t *num_local_tokens_per_expert, __gm__ int32_t * num_global_tokens_per_local_expert, \ + __gm__ int32_t *global_tokens_per_expert_matrix, int32_t local_expert_nums, int32_t EP, int32_t TP, \ + int32_t maxOutputSize, int32_t is_moe, bool is_deterministic, int32_t buffer_size\ + +#define PP_MATMUL_AIC_ARGS_CALL() \ + reinterpret_cast(gm_a), reinterpret_cast(gm_b), gm_bias, gm_c, gm_peer_mem, \ + reinterpret_cast(gm_workspace), reinterpret_cast(gm_dequant_scale), \ + reinterpret_cast(gm_dequant_offset), batch_size, m, k, n, m0, k0, n0, m_loop, k_loop, \ + n_loop, core_loop, swizzl_count, swizzl_direct, rank, rank_size, p_value, withSerialMode, quant_granularity, \ + dequant_granularity, ag_dim, rs_dim, inner_dim_is_Ag, weight_nz, is_91093, \ + num_local_tokens_per_expert, num_global_tokens_per_local_expert,\ + global_tokens_per_expert_matrix, local_expert_nums, EP, TP, maxOutputSize ,is_moe, is_deterministic, buffer_size\ + +#define PP_MATMUL_AIV_PADDING_ARGS_FUN() \ + GM_ADDR gm_a, GM_ADDR gm_b, GM_ADDR gm_workspace, GM_ADDR gm_dequant_scale, \ + GM_ADDR gm_dequant_offset, GM_ADDR gm_quant_scale, GM_ADDR gm_quant_offset, \ + int32_t batch_size, int32_t m, int32_t k, int32_t n, bool trans_a, bool trans_b, bool is_int8, \ + QuantGranularity dequant_granularity, int32_t dequant_group_size, QuantGranularity quant_granularity, \ + int32_t quant_group_size, int32_t weight_nz, int32_t is_moe, int32_t is_moe_averaged, int32_t is_alltoallvc, \ + int32_t EP,int32_t TP, int32_t local_expert_nums, bool is_deterministic + +#define PP_MATMUL_AIV_PADDING_ARGS_CALL() \ + reinterpret_cast(gm_a), reinterpret_cast(gm_b), \ + reinterpret_cast(gm_workspace), reinterpret_cast(gm_dequant_scale), \ + reinterpret_cast(gm_dequant_offset), reinterpret_cast(gm_quant_scale), \ + reinterpret_cast(gm_quant_offset), batch_size, m, k, n, trans_a, trans_b, is_int8, \ + dequant_granularity, dequant_group_size, quant_granularity, quant_group_size, weight_nz, is_moe, \ + is_moe_averaged, is_alltoallvc, EP,TP,local_expert_nums, is_deterministic + +#define PP_MATMUL_AIV_ADD_BIAS_ARGS_FUN() \ + GM_ADDR gm_bias, GM_ADDR gm_out, int32_t batch_size, int32_t m, int32_t n, int32_t rank_size + +#define PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL() \ + reinterpret_cast(gm_bias), reinterpret_cast(gm_out), batch_size, m, n, rank_size + +#define PP_MATMUL_AIV_POST_ARGS_CALL() \ + reinterpret_cast(gm_out), reinterpret_cast(gm_bias), \ + reinterpret_cast(gm_gamma), reinterpret_cast(para_gm) + +#define PP_MATMUL_AIV_POST_ARGS_FUN() \ + GM_ADDR gm_out, GM_ADDR gm_bias, GM_ADDR gm_gamma, GM_ADDR para_gm + +#define TEMPLATE_ARGS_FUN() bool ALIGN = true, bool IS_INT8 = false, bool HAVE_BIAS = false, typename T = half + +#define TEMPLATE_ARGS_CALL() ALIGN, IS_INT8, HAVE_BIAS, T + +inline __aicore__ void AlignJudge(bool trans_a, bool trans_b, int32_t m, int32_t k, int32_t n, int32_t m_align, + int32_t k_align, int32_t n_align, int32_t &aligned_a, int32_t &aligned_b) +{ + if (!trans_a) { + aligned_a = k != k_align; + } else { + aligned_a = (m != m_align && m != 1); + } + + if (!trans_b) { + aligned_b = (n != n_align); + } else { + aligned_b = (k != k_align); + } +} + +inline __aicore__ void GetBlockIdx(int32_t loop_idx, int32_t m_loop, int32_t n_loop, int32_t swizzl_direction, + int32_t swizzl_count, int64_t &m_idx, int64_t &n_idx) +{ + uint32_t in_batch_idx = loop_idx % (m_loop * n_loop); + if (swizzl_direction == 0) { // Zn + uint32_t tile_block_loop = (m_loop + swizzl_count - 1) / swizzl_count; + uint32_t tile_block_idx = in_batch_idx / (swizzl_count * n_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * n_loop); + + uint32_t n_row = swizzl_count; + if (tile_block_idx == tile_block_loop - 1) { + n_row = m_loop - swizzl_count * tile_block_idx; + } + m_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_row; + n_idx = in_tile_block_idx / n_row; + if (tile_block_idx % 2 != 0) { + n_idx = n_loop - n_idx - 1; + } + } else if (swizzl_direction == 1) { // Nz + uint32_t tile_block_loop = (n_loop + swizzl_count - 1) / swizzl_count; + uint32_t tile_block_idx = in_batch_idx / (swizzl_count * m_loop); + uint32_t in_tile_block_idx = in_batch_idx % (swizzl_count * m_loop); + + uint32_t n_col = swizzl_count; + if (tile_block_idx == tile_block_loop - 1) { + n_col = n_loop - swizzl_count * tile_block_idx; + } + m_idx = in_tile_block_idx / n_col; + n_idx = tile_block_idx * swizzl_count + in_tile_block_idx % n_col; + if (tile_block_idx % 2 != 0) { + m_idx = m_loop - m_idx - 1; + } + } +} + +template +FORCE_INLINE_AICORE void CopyGmToUbufAlign(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint32_t gmGap, uint32_t ubufGap = 0) +{ + if constexpr (sizeof(T) == 8) { + CopyGmToUbufAlign(reinterpret_cast<__ubuf__ int32_t *>(dst), reinterpret_cast<__gm__ int32_t *>(src), + nBurst * 2, lenBurst * 2, gmGap, ubufGap); + return; + } + DataCopyParams dataCopyParams(nBurst, // blockCount + (Block32B::Count(lenBurst)), // blockLen + (Block32B::Count(gmGap)), // srcStride + (ubufGap) // dstStride + ); + DataCopyExtParams dataCopyAlignParams(nBurst, lenBurst * sizeof(T), gmGap * sizeof(T), ubufGap, 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(dst); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(src); + + if (Block32B::IsAligned(lenBurst) && Block32B::IsAligned(gmGap)) { + DataCopy(ubTensor, gmTensor, dataCopyParams); + } else { + DataCopyPadExtParams padParams; + DataCopyPad(ubTensor, gmTensor, dataCopyAlignParams, padParams); + } +} + +template +FORCE_INLINE_AICORE void CopyUbufToGmAlign(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint32_t gmGap, uint32_t ubufGap = 0) +{ + DataCopyParams dataCopyParams(nBurst, // blockCount + static_cast(Block32B::Count(lenBurst)), // blockLen + static_cast(ubufGap), // srcStride + static_cast(Block32B::Count(gmGap)) // dstStride + ); + DataCopyExtParams dataCopyAlignParams(nBurst, lenBurst * sizeof(T), ubufGap, gmGap * sizeof(T), 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(src); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(dst); + if (Block32B::IsAligned(lenBurst) && Block32B::IsAligned(gmGap)) { + DataCopy(gmTensor, ubTensor, dataCopyParams); + } else { + DataCopyPadParams padParams; + DataCopyPad(gmTensor, ubTensor, dataCopyAlignParams); + } +} + +template +FORCE_INLINE_AICORE void CopyGmToUbufAlignB16(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint16_t srcStride, uint16_t dstStride) +{ + DataCopyExtParams dataCopyParams(nBurst, // blockCount + lenBurst, // blockLen + srcStride, // srcStride + dstStride, // dstStride + 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(dst); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(reinterpret_cast(src)); + DataCopyPadExtParams padParams; + DataCopyPad(ubTensor, gmTensor, dataCopyParams, padParams); +} + +template +FORCE_INLINE_AICORE void CopyUbufToGmAlignB16(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint16_t srcStride, uint16_t dstStride) +{ + DataCopyExtParams dataCopyParams(nBurst, // blockCount + lenBurst, // blockLen + srcStride, // srcStride + dstStride, // dstStride + 0); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(src); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(reinterpret_cast(dst)); + DataCopyPad(gmTensor, ubTensor, dataCopyParams); +} + +template +FORCE_INLINE_AICORE void CopyGmToUbuf(__ubuf__ T *dst, __gm__ T *src, uint16_t nBurst, uint32_t lenBurst, + uint16_t srcStride, uint16_t dstStride) +{ + DataCopyParams dataCopyParams(nBurst, // blockCount + lenBurst, // blockLen + srcStride, // srcStride + dstStride // dstStride + ); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(dst); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(src); + DataCopy(ubTensor, gmTensor, dataCopyParams); +} + +template +FORCE_INLINE_AICORE void CopyUbufToGm(__gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, uint16_t lenBurst, + uint16_t srcStride, uint16_t dstStride) +{ + DataCopyParams dataCopyParams(nBurst, // blockCount + lenBurst, // blockLen + srcStride, // srcStride + dstStride // dstStride + ); + LocalTensor ubTensor; + TBuffAddr ubAddr; + ubAddr.logicPos = static_cast(TPosition::VECIN); + ubAddr.bufferAddr = reinterpret_cast(src); + ubTensor.SetAddr(ubAddr); + GlobalTensor gmTensor; + gmTensor.SetGlobalBuffer(dst); + DataCopy(gmTensor, ubTensor, dataCopyParams); +} + +template +FORCE_INLINE_AICORE void CopyUbufToGmUnknown(bool ALIGN, __gm__ T *dst, __ubuf__ T *src, uint16_t nBurst, + uint32_t lenBurst, uint16_t srcStride, uint16_t dstStride) +{ + if (ALIGN) { + CopyUbufToGm(dst, src, nBurst, lenBurst / 32, srcStride, dstStride / 32); + } else { + CopyUbufToGmAlignB16(dst, src, nBurst, lenBurst, srcStride, dstStride); + } +} + +template +FORCE_INLINE_AICORE void VectorDup(__ubuf__ T *dst, const T &src, uint8_t repeat, uint16_t dstBlockStride, + uint8_t dstRepeatStride) +{ + LocalTensor ubTensor = CreateLocalTensor(dst); + Duplicate(ubTensor, src, -1, repeat, dstBlockStride, dstRepeatStride); +} + +template +struct CoCBuffAddrAndArgs { +public: + __aicore__ inline CoCBuffAddrAndArgs(COC_ARGS_FUN(T)) + { + GlobalTensor commArgsGm; + commArgsGm.SetGlobalBuffer(reinterpret_cast<__gm__ int *>(coc_comm_args), 2); + rank = commArgsGm.GetValue(0); + localRank = commArgsGm.GetValue(1); + rankSize = commArgsGm.GetValue(2); + localRankSize = commArgsGm.GetValue(3); + extraFlag = commArgsGm.GetValue(4); + RDMA = (extraFlag & ExtraFlag::RDMA) != 0; + TOPO_910B2C = (extraFlag & ExtraFlag::TOPO_910B2C) != 0; + TOPO_910_93 = (extraFlag & ExtraFlag::TOPO_910_93) != 0; + DETERMINISTIC = (extraFlag & ExtraFlag::DETERMINISTIC) != 0; + QUANT_FP16 = (extraFlag & ExtraFlag::QUANT_FP16) != 0; + QUANT_FP32 = (extraFlag & ExtraFlag::QUANT_FP32) != 0; + GlobalTensor<__gm__ T *> peerMemsAddrGm; + peerMemsAddrGm.SetGlobalBuffer(&(reinterpret_cast<__gm__ CoCCommArgs *>(coc_comm_args))->peerMems[0], + LCAL_MAX_RANK_SIZE); + for (int i = 0; i < rankSize; ++i) { + buff[i] = peerMemsAddrGm.GetValue(i); + } + } + + int rank; // attr rank_id, global rank + int localRank; + int rankSize; // global rank size + int localRankSize; + uint32_t extraFlag; + bool RDMA; + bool TOPO_910B2C; + bool TOPO_910_93; + bool DETERMINISTIC; + bool QUANT_FP16; + bool QUANT_FP32; + __gm__ T *buff[LCAL_MAX_RANK_SIZE]; // 共享内存地址列表 + //int64_t sendCountMatrix[LCAL_MAX_RANK_SIZE * LCAL_MAX_RANK_SIZE]; +}; + +FORCE_INLINE_AICORE void CommMatrixTrunc(__gm__ int32_t* global_tokens_per_expert_matrix, __gm__ int32_t* workspace, int32_t EP, int32_t local_expert_nums, int32_t maxOutputSize) +{ + int32_t expert_nums = local_expert_nums * EP; + for(int32_t i = 0; i < EP; i++) { + int32_t sum_tokens = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id++) { + int32_t expert_id = i * local_expert_nums + local_expert_id; + for(int32_t j = 0; j < EP; j++) { + if (sum_tokens + global_tokens_per_expert_matrix[j * expert_nums + expert_id] + >= maxOutputSize) { + workspace[j * expert_nums + expert_id] = maxOutputSize - sum_tokens; + sum_tokens = maxOutputSize; + } else { + workspace[j * expert_nums + expert_id] = global_tokens_per_expert_matrix[j * expert_nums + expert_id]; + sum_tokens += global_tokens_per_expert_matrix[j * expert_nums + expert_id]; + } + } + } + } +} + +#endif // LCAL_COC_INTERNAL_H diff --git a/comm/lcal/src/kernels/coc_matmul_allreduce.cce b/comm/lcal/src/kernels/coc_matmul_allreduce.cce new file mode 100644 index 0000000000000000000000000000000000000000..757553613c3a2c16d9799904d0a6312a409af51b --- /dev/null +++ b/comm/lcal/src/kernels/coc_matmul_allreduce.cce @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#include "coc_ppmatmul_switch.cce" +#include "coc_allreduce.cce" +#include "coc_internal.cce" + +#ifdef __DAV_C220_CUBE__ +// Matmul in LcalMatmulAllReduce +#define COC_MATMUL_ALLREDUCE_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulAllReduce_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +#elif __DAV_C220_VEC__ +// AllReduce in LcalMatmulAllReduce +#define COC_MATMUL_ALLREDUCE_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulAllReduce_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + CocMatmulAllReduceAiv(COC_ARGS_CALL()); \ +} +#endif + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // 910B support bf16 +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_MATMUL_ALLREDUCE_FUNC_AUTO_DEF); +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_matmul_reduce_scatter.cce b/comm/lcal/src/kernels/coc_matmul_reduce_scatter.cce new file mode 100644 index 0000000000000000000000000000000000000000..5366e03cfd14a9a32f1af232d5f086b4dd49a016 --- /dev/null +++ b/comm/lcal/src/kernels/coc_matmul_reduce_scatter.cce @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#include "coc_ppmatmul_switch.cce" +#include "coc_reduce_scatter.cce" +#include "coc_internal.cce" + +#ifdef __DAV_C220_CUBE__ +// Matmul in LcalMatmulReduceScatter +#define COC_MATMUL_REDUCE_SCATTER_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulReduceScatter_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +#elif __DAV_C220_VEC__ +// ReduceScatter in LcalMatmulReduceScatter +#define COC_MATMUL_REDUCE_SCATTER_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulReduceScatter_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + CocMatmulReduceScatterAiv(COC_ARGS_CALL()); \ +} +#endif + + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // 910B support bf16 +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_MATMUL_REDUCE_SCATTER_FUNC_AUTO_DEF); +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_matmul_reduce_scatter_alltoallv.cce b/comm/lcal/src/kernels/coc_matmul_reduce_scatter_alltoallv.cce new file mode 100644 index 0000000000000000000000000000000000000000..4d1e93acfadb4041d27993f32ab097a079c50873 --- /dev/null +++ b/comm/lcal/src/kernels/coc_matmul_reduce_scatter_alltoallv.cce @@ -0,0 +1,28 @@ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + + +#include "coc_ppmatmul_switch.cce" +#include "coc_alltoall_reduce_scatter_hidden.cce" +#ifdef __DAV_C220_CUBE__ +#define COC_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulReduceScatterAllToAllVHidden_##type##_mix_aic(COC_ARGS_FUN(type)){ \ + return CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + +#elif __DAV_C220_VEC__ +#define COC_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalMatmulReduceScatterAllToAllVHidden_##type##_mix_aiv(COC_ARGS_FUN(type)){ \ + return CocMatmulAllToAllVReduceScatterHiddenAiv(COC_ARGS_CALL()); \ +} +#endif + + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN_FUNC_AUTO_DEF); +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_matmulmoe.cce b/comm/lcal/src/kernels/coc_matmulmoe.cce new file mode 100644 index 0000000000000000000000000000000000000000..5bf45181f68ba7a837538e16f51088a4fb26bc83 --- /dev/null +++ b/comm/lcal/src/kernels/coc_matmulmoe.cce @@ -0,0 +1,1171 @@ +#include "coc_internal.cce" +#include "coc_ppmatmul.cce" +#ifdef __DAV_C220_CUBE__ +template +class PpMatmulMoe : public PpMatmul { + using T_ACCUM = typename GetAccumType::T; + static constexpr bool IS_INT8 = std::is_same::value; +public: + __aicore__ explicit PpMatmulMoe() {}; + inline __aicore__ void SetArgs(PP_MATMUL_AIC_ARGS_FUN(MmadDtype, OutDtype)) + { + PpMatmul::SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + + // moe args + is_moe_averaged = 0; + if (global_tokens_per_expert_matrix != nullptr) { + this -> global_tokens_per_expert_matrix = global_tokens_per_expert_matrix; + } else { + is_moe_averaged = 1; + } + this->local_expert_nums = local_expert_nums; + expert_nums = local_expert_nums * EP; + this->EP = EP; + this->TP = TP; + this->maxOutputSize = maxOutputSize; + + } + + //GMM + inline __aicore__ void CalLoop(int64_t batch_idx, int64_t m_idx, int64_t n_idx, int32_t m_actual, int32_t n_actual, + __gm__ MmadDtype *gm_a_src_tmp, __gm__ MmadDtype *gm_b_src_tmp, int32_t k, int32_t k_all, int32_t expert_dequant_param_offset = 0) { + + + int32_t k_loop = DivCeil(k, k0); + int32_t k_align = Block512B::AlignUp(k); + int32_t k_all_align = Block512B::AlignUp(k_all); + if (k != k_all) { + k_align = k; + } + + int64_t offset_a, offset_b, offset_a_next, offset_b_next; + int32_t m_round, n_round; + if (IS_INT8) { + // directive Restrictions + if (TA) { + m_round = DivCeil(m_actual, BLOCK_SIZE_32) * BLOCK_SIZE_32; + } else { + m_round = DivCeil(m_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } + if (TB) { + n_round = DivCeil(n_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } else { + n_round = DivCeil(n_actual, BLOCK_SIZE_32) * BLOCK_SIZE_32; + } + } else { + m_round = DivCeil(m_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + n_round = DivCeil(n_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } + + int32_t mn_max = m_round > n_round ? m_round : n_round; + int32_t k_part_len = L0AB_PINGPONG_BUFFER_LEN / mn_max / block_size * block_size; + if (TA) { + if (aligned_a == 1) { + offset_a = batch_idx * k * m_align + m_idx * m0; + } else { + offset_a = batch_idx * k * m + m_idx * m0; + } + } else { + if (aligned_a == 1) { + offset_a = batch_idx * m * k_align + m_idx * m0 * k_align; + } else { + offset_a = batch_idx * m * k + m_idx * m0 * k; + } + } + + if (TB) { + if (aligned_b == 1) { + offset_b = batch_idx * n * k_all_align + n_idx * n0 * k_all_align; + } else { + if (weight_nz) { + offset_b = n_idx * n0 * block_size; + } else { + offset_b = n_idx * n0 * k_all; + } + } + } else { + if (aligned_b == 1) { + offset_b = batch_idx * k * n_align + n_idx * n0; + } else { + if (weight_nz) { + offset_b = n_idx * n0 * k_align16; + } else { + offset_b = n_idx * n0; + } + } + } + + int64_t dequant_param_offset = n_idx * n0 + expert_dequant_param_offset; + + int32_t k_actual = (k_loop == 1) ? k : k0; + int32_t k_round = DivCeil(k_actual, block_size) * block_size; // int8 :32 fp16 :16 + + auto l1_buf_a = ping_flag ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + auto l1_buf_b = ping_flag ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto l0a_buf = ping_flag ? l0a_base : l0a_base + L0AB_PINGPONG_BUFFER_LEN; + auto l0b_buf = ping_flag ? l0b_base : l0b_base + L0AB_PINGPONG_BUFFER_LEN; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (IS_INT8 && has_offset) { + PipeBarrier(); + IntrinsicCopyGmToL1Nd2Nz::move( + ((__cbuf__ int32_t *)bias_l1), + ((__gm__ int32_t *)gm_format_dequant_offset) + dequant_param_offset, + 0, // sid + 1, // ndNum + 1, // nValue + n_actual, // dValue + 0, // srcNdMatrixStride, unused + n, // srcDValue + 1, // dstNzC0Stride + 1, // dstNzNStride + 0 // dstNzMatrixStride, unused + ); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + CopyCubfToBt(((uint64_t)bias_bt), ((__cbuf__ int32_t *)bias_l1), + (uint16_t)0ULL, 1, (n_actual * 4 + 63) / 64, 0, 0); + SetFlag(EVENT_ID1); // bias ready, mte2 can begin move A/B or scalar + SetFlag(EVENT_ID1); // bias ready, mmad can begin + WaitFlag(EVENT_ID1); // A/B or scalar wait moving bias from L1 to BT + + } + + auto gm_src_a = gm_a_src_tmp + offset_a; + //auto gm_src_b = gm_b_src + offset_b; + auto gm_src_b = gm_b_src_tmp + offset_b; + WaitFlag(event_id); + // *** load matrix A to L1 + if (m_actual == 1 && !TA) { + CopyGmToCbuf( + l1_buf_a, + gm_src_a, + 0, // sid + 1, // nBurst + k_round / block_size, // lenBurst + 0, // srcGap + 0, // dstGap + PAD_NONE // padMode + ); + } else { + if (TA) { + auto src_len = m; + if (aligned_a == 1) { + src_len = m_align; + } + CopyGmToL1Nd2zN::move(l1_buf_a, gm_src_a, k_actual, m_actual, src_len, k_round); + } else { + auto src_len = k; + if (aligned_a == 1) { + src_len = k_align; + } + CopyGmToL1Nd2zN::move(l1_buf_a, gm_src_a, m_actual, k_actual, src_len, m_round); + } + } + SetFlag(event_id); + + // *** load matrix B to L1 + WaitFlag(event_id + 2); + if (TB) { + //auto src_len = k; + auto src_len = k_all; + if (aligned_b == 1) { + //src_len = k_align; + src_len = k_all_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(k_actual, block_size); + CopyGmToCbuf(l1_buf_b, gm_src_b, 0, num_col, n_actual, n_align16 - n_actual, n_round - n_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b, gm_src_b, n_actual, k_actual, src_len, n_round); + } + } else { + auto src_len = n; + if (aligned_b == 1) { + src_len = n_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(n_actual, block_size); + CopyGmToCbuf(l1_buf_b, gm_src_b, 0, num_col, k_actual, k_align16 - k_actual, k_round - k_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b, gm_src_b, k_actual, n_actual, src_len, k_round); + } + } + SetFlag(event_id + 2); + + int mte1_mad_ping_flag = 1; + + for (int64_t k_idx = 0; k_idx < k_loop; k_idx++) { + + int32_t k_actual = (k_idx == (k_loop - 1)) ? (k - k_idx * k0) : k0; + int32_t k_round = DivCeil(k_actual, block_size) * block_size; + int32_t k_part_loop = DivCeil(k_actual, k_part_len); + + __cbuf__ MmadDtype *l1_buf_a = ping_flag ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + __cbuf__ MmadDtype *l1_buf_b = ping_flag ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (k_idx < k_loop - 1) { + if (TA) { + if (aligned_a == 1) { + offset_a_next = batch_idx * k * m_align + (k_idx + 1) * k0 * m_align + m_idx * m0; + } else { + offset_a_next = batch_idx * k * m + (k_idx + 1) * k0 * m + m_idx * m0; + } + } else { + if (aligned_a == 1) { + offset_a_next = batch_idx * m * k_align + m_idx * m0 * k_align + (k_idx + 1) * k0; + } else { + offset_a_next = batch_idx * m * k + m_idx * m0 * k + (k_idx + 1) * k0; + } + } + if (TB) { + if (aligned_b == 1) { + //offset_b_next = batch_idx * n * k_align + n_idx * n0 * k_align + (k_idx + 1) * k0; + offset_b_next = batch_idx * n * k_align + n_idx * n0 * k_all_align + (k_idx + 1) * k0; + } else { + if (weight_nz) { + offset_b_next = batch_idx * n * k + (k_idx + 1) * k0 * n_align16 + n_idx * n0 * block_size; + } else { + offset_b_next = batch_idx * n * k + n_idx * n0 * k_all + (k_idx + 1) * k0; + } + //offset_b_next = batch_idx * n * k + n_idx * n0 * k + (k_idx + 1) * k0; + //offset_b_next = batch_idx * n * k + n_idx * n0 * k_all + (k_idx + 1) * k0; + } + } else { + if (aligned_b == 1) { + offset_b_next = batch_idx * k * n_align + (k_idx + 1) * k0 * n_align + n_idx * n0; + } else { + //offset_b_next = batch_idx * k * n + (k_idx + 1) * k0 * n + n_idx * n0; + if (weight_nz) { + offset_b_next = batch_idx * k * n + (k_idx + 1) * k0 * block_size + n_idx * n0 * k_align16; + } else { + offset_b_next = batch_idx * k * n + (k_idx + 1) * k0 * n + n_idx * n0; + } + } + } + + int32_t k_actual_next = ((k_idx + 1) == (k_loop - 1)) ? (k - (k_idx + 1) * k0) : k0; + int32_t k_round_next = DivCeil(k_actual_next, block_size) * block_size; + + __cbuf__ MmadDtype *l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + __cbuf__ MmadDtype *l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + auto gm_src_a = gm_a_src_tmp + offset_a_next; + //auto gm_src_b = gm_b_src + offset_b_next; + auto gm_src_b = gm_b_src_tmp + offset_b_next; + WaitFlag(event_id_next); + // *** load matrix A to L1 + if (m_actual == 1 && !TA) { + CopyGmToCbuf( + l1_buf_a_next, + gm_src_a, + 0, // sid + 1, // nBurst + k_round_next / block_size, // lenBurst + 0, // srcGap + 0, // dstGap + PAD_NONE // padMode + ); + } else { + if (TA) { + auto src_len = m; + if (aligned_a == 1) { + src_len = m_align; + } + CopyGmToL1Nd2zN::move( + l1_buf_a_next, gm_src_a, k_actual_next, m_actual, src_len, k_round_next); + } else { + auto src_len = k; + if (aligned_a == 1) { + src_len = k_align; + } + CopyGmToL1Nd2zN::move( + l1_buf_a_next, gm_src_a, m_actual, k_actual_next, src_len, m_round); + } + } + SetFlag(event_id_next); + + // *** load matrix B to L1 + WaitFlag(event_id_next + 2); + if (TB) { + //auto src_len = k; + auto src_len = k_all; + if (aligned_b == 1) { + //src_len = k_align; + src_len = k_all_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(k_actual_next, block_size); + CopyGmToCbuf(l1_buf_b_next, gm_src_b, 0, num_col, n_actual, n_align16 - n_actual, n_round - n_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b_next, gm_src_b, n_actual, k_actual_next, src_len, n_round); + } + // CopyGmToL1Nd2zN::move( + // l1_buf_b_next, gm_src_b, n_actual, k_actual_next, src_len, n_round); + } else { + auto src_len = n; + if (aligned_b == 1) { + src_len = n_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(n_actual, block_size); + CopyGmToCbuf(l1_buf_b_next, gm_src_b, 0, num_col, k_actual_next, k_align16 - k_actual_next, k_round_next - k_actual_next, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b_next, gm_src_b, k_actual_next, n_actual, src_len, k_round_next); + } + // CopyGmToL1Nd2zN::move( + // l1_buf_b_next, gm_src_b, k_actual_next, n_actual, src_len, k_round_next); + } + SetFlag(event_id_next + 2); + } + + for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { + int32_t k0_round = (k_part_idx < k_part_loop - 1) ? + k_part_len : k_round - k_part_idx * k_part_len; + int32_t k0_actual = (k_part_idx < k_part_loop - 1) ? + k_part_len : k_actual - k_part_idx * k_part_len; + + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + auto l0a_buf = l0a_base + (1 - mte1_mad_ping_flag) * L0AB_PINGPONG_BUFFER_LEN; + auto l0b_buf = l0b_base + (1 - mte1_mad_ping_flag) * L0AB_PINGPONG_BUFFER_LEN; + + // *** load matrix A from L1 to L0A + if (k_part_idx == 0) { + WaitFlag(event_id); + } + WaitFlag(mte1_mad_event_id); + if (m_actual == 1 && !TA) { + LoadCbufToCa( + l0a_buf, + l1_buf_a + k_part_idx * k_part_len, + 0, // baseIdx + DivCeil(k0_round, cube_matrix_size), // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } else { + if (TA) { + if (IS_INT8) { + for (int i = 0; i < m_round / BLOCK_SIZE_32; i++) { + LoadCbufToCaTranspose( + l0a_buf + i * k0_round * BLOCK_SIZE_32, + l1_buf_a + k_part_idx * k_part_len * BLOCK_SIZE_32 + + i * k_round * BLOCK_SIZE_32, + 0, // baseIdx + k0_round / BLOCK_SIZE_32, // repeat + 1, // srcStride + 0, // dstStride + 0, // addrmode + k0_round / BLOCK_SIZE_32 - 1 // dstFracStride + ); + } + } else { + for (int i = 0; i < m_round / BLOCK_SIZE_16; i++) { + LoadCbufToCa( + l0a_buf + i * k0_round * BLOCK_SIZE_16, + l1_buf_a + k_part_idx * k_part_len * BLOCK_SIZE_16 + + i * k_round * BLOCK_SIZE_16, + 0, // baseIdx + k0_round / BLOCK_SIZE_16, // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + true, // transpose + inc // addr_cal_mode_t + ); + } + } + } else { + for (int32_t i = 0; i < k0_round / block_size; i++) { + LoadCbufToCa( + l0a_buf + i * cube_matrix_size, + l1_buf_a + k_part_idx * k_part_len * m_round + + i * m_round * block_size, + 0, // baseIdx + m_round / BLOCK_SIZE_16, // repeat + 1, // srcStride + k0_round / block_size - 1, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } + } + } + if (k_part_idx == k_part_loop - 1) { + SetFlag(event_id); + } + + // *** load matrix B from L1 to L0B + if (k_part_idx == 0) { + WaitFlag(event_id + 2); + } + if (TB) { + LoadCbufToCb( + l0b_buf, + l1_buf_b + k_part_idx * k_part_len * n_round, + 0, // baseIdx + k0_round * n_round / cube_matrix_size, // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } else { + if (IS_INT8) { + for (int32_t i = 0; i < k0_round / BLOCK_SIZE_32; i++) { + LoadCbufToCbTranspose( + l0b_buf + i * ((n_actual + 15) / 16 * 16) * BLOCK_SIZE_32, + l1_buf_b + (k_part_idx * k_part_len + i * BLOCK_SIZE_32) * BLOCK_SIZE_32, + 0, // baseIdx + n_round / BLOCK_SIZE_32, // repeat + k_round / BLOCK_SIZE_32, // srcStride + 1, // dstStride + 0, // addrmode + 0 // dstFracStride + ); + } + } else { + for (int32_t i = 0; i < k0_round / BLOCK_SIZE_16; i++) { + LoadCbufToCb( + l0b_buf + i * n_round * BLOCK_SIZE_16, + l1_buf_b + (k_part_idx * k_part_len + i * BLOCK_SIZE_16) * BLOCK_SIZE_16, + 0, // baseIdx + n_round / BLOCK_SIZE_16, // repeat + k_round / BLOCK_SIZE_16, // srcStride + 0, // dstStride + 0, // sid + true, // transpose + inc // addr_cal_mode_t + ); + } + } + } + if (k_part_idx == k_part_loop - 1) { + SetFlag(event_id + 2); + } + + SetFlag(mte1_mad_event_id); + WaitFlag(mte1_mad_event_id); + + bool init_c = (k_idx == 0 && k_part_idx == 0); + if (init_c) { + WaitFlag(EVENT_ID0); + } + + if (IS_INT8 && has_offset) { + if (init_c) { + WaitFlag(EVENT_ID1); // wait move bias fron L1 to BT + } + PipeBarrier(); + if (m != 1 && m_actual == 1 && TA) { + mad((__cc__ int32_t *)l0c_buf, + (__ca__ int8_t *)l0a_buf, + (__cb__ int8_t *)l0b_buf, + ((uint64_t)bias_bt), + 16, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + init_c, // cmatrixSource add C from BT + 0 // cmatrixInitVal + ); + } else { + mad((__cc__ int32_t *)l0c_buf, + (__ca__ int8_t *)l0a_buf, + (__cb__ int8_t *)l0b_buf, + ((uint64_t)bias_bt), + m_actual, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + init_c, // cmatrixSource add C from BT + 0 // cmatrixInitVal + ); + } + //has_offset = 0; + } else { + PipeBarrier(); + if (m != 1 && m_actual == 1 && TA) { + mad(l0c_buf, + l0a_buf, + l0b_buf, + 16, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + 0, // cmatrixSource + init_c // cmatrixInitVal + ); + } else { + mad(l0c_buf, + l0a_buf, + l0b_buf, + m_actual, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + 0, // cmatrixSource + init_c // cmatrixInitVal + ); + } + } + PipeBarrier(); + SetFlag(mte1_mad_event_id); + + mte1_mad_ping_flag = 1 - mte1_mad_ping_flag; + } + ping_flag = 1 - ping_flag; + } + + + if (IS_INT8 && std::is_same::value && (dequant_granularity == QuantGranularity::PER_CHANNEL || + dequant_granularity == QuantGranularity::PER_TOKEN)) { + //if (IS_INT8 && std::is_same::value && (dequant_granularity == QuantGranularity::PER_CHANNEL)) { + WaitFlag(EVENT_ID0); + PipeBarrier(); + CopyGmToCbuf( + scale_l1, + gm_dequant_scale + dequant_param_offset, + 0, + 1, + (n_actual * sizeof(int64_t) + 31) / 32, + 0, + 0, + PAD_NONE + ); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + + copy_cbuf_to_fbuf( + scale_FB, + scale_l1, + 1, + (n_actual * sizeof(int64_t) + 127) / 128, + 0, + 0 + ); + PipeBarrier(); + } + } + + inline __aicore__ void MoveL0CToGM(__gm__ OutDtype *gm_dst, int64_t offset_c, int64_t offset_l0c, int32_t m_actual, int32_t n_actual, int32_t src_stride, int32_t dst_stride) { + #if (__CCE_AICORE__ == 220) + FixpipeParamsV220 FixpipeParams( + n_actual, // nSize = nSizeIn; + m_actual, // mSize = mSizeIn; + src_stride, // srcStride = srcStrideIn; + dst_stride, // dstStride = dstStrideIn; + false // reluEn = reluEnIn; + ); + #elif (defined(__DAV_C310__)) + FixpipeParamsC310 FixpipeParams( + n_actual, // nSize = nSizeIn; + m_actual, // mSize = mSizeIn; + src_stride, // srcStride = srcStrideIn; + dst_stride // dstStride = dstStrideIn; + ); + #endif + LocalTensor srcTensor = CreateLocalTensor + (reinterpret_cast(l0c_buf + offset_l0c), static_cast(TPosition::CO1)); + GlobalTensor dstTensor = CreateGlobalTensor(gm_dst + offset_c); + + if (IS_INT8) { + if constexpr (std::is_same::value) { + if (dequant_granularity == QuantGranularity::PER_CHANNEL || dequant_granularity == QuantGranularity::PER_TOKEN) { + SetFpc(scale_FB); + FixpipeParams.quantPre = VDEQF16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + //SetFlag(EVENT_ID0); + } else if (dequant_granularity == QuantGranularity::PER_TENSOR) { + FixpipeParams.quantPre = DEQF16; + FixpipeParams.deqScalar = gm_dequant_scale[0]; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } + } else if constexpr (std::is_same::value) { + GlobalTensor dstAccum = CreateGlobalTensor(gm_accum + offset_c); + Fixpipe(dstAccum, srcTensor, FixpipeParams); + } + } else { + if constexpr (std::is_same::value) { + FixpipeParams.quantPre = F322BF16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } else { + FixpipeParams.quantPre = F322F16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } + } + } + + + + inline __aicore__ void RunAllToAllAllGatherMatmul(){ + InitFlags(); + int32_t k_actual; + if (aligned_a){ + k_actual = k_align; + } else { + k_actual = k; + } + + int64_t gm_a_pingpong_size = m0 * k_align * p_value * rank_size; + + int32_t comm_m = m0 * p_value; + int32_t comm_count; + if (is_moe_averaged) { + comm_count = DivCeil(m / EP , comm_m); + } else { + int32_t max_comm_count = 0; + int32_t max_input_per_ep = 0; + int32_t max_output_per_ep = 0; + for (int32_t ep_idx = 0; ep_idx < EP; ep_idx++) { + int32_t tmp_sum = 0; + for(int32_t i = 0; i < local_expert_nums; i++) { + tmp_sum += global_tokens_per_expert_matrix[rank * expert_nums + ep_idx * local_expert_nums + i]; + } + max_output_per_ep = max(max_output_per_ep, tmp_sum); + tmp_sum = 0; + for(int32_t i = 0; i < local_expert_nums; i++) { + tmp_sum += global_tokens_per_expert_matrix[ep_idx * expert_nums + rank * local_expert_nums + i]; + } + max_input_per_ep = max(max_input_per_ep, tmp_sum); + max_comm_count = max(max_comm_count, max(max_output_per_ep,max_input_per_ep)); + } + comm_count = DivCeil(max_comm_count, comm_m); + } + int32_t in_expert_offset[16 * 16] = {0};//[i][j]代表第i个expert的第j个rank的offset + int32_t data_len_in_expert_from_rank[16] = {0};//当前expert从rank[i]收到的token数。 + int32_t before_expert_offset_dst[16] = {0}; + int32_t before_rank_offset[16 * 16] = {0}; + + for (int32_t local_expert_idx = 0; local_expert_idx < local_expert_nums; local_expert_idx ++) { + int32_t expert_idx = rank * local_expert_nums + local_expert_idx; + if (is_moe_averaged) { + before_expert_offset_dst[local_expert_idx] = local_expert_idx * (m / local_expert_nums); + } else { + for(int32_t i = 0; i < local_expert_idx; i++) { + for(int32_t j = 0; j < rank_size; j++) { + before_expert_offset_dst[local_expert_idx] += global_tokens_per_expert_matrix[j * expert_nums + i + rank * local_expert_nums]; + } + } + } + + for (int i = 0; i < rank_size - 1; i ++) { + int32_t tmp_len; + if (is_moe_averaged) { + tmp_len = m / local_expert_nums / rank_size; + } else { + tmp_len = global_tokens_per_expert_matrix[i * expert_nums + expert_idx]; + } + before_rank_offset[local_expert_idx * rank_size + i + 1] = before_rank_offset[local_expert_idx * rank_size + i] + tmp_len; + } + } + int32_t out_this_rank[16] = {0}; + for(int32_t i = 0; i < rank_size; i++) { + for (int32_t j = 0; j < local_expert_nums; j++ ) { + if (is_moe_averaged) { + out_this_rank[i] = m / EP; + } else { + out_this_rank[i] += global_tokens_per_expert_matrix[i * expert_nums + j + rank * local_expert_nums]; + } + } + } + + + + + int32_t sum_loop = 0; + for(int32_t comm_idx = 0; comm_idx < comm_count; comm_idx++){ + uint64_t flag_id = comm_idx % MAX_BLOCK_COUNT; + WaitEvent(flag_id); + + for (int32_t local_expert_idx = 0; local_expert_idx < local_expert_nums; local_expert_idx ++) { + int32_t expert_idx = rank * local_expert_nums + local_expert_idx; + int32_t expert_num_this_com = 0; //本次通信该expert收到的token数。 + int32_t before_expert_offset_src = 0; //本次通信该expert在共享内存中的地址的offset。 + for(int32_t i = 0; i < rank_size; i++) { + int32_t data_len; + if ((comm_idx + 1) * comm_m >= out_this_rank[i]) { + data_len = out_this_rank[i] - comm_idx * comm_m; + } else { + data_len = comm_m; + } + + //expert token数的前缀和 + int32_t sum = 0; + for(int32_t j = 0; j <= local_expert_idx; j++) { + int32_t tmp_expert_id = rank * local_expert_nums + j; + int32_t expert_token_num; + if (is_moe_averaged) { + expert_token_num = (m / expert_nums); + } else { + expert_token_num = global_tokens_per_expert_matrix[i * expert_nums + tmp_expert_id]; + } + + if (comm_idx * comm_m < sum + expert_token_num && comm_idx * comm_m + data_len > sum) { + int32_t tmp_len = min(comm_idx * comm_m + data_len, sum + expert_token_num) - + max(comm_idx * comm_m, sum); + if (j < local_expert_idx) { + before_expert_offset_src += tmp_len; + } else { + expert_num_this_com += tmp_len; + data_len_in_expert_from_rank[i] = tmp_len; + } + } else { + if (j == local_expert_idx) { + data_len_in_expert_from_rank[i] = 0; + } + } + sum += expert_token_num; + } + } + + int32_t m_loop_in_expert = DivCeil(expert_num_this_com, m0); + int32_t loop_in_expert = m_loop_in_expert * n_loop; + for(int32_t loop_idx = 0; loop_idx < loop_in_expert; loop_idx ++) { + if ((loop_idx + sum_loop) % core_num != core_idx) { + continue; + } + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop_in_expert, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == m_loop_in_expert - 1) ? expert_num_this_com - m_idx * m0 : m0; + int32_t n_actual = (n_idx == n_loop - 1) ? n - n_idx * n0 : n0; + __gm__ MmadDtype *gm_peer_mem_st = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + + flag_id * gm_a_pingpong_size + + before_expert_offset_src * k_align; + + __gm__ MmadDtype *gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k * n_align; + if(TB){ + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align * n; + } + if (weight_nz) { + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align16 * n_align16; + } + CalLoop(0, m_idx, n_idx, m_actual, n_actual, gm_peer_mem_st, gm_b_src_tmp, k, k, n * local_expert_idx); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t rank_offset = 0; + int32_t l0c_offset = 0; + for (int32_t src_rank_id = 0; src_rank_id < rank_size; src_rank_id ++) { + if (rank_offset + data_len_in_expert_from_rank[src_rank_id] > m_idx * m0 && + rank_offset < m_idx * m0 + m_actual) { + int32_t rank_m_actual = min(m_idx * m0 + m_actual, rank_offset + data_len_in_expert_from_rank[src_rank_id]) - + max(rank_offset, m_idx * m0); + int32_t dst_stride = n; + int32_t tmp_in_rank_offset = 0; + if (m_idx * m0 > rank_offset) { + tmp_in_rank_offset = m_idx * m0 - rank_offset; + } + int64_t offset_c = before_expert_offset_dst[local_expert_idx] * n + before_rank_offset[local_expert_idx * rank_size + src_rank_id] * n + + in_expert_offset[local_expert_idx * rank_size + src_rank_id] * n + + tmp_in_rank_offset * n + + n_idx * n0; + int32_t src_stride = (m_actual + 15) / 16 * 16; + int32_t real_rank_m_actual = rank_m_actual; + int64_t m_offset_c = before_expert_offset_dst[local_expert_idx] + before_rank_offset[local_expert_idx * rank_size + src_rank_id] + + in_expert_offset[local_expert_idx * rank_size + src_rank_id] + tmp_in_rank_offset; + if (maxOutputSize > 0) { + if (maxOutputSize <= m_offset_c) { + real_rank_m_actual = 0; + } else if (m_offset_c + real_rank_m_actual > maxOutputSize) { + real_rank_m_actual = maxOutputSize - m_offset_c; + } + } + if (real_rank_m_actual > 0) { + MoveL0CToGM(gm_c, offset_c, l0c_offset, real_rank_m_actual, n_actual, src_stride, dst_stride); + } + l0c_offset += rank_m_actual * 16; + } + rank_offset += data_len_in_expert_from_rank[src_rank_id]; + } + + if (IS_INT8) { + if constexpr (std::is_same::value) { + if (dequant_granularity == QuantGranularity::PER_CHANNEL || dequant_granularity == QuantGranularity::PER_TOKEN) { + SetFlag(EVENT_ID0); + } + } + } + SetFlag(EVENT_ID0); + if (IS_INT8 && has_offset) { + SetFlag(EVENT_ID1); + } + } + + for (int32_t i = 0; i < rank_size; i ++) { + in_expert_offset[local_expert_idx * rank_size + i] += data_len_in_expert_from_rank[i]; + } + sum_loop += loop_in_expert; + } + FFTSCrossCoreSync(2, flag_id); + } + Endflags(); + PipeBarrier(); + } + + + inline __aicore__ void RunAllToAllAllGatherMatmulHidden(){ + InitFlags(); + int32_t max_m; + int32_t sum_m[16] = {0}; + int32_t sum_m_loop = 0; + if(is_moe_averaged) { + sum_m_loop = DivCeil((m / expert_nums) * EP, m0) * local_expert_nums; + max_m = m; + } else { + if (maxOutputSize == -1) { + max_m = 0; + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t sum_m_ep = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = local_expert_id + ep_idx * local_expert_nums; + for(int32_t i = 0; i < EP; i++) { + sum_m_ep += global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + } + max_m = max(max_m, sum_m_ep); + } + } else { + max_m = maxOutputSize; + } + for(int32_t i = 0; i < local_expert_nums; i++){ + int32_t last_sum_m = (i == 0 ? 0 : sum_m[i - 1]); + for(int j = 0; j < EP; j++) { + sum_m[i] += global_tokens_per_expert_matrix[j * expert_nums + rank * local_expert_nums + i]; + //global_tokens_per_expert_matrix[j][rank * local_expert_nums + i] + } + if (maxOutputSize > 0 && sum_m[i] + last_sum_m > maxOutputSize) { + sum_m[i] = maxOutputSize - last_sum_m; + } + sum_m_loop += DivCeil(sum_m[i], m0); + sum_m[i] += (i == 0 ? 0 : sum_m[i - 1]); + } + } + + int32_t comm_k = k0 * p_value; + int64_t gm_a_pingpong_size = comm_k * max_m; + int64_t gm_a_pingpong_num = buffer_size * 1024 * 1024 / sizeof(MmadDtype) / gm_a_pingpong_size; + if (gm_a_pingpong_num > 8) { + gm_a_pingpong_num = 8; + } + int32_t comm_count = DivCeil(k, comm_k); + int32_t sum_loop_num = sum_m_loop * n_loop; + int32_t sum_loop = 0; + //SetAtomicAdd(); + for(int32_t comm_idx = 0; comm_idx < comm_count; comm_idx++){ + if (comm_idx == 1) { + PipeBarrier(); + SetAtomicAdd(); + PipeBarrier(); + } + int32_t k_len; + if(comm_idx == comm_count - 1) { + k_len = k - comm_idx * comm_k; + } else { + k_len = comm_k; + } + if (comm_idx == 1) { + PipeBarrier(); + SetAtomicAdd(); + PipeBarrier(); + FFTSCrossCoreSync(0, AIC_FINISH_MATMUL_FLAG_ID); + WaitEvent(AIC_FINISH_MATMUL_FLAG_ID); + } + + uint64_t flag_id = comm_idx % gm_a_pingpong_num; + WaitEvent(flag_id); + for(int32_t loop_idx = 0; loop_idx < sum_loop_num; loop_idx ++) { + if((loop_idx + sum_loop) % core_num != core_idx) { + continue; + } + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, sum_m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + /* + 1.先判断m_idx和n_idx属于哪个expert。 + 2.再计算在该expert内的坐标。 + */ + int32_t sum_loop_before = 0; + int32_t local_expert_idx = -1; + int32_t m_in_expert; + for(int32_t i = 0; i < local_expert_nums; i++) { + if(is_moe_averaged) { + m_in_expert = m / local_expert_nums; + } else { + m_in_expert = sum_m[i] - (i == 0 ? 0 : sum_m[i - 1]); + } + sum_loop_before += DivCeil(m_in_expert, m0); + if(sum_loop_before > m_idx) { + local_expert_idx = i; + break; + } + } + int32_t m_loop_in_expert = DivCeil(m_in_expert, m0); + sum_loop_before -= m_loop_in_expert; + int32_t m_idx_in_expert = m_idx - sum_loop_before; + int32_t m_actual = (m_idx_in_expert == m_loop_in_expert - 1 ? m_in_expert - m_idx_in_expert * m0 : m0); + int32_t n_actual = (n_idx == n_loop - 1) ? n - n_idx * n0 : n0; + int32_t sum_m_before; + if(is_moe_averaged) { + sum_m_before = local_expert_idx * (m / local_expert_nums); + } else { + sum_m_before = sum_m[local_expert_idx] - m_in_expert; + } + + __gm__ MmadDtype *gm_peer_mem_st = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + + 1LL * flag_id * gm_a_pingpong_size + + 1LL * sum_m_before * k_len; + __gm__ MmadDtype *gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k * n_align + 1LL * comm_idx * comm_k * n_align; + //__gm__ MmadDtype *gm_b_src_tmp = gm_b_src; + if(TB){ + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align * n + 1LL * comm_idx * comm_k; + } + if (weight_nz) { + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align16 * n_align16 + 1LL * comm_idx * comm_k * block_size; + //gm_b_src_tmp = gm_b_src; + } + //CalLoop(0, m_idx_in_expert, n_idx, m_actual, n_actual, gm_peer_mem_st, gm_b_src_tmp); + CalLoop(0, m_idx_in_expert, n_idx, m_actual, n_actual, gm_peer_mem_st, gm_b_src_tmp, k_len, k_align, n * local_expert_idx); + //CalLoop(local_expert_idx, m_idx_in_expert, n_idx, m_actual, n_actual, gm_peer_mem_st, gm_b_src_tmp, k_len, k); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + int64_t offset_c = 1LL * sum_m_before * n + 1LL * m_idx_in_expert * m0 * n + 1LL * n_idx * n0; + MoveL0CToGM(gm_c, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, n); + } + sum_loop += sum_loop_num; + has_offset = 0; + FFTSCrossCoreSync(2, flag_id); + } + PipeBarrier(); + SetAtomicNone(); + PipeBarrier(); + + Endflags(); + PipeBarrier(); + } + + + inline __aicore__ void RunMatmulReduceScatterAllToAllHidden(){ + InitFlags(); + int32_t comm_n = p_value * n0; + int32_t cal_count = DivCeil(n, comm_n); + int32_t max_m; + int32_t sum_m[16] = {0}; + int32_t sum_m_loop = 0; + if(is_moe_averaged) { + sum_m_loop = DivCeil((m / expert_nums) * EP, m0) * local_expert_nums; + max_m = m; + } else { + if (maxOutputSize == -1) { + max_m = 0; + for(int32_t ep_idx = 0; ep_idx < EP; ep_idx ++) { + int32_t sum_m_ep = 0; + for(int32_t local_expert_id = 0; local_expert_id < local_expert_nums; local_expert_id ++) { + int32_t expert_id = local_expert_id + ep_idx * local_expert_nums; + for(int32_t i = 0; i < EP; i++) { + sum_m_ep += global_tokens_per_expert_matrix[i * expert_nums + expert_id]; + } + } + max_m = max(max_m, sum_m_ep); + } + } else { + max_m = maxOutputSize; + } + for(int32_t i = 0; i < local_expert_nums; i++){ + int32_t last_sum_m = (i == 0 ? 0 : sum_m[i - 1]); + for(int j = 0; j < EP; j++) { + sum_m[i] += global_tokens_per_expert_matrix[j * expert_nums + rank * local_expert_nums + i]; + //global_tokens_per_expert_matrix[j][rank * local_expert_nums + i] + } + if (maxOutputSize > 0 && sum_m[i] + last_sum_m > maxOutputSize) { + sum_m[i] = maxOutputSize - last_sum_m; + } + sum_m_loop += DivCeil(sum_m[i], m0); + sum_m[i] += (i == 0 ? 0 : sum_m[i - 1]); + } + } + + int64_t gm_a_pingpong_size = comm_n * max_m; + int64_t gm_a_pingpong_num = buffer_size * 1024 * 1024 / 2 / gm_a_pingpong_size; + if (gm_a_pingpong_num > 8) { + gm_a_pingpong_num = 8; + } + int32_t sum_loop = 0; + for (int32_t cal_idx = 0; cal_idx < cal_count; cal_idx++) { + int32_t n_len; + if(cal_idx == cal_count - 1) { + n_len = n - cal_idx * comm_n; + } else { + n_len = comm_n; + } + n_loop = DivCeil(n_len,n0); + int32_t sum_loop_num = sum_m_loop * n_loop; + int32_t flag_id = cal_idx % gm_a_pingpong_num; + WaitEvent(flag_id); + for(int32_t loop_idx = 0; loop_idx < sum_loop_num; loop_idx ++) { + if((loop_idx + sum_loop) % core_num != core_idx) { + continue; + } + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, sum_m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t sum_loop_before = 0; + int32_t local_expert_idx = -1; + int32_t m_in_expert; + for(int32_t i = 0; i < local_expert_nums; i++) { + if(is_moe_averaged) { + m_in_expert = m / local_expert_nums; + } else { + m_in_expert = sum_m[i] - (i == 0 ? 0 : sum_m[i - 1]); + } + sum_loop_before += DivCeil(m_in_expert, m0); + if(sum_loop_before > m_idx) { + local_expert_idx = i; + break; + } + } + int32_t m_loop_in_expert = DivCeil(m_in_expert, m0); + sum_loop_before -= m_loop_in_expert; + int32_t m_idx_in_expert = m_idx - sum_loop_before; + int32_t m_actual = ((m_idx_in_expert == m_loop_in_expert - 1) ? (m_in_expert - m_idx_in_expert * m0) : m0); + int32_t n_actual = ((n_idx == n_loop - 1) ? (n_len - n_idx * n0) : n0); + + int32_t sum_m_before = 0; + if(is_moe_averaged) { + sum_m_before = local_expert_idx * (m / local_expert_nums); + } else { + sum_m_before = sum_m[local_expert_idx] - m_in_expert; + } + __gm__ MmadDtype *gm_a_src_inner = gm_a_src + 1LL * sum_m_before * k_align; + __gm__ MmadDtype *gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k * n_align; + // __gm__ MmadDtype *gm_b_src_tmp = gm_b_src; + if(TB){ + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align * n; + } + if (weight_nz) { + gm_b_src_tmp = gm_b_src + 1LL * local_expert_idx * k_align16 * n_align16; + //gm_b_src_tmp = gm_b_src; + } + + int32_t real_n_idx = n_idx + cal_idx * comm_n / n0; + CalLoop(0, m_idx_in_expert, real_n_idx, m_actual, n_actual, gm_a_src_inner, gm_b_src_tmp, k, k, n * local_expert_idx); + //CalLoop(0, m_idx_in_expert, real_n_idx, m_actual, n_actual, gm_a_src_inner); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int32_t dst_stride = n_len; + __gm__ OutDtype *gm_out = reinterpret_cast<__gm__ OutDtype *>(gm_peer_mem); + int32_t offset_c = flag_id * gm_a_pingpong_size + 1LL * (sum_m_before + m_idx_in_expert * m0) * n_len + 1LL * n_idx * n0; + MoveL0CToGM(gm_out, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, dst_stride); + } + sum_loop += sum_loop_num; + FFTSCrossCoreSync(2, flag_id); + } + Endflags(); + PipeBarrier(); + } + + inline __aicore__ void Run() { + + if(RUN_TYPE == PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN) { + RunAllToAllAllGatherMatmulHidden(); + } else if(RUN_TYPE == PPMATMUL_RUN_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN){ + RunMatmulReduceScatterAllToAllHidden(); + } else if (RUN_TYPE == PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL) { + RunAllToAllAllGatherMatmul(); + } + } + using PpMatmul::gm_a_src; + using PpMatmul::gm_b_src; + using PpMatmul::gm_c; + using PpMatmul::gm_peer_mem; + using PpMatmul::gm_dequant_scale; + using PpMatmul::gm_format_dequant_offset; + using PpMatmul::gm_accum; + using PpMatmul::l1_base_a; + using PpMatmul::l1_base_b; + using PpMatmul::l0a_base; + using PpMatmul::l0b_base; + using PpMatmul::l0c_buf; + using PpMatmul::scale_l1; + using PpMatmul::scale_FB; + using PpMatmul::bias_l1; + using PpMatmul::bias_bt; + using PpMatmul::has_offset; + using PpMatmul::core_num; + using PpMatmul::batch_size; + using PpMatmul::m; + using PpMatmul::k; + using PpMatmul::n; + using PpMatmul::m_align; + using PpMatmul::k_align; + using PpMatmul::n_align; + using PpMatmul::k_align16; + using PpMatmul::n_align16; + using PpMatmul::m0; + using PpMatmul::k0; + using PpMatmul::n0; + using PpMatmul::m_loop; + using PpMatmul::n_loop; + using PpMatmul::k_loop; + using PpMatmul::core_loop; + using PpMatmul::core_idx; + using PpMatmul::ping_flag; + using PpMatmul::block_size; + using PpMatmul::cube_matrix_size; + using PpMatmul::aligned_a; + using PpMatmul::aligned_b; + using PpMatmul::swizzl_count; + using PpMatmul::swizzl_direct; + using PpMatmul::L1_PINGPONG_BUFFER_LEN; + using PpMatmul::L0AB_PINGPONG_BUFFER_LEN; + using PpMatmul::rank; + using PpMatmul::rank_size; + using PpMatmul::p_value; + using PpMatmul::loop_num_per_comm; + using PpMatmul::InitFlags; + using PpMatmul::Endflags; + using PpMatmul::MoveL0CToGM; + using PpMatmul::dequant_granularity; + using PpMatmul::workspace_info; + using PpMatmul::withSerialMode; + using PpMatmul::weight_nz; + using PpMatmul::CalLoop; + using PpMatmul::buffer_size; + +private: + int32_t EP; + int32_t TP; + int32_t maxOutputSize; + __gm__ int32_t *num_local_tokens_per_expert; + __gm__ int32_t *num_global_tokens_per_local_expert; + __gm__ int32_t * global_tokens_per_expert_matrix; + + __gm__ int32_t* gm_out_loop_per_expert; + __gm__ int32_t* gm_in_loop_per_expert; + __gm__ int32_t* gm_out_loop_per_EP; + __gm__ int32_t* gm_in_loop_per_EP; + __gm__ int32_t* gm_sum_num_local_tokens_per_expert; + __gm__ int32_t* gm_sum_num_global_tokens_per_local_expert; + __gm__ int32_t* gm_num_local_tokens_per_expert; + __gm__ int32_t* gm_num_global_tokens_per_local_expert; + __gm__ int32_t *gm_in_expert_comm_count_accum; + __gm__ int32_t *gm_out_expert_comm_count_accum; + + int32_t expert_nums; + int32_t local_expert_nums; + int32_t is_moe_averaged; + int32_t is_alltoallvc; +}; +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_postprocessor.cce b/comm/lcal/src/kernels/coc_postprocessor.cce new file mode 100644 index 0000000000000000000000000000000000000000..9609cc18ee127f03d8f260f9896a2176503bbf75 --- /dev/null +++ b/comm/lcal/src/kernels/coc_postprocessor.cce @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef __COC_POSTPROCESSOR__ +#define __COC_POSTPROCESSOR__ + +#ifdef __DAV_C220_VEC__ + +#include +#include "coc_internal.cce" +#include "kernel_operator.h" +#include "tiling_args.h" +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 1; // tensor num for each queue +constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float); +constexpr int32_t NUM_PER_BLK_FP32 = 8; +constexpr float MINUS_HALF = -0.5; +constexpr float ZERO = 0; +constexpr float ONE = 1; + +template +class RMSNormprocessor { +public: + __aicore__ explicit RMSNormprocessor() = default; + FORCE_INLINE_AICORE void SetArgs(__gm__ uint8_t *gm_in, __gm__ uint8_t *gm_out, __gm__ uint8_t *gm_gamma, + uint32_t m, uint32_t n) + { + this->rmsnorm_in = reinterpret_cast<__gm__ T_in *>(gm_out); + this->rmsnorm_gamma = reinterpret_cast<__gm__ T_out *>(gm_gamma); + this->rmsnorm_out = reinterpret_cast<__gm__ T_out *>(gm_out); + this->m = m; + this->n = n; + this->core_used = core_used; + } + + // 暂时只支持float16 + struct UBufConfig { + int64_t global_subblock_idx; + int64_t total_subblock; + __ubuf__ half *gamma; + __ubuf__ half *fp16_0; + __ubuf__ float *fp32_0; + __ubuf__ float *sqx0; + __ubuf__ float *sum_tmp0; + __ubuf__ float *sum0; + __ubuf__ half *fp16_1; + __ubuf__ float *fp32_1; + __ubuf__ float *sqx1; + __ubuf__ float *sum_tmp1; + __ubuf__ float *sum1; + float epsilon; + bool ping; + }; + + FORCE_INLINE_AICORE UBufConfig InitializeUBufConfig() + { + UBufConfig config; + config.global_subblock_idx = AscendC::GetBlockIdx(); + config.total_subblock = AscendC::GetBlockNum() * AscendC::GetTaskRation(); + + config.gamma = (__ubuf__ half *)get_imm(0); + config.fp16_0 = (__ubuf__ half *)get_imm(1 * 16 * 1024); + config.fp32_0 = (__ubuf__ float *)get_imm(2 * 16 * 1024); + config.sqx0 = (__ubuf__ float *)get_imm(4 * 16 * 1024); + config.sum_tmp0 = (__ubuf__ float *)config.fp16_0; + config.sum0 = (__ubuf__ float *)config.fp16_0 + 64; + + config.fp16_1 = (__ubuf__ half *)get_imm(1 * 16 * 1024 + 96 * 1024); + config.fp32_1 = (__ubuf__ float *)get_imm(2 * 16 * 1024 + 96 * 1024); + config.sqx1 = (__ubuf__ float *)get_imm(4 * 16 * 1024 + 96 * 1024); + config.sum_tmp1 = (__ubuf__ float *)config.fp16_1; + config.sum1 = (__ubuf__ float *)config.fp16_1 + 64; + config.epsilon = 1e-6; + config.ping = true; + return config; + } + + FORCE_INLINE_AICORE void RMSNormRun() + { + SetMaskCount(); + SetAtomicNone(); + + UBufConfig ubufConfig = InitializeUBufConfig(); + + CopyGmToUbufAlign(ubufConfig.gamma, (__gm__ half *)rmsnorm_gamma, 1, n, 0, 0); + PipeBarrier(); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + + for (int64_t global_row_id = ubufConfig.global_subblock_idx % ubufConfig.total_subblock; global_row_id < m; + global_row_id += ubufConfig.total_subblock) { + auto &fp16 = ubufConfig.ping ? ubufConfig.fp16_0 : ubufConfig.fp16_1; + auto &fp32 = ubufConfig.ping ? ubufConfig.fp32_0 : ubufConfig.fp32_1; + auto &sqx = ubufConfig.ping ? ubufConfig.sqx0 : ubufConfig.sqx1; + auto &sum_tmp = ubufConfig.ping ? ubufConfig.sum_tmp0 : ubufConfig.sum_tmp1; + auto &sum = ubufConfig.ping ? ubufConfig.sum0 : ubufConfig.sum1; + auto event_id = ubufConfig.ping ? EVENT_ID0 : EVENT_ID1; + + WaitFlag(event_id); + CopyGmToUbufAlign(fp16, (__gm__ half *)rmsnorm_in + global_row_id * n, 1, n, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + + // fp16 -> fp32 + SetVectorMask(0x0, n); + Vconv(((__ubuf__ float *)fp32), ((__ubuf__ half *)fp16), 1, 1, 1, 8, 4); + PipeBarrier(); + + // x^2 + Vmul(((__ubuf__ float *)sqx), ((__ubuf__ float *)fp32), ((__ubuf__ float *)fp32), 1, 1, 1, 1, 8, 8, 8); + PipeBarrier(); + + // x^2 / n + float average_val = 1.f / n; + Vmuls(((__ubuf__ float *)sqx), ((__ubuf__ float *)sqx), average_val, 1, 1, 1, 8, 8); + PipeBarrier(); + + // sum(x^2 / n) + SetVectorMask(0x0, 64); + VectorDup(((__ubuf__ float *)sum_tmp), 0.f, 1, 1, 8); + PipeBarrier(); + + SetVectorMask(0x0, n); + Vadd(((__ubuf__ float *)sum_tmp), ((__ubuf__ float *)sqx), ((__ubuf__ float *)sum_tmp), 1, 1, 1, 1, 0, 8, + 0); + PipeBarrier(); + + SetVectorMask(0x0, 64); + vcadd(((__ubuf__ float *)sum), ((__ubuf__ float *)sum_tmp), 1, 0, 1, 0, 0); + PipeBarrier(); + + // x * 1 / sqrt(sum(x^2 / n) + eps) + SetVectorMask(0x0, n); + SetFlag(event_id); + WaitFlag(event_id); + float mul_val = 1.f / sqrt(sum[0] + ubufConfig.epsilon); + PipeBarrier(); + SetFlag(event_id); + WaitFlag(event_id); + Vmuls(((__ubuf__ float *)fp32), ((__ubuf__ float *)fp32), mul_val, 1, 1, 1, 8, 8); + PipeBarrier(); + + // fp32 -> fp16 + Vconv(((__ubuf__ half *)fp16), ((__ubuf__ float *)fp32), 1, 1, 1, 4, 8); + PipeBarrier(); + + // x * 1 / sqrt(sum(x^2 / n) + eps) * g + Vmul(((__ubuf__ half *)fp16), ((__ubuf__ half *)fp16), ((__ubuf__ half *)ubufConfig.gamma), 1, 1, 1, 1, 8, + 8, 8); + PipeBarrier(); + SetFlag(event_id); + WaitFlag(event_id); + + CopyUbufToGmAlign((__gm__ half *)rmsnorm_out + global_row_id * n, fp16, 1, n, 0, 0); + SetFlag(event_id); + ubufConfig.ping = !ubufConfig.ping; + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + } + +private: + __gm__ T_out *rmsnorm_gamma; + __gm__ T_in *rmsnorm_in; + __gm__ T_out *rmsnorm_out; + int32_t m; + int32_t n; + int32_t core_used; +}; + +template +class Postprocessor { +public: + __aicore__ explicit Postprocessor() = default; + + FORCE_INLINE_AICORE void SetArgs(PP_MATMUL_AIV_POST_ARGS_FUN()) + { + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + this->with_rms_norm = para->postInfo.withRmsNorm; + if (this->with_rms_norm) { + uint32_t m = cocTilingData->m; + uint32_t n = cocTilingData->n; + rmsnormprocessor.SetArgs(gm_out, gm_out, gm_gamma, m, n); + } + } + + FORCE_INLINE_AICORE void Run() + { + // mode, flag_id + FFTSCrossCoreSync(0, 0); + WaitEvent(0); + if (this->with_rms_norm) { + rmsnormprocessor.RMSNormRun(); + } + } + +private: + int32_t with_rms_norm; + RMSNormprocessor rmsnormprocessor; +}; + +#endif +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_ppmatmul.cce b/comm/lcal/src/kernels/coc_ppmatmul.cce new file mode 100644 index 0000000000000000000000000000000000000000..ac9521d26f029179cace90d848fbf8206907449f --- /dev/null +++ b/comm/lcal/src/kernels/coc_ppmatmul.cce @@ -0,0 +1,1471 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef __PP_MATMUL__ +#define __PP_MATMUL__ +#include "coc_internal.cce" +#include "lcoc_workspace.h" +template +struct GetAccumType { + using T = float; +}; + +template <> +struct GetAccumType { + using T = int32_t; +}; + +#ifdef __DAV_C220_CUBE__ + +constexpr int32_t L0AB_PINGPONG_BUFFER_SIZE = 32768; // 32 KB +constexpr int32_t CUBE_MATRIX_SIZE_B16 = 256; // 16 * 16 +constexpr int32_t CUBE_MATRIX_SIZE_B8 = 16 * 32; // 16 * 32 +constexpr int64_t ND2NZ_STRIDE_LIMIT = 65536; +constexpr int32_t SCALE_L1_SIZE = 256 * 8; // 2 KB + +template +inline __aicore__ void CopyCubfToBt(uint64_t dst, __cbuf__ T *src, uint16_t convControl, uint16_t nBurst, uint16_t lenBurst, uint16_t sourceGap, uint16_t dstGap) +{ + DataCopyParams intriParams(nBurst, lenBurst, sourceGap, dstGap); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); // L1 + uint8_t dst_logicpos = static_cast(TPosition::C2); // Bias + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + DataCopy(dstTensor, srcTensor, intriParams); +} + +template +inline __aicore__ void CopyGmToCbuf(__cbuf__ T *dst, __gm__ T *src, uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride, pad_t padMode) +{ + DataCopyParams intriParams(nBurst, lenBurst, srcStride, dstStride); + GlobalTensor srcTensor; + srcTensor.SetGlobalBuffer(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t logicpos = static_cast(TPosition::C1); // L1 + LocalTensor dstTensor; + dstTensor = CreateLocalTensor(dst_buffer_offset, logicpos); + DataCopy(dstTensor, srcTensor, intriParams); +} + + +template +inline __aicore__ void SetFpc(__fbuf__ T *src) +{ + LocalTensor tensor; + uint32_t src_buffer_offset = reinterpret_cast(src); + tensor = CreateLocalTensor(src_buffer_offset); + SetFixPipeConfig(tensor); +} + + +template +inline __aicore__ void LoadCbufToCaTranspose(__ca__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat, uint16_t srcStride, uint16_t dstStride, bool addrmode, uint16_t dstFracStride) +{ + LoadData2dTransposeParams params( + indexID, + repeat, + srcStride, + dstStride, + dstFracStride, + addrmode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); // L1 + uint8_t dst_logicpos = static_cast(TPosition::A2); // L0A + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadDataWithTranspose(dstTensor, srcTensor, params); +} + +template +inline __aicore__ void LoadCbufToCbTranspose(__cb__ T *dst, __cbuf__ T *src, uint16_t indexID, uint8_t repeat, uint16_t srcStride, uint16_t dstStride, bool addrmode, uint16_t dstFracStride) +{ + LoadData2dTransposeParams params( + indexID, + repeat, + srcStride, + dstStride, + dstFracStride, + addrmode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); // L1 + uint8_t dst_logicpos = static_cast(TPosition::B2); // L0A + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadDataWithTranspose(dstTensor, srcTensor, params); +} + +template +inline __aicore__ void LoadCbufToCa(__ca__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose, uint8_t addr_cal_mode) +{ + LoadData2dParams params( + baseIdx, + repeat, + srcStride, + sid, + dstStride, + transpose, + addr_cal_mode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); // L1 + uint8_t dst_logicpos = static_cast(TPosition::A2); // L0A + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadData(dstTensor, srcTensor, params); +} + + +template +inline __aicore__ void LoadCbufToCb(__cb__ T *dst, __cbuf__ T *src, uint16_t baseIdx, uint8_t repeat, uint16_t srcStride, uint16_t dstStride, uint8_t sid, bool transpose, uint8_t addr_cal_mode) +{ + LoadData2dParams params( + baseIdx, + repeat, + srcStride, + sid, + dstStride, + transpose, + addr_cal_mode + ); + uint32_t src_buffer_offset = reinterpret_cast(src); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t src_logicpos = static_cast(TPosition::C1); // L1 + uint8_t dst_logicpos = static_cast(TPosition::B2); // L0B + LocalTensor srcTensor; + LocalTensor dstTensor; + srcTensor = CreateLocalTensor(src_buffer_offset, src_logicpos); + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + LoadData(dstTensor, srcTensor, params); +} + +template +struct IntrinsicCopyGmToL1Nd2Nz { + static inline __aicore__ void move( + __cbuf__ T *dst, __gm__ T *src, + uint8_t sid, uint16_t ndNum, uint16_t nValue, uint16_t dValue, + uint16_t srcNdMatrixStride, uint16_t srcDValue, uint16_t dstNzC0Stride, + uint16_t dstNzNStride, uint16_t dstNzMatrixStride) { + Nd2NzParams nd2nzParams( + ndNum, nValue, dValue, + srcNdMatrixStride, srcDValue, dstNzC0Stride, + dstNzNStride, dstNzMatrixStride + ); + uint32_t dst_buffer_offset = reinterpret_cast(dst); + uint8_t dst_logicpos = static_cast(TPosition::C1); + LocalTensor dstTensor; + dstTensor = CreateLocalTensor(dst_buffer_offset, dst_logicpos); + GlobalTensor srcTensor; + srcTensor.SetGlobalBuffer(src); + DataCopy(dstTensor, srcTensor, nd2nzParams); + } +}; + + + +template +struct CopyGmToL1Nd2zN { + static inline __aicore__ void move( + __cbuf__ T *dst, __gm__ T *src, + uint16_t nValue, uint16_t dValue, uint32_t srcDValue, uint16_t dstNzC0Stride) { + constexpr int BLOCK_LEN = 32 / sizeof(T); + if (srcDValue < ND2NZ_STRIDE_LIMIT) { + IntrinsicCopyGmToL1Nd2Nz::move( + dst, + src, + 0, // sid + 1, // ndNum + nValue, // nValue + dValue, // dValue + 0, // srcNdMatrixStride, unused + srcDValue, // srcDValue + dstNzC0Stride, // dstNzC0Stride + 1, // dstNzNStride, + 0 // dstNzMatrixStride, unused + ); + } else { + for (int i = 0; i < nValue; i++) { + IntrinsicCopyGmToL1Nd2Nz::move( + dst + i * BLOCK_LEN, + src + i * srcDValue, + 0, // sid + 1, // ndNum + 1, // nValue + dValue, // dValue + 0, // srcNdMatrixStride, unused + 0, // srcDValue, unused + dstNzC0Stride, // dstNzC0Stride + 0, // dstNzNStride, unused + 0 // dstNzMatrixStride, unused + ); + } + } + } +}; + +template +class PpMatmul { + using T_ACCUM = typename GetAccumType::T; + static constexpr bool IS_INT8 = std::is_same::value; +public: + __aicore__ explicit PpMatmul() {}; + + inline __aicore__ void SetArgs(PP_MATMUL_AIC_ARGS_FUN(MmadDtype, OutDtype)) + { + this->gm_c = reinterpret_cast<__gm__ OutDtype *>(gm_c); + this->gm_peer_mem = reinterpret_cast<__gm__ OutDtype *>(gm_peer_mem); + this->gm_dequant_scale = reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale); + has_offset = gm_dequant_offset != nullptr; + + + this->batch_size = batch_size; + this->m = m; + this->k = k; + this->n = n; + this->weight_nz = weight_nz; + + cube_matrix_size = IS_INT8 ? CUBE_MATRIX_SIZE_B8 : CUBE_MATRIX_SIZE_B16; + + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + + this->m0 = m0; + this->k0 = k0; + this->n0 = n0; + + this->dequant_granularity = dequant_granularity; + + AlignJudge(TA, TB, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + bool has_a_align = IsQuant(quant_granularity) || aligned_a; + bool has_b_align = IsQuant(dequant_granularity) && !IS_INT8 || aligned_b; + if (weight_nz) { + //k_align16 = Block32B::AlignUp(k); + k_align16 = (k + 16 - 1) / 16 * 16; + n_align16 = Block32B::AlignUp(n); + aligned_b = 0; // dont' do padding for nz weight + has_b_align = false; + } + bool has_accum = IsQuant(dequant_granularity) && IS_INT8 && std::is_same::value; + bool has_format_dequant_offset = (dequant_granularity == QuantGranularity::PER_TENSOR) && IS_INT8 && has_offset; + // if allgather, workspace *= rank size + int32_t accum_rank_size = 1; + if (RUN_TYPE == PPMATMUL_RUN_ALL_GATHER_MATMUL) { + accum_rank_size = rank_size; + } + int32_t is_moe_averaged = 0; + int32_t is_alltoallvc = 0; + + if (num_local_tokens_per_expert == nullptr && num_global_tokens_per_local_expert == nullptr && + global_tokens_per_expert_matrix == nullptr){ + is_moe_averaged = 1; + } else if(global_tokens_per_expert_matrix != nullptr) { + is_alltoallvc = 1; + } else { + is_alltoallvc = 0; + } + bool has_dequant_param = (dequant_granularity == QuantGranularity::PER_TOKEN || dequant_granularity == QuantGranularity::PER_TENSOR); + bool hasFormatDequantScale = (dequant_granularity == QuantGranularity::PER_CHANNEL); + + workspace_info = GetLcalWorkspaceInfo(gm_workspace, batch_size, m, k, n, m_align, k_align, n_align, + TA, TB, sizeof(MmadDtype), has_a_align, has_b_align, accum_rank_size, has_accum, 0, has_dequant_param, + hasFormatDequantScale,is_deterministic, is_moe, is_alltoallvc, EP, local_expert_nums, maxOutputSize); + + gm_a_src = reinterpret_cast<__gm__ MmadDtype *>(has_a_align ? workspace_info.gm_a_align : gm_a); + gm_b_src = reinterpret_cast<__gm__ MmadDtype *>(has_b_align ? workspace_info.gm_b_align : gm_b); + gm_accum = reinterpret_cast<__gm__ int32_t *>(workspace_info.gm_accum); + gm_format_dequant_offset = reinterpret_cast<__gm__ int32_t *>(has_format_dequant_offset ? + workspace_info.gm_dequant_param : gm_dequant_offset); + + block_size = 32 / sizeof(MmadDtype); + + L1_PINGPONG_BUFFER_LEN = ((m0 * k0 + cube_matrix_size - 1) / cube_matrix_size * cube_matrix_size + + (n0 * k0 + cube_matrix_size - 1) / cube_matrix_size * cube_matrix_size) * (IS_INT8 ? 2 : 1); + L0AB_PINGPONG_BUFFER_LEN = L0AB_PINGPONG_BUFFER_SIZE / sizeof(MmadDtype); + + int32_t a_l1_size = m0 * k0 * sizeof(MmadDtype); + int32_t a_l1_size_round = DivCeil(a_l1_size, 512) * 512; + int32_t b_l1_size = n0 * k0 * sizeof(MmadDtype); + int32_t b_l1_size_round = DivCeil(b_l1_size, 512) * 512; + l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t)(IS_INT8 ? SCALE_L1_SIZE : 0)); + l1_base_b = reinterpret_cast<__cbuf__ MmadDtype *>(a_l1_size_round * (IS_INT8 ? 2 : 1) + (uintptr_t) l1_base_a); + + core_num = get_block_num(); + core_idx = get_block_idx(); + + this->m_loop = m_loop; + this->k_loop = k_loop; + this->n_loop = n_loop; + this->core_loop = core_loop; + this->swizzl_count = swizzl_count; + this->swizzl_direct = swizzl_direct; + this->is_91093 = is_91093; + ping_flag = 1; + this->rank = rank; + this->rank_size = rank_size; + this->p_value = p_value; + this->withSerialMode = withSerialMode; + loop_num_per_comm = p_value * core_num; + this->buffer_size = buffer_size; + + // 2dtp 确定本卡的ag和rs分别的idx + this->ag_dim = ag_dim; + this->rs_dim = rs_dim; + this->inner_dim_is_Ag = inner_dim_is_Ag; + if (inner_dim_is_Ag) { + this->ag_rank_idx = rank % ag_dim; + this->rs_rank_idx = rank / ag_dim; + }else { + this->ag_rank_idx = rank / rs_dim; + this->rs_rank_idx = rank % rs_dim; + } + } + + inline __aicore__ void CalLoop(int64_t batch_idx, int64_t m_idx, int64_t n_idx, int32_t m_actual, int32_t n_actual, + __gm__ MmadDtype *gm_a_src_tmp) { + int64_t offset_a, offset_b, offset_a_next, offset_b_next; + int32_t m_round, n_round; + if (IS_INT8) { + // directive Restrictions + if (TA) { + m_round = DivCeil(m_actual, BLOCK_SIZE_32) * BLOCK_SIZE_32; + } else { + m_round = DivCeil(m_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } + if (TB) { + n_round = DivCeil(n_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } else { + n_round = DivCeil(n_actual, BLOCK_SIZE_32) * BLOCK_SIZE_32; + } + } else { + m_round = DivCeil(m_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + n_round = DivCeil(n_actual, BLOCK_SIZE_16) * BLOCK_SIZE_16; + } + + int32_t mn_max = m_round > n_round ? m_round : n_round; + int32_t k_part_len = L0AB_PINGPONG_BUFFER_LEN / mn_max / block_size * block_size; + if (TA) { + if (aligned_a == 1) { + offset_a = batch_idx * k * m_align + m_idx * m0; + } else { + offset_a = batch_idx * k * m + m_idx * m0; + } + } else { + if (aligned_a == 1) { + offset_a = batch_idx * m * k_align + m_idx * m0 * k_align; + } else { + offset_a = batch_idx * m * k + m_idx * m0 * k; + } + } + if (TB) { + if (aligned_b == 1) { + offset_b = n_idx * n0 * k_align; + } else { + if (weight_nz) { + offset_b = n_idx * n0 * block_size; + } else { + offset_b = n_idx * n0 * k; + } + } + } else { + if (aligned_b == 1) { + offset_b = n_idx * n0; + } else { + if (weight_nz) { + offset_b = n_idx * n0 * k_align16; + } else { + offset_b = n_idx * n0; + } + } + } + int64_t dequant_param_offset = n_idx * n0; + + int32_t k_actual = (k_loop == 1) ? k : k0; + int32_t k_round = DivCeil(k_actual, block_size) * block_size; // int8 :32 fp16 :16 + + auto l1_buf_a = ping_flag ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + auto l1_buf_b = ping_flag ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto l0a_buf = ping_flag ? l0a_base : l0a_base + L0AB_PINGPONG_BUFFER_LEN; + auto l0b_buf = ping_flag ? l0b_base : l0b_base + L0AB_PINGPONG_BUFFER_LEN; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (IS_INT8 && has_offset) { + PipeBarrier(); + IntrinsicCopyGmToL1Nd2Nz::move( + ((__cbuf__ int32_t *)bias_l1), + ((__gm__ int32_t *)gm_format_dequant_offset) + dequant_param_offset, + 0, // sid + 1, // ndNum + 1, // nValue + n_actual, // dValue + 0, // srcNdMatrixStride, unused + n, // srcDValue + 1, // dstNzC0Stride + 1, // dstNzNStride + 0 // dstNzMatrixStride, unused + ); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); // int8需要wait MTE1等FIX + CopyCubfToBt(((uint64_t)bias_bt), ((__cbuf__ int32_t *)bias_l1), + (uint16_t)0ULL, 1, (n_actual * 4 + 63) / 64, 0, 0); + SetFlag(EVENT_ID1); // bias ready, mte2 can begin move A/B or scalar + SetFlag(EVENT_ID1); // bias ready, mmad can begin + WaitFlag(EVENT_ID1); // A/B or scalar wait moving bias from L1 to BT + } + + auto gm_src_a = gm_a_src_tmp + offset_a; + auto gm_src_b = gm_b_src + offset_b; + + WaitFlag(event_id); + // *** load matrix A to L1 + if (m == 1 || m_actual == 1 && !TA) { + CopyGmToCbuf( + l1_buf_a, + gm_src_a, + 0, // sid + 1, // nBurst + k_round / block_size, // lenBurst + 0, // srcGap + 0, // dstGap + PAD_NONE // padMode + ); + } else { + if (TA) { + auto src_len = m; + if (aligned_a == 1) { + src_len = m_align; + } + CopyGmToL1Nd2zN::move(l1_buf_a, gm_src_a, k_actual, m_actual, src_len, k_round); + } else { + auto src_len = k; + if (aligned_a == 1) { + src_len = k_align; + } + CopyGmToL1Nd2zN::move(l1_buf_a, gm_src_a, m_actual, k_actual, src_len, m_round); + } + } + SetFlag(event_id); + + // *** load matrix B to L1 + WaitFlag(event_id + 2); + if (TB) { + auto src_len = k; + if (aligned_b == 1) { + src_len = k_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(k_actual, block_size); + CopyGmToCbuf(l1_buf_b, gm_src_b, 0, num_col, n_actual, n_align16 - n_actual, n_round - n_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b, gm_src_b, n_actual, k_actual, src_len, n_round); + } + } else { + auto src_len = n; + if (aligned_b == 1) { + src_len = n_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(n_actual, block_size); + CopyGmToCbuf(l1_buf_b, gm_src_b, 0, num_col, k_actual, k_align16 - k_actual, k_round - k_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b, gm_src_b, k_actual, n_actual, src_len, k_round); + } + } + SetFlag(event_id + 2); + + int mte1_mad_ping_flag = 1; + + for (int64_t k_idx = 0; k_idx < k_loop; k_idx++) { + + int32_t k_actual = (k_idx == (k_loop - 1)) ? (k - k_idx * k0) : k0; + int32_t k_round = DivCeil(k_actual, block_size) * block_size; + int32_t k_part_loop = DivCeil(k_actual, k_part_len); + + __cbuf__ MmadDtype *l1_buf_a = ping_flag ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + __cbuf__ MmadDtype *l1_buf_b = ping_flag ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1; + + if (k_idx < k_loop - 1) { + if (TA) { + if (aligned_a == 1) { + offset_a_next = batch_idx * k * m_align + (k_idx + 1) * k0 * m_align + m_idx * m0; + } else { + offset_a_next = batch_idx * k * m + (k_idx + 1) * k0 * m + m_idx * m0; + } + } else { + if (aligned_a == 1) { + offset_a_next = batch_idx * m * k_align + m_idx * m0 * k_align + (k_idx + 1) * k0; + } else { + offset_a_next = batch_idx * m * k + m_idx * m0 * k + (k_idx + 1) * k0; + } + } + if (TB) { + if (aligned_b == 1) { + offset_b_next = batch_idx * n * k_align + n_idx * n0 * k_align + (k_idx + 1) * k0; + } else { + if (weight_nz) { + offset_b_next = batch_idx * n * k + (k_idx + 1) * k0 * n_align16 + n_idx * n0 * block_size; + } else { + offset_b_next = batch_idx * n * k + n_idx * n0 * k + (k_idx + 1) * k0; + } + } + } else { + if (aligned_b == 1) { + offset_b_next = batch_idx * k * n_align + (k_idx + 1) * k0 * n_align + n_idx * n0; + } else { + if (weight_nz) { + offset_b_next = batch_idx * k * n + (k_idx + 1) * k0 * block_size + n_idx * n0 * k_align16; + } else { + offset_b_next = batch_idx * k * n + (k_idx + 1) * k0 * n + n_idx * n0; + } + } + } + + int32_t k_actual_next = ((k_idx + 1) == (k_loop - 1)) ? (k - (k_idx + 1) * k0) : k0; + int32_t k_round_next = DivCeil(k_actual_next, block_size) * block_size; + + __cbuf__ MmadDtype *l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a + L1_PINGPONG_BUFFER_LEN; + __cbuf__ MmadDtype *l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b + L1_PINGPONG_BUFFER_LEN; + auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1; + + auto gm_src_a = gm_a_src_tmp + offset_a_next; + auto gm_src_b = gm_b_src + offset_b_next; + + WaitFlag(event_id_next); + // *** load matrix A to L1 + if (m == 1 || m_actual == 1 && !TA) { + CopyGmToCbuf( + l1_buf_a_next, + gm_src_a, + 0, // sid + 1, // nBurst + k_round_next / block_size, // lenBurst + 0, // srcGap + 0, // dstGap + PAD_NONE // padMode + ); + } else { + if (TA) { + auto src_len = m; + if (aligned_a == 1) { + src_len = m_align; + } + CopyGmToL1Nd2zN::move( + l1_buf_a_next, gm_src_a, k_actual_next, m_actual, src_len, k_round_next); + } else { + auto src_len = k; + if (aligned_a == 1) { + src_len = k_align; + } + CopyGmToL1Nd2zN::move( + l1_buf_a_next, gm_src_a, m_actual, k_actual_next, src_len, m_round); + } + } + SetFlag(event_id_next); + + // *** load matrix B to L1 + WaitFlag(event_id_next + 2); + if (TB) { + auto src_len = k; + if (aligned_b == 1) { + src_len = k_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(k_actual_next, block_size); + CopyGmToCbuf(l1_buf_b_next, gm_src_b, 0, num_col, n_actual, n_align16 - n_actual, n_round - n_actual, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b_next, gm_src_b, n_actual, k_actual_next, src_len, n_round); + } + } else { + auto src_len = n; + if (aligned_b == 1) { + src_len = n_align; + } + if (weight_nz) { + int32_t num_col = DivCeil(n_actual, block_size); + CopyGmToCbuf(l1_buf_b_next, gm_src_b, 0, num_col, k_actual_next, k_align16 - k_actual_next, k_round_next - k_actual_next, PAD_NONE); + } else { + CopyGmToL1Nd2zN::move(l1_buf_b_next, gm_src_b, k_actual_next, n_actual, src_len, k_round_next); + } + } + SetFlag(event_id_next + 2); + } + + for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) { + int32_t k0_round = (k_part_idx < k_part_loop - 1) ? + k_part_len : k_round - k_part_idx * k_part_len; + int32_t k0_actual = (k_part_idx < k_part_loop - 1) ? + k_part_len : k_actual - k_part_idx * k_part_len; + + auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1; + auto l0a_buf = l0a_base + (1 - mte1_mad_ping_flag) * L0AB_PINGPONG_BUFFER_LEN; + auto l0b_buf = l0b_base + (1 - mte1_mad_ping_flag) * L0AB_PINGPONG_BUFFER_LEN; + + // *** load matrix A from L1 to L0A + if (k_part_idx == 0) { + WaitFlag(event_id); + } + WaitFlag(mte1_mad_event_id); + if (m == 1 || m_actual == 1 && !TA) { + LoadCbufToCa( + l0a_buf, + l1_buf_a + k_part_idx * k_part_len, + 0, // baseIdx + DivCeil(k0_round, cube_matrix_size), // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } else { + if (TA) { + if (IS_INT8) { + for (int i = 0; i < m_round / BLOCK_SIZE_32; i++) { + LoadCbufToCaTranspose( + l0a_buf + i * k0_round * BLOCK_SIZE_32, + l1_buf_a + k_part_idx * k_part_len * BLOCK_SIZE_32 + + i * k_round * BLOCK_SIZE_32, + 0, // baseIdx + k0_round / BLOCK_SIZE_32, // repeat + 1, // srcStride + 0, // dstStride + 0, // addrmode + k0_round / BLOCK_SIZE_32 - 1 // dstFracStride + ); + } + } else { + for (int i = 0; i < m_round / BLOCK_SIZE_16; i++) { + LoadCbufToCa( + l0a_buf + i * k0_round * BLOCK_SIZE_16, + l1_buf_a + k_part_idx * k_part_len * BLOCK_SIZE_16 + + i * k_round * BLOCK_SIZE_16, + 0, // baseIdx + k0_round / BLOCK_SIZE_16, // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + true, // transpose + inc // addr_cal_mode_t + ); + } + } + } else { + for (int32_t i = 0; i < k0_round / block_size; i++) { + LoadCbufToCa( + l0a_buf + i * cube_matrix_size, + l1_buf_a + k_part_idx * k_part_len * m_round + + i * m_round * block_size, + 0, // baseIdx + m_round / BLOCK_SIZE_16, // repeat + 1, // srcStride + k0_round / block_size - 1, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } + } + } + if (k_part_idx == k_part_loop - 1) { + SetFlag(event_id); + } + + // *** load matrix B from L1 to L0B + if (k_part_idx == 0) { + WaitFlag(event_id + 2); + } + if (TB) { + LoadCbufToCb( + l0b_buf, + l1_buf_b + k_part_idx * k_part_len * n_round, + 0, // baseIdx + k0_round * n_round / cube_matrix_size, // repeat + 1, // srcStride + 0, // dstStride + 0, // sid + false, // transpose + inc // addr_cal_mode_t + ); + } else { + if (IS_INT8) { + for (int32_t i = 0; i < k0_round / BLOCK_SIZE_32; i++) { + LoadCbufToCbTranspose( + l0b_buf + i * ((n_actual + 15) / 16 * 16) * BLOCK_SIZE_32, + l1_buf_b + (k_part_idx * k_part_len + i * BLOCK_SIZE_32) * BLOCK_SIZE_32, + 0, // baseIdx + n_round / BLOCK_SIZE_32, // repeat + k_round / BLOCK_SIZE_32, // srcStride + 1, // dstStride + 0, // addrmode + 0 // dstFracStride + ); + } + } else { + for (int32_t i = 0; i < k0_round / BLOCK_SIZE_16; i++) { + LoadCbufToCb( + l0b_buf + i * n_round * BLOCK_SIZE_16, + l1_buf_b + (k_part_idx * k_part_len + i * BLOCK_SIZE_16) * BLOCK_SIZE_16, + 0, // baseIdx + n_round / BLOCK_SIZE_16, // repeat + k_round / BLOCK_SIZE_16, // srcStride + 0, // dstStride + 0, // sid + true, // transpose + inc // addr_cal_mode_t + ); + } + } + } + if (k_part_idx == k_part_loop - 1) { + SetFlag(event_id + 2); + } + + SetFlag(mte1_mad_event_id); + WaitFlag(mte1_mad_event_id); + + bool init_c = (k_idx == 0 && k_part_idx == 0); + if (init_c) { + WaitFlag(EVENT_ID0); + } + + if (IS_INT8 && has_offset) { + if (init_c) { + WaitFlag(EVENT_ID1); // wait move bias fron L1 to BT + } + PipeBarrier(); + if (m != 1 && m_actual == 1 && TA) { + mad((__cc__ int32_t *)l0c_buf, + (__ca__ int8_t *)l0a_buf, + (__cb__ int8_t *)l0b_buf, + ((uint64_t)bias_bt), + 16, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + init_c, // cmatrixSource add C from BT + 0 // cmatrixInitVal + ); + } else { + mad((__cc__ int32_t *)l0c_buf, + (__ca__ int8_t *)l0a_buf, + (__cb__ int8_t *)l0b_buf, + ((uint64_t)bias_bt), + m_actual, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + init_c, // cmatrixSource add C from BT + 0 // cmatrixInitVal + ); + } + } else { + PipeBarrier(); + if (m != 1 && m_actual == 1 && TA) { + mad(l0c_buf, + l0a_buf, + l0b_buf, + 16, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + 0, // cmatrixSource + init_c // cmatrixInitVal + ); + } else { + mad(l0c_buf, + l0a_buf, + l0b_buf, + m_actual, // m + k0_actual, // k + n_actual, // n + 0, // unitFlag + 0, // kDirectionAlign + 0, // cmatrixSource + init_c // cmatrixInitVal + ); + } + } + PipeBarrier(); + SetFlag(mte1_mad_event_id); + + mte1_mad_ping_flag = 1 - mte1_mad_ping_flag; + } + ping_flag = 1 - ping_flag; + } + + + if (IS_INT8 && std::is_same::value && (dequant_granularity == QuantGranularity::PER_CHANNEL || + dequant_granularity == QuantGranularity::PER_TOKEN)) { + WaitFlag(EVENT_ID0); + PipeBarrier(); + CopyGmToCbuf( + scale_l1, + gm_dequant_scale + dequant_param_offset, + 0, + 1, + (n_actual * sizeof(int64_t) + 31) / 32, + 0, + 0, + PAD_NONE + ); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + + copy_cbuf_to_fbuf( + scale_FB, + scale_l1, + 1, + (n_actual * sizeof(int64_t) + 127) / 128, + 0, + 0 + ); + PipeBarrier(); + } + } + + inline __aicore__ void MoveL0CToGM(__gm__ OutDtype *gm_dst, int64_t offset_c, int32_t m_actual, int32_t n_actual, int32_t src_stride, int32_t dst_stride) { + #if (__CCE_AICORE__ == 220) + FixpipeParamsV220 FixpipeParams( + n_actual, // nSize = nSizeIn; + m_actual, // mSize = mSizeIn; + src_stride, // srcStride = srcStrideIn; + dst_stride, // dstStride = dstStrideIn; + false // reluEn = reluEnIn; + ); + #elif (defined(__DAV_C310__)) + FixpipeParamsC310 FixpipeParams( + n_actual, // nSize = nSizeIn; + m_actual, // mSize = mSizeIn; + src_stride, // srcStride = srcStrideIn; + dst_stride // dstStride = dstStrideIn; + ); + #endif + uint64_t src_addr = reinterpret_cast(l0c_buf); + LocalTensor srcTensor = CreateLocalTensor + (reinterpret_cast(l0c_buf), static_cast(TPosition::CO1)); + GlobalTensor dstTensor = CreateGlobalTensor(gm_dst + offset_c); + + if (IS_INT8) { + if constexpr (std::is_same::value) { + if (dequant_granularity == QuantGranularity::PER_CHANNEL || dequant_granularity == QuantGranularity::PER_TOKEN) { + SetFpc(scale_FB); + FixpipeParams.quantPre = VDEQF16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + SetFlag(EVENT_ID0); + } else if (dequant_granularity == QuantGranularity::PER_TENSOR) { + FixpipeParams.quantPre = DEQF16; + FixpipeParams.deqScalar = gm_dequant_scale[0]; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } + } else if constexpr (std::is_same::value) { + GlobalTensor dstAccum = CreateGlobalTensor(gm_accum + offset_c); + Fixpipe(dstAccum, srcTensor, FixpipeParams); + } + } else { + if constexpr (std::is_same::value) { + FixpipeParams.quantPre = F322BF16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } else { + FixpipeParams.quantPre = F322F16; + Fixpipe(dstTensor, srcTensor, FixpipeParams); + } + } + SetFlag(EVENT_ID0); + if (IS_INT8 && has_offset) { + SetFlag(EVENT_ID1); + } + } + + inline __aicore__ void InitFlags() { + WaitEvent(AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID1); // + } + + inline __aicore__ void Endflags() { + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + inline __aicore__ void RunPureMatmul() { + + InitFlags(); + for (int32_t loop_idx = 0; loop_idx < core_loop; loop_idx++) { + if (loop_idx % core_num != core_idx) { + continue; + } + + int64_t batch_idx = loop_idx / (m_loop * n_loop); + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + CalLoop(batch_idx, m_idx, n_idx, m_actual, n_actual, gm_a_src); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + // copy from L0C to gm + MoveL0CToGM(gm_c, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, n); + } + Endflags(); + PipeBarrier(); + + FFTSCrossCoreSync(0, AIC_FINISH_MATMUL_FLAG_ID); + WaitEvent(AIC_FINISH_MATMUL_FLAG_ID); + + FFTSCrossCoreSync(2, AIV_WAIT_AIC_FINISH_MATMUL_FLAG_ID); + PipeBarrier(); + } + + inline __aicore__ void RunMatmulAllReduce() { + InitFlags(); + int32_t comm_count = DivCeil(core_loop, loop_num_per_comm); + int32_t pipe_depth = is_91093 ? BLOCK_COUNT_4 : MAX_BLOCK_COUNT; + for (int32_t cal_idx = 0; cal_idx < comm_count; cal_idx++) { + int32_t loop_idx = cal_idx * core_num + core_idx; + int32_t flag_idx = cal_idx % pipe_depth; + if (cal_idx >= pipe_depth) { + WaitEvent(flag_idx); + } + int32_t actual_loop_num = loop_num_per_comm; + if (cal_idx == comm_count - 1){ + actual_loop_num = core_loop - cal_idx * loop_num_per_comm; + } + for (int32_t p = 0; p < p_value; p++) { + int loop_idx = cal_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + int64_t batch_idx = loop_idx / (m_loop * n_loop); + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + CalLoop(batch_idx, m_idx, n_idx, m_actual, n_actual, gm_a_src); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c; + int32_t n_stride; + // if constexpr (IS_INT8 && std::is_same::value) { + // offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0; + // n_stride = n; + // } else { + offset_c = flag_idx * m0 * loop_num_per_comm * n0 + + (loop_idx % loop_num_per_comm) * m0 * n0; + n_stride = n0; + //} + MoveL0CToGM(gm_peer_mem, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, n_stride); + } + FFTSCrossCoreSync(2, flag_idx); + } + Endflags(); + PipeBarrier(); + } + + inline __aicore__ void RunMatmulReduceScatter() { + int32_t tail_m = (m / rank_size) % m0; + m_loop = m / rank_size / m0; + if (tail_m) { + m_loop += 1; + } + m_loop *= rank_size; + core_loop = batch_size * m_loop * n_loop; + + InitFlags(); + + int32_t comm_num = DivCeil(core_loop, loop_num_per_comm); + // core_loop = batch_size * m_loop * n_loop = p_value * core_num * comm_num + int32_t m_loop_per_rank = m_loop / rank_size; + for (int32_t comm_idx = 0; comm_idx < comm_num; comm_idx++) { + int cur_p_value = p_value; + int32_t actual_loop_num = loop_num_per_comm; + int32_t flag_idx = is_91093 ? comm_idx % BLOCK_COUNT_3 : comm_idx % MAX_BLOCK_COUNT; + if (comm_idx == comm_num - 1) { + actual_loop_num = core_loop - comm_idx * loop_num_per_comm; + } + WaitEvent(flag_idx); + // core_num * p_value + for (int32_t p = 0; p < p_value; p++) { // 每个core一次通信,计算了p_value次 + int loop_idx = comm_idx * p_value * core_num + p * core_num + core_idx; + if (loop_idx >= core_loop) + break; + int64_t batch_idx = loop_idx / (m_loop * n_loop); + int32_t in_batch_idx = loop_idx % (m_loop * n_loop); + int64_t rank_idx = in_batch_idx % rank_size; + int32_t in_rank_idx = in_batch_idx / rank_size; + + int64_t m_idx, n_idx; + GetBlockIdx(in_rank_idx, m_loop_per_rank, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + + int32_t m_actual = (m_idx == (m_loop_per_rank - 1)) ? (m / rank_size - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + __gm__ MmadDtype *gm_a_rank_st; + if (TA) { + gm_a_rank_st = gm_a_src + rank_idx * m / rank_size; + } else { + gm_a_rank_st = gm_a_src + rank_idx * m / rank_size * k_align; + } + CalLoop(batch_idx, m_idx, n_idx, m_actual, n_actual, gm_a_rank_st); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c; + int32_t dst_stride; + __gm__ OutDtype *gm_dst = nullptr; + if (rank_idx == rank && !(IS_INT8 && (dequant_granularity == QuantGranularity::PER_TOKEN|| std::is_same::value))) { + offset_c = batch_idx * m * n / rank_size + m_idx * m0 * n + n_idx * n0; + gm_dst = gm_c; + dst_stride = n; + } else { + int64_t rank_offset_c = (loop_idx % rank_size) * (actual_loop_num / rank_size) * m0 * n0; + offset_c = flag_idx * m0 * loop_num_per_comm * n0 + + rank_offset_c + + ((loop_idx % loop_num_per_comm) / rank_size) * m0 * n0; + gm_dst = gm_peer_mem; + dst_stride = n0; + } + // copy from L0C to gm + MoveL0CToGM(gm_dst, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, dst_stride); + } + FFTSCrossCoreSync(2, flag_idx); + } + + Endflags(); + PipeBarrier(); + } + + inline __aicore__ void DoLocalMatmul() { + for (int32_t loop_idx = 0; loop_idx < core_loop; loop_idx++) { + if (loop_idx % core_num != core_idx) { + continue; + } + int64_t batch_idx = loop_idx / (m_loop * n_loop); + + int64_t m_idx, n_idx; + GetBlockIdx(loop_idx, m_loop, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + + int32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + + CalLoop(batch_idx, m_idx, n_idx, m_actual, n_actual, gm_a_src); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c = batch_idx * m * n * rank_size + (rank * m + m_idx * m0) * n + n_idx * n0; + // copy from L0C to gm + MoveL0CToGM(gm_c, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, n); + } + } + + inline __aicore__ void RunAllGatherMatmul() { + InitFlags(); + // rank + // m_loop * n_loop + DoLocalMatmul(); + + int64_t gm_a_pingpong_size = m0 * k_align * p_value * rank_size; + int32_t comm_count = DivCeil(batch_size * m_loop, p_value); + for (int32_t comm_idx = 0; comm_idx < comm_count; comm_idx++) { + uint64_t flag_id = comm_idx % MAX_BLOCK_COUNT; + if (is_91093) { + flag_id = comm_idx % 3; + } + int32_t actual_p_value = p_value; + if (comm_idx == comm_count - 1) { + actual_p_value = m_loop - comm_idx * p_value; + } + WaitEvent(flag_id); + + // other_rank, p_value * n_loop * (rank_size - 1) + int32_t actual_loop_num_in_other_rank = actual_p_value * (rank_size - 1) * n_loop; + for (int32_t loop_offset = 0; loop_offset < actual_loop_num_in_other_rank; loop_offset++) { + int32_t loop_idx = core_loop + comm_idx * p_value * n_loop * (rank_size - 1) + loop_offset; + if (loop_idx % core_num != core_idx) { + continue; + } + int64_t batch_idx = loop_idx / (m_loop * n_loop * rank_size); + + int64_t m_idx, n_idx; + GetBlockIdx(loop_offset, actual_p_value * (rank_size - 1), n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + + int32_t m_idx_in_rank = m_idx % actual_p_value; + int64_t m_idx_in_c = comm_idx * p_value + m_idx_in_rank; + int32_t m_actual = (m_idx_in_c == (m_loop - 1)) ? (m - m_idx_in_c * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + int64_t rank_idx = m_idx / actual_p_value; + if (rank_idx >= rank) { + rank_idx += 1; + } + __gm__ MmadDtype *gm_peer_mem_st = reinterpret_cast<__gm__ MmadDtype *>(gm_peer_mem) + + flag_id * gm_a_pingpong_size + + rank_idx * p_value * m0 * k_align; + CalLoop(batch_idx, m_idx_in_rank, n_idx, m_actual, n_actual, gm_peer_mem_st); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c = batch_idx * m * n * rank_size + (rank_idx * m + m_idx_in_c * m0) * n + n_idx * n0; + // copy from L0C to gm + MoveL0CToGM(gm_c, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, n); + } + FFTSCrossCoreSync(2, flag_id); + } + Endflags(); + PipeBarrier(); + } + + + // p_value的含义在RS和AG不一样:在RS中,每个core计算p_value次后通信一次;在AG中,每从其他rank各gather p_value行后计算一次 + // 在2DTP中,p_value含义和AG一致。 + inline __aicore__ void RunAllGatherMatmulReduceScatter() { + + InitFlags(); + int32_t twod_big_dim = ag_dim > rs_dim ? ag_dim: rs_dim; + int64_t gm_a_pingpong_size = m0 * k_align * p_value * twod_big_dim; + int64_t gm_c_pingpong_size = p_value * twod_big_dim * n_loop * m0 * n0; + int32_t m_loop_per_bigdim = DivCeil(m_loop * ag_dim, twod_big_dim); + int64_t m_per_bigdim = m * ag_dim / twod_big_dim; + int32_t comm_count = DivCeil(batch_size * m_loop_per_bigdim, p_value); + int32_t loop_num_per_cal = p_value * n_loop * twod_big_dim; + int32_t ag_part_dim = twod_big_dim / ag_dim; + int32_t rs_part_dim = twod_big_dim / rs_dim; + for (int32_t comm_idx = 0; comm_idx < comm_count; comm_idx++){ + uint64_t flag_id = comm_idx % MAX_BLOCK_COUNT; + int32_t actual_p_value = p_value; + if (comm_idx == comm_count - 1){ + actual_p_value = m_loop_per_bigdim - comm_idx * p_value; + } + WaitEvent(flag_id); + + int32_t actual_loop_num = actual_p_value * twod_big_dim * n_loop; + int32_t core_loop_num = DivCeil(actual_p_value * twod_big_dim * n_loop, core_num); + for (int32_t core_loop_idx = 0; core_loop_idx < core_loop_num; core_loop_idx++) { + int32_t loop_offset = core_loop_idx * core_num + core_idx; + if (loop_offset >= actual_loop_num) { + continue; + } + int32_t loop_idx = comm_idx * loop_num_per_cal + loop_offset; + int64_t batch_idx = loop_idx / (m_loop * n_loop * twod_big_dim); + + int64_t m_idx, n_idx; + GetBlockIdx(loop_offset, actual_p_value * twod_big_dim, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + + int32_t m_idx_in_rank = m_idx % actual_p_value; + int64_t m_idx_in_c = comm_idx * p_value + m_idx_in_rank; + int32_t m_actual = (m_idx_in_c == (m_loop_per_bigdim - 1)) ? (m_per_bigdim - m_idx_in_c * m0) : m0; + int32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + int64_t bigdim_idx = m_idx / actual_p_value; + // if bigdim=rs, ag_src_idx=bigdim_idx / (bigdim/agdim), ag_part_idx=bigdim_idx % (bigdim/agdim) + // 当rsdim>agdim时,ag会在每张卡拉ag_part_dim个块(每块是pvalue行); + // 当前core从ag_src_idx卡拉第ag_part_idx个块 + // 当rsdim(gm_peer_mem) + + (comm_idx % MAX_BLOCK_COUNT) * gm_a_pingpong_size + + bigdim_idx * p_value * m0 * k_align; + }else { + gm_mem_st = gm_a_src + (comm_idx * p_value) * m0 * k_align + ag_part_idx * m_per_bigdim * k_align; + // comm_idx * p_value决定每块内部的位置;ag_part_idx * m_per_bigdim决定第几块 + } + + CalLoop(batch_idx, m_idx_in_rank, n_idx, m_actual, n_actual, gm_mem_st); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + int64_t offset_c; + int64_t dst_stride; + __gm__ OutDtype *gm_dst = nullptr; + + // 每张卡最终大小为m * ag_dim / rs_dim + if (rs_dst_idx != rs_rank_idx){ // 需要RS,写到本卡的shared mem + offset_c = gm_c_pingpong_size * (comm_idx % MAX_BLOCK_COUNT) + + (m_idx * n_loop + n_idx) * m0 * n0 + + LCAL_2DTP_C_OFFSET; + gm_dst = gm_peer_mem; + dst_stride = n0; + }else { // 无需RS,写到本卡的gm_c;此处batch size可能不对; + offset_c = rs_part_idx * m_per_bigdim * n + + m_idx_in_c * m0 * n + + n_idx * n0; + gm_dst = gm_c; + dst_stride = n; + } + // copy from L0C to gm + MoveL0CToGM(gm_dst, offset_c, m_actual, n_actual, (m_actual + 15) / 16 * 16, dst_stride); + } + FFTSCrossCoreSync(2, flag_id); + } + + Endflags(); + PipeBarrier(); + } + + inline __aicore__ void Run() { + if (RUN_TYPE == PPMATMUL_RUN_PURE_MATMUL) { + RunPureMatmul(); + } else if (RUN_TYPE == PPMATMUL_RUN_MATMUL_ALLREDUCE) { + if (withSerialMode) { + gm_c = gm_peer_mem; + RunPureMatmul(); + } else { + RunMatmulAllReduce(); + } + } else if (RUN_TYPE == PPMATMUL_RUN_MATMUL_REDUCE_SCATTER) { + RunMatmulReduceScatter(); + } else if (RUN_TYPE == PPMATMUL_RUN_ALL_GATHER_MATMUL) { + RunAllGatherMatmul(); + } else if (RUN_TYPE == PPMATMUL_RUN_ALL_GATHER_MATMUL_REDUCE_SCATTER){ + RunAllGatherMatmulReduceScatter(); + } + } + +protected: + __gm__ MmadDtype *gm_a_src{nullptr}; + __gm__ MmadDtype *gm_b_src{nullptr}; + + __gm__ OutDtype *gm_c{nullptr}; + __gm__ OutDtype *gm_peer_mem{nullptr}; + __gm__ int64_t *gm_dequant_scale{nullptr}; + __gm__ int32_t *gm_format_dequant_offset{nullptr}; + __gm__ int32_t *gm_accum{nullptr}; + + __cbuf__ MmadDtype *l1_base_a = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) SCALE_L1_SIZE); + __cbuf__ MmadDtype *l1_base_b = reinterpret_cast<__cbuf__ MmadDtype *>((uintptr_t) (128 * 1024)); + + __ca__ MmadDtype *l0a_base = reinterpret_cast<__ca__ MmadDtype *>((uintptr_t) 0); + __cb__ MmadDtype *l0b_base = reinterpret_cast<__cb__ MmadDtype *>((uintptr_t) 0); + + __cc__ T_ACCUM *l0c_buf = reinterpret_cast<__cc__ T_ACCUM *>((uintptr_t) 0); + + __cbuf__ int64_t *scale_l1 = reinterpret_cast<__cbuf__ int64_t *>((uintptr_t) 0); + __fbuf__ int64_t *scale_FB = (__fbuf__ int64_t *)(0); + + __cbuf__ int32_t *bias_l1 = reinterpret_cast<__cbuf__ int32_t *>((uintptr_t)0); + uint16_t bias_bt = 0; + bool has_offset{false}; + LcalWorkspaceInfo workspace_info; + + int32_t core_num; + + int32_t batch_size; + int32_t m; + int32_t k; + int32_t n; + int32_t m_align; + int64_t k_align; + int32_t n_align; + int32_t k_align16; + int32_t n_align16; + int32_t m0; + int32_t k0; + int32_t n0; + + int32_t m_loop; + int32_t n_loop; + int32_t k_loop; + int32_t core_loop; + int32_t core_idx; + int32_t ping_flag; + int32_t block_size; + int32_t cube_matrix_size; + + int32_t aligned_a; + int32_t aligned_b; + + int32_t swizzl_count; + int32_t swizzl_direct; + + int32_t L1_PINGPONG_BUFFER_LEN; + int32_t L0AB_PINGPONG_BUFFER_LEN; + int32_t rank; + int32_t rank_size; + int32_t p_value; + int32_t loop_num_per_comm; + + int32_t withSerialMode; + int32_t buffer_size; + + // AG+MM+RS + int32_t ag_dim; + int32_t rs_dim; + bool inner_dim_is_Ag{false}; + int32_t ag_rank_idx; + int32_t rs_rank_idx; + bool weight_nz{false}; + // sio + bool is_91093{false}; + QuantGranularity dequant_granularity; + +}; + +#elif __DAV_C220_VEC__ + +#include "coc_preprocessor.cce" +#include "coc_add_bias_runner.cce" +#include "coc_dequant_runner.cce" +#include "tiling_args.h" + +template +inline __aicore__ void CocPureMatmulAiv(COC_ARGS_FUN(T)) +{ + SetAtomicNone(); + SetMaskNorm(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + // get tiling args + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + auto quantInfo = ¶->quantInfo; + auto moeInfo = ¶->moeInfo; + + GlobalTensor commArgsGm; + commArgsGm.SetGlobalBuffer(reinterpret_cast<__gm__ int *>(coc_comm_args), 2); + uint32_t extraFlag = commArgsGm.GetValue(4); + bool is_deterministic = (extraFlag & ExtraFlag::DETERMINISTIC) != 0; + + int32_t batch_size = cocTilingData->batchSize; + int32_t m = cocTilingData->m; + int32_t k = cocTilingData->k; + int32_t n = cocTilingData->n; + + int32_t m0 = cocTilingData->m0; + int32_t k0 = cocTilingData->k0; + int32_t n0 = cocTilingData->n0; + + int32_t m_loop = cocTilingData->mLoop; + int32_t k_loop = cocTilingData->kLoop; + int32_t n_loop = cocTilingData->nLoop; + + int32_t core_loop = cocTilingData->coreLoop; + int32_t swizzl_count = cocTilingData->swizzlCount; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t rank = cocTilingData->rank; + int32_t rank_size = cocTilingData->rankSize; + int32_t p_value = cocTilingData->pValue; + QuantGranularity dequant_granularity = static_cast(quantInfo->dequantGranularity); + int32_t dequant_group_size = quantInfo->dequantGroupSize; + QuantGranularity quant_granularity = static_cast(quantInfo->quantGranularity); + int32_t quant_group_size = quantInfo->quantGroupSize; + bool weight_nz = para->weightNz; + bool swizzl_direct = (tiling_key & SWIZZL_MASK) ? true : false; + bool trans_a = (tiling_key & TRANS_A_MASK) ? true : false; + bool trans_b = (tiling_key & TRANS_B_MASK) ? true : false; + bool have_bias = (tiling_key & BIAS_MASK) ? true : false; + bool is_int8 = (tiling_key & INT8_MASK) ? true : false; + + int32_t local_expert_nums = moeInfo->local_expert_nums; + int32_t EP = moeInfo->EP; + int32_t TP = moeInfo->TP; + int32_t is_moe_averaged = 0; + int32_t is_alltoallvc = 0; + int32_t is_moe = moeInfo->isMoe; + + + int32_t m_align, k_align, n_align; + if (is_int8) { + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + } else { + m_align = Block512B::AlignUp(m); + k_align = Block512B::AlignUp(k); + n_align = Block512B::AlignUp(n); + } + int32_t aligned_a, aligned_b; + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + + bool has_a_align = IsQuant(quant_granularity) || aligned_a; + bool has_b_align = IsQuant(dequant_granularity) && !is_int8 || aligned_b; + bool has_accum = IsQuant(dequant_granularity) && is_int8 && std::is_same::value; + bool has_dequant_param = (dequant_granularity == QuantGranularity::PER_TOKEN || dequant_granularity == QuantGranularity::PER_TENSOR); + bool hasFormatDequantScale = (dequant_granularity == QuantGranularity::PER_CHANNEL); + if (weight_nz) { + aligned_b = 0; + has_b_align = false; + } + auto workspace_info = GetLcalWorkspaceInfo(gm_workspace, batch_size, m, k, n, m_align, k_align, n_align, + trans_a, trans_b, is_int8 ? 1 : 2, has_a_align, has_b_align, 0, has_accum, 0, has_dequant_param, + hasFormatDequantScale,is_deterministic, 0, is_alltoallvc, 0, 0, 0); + + Preprocessor preprocessor; + PureMatmulBiasAdder add_bias_runner; + SerialDequantRunner serial_dequant_runner; + + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + preprocessor.Run(); + + if (has_accum) { + serial_dequant_runner.SetArgs(reinterpret_cast<__gm__ bfloat16_t *>(gm_out), workspace_info, + reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale), + reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset), dequant_granularity, batch_size, m, n); + serial_dequant_runner.FormatScale(); + } + + if (have_bias) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + + WaitEvent(AIV_WAIT_AIC_FINISH_MATMUL_FLAG_ID); + + if (has_accum) { + serial_dequant_runner.Run(); + } + + if (have_bias) { + add_bias_runner.Run(); + } +} + +#endif +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_ppmatmul_switch.cce b/comm/lcal/src/kernels/coc_ppmatmul_switch.cce new file mode 100644 index 0000000000000000000000000000000000000000..de0789d26084ead1723ca4ab4f289917583aefaf --- /dev/null +++ b/comm/lcal/src/kernels/coc_ppmatmul_switch.cce @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "coc_internal.cce" +#include "coc_ppmatmul.cce" +#include "tiling_args.h" +#include "coc_matmulmoe.cce" + + +#ifdef __DAV_C220_CUBE__ + +template +FORCE_INLINE_AICORE void RunPpMatmul(int32_t tiling_key, PP_MATMUL_AIC_ARGS_FUN(TData, TData)) { + constexpr bool IS_MOE = (RUN_TYPE == PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL_HIDDEN) || (RUN_TYPE == PPMATMUL_RUN_MATMUL_REDUCE_SCATTER_ALL_TO_ALL_HIDDEN) + || (RUN_TYPE == PPMATMUL_RUN_ALL_TO_ALL_ALL_GATHER_MATMUL); + if (IS_MOE) { + PpMatmulMoe matmul_z; + PpMatmulMoe matmul_tb_z; + PpMatmulMoe matmul_z_int8; + PpMatmulMoe matmul_tb_z_int8; + int32_t tiling_key_sel = tiling_key & 0b011101; + switch (tiling_key_sel) { + case 0b000000 : + matmul_z.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_z.Run(); + break; + case 0b001000 : + matmul_tb_z.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_tb_z.Run(); + break; + case 0b000100 : + matmul_z_int8.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_z_int8.Run(); + break; + case 0b001100 : + matmul_tb_z_int8.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_tb_z_int8.Run(); + break; + default : + break; + } + } else { + PpMatmul matmul_z; + PpMatmul matmul_tb_z; + PpMatmul matmul_z_int8; + PpMatmul matmul_tb_z_int8; + int32_t tiling_key_sel = tiling_key & 0b011101; + switch (tiling_key_sel) { + case 0b000000 : + matmul_z.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_z.Run(); + break; + case 0b001000 : + matmul_tb_z.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_tb_z.Run(); + break; + case 0b000100 : + matmul_z_int8.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_z_int8.Run(); + break; + case 0b001100 : + matmul_tb_z_int8.SetArgs(PP_MATMUL_AIC_ARGS_CALL()); + matmul_tb_z_int8.Run(); + break; + default : + break; + } + } + // 创建不同类型的PpMatmul实例 + +} + +template +inline __aicore__ void CocPpmatmulSwitchAic(COC_ARGS_FUN(TData)) { + CoCBuffAddrAndArgs coc_buff_and_args(COC_ARGS_CALL()); + __gm__ TData* buff[LCAL_MAX_RANK_SIZE]; + for (int i = 0; i < coc_buff_and_args.rankSize; ++i) { + buff[i] = coc_buff_and_args.buff[i]; + } + bool is_deterministic = coc_buff_and_args.DETERMINISTIC; + set_padding(0); + SetAtomicNone(); + uint64_t config = 0x1; + set_nd_para(config); + SetSyncBaseAddr((uint64_t)ffts_addr); + + // 获取 tiling 参数 + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + auto quantInfo = ¶->quantInfo; + auto twoDimTPInfo = ¶->twoDimTPInfo; + auto moeInfo = ¶->moeInfo; + + bool weight_nz = para->weightNz; + int32_t batch_size = cocTilingData->batchSize; + int32_t m = cocTilingData->m; + int32_t k = cocTilingData->k; + int32_t n = cocTilingData->n; + + int32_t m0 = cocTilingData->m0; + int32_t k0 = cocTilingData->k0; + int32_t n0 = cocTilingData->n0; + + int32_t m_loop = cocTilingData->mLoop; + int32_t k_loop = cocTilingData->kLoop; + int32_t n_loop = cocTilingData->nLoop; + + int32_t core_loop = cocTilingData->coreLoop; + int32_t swizzl_count = cocTilingData->swizzlCount; + int32_t tiling_key = cocTilingData->tilingKey; + int32_t rank = cocTilingData->rank; + int32_t rank_size = cocTilingData->rankSize; + int32_t p_value = cocTilingData->pValue; + int32_t withSerialMode = cocTilingData->withSerialMode; + bool is_91093 = cocTilingData->is91093; + int32_t buffer_size = cocTilingData->bufferSize; + + int32_t swizzl_direct = (tiling_key & SWIZZL_MASK) ? 1 : 0; + bool is_int8 = (tiling_key & INT8_MASK) != 0; + QuantGranularity dequant_granularity = static_cast(quantInfo->dequantGranularity); + int32_t dequant_group_size = quantInfo->dequantGroupSize; + QuantGranularity quant_granularity = static_cast(quantInfo->quantGranularity); + int32_t quant_group_size = quantInfo->quantGroupSize; + __gm__ TData* gm_peer_mem = buff[rank]; + __gm__ TData* gm_c = gm_out; + int32_t ag_dim = twoDimTPInfo->agDim; + int32_t rs_dim = twoDimTPInfo->rsDim; + bool inner_dim_is_Ag = twoDimTPInfo->innerDimIsAg; + + int32_t local_expert_nums = moeInfo->local_expert_nums; + int32_t TP = moeInfo->TP; + int32_t EP = moeInfo->EP; + int32_t maxOutputSize = moeInfo->maxOutputSize; + int32_t is_moe = moeInfo->isMoe; + + RunPpMatmul(tiling_key, PP_MATMUL_AIC_ARGS_CALL()); + PipeBarrier(); +} + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_preprocessor.cce b/comm/lcal/src/kernels/coc_preprocessor.cce new file mode 100644 index 0000000000000000000000000000000000000000..ce0de5678d532836bbae2c734c71a10e863df20d --- /dev/null +++ b/comm/lcal/src/kernels/coc_preprocessor.cce @@ -0,0 +1,2684 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef __COC_PREPROCESSOR__ +#define __COC_PREPROCESSOR__ + +#ifdef __DAV_C220_VEC__ + +#include +#include "coc_internal.cce" +#include "kernel_operator.h" +using namespace AscendC; + + +template +class BasePadder { +public: + class LoopIter { + public: + inline __aicore__ LoopIter(int32_t batch_size, int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) : + batch_size(batch_size), n_rows(n_rows), n_cols(n_cols), n_cols_aligned(n_cols_aligned) + { + int32_t align_core_num = get_block_num() * get_subblockdim(); + int32_t align_core_idx = get_block_idx() * get_subblockdim() + get_subblockid(); + int32_t n_rows_per_core_base = n_rows / align_core_num; + int32_t n_rows_remainder = n_rows % align_core_num; + int32_t row_offset_base = align_core_idx * n_rows_per_core_base; + if (align_core_idx < n_rows_remainder) { + n_rows_this_core = n_rows_per_core_base + 1; + row_offset_this_core = row_offset_base + align_core_idx; + } else { + n_rows_this_core = n_rows_per_core_base; + row_offset_this_core = row_offset_base + n_rows_remainder; + } + n_cols_this_core = n_cols; + col_offset_this_core = 0; + + src_core_offset = 1LL * row_offset_this_core * n_cols; + dst_core_offset = 1LL * row_offset_this_core * n_cols_aligned; + } + + inline __aicore__ void InitBatchLoop() + { + batch_idx = 0; + + src_batch_offset = 0; + dst_batch_offset = 0; + } + + inline __aicore__ bool EndBatchLoop() const + { + return batch_idx == batch_size; + } + + inline __aicore__ void NextBatchLoop() + { + ++batch_idx; + if (EndBatchLoop()) { + return; + } + + src_batch_offset = batch_idx * n_rows * n_cols; + dst_batch_offset = batch_idx * n_rows * n_cols_aligned; + } + + inline __aicore__ void InitRowLoop(int32_t max_rows_per_loop) + { + this->max_rows_per_loop = max_rows_per_loop; + n_rows_complete = 0; + src_row_loop_offset = 0; + dst_row_loop_offset = 0; + + n_rows_this_loop = (n_rows_this_core < max_rows_per_loop) ? n_rows_this_core : max_rows_per_loop; + } + + inline __aicore__ bool EndRowLoop() const + { + return n_rows_complete == n_rows_this_core; + } + + inline __aicore__ void NextRowLoop() + { + n_rows_complete += n_rows_this_loop; + if (EndRowLoop()) { + return; + } + + if (n_rows_complete + n_rows_this_loop > n_rows_this_core) { + n_rows_this_loop = n_rows_this_core - n_rows_complete; + } + src_row_loop_offset = n_rows_complete * n_cols; + dst_row_loop_offset = n_rows_complete * n_cols_aligned; + } + + inline __aicore__ void InitColLoop(int32_t max_cols_per_loop) + { + this->max_cols_per_loop = max_cols_per_loop; + n_cols_complete = 0; + col_loop_offset = 0; + + n_cols_this_loop = (n_cols < max_cols_per_loop) ? n_cols : max_cols_per_loop; + } + + inline __aicore__ bool EndColLoop() const + { + return n_cols_complete == n_cols_this_core; + } + + inline __aicore__ void NextColLoop() + { + n_cols_complete += n_cols_this_loop; + if (EndColLoop()) { + return; + } + + if (n_cols_complete + n_cols_this_loop > n_cols_this_core) { + n_cols_this_loop = n_cols_this_core - n_cols_complete; + } + col_loop_offset = n_cols_complete; + } + + inline __aicore__ int64_t src_offset() const + { + return src_core_offset + src_batch_offset + src_row_loop_offset + col_loop_offset; + } + + inline __aicore__ int64_t dst_offset() const + { + return dst_core_offset + dst_batch_offset + dst_row_loop_offset + col_loop_offset; + } + + int32_t batch_size; + int32_t n_rows; + int32_t n_cols; + int32_t n_cols_aligned; + + int32_t n_rows_this_core; + int32_t n_cols_this_core; + int32_t row_offset_this_core; + int32_t col_offset_this_core; + + int32_t max_rows_per_loop; + int32_t max_cols_per_loop; + + int32_t batch_idx; + int32_t n_rows_complete; + int32_t n_cols_complete; + + int32_t n_rows_this_loop; + int32_t n_cols_this_loop; + + int64_t src_core_offset; + int64_t dst_core_offset; + int64_t src_batch_offset; + int64_t dst_batch_offset; + int64_t src_row_loop_offset; + int64_t dst_row_loop_offset; + int64_t col_loop_offset; + }; + + __aicore__ explicit BasePadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b) + { + this->gm_a = reinterpret_cast<__gm__ LhsDtype *>(gm_a); + this->gm_b = reinterpret_cast<__gm__ RhsDtype *>(gm_b); + + this->batch_size = batch_size; + this->m = m; + this->k = k; + this->n = n; + this->trans_a = trans_a; + this->trans_b = trans_b; + + this->m_align = m_align; + this->k_align = k_align; + this->n_align = n_align; + + this->aligned_a = aligned_a; + this->aligned_b = aligned_b; + + gm_a_align = reinterpret_cast<__gm__ MmadDtype *>(workspace_info.gm_a_align ? workspace_info.gm_a_align : gm_a); + gm_b_align = reinterpret_cast<__gm__ MmadDtype *>(workspace_info.gm_b_align ? workspace_info.gm_b_align : gm_b); + } + +protected: + inline __aicore__ void PadMatrix(__gm__ MmadDtype *gm_dst, __gm__ MmadDtype *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = Block32B::AlignDown(MAX_UB_BUFF / sizeof(MmadDtype)); + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_base = reinterpret_cast<__ubuf__ MmadDtype *>((uintptr_t)0); + + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + CopyGmToUbufAlign(ub_base, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + CopyUbufToGmAlign(dst, ub_base, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + } + } + } + + inline __aicore__ void Barrier() + { + FFTSCrossCoreSync(0, AIV_FINISH_ALIGN_FLAG_ID); + WaitEvent(AIV_FINISH_ALIGN_FLAG_ID); + + FFTSCrossCoreSync(2, AIC_WAIT_AIV_FINISH_ALIGN_FLAG_ID); + PipeBarrier(); + } + + __gm__ LhsDtype *__restrict__ gm_a{ nullptr }; + __gm__ RhsDtype *__restrict__ gm_b{ nullptr }; + __gm__ MmadDtype *__restrict__ gm_a_align{ nullptr }; + __gm__ MmadDtype *__restrict__ gm_b_align{ nullptr }; + + int32_t batch_size; + + int32_t m_align; + int32_t n_align; + int32_t k_align; + + int32_t m; + int32_t n; + int32_t k; + + bool trans_a; + bool trans_b; + + int32_t aligned_a; + int32_t aligned_b; + + LcalWorkspaceInfo workspace_info; +}; + +template +class Padder : public BasePadder { +public: + __aicore__ explicit Padder() = default; + + inline __aicore__ void Run(int32_t expert_per_rank = 1) + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (this->aligned_b) { + int n_rows = this->trans_b ? this->n : this->k; + int n_cols = this->trans_b ? this->k : this->n; + int n_cols_aligned = this->trans_b ? this->k_align : this->n_align; + + this->PadMatrix(this->gm_b_align, this->gm_b, n_rows * expert_per_rank, n_cols, n_cols_aligned); + } + + this->Barrier(); + } +}; + +class FormatOffset { +public: + static constexpr int32_t max_len = 49152; + + static inline __aicore__ void Loop(__gm__ int32_t *dst, int32_t offset, int32_t len) + { + static const auto ub_offset = reinterpret_cast<__ubuf__ int32_t *>((uintptr_t)0); + + int32_t repeat_num = Block256B::Count(len); + int32_t loop_num = DivCeil(repeat_num, repeat); + uint8_t repeat_this_loop = static_cast(repeat); + for (int32_t loop_idx = 0; loop_idx < loop_num; ++loop_idx) { + if (loop_idx == loop_num - 1) { + repeat_this_loop = repeat_num - loop_idx * repeat; + } + VectorDup(ub_offset + loop_idx * repeat * Block256B::size, offset, repeat_this_loop, 1, 8); + } + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + CopyUbufToGmAlign(dst, ub_offset, 1, len, 0); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + +private: + static constexpr uint8_t repeat = 255; +}; + +template <> +class Padder : public BasePadder { +public: + __aicore__ explicit Padder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_offset = nullptr, + QuantGranularity dequant_granularity = QuantGranularity::QUANT_GRANULARITY_UNDEFINED) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + + if (gm_dequant_offset != nullptr && dequant_granularity == QuantGranularity::PER_TENSOR) { + offset = *reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset); + gm_format_dequant_offset = reinterpret_cast<__gm__ int32_t *>(workspace_info.gm_dequant_param); + need_format_dequant_offset = true; + } + } + + inline __aicore__ void Run(int32_t expert_per_rank = 1) + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (this->aligned_b) { + int n_rows = this->trans_b ? this->n : this->k; + int n_cols = this->trans_b ? this->k : this->n; + int n_cols_aligned = this->trans_b ? this->k_align : this->n_align; + + this->PadMatrix(this->gm_b_align, this->gm_b, n_rows * expert_per_rank, n_cols, n_cols_aligned); + } + + if (need_format_dequant_offset) { + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + FormatOffset(); + } + + this->Barrier(); + } + +private: + inline __aicore__ void FormatOffset() + { + int32_t align_core_idx = get_block_idx() * get_subblockdim() + get_subblockid(); + int32_t align_core_num = get_block_num() * get_subblockdim(); + + int32_t len = FormatOffset::max_len; + int32_t loop_num = DivCeil(n, len); + for (int32_t i = align_core_idx; i < loop_num; i += align_core_num) { + int32_t n_complete = i * len; + if (n_complete + len > n) { + len = n - n_complete; + } + FormatOffset::Loop(gm_format_dequant_offset + n_complete, offset, len); + } + } + + __gm__ int32_t *gm_format_dequant_offset; + int32_t offset; + bool need_format_dequant_offset{ false }; +}; + +template +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset) + {} + inline __aicore__ void Run() {} +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + + scale = *reinterpret_cast<__gm__ half *>(gm_dequant_scale); + if (gm_dequant_offset) { + offset = *reinterpret_cast<__gm__ half *>(gm_dequant_offset); + has_offset = true; + } + } + + inline __aicore__ void Run() + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + int n_rows = this->trans_b ? this->n : this->k; + int n_cols = this->trans_b ? this->k : this->n; + int n_cols_aligned = this->trans_b ? this->k_align : this->n_align; + + DequantAndPadMatrix(this->gm_b_align, this->gm_b, n_rows, n_cols, n_cols_aligned); + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrix(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = Block256B::AlignDown(MAX_UB_BUFF / (sizeof(int8_t) + sizeof(half))); + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_vconv = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_muls = reinterpret_cast<__ubuf__ half *>((uintptr_t)(MAX_LEN * sizeof(int8_t))); + + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + // 1. MTE2: ub_vconv <- gm_src + CopyGmToUbufAlign(ub_vconv, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + + int32_t n_blocks_per_row = Block32B::Count(it.n_cols_this_loop) * + (sizeof(half) / sizeof(int8_t)); + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row; + int32_t repeat_times = DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT); + + // 1 -> 2 + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + // 2. V: ub_muls <- vconv(ub_vconv) + uint8_t repeat = REPEAT_PER_LOOP; + for (int32_t n_repeat_complete = 0; n_repeat_complete < repeat_times; n_repeat_complete += repeat) { + if (n_repeat_complete + repeat > repeat_times) { + repeat = repeat_times - n_repeat_complete; + } + Vconv(ub_muls + n_repeat_complete * Block256B::size, + ub_vconv + n_repeat_complete * Block256B::size, repeat, 1, 1, 8, 4); + } + + // 2 -> 1 + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + if (has_offset) { + // 2 -> 3 + PipeBarrier(); + + // 3. V: ub_muls <- ub_muls + offset + repeat = REPEAT_PER_LOOP; + for (int32_t n_repeat_complete = 0; n_repeat_complete < repeat_times; + n_repeat_complete += repeat) { + if (n_repeat_complete + repeat > repeat_times) { + repeat = repeat_times - n_repeat_complete; + } + Vadds(ub_muls + n_repeat_complete * Block256B::size, + ub_muls + n_repeat_complete * Block256B::size, offset, repeat, 1, 1, 8, 8); + } + } + + // 2/3 -> 4 + PipeBarrier(); + + // 4. V: ub_muls <- ub_muls + offset + repeat = REPEAT_PER_LOOP; + for (int32_t n_repeat_complete = 0; n_repeat_complete < repeat_times; n_repeat_complete += repeat) { + if (n_repeat_complete + repeat > repeat_times) { + repeat = repeat_times - n_repeat_complete; + } + Vmuls(ub_muls + n_repeat_complete * Block256B::size, + ub_muls + n_repeat_complete * Block256B::size, scale, repeat, 1, 1, 8, 8); + } + + int32_t ubuf_gap = n_blocks_per_row - Block32B::Count(it.n_cols_this_loop); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + // 5. MTE3: ub_muls -> dst + CopyUbufToGmAlign(dst, ub_muls, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + + // 5 -> 2 + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + } + } + } + + half scale; + half offset; + bool has_offset{ false }; +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + + if (gm_dequant_offset) { + auto scale_dptr = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_scale); + auto offset_dptr = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_offset); + + auto ub_args = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)0); + auto ub_args_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)256); + + int32_t args_gap = Block32B::size; + + CopyGmToUbufAlign(ub_args, scale_dptr, 1, 1, 0); + CopyGmToUbufAlign(ub_args + args_gap, offset_dptr, 1, 1, 0); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + Vconv(ub_args_f32, ub_args, 1, 1, 1, 8, 4); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + scale = ub_args_f32[0]; + offset = ub_args_f32[args_gap]; + + has_offset = true; + } else { + auto scale_dptr = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_scale); + + auto ub_args = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)0); + auto ub_args_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)256); + + CopyGmToUbufAlign(ub_args, scale_dptr, 1, 1, 0); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + Vconv(ub_args_f32, ub_args, 1, 1, 1, 8, 4); + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + scale = ub_args_f32[0]; + offset = 0; + } + } + + inline __aicore__ void Run() + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + int n_rows = this->trans_b ? this->n : this->k; + int n_cols = this->trans_b ? this->k : this->n; + int n_cols_aligned = this->trans_b ? this->k_align : this->n_align; + + DequantAndPadMatrix(this->gm_b_align, this->gm_b, n_rows, n_cols, n_cols_aligned); + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrix(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 16320; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)32768); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)65536); + auto ub_adds = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)65536); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)131072); + auto ub_muls = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)131072); + + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + int32_t n_blocks_per_row_b8 = Block32B::Count(it.n_cols_this_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = n_blocks_per_row_b8 * (sizeof(float32_t) / sizeof(int8_t)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID0); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vadds(ub_adds, ub_vconv_f32, offset, repeat_b32, 1, 1, 8, 8); + + PipeBarrier(); + Vmuls(ub_muls, ub_adds, scale, repeat_b32, 1, 1, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID1); + Vconv(ub_output, ub_muls, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID1); + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + } + } + } + + float scale; + float offset; + bool has_offset{ false }; +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + + gm_scale = reinterpret_cast<__gm__ half *>(gm_dequant_scale); + if (gm_dequant_offset) { + gm_offset = reinterpret_cast<__gm__ half *>(gm_dequant_offset); + has_offset = true; + } + } + + inline __aicore__ void Run() + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (!this->trans_b && !has_offset) { + DequantAndPadMatrixNoOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (!this->trans_b && has_offset) { + DequantAndPadMatrixHasOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (this->trans_b && !has_offset) { + DequantAndPadMatrixTransposeNoOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } else { + DequantAndPadMatrixTransposeHasOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrixNoOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 28032; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)28416); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)84480); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)140544); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + + int32_t n_blocks_per_row = Block32B::Count(it.n_cols_this_loop) * + (sizeof(half) / sizeof(int8_t)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale, scale, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_scale + row * n_blocks_per_row * Block32B::size, ub_quant_scale, + 0, 1, n_blocks_per_row, 0, 0); /* sid */ + } + + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_vconv, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID2); + } + SetFlag(EVENT_ID0); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixHasOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 17792; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)18688); + auto ub_quant_offset = reinterpret_cast<__ubuf__ half *>((uintptr_t)54272); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)89856); + auto ub_add = reinterpret_cast<__ubuf__ half *>((uintptr_t)125440); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)161024); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + auto offset = gm_offset + it.n_cols_complete; + + int32_t n_blocks_per_row = Block32B::Count(it.n_cols_this_loop) * + (sizeof(half) / sizeof(int8_t)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale, scale, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_quant_offset, offset, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID0); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_scale + row * n_blocks_per_row * Block32B::size, ub_quant_scale, + 0, 1, n_blocks_per_row, 0, 0); /* sid */ + } + + WaitFlag(EVENT_ID1); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_offset + row * n_blocks_per_row * Block32B::size, ub_quant_offset, + 0, 1, n_blocks_per_row, 0, 0); /* sid */ + } + + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID2); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID2); + + PipeBarrier(); + Vadd(ub_add, ub_vconv, ub_quant_offset, repeat, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID3); + Vmul(ub_output, ub_add, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID3); + + WaitFlag(EVENT_ID3); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID3); + } + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + inline __aicore__ void DequantAndPadMatrixTransposeNoOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 28032; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)28416); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)84480); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)140544); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + it.row_offset_this_core + it.n_rows_complete; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row_b16; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale, scale, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_scale + block_col * Block32B::size, ub_quant_scale, + 0, it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); /* sid */ + } + + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_vconv, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + SetFlag(EVENT_ID0); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixTransposeHasOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 17792; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)18688); + auto ub_quant_offset = reinterpret_cast<__ubuf__ half *>((uintptr_t)54272); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)89856); + auto ub_add = reinterpret_cast<__ubuf__ half *>((uintptr_t)125440); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)161024); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + it.row_offset_this_core + it.n_rows_complete; + auto offset = gm_offset + it.row_offset_this_core + it.n_rows_complete; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row_b16; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale, scale, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_quant_offset, offset, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID0); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_scale + block_col * Block32B::size, ub_quant_scale, + 0, it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); /* sid */ + } + + WaitFlag(EVENT_ID1); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_offset + block_col * Block32B::size, ub_quant_offset, + 0, it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); /* sid */ + } + + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + WaitFlag(EVENT_ID2); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID2); + + PipeBarrier(); + Vadd(ub_add, ub_vconv, ub_quant_offset, repeat, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID3); + Vmul(ub_output, ub_add, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID3); + + WaitFlag(EVENT_ID3); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID3); + } + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + __gm__ half *gm_scale{ nullptr }; + __gm__ half *gm_offset{ nullptr }; + bool has_offset{ false }; +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + gm_scale = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_scale); + if (gm_dequant_offset) { + gm_offset = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_offset); + has_offset = true; + } + } + + inline __aicore__ void Run() + { + if (aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (!trans_b && !has_offset) { + DequantAndPadMatrixNoOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (!trans_b && has_offset) { + DequantAndPadMatrixHasOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (trans_b && !has_offset) { + DequantAndPadMatrixTransposeNoOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } else { + DequantAndPadMatrixTransposeHasOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrixNoOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 10240; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)10496); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)51712); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)72192); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)113152); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)133632); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)174592); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + + int32_t n_blocks_per_row_b16 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(float32_t) / sizeof(int8_t)); + uint8_t quant_repeat_b32 = static_cast( + DivCeil(n_blocks_per_row_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin, scale, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + Vconv(ub_quant_scale, ub_quant_scale_origin, quant_repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_scale + row * n_blocks_per_row_b32 * Block32B::size, + ub_quant_scale, /* sid */ 0, 1, n_blocks_per_row_b32, 0, 0); + } + + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vmul(ub_mul, ub_vconv_f32, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID2); + } + SetFlag(EVENT_ID0); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixHasOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 9344; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)0); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)0); // multiplex ub_quant_scale_origin + auto ub_add = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)18688); + auto ub_quant_offset_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)56064); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)56064); // multiplex ub_quant_offset_origin + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)74752); + auto ub_quant_offset = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)112384); + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)149760); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)159232); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)159232); // multiplex ub_conv_f32 + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + auto offset = gm_offset + it.n_cols_complete; + + int32_t n_blocks_per_row_b16 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(float32_t) / sizeof(int8_t)); + uint8_t quant_repeat_b32 = static_cast( + DivCeil(n_blocks_per_row_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin, scale, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_quant_offset_origin, offset, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID0); + Vconv(ub_quant_scale, ub_quant_scale_origin, quant_repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_scale + row * n_blocks_per_row_b32 * Block32B::size, + ub_quant_scale, /* sid */ 0, 1, n_blocks_per_row_b32, 0, 0); + } + + WaitFlag(EVENT_ID1); + Vconv(ub_quant_offset, ub_quant_offset_origin, quant_repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + for (int32_t row = 1; row < max_rows_per_loop; ++row) { + CopyUB2UB(ub_quant_offset + row * n_blocks_per_row_b32 * Block32B::size, + ub_quant_offset, /* sid */ 0, 1, n_blocks_per_row_b32, 0, 0); + } + + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID2); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID2); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vadd(ub_add, ub_vconv_f32, ub_quant_offset, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + Vmul(ub_mul, ub_add, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID3); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID3); + + WaitFlag(EVENT_ID3); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID3); + } + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + inline __aicore__ void DequantAndPadMatrixTransposeNoOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 10240; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)10496); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)51712); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)72192); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)113152); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)133632); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)174592); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + it.row_offset_this_core + it.n_rows_complete; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = n_blocks_per_row_b8 * (sizeof(float32_t) / sizeof(int8_t)); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin, scale, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_scale_origin + block_col * Block32B::size, + ub_quant_scale_origin, /* sid */ 0, it.n_rows_this_loop, 1, + n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + } + + PipeBarrier(); + Vconv(ub_quant_scale, ub_quant_scale_origin, repeat_b32, 1, 1, 8, 4); + + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vmul(ub_mul, ub_vconv_f32, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + SetFlag(EVENT_ID0); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixTransposeHasOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 9344; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)0); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)0); // multiplex ub_quant_scale_origin + auto ub_add = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)18688); + auto ub_quant_offset_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)56064); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)56064); // multiplex ub_quant_offset_origin + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)74752); + auto ub_quant_offset = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)112384); + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)149760); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)159232); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)159232); // multiplex ub_conv_f32 + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + it.row_offset_this_core + it.n_rows_complete; + auto offset = gm_offset + it.row_offset_this_core + it.n_rows_complete; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = n_blocks_per_row_b8 * (sizeof(float32_t) / sizeof(int8_t)); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin, scale, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_quant_offset_origin, offset, it.n_rows_this_loop, 1, 0, n_blocks_per_row_b16 - 1); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID0); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_scale_origin + block_col * Block32B::size, + ub_quant_scale_origin, /* sid */ 0, it.n_rows_this_loop, 1, + n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + } + + PipeBarrier(); + Vconv(ub_quant_scale, ub_quant_scale_origin, repeat_b32, 1, 1, 8, 4); + + WaitFlag(EVENT_ID1); + for (int32_t block_col = 1; block_col < n_blocks_per_row_b16; ++block_col) { + CopyUB2UB(ub_quant_offset_origin + block_col * Block32B::size, + ub_quant_offset_origin, /* sid */ 0, it.n_rows_this_loop, 1, + n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + } + + PipeBarrier(); + Vconv(ub_quant_offset, ub_quant_offset_origin, repeat_b32, 1, 1, 8, 4); + + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID2); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID2); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vadd(ub_add, ub_vconv_f32, ub_quant_offset, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + Vmul(ub_mul, ub_add, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID3); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID3); + + WaitFlag(EVENT_ID3); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID3); + } + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + __gm__ bfloat16_t *gm_scale{ nullptr }; + __gm__ bfloat16_t *gm_offset{ nullptr }; + bool has_offset{ false }; +}; + +template +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset, int32_t dequant_group_size) + {} + inline __aicore__ void Run() + {} +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset, int32_t dequant_group_size) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + gm_scale = reinterpret_cast<__gm__ half *>(gm_dequant_scale); + if (gm_dequant_offset) { + gm_offset = reinterpret_cast<__gm__ half *>(gm_dequant_offset); + has_offset = true; + } + group_size = dequant_group_size; + group_num = (this->k + group_size - 1) / group_size; + } + + inline __aicore__ void Run() + { + if (this->aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (!trans_b && !has_offset) { + DequantAndPadMatrixNoOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (!trans_b && has_offset) { + DequantAndPadMatrixHasOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (trans_b && !has_offset) { + DequantAndPadMatrixTransposeNoOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } else { + DequantAndPadMatrixTransposeHasOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrixNoOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 28032; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)28416); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)84480); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)140544); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + + int32_t n_blocks_per_row = Block32B::Count(it.n_cols_this_loop) * + (sizeof(half) / sizeof(int8_t)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row - Block32B::Count(it.n_cols_this_loop); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t row = 0; row < max_rows_per_loop; ++row) { + int32_t row_idx = it.row_offset_this_core + it.n_rows_complete + row; + int32_t in_group_idx = row_idx % group_size; + if (in_group_idx == 0 || it.n_rows_complete + row == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row * Block32B::size; + int32_t group_idx = row_idx / group_size; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale + ub_quant_args_offset, scale + group_idx * n_cols, + 1, it.n_cols_this_loop, 0); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < max_rows_per_loop || it.n_rows_complete == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, 1, n_blocks_per_row, 0, 0); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } else { + PipeBarrier(); + } + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_vconv, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + is_after_mte2 = false; + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixHasOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 17792; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)18688); + auto ub_quant_offset = reinterpret_cast<__ubuf__ half *>((uintptr_t)54272); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)89856); + auto ub_add = reinterpret_cast<__ubuf__ half *>((uintptr_t)125440); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)161024); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + auto offset = gm_offset + it.n_cols_complete; + + int32_t n_blocks_per_row = Block32B::Count(it.n_cols_this_loop) * + (sizeof(half) / sizeof(int8_t)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row - Block32B::Count(it.n_cols_this_loop); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t row = 0; row < max_rows_per_loop; ++row) { + int32_t row_idx = it.row_offset_this_core + it.n_rows_complete + row; + int32_t in_group_idx = row_idx % group_size; + if (in_group_idx == 0 || it.n_rows_complete + row == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row * Block32B::size; + int32_t group_idx = row_idx / group_size; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale + ub_quant_args_offset, scale + group_idx * n_cols, + 1, it.n_cols_this_loop, 0); + CopyGmToUbufAlign(ub_quant_offset + ub_quant_args_offset, offset + group_idx * n_cols, + 1, it.n_cols_this_loop, 0); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < max_rows_per_loop || it.n_rows_complete == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, + 1, n_blocks_per_row, 0, 0); + CopyUB2UB(ub_quant_offset + ub_quant_args_offset, + ub_quant_offset + ub_quant_args_root_offset, /* sid */ 0, + 1, n_blocks_per_row, 0, 0); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } else { + PipeBarrier(); + } + Vadd(ub_add, ub_vconv, ub_quant_offset, repeat, 1, 1, 1, 8, 8, 8); + is_after_mte2 = false; + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_add, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixTransposeNoOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 28032; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)28416); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)84480); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)140544); + + int32_t group_block = Block32B::Count(group_size); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + (it.row_offset_this_core + it.n_rows_complete) * group_num; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(half) / sizeof(int8_t)); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row_b16; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t block_col = 0; block_col < n_blocks_per_row_b16; ++block_col) { + int32_t block_col_idx = Block32B::Count(it.n_cols_complete) + block_col; + int32_t in_group_idx = block_col_idx % group_block; + if (in_group_idx == 0 || block_col_idx == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + int32_t group_idx = block_col_idx / group_block; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale + ub_quant_args_offset, scale + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < n_blocks_per_row_b16 || it.n_cols_complete == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } else { + PipeBarrier(); + } + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_vconv, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + is_after_mte2 = false; + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixTransposeHasOffset(__gm__ half *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 17792; + int32_t max_rows_per_loop = (it.n_rows_this_core * Block32B::size <= MAX_LEN) ? + it.n_rows_this_core : MAX_LEN / Block32B::size; + int32_t max_cols_per_loop = (it.n_rows_this_core * Block32B::size <= MAX_LEN) ? + Block32B::AlignDown(MAX_LEN / it.n_rows_this_core) : Block32B::size; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ half *>((uintptr_t)18688); + auto ub_quant_offset = reinterpret_cast<__ubuf__ half *>((uintptr_t)54272); + auto ub_output = reinterpret_cast<__ubuf__ half *>((uintptr_t)89856); + auto ub_add = reinterpret_cast<__ubuf__ half *>((uintptr_t)125440); + auto ub_vconv = reinterpret_cast<__ubuf__ half *>((uintptr_t)161024); + + int32_t group_block = Block32B::Count(group_size); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + (it.row_offset_this_core + it.n_rows_complete) * group_num; + auto offset = gm_offset + (it.row_offset_this_core + it.n_rows_complete) * group_num; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(half) / sizeof(int8_t)); + + int32_t n_blocks = it.n_rows_this_loop * n_blocks_per_row_b16; + uint8_t repeat = static_cast(DivCeil(n_blocks, VEC_BLOCK_PER_REPEAT)); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv, ub_input, repeat, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t block_col = 0; block_col < n_blocks_per_row_b16; ++block_col) { + int32_t block_col_idx = Block32B::Count(it.n_cols_complete) + block_col; + int32_t in_group_idx = block_col_idx % group_block; + if (in_group_idx == 0 || block_col_idx == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + int32_t group_idx = block_col_idx / group_block; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale + ub_quant_args_offset, scale + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + CopyGmToUbufAlign(ub_quant_offset + ub_quant_args_offset, offset + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < n_blocks_per_row_b16 || it.n_cols_complete == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + CopyUB2UB(ub_quant_offset + ub_quant_args_offset, + ub_quant_offset + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } else { + PipeBarrier(); + } + Vadd(ub_add, ub_vconv, ub_quant_offset, repeat, 1, 1, 1, 8, 8, 8); + is_after_mte2 = false; + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vmul(ub_output, ub_add, ub_quant_scale, repeat, 1, 1, 1, 8, 8, 8); + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + __gm__ half *gm_scale{ nullptr }; + __gm__ half *gm_offset{ nullptr }; + int32_t group_size; + int32_t group_num; + bool has_offset{ false }; +}; + +template <> +class DequantPadder : public BasePadder { +public: + __aicore__ explicit DequantPadder() = default; + + inline __aicore__ void SetArgs(__gm__ uint8_t *gm_a, __gm__ uint8_t *gm_b, const LcalWorkspaceInfo &workspace_info, + int32_t batch_size, int32_t m, int32_t k, int32_t n, int32_t m_align, int32_t k_align, int32_t n_align, bool aligned_a, bool aligned_b, bool trans_a, bool trans_b, + __gm__ uint8_t *gm_dequant_scale, __gm__ uint8_t *gm_dequant_offset, int32_t dequant_group_size) + { + this->BasePadder::SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + gm_scale = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_scale); + if (gm_dequant_offset) { + gm_offset = reinterpret_cast<__gm__ bfloat16_t *>(gm_dequant_offset); + has_offset = true; + } + group_size = dequant_group_size; + group_num = (this->k + group_size - 1) / group_size; + } + + inline __aicore__ void Run() + { + if (aligned_a) { + int n_rows = this->trans_a ? this->k : this->m; + int n_cols = this->trans_a ? this->m : this->k; + int n_cols_aligned = this->trans_a ? this->m_align : this->k_align; + + this->PadMatrix(this->gm_a_align, this->gm_a, n_rows, n_cols, n_cols_aligned); + } + + SetFlag(EVENT_ID1); + WaitFlag(EVENT_ID1); + + if (!trans_b && !has_offset) { + DequantAndPadMatrixNoOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (!trans_b && has_offset) { + DequantAndPadMatrixHasOffset(this->gm_b_align, this->gm_b, this->k, this->n, this->n_align); + } else if (trans_b && !has_offset) { + DequantAndPadMatrixTransposeNoOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } else { + DequantAndPadMatrixTransposeHasOffset(this->gm_b_align, this->gm_b, this->n, this->k, this->k_align); + } + + this->Barrier(); + } + +private: + inline __aicore__ void DequantAndPadMatrixNoOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 10240; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)10496); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)51712); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)72192); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)113152); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)133632); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)174592); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + + int32_t n_blocks_per_row_b16 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(float32_t) / sizeof(int8_t)); + uint8_t quant_repeat_b32 = static_cast( + DivCeil(n_blocks_per_row_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + for (int32_t row = 0; row < max_rows_per_loop; ++row) { + int32_t row_idx = it.row_offset_this_core + it.n_rows_complete + row; + int32_t in_group_idx = row_idx % group_size; + if (in_group_idx == 0 || it.n_rows_complete + row == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row_b16 * Block32B::size; + int32_t group_idx = row_idx / group_size; + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin + ub_quant_args_offset, scale + group_idx * n_cols, + 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + Vconv(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale_origin + ub_quant_args_offset, quant_repeat_b32, 1, 1, 8, 4); + SetFlag(EVENT_ID0); + + ub_quant_args_root_offset = ub_quant_args_offset; + PipeBarrier(); + } else if (in_group_idx < max_rows_per_loop || it.n_rows_complete == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row_b32 * Block32B::size; + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, + 1, n_blocks_per_row_b32, 0, 0); + } + } + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vmul(ub_mul, ub_vconv_f32, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixHasOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 8512; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_quant_offset = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)34048); + auto ub_add = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)68096); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)68096); // multiplex ub_add + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)102144); + auto ub_quant_offset_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)119168); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)136192); + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)153216); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)162560); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)162560); // multiplex ub_vconv_f32 + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + SetFlag(EVENT_ID3); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto scale = gm_scale + it.n_cols_complete; + auto offset = gm_offset + it.n_cols_complete; + + int32_t n_blocks_per_row_b16 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = + Block32B::Count(it.n_cols_this_loop) * (sizeof(float32_t) / sizeof(int8_t)); + uint8_t quant_repeat_b32 = static_cast( + DivCeil(n_blocks_per_row_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + int32_t ubuf_gap = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + for (int32_t row = 0; row < max_rows_per_loop; ++row) { + int32_t row_idx = it.row_offset_this_core + it.n_rows_complete + row; + int32_t in_group_idx = row_idx % group_size; + if (in_group_idx == 0 || it.n_rows_complete + row == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row_b16 * Block32B::size; + int32_t group_idx = row_idx / group_size; + WaitFlag(EVENT_ID0); + CopyGmToUbufAlign(ub_quant_scale_origin + ub_quant_args_offset, scale + group_idx * n_cols, + 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID0); + Vconv(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale_origin + ub_quant_args_offset, quant_repeat_b32, 1, 1, 8, 4); + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_quant_offset_origin + ub_quant_args_offset, + offset + group_idx * n_cols, 1, it.n_cols_this_loop, 0); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_quant_offset + ub_quant_args_offset, + ub_quant_offset_origin + ub_quant_args_offset, quant_repeat_b32, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + ub_quant_args_root_offset = ub_quant_args_offset; + PipeBarrier(); + } else if (in_group_idx < max_rows_per_loop || it.n_rows_complete == 0) { + int32_t ub_quant_args_offset = row * n_blocks_per_row_b32 * Block32B::size; + CopyUB2UB(ub_quant_scale + ub_quant_args_offset, + ub_quant_scale + ub_quant_args_root_offset, /* sid */ 0, + 1, n_blocks_per_row_b32, 0, 0); + CopyUB2UB(ub_quant_offset + ub_quant_args_offset, + ub_quant_offset + ub_quant_args_root_offset, /* sid */ 0, + 1, n_blocks_per_row_b32, 0, 0); + } + } + + WaitFlag(EVENT_ID2); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID2); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vadd(ub_add, ub_vconv_f32, ub_quant_offset, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + Vmul(ub_mul, ub_add, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID3); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID3); + + WaitFlag(EVENT_ID3); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap); + SetFlag(EVENT_ID3); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + WaitFlag(EVENT_ID3); + } + + inline __aicore__ void DequantAndPadMatrixTransposeNoOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 10240; + int32_t n_cols_round = Block32B::AlignUp(n_cols); + int32_t max_rows_per_loop = (n_cols_round <= MAX_LEN) ? (MAX_LEN / n_cols_round) : 1; + int32_t max_cols_per_loop = (n_cols_round <= MAX_LEN) ? n_cols : MAX_LEN; + + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)0); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)10496); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)51712); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)72192); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)113152); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)133632); + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)174592); + + int32_t group_block = Block32B::Count(group_size); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + (it.row_offset_this_core + it.n_rows_complete) * group_num; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = n_blocks_per_row_b8 * (sizeof(float32_t) / sizeof(int8_t)); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t block_col = 0; block_col < n_blocks_per_row_b16; ++block_col) { + int32_t block_col_idx = Block32B::Count(it.n_cols_complete) + block_col; + int32_t in_group_idx = block_col_idx % group_block; + if (in_group_idx == 0 || block_col_idx == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + int32_t group_idx = block_col_idx / group_block; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale_origin + ub_quant_args_offset, scale + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < n_blocks_per_row_b16 || it.n_cols_complete == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale_origin + ub_quant_args_offset, + ub_quant_scale_origin + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } else { + PipeBarrier(); + } + Vconv(ub_quant_scale, ub_quant_scale_origin, repeat_b32, 1, 1, 8, 4); + is_after_mte2 = false; + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vmul(ub_mul, ub_vconv_f32, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + inline __aicore__ void DequantAndPadMatrixTransposeHasOffset(__gm__ bfloat16_t *gm_dst, __gm__ int8_t *gm_src, + int32_t n_rows, int32_t n_cols, int32_t n_cols_aligned) + { + LoopIter it(this->batch_size, n_rows, n_cols, n_cols_aligned); + + const int32_t MAX_LEN = 8512; + int32_t max_rows_per_loop = (it.n_rows_this_core * Block32B::size <= MAX_LEN) ? + it.n_rows_this_core : MAX_LEN / Block32B::size; + int32_t max_cols_per_loop = (it.n_rows_this_core * Block32B::size <= MAX_LEN) ? + Block32B::AlignDown(MAX_LEN / it.n_rows_this_core) : Block32B::size; + + auto ub_quant_offset = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)0); + auto ub_quant_scale = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)34048); + auto ub_add = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)68096); + auto ub_vconv_f16 = reinterpret_cast<__ubuf__ float16_t *>((uintptr_t)68096); // multiplex ub_add + auto ub_output = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)102144); + auto ub_quant_offset_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)119168); + auto ub_quant_scale_origin = reinterpret_cast<__ubuf__ bfloat16_t *>((uintptr_t)136192); + auto ub_input = reinterpret_cast<__ubuf__ int8_t *>((uintptr_t)153216); + auto ub_vconv_f32 = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)162560); + auto ub_mul = reinterpret_cast<__ubuf__ float32_t *>((uintptr_t)162560); // multiplex ub_vconv_f32 + + int32_t group_block = Block32B::Count(group_size); + + SetFlag(EVENT_ID0); + SetFlag(EVENT_ID1); + SetFlag(EVENT_ID2); + for (it.InitBatchLoop(); !it.EndBatchLoop(); it.NextBatchLoop()) { + for (it.InitRowLoop(max_rows_per_loop); !it.EndRowLoop(); it.NextRowLoop()) { + auto scale = gm_scale + (it.row_offset_this_core + it.n_rows_complete) * group_num; + auto offset = gm_offset + (it.row_offset_this_core + it.n_rows_complete) * group_num; + + int32_t n_blocks_per_row_b8 = Block32B::Count(max_cols_per_loop); + int32_t n_blocks_per_row_b16 = n_blocks_per_row_b8 * (sizeof(bfloat16_t) / sizeof(int8_t)); + int32_t n_blocks_per_row_b32 = n_blocks_per_row_b8 * (sizeof(float32_t) / sizeof(int8_t)); + + int32_t n_blocks_b16 = it.n_rows_this_loop * n_blocks_per_row_b16; + int32_t n_blocks_b32 = it.n_rows_this_loop * n_blocks_per_row_b32; + uint8_t repeat_b16 = static_cast( + DivCeil(n_blocks_b16, VEC_BLOCK_PER_REPEAT)); + uint8_t repeat_b32 = static_cast( + DivCeil(n_blocks_b32, VEC_BLOCK_PER_REPEAT)); + + int32_t ub_quant_args_root_offset = 0; + for (it.InitColLoop(max_cols_per_loop); !it.EndColLoop(); it.NextColLoop()) { + auto src = gm_src + it.src_offset(); + auto dst = gm_dst + it.dst_offset(); + + int32_t src_gap = n_cols - it.n_cols_this_loop; + int32_t dst_gap = n_cols_aligned - it.n_cols_this_loop; + + int32_t ubuf_gap_b8 = n_blocks_per_row_b8 - Block32B::Count(it.n_cols_this_loop); + int32_t ubuf_gap_b16 = n_blocks_per_row_b16 - Block32B::Count(it.n_cols_this_loop); + + bool is_after_mte2 = false; + WaitFlag(EVENT_ID0); + for (int32_t block_col = 0; block_col < n_blocks_per_row_b16; ++block_col) { + int32_t block_col_idx = Block32B::Count(it.n_cols_complete) + block_col; + int32_t in_group_idx = block_col_idx % group_block; + if (in_group_idx == 0 || block_col_idx == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + int32_t group_idx = block_col_idx / group_block; + + if (ub_quant_args_offset == ub_quant_args_root_offset && !is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyGmToUbufAlign(ub_quant_scale_origin + ub_quant_args_offset, scale + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + CopyGmToUbufAlign(ub_quant_offset_origin + ub_quant_args_offset, offset + group_idx, + it.n_rows_this_loop, 1, group_num - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = true; + ub_quant_args_root_offset = ub_quant_args_offset; + } else if (in_group_idx < n_blocks_per_row_b16 || it.n_cols_complete == 0) { + int32_t ub_quant_args_offset = block_col * Block32B::size; + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } + CopyUB2UB(ub_quant_scale_origin + ub_quant_args_offset, + ub_quant_scale_origin + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + CopyUB2UB(ub_quant_offset_origin + ub_quant_args_offset, + ub_quant_offset_origin + ub_quant_args_root_offset, /* sid */ 0, + it.n_rows_this_loop, 1, n_blocks_per_row_b16 - 1, n_blocks_per_row_b16 - 1); + is_after_mte2 = false; + } + } + + if (is_after_mte2) { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + is_after_mte2 = false; + } else { + PipeBarrier(); + } + Vconv(ub_quant_scale, ub_quant_scale_origin, repeat_b32, 1, 1, 8, 4); + Vconv(ub_quant_offset, ub_quant_offset_origin, repeat_b32, 1, 1, 8, 4); + is_after_mte2 = false; + SetFlag(EVENT_ID0); + + WaitFlag(EVENT_ID1); + CopyGmToUbufAlign(ub_input, src, it.n_rows_this_loop, it.n_cols_this_loop, src_gap, ubuf_gap_b8); + SetFlag(EVENT_ID1); + + WaitFlag(EVENT_ID1); + Vconv(ub_vconv_f16, ub_input, repeat_b16, 1, 1, 8, 4); + SetFlag(EVENT_ID1); + + PipeBarrier(); + Vconv(ub_vconv_f32, ub_vconv_f16, repeat_b32, 1, 1, 8, 4); + + PipeBarrier(); + Vadd(ub_add, ub_vconv_f32, ub_quant_offset, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + Vmul(ub_mul, ub_add, ub_quant_scale, repeat_b32, 1, 1, 1, 8, 8, 8); + + PipeBarrier(); + WaitFlag(EVENT_ID2); + Vconv(ub_output, ub_mul, repeat_b32, 1, 1, 4, 8, RoundMode::CAST_RINT); + SetFlag(EVENT_ID2); + + WaitFlag(EVENT_ID2); + CopyUbufToGmAlign(dst, ub_output, it.n_rows_this_loop, it.n_cols_this_loop, dst_gap, ubuf_gap_b16); + SetFlag(EVENT_ID2); + } + } + } + WaitFlag(EVENT_ID0); + WaitFlag(EVENT_ID1); + WaitFlag(EVENT_ID2); + } + + __gm__ bfloat16_t *gm_scale{ nullptr }; + __gm__ bfloat16_t *gm_offset{ nullptr }; + int32_t group_size; + int32_t group_num; + bool has_offset{ false }; +}; + + +template +class Preprocessor { +public: + __aicore__ explicit Preprocessor() = default; + + FORCE_INLINE_AICORE void SetArgs(PP_MATMUL_AIV_PADDING_ARGS_FUN()) + { + this->is_int8 = is_int8; + this->dequant_granularity = dequant_granularity; + + int32_t m_align = is_int8 ? Block512B::AlignUp(m) : Block512B::AlignUp(m); + int32_t k_align = is_int8 ? Block512B::AlignUp(k) : Block512B::AlignUp(k); + int32_t n_align = is_int8 ? Block512B::AlignUp(n) : Block512B::AlignUp(n); + + int32_t aligned_a, aligned_b; + AlignJudge(trans_a, trans_b, m, k, n, m_align, k_align, n_align, aligned_a, aligned_b); + + bool has_a_align = IsQuant(quant_granularity) || aligned_a; + bool has_b_align = IsQuant(dequant_granularity) && !is_int8 || aligned_b; + bool has_accum = IsQuant(dequant_granularity) && is_int8 && std::is_same::value; + bool has_dequant_param = (dequant_granularity == QuantGranularity::PER_TOKEN || dequant_granularity == QuantGranularity::PER_TENSOR); + bool hasFormatDequantScale = (has_dequant_param || dequant_granularity == QuantGranularity::PER_CHANNEL); + + if (weight_nz) { + aligned_b = 0; + has_b_align = false; + } + LcalWorkspaceInfo workspace_info = GetLcalWorkspaceInfo(gm_workspace, batch_size, m, k, n, m_align, k_align, n_align, + trans_a, trans_b, is_int8 ? 1 : 2, has_a_align, has_b_align, 0, has_accum, 0, has_dequant_param, + hasFormatDequantScale,is_deterministic, is_moe, is_alltoallvc, EP, local_expert_nums, m * EP * TP); + + + if (this->is_int8) { + switch (this->dequant_granularity) { + case QuantGranularity::PER_TENSOR: + padder_int8.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b, + gm_dequant_offset, dequant_granularity); + return; + case QuantGranularity::PER_CHANNEL: + padder_int8.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + return; + case QuantGranularity::PER_TOKEN: + padder_int8.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + return; + case QuantGranularity::FLOAT32_SCALE_PER_CHANNEL: + padder_int8.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + return; + default: + return; + } + } + switch (this->dequant_granularity) { + case QuantGranularity::PER_TENSOR: + dequant_per_tensor_padder.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b, + gm_dequant_scale, gm_dequant_offset); + return; + case QuantGranularity::PER_CHANNEL: + dequant_per_channel_padder.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b, + gm_dequant_scale, gm_dequant_offset); + return; + case QuantGranularity::PER_GROUP: + dequant_per_group_padder.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b, + gm_dequant_scale, gm_dequant_offset, dequant_group_size); + return; + default: + padder.SetArgs(gm_a, gm_b, workspace_info, batch_size, m, k, n, + m_align, k_align, n_align, aligned_a, aligned_b, trans_a, trans_b); + return; + } + } + + FORCE_INLINE_AICORE void Run(int32_t expert_per_rank = 1) + { + if (this->is_int8) { + padder_int8.Run(expert_per_rank); + return; + } + switch (this->dequant_granularity) { + case QuantGranularity::PER_TENSOR: + dequant_per_tensor_padder.Run(); + return; + case QuantGranularity::PER_CHANNEL: + dequant_per_channel_padder.Run(); + return; + case QuantGranularity::PER_GROUP: + dequant_per_group_padder.Run(); + return; + default: + padder.Run(expert_per_rank); + return; + } + } + +private: + Padder padder; + Padder padder_int8; + + DequantPadder dequant_per_tensor_padder; + DequantPadder dequant_per_channel_padder; + DequantPadder dequant_per_group_padder; + bool is_int8; + QuantGranularity dequant_granularity; +}; + +#endif + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_pure_matmul.cce b/comm/lcal/src/kernels/coc_pure_matmul.cce new file mode 100644 index 0000000000000000000000000000000000000000..9922690d93711e54e43754f37031a00b3677910a --- /dev/null +++ b/comm/lcal/src/kernels/coc_pure_matmul.cce @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __CCE_KT_TEST__ +#define __aicore__ +#else +#define __aicore__ [aicore] +#endif + +#include "coc_ppmatmul_switch.cce" + +#ifdef __DAV_C220_CUBE__ +// LcalPureMatmul +#define COC_PURE_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalPureMatmul_##type##_mix_aic(COC_ARGS_FUN(type)) { \ + CocPpmatmulSwitchAic(COC_ARGS_CALL()); \ +} + + +#elif __DAV_C220_VEC__ +// LcalPureMatmul_Align +#define COC_PURE_MATMUL_FUNC_AUTO_DEF(type) \ +extern "C" __global__ __aicore__ void LcalPureMatmul_##type##_mix_aiv(COC_ARGS_FUN(type)) { \ + CocPureMatmulAiv(COC_ARGS_CALL()); \ +} +#endif + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // 910B support bf16 +#define COC_TYPE_FUNC(fun) fun(float16_t);fun(bfloat16_t) + +COC_TYPE_FUNC(COC_PURE_MATMUL_FUNC_AUTO_DEF); +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/coc_reduce_scatter.cce b/comm/lcal/src/kernels/coc_reduce_scatter.cce new file mode 100644 index 0000000000000000000000000000000000000000..ecf35301f05b39688935ba71d1fa0cdbbd13e147 --- /dev/null +++ b/comm/lcal/src/kernels/coc_reduce_scatter.cce @@ -0,0 +1,526 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifdef __DAV_C220_VEC__ +#include "coc_internal.cce" +#include "coc_comm_base.cce" +#include "kernel_operator.h" +using namespace AscendC; + +template +class ReduceScatter: public CocCommBase { +public: + __aicore__ explicit ReduceScatter() {}; + + FORCE_INLINE_AICORE void SetArgs(COC_ARGS_FUN(T)) { + CocCommBase::SetArgsForReduce(COC_ARGS_CALL()); + preprocessor.SetArgs(PP_MATMUL_AIV_PADDING_ARGS_CALL()); + if constexpr (HAVE_BIAS) { + add_bias_runner.SetArgs(PP_MATMUL_AIV_ADD_BIAS_ARGS_CALL()); + } + int32_t tail_m = (m / rank_size) % m0; + m_loop = m / rank_size / m0; + if (tail_m) { + m_loop += 1; + } + m_loop *= rank_size; + core_loop = batch_size * m_loop * n_loop; + cal_count = (core_loop + loop_num_per_comm - 1) / loop_num_per_comm; // 每次通信对应cal_count次计算 + + need_dequant = workspace_info.gm_accum; + if (need_dequant) { + fused_dequant_runner.SetArgs(reinterpret_cast<__gm__ bfloat16_t *>(buff[rank]), workspace_info, + reinterpret_cast<__gm__ int64_t *>(gm_dequant_scale), + reinterpret_cast<__gm__ int32_t *>(gm_dequant_offset), dequant_granularity, + batch_size, m, n, m0, n0, m_loop, n_loop, core_loop, swizzl_direct, + swizzl_count, p_value, rank_size); + } + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + fused_pertoken_dequant_runner.SetArgs(reinterpret_cast<__gm__ T *>(buff[rank]), + reinterpret_cast<__gm__ float32_t *>(gm_quant_scale), m, n, + m0, n0, m_loop, n_loop, core_loop, swizzl_direct, swizzl_count, p_value, rank_size); + } + + } + + FORCE_INLINE_AICORE void StartBeforeFisrtStep(bool needAdd) + { + if (needAdd) { + SetAtomicAdd(); + PipeBarrier(); + } + + SetFlag(EVENT_ID0); // MTE2等MTE3 + SetFlag(EVENT_ID1); // MTE2等MTE3 + } + + FORCE_INLINE_AICORE void EndFirstStep(bool needAdd) { + WaitFlag(EVENT_ID0); // MTE2等MTE3 + WaitFlag(EVENT_ID1); // MTE2等MTE3 + + if (needAdd) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + + FORCE_INLINE_AICORE void UbufToGm(int32_t m_offset, int32_t actual_m, + int32_t& actual_move_m, int32_t left_m, + int64_t batch_idx, int64_t m_idx, int64_t n_idx, + __ubuf__ T *ub_buff, int32_t actual_n) + { + // m0 - m_offset表示当前块剩下的一小段,跳过; + if (m_offset < actual_m) { + actual_move_m = actual_m < m_offset + left_m ? actual_m - m_offset : left_m; + // left_m较大,则该块copy完,下次再copy下一块; + // left_m较小,则只copy left_m的部分 + int64_t out_buff_offset = batch_idx * m * n / rank_size + + (m_idx * m0 + m_offset) * n + n_idx * n0; + CopyUbufToGmUnknown(ALIGN, gm_out + out_buff_offset, ub_buff, actual_move_m, + actual_n * sizeof(T), (n0 - actual_n) * sizeof(T) / 32, + (n - actual_n) * sizeof(T)); + } + } + /* + 将src卡的peermem数据累加到本卡localgm,并实现随路layout转换 + int32_t data_size_remain, copy数据量 + __gm__ T *input, src卡地址 + int32_t offset, src卡偏移量 + int32_t loop_idx_st, 偏移前的loopidx,用于计算本卡output位置 + */ + FORCE_INLINE_AICORE void FirstStepInOut(int32_t data_size_remain, __gm__ T *input, + int32_t gm_offset, int32_t move_offset, int32_t loop_idx_st) + { + int32_t ping_pong_move_count = (data_size_remain + max_ub_ping_pong_size - 1) / max_ub_ping_pong_size; // max_ub_ping_pong_size一定是N0的倍数,但不一定是M0*N0的倍数 + + for (int32_t move_idx = 0; move_idx < ping_pong_move_count; ++move_idx) { + int32_t actual_move_size = max_ub_ping_pong_size; + if (move_idx == ping_pong_move_count - 1) { + actual_move_size = data_size_remain - move_idx * max_ub_ping_pong_size; + } + auto event_id = (move_idx & 1) ? EVENT_ID0 : EVENT_ID1; + auto ub_buff_st = (move_idx & 1) ? output_UB_T[0] : output_UB_T[1]; + WaitFlag(event_id); + // 读的matrix是多个小的m0*n0块顺序排布,写的时候需要重排 + CopyGmToUbuf(ub_buff_st, input + gm_offset + move_idx * max_ub_ping_pong_size, 1, + actual_move_size * sizeof(T) / 32, 0, 0); + SetFlag(event_id); + WaitFlag(event_id); + + int32_t move_num_offset = move_offset + move_idx * max_ub_ping_pong_size; + auto ub_buff = ub_buff_st; + int32_t left_m = actual_move_size / n0; + while (left_m > 0) { + int32_t loop_idx = loop_idx_st + (move_num_offset / (m0 * n0)) * rank_size; + int64_t batch_idx = loop_idx / (m_loop * n_loop); + int32_t in_batch_idx = loop_idx % (m_loop * n_loop); + int32_t in_rank_idx = in_batch_idx / rank_size; + int64_t m_idx, n_idx; + GetBlockIdx(in_rank_idx, m_loop / rank_size, n_loop, swizzl_direct, swizzl_count, m_idx, n_idx); + int32_t actual_m = (m_idx == (m_loop / rank_size - 1)) ? (m / rank_size - m_idx * m0) : m0; + int32_t actual_n = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0; + int32_t m_offset = (move_num_offset % (m0 * n0)) / n0; // 当前一块起点对应的m,在当前块的位置 + int32_t actual_move_m = m0 < m_offset + left_m ? m0 - m_offset : left_m; + // m0 - m_offset表示当前块剩下的一小段,跳过; + if (m_offset < actual_m) { + actual_move_m = actual_m < m_offset + left_m ? actual_m - m_offset : left_m; + // left_m较大,则该块copy完,下次再copy下一块; + // left_m较小,则只copy left_m的部分 + int64_t out_buff_offset = batch_idx * m * n / rank_size + (m_idx * m0 + m_offset) * n + n_idx * n0; + CopyUbufToGmUnknown(ALIGN, gm_out + out_buff_offset, ub_buff, actual_move_m, actual_n * sizeof(T), + (n0 - actual_n) * sizeof(T) / 32, (n - actual_n) * sizeof(T)); + } + left_m -= actual_move_m; + move_num_offset += actual_move_m * n0; + ub_buff += actual_move_m * n0; + } + SetFlag(event_id); + } + } + + FORCE_INLINE_AICORE void FirstStepInOutWithSplit(int32_t rank_total, int32_t rank_offset, + int32_t loop_idx_st, int32_t data_loop_idx, bool isSio) + { + int32_t rank_per_core = isSio ? rank_size / 2 / comm_npu_split : rank_size / comm_npu_split; + int32_t before_core_offset = data_loop_idx * comm_data_split * len_per_loop; + int32_t core_rank_offset = (core_idx / comm_data_split) * rank_per_core; + int32_t core_offset = core_idx % comm_data_split * len_per_loop; + int32_t loop_total = rank_total - before_core_offset; + + int32_t rank_buff_offset = rank_offset + before_core_offset + core_offset; + + int32_t m_in_core = (core_offset >= loop_total) ? 0 : + ((core_offset + len_per_loop) > loop_total ? + loop_total - core_offset : len_per_loop); + + for (int32_t rank_idx = 0; rank_idx < rank_per_core; rank_idx++) { + // 由于有些服务器gm地址初始为脏数据,reduceScatter perToken量化场景中matmul数据全部写到了peerMem + // aiv写gm地址的时候均为atomic add,会导致在脏数据上进行累加,结果精度错误 + // 故perToken量化场景,此处第一次搬运不做累加,做覆盖搬运,从第二次开始做累加 + if ((is_int8 && (dequant_granularity == QuantGranularity::PER_TOKEN|| std::is_same::value)) && !isSio && (rank_idx == 1)) { + SetAtomicAdd(); + PipeBarrier(); + } + int32_t rank_idx_rot = (rank_idx + core_idx) % rank_per_core; + int32_t real_rank_idx = core_rank_offset + rank_idx_rot; + + real_rank_idx = isSio ? 2 * real_rank_idx + (rank % 2) : real_rank_idx; + + if (real_rank_idx == rank && !need_dequant && dequant_granularity != QuantGranularity::PER_TOKEN) + continue; + + FirstStepInOut(m_in_core, buff[real_rank_idx], rank_buff_offset, + before_core_offset + core_offset, loop_idx_st); + } + + if ((is_int8 && (dequant_granularity == QuantGranularity::PER_TOKEN|| std::is_same::value)) && !isSio) { + SetFlag(EVENT_ID0); // Scalar等MTE3 + WaitFlag(EVENT_ID0); + SetAtomicNone(); + PipeBarrier(); + } + } + + FORCE_INLINE_AICORE void RunLegacy() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + PipeBarrier(); + + // 初始化通知aic共享内存是空闲的 + int32_t max_flag_id = cal_count < MAX_BLOCK_COUNT? cal_count: MAX_BLOCK_COUNT; + for (int64_t cal_idx = 0; cal_idx < max_flag_id; ++cal_idx) { + if (cal_idx * loop_num_per_comm + core_idx < core_loop) { + SetAicSync(cal_idx); + } + } + for (int32_t cal_idx = 0; cal_idx < cal_count; ++cal_idx) { + uint64_t flag_idx = cal_idx % MAX_BLOCK_COUNT; + int32_t actual_loop_num = + (cal_idx == cal_count - 1) ? (core_loop - cal_idx * loop_num_per_comm) : loop_num_per_comm; + + m_per_rank = actual_loop_num * m0 / rank_size; + // wait aic + if (core_idx < actual_loop_num) { + WaitEvent(flag_idx); + } + if (need_dequant) { + //fused_dequant_runner.Run(cal_idx); + fused_dequant_runner.RunDequantReduceScatter(cal_idx); + } + if (dequant_granularity == QuantGranularity::PER_TOKEN) { + SetAndWaitAivSync(flag_idx); + //fused_pertoken_dequant_runner.Run(cal_idx); + fused_pertoken_dequant_runner.RunDequantReduceScatter(cal_idx); + } + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + + SetAndWaitAivSync(flag_idx); + bool needAdd = (is_int8 && (dequant_granularity == QuantGranularity::PER_TOKEN|| std::is_same::value)) ? false : true; + StartBeforeFisrtStep(needAdd); + + int32_t m_per_core = (m_per_rank * n0) / comm_data_split; + int32_t data_split_num = DivCeil(m_per_core, len_per_loop); + + int32_t rank_offset = flag_idx * m0 * n0 * loop_num_per_comm + rank * m_per_rank * n0; + for (int32_t loop_idx = 0; loop_idx < data_split_num; loop_idx++) { + if (aiv_idx == 0 && core_idx < comm_npu_split * comm_data_split) { + FirstStepInOutWithSplit(m_per_rank * n0, rank_offset, cal_idx * loop_num_per_comm, loop_idx, false); + } + } + + EndFirstStep(needAdd); + SetAndWaitAivSync(flag_idx); + + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + // aiv之间同步 + SetAndWaitAivSync(flag_idx); + + // 发送aic同步 + SetAicSync(flag_idx); + } + + ResetIpcFlags(2); + + if (aiv_idx == 1 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + } + + FORCE_INLINE_AICORE void DataCopySio(int32_t cal_idx_sio, int32_t len_per_rank) + { + + int32_t flag_idx_sio = cal_idx_sio % BLOCK_COUNT_3; + int32_t len_per_core = len_per_rank / SIO_TOTAL_CORE_NUM; + int32_t sio_core_idx = core_idx - core_count; + int32_t core_offset = sio_core_idx * len_per_core; + int32_t sio_peer_rank = rank ^ 1; + int32_t size_per_rank = gm_c_pingpong_size / rank_size; + // 循环搬所有卡;0卡读1卡的0 2 4 6 part + + for(int32_t src_rank = rank % 2; src_rank < rank_size; src_rank += 2) { + int32_t peer_offset = flag_idx_sio * gm_c_pingpong_size + src_rank * size_per_rank + core_offset; + if (src_rank == rank) { // eg. 0卡读1卡的0部分,直接存回local + StartBeforeFisrtStep(true); + FirstStepInOut(len_per_core, + buff[sio_peer_rank] + flag_idx_sio * gm_c_pingpong_size + src_rank * size_per_rank, + core_offset, core_offset, cal_idx_sio * loop_num_per_comm); + EndFirstStep(true); + } else { // eg. 0卡读1卡的2 4 6部分,存回peermem相同位置 + FirstStepInPeerMem(len_per_core, buff[sio_peer_rank] + peer_offset, buff[rank] + peer_offset, true); + } + } + } + + FORCE_INLINE_AICORE void RunWithSio() + { + // Padding + preprocessor.Run(); + + ResetIpcFlags(2); + PipeBarrier(); + + // 初始化通知aic共享内存是空闲的 + int32_t max_flag_id = cal_count < BLOCK_COUNT_3 ? cal_count: BLOCK_COUNT_3; + int32_t size_per_rank = gm_c_pingpong_size / rank_size; + for (int64_t cal_idx = 0; cal_idx < max_flag_id; ++cal_idx) { + SetAicSync(cal_idx); + } + int32_t tile_per_rank = loop_num_per_comm / rank_size; + for (int32_t cal_idx = 0; cal_idx < cal_count + 1; ++cal_idx) { + uint64_t flag_idx = cal_idx % BLOCK_COUNT_3; + int32_t hccs_idx = cal_idx - 1; // 先sio后hccs + int32_t flag_idx_hccs = hccs_idx % BLOCK_COUNT_3; + int32_t tile_per_rank_sio = + (cal_idx == cal_count - 1) ? (core_loop - cal_idx * loop_num_per_comm) / rank_size : tile_per_rank; + int32_t tile_per_rank_hccs = + (hccs_idx == cal_count - 1) ? (core_loop - hccs_idx * loop_num_per_comm) / rank_size : tile_per_rank; + + // wait aic + if (cal_idx < cal_count) { + WaitEvent(flag_idx); + } + + // aiv之间同步 + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + + CrossRankSyncV1(FLAG_ZERO_IDX, cal_idx + 1); + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + // 后SIO_TOTAL_CORE_NUM个core用于SIO搬运 + if (aiv_idx == 0 && core_idx >= core_count && + core_idx < core_count + SIO_TOTAL_CORE_NUM && cal_idx < cal_count) { // MoveSio + DataCopySio(cal_idx, tile_per_rank_sio * m0 * n0); + } + + StartBeforeFisrtStep(true); + int32_t m_per_core = tile_per_rank_hccs * m0 * n0 / comm_data_split; + int32_t data_split_num = DivCeil(m_per_core, len_per_loop); + + for (int32_t loop_idx = 0; loop_idx < data_split_num; loop_idx++) { + if (aiv_idx == 0 && core_idx < comm_npu_split * comm_data_split && cal_idx >= 1) { // 第二轮开始搬hccs + FirstStepInOutWithSplit(tile_per_rank_hccs * m0 * n0, + flag_idx_hccs * gm_c_pingpong_size + rank * size_per_rank, + hccs_idx * loop_num_per_comm, loop_idx, true); + } + } + EndFirstStep(true); + + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + CrossRankSyncV2(FLAG_ONE_IDX, cal_idx + 1); + // aiv之间同步 + SetAndWaitAivSync(flag_idx, BLOCK_COUNT_3); + + // 发送aic同步 + if (cal_idx >= 1) + SetAicSync(flag_idx_hccs); + } + + ResetIpcFlags(2); + + if (aiv_idx == 1 && core_idx < rank_size) { + CheckBuffFlag(ctrl_flags_UB, (__gm__ int32_t *)buff[other_rank] + flag_offset + FLAG_ZERO_IDX, 0); + } + PipeBarrier(); + + + if constexpr (HAVE_BIAS) { + add_bias_runner.Run(); + } + } + + FORCE_INLINE_AICORE void Run() + { + if (is_91093) { + RunWithSio(); + } else { + RunLegacy(); + } + } + +public: + using CocCommBase::SetAicSync; + using CocCommBase::SetAndWaitAivSync; + using CocCommBase::SetBuffFlag; + using CocCommBase::SetBuffFlagByAdd; + using CocCommBase::CheckBuffFlag; + using CocCommBase::FillZero; + using CocCommBase::FirstStepInPeerMem; + using CocCommBase::ResetIpcFlags; + using CocCommBase::CrossRankSyncV1; + using CocCommBase::CrossRankSyncV2; + using CocCommBase::buff; + using CocCommBase::gm_out; + using CocCommBase::ctrl_flags_UB; + using CocCommBase::output_UB_T; + using CocCommBase::batch_size; + using CocCommBase::m; + using CocCommBase::k; + using CocCommBase::n; + using CocCommBase::m0; + using CocCommBase::k0; + using CocCommBase::n0; + using CocCommBase::m_loop; + using CocCommBase::n_loop; + using CocCommBase::k_loop; + using CocCommBase::core_loop; + using CocCommBase::core_idx; + using CocCommBase::rank; + using CocCommBase::rank_size; + using CocCommBase::tiling_key; + using CocCommBase::swizzl_count; + using CocCommBase::swizzl_direct; + using CocCommBase::trans_a; + using CocCommBase::trans_b; + using CocCommBase::is_int8; + using CocCommBase::is_91093; + using CocCommBase::p_value; + using CocCommBase::aiv_idx; + using CocCommBase::other_rank; + using CocCommBase::comm_npu_split; + using CocCommBase::comm_data_split; + using CocCommBase::comm_direct; + using CocCommBase::len_per_loop; + using CocCommBase::core_count; + using CocCommBase::max_ub_single_dma_size; + using CocCommBase::max_ub_ping_pong_size; + using CocCommBase::loop_num_per_comm; + using CocCommBase::gm_c_pingpong_size; + using CocCommBase::dequant_granularity; + using CocCommBase::dequant_group_size; + using CocCommBase::quant_granularity; + using CocCommBase::quant_group_size; + using CocCommBase::workspace_info; + using CocCommBase::local_expert_nums; + using CocCommBase::is_moe; + using CocCommBase::is_moe_averaged; + using CocCommBase::is_alltoallvc; + using CocCommBase::is_deterministic; + using CocCommBase::weight_nz; + using CocCommBase::EP; + using CocCommBase::TP; + using CocCommBase::flag_offset; + + int32_t cal_count; + int32_t m_per_rank; + Preprocessor preprocessor; + MatmulReduceScatterBiasAdder add_bias_runner; + //ReduceScatterFusedPerTokenDequantRunner fused_pertoken_dequant_runner; + FusedPerTokenDequantRunner fused_pertoken_dequant_runner; + //FusedReduceScatterDequantRunner fused_dequant_runner; + FusedDequantRunner fused_dequant_runner; + bool need_dequant; +}; + +constexpr int32_t NO_BIAS_MASK2 = 0b000000 | 0b100000 | 0b010000 | 0b110000 | + 0b001000 | 0b101000 | 0b011000 | 0b111000; +constexpr int32_t BIAS_MASK2 = 0b000010 | 0b100010 | 0b010010 | 0b110010 | + 0b001010 | 0b101010 | 0b011010 | 0b111010; + +template +FORCE_INLINE_AICORE void RunReduceScatterAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) { + // 16 align + ReduceScatter reduce_scatter_align_16_without_bias; + ReduceScatter reduce_scatter_align_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + reduce_scatter_align_16_without_bias.SetArgs(COC_ARGS_CALL()); + reduce_scatter_align_16_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + reduce_scatter_align_16_with_bias.SetArgs(COC_ARGS_CALL()); + reduce_scatter_align_16_with_bias.Run(); + break; + default : + break; + } +} + +template +FORCE_INLINE_AICORE void RunReduceScatterUnAlign16(int32_t tiling_key, COC_ARGS_FUN(T)) { + // 16 unalign + ReduceScatter reduce_scatter_unalign_16_without_bias; + ReduceScatter reduce_scatter_unalign_16_with_bias; + switch (tiling_key) { + case 0b000000 : case 0b100000 : case 0b010000 : case 0b110000 : + case 0b001000 : case 0b101000 : case 0b011000 : case 0b111000 : + case 0b000100 : case 0b100100 : case 0b010100 : case 0b110100 : + case 0b001100 : case 0b101100 : case 0b011100 : case 0b111100 : + reduce_scatter_unalign_16_without_bias.SetArgs(COC_ARGS_CALL()); + reduce_scatter_unalign_16_without_bias.Run(); + break; + case 0b000010 : case 0b100010 : case 0b010010 : case 0b110010 : + case 0b001010 : case 0b101010 : case 0b011010 : case 0b111010 : + case 0b000110 : case 0b100110 : case 0b010110 : case 0b110110 : + case 0b001110 : case 0b101110 : case 0b011110 : case 0b111110 : + reduce_scatter_unalign_16_with_bias.SetArgs(COC_ARGS_CALL()); + reduce_scatter_unalign_16_with_bias.Run(); + break; + default : + break; + } +} + +template +FORCE_INLINE_AICORE void CocMatmulReduceScatterAiv(COC_ARGS_FUN(T)) { + SetAtomicNone(); + SetMaskNormImpl(); + SetSyncBaseAddr((uint64_t)ffts_addr); + SetVectorMask((uint64_t)-1, (uint64_t)-1); + + auto para = reinterpret_cast<__gm__ Lcal::CoCKernelParam *>(para_gm); + auto cocTilingData = ¶->cocTilingData; + int32_t n = cocTilingData->n; + int32_t tiling_key = cocTilingData->tilingKey; + if (n % BLOCK_SIZE_16 == 0) { + RunReduceScatterAlign16(tiling_key, COC_ARGS_CALL()); + } else { + RunReduceScatterUnAlign16(tiling_key, COC_ARGS_CALL()); + } + PipeBarrier(); +} + +#endif diff --git a/comm/lcal/src/kernels/collectives.cce b/comm/lcal/src/kernels/collectives.cce new file mode 100644 index 0000000000000000000000000000000000000000..401a7c840968a79d293119c47f35c2a93b033253 --- /dev/null +++ b/comm/lcal/src/kernels/collectives.cce @@ -0,0 +1,729 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_COLLECTIVES_H +#define LCAL_COLLECTIVES_H + +#if !defined(__DAV_C220_VEC__) && !defined(__DAV_M200_VEC__) && !defined(__DAV_C220_CUBE__) +#define __aicore__ +#define __ubuf__ +#define __gm__ +#endif + +#include +#include +#include "kernel_operator.h" +#include "comm_args.h" +#include "../ascendc_kernels/datacopy_gm2gm.h" +#include "coc_internal.cce" +using namespace AscendC; +using namespace Lcal; +constexpr int64_t UB_MAX_SIZE = 196608; + +constexpr int64_t MEM_DMA_UNIT_BYTE = 32; + +constexpr int64_t DMA_SIZE_PER_FLAG = UB_SINGLE_DMA_SIZE_MAX; + +constexpr int64_t EXCEPTION_VALUE = -11327; + +constexpr int64_t SIZE_OF_2M = 2 * 1024 * 1024; + +constexpr int64_t SIZE_OF_8M = 8 * 1024 * 1024; + +constexpr int64_t SIZE_OF_1M = 1 * 1024 * 1024; + +constexpr int64_t MAX_RANK_NUM_OF_ONE_910B2C = 16; + +constexpr int64_t MAX_SEND_COUNT_MATRIX_SIZ_OF_ONE_910B2C = MAX_RANK_NUM_OF_ONE_910B2C * MAX_RANK_NUM_OF_ONE_910B2C; + +constexpr int64_t ALL2ALL_V_C_BUFF_SIZE_PER_PARAGRAPH_910B2C = IPC_BUFF_MAX_SIZE / MAX_RANK_NUM_OF_ONE_910B2C / 2 * 2; + +constexpr int64_t DETERMINISTIC_BUFF_SIZE = (IPC_BUFF_MAX_SIZE >> 1) - 4 * 1024; + +#define ALLREDUCE_ARGS_FUN(T) \ +__gm__ T *input, __gm__ T *output, int rank, int rankSize, int64_t len, int64_t magic, int op, int root, \ +int localRankSize, int64_t loopTime, __gm__ int64_t *sendCountMatrix, GM_ADDR dumpAddr, \ +__gm__ T *buff0, __gm__ T *buff1, __gm__ T *buff2, __gm__ T *buff3, __gm__ T *buff4,\ +__gm__ T *buff5, __gm__ T *buff6, __gm__ T *buff7 + +#define ALLREDUCE_ARGS_CALL(type) \ +(__gm__ type *)input, (__gm__ type *) output, rank, rankSize, len, \ +magic, op, root, localRankSize, 0, nullptr, dumpAddr, shareAddrs[0], shareAddrs[1], shareAddrs[2], \ +shareAddrs[3], shareAddrs[4], shareAddrs[5], shareAddrs[6], shareAddrs[7] + +#define ALLREDUCE_ARGS_FUN_16P(T) \ +__gm__ T *input, __gm__ T *output, int rank, int rankSize, int64_t len, int64_t magic, int op, int root, \ +int localRankSize, int64_t loopTime, __gm__ int64_t *sendCountMatrix, GM_ADDR dumpAddr, \ +__gm__ T *buff0, __gm__ T *buff1, __gm__ T *buff2, __gm__ T *buff3, __gm__ T *buff4, \ +__gm__ T *buff5, __gm__ T *buff6, __gm__ T *buff7, __gm__ T *buff8, __gm__ T *buff9, \ +__gm__ T *buff10, __gm__ T *buff11, __gm__ T *buff12, __gm__ T *buff13, __gm__ T *buff14, __gm__ T *buff15 + +#define ALLREDUCE_ARGS_CALL_16P(type) \ +(__gm__ type *)input, (__gm__ type *) output, rank, rankSize, len, \ +magic, op, root, localRankSize, 0, nullptr, dumpAddr, shareAddrs[0], shareAddrs[1], shareAddrs[2], \ +shareAddrs[3], shareAddrs[4], shareAddrs[5], shareAddrs[6], shareAddrs[7], shareAddrs[8], shareAddrs[9], \ +shareAddrs[10], shareAddrs[11], shareAddrs[12], shareAddrs[13], shareAddrs[14], shareAddrs[15] \ + +#define ALLREDUCE_ARGS_FUN_16P_Origin(T) \ +__gm__ T *input, __gm__ T *output, int rank, int rankSize, int64_t len, int64_t magic, int op, int root, \ +int localRankSize, __gm__ int64_t *sendCountMatrix, GM_ADDR dumpAddr, __gm__ T* buff[MAX_RANK_NUM_OF_ONE_910B2C] + +#define ALLREDUCE_ARGS_CALL_16P_Origin() \ +input, output, rank, rankSize, len, magic, op, root, localRankSize, sendCountMatrix, dumpAddr, buff + +#define MODIFIABLE_MAGIC_PROCESSED_NUM_ALLREDUCE_ARGS_CALL_16P_Origin(processedNum, remainNum, magic) \ +(input + (processedNum)), (output + (processedNum)), rank, rankSize, (remainNum), (magic), op, root, \ +localRankSize, sendCountMatrix, dumpAddr, buff + +#define MODIFIABLE_MAGIC_ALLREDUCE_ARGS_CALL_16P(magic) \ +input, output, rank, rankSize, len, (magic), op, root, localRankSize, sendCountMatrix, dumpAddr, \ +buff0, buff1, buff2, buff3, buff4, buff5, buff6, buff7, buff8, buff9, buff10, buff11, \ +buff12, buff13, buff14, buff15 + +__attribute__((always_inline)) inline __aicore__ int64_t CeilDiv(int64_t source, int64_t cardinality) +{ + return (((source) + (cardinality) - 1) / (cardinality)); +} + +constexpr int64_t UB_SINGLE_ADD_SIZE_MAX = UB_SINGLE_DMA_SIZE_MAX; + +__attribute__((always_inline)) inline __aicore__ void CpUB2GMAlignB16(__gm__ void* gmAddr, __ubuf__ void* ubAddr, uint32_t size) +{ + CopyUbufToGmAlignB16(gmAddr, ubAddr, 1, size, 0, 0); +} + +__attribute__((always_inline)) inline __aicore__ void CpGM2UBAlignB16(__ubuf__ void* ubAddr, __gm__ void* gmAddr, uint32_t size) +{ + CopyGmToUbufAlignB16(ubAddr, gmAddr, 1, size, 0, 0); +} + +__attribute__((always_inline)) inline __aicore__ void DumpLcclLogInfo(GM_ADDR workspaceDumpAddr, LogId logId, Op operationType) +{ +#ifdef ENABLE_LCCL_DUMP + constexpr int32_t UB_HEAD_OFFSET = 96; + + AscendC::PipeBarrier(); + GM_ADDR blockGm = (GM_ADDR)(workspaceDumpAddr + LCCL_DUMP_UINT_SIZE * GetBlockIdx()); + __ubuf__ LcclDumpBlockInfo *blockUb = (__ubuf__ LcclDumpBlockInfo*)(UB_HEAD_OFFSET); + __ubuf__ LcclDumpLogInfo *logUb = (__ubuf__ LcclDumpLogInfo*)(UB_HEAD_OFFSET + sizeof(LcclDumpBlockInfo)); + + CpGM2UB((__ubuf__ uint8_t*)blockUb, blockGm, sizeof(LcclDumpBlockInfo)); + AscendC::PipeBarrier(); + + if (blockUb->dumpOffset < sizeof(LcclDumpLogInfo)) { + return; + } + + logUb->logId = logId; + logUb->blockId = GetBlockIdx(); + logUb->syscyc = static_cast(GetSystemCycle()); + logUb->curPc = static_cast(get_pc()); + logUb->operationType = operationType; + logUb->rsv = 0; + CpUB2GM((GM_ADDR) blockUb->dumpAddr, (__ubuf__ uint8_t*)logUb, sizeof(LcclDumpLogInfo)); + + blockUb->dumpAddr += sizeof(LcclDumpBlockInfo); + blockUb->dumpOffset -= sizeof(LcclDumpLogInfo); + CpUB2GM(blockGm, (__ubuf__ uint8_t*)blockUb, sizeof(LcclDumpBlockInfo)); + AscendC::PipeBarrier(); +#endif +} + +__attribute__((always_inline)) inline __aicore__ void SetFlag(__ubuf__ int64_t *ctrlFlagsUB, __gm__ int64_t *ctrlFlagGM, + int64_t checkValue) +{ + AscendC::PipeBarrier(); + *ctrlFlagsUB = checkValue; + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); + AscendC::PipeBarrier(); +} + +__attribute__((always_inline)) inline __aicore__ void SetFlagNonPipeBarrier(__ubuf__ int64_t *ctrlFlagsUB, __gm__ int64_t *ctrlFlagGM, + int64_t checkValue) +{ + *ctrlFlagsUB = checkValue; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); +} + +__attribute__((always_inline)) inline __aicore__ void SetFlag(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM1, __gm__ int64_t *ctrlFlagGM2, int64_t checkValue) +{ + AscendC::PipeBarrier(); + *ctrlFlagsUB = checkValue; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM1, ctrlFlagsUB, sizeof(int64_t)); + CpUB2GM(ctrlFlagGM2, ctrlFlagsUB, sizeof(int64_t)); + AscendC::PipeBarrier(); +} + +__attribute__((always_inline)) inline __aicore__ void SetFlagNonPipeBarrier(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM1, __gm__ int64_t *ctrlFlagGM2, int64_t checkValue) +{ + *ctrlFlagsUB = checkValue; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM1, ctrlFlagsUB, sizeof(int64_t)); + CpUB2GM(ctrlFlagGM2, ctrlFlagsUB, sizeof(int64_t)); +} + +__attribute__((always_inline)) inline __aicore__ void CheckFlag(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM, int64_t checkValue) +{ + while (true) { + AscendC::PipeBarrier(); + CpGM2UB(ctrlFlagsUB, ctrlFlagGM, sizeof(int64_t)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + if (*ctrlFlagsUB == checkValue) { + break; + } + } +} + +__attribute__((always_inline)) inline __aicore__ void CheckFlagNew(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM, int64_t checkValue) +{ + while (true) { + AscendC::PipeBarrier(); + CpGM2UB(ctrlFlagsUB, ctrlFlagGM, sizeof(int64_t)); + AscendC::PipeBarrier(); + if (*ctrlFlagsUB == checkValue || (*ctrlFlagsUB) == (checkValue + 1)) { + break; + } + } +} + +__attribute__((always_inline)) inline __aicore__ int64_t GetLcalBlockNum() { + #ifdef ENABLE_LCCL_MIX + constexpr int32_t aivNumPerAic = 2; + return GetBlockNum() * aivNumPerAic; + #else + return GetBlockNum(); + #endif +} + +__attribute__((always_inline)) inline __aicore__ void SyncWithinNPU(__ubuf__ int64_t* ctrlFlagsUB, __gm__ int64_t* buffRank, int64_t magic) { + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buffRank + (GetBlockIdx() * MEM_DMA_UNIT_INT_NUM), magic); + for (int64_t i = 0; i < GetLcalBlockNum(); i++) { + if (i == GetBlockIdx()) { + continue; + } + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, (__gm__ int64_t*)buffRank + i * MEM_DMA_UNIT_INT_NUM, magic); + } +} + +__attribute__((always_inline)) inline __aicore__ void SyncWithinNPUNew(__ubuf__ int64_t* ctrlFlagsUB, __gm__ int64_t* buffRank, int64_t magic) { + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buffRank + (GetBlockIdx() * MEM_DMA_UNIT_INT_NUM), magic); + for (int64_t i = 0; i < GetLcalBlockNum(); i++) { + if (i == GetBlockIdx()) { + continue; + } + CheckFlagNew((__ubuf__ int64_t*)ctrlFlagsUB, (__gm__ int64_t*)buffRank + i * MEM_DMA_UNIT_INT_NUM, magic); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void GM2GM( + int64_t dataSizeRemain, __ubuf__ T *inputUB, __gm__ T *receiveBuff, + int64_t revBuffOffsetNum, __gm__ T *sendBuff, int64_t sendBuffOffsetNum) +{ + int64_t times = 0; + while (dataSizeRemain >= UB_SINGLE_DMA_SIZE_MAX) { + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)receiveBuff + revBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + times += 1; + dataSizeRemain -= UB_SINGLE_DMA_SIZE_MAX; + } + if (dataSizeRemain <= 0) { + return; + } + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)receiveBuff + revBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + AscendC::PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void GM2GMPingPong( + int64_t dataSizeRemain, __ubuf__ T *inputUB[2], __gm__ T *receiveBuff, + int64_t revBuffOffsetNum, __gm__ T *sendBuff, int64_t sendBuffOffsetNum) +{ + if (dataSizeRemain <= 0) { + return; + } + AscendC::PipeBarrier(); + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > UB_SINGLE_PING_PONG_ADD_SIZE_MAX ? UB_SINGLE_PING_PONG_ADD_SIZE_MAX : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], (__gm__ T*)sendBuff + sendBuffOffsetNum, size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM((__gm__ T*)receiveBuff + revBuffOffsetNum, (i & 1) ? inputUB[0] : inputUB[1], size); + AscendC::SetFlag(eventId); + dataSizeRemain -= size; + sendBuffOffsetNum += (size / sizeof(T)); + revBuffOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::PipeBarrier(); + if (dataSizeRemain <= 0) { + return; + } +} + +template +__attribute__((always_inline)) inline __aicore__ void GM2GMPingPongNonPipeBarrier( + int64_t dataSizeRemain, __ubuf__ T *inputUB[2], __gm__ T *receiveBuff, + int64_t revBuffOffsetNum, __gm__ T *sendBuff, int64_t sendBuffOffsetNum) +{ + if (dataSizeRemain <= 0) { + return; + } + const int64_t offsetNumPerLoop = UB_SINGLE_PING_PONG_ADD_SIZE_MAX / sizeof(T); + uint32_t size = 0; + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + size = dataSizeRemain > UB_SINGLE_PING_PONG_ADD_SIZE_MAX ? UB_SINGLE_PING_PONG_ADD_SIZE_MAX : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], (__gm__ T*)sendBuff + sendBuffOffsetNum, size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM((__gm__ T*)receiveBuff + revBuffOffsetNum, (i & 1) ? inputUB[0] : inputUB[1], size); + AscendC::SetFlag(eventId); + dataSizeRemain -= size; + sendBuffOffsetNum += offsetNumPerLoop; + revBuffOffsetNum += offsetNumPerLoop; + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + if (dataSizeRemain <= 0) { + return; + } +} + +template +__attribute__((always_inline)) inline __aicore__ void input2BuffRankMagic( + int64_t dataSizeRemain, __ubuf__ T *inputUB, __gm__ T *ipcReceiveBuff, int64_t revBuffOffsetNum, + __gm__ T *sendBuff, int64_t sendBuffOffsetNum, __ubuf__ int64_t* ctrlFlagsUB, __gm__ int64_t* ctrlFlagGM, + int64_t magic) +{ + int64_t times = 0; + int64_t flag = 0; + + while (dataSizeRemain >= UB_SINGLE_DMA_SIZE_MAX) { + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)ipcReceiveBuff + revBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + times += 1; + flag = times * UB_SINGLE_DMA_SIZE_MAX / DMA_SIZE_PER_FLAG + magic; + if (flag != *ctrlFlagsUB && flag > 0) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + *ctrlFlagsUB = flag; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); + } + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + dataSizeRemain -= UB_SINGLE_DMA_SIZE_MAX; + } + if (dataSizeRemain <= 0) { + return; + } + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)ipcReceiveBuff + revBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + flag = CeilDiv(times * UB_SINGLE_DMA_SIZE_MAX + dataSizeRemain, DMA_SIZE_PER_FLAG) + magic; + AscendC::PipeBarrier(); + *ctrlFlagsUB = flag; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); + AscendC::PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void input2BuffRank( + int64_t dataSizeRemain, __ubuf__ T *inputUB, __gm__ T *ipcReceiveBuff, int64_t revBuffOffsetNum, + __gm__ T *sendBuff, int64_t sendBuffOffsetNum, __ubuf__ int64_t* ctrlFlagsUB, __gm__ int64_t* ctrlFlagGM) +{ + int64_t times = 0; + int64_t flag = 0; + + while (dataSizeRemain >= UB_SINGLE_DMA_SIZE_MAX) { + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)ipcReceiveBuff + revBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + times += 1; + flag = times * UB_SINGLE_DMA_SIZE_MAX / DMA_SIZE_PER_FLAG; + if (flag != *ctrlFlagsUB && flag > 0) { + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + *ctrlFlagsUB = flag; + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); + } + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + dataSizeRemain -= UB_SINGLE_DMA_SIZE_MAX; + } + if (dataSizeRemain <= 0) { + return; + } + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)ipcReceiveBuff + revBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + flag = CeilDiv(times * UB_SINGLE_DMA_SIZE_MAX + dataSizeRemain, DMA_SIZE_PER_FLAG); + AscendC::PipeBarrier(); + *ctrlFlagsUB = flag; + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM(ctrlFlagGM, ctrlFlagsUB, sizeof(int64_t)); + AscendC::PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void PostSyncBigData( + __ubuf__ int64_t *ctrlFlagsUB, __gm__ T* buff[8], uint32_t rank, uint32_t rankSize, + int64_t dataOffsetNum, int64_t ipcBuffMaxNum, int64_t magic, int64_t i) +{ + if (i <= 0) { + return; + } + + const int64_t postSyncFlagIdx = MEM_DMA_UNIT_INT_NUM + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + + SyncWithinNPUNew(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + MEM_DMA_UNIT_INT_NUM, magic + i); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic + i); + + for (int64_t targetNPU = 0; targetNPU < rankSize; targetNPU++) { + if (targetNPU == rank) { + continue; + } + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t *)((__gm__ T *)buff[targetNPU] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + CheckFlagNew(ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic + i); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void PostSyncBigData910B2C( + __ubuf__ int64_t *ctrlFlagsUB, __gm__ T* buff[MAX_RANK_NUM_OF_ONE_910B2C], uint32_t rank, uint32_t rankSize, + int64_t dataOffsetNum, int64_t ipcBuffMaxNum, int64_t magic, int64_t i, const int64_t peerRankId, + const int64_t singleNodeRankSize) +{ + if (i <= 0) { + return; + } + + const int64_t postSyncFlagIdx = MEM_DMA_UNIT_INT_NUM + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + + SyncWithinNPUNew(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + MEM_DMA_UNIT_INT_NUM, magic + i); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic + i); + + int64_t targetNPUBegin = rank < singleNodeRankSize ? 0 : singleNodeRankSize; + int64_t targetNPUEnd = rank < singleNodeRankSize ? singleNodeRankSize : rankSize; + for (int64_t targetNPU = targetNPUBegin; targetNPU < targetNPUEnd; targetNPU++) { + if (targetNPU == rank) { + continue; + } + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t *)((__gm__ T *)buff[targetNPU] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + CheckFlagNew(ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic + i); + } + const int64_t postSyncPeerFlagIdx = MEM_DMA_UNIT_INT_NUM + dataOffsetNum + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t* ctrlFlagsGMPeer = + (__gm__ int64_t *)((__gm__ T *)buff[peerRankId] + ipcBuffMaxNum) + dataOffsetNum + postSyncPeerFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMPeer, (int64_t)magic + i); + CheckFlagNew(ctrlFlagsUB, + (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + postSyncPeerFlagIdx, + (int64_t)magic + i); +} + +template +__attribute__((always_inline)) inline __aicore__ void PostSyncBigDataWriteAcrossCard( + __ubuf__ int64_t *ctrlFlagsUB, __gm__ T* buff[8], uint32_t rank, uint32_t rankSize, + int64_t dataOffsetNum, int64_t ipcBuffMaxNum, int64_t magic, int64_t i) +{ + const int64_t postSyncFlagIdx = MEM_DMA_UNIT_INT_NUM + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + int64_t x = (rank == 0) ? 1 : 0; + if (i > 0) { + SyncWithinNPUNew(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + MEM_DMA_UNIT_INT_NUM, magic + i); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t *)((__gm__ T *)buff[x] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic + i); + + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t *)((__gm__ T *)buff[rank] + ipcBuffMaxNum) + dataOffsetNum + postSyncFlagIdx; + CheckFlagNew(ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic + i); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void SetAtomicOp(int op) +{ + switch (op) { + case 0: + AscendC::SetAtomicAdd(); + break; + case 1: + break; + case 2: + AscendC::SetAtomicMax(); + break; + case 3: + AscendC::SetAtomicMin(); + break; + default: + ; + } +} + +__attribute__((always_inline)) inline __aicore__ void PostSync(__ubuf__ int64_t *ctrlFlagsUB, __gm__ int64_t **buff, + int32_t rank, int32_t rankSize, int64_t magic) +{ + if (GetBlockIdx() == 0) { + AscendC::PipeBarrier(); + *ctrlFlagsUB = rank + magic; + AscendC::PipeBarrier(); + CpUB2GM(buff[rank] + 1, ctrlFlagsUB, sizeof(int64_t)); + + AscendC::PipeBarrier(); + + for (int64_t x = 0; x < rankSize; ++x) { + if (x == rank) { + continue; + } + CheckFlag(ctrlFlagsUB, buff[x] + 1, x + magic); + } + } +} + +template +__attribute__((always_inline)) inline __aicore__ void ProcessData(int64_t dataSizeRemain, __ubuf__ T *inputUB, + __gm__ T *buff, int64_t dataOffsetNum, int64_t buffOffsetNum, __gm__ T *output, int64_t outputOffsetNum, int op) +{ + if (dataSizeRemain <= 0) { + return; + } + AscendC::PipeBarrier(); + #ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); + #endif + AscendC::PipeBarrier(); + + while (dataSizeRemain >= UB_SINGLE_ADD_SIZE_MAX) { + CpGM2UB(inputUB, (__gm__ T *)((__gm__ int64_t *)buff + dataOffsetNum) + buffOffsetNum, UB_SINGLE_ADD_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM((__gm__ T *)output + outputOffsetNum, inputUB, UB_SINGLE_ADD_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + dataSizeRemain -= UB_SINGLE_ADD_SIZE_MAX; + buffOffsetNum += (UB_SINGLE_ADD_SIZE_MAX / sizeof(T)); + outputOffsetNum += (UB_SINGLE_ADD_SIZE_MAX / sizeof(T)); + } + if (dataSizeRemain <= 0) { + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); + return; + } + + CpGM2UB(inputUB, (__gm__ T *)((__gm__ int64_t *)buff + dataOffsetNum) + buffOffsetNum, dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM((__gm__ T *)output + outputOffsetNum, (__ubuf__ T *)inputUB, dataSizeRemain); + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void ProcessDataNew(int64_t dataSizeRemain, __ubuf__ T *inputUB[2], + __gm__ T *buff, int64_t dataOffsetNum, int64_t buffOffsetNum, __gm__ T *output, int64_t outputOffsetNum, int op) +{ + if (dataSizeRemain <= 0) { + return; + } + + AscendC::PipeBarrier(); +#ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); +#endif + AscendC::PipeBarrier(); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > UB_SINGLE_PING_PONG_ADD_SIZE_MAX ? UB_SINGLE_PING_PONG_ADD_SIZE_MAX : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], (__gm__ T*)((__gm__ int64_t*)buff + dataOffsetNum) + buffOffsetNum, size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM((__gm__ T*)output + outputOffsetNum, (i & 1) ? inputUB[0] : inputUB[1], size); + AscendC::SetFlag(eventId); + + dataSizeRemain -= size; + buffOffsetNum += (size / sizeof(T)); + outputOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); + return; +} + + +template +__attribute__((always_inline)) inline __aicore__ void ProcessDataNewNonBarrier(int64_t dataSizeRemain, __ubuf__ T *inputUB[2], + __gm__ T *buff, int64_t dataOffsetNum, int64_t buffOffsetNum, __gm__ T *output, int64_t outputOffsetNum, int op) +{ + if (dataSizeRemain <= 0) { + return; + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); +#ifdef __DAV_C220_VEC__ + SetAtomicOpType(op); +#endif + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for (int64_t i = 0; dataSizeRemain > 0; i++) { + uint32_t size = dataSizeRemain > UB_SINGLE_PING_PONG_ADD_SIZE_MAX ? UB_SINGLE_PING_PONG_ADD_SIZE_MAX : dataSizeRemain; + event_t eventId = (i & 1) ? EVENT_ID0 : EVENT_ID1; + AscendC::WaitFlag(eventId); + CpGM2UB((i & 1) ? inputUB[0] : inputUB[1], (__gm__ T*)((__gm__ int64_t*)buff + dataOffsetNum) + buffOffsetNum, size); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + CpUB2GM((__gm__ T*)output + outputOffsetNum, (i & 1) ? inputUB[0] : inputUB[1], size); + AscendC::SetFlag(eventId); + + dataSizeRemain -= size; + buffOffsetNum += (size / sizeof(T)); + outputOffsetNum += (size / sizeof(T)); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + + AscendC::SetFlag(EVENT_ID3); + AscendC::WaitFlag(EVENT_ID3); + AscendC::SetAtomicNone(); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + return; +} + +__attribute__((always_inline)) inline __aicore__ void CheckFlagGE(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM, int64_t checkValue) +{ + while (true) { + AscendC::PipeBarrier(); + CpGM2UBAlignB16(ctrlFlagsUB, ctrlFlagGM, sizeof(int64_t)); + AscendC::PipeBarrier(); + if ((*ctrlFlagsUB >> 10) == (checkValue >> 10) && (*ctrlFlagsUB & 0x3FF) >= (checkValue & 0x3FF)) { + break; + } + } +} + +__attribute__((always_inline)) inline __aicore__ void NewCheckFlagGE(__ubuf__ int64_t *ctrlFlagsUB, + __gm__ int64_t *ctrlFlagGM, int64_t checkValue, event_t eventId) +{ + AscendC::SetFlag(eventId); + while (true) { + AscendC::WaitFlag(eventId); + CpGM2UBAlignB16(ctrlFlagsUB, ctrlFlagGM, sizeof(int64_t)); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + if ((*ctrlFlagsUB >> 20) == (checkValue >> 20) && (*ctrlFlagsUB & 0xFFFFF) >= (checkValue & 0xFFFFF)) { + break; + } + AscendC::SetFlag(eventId); + } +} + +__attribute__((always_inline)) inline __aicore__ int64_t GetDeterministicRankOffset(int64_t x) { + int64_t count = 1; + while (!(x & 1)) { + x >>= 1; + count <<= 1; + } + return count; +} + +__attribute__((always_inline)) inline __aicore__ void CopyInput2BuffBroadCast(__ubuf__ char* inputUB, __gm__ char* buff, + __gm__ char* input, int64_t singleCoreDataNum, + int64_t blockDataOffset) +{ + if (singleCoreDataNum <= 0) { + return; + } + CpGM2UBAlignB16(inputUB, input + blockDataOffset, singleCoreDataNum * sizeof(char)); + AscendC::PipeBarrier(); + + CpUB2GMAlignB16((__gm__ char*)((__gm__ int64_t * )buff + GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM) + blockDataOffset, + inputUB, singleCoreDataNum * sizeof(char)); + AscendC::PipeBarrier(); +} + + +#endif \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_all2all_transpose.cce b/comm/lcal/src/kernels/lcal_all2all_transpose.cce new file mode 100644 index 0000000000000000000000000000000000000000..68a772d0d6cf9c6d72dfc099c02d2bca52a0c565 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_all2all_transpose.cce @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAll2AllTranspose(ALLREDUCE_ARGS_FUN_16P(T)) +{ + int32_t width = root; + int32_t burstLen = width / rankSize; + const int64_t dataOffsetNum = GetLcalBlockNum() * 32 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + int numRows = len / width; + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(97312)}; + + int32_t coreIdx = GetBlockIdx(); + int32_t coreNum = GetLcalBlockNum(); + const int64_t corePerRank = coreNum / rankSize; + const int64_t coreIdxInRank = GetBlockIdx() % corePerRank; + const int64_t coreIdxRankId = GetBlockIdx() / corePerRank; + const int64_t rowNumPerCore = CeilDiv(numRows, corePerRank); + int64_t rowNumThisCore = rowNumPerCore; + if (coreIdxInRank == corePerRank - 1) { + rowNumThisCore = numRows - rowNumPerCore * (corePerRank - 1); + } + + const int64_t lenPerRank = len / rankSize; + AscendC::PipeBarrier(); + + AscendC::SetFlag(EVENT_ID0); + AscendC::SetFlag(EVENT_ID1); + for(int32_t loopId = 0; loopId < rowNumPerCore; ++loopId) { + event_t eventId = (loopId & 1)? EVENT_ID1 : EVENT_ID0; + int32_t rowId = loopId + coreIdxInRank * rowNumPerCore; + if (rowId >= numRows) { + break; + } + __gm__ T* srcPtr = (__gm__ T*)input + rowId * width + coreIdxRankId * burstLen; + __ubuf__ T* iub = (loopId & 1) ? inputUB[1] : inputUB[0]; + AscendC::WaitFlag(eventId); + CpGM2UB(iub, srcPtr, burstLen * sizeof(T)); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + __gm__ T* dstPtr = (__gm__ T*)buff[rank] + coreIdxRankId * lenPerRank + rowId * burstLen + dataOffsetNum; + if (coreIdxRankId == rank) { + dstPtr = (__gm__ T*) output + coreIdxRankId * lenPerRank + rowId * burstLen; + } + CpUB2GM(dstPtr, iub, burstLen * sizeof(T)); + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID1); + AscendC::PipeBarrier(); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*) buff[rank] + flagOffset1st; + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + __gm__ int64_t* ctrlFlagsGMWait = (__gm__ int64_t*)buff[coreIdxRankId] + (rank * corePerRank + coreIdxInRank) * MEM_DMA_UNIT_INT_NUM; + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMWait, (int64_t)magic); + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + __gm__ T* gm_src = (__gm__ T*)buff[coreIdxRankId] + + rank * lenPerRank + coreIdxInRank * rowNumPerCore * burstLen + dataOffsetNum; + __gm__ T* gm_dst = (__gm__ T*)output + coreIdxRankId * lenPerRank + + coreIdxInRank * rowNumPerCore * burstLen; + if (coreIdxRankId != rank) { + GM2GM(rowNumThisCore * burstLen * sizeof(T), inputUB[0], gm_dst, 0, gm_src, 0); + } +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allgather.cce b/comm/lcal/src/kernels/lcal_allgather.cce new file mode 100644 index 0000000000000000000000000000000000000000..f9be19a178060dab677b9e478cea2f83701f747e --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather.cce @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGather(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t x = GetBlockIdx() / corePerRank; + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(97312)}; + + int64_t dataNumRemain = len / GetLcalBlockNum(); + int64_t buffOffsetNum = rank * len + GetBlockIdx() * dataNumRemain; + if (GetBlockIdx() == GetLcalBlockNum() - 1) { + dataNumRemain = len - dataNumRemain * GetBlockIdx(); + } + + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ T *sendBuff = input; + int64_t sendBuffOffsetNum = buffOffsetNum - rank * len; + GM2GM(dataNumRemain * sizeof(T), inputUB[0], receiveBuff, buffOffsetNum, sendBuff, sendBuffOffsetNum); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*) buff[rank] + flagOffset1st; + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + + for (int64_t i = 0; i < GetLcalBlockNum(); i++) { + __gm__ int64_t* ctrlFlagsGMTemp = (__gm__ int64_t*)buff[x] + i * MEM_DMA_UNIT_INT_NUM; + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMTemp, (int64_t)magic); + } + dataNumRemain = len / corePerRank; + buffOffsetNum = x * len + coreSegmentedIdx * dataNumRemain; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumRemain = len - dataNumRemain * coreSegmentedIdx; + } + + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + GM2GM(dataNumRemain * sizeof(T), inputUB[0], (__gm__ T*)output, buffOffsetNum, sendBuff, buffOffsetNum); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allgather_2npu.cce b/comm/lcal/src/kernels/lcal_allgather_2npu.cce new file mode 100644 index 0000000000000000000000000000000000000000..713fe3202931e8275a8610d196649fad48019101 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather_2npu.cce @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +inline __aicore__ void LcalAllGather2npu(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t x = GetBlockIdx() / corePerRank; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = (rank * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset2nd = (x * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(32); + *ctrlFlagsUB2 = 0; + __ubuf__ T* inputUB[2] = { (__ubuf__ T*)(64), (__ubuf__ T*)(97312) }; + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + __gm__ T *sendBuff = input; + int64_t dataNumRemain = len / corePerRank; + int64_t sendBuffOffsetNum = coreSegmentedIdx * dataNumRemain; + int64_t buffOffsetNum = sendBuffOffsetNum + rank * len; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumRemain = len - dataNumRemain * coreSegmentedIdx; + } + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + GM2GM(dataNumRemain * sizeof(T), inputUB[0], receiveBuff, buffOffsetNum, sendBuff, sendBuffOffsetNum); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*) buff[x] + flagOffset1st; + AscendC::PipeBarrier(); + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + AscendC::PipeBarrier(); + + __gm__ int64_t* ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + flagOffset2nd; + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMTemp, (int64_t)magic); + + buffOffsetNum = sendBuffOffsetNum + x * len; + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + GM2GM(dataNumRemain * sizeof(T), inputUB[0], (__gm__ T*)output, buffOffsetNum, sendBuff, buffOffsetNum); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allgather_2npu_big_data_write.cce b/comm/lcal/src/kernels/lcal_allgather_2npu_big_data_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..39945f1e076a81f09d9fce474bde72313c038215 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather_2npu_big_data_write.cce @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGather2npuBigDataWriteOrigin( + __gm__ T* buff[8], __gm__ T *input, __gm__ T *output, int64_t processedNum, int64_t blockNumPerGroup, uint32_t rank, + uint32_t rankSize, int64_t allLen, int64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T* inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, int64_t flagOffset2nd, + int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx) +{ + const int64_t dataBlockAllNum = len * sizeof(T) / MEM_DMA_UNIT_BYTE; + const int64_t singleCoreDataBlockNum = dataBlockAllNum / blockNumPerGroup; + const int64_t singleCoreDataNum = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t buffDataDMAOffsetNum = coreSegmentedIdx * singleCoreDataNum; + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + __gm__ T *sendBuff = input; + int64_t dataSizeRemain = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE; + if (coreSegmentedIdx == blockNumPerGroup - 1) { + dataSizeRemain = (len - singleCoreDataNum * coreSegmentedIdx) * sizeof(T); + } + if (dataSizeRemain <= 0) { + return; + } + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + flagOffset1st; + if (GetBlockIdx() < blockNumPerGroup) { + input2BuffRankMagic( + dataSizeRemain, inputUB[0], receiveBuff, buffDataDMAOffsetNum, input, buffDataDMAOffsetNum, + ctrlFlagsUB, ctrlFlagsGMX, magic); + return; + } + GM2GMPingPong(dataSizeRemain, inputUB, output + allLen * rank + processedNum, buffDataDMAOffsetNum, input, buffDataDMAOffsetNum); + + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + int64_t allDataSizeNeed2Add = dataSizeRemain; + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG)) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGM, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) { + continue; + } + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 - magic); + if (preparedDataGroupCount <= 0 || *ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + + GM2GMPingPong(dataSizeRemain, inputUB, output + allLen * x + processedNum, + buffDataDMAOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum), + buffDataDMAOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T)); + AscendC::PipeBarrier(); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } + SetFlag(ctrlFlagsUB, ctrlFlagsGM, 0); +} + +template +inline __aicore__ void LcalAllGather2npuBigDataWrite(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = (rank == 0 ? 1 : 0); + if (GetBlockIdx() >= blockNumPerGroup) { + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + for (int64_t i = 0; i < CeilDiv(len, ipcBuffMaxNum); i++) { + *ctrlFlagsUB = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNum; + int64_t remainNum = (len - processedNum < ipcBuffMaxNum) ? len - processedNum : ipcBuffMaxNum; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalAllGather2npuBigDataWriteOrigin( + buff, input + processedNum, output, processedNum, blockNumPerGroup, rank, rankSize, len, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allgather_910B2C.cce b/comm/lcal/src/kernels/lcal_allgather_910B2C.cce new file mode 100644 index 0000000000000000000000000000000000000000..ea4b88387cfff66f633b3f022b35403047b55239 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather_910B2C.cce @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGather910B2C(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + + const int64_t singleNodeRankSize = rankSize >> 1; + if (GetBlockIdx() >= singleNodeRankSize + 2) { + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + return; + } + const int64_t localNodeRankId = rank >= singleNodeRankSize ? rank - singleNodeRankSize : rank; + const int64_t nodeId = rank < singleNodeRankSize ? 0 : 1; + + const int64_t peerRankId = rank < singleNodeRankSize ? rank + singleNodeRankSize : rank - singleNodeRankSize; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t corePerRank = 1; + + __gm__ T* buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(97312)}; + + int64_t dataSizeRemain = len * sizeof(T); + + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + + if (GetBlockIdx() >= singleNodeRankSize || GetBlockIdx() >= 8) { + int coreStep0Idx = 0; + if (GetBlockIdx() == 9 || GetBlockIdx() == singleNodeRankSize + 1) { + coreStep0Idx = 1; + } + int64_t sendBuffOffsetNum = 0; + int64_t revBuffOffsetNum = 0; + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (nodeId + singleNodeRankSize) * MEM_DMA_UNIT_INT_NUM; + if ((rank < singleNodeRankSize && coreStep0Idx == 1) || + (rank >= singleNodeRankSize && coreStep0Idx == 0)) { + receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[peerRankId] + dataOffsetNum); + ctrlFlagsGM = (__gm__ int64_t*)buff[peerRankId] + (nodeId + singleNodeRankSize) * MEM_DMA_UNIT_INT_NUM; + } + if (rank >= singleNodeRankSize) { + revBuffOffsetNum = len; + } + + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, input, sendBuffOffsetNum); + + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + return; + } + + int64_t x = rank < singleNodeRankSize ? GetBlockIdx() : GetBlockIdx() + singleNodeRankSize; + __gm__ T *receiveBuff = output; + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[x] + singleNodeRankSize * MEM_DMA_UNIT_INT_NUM, magic); + int64_t revBuffOffsetNum = GetBlockIdx() * len; + int64_t sendBuffOffsetNum = 0; + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, sendBuff, sendBuffOffsetNum); + + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[x] + (singleNodeRankSize + 1) * MEM_DMA_UNIT_INT_NUM, magic); + revBuffOffsetNum = (singleNodeRankSize + GetBlockIdx()) * len; + sendBuffOffsetNum = len; + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, sendBuff, sendBuffOffsetNum); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} diff --git a/comm/lcal/src/kernels/lcal_allgather_big_data.cce b/comm/lcal/src/kernels/lcal_allgather_big_data.cce new file mode 100644 index 0000000000000000000000000000000000000000..f6a59d335d225eb236c35f176be5e415ca119256 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather_big_data.cce @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGatherBigDataOrigin( + __gm__ T* buff[8], __gm__ T *input, __gm__ T *output, int64_t processedNum, int64_t blockNumPerGroup, uint32_t rank, + uint32_t rankSize, uint64_t allLen, uint64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1[16], + __ubuf__ int64_t* ctrlFlagsUB2[16], __ubuf__ T* inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, int64_t flagOffset2nd, + int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx) +{ + int64_t avgNumDMAPerCore = len / blockNumPerGroup; + int64_t dataNumRemain = avgNumDMAPerCore; + if (GetBlockIdx() == blockNumPerGroup - 1) { + dataNumRemain = len - dataNumRemain * GetBlockIdx(); + } + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ T *sendBuff = input; + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + flagOffset1st; + if (GetBlockIdx() < blockNumPerGroup) { + int64_t ipcBuffOffsetNum = GetBlockIdx() * avgNumDMAPerCore; + int64_t inputOffsetNum = GetBlockIdx() * avgNumDMAPerCore; + input2BuffRankMagic(dataNumRemain * sizeof(T), inputUB[0], receiveBuff, ipcBuffOffsetNum, + sendBuff, inputOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + return; + } + + for (int64_t i = 0; i < blockNumPerGroup; i++) { + *ctrlFlagsUB1[i] = 0; + *ctrlFlagsUB2[i] = 0; + } + + while (true) { + for (int64_t blockGroup0Idx = 0; blockGroup0Idx < blockNumPerGroup; blockGroup0Idx++) { + if (*ctrlFlagsUB1[blockGroup0Idx] == INT64_MAX) { + continue; + } + + int64_t allDataSizeNeedDMA = avgNumDMAPerCore * sizeof(T); + if (blockGroup0Idx == blockNumPerGroup - 1) { + allDataSizeNeedDMA = (len - blockGroup0Idx * avgNumDMAPerCore) * sizeof(T); + } + + if (*ctrlFlagsUB1[blockGroup0Idx] * DMA_SIZE_PER_FLAG >= allDataSizeNeedDMA) { + *ctrlFlagsUB1[blockGroup0Idx] = INT64_MAX; + continue; + } + + ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (blockGroup0Idx) * MEM_DMA_UNIT_INT_NUM; + CpGM2UB(ctrlFlagsUB2[blockGroup0Idx], ctrlFlagsGMX, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if ((*ctrlFlagsUB2[blockGroup0Idx] >> 10) != (magic >> 10)) { + continue; + } + int64_t preparedDataGroupCount = *ctrlFlagsUB2[blockGroup0Idx] - magic; + if (preparedDataGroupCount <= 0 || *ctrlFlagsUB1[blockGroup0Idx] >= preparedDataGroupCount) { + continue; + } + + receiveBuff = (__gm__ T *)output; + sendBuff = (__gm__ T *)((__gm__ int64_t *)buff[x] + dataOffsetNum); + int64_t revBuffOffsetNum = x * allLen + processedNum + blockGroup0Idx * avgNumDMAPerCore + + *ctrlFlagsUB1[blockGroup0Idx] * DMA_SIZE_PER_FLAG / sizeof(T); + int64_t sendBuffOffsetNum = blockGroup0Idx * avgNumDMAPerCore + + *ctrlFlagsUB1[blockGroup0Idx] * DMA_SIZE_PER_FLAG / sizeof(T); + + int64_t dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB1[blockGroup0Idx]) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeedDMA) { + dataSizeRemain = allDataSizeNeedDMA - *ctrlFlagsUB1[blockGroup0Idx] * DMA_SIZE_PER_FLAG; + } + + AscendC::PipeBarrier(); + GM2GMPingPong(dataSizeRemain, inputUB, receiveBuff, revBuffOffsetNum, sendBuff, sendBuffOffsetNum); + AscendC::PipeBarrier(); + + *ctrlFlagsUB1[blockGroup0Idx] = preparedDataGroupCount; + AscendC::PipeBarrier(); + } + + bool finished = true; + for (int64_t blockGroup0Idx = 0; blockGroup0Idx < blockNumPerGroup; blockGroup0Idx++) { + if (*ctrlFlagsUB1[blockGroup0Idx] != INT64_MAX) { + finished = false; + break; + } + } + if (finished) { + break; + } + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGatherBigData(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1[16]; + __ubuf__ int64_t* ctrlFlagsUB2[16]; + for (int64_t i = 0; i * 8 < 128; i ++) { + ctrlFlagsUB1[i] = (__ubuf__ int64_t*)(32) + i * 8; + ctrlFlagsUB2[i] = (__ubuf__ int64_t*)(544) + i * 8; + } + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(1056), (__ubuf__ T*)(98336)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + int64_t dataLen = len; + for (int64_t i = 0; i < CeilDiv(dataLen, ipcBuffMaxNum); i++) { + *ctrlFlagsUB = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNum; + int64_t remainNum = (dataLen - processedNum < ipcBuffMaxNum) ? dataLen - processedNum : ipcBuffMaxNum; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalAllGatherBigDataOrigin( + buff, input + processedNum, output, processedNum, blockNumPerGroup, rank, rankSize, len, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allgather_big_data_910B2C.cce b/comm/lcal/src/kernels/lcal_allgather_big_data_910B2C.cce new file mode 100644 index 0000000000000000000000000000000000000000..ad6304736c8deef5a2c0e095a8bee7be0b0560e1 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allgather_big_data_910B2C.cce @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void CheckThenDMAGM2GM(__ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __gm__ int64_t *ctrlFlagsGMStep1ToCheck, int64_t newMagic, int64_t allDataSizeNeedDMA, + int64_t revBuffOffsetNumOrigin, int64_t processedDataNum, __gm__ T *sendBuff, __gm__ T *revBuff, + int64_t sendBuffOffsetNumOrigin, __ubuf__ T* inputUB[2], int64_t &processedDataGroupCount, int64_t multipleTimes) +{ + PipeBarrier(); + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGMStep1ToCheck, sizeof(int64_t)); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + if (*ctrlFlagsUB1 == 0 || ((*ctrlFlagsUB1 >> 10) != (newMagic >> 10))) { + return; + } + + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 & 0x3FF); + if (processedDataGroupCount >= preparedDataGroupCount) { + return; + } + + int64_t curDataSizeRemain = (preparedDataGroupCount - processedDataGroupCount) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount >= multipleTimes) { + curDataSizeRemain = allDataSizeNeedDMA - processedDataGroupCount * DMA_SIZE_PER_FLAG; + } + PipeBarrier(); + GM2GMPingPongNonPipeBarrier(curDataSizeRemain, inputUB, revBuff, + revBuffOffsetNumOrigin + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T), + sendBuff, sendBuffOffsetNumOrigin + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T)); + processedDataGroupCount = preparedDataGroupCount; + PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllGatherBigData910B2C(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + magic *= 1024; + + const int64_t singleNodeRankSize = rankSize >> 1; + + const int64_t allGatherBuffSizePerParagraph910B2C = IPC_BUFF_MAX_SIZE / 2 / sizeof(T) * sizeof(T); + + const int64_t allGatherBuffNumPerParagraph910B2C = allGatherBuffSizePerParagraph910B2C / sizeof(T); + + if (GetBlockIdx() >= singleNodeRankSize + 2) { + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + return; + } + const int64_t localNodeRankId = rank >= singleNodeRankSize ? rank - singleNodeRankSize : rank; + const int64_t nodeId = rank < singleNodeRankSize ? 0 : 1; + + const int64_t peerRankId = rank < singleNodeRankSize ? rank + singleNodeRankSize : rank - singleNodeRankSize; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t corePerRank = 1; + + __gm__ T *buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ int64_t* ctrlFlagsUB3 = (__ubuf__ int64_t*)(96); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(128), (__ubuf__ T*)(98336)}; + int64_t revBuffOffsetNumOrigin = nodeId * allGatherBuffNumPerParagraph910B2C; + int64_t processedDataNum = 0; + int64_t totalDataSizeRemain = len * sizeof(T); + const int64_t totalLoopTimes = CeilDiv(totalDataSizeRemain, allGatherBuffSizePerParagraph910B2C); + DumpLcclLogInfo(dumpAddr, LogId::INIT, Op::COPYONLY); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); + for (int i = 0; i < totalLoopTimes; i++) { + int64_t newMagic = (magic + i + 1) * 1024; + int64_t dataSizeRemain = (i == totalLoopTimes - 1) ? + (totalDataSizeRemain - processedDataNum * sizeof(T)) : allGatherBuffSizePerParagraph910B2C; + int64_t dataNumRemain = dataSizeRemain / sizeof(T); + if (GetBlockIdx() >= singleNodeRankSize) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ int64_t *ctrlFlagsGMStep0ToSet = (__gm__ int64_t*)buff[rank] + flagOffset1st; + if ((nodeId == 0 && GetBlockIdx() == singleNodeRankSize + 1)) { + ctrlFlagsGMStep0ToSet = (__gm__ int64_t*)buff[peerRankId] + singleNodeRankSize * MEM_DMA_UNIT_INT_NUM; + } + if ((nodeId == 1 && GetBlockIdx() == singleNodeRankSize)) { + ctrlFlagsGMStep0ToSet = (__gm__ int64_t*)buff[peerRankId] + (singleNodeRankSize + 1) * MEM_DMA_UNIT_INT_NUM; + } + if ((nodeId == 0 && GetBlockIdx() == singleNodeRankSize + 1) || + (nodeId == 1 && GetBlockIdx() == singleNodeRankSize)) { + receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[peerRankId] + dataOffsetNum); + } + + input2BuffRankMagic(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNumOrigin, input, + processedDataNum, ctrlFlagsUB, ctrlFlagsGMStep0ToSet, newMagic); + if (i < totalLoopTimes - 1) { + if ((nodeId == 0 && GetBlockIdx() == singleNodeRankSize)) { + int64_t checkFlagOffset = nodeId * GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM; + for (int checkLogicRank = 0; checkLogicRank < singleNodeRankSize; checkLogicRank++) { + CheckFlag(ctrlFlagsUB, + (__gm__ int64_t*)buff[rank] + checkFlagOffset + checkLogicRank * MEM_DMA_UNIT_INT_NUM, + newMagic); + } + } + if ((nodeId == 1 && GetBlockIdx() == singleNodeRankSize + 1)) { + int64_t checkFlagOffset = nodeId * GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM; + for (int checkLogicRank = 0; checkLogicRank < singleNodeRankSize; checkLogicRank++) { + CheckFlag(ctrlFlagsUB, + (__gm__ int64_t*)buff[rank] + checkFlagOffset + checkLogicRank * MEM_DMA_UNIT_INT_NUM, + newMagic); + } + } + if (nodeId == 0 && GetBlockIdx() == singleNodeRankSize + 1) { + for (int checkLogicRank = 0; checkLogicRank < singleNodeRankSize; checkLogicRank++) { + CheckFlag(ctrlFlagsUB, + (__gm__ int64_t*)buff[checkLogicRank] + (GetLcalBlockNum() + localNodeRankId) * MEM_DMA_UNIT_INT_NUM, + newMagic); + } + } + if (nodeId == 1 && GetBlockIdx() == singleNodeRankSize) { + for (int checkLogicRank = 0; checkLogicRank < singleNodeRankSize; checkLogicRank++) { + CheckFlag(ctrlFlagsUB, + (__gm__ int64_t*)buff[checkLogicRank + singleNodeRankSize] + localNodeRankId * MEM_DMA_UNIT_INT_NUM, + newMagic); + } + } + } + } else { + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + *ctrlFlagsUB3 = 0; + __gm__ int64_t *ctrlFlagsGMStep1ToCheck1st = (__gm__ int64_t*)buff[GetBlockIdx()] + singleNodeRankSize * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t *ctrlFlagsGMStep1ToSet1st = (__gm__ int64_t*)buff[GetBlockIdx()] + localNodeRankId * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t *ctrlFlagsGMStep1ToCheck2nd = (__gm__ int64_t*)buff[GetBlockIdx()] + (singleNodeRankSize + 1) * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t *ctrlFlagsGMStep1ToSet2nd = (__gm__ int64_t*)buff[peerRankId] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM; + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[GetBlockIdx()] + dataOffsetNum); + if (nodeId == 1) { + ctrlFlagsGMStep1ToCheck1st = (__gm__ int64_t*)buff[GetBlockIdx() + singleNodeRankSize] + singleNodeRankSize * MEM_DMA_UNIT_INT_NUM; + ctrlFlagsGMStep1ToSet1st = (__gm__ int64_t*)buff[peerRankId] + (GetBlockIdx() + GetLcalBlockNum()) * MEM_DMA_UNIT_INT_NUM; + ctrlFlagsGMStep1ToCheck2nd = (__gm__ int64_t*)buff[GetBlockIdx() + singleNodeRankSize] + (singleNodeRankSize + 1) * MEM_DMA_UNIT_INT_NUM; + ctrlFlagsGMStep1ToSet2nd = (__gm__ int64_t*)buff[GetBlockIdx() + singleNodeRankSize] + (localNodeRankId + GetLcalBlockNum()) * MEM_DMA_UNIT_INT_NUM; + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[GetBlockIdx() + singleNodeRankSize] + dataOffsetNum); + } + int64_t revBuffOffsetNumOrigin1st = GetBlockIdx() * len + processedDataNum; + int64_t revBuffOffsetNumOrigin2nd = (GetBlockIdx() + singleNodeRankSize) * len + processedDataNum; + int64_t allDataSizeNeedDMA = dataSizeRemain; + int64_t multipleTimes = CeilDiv(dataSizeRemain, DMA_SIZE_PER_FLAG); + bool step1NeedSetFirst = true; + bool step1NeedSetSecond = true; + + int64_t processedDataGroupCount1st = 0; + int64_t processedDataGroupCount2nd = 0; + PipeBarrier(); + while (true) { + if (processedDataGroupCount1st < multipleTimes) { + CheckThenDMAGM2GM(ctrlFlagsUB, ctrlFlagsUB1, ctrlFlagsGMStep1ToCheck1st, newMagic, dataSizeRemain, + revBuffOffsetNumOrigin1st, processedDataNum, sendBuff, output, 0, inputUB, + processedDataGroupCount1st, multipleTimes); + } else if (step1NeedSetFirst) { + if (i < totalLoopTimes - 1) { + SetFlag(ctrlFlagsUB1, ctrlFlagsGMStep1ToSet1st, newMagic); + } + step1NeedSetFirst = false; + } + + if (processedDataGroupCount2nd < multipleTimes) { + CheckThenDMAGM2GM(ctrlFlagsUB2, ctrlFlagsUB3, ctrlFlagsGMStep1ToCheck2nd, newMagic, dataSizeRemain, + revBuffOffsetNumOrigin2nd, processedDataNum, sendBuff, output, allGatherBuffNumPerParagraph910B2C, inputUB, + processedDataGroupCount2nd, multipleTimes); + } else if (step1NeedSetSecond) { + if (i < totalLoopTimes - 1) { + SetFlag(ctrlFlagsUB3, ctrlFlagsGMStep1ToSet2nd, newMagic); + } + step1NeedSetSecond = false; + } + + if (!step1NeedSetFirst && !step1NeedSetSecond) { + break; + } + } + } + processedDataNum += allGatherBuffNumPerParagraph910B2C; + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, Op::COPYONLY); +} diff --git a/comm/lcal/src/kernels/lcal_allreduce_2npu_big_write.cce b/comm/lcal/src/kernels/lcal_allreduce_2npu_big_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..d770dc58984ec77c954d109d733839e62c87bf00 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_2npu_big_write.cce @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduce2npuBigDataWriteOrigin + (__gm__ T* buff[8], __gm__ T *input, __gm__ T *output, int64_t blockNumPerGroup, uint32_t rank, uint32_t rankSize, + uint64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T* inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, + int64_t flagOffset2nd, int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx, int op) +{ + const int64_t dataBlockAllNum = CeilDiv(len * sizeof(T), MEM_DMA_UNIT_BYTE); + const int64_t singleCoreDataBlockNum = dataBlockAllNum / blockNumPerGroup; + const int64_t singleCoreDataNum = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t buffDataDMAOffsetNum = coreSegmentedIdx * singleCoreDataNum; + + int64_t dataSizeRemain = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE; + if (coreSegmentedIdx == blockNumPerGroup - 1) { + dataSizeRemain = (len - singleCoreDataNum * coreSegmentedIdx) * sizeof(T); + } + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + flagOffset1st; + if (GetBlockIdx() < blockNumPerGroup) { + input2BuffRankMagic( + dataSizeRemain, inputUB[0], receiveBuff, buffDataDMAOffsetNum, input, buffDataDMAOffsetNum, + ctrlFlagsUB, ctrlFlagsGMX, magic); + return; + } + GM2GMPingPong(dataSizeRemain, inputUB, output, buffDataDMAOffsetNum, input, buffDataDMAOffsetNum); + + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + int64_t allDataSizeNeed2Add = dataSizeRemain; + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG)) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGM, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) { + continue; + } + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 - magic); + if (preparedDataGroupCount <= 0 || *ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + ProcessDataNew(dataSizeRemain, inputUB, buff[rank], dataOffsetNum, + buffDataDMAOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + output, buffDataDMAOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), op); + AscendC::PipeBarrier(); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduce2npuBigDataWrite(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = (rank == 0 ? 1 : 0); + if (GetBlockIdx() >= blockNumPerGroup) { + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + for (int64_t i = 0; i < CeilDiv(len, ipcBuffMaxNum); i++) { + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNum; + int64_t remainNum = (len - processedNum < ipcBuffMaxNum) ? len - processedNum : ipcBuffMaxNum; + + PostSyncBigDataWriteAcrossCard(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalAllReduce2npuBigDataWriteOrigin( + buff, input + processedNum, output + processedNum, blockNumPerGroup, rank, rankSize, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx, op); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} diff --git a/comm/lcal/src/kernels/lcal_allreduce_2npu_read.cce b/comm/lcal/src/kernels/lcal_allreduce_2npu_read.cce new file mode 100644 index 0000000000000000000000000000000000000000..bbf2ec421bc0bc6f38c294509f1dfe55a07a6025 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_2npu_read.cce @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduce2npuRead(ALLREDUCE_ARGS_FUN(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(32), (__ubuf__ T*)(97312)}; + + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * coreSegmentedIdx; + const int64_t x = GetBlockIdx() / corePerRank; + + const int64_t dataBlockAllNum = len * sizeof(T) / MEM_DMA_UNIT_BYTE; + const int64_t singleCoreDataBlockNum = dataBlockAllNum / corePerRank; + const int64_t singleCoreDataNum = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t buffOffsetNum = coreSegmentedIdx * singleCoreDataNum; + + int64_t dataSizeRemain = singleCoreDataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (len - singleCoreDataNum * coreSegmentedIdx) * sizeof(T); + } + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + if (x == rank) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum); + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + flagOffset1st, magic); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + return; + } else { + GM2GM(dataSizeRemain, inputUB[0], output, buffOffsetNum, input, buffOffsetNum); + } + + CheckFlag(ctrlFlagsUB, (((__gm__ int64_t*)buff[x]) + flagOffset1st), magic); + CheckFlag(ctrlFlagsUB, (((__gm__ int64_t*)buff[rank]) + flagOffset1st), magic); + + ProcessData(dataSizeRemain, inputUB[0], buff[x], dataOffsetNum, buffOffsetNum, output, buffOffsetNum, op); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_2npu_write.cce b/comm/lcal/src/kernels/lcal_allreduce_2npu_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..a4f058160db513144d7708cb90fa31c2efb78094 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_2npu_write.cce @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduce2npuWrite(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(32), (__ubuf__ T*)(97312)}; + + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * coreSegmentedIdx; + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * (GetLcalBlockNum() + coreSegmentedIdx); + const int64_t x = GetBlockIdx() / corePerRank; + + const int64_t dataBlockAllNum = len * sizeof(T) / MEM_DMA_UNIT_BYTE; + const int64_t singleCoreDataBlockNum = dataBlockAllNum / corePerRank; + const int64_t singleCoreDataNum = singleCoreDataBlockNum * MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t buffOffsetNum = coreSegmentedIdx * singleCoreDataNum; + + int64_t dataSizeRemain = singleCoreDataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (len - singleCoreDataNum * coreSegmentedIdx) * sizeof(T); + } + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + if (x != rank) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum); + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[x] + flagOffset1st, magic); + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + flagOffset2nd, magic); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + return; + } else { + GM2GM(dataSizeRemain, inputUB[0], output, buffOffsetNum, input, buffOffsetNum); + } + + CheckFlag(ctrlFlagsUB, (((__gm__ int64_t*)buff[rank]) + flagOffset1st), magic); + CheckFlag(ctrlFlagsUB, (((__gm__ int64_t*)buff[rank]) + flagOffset2nd), magic); + + ProcessData(dataSizeRemain, inputUB[0], buff[rank], dataOffsetNum, buffOffsetNum, output, buffOffsetNum, op); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_big_data.cce b/comm/lcal/src/kernels/lcal_allreduce_big_data.cce new file mode 100644 index 0000000000000000000000000000000000000000..d16ef7ede30570facf8d27037f2edf08e1baea2b --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_big_data.cce @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceBigDataOrigin + (__gm__ T* buff[16], __gm__ T *input, __gm__ T *output, int64_t blockNumPerGroup, uint32_t rank, uint32_t rankSize, + uint64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T* inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, + int64_t flagOffset2nd, int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx, int op) +{ + const int64_t memDmaUnitNum = MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t singleNPUProcessDataBlockNum = (len / memDmaUnitNum) / rankSize; + const int64_t singleNPUProcessDataNum = singleNPUProcessDataBlockNum * memDmaUnitNum; + int64_t thisNPUProcessDataNum = singleNPUProcessDataNum; + if (rank == rankSize - 1) { + thisNPUProcessDataNum = len - rank * singleNPUProcessDataNum; + } + + int64_t xNPUProcessDataNum = singleNPUProcessDataNum; + if (x == rankSize - 1) { + xNPUProcessDataNum = len - x * singleNPUProcessDataNum; + } + + const int64_t xNPUCoreGroupAvgDMADataNum = (xNPUProcessDataNum / corePerRank / memDmaUnitNum) * memDmaUnitNum; + const int64_t thisNPUCoreGroupAvgDMADataNum = (thisNPUProcessDataNum / corePerRank / memDmaUnitNum) * memDmaUnitNum; + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + int64_t buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + flagOffset1st; + if (GetBlockIdx() < blockNumPerGroup) { + input2BuffRankMagic(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + } else { + if (x == rank) { + goto label0; + } + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + __gm__ T *processOutput = (__gm__ T *)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + + int64_t allDataSizeNeed2Add = thisNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + allDataSizeNeed2Add = (thisNPUProcessDataNum - coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG)) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGM, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagsGMX, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if (*ctrlFlagsUB1 == 0 || *ctrlFlagsUB2 == 0 || + ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) || ((*ctrlFlagsUB2 >> 10) != (magic >> 10))) { + continue; + } + + int64_t preparedDataGroupCount = ((*ctrlFlagsUB1 & 0x3FF) <= (*ctrlFlagsUB2 & 0x3FF)) ? + (*ctrlFlagsUB1 & 0x3FF) : (*ctrlFlagsUB2 & 0x3FF); + if (*ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + buffOffsetNum = rank * singleNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + ProcessDataNew(dataSizeRemain, inputUB, buff[x], dataOffsetNum, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + processOutput, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), op); + AscendC::PipeBarrier(); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } + } +label0: + if (GetBlockIdx() >= blockNumPerGroup) { + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + flagOffset2nd, (int64_t)magic); + return; + } + AscendC::PipeBarrier(); + + for (int64_t i = 0; i < blockNumPerGroup; i++) { + if (i / corePerRank == x) { + continue; + } + __gm__ int64_t* ctrlFlagsGMTemp = ((__gm__ int64_t*)buff[x] + (GetLcalBlockNum() + i) * MEM_DMA_UNIT_INT_NUM); + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMTemp, (int64_t)magic); + } + + buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + GM2GMPingPong(dataSizeRemain, inputUB, output, buffOffsetNum, sendBuff, buffOffsetNum); +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceBigData(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + for (int64_t i = 0; i < CeilDiv(len, ipcBuffMaxNum); i++) { + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNum; + int64_t remainNum = (len - processedNum < ipcBuffMaxNum) ? len - processedNum : ipcBuffMaxNum; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalAllReduceBigDataOrigin( + buff, input + processedNum, output + processedNum, blockNumPerGroup, rank, rankSize, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx, op); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_big_data_910B2C.cce b/comm/lcal/src/kernels/lcal_allreduce_big_data_910B2C.cce new file mode 100644 index 0000000000000000000000000000000000000000..4f68ed44919891cb40cb4c536f08d0f5a6ae8677 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_big_data_910B2C.cce @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceBigData910B2COrigin( + ALLREDUCE_ARGS_FUN_16P_Origin(T), const int64_t singleNodeRankSize, const int64_t localNodeRankId, + const int64_t coreGroupIdx, const int64_t peerRankId, const int64_t dataOffsetNum, __ubuf__ int64_t* ctrlFlagsUB, + __ubuf__ int64_t* ctrlFlagsUB1, __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ int64_t* ctrlFlagsUB3, + __ubuf__ T* inputUB[2], const int64_t x, const int64_t xLocalNodeRankId, __gm__ int64_t *ctrlFlagGMSet, + __gm__ int64_t *ctrlFlagGMCheck, __gm__ T *sendBuff, __gm__ T *receiveBuff +) +{ + const int64_t oneNPUProcessDataAvgNum = len / singleNodeRankSize; + int64_t thisNPUProcessDataNum = oneNPUProcessDataAvgNum; + if (localNodeRankId == singleNodeRankSize - 1) { + thisNPUProcessDataNum = len - localNodeRankId * oneNPUProcessDataAvgNum; + } + + int64_t xNPUProcessDataNum = oneNPUProcessDataAvgNum; + if (xLocalNodeRankId == singleNodeRankSize - 1) { + xNPUProcessDataNum = len - xLocalNodeRankId * oneNPUProcessDataAvgNum; + } + + int64_t dataSizeRemain = xNPUProcessDataNum * sizeof(T); + *ctrlFlagsUB = 0; + if (coreGroupIdx == 0) { + const int64_t buffOffsetNum = xLocalNodeRankId * oneNPUProcessDataAvgNum; + AscendC::PipeBarrier(); + input2BuffRankMagic(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, sendBuff, buffOffsetNum, ctrlFlagsUB, ctrlFlagGMSet, magic); + } else if (coreGroupIdx == 1) { + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + __gm__ int64_t *ctrlFlagGMCheckLocal = (__gm__ int64_t*)buff[rank] + localNodeRankId * MEM_DMA_UNIT_INT_NUM; + + const int64_t buffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum; + const int64_t allDataSizeNeed2Add = thisNPUProcessDataNum * sizeof(T); + const int64_t multipleTimes = CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG); + if (x == rank || multipleTimes == 0) { + SetFlag(ctrlFlagsUB3, ctrlFlagGMSet, ((magic & 0xfffffffffffffc00) | multipleTimes)); + return; + } + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= multipleTimes) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagGMCheckLocal, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagGMCheck, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if (*ctrlFlagsUB1 == 0 || *ctrlFlagsUB2 == 0 || + ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) || ((*ctrlFlagsUB2 >> 10) != (magic >> 10))) { + continue; + } + + int64_t preparedDataGroupCount = ((*ctrlFlagsUB1 & 0x3FF) <= (*ctrlFlagsUB2 & 0x3FF)) ? + (*ctrlFlagsUB1 & 0x3FF) : (*ctrlFlagsUB2 & 0x3FF); + if (*ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount >= multipleTimes) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + ProcessDataNewNonBarrier(dataSizeRemain, inputUB, sendBuff, 0, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + receiveBuff, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), op); + SetFlag(ctrlFlagsUB3, ctrlFlagGMSet, ((*ctrlFlagsUB1 & 0xfffffffffffffc00) | preparedDataGroupCount)); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } + } else if (coreGroupIdx == 3) { + if (GetBlockIdx() == singleNodeRankSize * 3) { + __gm__ int64_t *ctrlFlagGMSetLocal = (__gm__ int64_t *)buff[rank] + (rankSize + 1) * MEM_DMA_UNIT_INT_NUM; + *ctrlFlagsUB2 = 0; + const int64_t buffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum; + const int64_t allDataSizeNeed2Add = thisNPUProcessDataNum * sizeof(T); + const int64_t multipleTimes = CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG); + int64_t processedDataGroupCount = 0; + int64_t preparedDataGroupCount = 0; + AscendC::PipeBarrier(); + while (true) { + *ctrlFlagsUB1 = INT64_MAX; + if (processedDataGroupCount >= multipleTimes) { + break; + } + + for (int i = 0; i < singleNodeRankSize; i++) { + if (i == localNodeRankId) { + continue; + } + *ctrlFlagsUB2 = 0; + + do { + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + CpGM2UB(ctrlFlagsUB2, ctrlFlagGMCheck + i * MEM_DMA_UNIT_INT_NUM, sizeof(int64_t)); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + } while ((*ctrlFlagsUB2 >> 10) != (magic >> 10)); + + if (*ctrlFlagsUB1 > *ctrlFlagsUB2) { + *ctrlFlagsUB1 = *ctrlFlagsUB2; + } + } + + preparedDataGroupCount = (*ctrlFlagsUB1 & 0x3FF); + if (processedDataGroupCount >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - processedDataGroupCount) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount >= multipleTimes) { + dataSizeRemain = allDataSizeNeed2Add - processedDataGroupCount * DMA_SIZE_PER_FLAG; + } + + AscendC::PipeBarrier(); + GM2GMPingPongNonPipeBarrier(dataSizeRemain, inputUB, receiveBuff, + len + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T), + sendBuff, + buffOffsetNum + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T)); + SetFlagNonPipeBarrier(ctrlFlagsUB3, ctrlFlagGMSet, ctrlFlagGMSetLocal, + ((*ctrlFlagsUB1 & 0xfffffffffffffc00) | preparedDataGroupCount)); + + processedDataGroupCount = preparedDataGroupCount; + } + } else { + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + __gm__ int64_t *ctrlFlagGMCheckLocal = (__gm__ int64_t*)buff[rank] + (rankSize + 1) * MEM_DMA_UNIT_INT_NUM; + + const int64_t buffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum; + const int64_t allDataSizeNeed2Add = thisNPUProcessDataNum * sizeof(T); + const int64_t multipleTimes = CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG); + int64_t processedDataGroupCount = 0; + int64_t preparedDataGroupCount = 0; + while (true) { + AscendC::PipeBarrier(); + if (processedDataGroupCount >= multipleTimes) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagGMCheckLocal, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagGMCheck, sizeof(int64_t)); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + if (*ctrlFlagsUB1 == 0 || *ctrlFlagsUB2 == 0 || + ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) || ((*ctrlFlagsUB2 >> 10) != (magic >> 10))) { + continue; + } + + preparedDataGroupCount = ((*ctrlFlagsUB1 & 0x3FF) <= (*ctrlFlagsUB2 & 0x3FF)) ? + (*ctrlFlagsUB1 & 0x3FF) : (*ctrlFlagsUB2 & 0x3FF); + if (processedDataGroupCount >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - processedDataGroupCount) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount >= multipleTimes) { + dataSizeRemain = allDataSizeNeed2Add - processedDataGroupCount * DMA_SIZE_PER_FLAG; + } + AscendC::PipeBarrier(); + ProcessDataNewNonBarrier(dataSizeRemain, inputUB, sendBuff, 0, buffOffsetNum + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T), + receiveBuff, len + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T), op); + SetFlagNonPipeBarrier(ctrlFlagsUB3, ctrlFlagGMSet, ((magic & 0xfffffffffffffc00) | preparedDataGroupCount)); + + processedDataGroupCount = preparedDataGroupCount; + } + } + } else if (coreGroupIdx == 2) { + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + const int64_t buffOffsetNum = xLocalNodeRankId * oneNPUProcessDataAvgNum; + const int64_t allDataSizeNeed2Add = xNPUProcessDataNum * sizeof(T); + const int64_t multipleTimes = CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG); + int64_t processedDataGroupCount = 0; + int64_t preparedDataGroupCount = 0; + + if (thisNPUProcessDataNum != 0) { + CheckFlag(ctrlFlagsUB, + (__gm__ int64_t*)buff[rank] + (singleNodeRankSize + xLocalNodeRankId) * MEM_DMA_UNIT_INT_NUM, + CeilDiv(thisNPUProcessDataNum * sizeof(T), DMA_SIZE_PER_FLAG) + magic); + } + + while (true) { + if (processedDataGroupCount >= multipleTimes) { + break; + } + + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + CpGM2UB(ctrlFlagsUB1, ctrlFlagGMCheck, sizeof(int64_t)); + SetFlag(EVENT_ID0); + WaitFlag(EVENT_ID0); + + if (*ctrlFlagsUB1 == 0 || ((*ctrlFlagsUB1 >> 10) != (magic >> 10))) { + continue; + } + + preparedDataGroupCount = (*ctrlFlagsUB1 & 0x3FF); + if (processedDataGroupCount >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - processedDataGroupCount) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount >= multipleTimes) { + dataSizeRemain = allDataSizeNeed2Add - processedDataGroupCount * DMA_SIZE_PER_FLAG; + } + AscendC::PipeBarrier(); + GM2GMPingPongNonPipeBarrier(dataSizeRemain, inputUB, receiveBuff, + buffOffsetNum + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T), + sendBuff, len + processedDataGroupCount * DMA_SIZE_PER_FLAG / sizeof(T)); + processedDataGroupCount = preparedDataGroupCount; + } + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceBigData910B2C(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic *= 1024; + __gm__ T* buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ int64_t* ctrlFlagsUB3 = (__ubuf__ int64_t*)(96); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(128), (__ubuf__ T*)(98336)}; + + const int64_t singleNodeRankSize = rankSize >> 1; + const int64_t localNodeRankId = rank >= singleNodeRankSize ? rank - singleNodeRankSize : rank; + + const int64_t coreGroupIdx = GetBlockIdx() / singleNodeRankSize; + + const int64_t peerRankId = rank < singleNodeRankSize ? rank + singleNodeRankSize : rank - singleNodeRankSize; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + + const int64_t x = (rank < singleNodeRankSize) ? (GetBlockIdx() % singleNodeRankSize) : + ((GetBlockIdx() % singleNodeRankSize) + singleNodeRankSize); + const int64_t xLocalNodeRankId = x % singleNodeRankSize; + + __gm__ T *sendBuff = input; + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ int64_t *ctrlFlagGMSet = ((__gm__ int64_t*)buff[rank] + (GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM); + + + __gm__ int64_t *ctrlFlagGMCheck = ((__gm__ int64_t*)buff[x] + (localNodeRankId) * MEM_DMA_UNIT_INT_NUM); + switch (coreGroupIdx) { + case 0: + break; + case 1: + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + ctrlFlagGMSet = ((__gm__ int64_t*)buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM); + break; + case 2: + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + receiveBuff = output; + ctrlFlagGMCheck = ((__gm__ int64_t*)buff[x] + (rankSize + 2) * MEM_DMA_UNIT_INT_NUM); + break; + case 3: + { + if (GetBlockIdx() == singleNodeRankSize * 3) { + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[peerRankId] + dataOffsetNum); + ctrlFlagGMCheck = ((__gm__ int64_t*)buff[rank] + singleNodeRankSize * MEM_DMA_UNIT_INT_NUM); + ctrlFlagGMSet = ((__gm__ int64_t*)buff[peerRankId] + rankSize * MEM_DMA_UNIT_INT_NUM); + } else { + sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + ctrlFlagGMCheck = ((__gm__ int64_t*)buff[rank] + rankSize * MEM_DMA_UNIT_INT_NUM); + ctrlFlagGMSet = ((__gm__ int64_t*)buff[rank] + (rankSize + 2) * MEM_DMA_UNIT_INT_NUM); + } + } + default: + ; + } + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + + const int64_t allreduceBuffSizePerParagraph910B2C = + IPC_BUFF_MAX_SIZE / (singleNodeRankSize + 1) / sizeof(T) * sizeof(T); + + const int64_t ipcBuffMaxSizePerLoop = allreduceBuffSizePerParagraph910B2C * singleNodeRankSize; + const int64_t ipcBuffMaxNumPerLoop = ipcBuffMaxSizePerLoop / sizeof(T); + const int64_t loopTimes = CeilDiv(len, ipcBuffMaxNumPerLoop); + const int64_t ipcMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + for (int64_t i = 0; i < loopTimes; i++) { + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + *ctrlFlagsUB3 = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNumPerLoop; + int64_t remainNum = (len - processedNum < ipcBuffMaxNumPerLoop) ? len - processedNum : ipcBuffMaxNumPerLoop; + + switch (coreGroupIdx) { + case 0: + sendBuff = input + processedNum; + break; + case 2: + receiveBuff = output + processedNum; + break; + default: + ; + } + + PostSyncBigData910B2C(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcMaxNum, magic, i, peerRankId, + singleNodeRankSize); + LcalAllReduceBigData910B2COrigin( + MODIFIABLE_MAGIC_PROCESSED_NUM_ALLREDUCE_ARGS_CALL_16P_Origin(processedNum, remainNum, ((magic + i) * 1024)), + singleNodeRankSize, localNodeRankId, coreGroupIdx, peerRankId, dataOffsetNum, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, ctrlFlagsUB3, inputUB, x, xLocalNodeRankId, ctrlFlagGMSet, ctrlFlagGMCheck, sendBuff, + receiveBuff + ); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_deterministic.cce b/comm/lcal/src/kernels/lcal_allreduce_deterministic.cce new file mode 100644 index 0000000000000000000000000000000000000000..595f1920439b171ceb97ff5761c99eed23769bac --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_deterministic.cce @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void Utils(__ubuf__ T * ub, __gm__ T * gm, T value) +{ + AscendC::PipeBarrier(); + *ub = value; + AscendC::PipeBarrier(); + CpUB2GM(gm, ub, sizeof(T)); + AscendC::PipeBarrier(); +} + +template +__attribute__((always_inline)) inline __aicore__ void SumByPairs( + __ubuf__ int64_t *ctrlFlagsUB, __gm__ T* buff[8], int64_t x, int64_t blockNumPerGroup, int64_t corePerRank, + int64_t coreSegmentedIdx, int64_t magic, int64_t deterministicOffNum, int64_t thisNPUProcessDataNum, + int64_t thisNPUCoreGroupAvgDMADataNum, int64_t dataOffsetNum, int64_t dataSizeRemain, __ubuf__ T *inputUB[2], + int op, int rank, int rankSize) { + int64_t target = 0; + __gm__ int64_t *ctrlFlagsGM; + __gm__ int64_t *ctrlFlagsGMTemp; + __gm__ int64_t *ctrlFlagsGMTemp1; + int64_t buffOffsetNum; + + if (x == 0) { + return; + } + + int64_t multiple = GetDeterministicRankOffset(x); + if ((x & 1) == 1) { + target = x - multiple; + ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup + target * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + CheckFlag(ctrlFlagsUB, ctrlFlagsGMTemp, magic); + + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + __gm__ T *processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + int64_t outputOffsetNum = deterministicOffNum + target * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + ProcessData(dataSizeRemain, inputUB[0], buff[rank], dataOffsetNum, buffOffsetNum, + processOutput, outputOffsetNum, op); + ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic + multiple); + } else { + target = x - multiple; + ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + (GetLcalBlockNum() + blockNumPerGroup + (target + multiple / 2) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + CheckFlag(ctrlFlagsUB, ctrlFlagsGMTemp, magic + multiple / 2); + + int64_t multipleTemp = multiple; + while (x + multipleTemp / 2 >= rankSize) { + multipleTemp /= 2; + } + if (multipleTemp > 1) { + ctrlFlagsGMTemp1 = (__gm__ int64_t*)buff[rank] + (GetLcalBlockNum() + blockNumPerGroup + (x + multipleTemp / 2) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + CheckFlag(ctrlFlagsUB, ctrlFlagsGMTemp1, magic + multipleTemp / 2); + } + + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + __gm__ T *processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + int64_t outputOffsetNum = deterministicOffNum + target * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + ProcessData(dataSizeRemain, inputUB[0], buff[rank], dataOffsetNum, buffOffsetNum, + processOutput, outputOffsetNum, op); + ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic + multiple); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceDeterministic(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic <<= 10; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + constexpr int32_t maxBuffSize = 16; + __gm__ T* buff[maxBuffSize] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + const int64_t singleNPUProcessDataNum = len / rankSize; + int64_t thisNPUProcessDataNum = singleNPUProcessDataNum; + if (rank == rankSize - 1) { + thisNPUProcessDataNum = len - rank * singleNPUProcessDataNum; + } + + int64_t xNPUProcessDataNum = singleNPUProcessDataNum; + if (x == rankSize - 1) { + xNPUProcessDataNum = len - x * singleNPUProcessDataNum; + } + + const int64_t xNPUCoreGroupAvgDMADataNum = xNPUProcessDataNum / corePerRank; + const int64_t thisNPUCoreGroupAvgDMADataNum = thisNPUProcessDataNum / corePerRank; + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + int64_t buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + int64_t deterministicOffNum = len; + + if (GetBlockIdx() < blockNumPerGroup) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + } else { + buffOffsetNum = rank * singleNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + dataSizeRemain = thisNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (thisNPUProcessDataNum - coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + __gm__ T *sendBuff = (__gm__ T *)((__gm__ int64_t *)buff[x] + dataOffsetNum); + int64_t revBuffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMX, magic); + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, sendBuff, buffOffsetNum); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM; + + if (rankSize >= 4) { + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + SumByPairs(ctrlFlagsUB, buff, x, blockNumPerGroup, corePerRank, coreSegmentedIdx, magic, deterministicOffNum, thisNPUProcessDataNum, + thisNPUCoreGroupAvgDMADataNum, dataOffsetNum, dataSizeRemain, inputUB, op, rank, rankSize); + } else { + SetFlag(ctrlFlagsUB, ctrlFlagsGM, ((x == 0) ? (magic + 1) : magic)); + if (x != 0) { + __gm__ int64_t *ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup + (x - 1) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + CheckFlag(ctrlFlagsUB, ctrlFlagsGMTemp, magic + 1); + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + __gm__ T *processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + int64_t outputOffsetNum = deterministicOffNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + ProcessData(dataSizeRemain, inputUB[0], buff[rank], dataOffsetNum, buffOffsetNum, + processOutput, outputOffsetNum, op); + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic + 1); + } + } + } + SyncWithinNPU(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)((__gm__ int64_t *)buff[rank] + dataOffsetNum) + IPC_BUFF_MAX_SIZE / sizeof(T)) + MEM_DMA_UNIT_INT_NUM, magic); + + if (GetBlockIdx() >= blockNumPerGroup) { + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + return; + } + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic); + + __gm__ int64_t* ctrlFlagsGMX= ((__gm__ int64_t*)buff[x] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM); + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic); + + buffOffsetNum = coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + int64_t revBuffOffsetNum = x * singleNPUProcessDataNum + buffOffsetNum; + int64_t sendBuffOffsetNum = deterministicOffNum + buffOffsetNum; + GM2GM(dataSizeRemain, inputUB[0], output, revBuffOffsetNum, sendBuff, sendBuffOffsetNum); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_deterministic_big_data.cce b/comm/lcal/src/kernels/lcal_allreduce_deterministic_big_data.cce new file mode 100644 index 0000000000000000000000000000000000000000..c9454d6a226455800d11cadb0be1c7fedf88b3c5 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_deterministic_big_data.cce @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void SumByPairsBigData( + __ubuf__ int64_t *ctrlFlagsUB, __ubuf__ int64_t *ctrlFlagsUB1, __ubuf__ int64_t *ctrlFlagsUB2, __gm__ T* buff[8], int64_t x, + int64_t blockNumPerGroup, int64_t corePerRank, int64_t coreSegmentedIdx, int64_t magic, int64_t deterministicOffNum, + int64_t thisNPUProcessDataNum, int64_t thisNPUCoreGroupAvgDMADataNum, int64_t dataOffsetNum, int64_t dataSizeRemain, __ubuf__ T *inputUB[2], + int op, int rank, int rankSize, int64_t allTimes, int64_t allDataSizeNeed2Add) { + int64_t target = 0; + __gm__ int64_t *ctrlFlagsGM; + __gm__ int64_t *ctrlFlagsGMTemp; + __gm__ int64_t *ctrlFlagsGMTemp1; + int64_t buffOffsetNum; + __gm__ T *processOutput; + int64_t outputOffsetNum; + + if (x == 0) { + return; + } + + int64_t multiple = GetDeterministicRankOffset(x); + if (x % 2 == 1) { + target = x - multiple; + ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup + target * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + ctrlFlagsGMTemp1 = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup + x * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + outputOffsetNum = deterministicOffNum + target * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + } else { + target = x - multiple; + ctrlFlagsGMTemp = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup * 2 + (target + multiple / 2) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + + int64_t multipleTemp = multiple; + while (x + multipleTemp / 2 >= rankSize) { + multipleTemp /= 2; + } + if (multipleTemp > 0) { + if ((x + multipleTemp / 2) != x) { + ctrlFlagsGMTemp1 = (__gm__ int64_t*)buff[rank] + + (blockNumPerGroup * 2 + (x + multipleTemp / 2) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + } else { + ctrlFlagsGMTemp1 = (__gm__ int64_t*)buff[rank] + + (blockNumPerGroup + x * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + } + + } + + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + outputOffsetNum = deterministicOffNum + target * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + } + AscendC::PipeBarrier(); + + while (true) { + if (*ctrlFlagsUB >= allTimes) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGMTemp, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagsGMTemp1, sizeof(int64_t)); + AscendC::PipeBarrier(); + if ((*ctrlFlagsUB1 >> 10) != (magic >> 10) || (*ctrlFlagsUB2 >> 10) != (magic >> 10)) { + continue; + } + + *ctrlFlagsUB1 = ((*ctrlFlagsUB1 & 0x3FF) <= (*ctrlFlagsUB2 & 0x3FF)) ? *ctrlFlagsUB1 : *ctrlFlagsUB2; + AscendC::PipeBarrier(); + + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 & 0x3FF); + if (*ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + ProcessDataNew(dataSizeRemain, inputUB, buff[rank], dataOffsetNum, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + processOutput, outputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), op); + AscendC::PipeBarrier(); + *ctrlFlagsUB = preparedDataGroupCount; + CpUB2GM((__gm__ int64_t *) buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM, ctrlFlagsUB1, sizeof(int64_t)); + AscendC::PipeBarrier(); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceDeterministicBigDataOrigin( + __gm__ T* buff[8], __gm__ T *input, __gm__ T *output, int64_t blockNumPerGroup, uint32_t rank, uint32_t rankSize, + uint64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T* inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, + int64_t flagOffset2nd, int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx, int op) +{ + const int64_t singleNPUProcessDataNum = len / rankSize; + int64_t thisNPUProcessDataNum = singleNPUProcessDataNum; + if (rank == rankSize - 1) { + thisNPUProcessDataNum = len - rank * singleNPUProcessDataNum; + } + + int64_t xNPUProcessDataNum = singleNPUProcessDataNum; + if (x == rankSize - 1) { + xNPUProcessDataNum = len - x * singleNPUProcessDataNum; + } + + const int64_t xNPUCoreGroupAvgDMADataNum = xNPUProcessDataNum / corePerRank; + const int64_t thisNPUCoreGroupAvgDMADataNum = thisNPUProcessDataNum / corePerRank; + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + int64_t buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + int64_t deterministicOffNum = len; + + if (GetBlockIdx() < blockNumPerGroup) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + input2BuffRankMagic(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + return; + } + + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + int64_t allDataSizeNeed2Add = thisNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + allDataSizeNeed2Add = (thisNPUProcessDataNum - coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + int64_t allTimes = CeilDiv(allDataSizeNeed2Add, DMA_SIZE_PER_FLAG); + + if (GetBlockIdx() < blockNumPerGroup * 2) { + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM; + + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + __gm__ T *sendBuff = (__gm__ T *)((__gm__ int64_t *)buff[x] + dataOffsetNum); + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= allTimes) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGMX, sizeof(int64_t)); + AscendC::PipeBarrier(); + if ((*ctrlFlagsUB1 >> 10) != (magic >> 10)) { + continue; + } + + int64_t preparedDataGroupCount = *ctrlFlagsUB1 & 0x3FF; + if (*ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + buffOffsetNum = rank * singleNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + int64_t revBuffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + + GM2GMPingPong(dataSizeRemain, inputUB, receiveBuff, revBuffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + sendBuff, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T)); + AscendC::PipeBarrier(); + *ctrlFlagsUB = preparedDataGroupCount; + if (x == 0) { + CpUB2GM((__gm__ int64_t *) buff[rank] + (GetBlockIdx() + blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM, ctrlFlagsUB1, sizeof(int64_t)); + } + CpUB2GM(ctrlFlagsGM, ctrlFlagsUB1, sizeof(int64_t)); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + } + } + + if (GetBlockIdx() >= blockNumPerGroup * 2) { + if (x == 0) { + return; + } + + if (rankSize >= 4) { + AscendC::PipeBarrier(); + SumByPairsBigData(ctrlFlagsUB, ctrlFlagsUB1, ctrlFlagsUB2, buff, x, blockNumPerGroup, corePerRank, coreSegmentedIdx, magic, deterministicOffNum, thisNPUProcessDataNum, + thisNPUCoreGroupAvgDMADataNum, dataOffsetNum, dataSizeRemain, inputUB, op, rank, rankSize, allTimes, allDataSizeNeed2Add); + } else { + __gm__ int64_t* ctrlFlagsGMPre = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup * 2 + (x - 1) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (blockNumPerGroup + x * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + buffOffsetNum = deterministicOffNum + x * thisNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + __gm__ T *processOutput = (__gm__ T*)((__gm__ int64_t *)buff[rank] + dataOffsetNum); + int64_t outputOffsetNum = deterministicOffNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB >= allTimes) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGMPre, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagsGM, sizeof(int64_t)); + AscendC::PipeBarrier(); + if ((*ctrlFlagsUB1 >> 10) != (magic >> 10) || (*ctrlFlagsUB2 >> 10) != (magic >> 10)) { + continue; + } + + *ctrlFlagsUB1 = ((*ctrlFlagsUB1 & 0x3FF) <= (*ctrlFlagsUB2 & 0x3FF)) ? *ctrlFlagsUB1 : *ctrlFlagsUB2; + AscendC::PipeBarrier(); + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 & 0x3FF); + if (*ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + ProcessDataNew(dataSizeRemain, inputUB, buff[rank], dataOffsetNum, buffOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), + processOutput, outputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T), op); + AscendC::PipeBarrier(); + *ctrlFlagsUB = preparedDataGroupCount; + CpUB2GM((__gm__ int64_t *) buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM, ctrlFlagsUB1, sizeof(int64_t)); + AscendC::PipeBarrier(); + } + } + SetFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[rank] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM, magic); + return; + } + + __gm__ int64_t* ctrlFlagsGMX; + if (rankSize >= 4) { + ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + + (GetLcalBlockNum() + 2 * blockNumPerGroup + + (rankSize > 4 ? 4 : 2) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + } else { + ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + + (GetLcalBlockNum() + 2 * blockNumPerGroup + + (rankSize - 1) * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + } + + constexpr int32_t lastFlagPos = 8; + constexpr int32_t sumPairGroup = 2; + if (rankSize > lastFlagPos) { + ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + MEM_DMA_UNIT_INT_NUM * + (GetLcalBlockNum() + sumPairGroup * blockNumPerGroup + lastFlagPos * corePerRank + coreSegmentedIdx); + } + + dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + buffOffsetNum = coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic); + + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + int64_t revBuffOffsetNum = x * singleNPUProcessDataNum + buffOffsetNum; + int64_t sendBuffOffsetNum = deterministicOffNum + buffOffsetNum; + + GM2GMPingPong(dataSizeRemain, inputUB, output, revBuffOffsetNum, sendBuff, sendBuffOffsetNum); + return; +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceDeterministicBigData(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic <<= 10; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + constexpr int32_t maxBuffSize = 16; + __gm__ T* buff[maxBuffSize] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() / 3; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup && GetBlockIdx() < 2 * blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } else if (GetBlockIdx() >= 2 * blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup * 2) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup * 2) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + int64_t ipcBuffDeterministicMaxNum = DETERMINISTIC_BUFF_SIZE / sizeof(T); + int64_t loopTimes = CeilDiv(len, ipcBuffDeterministicMaxNum); + for (int64_t i = 0; i < loopTimes; i++) { + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffDeterministicMaxNum; + int64_t remainNum = (len - processedNum < ipcBuffDeterministicMaxNum) ? len - processedNum : ipcBuffDeterministicMaxNum; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalAllReduceDeterministicBigDataOrigin( + buff, input + processedNum, output + processedNum, blockNumPerGroup, rank, rankSize, remainNum, (magic + i) << 10, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx, op); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_two_shot.cce b/comm/lcal/src/kernels/lcal_allreduce_two_shot.cce new file mode 100644 index 0000000000000000000000000000000000000000..0c9b4a65c3718d1bec0610415df95a87f335150e --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_two_shot.cce @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceTwoShot(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t x = GetBlockIdx() / corePerRank; + __gm__ T* buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(97312)}; + + const int64_t memDmaUnitNum = MEM_DMA_UNIT_BYTE / sizeof(T); + const int64_t singleNPUProcessDataBlockNum = len / memDmaUnitNum / rankSize; + const int64_t singleNPUProcessDataNum = singleNPUProcessDataBlockNum * memDmaUnitNum; + int64_t thisNPUProcessDataNum = singleNPUProcessDataNum; + if (rank == rankSize - 1) { + thisNPUProcessDataNum = len - rank * singleNPUProcessDataNum; + } + + int64_t xNPUProcessDataNum = singleNPUProcessDataNum; + if (x == rankSize - 1) { + xNPUProcessDataNum = len - x * singleNPUProcessDataNum; + } + + const int64_t xNPUCoreGroupAvgDMADataNum = xNPUProcessDataNum / corePerRank / memDmaUnitNum * memDmaUnitNum; + const int64_t thisNPUCoreGroupAvgDMADataNum = thisNPUProcessDataNum / corePerRank / memDmaUnitNum * memDmaUnitNum; + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + int64_t buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + if (input != (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum)) { + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, buffOffsetNum, input, buffOffsetNum); + AscendC::PipeBarrier(); + + SyncWithinNPU(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)((__gm__ int64_t *)buff[rank] + dataOffsetNum) + len) + MEM_DMA_UNIT_INT_NUM, magic); + AscendC::PipeBarrier(); + + SetFlag(ctrlFlagsUB, ctrlFlagsGM, magic); + } + + __gm__ T *processOutput = (__gm__ T *)(((__gm__ int64_t *)buff[rank]) + dataOffsetNum); + + if (x == rank) { + goto label0; + } + buffOffsetNum = rank * singleNPUProcessDataNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + dataSizeRemain = thisNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (thisNPUProcessDataNum - coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + ctrlFlagsGM = ((__gm__ int64_t*)buff[x]) + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + if (input != (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum)) { + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic); + } + ProcessData(dataSizeRemain, inputUB[0], buff[x], dataOffsetNum, buffOffsetNum, processOutput, buffOffsetNum, op); + +label0: + ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset2nd; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic); + + for (int i = 0; i < GetLcalBlockNum(); i++) { + if (i / corePerRank == x) { + continue; + } + __gm__ int64_t* ctrlFlagsGMTemp = ((__gm__ int64_t*)buff[x] + (GetLcalBlockNum() + i) * MEM_DMA_UNIT_INT_NUM); + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGMTemp, (int64_t)magic); + } + + buffOffsetNum = x * singleNPUProcessDataNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + __gm__ T *sendBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + GM2GM(dataSizeRemain, inputUB[0], output, buffOffsetNum, sendBuff, buffOffsetNum); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_allreduce_two_shot_910B2C.cce b/comm/lcal/src/kernels/lcal_allreduce_two_shot_910B2C.cce new file mode 100644 index 0000000000000000000000000000000000000000..fe0f3ebdb85204c18ff76e10c5f5511df9e6993a --- /dev/null +++ b/comm/lcal/src/kernels/lcal_allreduce_two_shot_910B2C.cce @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalAllReduceTwoShot910B2C(ALLREDUCE_ARGS_FUN_16P(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + const int64_t singleNodeRankSize = rankSize >> 1; + const int64_t localNodeRankId = rank >= singleNodeRankSize ? rank - singleNodeRankSize : rank; + + const int64_t peerRankId = rank < singleNodeRankSize ? rank + singleNodeRankSize : rank - singleNodeRankSize; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + + const int64_t x = GetBlockIdx() / corePerRank; + const int64_t xLocalRankId = x % singleNodeRankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + __gm__ T* buff[16] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7, + buff8, buff9, buff10, buff11, + buff12, buff13, buff14, buff15 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(97312)}; + + const int64_t oneNPUProcessDataAvgNum = len / singleNodeRankSize; + int64_t thisNPUProcessDataNum = oneNPUProcessDataAvgNum; + if (localNodeRankId == singleNodeRankSize - 1) { + thisNPUProcessDataNum = len - localNodeRankId * oneNPUProcessDataAvgNum; + } + + int64_t xNPUProcessDataNum = oneNPUProcessDataAvgNum; + if (xLocalRankId == singleNodeRankSize - 1) { + xNPUProcessDataNum = len - xLocalRankId * oneNPUProcessDataAvgNum; + } + + const int64_t xNPUCoreGroupAvgDMADataNum = xNPUProcessDataNum / corePerRank; + const int64_t thisNPUCoreGroupAvgDMADataNum = thisNPUProcessDataNum / corePerRank; + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + if ((rank < singleNodeRankSize && x < singleNodeRankSize) || + (rank >= singleNodeRankSize && x >= singleNodeRankSize)) { + __gm__ int64_t* ctrlFlagsGMSet = (__gm__ int64_t*)buff[rank] + (xLocalRankId + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + + int64_t sendBuffOffsetNum = xLocalRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + int64_t revBuffOffsetNum = xLocalRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, input, sendBuffOffsetNum); + SetFlag(ctrlFlagsUB, ctrlFlagsGMSet, magic); + + dataSizeRemain = thisNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (thisNPUProcessDataNum - coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + if (rank != x) { + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + (localNodeRankId + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[x] + (localNodeRankId + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + sendBuffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + revBuffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + ProcessData(dataSizeRemain, inputUB[0], buff[x], dataOffsetNum, sendBuffOffsetNum, receiveBuff, revBuffOffsetNum, op); + } + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + (xLocalRankId + singleNodeRankSize + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + if (rank == x) { + for (int i = 0; i < singleNodeRankSize; i++) { + if ((xLocalRankId + singleNodeRankSize + coreSegmentedIdx) == + (i * corePerRank + singleNodeRankSize + coreSegmentedIdx)) { + continue; + } + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + (i * corePerRank + singleNodeRankSize + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + } + + receiveBuff = ((__gm__ T*)((__gm__ int64_t*)buff[peerRankId] + dataOffsetNum)) + len; + sendBuffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + revBuffOffsetNum = coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + GM2GM(dataSizeRemain, inputUB[0], receiveBuff, revBuffOffsetNum, (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum), sendBuffOffsetNum); + + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[peerRankId] + (rankSize + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + (rankSize + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + + revBuffOffsetNum = localNodeRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum; + ProcessData(dataSizeRemain, inputUB[0], buff[rank], dataOffsetNum, len + coreSegmentedIdx * thisNPUCoreGroupAvgDMADataNum, + (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum), revBuffOffsetNum, op); + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + (rankSize + corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + } + + CheckFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[x] + (rankSize + corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM, magic); + + int64_t dataSizeRemain = xNPUCoreGroupAvgDMADataNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataSizeRemain = (xNPUProcessDataNum - coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum) * sizeof(T); + } + + sendBuffOffsetNum = xLocalRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum;; + revBuffOffsetNum = xLocalRankId * oneNPUProcessDataAvgNum + coreSegmentedIdx * xNPUCoreGroupAvgDMADataNum; + GM2GM(dataSizeRemain, inputUB[0], output, revBuffOffsetNum, (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum), sendBuffOffsetNum); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::OVERALL, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_broadcast_big_data.cce b/comm/lcal/src/kernels/lcal_broadcast_big_data.cce new file mode 100644 index 0000000000000000000000000000000000000000..5debb811ade5b0c461c58a23b858956a76eb53ef --- /dev/null +++ b/comm/lcal/src/kernels/lcal_broadcast_big_data.cce @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +__attribute__((always_inline)) inline __aicore__ void LcalBroadcastOrigin(ALLREDUCE_ARGS_FUN(char)) +{ + uint32_t blockSize = UB_SINGLE_DMA_SIZE_MAX; + int64_t magicInner = magic << SYNC_FLAG_BIT_NUM; + __gm__ char* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t * )(0); + __ubuf__ char* inputUB = (__ubuf__ char*)(64); + len = CeilDiv(len, 2) * 2; + const int64_t groupSize = rankSize - 1; + const int64_t groupDataSize = blockSize * groupSize; + const int64_t groupNum = CeilDiv(len, groupDataSize); + const int64_t blockTotalNum = CeilDiv(len, blockSize); + const int64_t lastBlockNum = CeilDiv(len - (groupNum - 1) * groupDataSize, blockSize); + if (GetBlockIdx() == rank || GetBlockIdx() >= rankSize) { + return; + } + int64_t inputIndex = GetBlockIdx(); + if (GetBlockIdx() == root) { + if (rank > root) { + inputIndex = rank - 1; + } else { + inputIndex = rank; + } + } else if (GetBlockIdx() > root) { + inputIndex = GetBlockIdx() - 1; + } + int64_t blockDataOffset; + int64_t remain; + for (int64_t currentCount = inputIndex; currentCount < blockTotalNum; currentCount += groupSize) { + blockDataOffset = currentCount * blockSize; + remain = blockSize; + if (currentCount == blockTotalNum - 1) { + remain = len - (blockTotalNum - 1) * blockSize; + } + if (rank == root) { + CopyInput2BuffBroadCast(inputUB, buff[rank], (__gm__ char*)input, remain, blockDataOffset); + SetFlag(ctrlFlagsUB, (__gm__ int64_t * )buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM, magicInner + currentCount); + } else { + if (GetBlockIdx() == root) { + CheckFlagGE(ctrlFlagsUB, ((__gm__ int64_t * )buff[GetBlockIdx()] + rank * MEM_DMA_UNIT_INT_NUM), + magicInner + currentCount); + } else { + CheckFlagGE(ctrlFlagsUB, ((__gm__ int64_t * )buff[GetBlockIdx()] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM), + magicInner + currentCount); + } + AscendC::PipeBarrier(); + + if (remain > 0) { + CpGM2UBAlignB16(inputUB, (__gm__ char*)((__gm__ int64_t * )buff[GetBlockIdx()] + GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM) +blockDataOffset, remain); + AscendC::PipeBarrier(); + CpUB2GMAlignB16((__gm__ char*)output + blockDataOffset, inputUB, remain); + if (GetBlockIdx() == root) { + CpUB2GMAlignB16((__gm__ char*)((__gm__ int64_t * )buff[rank] + GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM) + blockDataOffset, inputUB, remain); + SetFlag(ctrlFlagsUB, (__gm__ int64_t * )buff[rank] + rank * MEM_DMA_UNIT_INT_NUM, magicInner + currentCount); + } + } + } + } + if (rank != root) { + SetFlag(ctrlFlagsUB, (__gm__ int64_t * )buff[GetBlockIdx()] + (GetLcalBlockNum() + rank) * MEM_DMA_UNIT_INT_NUM, magic); + } + + if (rank == root) { + CheckFlag(ctrlFlagsUB, ((__gm__ int64_t * )buff[rank] + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM), + magic); + } else if (GetBlockIdx() == root) { + for (int64_t i = 0; i < rankSize; ++i) { + if (i == rank || i == root) { + continue; + } + CheckFlag(ctrlFlagsUB, ((__gm__ int64_t * )buff[rank] + (GetLcalBlockNum() + i) * MEM_DMA_UNIT_INT_NUM), magic); + } + } +} + +__attribute__((always_inline)) inline __aicore__ void LcalBroadcastBigData(ALLREDUCE_ARGS_FUN(char)) +{ + magic = magic << SYNC_FLAG_BIT_NUM; + __gm__ char* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t postSyncFlagIdx = MEM_DMA_UNIT_INT_NUM + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + const int64_t loopNum = CeilDiv(len, IPC_BUFF_MAX_SIZE); + + for (int64_t i = 0; i < loopNum; i++) { + int64_t processedNum = i * IPC_BUFF_MAX_SIZE; + int64_t remainNum = (len - processedNum < IPC_BUFF_MAX_SIZE) ? len - processedNum : IPC_BUFF_MAX_SIZE; + if (i > 0) { + SyncWithinNPUNew(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ char *)buff[rank] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + MEM_DMA_UNIT_INT_NUM, magic + i); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t *)((__gm__ char *)buff[rank] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + postSyncFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic + i); + + for (int64_t targetNPU = 0; targetNPU < rankSize; targetNPU++) { + if (targetNPU == rank) { + continue; + } + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t *)((__gm__ char *)buff[targetNPU] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + postSyncFlagIdx; + CheckFlagNew(ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic + i); + } + } + LcalBroadcastOrigin( + input + processedNum, output + processedNum, rank, rankSize, remainNum, magic + i, 0, root, localRankSize, + loopTime, sendCountMatrix, dumpAddr, buff0, buff1, buff2, buff3, buff4, buff5, buff6, buff7); + AscendC::PipeBarrier(); + } +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_broadcast_write.cce b/comm/lcal/src/kernels/lcal_broadcast_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..838372a06288533079adc693ffc2ec2eca17e9a8 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_broadcast_write.cce @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "collectives.cce" + +template +inline __aicore__ void GM2GMB8(int64_t dataSizeRemain, __ubuf__ T *inputUB, __gm__ T *receiveBuff, int64_t revBuffOffsetNum, + __gm__ T *sendBuff, int64_t sendBuffOffsetNum) +{ + int64_t times = 0; + AscendC::PipeBarrier(); + while (dataSizeRemain >= UB_SINGLE_DMA_SIZE_MAX) { + AscendC::PipeBarrier(); + CpGM2UBAlignB16(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::PipeBarrier(); + CpUB2GMAlignB16( + (__gm__ T*)receiveBuff + revBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + AscendC::PipeBarrier(); + times += 1; + dataSizeRemain -= UB_SINGLE_DMA_SIZE_MAX; + } + if (dataSizeRemain <= 0) { + return; + } + AscendC::PipeBarrier(); + CpGM2UBAlignB16(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + AscendC::PipeBarrier(); + CpUB2GMAlignB16( + (__gm__ T*)receiveBuff + revBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + AscendC::PipeBarrier(); +} + + +extern "C" __global__ __aicore__ void LcalBroadcastWrite(ALLREDUCE_ARGS_FUN(char)) +{ + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t x = GetBlockIdx() / corePerRank; + if (x >= rankSize) { + return; + } + if (rank != root && x != rank) { + return; + } + if (rank == root && x == root) { + return; + } + + __gm__ char* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ char* inputUB = (__ubuf__ char*)(64); + + const int64_t singleCoreDataSize = len / corePerRank; + int64_t dataNumRemain = singleCoreDataSize; + int64_t buffOffsetNum = coreSegmentedIdx * singleCoreDataSize; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumRemain = len - buffOffsetNum; + } + if (rank == root) { + __gm__ char *receiveBuff = (__gm__ char*)((__gm__ int64_t*)buff[x] + GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM); + GM2GMB8(dataNumRemain, inputUB, receiveBuff, buffOffsetNum, input, buffOffsetNum); + SetFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[x] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM, magic); + CheckFlag(ctrlFlagsUB, ((__gm__ int64_t*)buff[rank] + (GetBlockIdx() + GetLcalBlockNum()) * MEM_DMA_UNIT_INT_NUM), magic); + } else { + CheckFlag(ctrlFlagsUB, ((__gm__ int64_t*)buff[rank] + GetBlockIdx() * MEM_DMA_UNIT_INT_NUM), magic); + __gm__ char *sendBuff = (__gm__ char*)((__gm__ int64_t*)buff[x] + GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM); + __gm__ char *receiveBuff = (__gm__ char*)output; + GM2GMB8(dataNumRemain, inputUB, receiveBuff, buffOffsetNum, sendBuff, buffOffsetNum); + SetFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[root] + (GetBlockIdx() + GetLcalBlockNum()) * MEM_DMA_UNIT_INT_NUM, magic); + } +} + +__attribute__((always_inline)) inline __aicore__ void LcalBroadcast2npuBigDataWriteOrigin(ALLREDUCE_ARGS_FUN(char)) +{ + uint32_t blockSize = UB_SINGLE_DMA_SIZE_MAX; + int64_t magicInner = magic << SYNC_FLAG_BIT_NUM; + __gm__ char* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t * )(0); + __ubuf__ char* inputUB = (__ubuf__ char*)(64); + len = CeilDiv(len, 2) * 2; + const int64_t groupSize = rankSize - 1; + const int64_t groupDataSize = blockSize * groupSize; + const int64_t groupNum = CeilDiv(len, groupDataSize); + const int64_t blockTotalNum = CeilDiv(len, blockSize); + const int64_t lastBlockNum = CeilDiv(len - (groupNum - 1) * groupDataSize, blockSize); + if (GetBlockIdx() == rank || GetBlockIdx() >= rankSize) { + return; + } + int64_t inputIndex = GetBlockIdx(); + if (GetBlockIdx() == root) { + if (rank > root) { + inputIndex = rank - 1; + } else { + inputIndex = rank; + } + } else if (GetBlockIdx() > root) { + inputIndex = GetBlockIdx() - 1; + } + int64_t blockDataOffset; + int64_t remain; + if (rank == root) { + for (int64_t currentCount = inputIndex; currentCount < blockTotalNum; currentCount += groupSize) { + blockDataOffset = currentCount * blockSize; + remain = (currentCount == blockTotalNum - 1) ? (len - (blockTotalNum - 1) * blockSize) : blockSize; + CopyInput2BuffBroadCast(inputUB, buff[GetBlockIdx()], (__gm__ char*)input, remain, blockDataOffset); + SetFlag(ctrlFlagsUB, (__gm__ int64_t * )buff[GetBlockIdx()] + root * MEM_DMA_UNIT_INT_NUM, magicInner + currentCount); + } + } else { + for (int64_t currentCount = inputIndex; currentCount < blockTotalNum; currentCount += groupSize) { + blockDataOffset = currentCount * blockSize; + remain = (currentCount == blockTotalNum - 1) ? (len - (blockTotalNum - 1) * blockSize) : blockSize; + CheckFlagGE(ctrlFlagsUB, ((__gm__ int64_t * )buff[rank] + root * MEM_DMA_UNIT_INT_NUM), magicInner + currentCount); + AscendC::PipeBarrier(); + + if (remain > 0) { + CpGM2UB(inputUB, (__gm__ char*)((__gm__ int64_t * )buff[rank] + GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM) + blockDataOffset, remain); + AscendC::PipeBarrier(); + CpUB2GM((__gm__ char*)output + blockDataOffset, inputUB, remain); + } + } + } + SetFlag(ctrlFlagsUB, (__gm__ int64_t * )buff[GetBlockIdx()] + (GetLcalBlockNum() + root) * MEM_DMA_UNIT_INT_NUM, magic); + CheckFlag(ctrlFlagsUB, ((__gm__ int64_t * )buff[rank] + (GetLcalBlockNum() + root) * MEM_DMA_UNIT_INT_NUM), magic); +} + +__attribute__((always_inline)) inline __aicore__ void LcalBroadcast2npuBigDataWrite(ALLREDUCE_ARGS_FUN(char)) +{ + magic = magic << SYNC_FLAG_BIT_NUM; + __gm__ char* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t postSyncFlagIdx = MEM_DMA_UNIT_INT_NUM + (GetLcalBlockNum() + GetBlockIdx()) * MEM_DMA_UNIT_INT_NUM; + const int64_t loopNum = CeilDiv(len, IPC_BUFF_MAX_SIZE); + + for (int64_t i = 0; i < loopNum; i++) { + int64_t processedNum = i * IPC_BUFF_MAX_SIZE; + int64_t remainNum = (len - processedNum < IPC_BUFF_MAX_SIZE) ? len - processedNum : IPC_BUFF_MAX_SIZE; + if (i > 0) { + SyncWithinNPUNew(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ char *)buff[rank] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + MEM_DMA_UNIT_INT_NUM, magic + i); + + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t *)((__gm__ char *)buff[rank] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + postSyncFlagIdx; + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, ctrlFlagsGM, (int64_t)magic + i); + + for (int64_t targetNPU = 0; targetNPU < rankSize; targetNPU++) { + if (targetNPU == rank) { + continue; + } + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t *)((__gm__ char *)buff[targetNPU] + IPC_BUFF_MAX_SIZE) + dataOffsetNum + postSyncFlagIdx; + CheckFlagNew(ctrlFlagsUB, ctrlFlagsGMX, (int64_t)magic + i); + } + } + LcalBroadcast2npuBigDataWriteOrigin( + input + processedNum, output + processedNum, rank, rankSize, remainNum, magic + i, 0, root, + localRankSize, loopTime, sendCountMatrix, dumpAddr, + buff0, buff1, buff2, buff3, buff4, buff5, buff6, buff7); + AscendC::PipeBarrier(); + } +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_reduce_scatter.cce b/comm/lcal/src/kernels/lcal_reduce_scatter.cce new file mode 100644 index 0000000000000000000000000000000000000000..eaae5176aa9db3f844ad7375e03dd6b9e53145b8 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_reduce_scatter.cce @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void CpInputToBuffAndOutput(__ubuf__ T** inputUB, __gm__ T* buff, __gm__ T* input, __gm__ T* output, + int64_t dataOffsetNum, int64_t dataNumDMARemain, int64_t inputOffset, + int64_t outputOffsetNum, int32_t rank, int64_t corePerRank, + int64_t UB_SINGLE_DMA_NUM_MAX) +{ + int64_t dataProcessingBatchTime = 0; + while (dataNumDMARemain >= UB_SINGLE_DMA_NUM_MAX) { + CpGM2UB(inputUB[0], input + inputOffset + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::PipeBarrier(); + if (GetBlockIdx() >= rank * corePerRank && (GetBlockIdx() < (rank * corePerRank + corePerRank))) { + CpUB2GM((__gm__ T *)output + outputOffsetNum + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + inputUB[0], UB_SINGLE_DMA_SIZE_MAX); + } else { + CpUB2GM( + (__gm__ T *)((__gm__ int64_t *)buff + dataOffsetNum) + inputOffset + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + inputUB[0], UB_SINGLE_DMA_SIZE_MAX); + } + AscendC::PipeBarrier(); + dataNumDMARemain -= UB_SINGLE_DMA_NUM_MAX; + dataProcessingBatchTime += 1; + AscendC::PipeBarrier(); + } + if (dataNumDMARemain <= 0) { + return; + } + CpGM2UB(inputUB[0], input + inputOffset + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + dataNumDMARemain * sizeof(T)); + AscendC::PipeBarrier(); + if (GetBlockIdx() >= rank * corePerRank && (GetBlockIdx() < (rank * corePerRank + corePerRank))) { + CpUB2GM((__gm__ T *)output + outputOffsetNum + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + inputUB[0], dataNumDMARemain * sizeof(T)); + AscendC::PipeBarrier(); + } else { + CpUB2GM( + (__gm__ T *)((__gm__ int64_t *)buff + dataOffsetNum) + inputOffset + UB_SINGLE_DMA_NUM_MAX * dataProcessingBatchTime, + inputUB[0], dataNumDMARemain * sizeof(T)); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalReduceScatter(ALLREDUCE_ARGS_FUN(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(64), (__ubuf__ T*)(98304)}; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + const int64_t flagOffset2nd = MEM_DMA_UNIT_INT_NUM * GetLcalBlockNum() + flagOffset1st; + const int64_t UB_SINGLE_DMA_NUM_MAX = UB_SINGLE_DMA_SIZE_MAX / sizeof(T); + + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t inputNum = len * rankSize; + const int64_t dataDMAPerCore = CeilDiv(len, corePerRank); + const int64_t inputOffset = GetBlockIdx() / corePerRank * len + coreSegmentedIdx * dataDMAPerCore; + + int64_t dataNumDMARemain = dataDMAPerCore; + int64_t oneNPUProcessNum = len; + int64_t oneCoreProcessNum = CeilDiv(len, corePerRank); + const int64_t outputOffsetNum = oneCoreProcessNum * (GetBlockIdx() % corePerRank); + int64_t dataSizeRemain = oneCoreProcessNum * sizeof(T); + if (coreSegmentedIdx == corePerRank - 1) { + dataNumDMARemain = len - coreSegmentedIdx * dataDMAPerCore; + dataSizeRemain = (len - coreSegmentedIdx * oneCoreProcessNum) * sizeof(T); + } + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + AscendC::PipeBarrier(); + CpInputToBuffAndOutput(inputUB, buff[rank], input, output, dataOffsetNum, dataNumDMARemain, + inputOffset, outputOffsetNum, rank, corePerRank, UB_SINGLE_DMA_NUM_MAX); + SyncWithinNPU(ctrlFlagsUB, (__gm__ int64_t *)((__gm__ T *)((__gm__ int64_t *)buff[rank] + dataOffsetNum) + inputNum) + MEM_DMA_UNIT_INT_NUM, magic); + + SetFlag(ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + flagOffset1st, (int64_t)magic); + const int64_t x = GetBlockIdx() / corePerRank; + AscendC::PipeBarrier(); + if (x == rank) { + SetFlag((__ubuf__ int64_t*)ctrlFlagsUB, (__gm__ int64_t*)buff[rank] + flagOffset2nd, (int64_t)magic); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + return; + } + const int64_t buffOffsetNum = rank * oneNPUProcessNum + outputOffsetNum; + + CheckFlag((__ubuf__ int64_t*)ctrlFlagsUB, + (__gm__ int64_t*)buff[x] + (rank * corePerRank + (GetBlockIdx() % corePerRank)) * MEM_DMA_UNIT_INT_NUM, + (int64_t)magic); + + ProcessData(dataSizeRemain, inputUB[0], buff[x], dataOffsetNum, buffOffsetNum, output, outputOffsetNum, op); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_reduce_scatter_big_data.cce b/comm/lcal/src/kernels/lcal_reduce_scatter_big_data.cce new file mode 100644 index 0000000000000000000000000000000000000000..414d690e717228809f8a0d8a49f85a8fe577c3d5 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_reduce_scatter_big_data.cce @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalReduceScatterBigDataOrigin( + __gm__ T* buff[8], __gm__ T *input, __gm__ T *output, uint64_t processedNum, int64_t blockNumPerGroup, uint32_t rank, + uint32_t rankSize, int64_t allLen, int64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T *inputUB[2], int64_t dataOffsetNum, int64_t flagOffset1st, int64_t flagOffset2nd, + int64_t x, int64_t corePerRank, int64_t coreSegmentedIdx, int op) +{ + const int64_t inputNum = len * rankSize; + const int64_t avgNumDMAPerCore = len / corePerRank; + int64_t dataNumRemain = avgNumDMAPerCore; + + int64_t inputOffsetNum = coreSegmentedIdx * avgNumDMAPerCore; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumRemain = len - inputOffsetNum; + } + if (dataNumRemain <= 0) { + return; + } + + if (GetBlockIdx() < blockNumPerGroup) { + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset1st; + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[rank] + dataOffsetNum); + int64_t ipcBuffOffsetNum = x * len + inputOffsetNum; + input2BuffRankMagic(dataNumRemain * sizeof(T), inputUB[0], receiveBuff, ipcBuffOffsetNum, + input, inputOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + return; + } + + if (x == rank) { + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + flagOffset2nd; + input2BuffRankMagic(dataNumRemain * sizeof(T), inputUB[0], output, inputOffsetNum, + input, inputOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + return; + } + + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (coreSegmentedIdx + rank * corePerRank + GetLcalBlockNum()) * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + int64_t allDataSizeNeed2Add = dataNumRemain * sizeof(T); + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB * DMA_SIZE_PER_FLAG >= allDataSizeNeed2Add) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGM, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagsGMX, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if (((*ctrlFlagsUB1 >> 10) != (magic >> 10)) || ((*ctrlFlagsUB2 >> 10) != (magic >> 10))) { + continue; + } + if (*ctrlFlagsUB1 == 0 || *ctrlFlagsUB2 == 0) { + continue; + } + + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 <= *ctrlFlagsUB2) ? (*ctrlFlagsUB1 - magic) : (*ctrlFlagsUB2 - magic); + if (preparedDataGroupCount <= 0 || *ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + int64_t dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + int64_t ipcBuffOffsetNum = rank * len + inputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T); + int64_t outputOffsetNum = inputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T); + + ProcessDataNew(dataSizeRemain, inputUB, buff[x], dataOffsetNum, ipcBuffOffsetNum, output, outputOffsetNum, op); + AscendC::PipeBarrier(); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalReduceScatterBigData(ALLREDUCE_ARGS_FUN(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + flagOffset1st = (GetBlockIdx() - blockNumPerGroup) * MEM_DMA_UNIT_INT_NUM; + } + int64_t flagOffset2nd = GetLcalBlockNum() * MEM_DMA_UNIT_INT_NUM + flagOffset1st; + + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + int64_t ipcBuffMaxNumPerRank = ipcBuffMaxNum / rankSize; + int64_t dataLen = len; + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + for (int64_t i = 0; i < CeilDiv(dataLen, ipcBuffMaxNumPerRank); i++) { + *ctrlFlagsUB = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNumPerRank; + int64_t remainNum = (dataLen - processedNum < ipcBuffMaxNumPerRank) ? dataLen - processedNum : ipcBuffMaxNumPerRank; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalReduceScatterBigDataOrigin( + buff, input + len * x + processedNum, output + processedNum, processedNum, blockNumPerGroup, rank, rankSize, + len, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, ctrlFlagsUB2, inputUB, dataOffsetNum, + flagOffset1st, flagOffset2nd, x, corePerRank, coreSegmentedIdx, op); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_reduce_scatter_big_data_write.cce b/comm/lcal/src/kernels/lcal_reduce_scatter_big_data_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..c28659003d97d00505bc8fcbf6351225221b52d6 --- /dev/null +++ b/comm/lcal/src/kernels/lcal_reduce_scatter_big_data_write.cce @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +__attribute__((always_inline)) inline __aicore__ void LcalReduceScatterBigDataWriteOrigin( + __gm__ T* buff[8], __gm__ T *input, __gm__ T *output, uint64_t processedNum, int64_t blockNumPerGroup, uint32_t rank, + uint32_t rankSize, int64_t allLen, int64_t len, int64_t magic, __ubuf__ int64_t* ctrlFlagsUB, __ubuf__ int64_t* ctrlFlagsUB1, + __ubuf__ int64_t* ctrlFlagsUB2, __ubuf__ T *inputUB[2], int64_t dataOffsetNum,int64_t x, int64_t corePerRank, + int64_t coreSegmentedIdx, int op) +{ + const int64_t inputNum = len * rankSize; + const int64_t avgNumDMAPerCore = len / corePerRank; + int64_t dataNumRemain = avgNumDMAPerCore; + + int64_t inputOffsetNum = coreSegmentedIdx * avgNumDMAPerCore; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumRemain = len - inputOffsetNum; + } + if (dataNumRemain <= 0) { + return; + } + + if (GetBlockIdx() < blockNumPerGroup) { + if (rank == x) { + return; + } + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[x] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + int64_t receiveBuffOffsetNum = rank * len + inputOffsetNum; + input2BuffRankMagic(dataNumRemain * sizeof(T), inputUB[0], receiveBuff, receiveBuffOffsetNum, + input, inputOffsetNum, ctrlFlagsUB, ctrlFlagsGMX, magic); + return; + } + + if (x == rank) { + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + input2BuffRankMagic(dataNumRemain * sizeof(T), inputUB[0], output, inputOffsetNum, + input, inputOffsetNum, ctrlFlagsUB, ctrlFlagsGM, magic); + return; + } + + *ctrlFlagsUB = 0; + *ctrlFlagsUB1 = 0; + *ctrlFlagsUB2 = 0; + __gm__ int64_t* ctrlFlagsGM = (__gm__ int64_t*)buff[rank] + (coreSegmentedIdx + rank * corePerRank) * MEM_DMA_UNIT_INT_NUM; + __gm__ int64_t* ctrlFlagsGMX = (__gm__ int64_t*)buff[rank] + (coreSegmentedIdx + x * corePerRank) * MEM_DMA_UNIT_INT_NUM; + int64_t allDataSizeNeed2Add = dataNumRemain * sizeof(T); + AscendC::PipeBarrier(); + while (true) { + if (*ctrlFlagsUB * DMA_SIZE_PER_FLAG >= allDataSizeNeed2Add) { + break; + } + + CpGM2UB(ctrlFlagsUB1, ctrlFlagsGM, sizeof(int64_t)); + CpGM2UB(ctrlFlagsUB2, ctrlFlagsGMX, sizeof(int64_t)); + AscendC::PipeBarrier(); + + if (((*ctrlFlagsUB1 >> 10) != (magic >> 10)) || ((*ctrlFlagsUB2 >> 10) != (magic >> 10))) { + continue; + } + if (*ctrlFlagsUB1 == 0 || *ctrlFlagsUB2 == 0) { + continue; + } + + int64_t preparedDataGroupCount = (*ctrlFlagsUB1 <= *ctrlFlagsUB2) ? (*ctrlFlagsUB1 - magic) : (*ctrlFlagsUB2 - magic); + if (preparedDataGroupCount <= 0 || *ctrlFlagsUB >= preparedDataGroupCount) { + continue; + } + + int64_t dataSizeRemain = (preparedDataGroupCount - *ctrlFlagsUB) * DMA_SIZE_PER_FLAG; + if (preparedDataGroupCount * DMA_SIZE_PER_FLAG > allDataSizeNeed2Add) { + dataSizeRemain = allDataSizeNeed2Add - *ctrlFlagsUB * DMA_SIZE_PER_FLAG; + } + int64_t ipcBuffOffsetNum = x * len + inputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T); + int64_t outputOffsetNum = inputOffsetNum + (*ctrlFlagsUB) * DMA_SIZE_PER_FLAG / sizeof(T); + + ProcessDataNew(dataSizeRemain, inputUB, buff[rank], dataOffsetNum, ipcBuffOffsetNum, output, outputOffsetNum, op); + AscendC::PipeBarrier(); + + *ctrlFlagsUB = preparedDataGroupCount; + AscendC::PipeBarrier(); + } + SetFlag(ctrlFlagsUB, ctrlFlagsGM, 0); + SetFlag(ctrlFlagsUB1, ctrlFlagsGMX, 0); +} + +template +__attribute__((always_inline)) inline __aicore__ void LcalReduceScatterBigDataWrite(ALLREDUCE_ARGS_FUN(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + magic *= 1024; + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + int64_t flagOffset1st = MEM_DMA_UNIT_INT_NUM * GetBlockIdx(); + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t* ctrlFlagsUB = (__ubuf__ int64_t*)(0); + __ubuf__ int64_t* ctrlFlagsUB1 = (__ubuf__ int64_t*)(32); + __ubuf__ int64_t* ctrlFlagsUB2 = (__ubuf__ int64_t*)(64); + __ubuf__ T* inputUB[2] = {(__ubuf__ T*)(96), (__ubuf__ T*)(97440)}; + + int64_t blockNumPerGroup = GetLcalBlockNum() >> 1; + int64_t corePerRank = blockNumPerGroup / rankSize; + int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + int64_t x = GetBlockIdx() / corePerRank; + if (GetBlockIdx() >= blockNumPerGroup) { + x = (GetBlockIdx() - blockNumPerGroup) / corePerRank; + } + + int64_t ipcBuffMaxNum = IPC_BUFF_MAX_SIZE / sizeof(T); + int64_t ipcBuffMaxNumPerRank = ipcBuffMaxNum / rankSize; + int64_t dataLen = len; + + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + for (int64_t i = 0; i < CeilDiv(dataLen, ipcBuffMaxNumPerRank); i++) { + *ctrlFlagsUB = 0; + AscendC::PipeBarrier(); + + int64_t processedNum = i * ipcBuffMaxNumPerRank; + int64_t remainNum = (dataLen - processedNum < ipcBuffMaxNumPerRank) ? dataLen - processedNum : ipcBuffMaxNumPerRank; + + PostSyncBigData(ctrlFlagsUB, buff, rank, rankSize, dataOffsetNum, ipcBuffMaxNum, magic, i); + LcalReduceScatterBigDataWriteOrigin( + buff, input + len * x + processedNum, output + processedNum, processedNum, blockNumPerGroup, rank, rankSize, len, remainNum, (magic + i) * 1024, ctrlFlagsUB, ctrlFlagsUB1, + ctrlFlagsUB2, inputUB, dataOffsetNum, x, corePerRank, coreSegmentedIdx, op); + AscendC::PipeBarrier(); + } + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/kernels/lcal_reduce_scatter_write.cce b/comm/lcal/src/kernels/lcal_reduce_scatter_write.cce new file mode 100644 index 0000000000000000000000000000000000000000..350e3ab4cdc36bcf6ebae6ab70945f4a8aa3736b --- /dev/null +++ b/comm/lcal/src/kernels/lcal_reduce_scatter_write.cce @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "collectives.cce" + +template +inline __aicore__ void GM2GMAndOutput(int64_t dataSizeRemain, __ubuf__ T *inputUB, __gm__ T *receiveBuff, int64_t revBuffOffsetNum, + __gm__ T *sendBuff, int64_t sendBuffOffsetNum, bool needDMA2Output, __gm__ T *output, int64_t outputOffsetNum) +{ + int64_t times = 0; + while (dataSizeRemain >= UB_SINGLE_DMA_SIZE_MAX) { + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + UB_SINGLE_DMA_SIZE_MAX); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)receiveBuff + revBuffOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + + if (needDMA2Output) { + CpUB2GM( + (__gm__ T*)output + outputOffsetNum + UB_SINGLE_DMA_SIZE_MAX / sizeof(T) * times, + inputUB, UB_SINGLE_DMA_SIZE_MAX); + } + AscendC::SetFlag(EVENT_ID1); + AscendC::WaitFlag(EVENT_ID1); + times += 1; + dataSizeRemain -= UB_SINGLE_DMA_SIZE_MAX; + } + if (dataSizeRemain <= 0) { + return; + } + CpGM2UB(inputUB, (__gm__ T*)sendBuff + sendBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + dataSizeRemain); + AscendC::SetFlag(EVENT_ID0); + AscendC::WaitFlag(EVENT_ID0); + CpUB2GM( + (__gm__ T*)receiveBuff + revBuffOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + if (needDMA2Output) { + CpUB2GM( + (__gm__ T*)output + outputOffsetNum + times * UB_SINGLE_DMA_SIZE_MAX / sizeof(T), + inputUB, dataSizeRemain); + } + AscendC::PipeBarrier(); +} + +template +inline __aicore__ void LcalReduceScatterWrite(ALLREDUCE_ARGS_FUN(T)) +{ + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + const int64_t corePerRank = GetLcalBlockNum() / rankSize; + const int64_t coreSegmentedIdx = GetBlockIdx() % corePerRank; + const int64_t inputNum = len * rankSize; + const int64_t x = GetBlockIdx() / corePerRank; + + __gm__ T* buff[8] = { + buff0, buff1, buff2, buff3, + buff4, buff5, buff6, buff7 + }; + __ubuf__ int64_t *ctrlFlagsUB = (__ubuf__ int64_t *)(0); + __ubuf__ T *inputUB[2] = {(__ubuf__ T *)(32), (__ubuf__ T *)(98304)}; + + const int64_t dataOffsetNum = GetLcalBlockNum() * 2 * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset1st = (rank * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + const int64_t flagOffset2nd = (x * corePerRank + coreSegmentedIdx) * MEM_DMA_UNIT_INT_NUM; + + const int64_t dataDMAPerCore = CeilDiv(len, corePerRank); + int64_t buffDMAOffsetNum = coreSegmentedIdx * dataDMAPerCore; + int64_t dataNumDMARemain = dataDMAPerCore; + if (coreSegmentedIdx == corePerRank - 1) { + dataNumDMARemain = len - buffDMAOffsetNum; + } + DumpLcclLogInfo(dumpAddr, LogId::INIT, static_cast(op)); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + __gm__ T *receiveBuff = (__gm__ T*)((__gm__ int64_t*)buff[x] + dataOffsetNum); + GM2GMAndOutput(dataNumDMARemain * sizeof(T), inputUB[0], receiveBuff, rank * len + buffDMAOffsetNum, + input, x * len + buffDMAOffsetNum, (x == rank), output, buffDMAOffsetNum); + SetFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[x] + flagOffset1st, magic); + AscendC::PipeBarrier(); + + CheckFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[rank] + flagOffset2nd, magic); + CheckFlag(ctrlFlagsUB, (__gm__ int64_t *)buff[rank] + flagOffset1st, magic); + const int64_t buffOffsetNum = x * len + buffDMAOffsetNum; + if (x == rank) { + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); + return; + } + ProcessData(dataNumDMARemain * sizeof(T), inputUB[0], buff[rank], dataOffsetNum, buffOffsetNum, output, buffDMAOffsetNum, op); + DumpLcclLogInfo(dumpAddr, LogId::PROCESS, static_cast(op)); +} \ No newline at end of file diff --git a/comm/lcal/src/lcal_comm.cpp b/comm/lcal/src/lcal_comm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b54380a0c0ca6f4a9109041183a4cc4fb357e993 --- /dev/null +++ b/comm/lcal/src/lcal_comm.cpp @@ -0,0 +1,833 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include "lcal_internal.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "mki/utils/log/log.h" +#include "mki/utils/env/env.h" +#include "tools/socket/lcal_sock_exchange.h" + +#include "runtime/kernel.h" +#include "runtime/mem.h" +#include "runtime/dev.h" +#include "runtime/rt_ffts.h" +#include "profiling/report_timing.h" + +constexpr int AI_CORE_NUM_24 = 24; +constexpr int AI_CORE_NUM_20 = 20; +constexpr int AI_CORE_NUM_2 = 2; + +enum TopologyType : int { + TOPOLOGY_HCCS = 0, + TOPOLOGY_PIX, + TOPOLOGY_PIB, + TOPOLOGY_PHB, + TOPOLOGY_SYS, + TOPOLOGY_SIO, + TOPOLOGY_HCCS_SW +}; + +using namespace std; +using namespace chrono; +using namespace Mki; + +namespace Lcal { +constexpr int HCCL_IPC_PID_ARRAY_SIZE = 1; +constexpr int LCAL_INIT_TIMEOUT = 600; + +static map g_localPeerMemMap; +static map g_devList; +static std::mutex g_mtx; + +static const std::unordered_map CHIP_MAP = { + {"Ascend310P", ChipName::CHIP_310P3}, + {"Ascend910B1", ChipName::CHIP_910B1}, + {"Ascend910B2", ChipName::CHIP_910B2}, + {"Ascend910B2C", ChipName::CHIP_910B2C}, + {"Ascend910B3", ChipName::CHIP_910B3}, + {"Ascend910B4", ChipName::CHIP_910B4}, + {"Ascend910B4-1", ChipName::CHIP_910B41}, + {"Ascend910_9391", ChipName::CHIP_910_9391}, + {"Ascend910_9381", ChipName::CHIP_910_9381}, + {"Ascend910_9392", ChipName::CHIP_910_9392}, + {"Ascend910_9382", ChipName::CHIP_910_9382}, + {"Ascend910_9372", ChipName::CHIP_910_9372}, + {"Ascend910_9361", ChipName::CHIP_910_9361}, + {"Ascend910_9362", ChipName::CHIP_910_9362} +}; + +ChipName GetChipName() +{ + static ChipName curChipName = ChipName::RESERVED; + if (curChipName != ChipName::RESERVED) { + return curChipName; + } + constexpr int socVerLength = 100; + char ver[socVerLength]; + auto ret = rtGetSocVersion(ver, socVerLength); + if (ret != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "rtGetSocVersion failed, not sure whether the function is normal, please use it with caution"; + return ChipName::RESERVED; + } + string chipName(ver); + MKI_LOG(DEBUG) << "rtGetSocVersion -- The result after converting ver to string is:" << chipName; + + auto it = CHIP_MAP.find(chipName); + if (it != CHIP_MAP.end()) { + curChipName = it->second; + } else { + MKI_LOG(WARN) << "There is no commitment to the supported chip types yet," << + " and it is not certain whether the functions will work properly."; + } + return curChipName; +} + +uint32_t GetCoreNum(ChipName chipName) +{ + switch (chipName) { + case ChipName::CHIP_910B1: + case ChipName::CHIP_910B2: + case ChipName::CHIP_910_9391: + case ChipName::CHIP_910_9381: + case ChipName::CHIP_910_9392: + case ChipName::CHIP_910_9382: + case ChipName::CHIP_910B2C: + return AI_CORE_NUM_24; + case ChipName::CHIP_910B3: + case ChipName::CHIP_910B4: + case ChipName::CHIP_910B41: + case ChipName::CHIP_910_9372: + case ChipName::CHIP_910_9361: + case ChipName::CHIP_910_9362: + case ChipName::CHIP_910A5: + return AI_CORE_NUM_20; + case ChipName::CHIP_310P3: + return AI_CORE_NUM_2; + default: + MKI_LOG(ERROR) << "Unknown chip name"; + return 0; + } +} + +bool SkipUnusedChannel910B2C(int curRank, int peerRank, ChipName chipName) +{ + if (chipName == ChipName::CHIP_910B2C) { + constexpr int rankSizePerNode = 8; + if ((curRank / rankSizePerNode != peerRank / rankSizePerNode) && + (std::abs(curRank - peerRank) != rankSizePerNode)) { + return true; + } + } + return false; +} + +int LcalComm::InitDumpAddr() +{ + constexpr uint32_t dumpCoreCnt = 75; + constexpr uint32_t dumpSizePerCore = 1 * 1024 * 1024; + constexpr uint32_t dumpWorkspaceSize = dumpCoreCnt * dumpSizePerCore; + GM_ADDR dumpAddr = nullptr; + int ret = 0; + ret = aclrtMalloc(reinterpret_cast(&dumpAddr), dumpWorkspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMalloc err " << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + aclrtMemset(dumpAddr, dumpWorkspaceSize, 0, dumpWorkspaceSize); + + GM_ADDR memory = static_cast(std::malloc(dumpWorkspaceSize)); + if (!memory) { + MKI_LOG(ERROR) << "std::malloc err " << __LINE__; + return LCAL_ERROR_INTERNAL; + } + errno_t result = memset_s(memory, dumpWorkspaceSize, 0, dumpWorkspaceSize); + if (result != 0) { + MKI_LOG(ERROR) << "memset_s err " << result; + } + for (uint32_t i = 0; i < dumpCoreCnt; ++i) { + GM_ADDR block_start = memory + i * dumpSizePerCore; + GM_ADDR deviceBlockStart = dumpAddr + i * dumpSizePerCore; + + LcclDumpBlockInfo* block_info = reinterpret_cast(block_start); + block_info->len = dumpSizePerCore; + block_info->core = i; + block_info->blockNum = 0; + block_info->dumpOffset = dumpSizePerCore - sizeof(LcclDumpBlockInfo); + block_info->magic = 0; + block_info->dumpAddr = reinterpret_cast(deviceBlockStart + sizeof(LcclDumpBlockInfo)); + } + + ret = aclrtMemcpy(dumpAddr, dumpWorkspaceSize, memory, dumpWorkspaceSize, ACL_MEMCPY_HOST_TO_DEVICE); + std::free(memory); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMemcpy err " << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + + commArgs_.dumpAddr = dumpAddr; + return LCAL_SUCCESS; +} + +int LcalComm::SyncCommArgs() +{ + commArgs_.rank = rank_; + commArgs_.localRank = localRank_; + commArgs_.rankSize = rankSize_; + commArgs_.localRankSize = localRankSize_; + for (int i = 0; i < rankSize_; ++i) { + commArgs_.peerMems[i] = peerMem_[i]; + } + + if (isEnableMsprofOp_ && InitDumpAddr() != LCAL_SUCCESS) { + return LCAL_ERROR_INTERNAL; + } + + if (isEnableMix_) { + uint64_t fftsVal = 0; + uint32_t fftsLen = 0; + int error = rtGetC2cCtrlAddr(&fftsVal, &fftsLen); + if (error != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "rtGetC2cCtrlAddr err:" << error; + return LCAL_ERROR_MKIRT; + } + commArgs_.fftsVal = fftsVal; + } + + int ret = 0; + ret = aclrtMalloc(reinterpret_cast(&commArgsPtr_), sizeof(commArgs_), ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMalloc err " << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + ret = aclrtMemcpy(commArgsPtr_, sizeof(commArgs_), &commArgs_, sizeof(commArgs_), ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMemcpy err " << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalComm::InitCommon() +{ + if (EnablePeerAccess() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "EnablePeerAccess failed!"; + return LCAL_ERROR_INTERNAL; + } + const char *lcclDeterministic = Mki::GetEnv("LCCL_DETERMINISTIC"); + if (lcclDeterministic && (string(lcclDeterministic) == "1" || string(lcclDeterministic) == "true")) { + deterministic_ = true; + commArgs_.extraFlag |= ExtraFlag::DETERMINISTIC; + } + if (GetChipName() == ChipName::CHIP_910B2C) { + commArgs_.extraFlag |= ExtraFlag::TOPO_910B2C; + } + if (GetChipName() >= ChipName::CHIP_910_9391) { + commArgs_.extraFlag |= ExtraFlag::TOPO_910_93; + } + if (GetChipName() > ChipName::CHIP_910_9362) { + commArgs_.extraFlag |= ExtraFlag::TOPO_910A5; + } + if (GetCoreNum(GetChipName()) > AI_CORE_NUM_20) { + commArgs_.extraFlag |= ExtraFlag::IS_GREATER_THAN_40_AIV; + } + + ReportTiming report("LcclReporting", rank_, false, nullptr, nullptr); + MKI_LOG(INFO) << "LcalComm::InitCommon ReportTiming " << std::hex << ReportTiming::ProfilingStatus() << std::dec; + if (ReportTiming::ProfilingStatus() == ReportTiming::PROF_TASK_TIME_DUMP) { + isEnableMsprofOp_ = true; + isEnableMix_ = true; + } + + int32_t opGroup = 0; + if (isEnableMsprofOp_) { + opGroup = 0; + } else if (isEnableMix_) { + opGroup = 1; + } else { + constexpr int32_t normalOpGroup = 2; + opGroup = normalOpGroup; + } + MKI_LOG(INFO) << "LcalComm::InitCommon RegistKernel opGroup " << opGroup; + RegistKernel(opGroup); + + localRank_ = rank_ % localRankSize_; + return LCAL_SUCCESS; +} + +void LcalComm::CloseIpcMem() +{ + for (int i = 0; i < rankSize_; ++i) { + if (i == rank_ || peerMem_[i] == nullptr) { + continue; + } + + int ret = rtIpcCloseMemory(static_cast(peerMem_[i])); + if (ret != RT_ERROR_NONE) { + MKI_LOG(WARN) << "Close ipc[" << i << "] memory failed! ret: " << ret; + } + peerMem_[i] = nullptr; + } +} + +void LcalComm::FreePeerMem(GM_ADDR &mem) const +{ + if (mem != nullptr) { + aclError aclRet = aclrtFree(mem); + if (aclRet != ACL_SUCCESS) { + MKI_LOG(ERROR) << "Free share memory failed! ret: " << aclRet; + } + } + mem = nullptr; +} + +int LcalComm::Init() +{ + if (inited_) { + return LCAL_SUCCESS; + } + if (rank_ < 0 || rank_ >= rankSize_ || rankSize_ <= 0 || rankSize_ > LCAL_MAX_RANK_SIZE) { + MKI_LOG(ERROR) << "The rank is invalid! rank:" << rank_ << " rankSize:" << rankSize_; + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (LcalSockExchange::CheckValid(commId_)) { + socketExchange_ = new (nothrow) LcalSockExchange(rank_, rankSize_, commId_); + } else { + socketExchange_ = new (nothrow) LcalSockExchange(rank_, rankSize_, rankList_, commDomain_); + } + if (socketExchange_ == nullptr) { + MKI_LOG(ERROR) << "LcalSockExchange create failed. rank : " << rank_ << " rankSize:" << rankSize_; + return LCAL_ERROR_INTERNAL; + } + int ret = GetDev(); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "init context failed! ret: " << ret; + return ret; + } + + MKI_LOG(INFO) << "rank " << rank_ << "/" << rankSize_ << " running devId:" << devId_; + + if (InitCommon() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "init common failed!"; + return LCAL_ERROR_INTERNAL; + } + + MKI_LOG(DEBUG) << "Prepare to InitCommMem localRankSize_ -> " << localRankSize_ << ", localRank_ -> " << localRank_; + if (InitCommMem() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "InitCommMem failed!"; + return LCAL_ERROR_INTERNAL; + } + MKI_LOG(DEBUG) << "InitCommMem " << rank_ << "/" << rankSize_ << ", localRank_ : " << localRank_ << + ", localRankSize_ : " << localRankSize_ << " success"; + + SyncCommArgs(); + MKI_LOG(INFO) << "LcalCommInit " << rank_ << "/" << rankSize_ << " success and extraFlag:" << commArgs_.extraFlag << + " commArgs_.localRank : " << commArgs_.localRank << " commArgs_.localRankSize : " << commArgs_.localRankSize; + inited_ = true; + delete socketExchange_; + socketExchange_ = nullptr; + return LCAL_SUCCESS; +} + +int LcalComm::InitThread(const std::string &uid) +{ + if (inited_) { + return LCAL_SUCCESS; + } + if (rank_ < 0 || rank_ >= rankSize_ || rankSize_ <= 0 || rankSize_ > LCAL_MAX_RANK_SIZE) { + MKI_LOG(ERROR) << "The rank is invalid! rank:" << rank_ << "rankSize:" << rankSize_; + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (GetDevThread(uid) != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "get devs failed."; + return LCAL_ERROR_INTERNAL; + } + MKI_LOG(INFO) << "rank " << rank_ << "/" << rankSize_ << " running devId:" << devId_ << "uid: " << uid; + + if (InitCommon() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "init common failed!"; + return LCAL_ERROR_INTERNAL; + } + { + lock_guard lock(g_mtx); + if (g_localPeerMemMap.find(uid) == g_localPeerMemMap.end()) { + for (int i = 0; i < rankSize_; ++i) { + g_localPeerMemMap[uid][i] = nullptr; + } + } + uid_ = uid; + } + InitMem(); + g_localPeerMemMap[uid][rank_] = peerMem_[rank_]; + + auto start = high_resolution_clock::now(); + for (int i = 0; i < rankSize_; ++i) { + while (g_localPeerMemMap[uid][i] == nullptr) { + this_thread::sleep_for(1ms); + auto elapsed = duration_cast(high_resolution_clock::now() - start); + if (elapsed.count() > LCAL_INIT_TIMEOUT) { + MKI_LOG(ERROR) << "Lccl Init timeout!"; + FreePeerMem(g_localPeerMemMap[uid][rank_]); + return LCAL_ERROR_TIMEOUT; + } + } + peerMem_[i] = g_localPeerMemMap[uid][i]; + } + localRank_ = rank_; + localRankSize_ = rankSize_; + SyncCommArgs(); + MKI_LOG(INFO) << "Lccl init multi thread " << rank_ << "/" << rankSize_ << " success, uid:" << uid; + inited_ = true; + return LCAL_SUCCESS; +} + +int LcalComm::EnablePeerAccess() +{ + physicalInfo_.chipName = GetChipName(); + for (auto &dev : devList_) { + if (devId_ == dev) { + continue; + } + if (SkipUnusedChannel910B2C(dev, devId_, GetChipName())) { + continue; + } + + int64_t value = 0; + if (rtGetPairDevicesInfo(devId_, dev, 0, &value) != RT_ERROR_NONE) { + MKI_LOG(WARN) << devId_ << " & " << dev << " pair devices info failed to get"; + } else { + MKI_LOG(DEBUG) << devId_ << " <-----> " << dev << ", halGetPairDevicesInfo: *value = " << value; + } + + if (value == TOPOLOGY_HCCS || value == TOPOLOGY_SIO || value == TOPOLOGY_HCCS_SW || + GetChipName() == ChipName::CHIP_910B2C) { + physicalInfo_.physicalLink = PhysicalLink::HCCS; + commArgs_.extraFlag &= ~(ExtraFlag::TOPO_PCIE); + } else if (physicalInfo_.physicalLink == PhysicalLink::RESERVED) { + physicalInfo_.physicalLink = PhysicalLink::PCIE; + commArgs_.extraFlag |= ExtraFlag::TOPO_PCIE; + if (rankSize_ > PING_PONG_SIZE) { + MKI_LOG(ERROR) << "do not support pcie > 2 rank! rankSize_ = " << rankSize_; + return LCAL_ERROR_INTERNAL; + } + } + + physicalInfo_.coreNum = GetCoreNum(physicalInfo_.chipName); + + if (physicalInfo_.chipName == ChipName::CHIP_310P3 && value == 0) { + MKI_LOG(WARN) << "warn aclrtDeviceEnablePeerAccess is skipped! peerDeviceId = " << dev; + continue; + } + + aclError ret = aclrtDeviceEnablePeerAccess(dev, 0); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "err aclrtDeviceEnablePeerAccess failed peerDeviceId = " << dev << " ,rank = " << rank_ + << ", value = " << value << ", flags = " << 0 << "," << __LINE__ << ": " << ret; + return LCAL_ERROR_INTERNAL; + } + } + MKI_LOG(DEBUG) << "EnablePeerAccess succeed" << rank_; + return LCAL_SUCCESS; +} + +int LcalComm::GetDev() +{ + int nodeNum = socketExchange_->GetNodeNum(); + if (nodeNum <= 0 || nodeNum > rankSize_) { + MKI_LOG(ERROR) << "error! node num : " << nodeNum << " rank size: " << rankSize_; + return LCAL_ERROR_INTERNAL; + } + localRankSize_ = rankSize_ / nodeNum; + localRank_ = rank_ % localRankSize_; + MKI_LOG(DEBUG) << "GetDev : localRankSize_ : " << localRankSize_ << " localRank_: " << localRank_ + << " rank :" << rank_ << " rankSize :" << rankSize_; + devList_.resize(rankSize_); + aclError aclRet = aclrtGetDevice(&devId_); + if (aclRet != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtGetDevice error! ret: " << aclRet; + return LCAL_ERROR_INTERNAL; + } + int ret = socketExchange_->AllGather(&devId_, 1, devList_.data()); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "LcalSockExchange AllGather error! ret: " << ret; + return LCAL_ERROR_INTERNAL; + } + std::string devIdStr = ""; + for (int i = 0; i < rankSize_; ++i) { + devIdStr += (i == 0 ? "" : ", "); + devIdStr += to_string(devList_[i]); + } + MKI_LOG(DEBUG) << "rank " << rank_ << " devId: " << devId_ << ", otherDevList : " << devIdStr; + MKI_LOG(INFO) << "AllGather: Get other rank dev id success"; + return LCAL_SUCCESS; +} + +int LcalComm::GetDevThread(const std::string &uid) +{ + devList_.resize(rankSize_); + aclError aclRet = aclrtGetDevice(&devId_); + if (aclRet != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtGetDevice error! ret: " << aclRet; + return LCAL_ERROR_INTERNAL; + } + { + std::lock_guard lock(g_mtx); + if (g_devList.find(uid) == g_devList.end()) { + for (int i = 0; i < rankSize_; ++i) { + g_devList[uid][i] = 0; + } + } + } + g_devList[uid][rank_] = devId_ + 1; + auto start = high_resolution_clock::now(); + for (int i = 0; i < rankSize_; ++i) { + while (g_devList[uid][i] == 0) { + this_thread::sleep_for(1ms); + auto elapsed = duration_cast(high_resolution_clock::now() - start); + if (elapsed.count() > LCAL_INIT_TIMEOUT) { + MKI_LOG(ERROR) << "Lccl Init timeout!"; + return LCAL_ERROR_TIMEOUT; + } + } + devList_.at(i) = g_devList[uid][i] - 1; + } + return LCAL_SUCCESS; +} + +int LcalComm::InitMem() +{ + constexpr int32_t bufferSizeUint = 1024 * 1024; + int lcalBuffSize = bufferSize_ * bufferSizeUint + LCAL_FLAG_BUFF_BYTES; + + MKI_LOG(DEBUG) << "lcal buffer size " << lcalBuffSize; + aclError ret = aclrtMalloc( + reinterpret_cast(&peerMem_[rank_]), lcalBuffSize, + (GetChipName() == ChipName::CHIP_310P3) ? ACL_MEM_MALLOC_HUGE_FIRST_P2P : ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "allocate device mem error " << __FILE__ << ":" << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + MKI_LOG(DEBUG) << "peerMem[rank" << rank_ << "], allocate finished."; + aclrtMemset(peerMem_[rank_], lcalBuffSize, 0, lcalBuffSize); + return LCAL_SUCCESS; +} + +int LcalComm::GetPid(uint32_t *pids) +{ + if (rtDeviceGetBareTgid(&pids[rank_]) != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "DeviceGetBareTgid err " << __LINE__; + return LCAL_ERROR_INTERNAL; + } + int ret = socketExchange_->AllGather(&pids[rank_], 1, pids); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "LcalSockExchange AllGather error! ret: " << ret; + return ret; + } + for (int i = 0; i < rankSize_; ++i) { + MKI_LOG(DEBUG) << "rank : " << rank_ << ", otherRank : " << i << " pid[" << i << "]: " << pids[i]; + } + MKI_LOG(DEBUG) << "AllGather: Get other rank pid"; + return LCAL_SUCCESS; +} + +int LcalComm::GetSidId(int64_t sdids[LCAL_MAX_RANK_SIZE], int rankSize) +{ + if (rank_ >= rankSize) { + MKI_LOG(ERROR) << "LcalComm::GetSidId err rank_ >= rankSize " << rank_ << ">=" << rankSize; + return LCAL_ERROR_INTERNAL; + } + if ((physicalInfo_.chipName >= ChipName::CHIP_910_9391) && (physicalInfo_.chipName < ChipName::RESERVED)) { + const int rtModuleTypeSystem = 0; + const int infoTypeSdid = 26; + if (rtGetDeviceInfo(devList_[rank_], rtModuleTypeSystem, infoTypeSdid, &sdids[rank_]) != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "DeviceGetDeviceInfo err " << __LINE__; + return LCAL_ERROR_INTERNAL; + } + MKI_LOG(DEBUG) << "rank " << rank_ << " dev id: " << devList_[rank_] + << " rtGetDeviceInfo sdid: " << sdids[rank_]; + + int ret = socketExchange_->AllGather(&sdids[rank_], 1, sdids); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "LcalSockExchange AllGather error! ret: " << ret; + return ret; + } + for (int i = 0; i < rankSize_; ++i) { + MKI_LOG(DEBUG) << "rank " << i << " sdid: " << sdids[i]; + } + MKI_LOG(DEBUG) << "AllGather: Get other rank sdid"; + } + return LCAL_SUCCESS; +} + +int LcalComm::GetName(string &name, char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]) const +{ + int ret = socketExchange_->AllGather(name.c_str(), IPC_NAME_SIZE, names[0]); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "LcalSockExchange AllGather error! ret: " << ret; + return LCAL_ERROR_INTERNAL; + } + for (int i = 0; i < rankSize_; ++i) { + names[i][IPC_NAME_SIZE - 1] = '\0'; + MKI_LOG(DEBUG) << "rank " << i << " mem name: " << names[i]; + } + MKI_LOG(DEBUG) << "AllGather: Get other rank mem name"; + return LCAL_SUCCESS; +} + +int LcalComm::InitCommMem() +{ + int ret = InitMem(); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "InitMem error! ret: " << ret; + return ret; + } + + uint32_t pids[LCAL_MAX_RANK_SIZE] = {0}; + ret = GetPid(pids); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "GetPid error! ret: " << ret; + return ret; + } + + int64_t sdids[LCAL_MAX_RANK_SIZE] = {0}; + ret = GetSidId(sdids, rankSize_); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "GetSidId error! ret: " << ret; + return ret; + } + + string name; + if (SetMemoryName(name) != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "SetMemoryName err "; + return LCAL_ERROR_INTERNAL; + } + + if (SetIpcPidSdid(name, pids, sdids) != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "SetIpcPidSdid failed!"; + return LCAL_ERROR_INTERNAL; + } + + MKI_LOG(DEBUG) << "rank " << rank_ << " mem name: " << name; + char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]; + ret = GetName(name, names); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "GetName error! ret: " << ret; + return ret; + } + + if (OpenIpcMem(names) != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "rank: " << rank_ << " OpenIpcMem failed!"; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalComm::OpenIpcMem(const char names[LCAL_MAX_RANK_SIZE][IPC_NAME_SIZE]) +{ + static mutex mut; + lock_guard lock(mut); + for (int i = 0; i < rankSize_; ++i) { + if (i == rank_) { + continue; + } + if (SkipUnusedChannel910B2C(rank_, i, GetChipName())) { + continue; + } + int ret = rtIpcOpenMemory(reinterpret_cast(&peerMem_[i]), names[i]); + if (ret != RT_ERROR_NONE) { + CloseIpcMem(); + MKI_LOG(ERROR) << "rank : " << rank_ << " localRank : " << localRank_ << " peerMem: " << i << + " IpcOpenMemory err " << ret; + return LCAL_ERROR_INTERNAL; + } + } + ipcMemInited_ = true; + return LCAL_SUCCESS; +} + +int LcalComm::SetMemoryName(string &name) +{ + char nameModified[IPC_NAME_SIZE] = {}; + int memRank = rank_; + constexpr int32_t bufferSizeUint = 1024 * 1024; + int lcalBuffSize = bufferSize_ * bufferSizeUint + LCAL_FLAG_BUFF_BYTES; + if (rtIpcSetMemoryName(peerMem_[memRank], lcalBuffSize, nameModified, IPC_NAME_SIZE) != RT_ERROR_NONE) { + return LCAL_ERROR_INTERNAL; + } + name = nameModified; + return LCAL_SUCCESS; +} + +int LcalComm::SetIpcPidSdid(string &name, const uint32_t *pids, const int64_t *sdids) const +{ + for (int i = 0; i < rankSize_; ++i) { + if (i == rank_) { + continue; + } + + if (physicalInfo_.chipName < ChipName::CHIP_910_9391) { + int32_t pidInt32 = pids[i]; + int rtRet = rtSetIpcMemPid(name.c_str(), &pidInt32, HCCL_IPC_PID_ARRAY_SIZE); + if (rtRet != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "err " << rtRet; + return LCAL_ERROR_INTERNAL; + } + } else { + int32_t pidInt32 = pids[i]; + int rtRet = rtSetIpcMemorySuperPodPid(name.c_str(), sdids[i], &pidInt32, HCCL_IPC_PID_ARRAY_SIZE); + if (rtRet != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "err " << rtRet; + return LCAL_ERROR_INTERNAL; + } + } + } + return LCAL_SUCCESS; +} + +LcalComm::~LcalComm() +{ + { + lock_guard lock(g_mtx); + if (g_localPeerMemMap.find(uid_) != g_localPeerMemMap.end()) { + g_localPeerMemMap.erase(uid_); + } + } + + if (ipcMemInited_) { +#ifndef USE_MSSANITIZER + CloseIpcMem(); +#endif + ipcMemInited_ = false; + } + if (socketExchange_) { + delete socketExchange_; + socketExchange_ = nullptr; + } + FreePeerMem(commArgs_.dumpAddr); + FreePeerMem(peerMem_[rank_]); + FreePeerMem(commArgsPtr_); +} + +LcalComm::LcalComm(int rank, int rankSize) : rank_(rank), rankSize_(rankSize) +{ +} + +LcalComm::LcalComm(int rank, int rankSize, int bufferSize) : rank_(rank), rankSize_(rankSize), bufferSize_(bufferSize) +{ +} + +LcalComm::LcalComm(int rank, int rankSize, int commDomain, int bufferSize, int isEnableMagic) + : rank_(rank), rankSize_(rankSize), commDomain_(commDomain), bufferSize_(bufferSize), isEnableMix_(isEnableMagic) +{ +} + +LcalComm::LcalComm(int rank, int rankSize, LcalUniqueId commId) + : rank_(rank), rankSize_(rankSize), commId_(commId) +{ +} + +int LcalComm::GetRank() const +{ + return rank_; +} + +int LcalComm::GetRankSize() const +{ + return rankSize_; +} + +int LcalComm::GetCommSize() const +{ + return commSize_; +} + +int LcalComm::GetBufferSize() const +{ + return bufferSize_; +} + +const PhysicalInfo &LcalComm::GetPhysicalInfo() const +{ + return physicalInfo_; +} + +GM_ADDR LcalComm::GetCommArgsPtr() const +{ + return commArgsPtr_; +} + +CommArgs* LcalComm::GetCommArgs() +{ + return &commArgs_; +} + + +std::string LcalComm::PrintDFX() +{ + if (commArgsPtr_ == nullptr) { + return "no comm args"; + } + int ret = aclrtMemcpy(&commArgs_, sizeof(commArgs_), commArgsPtr_, sizeof(commArgs_), + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMemcpy err " << __LINE__ << " " << ret; + return "acl mem copy error"; + } + stringstream ss; + ss << "CommArgs {" + << "\n rank: " << commArgs_.rank + << "\n localRank: " << commArgs_.localRank + << "\n rankSize: " << commArgs_.rankSize + << "\n localRankSize: " << commArgs_.localRankSize + << "\n extraFlag: 0x" << std::hex << std::setfill('0') << commArgs_.extraFlag << std::dec; + + ss << "\n peerMems: ["; + for (int i = 0; i < LCAL_MAX_RANK_SIZE; ++i) { + if (commArgs_.peerMems[i] == nullptr) { + continue; + } + if (i > 0) { + ss << ", "; + } + ss << "{id: " << static_cast(commArgs_.peerMems[i]) << "}"; + } + ss << "]"; + + ss << "\n magics: ["; + for (int i = 0; i < rankSize_; ++i) { + ss << std::dec << commArgs_.magics[i] << ","; + } + ss << "] \n"; + + ss << "\n dfx: ["; + const int dfxGroupCount = 5; + for (int i = 0; i < DFX_COUNT; ++i) { + if (i % dfxGroupCount == 0) { + ss << "\n " << std::dec << setw(dfxGroupCount) << i << ": "; + } + ss << "0x"<< std::hex << commArgs_.dfx[i] << std::dec << ", "; + } + ss << "\n ]"; + + ss << "\n}"; + return ss.str(); +} + +} // Lcal \ No newline at end of file diff --git a/comm/lcal/src/lcal_internal.cpp b/comm/lcal/src/lcal_internal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5fd8b8b56e2d95b78b176219ab2686ed2b0fc33f --- /dev/null +++ b/comm/lcal/src/lcal_internal.cpp @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "lcal_internal.h" +#include +#include +#include +#include +#include +#include +#include "ccl_kernel_args.h" +#include "coc_kernel_args.h" +#include "lcoc.h" + +using namespace std; +using namespace Mki; + +extern const int LCAL_CCE_BIN_STR[]; +asm(R"(.section .rodata, "a", @progbits +LCAL_CCE_BIN_STR:.incbin "/tmp/lcal_cce.o" +.byte 0 +.previous)"); + +constexpr int LCCL_RT_DEV_BINARY_MAGIC_ELF_AIVEC = 0x41415246; +constexpr int COC_RT_DEV_BINARY_MAGIC_ELF = 0x43554245; + +namespace Lcal { +const std::map DATATYPE2NAME = { + { HCCL_DATA_TYPE_INT32, "int" }, + { HCCL_DATA_TYPE_INT16, "int16_t" }, + { HCCL_DATA_TYPE_INT8, "int8_t" }, + { HCCL_DATA_TYPE_INT64, "int64_t" }, + { HCCL_DATA_TYPE_FP32, "float" }, + { HCCL_DATA_TYPE_FP16, "float16_t" }, + { HCCL_DATA_TYPE_BFP16, "bfloat16_t" } +}; + +template +int RegisterBinaryKernel(const string &funcName, int8_t *funSig, const T *binStrPtr, int magic, int len = 0) +{ + rtDevBinary_t binary; + void *binHandle = nullptr; + binary.data = binStrPtr; + binary.length = (len == 0 ? LCAL_1OP_BIN_SIZE : len); + + binary.magic = magic; + binary.version = 0; + rtError_t rtRet = rtDevBinaryRegister(&binary, &binHandle); + if (rtRet != RT_ERROR_NONE) { + MKI_LOG(WARN) << "rtDevBinaryRegister failed! " << to_string(rtRet) << ", funcName = " << funcName; + return LCAL_ERROR_INTERNAL; + } + rtRet = rtFunctionRegister(binHandle, funSig, funcName.c_str(), funcName.c_str(), 0); + if (rtRet != RT_ERROR_NONE) { + MKI_LOG(WARN) << "rtFunctionRegister failed! " << to_string(rtRet) << ", funcName = " << funcName; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int8_t *GetFunSig(LcalType type, HcclDataType dataType, uint64_t devType = 0) +{ + constexpr int sigOffset = 16; + constexpr int sigSkew = 0x1000; + return reinterpret_cast((static_cast(type) << sigOffset << sigOffset) + + (static_cast(dataType)<< sigOffset) + devType + sigSkew); +} + +const int* FindNextOpStart(const int opStartMaigc, const int* cclBinEndPtr, const int* cclBinPtr) +{ + if (cclBinPtr == nullptr) { + MKI_LOG(ERROR) << "FindNextOpStart failed! cclBinPtr is nullptr"; + return nullptr; + } + while (cclBinPtr < cclBinEndPtr && *cclBinPtr != opStartMaigc) { + cclBinPtr++; + } + if (*cclBinPtr == opStartMaigc) { + cclBinPtr++; + } + return cclBinPtr; +} + +int RegistCCLOp2Kernel(const int* cclBinPtr, const int* nextPtr) +{ + vector registerTypes = { HCCL_DATA_TYPE_INT32, HCCL_DATA_TYPE_INT16, HCCL_DATA_TYPE_INT8, + HCCL_DATA_TYPE_FP32, HCCL_DATA_TYPE_FP16, HCCL_DATA_TYPE_BFP16, + HCCL_DATA_TYPE_INT64 }; + std::vector registerCCLTypesOp2 = { // 完成算子实现后在这里添加算子注册 + LcalType::ALL_GATHER, LcalType::REDUCE_SCATTER, LcalType::ALL2ALL, + }; + int res = LCAL_SUCCESS; + for (auto ccl : registerCCLTypesOp2) { + for (auto t : registerTypes) { + res = RegisterBinaryKernel(LCAL_TYPE2NAME.at(ccl) + "_" + DATATYPE2NAME.at(t), GetFunSig(ccl, t), + cclBinPtr, LCCL_RT_DEV_BINARY_MAGIC_ELF_AIVEC, (nextPtr - cclBinPtr) * sizeof(int)); + } + } + if (res != LCAL_SUCCESS) { + return res; + } + res = RegisterBinaryKernel(LCAL_TYPE2NAME.at(LcalType::BROADCAST), + GetFunSig(LcalType::BROADCAST, HCCL_DATA_TYPE_RESERVED), cclBinPtr, LCCL_RT_DEV_BINARY_MAGIC_ELF_AIVEC); + return res; +} + +int RegistCCLOp1Kernel(const int* cclBinPtr, const int* nextPtr) +{ + vector registerTypes = { HCCL_DATA_TYPE_INT32, HCCL_DATA_TYPE_INT16, HCCL_DATA_TYPE_INT8, + HCCL_DATA_TYPE_FP32, HCCL_DATA_TYPE_FP16, HCCL_DATA_TYPE_BFP16, + HCCL_DATA_TYPE_INT64 }; + std::vector registerCCLTypesOp1 = { // 完成算子实现后在这里添加算子注册 + LcalType::ALL_REDUCE, + }; + int res = LCAL_SUCCESS; + for (auto ccl : registerCCLTypesOp1) { + for (auto t : registerTypes) { + res = RegisterBinaryKernel(LCAL_TYPE2NAME.at(ccl) + "_" + DATATYPE2NAME.at(t), GetFunSig(ccl, t), + cclBinPtr, LCCL_RT_DEV_BINARY_MAGIC_ELF_AIVEC, (nextPtr - cclBinPtr) * sizeof(int)); + } + } + return res; +} + +int RegistCCLKernel(const int32_t opGroup) +{ + const int* cclBinStr = LCAL_CCE_BIN_STR; + auto cclBinEndPtr = cclBinStr + LCAL_1OP_BIN_SIZE / sizeof(int); + const int* cclBinPtr = cclBinStr + 1; + constexpr int opStartMaigc = 0x44444444; + const int* nextPtr = FindNextOpStart(opStartMaigc, cclBinEndPtr, cclBinPtr); + if (nextPtr == nullptr) { + return LCAL_ERROR_INTERNAL; + } + + constexpr int32_t smallGroupNum = 2; + for (int32_t opGroupIdx = 0; opGroupIdx < opGroup; ++opGroupIdx) { + for (int32_t opIdx = 0; opIdx < smallGroupNum; ++opIdx) { + cclBinPtr = nextPtr; + nextPtr = FindNextOpStart(opStartMaigc, cclBinEndPtr, nextPtr); + if (cclBinPtr == nullptr || cclBinPtr == cclBinEndPtr || nextPtr == nullptr) { + return LCAL_ERROR_INTERNAL; + } + } + } + + int ret = 0; + ret = RegistCCLOp1Kernel(cclBinPtr, nextPtr); + if (ret != LCAL_SUCCESS) { + return LCAL_ERROR_INTERNAL; + } + + // 切换到大组内第二个小组是 + cclBinPtr = nextPtr; + nextPtr = FindNextOpStart(opStartMaigc, cclBinEndPtr, nextPtr); + if (cclBinPtr == nullptr || cclBinPtr == cclBinEndPtr || nextPtr == nullptr) { + return LCAL_ERROR_INTERNAL; + } + + // 大组内第二个小组是 reducescatter, allgather 等 + ret = RegistCCLOp2Kernel(cclBinPtr, nextPtr); + if (ret != LCAL_SUCCESS) { + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +void RegistCoCKernel() +{ + vector registerTypes = { HCCL_DATA_TYPE_FP16, HCCL_DATA_TYPE_BFP16 }; + vector> registerCOCTypes = { + { LcalType::PURE_MATMUL}, + { LcalType::MATMUL_ALL_REDUCE }, + { LcalType::MATMUL_REDUCE_SCATTER }, + { LcalType::ALL_GATHER_MATMUL, LcalType::ALL_GATHER_MATMUL_V2 }, + { LcalType::ALL_GATHER_MATMUL_REDUCE_SCATTER}, + { LcalType::ALLTOALLV_ALLGATHER_MATMUL, LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN}, + { LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN}, + }; + + auto cocCceBinStr = LCAL_CCE_BIN_STR + LCAL_1OP_BIN_SIZE / sizeof(int); + for (auto lcalTypeGroup : registerCOCTypes) { + for (auto lcalType : lcalTypeGroup) { + for (auto t : registerTypes) { + RegisterBinaryKernel(LCAL_TYPE2NAME.at(lcalType) + "_" + DATATYPE2NAME.at(t), GetFunSig(lcalType, t), + cocCceBinStr, COC_RT_DEV_BINARY_MAGIC_ELF); + } + } + cocCceBinStr += LCAL_1OP_BIN_SIZE / sizeof(int); + } +} + +int RegistKernel(const int32_t opGroup) +{ + static bool init = false; + static mutex mut; + lock_guard guard(mut); + if (init) { + return 0; + } + RegistCoCKernel(); + RegistCCLKernel(opGroup); + init = true; + return LCAL_SUCCESS; +} + +int64_t Count2Size(int64_t count, const HcclDataType &dataType) +{ + int64_t dataSize = LCAL_INVALID_VALUE; + if (dataType == HCCL_DATA_TYPE_INT8 || dataType == HCCL_DATA_TYPE_UINT8) { + dataSize = count; + } else if (dataType == HCCL_DATA_TYPE_INT16 || dataType == HCCL_DATA_TYPE_FP16 || + dataType == HCCL_DATA_TYPE_BFP16 || dataType == HCCL_DATA_TYPE_UINT16) { + dataSize = count * sizeof(int16_t); + } else if (dataType == HCCL_DATA_TYPE_FP32 || dataType == HCCL_DATA_TYPE_INT32 || + dataType == HCCL_DATA_TYPE_UINT32) { + dataSize = count * sizeof(int32_t); + } else if (dataType == HCCL_DATA_TYPE_INT64 || dataType == HCCL_DATA_TYPE_UINT64) { + dataSize = count * sizeof(int64_t); + } else { + MKI_LOG(ERROR) << "unknown datatype"; + } + return dataSize; +} + +int LoadMTE(LcalType cclType, AscendCCLKernelArgs &args, uint32_t blockDim, HcclDataType dataType, aclrtStream stream) +{ + int error = 0; + MKI_LOG(DEBUG) << "LoadMTE " << LCAL_TYPE2NAME.at(cclType) << " count:" << args.count << " dataType:" << dataType + << " op:" << args.op << " blockDim:" << blockDim << " rootRank:" << args.root + << ", magic: " << args.magic; + int64_t dataSize = Count2Size(args.count, dataType); + if (dataSize == LCAL_INVALID_VALUE || blockDim == 0) { + MKI_LOG(ERROR) << ("LoadMTE args are invalid"); + return LCAL_ERROR_PARA_CHECK_FAIL; + } + + static const char *ENV = Mki::GetEnv("LCCL_PARALLEL"); + if (ENV && (string(ENV) == "1" || string(ENV) == "true") && dataSize >= IPC_BUFF_MAX_SIZE) { + MKI_LOG(ERROR) << ("LoadMTE args are invalid, because LCCL_PARALLEL is open, and dataSize is too big."); + return LCAL_ERROR_PARA_CHECK_FAIL; + } + + rtTaskCfgInfo_t cfgInfo{}; + cfgInfo.schemMode = 1; + + rtArgsEx_t argsInfo{}; + argsInfo.args = &args; + argsInfo.argsSize = sizeof(args); + + if (cclType == LcalType::BROADCAST || cclType == LcalType::BANDWIDTH) { + args.count = dataSize; + error = rtKernelLaunchWithFlagV2(GetFunSig(cclType, HCCL_DATA_TYPE_RESERVED), + blockDim, &argsInfo, nullptr, stream, 0, &cfgInfo); + } else { + error = rtKernelLaunchWithFlagV2(GetFunSig(cclType, dataType), + blockDim, &argsInfo, nullptr, stream, 0, &cfgInfo); + } + if (error != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "AsdRtFunctionLaunch -:" << LCAL_TYPE2NAME.at(cclType) << to_string(error); + return LCAL_ERROR_MKIRT; + } + return error; +} + +int LoadMTE(LcalType cclType, CCLGatherArgs &args, uint32_t blockDim, HcclDataType dataType, aclrtStream stream) +{ + int error = 0; + MKI_LOG(DEBUG) << "LoadMTE " << LCAL_TYPE2NAME.at(cclType) << " embTableLen:" << args.embTableLen + << " embTableDim:" << args.embTableDim + << " lookupLen:" << args.lookupLen; + + rtTaskCfgInfo_t cfgInfo{}; + cfgInfo.schemMode = 1; + + rtArgsEx_t argsInfo{}; + argsInfo.args = &args; + argsInfo.argsSize = sizeof(args); + + if (cclType == LcalType::GATHER) { + error = rtKernelLaunchWithFlagV2(GetFunSig(cclType, dataType), + blockDim, &argsInfo, nullptr, stream, 0, &cfgInfo); + } + if (error != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "AsdRtFunctionLaunch -:" << to_string(error); + return LCAL_ERROR_MKIRT; + } + return error; +} + +template +size_t OffsetOf(M T::*member, T obj) +{ + return reinterpret_cast(&(obj.*member)) - reinterpret_cast(&obj); +} + +int ComputeOverComm(LcalType cocType, CoCKernelArgs kernelArgs, HcclDataType dataType, aclrtStream stream) +{ + int error = LCAL_SUCCESS; + + size_t tilingAddrOffset = OffsetOf(&CoCKernelArgs::pCocTiling, kernelArgs); + size_t tilingDataOffset = OffsetOf(&CoCKernelArgs::cocKernelParam, kernelArgs) + + OffsetOf(&CoCKernelParam::cocTilingData, kernelArgs.cocKernelParam); + + auto &cocTilingData = kernelArgs.cocKernelParam.cocTilingData; + if (cocTilingData.withSerialMode != 0) { + static std::vector serialTags(LCAL_MAX_RANK_SIZE, 1); + cocTilingData.tag = serialTags[cocTilingData.rank]; + serialTags[cocTilingData.rank] = serialTags[cocTilingData.rank] % TAG_MOD + 1; + } + + rtTaskCfgInfo_t cfgInfo{}; + cfgInfo.schemMode = 1; + + rtArgsEx_t argsInfo{}; + argsInfo.args = static_cast(&kernelArgs); + argsInfo.hostInputInfoPtr = nullptr; + argsInfo.argsSize = sizeof(kernelArgs); + argsInfo.tilingAddrOffset = tilingAddrOffset; + argsInfo.tilingDataOffset = tilingDataOffset; + argsInfo.hostInputInfoNum = 0; + argsInfo.hasTiling = 1; + argsInfo.isNoNeedH2DCopy = 0; + + error = rtKernelLaunchWithFlagV2(GetFunSig(cocType, dataType), + kernelArgs.cocKernelParam.cocTilingData.blockDim, + &argsInfo, nullptr, stream, 0, &cfgInfo); + if (error != RT_ERROR_NONE) { + MKI_LOG(ERROR) << "AsdRtFunctionLaunch -:" << to_string(error); + return LCAL_ERROR_MKIRT; + } + return error; +} +} \ No newline at end of file diff --git a/comm/lcal/src/lcal_internal.h b/comm/lcal/src/lcal_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..2a0d3897c3c8b83b35e8a450354a1fe33242e377 --- /dev/null +++ b/comm/lcal/src/lcal_internal.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCAL_INTERNAL_H +#define LCAL_INTERNAL_H + +#include +#include +#include "lcal_types.h" +#include "coc_kernel_args.h" +#include "ccl_kernel_args.h" + +namespace Lcal { + +int RegistKernel(const int32_t opGroup = 0); + +int64_t Count2Size(int64_t count, const HcclDataType &dataType); + +int LoadMTE(LcalType cclType, AscendCCLKernelArgs &args, uint32_t blockDim, HcclDataType dataType, aclrtStream stream); + +int LoadMTE(LcalType cclType, CCLGatherArgs &args, uint32_t blockDim, HcclDataType dataType, aclrtStream stream); + +int ComputeOverComm(LcalType cocType, CoCKernelArgs kernelArgs, HcclDataType dataType, aclrtStream stream); +} + +#endif \ No newline at end of file diff --git a/comm/lcal/src/lcal_wrap.cpp b/comm/lcal/src/lcal_wrap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0bf96fecafcb20257a2b47e8c93d44d08836c9bb --- /dev/null +++ b/comm/lcal/src/lcal_wrap.cpp @@ -0,0 +1,278 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include "mki/utils/log/log.h" +#include "lcal.h" +#include "tools/socket/lcal_sock_exchange.h" + +using namespace std; +using namespace Lcal; + +int LcalCommInitRankLocal(int rankSize, int rank, LcalCommPtr *comm) +{ + MKI_LOG(INFO) << "using lcal c++ api! rank" << rank; + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm ptr is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + auto *c = new (std::nothrow) LcalComm(rank, rankSize); + if (c == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. rank : " << rank << ", rankSize : " << rankSize; + return LCAL_ERROR_INTERNAL; + } + *comm = c; + int ret = c->Init(); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lccl init failed!"; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalGetUniqueId(LcalUniqueId *uniqueId, int commDomain) +{ + if (uniqueId == nullptr) { + MKI_LOG(ERROR) << "uniqueId is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + int res = BootstrapGetUniqueId(reinterpret_cast(uniqueId), commDomain); + if (res != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lcal BootstrapGetUniqueId failed!"; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalCommInitRank(LcalUniqueId commId, int rankSize, int rank, LcalCommPtr *comm) +{ + MKI_LOG(INFO) << "using lcal c++ api! rank" << rank; + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm ptr is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + auto *c = new (std::nothrow) LcalComm(rank, rankSize, commId); + if (c == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. rank : " << rank << ", rankSize : " << rankSize; + return LCAL_ERROR_INTERNAL; + } + *comm = c; + int ret = c->Init(); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lccl init failed!"; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalCommInitRankWithCustDomainSize(int commDomain, int bufferSize, int rankSize, int rank, LcalCommPtr *comm, + const bool isEnableAutoMagicNum) +{ + MKI_LOG(INFO) << "using lcal c++ api! rank : " << rank << ", rankSize : " << rankSize << ", commDomain:" << + commDomain << ", bufferSize:" << bufferSize << ", isEnableAutoMagicNum:" << isEnableAutoMagicNum; + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm ptr is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + + constexpr int minBufferSize = LCAL_COMM_BUFFER_SIZE; + if (bufferSize < minBufferSize) { + MKI_LOG(ERROR) << "lcal comm buffer size " << bufferSize << " MBytes should not be less than " << + minBufferSize << " MBytes!"; + return LCAL_ERROR_INTERNAL; + } + + auto *c = new (std::nothrow) LcalComm(rank, rankSize, commDomain, bufferSize, isEnableAutoMagicNum); + if (c == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. rank : " << rank << ", rankSize : " << rankSize << ", commDomain:" << + commDomain << ", bufferSize:" << bufferSize << ", isEnableAutoMagicNum:" << isEnableAutoMagicNum; + return LCAL_ERROR_INTERNAL; + } + *comm = c; + int ret = c->Init(); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lccl init failed!"; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalCommInitRankWithDomain(int commDomain, int rankSize, int rank, LcalCommPtr *comm) +{ + constexpr int minBufferSize = LCAL_COMM_BUFFER_SIZE; + return LcalCommInitRankWithCustDomainSize(commDomain, minBufferSize, rankSize, rank, comm); +} + +int LcalGetCommArgsDev(LcalCommPtr comm, GM_ADDR &commArgsPtr) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + auto *lccl = static_cast(comm); + commArgsPtr = lccl->GetCommArgsPtr(); + return LCAL_SUCCESS; +} + +int LcalGetCommArgsHost(LcalCommPtr comm, Lcal::CommArgs *&commArgsPtr) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + auto *c = static_cast(comm); + commArgsPtr = c->GetCommArgs(); + return LCAL_SUCCESS; +} + +void LcalPrintDFX2Log(LcalCommPtr comm) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "lcal comm is nullptr!"; + return; + } + auto *lcal = static_cast(comm); + MKI_LOG(INFO) << lcal->PrintDFX(); +} + +int LcalCommInit(int rank, int rankSize, LcalCommPtr *comms) +{ + if (comms == nullptr) { + MKI_LOG(ERROR) << "lcal comms is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + *comms = new (std::nothrow) LcalComm(rank, rankSize); + if (*comms == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. rank : " << rank << ", rankSize : " << rankSize; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +int LcalCommInitAll(uint32_t ndev, int32_t *devices, LcalCommPtr *comms) +{ + if (comms == nullptr) { + MKI_LOG(ERROR) << "lcal comms is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + if (devices == nullptr) { + MKI_LOG(ERROR) << "lcal devices is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + static int commDomain = 0; + commDomain++; + for (uint32_t i = 0; i < ndev; ++i) { + comms[i] = new (std::nothrow) LcalComm(i, ndev, commDomain, LCAL_COMM_BUFFER_SIZE, false); + if (comms[i] == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. dev : " << i << ", ndev : " << ndev; + return LCAL_ERROR_INTERNAL; + } + } + static atomic uid; + uid++; + vector> threads; + int error = LCAL_SUCCESS; + for (uint32_t r = 0; r < ndev; r++) { + threads.emplace_back(make_unique( + [&](int rank) { + aclrtSetDevice(devices[rank]); + auto *c = static_cast(comms[rank]); + int ret = c->InitThread("uid" + to_string(uid)); + if (ret != LCAL_SUCCESS) { + error = ret; + } + }, + r)); + } + for (auto &t : threads) { + t->join(); + } + threads.clear(); + return error; +} + +int LcalCommInitThread(int rank, int rankSize, const char *uid, LcalCommPtr *comms) +{ + if (uid == nullptr) { + MKI_LOG(ERROR) << "lcal uid is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + if (comms == nullptr) { + MKI_LOG(ERROR) << "lcal comms is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + if (rank >= rankSize) { + MKI_LOG(ERROR) << "lcal rank : " << rank << " rankSize : " << rankSize; + return LCAL_ERROR_INTERNAL; + } + *comms = new (std::nothrow) LcalComm(rank, rankSize); + if (*comms == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed. rank : " << rank << ", rankSize : " << rankSize; + return LCAL_ERROR_INTERNAL; + } + auto *c = static_cast(*comms); + return c->InitThread(string(uid)); +} + +int LcclAllReduce(void *sendBuf, void *recvBuf, int64_t count, HcclDataType dataType, HcclReduceOp op, + LcalCommPtr comm, aclrtStream stream) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "LcclAllReduce comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + Lccl lccl(static_cast(comm)); + return lccl.AllReduce(sendBuf, recvBuf, count, dataType, op, stream); +} + +int LcclAllGather(void *sendBuf, void *recvBuf, int64_t sendCount, HcclDataType dataType, LcalCommPtr comm, + aclrtStream stream) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "LcclAllGather comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + Lccl lccl(static_cast(comm)); + return lccl.AllGather(sendBuf, recvBuf, sendCount, dataType, stream); +} + +int LcclReduceScatter(void *sendBuf, void *recvBuf, int64_t recvCount, HcclDataType dataType, HcclReduceOp op, + LcalCommPtr comm, aclrtStream stream) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "LcclReduceScatter comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + Lccl lccl(static_cast(comm)); + return lccl.ReduceScatter(sendBuf, recvBuf, recvCount, dataType, op, stream); +} + +int LcclBroadcast(void *buf, int64_t count, HcclDataType dataType, int root, LcalCommPtr comm, + aclrtStream stream) +{ + if (comm == nullptr) { + MKI_LOG(ERROR) << "LcclBroadcast comm is nullptr!"; + return LCAL_ERROR_INTERNAL; + } + Lccl lccl(static_cast(comm)); + return lccl.Broadcast(buf, count, dataType, root, stream); +} + +int LcclCommDestroy(LcalCommPtr comm) +{ + if (comm == nullptr) { + return LCAL_INVALID_VALUE; + } + auto *c = static_cast(comm); + delete c; + return LCAL_SUCCESS; +} \ No newline at end of file diff --git a/comm/lcal/src/lccl.cpp b/comm/lcal/src/lccl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..694bfc74e05b3e02fdad124e7c5bfe08258f89e1 --- /dev/null +++ b/comm/lcal/src/lccl.cpp @@ -0,0 +1,506 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "lccl.h" +#include "lcal_internal.h" + +#include +#include +#include +#include + +#include +#include + +#include "profiling/report_timing.h" + +using namespace std; +using namespace chrono; +using namespace Mki; + +namespace Lcal { + +uint32_t GetLocalReduceBlockDum(int64_t dataSize) +{ + constexpr int oneDataSize = 190 * 1024; + constexpr int maxBlockDim = 8; + int blockDim = dataSize / oneDataSize + 1; + return blockDim <= maxBlockDim ? blockDim : maxBlockDim; +} + +bool GetParallel() +{ + static int parallel = -1; + if (parallel == -1) { + static const char *ENV = Mki::GetEnv("LCCL_PARALLEL"); + parallel = (ENV && (string(ENV) == "1" || string(ENV) == "true")) ? 1 : 0; + MKI_LOG(INFO) << "LCCL_PARALLEL is " << parallel; + } + return static_cast(parallel); +} + +uint32_t GetAllReduceDetermBlockNum(uint32_t rankSize, int64_t dataSize, uint32_t extraFlag) +{ + constexpr uint32_t quickOneshotRankSize = 2; + constexpr uint32_t twoBlockNum = 2; + constexpr uint32_t threeStepNum = 3; + constexpr uint32_t rankSize910a3 = 16; + constexpr uint32_t dbRingBlockNum = 34; + constexpr int64_t smallDataSize = 1 * 1024 * 1024; + constexpr int32_t smallDataSize910a3 = 32 * 1024 * 1024; + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0) { + constexpr uint32_t maxAivNum = 40; + const bool isAivNumSupport = ((extraFlag & ExtraFlag::IS_GREATER_THAN_40_AIV) != 0 || + rankSize * threeStepNum <= maxAivNum); + if (rankSize % quickOneshotRankSize == 1 || rankSize == quickOneshotRankSize || + (rankSize <= rankSize910a3 && dataSize <= smallDataSize910a3 && isAivNumSupport)) { + return rankSize * threeStepNum; + } else { + return dbRingBlockNum; + } + } + if (dataSize < smallDataSize) { + return rankSize * twoBlockNum; + } + return rankSize * threeStepNum; +} + +uint32_t GetAllReduceBlockNum(uint32_t rankSize, int64_t dataSize, uint32_t extraFlag) +{ + constexpr uint32_t twoBlockNum = 2; + constexpr uint32_t threeStepNum = 3; + constexpr uint32_t dbRingBlockNum = 34; + constexpr int64_t smallDataSize = 1 * 1024 * 1024; + constexpr uint32_t smallRankSize = 8; + constexpr uint32_t cceSmallDataSize = 2 * 1024 * 1024; + constexpr uint32_t quickOneshotRankSize = 2; + const int64_t quantSmallDataSize = ((extraFlag & ExtraFlag::QUANT_FP16) != 0) ? (smallDataSize / 2) : smallDataSize; + constexpr int32_t smallDataSize910a3 = 32 * 1024 * 1024; + + if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { + return rankSize * twoBlockNum; + } else if ((extraFlag & ExtraFlag::QUANT_FP16) != 0) { + return dataSize <= quantSmallDataSize ? rankSize : rankSize * twoBlockNum; + } else if ((extraFlag & ExtraFlag::TOPO_910B2C) != 0 && rankSize > smallRankSize) { + return dataSize < cceSmallDataSize ? rankSize : (rankSize / twoBlockNum * threeStepNum + twoBlockNum); + } else if ((extraFlag & ExtraFlag::DETERMINISTIC) != 0) { + return GetAllReduceDetermBlockNum(rankSize, dataSize, extraFlag); + } + + if (GetParallel()) { + return rankSize; + } + + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && dataSize > smallDataSize910a3 && + (rankSize != quickOneshotRankSize)) { + return rankSize % quickOneshotRankSize == 0 ? dbRingBlockNum : rankSize * threeStepNum; + } + return (rankSize == quickOneshotRankSize || dataSize >= cceSmallDataSize) ? rankSize * twoBlockNum : rankSize; +} + +uint32_t GetReduceScatterBlockNum(uint32_t rankSize, int64_t dataSize, uint32_t extraFlag) +{ + constexpr uint32_t twoBlockNum = 2; + constexpr int64_t smallDataSize = 1 * 1024 * 1024; + constexpr uint32_t quickOneshotRankSize = 2; + constexpr int64_t cceSmallDataSize = 2 * 1024 * 1024; + constexpr int64_t a3BigDataSize = 32 * 1024 * 1024; + constexpr uint32_t fourStepBlockNum = 34; + constexpr uint32_t a3SupportRankSize = 4; + constexpr uint32_t smallRankSize = 8; + constexpr uint32_t dbRingBlockNum = 36; + + const bool isDbRing = (rankSize == a3SupportRankSize || rankSize == smallRankSize) && + (dataSize * smallRankSize > cceSmallDataSize && dataSize * smallRankSize <= a3BigDataSize); + + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && (rankSize > smallRankSize || isDbRing)) { + if (isDbRing) { + return dbRingBlockNum; + } else { + return dataSize <= smallDataSize ? rankSize : fourStepBlockNum; + } + } else { + return (rankSize == quickOneshotRankSize || dataSize >= cceSmallDataSize) ? rankSize * twoBlockNum : rankSize; + } +} + +uint32_t GetAll2AllBlockNum(uint32_t rankSize, int64_t dataSize, uint32_t extraFlag) +{ + constexpr uint32_t twoStepBlockNum = 16; + constexpr uint32_t twoBlockNum = 2; + constexpr int64_t smallDataSize = 1 * 1024 * 1024; + constexpr uint32_t smallRankSize = 8; + + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0) { + if (rankSize <= smallRankSize && dataSize > smallDataSize && + dataSize % (smallRankSize * smallRankSize * rankSize) == 0) { + return twoStepBlockNum * twoBlockNum; + } else { + return rankSize <= twoStepBlockNum ? rankSize * twoBlockNum : twoStepBlockNum * twoBlockNum; + } + } + return rankSize * twoBlockNum; +} + + +uint32_t GetAllGatherBlockNum(uint32_t rankSize, int64_t dataSize, uint32_t extraFlag) +{ + constexpr uint32_t axRankSize = 16; + constexpr uint32_t twoBlockNum = 2; + constexpr uint32_t quickOneshotRankSize = 2; + constexpr uint32_t allGatherHDBRingBlockNum = 32; + constexpr uint32_t cceSmallDataSize = 2 * 1024 * 1024; + constexpr int64_t smallDataSize910a3 = 32 * 1024 * 1024; + constexpr uint32_t smallRankSize = 8; + + if ((extraFlag & ExtraFlag::TOPO_910B2C) != 0 && (rankSize == axRankSize)) { + constexpr uint32_t axBlockNum = 10; + return axBlockNum; + } else if ((extraFlag & ExtraFlag::TOPO_PCIE) != 0) { + return rankSize * twoBlockNum; + } + + if (GetParallel()) { + return rankSize; + } + + if ((extraFlag & ExtraFlag::TOPO_910_93) != 0 && + (dataSize > smallDataSize910a3 || rankSize > smallRankSize) && + rankSize > quickOneshotRankSize && rankSize % quickOneshotRankSize == 0) { + return allGatherHDBRingBlockNum; + } + return (rankSize == quickOneshotRankSize || dataSize >= cceSmallDataSize) ? rankSize * twoBlockNum : rankSize; +} + +uint32_t GetKernelBlockNum(LcalType cclType, uint32_t rankSize, int64_t dataSize, int localRankSize, uint32_t extraFlag) +{ + constexpr uint32_t twoStepBlockNum = 16; + constexpr uint32_t twoBlockNum = 2; + constexpr int64_t smallDataSize = 1 * 1024 * 1024; + constexpr uint32_t gatherDefaultBlockNum = 4; + const uint32_t rankSizeLocal = static_cast(localRankSize); + + if (cclType == LcalType::LOCAL_REDUCE) { + return GetLocalReduceBlockDum(dataSize); + } + + if (cclType == LcalType::BROADCAST) { + return rankSize; + } + + if (cclType == LcalType::ALL2ALL_V_C) { + return twoStepBlockNum * twoBlockNum; + } + if (cclType == LcalType::ALL2ALL) { + return GetAll2AllBlockNum(rankSize, dataSize, extraFlag); + } + if (cclType == LcalType::BANDWIDTH) { + return twoStepBlockNum * twoBlockNum; + } + if (cclType == LcalType::ALL_REDUCE) { + return GetAllReduceBlockNum(rankSize, dataSize, extraFlag); + } + if (cclType == LcalType::REDUCE_SCATTER) { + return GetReduceScatterBlockNum(rankSize, dataSize, extraFlag); + } + if (cclType == LcalType::ALL_GATHER) { + return GetAllGatherBlockNum(rankSize, dataSize, extraFlag); + } + if (cclType == LcalType::GATHER) { + return gatherDefaultBlockNum; + } + bool sendOrRecv = cclType == LcalType::RECV || cclType == LcalType::SEND; + if (sendOrRecv) { + return dataSize <= smallDataSize ? rankSizeLocal : rankSizeLocal * twoBlockNum; + } + return twoStepBlockNum; +} + +uint32_t Lccl::GetBlockNum(LcalType cclType, uint32_t rankSize, int64_t dataSize, + int localRankSize, uint32_t extraFlag) const +{ + if (comm_ == nullptr) { + MKI_LOG(ERROR) << "comm is nullptr" << __LINE__; + return 0; + } + uint32_t blockNum = GetKernelBlockNum(cclType, rankSize, dataSize, localRankSize, extraFlag); + if (comm_->isEnableMix_) { + constexpr uint32_t aivNumPerAic = 2; + if (blockNum % aivNumPerAic == 1) { + MKI_LOG(ERROR) << "Lccl not support odd block number at msprof op enabled!"; + return 0; + } + return blockNum / aivNumPerAic; + } else { + return blockNum; + } +} + +int Lccl::LoopBack(const void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const +{ + if (sendBuff != recvBuff) { + auto ret = aclrtMemcpyAsync(recvBuff, Count2Size(count, dataType), sendBuff, Count2Size(count, dataType), + ACL_MEMCPY_DEVICE_TO_DEVICE, stream); + if (ret != 0) { + MKI_LOG(ERROR) << "LoopBack failed!"; + return LCAL_ERROR_INTERNAL; + } + } + return LCAL_SUCCESS; +} + +int Lccl::AllReduce(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, HcclReduceOp op, + aclrtStream stream, HcclDataType outputDataType, const void *scale, int64_t scaleCount, const void *offset) const +{ + if (!CheckBuff(sendBuff, recvBuff)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (!CheckDataType(dataType) || op == HCCL_REDUCE_PROD || + (outputDataType != HCCL_DATA_TYPE_RESERVED && !CheckDataType(outputDataType))) { + MKI_LOG(ERROR) << "Lccl not support."; + return LCAL_ERROR_NOT_INITIALIZED; + } + std::unique_ptr report; + if (comm_->isEnableMsprofOp_) { + report = std::make_unique("LcclAllReduce", comm_->rank_, true, + comm_->commArgs_.dumpAddr, stream); + } else { + report = std::make_unique("LcclAllReduce", comm_->commDomain_, count, dataType); + } + if ((dataType == HCCL_DATA_TYPE_INT8 && outputDataType == HCCL_DATA_TYPE_FP16) != + static_cast(comm_->commArgs_.extraFlag & ExtraFlag::QUANT_FP16)) { + if (dataType == HCCL_DATA_TYPE_INT8 && outputDataType == HCCL_DATA_TYPE_FP16) { + comm_->commArgs_.extraFlag |= ExtraFlag::QUANT_FP16; + } else { + comm_->commArgs_.extraFlag &= ~ExtraFlag::QUANT_FP16; + } + + auto ret = aclrtMemcpyAsync(comm_->commArgsPtr_, sizeof(CommArgs), &(comm_->commArgs_), sizeof(CommArgs), + ACL_MEMCPY_HOST_TO_DEVICE, stream); + if (ret != ACL_SUCCESS) { + MKI_LOG(ERROR) << "aclrtMemcpy err " << __LINE__ << " " << ret; + return LCAL_ERROR_INTERNAL; + } + } + + if ((comm_->commArgs_.extraFlag & ExtraFlag::QUANT_FP16) != 0 && + (comm_->commArgs_.extraFlag & (ExtraFlag::QUANT_DELAY | ExtraFlag::QUANT_CURRENT)) == 0) { + uint32_t blockDim = GetBlockNum(LcalType::ALL_REDUCE, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + AscendCCLKernelArgs ascendArgs = {sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, op, 0, 0, + scale, scaleCount, offset}; + comm_->magic_++; + return LoadMTE(LcalType::ALL_REDUCE, ascendArgs, blockDim, dataType, stream); + } + + if (rankSize_ <= 1) { + return LoopBack(sendBuff, recvBuff, count, dataType, stream); + } + + if ((comm_->commArgs_.extraFlag & (ExtraFlag::QUANT_DELAY | ExtraFlag::QUANT_CURRENT)) != 0) { + uint32_t blockDim = GetBlockNum(LcalType::ALL_REDUCE, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + AscendCCLKernelArgs args = { sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, op, 0, 0, scale, + scaleCount}; + comm_->magic_++; + return LoadMTE(LcalType::ALL_REDUCE, args, blockDim, dataType, stream); + } + + uint32_t blockDim = GetBlockNum(LcalType::ALL_REDUCE, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + AscendCCLKernelArgs args = {sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, op, 0}; + comm_->magic_++; + return LoadMTE(LcalType::ALL_REDUCE, args, blockDim, dataType, stream); +} + +bool Lccl::CheckDataType(const HcclDataType &dataType) const +{ + return (dataType == HCCL_DATA_TYPE_INT32 or dataType == HCCL_DATA_TYPE_FP16 or dataType == HCCL_DATA_TYPE_FP32 or + dataType == HCCL_DATA_TYPE_INT8 or dataType == HCCL_DATA_TYPE_INT16 or dataType == HCCL_DATA_TYPE_BFP16 or + dataType == HCCL_DATA_TYPE_INT64); +} + +bool Lccl::CheckBuff(const void *sendBuff, const void *recvBuff) const +{ + bool res = true; + if (sendBuff == nullptr) { + MKI_LOG(ERROR) << "Lccl sendBuff is nullptr"; + res = false; + } else if (recvBuff == nullptr) { + MKI_LOG(ERROR) << "Lccl recvBuff is nullptr"; + res = false; + } else if (comm_ == nullptr) { + MKI_LOG(ERROR) << "comm is nullptr" << __LINE__; + res = false; + } + return res; +} + +int Lccl::ReduceScatter(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, HcclReduceOp op, + aclrtStream stream) const +{ + if (!CheckBuff(sendBuff, recvBuff)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (rankSize_ <= 1) { + return LoopBack(sendBuff, recvBuff, count, dataType, stream); + } + std::unique_ptr report; + if (comm_->isEnableMsprofOp_) { + report = std::make_unique("LcclReduceScatter", comm_->rank_, true, + comm_->commArgs_.dumpAddr, stream); + } else { + report = std::make_unique("LcclReduceScatter", comm_->commDomain_, count, dataType); + } + if (CheckDataType(dataType) and op != HCCL_REDUCE_PROD) { + uint32_t blockDim = GetBlockNum(LcalType::REDUCE_SCATTER, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + AscendCCLKernelArgs args = { sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, op, 0 }; + comm_->magic_++; + return LoadMTE(LcalType::REDUCE_SCATTER, args, blockDim, dataType, stream); + } + MKI_LOG(ERROR) << "Lccl not support."; + return LCAL_ERROR_NOT_INITIALIZED; +} + +int Lccl::AllGather(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const +{ + if (!CheckBuff(sendBuff, recvBuff)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (rankSize_ <= 1) { + return LoopBack(sendBuff, recvBuff, count, dataType, stream); + } + std::unique_ptr report; + if (comm_->isEnableMsprofOp_) { + report = std::make_unique("LcclAllGather", comm_->rank_, true, + comm_->commArgs_.dumpAddr, stream); + } else { + report = std::make_unique("LcclAllGather", comm_->commDomain_, count, dataType); + } + AscendCCLKernelArgs args = { sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, 0, 0 }; + comm_->magic_++; + uint32_t blockDim = GetBlockNum(LcalType::ALL_GATHER, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + return LoadMTE(LcalType::ALL_GATHER, args, blockDim, dataType, stream); +} + +int Lccl::All2All(void *sendBuff, void *recvBuff, int64_t count, HcclDataType dataType, aclrtStream stream) const +{ + constexpr int32_t supportRankNum = 2; + if (!CheckBuff(sendBuff, recvBuff) || (rankSize_ > 1 && rankSize_ % supportRankNum != 0)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (rankSize_ <= 1) { + return LoopBack(sendBuff, recvBuff, count, dataType, stream); + } + ReportTiming report("LcclAll2All", comm_->commDomain_, count, dataType); + AscendCCLKernelArgs args = { sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, 0, 0, 0 }; + comm_->magic_++; + uint32_t blockDim = GetBlockNum(LcalType::ALL2ALL, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + return LoadMTE(LcalType::ALL2ALL, args, blockDim, dataType, stream); +} + +int Lccl::All2All(void *sendBuff, void *recvBuff, int64_t count, int32_t burstLen, + int32_t stride, HcclDataType dataType, aclrtStream stream) const +{ + if (!CheckBuff(sendBuff, recvBuff)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (rankSize_ <= 1) { + return LoopBack(sendBuff, recvBuff, count, dataType, stream); + } + ReportTiming report("LcclAll2AllTranspose", comm_->commDomain_, count, dataType); + + AscendCCLKernelArgs args = { sendBuff, recvBuff, comm_->commArgsPtr_, count, comm_->magic_, burstLen, stride}; + comm_->magic_++; + uint32_t blockDim = GetBlockNum(LcalType::ALL2ALL, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + return LoadMTE(LcalType::ALL2ALL, args, blockDim, dataType, stream); +} + +int64_t GetSizeByHcclDataType(const HcclDataType &dataType) +{ + int64_t dataSize = sizeof(int); + switch (dataType) { + case HCCL_DATA_TYPE_INT8: + case HCCL_DATA_TYPE_UINT8: + dataSize = sizeof(int8_t); + break; + case HCCL_DATA_TYPE_INT16: + case HCCL_DATA_TYPE_FP16: + case HCCL_DATA_TYPE_BFP16: + case HCCL_DATA_TYPE_UINT16: + dataSize = sizeof(int16_t); + break; + case HCCL_DATA_TYPE_FP32: + case HCCL_DATA_TYPE_INT32: + case HCCL_DATA_TYPE_UINT32: + dataSize = sizeof(int32_t); + break; + case HCCL_DATA_TYPE_INT64: + case HCCL_DATA_TYPE_UINT64: + dataSize = sizeof(int64_t); + break; + default: + MKI_LOG(ERROR) << "unknown datatype"; + } + return dataSize; +} + +int Lccl::Broadcast(void *buff, int64_t count, HcclDataType dataType, int32_t root, aclrtStream stream) const +{ + constexpr int supportRankSize = 8; + if (rankSize_ <= 1) { + return LCAL_SUCCESS; + } + if (rankSize_ > supportRankSize) { + MKI_LOG(ERROR) << "Broadcast does not support ranksize over 8"; + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (!CheckBuff(buff, buff)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcclBroadcast", comm_->commDomain_, count, dataType); + AscendCCLKernelArgs args = { buff, buff, comm_->commArgsPtr_, count, comm_->magic_, 0, root }; + comm_->magic_++; + uint32_t blockDim = GetBlockNum(LcalType::BROADCAST, rankSize_, Count2Size(count, dataType), + comm_->localRankSize_, comm_->commArgs_.extraFlag); + return LoadMTE(LcalType::BROADCAST, args, blockDim, dataType, stream); +} + +Lccl::~Lccl() +{ + if (rankSize_ == -1 and comm_ != nullptr) { + delete comm_; + } +} + +Lccl::Lccl(LcalComm *comm) : comm_(comm) +{ + if (comm != nullptr) { + rank_ = comm->rank_; + rankSize_ = comm->rankSize_; + } else { + MKI_LOG(ERROR) << "comm is nullptr."; + comm_ = new (std::nothrow) LcalComm(0, 0); + if (comm_ == nullptr) { + MKI_LOG(ERROR) << "LcalComm create failed " << __LINE__; + } + rankSize_ = -1; + } +} + +Lccl::Lccl(LcalComm &comm) : comm_(&comm) +{ + rank_ = comm.rank_; + rankSize_ = comm.rankSize_; +} +} \ No newline at end of file diff --git a/comm/lcal/src/lcoc.cpp b/comm/lcal/src/lcoc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ac343d62db27b4eabeaac87b2eecba0ece2bb91 --- /dev/null +++ b/comm/lcal/src/lcoc.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "lcal_internal.h" +#include "mki/utils/log/log.h" +#include "mki/utils/env/env.h" +#include "profiling/report_timing.h" +#include "runtime/rt_ffts.h" + +using namespace std; +using namespace chrono; +namespace Lcal { +bool CheckLcalComm(const LcalComm *lcalComm) +{ + if (lcalComm == nullptr) { + MKI_LOG(ERROR) << "The lcalComm is nullptr!"; + return false; + } + + auto rank = lcalComm->GetRank(); + auto rankSize = lcalComm->GetRankSize(); + auto coreNum = lcalComm->GetPhysicalInfo().coreNum; + std::vector> paramCheckList = { + {"rankSize", rankSize, PARAM_CHECK_MIN_VALUE_ONE, LCAL_MAX_RANK_SIZE}, + {"rank", rank, PARAM_CHECK_MIN_VALUE_ZERO, rankSize - 1}, + {"coreNum", coreNum, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + }; + return CheckParamScopeList(paramCheckList); +} + +bool CheckLcalType(LcalType lcalType) +{ + if (lcalType < LcalType::PURE_MATMUL || lcalType >= LcalType::LCAL_TYPE_MAX) { + MKI_LOG(ERROR) << "The lcalType:" << int(lcalType) + << " must be in [" << int(LcalType::PURE_MATMUL) << ", " << int(LcalType::LCAL_TYPE_MAX) << ")!"; + return false; + } + return true; +} + +bool Check2DTPType(LcalType lcalType) +{ + return lcalType == LcalType::ALL_GATHER_MATMUL_REDUCE_SCATTER; +} + +bool CheckMOEType(LcalType lcalType) +{ + return (lcalType >= LcalType::ALLTOALLV_ALLGATHER_MATMUL) && + (lcalType <= LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN); +} + +bool CheckCoCParamDesc(LcalType lcalType, const CoCParamDesc ¶mDesc) +{ + if (COC_TYPE2ELE_SIZE.find(paramDesc.dataTypeDesc) == COC_TYPE2ELE_SIZE.end()) { + MKI_LOG(ERROR) << "The dataTypeDesc:" << paramDesc.dataTypeDesc << " is not support yet!"; + return false; + } + if (paramDesc.op != HCCL_REDUCE_SUM) { + MKI_LOG(ERROR) << "The ReduceOp:" << paramDesc.op << " is not support yet!"; + return false; + } + + auto batchSize = paramDesc.mmInfo.batchSize; + auto m = paramDesc.mmInfo.m; + auto n = paramDesc.mmInfo.n; + auto k = paramDesc.mmInfo.k; + std::vector> paramCheckList = { + {"batchSize", batchSize, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MIN_VALUE_ONE}, + {"m", m, INPUT_PARAM_DEFAULT_VALUE, MAX_M_VALUE}, + {"n", n, PARAM_CHECK_MIN_VALUE_ONE, MAX_N_VALUE}, + {"k", k, PARAM_CHECK_MIN_VALUE_ONE, MAX_K_VALUE}, + }; + if (Check2DTPType(lcalType)) { + auto agDim = paramDesc.twoDimTPInfo.agDim; + auto rsDim = paramDesc.twoDimTPInfo.rsDim; + paramCheckList.emplace_back("agDim", agDim, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE); + paramCheckList.emplace_back("rsDim", rsDim, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE); + } + if (CheckMOEType(lcalType)) { + auto ep = paramDesc.moeInfo.EP; + auto tp = paramDesc.moeInfo.TP; + auto localExpertNums = paramDesc.moeInfo.local_expert_nums; + paramCheckList.emplace_back("ep", ep, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE); + paramCheckList.emplace_back("tp", tp, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE); + paramCheckList.emplace_back("localExpertNums", localExpertNums, + PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE); + } + return CheckParamScopeList(paramCheckList); +} + +bool Lcoc::CheckInputParam(LcalType lcalType, const CoCTiling &tiling, const CoCParamDesc ¶mDesc) const +{ + if (!CheckLcalComm(comm_)) { + return false; + } + if (!CheckLcalType(lcalType)) { + return false; + } + if (!CheckCoCTiling(tiling)) { + return false; + } + if (!CheckCoCParamDesc(lcalType, paramDesc)) { + return false; + } + return true; +} + +void Lcoc::SetTaskParam(LcalType lcalType, const CoCParamDesc ¶mDesc, const LcalComm &comm) +{ + taskParam_.rank = comm.GetRank(); + taskParam_.rankSize = comm.GetRankSize(); + taskParam_.blockDim = comm.GetPhysicalInfo().coreNum; + taskParam_.chipName = comm.GetPhysicalInfo().chipName; + taskParam_.cocParamDesc = paramDesc; + taskParam_.lcalType = lcalType; + taskParam_.bufferSize = comm.GetBufferSize(); +} + +void Lcoc::SetLcocParam(LcalType lcalType, const CoCParamDesc ¶mDesc) +{ + SetTaskParam(lcalType, paramDesc, *comm_); + tilingSuccess_ = false; +} + +CoCTilingFunc *CreateCoCTilingFunc(LcalType lcalType) +{ + bool isDeterministic = false; + const char *lcocDeterministic = Mki::GetEnv("LCCL_DETERMINISTIC"); + std::string lcocDeterministicStr = lcocDeterministic == nullptr ? "" : lcocDeterministic; + if (lcocDeterministicStr == "1" || lcocDeterministicStr == "true") { + isDeterministic = true; + } + CoCTilingFunc *pTilingFunc = nullptr; + switch (lcalType) { + case LcalType::ALL_GATHER_MATMUL: + pTilingFunc = new (std::nothrow) CoCAllGatherMatmulTilingFunc(); + break; + case LcalType::ALL_GATHER_MATMUL_V2: + pTilingFunc = new (std::nothrow) CoCAllGatherMatmulV2TilingFunc(); + break; + case LcalType::MATMUL_REDUCE_SCATTER: + pTilingFunc = new (std::nothrow) CoCMatmulReduceScatterTilingFunc(); + break; + case LcalType::MATMUL_ALL_REDUCE: + if (isDeterministic) { + pTilingFunc = new (std::nothrow) CoCMatmulAllReduceDeterTilingFunc(); + } else { + pTilingFunc = new (std::nothrow) CoCMatmulAllReduceTilingFunc(); + } + break; + case LcalType::ALL_GATHER_MATMUL_REDUCE_SCATTER: + pTilingFunc = new (std::nothrow) CoCAllgatherMatmulReduceScatterTilingFunc(); + break; + case LcalType::ALLTOALLV_ALLGATHER_MATMUL: + pTilingFunc = new (std::nothrow) CoCAllToAllAllGatherMatmulTilingFunc(); + break; + case LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN: + pTilingFunc = new (std::nothrow) CoCAllToAllAllGatherMatmulHiddenTilingFunc(); + break; + case LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN: + pTilingFunc = new (std::nothrow) CoCMatmulReduceScatterAllToAllHiddenTilingFunc(); + break; + default: + pTilingFunc = new (std::nothrow) CoCTilingFunc(); + } + return pTilingFunc; +} + +Lcoc::~Lcoc() {} + +Lcoc::Lcoc(LcalComm *comm) : comm_(comm) {} + +Lcoc::Lcoc(LcalComm &comm) : comm_(&comm) {} + +int Lcoc::SetParam(LcalType lcalType, const CoCTiling &tiling, const CoCParamDesc ¶mDesc) +{ + // 参数检查 + if (!CheckInputParam(lcalType, tiling, paramDesc)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + // 设置LCOC初始化参数 + SetLcocParam(lcalType, paramDesc); + // 创建Tiling函数 + CoCTilingFunc *pTilingFunc = CreateCoCTilingFunc(lcalType); + if (pTilingFunc == nullptr) { + PrintErrorLog(lcalType, "Create CoCTilingFunc failed!"); + return LCAL_ERROR_INTERNAL; + } + // 生成Tiling策略参数 + CoCTilingData tilingData = pTilingFunc->GenerateTiling(taskParam_, tiling); + // 检查Tiling策略参数是否合法 + bool tilingCheckRes = pTilingFunc->CheckTiling(taskParam_); + if (!tilingCheckRes) { + PrintErrorLog(lcalType, "Tiling check failed!"); + // 释放TilingFunc + delete pTilingFunc; + pTilingFunc = nullptr; + return LCAL_ERROR_INTERNAL; + } + // 赋值Tiling参数 + tiling_ = tilingData; + // 设置成功标志 + tilingSuccess_ = true; + // 释放TilingFunc + delete pTilingFunc; + pTilingFunc = nullptr; + return LCAL_SUCCESS; +} + +int Lcoc::LaunchOperator(CoCInputPkg &inputPkg, CoCOutputPkg &outputPkg, void *workspace, aclrtStream stream) +{ + CoCKernelArgs args; + int error = args.SetFFTSAddr(); + if (error != LCAL_SUCCESS) { + return error; + } + auto paramDesc = taskParam_.cocParamDesc; + args.SetInputPkgArgs(inputPkg); + args.SetOutputPkgArgs(outputPkg); + args.SetWorkspacePtrArg(workspace); + args.SetParamDescArgs(paramDesc); + args.SetCommArgs(*comm_); + args.SetCoCTilingDataArgs(tiling_); + MKI_LOG(DEBUG) << "[" << LCAL_TYPE2NAME.at(taskParam_.lcalType) << "]:" << args.ParamToString(); + return ComputeOverComm(taskParam_.lcalType, args, COC_TYPE2HCCL_TYPE.at(paramDesc.dataTypeDesc), stream); +} + +bool Lcoc::CheckBasic(const CoCInputPkg &inputPkg, const CoCOutputPkg &outputPkg, LcalType lcalType) const +{ + (void) outputPkg; + if (!tilingSuccess_) { + std::string str = "Tiling error. Please check whether the 'Lcoc::SetParam' method has been called, " + "or verify if the tiling parameter is valid."; + PrintErrorLog(lcalType, str); + return false; + } + if (taskParam_.lcalType != lcalType) { + std::string str = "lcalType of Lcoc::SetParam doesn't match launch function."; + PrintErrorLog(lcalType, str); + return false; + } + if (COC_TYPE2HCCL_TYPE.find(taskParam_.cocParamDesc.dataTypeDesc) == COC_TYPE2HCCL_TYPE.end()) { + std::string str = "invalid dataTypeDesc"; + PrintErrorLog(lcalType, str); + return false; + } + if (inputPkg.matrixA == nullptr || inputPkg.matrixB == nullptr) { + std::string str = "inputPkg.matrixA or inputPkg.matrixB is nullptr"; + PrintErrorLog(lcalType, str); + return false; + } + return true; +} + +int Lcoc::AllGatherMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream) +{ + LcalType lcalType = LcalType::ALL_GATHER_MATMUL; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocAllGatherMatmul", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::AllGatherMatmulV2(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream) +{ + LcalType lcalType = LcalType::ALL_GATHER_MATMUL_V2; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocAllGatherMatmulV2", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::MatmulReduceScatter(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream) +{ + LcalType lcalType = LcalType::MATMUL_REDUCE_SCATTER; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + if (taskParam_.cocParamDesc.mmInfo.m % taskParam_.rankSize != 0) { + if (taskParam_.rank == 0) { + MKI_LOG(ERROR) << "MatmulReduceScatter: input tensor must be the same size as output size times world size"; + } + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocMatmulReduceScatter", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::MatmulAllReduce(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream) +{ + LcalType lcalType = LcalType::MATMUL_ALL_REDUCE; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocMatmulAllReduce", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::PureMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, aclrtStream stream) +{ + LcalType lcalType = LcalType::PURE_MATMUL; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocPureMatmul", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::AllGatherMatmulReduceScatter(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream) +{ + LcalType lcalType = LcalType::ALL_GATHER_MATMUL_REDUCE_SCATTER; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + ReportTiming report("LcocAllGatherMatmulReduceScatter", true); + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::AllToAllVAllGatherMatmul(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream) +{ + LcalType lcalType = LcalType::ALLTOALLV_ALLGATHER_MATMUL; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::MatmulReduceScatterAllToAllVHidden(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream) +{ + LcalType lcalType = LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} + +int Lcoc::AllToAllVAllGatherMatmulHidden(CoCInputPkg inputPkg, CoCOutputPkg outputPkg, void *workspace, + aclrtStream stream) +{ + LcalType lcalType = LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN; + if (!CheckBasic(inputPkg, outputPkg, lcalType)) { + return LCAL_ERROR_PARA_CHECK_FAIL; + } + return LaunchOperator(inputPkg, outputPkg, workspace, stream); +} +LcalComm *Lcoc::GetComm() +{ + return comm_; +} + +MatMulInfo &Lcoc::GetMatMulInfo() +{ + return taskParam_.cocParamDesc.mmInfo; +} + +void Lcoc::GetTiling(CoCTiling &tiling) +{ + tiling = tiling_; +} + + +bool IsMatrixAligned(const int64_t &m, const int64_t &n, const bool &transpose, int nElemAlign) +{ + if (nElemAlign == 0) { + return false; + } + return (transpose ? m : n) % nElemAlign == 0; +} + +int64_t Lcoc::GetWorkspaceSize() +{ + LcalType lcalType = taskParam_.lcalType; + auto cocParamDesc = taskParam_.cocParamDesc; + bool isDeterministic = (GetComm()->GetCommArgs()->extraFlag & ExtraFlag::DETERMINISTIC) != 0; + CoCDataTypeDesc dataType = cocParamDesc.dataTypeDesc; + const MatMulInfo &mmInfo = cocParamDesc.mmInfo; + const QuantInfo &quantInfo = cocParamDesc.quantInfo; + const MoeInfo& moeInfo = cocParamDesc.moeInfo; + bool hasQuant = quantInfo.quantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + bool hasDequant = quantInfo.dequantGranularity != QuantGranularity::QUANT_GRANULARITY_UNDEFINED; + int32_t eleSize = COC_TYPE2ELE_SIZE.at(dataType); + int32_t nElemAlign = Lcal::ALIGN_BYTES / eleSize; + int32_t mAlign = AlignUp(mmInfo.m, nElemAlign); + int32_t nAlign = AlignUp(mmInfo.n, nElemAlign); + int32_t kAlign = AlignUp(mmInfo.k, nElemAlign); + int32_t maxOutputSize = moeInfo.maxOutputSize; + + bool hasAAlign = hasQuant || (!IsMatrixAligned(mmInfo.m, mmInfo.k, mmInfo.transA, nElemAlign) && mmInfo.m != 1); + + bool hasBAlign = (!mmInfo.weightNz) && ((hasDequant && !mmInfo.isInt8) + || (!IsMatrixAligned(mmInfo.k, mmInfo.n, mmInfo.transB, nElemAlign))); + + int32_t accumRankSize = taskParam_.lcalType == LcalType::ALL_GATHER_MATMUL ? taskParam_.rankSize : 0; + + bool hasAccum = dataType == CoCDataTypeDesc::INT8INT8_INT32_BF16; + bool hasDequantParam = (quantInfo.dequantGranularity == QuantGranularity::PER_TOKEN || + quantInfo.dequantGranularity == QuantGranularity::PER_TENSOR); + bool hasFormatDequantScale = (quantInfo.dequantGranularity == QuantGranularity::PER_CHANNEL); + bool isMoe = false; + if (lcalType == LcalType::ALLTOALLV_ALLGATHER_MATMUL || + lcalType == LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN || + lcalType == LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN) { + isMoe = true; + } + bool isAlltoallVc = + lcalType == LcalType::ALLTOALLV_ALLGATHER_MATMUL || lcalType == LcalType::ALLTOALLVC_ALLGATHER_MATMUL_HIDDEN || + lcalType == LcalType::MATMUL_REDUCESCATTER_ALLTOALLVC_HIDDEN; + + uint64_t dequantWorkSpaceSize = GetDequantWorkSpaceSize(lcalType, tiling_.withSerialMode, mmInfo.m, mmInfo.n, + tiling_.m0, tiling_.n0, tiling_.pValue, tiling_.nLoop, taskParam_.rankSize, taskParam_.blockDim, maxOutputSize); + LcalWorkspaceInfo lcalWorkspaceInfo = GetLcalWorkspaceInfo(0, mmInfo.batchSize, mmInfo.m, mmInfo.k, + mmInfo.n, mAlign, kAlign, nAlign, mmInfo.transA, mmInfo.transB, eleSize, hasAAlign, hasBAlign, + accumRankSize, hasAccum, dequantWorkSpaceSize, hasDequantParam, hasFormatDequantScale, isDeterministic, + isMoe, isAlltoallVc, moeInfo.EP, moeInfo.local_expert_nums, maxOutputSize); + + MKI_LOG(DEBUG) << "[Lcoc Workspace]: " << "m=" << mmInfo.m << ", k=" << mmInfo.k << ", n=" << mmInfo.n + << ", mAlign=" << mAlign << ", kAlign=" << kAlign << ", nAlign=" << nAlign << ", transA=" << mmInfo.transA + << ", transB=" << mmInfo.transB << ", eleSize=" << eleSize << ", hasAAlign=" << hasAAlign + << ", hasBAlign=" << hasBAlign << ", accumRankSize=" << accumRankSize << ", hasAccum=" << hasAccum + << ", dequantWorkSpaceSize=" << dequantWorkSpaceSize << ", hasDequantParam=" << hasDequantParam + << ", hasFormatDequantScale=" << hasFormatDequantScale << ", isDeterministic=" << isDeterministic + << ", isMoe=" << isMoe << ", isAlltoallVc=" << isAlltoallVc << ", moeInfo.EP=" << static_cast(moeInfo.EP) + << ", moeInfo.local_expert_nums=" << moeInfo.local_expert_nums + << ", maxOutputSize=" << maxOutputSize << ", workspaceSize=" << lcalWorkspaceInfo.workspaceSize; + return lcalWorkspaceInfo.workspaceSize; +} +} diff --git a/comm/lcal/src/lcoc_func.cpp b/comm/lcal/src/lcoc_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed5550fd2296e5e56a2f8c58deb3871402f856c5 --- /dev/null +++ b/comm/lcal/src/lcoc_func.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "lcoc_func.h" +#include "lcoc_args.h" +#include "mki/utils/log/log.h" + +using namespace std; +namespace Lcal { + // 校验参数取值范围在[min, max]内,当max=-1时,表示参数取值范围在[min, +∞) + bool CheckParamScope(const std::string &name, const int &value, const int &min, const int &max) + { + if (value < min || (max != PARAM_CHECK_MAX_VALUE && value > max)) { + if (max == PARAM_CHECK_MAX_VALUE) { + MKI_LOG(ERROR) << "The " << name << ":" << value << " must equal or greater than " << min << "!"; + } else { + MKI_LOG(ERROR) << "The " << name << ":" << value << " must be in [" << min << ", " << max << "]!"; + } + return false; + } + return true; + } + + bool CheckParamScopeList(std::vector> paramCheckList) + { + for (auto ¶m : paramCheckList) { + auto name = std::get<0>(param); + auto value = std::get<1>(param); + auto min = std::get<2>(param); + auto max = std::get<3>(param); + if (value == INPUT_PARAM_DEFAULT_VALUE) { + continue; + } + if (!CheckParamScope(name, value, min, max)) { + return false; + } + } + return true; + } + + bool CheckParamAlign(const std::string &name, const int &value, const int &align) + { + if (align == 0) { + return false; + } + if (value % align != 0) { + MKI_LOG(ERROR) << "The " << name << ":" << value << " must be aligned by " << align << "!"; + return false; + } + return true; + } + + void PrintErrorLog(LcalType lcalType, const std::string &log) + { + MKI_LOG(ERROR) << "[" + LCAL_TYPE2NAME.at(lcalType) + "]: " << log; + } + + bool CheckParamPowerOfTwo(const std::string &name, int value) + { + if (value <= 0) { + MKI_LOG(ERROR) << "The " << name << ":" << value << " must be greater than zero!"; + return false; + } + if ((static_cast(value) & (static_cast(value) - 1)) != 0) { + MKI_LOG(ERROR) << "The " << name << ":" << value << " must be power of two!"; + return false; + } + return true; + } + + int64_t GetAlignedMatrixSize(const int64_t &batchSize, const int64_t &m, const int64_t &n, const bool &transpose, + int nElemAlign) + { + if (nElemAlign == 0) { + return false; + } + int64_t nRow = transpose ? n : m; + int64_t nCol = transpose ? m : n; + int64_t nColAlign = (nCol + nElemAlign - 1) / nElemAlign * nElemAlign; + return batchSize * nRow * nColAlign; + } + +} \ No newline at end of file diff --git a/comm/lcal/src/profiling/report_timing.h b/comm/lcal/src/profiling/report_timing.h new file mode 100644 index 0000000000000000000000000000000000000000..bbcc3e365a741580fa36af1e59e380d714c4d761 --- /dev/null +++ b/comm/lcal/src/profiling/report_timing.h @@ -0,0 +1,373 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef REPORT_TIMING_H +#define REPORT_TIMING_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Lcal { +class ReportTiming { +public: + static constexpr uint64_t PROF_TASK_TIME_DUMP = 0x0000100000000ULL; + ReportTiming() = delete; + explicit ReportTiming(const char *opName, int commDomain, int64_t count = 0, + HcclDataType dataType = HCCL_DATA_TYPE_RESERVED) + : opName_(opName), typeMix_(false), count_(count), dataType_(dataType) + { + InitProfiling(commDomain); + } + + explicit ReportTiming(const char *opName, uint32_t blockDim) + : opName_(opName), blockDim_(blockDim), typeMix_(true) + { + InitProfiling(0); + } + + explicit ReportTiming(const char *opName, const int32_t rankId, const bool isReporting, uint8_t *dumpAddr, + const aclrtStream stream) : opName_(opName), rankId_(rankId), isReporting_(isReporting), + dumpAddr_(dumpAddr), stream_(stream) + { + moduleId_ = DUMP_MODULE_ID; + InitProfiling(0); + } + + ~ReportTiming() + { + MKI_LOG(DEBUG) << "ReportTiming " << __LINE__ << " ~ReportTiming() " << + " isReporting_:" << isReporting_ << " profEnable_:" << profEnable_; + if (profEnable_ && isReporting_) { + ReportMsprofData(); + } + + if (!isReporting_) { + ProfilingStatus(RESET_STATUS); + } + } + + void InitProfiling(int commDomain) + { + if (ProfilingStatus() == -1) { + ProfilingStatus(0); + MKI_LOG(INFO) << "MsprofRegisterCallback start!"; + if (MsprofRegisterCallback(moduleId_, ProfHandle) != 0) { + MKI_LOG(ERROR) << "MsprofRegisterCallback fail!"; + } + } + + MKI_LOG(DEBUG) << "InitProfiling " << __LINE__ << " ProfilingStatus():" << ProfilingStatus() << + " isReporting_:" << isReporting_; + if (ProfilingStatus() > 0) { + ParamsInit(commDomain); + } + MKI_LOG(DEBUG) << "InitProfiling " << __LINE__ << " ProfilingStatus():" << ProfilingStatus() << + " isReporting_:" << isReporting_ << " profEnable_:" << profEnable_; + } + + static int64_t ProfilingStatus(int64_t setValue = -1) + { + static int64_t profilingStatus = -1; + if (setValue == RESET_STATUS) { + profilingStatus = -1; + } else if (setValue != -1) { + profilingStatus = setValue; + } + return profilingStatus; + } + + void ParamsInit(int commDomain) + { + profEnable_ = true; + std::string groupName = std::to_string(commDomain); + groupHash_ = MsprofGetHashId(groupName.c_str(), strlen(groupName.c_str())); + + std::string naStr = "NA"; + naHash_ = MsprofGetHashId(naStr.c_str(), strlen(naStr.c_str())); + + nameHash_ = MsprofGetHashId(opName_, strlen(opName_)); + beginTime_ = MsprofSysCycleTime(); + } + + void ReportMsprofData() + { + tid_ = GetCurrentThreadId(); + if (tid_ == -1) { + MKI_LOG(ERROR) << "GetCurrentThreadId error!" << " name: " << opName_; + return; + } + endTime_ = MsprofSysCycleTime(); + + MKI_LOG(DEBUG) << "ReportMsprofData " << ProfilingStatus() << " dumpAddr_ is " << + (dumpAddr_ == nullptr ? "" : "not") << " nullptr "; + + if (ProfilingStatus() != PROF_TASK_TIME_DUMP || dumpAddr_ == nullptr) { + CallMsprofReportHostNodeApi(); + CallMsprofReportHostLcclOpApi(); + CallMsprofReportHostLcclOpInfo(); + CallMsprofReportHostNodeBasicInfo(); + CallMsprofReportContextIdInfo(); + } else { + CallMsprofReportDumpApi(); + } + } + + void CallMsprofReportDumpApi() const + { + constexpr uint32_t dumpCoreCnt = 75; + constexpr uint32_t dumpSizePerCore = 1 * 1024 * 1024; + constexpr uint32_t dumpWorkspaceSize = dumpCoreCnt * dumpSizePerCore; + + MKI_LOG(DEBUG) << "LcclReporting dump rankId " << rankId_; + uint8_t *devProfData = dumpAddr_; + size_t profLen = dumpWorkspaceSize; + + std::vector buffer(profLen, 0); + int ret = 0; + ret = aclrtMemcpyAsync(&buffer[0], profLen, devProfData, profLen, ACL_MEMCPY_DEVICE_TO_HOST, stream_); + if (ret != 0) { + MKI_LOG(ERROR) << "aclrtMemcpyAsync dump data failed"; + } + ret = aclrtSynchronizeStream(stream_); + if (ret != 0) { + MKI_LOG(ERROR) << "aclrtSynchronizeStream dump data failed"; + } + + constexpr int32_t logLimit = 2; + constexpr int32_t logFirstLimit = 10; + constexpr int32_t profLevel = 3000; + MsprofAdditionalInfo t; + t.level = profLevel; + t.type = 0; + t.threadId = 0; + t.dataLen = sizeof(LcclDumpLogInfo); + t.timeStamp = 0; + for (uint32_t coreId = 0; coreId < dumpCoreCnt; ++coreId) { + LcclDumpUnion *u = reinterpret_cast(&buffer[coreId * dumpSizePerCore]); + LcclDumpBlockInfo *b = &(u->blockInfo); + LcclDumpLogInfo *l = &((u + 1)->logInfo); + + int32_t logLen = (dumpSizePerCore - b->dumpOffset) / sizeof(LcclDumpUnion) - 1; + for (int32_t logInfoIdx = 0; logInfoIdx < logLen; ++logInfoIdx) { + LcclDumpLogInfo *logInfo = l + logInfoIdx; + auto ret = memcpy_s(t.data, sizeof(LcclDumpLogInfo), logInfo, sizeof(LcclDumpLogInfo)); + if (ret != 0) { + MKI_LOG(ERROR) << "LcclReporting report memcpy_s err " << ret; + } + if ((logInfoIdx < logLimit) || (logInfoIdx < logFirstLimit && rankId_ == 0 && coreId == 0)) { + MKI_LOG(DEBUG) << "LcclReporting report: rankId=" << rankId_ << ", coreId=" << coreId << + ", curLog=" << logInfoIdx << "/" << logLen << + "; LcclDumpLogInfo: logId=" << logInfo->logId << ", blockId=" << logInfo->blockId << + ", syscyc=" << logInfo->syscyc << ", curPc=" << logInfo->curPc << + ", operationType=" << logInfo->operationType; + } + MsprofReportAdditionalInfo(0, &t, sizeof(MsprofAdditionalInfo)); + } + } + } + + void CallMsprofReportHostNodeApi() const + { + MsprofApi reporterData{}; + reporterData.level = MSPROF_REPORT_NODE_LEVEL; + reporterData.type = MSPROF_REPORT_NODE_LAUNCH_TYPE; + reporterData.threadId = static_cast(tid_); + reporterData.beginTime = beginTime_; + reporterData.endTime = endTime_; + reporterData.itemId = nameHash_; + + auto ret = MsprofReportApi(true, &reporterData); + if (ret != 0) { + MKI_LOG(ERROR) << "CallMsprofReportHostNodeApi error! code: " << ret << " name: " << opName_; + } + } + + void CallMsprofReportHostLcclOpApi() const + { + if (typeMix_) { + return; + } + MsprofApi reporterData{}; + reporterData.level = MSPROF_REPORT_HCCL_NODE_LEVEL; + reporterData.type = MSPROF_REPORT_HCCL_MASTER_TYPE; + reporterData.threadId = static_cast(tid_); + reporterData.beginTime = beginTime_; + reporterData.endTime = endTime_; + reporterData.itemId = nameHash_; + + auto ret = MsprofReportApi(true, &reporterData); + if (ret != 0) { + MKI_LOG(ERROR) << "CallMsprofReportHostLcclOpApi error! code: " << ret << " name: " << opName_; + } + } + + void CallMsprofReportHostLcclOpInfo() const + { + if (typeMix_) { + return; + } + MsprofCompactInfo reporterData = {}; + reporterData.level = MSPROF_REPORT_NODE_LEVEL; + reporterData.type = MSPROF_REPORT_NODE_HCCL_OP_INFO_TYPE; + reporterData.threadId = static_cast(tid_); + reporterData.dataLen = sizeof(MsprofHCCLOPInfo); + reporterData.timeStamp = beginTime_ + 1; + + reporterData.data.hcclopInfo.relay = 0; + reporterData.data.hcclopInfo.retry = 0; + reporterData.data.hcclopInfo.dataType = dataType_; + reporterData.data.hcclopInfo.algType = naHash_; + reporterData.data.hcclopInfo.count = count_; + reporterData.data.hcclopInfo.groupName = groupHash_; + + auto ret = MsprofReportCompactInfo(static_cast(true), + static_cast(&reporterData), static_cast(sizeof(MsprofCompactInfo))); + if (ret != 0) { + MKI_LOG(ERROR) << "CallMsprofReportHostLcclOpInfo error! code: " << ret << " name: " << opName_; + } + } + + void CallMsprofReportHostNodeBasicInfo() const + { + if (ProfilingStatus() == PROF_TASK_TIME_L0) { + return; + } + MsprofCompactInfo reporterData{}; + + reporterData.level = MSPROF_REPORT_NODE_LEVEL; + reporterData.type = MSPROF_REPORT_NODE_BASIC_INFO_TYPE; + reporterData.threadId = static_cast(tid_); + reporterData.dataLen = sizeof(MsprofNodeBasicInfo); + reporterData.timeStamp = endTime_; + + reporterData.data.nodeBasicInfo.opName = nameHash_; + reporterData.data.nodeBasicInfo.opType = nameHash_; + reporterData.data.nodeBasicInfo.blockDim = ((blockDim_ & 0x0000FFFU) | 0x20000U); + + auto ret = MsprofReportCompactInfo(static_cast(true), + static_cast(&reporterData), + static_cast(sizeof(MsprofCompactInfo))); + if (ret != 0) { + MKI_LOG(ERROR) << "CallMsprofReportHostNodeBasicInfo error! code: " << ret << " name: " << opName_; + } + } + + void CallMsprofReportContextIdInfo() const + { + if (!typeMix_) { + return; + } + + MsprofAdditionalInfo additionalInfo = {}; + additionalInfo.magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM; + additionalInfo.level = MSPROF_REPORT_NODE_LEVEL; + additionalInfo.type = MSPROF_REPORT_NODE_CONTEXT_ID_INFO_TYPE; + additionalInfo.timeStamp = beginTime_ + 1; + additionalInfo.threadId = static_cast(tid_); + additionalInfo.dataLen = sizeof(MsprofContextIdInfo); + + MsprofContextIdInfo info = {}; + info.opName = nameHash_; + info.ctxIdNum = 1; + info.ctxIds[0] = 0; + + int ret = memcpy_s(additionalInfo.data, MSPROF_ADDTIONAL_INFO_DATA_LENGTH, &info, sizeof(MsprofContextIdInfo)); + MKI_LOG_IF(ret != EOK, ERROR) << "memcpy_s Error! Error Code: " << ret; + + auto retReport = MsprofReportAdditionalInfo(static_cast(true), + static_cast(&additionalInfo), + static_cast(sizeof(MsprofAdditionalInfo))); + if (retReport != 0) { + MKI_LOG(ERROR) << "ProfReportAdditionalInfo error!" << " name: " << opName_; + } + } + + static int32_t GetCurrentThreadId() + { + int32_t tid = static_cast(syscall(SYS_gettid)); + if (tid == -1) { + MKI_LOG(ERROR) << "get tid failed, errno: " << errno; + } + return tid; + } + + static int32_t ProfHandle(uint32_t type, void *data, uint32_t len) + { + if (data == nullptr) { + MKI_LOG(ERROR) << "ProfHandle failed! data is nullptr!"; + return -1; + } + if (type != PROF_CTRL_SWITCH) { + MKI_LOG(ERROR) << "ProfHandle failed! ProfCtrlType is not correct!"; + return -1; + } + if (len < sizeof(MsprofCommandHandle)) { + MKI_LOG(ERROR) << "ProfHandle failed! dataSize is not correct!"; + return -1; + } + MsprofCommandHandle *profilerConfig = static_cast(data); + const uint32_t profType = profilerConfig->type; + const uint64_t profSwitch = profilerConfig->profSwitch; + if (profType == PROF_COMMANDHANDLE_TYPE_START) { + MKI_LOG(INFO) << "Open Profiling Switch " << std::hex << profSwitch << std::dec; + if ((profSwitch & PROF_TASK_TIME_L0) != PROF_CTRL_INVALID) { + ProfilingStatus(PROF_TASK_TIME_L0); + MKI_LOG(DEBUG) << "Profiling Level0 Enable"; + } + if ((profSwitch & PROF_TASK_TIME_L1) != PROF_CTRL_INVALID) { + ProfilingStatus(PROF_TASK_TIME_L1); + MKI_LOG(DEBUG) << "Profiling Level1 Enable"; + } + if ((profSwitch & PROF_TASK_TIME_DUMP) != PROF_CTRL_INVALID) { + ProfilingStatus(PROF_TASK_TIME_DUMP); + MKI_LOG(DEBUG) << "Profiling dump Enable"; + } + } + if (profType == PROF_COMMANDHANDLE_TYPE_STOP) { + MKI_LOG(INFO) << "Close Profiling Switch"; + ProfilingStatus(0); + } + return 0; + } + +private: + static constexpr uint64_t PROF_TASK_TIME_L0 = 0x00000800ULL; + static constexpr uint64_t PROF_TASK_TIME_L1 = 0x00000002ULL; + static constexpr int32_t DUMP_MODULE_ID = 61; + static constexpr int32_t RESET_STATUS = -2; + uint64_t beginTime_ = 0; + uint64_t endTime_ = 0; + const char *opName_ = nullptr; + uint32_t blockDim_ = 0; + uint64_t nameHash_ = 0; + uint64_t groupHash_ = 0; + uint64_t naHash_ = 0; + bool typeMix_ = false; + long tid_ = 0; + bool profEnable_ = false; + int64_t count_ = 0; + uint8_t dataType_ = HCCL_DATA_TYPE_RESERVED; + int32_t rankId_ = 0; + bool isReporting_ = true; + uint8_t *dumpAddr_ = nullptr; + aclrtStream stream_ = nullptr; + int32_t moduleId_ = INVLID_MOUDLE_ID; +}; +} +#endif \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgather_reducescatter_tiling.cpp b/comm/lcal/src/tiling/allgather_reducescatter_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f25363b392715bfbe031f70aa4676e199845fd87 --- /dev/null +++ b/comm/lcal/src/tiling/allgather_reducescatter_tiling.cpp @@ -0,0 +1,410 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include +#include "tiling.h" +#include "tiling_func.h" +#include "lcoc_func.h" + +#define TILING_MAP std::map>> +namespace Lcal { +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_SWIZZLECOUNT_DEFAULT = 11; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16SwizzlecountMap = { + {9, + {{768, 1536, -1, 2147483647, -1, 7168}, + {1536, 3072, -1, 5120, -1, 14848}, + {1536, 2147483647, 5120, 2147483647, -1, 10752}}}, + {14, {{768, 1536, -1, 5120, 10752, 2147483647}, {1536, 2147483647, -1, 5120, 14848, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_UBMOVENUM_DEFAULT = 40; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16UbmovenumMap = { + {24, + {{768, 1536, -1, 2147483647, 3072, 10752}, + {1536, 3072, -1, 7168, 3072, 2147483647}, + {3072, 2147483647, -1, 7168, 3072, 2147483647}}}, + {30, {{3072, 2147483647, 7168, 2147483647, 3072, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_LENPERLOOPMULT_DEFAULT = 400; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16LenperloopmultMap = { + {2, {{768, 1536, -1, 5120, -1, 3072}}}, + {4, {{3072, 2147483647, -1, 3072, -1, 2147483647}, {3072, 2147483647, 14848, 2147483647, 3072, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_COMMNPUSPLIT_DEFAULT = 8; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16CommnpusplitMap = { + {1, + {{768, 1536, 5120, 2147483647, -1, 3072}, + {1536, 3072, 14848, 2147483647, -1, 7168}, + {3072, 2147483647, 14848, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_COMMDATASPLIT_DEFAULT = 1; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16CommdatasplitMap = { + {8, + {{768, 1536, 5120, 2147483647, -1, 3072}, + {1536, 3072, 14848, 2147483647, -1, 7168}, + {3072, 2147483647, 14848, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRAUBMOVENUM_DEFAULT = 12; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16ExtraubmovenumMap = { + {10, + {{-1, 768, -1, 2147483647, 5120, 10752}, + {768, 1536, -1, 2147483647, 5120, 2147483647}, + {1536, 2147483647, -1, 10752, 5120, 2147483647}, + {1536, 2147483647, 10752, 14848, -1, 10752}}}, + {20, {{1536, 2147483647, 14848, 2147483647, 10752, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRALENPERLOOPMULT_DEFAULT = 4; +static TILING_MAP g_allgatherEightReducescatterTwoFalseFP16ExtralenperloopmultMap = { + {1, {{3072, 2147483647, -1, 3072, 10752, 2147483647}, {1536, 2147483647, 3072, 7168, -1, 2147483647}}}, + {2, {{768, 1536, 5120, 2147483647, 5120, 2147483647}, {1536, 2147483647, 7168, 10752, -1, 2147483647}}}, + {400, {{3072, 2147483647, 10752, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRACOMMNPUSPLIT_DEFAULT = 1; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRACOMMDATASPLIT_DEFAULT = 8; + +// 821 +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_SWIZZLECOUNT_DEFAULT = 5; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16SwizzlecountMap = { + {9, {{192, 768, -1, 2147483647, -1, 12800}}}, + {17, {{768, 2147483647, 7168, 2147483647, -1, 7936}}}, + {13, + {{-1, 192, 5120, 2147483647, -1, 12800}, + {-1, 192, -1, 2147483647, 15360, 2147483647}, + {768, 2147483647, -1, 7168, -1, 9088}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_UBMOVENUM_DEFAULT = 60; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16UbmovenumMap = { + {30, {{384, 2147483647, -1, 5120, 3968, 2147483647}, {768, 2147483647, 7168, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_LENPERLOOPMULT_DEFAULT = 400; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMNPUSPLIT_DEFAULT = 8; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16CommnpusplitMap = { + {2, {{-1, 192, 7168, 2147483647, -1, 4608}, {192, 2147483647, 5120, 7168, -1, 4544}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMDATASPLIT_DEFAULT = 1; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16CommdatasplitMap = { + {4, {{-1, 192, 7168, 2147483647, -1, 4608}, {192, 2147483647, 5120, 7168, -1, 4544}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMDIRECT_DEFAULT = 1; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRAUBMOVENUM_DEFAULT = 20; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16ExtraubmovenumMap = { + {60, {{-1, 192, -1, 5120, -1, 6912}, {-1, 192, -1, 2147483647, 10368, 2147483647}}}, + {40, + {{-1, 192, 5120, 2147483647, -1, 6912}, + {-1, 192, -1, 2147483647, 6912, 10368}, + {192, 384, -1, 2147483647, 1600, 4608}}}, + {30, {{192, 384, -1, 2147483647, 4608, 2147483647}, {768, 2147483647, -1, 5120, -1, 3968}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRALENPERLOOPMULT_DEFAULT = 2; +static TILING_MAP g_allgatherEightReducescatterTwoTrueFP16ExtralenperloopmultMap = { + {4, {{384, 2147483647, -1, 5120, -1, 3968}, {384, 2147483647, 7168, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRACOMMNPUSPLIT_DEFAULT = 1; + +constexpr int32_t ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRACOMMDATASPLIT_DEFAULT = 8; + +// 281 +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_SWIZZLECOUNT_DEFAULT = 11; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16SwizzlecountMap = { + {9, + {{3072, 6144, -1, 2147483647, -1, 10752}, + {12288, 2147483647, -1, 7168, -1, 10752}, + {12288, 2147483647, 10752, 2147483647, -1, 5120}}}, + {14, + {{-1, 3072, -1, 7168, -1, 14848}, + {-1, 3072, -1, 10752, 14848, 2147483647}, + {12288, 2147483647, 7168, 10752, -1, 5120}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_UBMOVENUM_DEFAULT = 10; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16UbmovenumMap = { + {14, {{3072, 6144, 14848, 2147483647, -1, 3072}, {6144, 12288, 14848, 2147483647, 3072, 2147483647}}}, + {24, {{12288, 2147483647, 10752, 14848, -1, 2147483647}}}, + {32, {{-1, 6144, -1, 2147483647, 10752, 2147483647}, {12288, 2147483647, -1, 10752, -1, 2147483647}}}, + {40, {{3072, 6144, -1, 2147483647, 10752, 2147483647}, {6144, 12288, -1, 14848, 3072, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_LENPERLOOPMULT_DEFAULT = 400; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16LenperloopmultMap = { + {4, + {{3072, 6144, -1, 2147483647, 3072, 10752}, + {12288, 2147483647, -1, 10752, -1, 3072}, + {6144, 2147483647, -1, 2147483647, 3072, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_COMMNPUSPLIT_DEFAULT = 1; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_COMMDATASPLIT_DEFAULT = 8; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRAUBMOVENUM_DEFAULT = 20; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16ExtraubmovenumMap = { + {8, {{6144, 2147483647, 3072, 5120, 7168, 2147483647}}}, + {10, {{6144, 2147483647, 3072, 5120, -1, 7168}, {6144, 2147483647, 5120, 14848, -1, 2147483647}}}, + {12, {{3072, 6144, 3072, 2147483647, 14848, 2147483647}}}, + {15, {{-1, 3072, 3072, 2147483647, -1, 10752}, {3072, 6144, -1, 2147483647, -1, 5120}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRALENPERLOOPMULT_DEFAULT = 2; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16ExtralenperloopmultMap = { + {4, + {{-1, 3072, -1, 10752, 14848, 2147483647}, + {12288, 2147483647, 3072, 5120, -1, 2147483647}, + {6144, 2147483647, 5120, 7168, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRACOMMNPUSPLIT_DEFAULT = 8; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16ExtracommnpusplitMap = { + {1, {{12288, 2147483647, 14848, 2147483647, -1, 3072}, {12288, 2147483647, 14848, 2147483647, 5120, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRACOMMDATASPLIT_DEFAULT = 2; +static TILING_MAP g_allgatherTwoReducescatterEightTrueFP16ExtracommdatasplitMap = { + {8, {{12288, 2147483647, 14848, 2147483647, -1, 3072}, {12288, 2147483647, 14848, 2147483647, 5120, 2147483647}}}}; + +// 280 +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_SWIZZLECOUNT_DEFAULT = 9; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16SwizzlecountMap = { + {13, + {{-1, 768, 1280, 2147483647, -1, 7168}, + {1536, 3072, -1, 2147483647, -1, 7168}, + {3072, 2147483647, 5184, 2147483647, -1, 2147483647}}}, + {17, + {{-1, 768, -1, 2147483647, 7168, 2147483647}, + {3072, 2147483647, -1, 4544, 7168, 2147483647}, + {3072, 2147483647, 4544, 5184, -1, 2147483647}}}, + {5, + {{768, 1536, -1, 2147483647, 5120, 2147483647}, + {3072, 2147483647, -1, 4544, -1, 7168}, + {3072, 2147483647, 7680, 2147483647, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_UBMOVENUM_DEFAULT = 40; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16UbmovenumMap = { + {30, + {{-1, 768, 2176, 3840, -1, 5120}, {-1, 768, 2560, 2147483647, 5120, 7168}, + {-1, 768, -1, 7680, 7168, 2147483647}, {1536, 3072, -1, 6400, -1, 2147483647}}}, + {60, {{-1, 768, 7680, 2147483647, 7168, 2147483647}, {768, 1536, -1, 1280, -1, 7168}}}, + {20, {{768, 1536, -1, 4352, 7168, 2147483647}, {3072, 2147483647, -1, 6400, -1, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_LENPERLOOPMULT_DEFAULT = 400; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMNPUSPLIT_DEFAULT = 1; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMDATASPLIT_DEFAULT = 8; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMDIRECT_DEFAULT = 0; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16CommdirectMap = { + {1, + {{-1, 768, 3456, 2147483647, -1, 5120}, + {-1, 768, 2560, 2147483647, 5120, 7168}, + {-1, 768, 4352, 7680, 7168, 2147483647}, + {768, 1536, -1, 2147483647, -1, 7168}, + {1536, 3072, 1280, 2147483647, -1, 2147483647}, + {3072, 2147483647, -1, 7680, 5120, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRAUBMOVENUM_DEFAULT = 60; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16ExtraubmovenumMap = { + {40, {{768, 2147483647, -1, 2176, -1, 5120}}}, + {30, + {{768, 1536, 2176, 2147483647, -1, 5120}, + {768, 1536, -1, 2147483647, 5120, 2147483647}, + {1536, 2147483647, -1, 1792, 5120, 2147483647}}}, + {20, {{1536, 2147483647, 2176, 2147483647, -1, 5120}, {1536, 2147483647, 1792, 2147483647, 5120, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRALENPERLOOPMULT_DEFAULT = 2; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRACOMMNPUSPLIT_DEFAULT = 8; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16ExtracommnpusplitMap = { + {1, {{3072, 2147483647, 2176, 2147483647, -1, 5120}, {768, 2147483647, -1, 2147483647, 5120, 2147483647}}}}; + +constexpr int32_t ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRACOMMDATASPLIT_DEFAULT = 1; +static TILING_MAP g_allgatherTwoReducescatterEightFalseFP16ExtracommdatasplitMap = { + {8, {{3072, 2147483647, 2176, 2147483647, -1, 5120}, {768, 2147483647, -1, 2147483647, 5120, 2147483647}}}}; + +const int PVALUE_ONE = 1; +const int M0_DEFAULT = 128; +const int K0_DEFAULT = 256; +const int N0_DEFAULT = 256; +const int SWIZZLEDIRECT_ONE = 1; + +void AG8RS2FalseFP16Tiling(CoCTilingData &cocTilingData) +{ + std::map tilingParamMap = { + {&cocTilingData.swizzlCount, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_SWIZZLECOUNT_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16SwizzlecountMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_UBMOVENUM_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16UbmovenumMap}}, + {&cocTilingData.lenPerLoop, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_LENPERLOOPMULT_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16LenperloopmultMap}}, + {&cocTilingData.commNpuSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_COMMNPUSPLIT_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16CommnpusplitMap}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_COMMDATASPLIT_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16CommdatasplitMap}}, + {&cocTilingData.extraUbMoveNum, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRAUBMOVENUM_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16ExtraubmovenumMap}}, + {&cocTilingData.extraLenPerLoop, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRALENPERLOOPMULT_DEFAULT, + g_allgatherEightReducescatterTwoFalseFP16ExtralenperloopmultMap}}, + {&cocTilingData.extraCommNpuSplit, {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRACOMMNPUSPLIT_DEFAULT}}, + {&cocTilingData.extraCommDataSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_FALSE_FP16_EXTRACOMMDATASPLIT_DEFAULT}}}; + SetTilingParam2D(cocTilingData, tilingParamMap); + return; +} + +void AG8RS2TrueFP16Tiling(CoCTilingData &cocTilingData) +{ + std::map tilingParamMap = { + {&cocTilingData.swizzlCount, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_SWIZZLECOUNT_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16SwizzlecountMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_UBMOVENUM_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16UbmovenumMap}}, + {&cocTilingData.lenPerLoop, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_LENPERLOOPMULT_DEFAULT}}, + {&cocTilingData.commNpuSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMNPUSPLIT_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16CommnpusplitMap}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMDATASPLIT_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16CommdatasplitMap}}, + {&cocTilingData.commDirect, {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_COMMDIRECT_DEFAULT}}, + {&cocTilingData.extraUbMoveNum, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRAUBMOVENUM_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16ExtraubmovenumMap}}, + {&cocTilingData.extraLenPerLoop, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRALENPERLOOPMULT_DEFAULT, + g_allgatherEightReducescatterTwoTrueFP16ExtralenperloopmultMap}}, + {&cocTilingData.extraCommNpuSplit, {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRACOMMNPUSPLIT_DEFAULT}}, + {&cocTilingData.extraCommDataSplit, + {ALLGATHER_EIGHT_REDUCESCATTER_TWO_TRUE_FP16_EXTRACOMMDATASPLIT_DEFAULT}}}; + SetTilingParam2D(cocTilingData, tilingParamMap); + return; +} + +void AG2RS8TrueFP16Tiling(CoCTilingData &cocTilingData) +{ + std::map tilingParamMap = { + {&cocTilingData.swizzlCount, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_SWIZZLECOUNT_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16SwizzlecountMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_UBMOVENUM_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16UbmovenumMap}}, + {&cocTilingData.lenPerLoop, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_LENPERLOOPMULT_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16LenperloopmultMap}}, + {&cocTilingData.commNpuSplit, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_COMMNPUSPLIT_DEFAULT}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_COMMDATASPLIT_DEFAULT}}, + {&cocTilingData.extraUbMoveNum, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRAUBMOVENUM_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16ExtraubmovenumMap}}, + {&cocTilingData.extraLenPerLoop, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRALENPERLOOPMULT_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16ExtralenperloopmultMap}}, + {&cocTilingData.extraCommNpuSplit, + {DIM_EIGHT, g_allgatherTwoReducescatterEightTrueFP16ExtracommnpusplitMap}}, + {&cocTilingData.extraCommDataSplit, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_TRUE_FP16_EXTRACOMMDATASPLIT_DEFAULT, + g_allgatherTwoReducescatterEightTrueFP16ExtracommdatasplitMap}}}; + SetTilingParam2D(cocTilingData, tilingParamMap); + return; +} + +void AG2RS8FalseFP16Tiling(CoCTilingData &cocTilingData) +{ + std::map tilingParamMap = { + {&cocTilingData.swizzlCount, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_SWIZZLECOUNT_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16SwizzlecountMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_UBMOVENUM_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16UbmovenumMap}}, + {&cocTilingData.lenPerLoop, {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_LENPERLOOPMULT_DEFAULT}}, + {&cocTilingData.commNpuSplit, {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMNPUSPLIT_DEFAULT}}, + {&cocTilingData.commDataSplit, {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMDATASPLIT_DEFAULT}}, + {&cocTilingData.commDirect, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_COMMDIRECT_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16CommdirectMap}}, + {&cocTilingData.extraUbMoveNum, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRAUBMOVENUM_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16ExtraubmovenumMap}}, + {&cocTilingData.extraLenPerLoop, {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRALENPERLOOPMULT_DEFAULT}}, + {&cocTilingData.extraCommNpuSplit, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRACOMMNPUSPLIT_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16ExtracommnpusplitMap}}, + {&cocTilingData.extraCommDataSplit, + {ALLGATHER_TWO_REDUCESCATTER_EIGHT_FALSE_FP16_EXTRACOMMDATASPLIT_DEFAULT, + g_allgatherTwoReducescatterEightFalseFP16ExtracommdatasplitMap}}}; + SetTilingParam2D(cocTilingData, tilingParamMap); + return; +} + +void CoCAllgatherMatmulReduceScatterTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + + cocTilingData.swizzlDirect = SWIZZLEDIRECT_ONE; + + cocTilingData.m0 = M0_DEFAULT; + cocTilingData.k0 = K0_DEFAULT; + cocTilingData.n0 = N0_DEFAULT; + + cocTilingData.withSerialMode = 0; + cocTilingData.is91093 = 0; + cocTilingData.pValue = PVALUE_ONE; + cocTilingData.commDirect = 0; + + auto rsDim = taskParam.cocParamDesc.twoDimTPInfo.rsDim; + auto agDim = taskParam.cocParamDesc.twoDimTPInfo.agDim; + auto innerDimIsAg = taskParam.cocParamDesc.twoDimTPInfo.innerDimIsAg; + if (agDim == DIM_EIGHT && rsDim == DIM_TWO && !innerDimIsAg) { + AG8RS2FalseFP16Tiling(cocTilingData); + } else if (agDim == DIM_EIGHT && rsDim == DIM_TWO && innerDimIsAg) { + AG8RS2TrueFP16Tiling(cocTilingData); + } else if (agDim == DIM_TWO && rsDim == DIM_EIGHT && innerDimIsAg) { + AG2RS8TrueFP16Tiling(cocTilingData); + } else { + AG2RS8FalseFP16Tiling(cocTilingData); + } + cocTilingData.commNpuSplit = std::min(cocTilingData.commNpuSplit, agDim); + cocTilingData.extraCommNpuSplit = std::min(cocTilingData.extraCommNpuSplit, rsDim); +} + +bool CoCAllgatherMatmulReduceScatterTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCTilingFunc::CheckTiling(taskParam)) { + return false; + } + + auto commNpuSplit = cocTilingData.commNpuSplit; + auto commDataSplit = cocTilingData.commDataSplit; + auto extraCommNpuSplit = cocTilingData.extraCommNpuSplit; + auto extraCommDataSplit = cocTilingData.extraCommDataSplit; + auto coreNum = cocTilingData.blockDim; + auto useCoreCount = commNpuSplit * commDataSplit + extraCommNpuSplit * extraCommDataSplit; + + const int maxMValue = 200000; + const int maxNValue = 32768; + const int maxKValue = 32768; + std::vector> paramCheckList = { + {"m", cocTilingData.m, PARAM_CHECK_MIN_VALUE_ONE, maxMValue}, + {"k", cocTilingData.k, PARAM_CHECK_MIN_VALUE_ONE, maxKValue}, + {"n", cocTilingData.n, PARAM_CHECK_MIN_VALUE_ONE, maxNValue}, + {"commNpuSplit * commDataSplit + extraCommNpuSplit * extraCommDataSplit", + useCoreCount, PARAM_CHECK_MIN_VALUE_ONE, coreNum}, + }; + return CheckParamScopeList(paramCheckList); +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgather_tiling.cpp b/comm/lcal/src/tiling/allgather_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d30ce0daea7e191eabc8ce2a91d6f36b29b5dcaf --- /dev/null +++ b/comm/lcal/src/tiling/allgather_tiling.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "tiling.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" +#include "lcoc_func.h" + +namespace Lcal { +void CoCAllGatherMatmulTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + if (Is91093(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_EIGHT) { + AllGatherNPU91093EightRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_SIXTEEN) { + AllGatherNPU91093SixteenRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_TWO && + taskParam.cocParamDesc.mmInfo.isInt8) { + AllGatherNPU91093TwoRankINT8Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_TWO) { + AllGatherNPU91093TwoRankFP16Tiling(cocTilingData); + return; + } + } else if (Is910B(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_EIGHT) { + AllGatherEightRankFP16GetDefaultTiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_FOUR) { + AllGatherFourRankINT8Tiling(cocTilingData); // INT8 + return; + } + } + AllGatherGetDefaultTiling(cocTilingData); +} + +void CoCAllGatherMatmulV2TilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + auto coreNum = cocTilingData.blockDim; + if (Is91093(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_EIGHT) { + AllGatherV2NPU91093EightRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_SIXTEEN) { + AllGatherV2NPU91093SixteenRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_TWO) { + AllGatherV2NPU91093TwoRankFP16Tiling(cocTilingData); + return; + } + } + if (coreNum >= ALLGATHERV2_CORENUM_SIXTEEN + cocTilingData.rankSize) { + AllGatherV2EightRankFP16Core16GetDefaultTiling(cocTilingData); + return; + } + AllGatherV2EightRankFP16GetDefaultTiling(cocTilingData); +} + +bool CheckKValue(const TaskParam &taskParam, const CoCTilingData &data) +{ + auto blockCount = data.is91093 ? BLOCK_COUNT_3 : MAX_BLOCK_COUNT; + int32_t maxPeerMemPerRank = (taskParam.bufferSize * 1024 * 1024) / INPUT_DTYPE / data.rankSize / blockCount; + if (data.pValue * data.m0 * data.k0 * data.kLoop >= maxPeerMemPerRank) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(data.pValue) + ", m0: " + std::to_string(data.m0) + + ", k0: " + std::to_string(data.k0) + ", kLoop: " + std::to_string(data.kLoop) + + "maxPeerMemPerRank: " + std::to_string(maxPeerMemPerRank); + PrintErrorLog(taskParam.lcalType, str); + return false; + } + return true; +} + +bool CoCAllGatherMatmulTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCTilingFunc::CheckTiling(taskParam)) { + return false; + } + if (!CheckKValue(taskParam, cocTilingData)) { + return false; + } + + auto rankSize = cocTilingData.rankSize; + auto commNpuSplit = cocTilingData.commNpuSplit; + auto commDataSplit = cocTilingData.commDataSplit; + auto coreNum = cocTilingData.blockDim; + auto is91093 = cocTilingData.is91093; + auto minCoreCount = is91093 ? rankSize / A3_DIE_NUM : rankSize; + int32_t useCoreCount = commNpuSplit * commDataSplit; + + std::vector> paramCheckList = { + {"commNpuSplit * commDataSplit", useCoreCount, minCoreCount, coreNum} + }; + return CheckParamScopeList(paramCheckList); +} + +bool CoCAllGatherMatmulV2TilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCTilingFunc::CheckTiling(taskParam)) { + return false; + } + if (!CheckKValue(taskParam, cocTilingData)) { + return false; + } + + auto commNpuSplit = cocTilingData.commNpuSplit; + auto commDataSplit = cocTilingData.commDataSplit; + auto coreNum = cocTilingData.blockDim; + int32_t useCoreCount = commNpuSplit * commDataSplit; + + std::vector> paramCheckList = { + {"commNpuSplit * commDataSplit", useCoreCount, PARAM_CHECK_MIN_VALUE_ONE, coreNum-1} + }; + return CheckParamScopeList(paramCheckList); +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgather_tiling_91093.cpp b/comm/lcal/src/tiling/allgather_tiling_91093.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d36783c09ec581146c8e62e0f1ed9629c5d52d8 --- /dev/null +++ b/comm/lcal/src/tiling/allgather_tiling_91093.cpp @@ -0,0 +1,474 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_91093.h" +#include "tiling_func.h" +namespace Lcal { + constexpr int32_t ALLGATHER_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHER_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT = 12; + constexpr int32_t ALLGATHER_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 8; + constexpr int32_t ALLGATHER_91093_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHER_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT = 30; + constexpr int32_t ALLGATHER_91093_SIXTEEN_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHER_91093_SIXTEEN_RANK_FP16_COMMDIRECT_DEFAULT = 1; + constexpr int32_t ALLGATHER_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT = 10; + constexpr int32_t ALLGATHER_91093_SIXTEEN_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHER_91093_TWO_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t ALLGATHER_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT = 20; + constexpr int32_t ALLGATHER_91093_TWO_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHER_91093_TWO_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHER_91093_TWO_RANK_INT8_M0_DEFAULT = 128; + constexpr int32_t ALLGATHER_91093_TWO_RANK_INT8_PVALUE_DEFAULT = 14; + constexpr int32_t ALLGATHER_91093_TWO_RANK_INT8_UBMOVENUM_DEFAULT = 40; + constexpr int32_t ALLGATHER_91093_TWO_RANK_INT8_COMMDATASPLIT_DEFAULT = 16; + + static std::map>> g_allgather91093EightRankFP16M0Map = { + {128, + {{-1, 1262, -1, 2147483647, -1, 1720}, {-1, 1262, -1, 2147483647, 1824, 3248}, + {1262, 2147483647, -1, 2147483647, -1, 3248}, {-1, 2274, -1, 6700, 3248, 5660}, + {-1, 2274, 6700, 2147483647, 3248, 6172}, {2274, 2147483647, -1, 2147483647, 3248, 5360}, + {-1, 2147483647, -1, 2147483647, 6934, 8446}, {-1, 2147483647, 9950, 2147483647, 8446, 8958}, + {-1, 2147483647, -1, 2147483647, 8958, 2147483647}}}, + {256, + {{-1, 1262, -1, 2147483647, 1720, 1824}, {-1, 2274, -1, 6700, 5660, 6934}, + {-1, 2274, 6700, 2147483647, 6172, 6934}, {2274, 2147483647, -1, 2147483647, 5360, 6934}, + {-1, 2147483647, -1, 9950, 8446, 8958}}} + }; + + static std::map>> g_allgather91093EightRankFP16UbmovenumMap = { + {8.0, + {{-1, 1768, -1, 2147483647, -1, 1624}, {-1, 1768, 1774, 2147483647, 1624, 2274}, + {1768, 2147483647, -1, 4810, -1, 1200}, {1768, 2147483647, 4810, 2147483647, -1, 2274}}}, + {4.0, + {{-1, 1768, -1, 1774, 1624, 2274}}}, + {2.0, + {{1768, 2147483647, -1, 4810, 1200, 2274}, {-1, 768, 768, 2147483647, 2274, 4608}, + {-1, 768, -1, 2147483647, 4608, 2147483647}, {768, 2147483647, -1, 2147483647, 2274, 2147483647}}}, + {3.0, + {{-1, 768, -1, 768, 2274, 4608}}} + }; + + static std::map>> g_allgather91093EightRankFP16PvalueMap = { + {2, + {{-1, 2786, -1, 4608, -1, 1200}, {-1, 2786, -1, 8192, 1518, 1624}, + {768, 1262, -1, 2147483647, 1624, 1720}, {1262, 2786, 768, 2147483647, 1624, 2274}, + {2786, 2147483647, 4810, 9728, -1, 768}, {7168, 2147483647, 9728, 2147483647, -1, 768}, + {2786, 2147483647, 10720, 2147483647, 768, 1774}, {2786, 2147483647, 4810, 2147483647, 1774, 2274}, + {768, 1262, -1, 2147483647, 2274, 2912}, {1262, 2147483647, -1, 2147483647, 2274, 3248}, + {1262, 2147483647, -1, 2147483647, 3840, 4298}, {-1, 2147483647, -1, 2274, 4298, 5660}, + {1774, 2147483647, -1, 2274, 6684, 8958}}}, + {1, + {{-1, 2786, 4608, 2147483647, -1, 1200}, {-1, 2786, -1, 2147483647, 1200, 1518}, + {-1, 2786, 8192, 2147483647, 1518, 1624}, {-1, 768, -1, 2147483647, 1624, 2274}, + {768, 1262, -1, 2147483647, 1720, 2274}, {2786, 7168, 9728, 2147483647, -1, 768}, + {-1, 768, -1, 2147483647, 2274, 4298}, {768, 1262, -1, 2147483647, 2912, 4298}, + {1262, 2147483647, -1, 2147483647, 3248, 3840}, {-1, 2147483647, -1, 2274, 5660, 6684}, + {-1, 1774, -1, 2274, 6684, 8958}, {-1, 2147483647, 2274, 9728, 4298, 8958}, + {-1, 2147483647, 9728, 2147483647, 4298, 6934}, {-1, 768, -1, 2147483647, 8958, 11744}, + {768, 1262, -1, 8704, 8958, 11744}, {1262, 2147483647, -1, 8446, 10720, 11744}}}, + {4, + {{1262, 2786, -1, 768, 1624, 2274}, {2786, 2147483647, 1262, 4810, 768, 1774}, + {2786, 4608, -1, 4810, 1774, 2274}, {2786, 2147483647, 4810, 10720, 768, 1774}, + {-1, 2147483647, 9728, 2147483647, 6934, 8958}, {768, 1262, 8704, 2147483647, 8958, 11744}, + {1262, 2147483647, 8446, 2147483647, 8958, 11744}, {-1, 2000, -1, 2147483647, 11744, 2147483647}, + {2000, 2147483647, 8200, 2147483647, 11744, 2147483647}}}, + {12, + {{2786, 2147483647, -1, 1262, -1, 768}}}, + {6, + {{2786, 2147483647, 1262, 4810, -1, 768}, {4608, 2147483647, 768, 4810, 1774, 2274}, + {1262, 2147483647, -1, 8446, 8958, 10720}, {2000, 2147483647, -1, 8200, 11744, 2147483647}}}, + {10, + {{2786, 2147483647, -1, 1262, 768, 1774}, {4608, 2147483647, -1, 768, 1774, 2274}}} + }; + + static std::map>> g_allgather91093EightRankFP16CommdatasplitMap = { + {8, + {{-1, 1768, -1, 2147483647, -1, 832}, {1262, 1768, -1, 768, 832, 1624}, + {-1, 1768, 768, 2147483647, 832, 1624}, {-1, 1768, 1774, 4608, 1880, 2274}, + {-1, 1768, 4608, 2147483647, 1624, 2274}, {1768, 8958, -1, 4810, -1, 1200}, + {8958, 2147483647, -1, 4810, -1, 1536}, {1768, 2147483647, 4810, 2147483647, -1, 2274}, + {8100, 8600, 9728, 2147483647, 4636, 2147483647}}}, + {16, + {{-1, 1262, -1, 768, 832, 1624}, {-1, 1768, -1, 1774, 1624, 2274}, + {-1, 1768, 1774, 4608, 1624, 1880}, {1768, 8958, -1, 4810, 1200, 2274}, + {8958, 2147483647, -1, 4810, 1536, 2274}, {-1, 8100, -1, 2147483647, 2274, 2147483647}, + {8100, 8600, -1, 9728, 2274, 2147483647}, {8100, 8600, 9728, 2147483647, 2274, 4636}, + {8600, 2147483647, -1, 2147483647, 2274, 2147483647}}} + }; + + static std::map>> g_allgather91093SixteenRankFP16M0Map = { + {128, + {{-1, 2274, -1, 2147483647, -1, 5552}, {2274, 2786, 8200, 2147483647, -1, 4000}, + {2274, 2786, 5100, 2147483647, 4000, 5552}, {2786, 2147483647, -1, 5360, -1, 5552}, + {2786, 2147483647, 5900, 2147483647, -1, 5552}, {-1, 2147483647, 5360, 2147483647, 5552, 6172}, + {-1, 8958, 5360, 2147483647, 6172, 6934}, {-1, 2147483647, -1, 2147483647, 6934, 2147483647}}}, + {256, + {{2274, 2786, -1, 8200, -1, 4000}, {2274, 2786, -1, 5100, 4000, 5552}, + {2786, 2147483647, 5360, 5900, -1, 5552}, {-1, 2147483647, -1, 5360, 5552, 6934}, + {8958, 2147483647, 5360, 2147483647, 6172, 6934}}} + }; + + static std::map>> g_allgather91093SixteenRankFP16PvalueMap = { + {10, + {{-1, 3798, -1, 1774, -1, 576}, {3798, 9728, -1, 1262, -1, 2274}, + {3798, 2147483647, 1262, 2274, -1, 768}}}, + {6, + {{-1, 3798, 1774, 4608, -1, 576}, {9728, 2147483647, -1, 1262, -1, 2274}}}, + {1, + {{-1, 3798, 4608, 7696, -1, 576}, {-1, 3798, -1, 2147483647, 576, 832}, + {-1, 2560, -1, 2147483647, 832, 1200}, {-1, 2786, 1774, 2147483647, 1200, 2274}, + {2786, 3798, 4298, 2147483647, 1200, 2274}, {-1, 3798, -1, 2147483647, 2274, 3248}, + {3798, 5900, -1, 2147483647, 2274, 2786}, {3798, 5900, 4608, 2147483647, 2786, 3248}, + {5900, 2147483647, 5360, 2147483647, 2274, 3248}, {-1, 2560, -1, 768, 3248, 8704}, + {-1, 7850, -1, 768, 8704, 2147483647}, {-1, 2147483647, 768, 2147483647, 3248, 11744}, + {-1, 1262, 768, 2147483647, 11744, 2147483647}, {2000, 2147483647, 6150, 2147483647, 11744, 2147483647}}}, + {2, + {{-1, 3798, 7696, 2147483647, -1, 576}, {2560, 3798, -1, 2147483647, 832, 1200}, + {-1, 2286, -1, 768, 1200, 2274}, {-1, 3798, 768, 1774, 1200, 2274}, + {2786, 3798, 1774, 4298, 1200, 2274}, {3798, 6700, 4810, 5360, -1, 2274}, + {3798, 8100, 5360, 2147483647, -1, 2274}, {8100, 2147483647, 4810, 6450, 768, 2274}, + {8100, 2147483647, 6450, 2147483647, -1, 2274}, {3798, 5900, -1, 4608, 2786, 3248}, + {5900, 2147483647, 2274, 5360, 2274, 3248}, {2560, 2147483647, -1, 768, 3248, 8704}, + {7850, 2147483647, -1, 768, 8704, 2147483647}, {1262, 2000, 768, 2147483647, 11744, 2147483647}}}, + {4, + {{2286, 3798, -1, 768, 1200, 2274}, {3798, 2147483647, 1262, 2274, 768, 2274}, + {3798, 2147483647, 2274, 4810, -1, 2274}, {6700, 8100, 4810, 5360, -1, 2274}, + {8100, 2147483647, 4810, 6450, -1, 768}, {5900, 2147483647, -1, 2274, 2274, 3248}, + {2000, 2147483647, 768, 6150, 11744, 2147483647}}} + }; + + static std::map>> g_allgather91093SixteenRankFP16CommdirectMap = { + {0, + {{-1, 2147483647, -1, 2147483647, -1, 1200}, {768, 8958, -1, 2147483647, 1200, 1438}, + {-1, 8958, -1, 2147483647, 1438, 2147483647}, {8958, 2147483647, -1, 2147483647, 1200, 2147483647}}}, + {1, + {{-1, 768, -1, 2147483647, 1200, 1438}}} + }; + + static std::map>> g_allgather91093SixteenRankFP16CommdatasplitMap = { + {16, + {{-1, 1262, -1, 2147483647, -1, 1624}, {-1, 1262, -1, 2147483647, 1720, 2626}, + {1262, 2147483647, -1, 2147483647, -1, 2626}, {-1, 768, -1, 3798, 2626, 2147483647}, + {2274, 2147483647, -1, 3798, 2626, 3798}}}, + {1, + {{-1, 1262, -1, 2147483647, 1624, 1720}, {768, 2274, -1, 3798, 2626, 2147483647}, + {2274, 2147483647, -1, 3798, 3798, 2147483647}, {-1, 2147483647, 3798, 2147483647, 2626, 2147483647}}} + }; + + static std::map>> g_allgather91093SixteenRankFP16UbmovenumMap = { + {20.0, + {{-1, 3286, -1, 2147483647, -1, 832}, {-1, 3286, -1, 1262, 832, 2274}, + {-1, 3286, 1774, 2147483647, 832, 2274}, {-1, 3286, -1, 2147483647, 2274, 3248}, + {3286, 3798, -1, 2147483647, -1, 2000}, {3286, 3798, 6150, 2147483647, 2000, 3248}, + {3798, 2147483647, -1, 5360, -1, 2274}, {3798, 2147483647, 5360, 5900, -1, 2000}, + {3798, 2147483647, 5900, 2147483647, -1, 3248}, {-1, 1262, -1, 2147483647, 3542, 4298}, + {-1, 1518, 768, 2147483647, 4298, 6172}, {-1, 768, -1, 2147483647, 6172, 8704}, + {-1, 768, -1, 6656, 8704, 2147483647}}}, + {30.0, + {{-1, 3286, 1262, 1774, 832, 2274}}}, + {10.0, + {{3286, 3798, -1, 6150, 2000, 3248}, {3798, 2147483647, -1, 5360, 2274, 3248}, + {3798, 2147483647, 5360, 5900, 2000, 3248}, {-1, 3000, -1, 7200, 3248, 3542}, + {-1, 2147483647, 7200, 2147483647, 3248, 3542}, {1262, 2024, -1, 2147483647, 3542, 4298}, + {-1, 1006, -1, 768, 4298, 6172}, {-1, 768, 6656, 2147483647, 8704, 2147483647}, + {768, 1774, -1, 11264, 6172, 6684}}}, + {8.0, + {{3000, 2147483647, -1, 7200, 3248, 3542}, {2024, 2147483647, -1, 2147483647, 3542, 4298}, + {1006, 3584, -1, 768, 4298, 6172}, {768, 1774, -1, 11264, 6684, 2147483647}}}, + {6.0, + {{3584, 2147483647, -1, 768, 4298, 6172}, {1518, 2147483647, 768, 2147483647, 4298, 6172}, + {768, 1774, 11264, 2147483647, 6172, 2147483647}}}, + {4.0, + {{1774, 2274, -1, 2560, 6172, 2147483647}, {2274, 3286, -1, 2147483647, 6172, 8958}, + {3286, 2147483647, -1, 2147483647, 6172, 8446}}}, + {3.0, + {{1774, 2274, 2560, 2147483647, 6172, 2147483647}, {2274, 3286, -1, 2147483647, 8958, 2147483647}, + {3286, 2147483647, -1, 2147483647, 8446, 2147483647}}} + }; + + static std::map>> g_allgather91093TwoRankFP16CommdatasplitMap = { + {8, + {{-1, 1536, -1, 3584, -1, 1536}, {1536, 2560, -1, 8704, -1, 1536}, + {1536, 9728, 8704, 9728, -1, 1536}, {3584, 9728, 9728, 2147483647, -1, 1536}, + {9728, 2147483647, 768, 2560, -1, 1536}, {9728, 2147483647, 5120, 2147483647, -1, 1536}}}, + {16, + {{-1, 1536, 3584, 8704, -1, 1536}, {2560, 9728, -1, 8704, -1, 1536}, + {-1, 1536, 8704, 9728, -1, 1536}, {-1, 3584, 9728, 2147483647, -1, 1536}, + {9728, 2147483647, -1, 768, -1, 1536}, {9728, 2147483647, 2560, 5120, -1, 1536}, + {-1, 2147483647, -1, 2147483647, 1536, 2147483647}}} + }; + + static std::map>> g_allgather91093TwoRankFP16M0Map = { + {128, + {{-1, 4608, -1, 1280, -1, 1536}, {-1, 2560, 1280, 2147483647, -1, 1536}, + {2560, 4608, 5632, 2147483647, -1, 1536}, {4608, 5632, 7680, 2147483647, -1, 1536}, + {9728, 2147483647, 8192, 2147483647, -1, 1536}, {-1, 1536, -1, 3584, 1536, 2147483647}, + {1536, 2147483647, -1, 4608, 1536, 2147483647}, {-1, 2147483647, 4608, 2147483647, 1536, 7680}, + {3584, 2147483647, 4608, 2147483647, 7680, 2147483647}}}, + {256, + {{2560, 4608, 1280, 5632, -1, 1536}, {4608, 5632, -1, 7680, -1, 1536}, + {5632, 9728, -1, 2147483647, -1, 1536}, {9728, 2147483647, -1, 8192, -1, 1536}, + {-1, 1536, 3584, 4608, 1536, 2147483647}, {-1, 3584, 4608, 2147483647, 7680, 2147483647}}} + }; + + static std::map>> g_allgather91093TwoRankFP16UbmovenumMap = { + {10.0, + {{-1, 4608, -1, 1792, -1, 1536}}}, + {20.0, + {{-1, 4608, 1792, 2560, -1, 1536}}}, + {6.0, + {{-1, 4608, 2560, 2147483647, -1, 1536}, {4608, 2147483647, -1, 8704, -1, 1536}, + {5632, 2147483647, -1, 8704, 1536, 2560}, {4608, 2147483647, 8704, 2147483647, 1536, 3584}}}, + {4.0, + {{-1, 3584, -1, 6656, 1536, 2560}, {1536, 4608, 8704, 2147483647, 2560, 3584}, + {4608, 5632, -1, 8704, 1536, 2560}, {4608, 7680, 5632, 8704, 2560, 4608}, + {7680, 2147483647, -1, 8704, 2560, 4608}, {4608, 2147483647, 8704, 9728, 3584, 4608}}}, + {3.0, + {{3584, 4608, -1, 6656, 1536, 2560}, {-1, 4608, 6656, 2147483647, 1536, 2560}, + {-1, 1536, -1, 4608, 2560, 3584}, {-1, 1536, -1, 1536, 3584, 4608}, + {1536, 4608, 5632, 2147483647, 3584, 4608}, {4608, 7680, -1, 5632, 2560, 4608}, + {4608, 2147483647, 9728, 2147483647, 3584, 4608}, {-1, 1536, 9728, 2147483647, 19456, 2147483647}, + {1536, 3584, 6656, 2147483647, 4608, 5632}, {5632, 2147483647, 3584, 2147483647, 4608, 5632}}}, + {2.0, + {{-1, 1536, 4608, 2147483647, 2560, 3584}, {1536, 4608, -1, 8704, 2560, 3584}, + {-1, 1536, 1536, 2147483647, 3584, 4608}, {1536, 4608, -1, 5632, 3584, 4608}, + {-1, 1536, -1, 2147483647, 4608, 15360}, {-1, 1536, -1, 9728, 15360, 2147483647}, + {-1, 1536, 9728, 2147483647, 15360, 19456}, {1536, 3584, -1, 6656, 4608, 5632}, + {1536, 3584, -1, 2147483647, 5632, 2147483647}, {3584, 5632, -1, 2147483647, 4608, 2147483647}, + {5632, 2147483647, -1, 3584, 4608, 5632}, {5632, 2147483647, -1, 2147483647, 5632, 2147483647}}}, + {16.0, + {{4608, 2147483647, 8704, 2147483647, -1, 1536}}} + }; + + static std::map>> g_allgather91093TwoRankFP16PvalueMap = { + {10, + {{-1, 2560, -1, 5632, -1, 1536}, {1536, 2560, -1, 2147483647, 1536, 2560}, + {-1, 2560, -1, 7680, 2560, 3584}, {3584, 7680, -1, 3584, -1, 1536}, + {1536, 2560, -1, 2147483647, 11264, 13312}, {2560, 3584, 9728, 2147483647, 9728, 11264}, + {2560, 3584, 8704, 2147483647, 17408, 2147483647}}}, + {4, + {{-1, 2560, 5632, 2147483647, -1, 1536}, {1536, 2560, 8704, 2147483647, 3584, 4608}}}, + {6, + {{-1, 1536, -1, 2147483647, 1536, 2560}, {-1, 1536, -1, 2147483647, 3584, 4608}, + {-1, 1536, 6656, 2147483647, 4608, 9728}, {1536, 2560, -1, 2147483647, 9728, 11264}, + {2560, 3584, 1792, 8704, 9728, 15360}}}, + {12, + {{-1, 2560, 7680, 2147483647, 2560, 3584}, {2560, 3584, -1, 3584, -1, 1536}, + {2560, 2147483647, 3584, 5632, 4608, 9728}, {3584, 2147483647, 6656, 2147483647, -1, 9728}, + {-1, 1536, 6656, 2147483647, 9728, 11264}, {2560, 3584, -1, 2560, 15360, 2147483647}, + {2560, 3584, 8704, 9728, 9728, 11264}, {3584, 8704, 3584, 2147483647, 9728, 11264}, + {3584, 7680, -1, 2147483647, 11264, 13312}, {3584, 4608, -1, 2147483647, 13312, 2147483647}, + {4608, 8704, -1, 8704, 13312, 2147483647}, {8704, 9728, -1, 1792, 9728, 2147483647}, + {8704, 9728, 2560, 2147483647, 9728, 2147483647}, {9728, 2147483647, 1280, 2147483647, 9728, 2147483647}}}, + {14, + {{1536, 2560, -1, 8704, 3584, 4608}, {-1, 1536, -1, 6656, 4608, 9728}, + {1536, 2560, -1, 5632, 4608, 9728}, {7680, 2147483647, -1, 3584, -1, 1536}, + {2560, 2147483647, -1, 3584, 1536, 9728}, {2560, 2147483647, 3584, 5632, -1, 4608}, + {2560, 3584, 5632, 8704, -1, 9728}, {-1, 1536, -1, 6656, 9728, 11264}, + {-1, 1536, -1, 2147483647, 11264, 2147483647}, {1536, 2560, -1, 2147483647, 13312, 2147483647}, + {2560, 3584, -1, 1792, 9728, 15360}, {2560, 3584, 2560, 8704, 15360, 2147483647}, + {3584, 8704, -1, 3584, 9728, 11264}, {7680, 8704, -1, 2147483647, 11264, 13312}, + {4608, 8704, 8704, 2147483647, 13312, 2147483647}, {8704, 9728, 1792, 2560, 9728, 2147483647}, + {9728, 2147483647, -1, 1280, 9728, 2147483647}}}, + {3, + {{1536, 2560, 5632, 2147483647, 4608, 9728}, {2560, 3584, 8704, 2147483647, -1, 9728}}}, + {8, + {{3584, 2147483647, 5632, 6656, -1, 9728}, {2560, 3584, 8704, 2147483647, 11264, 17408}}} + }; + + static std::map>> g_allgather91093TwoRankINT8CommdatasplitMap = { + {8, + {{-1, 1536, -1, 4608, -1, 1536}, {-1, 1536, -1, 3584, 1536, 15360}, + {-1, 1536, 3584, 4608, 1536, 6656}, {1536, 3584, 1280, 1792, -1, 7680}, + {-1, 1536, -1, 1280, 15360, 17408}, {-1, 1536, 8192, 9728, 15360, 17408}, + {-1, 1536, -1, 2048, 17408, 2147483647}, {-1, 1536, 4608, 5632, 17408, 2147483647}}}, + {16, + {{-1, 1536, 4608, 2147483647, -1, 1536}, {-1, 1536, 3584, 4608, 6656, 15360}, + {-1, 1536, 4608, 2147483647, 1536, 15360}, {1536, 2147483647, -1, 1280, -1, 15360}, + {3584, 2147483647, 1280, 1792, -1, 7680}, {1536, 2147483647, 1280, 1792, 7680, 15360}, + {1536, 2147483647, 1792, 2147483647, -1, 15360}, {-1, 1536, 1280, 8192, 15360, 17408}, + {-1, 1536, 9728, 2147483647, 15360, 17408}, {-1, 1536, 2048, 4608, 17408, 2147483647}, + {-1, 1536, 5632, 2147483647, 17408, 2147483647}, {1536, 2147483647, -1, 2147483647, 15360, 2147483647}}} + }; + + static std::map>> g_allgather91093TwoRankINT8UbmovenumMap = { + {30.0, + {{-1, 1536, -1, 4608, -1, 1536}}}, + {10.0, + {{-1, 1536, 4608, 2147483647, -1, 5632}, {6656, 2147483647, -1, 768, -1, 4608}, + {1536, 2147483647, 768, 2147483647, -1, 5632}, {-1, 8704, -1, 2147483647, 5632, 9728}, + {-1, 5632, -1, 2147483647, 9728, 11264}, {-1, 6656, 768, 2147483647, 11264, 13312}, + {9728, 2147483647, -1, 2147483647, 11264, 13312}}}, + {40.0, + {{-1, 1536, -1, 768, 1536, 5632}, {1536, 6656, -1, 768, -1, 1536}}}, + {20.0, + {{-1, 1536, 768, 4608, 1536, 5632}, {-1, 6656, -1, 768, 11264, 13312}}}, + {12.0, + {{1536, 6656, -1, 768, 1536, 5632}, {6656, 2147483647, -1, 768, 4608, 5632}, + {8704, 2147483647, -1, 2147483647, 5632, 9728}, {5632, 2147483647, -1, 2147483647, 9728, 11264}, + {-1, 6656, -1, 2147483647, 13312, 2147483647}, {6656, 9728, -1, 6656, 11264, 2147483647}}}, + {16.0, + {{6656, 9728, 6656, 2147483647, 11264, 2147483647}, {9728, 2147483647, -1, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_allgather91093TwoRankINT8PvalueMap = { + {6, + {{-1, 1536, -1, 4608, -1, 1536}, {1536, 9728, -1, 768, 2560, 13312}, + {9728, 2147483647, -1, 1280, 6656, 13312}, {8704, 9728, 1792, 2560, 13312, 2147483647}, + {9728, 2147483647, 2560, 7680, 2560, 5632}, {9728, 2147483647, 9728, 2147483647, 4608, 11264}}}, + {4, + {{-1, 1536, 4608, 2147483647, -1, 1536}, {9728, 2147483647, 1280, 2560, 6656, 13312}, + {-1, 8704, 1792, 2560, 13312, 2147483647}}}, + {10, + {{1536, 2560, -1, 2147483647, -1, 1536}, {-1, 2560, -1, 1280, 1536, 2560}, + {6656, 7680, 3072, 2147483647, 1536, 2560}, {-1, 6656, 768, 2560, 2560, 13312}, + {-1, 9728, 2560, 2147483647, 15360, 17408}}}, + {14, + {{2560, 6656, -1, 2147483647, -1, 1536}, {2560, 4608, -1, 1280, 1536, 2560}, + {-1, 1536, 1280, 2147483647, 1536, 2560}, {4608, 6656, -1, 2147483647, 1536, 2560}, + {7680, 2147483647, -1, 4608, -1, 2560}, {7680, 2147483647, 4608, 8704, -1, 1536}, + {7680, 8704, 8704, 2147483647, -1, 2560}, {9728, 2147483647, 8704, 2147483647, -1, 2560}, + {-1, 1536, -1, 768, 2560, 13312}, {9728, 2147483647, -1, 2560, 2560, 6656}, + {-1, 2147483647, 768, 1792, 17408, 2147483647}, {-1, 8704, 4608, 2147483647, 2560, 15360}, + {-1, 9728, 2560, 2147483647, 17408, 19456}, {9728, 2147483647, 2560, 7680, 5632, 2147483647}, + {9728, 2147483647, 7680, 9728, 2560, 2147483647}, {9728, 2147483647, 9728, 2147483647, 11264, 2147483647}}}, + {12, + {{6656, 7680, -1, 2147483647, -1, 1536}, {1536, 4608, 1280, 2147483647, 1536, 2560}, + {6656, 7680, -1, 3072, 1536, 2560}, {7680, 9728, 4608, 8704, 1536, 2560}, + {-1, 9728, 2560, 4608, 2560, 15360}, {8704, 9728, 4608, 2147483647, 2560, 15360}, + {-1, 9728, 2560, 2147483647, 19456, 2147483647}}}, + {3, + {{9728, 2147483647, 4608, 8704, 1536, 2560}, {-1, 2147483647, -1, 768, 17408, 2147483647}, + {9728, 2147483647, 1792, 2560, 13312, 2147483647}}}, + {8, + {{8704, 9728, 8704, 2147483647, -1, 2560}, {6656, 9728, 768, 2560, 2560, 13312}, + {-1, 2560, -1, 1792, 13312, 17408}, {9728, 2147483647, 9728, 2147483647, 2560, 4608}}}, + {2, + {{2560, 2147483647, -1, 1792, 13312, 17408}}} + }; + + static std::map>> g_allgather91093TwoRankINT8M0Map = { + {128, + {{-1, 4608, -1, 2147483647, -1, 2560}, {9728, 2147483647, 8704, 2147483647, -1, 1536}, + {7680, 2147483647, 7680, 2147483647, 1536, 2560}, {-1, 2147483647, -1, 4608, 2560, 2147483647}, + {-1, 2147483647, 4608, 2147483647, 2560, 19456}, {5632, 2147483647, 4608, 2147483647, 19456, 2147483647}}}, + {256, + {{4608, 2147483647, -1, 8704, -1, 1536}, {4608, 9728, 8704, 2147483647, -1, 1536}, + {4608, 7680, -1, 2147483647, 1536, 2560}, {7680, 2147483647, -1, 7680, 1536, 2560}, + {-1, 5632, 4608, 2147483647, 19456, 2147483647}}} + }; + + void AllGatherNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {ALLGATHER_91093_EIGHT_RANK_FP16_M0_DEFAULT, + g_allgather91093EightRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgather91093EightRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHER_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allgather91093EightRankFP16PvalueMap}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgather91093EightRankFP16CommdatasplitMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_NPU_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {ALLGATHER_91093_SIXTEEN_RANK_FP16_M0_DEFAULT, + g_allgather91093SixteenRankFP16M0Map}}, + {&cocTilingData.pValue, + {ALLGATHER_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT, + g_allgather91093SixteenRankFP16PvalueMap}}, + {&cocTilingData.commDirect, + {ALLGATHER_91093_SIXTEEN_RANK_FP16_COMMDIRECT_DEFAULT, + g_allgather91093SixteenRankFP16CommdirectMap}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_91093_SIXTEEN_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgather91093SixteenRankFP16CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgather91093SixteenRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}} + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit == COMMDATASPLIT_ONE ? cocTilingData.rankSize : COMMNPUSPLIT_ONE; + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherNPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLGATHER_91093_TWO_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgather91093TwoRankFP16CommdatasplitMap}}, + {&cocTilingData.m0, + {ALLGATHER_91093_TWO_RANK_FP16_M0_DEFAULT, + g_allgather91093TwoRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgather91093TwoRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHER_91093_TWO_RANK_FP16_PVALUE_DEFAULT, + g_allgather91093TwoRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}} + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherNPU91093TwoRankINT8Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLGATHER_91093_TWO_RANK_INT8_COMMDATASPLIT_DEFAULT, + g_allgather91093TwoRankINT8CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_91093_TWO_RANK_INT8_UBMOVENUM_DEFAULT, + g_allgather91093TwoRankINT8UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHER_91093_TWO_RANK_INT8_PVALUE_DEFAULT, + g_allgather91093TwoRankINT8PvalueMap}}, + {&cocTilingData.m0, + {ALLGATHER_91093_TWO_RANK_INT8_M0_DEFAULT, + g_allgather91093TwoRankINT8M0Map}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + DealTilingParamByBuffSize(cocTilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgather_tiling_910B.cpp b/comm/lcal/src/tiling/allgather_tiling_910B.cpp new file mode 100644 index 0000000000000000000000000000000000000000..81b6ffec7d4cbb4dfbcc9464e107b807b0530cf0 --- /dev/null +++ b/comm/lcal/src/tiling/allgather_tiling_910B.cpp @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_910B.h" +#include "tiling_func.h" +#include "lcal_types.h" + +namespace Lcal { + constexpr int32_t ALLGATHER_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 10; + constexpr int32_t ALLGATHER_EIGHT_RANK_FP16_PVALUE_DEFAULT = 8; + constexpr int32_t ALLGATHER_EIGHT_RANK_FP16_COMMDIRECT_DEFAULT = 1; + constexpr int32_t ALLGATHER_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHER_EIGHT_RANK_FP16_M0_DEFAULT = 128; + + constexpr int32_t ALLGATHER_FOUR_RANK_INT8_UBMOVENUM_DEFAULT = 4; + constexpr int32_t ALLGATHER_FOUR_RANK_INT8_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHER_FOUR_RANK_INT8_PVALUE_DEFAULT = 14; + constexpr int32_t ALLGATHER_FOUR_RANK_INT8_M0_DEFAULT = 128; + + const int UBMOVE_MAX_M = 10096; + const int UBMOVE_MAX_K = 6656; + const int UBMOVE_MAX_N = 2336; + const int UBMOVE_MAX_M_HIGH = 14144; + const int UBMOVE_DEFAULT = 20; + const int UBMOVE_SMALL_M_SMALL_K_SMALL_N = 50; + const int UBMOVE_SMALL_M_LARGE_K_SMALL_N = 60; + const int UBMOVE_LARGE_M_SMALL_N = 80; + const int SCALE_FACTOR = 512; + const int PVALUE_M_SMALL = 6144; + const int PVALUE_K_SMALL = 10240; + const int PVALUE_N_MEDIUM = 9216; + const int PVALUE_M_MEDIUM = 14144; + const int PVALUE_M_LARGE = 10096; + const int PVALUE_ONE = 1; + const int PVALUE_TWO = 2; + const int PVALUE_THREE = 3; + const int PVALUE_FOUR = 4; + + static std::vector g_allgatherSwizzldirectCoef = { + { -2.462e-04, -7.154e-06, 6.700e-05, 1.416e-06, 1.747e-04, 1.513e-07, 2.296e-02, -3.022e-04, -6.992e-03, + -1.865e-03, 8.685e-03, -2.039e-03, -1.701e-02, 1.805e-03, 1.174e-03, -5.262e-03, 3.752e-05, -1.539e-05, + -2.508e-02, -9.660e-05, 2.489e-03, -7.638e-03, -1.360e-03, -3.614e-04, -1.150e-03 } + }; + + static std::map>> g_allgatherEightRankFP16M0Map = { + {128, + {{-1, 1262, -1, 2147483647, -1, 576}, {-1, 1262, 5660, 2147483647, 576, 1200}, + {-1, 1262, -1, 2147483647, 1200, 5552}, {1262, 8958, -1, 2274, 1888, 5552}, + {8958, 2147483647, -1, 2147483647, 1536, 5552}, {-1, 2147483647, -1, 2274, 6684, 2147483647}, + {9728, 2147483647, 2786, 3286, 5552, 2147483647}, {-1, 768, 3286, 7696, 5552, 2147483647}}}, + {256, + {{-1, 1262, -1, 5660, 576, 1200}, {1262, 8958, -1, 2147483647, -1, 1888}, + {1262, 8958, 2274, 2147483647, 1888, 5552}, {8958, 2147483647, -1, 2147483647, -1, 1536}, + {-1, 2147483647, -1, 2274, 5552, 6684}, {-1, 2147483647, 2274, 2786, 5552, 2147483647}, + {-1, 9728, 2786, 3286, 5552, 2147483647}, {-1, 768, 7696, 2147483647, 5552, 2147483647}, + {768, 2147483647, 3286, 2147483647, 5552, 2147483647}}} + }; + + static std::map>> g_allgatherEightRankFP16CommdatasplitMap = { + {8, + {{-1, 1262, -1, 768, -1, 576}, {-1, 2274, 768, 2147483647, -1, 576}, + {-1, 2274, -1, 2147483647, 576, 1624}, {-1, 1512, -1, 2147483647, 1624, 1720}, + {-1, 2274, -1, 2147483647, 1720, 2274}, {2274, 9728, -1, 2147483647, -1, 768}, + {9728, 2147483647, 768, 2147483647, -1, 768}, {2786, 2147483647, -1, 4298, 768, 1774}, + {2274, 2147483647, 4298, 2147483647, 768, 1774}, {2274, 3584, -1, 1262, 1774, 2274}, + {3584, 2147483647, 1774, 2147483647, 1774, 2274}, {-1, 768, 6700, 2147483647, 2626, 3248}}}, + {16, + {{1262, 2274, -1, 768, -1, 576}, {1512, 2274, -1, 2147483647, 1624, 1720}, + {9728, 2147483647, -1, 768, -1, 768}, {2274, 2786, -1, 4298, 768, 1774}, + {2274, 3584, 2798, 2147483647, 1774, 2274}, {768, 2147483647, 6700, 2147483647, 2626, 2912}}}, + {2, + {{2274, 3584, 1262, 2798, 1774, 2274}, {3584, 2147483647, -1, 1774, 1774, 2274}, + {-1, 2147483647, -1, 2147483647, 2274, 2626}, {-1, 2147483647, -1, 6700, 2626, 3248}, + {768, 2147483647, 6700, 2147483647, 2912, 3248}, {-1, 2147483647, -1, 2147483647, 3248, 2147483647}}} + }; + + static std::map>> g_allgatherEightRankFP16CommdirectMap = { + {1, + {{-1, 768, -1, 2147483647, -1, 4608}, {-1, 768, 768, 2147483647, 4608, 5824}, + {768, 2147483647, -1, 2147483647, -1, 2912}, {1774, 2147483647, -1, 768, 2912, 3584}, + {768, 2147483647, -1, 768, 3584, 6172}, {768, 2147483647, 768, 2147483647, 2912, 6172}, + {-1, 6950, -1, 2147483647, 6172, 2147483647}, {6950, 7450, -1, 2560, 6172, 2147483647}, + {6950, 7450, 2560, 3584, 6172, 8704}, {6950, 7450, 3584, 2147483647, 6172, 2147483647}, + {7450, 2147483647, -1, 2147483647, 6172, 2147483647}}}, + {0, + {{-1, 768, -1, 768, 4608, 5824}, {-1, 768, -1, 2147483647, 5824, 6172}, + {768, 1774, -1, 768, 2912, 3584}, {6950, 7450, 2560, 3584, 8704, 2147483647}}} + }; + + static std::map>> g_allgatherEightRankFP16PvalueMap = { + {1, + {{-1, 768, -1, 2147483647, -1, 576}, {768, 1262, 7196, 2147483647, -1, 576}, + {1262, 4298, -1, 2147483647, -1, 576}, {-1, 4298, 768, 2147483647, 576, 1518}, + {-1, 4298, -1, 2147483647, 1518, 1984}, {-1, 2560, -1, 5660, 1984, 2274}, + {-1, 4298, 5660, 2147483647, 1984, 2274}, {4810, 2147483647, 7946, 10720, -1, 2274}, + {-1, 8958, -1, 2147483647, 2274, 2912}, {-1, 8958, 2560, 2147483647, 2912, 3248}, + {-1, 8958, -1, 1262, 3248, 5660}, {-1, 8958, 1262, 2147483647, 3248, 6450}, + {-1, 8958, -1, 2147483647, 6450, 2147483647}, {8958, 9728, -1, 2147483647, 2274, 5660}, + {8958, 9728, 1536, 2147483647, 5660, 6684}, {8958, 9728, -1, 2147483647, 6684, 2147483647}, + {9728, 2147483647, -1, 768, 6684, 2147483647}, {9728, 2147483647, 768, 3584, 4608, 2147483647}, + {9728, 2147483647, 3584, 2147483647, 2274, 2147483647}}}, + {6, + {{768, 1262, -1, 7196, -1, 576}, {-1, 4298, -1, 768, 576, 1518}, + {4810, 7850, -1, 1262, -1, 2274}, {4810, 7450, 1262, 3286, -1, 2274}}}, + {4, + {{2560, 4298, -1, 5660, 1984, 2274}, {7450, 2147483647, 1262, 3286, -1, 2274}, + {9728, 2147483647, -1, 768, 2274, 3584}}}, + {2, + {{4298, 4810, -1, 2147483647, -1, 2274}, {4810, 2147483647, 3286, 7946, -1, 2274}, + {4810, 2147483647, 10720, 2147483647, -1, 2274}, {-1, 8958, -1, 2560, 2912, 3248}, + {-1, 8958, -1, 1262, 5660, 6450}, {8958, 9728, -1, 1536, 5660, 6684}, + {9728, 2147483647, -1, 768, 3584, 6684}, {9728, 2147483647, 768, 3584, 2274, 4608}}}, + {8, + {{7850, 2147483647, -1, 1262, -1, 2274}}} + }; + + static std::map>> g_allgatherEightRankFP16UbmovenumMap = { + {3.0, + {{-1, 1262, -1, 2147483647, -1, 832}, {-1, 1262, 768, 2147483647, 832, 2400}, + {1262, 2147483647, 768, 2147483647, -1, 1624}, {1262, 2147483647, 1774, 2147483647, 1624, 2274}, + {1262, 2147483647, -1, 1262, 6684, 7434}, {7850, 2147483647, 1262, 1774, 5552, 7434}, + {1262, 2147483647, 1774, 2147483647, 5552, 7434}, {1262, 2147483647, -1, 768, 7434, 8704}, + {-1, 768, 1262, 3286, 7434, 9728}, {768, 1262, 3286, 2147483647, 7434, 9728}}}, + {2.0, + {{-1, 1262, -1, 768, 832, 2400}, {-1, 1262, 7696, 2147483647, 6684, 7434}, + {1262, 2147483647, -1, 768, -1, 1624}, {1262, 2147483647, -1, 1262, 5552, 6684}, + {1262, 7850, 1262, 1774, 5552, 7434}, {-1, 1262, -1, 768, 7434, 8704}, + {-1, 2147483647, -1, 768, 8704, 9728}, {768, 2147483647, 768, 3286, 7434, 9728}, + {1262, 2147483647, 3286, 2147483647, 7434, 9728}, {-1, 7200, -1, 2147483647, 9728, 2147483647}, + {7200, 2147483647, -1, 2147483647, 9728, 11744}, {7200, 2147483647, -1, 11744, 11744, 2147483647}}}, + {8.0, + {{-1, 1262, -1, 2147483647, 2400, 6172}, {1262, 2147483647, -1, 2147483647, 2274, 3248}, + {-1, 768, 3286, 2147483647, 7434, 9728}}}, + {6.0, + {{-1, 768, -1, 2147483647, 6172, 6684}, {-1, 1262, -1, 7696, 6684, 7434}, + {1262, 2147483647, -1, 2147483647, 3248, 4298}, {1262, 2147483647, -1, 768, 4298, 5552}}}, + {4.0, + {{768, 1262, -1, 2147483647, 6172, 6684}, {1262, 2147483647, 768, 2147483647, 4298, 5552}, + {-1, 768, 768, 1262, 7434, 9728}, {7200, 2147483647, 11744, 2147483647, 11744, 2147483647}}}, + {10.0, + {{1262, 2147483647, -1, 1774, 1624, 2274}}} + }; + + static std::map>> g_allgatherFourRankINT8M0Map = { + {128, + {{-1, 2147483647, -1, 2147483647, -1, 3584}, {-1, 2147483647, -1, 4608, 3584, 8704}, + {-1, 2560, 4608, 2147483647, 3584, 8704}, {-1, 2147483647, -1, 2560, 8704, 2147483647}}}, + {256, + {{2560, 2147483647, 4608, 2147483647, 3584, 8704}, {-1, 2147483647, 2560, 2147483647, 8704, 2147483647}}} + }; + + static std::map>> g_allgatherFourRankINT8PvalueMap = { + {12, + {{-1, 5632, -1, 1792, -1, 1536}, {9728, 2147483647, -1, 8704, -1, 2560}}}, + {10, + {{-1, 5632, 1792, 3584, -1, 1536}, {2560, 5632, -1, 1280, 1536, 2560}, + {5632, 7680, -1, 2147483647, -1, 4608}, {7680, 9728, -1, 2147483647, -1, 2560}, + {7680, 9728, -1, 5632, 3584, 4608}, {9728, 2147483647, 8704, 2147483647, 1536, 2560}, + {9728, 2147483647, 1280, 2147483647, 3584, 4608}}}, + {6, + {{-1, 5632, 3584, 8704, -1, 1536}, {-1, 2560, -1, 1280, 1536, 2560}, + {3584, 5632, -1, 2147483647, 2560, 3584}, {1536, 5632, -1, 6656, 3584, 4608}, + {9728, 2147483647, 8704, 2147483647, -1, 1536}}}, + {3, + {{-1, 5632, 8704, 2147483647, -1, 1536}, {-1, 5632, 1280, 2147483647, 1536, 2560}, + {768, 2560, -1, 2147483647, 2560, 3584}, {1536, 5632, 6656, 2147483647, 3584, 4608}, + {7680, 9728, 5632, 2147483647, 3584, 4608}, {1536, 9728, -1, 1280, 4608, 8704}, + {9728, 2147483647, -1, 2560, 4608, 8704}, {1536, 2147483647, 2560, 6656, 4608, 5632}}}, + {1, + {{-1, 768, -1, 2147483647, 2560, 3584}, {-1, 1536, 4096, 2147483647, 3584, 4608}, + {-1, 768, -1, 2560, 4608, 8704}, {-1, 1536, 2560, 2147483647, 4608, 5632}, + {1536, 2147483647, 6656, 2147483647, 4608, 5632}, {-1, 2147483647, 2560, 2147483647, 5632, 8704}, + {-1, 768, -1, 2147483647, 8704, 2147483647}, {768, 2147483647, -1, 768, 9728, 11264}, + {768, 2147483647, 768, 2147483647, 8704, 11264}, {768, 1536, 4096, 2147483647, 11264, 2147483647}, + {1536, 2147483647, -1, 2147483647, 11264, 2147483647}}}, + {4, + {{2560, 3584, -1, 2147483647, 2560, 3584}}}, + {2, + {{-1, 1536, -1, 4096, 3584, 4608}, {768, 1536, -1, 2560, 4608, 8704}, + {1536, 9728, 1280, 2560, 4608, 8704}, {768, 2147483647, -1, 768, 8704, 9728}, + {768, 1536, -1, 4096, 11264, 2147483647}}}, + {14, + {{7680, 9728, -1, 2147483647, 2560, 3584}, {9728, 2147483647, -1, 1280, 2560, 4608}, + {9728, 2147483647, 1280, 2147483647, 2560, 3584}}} + }; + + static std::map>> g_allgatherFourRankINT8CommdatasplitMap = { + {16, + {{-1, 2147483647, -1, 2147483647, -1, 2147483647}}} + }; + + static std::map>> g_allgatherFourRankINT8UbmovenumMap = { + {4.0, + {{-1, 2560, -1, 2147483647, -1, 3584}, {-1, 2560, 1792, 2147483647, 3584, 4608}, + {2560, 2147483647, -1, 1792, -1, 2560}, {2560, 4608, -1, 1792, 2560, 4608}, + {2560, 2147483647, 1792, 2147483647, -1, 3584}, {9728, 2147483647, 1792, 2147483647, 3584, 4608}, + {-1, 768, -1, 4096, 4608, 5632}}}, + {2.0, + {{-1, 2560, -1, 1792, 3584, 4608}, {-1, 768, 4096, 2147483647, 4608, 5632}, + {-1, 768, -1, 2147483647, 5632, 2147483647}, {768, 2147483647, -1, 2147483647, 4608, 2147483647}}}, + {3.0, + {{4608, 2147483647, -1, 1792, 2560, 4608}, {2560, 9728, 1792, 2147483647, 3584, 4608}}} + }; + + void AllGatherFourRankINT8Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {ALLGATHER_FOUR_RANK_INT8_M0_DEFAULT, + g_allgatherFourRankINT8M0Map}}, + {&cocTilingData.pValue, + {ALLGATHER_FOUR_RANK_INT8_PVALUE_DEFAULT, + g_allgatherFourRankINT8PvalueMap}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_FOUR_RANK_INT8_COMMDATASPLIT_DEFAULT, + g_allgatherFourRankINT8CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_FOUR_RANK_INT8_UBMOVENUM_DEFAULT, + g_allgatherFourRankINT8UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ZERO}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_NPU_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {ALLGATHER_EIGHT_RANK_FP16_M0_DEFAULT, + g_allgatherEightRankFP16M0Map}}, + {&cocTilingData.commDataSplit, + {ALLGATHER_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgatherEightRankFP16CommdatasplitMap}}, + {&cocTilingData.commDirect, + {ALLGATHER_EIGHT_RANK_FP16_COMMDIRECT_DEFAULT, + g_allgatherEightRankFP16CommdirectMap}}, + {&cocTilingData.pValue, + {ALLGATHER_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allgatherEightRankFP16PvalueMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHER_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgatherEightRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ZERO}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}} + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit >= COMMDATASPLIT_EIGHT ? COMMNPUSPLIT_ONE : cocTilingData.rankSize; + cocTilingData.commDataSplit = ClampValue(cocTilingData.commDataSplit, COMMDATASPLIT_ONE, + cocTilingData.blockDim / cocTilingData.commNpuSplit); + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } + + int AllGatherUbMoveNum(int m, int k, int n) + { + if (m <= UBMOVE_MAX_M) { + if (k <= UBMOVE_MAX_K) { + if (n <= UBMOVE_MAX_N) { + return UBMOVE_SMALL_M_SMALL_K_SMALL_N * SCALE_FACTOR; + } else { + return UBMOVE_DEFAULT * SCALE_FACTOR; + } + } else { + return UBMOVE_SMALL_M_LARGE_K_SMALL_N * SCALE_FACTOR; + } + } else { + if (n <= UBMOVE_MAX_N) { + if (m <= UBMOVE_MAX_M_HIGH) { + return UBMOVE_LARGE_M_SMALL_N * SCALE_FACTOR; + } else { + return UBMOVE_SMALL_M_LARGE_K_SMALL_N * SCALE_FACTOR; + } + } else { + return UBMOVE_DEFAULT * SCALE_FACTOR; + } + } + return UBMOVE_DEFAULT * SCALE_FACTOR; + } + + int AllGatherPValue(int m, int k, int n) + { + if (m <= PVALUE_M_SMALL) { + if (k <= PVALUE_K_SMALL) { + return PVALUE_ONE; + } else { + if (n <= PVALUE_N_MEDIUM) { + return PVALUE_ONE; + } else { + return PVALUE_TWO; + } + } + } else { + if (n <= PVALUE_N_MEDIUM) { + if (m <= PVALUE_M_MEDIUM) { + return PVALUE_ONE; + } else { + return PVALUE_THREE; + } + } else { + if (m <= PVALUE_M_LARGE) { + return PVALUE_THREE; + } else { + return PVALUE_FOUR; + } + } + } + } + + void AllGatherGetDefaultTiling(CoCTilingData &cocTilingData) + { + int32_t m = cocTilingData.m; + int32_t k = cocTilingData.k; + int32_t n = cocTilingData.n; + double mknGB = (1.0 * m / ONE_K) * (1.0 * k / ONE_K) * (1.0 * n / ONE_K); + double mkGB = (1.0 * m / ONE_K) * (1.0 * k / ONE_K); + double mnGB = (1.0 * m / ONE_K) * (1.0 * n / ONE_K); + double knGB = (1.0 * k / ONE_K) * (1.0 * n / ONE_K); + double c0 = sqrt(1.0 * m / k); + double c1 = 1.0 * m * k / n; + double c2 = sqrt(c1); + double c3 = sqrt(m * k) / n; + double c4 = sqrt(1.0 * k / n); + double swizzlDirectDouble = 0; + std::vector feats = { 1.0 * m, 1.0 / m, 1.0 * k, 1.0 / k, 1.0 * n, 1.0 / n, mknGB, + 1.0 / mknGB, mkGB, 1.0 / mkGB, mnGB, 1.0 / mnGB, knGB, 1.0 / knGB, + c0, 1.0 / c0, c1, 1.0 / c1, c2, 1.0 / c2, c3, + 1.0 / c3, c4, 1.0 / c4, 1 }; + for (uint32_t i = 0; i < feats.size(); i++) { + swizzlDirectDouble += feats[i] * g_allgatherSwizzldirectCoef[i]; + } + swizzlDirectDouble = 1.0 / (1.0 + exp(-swizzlDirectDouble)); + if (swizzlDirectDouble >= HALF_PROB) { + cocTilingData.swizzlDirect = 1; + } else { + cocTilingData.swizzlDirect = 0; + } + + cocTilingData.pValue = AllGatherPValue(m, k, n); + cocTilingData.ubMoveNum = AllGatherUbMoveNum(m, k, n); + cocTilingData.m0 = DEFAULT_ROW; + cocTilingData.n0 = DEFAULT_COL; + cocTilingData.k0 = DEFAULT_COL; + cocTilingData.kLoop = CeilDev(k, cocTilingData.k0); + + cocTilingData.write2OtherRank = 1; + cocTilingData.commDirect = COMM_DATA_DIRECT; + cocTilingData.commNpuSplit = cocTilingData.rankSize; + cocTilingData.commDataSplit = COMMDATASPLIT_ONE; + DealTilingParamByBuffSize(cocTilingData); + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.k0 * cocTilingData.kLoop * cocTilingData.pValue; + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgatherv2_tiling_91093.cpp b/comm/lcal/src/tiling/allgatherv2_tiling_91093.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6bb507892ae23e15337d5c4fc665d82696fb422f --- /dev/null +++ b/comm/lcal/src/tiling/allgatherv2_tiling_91093.cpp @@ -0,0 +1,357 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_91093.h" +#include "tiling_func.h" + +namespace Lcal { + constexpr int32_t ALLGATHERV2_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT = 6; + constexpr int32_t ALLGATHERV2_91093_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHERV2_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 190; + constexpr int32_t ALLGATHERV2_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHERV2_91093_SIXTEEN_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHERV2_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT = 160; + constexpr int32_t ALLGATHERV2_91093_SIXTEEN_RANK_FP16_COMMNPUSPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHERV2_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT = 12; + constexpr int32_t ALLGATHERV2_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT = 12; + constexpr int32_t ALLGATHERV2_91093_TWO_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHERV2_91093_TWO_RANK_FP16_PVALUE_DEFAULT = 14; + + static std::map>> g_allgatherV291093EightRankFP16CommdatasplitMap = { + {8, + {{-1, 1518, -1, 2147483647, -1, 1624}, {-1, 1518, -1, 4608, 1880, 2274}, + {-1, 1518, 4608, 2147483647, 1624, 2274}, {1518, 8600, -1, 2147483647, -1, 1200}, + {8600, 8958, 10494, 2147483647, -1, 2274}, {8958, 2147483647, -1, 2147483647, -1, 2274}, + {-1, 2147483647, -1, 768, 2912, 5660}}}, + {16, + {{-1, 1518, -1, 4608, 1624, 1880}, {1518, 8600, -1, 2147483647, 1200, 2274}, + {8600, 8958, -1, 10494, -1, 2274}, {-1, 2147483647, -1, 2147483647, 2274, 2912}, + {-1, 2147483647, -1, 768, 5660, 6450}, {-1, 2147483647, 768, 2147483647, 2912, 6450}, + {-1, 2147483647, -1, 2147483647, 6450, 2147483647}}} + }; + + static std::map>> g_allgatherV291093EightRankFP16UbmovenumMap = { + {160, + {{-1, 6950, -1, 2560, -1, 576}, {-1, 6950, -1, 2560, 832, 1200}, + {1262, 6950, -1, 1774, 1200, 2274}, {6950, 2147483647, -1, 2274, -1, 1536}, + {8858, 9728, -1, 2274, 1536, 2274}, {-1, 2147483647, -1, 768, 2912, 4608}, + {-1, 8720, -1, 768, 4608, 5312}}}, + {8, + {{-1, 6950, 2560, 2147483647, -1, 576}, {-1, 6950, -1, 2147483647, 576, 832}, + {-1, 6950, 2560, 2147483647, 832, 1200}, {-1, 1262, -1, 2147483647, 1200, 2274}, + {1262, 6950, 3798, 2147483647, 1200, 2274}, {6950, 2147483647, 2274, 2147483647, -1, 768}, + {6950, 9728, 2274, 2147483647, 768, 1262}, {6950, 2147483647, 4810, 2147483647, 1262, 2274}}}, + {2, + {{1262, 6950, 1774, 3798, 1200, 2274}, {6950, 8858, -1, 2274, 1536, 2274}, + {9728, 2147483647, 1280, 2274, 1536, 2274}, {6950, 2147483647, 2274, 4810, 1262, 1774}, + {-1, 2147483647, -1, 2147483647, 2274, 2912}, {8720, 2147483647, -1, 768, 4608, 5312}, + {-1, 2147483647, 768, 2147483647, 2912, 5312}, {-1, 2147483647, -1, 2147483647, 5312, 8446}, + {-1, 9728, -1, 2147483647, 8446, 10720}, {9728, 2147483647, -1, 9728, 8446, 10720}, + {-1, 2147483647, -1, 2147483647, 10720, 2147483647}}}, + {190, + {{9728, 2147483647, -1, 1280, 1536, 2274}}}, + {20, + {{9728, 2147483647, 2274, 2147483647, 768, 1262}}}, + {3, + {{6950, 2147483647, 2274, 4810, 1774, 2274}, {9728, 2147483647, 9728, 2147483647, 8446, 10720}}} + }; + + static std::map>> g_allgatherV291093EightRankFP16M0Map = { + {128, + {{-1, 4810, -1, 2147483647, -1, 832}, {-1, 4810, -1, 768, 1536, 2274}, + {-1, 4810, 768, 2147483647, 832, 2274}, {4810, 2147483647, -1, 2147483647, -1, 2274}, + {-1, 2147483647, -1, 2274, 2274, 5660}, {-1, 2147483647, 2274, 6950, 2274, 5360}, + {-1, 2147483647, 6950, 7450, 2274, 5660}, {-1, 2147483647, 7450, 2147483647, 2274, 6934}, + {-1, 2147483647, -1, 2274, 6934, 2147483647}, {-1, 2147483647, 2274, 2786, 9470, 2147483647}, + {-1, 2147483647, 2786, 2147483647, 6934, 2147483647}}}, + {256, + {{-1, 4810, -1, 768, 832, 1536}, {-1, 2147483647, -1, 2274, 5660, 6934}, + {-1, 2147483647, 2274, 6950, 5360, 6934}, {-1, 2147483647, 6950, 7450, 5660, 6934}, + {-1, 2147483647, 2274, 2786, 6934, 9470}}} + }; + + static std::map>> g_allgatherV291093EightRankFP16PvalueMap = { + {1, + {{-1, 1774, -1, 2147483647, -1, 576}, {-1, 1262, -1, 2147483647, 576, 2274}, + {-1, 2147483647, -1, 2274, 5660, 6934}, {-1, 2147483647, 2274, 2147483647, 2274, 6934}, + {-1, 2147483647, 1774, 5900, 6934, 2147483647}}}, + {4, + {{1774, 2786, -1, 2147483647, -1, 576}, {2786, 5660, -1, 2274, -1, 2274}, + {2786, 2147483647, 2274, 2147483647, 768, 2274}, {-1, 2147483647, 5900, 2147483647, 6934, 2147483647}}}, + {2, + {{1262, 2786, -1, 2147483647, 576, 2274}, {2786, 2147483647, 2274, 2147483647, -1, 768}, + {-1, 2147483647, -1, 2274, 2274, 5660}, {-1, 2147483647, -1, 1774, 6934, 2147483647}}}, + {6, + {{5660, 2147483647, -1, 2274, -1, 2274}}} + }; + + static std::map>> g_allgatherV291093SixteenRankFP16PvalueMap = { + {4, + {{-1, 2786, -1, 768, -1, 768}, {-1, 2786, -1, 768, 1984, 2274}, + {4608, 6700, -1, 768, -1, 2274}, {3798, 4298, -1, 1006, 2274, 5312}, + {4298, 2147483647, -1, 1262, 2274, 4608}}}, + {2, + {{-1, 2786, -1, 768, 768, 1200}, {1262, 2786, 768, 1262, 1984, 2274}, + {2786, 9728, 768, 3286, -1, 768}, {2786, 2147483647, 768, 3286, 768, 2274}, + {2786, 5900, 3286, 5660, 1774, 2274}, {5900, 2147483647, 3286, 3798, -1, 2274}, + {768, 3798, -1, 1262, 2274, 3248}, {1774, 3798, -1, 768, 3248, 5312}, + {4298, 2147483647, -1, 1262, 4608, 5312}}}, + {1, + {{-1, 2786, 768, 2147483647, -1, 1200}, {-1, 2786, -1, 2147483647, 1200, 1984}, + {-1, 1262, 768, 1262, 1984, 2274}, {-1, 2786, 1262, 2147483647, 1984, 2274}, + {2786, 5900, 3286, 2147483647, -1, 1774}, {2786, 5900, 5660, 2147483647, 1774, 2274}, + {5900, 2147483647, 3798, 2147483647, -1, 2274}, {-1, 768, -1, 1262, 2274, 3248}, + {-1, 3798, 1262, 2147483647, 2274, 3248}, {-1, 1774, -1, 768, 3248, 5312}, + {-1, 3798, 768, 2147483647, 3248, 5312}, {3798, 4298, 1006, 2147483647, 2274, 5312}, + {4298, 2147483647, 1262, 2147483647, 2274, 5312}, {-1, 2147483647, -1, 2147483647, 5312, 2147483647}}}, + {6, + {{2786, 4608, -1, 768, -1, 2274}}}, + {10, + {{6700, 2147483647, -1, 768, -1, 768}}}, + {8, + {{6700, 2147483647, -1, 768, 768, 2274}}}, + {12, + {{9728, 2147483647, 768, 3286, -1, 768}}} + }; + + static std::map>> g_allgatherV291093SixteenRankFP16CommnpusplitMap = { + {8, + {{-1, 2274, -1, 1262, -1, 576}, {-1, 2274, -1, 768, 576, 1200}, + {2024, 2274, 6160, 7696, 1624, 2274}, {2274, 2147483647, -1, 1774, -1, 768}, + {2274, 2147483647, -1, 768, 768, 1262}, {2274, 2147483647, -1, 768, 1774, 2274}, + {2274, 3584, 5360, 2147483647, -1, 768}, {-1, 768, 5660, 7696, 2274, 2147483647}}}, + {1, + {{-1, 2274, 1262, 2147483647, -1, 576}, {-1, 2274, 768, 2147483647, 576, 1200}, + {-1, 2274, -1, 2147483647, 1200, 1624}, {-1, 1518, -1, 2147483647, 1824, 2274}, + {2024, 2274, -1, 6160, 1624, 2274}, {2024, 2274, 7696, 2147483647, 1624, 2274}, + {2274, 2147483647, 1774, 5360, -1, 768}, {2274, 2147483647, 768, 5360, 768, 1262}, + {2274, 2147483647, 768, 5360, 1774, 2274}, {6950, 2147483647, 7696, 2147483647, 1774, 2274}, + {-1, 768, -1, 768, 2274, 4608}, {768, 6450, -1, 768, 2274, 5312}, + {6450, 2147483647, -1, 768, 2274, 5660}}}, + {16, + {{-1, 1518, -1, 2147483647, 1624, 1824}, {1518, 2024, -1, 2147483647, 1624, 2274}, + {2274, 2147483647, -1, 5360, 1262, 1774}, {3584, 6950, 5360, 2147483647, -1, 768}, + {2274, 6950, 5360, 2147483647, 768, 2274}, {6950, 2147483647, 5360, 2147483647, -1, 1774}, + {6950, 2147483647, 5360, 7696, 1774, 2274}, {-1, 768, -1, 768, 4608, 2147483647}, + {-1, 768, 768, 5660, 2274, 2147483647}, {-1, 768, 7696, 2147483647, 2274, 2147483647}, + {768, 6450, 768, 2147483647, 2274, 5312}, {768, 6450, -1, 2147483647, 5312, 2147483647}, + {6450, 2147483647, -1, 768, 5660, 2147483647}, {6450, 2147483647, 768, 2147483647, 2274, 2147483647}}} + }; + + static std::map>> g_allgatherV291093SixteenRankFP16UbmovenumMap = { + {160, + {{-1, 2274, -1, 768, -1, 1456}, {2274, 2147483647, -1, 768, -1, 2274}}}, + {16, + {{-1, 2274, 768, 2147483647, -1, 1456}, {-1, 2274, -1, 8704, 1456, 1824}, + {-1, 2274, -1, 2147483647, 1824, 2400}, {1262, 2274, 768, 2147483647, 2400, 3504}, + {2274, 2147483647, 768, 5900, -1, 2274}, {2274, 2147483647, 5900, 2147483647, 1774, 2274}, + {2274, 2147483647, 11744, 2147483647, 2274, 2786}, {2274, 2147483647, -1, 7696, 2786, 3286}}}, + {8, + {{-1, 2274, 8704, 2147483647, 1456, 1824}, {1262, 2274, 768, 2147483647, 3504, 5552}, + {2274, 2147483647, 768, 2147483647, 3286, 4298}, {768, 1262, 2560, 2147483647, 5552, 6684}, + {1262, 1518, -1, 2560, 5552, 6684}}}, + {18, + {{-1, 1262, -1, 2147483647, 2400, 2912}, {2274, 2147483647, 5900, 2147483647, -1, 1774}}}, + {14, + {{-1, 768, -1, 2147483647, 2912, 5552}}}, + {10, + {{768, 1262, -1, 2147483647, 2912, 5552}, {-1, 768, -1, 2560, 5552, 6684}, + {-1, 768, -1, 1262, 6684, 8704}, {-1, 768, 6656, 2147483647, 8704, 2147483647}, + {768, 1262, -1, 2560, 5552, 6684}, {1262, 1518, 2560, 2147483647, 5552, 6684}}}, + {2, + {{1262, 2274, -1, 768, 2400, 5552}, {2274, 2147483647, -1, 768, 3286, 5552}}}, + {12, + {{2274, 2147483647, -1, 11744, 2274, 2786}, {2274, 2147483647, 7696, 2147483647, 2786, 3286}, + {-1, 768, 2560, 2147483647, 5552, 6172}, {-1, 768, 1262, 2147483647, 6684, 8704}, + {-1, 768, -1, 6656, 8704, 2147483647}}}, + {6, + {{2274, 2786, 768, 2147483647, 4298, 5552}, {768, 1518, -1, 1774, 6684, 2147483647}, + {768, 1518, 5660, 11264, 6684, 2147483647}}}, + {5, + {{2786, 2147483647, 768, 2147483647, 4298, 5552}, {1518, 3286, -1, 2147483647, 5552, 5872}}}, + {80, + {{-1, 768, 2560, 2147483647, 6172, 6684}}}, + {4, + {{768, 1518, 1774, 5660, 6684, 2147483647}, {1518, 3286, -1, 2147483647, 5872, 8958}, + {2786, 3286, 6160, 2147483647, 8958, 2147483647}, {3286, 8958, 1774, 2274, 5552, 2147483647}, + {3286, 8958, 2274, 2147483647, 5552, 7434}, {8958, 9728, 2560, 2147483647, 5552, 2147483647}, + {9728, 2147483647, 9728, 2147483647, 5552, 2147483647}}}, + {3, + {{768, 1518, 11264, 2147483647, 6684, 2147483647}, {1518, 2786, -1, 2147483647, 8958, 2147483647}, + {2786, 3286, -1, 6160, 8958, 2147483647}, {3286, 8958, -1, 1774, 5552, 2147483647}, + {3286, 8958, 2274, 2147483647, 7434, 2147483647}, {8958, 9728, -1, 2560, 5552, 2147483647}, + {9728, 2147483647, -1, 9728, 5552, 2147483647}}} + }; + + static std::map>> g_allgatherV291093SixteenRankFP16M0Map = { + {256, + {{-1, 2024, -1, 7696, -1, 576}, {-1, 2024, -1, 4608, 576, 1200}, + {768, 2024, -1, 2147483647, 1200, 1518}, {2024, 9728, 768, 2147483647, -1, 1518}, + {9728, 2147483647, 2560, 2147483647, -1, 1518}, {-1, 2147483647, -1, 2274, 5660, 6684}, + {2274, 8958, 2274, 2147483647, 1518, 9728}, {2274, 2147483647, 2274, 11744, 9728, 2147483647}}}, + {128, + {{-1, 2024, 7696, 2147483647, -1, 576}, {-1, 2024, 4608, 2147483647, 576, 1200}, + {-1, 768, -1, 2147483647, 1200, 1518}, {2024, 2147483647, -1, 768, -1, 1518}, + {9728, 2147483647, 768, 2560, -1, 1518}, {-1, 2147483647, -1, 2274, 1518, 5660}, + {-1, 2147483647, -1, 2274, 6684, 2147483647}, {-1, 2274, 2274, 2147483647, 1518, 2147483647}, + {8958, 2147483647, 2274, 2147483647, 1518, 9728}, {2274, 2147483647, 11744, 2147483647, 9728, 2147483647}}} + }; + + static std::map>> g_allgatherV291093TwoRankFP16PvalueMap = { + {3, + {{-1, 4608, -1, 3584, -1, 1536}}}, + {6, + {{-1, 4608, 3584, 4608, -1, 1536}, {4608, 6656, 4608, 2147483647, -1, 8704}, + {6656, 2147483647, 4608, 2147483647, -1, 7680}, {6656, 9728, -1, 2560, 9728, 15360}, + {9728, 2147483647, 1280, 2560, 9728, 2147483647}, {3584, 6656, 2560, 3584, 9728, 15360}}}, + {4, + {{-1, 4608, 4608, 2147483647, -1, 1536}, {-1, 2560, -1, 3584, 1536, 2560}, + {-1, 4608, 3584, 2147483647, 1536, 2560}}}, + {10, + {{2560, 4608, -1, 3584, 1536, 2560}, {3584, 4608, -1, 2147483647, 2560, 5632}, + {4608, 2147483647, 768, 1280, -1, 4608}, {4608, 2147483647, 1280, 2560, 1536, 2560}, + {-1, 1536, 7168, 2147483647, 9728, 11264}, {-1, 1536, 8704, 2147483647, 11264, 13312}, + {-1, 1536, 7680, 2147483647, 15360, 2147483647}, {1536, 2560, 8704, 2147483647, 11264, 2147483647}, + {4608, 9728, -1, 2560, 15360, 2147483647}}}, + {12, + {{-1, 1536, -1, 2147483647, 2560, 5632}, {2560, 3584, -1, 2147483647, 2560, 5632}, + {4608, 8704, -1, 768, -1, 9728}, {4608, 2147483647, 768, 1280, 4608, 9728}, + {4608, 2147483647, 1280, 2560, -1, 1536}, {1536, 2560, -1, 2147483647, 9728, 11264}, + {2560, 3584, 7680, 2147483647, 9728, 2147483647}, {3584, 6656, 2560, 3584, 15360, 2147483647}, + {6656, 2147483647, 2560, 3584, 9728, 2147483647}, {3584, 2147483647, 3584, 2147483647, 9728, 2147483647}}}, + {14, + {{1536, 2560, -1, 2147483647, 2560, 5632}, {-1, 4608, -1, 2147483647, 5632, 9728}, + {8704, 2147483647, -1, 768, -1, 9728}, {4608, 2147483647, 1280, 2560, 2560, 8704}, + {4608, 6656, 2560, 4608, -1, 9728}, {4608, 6656, 4608, 2147483647, 8704, 9728}, + {6656, 2147483647, 4608, 2147483647, 7680, 9728}, {-1, 1536, -1, 7168, 9728, 11264}, + {-1, 1536, -1, 8704, 11264, 13312}, {-1, 1536, -1, 7680, 13312, 2147483647}, + {-1, 1536, 7680, 2147483647, 13312, 15360}, {1536, 2560, -1, 8704, 11264, 2147483647}, + {2560, 3584, -1, 7680, 9728, 2147483647}, {3584, 6656, -1, 2560, 9728, 15360}, + {3584, 4608, -1, 2560, 15360, 2147483647}, {9728, 2147483647, -1, 1280, 9728, 2147483647}}}, + {8, + {{4608, 2147483647, 1280, 2560, 8704, 9728}, {6656, 2147483647, 2560, 4608, -1, 9728}}} + }; + + static std::map>> g_allgatherV291093TwoRankFP16M0Map = { + {128, + {{-1, 3584, -1, 2147483647, -1, 2560}, {-1, 3584, -1, 8704, 2560, 3584}, + {3584, 2147483647, -1, 2147483647, -1, 3584}, {4608, 2147483647, -1, 7680, 3584, 2147483647}, + {3584, 2147483647, 7680, 2147483647, 3584, 2147483647}}}, + {256, + {{-1, 3584, 8704, 2147483647, 2560, 3584}, {-1, 4608, -1, 7680, 3584, 2147483647}, + {-1, 3584, 7680, 2147483647, 3584, 2147483647}}} + }; + + static std::map>> g_allgatherV291093TwoRankFP16UbmovenumMap = { + {3.0, + {{-1, 4608, -1, 768, -1, 1536}, {-1, 4608, 2560, 2147483647, 5632, 7680}, + {-1, 1536, -1, 2147483647, 7680, 8704}, {4608, 7680, -1, 2147483647, -1, 4608}, + {4608, 7680, 7680, 2147483647, 4608, 8704}, {7680, 2147483647, -1, 2147483647, -1, 1536}, + {7680, 2147483647, 6656, 2147483647, 1536, 8704}}}, + {4.0, + {{-1, 4608, 768, 2147483647, -1, 1536}, {9728, 2147483647, -1, 768, 8704, 18432}}}, + {6.0, + {{-1, 4608, -1, 2147483647, 1536, 5632}, {-1, 4608, -1, 2560, 5632, 7680}, + {1536, 4608, -1, 768, 8704, 19456}, {9728, 2147483647, -1, 768, 18432, 2147483647}}}, + {2.0, + {{1536, 4608, -1, 2147483647, 7680, 8704}, {4608, 7680, -1, 7680, 4608, 8704}, + {7680, 2147483647, -1, 6656, 1536, 8704}, {4608, 9728, -1, 768, 8704, 2147483647}, + {-1, 2147483647, 768, 2147483647, 8704, 2147483647}}}, + {8.0, + {{-1, 1536, -1, 768, 8704, 13312}}}, + {10.0, + {{-1, 1536, -1, 768, 13312, 2147483647}}}, + {12.0, + {{1536, 4608, -1, 768, 19456, 2147483647}}} + }; + + void AllGatherV2NPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLGATHERV2_91093_EIGHT_RANK_FP16_M0_DEFAULT, + g_allgatherV291093EightRankFP16M0Map}}, + {&cocTilingData.commDataSplit, + {ALLGATHERV2_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgatherV291093EightRankFP16CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHERV2_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgatherV291093EightRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHERV2_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allgatherV291093EightRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.commDirect = + cocTilingData.commDataSplit == COMMDATASPLIT_ONE ? COMM_DATA_DIRECT : COMM_NPU_DIRECT; + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit == COMMDATASPLIT_ONE ? cocTilingData.rankSize : COMMNPUSPLIT_ONE; + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherV2NPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLGATHERV2_91093_SIXTEEN_RANK_FP16_M0_DEFAULT, + g_allgatherV291093SixteenRankFP16M0Map}}, + {&cocTilingData.commNpuSplit, + {ALLGATHERV2_91093_SIXTEEN_RANK_FP16_COMMNPUSPLIT_DEFAULT, + g_allgatherV291093SixteenRankFP16CommnpusplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHERV2_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgatherV291093SixteenRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHERV2_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT, + g_allgatherV291093SixteenRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.commDirect = + cocTilingData.commNpuSplit <= COMMNPUSPLIT_EIGHT ? COMM_NPU_DIRECT : COMM_DATA_DIRECT; + cocTilingData.commDataSplit = + cocTilingData.commNpuSplit > COMMNPUSPLIT_ONE ? COMMDATASPLIT_ONE : COMMDATASPLIT_EIGHT; + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherV2NPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.pValue, + {ALLGATHERV2_91093_TWO_RANK_FP16_PVALUE_DEFAULT, + g_allgatherV291093TwoRankFP16PvalueMap}}, + {&cocTilingData.m0, + {ALLGATHERV2_91093_TWO_RANK_FP16_M0_DEFAULT, + g_allgatherV291093TwoRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLGATHERV2_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgatherV291093TwoRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}} + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allgatherv2_tiling_910B.cpp b/comm/lcal/src/tiling/allgatherv2_tiling_910B.cpp new file mode 100644 index 0000000000000000000000000000000000000000..416da4eea5fee4e90d4e9dcabcbb8fb2d10ae864 --- /dev/null +++ b/comm/lcal/src/tiling/allgatherv2_tiling_910B.cpp @@ -0,0 +1,227 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_910B.h" +#include "tiling_func.h" +#include "lcal_types.h" + +namespace Lcal { + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_PVALUE_DEFAULT = 6; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 3; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 8; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_CORE16_PVALUE_DEFAULT = 6; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_CORE16_UBMOVENUM_DEFAULT = 8; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_CORE16_COMMDIRECT_DEFAULT = 1; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_CORE16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLGATHERV2_EIGHT_RANK_FP16_CORE16_M0_DEFAULT = 128; + + static std::map>> g_allgatherV2EightRankFP16CorE16M0Map = { + {128, + {{-1, 3798, -1, 2147483647, -1, 1200}, {-1, 3798, -1, 10720, 1720, 2274}, + {-1, 3798, 10720, 2147483647, 1200, 2274}, {3798, 4298, -1, 2786, -1, 2274}, + {4810, 2147483647, -1, 2786, -1, 2274}, {3798, 2147483647, 2786, 2147483647, -1, 2274}, + {-1, 2147483647, -1, 6950, 2912, 5360}, {-1, 2147483647, 6450, 6950, 5360, 6934}, + {-1, 2147483647, -1, 6950, 6934, 2147483647}, {-1, 2147483647, 9950, 2147483647, 2274, 2626}, + {-1, 2147483647, 6950, 2147483647, 2626, 2147483647}}}, + {256, + {{-1, 3798, -1, 10720, 1200, 1720}, {4298, 4810, -1, 2786, -1, 2274}, + {-1, 2147483647, -1, 6950, 2274, 2912}, {-1, 2147483647, -1, 6450, 5360, 6934}, + {-1, 2147483647, 6950, 9950, 2274, 2626}}} + }; + + static std::map>> g_allgatherV2EightRankFP16CorE16CommdatasplitMap = { + {16, + {{-1, 2147483647, -1, 2147483647, -1, 2147483647}}} + }; + + static std::map>> g_allgatherV2EightRankFP16CorE16CommdirectMap = { + {1, + {{-1, 2147483647, -1, 2147483647, -1, 3248}, {-1, 2147483647, -1, 768, 3248, 5660}, + {-1, 2147483647, -1, 768, 8704, 2147483647}, {-1, 2147483647, 768, 2147483647, 3248, 2147483647}}}, + {0, + {{-1, 2147483647, -1, 768, 5660, 8704}}} + }; + + static std::map>> g_allgatherV2EightRankFP16CorE16UbmovenumMap = { + {2, + {{-1, 6450, -1, 2147483647, -1, 2912}, {-1, 2786, -1, 2147483647, 2912, 7434}, + {2786, 6450, -1, 2147483647, 2912, 6934}, {2786, 6450, 768, 2147483647, 6934, 7434}, + {6450, 7850, 1262, 1774, 768, 7434}, {7850, 2147483647, -1, 1774, 1536, 7434}, + {6450, 8958, 1774, 10720, -1, 7434}, {6450, 8958, 10720, 2147483647, -1, 6150}, + {8958, 2147483647, 1774, 2147483647, -1, 7434}, {-1, 1262, -1, 2147483647, 7434, 8958}, + {1262, 1774, 1792, 2147483647, 7434, 8958}, {1774, 2147483647, -1, 2147483647, 7434, 8958}, + {1774, 2147483647, -1, 768, 8958, 2147483647}, {-1, 2147483647, 768, 2147483647, 8958, 2147483647}}}, + {3, + {{2786, 6450, -1, 768, 6934, 7434}, {6450, 7850, -1, 1262, -1, 7434}, + {7850, 2147483647, -1, 1774, 768, 1536}, {1262, 1774, -1, 1792, 7434, 8958}, + {1262, 1774, -1, 768, 8958, 2147483647}}}, + {6, + {{6450, 7850, 1262, 1774, -1, 768}}}, + {8, + {{7850, 2147483647, -1, 1774, -1, 768}}}, + {4, + {{6450, 8958, 10720, 2147483647, 6150, 7434}, {-1, 1262, -1, 768, 8958, 2147483647}}} + }; + + static std::map>> g_allgatherV2EightRankFP16CorE16PvalueMap = { + {2, + {{-1, 3798, -1, 2147483647, -1, 832}, {768, 3798, 768, 2147483647, 832, 1200}, + {2024, 2560, -1, 1262, 1200, 2274}, {2024, 3798, 1262, 2147483647, 1200, 2274}, + {3798, 4608, 7946, 8446, -1, 2274}, {3798, 2147483647, 9728, 2147483647, -1, 768}, + {-1, 1262, -1, 768, 2274, 5660}, {768, 1262, -1, 7696, 7680, 10752}, + {1262, 2147483647, -1, 2147483647, 2912, 3248}, {8958, 9728, -1, 2147483647, 7680, 2147483647}}}, + {4, + {{-1, 3798, -1, 768, 832, 1200}, {2560, 3798, -1, 1262, 1200, 2274}, + {3798, 2147483647, -1, 1774, -1, 2274}, {3798, 2147483647, 8446, 9728, -1, 1262}, + {3798, 2147483647, 9728, 2147483647, 768, 1262}, {3798, 2147483647, 8446, 2147483647, 1262, 2274}, + {-1, 1262, 7696, 2147483647, 8704, 10752}, {-1, 1262, -1, 2147483647, 10752, 2147483647}, + {1262, 2147483647, -1, 2274, 3248, 6934}, {1262, 2147483647, 8958, 2147483647, 4298, 6934}, + {1262, 1774, 6700, 2147483647, 6934, 2147483647}, {1774, 8958, 6450, 2147483647, 6934, 2147483647}, + {8958, 2147483647, 6700, 2147483647, 6934, 7680}, {9728, 2147483647, -1, 2147483647, 7680, 2147483647}}}, + {1, + {{-1, 768, 768, 2147483647, 832, 1200}, {-1, 2024, -1, 2147483647, 1200, 2274}, + {-1, 1262, -1, 768, 5660, 7680}, {-1, 1262, 768, 2147483647, 2274, 7680}, + {-1, 768, -1, 7696, 7680, 10752}, {-1, 1262, 7696, 2147483647, 7680, 8704}, + {1262, 2147483647, -1, 2147483647, 2274, 2912}, {1262, 2147483647, 2274, 8958, 3248, 6934}, + {1262, 2147483647, 8958, 2147483647, 3248, 4298}, {1262, 1774, -1, 6700, 6934, 2147483647}, + {1774, 8958, -1, 6450, 6934, 2147483647}, {8958, 2147483647, -1, 6700, 6934, 7680}}}, + {6, + {{3798, 2147483647, 1774, 7946, -1, 2274}, {4608, 2147483647, 7946, 8446, -1, 2274}}} + }; + + static std::map>> g_allgatherV2EightRankFP16CommdatasplitMap = { + {8, + {{-1, 2274, -1, 2147483647, -1, 5312}, {2274, 2147483647, -1, 2147483647, -1, 4810}, + {5148, 2147483647, -1, 768, 4810, 5312}, {2274, 2147483647, 768, 2147483647, 4810, 5312}, + {-1, 2147483647, -1, 2147483647, 5312, 2147483647}}}, + {4, + {{2274, 5148, -1, 768, 4810, 5312}}} + }; + + static std::map>> g_allgatherV2EightRankFP16M0Map = { + {128, + {{-1, 2274, -1, 2147483647, -1, 6172}, {-1, 2274, 6700, 2147483647, 6172, 6934}, + {2274, 2786, 8200, 2147483647, -1, 6934}, {2786, 2147483647, -1, 6950, -1, 5360}, + {2786, 2147483647, 6950, 2147483647, -1, 6934}, {-1, 2147483647, -1, 2274, 6934, 2147483647}, + {-1, 2147483647, 2274, 4810, 6934, 7434}, {-1, 2147483647, 2274, 4810, 7946, 2147483647}, + {-1, 2147483647, 4810, 2147483647, 6934, 2147483647}}}, + {256, + {{-1, 2274, -1, 6700, 6172, 6934}, {2274, 2786, -1, 8200, -1, 6934}, + {2786, 2147483647, -1, 6950, 5360, 6934}, {-1, 2147483647, 2274, 4810, 7434, 7946}}} + }; + + static std::map>> g_allgatherV2EightRankFP16UbmovenumMap = { + {2.0, + {{-1, 768, -1, 2560, -1, 576}, {-1, 768, -1, 3584, 832, 2274}, + {768, 2147483647, -1, 1774, -1, 2274}, {-1, 2147483647, -1, 8200, 2274, 2626}, + {-1, 2147483647, -1, 2147483647, 2626, 2147483647}}}, + {3.0, + {{-1, 768, 2560, 2147483647, -1, 576}, {-1, 768, -1, 2147483647, 576, 832}, + {-1, 768, 3584, 2147483647, 832, 2274}, {768, 2147483647, 1774, 2147483647, -1, 2274}, + {-1, 2147483647, 8200, 2147483647, 2274, 2626}}} + }; + + static std::map>> g_allgatherV2EightRankFP16PvalueMap = { + {2, + {{-1, 2786, -1, 2560, -1, 576}, {1280, 2786, -1, 2147483647, 576, 832}, + {-1, 2786, -1, 4608, 832, 1200}, {1262, 2786, -1, 2147483647, 1720, 2274}, + {2786, 3286, -1, 7708, -1, 768}, {2786, 3286, -1, 2147483647, 768, 2274}, + {3286, 4298, -1, 2147483647, -1, 1262}, {4298, 7450, 4810, 2147483647, -1, 768}, + {-1, 2147483647, -1, 768, 2274, 3584}, {4608, 2147483647, -1, 768, 5660, 7680}, + {-1, 2147483647, 768, 5900, 2912, 3248}, {-1, 2147483647, 6450, 8446, 8446, 2147483647}, + {-1, 768, 11264, 2147483647, 2274, 7680}, {-1, 768, 8446, 2147483647, 7680, 2147483647}, + {768, 1262, 8446, 9728, 4636, 2147483647}, {768, 1262, 9728, 2147483647, 6684, 2147483647}, + {1262, 2147483647, 8446, 9728, 6450, 2147483647}, {1262, 2147483647, 9728, 10720, 6684, 2147483647}, + {1262, 2147483647, 10720, 2147483647, 5104, 2147483647}}}, + {1, + {{-1, 2786, 2560, 2147483647, -1, 576}, {-1, 1280, -1, 2147483647, 576, 832}, + {-1, 2786, 4608, 2147483647, 832, 1200}, {-1, 2786, -1, 2147483647, 1200, 1720}, + {-1, 1262, -1, 2147483647, 1720, 2274}, {2786, 3286, 7708, 2147483647, -1, 768}, + {3286, 4298, -1, 2147483647, 1774, 2274}, {7450, 2147483647, 4810, 2147483647, 1774, 2274}, + {-1, 4608, -1, 768, 5660, 7680}, {-1, 2147483647, -1, 768, 7680, 2147483647}, + {-1, 2147483647, 768, 5900, 2274, 2912}, {-1, 2147483647, 768, 5900, 3248, 2147483647}, + {-1, 2147483647, 5900, 8446, 2274, 8446}, {-1, 2147483647, 5900, 6450, 8446, 2147483647}, + {-1, 768, 8446, 11264, 2274, 7680}, {768, 1262, 8446, 9728, 2274, 4636}, + {768, 1262, 9728, 2147483647, 2274, 6684}, {1262, 2147483647, 8446, 9728, 2274, 6450}, + {1262, 2147483647, 9728, 10720, 2274, 6684}, {1262, 2147483647, 10720, 2147483647, 2274, 5104}}}, + {4, + {{3286, 4298, -1, 2147483647, 1262, 1774}, {4298, 8958, -1, 4298, -1, 2274}, + {8958, 2147483647, -1, 4810, -1, 1536}, {4298, 7450, 4810, 2147483647, 768, 2274}, + {7450, 2147483647, 4810, 2147483647, -1, 1774}, {-1, 2147483647, -1, 768, 3584, 5660}}}, + {6, + {{4298, 8958, 4298, 4810, -1, 2274}, {8958, 2147483647, -1, 4810, 1536, 2274}}} + }; + + void AllGatherV2EightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLGATHERV2_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allgatherV2EightRankFP16CommdatasplitMap}}, + {&cocTilingData.m0, + {ALLGATHERV2_EIGHT_RANK_FP16_M0_DEFAULT, + g_allgatherV2EightRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLGATHERV2_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allgatherV2EightRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHERV2_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allgatherV2EightRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}}, + {&cocTilingData.commDirect, {COMM_NPU_DIRECT}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + int32_t coreNum = cocTilingData.blockDim - cocTilingData.rankSize; + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit >= COMMDATASPLIT_EIGHT ? COMMNPUSPLIT_ONE : COMMNPUSPLIT_THREE; + cocTilingData.commNpuSplit = std::min(cocTilingData.commNpuSplit, cocTilingData.rankSize); + cocTilingData.commDataSplit = + ClampValue(cocTilingData.commDataSplit, COMMDATASPLIT_ONE, coreNum / cocTilingData.commNpuSplit); + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } + + void AllGatherV2EightRankFP16Core16GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLGATHERV2_EIGHT_RANK_FP16_CORE16_M0_DEFAULT, + g_allgatherV2EightRankFP16CorE16M0Map}}, + {&cocTilingData.commDataSplit, + {ALLGATHERV2_EIGHT_RANK_FP16_CORE16_COMMDATASPLIT_DEFAULT, + g_allgatherV2EightRankFP16CorE16CommdatasplitMap}}, + {&cocTilingData.commDirect, + {ALLGATHERV2_EIGHT_RANK_FP16_CORE16_COMMDIRECT_DEFAULT, + g_allgatherV2EightRankFP16CorE16CommdirectMap}}, + {&cocTilingData.ubMoveNum, + {ALLGATHERV2_EIGHT_RANK_FP16_CORE16_UBMOVENUM_DEFAULT, + g_allgatherV2EightRankFP16CorE16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLGATHERV2_EIGHT_RANK_FP16_CORE16_PVALUE_DEFAULT, + g_allgatherV2EightRankFP16CorE16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + int32_t coreNum = cocTilingData.blockDim - cocTilingData.rankSize; + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit >= COMMDATASPLIT_EIGHT ? COMMNPUSPLIT_ONE : cocTilingData.rankSize; + cocTilingData.commDataSplit = + ClampValue(cocTilingData.commDataSplit, COMMDATASPLIT_ONE, coreNum / cocTilingData.commNpuSplit); + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum * cocTilingData.commDataSplit; + + DealTilingParamByBuffSize(cocTilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allreduce_tiling.cpp b/comm/lcal/src/tiling/allreduce_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e4eff5ebd161cd933c9e596116cb6a56997e1dab --- /dev/null +++ b/comm/lcal/src/tiling/allreduce_tiling.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tiling.h" +#include "lcoc_func.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" + +namespace Lcal { +const int ALLREDUCE_M_EDGE = 3072; +const int ALLREDUCE_N_EDGE = 3072; + +void CoCMatmulAllReduceTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + if (Is91093(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_EIGHT) { + AllReduceNPU91093EightRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_SIXTEEN) { + AllReduceNPU91093SixteenRankFP16Tiling(cocTilingData); + return; + } + } + if (cocTilingData.rankSize == RANKSIZE_FOUR) { + if (taskParam.cocParamDesc.mmInfo.isInt8) { + AllReduceFourRankInt8GetDefaultTiling(cocTilingData); + return; + } else { + AllReduceFourRankFP16GetDefaultTiling(cocTilingData); + return; + } + } else if (cocTilingData.rankSize == RANKSIZE_TWO) { + AllReduceTwoRankFP16Tiling(cocTilingData); + return; + } + AllReduceGetDefaultTiling(cocTilingData); +} + +void CoCMatmulAllReduceDeterTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + if (cocTilingData.rankSize == RANKSIZE_FOUR) { + if (taskParam.cocParamDesc.mmInfo.isInt8) { + AllReduceFourRankInt8GetDefaultTiling(cocTilingData); + } else { + AllReduceFourRankFP16GetDefaultTiling(cocTilingData); + } + } else { + if (taskParam.cocParamDesc.mmInfo.isInt8) { + AllReduceEightRankINT8GetDefaultTiling(cocTilingData); + } else { + AllReduceEightRankFP16GetDefaultTiling(cocTilingData); + } + } + if (cocTilingData.m * cocTilingData.n >= ALLREDUCE_M_EDGE * ALLREDUCE_N_EDGE) { + cocTilingData.lenPerLoop = ALLREDUCE_LENPERLOOP_DEFAULT / RANKSIZE_EIGHT * cocTilingData.rankSize; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + cocTilingData.ubMoveNum = cocTilingData.lenPerLoop; + cocTilingData.extraLenPerLoop = cocTilingData.lenPerLoop; + cocTilingData.extraUbMoveNum = cocTilingData.ubMoveNum; + } + if (cocTilingData.lenPerLoop > TREE_LEN_PER_LOOP) { + cocTilingData.lenPerLoop = TREE_LEN_PER_LOOP; + cocTilingData.ubMoveNum = TREE_LEN_PER_LOOP; + cocTilingData.extraLenPerLoop = cocTilingData.lenPerLoop; + cocTilingData.extraUbMoveNum = cocTilingData.ubMoveNum; + } +} + +bool CheckCMatrix(const TaskParam &taskParam, const CoCTilingData &data) +{ + constexpr int32_t BUFFER_UNIT = 1024; + if (data.withSerialMode != 0 && + data.batchSize * data.m * data.n >= (taskParam.bufferSize * BUFFER_UNIT * BUFFER_UNIT) + / INPUT_DTYPE / MAX_BLOCK_COUNT) { + std::string str = "The matrix c is too large to support serial. " + "withSerialMode: " + std::to_string(data.withSerialMode) + + ", batchSize: " + std::to_string(data.batchSize) + + ", m: " + std::to_string(data.m) + + ", n: " + std::to_string(data.n); + PrintErrorLog(taskParam.lcalType, str); + return false; + } + return true; +} + +bool CoCMatmulAllReduceTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCTilingFunc::CheckTiling(taskParam)) { + return false; + } + if (!CheckCMatrix(taskParam, cocTilingData)) { + return false; + } + + auto rankSize = cocTilingData.rankSize; + auto commNpuSplit = cocTilingData.commNpuSplit; + auto commDataSplit = cocTilingData.commDataSplit; + auto coreNum = cocTilingData.blockDim; + int32_t useCoreCount = commNpuSplit * commDataSplit; + + std::vector> paramCheckList = { + {"commNpuSplit * commDataSplit", useCoreCount, rankSize, coreNum}, + {"commNpuSplit", commNpuSplit, PARAM_CHECK_MIN_VALUE_ONE, rankSize} + }; + return CheckParamScopeList(paramCheckList); +} + +bool CoCMatmulAllReduceDeterTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCMatmulAllReduceTilingFunc::CheckTiling(taskParam)) { + return false; + } + + auto commNpuSplit = cocTilingData.commNpuSplit; + if (commNpuSplit != 1) { + std::string str = "The product of commNpuSplit must equal 1. commNpuSplit: " + std::to_string(commNpuSplit); + PrintErrorLog(taskParam.lcalType, str); + return false; + } + return true; +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allreduce_tiling_91093.cpp b/comm/lcal/src/tiling/allreduce_tiling_91093.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4c40fc6bee8c22dda44ce511f684e43dcc71b600 --- /dev/null +++ b/comm/lcal/src/tiling/allreduce_tiling_91093.cpp @@ -0,0 +1,261 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_91093.h" +#include "tiling_func.h" + +namespace Lcal { + constexpr int32_t ALLREDUCE_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 160; + constexpr int32_t ALLREDUCE_91093_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t ALLREDUCE_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t ALLREDUCE_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t ALLREDUCE_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT = 160; + constexpr int32_t ALLREDUCE_91093_SIXTEEN_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_91093_SIXTEEN_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + + static std::map>> g_allreduce91093EightRankFP16CommdatasplitMap = { + {1, + {{-1, 3072, -1, 2147483647, -1, 768}, {-1, 768, -1, 2147483647, 768, 1536}, + {768, 1536, -1, 7170, 768, 1536}, {1536, 3072, 3072, 6144, 768, 1536}, + {3072, 5148, 3072, 5634, -1, 768}, {-1, 1280, 384, 768, 1536, 3072}, + {-1, 2304, 768, 4000, 1536, 3072}, {-1, 768, 4000, 5634, 1536, 5376}, + {-1, 768, 5634, 2147483647, 1536, 3072}, {768, 1536, 4000, 5634, 1536, 3072}}}, + {16, + {{768, 1536, 7170, 2147483647, 768, 1536}, {1536, 3072, -1, 3072, 768, 1536}, + {1536, 3072, 6144, 2147483647, 768, 1536}, {3072, 5148, -1, 3072, -1, 768}, + {3072, 5148, -1, 5634, 768, 1536}, {3072, 5148, 5634, 2147483647, -1, 1536}, + {5148, 2147483647, -1, 2147483647, -1, 1536}, {-1, 2147483647, -1, 384, 1536, 3072}, + {1280, 2147483647, 384, 768, 1536, 3072}, {2304, 2147483647, 768, 4000, 1536, 3072}, + {-1, 2147483647, -1, 4000, 3072, 2147483647}, {-1, 768, 4000, 5634, 5376, 2147483647}, + {-1, 768, 5634, 2147483647, 3072, 2147483647}, {1536, 2147483647, 4000, 5634, 1536, 3072}, + {768, 2147483647, 5634, 2147483647, 1536, 3072}, {768, 2147483647, 4000, 2147483647, 3072, 2147483647}}} + }; + + static std::map>> g_allreduce91093EightRankFP16PvalueMap = { + {4, + {{-1, 3072, -1, 3072, -1, 768}, {5148, 31220, 3072, 4608, -1, 768}, + {-1, 3072, 768, 4608, 768, 1536}, {31220, 43740, 3072, 2147483647, -1, 1536}, + {43740, 53340, -1, 2147483647, -1, 1536}, {62400, 68160, -1, 2147483647, -1, 768}, + {68160, 2147483647, 3072, 7170, -1, 768}, {68160, 2147483647, 4608, 2147483647, 768, 1536}, + {3072, 19680, 1280, 7170, 1536, 7424}, {7350, 2147483647, 7170, 11264, 1536, 7424}, + {-1, 11904, 2976, 2147483647, 7424, 19456}, {11904, 2147483647, 7170, 2147483647, 7424, 2147483647}}}, + {1, + {{-1, 5148, 3072, 2147483647, -1, 768}, {-1, 3072, 4608, 2147483647, 768, 1536}, + {-1, 2147483647, 11264, 2147483647, 1536, 7424}}}, + {6, + {{3072, 31220, -1, 3072, -1, 768}, {3072, 31220, 768, 4608, 768, 1536}, + {68160, 2147483647, 3072, 4608, 768, 1536}, {19680, 2147483647, 5634, 7170, 1536, 7424}}}, + {2, + {{5148, 31220, 4608, 2147483647, -1, 768}, {3072, 31220, 4608, 2147483647, 768, 1536}, + {68160, 2147483647, 7170, 2147483647, -1, 768}, {-1, 3072, 1280, 7170, 1536, 7424}, + {-1, 7350, 7170, 11264, 1536, 7424}, {-1, 11904, 2976, 2147483647, 19456, 2147483647}}}, + {10, + {{-1, 31220, -1, 768, 768, 1536}, {31220, 43740, -1, 3072, -1, 1536}, + {53340, 62400, -1, 2147483647, -1, 1536}, {62400, 68160, -1, 2147483647, 768, 1536}, + {112280, 2147483647, -1, 3072, -1, 768}, {68160, 2147483647, 1536, 3072, 768, 1536}, + {-1, 38592, 768, 1280, 1536, 3072}, {-1, 5148, -1, 1280, 7424, 13312}, + {38592, 68160, 768, 1280, 1536, 3072}, {19680, 2147483647, 1280, 5634, 1536, 7424}, + {-1, 14336, 1280, 1792, 7424, 2147483647}, {11904, 2147483647, 2976, 7170, 7424, 2147483647}}}, + {12, + {{68160, 112280, -1, 3072, -1, 768}, {-1, 38592, -1, 768, 1536, 7424}, + {-1, 38592, 768, 1280, 3072, 7424}, {-1, 38592, 768, 1280, 13312, 2147483647}, + {68160, 2147483647, 768, 1280, 1536, 3072}, {14336, 2147483647, 1280, 1792, 7424, 2147483647}, + {-1, 2147483647, 1792, 2976, 7424, 2147483647}}}, + {14, + {{68160, 2147483647, -1, 1536, 768, 1536}, {5148, 38592, -1, 1280, 7424, 13312}, + {-1, 38592, -1, 768, 13312, 2147483647}, {38592, 2147483647, -1, 768, 1536, 2147483647}, + {38592, 2147483647, 768, 1280, 3072, 2147483647}}} + }; + + static std::map>> g_allreduce91093EightRankFP16M0Map = { + {128, + {{-1, 3072, -1, 2147483647, -1, 10240}, {-1, 3072, -1, 3072, 10240, 19456}, + {3072, 2147483647, -1, 2147483647, -1, 19456}, {1536, 2147483647, -1, 2147483647, 19456, 2147483647}}}, + {256, + {{-1, 3072, 3072, 2147483647, 10240, 19456}, {-1, 1536, -1, 2147483647, 19456, 2147483647}}} + }; + + static std::map>> g_allreduce91093EightRankFP16UbmovenumMap = { + {80, + {{-1, 768, -1, 7170, -1, 768}, {31220, 36980, -1, 2147483647, -1, 768}, + {-1, 10010, -1, 3072, 1536, 3072}, {-1, 768, 3072, 2147483647, 1536, 3072}}}, + {100, + {{-1, 768, 7170, 2147483647, -1, 768}}}, + {140, + {{768, 3072, -1, 2147483647, -1, 768}}}, + {60, + {{3072, 23040, -1, 3072, -1, 768}, {-1, 36980, -1, 1536, 768, 1536}, + {10010, 36980, -1, 3072, 1536, 3072}, {36980, 2147483647, -1, 1536, -1, 3072}, + {-1, 2147483647, -1, 1280, 3072, 2147483647}}}, + {20, + {{3072, 23040, 3072, 2147483647, -1, 768}, {-1, 36980, 3072, 2147483647, 768, 1536}, + {768, 36980, 3072, 2147483647, 1536, 3072}, {36980, 142040, 3072, 4608, -1, 768}, + {36980, 2147483647, 3072, 4608, 768, 3072}, {768, 2147483647, 2976, 4608, 3072, 2147483647}, + {768, 5148, 4608, 5634, 3072, 2147483647}, {5148, 10010, 4608, 5634, 3072, 9472}, + {-1, 768, 5634, 2147483647, 3072, 2147483647}}}, + {10, + {{23040, 31220, -1, 2147483647, -1, 768}, {142040, 2147483647, 3072, 4608, -1, 768}, + {36980, 2147483647, 4608, 2147483647, -1, 3072}, {5148, 10010, 4608, 5634, 9472, 2147483647}, + {10010, 2147483647, 4608, 5634, 3072, 2147483647}, {768, 2147483647, 5634, 2147483647, 3072, 2147483647}}}, + {30, + {{-1, 36980, 1536, 3072, 768, 1536}, {36980, 2147483647, 1536, 3072, -1, 3072}, + {-1, 2147483647, 1280, 2976, 3072, 2147483647}, {-1, 768, 2976, 4608, 3072, 2147483647}}}, + {160, + {{-1, 768, 4608, 5634, 3072, 2147483647}}} + }; + + static std::map>> g_allreduce91093SixteenRankFP16CommdatasplitMap = { + {1, + {{-1, 36980, -1, 2147483647, -1, 768}, {36980, 74380, -1, 7170, -1, 768}, + {74380, 82060, -1, 3072, -1, 768}, {-1, 82060, -1, 1536, 768, 1536}, + {-1, 23040, 1536, 2147483647, 768, 1536}, {23040, 82060, 5634, 2147483647, 768, 1536}, + {82060, 2147483647, -1, 1536, -1, 1536}, {82060, 112280, 1536, 3072, -1, 1536}, + {129600, 2147483647, 1536, 3072, -1, 1536}, {176600, 222720, 3072, 2147483647, 768, 1536}, + {-1, 2147483647, -1, 2976, 1536, 10240}, {-1, 107968, -1, 2976, 10240, 13312}, + {107968, 2147483647, -1, 1792, 10240, 13312}, {-1, 2147483647, -1, 1536, 13312, 2147483647}, + {-1, 75840, 1536, 2976, 13312, 2147483647}, {-1, 11904, 2976, 2147483647, 1536, 3072}, + {-1, 3072, 2976, 2147483647, 3072, 2147483647}, {3072, 11904, 2976, 2147483647, 3072, 13312}, + {11904, 2147483647, 5634, 2147483647, 5376, 7424}}}, + {16, + {{36980, 74380, 7170, 2147483647, -1, 768}, {74380, 82060, 3072, 2147483647, -1, 768}, + {23040, 82060, 1536, 5634, 768, 1536}, {112280, 129600, 1536, 3072, -1, 1536}, + {82060, 2147483647, 3072, 2147483647, -1, 768}, {82060, 176600, 3072, 2147483647, 768, 1536}, + {222720, 2147483647, 3072, 2147483647, 768, 1536}, {107968, 2147483647, 1792, 2976, 10240, 13312}, + {75840, 2147483647, 1536, 2976, 13312, 2147483647}, {3072, 11904, 2976, 2147483647, 13312, 2147483647}, + {11904, 2147483647, 2976, 5634, 1536, 2147483647}, {11904, 2147483647, 5634, 2147483647, 1536, 5376}, + {11904, 2147483647, 5634, 2147483647, 7424, 2147483647}}} + }; + + static std::map>> g_allreduce91093SixteenRankFP16M0Map = { + {128, + {{-1, 2147483647, -1, 2147483647, -1, 3072}, {-1, 2147483647, -1, 2976, 3072, 2147483647}}}, + {256, + {{-1, 2147483647, 2976, 2147483647, 3072, 2147483647}}} + }; + + static std::map>> g_allreduce91093SixteenRankFP16UbmovenumMap = { + {60, + {{-1, 768, -1, 5634, -1, 768}, {3072, 2147483647, -1, 1536, -1, 1536}, + {3072, 36980, 1536, 3072, -1, 1536}, {-1, 15412, -1, 2976, 5376, 2147483647}, + {15412, 2147483647, -1, 2976, 1536, 13312}, {15412, 2147483647, -1, 1536, 13312, 2147483647}}}, + {20, + {{-1, 768, 5634, 2147483647, -1, 768}, {10320, 2147483647, 3072, 4608, -1, 1536}, + {3072, 2147483647, 4608, 2147483647, -1, 1536}, {-1, 15412, 3072, 2147483647, 1536, 5376}, + {-1, 15412, 2976, 2147483647, 5376, 2147483647}, {15412, 2147483647, 2976, 2147483647, 1536, 13312}, + {15412, 2147483647, 3072, 2147483647, 13312, 2147483647}}}, + {160, + {{768, 3072, -1, 384, -1, 768}}}, + {80, + {{768, 3072, 384, 2147483647, -1, 768}}}, + {120, + {{-1, 1536, -1, 4608, 768, 1536}, {1536, 3072, 640, 2147483647, 768, 1536}}}, + {40, + {{-1, 1536, 4608, 2147483647, 768, 1536}}}, + {140, + {{1536, 3072, -1, 640, 768, 1536}}}, + {30, + {{36980, 2147483647, 1536, 3072, -1, 1536}, {3072, 10320, 3072, 4608, -1, 1536}, + {-1, 15412, 1536, 3072, 1536, 5376}, {15412, 2147483647, 1536, 3072, 13312, 2147483647}}}, + {100, + {{-1, 15412, -1, 1536, 1536, 5376}}} + }; + + static std::map>> g_allreduce91093SixteenRankFP16PvalueMap = { + {4, + {{-1, 3072, -1, 4608, -1, 768}, {5148, 31220, -1, 4608, -1, 768}, + {10010, 36980, 4608, 2147483647, 768, 1536}, {36980, 53340, 1536, 2147483647, -1, 1536}, + {53340, 68160, -1, 2147483647, -1, 768}, {68160, 74380, 3586, 2147483647, -1, 768}, + {1536, 3072, 2976, 2147483647, 3072, 2147483647}, {3072, 16340, 4608, 7170, 1536, 2147483647}, + {3072, 2147483647, 7170, 2147483647, 1536, 2147483647}}}, + {1, + {{-1, 5148, 4608, 2147483647, -1, 768}, {-1, 1536, 2976, 7170, 1536, 5376}, + {-1, 1536, 7170, 2147483647, 1536, 2147483647}}}, + {8, + {{3072, 5148, -1, 4608, -1, 768}, {-1, 19680, 768, 3072, 768, 1536}, + {-1, 1536, -1, 384, 10240, 2147483647}, {-1, 2560, 384, 1280, 7424, 2147483647}}}, + {12, + {{31220, 36980, -1, 4608, -1, 768}, {-1, 36980, -1, 768, 768, 1536}, + {68160, 74380, -1, 3586, -1, 768}, {68160, 74380, 1536, 2147483647, 768, 1536}, + {-1, 19680, -1, 384, 1536, 7424}, {-1, 6298, -1, 384, 7424, 10240}}}, + {2, + {{5148, 36980, 4608, 2147483647, -1, 768}, {-1, 10010, 3072, 2147483647, 768, 1536}, + {-1, 1536, 2976, 7170, 5376, 2147483647}, {1536, 3072, 2976, 2147483647, 1536, 3072}}}, + {10, + {{19680, 36980, 768, 3072, 768, 1536}, {142040, 2147483647, 1536, 2147483647, -1, 768}, + {74380, 189080, 3072, 2147483647, 768, 1536}, {19680, 2147483647, 2976, 4608, 1536, 5376}, + {3072, 2147483647, 2976, 4608, 5376, 2147483647}}}, + {6, + {{10010, 36980, 3072, 4608, 768, 1536}, {74380, 142040, 1536, 2147483647, -1, 768}, + {189080, 2147483647, 5634, 2147483647, 768, 1536}, {-1, 19680, 1536, 2976, 1536, 7424}, + {3072, 19680, 2976, 4608, 1536, 5376}, {16340, 2147483647, 4608, 7170, 1536, 2147483647}}}, + {14, + {{36980, 53340, -1, 1536, -1, 1536}, {53340, 68160, -1, 2147483647, 768, 1536}, + {68160, 74380, -1, 1536, 768, 1536}, {74380, 2147483647, -1, 1536, -1, 768}, + {74380, 189080, -1, 3072, 768, 1536}, {189080, 2147483647, -1, 5634, 768, 1536}, + {-1, 19680, 384, 1536, 1536, 7424}, {19680, 2147483647, -1, 2976, 1536, 7424}, + {6298, 2147483647, -1, 384, 7424, 10240}, {1536, 2147483647, -1, 384, 10240, 2147483647}, + {2560, 2147483647, 384, 1280, 7424, 2147483647}, {-1, 2147483647, 1280, 2976, 7424, 2147483647}}} + }; + + void AllReduceNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLREDUCE_91093_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allreduce91093EightRankFP16CommdatasplitMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allreduce91093EightRankFP16PvalueMap}}, + {&cocTilingData.m0, + {ALLREDUCE_91093_EIGHT_RANK_FP16_M0_DEFAULT, + g_allreduce91093EightRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allreduce91093EightRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit == COMMDATASPLIT_ONE ? cocTilingData.rankSize : COMMNPUSPLIT_ONE; + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLREDUCE_91093_SIXTEEN_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allreduce91093SixteenRankFP16CommdatasplitMap}}, + {&cocTilingData.m0, + {ALLREDUCE_91093_SIXTEEN_RANK_FP16_M0_DEFAULT, + g_allreduce91093SixteenRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT, + g_allreduce91093SixteenRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT, + g_allreduce91093SixteenRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ZERO}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + cocTilingData.commNpuSplit = + cocTilingData.commDataSplit == COMMDATASPLIT_ONE ? cocTilingData.rankSize : COMMNPUSPLIT_ONE; + SetSecondCoreSplitTling(cocTilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/allreduce_tiling_910B.cpp b/comm/lcal/src/tiling/allreduce_tiling_910B.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fee13a288b46f2aa4e250a8c4fe13e28e7d2cb77 --- /dev/null +++ b/comm/lcal/src/tiling/allreduce_tiling_910B.cpp @@ -0,0 +1,663 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_910B.h" +#include "tiling_func.h" +#include "lcal_types.h" +namespace Lcal { + const int32_t ALLREDUCE_SERIAL_MODE_K_SIZE = 8192; + const int64_t ALLREDUCE_SERIAL_MODE_MN_SIZE = 256 * 256 * 12; + + constexpr int32_t ALLREDUCE_FOUR_RANK_FP16_DATASPLIT_DEFAULT = 32; + constexpr int32_t ALLREDUCE_FOUR_RANK_FP16_PVALUE_DEFAULT = 8; + constexpr int32_t ALLREDUCE_FOUR_RANK_FP16_UBMOVENUM_DEFAULT = 30; + constexpr int32_t ALLREDUCE_FOUR_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_FOUR_RANK_INT8_UBMOVENUM_DEFAULT = 40; + constexpr int32_t ALLREDUCE_FOUR_RANK_INT8_PVALUE_DEFAULT = 8; + constexpr int32_t ALLREDUCE_FOUR_RANK_INT8_DATASPLIT_DEFAULT = 32; + constexpr int32_t ALLREDUCE_FOUR_RANK_INT8_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 100; + constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_DATASPLIT_DEFAULT = 16; + constexpr int32_t ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_EIGHT_RANK_INT8_UBMOVENUM_DEFAULT = 100; + constexpr int32_t ALLREDUCE_EIGHT_RANK_INT8_PVALUE_DEFAULT = 14; + constexpr int32_t ALLREDUCE_EIGHT_RANK_INT8_DATASPLIT_DEFAULT = 8; + constexpr int32_t ALLREDUCE_EIGHT_RANK_INT8_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_PVALUE_DEFAULT = 6; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_SWIZZLCOUNT_DEFAULT = 8; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_SWIZZLDIRECT_DEFAULT = 0; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_UBMOVENUM_DEFAULT = 6; + constexpr int32_t ALLREDUCE_TWO_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + + static std::vector g_allreduceUbmovenumCoef = { + { -1.72352427e+01, 2.56887672e-03, -8.21819480e+00, 8.70965589e+01, -3.63853858e-01, 1.27789264e+01, + 1.29782183e+02, 1.90250023e-02, -3.48175441e+00, 6.18921914e+03, 3.77072171e+03, -5.86895290e+01, + -8.70740991e-01, -1.40262280e-04, -2.81910331e-08, 3.22795486e-05, -4.84522320e-03, 2.94839177e-01, + 2.97260958e-03, 9.08844709e+01, -5.80426209e-10, 38.183465184603484 } + }; + static std::vector g_allreducePvalueCoef = { + { -4.23166350e+00, 6.71137487e-04, -1.33434156e+00, 1.12915884e+01, -7.85892737e-02, 2.59059897e+00, + 3.22129881e+01, -5.15776887e-02, 9.15542742e-01, 1.56322201e+03, 3.61977421e+01, -5.49544589e-01, + -2.66903417e-01, -3.68521920e-05, -6.40666333e-09, 6.77406054e-06, -9.92992099e-04, 5.60658043e-02, + 2.69372863e-04, 2.17222337e+01, -1.17749660e-10, 6.100544547671263 } + }; + + static std::map>> g_allreduceFourRankInT8M0Map = { + {256, + {{-1, 3072, -1, 2147483647, -1, 768}}} + }; + + static std::map>> g_allreduceFourRankInT8DatasplitMap = { + {1, + {{-1, 768, -1, 2147483647, -1, 768}}}, + {2, + {{-1, 1536, -1, 2147483647, 768, 1536}}}, + {4, + {{768, 10010, -1, 2147483647, -1, 768}}}, + {8, + {{1536, 10010, -1, 2147483647, 768, 1536}, {10010, 2147483647, 3072, 2147483647, -1, 768}, + {-1, 19680, 7170, 2147483647, 1536, 7424}, {-1, 7350, 7170, 2147483647, 7424, 2147483647}}}, + {16, + {{10010, 2147483647, 3072, 2147483647, 768, 1536}, {19680, 2147483647, 7170, 2147483647, 1536, 7424}, + {7350, 2147483647, 7170, 2147483647, 7424, 2147483647}}}, + {32, + {{10010, 2147483647, -1, 3072, -1, 1536}, {-1, 2147483647, -1, 7170, 1536, 2147483647}}} + }; + + static std::map>> g_allreduceFourRankInT8PvalueMap = { + {1, + {{-1, 10010, -1, 2147483647, -1, 768}, {-1, 5148, -1, 2147483647, 768, 1536}}}, + {2, + {{5148, 10010, -1, 2147483647, 768, 1536}, {10010, 2147483647, 3072, 2147483647, -1, 768}, + {-1, 19680, 7170, 2147483647, 1536, 7424}, {-1, 7350, 7170, 2147483647, 7424, 2147483647}}}, + {4, + {{19680, 2147483647, 7170, 2147483647, 1536, 7424}, {7350, 2147483647, 7170, 2147483647, 7424, 2147483647}, + {10010, 2147483647, 3072, 2147483647, 768, 1536}}}, + {6, + {{10010, 36980, -1, 3072, -1, 1536}}}, + {8, + {{36980, 2147483647, -1, 3072, -1, 1536}, {-1, 2147483647, -1, 7170, 1536, 2147483647}}} + }; + + static std::map>> g_allreduceFourRankInT8UbmovenumMap = { + {20, + {{53340, 2147483647, 7170, 2147483647, -1, 3072}}}, + {30, + {{-1, 53340, -1, 2147483647, -1, 3072}, {53340, 2147483647, -1, 4608, -1, 768}, + {53340, 2147483647, 4608, 7170, -1, 3072}, {-1, 7196, -1, 2147483647, 3072, 2147483647}, + {10010, 15412, -1, 2147483647, 3072, 2147483647}, {15412, 2147483647, 5634, 2147483647, 3072, 2147483647}}}, + {40, + {{53340, 2147483647, -1, 4608, 768, 3072}, {7196, 10010, -1, 2147483647, 3072, 2147483647}, + {15412, 2147483647, -1, 5634, 3072, 2147483647}}} + }; + + static std::map>> g_allreduceFourRankFP16M0Map = { + {256, + {{-1, 12980, -1, 2147483647, -1, 768}, {12980, 2147483647, -1, 5634, -1, 768}, + {-1, 63360, -1, 4608, 768, 2147483647}, {63360, 2147483647, -1, 4000, 768, 2147483647}, + {63360, 2147483647, 4000, 4608, 1536, 2147483647}, {-1, 5148, 4608, 11264, 768, 2147483647}, + {-1, 2560, 11264, 2147483647, 768, 2147483647}, {5148, 19680, 4608, 2147483647, 13312, 2147483647}}}, + {128, + {{12980, 2147483647, 5634, 2147483647, -1, 768}, {63360, 2147483647, 4000, 4608, 768, 1536}, + {2560, 5148, 11264, 2147483647, 768, 2147483647}, {5148, 19680, 4608, 2147483647, 768, 13312}, + {19680, 2147483647, 4608, 2147483647, 768, 2147483647}}} + }; + + static std::map>> g_allreduceFourRankFP16UbmovenumMap = { + {20, + {{-1, 1536, -1, 2147483647, -1, 1536}, {-1, 1536, 7170, 2147483647, 1536, 19456}, + {1536, 2147483647, 5634, 2147483647, -1, 19456}, {-1, 2147483647, -1, 2147483647, 19456, 2147483647}}}, + {30.0, + {{-1, 1536, -1, 7170, 1536, 19456}, {1536, 2147483647, -1, 5634, -1, 19456}}} + }; + + static std::map>> g_allreduceFourRankFP16PvalueMap = { + {2, + {{-1, 5148, -1, 1536, -1, 1536}, {-1, 5148, 1152, 4608, 3072, 5376}, + {5148, 31220, 3072, 2147483647, -1, 1536}, {-1, 3072, -1, 2147483647, 10240, 2147483647}, + {3072, 5148, 7170, 2147483647, 5376, 2147483647}, {13364, 142040, 7170, 2147483647, 5376, 7424}, + {142040, 2147483647, 7170, 2147483647, 5376, 2147483647}}}, + {1, + {{-1, 5148, 1536, 2147483647, -1, 3072}, {-1, 5148, 4608, 2147483647, 3072, 5376}, + {-1, 3072, -1, 2147483647, 5376, 10240}}}, + {8, + {{-1, 5148, -1, 1536, 1536, 3072}, {68160, 2147483647, -1, 3072, -1, 768}, + {16340, 2147483647, -1, 3072, 768, 5376}, {5148, 13364, -1, 2976, 5376, 2147483647}, + {13364, 2147483647, -1, 5634, 5376, 2147483647}, {13364, 2147483647, 5634, 7170, 10240, 2147483647}}}, + {4, + {{-1, 5148, -1, 1152, 3072, 5376}, {5148, 68160, -1, 3072, -1, 768}, + {5148, 16340, -1, 3072, 768, 5376}, {5148, 31220, 3072, 2147483647, 1536, 5376}, + {31220, 2147483647, 3072, 2147483647, -1, 5376}, {3072, 5148, -1, 7170, 5376, 2147483647}, + {5148, 13364, 2976, 2147483647, 5376, 2147483647}, {13364, 2147483647, 5634, 7170, 5376, 10240}, + {13364, 142040, 7170, 2147483647, 7424, 2147483647}}} + }; + + static std::map>> g_allreduceFourRankFP16DatasplitMap = { + {8, + {{-1, 5148, -1, 3072, -1, 1536}, {-1, 5148, 1152, 4608, 3072, 5376}, + {5148, 68160, 3072, 2147483647, -1, 1536}, {-1, 3072, -1, 2147483647, 10240, 2147483647}, + {3072, 5148, 7170, 2147483647, 5376, 2147483647}, {13364, 142040, 7170, 2147483647, 5376, 7424}, + {142040, 2147483647, 7170, 2147483647, 5376, 2147483647}}}, + {4, + {{-1, 5148, 3072, 2147483647, -1, 1536}, {-1, 5148, 1536, 2147483647, 1536, 3072}, + {-1, 5148, 4608, 2147483647, 3072, 5376}, {-1, 3072, -1, 2147483647, 5376, 10240}}}, + {32, + {{-1, 5148, -1, 1536, 1536, 3072}, {68160, 2147483647, -1, 3072, -1, 768}, + {16340, 2147483647, -1, 3072, 768, 5376}, {5148, 13364, -1, 2976, 5376, 2147483647}, + {13364, 2147483647, -1, 5634, 5376, 2147483647}, {13364, 2147483647, 5634, 7170, 10240, 2147483647}}}, + {16, + {{-1, 5148, -1, 1152, 3072, 5376}, {5148, 68160, -1, 3072, -1, 768}, + {5148, 16340, -1, 3072, 768, 5376}, {5148, 68160, 3072, 2147483647, 1536, 5376}, + {68160, 2147483647, 3072, 2147483647, -1, 5376}, {3072, 5148, -1, 7170, 5376, 2147483647}, + {5148, 13364, 2976, 2147483647, 5376, 2147483647}, {13364, 2147483647, 5634, 7170, 5376, 10240}, + {13364, 142040, 7170, 2147483647, 7424, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankFP16M0Map = { + {128, + {{-1, 31220, -1, 2147483647, -1, 768}, {31220, 36980, 1280, 2147483647, -1, 768}, + {36980, 2147483647, -1, 2147483647, -1, 768}, {-1, 2147483647, -1, 2147483647, 768, 2147483647}}}, + {256, + {{31220, 36980, -1, 1280, -1, 768}}} + }; + + static std::map>> g_allreduceEightRankFP16DatasplitMap = { + {1, + {{-1, 3072, -1, 2147483647, -1, 768}, {3072, 26880, 3072, 2147483647, -1, 768}, + {-1, 1536, -1, 2147483647, 768, 1536}, {1536, 26880, 4608, 2147483647, 768, 1536}, + {26880, 53340, 4608, 2147483647, -1, 768}, {26880, 53340, 3072, 2147483647, 768, 1536}, + {53340, 2147483647, 3072, 2147483647, -1, 1536}, {-1, 768, 4608, 2147483647, 1536, 2147483647}, + {768, 5148, 4608, 2147483647, 1536, 7424}}}, + {4, + {{3072, 26880, -1, 3072, -1, 768}, {-1, 22848, 2976, 4608, 1536, 2147483647}, + {23040, 2147483647, 4608, 7170, 1536, 2147483647}}}, + {8, + {{1536, 26880, -1, 4608, 768, 1536}, {26880, 53340, -1, 3072, 768, 1536}, + {53340, 2147483647, -1, 3072, -1, 1536}, {-1, 2147483647, -1, 384, 3072, 10240}, + {3072, 2147483647, 384, 2976, 1536, 2147483647}, {22848, 2147483647, 2976, 4608, 1536, 2147483647}}}, + {2, + {{26880, 53340, -1, 4608, -1, 768}, {-1, 3072, 384, 2976, 1536, 2147483647}, + {768, 5148, 4608, 2147483647, 7424, 2147483647}, {5148, 23040, 4608, 7170, 1536, 2147483647}, + {5148, 2147483647, 7170, 2147483647, 1536, 2147483647}}}, + {16, + {{-1, 2147483647, -1, 384, 1536, 3072}, {-1, 2147483647, -1, 384, 10240, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankFP16UbmovenumMap = { + {100, + {{-1, 3072, -1, 2147483647, -1, 768}, {3072, 19680, -1, 3072, -1, 768}, + {-1, 3072, -1, 2147483647, 768, 1536}, {3072, 19680, -1, 3072, 768, 1536}, + {-1, 2147483647, 1792, 2976, 1536, 13312}}}, + {30, + {{3072, 19680, 3072, 2147483647, -1, 768}, {19680, 2147483647, -1, 3072, -1, 1536}, + {-1, 2147483647, -1, 1792, 1536, 13312}, {-1, 768, 2976, 2147483647, 5376, 13312}, + {-1, 768, -1, 2147483647, 13312, 2147483647}, {26880, 2147483647, -1, 3072, 13312, 2147483647}}}, + {20, + {{3072, 19680, 3072, 2147483647, 768, 1536}, {19680, 2147483647, 3072, 2147483647, -1, 1536}, + {-1, 2147483647, 2976, 2147483647, 1536, 5376}, {768, 2147483647, 2976, 2147483647, 5376, 13312}, + {768, 26880, -1, 2147483647, 13312, 2147483647}, {26880, 2147483647, 3072, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankFP16PvalueMap = { + {4, + {{-1, 768, -1, 2147483647, -1, 768}, {12980, 26880, -1, 3072, -1, 768}, + {-1, 15412, 2976, 4608, 1536, 2147483647}, {23040, 2147483647, 4608, 7170, 1536, 2147483647}}}, + {1, + {{768, 12980, -1, 2147483647, -1, 768}, {12980, 26880, 3072, 2147483647, -1, 768}, + {-1, 1536, -1, 2147483647, 768, 1536}, {1536, 26880, 4608, 2147483647, 768, 1536}, + {26880, 53340, 4608, 2147483647, -1, 768}, {26880, 53340, 3072, 2147483647, 768, 1536}, + {53340, 2147483647, 3072, 2147483647, -1, 1536}, {-1, 768, 4608, 2147483647, 1536, 2147483647}, + {768, 5148, 4608, 2147483647, 1536, 7424}}}, + {8, + {{1536, 26880, -1, 4608, 768, 1536}, {26880, 53340, -1, 3072, 768, 1536}, + {53340, 2147483647, -1, 3072, -1, 1536}, {-1, 2147483647, -1, 384, 3072, 10240}, + {3072, 2147483647, 384, 2976, 1536, 2147483647}, {15412, 2147483647, 2976, 4608, 1536, 2147483647}}}, + {2, + {{26880, 53340, -1, 4608, -1, 768}, {-1, 3072, 384, 2976, 1536, 2147483647}, + {768, 5148, 4608, 2147483647, 7424, 2147483647}, {5148, 23040, 4608, 7170, 1536, 2147483647}, + {5148, 2147483647, 7170, 2147483647, 1536, 2147483647}}}, + {14, + {{-1, 2147483647, -1, 384, 1536, 3072}, {-1, 2147483647, -1, 384, 10240, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankInT8M0Map = { + {128, + {{-1, 31220, -1, 2147483647, -1, 768}, {31220, 36980, 1280, 2147483647, -1, 768}, + {-1, 36980, -1, 2147483647, 768, 3072}, {36980, 2147483647, -1, 2147483647, -1, 3072}, + {-1, 2147483647, -1, 2147483647, 3072, 13312}, {-1, 1536, -1, 384, 13312, 2147483647}, + {5274, 2147483647, -1, 384, 13312, 2147483647}, {-1, 2147483647, 384, 2147483647, 13312, 2147483647}}}, + {256, + {{31220, 36980, -1, 1280, -1, 768}, {1536, 5274, -1, 384, 13312, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankInT8DatasplitMap = { + {1, + {{-1, 3072, -1, 2147483647, -1, 768}, {3072, 5148, 4608, 2147483647, -1, 768}, + {-1, 1536, -1, 2147483647, 768, 1536}, {3072, 5148, -1, 2147483647, 768, 1536}, + {5148, 2147483647, 5634, 2147483647, -1, 1536}, {-1, 2147483647, 11264, 2147483647, 1536, 5376}}}, + {4, + {{3072, 5148, -1, 4608, -1, 768}, {5148, 31220, -1, 3072, -1, 768}, + {5148, 2147483647, 3072, 4608, -1, 1536}, {-1, 2147483647, 5634, 11264, 1536, 5376}, + {34560, 2147483647, 5634, 2147483647, 5376, 7424}, {7196, 2147483647, 5634, 2147483647, 7424, 13312}}}, + {2, + {{1536, 3072, -1, 2147483647, 768, 1536}, {5148, 2147483647, 4608, 5634, -1, 1536}, + {-1, 34560, 5634, 2147483647, 5376, 7424}, {-1, 3072, -1, 2147483647, 7424, 2147483647}, + {3072, 7196, 5634, 2147483647, 7424, 2147483647}}}, + {8, + {{5148, 31220, -1, 3072, 768, 1536}, {31220, 2147483647, -1, 3072, -1, 1536}, + {-1, 2147483647, -1, 5634, 1536, 7424}, {3072, 7196, -1, 5634, 7424, 2147483647}, + {7196, 2147483647, -1, 5634, 7424, 13312}, {7196, 2147483647, -1, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankInT8PvalueMap = { + {14, + {{-1, 1536, -1, 2147483647, -1, 768}, {10010, 12980, -1, 2147483647, -1, 768}, + {-1, 7350, -1, 1536, 768, 1536}, {-1, 768, 1536, 2147483647, 768, 1536}, + {-1, 768, -1, 1536, 7424, 2147483647}}}, + {1, + {{1536, 10010, -1, 2147483647, -1, 768}, {12980, 2147483647, 5634, 2147483647, -1, 1536}, + {-1, 2147483647, 11264, 2147483647, 1536, 5376}}}, + {10, + {{7350, 12980, -1, 1536, 768, 1536}}}, + {2, + {{768, 12980, 1536, 2147483647, 768, 1536}, {12980, 2147483647, 4608, 5634, -1, 1536}, + {-1, 34560, 5634, 2147483647, 5376, 7424}, {-1, 768, 1536, 2147483647, 7424, 2147483647}, + {768, 3072, -1, 2147483647, 7424, 19456}}}, + {4, + {{12980, 36980, -1, 3072, -1, 768}, {12980, 2147483647, 3072, 4608, -1, 1536}, + {-1, 2147483647, 5634, 11264, 1536, 5376}, {34560, 2147483647, 5634, 2147483647, 5376, 7424}, + {768, 3072, -1, 2147483647, 19456, 2147483647}, {3072, 2147483647, 5634, 2147483647, 7424, 2147483647}}}, + {8, + {{12980, 36980, -1, 3072, 768, 1536}, {36980, 2147483647, -1, 3072, -1, 1536}, + {-1, 2147483647, -1, 5634, 1536, 7424}, {3072, 2147483647, -1, 5634, 7424, 2147483647}}} + }; + + static std::map>> g_allreduceEightRankInT8UbmovenumMap = { + {80, + {{-1, 7350, -1, 3072, -1, 768}}}, + {100, + {{-1, 7350, 3072, 2147483647, -1, 768}, {-1, 7350, -1, 7170, 768, 3072}, + {-1, 7350, -1, 4608, 3072, 5376}, {7350, 2147483647, -1, 3072, -1, 5376}, + {-1, 768, -1, 2147483647, 5376, 10240}, {768, 1536, -1, 4608, 5376, 2147483647}, + {3072, 2147483647, -1, 2976, 5376, 2147483647}}}, + {30, + {{-1, 7350, 7170, 2147483647, 768, 3072}, {-1, 3072, 7170, 2147483647, 3072, 5376}, + {7350, 23040, 3072, 2147483647, 768, 5376}, {23040, 2147483647, 3072, 2147483647, -1, 5376}, + {-1, 768, -1, 2147483647, 13312, 2147483647}, {768, 1536, 4608, 2147483647, 5376, 2147483647}, + {3072, 120832, 2976, 2147483647, 5376, 13312}, {3072, 2147483647, 2976, 4608, 13312, 2147483647}}}, + {50, + {{-1, 7350, 4608, 7170, 3072, 5376}, {-1, 768, -1, 2147483647, 10240, 13312}}}, + {20, + {{3072, 7350, 7170, 2147483647, 3072, 5376}, {1536, 3072, 7170, 2147483647, 5376, 2147483647}, + {120832, 2147483647, 2976, 2147483647, 5376, 13312}, + {3072, 2147483647, 4608, 2147483647, 13312, 2147483647}}}, + {40, + {{7350, 23040, 3072, 2147483647, -1, 768}, {1536, 3072, -1, 7170, 5376, 2147483647}}} + }; + + static std::map>> g_allreduceTwoRankFP16CommdatasplitMap = { + {16, + {{-1, 6656, -1, 2147483647, -1, 1536}, {6656, 2147483647, -1, 19456, -1, 1536}, + {7680, 2147483647, 19456, 2147483647, -1, 1536}, {-1, 2147483647, -1, 2147483647, 1536, 2147483647}}}, + {4, + {{6656, 7680, 19456, 2147483647, -1, 1536}}} + }; + + static std::map>> g_allreduceTwoRankFP16UbmovenumMap = { + {2, + {{-1, 1536, -1, 3072, -1, 1536}, {-1, 1536, 15360, 2147483647, -1, 1536}, + {1536, 6656, -1, 2147483647, -1, 1536}, {6656, 2147483647, -1, 19456, -1, 1536}, + {7680, 2147483647, 19456, 2147483647, -1, 1536}, {-1, 2147483647, -1, 2147483647, 1536, 2147483647}}}, + {3, + {{-1, 1536, 3072, 15360, -1, 1536}}}, + {6, + {{6656, 7680, 19456, 2147483647, -1, 1536}}} + }; + + static std::map>> g_allreduceTwoRankFP16SwizzldirectMap = { + {1, + {{-1, 6656, -1, 2147483647, -1, 7680}, {6656, 35840, -1, 13312, -1, 7680}, + {35840, 2147483647, -1, 2147483647, -1, 7680}, {-1, 25600, -1, 2147483647, 7680, 2147483647}, + {25600, 2147483647, -1, 2147483647, 7680, 9216}, {25600, 2147483647, -1, 15360, 9216, 11264}, + {25600, 2147483647, -1, 2147483647, 11264, 2147483647}}}, + {0, + {{6656, 35840, 13312, 2147483647, -1, 7680}, {25600, 2147483647, 15360, 2147483647, 9216, 11264}}} + }; + + static std::map>> g_allreduceTwoRankFP16SwizzlcountMap = { + {4, + {{-1, 5632, -1, 2147483647, -1, 1536}, {5632, 7680, -1, 17408, -1, 1536}, + {7680, 9216, -1, 11264, -1, 1536}, {9216, 2147483647, -1, 19456, -1, 1536}, + {19456, 2147483647, 19456, 2147483647, -1, 1536}, {-1, 2147483647, -1, 11264, 1536, 13312}, + {-1, 2147483647, 11264, 15360, 4608, 13312}, {-1, 2147483647, 17408, 2147483647, 1536, 13312}, + {-1, 9216, -1, 15360, 13312, 2147483647}, {-1, 9216, 17408, 2147483647, 13312, 2147483647}, + {9216, 25600, -1, 11264, 13312, 2147483647}, {25600, 35840, -1, 13312, 13312, 2147483647}, + {35840, 2147483647, -1, 2147483647, 13312, 2147483647}}}, + {8, + {{5632, 7680, 17408, 19456, -1, 1536}, {5632, 19456, 19456, 2147483647, -1, 1536}, + {-1, 2147483647, 11264, 17408, 1536, 4608}, {-1, 2147483647, 15360, 17408, 4608, 13312}, + {-1, 9216, 15360, 17408, 13312, 2147483647}, {9216, 25600, 11264, 2147483647, 13312, 2147483647}}}, + {16, + {{7680, 9216, 11264, 19456, -1, 1536}, {25600, 35840, 13312, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_allreduceTwoRankFP16M0Map = { + {128, + {{-1, 6656, -1, 2147483647, -1, 7680}, {6656, 2147483647, -1, 13312, -1, 7680}, + {-1, 1536, -1, 7680, 7680, 11264}, {-1, 1536, -1, 6656, 11264, 2147483647}, + {1536, 2147483647, -1, 2147483647, 7680, 2147483647}}}, + {256, + {{6656, 2147483647, 13312, 2147483647, -1, 7680}, {-1, 1536, 7680, 2147483647, 7680, 11264}, + {-1, 1536, 6656, 2147483647, 11264, 2147483647}}} + }; + + static std::map>> g_allreduceTwoRankFP16PvalueMap = { + {2, + {{-1, 2560, -1, 3584, -1, 1536}, {4608, 7680, -1, 7680, -1, 1536}, + {7680, 9216, -1, 2147483647, -1, 1536}, {-1, 15360, 4608, 13312, 1536, 2560}, + {-1, 7680, -1, 13312, 2560, 3584}, {6656, 15360, 13312, 2147483647, 2560, 3584}, + {15360, 25600, 4608, 15360, -1, 2560}, {25600, 2147483647, 19456, 2147483647, -1, 2560}, + {15360, 25600, 11264, 2147483647, 2560, 3584}, {-1, 15360, 9216, 17408, 3584, 9216}, + {-1, 6656, 13312, 2147483647, 11264, 2147483647}, {15360, 35840, 13312, 2147483647, 11264, 2147483647}}}, + {1, + {{-1, 2560, 3584, 2147483647, -1, 1536}, {2560, 4608, -1, 2147483647, -1, 1536}, + {4608, 7680, 7680, 2147483647, -1, 1536}, {9216, 15360, -1, 2147483647, -1, 1536}, + {-1, 15360, 13312, 2147483647, 1536, 2560}, {-1, 6656, 13312, 2147483647, 2560, 3584}, + {15360, 25600, 15360, 2147483647, -1, 2560}, {-1, 15360, 17408, 2147483647, 3584, 9216}, + {-1, 6656, 13312, 2147483647, 9216, 11264}, {9216, 15360, 13312, 2147483647, 9216, 2147483647}}}, + {3, + {{-1, 15360, -1, 4608, 1536, 2560}, {7680, 15360, -1, 13312, 2560, 3584}, + {15360, 2147483647, 2560, 4608, -1, 1536}, {25600, 2147483647, 4608, 19456, -1, 2560}, + {15360, 25600, 4608, 11264, 2560, 3584}, {-1, 15360, 1536, 9216, 3584, 9216}, + {-1, 6656, -1, 13312, 9216, 2147483647}, {15360, 25600, 11264, 2147483647, 3584, 7680}}}, + {4, + {{15360, 30720, -1, 1536, -1, 3584}, {15360, 2147483647, 1536, 2560, -1, 2560}, + {15360, 2147483647, 2560, 4608, 1536, 3584}, {25600, 2147483647, 4608, 2147483647, 2560, 3584}, + {-1, 15360, -1, 1536, 3584, 9216}, {6656, 9216, -1, 2147483647, 9216, 2147483647}, + {9216, 15360, -1, 13312, 9216, 2147483647}, {15360, 25600, -1, 11264, 3584, 7680}, + {25600, 35840, -1, 2147483647, 3584, 6656}, {15360, 35840, 5632, 2147483647, 7680, 11264}}}, + {6, + {{30720, 2147483647, -1, 1536, -1, 3584}, {15360, 2147483647, 1536, 2560, 2560, 3584}, + {25600, 35840, -1, 2147483647, 6656, 7680}, {15360, 35840, -1, 5632, 7680, 11264}, + {15360, 35840, -1, 13312, 11264, 2147483647}, {35840, 2147483647, -1, 2147483647, 3584, 2147483647}}} + }; + + int32_t AllReduceUbMoveNum(int m, int k, int n) + { + double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40; + double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n; + double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K); + double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL); + double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW); + double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2); + double matmulPredict = std::max(cubePredict, mteTimePredict); + double c0 = matmulPredict / commPredict; + double c1 = 1.0 * m * n / k; + double c2 = sqrt(c1); + double c3 = sqrt(1.0 * m * n) / k; + double c4 = c3 * c3; + double c5 = matmulPredict; + double c6 = commPredict; + double c7 = 1.0 * n / m; + double c8 = 1.0 * m * n / sqrt(k); + double c9 = 1.0 * m * n * sqrt(k); + double c10 = sqrt(1.0 * m * n) * k; + double c11 = sqrt(1.0 * m * n * k); + double c12 = sqrt(1.0 * m * n); + double c13 = 1.0 * k * k / sqrt(1.0 * m * n); + double c14 = 1.0 * k * k * sqrt(1.0 * m * n); + double ubMoveNumDouble = 0; + std::vector featsUpdate = { c0, c1, c2, c3, c4, c5, c6, c7, 1.0 / c0, 1.0 / c1, 1.0 / c2, 1.0 / c3, + 1.0 / c4, c8, c9, c10, c11, c12, c13, 1.0 / c13, c14, 1 }; + for (uint32_t i = 0; i < featsUpdate.size(); i++) { + ubMoveNumDouble += featsUpdate[i] * g_allreduceUbmovenumCoef[i]; + } + + return std::min(std::max(static_cast(ubMoveNumDouble) * HALF_KBYTE, MIN_UB_MOVE_NUM), MAX_UB_NUM); + } + + int32_t AllReducePValue(int m, int k, int n) + { + double commPredict = 1.0 * (m / ONE_K) * (n / ONE_K) * (SECOND_TO_MS / ONE_K) / 40; + double cubePredict = DOUBLE * m * k / B1_FLOP_PER_MS * n; + double mknGB = (m / ONE_K) * (k / ONE_K) * (n / ONE_K); + double mteTimePredict1 = GetMTETime(mknGB, DEFAULT_ROW, DEFAULT_COL); + double mteTimePredict2 = GetMTETime(mknGB, DEFAULT_COL, DEFAULT_ROW); + double mteTimePredict = std::min(mteTimePredict1, mteTimePredict2); + double matmulPredict = std::max(cubePredict, mteTimePredict); + double c0 = matmulPredict / commPredict; + double c1 = 1.0 * m * n / k; + double c2 = sqrt(c1); + double c3 = sqrt(1.0 * m * n) / k; + double c4 = c3 * c3; + double c5 = matmulPredict; + double c6 = commPredict; + double c7 = 1.0 * n / m; + double c8 = 1.0 * m * n / sqrt(k); + double c9 = 1.0 * m * n * sqrt(k); + double c10 = sqrt(1.0 * m * n) * k; + double c11 = sqrt(1.0 * m * n * k); + double c12 = sqrt(1.0 * m * n); + double c13 = 1.0 * k * k / sqrt(1.0 * m * n); + double c14 = 1.0 * k * k * sqrt(1.0 * m * n); + double pValueDouble = 0; + std::vector featsUpdate = { c0, c1, c2, c3, c4, c5, c6, c7, 1.0 / c0, 1.0 / c1, 1.0 / c2, 1.0 / c3, + 1.0 / c4, c8, c9, c10, c11, c12, c13, 1.0 / c13, c14, 1 }; + for (uint32_t i = 0; i < featsUpdate.size(); i++) { + pValueDouble += featsUpdate[i] * g_allreducePvalueCoef[i]; + } + + return std::min(std::max(static_cast(pValueDouble), 1), MAX_P_VALUE); + } + + void AllReduceSetWithSerialMode(CoCTilingData &cocTilingData) + { + int32_t m = cocTilingData.m; + int32_t k = cocTilingData.k; + int32_t n = cocTilingData.n; + + int64_t batchSize = cocTilingData.batchSize; + int64_t commSize = static_cast(batchSize) * m * n; + if (commSize <= ALLREDUCE_SERIAL_MODE_MN_SIZE && k <= ALLREDUCE_SERIAL_MODE_K_SIZE) { + cocTilingData.withSerialMode = 1; + cocTilingData.ubMoveNum = MAX_UB_NUM; + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } else { + cocTilingData.withSerialMode = 0; + } + } + + void AllReduceGetDefaultTiling(CoCTilingData &cocTilingData) + { + int64_t batchSize = cocTilingData.batchSize; + int32_t m = cocTilingData.m; + int32_t k = cocTilingData.k; + int32_t n = cocTilingData.n; + + cocTilingData.swizzlDirect = SWIZZLE_DIRECT_ONE; + cocTilingData.ubMoveNum = AllReduceUbMoveNum(m, k, n); + cocTilingData.pValue = AllReducePValue(m, k, n); + + int64_t cubeSize = static_cast(batchSize) * m * k * n; + int64_t commSize = static_cast(batchSize) * m * n; + constexpr int32_t bufferUnit = 1024; + if ((cubeSize <= MATMUL_BASE_100US && + commSize < (cocTilingData.bufferSize * bufferUnit * bufferUnit) / INPUT_DTYPE / MAX_BLOCK_COUNT) || + commSize <= ALLREDUCE_BASE_100US) { + cocTilingData.withSerialMode = 1; + cocTilingData.ubMoveNum = MAX_UB_NUM; + } else { + cocTilingData.withSerialMode = 0; + } + cocTilingData.commDirect = COMM_DATA_DIRECT; + cocTilingData.commNpuSplit = cocTilingData.rankSize; + cocTilingData.commDataSplit = COMMDATASPLIT_ONE; + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 * cocTilingData.pValue * cocTilingData.blockDim; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop / cocTilingData.rankSize; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceFourRankInt8GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLREDUCE_FOUR_RANK_INT8_M0_DEFAULT, + g_allreduceFourRankInT8M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_FOUR_RANK_INT8_UBMOVENUM_DEFAULT, + g_allreduceFourRankInT8UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_FOUR_RANK_INT8_PVALUE_DEFAULT, + g_allreduceFourRankInT8PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = ALLREDUCE_LENPERLOOP_DEFAULT / RANKSIZE_EIGHT * cocTilingData.rankSize; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + cocTilingData.ubMoveNum = cocTilingData.lenPerLoop; + + AllReduceSetWithSerialMode(cocTilingData); + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceFourRankFP16GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLREDUCE_FOUR_RANK_FP16_M0_DEFAULT, + g_allreduceFourRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_FOUR_RANK_FP16_UBMOVENUM_DEFAULT, + g_allreduceFourRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_FOUR_RANK_FP16_PVALUE_DEFAULT, + g_allreduceFourRankFP16PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = ALLREDUCE_LENPERLOOP_DEFAULT / RANKSIZE_EIGHT * cocTilingData.rankSize; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + cocTilingData.ubMoveNum = cocTilingData.lenPerLoop; + + AllReduceSetWithSerialMode(cocTilingData); + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData) + { + int dataSplit = 0; + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLREDUCE_EIGHT_RANK_FP16_M0_DEFAULT, + g_allreduceEightRankFP16M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_allreduceEightRankFP16UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_allreduceEightRankFP16PvalueMap}}, + {&dataSplit, + {ALLREDUCE_EIGHT_RANK_FP16_DATASPLIT_DEFAULT, + g_allreduceEightRankFP16DatasplitMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 * cocTilingData.pValue * cocTilingData.blockDim; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop / cocTilingData.rankSize / cocTilingData.commDataSplit; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop / dataSplit; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + cocTilingData.lenPerLoop = std::min(cocTilingData.lenPerLoop, TREE_LEN_PER_LOOP); + + AllReduceSetWithSerialMode(cocTilingData); + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceEightRankINT8GetDefaultTiling(CoCTilingData &cocTilingData) + { + int dataSplit = 0; + std::map tilingParamMap = { + {&cocTilingData.m0, + {ALLREDUCE_EIGHT_RANK_INT8_M0_DEFAULT, + g_allreduceEightRankInT8M0Map}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_EIGHT_RANK_INT8_UBMOVENUM_DEFAULT, + g_allreduceEightRankInT8UbmovenumMap}}, + {&cocTilingData.pValue, + {ALLREDUCE_EIGHT_RANK_INT8_PVALUE_DEFAULT, + g_allreduceEightRankInT8PvalueMap}}, + {&dataSplit, + {ALLREDUCE_EIGHT_RANK_INT8_DATASPLIT_DEFAULT, + g_allreduceEightRankInT8DatasplitMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 * cocTilingData.pValue * cocTilingData.blockDim; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop / cocTilingData.rankSize / cocTilingData.commDataSplit; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop / dataSplit; + cocTilingData.lenPerLoop = RoundNum(cocTilingData.lenPerLoop, HALF_KBYTE); + cocTilingData.lenPerLoop = std::min(cocTilingData.lenPerLoop, TREE_LEN_PER_LOOP); + + AllReduceSetWithSerialMode(cocTilingData); + SetSecondCoreSplitTling(cocTilingData); + } + + void AllReduceTwoRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map tilingParamMap = { + {&cocTilingData.commDataSplit, + {ALLREDUCE_TWO_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_allreduceTwoRankFP16CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {ALLREDUCE_TWO_RANK_FP16_UBMOVENUM_DEFAULT, + g_allreduceTwoRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, + {ALLREDUCE_TWO_RANK_FP16_SWIZZLDIRECT_DEFAULT, + g_allreduceTwoRankFP16SwizzldirectMap}}, + {&cocTilingData.swizzlCount, + {ALLREDUCE_TWO_RANK_FP16_SWIZZLCOUNT_DEFAULT, + g_allreduceTwoRankFP16SwizzlcountMap}}, + {&cocTilingData.m0, + {ALLREDUCE_TWO_RANK_FP16_M0_DEFAULT, + g_allreduceTwoRankFP16M0Map}}, + {&cocTilingData.pValue, + {ALLREDUCE_TWO_RANK_FP16_PVALUE_DEFAULT, + g_allreduceTwoRankFP16PvalueMap}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}} + }; + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + AllReduceSetWithSerialMode(cocTilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/alltoall_allgather_hidden_tiling.cpp b/comm/lcal/src/tiling/alltoall_allgather_hidden_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afac5b045ab47913635ee92660bcffc9e8d6c329 --- /dev/null +++ b/comm/lcal/src/tiling/alltoall_allgather_hidden_tiling.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" +#include "lcoc_func.h" + +namespace Lcal { +void CoCAllToAllAllGatherMatmulHiddenTilingFunc::GetDefaultTiling(const TaskParam &tilingInfo) +{ + CoCTilingFunc::GetDefaultTiling(tilingInfo); + auto k = tilingInfo.cocParamDesc.mmInfo.k; + auto m = tilingInfo.cocParamDesc.mmInfo.m; + auto maxOutputSize = tilingInfo.cocParamDesc.moeInfo.maxOutputSize; + + auto blockCount = MAX_BLOCK_COUNT; + int32_t maxPvalue = (k + 255) / 256; + + cocTilingData.m0 = DEFAULT_ROW; + cocTilingData.n0 = DEFAULT_COL; + cocTilingData.k0 = DEFAULT_COL; + int32_t bufferSize = tilingInfo.bufferSize * 1024 * 1024; + int32_t maxPeerMemPerRank = bufferSize / INPUT_DTYPE / blockCount; + constexpr int32_t Seven = 7; + cocTilingData.pValue = Seven; + if (cocTilingData.pValue > maxPvalue) { + cocTilingData.pValue = maxPvalue; + } + + if (m < DEFAULT_ROW) { + cocTilingData.pValue = (k + cocTilingData.k0 - 1) / cocTilingData.k0; + } + + if (cocTilingData.pValue * cocTilingData.k0 * maxOutputSize > maxPeerMemPerRank) { + cocTilingData.pValue = maxPeerMemPerRank / maxOutputSize / cocTilingData.k0; + } + cocTilingData.ubMoveNum = AllTOAll_HIDDEN_UBMOVENUM; + constexpr int32_t two = 2; + int32_t maxUbPingPongSize = cocTilingData.ubMoveNum / two; + if (cocTilingData.pValue * cocTilingData.k0 > maxUbPingPongSize) { + cocTilingData.pValue = maxUbPingPongSize / cocTilingData.k0; + } + + return; +} + +bool CoCAllToAllAllGatherMatmulHiddenTilingFunc::CheckTiling(const TaskParam &tilingInfo) +{ + int32_t rankSize = cocTilingData.rankSize; + int32_t ep = tilingInfo.cocParamDesc.moeInfo.EP; + int32_t tp = tilingInfo.cocParamDesc.moeInfo.TP; + int32_t expertPerRank = tilingInfo.cocParamDesc.moeInfo.local_expert_nums; + int32_t k = tilingInfo.cocParamDesc.mmInfo.k; + auto maxOutputSize = tilingInfo.cocParamDesc.moeInfo.maxOutputSize; + + auto blockCount = MAX_BLOCK_COUNT; + int32_t maxPeerMemPerRank = (tilingInfo.bufferSize * 1024 * 1024) / INPUT_DTYPE / blockCount; + if ((cocTilingData.pValue - 1) * cocTilingData.k0 > k) { + return false; + } + if (cocTilingData.pValue * cocTilingData.k0 * maxOutputSize > maxPeerMemPerRank) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(cocTilingData.pValue) + ", k0: " + + std::to_string(cocTilingData.k0) + "maxPeerMemPerRank: " + std::to_string(maxPeerMemPerRank); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + constexpr int32_t Two = 2; + int32_t maxUbPingPongSize = cocTilingData.ubMoveNum / Two; + if (cocTilingData.pValue * cocTilingData.k0 > maxUbPingPongSize) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(cocTilingData.pValue) + ", k0: " + + std::to_string(cocTilingData.k0) + "maxUbPingPongSize: " + std::to_string(maxUbPingPongSize); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + + if (ep * tp != rankSize) { + std::string str = "The ep * tp != rankSize. " + "rankSize: " + std::to_string(rankSize) + ", ep: " + std::to_string(ep) + + " , tp: " + std::to_string(tp); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + + std::vector> paramCheckList = { + {"expertPerrank", expertPerRank, 1, 20} + }; + return CheckParamScopeList(paramCheckList); +} +} diff --git a/comm/lcal/src/tiling/alltoall_allgather_tiling.cpp b/comm/lcal/src/tiling/alltoall_allgather_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e788eb2ca91ceed93090c9d8a5e163ca907ebf4f --- /dev/null +++ b/comm/lcal/src/tiling/alltoall_allgather_tiling.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" +#include "lcoc_func.h" +# +namespace Lcal { +void CoCAllToAllAllGatherMatmulTilingFunc::GetDefaultTiling(const TaskParam &tilingInfo) +{ + CoCTilingFunc::GetDefaultTiling(tilingInfo); + cocTilingData.m0 = DEFAULT_ROW; + cocTilingData.n0 = DEFAULT_COL; + cocTilingData.k0 = DEFAULT_COL; + constexpr int32_t pValue = 1; + cocTilingData.pValue = pValue; + constexpr int32_t ubMove = 28672; + cocTilingData.ubMoveNum = ubMove; + return; +} + +bool CheckPValue(const TaskParam &tilingInfo, const CoCTilingData &data) +{ + auto blockCount = MAX_BLOCK_COUNT; + int32_t bufferSize = tilingInfo.bufferSize * 1024 * 1024; + int32_t maxPeerMemPerRank = bufferSize / INPUT_DTYPE / data.rankSize / blockCount; + if (data.pValue * data.m0 * data.k0 * data.kLoop >= maxPeerMemPerRank) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(data.pValue) + ", m0: " + std::to_string(data.m0) + + ", k0: " + std::to_string(data.k0) + ", kLoop: " + std::to_string(data.kLoop) + + "maxPeerMemPerRank: " + std::to_string(maxPeerMemPerRank); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + return true; +} + +bool CoCAllToAllAllGatherMatmulTilingFunc::CheckTiling(const TaskParam &tilingInfo) +{ + if (!CoCTilingFunc::CheckTiling(tilingInfo)) { + return false; + } + if (!CheckPValue(tilingInfo, cocTilingData)) { + return false; + } + + int32_t rankSize = cocTilingData.rankSize; + int32_t ep = tilingInfo.cocParamDesc.moeInfo.EP; + int32_t tp = tilingInfo.cocParamDesc.moeInfo.TP; + int32_t expertPerRank = tilingInfo.cocParamDesc.moeInfo.local_expert_nums; + + if (ep * tp != rankSize) { + std::string str = "The ep * tp != rankSize. " + "rankSize: " + std::to_string(rankSize) + ", ep: " + std::to_string(ep) + + " , tp: " + std::to_string(tp); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + + std::vector> paramCheckList = { + {"expertPerrank", expertPerRank, 1, 20} + }; + return CheckParamScopeList(paramCheckList); +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/reducescatter_alltoall_hidden_tiling.cpp b/comm/lcal/src/tiling/reducescatter_alltoall_hidden_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a1e9ffb67332432c33f7a7604169465d54654bf7 --- /dev/null +++ b/comm/lcal/src/tiling/reducescatter_alltoall_hidden_tiling.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" +#include "lcoc_func.h" +# +namespace Lcal { +void CoCMatmulReduceScatterAllToAllHiddenTilingFunc::GetDefaultTiling(const TaskParam &tilingInfo) +{ + CoCTilingFunc::GetDefaultTiling(tilingInfo); + auto n = tilingInfo.cocParamDesc.mmInfo.n; + auto maxOutputSize = tilingInfo.cocParamDesc.moeInfo.maxOutputSize; + int32_t maxPvalue = (n + 255) / 256; + + cocTilingData.m0 = DEFAULT_ROW; + cocTilingData.n0 = DEFAULT_COL; + cocTilingData.k0 = DEFAULT_COL; + int32_t m = tilingInfo.cocParamDesc.mmInfo.m; + auto blockCount = MAX_BLOCK_COUNT; + int32_t bufferSize = tilingInfo.bufferSize * 1024 * 1024; + int32_t maxPeerMemPerRank = bufferSize / INPUT_DTYPE / blockCount; + constexpr int32_t Four = 4; + cocTilingData.pValue = Four; + if (cocTilingData.pValue > maxPvalue) { + cocTilingData.pValue = maxPvalue; + } + + if (m < DEFAULT_ROW) { + cocTilingData.pValue = (n + cocTilingData.n0 - 1) / cocTilingData.n0; + } + + if (cocTilingData.pValue * cocTilingData.n0 * maxOutputSize > maxPeerMemPerRank) { + cocTilingData.pValue = maxPeerMemPerRank / maxOutputSize / cocTilingData.n0; + } + + cocTilingData.ubMoveNum = AllTOAll_HIDDEN_UBMOVENUM; + constexpr int32_t two = 2; + int32_t maxUbPingPongSize = cocTilingData.ubMoveNum / two; + if (cocTilingData.pValue * cocTilingData.n0 > maxUbPingPongSize) { + cocTilingData.pValue = maxUbPingPongSize / cocTilingData.n0; + } + return; +} +bool CoCMatmulReduceScatterAllToAllHiddenTilingFunc::CheckTiling(const TaskParam &tilingInfo) +{ + int32_t rankSize = cocTilingData.rankSize; + int32_t ep = tilingInfo.cocParamDesc.moeInfo.EP; + int32_t tp = tilingInfo.cocParamDesc.moeInfo.TP; + int32_t expertPerRank = tilingInfo.cocParamDesc.moeInfo.local_expert_nums; + int32_t n = tilingInfo.cocParamDesc.mmInfo.n; + auto maxOutputSize = tilingInfo.cocParamDesc.moeInfo.maxOutputSize; + + auto blockCount = MAX_BLOCK_COUNT; + int32_t maxPeerMemPerRank = (tilingInfo.bufferSize * 1024 * 1024) / INPUT_DTYPE / blockCount; + if ((cocTilingData.pValue - 1) * cocTilingData.n0 > n) { + return false; + } + if (cocTilingData.pValue * cocTilingData.n0 * maxOutputSize > maxPeerMemPerRank) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(cocTilingData.pValue) + ", n0: " + + std::to_string(cocTilingData.n0) + "maxPeerMemPerRank: " + std::to_string(maxPeerMemPerRank); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + constexpr int32_t Two = 2; + int32_t maxUbPingPongSize = cocTilingData.ubMoveNum / Two; + if (cocTilingData.pValue * cocTilingData.n0 > maxUbPingPongSize) { + std::string str = "The k value is too large and is currently not supported. " + "pValue: " + std::to_string(cocTilingData.pValue) + ", n0: " + + std::to_string(cocTilingData.n0) + "maxUbPingPongSize: " + std::to_string(maxUbPingPongSize); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + + if (ep * tp != rankSize) { + std::string str = "The ep * tp != rankSize. " + "rankSize: " + std::to_string(rankSize) + ", ep: " + std::to_string(ep) + + " , tp: " + std::to_string(tp); + PrintErrorLog(tilingInfo.lcalType, str); + return false; + } + + std::vector> paramCheckList = { + {"expertPerrank", expertPerRank, 1, 20} + }; + return CheckParamScopeList(paramCheckList); +} +} diff --git a/comm/lcal/src/tiling/reducescatter_tiling.cpp b/comm/lcal/src/tiling/reducescatter_tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0afd92049c46f3c7a957251ad69e13bf5993d977 --- /dev/null +++ b/comm/lcal/src/tiling/reducescatter_tiling.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tiling.h" +#include "lcoc_func.h" +#include "tiling_910B.h" +#include "tiling_91093.h" +#include "tiling_func.h" +namespace Lcal { +void CoCMatmulReduceScatterTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + CoCTilingFunc::GetDefaultTiling(taskParam); + if (Is91093(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_EIGHT) { + ReduceScatterNPU91093EightRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_SIXTEEN) { + ReduceScatterNPU91093SixteenRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_TWO && + taskParam.cocParamDesc.mmInfo.isInt8) { + ReduceScatterNPU91093TwoRankINT8Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_TWO) { + ReduceScatterNPU91093TwoRankFP16Tiling(cocTilingData); + return; + } else if (cocTilingData.rankSize == RANKSIZE_FOUR) { + ReduceScatterNPU91093FourRankFP16Tiling(cocTilingData); + return; + } + } else if (Is910B(taskParam.chipName)) { + if (cocTilingData.rankSize == RANKSIZE_FOUR) { + ReduceScatterFourRankINT8Tiling(cocTilingData); // INT8 + return; + } + } + ReduceScatterEightRankFP16GetDefaultTiling(cocTilingData); +} + +bool CoCMatmulReduceScatterTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + if (!CoCTilingFunc::CheckTiling(taskParam)) { + return false; + } + auto pValue = cocTilingData.pValue; + auto rankSize = cocTilingData.rankSize; + auto blockDim = cocTilingData.blockDim; + if ((pValue * blockDim) % rankSize != 0) { + std::string str = "The product of pValue and blockDim must be divisible by rankSize." + " pValue: " + std::to_string(pValue) + " blockDim: " + std::to_string(blockDim) + + " rankSize: " + std::to_string(rankSize); + PrintErrorLog(taskParam.lcalType, str); + return false; + } + return true; +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/reducescatter_tiling_91093.cpp b/comm/lcal/src/tiling/reducescatter_tiling_91093.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c30c0c1628ed2e8b6dbbc0cbf0a3063647c6c0c9 --- /dev/null +++ b/comm/lcal/src/tiling/reducescatter_tiling_91093.cpp @@ -0,0 +1,516 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tiling_91093.h" +#include "tiling_func.h" +namespace Lcal { + constexpr int32_t REDUCESCATTER_91093_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t REDUCESCATTER_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 40; + constexpr int32_t REDUCESCATTER_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t REDUCESCATTER_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT = 40; + constexpr int32_t REDUCESCATTER_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT = 14; + constexpr int32_t REDUCESCATTER_91093_SIXTEEN_RANK_FP16_M0_DEFAULT = 128; + + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT = 16; + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_FP16_PVALUE_DEFAULT = 12; + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_INT8_PVALUE_DEFAULT = 14; + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_INT8_M0_DEFAULT = 128; + constexpr int32_t REDUCESCATTER_91093_TWO_RANK_INT8_UBMOVENUM_DEFAULT = 16; + constexpr int32_t REDUCESCATTER_91093_FOUR_RANK_FP16_UBMOVENUM_DEFAULT = 20; + constexpr int32_t REDUCESCATTER_91093_FOUR_RANK_FP16_PVALUE_DEFAULT = 12; + constexpr int32_t REDUCESCATTER_91093_FOUR_RANK_FP16_M0_DEFAULT = 128; + + static std::map>> g_reducescatter91093EightRankFP16PvalueMap = { + {4, + {{-1, 6656, -1, 2560, -1, 1536}, {2560, 6656, -1, 3584, 1536, 2560}, + {6656, 7680, 1536, 2560, -1, 3584}, {7680, 2147483647, -1, 2560, 1536, 3584}, + {-1, 2147483647, -1, 2560, 3584, 6656}, {-1, 2147483647, 1536, 2560, 8704, 11264}, + {-1, 3584, 1536, 2560, 11264, 13312}, {4608, 2147483647, 2560, 4608, 4608, 2147483647}, + {4608, 2147483647, 4608, 5632, 6656, 2147483647}, {-1, 2147483647, 5632, 6656, 6656, 2147483647}, + {3584, 2147483647, 6656, 7680, 3584, 2147483647}}}, + {2, + {{-1, 6656, 2560, 4608, -1, 1536}, {-1, 6656, 11264, 2147483647, -1, 1536}, + {-1, 2560, -1, 3584, 1536, 2560}, {-1, 6656, 2560, 4608, 2560, 3584}, + {6656, 8704, 2560, 2147483647, -1, 2560}, {8704, 9728, 2560, 2147483647, -1, 3584}, + {9728, 2147483647, 4608, 2147483647, -1, 2560}, {2560, 4608, 2560, 5632, 3584, 4608}, + {-1, 3584, 2560, 5632, 4608, 2147483647}, {4608, 2147483647, 4608, 5632, 3584, 6656}, + {-1, 3584, 6656, 7680, 3584, 2147483647}, {-1, 2560, 7680, 9728, 3584, 2147483647}}}, + {1, + {{-1, 6656, 4608, 11264, -1, 1536}, {-1, 6656, 3584, 2147483647, 1536, 2560}, + {-1, 4608, 4608, 9728, 2560, 3584}, {4608, 6656, 4608, 2147483647, 2560, 3584}, + {9728, 2147483647, 13312, 2147483647, 2560, 3584}, {-1, 2560, 9728, 2147483647, 3584, 2147483647}}}, + {8, + {{-1, 6656, -1, 1536, 2560, 3584}, {-1, 2147483647, 1536, 2560, 6656, 8704}, + {-1, 4608, -1, 1536, 11264, 13312}, {3584, 2147483647, 1536, 2560, 11264, 13312}, + {7680, 2147483647, 1536, 2560, 13312, 2147483647}}}, + {3, + {{-1, 6656, 1536, 2560, 2560, 3584}, {-1, 4608, 9728, 2147483647, 2560, 3584}, + {6656, 8704, 2560, 2147483647, 2560, 3584}, {9728, 2147483647, 2560, 4608, -1, 2560}, + {9728, 2147483647, 2560, 13312, 2560, 3584}, {-1, 2560, 2560, 5632, 3584, 4608}, + {3584, 4608, 2560, 5632, 4608, 2147483647}, {4608, 2147483647, 2560, 4608, 3584, 4608}, + {-1, 2147483647, 5632, 6656, 3584, 6656}, {2560, 2147483647, 7680, 2147483647, 3584, 2147483647}}}, + {10, + {{6656, 7680, -1, 1536, -1, 3584}, {-1, 2147483647, -1, 1536, 6656, 8704}, + {-1, 4608, -1, 1536, 13312, 2147483647}}}, + {6, + {{7680, 2147483647, -1, 2560, -1, 1536}}}, + {12, + {{-1, 2147483647, -1, 1536, 8704, 11264}, {4608, 2147483647, -1, 1536, 13312, 2147483647}, + {-1, 7680, 1536, 2560, 13312, 2147483647}}}, + {14, + {{4608, 2147483647, -1, 1536, 11264, 13312}}} + }; + + static std::map>> g_reducescatter91093EightRankFP16UbmovenumMap = { + {8.0, + {{-1, 1536, -1, 7168, -1, 1536}, {9728, 2147483647, 4608, 7168, 1536, 2560}}}, + {12.0, + {{-1, 1536, 7168, 2147483647, -1, 1536}, {1536, 6656, -1, 2147483647, -1, 1536}, + {7680, 9728, -1, 2147483647, -1, 1536}, {-1, 9728, 3584, 2147483647, 1536, 2560}, + {-1, 9728, 2560, 2147483647, 2560, 6656}, {9728, 2147483647, -1, 4608, 1536, 2560}, + {9728, 2147483647, -1, 8704, 2560, 5632}, {9728, 2147483647, 2560, 11264, 5632, 6656}, + {-1, 2147483647, 2560, 2147483647, 6656, 13312}, {4608, 2147483647, 2560, 11264, 13312, 2147483647}}}, + {16.0, + {{6656, 7680, -1, 6144, -1, 1536}, {-1, 9728, 1536, 3584, 1536, 2560}, + {-1, 9728, -1, 2560, 3584, 6656}, {9728, 2147483647, 1536, 2147483647, -1, 1536}, + {9728, 2147483647, 7168, 2147483647, 1536, 2560}, {9728, 2147483647, 8704, 2147483647, 2560, 5632}, + {9728, 2147483647, 1536, 2560, 5632, 6656}, {-1, 2147483647, 1536, 2560, 6656, 8704}, + {2048, 2147483647, 1536, 2560, 8704, 9728}, {-1, 2147483647, 1536, 2560, 9728, 2147483647}, + {-1, 4608, 2560, 11264, 13312, 2147483647}, {2560, 2147483647, 11264, 2147483647, 13312, 2147483647}}}, + {10.0, + {{6656, 7680, 6144, 2147483647, -1, 1536}}}, + {20.0, + {{-1, 9728, -1, 1536, 1536, 2560}, {-1, 9728, -1, 2560, 2560, 3584}, + {9728, 2147483647, 11264, 2147483647, 5632, 6656}, {-1, 2560, -1, 1536, 11264, 13312}, + {-1, 2048, 1536, 2560, 8704, 9728}, {-1, 2560, 11264, 2147483647, 13312, 2147483647}}}, + {40.0, + {{9728, 2147483647, -1, 1536, -1, 1536}, {9728, 2147483647, -1, 1536, 5632, 6656}, + {-1, 8704, -1, 1536, 6656, 7680}, {-1, 3584, -1, 1536, 7680, 11264}, + {-1, 4608, -1, 1536, 13312, 2147483647}}}, + {30.0, + {{8704, 2147483647, -1, 1536, 6656, 7680}, {3584, 2147483647, -1, 1536, 7680, 11264}, + {2560, 2147483647, -1, 1536, 11264, 13312}, {4608, 2147483647, -1, 1536, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093EightRankFP16M0Map = { + {128, + {{-1, 1536, -1, 7168, -1, 1536}, {1536, 2560, -1, 9728, -1, 1536}, + {-1, 2560, -1, 2147483647, 1536, 3584}, {2560, 2147483647, -1, 2147483647, -1, 3584}, + {-1, 1536, -1, 2147483647, 3584, 11264}, {1536, 2560, -1, 2147483647, 3584, 5632}, + {2560, 2147483647, -1, 2147483647, 3584, 11264}, {-1, 1536, -1, 2147483647, 11264, 13312}, + {1536, 2147483647, -1, 2560, 11264, 13312}, {-1, 2147483647, -1, 3584, 13312, 2147483647}, + {-1, 2147483647, 5632, 2147483647, 13312, 2147483647}}}, + {256, + {{-1, 1536, 7168, 2147483647, -1, 1536}, {1536, 2560, 9728, 2147483647, -1, 1536}, + {1536, 2560, -1, 2147483647, 5632, 11264}, {1536, 2147483647, 2560, 2147483647, 11264, 13312}, + {-1, 2147483647, 3584, 5632, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093SixteenRankFP16M0Map = { + {128, + {{-1, 6656, -1, 2147483647, -1, 1536}, {6656, 8704, -1, 3584, -1, 1536}, + {-1, 8704, -1, 2147483647, 1536, 3584}, {8704, 2147483647, -1, 2147483647, -1, 3584}, + {-1, 2147483647, -1, 3584, 3584, 2147483647}, {-1, 2560, 3584, 2147483647, 3584, 2147483647}, + {4608, 2147483647, 3584, 2147483647, 3584, 2147483647}}}, + {256, + {{6656, 8704, 3584, 2147483647, -1, 1536}, {2560, 4608, 3584, 2147483647, 3584, 2147483647}}} + }; + + static std::map>> g_reducescatter91093SixteenRankFP16PvalueMap = { + {8, + {{-1, 1536, -1, 7168, -1, 1536}, {1536, 2560, -1, 5632, -1, 1536}, + {5632, 6656, -1, 2560, 1536, 2560}, {7680, 9728, -1, 2560, -1, 3584}, + {-1, 2147483647, -1, 2560, 3584, 4608}, {2560, 2147483647, -1, 2560, 8704, 9728}}}, + {2, + {{-1, 1536, 7168, 2147483647, -1, 1536}, {1536, 2560, 5632, 2147483647, -1, 1536}, + {2560, 4608, -1, 2560, -1, 1536}, {2560, 6656, 2560, 2147483647, -1, 1536}, + {-1, 5632, -1, 2147483647, 1536, 2560}, {5632, 6656, 2560, 2147483647, 1536, 2560}, + {-1, 6656, 3584, 2147483647, 2560, 3584}, {6656, 7680, 4608, 2147483647, -1, 3584}, + {9728, 2147483647, -1, 3584, -1, 1536}, {7680, 2147483647, 3584, 2147483647, -1, 3584}, + {-1, 6656, 4608, 11264, 3584, 2147483647}, {6656, 2147483647, 9728, 11264, 3584, 2147483647}, + {-1, 2147483647, 11264, 2147483647, 3584, 2147483647}}}, + {4, + {{4608, 6656, -1, 2560, -1, 1536}, {-1, 5632, -1, 3584, 2560, 3584}, + {6656, 7680, -1, 4608, -1, 1536}, {6656, 7680, -1, 4608, 2560, 3584}, + {7680, 9728, 2560, 3584, -1, 3584}, {9728, 2147483647, -1, 3584, 1536, 3584}, + {-1, 2147483647, 2560, 4608, 3584, 7680}, {-1, 7680, 2560, 4608, 7680, 2147483647}, + {7680, 2147483647, 2560, 4608, 7680, 8704}, {6656, 2147483647, 4608, 9728, 3584, 2147483647}}}, + {6, + {{5632, 6656, -1, 3584, 2560, 3584}, {6656, 7680, -1, 4608, 1536, 2560}, + {-1, 2560, -1, 2560, 8704, 9728}, {7680, 2147483647, 2560, 4608, 8704, 2147483647}}}, + {10, + {{-1, 2147483647, -1, 2560, 4608, 8704}, {-1, 2147483647, -1, 2560, 9728, 13312}, + {-1, 2147483647, 1536, 2560, 13312, 2147483647}}}, + {14, + {{-1, 2147483647, -1, 1536, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093SixteenRankFP16UbmovenumMap = { + {16.0, + {{-1, 9728, -1, 2560, -1, 2560}, {9728, 2147483647, -1, 1536, -1, 2560}, + {9728, 2147483647, 9216, 2147483647, -1, 2560}, {9728, 2147483647, 7680, 2147483647, 2560, 5632}, + {-1, 2147483647, 1536, 2560, 5632, 9728}, {-1, 2147483647, 7680, 13312, 5632, 11264}, + {-1, 2147483647, 7680, 2147483647, 13312, 2147483647}}}, + {12.0, + {{-1, 9728, 2560, 2147483647, -1, 3584}, {-1, 9728, -1, 2147483647, 3584, 5632}, + {9728, 2147483647, 1536, 9216, -1, 2560}, {9728, 2147483647, -1, 7680, 2560, 5632}, + {-1, 2147483647, 2560, 7680, 5632, 11264}, {2560, 2147483647, 2560, 2147483647, 11264, 13312}, + {-1, 2147483647, 2560, 7680, 13312, 2147483647}}}, + {30.0, + {{-1, 9728, -1, 2560, 2560, 3584}, {-1, 2147483647, -1, 1536, 5632, 2147483647}}}, + {20.0, + {{-1, 2147483647, 1536, 2560, 9728, 2147483647}, {-1, 2147483647, 13312, 2147483647, 5632, 11264}}}, + {40.0, + {{-1, 2560, 2560, 2147483647, 11264, 13312}}} + }; + + static std::map>> g_reducescatter91093TwoRankFP16PvalueMap = { + {3, + {{-1, 9728, -1, 1536, -1, 1536}, {-1, 1536, 5632, 2147483647, 2560, 3584}, + {9728, 2147483647, -1, 2560, 1536, 2560}, {9728, 2147483647, 5120, 11264, 2560, 3584}, + {9728, 2147483647, 13312, 2147483647, 2560, 3584}, {2560, 3584, 2560, 2147483647, 3584, 4608}, + {7168, 2147483647, 2560, 3584, 9728, 2147483647}, {3584, 2147483647, 6656, 2147483647, 9728, 2147483647}}}, + {2, + {{-1, 9728, 1536, 2560, -1, 1536}, {-1, 9728, 2560, 3584, 1536, 2560}, + {1536, 9728, 2560, 3584, 2560, 3584}, {9728, 2147483647, 1536, 3584, -1, 1536}, + {9728, 2147483647, 2560, 9216, 1536, 2560}, {9728, 2147483647, 11264, 2147483647, 1536, 2560}, + {9728, 2147483647, 2560, 5120, 2560, 3584}, {-1, 1536, 1536, 2560, 3584, 7680}, + {3584, 2147483647, 2560, 3584, 3584, 9728}, {3584, 7168, 2560, 3584, 9728, 2147483647}, + {3584, 9728, 3584, 6656, 3584, 2147483647}, {3584, 2147483647, 6656, 2147483647, 3584, 9728}}}, + {1, + {{-1, 9728, 2560, 2147483647, -1, 1536}, {-1, 9728, 3584, 2147483647, 1536, 2560}, + {-1, 1536, 2560, 5632, 2560, 3584}, {1536, 9728, 3584, 2147483647, 2560, 3584}, + {9728, 2147483647, 3584, 2147483647, -1, 1536}, {9728, 2147483647, 9216, 11264, 1536, 2560}, + {9728, 2147483647, 11264, 13312, 2560, 3584}, {-1, 2560, 2560, 2147483647, 3584, 6656}, + {2560, 3584, 2560, 2147483647, 4608, 6656}, {-1, 3584, 2560, 2147483647, 6656, 2147483647}}}, + {4, + {{-1, 5632, -1, 2560, 1536, 2560}, {-1, 2147483647, -1, 2560, 2560, 3584}, + {-1, 3584, -1, 1536, 3584, 7680}, {1536, 6656, 1536, 2560, 3584, 7680}, + {-1, 3584, 1536, 2560, 7680, 2147483647}, {9728, 2147483647, 3584, 6656, 3584, 2147483647}}}, + {6, + {{5632, 9728, -1, 2560, 1536, 2560}, {9728, 2147483647, -1, 1536, -1, 1536}, + {-1, 3584, -1, 1536, 7680, 9728}, {9728, 2147483647, -1, 1536, 3584, 4608}, + {3584, 6656, 1536, 2560, 7680, 2147483647}, {6656, 2147483647, 1536, 2560, 3584, 2147483647}}}, + {8, + {{-1, 3584, -1, 1536, 9728, 2147483647}, {3584, 9728, -1, 1536, 3584, 4608}}}, + {12, + {{3584, 7680, -1, 1536, 4608, 2147483647}}}, + {10, + {{7680, 2147483647, -1, 1536, 4608, 2147483647}}} + }; + + static std::map>> g_reducescatter91093TwoRankFP16UbmovenumMap = { + {6.0, + {{-1, 1536, -1, 8704, -1, 1536}}}, + {8.0, + {{-1, 1536, 8704, 2147483647, -1, 1536}, {-1, 2560, -1, 1536, 9728, 2147483647}}}, + {10.0, + {{1536, 6656, -1, 1536, -1, 1536}, {6656, 7680, -1, 1536, -1, 2560}, + {8704, 2147483647, -1, 1536, -1, 2560}, {2560, 8704, -1, 1536, 2560, 3584}, + {2560, 2147483647, -1, 1536, 5632, 2147483647}}}, + {4.0, + {{1536, 6656, 1536, 2147483647, -1, 1536}, {-1, 8704, 1536, 2147483647, 1536, 2560}, + {-1, 2560, 4608, 2147483647, 2560, 7680}, {-1, 2560, 1536, 2147483647, 7680, 2147483647}, + {8704, 2147483647, -1, 1536, 2560, 3584}, {2560, 2147483647, 13312, 2147483647, 2560, 13312}}}, + {16.0, + {{-1, 3584, -1, 1536, 1536, 2560}}}, + {12.0, + {{3584, 6656, -1, 1536, 1536, 2560}, {7680, 8704, -1, 1536, -1, 2560}, + {-1, 2560, -1, 1536, 2560, 9728}, {2560, 2147483647, -1, 1536, 3584, 5632}}}, + {3.0, + {{6656, 2147483647, 1536, 2147483647, -1, 1536}, {8704, 2147483647, 1536, 2147483647, 1536, 2560}, + {-1, 2560, 1536, 4608, 2560, 7680}, {2560, 2147483647, 1536, 13312, 2560, 13312}, + {2560, 2147483647, 1536, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093TwoRankFP16M0Map = { + {256, + {{-1, 1536, -1, 7168, -1, 1536}, {-1, 1536, 6656, 2147483647, 5632, 2147483647}, + {1536, 3584, 4608, 2147483647, 5632, 2147483647}}}, + {128, + {{-1, 1536, 7168, 2147483647, -1, 1536}, {1536, 4608, -1, 2147483647, -1, 1536}, + {-1, 4608, -1, 2147483647, 1536, 5632}, {4608, 2147483647, -1, 2147483647, -1, 5632}, + {-1, 2147483647, -1, 4608, 5632, 2147483647}, {-1, 1536, 4608, 6656, 5632, 2147483647}, + {3584, 2147483647, 4608, 2147483647, 5632, 2147483647}}} + }; + + static std::map>> g_reducescatter91093TwoRankINT8UbmovenumMap = { + {16.0, + {{-1, 4608, -1, 2560, -1, 1536}}}, + {8.0, + {{-1, 4608, 2560, 4608, -1, 1536}, {-1, 4608, 1536, 3584, 3584, 4608}, + {4608, 8704, 1536, 2560, -1, 4608}, {-1, 1536, 1536, 2560, 4608, 2147483647}, + {1536, 4608, -1, 2560, 9728, 11264}}}, + {4.0, + {{-1, 4608, 4608, 2147483647, -1, 1536}, {-1, 4608, -1, 2147483647, 1536, 2560}, + {-1, 4608, 4608, 2147483647, 2560, 3584}, {-1, 2560, 3584, 2147483647, 3584, 4608}, + {3584, 4608, 3584, 2147483647, 3584, 4608}, {4608, 8704, 4608, 7680, -1, 4608}, + {-1, 1536, 3584, 2147483647, 4608, 2147483647}}}, + {3.0, + {{-1, 4608, -1, 4608, 2560, 3584}, {2560, 3584, 3584, 2147483647, 3584, 4608}, + {4608, 8704, 2560, 4608, -1, 4608}, {4608, 8704, 7680, 2147483647, -1, 4608}, + {8704, 2147483647, 2560, 2147483647, -1, 4608}, {-1, 1536, 2560, 3584, 6656, 2147483647}, + {1536, 4608, -1, 2147483647, 4608, 9728}, {1536, 4608, 2560, 2147483647, 9728, 11264}, + {1536, 4608, -1, 2147483647, 11264, 2147483647}, {4608, 2147483647, 2560, 2147483647, 4608, 13312}, + {4608, 9728, 1536, 2147483647, 13312, 2147483647}, {9728, 2147483647, 3072, 2147483647, 13312, 2147483647}}}, + {12.0, + {{-1, 4608, -1, 1536, 3584, 4608}, {4608, 7680, -1, 1536, -1, 4608}, + {-1, 1536, -1, 1536, 4608, 2147483647}}}, + {10.0, + {{7680, 2147483647, -1, 1536, -1, 4608}, {4608, 2147483647, -1, 1536, 4608, 2147483647}}}, + {6.0, + {{8704, 2147483647, 1536, 2560, -1, 4608}, {-1, 1536, 2560, 3584, 4608, 6656}, + {4608, 2147483647, 1536, 2560, 4608, 13312}, {9728, 2147483647, 1536, 3072, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093TwoRankINT8M0Map = { + {128, + {{-1, 1536, -1, 4096, -1, 1536}, {-1, 2560, -1, 2147483647, 1536, 9728}, + {2560, 3584, -1, 3584, 1536, 9728}, {3584, 2147483647, -1, 3584, -1, 9728}, + {3584, 2147483647, 3584, 2147483647, 6656, 9728}, {-1, 2147483647, -1, 2147483647, 9728, 11264}, + {3584, 2147483647, -1, 2147483647, 11264, 13312}, {-1, 3584, -1, 3584, 13312, 2147483647}, + {-1, 1536, 3584, 2147483647, 13312, 2147483647}, {3584, 2147483647, -1, 2147483647, 13312, 2147483647}}}, + {256, + {{-1, 1536, 4096, 2147483647, -1, 1536}, {1536, 2560, -1, 2147483647, -1, 1536}, + {2560, 3584, -1, 3584, -1, 1536}, {2560, 3584, 3584, 2147483647, -1, 9728}, + {3584, 2147483647, 3584, 2147483647, -1, 6656}, {-1, 3584, -1, 2147483647, 11264, 13312}, + {1536, 3584, 3584, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatter91093TwoRankINT8PvalueMap = { + {3, + {{-1, 3584, -1, 1536, -1, 1536}, {-1, 3584, 1536, 2560, 2560, 3584}, + {-1, 1536, 9216, 2147483647, 2560, 3584}, {3584, 4608, -1, 2560, -1, 1536}, + {7680, 8704, 4608, 5632, -1, 3584}, {-1, 1536, 1536, 3584, 3584, 2147483647}, + {2560, 3584, 3584, 4608, 3584, 2147483647}, {7680, 2147483647, 3584, 4608, 3584, 6656}, + {6656, 2147483647, 4608, 9728, 3584, 2147483647}}}, + {4, + {{-1, 3584, 1536, 2560, -1, 1536}, {4608, 2147483647, -1, 2560, -1, 1536}, + {3584, 2147483647, -1, 3584, 1536, 3584}, {-1, 1536, -1, 1536, 5632, 2147483647}, + {1536, 3584, 1536, 3584, 3584, 2147483647}, {3584, 2147483647, 3584, 4608, 6656, 2147483647}}}, + {2, + {{-1, 3584, 2560, 4096, -1, 1536}, {1536, 3584, -1, 4608, 1536, 2560}, + {-1, 3584, 2560, 4608, 2560, 3584}, {3584, 2147483647, 2560, 4608, -1, 1536}, + {3584, 2147483647, 3584, 4608, 1536, 3584}, {8704, 2147483647, 4608, 5632, -1, 3584}, + {-1, 2560, 3584, 5632, 3584, 2147483647}, {2560, 3584, 4608, 5632, 3584, 2147483647}, + {3584, 7680, 3584, 4608, 3584, 6656}, {3584, 6656, 4608, 2147483647, 6656, 2147483647}, + {6656, 2147483647, 9728, 2147483647, 3584, 2147483647}}}, + {1, + {{-1, 3584, 4096, 2147483647, -1, 1536}, {-1, 1536, -1, 4608, 1536, 2560}, + {-1, 3584, 4608, 2147483647, 1536, 2560}, {-1, 1536, 4608, 9216, 2560, 3584}, + {1536, 3584, 4608, 2147483647, 2560, 3584}, {3584, 7680, 4608, 2147483647, -1, 3584}, + {7680, 2147483647, 5632, 2147483647, -1, 3584}, {-1, 3584, 5632, 2147483647, 3584, 2147483647}, + {3584, 6656, 4608, 2147483647, 3584, 6656}}}, + {14, + {{-1, 3584, -1, 1536, 2560, 3584}, {2560, 2147483647, -1, 1536, 9728, 2147483647}}}, + {8, + {{-1, 1536, -1, 1536, 3584, 5632}, {3584, 7680, 1536, 2560, 3584, 2147483647}, + {3584, 2147483647, 2560, 3584, 11264, 2147483647}}}, + {6, + {{1536, 2560, -1, 1536, 3584, 2147483647}, {3584, 2147483647, 2560, 3584, 3584, 11264}}}, + {10, + {{2560, 7680, -1, 1536, 3584, 7680}}}, + {12, + {{7680, 2147483647, -1, 1536, 3584, 7680}, {2560, 2147483647, -1, 1536, 7680, 9728}, + {7680, 2147483647, 1536, 2560, 3584, 2147483647}}} + }; + + static std::map>> g_reducescatter91093FourRankFP16M0Map = { + {256, + {{-1, 1536, -1, 4096, -1, 1536}, {1536, 6656, 1536, 2147483647, -1, 1536}, + {-1, 5632, 3584, 2147483647, 1536, 2560}, {5632, 6656, 2560, 2147483647, 1536, 2560}, + {6656, 7680, -1, 7680, -1, 2560}, {7680, 2147483647, 3584, 7680, -1, 2560}, + {6656, 7680, 7680, 8704, -1, 2560}, {-1, 4608, 4608, 2147483647, 2560, 11264}, + {-1, 4608, 3584, 2147483647, 11264, 2147483647}, {4608, 6656, 3584, 9728, 2560, 2147483647}, + {9728, 2147483647, -1, 1536, 5120, 2147483647}}}, + {128, + {{-1, 1536, 4096, 2147483647, -1, 1536}, {1536, 6656, -1, 1536, -1, 1536}, + {-1, 5632, -1, 3584, 1536, 2560}, {5632, 6656, -1, 2560, 1536, 2560}, + {7680, 2147483647, -1, 3584, -1, 2560}, {7680, 2147483647, 7680, 8704, -1, 2560}, + {6656, 2147483647, 8704, 2147483647, -1, 2560}, {-1, 4608, -1, 4608, 2560, 11264}, + {-1, 4608, -1, 3584, 11264, 2147483647}, {4608, 6656, -1, 3584, 2560, 2147483647}, + {4608, 6656, 9728, 2147483647, 2560, 2147483647}, {6656, 9728, -1, 2147483647, 2560, 2147483647}, + {9728, 2147483647, -1, 1536, 2560, 5120}, {9728, 2147483647, 1536, 2147483647, 2560, 2147483647}}} + }; + + static std::map>> g_reducescatter91093FourRankFP16PvalueMap = { + {4, + {{-1, 2560, -1, 1536, -1, 1536}, {-1, 3584, -1, 1536, 1536, 2560}, + {3584, 4608, -1, 1536, -1, 3584}, {9728, 2147483647, -1, 2560, -1, 3584}, + {-1, 3584, -1, 1536, 3584, 7680}, {4608, 2147483647, 1536, 2560, 3584, 2147483647}}}, + {2, + {{2560, 3584, -1, 1536, -1, 1536}, {-1, 3584, 1536, 2560, 1536, 2560}, + {1536, 3584, 1536, 3584, 2560, 3584}, {3584, 4608, 1536, 3584, -1, 3584}, + {4608, 8704, -1, 3584, -1, 3584}, {9728, 2147483647, 2560, 3584, -1, 3584}, + {3584, 2147483647, 3584, 4608, 1536, 3584}, {-1, 1536, 1536, 2560, 3584, 2147483647}, + {-1, 2147483647, 2560, 4608, 3584, 9728}, {-1, 7680, 2560, 4608, 9728, 2147483647}, + {7680, 2147483647, 3584, 4608, 9728, 2147483647}, {3584, 8704, 4608, 9728, 3584, 2147483647}, + {9728, 2147483647, 4608, 8704, 3584, 2147483647}, {9728, 2147483647, 9728, 2147483647, 3584, 2147483647}}}, + {1, + {{-1, 3584, 1536, 2147483647, -1, 1536}, {-1, 3584, 2560, 2147483647, 1536, 2560}, + {-1, 1536, 1536, 2147483647, 2560, 3584}, {1536, 3584, 3584, 2147483647, 2560, 3584}, + {3584, 2147483647, 3584, 2147483647, -1, 1536}, {3584, 2147483647, 4608, 2147483647, 1536, 3584}, + {-1, 3584, 4608, 2147483647, 3584, 2147483647}, {8704, 9728, 4608, 9728, 3584, 2147483647}, + {9728, 2147483647, 8704, 9728, 3584, 2147483647}, {3584, 9728, 9728, 2147483647, 3584, 2147483647}}}, + {12, + {{-1, 3584, -1, 1536, 2560, 3584}}}, + {8, + {{8704, 9728, -1, 3584, -1, 3584}, {-1, 3584, -1, 1536, 7680, 2147483647}}}, + {10, + {{3584, 2147483647, -1, 1536, 3584, 2147483647}}}, + {3, + {{1536, 4608, 1536, 2560, 3584, 2147483647}, {7680, 2147483647, 2560, 3584, 9728, 2147483647}}} + }; + + static std::map>> g_reducescatter91093FourRankFP16UbmovenumMap = { + {12.0, + {{-1, 1536, -1, 4096, -1, 1536}, {1536, 2560, 2560, 4608, 1536, 2560}, + {2560, 3584, 3072, 7680, -1, 2560}, {2560, 3584, 9216, 2147483647, -1, 1536}, + {3584, 2147483647, 2560, 4608, -1, 2560}, {-1, 2560, 2560, 3584, 2560, 3584}, + {-1, 2560, 3584, 5120, 2560, 4608}, {-1, 1536, 5120, 2147483647, 2560, 4608}, + {2560, 2147483647, 1536, 2560, 2560, 7680}, {-1, 2147483647, 1536, 2560, 7680, 2147483647}}}, + {8.0, + {{-1, 1536, 4096, 2147483647, -1, 1536}, {-1, 2560, 4608, 7680, 1536, 2560}, + {-1, 2560, 8704, 2147483647, 1536, 2560}, {2560, 3584, 7680, 2147483647, 1536, 2560}, + {3584, 2147483647, 4608, 2147483647, -1, 2560}, {2560, 2147483647, 2560, 4608, 2560, 4608}, + {-1, 2560, 2560, 2147483647, 4608, 2147483647}, {2560, 2147483647, 2560, 4608, 4608, 2147483647}}}, + {16.0, + {{1536, 2560, -1, 2147483647, -1, 1536}, {-1, 1536, -1, 4608, 1536, 2560}, + {1536, 2560, -1, 2560, 1536, 2560}, {-1, 2560, 7680, 8704, 1536, 2560}, + {2560, 3584, -1, 3072, -1, 1536}, {3584, 2147483647, 1536, 2560, -1, 2560}, + {-1, 2560, -1, 2560, 2560, 3584}, {-1, 2560, -1, 3584, 3584, 4608}, + {-1, 2560, 1536, 2560, 4608, 7680}}}, + {20.0, + {{2560, 3584, -1, 3072, 1536, 2560}, {3584, 2147483647, -1, 1536, -1, 2560}, + {2560, 2147483647, -1, 1536, 2560, 4608}, {-1, 2147483647, -1, 1536, 4608, 2147483647}}}, + {10.0, + {{2560, 3584, 7680, 9216, -1, 1536}}}, + {4.0, + {{1536, 2560, 5120, 2147483647, 2560, 4608}, {2560, 2147483647, 13312, 2147483647, 3584, 4608}}}, + {6.0, + {{2560, 2147483647, 4608, 2147483647, 2560, 3584}, {2560, 2147483647, 4608, 13312, 3584, 4608}, + {2560, 2147483647, 4608, 2147483647, 4608, 2147483647}}} + }; + + void ReduceScatterNPU91093EightRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.pValue, + {REDUCESCATTER_91093_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_reducescatter91093EightRankFP16PvalueMap}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_91093_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_reducescatter91093EightRankFP16UbmovenumMap}}, + {&cocTilingData.m0, + {REDUCESCATTER_91093_EIGHT_RANK_FP16_M0_DEFAULT, + g_reducescatter91093EightRankFP16M0Map}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } + + void ReduceScatterNPU91093SixteenRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {REDUCESCATTER_91093_SIXTEEN_RANK_FP16_M0_DEFAULT, + g_reducescatter91093SixteenRankFP16M0Map}}, + {&cocTilingData.pValue, + {REDUCESCATTER_91093_SIXTEEN_RANK_FP16_PVALUE_DEFAULT, + g_reducescatter91093SixteenRankFP16PvalueMap}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_91093_SIXTEEN_RANK_FP16_UBMOVENUM_DEFAULT, + g_reducescatter91093SixteenRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } + + void ReduceScatterNPU91093TwoRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.pValue, + {REDUCESCATTER_91093_TWO_RANK_FP16_PVALUE_DEFAULT, + g_reducescatter91093TwoRankFP16PvalueMap}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_91093_TWO_RANK_FP16_UBMOVENUM_DEFAULT, + g_reducescatter91093TwoRankFP16UbmovenumMap}}, + {&cocTilingData.m0, + {REDUCESCATTER_91093_TWO_RANK_FP16_M0_DEFAULT, + g_reducescatter91093TwoRankFP16M0Map}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } + + void ReduceScatterNPU91093TwoRankINT8Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_91093_TWO_RANK_INT8_UBMOVENUM_DEFAULT, + g_reducescatter91093TwoRankINT8UbmovenumMap}}, + {&cocTilingData.m0, + {REDUCESCATTER_91093_TWO_RANK_INT8_M0_DEFAULT, + g_reducescatter91093TwoRankINT8M0Map}}, + {&cocTilingData.pValue, + {REDUCESCATTER_91093_TWO_RANK_INT8_PVALUE_DEFAULT, + g_reducescatter91093TwoRankINT8PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } + + void ReduceScatterNPU91093FourRankFP16Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {REDUCESCATTER_91093_FOUR_RANK_FP16_M0_DEFAULT, + g_reducescatter91093FourRankFP16M0Map}}, + {&cocTilingData.pValue, + {REDUCESCATTER_91093_FOUR_RANK_FP16_PVALUE_DEFAULT, + g_reducescatter91093FourRankFP16PvalueMap}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_91093_FOUR_RANK_FP16_UBMOVENUM_DEFAULT, + g_reducescatter91093FourRankFP16UbmovenumMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } +} diff --git a/comm/lcal/src/tiling/reducescatter_tiling_910B.cpp b/comm/lcal/src/tiling/reducescatter_tiling_910B.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f24cb8988ad90ee3810fc5525bd7195f76b4f7cc --- /dev/null +++ b/comm/lcal/src/tiling/reducescatter_tiling_910B.cpp @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tiling_910B.h" +#include "tiling_func.h" +namespace Lcal { + constexpr int32_t REDUCESCATTER_FOUR_RANK_INT8_PVALUE_DEFAULT = 14; + constexpr int32_t REDUCESCATTER_FOUR_RANK_INT8_UBMOVENUM_DEFAULT = 8; + constexpr int32_t REDUCESCATTER_FOUR_RANK_INT8_M0_DEFAULT = 128; + constexpr int32_t REDUCESCATTER_EIGHT_RANK_FP16_M0_DEFAULT = 128; + constexpr int32_t REDUCESCATTER_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT = 20; + constexpr int32_t REDUCESCATTER_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT = 16; + constexpr int32_t REDUCESCATTER_EIGHT_RANK_FP16_PVALUE_DEFAULT = 12; + + static std::map>> g_reducescatterFourRankINT8M0Map = { + {128, + {{-1, 2560, -1, 7680, -1, 1536}, {-1, 1536, 7680, 2147483647, -1, 1536}, + {1536, 2560, 8704, 2147483647, -1, 1536}, {3584, 2147483647, -1, 4608, -1, 1536}, + {8704, 2147483647, 4608, 5632, -1, 1536}, {2560, 3584, 5632, 2147483647, -1, 1536}, + {-1, 2147483647, -1, 2147483647, 1536, 2147483647}}}, + {256, + {{1536, 2560, 7680, 8704, -1, 1536}, {2560, 3584, -1, 4608, -1, 1536}, + {2560, 8704, 4608, 5632, -1, 1536}, {3584, 2147483647, 5632, 2147483647, -1, 1536}}} + }; + + static std::map>> g_reducescatterFourRankINT8UbmovenumMap = { + {8.0, + {{-1, 1536, -1, 7168, -1, 1536}, {-1, 1536, -1, 2560, 1536, 3584}, + {1536, 2147483647, -1, 1536, -1, 1536}, {1536, 2560, 1536, 4608, -1, 1536}}}, + {6.0, + {{-1, 1536, 7168, 2147483647, -1, 1536}, {-1, 1536, 2560, 2147483647, 1536, 3584}, + {-1, 1536, -1, 4608, 3584, 13312}, {1536, 2147483647, -1, 1536, 1536, 13312}, + {1536, 2560, 1536, 4608, 1536, 13312}, {2560, 2147483647, 1536, 4608, -1, 13312}, + {1536, 2560, 4608, 5632, -1, 6144}, {-1, 2147483647, -1, 4608, 13312, 2147483647}, + {5632, 6656, 9728, 2147483647, 13312, 2147483647}}}, + {4.0, + {{-1, 1536, 4608, 2147483647, 3584, 13312}, {1536, 2560, 4608, 5632, 6144, 13312}, + {2560, 2147483647, 4608, 5632, -1, 13312}, {1536, 2147483647, 5632, 2147483647, -1, 13312}, + {-1, 5632, 4608, 2147483647, 13312, 2147483647}, {5632, 2147483647, 4608, 9728, 13312, 2147483647}, + {6656, 2147483647, 9728, 2147483647, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatterFourRankINT8PvalueMap = { + {12, + {{-1, 1536, -1, 4096, -1, 1536}, {5632, 2147483647, -1, 2560, 3584, 5632}}}, + {1, + {{-1, 3584, 4096, 2147483647, -1, 1536}, {-1, 3584, 6656, 2147483647, 1536, 3584}, + {4608, 7680, 7680, 2147483647, -1, 3584}, {9728, 2147483647, 8192, 2147483647, -1, 1536}, + {-1, 1536, 6656, 9728, 3584, 2147483647}, {-1, 1536, 9728, 2147483647, 9728, 2147483647}, + {1536, 2560, 7680, 2147483647, 3584, 11264}}}, + {2, + {{1536, 3584, -1, 4096, -1, 1536}, {-1, 3584, -1, 6656, 1536, 3584}, + {3584, 4608, -1, 2147483647, -1, 2560}, {4608, 7680, 4608, 7680, -1, 3584}, + {7680, 9728, -1, 2147483647, -1, 1536}, {9728, 2147483647, -1, 8192, -1, 1536}, + {-1, 1536, 4608, 6656, 3584, 2147483647}, {-1, 1536, 9728, 2147483647, 3584, 9728}, + {1536, 2560, 5632, 7680, 3584, 2147483647}, {1536, 2560, 7680, 2147483647, 11264, 2147483647}}}, + {4, + {{3584, 4608, -1, 6144, 2560, 3584}, {4608, 7680, 1536, 4608, -1, 3584}, + {-1, 1536, 1536, 4608, 3584, 2147483647}, {1536, 2560, -1, 4608, 4608, 7680}, + {5632, 6656, 4608, 5632, 3584, 2147483647}, {6656, 8704, 4608, 2147483647, 6656, 2147483647}}}, + {3, + {{3584, 4608, 6144, 2147483647, 2560, 3584}, {7680, 8704, 4608, 2147483647, 1536, 3584}, + {8704, 2147483647, 5632, 2147483647, 1536, 3584}, {1536, 2560, -1, 4608, 3584, 4608}, + {1536, 2560, 4608, 5632, 3584, 2147483647}, {2560, 5632, 4608, 2147483647, 3584, 2147483647}, + {5632, 6656, 5632, 2147483647, 3584, 2147483647}, {6656, 8704, 4608, 2147483647, 3584, 6656}, + {8704, 2147483647, 4608, 2147483647, 3584, 2147483647}}}, + {8, + {{4608, 7680, -1, 1536, -1, 3584}, {2560, 5632, -1, 2560, 3584, 7680}, + {4608, 5632, 2560, 4608, 3584, 2147483647}, {5632, 2147483647, 2560, 4608, 3584, 9728}}}, + {6, + {{7680, 8704, -1, 4608, 1536, 3584}, {8704, 2147483647, -1, 5632, 1536, 3584}, + {-1, 1536, -1, 1536, 3584, 2147483647}, {1536, 2560, 1536, 4608, 7680, 2147483647}, + {2560, 4608, 2560, 4608, 3584, 2147483647}}}, + {10, + {{1536, 2560, -1, 1536, 7680, 2147483647}}}, + {14, + {{2560, 5632, -1, 2560, 7680, 2147483647}, {5632, 2147483647, -1, 2560, 5632, 2147483647}, + {5632, 2147483647, 2560, 4608, 9728, 2147483647}}} + }; + + static std::map>> g_reducescatterEightRankFP16PvalueMap = { + {2, + {{-1, 1536, -1, 2147483647, -1, 1536}, {1536, 5632, 1536, 2147483647, -1, 1536}, + {-1, 1536, -1, 2147483647, 1536, 2560}, {1536, 5632, 1536, 2147483647, 1536, 2560}, + {5632, 6656, 1536, 2560, -1, 1536}, {5632, 2147483647, 2560, 2147483647, -1, 2560}, + {-1, 4608, 1536, 2560, 2560, 4608}, {-1, 2147483647, 2560, 2147483647, 2560, 2147483647}}}, + {4, + {{1536, 6656, -1, 1536, -1, 2560}, {5632, 6656, 1536, 2560, 1536, 2560}, + {6656, 2147483647, 1536, 2560, -1, 2560}, {-1, 4608, -1, 1536, 2560, 5632}, + {-1, 4608, 1536, 2560, 4608, 5632}, {4608, 8704, -1, 2560, 2560, 3584}, + {8704, 2147483647, 1536, 2560, 2560, 5632}, {-1, 2560, -1, 2560, 5632, 2147483647}, + {2560, 2147483647, 1536, 2560, 5632, 2147483647}}}, + {6, + {{6656, 8704, -1, 1536, -1, 2560}}}, + {8, + {{8704, 2147483647, -1, 1536, -1, 2560}, {4608, 8704, -1, 2560, 3584, 5632}, + {2560, 6656, -1, 1536, 5632, 2147483647}}}, + {10, + {{8704, 2147483647, -1, 1536, 2560, 5632}}}, + {12, + {{6656, 2147483647, -1, 1536, 5632, 2147483647}}} + }; + + static std::map>> g_reducescatterEightRankFP16CommdatasplitMap = { + {16, + {{-1, 9728, -1, 2147483647, -1, 1536}, {9728, 2147483647, -1, 9728, -1, 1536}, + {-1, 2147483647, -1, 2147483647, 1536, 2147483647}}}, + {8, + {{9728, 2147483647, 9728, 2147483647, -1, 1536}}} + }; + + static std::map>> g_reducescatterEightRankFP16UbmovenumMap = { + {8.0, + {{-1, 1536, -1, 4096, -1, 1536}, {-1, 1536, 7168, 8704, -1, 1536}, + {1536, 2560, -1, 7680, -1, 1536}, {-1, 2560, 8704, 2147483647, -1, 1536}, + {2560, 2147483647, -1, 1536, -1, 1536}, {3584, 2147483647, 7680, 8704, -1, 1536}, + {6144, 2147483647, 8704, 9728, -1, 1536}, {2560, 3584, 9728, 2147483647, -1, 1536}, + {-1, 1536, -1, 3584, 1536, 2560}, {-1, 1536, -1, 5120, 5632, 7680}, + {1536, 2560, -1, 1536, 1536, 2147483647}, {1536, 2560, 9728, 2147483647, 11264, 2147483647}}}, + {10.0, + {{-1, 1536, 4096, 7168, -1, 1536}, {1536, 2560, 7680, 8704, -1, 1536}, + {2560, 2147483647, 1536, 7680, -1, 1536}, {2560, 3584, 7680, 8704, -1, 1536}, + {2560, 6144, 8704, 9728, -1, 1536}, {3584, 9728, 9728, 2147483647, -1, 1536}, + {-1, 1536, -1, 3584, 2560, 5632}, {-1, 1536, 3584, 2147483647, 1536, 5632}, + {-1, 1536, -1, 5120, 7680, 13312}, {-1, 1536, 5120, 2147483647, 5632, 13312}, + {-1, 1536, -1, 5120, 13312, 2147483647}, {-1, 1536, 7680, 2147483647, 13312, 2147483647}, + {2560, 2147483647, -1, 1536, 1536, 2147483647}, {1536, 2147483647, 1536, 9728, 1536, 2147483647}, + {1536, 2147483647, 9728, 2147483647, 1536, 11264}, {2560, 2147483647, 9728, 2147483647, 11264, 2147483647}}}, + {20.0, + {{9728, 2147483647, 9728, 2147483647, -1, 1536}}}, + {6.0, + {{-1, 1536, 5120, 7680, 13312, 2147483647}}} + }; + + static std::map>> g_reducescatterEightRankFP16M0Map = { + {128, + {{-1, 5632, -1, 2147483647, -1, 1536}, {5632, 8704, -1, 3584, -1, 1536}, + {-1, 8704, -1, 2147483647, 1536, 7680}, {8704, 2147483647, -1, 2147483647, -1, 7680}, + {-1, 2147483647, -1, 3584, 7680, 2147483647}, {-1, 1536, 3584, 2147483647, 7680, 2147483647}, + {2560, 2147483647, 3584, 2147483647, 7680, 2147483647}}}, + {256, + {{5632, 8704, 3584, 2147483647, -1, 1536}, {1536, 2560, 3584, 2147483647, 7680, 2147483647}}} + }; + + void ReduceScatterFourRankINT8Tiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.m0, + {REDUCESCATTER_FOUR_RANK_INT8_M0_DEFAULT, + g_reducescatterFourRankINT8M0Map}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_FOUR_RANK_INT8_UBMOVENUM_DEFAULT, + g_reducescatterFourRankINT8UbmovenumMap}}, + {&cocTilingData.pValue, + {REDUCESCATTER_FOUR_RANK_INT8_PVALUE_DEFAULT, + g_reducescatterFourRankINT8PvalueMap}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {DEFAULT_SWIZZLE_COUNT}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + {&cocTilingData.commDataSplit, {COMMDATASPLIT_SIXTEEN}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } + + void ReduceScatterEightRankFP16GetDefaultTiling(CoCTilingData &cocTilingData) + { + std::map TilingParamMap = { + {&cocTilingData.pValue, + {REDUCESCATTER_EIGHT_RANK_FP16_PVALUE_DEFAULT, + g_reducescatterEightRankFP16PvalueMap}}, + {&cocTilingData.commDataSplit, + {REDUCESCATTER_EIGHT_RANK_FP16_COMMDATASPLIT_DEFAULT, + g_reducescatterEightRankFP16CommdatasplitMap}}, + {&cocTilingData.ubMoveNum, + {REDUCESCATTER_EIGHT_RANK_FP16_UBMOVENUM_DEFAULT, + g_reducescatterEightRankFP16UbmovenumMap}}, + {&cocTilingData.m0, + {REDUCESCATTER_EIGHT_RANK_FP16_M0_DEFAULT, + g_reducescatterEightRankFP16M0Map}}, + {&cocTilingData.swizzlDirect, {SWIZZLE_DIRECT_ONE}}, + {&cocTilingData.swizzlCount, {SWIZZLE_COUNT_FOUR}}, + {&cocTilingData.commDirect, {COMM_DATA_DIRECT}}, + {&cocTilingData.commNpuSplit, {COMMNPUSPLIT_ONE}}, + }; + SetTilingParam(cocTilingData, TilingParamMap); + + cocTilingData.lenPerLoop = cocTilingData.ubMoveNum; + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/tiling.cpp b/comm/lcal/src/tiling/tiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1d74aa6d039448d67a83750ddec057bf56ad649d --- /dev/null +++ b/comm/lcal/src/tiling/tiling.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "tiling_func.h" +#include "mki/utils/log/log.h" +#include "tiling.h" +namespace Lcal { +CoCTilingData CoCTilingFunc::GenerateTiling(const TaskParam &taskParam, const CoCTiling &tiling) +{ + SetTilingInputParam(taskParam, cocTilingData); + + cocTilingData.SetDefaultValue(); + + this->GetDefaultTiling(taskParam); + + // 设置Tiling策略参数 + SetTilingData(taskParam, tiling, cocTilingData); + + return cocTilingData; +} + +bool CoCTilingFunc::CheckTiling(const TaskParam &taskParam) +{ + (void) taskParam; + return CheckCoCTilingData(cocTilingData); +} + +void CoCTilingFunc::GetDefaultTiling(const TaskParam &taskParam) +{ + (void) taskParam; + cocTilingData.ubMoveNum = VALID_UB_MOVE_NUM; + cocTilingData.commNpuSplit = cocTilingData.rankSize; + cocTilingData.commDataSplit = COMMDATASPLIT_ONE; + cocTilingData.commDirect = COMM_DATA_DIRECT; + cocTilingData.lenPerLoop = LENPERLOOP_DEFAULT; +} +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/tiling_args.cpp b/comm/lcal/src/tiling/tiling_args.cpp new file mode 100644 index 0000000000000000000000000000000000000000..abe4dde67748e1f7a681755a638b47e4570435e3 --- /dev/null +++ b/comm/lcal/src/tiling/tiling_args.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include +#include "tiling_args.h" + +namespace Lcal { + const char* CoCTilingData::ToString() const + { + std::string str = + "[CoCTilingData]: \nm=" + std::to_string(m) + ", k=" + std::to_string(k) + ", n=" + std::to_string(n) + + ", batchSize=" + std::to_string(batchSize) + + ", \nblockDim=" + std::to_string(blockDim) + ", rank=" + std::to_string(rank) + + ", rankSize=" + std::to_string(rankSize) + ", tag=" + std::to_string(tag) + + ", \nmLoop=" + std::to_string(mLoop) + ", kLoop=" + std::to_string(kLoop) + + ", nLoop=" + std::to_string(nLoop) + ", coreLoop=" + std::to_string(coreLoop) + + ", tilingKey=" + std::to_string(tilingKey) + + ", \nm0=" + std::to_string(m0) + ", k0=" + std::to_string(k0) + ", n0=" + std::to_string(n0) + + ", swizzlCount=" + std::to_string(swizzlCount) + ", swizzlDirect=" + std::to_string(swizzlDirect) + + ", pValue=" + std::to_string(pValue) + ", ubMoveNum=" + std::to_string(ubMoveNum) + + ", commNpuSplit=" + std::to_string(commNpuSplit) + ", commDataSplit=" + std::to_string(commDataSplit) + + ", commDirect=" + std::to_string(commDirect) + ", lenPerLoop=" + std::to_string(lenPerLoop) + + ", \nextraUbMoveNum=" + std::to_string(extraUbMoveNum) + + ", extraCommNpuSplit=" + std::to_string(extraCommNpuSplit) + + ", extraCommDataSplit=" + std::to_string(extraCommDataSplit) + + ", extraCommDirect=" + std::to_string(extraCommDirect) + + ", extraLenPerLoop=" + std::to_string(extraLenPerLoop) + ", \nsplitK=" + std::to_string(splitK) + + ", write2OtherRank=" + std::to_string(write2OtherRank) + + ", withSerialMode=" + std::to_string(withSerialMode) + ", \nis_91093=" + std::to_string(is91093); + return str.data(); + } + + void CoCTilingData::SetDefaultValue() + { + m0 = m < n ? DEFAULT_COL : DEFAULT_ROW; + k0 = DEFAULT_COL; + n0 = m0 == DEFAULT_COL ? DEFAULT_ROW : DEFAULT_COL; + swizzlCount = DEFAULT_SWIZZLE_COUNT; + swizzlDirect = m > n ? SWIZZLE_DIRECT_ZERO : SWIZZLE_DIRECT_ONE; + pValue = DEFAULT_P_VALUE; + ubMoveNum = MAX_UB_NUM; + commNpuSplit = COMMNPUSPLIT_ONE; + commDataSplit = rankSize; + commDirect = COMM_DATA_DIRECT; + lenPerLoop = LENPERLOOP_DEFAULT; + extraUbMoveNum = ubMoveNum; + extraCommNpuSplit = commNpuSplit; + extraCommDataSplit = commDataSplit; + extraCommDirect = commDirect; + extraLenPerLoop = lenPerLoop; + splitK = DEFAULT_SPLIT_K; + write2OtherRank = false; + withSerialMode = false; + is91093 = false; + tag = 0; + } +} \ No newline at end of file diff --git a/comm/lcal/src/tiling/tiling_func.cpp b/comm/lcal/src/tiling/tiling_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d99cb3f26c9bacbcff405f96d8a48b04f74928ae --- /dev/null +++ b/comm/lcal/src/tiling/tiling_func.cpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "lcoc_func.h" +#include "lcoc_args.h" +#include "tiling_args.h" +#include "tiling_func.h" + +namespace Lcal { + int32_t CeilDev(int32_t num, int32_t div) + { + if (div == 0) { + return 0; + } + return (num + div - 1) / div; + } + + int32_t RoundNum(int32_t num, int32_t rnd) + { + if (rnd == 0) { + return 0; + } + return (num + rnd - 1) / rnd * rnd; + } + + void UpdateTilingValue(const int32_t &tilingParam, int32_t &tilingDataParam) + { + if (tilingParam != INPUT_PARAM_DEFAULT_VALUE) { + tilingDataParam = tilingParam; + } + } + + double GetMTETime(double mknGB, int32_t m0, int32_t n0, double aBindWidth, double bBindWidth) + { + // 预估Matmul计算的MTE2搬运时间 + return DOUBLE * mknGB * (SECOND_TO_MS / ONE_K) * (1.0 / (n0 * aBindWidth) + 1.0 / (m0 * bBindWidth)); + } + + int32_t GetValueFromMKNConditionMap(int32_t m, int32_t k, int32_t n, + int32_t defaultValue, + std::map>> conditionMap) + { + int32_t value = defaultValue; + for (auto iter = conditionMap.cbegin(); iter != conditionMap.cend(); ++iter) { + for (auto &condition : iter->second) { + bool inRange = + m > condition[CONDITION_M_ST] && m <= condition[CONDITION_M_END] && + k > condition[CONDITION_K_ST] && k <= condition[CONDITION_K_END] && + n > condition[CONDITION_N_ST] && n <= condition[CONDITION_N_END]; + if (inRange) { + return iter->first; + } + } + } + return value; + } + + bool Is910B(const ChipName &chipName) + { + return chipName >= ChipName::CHIP_910B1 && chipName <= ChipName::CHIP_910B41; + } + + bool Is91093(const ChipName &chipName) + { + return chipName >= ChipName::CHIP_910_9391 && chipName <= ChipName::CHIP_910_9362; + } + + uint32_t GetTilingKey(const MatMulInfo &mmInfo, CoCTilingData &tilingData) + { + uint32_t tilingKey = static_cast(tilingData.swizzlDirect); // 32 + tilingKey = (static_cast(tilingKey) << 1) + static_cast(mmInfo.transA); // 16 + tilingKey = (static_cast(tilingKey) << 1) + static_cast(mmInfo.transB); // 8 + tilingKey = (static_cast(tilingKey) << 1) + static_cast(mmInfo.isInt8); // 4 + tilingKey = (static_cast(tilingKey) << 1) + static_cast(mmInfo.withBias); // 2 + tilingKey = (static_cast(tilingKey) << 1) + static_cast(tilingData.splitK); // 1 + return tilingKey; + } + + void DealTilingParamByBuffSize(CoCTilingData &cocTilingData) + { + auto blockCount = (cocTilingData.is91093 != 0) ? BLOCK_COUNT_3 : MAX_BLOCK_COUNT; + int maxPeerMemPerRank = + (cocTilingData.bufferSize * 1024 * 1024) / INPUT_DTYPE / cocTilingData.rankSize / blockCount; + int maxPValue = maxPeerMemPerRank / cocTilingData.m0 / cocTilingData.k0 / cocTilingData.kLoop; + cocTilingData.pValue = ClampValue(cocTilingData.pValue, MIN_P_VALUE, maxPValue); + + if (cocTilingData.m0 == DEFAULT_COL + && cocTilingData.pValue * cocTilingData.m0 * cocTilingData.k0 * cocTilingData.kLoop >= maxPeerMemPerRank) { + cocTilingData.m0 = DEFAULT_ROW; + cocTilingData.n0 = DEFAULT_COL; + cocTilingData.mLoop = CeilDev(cocTilingData.m, cocTilingData.m0); + cocTilingData.nLoop = CeilDev(cocTilingData.n, cocTilingData.n0); + } + } + + int ClampValue(int32_t value, int32_t min, int32_t max) + { + return std::max(min, std::min(value, max)); + } + + void SetTilingParam(CoCTilingData &cocTilingData, const std::map& tilingParamMap) + { + int32_t m = cocTilingData.m; + int32_t k = cocTilingData.k; + int32_t n = cocTilingData.n; + + for (auto &item : tilingParamMap) { + auto value = item.second.value; + auto conditionMap = item.second.conditionMap; + if (!conditionMap.empty()) { + *item.first = GetValueFromMKNConditionMap(m, k, n, value, conditionMap); + } else if (value != -1) { + *item.first = value; + } + } + + cocTilingData.ubMoveNum = cocTilingData.ubMoveNum * HALF_KBYTE; + if (cocTilingData.m0 >= DEFAULT_ROW) { + cocTilingData.k0 = DEFAULT_COL; + cocTilingData.n0 = cocTilingData.m0 == DEFAULT_ROW ? DEFAULT_COL : DEFAULT_ROW; + cocTilingData.mLoop = CeilDev(cocTilingData.m, cocTilingData.m0); + cocTilingData.nLoop = CeilDev(cocTilingData.n, cocTilingData.n0); + cocTilingData.kLoop = CeilDev(cocTilingData.k, cocTilingData.k0); + } + } + + void SetSecondCoreSplitTling(CoCTilingData &cocTilingData) + { + cocTilingData.extraCommDirect = cocTilingData.commDirect; + cocTilingData.extraCommNpuSplit = cocTilingData.commNpuSplit; + cocTilingData.extraCommDataSplit = cocTilingData.commDataSplit; + cocTilingData.extraLenPerLoop = cocTilingData.lenPerLoop; + cocTilingData.extraUbMoveNum = cocTilingData.ubMoveNum; + } + + void SetTilingParam2D(CoCTilingData &cocTilingData, const std::map& tilingParamMap) + { + SetTilingParam(cocTilingData, tilingParamMap); + + cocTilingData.extraUbMoveNum = cocTilingData.extraUbMoveNum * HALF_KBYTE; + cocTilingData.lenPerLoop = cocTilingData.lenPerLoop * cocTilingData.ubMoveNum / DIV_TWO; + cocTilingData.extraLenPerLoop = cocTilingData.extraLenPerLoop * cocTilingData.extraUbMoveNum / DIV_TWO; + } + + std::map GetCoCTilingPowerOfTwoParamMap() + { + std::map powerOfTwoParamMap = { + {"commDataSplit", true}, + {"extraCommDataSplit", true} + }; + return powerOfTwoParamMap; + } + + std::map GetCoCTilingAlignParamMap() + { + std::map alignParamMap = { + {"m0", BLOCK_SIZE}, + {"n0", BLOCK_SIZE}, + {"k0", BLOCK_SIZE}, + {"ubMoveNum", HALF_KBYTE}, + {"lenPerLoop", HALF_KBYTE}, + {"extraUbMoveNum", HALF_KBYTE}, + {"extraLenPerLoop", HALF_KBYTE} + }; + return alignParamMap; + } + + std::vector> GetCoCTilingParamCheckList(const CoCTiling &tiling) + { + std::vector> paramCheckList = { + {"m0", tiling.m0, BLOCK_SIZE, CUBE_BLOCK_SIZE}, + {"n0", tiling.n0, BLOCK_SIZE, CUBE_BLOCK_SIZE}, + {"k0", tiling.k0, CUBE_BLOCK_SIZE, AXES_ALIGN_SIZE}, + {"swizzlCount", tiling.swizzlCount, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"swizzlDirect", tiling.swizzlDirect, SWIZZLE_DIRECT_ZERO, SWIZZLE_DIRECT_ONE}, + {"ubMoveNum", tiling.ubMoveNum, HALF_KBYTE, MAX_UB_NUM}, + {"commNpuSplit", tiling.commNpuSplit, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"commDataSplit", tiling.commDataSplit, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"commDirect", tiling.commDirect, COMM_DATA_DIRECT, COMM_NPU_DIRECT}, + {"lenPerLoop", tiling.lenPerLoop, HALF_KBYTE, PARAM_CHECK_MAX_VALUE}, + {"extraUbMoveNum", tiling.extraUbMoveNum, HALF_KBYTE, MAX_UB_NUM}, + {"extraCommNpuSplit", tiling.extraCommNpuSplit, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"extraCommDataSplit", tiling.extraCommDataSplit, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"extraCommDirect", tiling.extraCommDirect, COMM_DATA_DIRECT, COMM_NPU_DIRECT}, + {"extraLenPerLoop", tiling.extraLenPerLoop, HALF_KBYTE, PARAM_CHECK_MAX_VALUE}, + {"splitK", tiling.splitK, PARAM_CHECK_MIN_VALUE_ZERO, PARAM_CHECK_MAX_VALUE}, + {"write2OtherRank", tiling.write2OtherRank, PARAM_CHECK_MIN_VALUE_ZERO, PARAM_CHECK_MIN_VALUE_ONE}, + {"withSerialMode", tiling.withSerialMode, PARAM_CHECK_MIN_VALUE_ZERO, PARAM_CHECK_MIN_VALUE_ONE}, + {"is91093", tiling.is91093, PARAM_CHECK_MIN_VALUE_ZERO, PARAM_CHECK_MIN_VALUE_ONE} + }; + return paramCheckList; + } + + bool CheckCoCTiling(const CoCTiling &tiling) + { + auto powerOfTwoParamMap = GetCoCTilingPowerOfTwoParamMap(); + auto alignParamMap = GetCoCTilingAlignParamMap(); + auto paramCheckList = GetCoCTilingParamCheckList(tiling); + for (auto ¶m : paramCheckList) { + auto name = std::get<0>(param); + auto value = std::get<1>(param); + auto min = std::get<2>(param); + auto max = std::get<3>(param); + if (value == INPUT_PARAM_DEFAULT_VALUE) { + continue; + } + if (!CheckParamScope(name, value, min, max)) { + return false; + } + if (alignParamMap.find(name) != alignParamMap.end() + && !CheckParamAlign(name, value, alignParamMap[name])) { + return false; + } + if (powerOfTwoParamMap.find(name) != powerOfTwoParamMap.end() + && !CheckParamPowerOfTwo(name, value)) { + return false; + } + } + return true; + } + + bool CheckCoCTilingData(const CoCTilingData &tilingData) + { + if (!CheckCoCTiling(tilingData)) { + return false; + } + std::vector> paramCheckList = { + {"mLoop", tilingData.mLoop, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"kLoop", tilingData.kLoop, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"nLoop", tilingData.nLoop, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"coreLoop", tilingData.coreLoop, PARAM_CHECK_MIN_VALUE_ONE, PARAM_CHECK_MAX_VALUE}, + {"tilingKey", tilingData.tilingKey, PARAM_CHECK_MIN_VALUE_ZERO, PARAM_CHECK_MAX_VALUE}, + }; + return CheckParamScopeList(paramCheckList); + } + + void TransformCoCTiling(const CoCTiling &tiling, CoCTilingData &tilingData) + { + int* tilingPtr = reinterpret_cast(const_cast(&tiling)); + int* tilingDataPtr = reinterpret_cast(&tilingData); + int length = sizeof(tiling) / sizeof(int32_t); + for (int i = 0; i < length; i++) { + UpdateTilingValue(tilingPtr[i], tilingDataPtr[i]); + } + } + + void CalTilingParam(const MatMulInfo &mmInfo, CoCTilingData &tilingData) + { + // 计算 + tilingData.mLoop = CeilDev(tilingData.m, tilingData.m0); + tilingData.kLoop = CeilDev(tilingData.k, tilingData.k0); + tilingData.nLoop = CeilDev(tilingData.n, tilingData.n0); + tilingData.coreLoop = tilingData.batchSize * tilingData.mLoop * tilingData.nLoop; + tilingData.tilingKey = GetTilingKey(mmInfo, tilingData); + // 对齐 + tilingData.ubMoveNum = RoundNum(tilingData.ubMoveNum, HALF_KBYTE); + tilingData.lenPerLoop = RoundNum(tilingData.lenPerLoop, HALF_KBYTE); + tilingData.extraUbMoveNum = RoundNum(tilingData.extraUbMoveNum, HALF_KBYTE); + tilingData.extraLenPerLoop = RoundNum(tilingData.extraLenPerLoop, HALF_KBYTE); + } + + void SetTilingInputParam(const TaskParam &taskParam, CoCTilingData &tilingData) + { + tilingData.m = taskParam.cocParamDesc.mmInfo.m; + tilingData.n = taskParam.cocParamDesc.mmInfo.n; + tilingData.k = taskParam.cocParamDesc.mmInfo.k; + tilingData.batchSize = taskParam.cocParamDesc.mmInfo.batchSize; + tilingData.blockDim = taskParam.blockDim; + tilingData.rank = taskParam.rank; + tilingData.rankSize = taskParam.rankSize; + tilingData.bufferSize = taskParam.bufferSize; + } + + void SetTilingData(const TaskParam &taskParam, const CoCTiling &tiling, CoCTilingData &tilingData) + { + // 输入Tiling赋值给Tiling策略的参数 + TransformCoCTiling(tiling, tilingData); + // 根据最终的Tiling策略参数,计算mLoop等参数 + CalTilingParam(taskParam.cocParamDesc.mmInfo, tilingData); + } +} \ No newline at end of file diff --git a/comm/lcal/src/tools/socket/lcal_sock_exchange.cpp b/comm/lcal/src/tools/socket/lcal_sock_exchange.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ba586718fc12dbd18a5bbb1212dc89502ccb88af --- /dev/null +++ b/comm/lcal/src/tools/socket/lcal_sock_exchange.cpp @@ -0,0 +1,419 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "lcal_sock_exchange.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include + +using namespace std; +namespace Lcal { +const string LCAL_LOCAL_SOCK_IP = "127.0.0.1"; +constexpr uint16_t LCAL_DEFAULT_SOCK_PORT = 10067; +constexpr uint32_t LCAL_MAX_BACK_LOG = 65535; + +int ParseIpAndPort(const char* input, string &ip, uint16_t &port) +{ + if (input == nullptr) { + return LCAL_INVALID_VALUE; + } + string inputStr(input); + size_t colonPos = inputStr.find(':'); + if (colonPos == string::npos) { + MKI_LOG(ERROR) << "Input string does not contain a colon separating IP and port."; + return LCAL_ERROR_INTERNAL; + } + + ip = inputStr.substr(0, colonPos); + std::string portStr = inputStr.substr(colonPos + 1); + + std::istringstream portStream(portStr); + portStream >> port; + if (portStream.fail() || portStream.bad()) { + MKI_LOG(ERROR) << "Invalid port number."; + return LCAL_ERROR_INTERNAL; + } + return LCAL_SUCCESS; +} + +LcalSockExchange::~LcalSockExchange() +{ + Cleanup(); +} + +LcalSockExchange::LcalSockExchange(int rank, int rankSize, std::vector &rankList, int commDomain) + : rank_(rank), rankSize_(rankSize), rankList_(rankList), commDomain_(commDomain) +{ +} + +LcalSockExchange::LcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId) + : rank_(rank), rankSize_(rankSize) +{ + lcalCommId_.uid = lcalCommId; +} + +int LcalSockExchange::GetNodeNum() +{ + if (!isInit_ && Prepare() != LCAL_SUCCESS) { + return LCAL_ERROR_INTERNAL; + } + isInit_ = true; + const string filePath = "/proc/sys/kernel/random/boot_id"; + ifstream fileStream(filePath); + stringstream buffer; + if (fileStream) { + buffer << fileStream.rdbuf(); + fileStream.close(); + } + const std::string uuid = buffer.str(); + MKI_LOG(DEBUG) << "rank:" << rank_ << " UUID " << uuid; + + set uuidSet {}; + uuidSet.insert(uuid); + int nodeNum = -1; + if (IsServer()) { + for (int i = 1; i < rankSize_; ++i) { + if (Recv(clientFds_[i], const_cast<__caddr_t>(uuid.data()), uuid.size(), 0) <= 0) { + MKI_LOG(ERROR) << "Server side recv rank " << i << " buffer failed"; + return LCAL_ERROR_INTERNAL; + } + uuidSet.insert(uuid); + } + nodeNum = static_cast(uuidSet.size()); + for (int i = 1; i < rankSize_; ++i) { + if (Send(clientFds_[i], &nodeNum, sizeof(int), 0) <= 0) { + MKI_LOG(ERROR) << "Server side send rank " << i << " buffer failed"; + return LCAL_ERROR_INTERNAL; + } + } + } else { + if (Send(fd_, uuid.data(), uuid.size(), 0) <= 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " send buffer failed"; + return LCAL_ERROR_INTERNAL; + } + if (Recv(fd_, &nodeNum, sizeof(int), 0) <= 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " recv buffer failed "; + return LCAL_ERROR_INTERNAL; + } + } + return nodeNum; +} + +void LcalSockExchange::GetIpAndPort() +{ + const char* env = Mki::GetEnv("LCAL_COMM_ID"); + + if (env == nullptr or ParseIpAndPort(env, ip_, port_) != LCAL_SUCCESS) { + ip_ = LCAL_LOCAL_SOCK_IP; + port_ = LCAL_DEFAULT_SOCK_PORT; + } + port_ += commDomain_; + lcalCommId_.handle.addr.sin.sin_family = AF_INET; + lcalCommId_.handle.addr.sin.sin_addr.s_addr = inet_addr(LCAL_LOCAL_SOCK_IP.c_str()); + lcalCommId_.handle.addr.sin.sin_port = htons(port_); + MKI_LOG(DEBUG) << "curRank: " << rank_ << " commDomain: " << commDomain_ << " ip: " << ip_ << " port: " << port_; +} + +int LcalSockExchange::Prepare() +{ + if (lcalCommId_.handle.magic != LCAL_MAGIC) { + GetIpAndPort(); + } + if (!IsServer()) { + if (ip_ != LCAL_LOCAL_SOCK_IP) { + MKI_LOG(ERROR) << "Multi-machine is not supported at the moment"; + return LCAL_ERROR_INTERNAL; + } + return Connect(); + } + + clientFds_.resize(rankSize_, -1); + if (Listen() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "Listen Failed!"; + return LCAL_ERROR_INTERNAL; + } + + if (Accept() != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "Accept Failed!"; + return LCAL_ERROR_INTERNAL; + } + + return LCAL_SUCCESS; +} + +int LcalSockExchange::Listen() +{ + fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (fd_ < 0) { + MKI_LOG(ERROR) << "Server side create socket failed"; + return LCAL_ERROR_INTERNAL; + } + + int reuse = 1; + if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(int)) < 0) { + MKI_LOG(ERROR) << "Server side set reuseaddr failed"; + return LCAL_ERROR_INTERNAL; + } + + struct sockaddr *addrPtr = &lcalCommId_.handle.addr.sa; + if (bind(fd_, addrPtr, sizeof(struct sockaddr)) < 0) { + MKI_LOG(ERROR) << "Server side bind " << ntohs(lcalCommId_.handle.addr.sin.sin_port) << " failed"; + return LCAL_ERROR_INTERNAL; + } + + if (listen(fd_, LCAL_MAX_BACK_LOG) < 0) { + MKI_LOG(ERROR) << "Server side listen " << ntohs(lcalCommId_.handle.addr.sin.sin_port) << " failed"; + return LCAL_ERROR_INTERNAL; + } + MKI_LOG(INFO) << "The server is listening! ip: "<< inet_ntoa(lcalCommId_.handle.addr.sin.sin_addr) + << " port: " << ntohs(lcalCommId_.handle.addr.sin.sin_port); + + return LCAL_SUCCESS; +} + +int LcalSockExchange::AcceptConnection(int fd, sockaddr_in& clientAddr, socklen_t *sinSize) const +{ + int clientFd; + LcalSocketAddress clientAddrPtr; + clientAddrPtr.sin = clientAddr; + + do { + clientFd = accept(fd, &clientAddrPtr.sa, sinSize); + if (clientFd < 0) { + if (!CheckErrno(errno)) { + MKI_LOG(ERROR) << "Server side accept failed" << strerror(errno); + return -1; + } + MKI_LOG(DEBUG) << "accept failed: " << strerror(errno); + continue; + } + break; + } while (true); + + return clientFd; +} + +int LcalSockExchange::Accept() +{ + struct sockaddr_in clientAddr; + socklen_t sinSize = sizeof(struct sockaddr_in); + + for (int i = 1; i < rankSize_; ++i) { + int fd = AcceptConnection(fd_, clientAddr, &sinSize); + if (fd < 0) { + MKI_LOG(ERROR) << "AcceptConnection failed"; + return LCAL_ERROR_INTERNAL; + } + + int rank = 0; + if (Recv(fd, &rank, sizeof(rank), 0) <= 0) { + MKI_LOG(ERROR) << "Server side recv rank id failed"; + return LCAL_ERROR_INTERNAL; + } + + if (rank >= rankSize_ || rank <= 0 || clientFds_[rank] >= 0) { + MKI_LOG(ERROR) << "Server side recv invalid rank id " << rank; + return LCAL_ERROR_INTERNAL; + } + + MKI_LOG(DEBUG) << "Server side recv rank id " << rank; + clientFds_[rank] = fd; + } + + return LCAL_SUCCESS; +} + +void LcalSockExchange::Close(int &fd) const +{ + if (fd == -1) { + return; + } + + if (close(fd) < 0) { + MKI_LOG(WARN) << "failed to close fd:" << fd; + return; + } + + fd = -1; +} + +int LcalSockExchange::Connect() +{ + MKI_LOG(DEBUG) << "Client side " << rank_ << " begin to connect"; + + fd_ = socket(AF_INET, SOCK_STREAM, 0); + if (fd_ < 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " create socket failed"; + return LCAL_ERROR_INTERNAL; + } + + int sleepTimeS = 1; + int maxRetryCount = 180; + int retryCount = 0; + bool success = false; + struct sockaddr *addrPtr = &lcalCommId_.handle.addr.sa; + while (retryCount < maxRetryCount) { + if (connect(fd_, addrPtr, sizeof(struct sockaddr)) < 0) { + if (errno == ECONNREFUSED) { + MKI_LOG(DEBUG) << "Client side " << rank_ << " try connect " << (retryCount + 1) << " times refused"; + retryCount++; + sleep(sleepTimeS); + continue; + } + if (errno != EINTR) { + MKI_LOG(ERROR) << "Client side " << rank_ << " connect failed: " << strerror(errno); + break; + } + MKI_LOG(DEBUG) << "Client side " << rank_ << " try connect failed: " << strerror(errno); + continue; + } + success = true; + break; + } + + if (!success) { + MKI_LOG(ERROR) << "Client side " << rank_ << " connect failed"; + return LCAL_ERROR_INTERNAL; + } + + if (Send(fd_, &rank_, sizeof(rank_), 0) <= 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " send rank failed"; + return LCAL_ERROR_INTERNAL; + } + + return LCAL_SUCCESS; +} + +bool LcalSockExchange::IsServer() const +{ + return rank_ == 0; +} + +void LcalSockExchange::Cleanup() +{ + if (fd_ >= 0) { + Close(fd_); + } + + if (clientFds_.empty()) { + return; + } + + for (int i = 1; i < rankSize_; ++i) { + if (clientFds_[i] >= 0) { + Close(clientFds_[i]); + } + } + if (pid_ > 0) { + kill(pid_, SIGINT); + int status; + waitpid(pid_, &status, 0); + MKI_LOG(DEBUG) << "child process resources cleaned up"; + } +} + +int GetAddrFromString(LcalSocketAddress* ua, const char* ipPortPair) +{ + std::string ip; + uint16_t port; + int ret = ParseIpAndPort(ipPortPair, ip, port); + if (ret != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lcal ParseIpAndPort failed!"; + return LCAL_ERROR_INTERNAL; + } + ua->sin.sin_family = AF_INET; + ua->sin.sin_addr.s_addr = inet_addr(ip.c_str()); + ua->sin.sin_port = htons(port); + return LCAL_SUCCESS; +} + +int BootstrapGetServerIp(LcalSocketAddress& handle) +{ + char hostname[256]; + + if (gethostname(hostname, sizeof(hostname)) < 0) { + MKI_LOG(ERROR) << "ERROR: Failed to get hostname."; + return LCAL_ERROR_INTERNAL; + } + + struct hostent *hostEntry = gethostbyname(hostname); + if (hostEntry == nullptr) { + MKI_LOG(ERROR) << "ERROR: Failed to get host entry." ; + return LCAL_ERROR_INTERNAL; + } + + const char* ip = inet_ntoa(*reinterpret_cast(hostEntry->h_addr_list[0])); + if (ip == nullptr) { + MKI_LOG(ERROR) << "ERROR: Failed to convert IP address."; + return LCAL_ERROR_INTERNAL; + } + + auto ret = memset_s(&handle, sizeof(handle), 0, sizeof(handle)); + if (ret != EOK) { + MKI_LOG(ERROR) << "Failed to memset_s handle in BootstrapGetServerIp."; + return LCAL_ERROR_INTERNAL; + } + handle.sin.sin_family = AF_INET; + handle.sin.sin_addr.s_addr = inet_addr(ip); + handle.sin.sin_port = 0; + + return LCAL_SUCCESS; +} + +int BootstrapGetUniqueId(struct LcalBootstrapHandle& handle, int commDomain) +{ + auto ret = memset_s(&handle, sizeof(LcalBootstrapHandle), 0, sizeof(LcalBootstrapHandle)); + if (ret != EOK) { + MKI_LOG(ERROR) << "Failed to memset_s handle in BootstrapGetUniqueId."; + return LCAL_ERROR_INTERNAL; + } + + const char* env = Mki::GetEnv("LCAL_COMM_ID"); + if (env) { + MKI_LOG(INFO) << "LCAL_COMM_ID set by environment to " << env; + if (GetAddrFromString(&handle.addr, env) != LCAL_SUCCESS) { + MKI_LOG(WARN) << ("Invalid LCAL_COMM_ID, please use format: :"); + return LCAL_INVALID_VALUE; + } + } else { + int bootRet = BootstrapGetServerIp(handle.addr); + if (bootRet != LCAL_SUCCESS) { + MKI_LOG(ERROR) << "lcal BootstrapGetIpPort failed!"; + return LCAL_ERROR_INTERNAL; + } + } + int dev; + int aclRet = aclrtGetDevice(&dev); + if (aclRet != ACL_SUCCESS) { + MKI_LOG(ERROR) << "ERROR: GetDevice."; + return LCAL_ERROR_INTERNAL; + } + handle.addr.sin.sin_port = htons(LCAL_DEFAULT_SOCK_PORT + dev + commDomain); + handle.magic = LCAL_MAGIC; + + return LCAL_SUCCESS; +} +} \ No newline at end of file diff --git a/comm/lcal/src/tools/socket/lcal_sock_exchange.h b/comm/lcal/src/tools/socket/lcal_sock_exchange.h new file mode 100644 index 0000000000000000000000000000000000000000..90f997ac550b74b294891f04e52a19b062abce81 --- /dev/null +++ b/comm/lcal/src/tools/socket/lcal_sock_exchange.h @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef LCCL_SOCK_EXCHANGE_H +#define LCCL_SOCK_EXCHANGE_H + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "mki/utils/log/log.h" + +#include "lcal_types.h" +#include "lcal_api.h" + +namespace Lcal { + +union LcalSocketAddress { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; +}; + +constexpr uint64_t LCAL_MAGIC = 0xdddd0000dddd0000; + +struct LcalBootstrapHandle { + uint64_t magic; + union LcalSocketAddress addr; +}; + +union LcalBootstrap { + LcalBootstrapHandle handle; + LcalUniqueId uid; +}; + +int BootstrapGetUniqueId(LcalBootstrapHandle &handle, int commDomain); + +class LcalSockExchange { +public: + LcalSockExchange(int rank, int rankSize, std::vector &rankList, int commDomain); + LcalSockExchange(int rank, int rankSize, LcalUniqueId lcalCommId); + ~LcalSockExchange(); + + template int AllGather(const T *sendBuf, size_t sendCount, T *recvBuf) + { + if (!isInit_ && Prepare() != LCAL_SUCCESS) { + return LCAL_ERROR_INTERNAL; + } + isInit_ = true; + + if (!IsServer()) { + return ClientSendRecv(sendBuf, sendCount, recvBuf); + } else { + return ServerRecvSend(sendBuf, sendCount, recvBuf); + } + } + + int GetNodeNum(); + + static bool CheckValid(LcalUniqueId lcalCommId) + { + LcalBootstrap id {}; + id.uid = lcalCommId; + return id.handle.magic == LCAL_MAGIC; + } + +private: + void GetIpAndPort(); + int Prepare(); + int Listen(); + int Accept(); + int StartSecureTunnel(); + void Close(int &fd) const; + int Connect(); + int AcceptConnection(int fd, sockaddr_in &clientAddr, socklen_t *sinSize) const; + void Cleanup(); + bool IsServer() const; + static bool CheckErrno(int ioErrno) + { + return ((ioErrno == EAGAIN) || (ioErrno == EWOULDBLOCK) || (ioErrno == EINTR)); + } + + template int Send(int fd, const T *sendBuf, size_t sendSize, int flag) const + { + do { + auto ret = send(fd, sendBuf, sendSize, flag); + if (ret < 0) { + if (CheckErrno(errno)) { + MKI_LOG(ERROR) << "send failed: " << strerror(errno); + continue; + } + MKI_LOG(DEBUG) << "Send failed: " << strerror(errno); + } + return ret; + } while (true); + } + + template int Recv(int fd, T *recvBuf, size_t recvSize, int flag) const + { + do { + auto ret = recv(fd, recvBuf, recvSize, flag); + if (ret < 0) { + if (CheckErrno(errno)) { + MKI_LOG(ERROR) << "recv failed: " << strerror(errno); + continue; + } + MKI_LOG(DEBUG) << "recv failed: " << strerror(errno); + } + return ret; + } while (true); + } + + template int ClientSendRecv(const T *sendBuf, size_t sendSize, T *recvBuf) + { + if (Send(fd_, sendBuf, sendSize * sizeof(T), 0) <= 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " send buffer failed"; + return LCAL_ERROR_INTERNAL; + } + + if (Recv(fd_, recvBuf, sendSize * rankSize_ * sizeof(T), MSG_WAITALL) <= 0) { + MKI_LOG(ERROR) << "Client side " << rank_ << " recv buffer failed "; + return LCAL_ERROR_INTERNAL; + } + + return LCAL_SUCCESS; + } + + template int ServerRecvSend(const T *sendBuf, size_t sendSize, T *recvBuf) + { + auto ret = memcpy_s(recvBuf, sendSize * sizeof (T), sendBuf, sendSize * sizeof (T)); + if (ret != EOK) { + MKI_LOG(ERROR) << "Failed to copy sendBuf to recvBuf."; + return LCAL_ERROR_INTERNAL; + } + + for (int i = 1; i < rankSize_; ++i) { + if (Recv(clientFds_[i], recvBuf + i * sendSize, sendSize * sizeof(T), MSG_WAITALL) <= 0) { + MKI_LOG(ERROR) << "Server side recv rank " << i << " buffer failed"; + return LCAL_ERROR_INTERNAL; + } + } + + for (int i = 1; i < rankSize_; ++i) { + if (Send(clientFds_[i], recvBuf, sendSize * rankSize_ * sizeof(T), 0) <= 0) { + MKI_LOG(ERROR) << "Server side send rank " << i << " buffer failed"; + return LCAL_ERROR_INTERNAL; + } + } + + return LCAL_SUCCESS; + } + + pid_t pid_ = 0; + int rank_ = 0; + int rankSize_ = 0; + int fd_ = -1; + std::vector clientFds_ = {}; + bool isInit_ = false; + std::vector rankList_ = {}; + int commDomain_ = -1; + std::string ip_ = ""; + uint16_t port_ = 0; + LcalBootstrap lcalCommId_ = {}; +}; +} + +#endif \ No newline at end of file diff --git a/src/include/atb/runner/lcal_runner.h b/src/include/atb/runner/lcal_runner.h index f27f09f76d5dfe5b7089cbfac2c06a17892ad57f..f684ebd541ccac6e5d8e92de5227645bcb30ce59 100644 --- a/src/include/atb/runner/lcal_runner.h +++ b/src/include/atb/runner/lcal_runner.h @@ -10,7 +10,7 @@ #ifndef ATB_LCAL_RUNNER_H #define ATB_LCAL_RUNNER_H -#include +#include #include #include #include "atb/runner/runner.h" diff --git a/src/include/atb/runner/lccl_runner.h b/src/include/atb/runner/lccl_runner.h index c96188cd4bea0e563b5a074533c72603bd9fcf1d..2f2adaaabbd5764741a4f5a1de00878bcb019550 100644 --- a/src/include/atb/runner/lccl_runner.h +++ b/src/include/atb/runner/lccl_runner.h @@ -10,7 +10,7 @@ #ifndef ATB_LCCL_RUNNER_H #define ATB_LCCL_RUNNER_H -#include +#include #include #include "atb/runner/runner.h" #include "atb/runner/lcal_runner.h" diff --git a/src/include/atb/runner/lcoc_runner.h b/src/include/atb/runner/lcoc_runner.h index 0dba0e6e20eb4ee9ab00f9fb6d51fa2ba99fdc08..b181bcfd19a5360a960a01abd67640b4c2476e9e 100644 --- a/src/include/atb/runner/lcoc_runner.h +++ b/src/include/atb/runner/lcoc_runner.h @@ -10,7 +10,7 @@ #ifndef ATB_LCOC_RUNNER_H #define ATB_LCOC_RUNNER_H -#include +#include #include #include "atb/runner/runner.h" #include "atb/runner/lcal_runner.h"