From 1b389f4d9408367e91e166083a2d527ac872d0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E6=A1=82=E5=86=9B?= Date: Thu, 27 Mar 2025 13:46:52 +0000 Subject: [PATCH] =?UTF-8?q?!2949=20dfx=20print=20option=20name=20Merge=20p?= =?UTF-8?q?ull=20request=20!2949=20from=20=E9=BB=84=E6=A1=82=E5=86=9B/0327?= =?UTF-8?q?=5Fdfx?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/util/npu_attrs.cc | 12 +++++++----- tf_adapter/util/npu_attrs.h | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index f1e23afb0..c6cdc0eaf 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -2076,7 +2076,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options "allow_mix_precision_fp16", "allow_mix_precision_bf16", "allow_fp32_to_bf16"}; - NPU_REQUIRES_OK(CheckValueAllowed(precision_mode, kPrecisionModeList)); + NPU_REQUIRES_OK(CheckValueAllowed("precision_mode", precision_mode, kPrecisionModeList)); init_options_["precision_mode"] = precision_mode; init_options_[ge::PRECISION_MODE] = precision_mode; } @@ -2084,7 +2084,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options precision_mode_v2 = params.at("precision_mode_v2").s(); const static std::vector kPrecisionModeV2List = {"fp16", "origin", "cube_fp16in_fp32out", "mixed_float16", "mixed_bfloat16", "cube_hif8", "mixed_hif8"}; - NPU_REQUIRES_OK(CheckValueAllowed(precision_mode_v2, kPrecisionModeV2List)); + NPU_REQUIRES_OK(CheckValueAllowed("precision_mode_v2", precision_mode_v2, kPrecisionModeV2List)); init_options_["precision_mode_v2"] = precision_mode_v2; init_options_["ge.exec.precision_mode_v2"] = precision_mode_v2; } @@ -2384,7 +2384,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options if (params.count("variable_use_1g_huge_page") > 0) { variable_use_1g_huge_page = std::to_string(params.at("variable_use_1g_huge_page").i()); const static std::vector kAllowList = {0, 1, 2}; - NPU_REQUIRES_OK(CheckValueAllowed(params.at("variable_use_1g_huge_page").i(), kAllowList)); + NPU_REQUIRES_OK(CheckValueAllowed("variable_use_1g_huge_page", + params.at("variable_use_1g_huge_page").i(), kAllowList)); } if (params.count("memory_optimization_policy") > 0) { memory_optimization_policy = params.at("memory_optimization_policy").s(); @@ -2464,7 +2465,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options (params.at("export_compile_stat").value_case() == AttrValue::ValueCase::kI)) { export_compile_stat = params.at("export_compile_stat").i(); const static std::vector kExportCompileStatList = {0, 1, 2}; - NPU_REQUIRES_OK(CheckValueAllowed(export_compile_stat, kExportCompileStatList)); + NPU_REQUIRES_OK(CheckValueAllowed("export_compile_stat", export_compile_stat, + kExportCompileStatList)); init_options_["export_compile_stat"] = std::to_string(export_compile_stat); init_options_["ge.exportCompileStat"] = std::to_string(export_compile_stat); } @@ -2486,7 +2488,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options "auto"}; NPU_REQUIRES(params.at("jit_compile").value_case() == params.at("jit_compile").kS, errors::InvalidArgument("The data type of jit_compile is invalid. Expected string type.")); - NPU_REQUIRES_OK(CheckValueAllowed(params.at("jit_compile").s(), kJitCompileList)); + NPU_REQUIRES_OK(CheckValueAllowed("jit_compile", params.at("jit_compile").s(), kJitCompileList)); jit_compile = ConvertToGeJitValue(params.at("jit_compile").s()); } else { jit_compile = "2"; // 2 means auto diff --git a/tf_adapter/util/npu_attrs.h b/tf_adapter/util/npu_attrs.h index e5f6a4bed..b81cced0e 100644 --- a/tf_adapter/util/npu_attrs.h +++ b/tf_adapter/util/npu_attrs.h @@ -83,12 +83,12 @@ class NpuAttrs { return ss.str(); } template - static Status CheckValueAllowed(const T &v, const std::vector &allowed_values) { + static Status CheckValueAllowed(const std::string &option, const T &v, const std::vector &allowed_values) { if (find(allowed_values.begin(), allowed_values.end(), v) != allowed_values.cend()) { return Status::OK(); } else { std::stringstream ss; - ss << "'" << v << "' is invalid, it should be one of the list:"; + ss << "option " << option << " value '" << v << "' is invalid, it should be one of the list:"; ss << VectorToString(allowed_values); return errors::InvalidArgument(ss.str()); } -- Gitee