From 9d0f8c62c3f5848e1cada06a7151ab7ed0d93892 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=9D=A8?= Date: Thu, 21 Aug 2025 10:56:12 +0800 Subject: [PATCH] allgather opt --- .../AclNNInvocation/README.md | 12 +- .../AclNNInvocation/scripts/gen_data.py | 20 +-- .../AclNNInvocation/src/main.cpp | 8 +- .../AclNNInvocation/src/op_runner.cpp | 1 - .../op_host/all_gather_matmul_custom.cpp | 147 +++++------------- .../op_kernel/all_gather_matmul_custom.cpp | 55 +++---- .../all_gather_matmul_custom_tiling.h | 13 +- .../op_kernel/gather_mm.h | 27 ++-- .../op_kernel/mc2_matmul_block.h | 57 +------ .../op_kernel/mc2_matmul_compute.h | 74 +++------ .../21_all_gather_matmul_custom/README.md | 18 +-- .../all_gather_matmul_custom.json | 55 +------ .../all_gather_matmul_demo_def.h | 12 -- 13 files changed, 136 insertions(+), 363 deletions(-) diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md index b6ddd2e8c..40cfc9d50 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/README.md @@ -26,17 +26,11 @@ ```cpp // 获取算子使用的workspace空间大小 aclnnStatus aclnnAllGatherMatmulCustomGetWorkspaceSize( - const aclTensor *x1, - const aclTensor *x2, + const aclTensor *a, + const aclTensor *b, const aclTensor *biasOptional, char *group, - bool isTransAOptional, - bool isTransBOptional, - int64_t gatherIndexOptional, - int64_t commTurnOptional, - int64_t rankSizeOptional, - bool isGatherOutOptional, - const aclTensor *yOut, + const aclTensor *cOut, const aclTensor *gatherOutOut, uint64_t *workspaceSize, aclOpExecutor **executor); diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/scripts/gen_data.py b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/scripts/gen_data.py index cd0d33c99..bd6cea60b 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/scripts/gen_data.py +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/scripts/gen_data.py @@ -15,20 +15,20 @@ def gen_golden_data_simple(): if not os.path.exists("output"): os.mkdir("output") - input_x1 = [] - input_x2 = [] + input_a = [] + input_b = [] for i in range(rank_dim): - x1 = np.random.uniform(-3, 3, [rank_m, rank_k]).astype(np.float16) - x2 = np.random.uniform(-3, 3, [rank_k, rank_n]).astype(np.float16) - x1.tofile("./input/input_x1_{}.bin".format(i)) - x2.tofile("./input/input_x2_{}.bin".format(i)) - input_x1.append(x1) - input_x2.append(x2) + a = np.random.uniform(-3, 3, [rank_m, rank_k]).astype(np.float16) + b = np.random.uniform(-3, 3, [rank_k, rank_n]).astype(np.float16) + a.tofile("./input/input_a_{}.bin".format(i)) + b.tofile("./input/input_b_{}.bin".format(i)) + input_a.append(a) + input_b.append(b) - golden_gather_out = np.concatenate(input_x1, axis=0) + golden_gather_out = np.concatenate(input_a, axis=0) for i in range(rank_dim): golden_gather_out.tofile("./output/golden_gather_out_{}.bin".format(i)) - out = np.matmul(golden_gather_out.astype(np.float32), input_x2[i].astype(np.float32)).astype(np.float16) + out = np.matmul(golden_gather_out.astype(np.float32), input_b[i].astype(np.float32)).astype(np.float16) out.tofile("./output/golden_out_{}.bin".format(i)) if __name__ == "__main__": diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp index dafa01311..86ff36642 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/main.cpp @@ -52,10 +52,10 @@ OperatorDesc CreateOpDesc() bool SetInputData(OpRunner &runner, uint32_t rankId) { size_t fileSize = 0; - ReadFile("../input/input_x1_" + std::to_string(rankId) + ".bin", fileSize, - runner.GetInputBuffer(0), runner.GetInputSize(0)); // Read input_x1 file - ReadFile("../input/input_x2_" + std::to_string(rankId) + ".bin", fileSize, - runner.GetInputBuffer(1), runner.GetInputSize(1)); // Read input_x2 file + ReadFile("../input/input_a_" + std::to_string(rankId) + ".bin", fileSize, + runner.GetInputBuffer(0), runner.GetInputSize(0)); // Read input_a file + ReadFile("../input/input_b_" + std::to_string(rankId) + ".bin", fileSize, + runner.GetInputBuffer(1), runner.GetInputSize(1)); // Read input_b file ReadFile("../input/input_bias_" + std::to_string(rankId) + ".bin", fileSize, runner.GetInputBuffer(INPUT_BUFFER_BIAS), runner.GetInputSize(INPUT_BUFFER_BIAS)); INFO_LOG("Set input success"); diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp index 0d40007d4..5aa62934f 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AclNNInvocation/src/op_runner.cpp @@ -302,7 +302,6 @@ bool OpRunner::RunOp(std::string group, aclrtStream stream) size_t workspaceSize = 0; aclOpExecutor *handle = nullptr; auto ret = aclnnAllGatherMatmulCustomGetWorkspaceSize(inputTensor_[0], inputTensor_[1], bias, (char*)group.c_str(), - IS_TRANS_A, IS_TRANS_B, GATHER_INDEX, COMM_TURN, RANK_DIM, IS_GATHER_OUT, outputTensor_[0], outputTensor_[1], &workspaceSize, &handle); if (ret != ACL_SUCCESS) { ERROR_LOG("Get Operator Workspace failed. error code is %d", static_cast(ret)); diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp index 2166672ad..9916b7b9d 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_host/all_gather_matmul_custom.cpp @@ -20,39 +20,11 @@ #define ERROR_LOG(fmt, args...) fprintf(stderr, "[ERROR] " fmt "\n", ##args) // tiling +namespace { constexpr int32_t TILE_M = 448; constexpr uint32_t HCCL_CMD_ALLGATHER = 6; -constexpr uint32_t FP16_BF16_SIZE = 2; -constexpr uint32_t CUSTOM_TILING_KEY = 100; // full mesh + no nd2nz + no cast bias +constexpr uint32_t HCCL_REDUCE_SUM = 0; constexpr int32_t L1_BUFFER_SIZE = 512 * 1024; - -namespace { -enum class HcclDataType { - HCCL_DATA_TYPE_INT8 = 0, /* *< int8 */ - HCCL_DATA_TYPE_INT16 = 1, /* *< int16 */ - HCCL_DATA_TYPE_INT32 = 2, /* *< int32 */ - HCCL_DATA_TYPE_FP16 = 3, /* *< fp16 */ - HCCL_DATA_TYPE_FP32 = 4, /* *< fp32 */ - HCCL_DATA_TYPE_INT64 = 5, /* *< int64 */ - HCCL_DATA_TYPE_UINT64 = 6, /* *< uint64 */ - HCCL_DATA_TYPE_UINT8 = 7, /* *< uint8 */ - HCCL_DATA_TYPE_UINT16 = 8, /* *< uint16 */ - HCCL_DATA_TYPE_UINT32 = 9, /* *< uint32 */ - HCCL_DATA_TYPE_FP64 = 10, /* *< fp64 */ - HCCL_DATA_TYPE_BFP16 = 11, /* *< bfp16 */ - HCCL_DATA_TYPE_RESERVED /* *< reserved */ -}; - -const std::set SUPPORT_DTYPE = { ge::DT_FLOAT16, ge::DT_BF16 }; -const std::map DTYPE_MAP = { - {ge::DT_FLOAT16, matmul_tiling::DataType::DT_FLOAT16}, - {ge::DT_BF16, matmul_tiling::DataType::DT_BFLOAT16} -}; - -const std::map HCCL_DATA_TYPE_MAP = { - {ge::DataType::DT_FLOAT16, HcclDataType::HCCL_DATA_TYPE_FP16}, - {ge::DataType::DT_BF16, HcclDataType::HCCL_DATA_TYPE_BFP16} -}; } static ge::graphStatus AllGatherMatmulCustomTilingFunc(gert::TilingContext *context) @@ -61,43 +33,31 @@ static ge::graphStatus AllGatherMatmulCustomTilingFunc(gert::TilingContext *cont auto aicCoreNum = ascendcPlatform.GetCoreNumAic(); // get attrs - uint32_t index = 0U; - const char *group = context->GetAttrs()->GetAttrPointer(index++); - bool isTransA = *(context->GetAttrs()->GetAttrPointer(index++)); - bool isTransB = *(context->GetAttrs()->GetAttrPointer(index++)); - int gatherIndex = *(context->GetAttrs()->GetAttrPointer(index++)); - int commTurn = *(context->GetAttrs()->GetAttrPointer(index++)); - int rankSize = *(context->GetAttrs()->GetAttrPointer(index++)); - bool isGatherOut = *(context->GetAttrs()->GetAttrPointer(index++)); - INFO_LOG("Group %s, isTransA is %d, isTransB is %d, gatherIndex is %d, commTurn is %d, rankSize %d, isGatherOut %d", - group, isTransA, isTransB, gatherIndex, commTurn, rankSize, isGatherOut); + const char *group = context->GetAttrs()->GetAttrPointer(0); + INFO_LOG("Group %s", group); // get shape [[4096/8,5120], [5120,640]] fp16 uint64_t rankM = context->GetInputShape(0)->GetStorageShape().GetDim(0); uint64_t rankK = context->GetInputShape(0)->GetStorageShape().GetDim(1); - uint64_t rankN = isTransB ? - context->GetInputShape(1)->GetStorageShape().GetDim(0) : context->GetInputShape(1)->GetStorageShape().GetDim(1); + uint64_t rankN = context->GetInputShape(1)->GetStorageShape().GetDim(1); INFO_LOG("RankM %lu, rankK %lu, rankN %lu", rankM, rankK, rankN); // get dtype auto aType = context->GetInputTensor(0)->GetDataType(); auto bType = context->GetInputTensor(1)->GetDataType(); auto cType = aType; - if (SUPPORT_DTYPE.find(aType) == SUPPORT_DTYPE.cend() || SUPPORT_DTYPE.find(bType) == SUPPORT_DTYPE.cend()) { - ERROR_LOG("Dtype of a %d or b %d is unsupported", static_cast(aType), static_cast(bType)); + if (aType != ge::DT_FLOAT16 || bType != ge::DT_FLOAT16) { + ERROR_LOG("Dtype is unsupported"); return ge::GRAPH_FAILED; } - // set block dim & tiling key + // set block dim context->SetBlockDim(ascendcPlatform.GetCoreNumAic()); - context->SetTilingKey(CUSTOM_TILING_KEY); // set work space size - size_t nd2nzLen = (rankN + 16) * (rankK + 16) * FP16_BF16_SIZE; size_t systemWorkspaceSize = static_cast(ascendcPlatform.GetLibApiWorkSpaceSize()); - size_t workspaceSize = nd2nzLen + 0 + 0 + 0 + systemWorkspaceSize; // nd2nzLen + gmcFloat + gatherLen + biasLen + 16M - size_t *currentWorkspace = context->GetWorkspaceSizes(1); - currentWorkspace[0] = workspaceSize; + size_t *workspaceSizes = context->GetWorkspaceSizes(1); // 获取设置workspace大小的指针。 + workspaceSizes[0] = systemWorkspaceSize; uint64_t tileNum = rankM / TILE_M; uint64_t tailNum = (rankM % TILE_M == 0) ? 0 : 1; @@ -106,27 +66,18 @@ static ge::graphStatus AllGatherMatmulCustomTilingFunc(gert::TilingContext *cont AllGatherMatmulCustomTilingData *tiling = context->GetTilingData(); - tiling->param.rankDim = rankSize; - tiling->param.tileNum = tileNum; - tiling->param.tailM = tailM; - tiling->param.tailNum = tailNum; - tiling->param.rankM = rankM; - tiling->param.rankN = rankN; - tiling->param.rankK = rankK; - tiling->param.gatherIndex = gatherIndex; - tiling->param.isTransposeA = isTransA ? 1 : 0; - tiling->param.isTransposeB = isTransB ? 1 : 0; - tiling->param.storageGather = isGatherOut; - tiling->param.cToFloatLen = 0; // 通信结果输出到workspace或gatherOut 不需要gmcFloat - tiling->param.nd2NzWorkLen = nd2nzLen; - tiling->param.gatherLen = 0; // gatherOut=true 不输出到workspace - tiling->param.dataType = static_cast(HCCL_DATA_TYPE_MAP.at(aType)); + tiling->cfg.tileNum = tileNum; + tiling->cfg.tailM = tailM; + tiling->cfg.tailNum = tailNum; + tiling->cfg.rankM = rankM; + tiling->cfg.rankN = rankN; + tiling->cfg.rankK = rankK; // matmul tiling func auto matmulTilingFunc = [&](int64_t m, int64_t n, int64_t k, TCubeTiling &cubeTiling) -> bool { matmul_tiling::MultiCoreMatmulTiling mmTiling; - mmTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, DTYPE_MAP.at(aType), isTransA); - mmTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, DTYPE_MAP.at(bType), isTransB); - mmTiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, DTYPE_MAP.at(cType)); + mmTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + mmTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); + mmTiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT16); mmTiling.SetBias(false); mmTiling.SetDim(aicCoreNum); mmTiling.SetShape(m, n, k); @@ -153,9 +104,9 @@ static ge::graphStatus AllGatherMatmulCustomTilingFunc(gert::TilingContext *cont return ge::GRAPH_FAILED; } - uint32_t opType = 6; + uint32_t opType = HCCL_CMD_ALLGATHER; std::string algConfig = "AllGather=level0:doublering"; - uint32_t reduceType = 0; + uint32_t reduceType = HCCL_REDUCE_SUM; AscendC::Mc2CcTilingConfig mc2CcTilingConfig(group, opType, algConfig, reduceType); mc2CcTilingConfig.GetTiling(tiling->mc2InitTiling); mc2CcTilingConfig.SetSkipLocalRankCopy(0); @@ -169,54 +120,38 @@ class AllGatherMatmulCustom : public OpDef { public: explicit AllGatherMatmulCustom(const char *name) : OpDef(name) { - this->Input("x1") + this->Input("a") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); - this->Input("x2") + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("b") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}) .IgnoreContiguous(); this->Input("bias") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); - this->Output("y") + this->Output("c") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); this->Output("gather_out") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_BF16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); this->Attr("group").AttrType(REQUIRED).String(); - this->Attr("isTransA").AttrType(OPTIONAL).Bool(false); - this->Attr("isTransB").AttrType(OPTIONAL).Bool(false); - this->Attr("gatherIndex").AttrType(OPTIONAL).Int(0); - this->Attr("commTurn").AttrType(OPTIONAL).Int(0); - this->Attr("rank_size").AttrType(OPTIONAL).Int(0); - this->Attr("is_gather_out").AttrType(OPTIONAL).Bool(true); - - OpAICoreConfig aicore_config; - aicore_config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(true) - .ExtendCfgInfo("aclnnSupport.value", "support_aclnn") - .ExtendCfgInfo("jitCompile.flag", "static_false") - .ExtendCfgInfo("multiKernelSupportDynamicGraph.value", "multi_kernel"); + this->AICore().SetTiling(AllGatherMatmulCustomTilingFunc); - this->AICore().AddConfig("ascend910b", aicore_config); + this->AICore().AddConfig("ascend910b"); this->MC2().HcclGroup("group"); } }; diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp index c35f7616a..bcdae45b8 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom.cpp @@ -16,15 +16,6 @@ using namespace AscendC; -template -struct BiasType { - using type = float; -}; -template <> -struct BiasType { - using type = half; -}; - extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR biasGM, GM_ADDR cGM, GM_ADDR gatherOutGM, GM_ADDR workspaceGM, GM_ADDR tilingGM) { @@ -37,7 +28,7 @@ extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_A __gm__ void *mc2CcTiling = (__gm__ void *)(&(tiling->mc2CcTiling)); GET_TILING_DATA(tilingData, tilingGM); - auto &&cfg = tilingData.param; + auto &&cfg = tilingData.cfg; auto &&localTiling = tilingData.localTiling; auto &&tileTiling = tilingData.tileTiling; auto &&tailTiling = tilingData.tailTiling; @@ -59,35 +50,33 @@ extern "C" __global__ __aicore__ void all_gather_matmul_custom(GM_ADDR aGM, GM_A // 下发allgather任务 // 首块 - auto handleId = hccl.AllGather(aGM, gatherOutGM, aTileCnt, HcclDataType(cfg.dataType), aRankCnt, tileNum); + auto handleId = hccl.AllGather(aGM, gatherOutGM, aTileCnt, HcclDataType::HCCL_DATA_TYPE_FP16, aRankCnt, tileNum); // 尾块 auto tailHandleId = hccl.AllGather(aGM + tileNum * aTileOffset, gatherOutGM + tileNum * aTileOffset, aTailCnt, - HcclDataType(cfg.dataType), aRankCnt, tailNum); + HcclDataType::HCCL_DATA_TYPE_FP16, aRankCnt, tailNum); - if (TILING_KEY_IS(100)) { // full mesh + no nd2nz + biasNoNeedCast - using A_TYPE = MatmulType; - using B_TYPE = MatmulType; - using C_TYPE = MatmulType; - using BIAS_TYPE = MatmulType::type>; + using A_TYPE = MatmulType; + using B_TYPE = MatmulType; + using C_TYPE = MatmulType; - // 本卡数据计算 - MatmulKernelLocal(aGM, bGM, biasGM, cGM, cfg, localTiling, hccl); + // 本卡数据计算 + MatmulKernelLocal(aGM, bGM, cGM, cfg, localTiling, hccl); - // tile首块计算 - auto aAddr = gatherOutGM; // gatherOut 作为 mm A矩阵地址 - auto cAddr = cGM; - if (tileNum > 0) { - MatmulKernel(aAddr, bGM, biasGM, cAddr, cfg, tileTiling, hccl, handleId, - tileNum); - } + // tile首块计算 + auto aAddr = gatherOutGM; // gatherOut 作为 mm A矩阵地址 + auto cAddr = cGM; + if (tileNum > 0) { + MatmulKernel(aAddr, bGM, cAddr, cfg, tileTiling, hccl, handleId, + tileNum); + } - // tail尾块计算 - aAddr = gatherOutGM + tileNum * aTileOffset; - cAddr = cGM + tileNum * cTileOffset; - if (tailNum > 0) { - MatmulKernel(aAddr, bGM, biasGM, cAddr, cfg, tailTiling, hccl, tailHandleId, - tailNum); - } + // tail尾块计算 + aAddr = gatherOutGM + tileNum * aTileOffset; + cAddr = cGM + tileNum * cTileOffset; + if (tailNum > 0) { + MatmulKernel(aAddr, bGM, cAddr, cfg, tailTiling, hccl, tailHandleId, + tailNum); } + hccl.Finalize(); } \ No newline at end of file diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom_tiling.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom_tiling.h index de2ba7834..207eb9493 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom_tiling.h +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/all_gather_matmul_custom_tiling.h @@ -14,22 +14,13 @@ #include #include "kernel_tiling/kernel_tiling.h" -struct AllGatherMatmulRCSTiling { - uint32_t rankDim; +struct AllGatherMatmulTiling { uint32_t rankM; uint32_t rankN; uint32_t rankK; - uint32_t gatherIndex; - uint32_t isTransposeA; - uint32_t isTransposeB; - uint32_t storageGather; - uint64_t cToFloatLen; - uint64_t nd2NzWorkLen; - uint64_t gatherLen; uint32_t tileNum; uint32_t tailM; uint32_t tailNum; - uint32_t dataType; }; class AllGatherMatmulCustomTilingData { @@ -39,7 +30,7 @@ public: TCubeTiling localTiling; TCubeTiling tileTiling; TCubeTiling tailTiling; - AllGatherMatmulRCSTiling param; + AllGatherMatmulTiling cfg; }; #endif //__ALL_GATHER_MATMUL_CUSTOM_TILING_H__ diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h index 891f1082e..b363d8ce2 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/gather_mm.h @@ -14,8 +14,8 @@ #if defined ASCENDC_CPU_DEBUG #define SET_G_CORE_TYPE_IS_AIV thread_local int g_coreType = 2 #define SET_G_CORE_TYPE_IS_AIC thread_local int g_coreType = 1 -#define DTYPE_X1 half -#define DTYPE_Y half +#define DTYPE_A half +#define DTYPE_C half #else #define SET_G_CORE_TYPE_IS_AIV #define SET_G_CORE_TYPE_IS_AIC @@ -26,13 +26,12 @@ #include "all_gather_matmul_custom_tiling.h" namespace AscendC { -using A_DTYPE = DTYPE_X1; -using B_DTYPE = DTYPE_X1; -using C_DTYPE = DTYPE_Y; -using BIAS_DTYPE = DTYPE_Y; +using A_DTYPE = DTYPE_A; +using B_DTYPE = DTYPE_B; +using C_DTYPE = DTYPE_C; -template -__aicore__ inline void MatmulKernelLocal(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR biasGM, GM_ADDR cGM, AllGatherMatmulRCSTiling &cfg, +template +__aicore__ inline void MatmulKernelLocal(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR cGM, AllGatherMatmulTiling &cfg, TCubeTiling &tiling, Hccl &hccl) { if ASCEND_IS_AIC { @@ -43,17 +42,17 @@ __aicore__ inline void MatmulKernelLocal(GM_ADDR aGM, GM_ADDR bGM, GM_ADDR biasG const auto aRankDataCnt = cfg.rankM * cfg.rankK; const auto cRankDataCnt = cfg.rankM * cfg.rankN; - MatmulCompute mmLocal; + MatmulCompute mmLocal; mmLocal.Init(cfg, tiling); - mmLocal.UpdateWeightBias(bGM, biasGM); + mmLocal.UpdateWeight(bGM); mmLocal.UpdateAddress(aGM, aRankDataCnt, cGM + hccl.GetRankId() * cRankDataCnt * sizeof(C_T), cRankDataCnt); mmLocal.Process(); mmLocal.End(); } } -template -__aicore__ inline void MatmulKernel(GM_ADDR aAddr, GM_ADDR bGM, GM_ADDR biasGM, GM_ADDR cAddr, AllGatherMatmulRCSTiling &cfg, +template +__aicore__ inline void MatmulKernel(GM_ADDR aAddr, GM_ADDR bGM, GM_ADDR cAddr, AllGatherMatmulTiling &cfg, TCubeTiling &tiling, Hccl &hccl, HcclHandle &handleId, uint32_t tileCnt) { if ASCEND_IS_AIC { @@ -73,9 +72,9 @@ __aicore__ inline void MatmulKernel(GM_ADDR aAddr, GM_ADDR bGM, GM_ADDR biasGM, const auto aRankOffset = cfg.rankM * cfg.rankK * sizeof(A_T); const auto cRankOffset = cfg.rankM * cfg.rankN * sizeof(C_T); - MatmulCompute mm; + MatmulCompute mm; mm.Init(cfg, tiling); - mm.UpdateWeightBias(bGM, biasGM); + mm.UpdateWeight(bGM); for (uint32_t i = 0; i < tileCnt; i++) { // wait current handle allgather hccl.Wait(handleId); diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h index 5e62fd6f6..00d4322b5 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_block.h @@ -19,15 +19,11 @@ struct BaseBlockOffset { uint64_t offsetA; uint64_t offsetB; uint64_t offsetC; - uint64_t offsetBias; }; struct BaseBlockArguments { bool isRowOrder; - bool isAtomic; - bool isTransA; - bool isTransB; uint32_t singleCoreM; uint32_t singleCoreN; uint32_t mBlockCnt; // M方向的基本块个数 @@ -48,11 +44,10 @@ struct BaseBlockArguments class MatmulBaseBlock { public: __aicore__ inline MatmulBaseBlock() {} - __aicore__ inline void Init(TCubeTiling& tiling, bool isTransA, bool isTransB); + __aicore__ inline void Init(TCubeTiling& tiling); __aicore__ inline void InitBlockWithoutIndex(); __aicore__ inline void UpdateBlockIndex(uint32_t currPos); __aicore__ inline void UpdateBlockParams(int32_t mTileIndex=0, int32_t nTileIndex=0); - template __aicore__ inline void CalcGMOffset(); __aicore__ inline void GetBlockStartIdx(uint32_t startIdx, uint32_t endIdx); @@ -62,7 +57,7 @@ public: TCubeTiling tiling_; }; -__aicore__ inline void MatmulBaseBlock::Init(TCubeTiling& tiling, bool isTransA, bool isTransB) +__aicore__ inline void MatmulBaseBlock::Init(TCubeTiling& tiling) { tiling_ = tiling; args_.preCoreStartIdx = 0; @@ -75,12 +70,6 @@ __aicore__ inline void MatmulBaseBlock::Init(TCubeTiling& tiling, bool isTransA, if (tiling_.N > 5 * tiling_.M) { // 5: ratio of rowOrder args_.isRowOrder = false; } - args_.isTransA = isTransA; - args_.isTransB = isTransB; - args_.isAtomic = false; - if (args_.isTransA) { - args_.isAtomic = true; - } } __aicore__ inline void MatmulBaseBlock::InitBlockWithoutIndex() @@ -168,47 +157,11 @@ __aicore__ inline void MatmulBaseBlock::UpdateBlockParams(int32_t mTileIndex, in args_.mCWorkOffset = args_.mBlockOffset; } -template __aicore__ inline void MatmulBaseBlock::CalcGMOffset() { - auto alignedKa = AlignUp(tiling_.Ka, C0_SIZE); - auto alignedKb = AlignUp(tiling_.Kb, C0_SIZE); - - if constexpr (A_TYPE::format == CubeFormat::ND) { - if (args_.isTransA) { - offset_.offsetA = args_.mBlockOffset; - } else { - offset_.offsetA = args_.mBlockOffset * tiling_.Ka; - } - } else if constexpr (A_TYPE::format == CubeFormat::NZ) { - if (args_.isTransA) { - offset_.offsetA = args_.mBlockOffset * alignedKa; - } else { - offset_.offsetA = args_.mBlockOffset * C0_SIZE; - } - } - - if constexpr (B_TYPE::format == CubeFormat::ND) { - if (args_.isTransB) { - offset_.offsetB = args_.nBlockOffset * tiling_.Kb; - } else { - offset_.offsetB = args_.nBlockOffset; - } - } else if constexpr (B_TYPE::format == CubeFormat::NZ) { - if (args_.isTransB) { - offset_.offsetB = args_.nBlockOffset * C0_SIZE; - } else { - offset_.offsetB = args_.nBlockOffset * alignedKb; - } - } - - // C矩阵和BIAS只支持ND - if constexpr (C_TYPE::format == CubeFormat::ND || C_TYPE::format == CubeFormat::ND_ALIGN) { - offset_.offsetC = args_.nBlockOffset + args_.mCWorkOffset * tiling_.N; - } - if constexpr (BIAS_TYPE::format == CubeFormat::ND) { - offset_.offsetBias = args_.nBlockOffset; - } + offset_.offsetA = args_.mBlockOffset * tiling_.Ka; + offset_.offsetB = args_.nBlockOffset; + offset_.offsetC = args_.nBlockOffset + args_.mCWorkOffset * tiling_.N; } } // namespace ASCENDC #endif // MC2_MATMUL_BLOCK_H diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h index 23f0a1f18..0bac09100 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/AllGatherMatmulCustom/op_kernel/mc2_matmul_compute.h @@ -16,103 +16,69 @@ namespace AscendC { using namespace matmul; -template +template class MatmulCompute { using A_T = typename A_TYPE::T; using B_T = typename B_TYPE::T; using C_T = typename C_TYPE::T; - using BiasT = typename BIAS_TYPE::T; public: __aicore__ inline MatmulCompute() {} - __aicore__ inline void Init(AllGatherMatmulRCSTiling& cfg, TCubeTiling& tiling); - __aicore__ inline void UpdateWeightBias(GM_ADDR bGM, GM_ADDR biasGM); + __aicore__ inline void Init(AllGatherMatmulTiling& cfg, TCubeTiling& tiling); + __aicore__ inline void UpdateWeight(GM_ADDR bGM); __aicore__ inline void UpdateAddress(GM_ADDR aGM, uint32_t aSize, GM_ADDR cGM, uint32_t cSize); __aicore__ inline void Process(); __aicore__ inline void End(); private: - __aicore__ inline void SetOrgShapeAlign(); - -private: - MatmulImpl mm_; + MatmulImpl mm_; GlobalTensor aGlobal; GlobalTensor bGlobal; GlobalTensor cGlobal; - GlobalTensor biasGlobal; MatmulBaseBlock block_; TCubeTiling tiling_; - AllGatherMatmulRCSTiling cfg_; + AllGatherMatmulTiling cfg_; }; -template -__aicore__ inline void MatmulCompute::UpdateWeightBias( - GM_ADDR bGM, GM_ADDR biasGM) +template +__aicore__ inline void MatmulCompute::UpdateWeight(GM_ADDR bGM) { // MC2的计算流中默认B矩阵不变,GM地址无需偏移 bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ B_T *>(bGM), tiling_.Kb * tiling_.N); - biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ BiasT *>(biasGM), tiling_.N); } -template -__aicore__ inline void MatmulCompute::UpdateAddress( +template +__aicore__ inline void MatmulCompute::UpdateAddress( GM_ADDR aGM, uint32_t aSize, GM_ADDR cGM, uint32_t cSize) { aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ A_T *>(aGM), aSize); cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ C_T *>(cGM), cSize); } -template -__aicore__ inline void MatmulCompute::Init(AllGatherMatmulRCSTiling& cfg, TCubeTiling& tiling) +template +__aicore__ inline void MatmulCompute::Init(AllGatherMatmulTiling& cfg, TCubeTiling& tiling) { - // MatmulImpl初始化 mm_.SetSubBlockIdx(0); mm_.Init(&tiling, GetTPipePtr()); tiling_ = tiling; cfg_ = cfg; - bool isTransA = cfg.isTransposeA > 0; - bool isTransB = cfg.isTransposeB > 0; - block_.Init(tiling, isTransA, isTransB); - SetOrgShapeAlign(); -} - -template -__aicore__ inline void MatmulCompute::SetOrgShapeAlign() -{ - if constexpr (A_TYPE::format == CubeFormat::NZ && B_TYPE::format == CubeFormat::NZ) { - auto alignKa = AlignUp(tiling_.Ka, C0_SIZE); - auto alignKb = AlignUp(tiling_.Kb, C0_SIZE); - auto alignM = AlignUp(tiling_.M, C0_SIZE); - auto alignN = AlignUp(tiling_.N, C0_SIZE); - mm_.SetOrgShape(alignM, alignN, alignKa, alignKb, cfg_.rankN); - } else if (A_TYPE::format == CubeFormat::NZ) { - auto alignKa = AlignUp(tiling_.Ka, C0_SIZE); - auto alignM = AlignUp(tiling_.M, C0_SIZE); - mm_.SetOrgShape(alignM, tiling_.N, alignKa, tiling_.Kb, cfg_.rankN); - } else if (B_TYPE::format == CubeFormat::NZ) { - auto alignKb = AlignUp(tiling_.Kb, C0_SIZE); - auto alignN = AlignUp(tiling_.N, C0_SIZE); - mm_.SetOrgShape(tiling_.M, alignN, tiling_.Ka, alignKb, cfg_.rankN); - } + block_.Init(tiling); } -template -__aicore__ inline void MatmulCompute::Process() +template +__aicore__ inline void MatmulCompute::Process() { // 每次block循环开始前需要计算初始blockIndex block_.InitBlockWithoutIndex(); for (uint32_t i = 0; i < block_.args_.blockCnt; i++) { - // calculate blcokCurrIndex + // calculate blockCurrIndex block_.UpdateBlockIndex(i); if (block_.args_.blockCurrIdx < block_.args_.totalBlockCnt) { block_.UpdateBlockParams(); - block_.template CalcGMOffset(); + block_.CalcGMOffset(); mm_.SetSingleShape(block_.args_.singleCoreM, block_.args_.singleCoreN, tiling_.singleCoreK); - mm_.SetTensorA(aGlobal[block_.offset_.offsetA], block_.args_.isTransA); - mm_.SetTensorB(bGlobal[block_.offset_.offsetB], block_.args_.isTransB); - if (tiling_.isBias) { - mm_.SetBias(biasGlobal[block_.offset_.offsetBias]); - } + mm_.SetTensorA(aGlobal[block_.offset_.offsetA]); + mm_.SetTensorB(bGlobal[block_.offset_.offsetB]); mm_.Iterate(); mm_.GetTensorC(cGlobal[block_.offset_.offsetC]); // 增加M等FIX同步 @@ -123,8 +89,8 @@ __aicore__ inline void MatmulCompute::Process } } -template -__aicore__ inline void MatmulCompute::End() +template +__aicore__ inline void MatmulCompute::End() { mm_.End(); } diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md index c121aac07..03c7430e5 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/README.md @@ -7,21 +7,21 @@ │ ├── AllGatherMatmulCustom // AllGatherMatmulCustom算子工程 │ ├── all_gather_matmul_custom.json // AllGatherMatmulCustom算子的原型定义json文件 │ ├── all_gather_matmul_demo_def.h // AllGatherMatmulCustom算子参数配置 -│ └── install.sh // 脚本,调用msOpGen生成自定义算子工程,并编译 +│ └── install.sh // 脚本,调用msOpGen生成自定义算子工程,并编译 ``` ## 算子描述 -AllGatherMatmul算子实现了AllGather通信操作和Matmul矩阵乘法运算操作的融合,输出AllGather通信操作结果gather\_out和Matmul运算操作结果y,对应的数学表达式为: +AllGatherMatmul算子实现了AllGather通信操作和Matmul矩阵乘法运算操作的融合,输出AllGather通信操作结果gather\_out和Matmul运算操作结果c,对应的数学表达式为: $$ -gather\_out=AllGather(x1) +gather\_out=AllGather(a) $$ $$ -y=gather\_out * x2 +c=gather\_out * b $$ - AllGather() 为集合通信AllGather通信操作。 -- x1、x2为源操作数,x1为左矩阵,形状为\[M, K];x2为右矩阵,形状为\[K, N]。 +- a、b为源操作数,a为左矩阵,形状为\[M, K];b为右矩阵,形状为\[K, N]。 - gather\_out为目的操作数,存放AllGather通信结果的矩阵,形状为[M * rankDim, K],其中rankDim为通信域内的节点数。 -- y为目的操作数,存放Matmul运算结果的矩阵,形状为[M * rankDim, N]。 +- c为目的操作数,存放Matmul运算结果的矩阵,形状为[M * rankDim, N]。 备注:集合通信原语AllGather请参考[集合通信用户指南](https://hiascend.com/document/redirect/CannCommunityHcclUg)>集合通信原语。 @@ -30,12 +30,12 @@ $$ 算子类型(OpType)AllGatherMatmulCustom 算子输入nameshapedata typeformat -x1512 * 5120float16ND -x25120 * 640float16ND +a512 * 5120float16ND +b5120 * 640float16ND bias/// -算子输出y4096 * 640float16ND +算子输出c4096 * 640float16ND gather_out4096 * 5120float16ND 核函数名all_gather_matmul_custom diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json index 6c137f335..6ab16e763 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_custom.json @@ -3,14 +3,13 @@ "op": "AllGatherMatmulCustom", "input_desc": [ { - "name": "x1", + "name": "a", "param_type": "required", "format": [ "ND" ], "type": [ - "float16", - "bfloat16" + "float16" ] }, { @@ -20,8 +19,7 @@ "ND" ], "type": [ - "float16", - "bfloat16" + "float16" ] }, { @@ -31,21 +29,19 @@ "ND" ], "type": [ - "float16", - "bfloat16" + "float16" ] } ], "output_desc":[ { - "name": "y", + "name": "c", "param_type": "required", "format": [ "ND" ], "type": [ - "float16", - "bfloat16" + "float16" ] }, { @@ -55,8 +51,7 @@ "ND" ], "type": [ - "float16", - "bfloat16" + "float16" ] } ], @@ -66,42 +61,6 @@ "type": "string", "default_value":"", "param_type":"required" - }, - { - "name": "is_trans_a", - "type": "bool", - "default_value":false, - "param_type":"optional" - }, - { - "name": "is_trans_b", - "type": "bool", - "default_value":false, - "param_type":"optional" - }, - { - "name": "gather_index", - "type": "int", - "default_value":0, - "param_type":"optional" - }, - { - "name": "comm_turn", - "type": "int", - "default_value":0, - "param_type":"optional" - }, - { - "name": "rank_size", - "type": "int", - "default_value":0, - "param_type":"optional" - }, - { - "name": "is_gather_out", - "type": "bool", - "default_value":true, - "param_type":"optional" } ] } diff --git a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_demo_def.h b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_demo_def.h index bea5e9251..1377915f5 100644 --- a/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_demo_def.h +++ b/operator/ascendc/4_best_practices/21_all_gather_matmul_custom/all_gather_matmul_demo_def.h @@ -15,17 +15,5 @@ constexpr uint32_t RANK_DIM = 8; constexpr uint32_t RANK_M = 512; constexpr uint32_t RANK_K = 5120; constexpr uint32_t RANK_N = 640; -constexpr bool IS_TRANS_A = false; -constexpr bool IS_TRANS_B = false; -constexpr int64_t GATHER_INDEX = 0; -constexpr int64_t COMM_TURN = 0; -constexpr bool IS_GATHER_OUT = true; - -// tiling -constexpr int32_t TILE_M = 448; -constexpr uint32_t HCCL_CMD_ALLGATHER = 6; -constexpr uint32_t FP16_BF16_SIZE = 2; -constexpr uint32_t CUSTOM_TILING_KEY = 100; // full mesh + no nd2nz + no cast bias -constexpr int32_t L1_BUFFER_SIZE = 512 * 1024; #endif \ No newline at end of file -- Gitee