diff --git a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp index 1f864b0a68c7fcf41ed76e25f465e55f4e4c2e4b..65b11dd4a923ca7123d1d47e179c0702b6159334 100644 --- a/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp +++ b/src/ops_infer/topk_topp_sampling/topk_topp_sampling_operation.cpp @@ -128,8 +128,9 @@ Status TopkToppSamplingOperation::InferShapeImpl(const SVector &inTe outTensorDescs.at(1) = inTensorDescs.at(0); outTensorDescs.at(1).shape.dims[dimNum - 1] = 1; - if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { + if ((param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || + param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) && + param_.logProbsSize != 0) { outTensorDescs.at(OUT_TENSOR_LOGPROBS) = inTensorDescs.at(0); outTensorDescs.at(OUT_TENSOR_LOGPROBS).dtype = ACL_FLOAT; outTensorDescs.at(OUT_TENSOR_LOGPROBS).shape.dims[dimNum - 1] = param_.logProbsSize; @@ -277,6 +278,29 @@ Status TopkToppSamplingOperation::InferShapeCheckImpl(const SVector return CheckIntensorAndParam(inTensorDescs); } +Status TopkToppSamplingOperation::TopkToppLogProbsOutTensorCheck(const SVector &outTensorDescs) const +{ + if (param_.logProbsSize == 0) { + if (outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dimNum != 0) { + ATB_LOG(ERROR) << "outTensor[2] should be empty when logProbsSize=0, dimNum expected to be 0, but got: " + << outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dimNum; + return ERROR_INVALID_TENSOR_DIM_NUM; + } else { + return NO_ERROR; + } + } + if (outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dimNum != LOG_PROBS_OUT_TENSOR_DIM) { + ATB_LOG(ERROR) << "logProbsTensor shape dimNum: " << outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dimNum + << ", which should be same as " << LOG_PROBS_OUT_TENSOR_DIM; + return ERROR_INVALID_TENSOR_DIM_NUM; + } + if (outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dims[LAST_DIM] != param_.logProbsSize) { + ATB_LOG(ERROR) << "logProbsTensor dim:" << LAST_DIM << ", which should be same as " << param_.logProbsSize; + return ERROR_INVALID_TENSOR_DIM; + } + return NO_ERROR; +} + Status TopkToppSamplingOperation::SetupCheckImpl(const SVector &inTensors, const SVector &outTensors) const { @@ -285,18 +309,6 @@ Status TopkToppSamplingOperation::SetupCheckImpl(const SVector &inTensor SVector outTensorDescs = {}; OperationUtil::InTensorsToInTensorDescs(outTensors, outTensorDescs); ATB_LOG(DEBUG) << "outTensors size:" << outTensors.size(); - for (size_t i = 0; i < outTensorDescs.size(); i++) { - if (!TensorCheck::IsTensorDescDimNumValid(outTensorDescs.at(i), 2)) { - if ((param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || - param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) && - param_.logProbsSize == 0 && i == outTensorDescs.size() - 1) { - ATB_LOG(INFO) << GetLogPrefix() << "the dimNum of outTensor[" << i << "] allows not be 2 when logProbsSize is 0"; - } else { - ATB_LOG(ERROR) << GetLogPrefix() << "the dimNum of outTensor[" << i <<"] should be 2"; - return ERROR_INVALID_TENSOR_DIM_NUM; - } - } - } if (!(outTensorDescs.at(0).shape.dims[0] == outTensorDescs.at(1).shape.dims[0])) { ATB_LOG(ERROR) << "The batch size of outTensors should be same."; return ERROR_INVALID_TENSOR_DIM; @@ -313,10 +325,9 @@ Status TopkToppSamplingOperation::SetupCheckImpl(const SVector &inTensor } if (param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING || param_.topkToppSamplingType == atbInferTopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING) { - if (param_.logProbsSize != 0 && outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dims[LAST_DIM] != param_.logProbsSize) { - ATB_LOG(ERROR) << "the last dim of logProbsTensor dim:" << outTensorDescs.at(LOG_PROBS_OUT_TENSOR_INDEX).shape.dims[LAST_DIM] - << ", which should be same as " << param_.logProbsSize; - return ERROR_INVALID_TENSOR_DIM; + Status LogProbsOutTensorCheckRes = TopkToppLogProbsOutTensorCheck(outTensorDescs); + if (LogProbsOutTensorCheckRes != NO_ERROR) { + return LogProbsOutTensorCheckRes; } } return CheckIntensorAndParam(inTensorDescs);