From 734642145d85276491f7fa68d71e196dd24119ce Mon Sep 17 00:00:00 2001 From: CuiXiaoFeng Date: Mon, 18 Aug 2025 11:16:06 +0800 Subject: [PATCH 1/3] Add switch optimization --- tensorflow/core/grappler/optimizers/BUILD | 25 ++ .../grappler/optimizers/meta_optimizer.cc | 6 + .../grappler/optimizers/switch_optimizer.cc | 370 ++++++++++++++++++ .../grappler/optimizers/switch_optimizer.h | 52 +++ .../core/protobuf/rewriter_config.proto | 2 + tensorflow/python/eager/context.py | 2 + tensorflow/python/framework/config.py | 1 + 7 files changed, 458 insertions(+) create mode 100644 tensorflow/core/grappler/optimizers/switch_optimizer.cc create mode 100644 tensorflow/core/grappler/optimizers/switch_optimizer.h diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 9be609b3..16307b89 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -620,6 +620,7 @@ cc_library( ":remapper", ":scoped_allocator_optimizer", ":shape_optimizer", + ":switch_optimizer", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -792,6 +793,30 @@ tf_cuda_cc_test( ], ) +tf_kernel_library( + name = "switch_optimizer", + srcs = ["switch_optimizer.cc"], + hdrs = [ + "switch_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":constant_folding", + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:symbolic_shapes", + ], +) + + tf_kernel_library( name = "remapper", srcs = ["remapper.cc"], diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index da83e413..4e671327 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h" #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" +#include "tensorflow/core/grappler/optimizers/switch_optimizer.h" #include "tensorflow/core/grappler/utils/canonicalizer.h" #include "tensorflow/core/grappler/utils/colocation.h" #include "tensorflow/core/grappler/utils/functions.h" @@ -179,6 +180,7 @@ std::unique_ptr MetaOptimizer::MakeNewOptimizer( /*lower_control_flow=*/!IsSingleThreadedExecutor())); MK_OPT("constfold", new ConstantFolding(cpu_device_)); MK_OPT("shape", new ShapeOptimizer()); + MK_OPT("switch", new SwitchOptimizer()); MK_OPT("remap", new Remapper(cfg_.remapping())); MK_OPT("layout", new GenericLayoutOptimizer()); MK_OPT("auto_mixed_precision", @@ -234,6 +236,9 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.shape_optimization() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique()); } + if (cfg_.switch_optimization() != RewriterConfig::OFF) { + optimizers->push_back(MakeUnique()); + } if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) { optimizers->push_back( MakeUnique(cfg_.auto_mixed_precision())); @@ -811,6 +816,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) { rewrite_cfg.function_optimization() != RewriterConfig::OFF || rewrite_cfg.constant_folding() != RewriterConfig::OFF || rewrite_cfg.shape_optimization() != RewriterConfig::OFF || + rewrite_cfg.switch_optimization() != RewriterConfig::OFF || rewrite_cfg.remapping() != RewriterConfig::OFF || rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF || rewrite_cfg.loop_optimization() != RewriterConfig::OFF || diff --git a/tensorflow/core/grappler/optimizers/switch_optimizer.cc b/tensorflow/core/grappler/optimizers/switch_optimizer.cc new file mode 100644 index 00000000..8ec03f95 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/switch_optimizer.cc @@ -0,0 +1,370 @@ +/* Copyright 2025 Huawei. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/switch_optimizer.h" + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" +#include "tensorflow/core/lib/core/errors.h" +#include + +namespace tensorflow { +namespace grappler { +const string thisPrefix = "my_"; + +struct SwitchChain { + // The producer, usually a data node. + // Only got name of the producer. As we are not using NodeMap, we don't know its NodeDef + string producer; + + // Ths consumer, usually a compute node guarded by switch node[s] + NodeDef *consumer; + + // The chain of switches, ordered in reverse + std::vector > switches; + + // The chain of predicate, again only have their names here + std::vector > predicates; + + // The built string used to match for the repeated chain of predicates + string match_str; +}; + +using SwitchChains = std::vector; +using Group = std::vector; +// TODO: describing this optimization +Status SwitchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + const char* enable_dump = getenv("SWTICH_OPTIMIZATION_INFO"); + if (enable_dump != nullptr) { + std::cerr << "In SwitchOptimizer::Optimize\n"; + } + *optimized_graph = item.graph; + std::vector groupList={}; + std::deque worklist = {}; + for (NodeDef& node : *optimized_graph->mutable_node()) { + worklist.push_back(&node); + } + std::vector tranversedList={}; + std::unordered_map createdLogicalNodes = {}; + for (;!worklist.empty(); worklist.pop_front()) { + NodeDef* np = worklist.front(); + NodeDef& node = *np; + if (enable_dump != nullptr) { + std::cerr << np->DebugString() <<"\n"; + } + // Node is ready if const or placeholder + if (IsConstant(node) || IsPlaceholder(node) || IsIdentity(node)) { + tranversedList.push_back(&node); + continue; + } + auto node_name = NodeName(node.name()); + // If this node was created by this pass(starts with thisPrefix) in earlier run, + // add it into the createLogicalNodes list to aviod creating duplicated nodes + if (node_name.compare(0, thisPrefix.size(), thisPrefix) == 0) { + createdLogicalNodes[node_name] = &node; + tranversedList.push_back(&node); + continue; + } + + auto all_inputs = node.input(); + bool ready = true; + // if any of inputs is not in the tranversedList or const/variables, it is not ready. + for (auto in : all_inputs) { + if (!std::any_of(tranversedList.begin(), tranversedList.end(), [in](NodeDef* n) { + int dummy; + return n->name() == ParseNodeName(in, &dummy); })) { + worklist.push_back(&node); + ready = false; + break; + } + } + if (!ready) continue; + tranversedList.push_back(&node); + bool found = false; + for (auto in : all_inputs) { + int pos; + string input_name = ParseNodeName(in, &pos); + // check if this switch's input data (which should be another switch) is in the list + for (Group& nodes: groupList) { + if (std::any_of(nodes.begin(), nodes.end(), [input_name](NodeDef* n) { + return IsSwitch(*n) && n->name() == input_name; })) + { + if (enable_dump != nullptr) { + std::cerr << "Found existing node\n"; + std::cerr << "Pusing node to existing group: " << node.name()<< "\n"; + } + nodes.push_back(&node); + found = true; + break; + } + } + // If found, don't break here as a node may have more than one inputs that are switch + } + if (!found) { + if (IsSwitch(node)) { + Group newGroup; + if (enable_dump != nullptr) { + std::cerr << "Pusing node to new group: " << node.name()<< "\n"; + } + newGroup.push_back(&node); + groupList.push_back(newGroup); + } + } + } +if (enable_dump != nullptr) { + for (Group nodes: groupList) { + std::cerr << "Group start: " << nodes.size() << "\n"; + for (auto n: nodes) { + std::cerr << n->name() << "\n"; + } + std::cerr << "Group end\n"; + } + + std::cerr << "Previously created logical node found:\n"; + for (auto n : createdLogicalNodes) { + std::cerr << n.second->DebugString() << "\n"; + } +} + + // Now build switch chains + SwitchChains chains={}; + for (Group nodes: groupList) { + // mark bypass + if (nodes.size() != 5) + continue; + std::vector switchList = {}; + std::vector consumerList = {}; + for (NodeDef * np : nodes) { + if (IsSwitch(*np)) + switchList.push_back(np); + else + consumerList.push_back(np); + } + for (NodeDef * consumer: consumerList) { + SwitchChain chain={}; + chain.consumer = consumer; + NodeDef * tailSwitch; + bool found = false; + for (auto in: consumer->input()) { + int pos; + auto it = std::find_if(switchList.begin(), switchList.end(), [in,&pos](NodeDef *n) { + return n->name() == ParseNodeName(in, &pos);}); + if (it != switchList.end()) { + tailSwitch = *it; + // Insert this switch in the front of the list + chain.switches.insert(chain.switches.begin(), std::make_pair(tailSwitch, bool(pos))); + found = true; + break; + } + } + if (found) { + for (NodeDef * theSwitch = tailSwitch;;) { + auto switch_data = theSwitch->input(0); + int data_pos; + string switch_data_name = ParseNodeName(switch_data, &data_pos); + if (enable_dump != nullptr) { + std::cerr << "Data of switch: " << switch_data_name << "\n"; // cxf success + } + auto it = std::find_if(switchList.begin(), switchList.end(), [switch_data_name](NodeDef *n) { + return n->name() == switch_data_name;}); + if (it != switchList.end()) { + // Insert this switch in the front of the list + chain.switches.insert(chain.switches.begin(), std::make_pair(*it, bool(data_pos))); + theSwitch = *it; + } + else { + // Cannot find further switch in the producer chain, Log the most recent producer and stop + chain.producer = theSwitch->input(0); + break; + } + } + } + else { + if (enable_dump != nullptr) { + return Status(error::NOT_FOUND, "Not switch in a chain"); + } + } + // mark bypass + if (chain.switches.size() != 4) + continue; + chains.push_back(std::move(chain)); + } + } +if (enable_dump != nullptr) { + std::cerr << "Chains:\n"; + for (auto chain: chains) { + std::cerr << "Producer: " << chain.producer << "\n"; + //for (auto s: chain.switches) { + for (auto it = chain.switches.begin(); it != chain.switches.end(); ++it) { + std::cerr << it->first->name() <<":"<second<< " - "; + } + std::cerr << "\nConsumer: " << chain.consumer->name() << "\n"; + } +} + + // Now build predicate chains from the switch chains and match_str + //using predicateChain = std::pair, string>; + for (auto& sc : chains) { + sc.match_str = ""; + for (auto it = sc.switches.begin(); it != sc.switches.end(); ++it) { + auto thisSwitch = *it; + // Always take the last predicate as true in matching + auto thisPred = thisSwitch.first->input(1); + auto thisPredValue = thisSwitch.second; + //auto thisPredValue = ((std::next(it) == sc.switches.end()) ? true : thisSwitch.second); + sc.predicates.push_back(std::make_pair(thisPred, thisPredValue)); + sc.match_str += "@"; + sc.match_str += thisPred; + sc.match_str += "^"; + sc.match_str += std::to_string(thisPredValue); + } + } + +if (enable_dump != nullptr) { + std::cerr << "Predicate chain strings:\n"; + for (auto p: chains) { + std::cerr << p.match_str <<"\n"; + } +} + + // Now group the same predicate chains + std::unordered_map> groupByPredicates; + for (auto p: chains) { + auto key = p.match_str; + groupByPredicates[key].push_back(p); + } + +if (enable_dump != nullptr) { + std::cerr << "Grouped predicate chains:\n"; + for (auto g: groupByPredicates) { + std::cerr << g.first << ": "; + for (SwitchChain p: g.second) { + for (auto pred: p.predicates) { + std::cerr << pred.first << ":" << pred.second << "-"; + } + std::cerr << " & "; + } + std::cerr << "\n"; + } +} + + for (auto g: groupByPredicates) { + // Pick up the first chain as a representative for rest of them + SwitchChain firstChain = *g.second.begin(); + if (firstChain.predicates.size() <= 1) continue; + auto thisDevice = firstChain.switches[0].first->device(); + bool first_pred = true; + string prevPredicate = {}; + bool prevPredValue = true; + //for (auto pred: firstChain.predicates) { + for (auto it = firstChain.predicates.begin(); it != firstChain.predicates.end(); ++it) { + auto pred = *it; + auto thisPredicate = pred.first; + auto thisPredValue = pred.second; + // For the last predicator, we don't need to create the not operation, + // as the not logic of the last predicator is embedded in the edge between + // the last switch and the consumer. We don't change this edge. + if (!thisPredValue /* && std::next(it) != firstChain.predicates.end()*/) { + // Create a logical not node + string node_name = thisPrefix + "not_" + thisPredicate; + if (createdLogicalNodes.find(node_name) != createdLogicalNodes.end()) { + thisPredicate = node_name; + } + else { + NodeDef* not_op = optimized_graph->add_node(); + not_op->set_op("LogicalNot"); + not_op->add_input(thisPredicate); + not_op->set_name(node_name); + not_op->set_device(thisDevice); + createdLogicalNodes[node_name] = not_op; + thisPredicate = node_name; + } + } + if (!first_pred) { + // Create a logical and node + string node_name = thisPrefix + prevPredicate + "_and_" + thisPredicate; + if (createdLogicalNodes.find(node_name) != createdLogicalNodes.end()) { + thisPredicate = node_name; + } + else { + NodeDef* and_op = optimized_graph->add_node(); + and_op->set_op("LogicalAnd"); + and_op->add_input(prevPredicate); + and_op->add_input(thisPredicate); + and_op->set_name(node_name); + and_op->set_device(thisDevice); + createdLogicalNodes[node_name] = and_op; + thisPredicate = node_name; + } + } + first_pred = false; + prevPredicate = thisPredicate; + prevPredValue = thisPredValue; + } + + if (prevPredicate.empty()) continue; + for (SwitchChain p: g.second) { + NodeDef * lastSwitch = (*(p.switches.rbegin())).first; + if (enable_dump != nullptr) { + std::cerr << "Last switch before change: \n"; + std::cerr << lastSwitch->DebugString(); + } + lastSwitch->set_input(0, p.producer); + lastSwitch->set_input(1, prevPredicate); + if (prevPredValue == false) { + int consumer_input_size = p.consumer->input().size(); + int pos; + string lastSwitchName = ParseNodeName(lastSwitch->name(), &pos); + for (int i = 0; i< consumer_input_size; i++) { + if (lastSwitchName == p.consumer->input(i)) { + if (pos != 0) { + return Status(error::NOT_FOUND, "consumer of the last switch is not using the false branch"); + } + p.consumer->set_input(i, lastSwitchName + ":1"); + } + } + } + if (enable_dump != nullptr) { + std::cerr << "Last switch after change: \n"; + std::cerr << lastSwitch->DebugString(); + } + } + } + if (enable_dump != nullptr) { + std::cerr << "List of logical node created:\n"; + for (auto n: createdLogicalNodes) { + std::cerr << n.second->DebugString(); + } + std::cerr << "After optimization:\n" << optimized_graph->DebugString(); + } + return Status::OK(); +} + +void SwitchOptimizer::Feedback(Cluster* /*cluster*/, + const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for LoopOptimizer. +} + +} // end namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/switch_optimizer.h b/tensorflow/core/grappler/optimizers/switch_optimizer.h new file mode 100644 index 00000000..5659d420 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/switch_optimizer.h @@ -0,0 +1,52 @@ +/* Copyright 2025 Huawei. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SWITCH_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SWITCH_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TensorFlow subgraphs that operate on shape and shape related +// information. +class SwitchOptimizer : public GraphOptimizer { + public: + SwitchOptimizer() {} + explicit SwitchOptimizer(RewriterConfig::Toggle opt_level) {} + + ~SwitchOptimizer() override {} + + string name() const override { return "~witch_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SWITCH_OPTIMIZER_H_ diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index e657a184..b0e7ab66 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -87,6 +87,8 @@ message RewriterConfig { // Note that this can change the numerical stability of the graph and may // require the use of loss scaling to maintain model convergence. Toggle auto_mixed_precision = 23; + // Switch reduction optimization + Toggle switch_optimization = 24; // Disable the entire meta optimizer (off by default). bool disable_meta_optimizer = 19; diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 073c3338..3982751e 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -925,6 +925,7 @@ class Context(object): rewriter_toggle("layout_optimizer") rewriter_toggle("constant_folding") rewriter_toggle("shape_optimization") + rewriter_toggle("switch_optimization") rewriter_toggle("remapping") rewriter_toggle("arithmetic_optimization") rewriter_toggle("dependency_optimization") @@ -1408,6 +1409,7 @@ class Context(object): rewriter_toggle("layout_optimizer") rewriter_toggle("constant_folding") rewriter_toggle("shape_optimization") + rewriter_toggle("switch_optimization") rewriter_toggle("remapping") rewriter_toggle("arithmetic_optimization") rewriter_toggle("dependency_optimization") diff --git a/tensorflow/python/framework/config.py b/tensorflow/python/framework/config.py index c696675f..065a812d 100644 --- a/tensorflow/python/framework/config.py +++ b/tensorflow/python/framework/config.py @@ -137,6 +137,7 @@ def set_optimizer_experimental_options(options): Statically infer the value of tensors when possible, and materialize the result using constants. - shape_optimization: Simplify computations made on shapes. + - switch_optimization: Simplify chain of switches - remapping: Remap subgraphs onto more efficient implementations. - arithmetic_optimization: Simplify arithmetic ops with common sub-expression elimination and arithmetic simplification. -- Gitee From 1760d8778426a4037e2fa652bca5630f70be3987 Mon Sep 17 00:00:00 2001 From: CuiXiaoFeng Date: Thu, 21 Aug 2025 20:35:34 +0800 Subject: [PATCH 2/3] Remove restrict for switch opt and reject switch with multi users --- .../grappler/optimizers/switch_optimizer.cc | 230 ++++++++---------- .../grappler/optimizers/switch_optimizer.h | 5 +- 2 files changed, 101 insertions(+), 134 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/switch_optimizer.cc b/tensorflow/core/grappler/optimizers/switch_optimizer.cc index 8ec03f95..590009d0 100644 --- a/tensorflow/core/grappler/optimizers/switch_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/switch_optimizer.cc @@ -24,16 +24,14 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/lib/core/errors.h" -#include namespace tensorflow { namespace grappler { -const string thisPrefix = "my_"; +const string thisPrefix = "swo_"; struct SwitchChain { // The producer, usually a data node. - // Only got name of the producer. As we are not using NodeMap, we don't know its NodeDef - string producer; + NodeDef *producer; // Ths consumer, usually a compute node guarded by switch node[s] NodeDef *consumer; @@ -41,7 +39,7 @@ struct SwitchChain { // The chain of switches, ordered in reverse std::vector > switches; - // The chain of predicate, again only have their names here + // The chain of predicate, only need their names here std::vector > predicates; // The built string used to match for the repeated chain of predicates @@ -50,68 +48,43 @@ struct SwitchChain { using SwitchChains = std::vector; using Group = std::vector; -// TODO: describing this optimization + +// SwitchOptimizer reduces a chain of switches. +// For a given chain of switches +// tf.add(tf.raw_ops.Switch(tf.raw_ops.Switch(tf.raw_ops.Switch(data, pred0)[1], pred1)[0], pred2))[1], x) +// Replace it with +// all_preds = tf.logical_and(pred0, tf.logical_and(tf.logical_not(pred1), pred2)) +// tf.add(tf.raw_ops.Switch(data, all_preds)[1],x) +// By doing so, number of Switches can be reduced, at the cost of computing more logical_and/not node. +// However, in the real-world models, the same chain of logical_and/not can be reused in many places. Say there +// are other places like this: +// tf.multiply(tf.raw_ops.Switch(tf.raw_ops.Switch(tf.raw_ops.Switch(data2, pred0)[1], pred1)[0], pred2))[1], y) +// to be replaced with +// tf.multiply(data2, tf.raw_ops.Switch(all_preds, all_preds)[1], y) +// Thus this replacement should actually benefit. Status SwitchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { - const char* enable_dump = getenv("SWTICH_OPTIMIZATION_INFO"); - if (enable_dump != nullptr) { - std::cerr << "In SwitchOptimizer::Optimize\n"; - } + VLOG(2) << "In SwitchOptimizer::Optimize\n"; *optimized_graph = item.graph; + NodeMap node_map(optimized_graph); + Status status; + utils::MutableGraphView graph_view(optimized_graph, &status); + TF_RETURN_IF_ERROR(graph_view.SortTopologically(false, {})); std::vector groupList={}; - std::deque worklist = {}; - for (NodeDef& node : *optimized_graph->mutable_node()) { - worklist.push_back(&node); - } - std::vector tranversedList={}; - std::unordered_map createdLogicalNodes = {}; - for (;!worklist.empty(); worklist.pop_front()) { - NodeDef* np = worklist.front(); + for (auto& nv: graph_view.GetNodes()) { + NodeDef* np = nv.node(); NodeDef& node = *np; - if (enable_dump != nullptr) { - std::cerr << np->DebugString() <<"\n"; - } - // Node is ready if const or placeholder - if (IsConstant(node) || IsPlaceholder(node) || IsIdentity(node)) { - tranversedList.push_back(&node); - continue; - } - auto node_name = NodeName(node.name()); - // If this node was created by this pass(starts with thisPrefix) in earlier run, - // add it into the createLogicalNodes list to aviod creating duplicated nodes - if (node_name.compare(0, thisPrefix.size(), thisPrefix) == 0) { - createdLogicalNodes[node_name] = &node; - tranversedList.push_back(&node); - continue; - } - - auto all_inputs = node.input(); - bool ready = true; - // if any of inputs is not in the tranversedList or const/variables, it is not ready. - for (auto in : all_inputs) { - if (!std::any_of(tranversedList.begin(), tranversedList.end(), [in](NodeDef* n) { - int dummy; - return n->name() == ParseNodeName(in, &dummy); })) { - worklist.push_back(&node); - ready = false; - break; - } - } - if (!ready) continue; - tranversedList.push_back(&node); bool found = false; + auto all_inputs = node.input(); for (auto in : all_inputs) { - int pos; - string input_name = ParseNodeName(in, &pos); + string input_name = NodeName(in); // check if this switch's input data (which should be another switch) is in the list for (Group& nodes: groupList) { if (std::any_of(nodes.begin(), nodes.end(), [input_name](NodeDef* n) { return IsSwitch(*n) && n->name() == input_name; })) { - if (enable_dump != nullptr) { - std::cerr << "Found existing node\n"; - std::cerr << "Pusing node to existing group: " << node.name()<< "\n"; - } + // Input is a switch node, and the switch node is in this group + VLOG(2) << "Pusing node to existing group: " << node.name()<< "\n"; nodes.push_back(&node); found = true; break; @@ -121,36 +94,28 @@ Status SwitchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } if (!found) { if (IsSwitch(node)) { + // This is a switch node, and none of its input is in a existing group. + // Adding this switch to a new group Group newGroup; - if (enable_dump != nullptr) { - std::cerr << "Pusing node to new group: " << node.name()<< "\n"; - } + VLOG(2) << "Pusing node to new group: " << node.name()<< "\n"; newGroup.push_back(&node); groupList.push_back(newGroup); } + else { + VLOG(2) << "Not adding into any group: " << node.name()<< "\n"; + } } } -if (enable_dump != nullptr) { for (Group nodes: groupList) { - std::cerr << "Group start: " << nodes.size() << "\n"; + VLOG(2) << "Group start: " << nodes.size() << "\n"; for (auto n: nodes) { - std::cerr << n->name() << "\n"; + VLOG(2) << n->name() << "\n"; } - std::cerr << "Group end\n"; - } - - std::cerr << "Previously created logical node found:\n"; - for (auto n : createdLogicalNodes) { - std::cerr << n.second->DebugString() << "\n"; + VLOG(2) << "Group end\n"; } -} - - // Now build switch chains + // Now build switch chains from groups SwitchChains chains={}; for (Group nodes: groupList) { - // mark bypass - if (nodes.size() != 5) - continue; std::vector switchList = {}; std::vector consumerList = {}; for (NodeDef * np : nodes) { @@ -159,6 +124,17 @@ if (enable_dump != nullptr) { else consumerList.push_back(np); } + // If a switch has more than two users forming a fork, reject the group + if (switchList.empty()) + continue; + bool invalide = false; + auto last_it = std::prev(switchList.end()); + for (auto it = switchList.begin(); it != last_it; it++) { + if (node_map.GetOutputs(NodeName((*it)->name())).size() != 1) + invalide = true; + } + if (invalide) + continue; for (NodeDef * consumer: consumerList) { SwitchChain chain={}; chain.consumer = consumer; @@ -181,9 +157,6 @@ if (enable_dump != nullptr) { auto switch_data = theSwitch->input(0); int data_pos; string switch_data_name = ParseNodeName(switch_data, &data_pos); - if (enable_dump != nullptr) { - std::cerr << "Data of switch: " << switch_data_name << "\n"; // cxf success - } auto it = std::find_if(switchList.begin(), switchList.end(), [switch_data_name](NodeDef *n) { return n->name() == switch_data_name;}); if (it != switchList.end()) { @@ -193,38 +166,31 @@ if (enable_dump != nullptr) { } else { // Cannot find further switch in the producer chain, Log the most recent producer and stop - chain.producer = theSwitch->input(0); + chain.producer = node_map.GetNode(NodeName(theSwitch->input(0))); break; } } } else { - if (enable_dump != nullptr) { - return Status(error::NOT_FOUND, "Not switch in a chain"); - } + return Status(error::NOT_FOUND, "Not switch in a chain"); } - // mark bypass - if (chain.switches.size() != 4) - continue; chains.push_back(std::move(chain)); } } -if (enable_dump != nullptr) { - std::cerr << "Chains:\n"; + VLOG(2) << "Chains:\n"; for (auto chain: chains) { - std::cerr << "Producer: " << chain.producer << "\n"; + VLOG(2) << "Producer: " << chain.producer->name() << "\n"; //for (auto s: chain.switches) { for (auto it = chain.switches.begin(); it != chain.switches.end(); ++it) { - std::cerr << it->first->name() <<":"<second<< " - "; + VLOG(2) << it->first->name() <<":"<second<< " - "; } - std::cerr << "\nConsumer: " << chain.consumer->name() << "\n"; + VLOG(2) << "\nConsumer: " << chain.consumer->name() << "\n"; } -} // Now build predicate chains from the switch chains and match_str //using predicateChain = std::pair, string>; for (auto& sc : chains) { - sc.match_str = ""; + if (sc.switches.size() <= 1) continue; for (auto it = sc.switches.begin(); it != sc.switches.end(); ++it) { auto thisSwitch = *it; // Always take the last predicate as true in matching @@ -239,12 +205,10 @@ if (enable_dump != nullptr) { } } -if (enable_dump != nullptr) { - std::cerr << "Predicate chain strings:\n"; + VLOG(2) << "Predicate chain strings:\n"; for (auto p: chains) { - std::cerr << p.match_str <<"\n"; + VLOG(2) << p.match_str <<"\n"; } -} // Now group the same predicate chains std::unordered_map> groupByPredicates; @@ -253,19 +217,17 @@ if (enable_dump != nullptr) { groupByPredicates[key].push_back(p); } -if (enable_dump != nullptr) { - std::cerr << "Grouped predicate chains:\n"; + VLOG(2) << "Grouped predicate chains:\n"; for (auto g: groupByPredicates) { - std::cerr << g.first << ": "; + VLOG(2) << (g.first.empty() ? "" : g.first) << ": "; for (SwitchChain p: g.second) { for (auto pred: p.predicates) { - std::cerr << pred.first << ":" << pred.second << "-"; + VLOG(2) << pred.first << ":" << pred.second << "-"; } - std::cerr << " & "; + VLOG(2) << " & "; } - std::cerr << "\n"; + VLOG(2) << "\n"; } -} for (auto g: groupByPredicates) { // Pick up the first chain as a representative for rest of them @@ -275,44 +237,55 @@ if (enable_dump != nullptr) { bool first_pred = true; string prevPredicate = {}; bool prevPredValue = true; - //for (auto pred: firstChain.predicates) { - for (auto it = firstChain.predicates.begin(); it != firstChain.predicates.end(); ++it) { - auto pred = *it; + for (auto pred: firstChain.predicates) { auto thisPredicate = pred.first; auto thisPredValue = pred.second; - // For the last predicator, we don't need to create the not operation, - // as the not logic of the last predicator is embedded in the edge between - // the last switch and the consumer. We don't change this edge. - if (!thisPredValue /* && std::next(it) != firstChain.predicates.end()*/) { - // Create a logical not node + if (!thisPredValue) { string node_name = thisPrefix + "not_" + thisPredicate; - if (createdLogicalNodes.find(node_name) != createdLogicalNodes.end()) { + if (node_map.GetNode(node_name) != nullptr) { thisPredicate = node_name; } else { + // Not found in in exist nodes, create a logical not node NodeDef* not_op = optimized_graph->add_node(); not_op->set_op("LogicalNot"); not_op->add_input(thisPredicate); not_op->set_name(node_name); not_op->set_device(thisDevice); - createdLogicalNodes[node_name] = not_op; + // Not sure if need to set _dtype, as the default LogicalNot does not + //(*not_op->mutable_attr())["_dtype"].set_type(DT_BOOL); + tensorflow::TensorShapeProto* shape = + (*not_op->mutable_attr())["_output_shapes"] + .mutable_list() + ->add_shape(); + shape->set_unknown_rank(true); + node_map.AddNode(node_name, not_op); + node_map.AddOutput(thisPredicate, node_name); thisPredicate = node_name; } } if (!first_pred) { - // Create a logical and node string node_name = thisPrefix + prevPredicate + "_and_" + thisPredicate; - if (createdLogicalNodes.find(node_name) != createdLogicalNodes.end()) { + if (node_map.GetNode(node_name) != nullptr) { thisPredicate = node_name; } else { + // Not found in in exist nodes, create a logical and node NodeDef* and_op = optimized_graph->add_node(); and_op->set_op("LogicalAnd"); and_op->add_input(prevPredicate); and_op->add_input(thisPredicate); and_op->set_name(node_name); and_op->set_device(thisDevice); - createdLogicalNodes[node_name] = and_op; + // Not sure if need to set _dtype, as the default LogicalAnd does not + // (*and_op->mutable_attr())["_dtype"].set_type(DT_BOOL); + tensorflow::TensorShapeProto* shape = + (*and_op->mutable_attr())["_output_shapes"] + .mutable_list() + ->add_shape(); + shape->set_unknown_rank(true); + node_map.AddNode(node_name, and_op); + node_map.AddOutput(thisPredicate, node_name); thisPredicate = node_name; } } @@ -324,12 +297,14 @@ if (enable_dump != nullptr) { if (prevPredicate.empty()) continue; for (SwitchChain p: g.second) { NodeDef * lastSwitch = (*(p.switches.rbegin())).first; - if (enable_dump != nullptr) { - std::cerr << "Last switch before change: \n"; - std::cerr << lastSwitch->DebugString(); - } - lastSwitch->set_input(0, p.producer); + VLOG(2) << "Last switch before change: \n"; + VLOG(2) << lastSwitch->DebugString(); + auto old_input0 = lastSwitch->input(0); + auto old_input1 = lastSwitch->input(1); + lastSwitch->set_input(0, p.producer->name()); lastSwitch->set_input(1, prevPredicate); + node_map.UpdateInput(lastSwitch->name(), old_input0, p.producer->name()); + node_map.UpdateInput(lastSwitch->name(), old_input1, prevPredicate); if (prevPredValue == false) { int consumer_input_size = p.consumer->input().size(); int pos; @@ -340,22 +315,15 @@ if (enable_dump != nullptr) { return Status(error::NOT_FOUND, "consumer of the last switch is not using the false branch"); } p.consumer->set_input(i, lastSwitchName + ":1"); + node_map.UpdateInput(lastSwitch->name(), lastSwitchName, lastSwitchName + ":1"); } } } - if (enable_dump != nullptr) { - std::cerr << "Last switch after change: \n"; - std::cerr << lastSwitch->DebugString(); - } - } - } - if (enable_dump != nullptr) { - std::cerr << "List of logical node created:\n"; - for (auto n: createdLogicalNodes) { - std::cerr << n.second->DebugString(); + VLOG(2) << "Last switch after change: \n"; + VLOG(2) << lastSwitch->DebugString(); } - std::cerr << "After optimization:\n" << optimized_graph->DebugString(); } + // VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString(); return Status::OK(); } @@ -363,7 +331,7 @@ void SwitchOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/, const GraphDef& /*optimized_graph*/, double /*result*/) { - // Nothing to do for LoopOptimizer. + // Nothing to do for SwitchOptimizer. } } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/switch_optimizer.h b/tensorflow/core/grappler/optimizers/switch_optimizer.h index 5659d420..1d541a00 100644 --- a/tensorflow/core/grappler/optimizers/switch_optimizer.h +++ b/tensorflow/core/grappler/optimizers/switch_optimizer.h @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Optimize TensorFlow subgraphs that operate on shape and shape related -// information. +// Optimize TensorFlow subgraphs that shorten chain of switches class SwitchOptimizer : public GraphOptimizer { public: SwitchOptimizer() {} @@ -35,7 +34,7 @@ class SwitchOptimizer : public GraphOptimizer { ~SwitchOptimizer() override {} - string name() const override { return "~witch_optimizer"; }; + string name() const override { return "switch_optimizer"; }; bool UsesFunctionLibrary() const override { return false; } -- Gitee From 37c0b60280446836818823b559705b77c76c89f9 Mon Sep 17 00:00:00 2001 From: CuiXiaoFeng Date: Tue, 9 Sep 2025 09:28:45 +0800 Subject: [PATCH 3/3] Fix review comments --- tensorflow/core/grappler/optimizers/BUILD | 2 +- tensorflow/core/grappler/optimizers/switch_optimizer.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 16307b89..5958258d 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -802,7 +802,7 @@ tf_kernel_library( visibility = ["//visibility:public"], deps = [ ":constant_folding", - ":graph_optimizer", + ":graph_optimizer", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/optimizers/switch_optimizer.cc b/tensorflow/core/grappler/optimizers/switch_optimizer.cc index 590009d0..dec2ed3c 100644 --- a/tensorflow/core/grappler/optimizers/switch_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/switch_optimizer.cc @@ -70,7 +70,7 @@ Status SwitchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, Status status; utils::MutableGraphView graph_view(optimized_graph, &status); TF_RETURN_IF_ERROR(graph_view.SortTopologically(false, {})); - std::vector groupList={}; + std::vector groupList = {}; for (auto& nv: graph_view.GetNodes()) { NodeDef* np = nv.node(); NodeDef& node = *np; -- Gitee