From e84dda9bdecfcf25c7fc64dd8a226411d8772169 Mon Sep 17 00:00:00 2001 From: xiebangrui Date: Mon, 16 Oct 2023 11:27:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=8F=AF=E8=83=BD=E5=AF=BC?= =?UTF-8?q?=E8=87=B4ReDoS=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tests/st/util/testcase/npu_attrs_test.cc | 32 +++++++++++ .../tests/ut/util/testcase/npu_attrs_test.cc | 32 +++++++++++ tf_adapter/util/npu_attrs.cc | 53 +++++++++++-------- 3 files changed, 94 insertions(+), 23 deletions(-) diff --git a/tf_adapter/tests/st/util/testcase/npu_attrs_test.cc b/tf_adapter/tests/st/util/testcase/npu_attrs_test.cc index 8f6df65cf..da59771c9 100644 --- a/tf_adapter/tests/st/util/testcase/npu_attrs_test.cc +++ b/tf_adapter/tests/st/util/testcase/npu_attrs_test.cc @@ -274,5 +274,37 @@ TEST_F(NpuAttrTest, SetNpuOptimizerAttrInvalidEnableOnlineInference) { s = NpuAttrs::SetNpuOptimizerAttr(options, nullptr); EXPECT_EQ(s.ok(), false); } + +TEST_F(NpuAttrTest, GetNpuOptimizerAttrCheckDumpStep) { + AttrValueMap attr_map; + + AttrValue graph_compiler_cache_dir = AttrValue(); + graph_compiler_cache_dir.set_s("./cache_dir"); + attr_map["_graph_compiler_cache_dir"] = graph_compiler_cache_dir; + + AttrValue npu_optimizer = AttrValue(); + npu_optimizer.set_s("NpuOptimizer"); + attr_map["_NpuOptimizer"] = npu_optimizer; + + AttrValue enable_dump = AttrValue(); + enable_dump.set_s("1"); + attr_map["_enable_dump"] = enable_dump; + + AttrValue dump_step = AttrValue(); + dump_step.set_s("yyy"); + attr_map["_dump_step"] = dump_step; + + AttrSlice attrs(&attr_map); + const auto &all_options = NpuAttrs::GetAllAttrOptions(attrs); + EXPECT_NE(all_options.find("dump_step"), all_options.cend()); + + AttrValue dump_step_2 = AttrValue(); + dump_step_2.set_s("0|2-1"); + attr_map["_dump_step"] = dump_step_2; + + AttrSlice attrs2(&attr_map); + const auto &all_options2 = NpuAttrs::GetAllAttrOptions(attrs2); + EXPECT_NE(all_options2.find("dump_step"), all_options2.cend()); +} } } // end tensorflow diff --git a/tf_adapter/tests/ut/util/testcase/npu_attrs_test.cc b/tf_adapter/tests/ut/util/testcase/npu_attrs_test.cc index a33b0362c..4c2b4f5ad 100644 --- a/tf_adapter/tests/ut/util/testcase/npu_attrs_test.cc +++ b/tf_adapter/tests/ut/util/testcase/npu_attrs_test.cc @@ -375,5 +375,37 @@ TEST_F(NpuAttrTest, CheckGraphCompilerCacheDir) { ASSERT_TRUE(find_ret != all_options.cend()); EXPECT_EQ(find_ret->second, "./cache_dir"); } + +TEST_F(NpuAttrTest, GetNpuOptimizerAttrCheckDumpStep) { + AttrValueMap attr_map; + + AttrValue graph_compiler_cache_dir = AttrValue(); + graph_compiler_cache_dir.set_s("./cache_dir"); + attr_map["_graph_compiler_cache_dir"] = graph_compiler_cache_dir; + + AttrValue npu_optimizer = AttrValue(); + npu_optimizer.set_s("NpuOptimizer"); + attr_map["_NpuOptimizer"] = npu_optimizer; + + AttrValue enable_dump = AttrValue(); + enable_dump.set_s("1"); + attr_map["_enable_dump"] = enable_dump; + + AttrValue dump_step = AttrValue(); + dump_step.set_s("yyy"); + attr_map["_dump_step"] = dump_step; + + AttrSlice attrs(&attr_map); + const auto &all_options = NpuAttrs::GetAllAttrOptions(attrs); + ASSERT_TRUE(all_options.find("dump_step") != all_options.cend()); + + AttrValue dump_step_2 = AttrValue(); + dump_step_2.set_s("0|2-1"); + attr_map["_dump_step"] = dump_step_2; + + AttrSlice attrs2(&attr_map); + const auto &all_options2 = NpuAttrs::GetAllAttrOptions(attrs2); + ASSERT_TRUE(all_options2.find("dump_step") != all_options2.cend()); +} } } // end tensorflow diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index a9a39c6e7..716a00185 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -46,6 +46,11 @@ std::map NpuAttrs::dataset_execute_info_; std::map NpuAttrs::init_options_; std::mutex NpuAttrs::mutex_; const static int32_t kRuntimeTypeHeterogeneous = 1; +const std::string kNumerics = "0123456789"; +const std::string kErrMsgInvalidStepSets = "dump_step only support dump <= 100 sets of data"; +const std::string kErrMsgReverseStepNum = "in range steps, the first step is >= "\ + "second step, correct example:'0|5|10-20'"; +const std::string kErrMsgInvalidStepFormat = "dump_step string style is error, correct example:'0|5|10|50-100'"; bool NpuAttrs::CheckIsNewDataTransfer() { uint32_t device_id = 0U; @@ -186,29 +191,33 @@ void Split(const std::string &s, std::vector &result, const char *d } inline Status checkDumpStep(const std::string &dump_step) { - std::string tmp_dump_step = dump_step + "|"; - std::smatch result; std::vector match_vecs; - std::regex pattern(R"((\d{1,}-\d{1,}\||\d{1,}\|)+)"); - if (regex_match(tmp_dump_step, result, pattern)) { - Split(result.str(), match_vecs, "|"); - // 100 is the max sets of dump steps. - if (match_vecs.size() > 100) { - return errors::InvalidArgument("dump_step only support dump <= 100 sets of data"); - } - for (const auto &match_vec : match_vecs) { - std::vector tmp_vecs; - Split(match_vec, tmp_vecs, "-"); - if (tmp_vecs.size() > 1) { - if (std::atoi(tmp_vecs[0].c_str()) >= std::atoi(tmp_vecs[1].c_str())) { - return errors::InvalidArgument("in range steps, the first step is >= " - "second step, correct example:'0|5|10-20'"); - } + Split(dump_step, match_vecs, "|"); + // 100 is the max sets of dump steps. + if (match_vecs.size() > 100U) { + return errors::InvalidArgument(kErrMsgInvalidStepSets); + } + const auto is_str_num = [] (const std::string &s) -> bool { + return s.find_first_not_of(kNumerics) == string::npos; + }; + for (const auto &match_vec : match_vecs) { + std::vector tmp_vecs; + Split(match_vec, tmp_vecs, "-"); + // 正确的格式是1或者2-3这种,因此tmp_vecs的大小只能是1或者2 + if (tmp_vecs.size() == 1U) { + if (!is_str_num(tmp_vecs[0U])) { + return errors::InvalidArgument(kErrMsgInvalidStepFormat); + } + } else if (tmp_vecs.size() == 2U) { + if (!is_str_num(tmp_vecs[0U]) || !is_str_num(tmp_vecs[1U])) { + return errors::InvalidArgument(kErrMsgInvalidStepFormat); + } + if (std::atoi(tmp_vecs[0U].c_str()) >= std::atoi(tmp_vecs[1U].c_str())) { + return errors::InvalidArgument(kErrMsgReverseStepNum); } + } else { + return errors::InvalidArgument(kErrMsgInvalidStepFormat); } - } else { - return errors::InvalidArgument("dump_step string style is error," - " correct example:'0|5|10|50-100'"); } return Status::OK(); } @@ -439,7 +448,6 @@ std::map NpuAttrs::GetSessOptions(const OpKernelConstr Status s = checkDumpStep(dump_step); if (!s.ok()) { ADP_LOG(FATAL) << s.error_message(); - LOG(FATAL) << s.error_message(); } } if (ctx->GetAttr("_dump_mode", &dump_mode) == Status::OK()) { @@ -1278,7 +1286,6 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & Status s = checkDumpStep(dump_step); if (!s.ok()) { ADP_LOG(FATAL) << s.error_message(); - LOG(FATAL) << s.error_message(); } } } @@ -1759,7 +1766,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options Status s = checkDumpStep(dump_step); if (!s.ok()) { ADP_LOG(FATAL) << s.error_message(); - LOG(FATAL) << s.error_message(); + return errors::Internal(s.error_message()); } } if (params.count("dump_mode") > 0) { -- Gitee