From 0ae06cbea6e7871d457f9b8f60823cf2f13d6d6d Mon Sep 17 00:00:00 2001 From: huanruizhi Date: Tue, 31 Dec 2024 17:09:04 +0800 Subject: [PATCH] debug --- tf_adapter/optimizers/set_var_format_pass.cc | 18 +++ .../testcase/set_var_format_pass_test.cc | 111 ++++++++++++++++++ .../testcase/set_var_format_pass_test.cc | 111 ++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 tf_adapter/tests/st/optimizers/testcase/set_var_format_pass_test.cc create mode 100644 tf_adapter/tests/ut/optimizers/testcase/set_var_format_pass_test.cc diff --git a/tf_adapter/optimizers/set_var_format_pass.cc b/tf_adapter/optimizers/set_var_format_pass.cc index e54fbb317..59d6c70ac 100644 --- a/tf_adapter/optimizers/set_var_format_pass.cc +++ b/tf_adapter/optimizers/set_var_format_pass.cc @@ -144,12 +144,30 @@ Status SetVarFormatPass::Run(const GraphOptimizationPassOptions &options) { return Status::OK(); } + static std::atomic num{0}; + if (kDumpGraph) { + GraphDef ori_graph_def; + graph_in->ToGraphDef(&ori_graph_def); + string ori_model_path = GetDumpPath() + "BeforeSetVarFormatGraph_"; + string graph_path = ori_model_path + std::to_string(num) + ".pbtxt"; + (void)WriteTextProto(Env::Default(), graph_path, ori_graph_def); + } + for (Node *node : graph_in->op_nodes()) { if ((node != nullptr) && ((node->type_string() == KEY_VAR_HANDLE_OP_VALUE) || (node->type_string() == KEY_VARIABLE_V2_VALUE))) { (void) AssignFormatToVarOutNodes(node); } } + if (kDumpGraph) { + GraphDef ori_graph_def; + graph_in->ToGraphDef(&ori_graph_def); + string ori_model_path = GetDumpPath() + "AfterSetVarFormatGraph_"; + string graph_path = ori_model_path + std::to_string(num) + ".pbtxt"; + (void)WriteTextProto(Env::Default(), graph_path, ori_graph_def); + } + num.fetch_add(1); + return Status::OK(); } diff --git a/tf_adapter/tests/st/optimizers/testcase/set_var_format_pass_test.cc b/tf_adapter/tests/st/optimizers/testcase/set_var_format_pass_test.cc new file mode 100644 index 000000000..61df07810 --- /dev/null +++ b/tf_adapter/tests/st/optimizers/testcase/set_var_format_pass_test.cc @@ -0,0 +1,111 @@ +#include "tf_adapter/optimizers/set_var_format_pass.h" +#include "gtest/gtest.h" +#include "mmpa/mmpa_api.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +class SetVarFormatPassTest : public testing::Test { + public: + SetVarFormatPassTest() : graph_(absl::make_unique(OpRegistry::Global())) {} + static void InitGraph(const string &graph_def_path, Graph *graph) { + GraphDef graph_def; + ReadTextProto(Env::Default(), graph_def_path, &graph_def); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); + } + + void InitGraph(const string &graph_def_path) { + char trusted_path[MMPA_MAX_PATH] = { "\0" }; + if (mmRealPath(graph_def_path.c_str(), trusted_path, MMPA_MAX_PATH) != EN_OK) { + LOG(ERROR) << "Get real path failed."; + return; + } + LOG(INFO) << "input graph def path: " << trusted_path; + InitGraph(trusted_path, graph_.get()); + original_ = CanonicalGraphString(graph_.get()); + } + + static bool IncludeNode(const Node *n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->type_string(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->type_string(), ":control"); + } else { + return strings::StrCat(n->type_string(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + for (Node* n : g->nodes()) { + if (IncludeNode(n)) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name("/job:localhost/replica:0/task:0/device:CPU:0"); + break; + } + } + } + + std::vector edges; + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + return strings::StrCat(absl::StrJoin(edges, ";")); + } + + string DoRunSetVarFormatPassTest() { + string before = CanonicalGraphString(graph_.get()); + LOG(INFO) << "Before set var format pass: " << before; + + std::unique_ptr *ug = &graph_; + GraphOptimizationPassOptions options; + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_do_function_inlining(true); + auto *custom_config = session_options.config.mutable_graph_options()->mutable_rewrite_options()->add_custom_optimizers(); + custom_config->set_name("NpuOptimizer"); + AttrValue job = AttrValue(); + job.set_s("localhost"); + (*custom_config->mutable_parameter_map())["job"] = job; + options.session_options = &session_options; + options.graph = ug; + FunctionLibraryDefinition flib_def((*ug)->flib_def()); + options.flib_def = &flib_def; + setenv("PRINT_MODEL", "1", 1); + SetVarFormatPass().Run(options); + unsetenv("PRINT_MODEL"); + + string result = CanonicalGraphString(options.graph->get()); + LOG(INFO) << "After set var format pass: " << result; + return result; + } + + const string &OriginalGraph() const { return original_; } + + std::unique_ptr graph_; + string original_; + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(SetVarFormatPassTest, DumpTest) { + string org_graph_def_path = "tf_adapter/tests/ut//optimizers/pbtxt/input_max_size_test.pbtxt"; + InitGraph(org_graph_def_path); + DoRunSetVarFormatPassTest(); +} +} // end namespace +} // end tensorflow + diff --git a/tf_adapter/tests/ut/optimizers/testcase/set_var_format_pass_test.cc b/tf_adapter/tests/ut/optimizers/testcase/set_var_format_pass_test.cc new file mode 100644 index 000000000..61df07810 --- /dev/null +++ b/tf_adapter/tests/ut/optimizers/testcase/set_var_format_pass_test.cc @@ -0,0 +1,111 @@ +#include "tf_adapter/optimizers/set_var_format_pass.h" +#include "gtest/gtest.h" +#include "mmpa/mmpa_api.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +class SetVarFormatPassTest : public testing::Test { + public: + SetVarFormatPassTest() : graph_(absl::make_unique(OpRegistry::Global())) {} + static void InitGraph(const string &graph_def_path, Graph *graph) { + GraphDef graph_def; + ReadTextProto(Env::Default(), graph_def_path, &graph_def); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); + } + + void InitGraph(const string &graph_def_path) { + char trusted_path[MMPA_MAX_PATH] = { "\0" }; + if (mmRealPath(graph_def_path.c_str(), trusted_path, MMPA_MAX_PATH) != EN_OK) { + LOG(ERROR) << "Get real path failed."; + return; + } + LOG(INFO) << "input graph def path: " << trusted_path; + InitGraph(trusted_path, graph_.get()); + original_ = CanonicalGraphString(graph_.get()); + } + + static bool IncludeNode(const Node *n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->type_string(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->type_string(), ":control"); + } else { + return strings::StrCat(n->type_string(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + for (Node* n : g->nodes()) { + if (IncludeNode(n)) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name("/job:localhost/replica:0/task:0/device:CPU:0"); + break; + } + } + } + + std::vector edges; + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + return strings::StrCat(absl::StrJoin(edges, ";")); + } + + string DoRunSetVarFormatPassTest() { + string before = CanonicalGraphString(graph_.get()); + LOG(INFO) << "Before set var format pass: " << before; + + std::unique_ptr *ug = &graph_; + GraphOptimizationPassOptions options; + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_do_function_inlining(true); + auto *custom_config = session_options.config.mutable_graph_options()->mutable_rewrite_options()->add_custom_optimizers(); + custom_config->set_name("NpuOptimizer"); + AttrValue job = AttrValue(); + job.set_s("localhost"); + (*custom_config->mutable_parameter_map())["job"] = job; + options.session_options = &session_options; + options.graph = ug; + FunctionLibraryDefinition flib_def((*ug)->flib_def()); + options.flib_def = &flib_def; + setenv("PRINT_MODEL", "1", 1); + SetVarFormatPass().Run(options); + unsetenv("PRINT_MODEL"); + + string result = CanonicalGraphString(options.graph->get()); + LOG(INFO) << "After set var format pass: " << result; + return result; + } + + const string &OriginalGraph() const { return original_; } + + std::unique_ptr graph_; + string original_; + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(SetVarFormatPassTest, DumpTest) { + string org_graph_def_path = "tf_adapter/tests/ut//optimizers/pbtxt/input_max_size_test.pbtxt"; + InitGraph(org_graph_def_path); + DoRunSetVarFormatPassTest(); +} +} // end namespace +} // end tensorflow + -- Gitee