From cf78b76bd74c8fb58a8e96cfbc25650be96113f7 Mon Sep 17 00:00:00 2001 From: guanguan Date: Tue, 22 Jul 2025 10:05:20 +0800 Subject: [PATCH 1/2] add K value --- .../linear_parallel_operation.cpp | 45 ++++++++++++++++--- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp index 95eabcfc..9e33f915 100644 --- a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp +++ b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp @@ -35,6 +35,7 @@ static const uint32_t IN_TENSOR_DIM_NUM = 2; static const uint32_t RESIDUAL_TENSOR_INDEX_3 = 3; static const uint32_t RESIDUAL_TENSOR_INDEX_4 = 4; static const uint32_t MAX_OUTPUT_SIZE = 204800; +static const uint32_t MAX_K = 24000; static bool AllToAllvcAllGatherGmmOutTensorCheck(const SVector &inTensorDescs, const TensorDesc &outTensorDesc, const std::string &logPrefix) @@ -67,9 +68,11 @@ bool CheckType(const infer::LinearParallelParam &opParam, Status &isOK) } if (opParam.quantType == atb::infer::LinearParallelParam::QuantType::QUANT_TYPE_PER_TOKEN && opParam.type != atb::infer::LinearParallelParam::ParallelType::ALLTOALLVC_ALL_GATHER_GMM && - opParam.type != atb::infer::LinearParallelParam::ParallelType::GMM_REDUCE_SCATTER_ALLTOALLVC) { - ATB_LOG(ERROR) << "LinearParallel backend lcoc only ALLTOALLVC_ALL_GATHER_GMM and " - << "GMM_REDUCE_SCATTER_ALLTOALLVC support quantType[QUANT_TYPE_PER_TOKEN]"; + opParam.type != atb::infer::LinearParallelParam::ParallelType::GMM_REDUCE_SCATTER_ALLTOALLVC && + opParam.type != atb::infer::LinearParallelParam::ParallelType::LINEAR_ALL_REDUCE) { + ATB_LOG(ERROR) << "When LinearParallel backend is lcoc, " + << "only ALLTOALLVC_ALL_GATHER_GMM, GMM_REDUCE_SCATTER_ALLTOALLVC and LINEAR_ALL_REDUCE" + << " support quantType[QUANT_TYPE_PER_TOKEN]"; isOK = ERROR_INVALID_PARAM; return true; } @@ -119,9 +122,8 @@ template <> Status CreateOperation(const infer::LinearParallelParam &opParam, Op } if (opParam.backend == "lcoc") { Status isOk; - if (CheckType(opParam, isOk)) { + if (CheckType(opParam, isOk)) return isOk; - } } *operation = new (std::nothrow) LinearParallelOperation(opParam); if (*operation == nullptr) { @@ -278,6 +280,11 @@ Status LinearParallelOperation::InferShapeAllToAllvcAllGatherGmm(const SVector MAX_K) { + ATB_LOG(ERROR) << GetLogPrefix() << "K is too large. should be 1 to 24000, now is " << k; + return ERROR_INVALID_TENSOR_SIZE; + } outTensorDescs.at(0).shape.dims[0] = inTensorDescs.at(inTensorDescs.size() - 1).shape.dims[0]; return NO_ERROR; } @@ -336,8 +343,11 @@ Status LinearParallelOperation::InferShapeCheckLinearAllReduce(const SVector infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + ATB_LOG(ERROR) << GetLogPrefix() << "when perChannelScale's type is float, " + << "outputDataType do not support float16_t"; return ERROR_INVALID_TENSOR_INI_MATCH; } + return CheckResidual(inTensorDescs); } @@ -346,6 +356,15 @@ Status LinearParallelOperation::InferShapeCheckLinearReduceScatter(const SVector if (!OperationUtil::MatmulInTensorDescsCheck(inTensorDescs, GetLogPrefix(), commonCheckParam_)) { return ERROR_INVALID_TENSOR_DIM; } + + bool isQuant = param_.quantType > infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && + param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; + if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + ATB_LOG(ERROR) << GetLogPrefix() << "when perChannelScale's type is float, " + << "outputDataType do not support float16_t"; + return ERROR_INVALID_TENSOR_INI_MATCH; + } + int64_t xTensorM = OperationUtil::GetXTensorM(inTensorDescs.at(0)); if (xTensorM % param_.rankSize != 0) { ATB_LOG(ERROR) << GetLogPrefix() << "inTensor0 m [" << xTensorM @@ -357,6 +376,14 @@ Status LinearParallelOperation::InferShapeCheckLinearReduceScatter(const SVector Status LinearParallelOperation::InferShapeCheckAllGatherLinear(const SVector &inTensorDescs) const { + bool isQuant = param_.quantType > infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && + param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; + if (isQuant && inTensorDescs.at(3).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + ATB_LOG(ERROR) << GetLogPrefix() << "when perChannelScale's type is float, " + << "outputDataType do not support float16_t"; + return ERROR_INVALID_TENSOR_INI_MATCH; + } + SVector newInTensorDescs; newInTensorDescs = inTensorDescs; newInTensorDescs.at(0).shape.dims[0] *= param_.rankSize; @@ -397,6 +424,14 @@ Status LinearParallelOperation::InferShapeCheckAllGatherLinearReduceScatter( Status LinearParallelOperation::InferShapeCheckAllToAllvcAllGatherGmm(const SVector &inTensorDescs) const { + bool isQuant = param_.quantType > infer::LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT && + param_.quantType < infer::LinearParallelParam::QuantType::QUANT_TYPE_MAX; + if (isQuant && inTensorDescs.at(2).dtype == ACL_FLOAT && param_.outDataType == ACL_FLOAT16) { + ATB_LOG(ERROR) << GetLogPrefix() << "when perChannelScale's type is float, " + << "outputDataType do not support float16_t"; + return ERROR_INVALID_TENSOR_INI_MATCH; + } + int64_t globalTokensPerExpertMatrixM = OperationUtil::GetXTensorM(inTensorDescs.at(inTensorDescs.size() - 2)); int64_t globalTokensPerExpertMatrixK = OperationUtil::GetXTensorK(inTensorDescs.at(inTensorDescs.size() - 2)); int64_t expertNums = param_.moeInfo.epSize * param_.moeInfo.localExpertNums; -- Gitee From 311d4cb2bdfe46571cfdb618c1a930c5a2e7d5bf Mon Sep 17 00:00:00 2001 From: guanguan Date: Fri, 25 Jul 2025 14:13:18 +0800 Subject: [PATCH 2/2] fix rs data --- tests/framework/python/CsvOpsTestTool/data_generation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) mode change 100755 => 100644 tests/framework/python/CsvOpsTestTool/data_generation.py diff --git a/tests/framework/python/CsvOpsTestTool/data_generation.py b/tests/framework/python/CsvOpsTestTool/data_generation.py old mode 100755 new mode 100644 index 97662fe9..a5c22deb --- a/tests/framework/python/CsvOpsTestTool/data_generation.py +++ b/tests/framework/python/CsvOpsTestTool/data_generation.py @@ -783,7 +783,12 @@ class LinearParallelOperation(DataGen): sum_tensor = matmul_result.clone() for i in range(rank_size - 1): sum_tensor += matmul_result - chunks = torch.split(sum_tensor, int(in_tensors[0].shape[0]/rank_size)) + if len(in_tensors[0].shape) == 3 and in_tensors[0].shape[0] == 1: + chunks_size = int(in_tensors[0].shape[1] / rank_size) + chunks = torch.split(sum_tensor, chunks_size, dim = 1) + else: + chunks_size = int(in_tensors[0].shape[0] / rank_size) + chunks = torch.split(sum_tensor, chunks_size) golden_result = chunks[rank] if LinearParallelOperation.residual_golden is not None: golden_result = golden_result + LinearParallelOperation.residual_golden -- Gitee