diff --git a/core/ops/add_norm/add_norm_ops_runner.cpp b/core/ops/add_norm/add_norm_ops_runner.cpp index 6bfeda38d589af34e04441fa7b4e2a2c2e70d695..25f2730edbb1647328e5324ba6c2bdeb1cd7121b 100644 --- a/core/ops/add_norm/add_norm_ops_runner.cpp +++ b/core/ops/add_norm/add_norm_ops_runner.cpp @@ -70,16 +70,9 @@ AsdOps::Status AddNormOpsRunner::SetupKernelGraph(const VariantPack &variantPack bool AddNormOpsRunner::CalcLayerNormTensor(const VariantPack &variantPack, int64_t &beginDim) { - AsdOps::TensorDesc inputDesc; - inputDesc.dtype = variantPack.inTensors.at(0).desc.dtype; - if (variantPack.inTensors.at(0).desc.dims.size() > variantPack.inTensors.at(1).desc.dims.size()) { - inputDesc.dims = variantPack.inTensors.at(0).desc.dims; - } else { - inputDesc.dims = variantPack.inTensors.at(1).desc.dims; - } - - const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(2); - const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(3); + const AsdOps::TensorDesc &inputDesc = variantPack.inTensors.at(0).desc; + const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(1); + const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(2); ASD_LOG(INFO) << GetName() << " layer norm input desc:" << TensorUtil::AsdOpsTensorDescToString(inputDesc) << ", weightTensor:" << TensorUtil::AsdOpsTensorToString(weightTensor) diff --git a/core/ops/norm/norm_ops_runner.cpp b/core/ops/norm/norm_ops_runner.cpp index 07edfb88353ffb59a259de19d4fd177b58ccd418..9b1321d645172b4ff73ddf604cd5c791a2c60257 100644 --- a/core/ops/norm/norm_ops_runner.cpp +++ b/core/ops/norm/norm_ops_runner.cpp @@ -62,16 +62,9 @@ AsdOps::Status NormOpsRunner::SetupKernelGraph(const VariantPack &variantPack) bool NormOpsRunner::CalcLayerNormTensor(const VariantPack &variantPack, int64_t &beginDim) { - AsdOps::TensorDesc inputDesc; - inputDesc.dtype = variantPack.inTensors.at(0).desc.dtype; - if (variantPack.inTensors.at(0).desc.dims.size() > variantPack.inTensors.at(1).desc.dims.size()) { - inputDesc.dims = variantPack.inTensors.at(0).desc.dims; - } else { - inputDesc.dims = variantPack.inTensors.at(1).desc.dims; - } - - const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(2); - const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(3); + const AsdOps::TensorDesc &inputDesc = variantPack.inTensors.at(0).desc; + const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(1); + const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(2); ASD_LOG(INFO) << GetName() << " layer norm input desc:" << TensorUtil::AsdOpsTensorDescToString(inputDesc) << ", weightTensor:" << TensorUtil::AsdOpsTensorToString(weightTensor)