From 4b625dcc01af8c063f3a2179c26f5ac9664d11a4 Mon Sep 17 00:00:00 2001 From: lotus Date: Thu, 15 Jun 2023 20:24:51 +0800 Subject: [PATCH] feat:Norm CalcLayerNormTensor --- core/ops/add_norm/add_norm_ops_runner.cpp | 13 +++---------- core/ops/norm/norm_ops_runner.cpp | 13 +++---------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/core/ops/add_norm/add_norm_ops_runner.cpp b/core/ops/add_norm/add_norm_ops_runner.cpp index 6bfeda38..25f2730e 100644 --- a/core/ops/add_norm/add_norm_ops_runner.cpp +++ b/core/ops/add_norm/add_norm_ops_runner.cpp @@ -70,16 +70,9 @@ AsdOps::Status AddNormOpsRunner::SetupKernelGraph(const VariantPack &variantPack bool AddNormOpsRunner::CalcLayerNormTensor(const VariantPack &variantPack, int64_t &beginDim) { - AsdOps::TensorDesc inputDesc; - inputDesc.dtype = variantPack.inTensors.at(0).desc.dtype; - if (variantPack.inTensors.at(0).desc.dims.size() > variantPack.inTensors.at(1).desc.dims.size()) { - inputDesc.dims = variantPack.inTensors.at(0).desc.dims; - } else { - inputDesc.dims = variantPack.inTensors.at(1).desc.dims; - } - - const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(2); - const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(3); + const AsdOps::TensorDesc &inputDesc = variantPack.inTensors.at(0).desc; + const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(1); + const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(2); ASD_LOG(INFO) << GetName() << " layer norm input desc:" << TensorUtil::AsdOpsTensorDescToString(inputDesc) << ", weightTensor:" << TensorUtil::AsdOpsTensorToString(weightTensor) diff --git a/core/ops/norm/norm_ops_runner.cpp b/core/ops/norm/norm_ops_runner.cpp index 07edfb88..9b1321d6 100644 --- a/core/ops/norm/norm_ops_runner.cpp +++ b/core/ops/norm/norm_ops_runner.cpp @@ -62,16 +62,9 @@ AsdOps::Status NormOpsRunner::SetupKernelGraph(const VariantPack &variantPack) bool NormOpsRunner::CalcLayerNormTensor(const VariantPack &variantPack, int64_t &beginDim) { - AsdOps::TensorDesc inputDesc; - inputDesc.dtype = variantPack.inTensors.at(0).desc.dtype; - if (variantPack.inTensors.at(0).desc.dims.size() > variantPack.inTensors.at(1).desc.dims.size()) { - inputDesc.dims = variantPack.inTensors.at(0).desc.dims; - } else { - inputDesc.dims = variantPack.inTensors.at(1).desc.dims; - } - - const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(2); - const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(3); + const AsdOps::TensorDesc &inputDesc = variantPack.inTensors.at(0).desc; + const AsdOps::Tensor &weightTensor = variantPack.inTensors.at(1); + const AsdOps::Tensor &biasTensor = variantPack.inTensors.at(2); ASD_LOG(INFO) << GetName() << " layer norm input desc:" << TensorUtil::AsdOpsTensorDescToString(inputDesc) << ", weightTensor:" << TensorUtil::AsdOpsTensorToString(weightTensor) -- Gitee