From 9335c59c36b94c42bfc75e9fd3986768d1c802c0 Mon Sep 17 00:00:00 2001 From: rayshine <1324789704@qq.com> Date: Wed, 27 Aug 2025 17:49:24 +0800 Subject: [PATCH] =?UTF-8?q?Select=E7=AE=97=E5=AD=90=E5=8A=A0=E5=9B=BA?= =?UTF-8?q?=EF=BC=9A=E5=B8=B8=E9=87=8F=E8=BE=93=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/embedding_fused_sparse_select.cc | 58 +++++++++++-------- tensorflow/core/ops/embedding_fused_ops.cc | 4 ++ .../fused_embedding_sparse_select_test.py | 42 +++++++++----- 3 files changed, 64 insertions(+), 40 deletions(-) diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc index 953d65241..0a653cdd0 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_select.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc @@ -33,7 +33,19 @@ public: const Tensor& input_a = context->input(0); const Tensor& input_b = context->input(1); const Tensor& input_c = context->input(2); - + const Tensor& greater = context->input(3); + const Tensor& equal1 = context->input(4); + const Tensor& equal2 = context->input(5); + const Tensor& equal3 = context->input(6); + + int32_t equal1_val = equal1.flat()(0); + int32_t equal2_val = equal2.flat()(0); + int32_t equal3_val = equal3.flat()(0); + VLOG(1) << "equal1_val: " << equal1_val; + VLOG(1) << "equal2_val: " << equal2_val; + VLOG(1) << "equal3_val: " << equal3_val; + + int32_t greater_val = greater.flat()(0); auto a_flat = input_a.flat(); auto b_flat = input_b.flat(); auto c_flat = input_c.flat(); @@ -58,44 +70,40 @@ public: OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y)); OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w)); - auto out_x = output_x->matrix(); // [N,1] - auto out_y = output_y->matrix(); // [N,1] - auto out_w = output_w->matrix(); // [N,2] - - auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); - const int64 cost_per_unit = 10; - - auto work = [&](int64 start, int64 end) { - for (int64 i = start; i < end; i++) { - float a_greater = (a_reshaped_tensor(i) > 0) ? 1.0f : 0.0f; - float select_2412 = (b_reshaped_tensor(i) == 4563) ? 1.0f : a_greater; - float select_2415 = (b_reshaped_tensor(i) == 10831) ? 1.0f : select_2412; - float sub_out = 1.0f - select_2415; - out_x(i, 0) = 0.0f; - out_y(i, 0) = sub_out; - out_w(i, 0) = select_2415; - out_w(i, 1) = 1.0f; - } - }; - Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work); - - Eigen::TensorMap> map_output_x( + Eigen::TensorMap> out_x( output_x->flat().data(), output_x->dim_size(0), output_x->dim_size(1) ); - Eigen::TensorMap> map_output_y( + Eigen::TensorMap> out_y( output_y->flat().data(), output_y->dim_size(0), output_y->dim_size(1) ); - Eigen::TensorMap> map_output_w( + Eigen::TensorMap> out_w( output_w->flat().data(), output_w->dim_size(0), output_w->dim_size(1) ); + + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = std::max(N / worker_threads->num_threads, int64(10)); + + auto work = [&](int64 start, int64 end) { + for (int64 i = start; i < end; i++) { + // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f + float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f; + float select_2412 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f + float select_2415 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : select_2412; // Fill.2409-->1.0f + out_x(i, 0) = a_reshaped_tensor(i, 0); // Reshape.2401 + out_y(i, 0) = select_2415; + out_w(i, 0) = select_2415; // Mul.2419 硬编码 1.0f * input + out_w(i, 1) = 1.0f; // select_2427被消除,直接使用Fill.2422-->1.0f + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work); } }; diff --git a/tensorflow/core/ops/embedding_fused_ops.cc b/tensorflow/core/ops/embedding_fused_ops.cc index 27a667e59..aff7e0cfd 100644 --- a/tensorflow/core/ops/embedding_fused_ops.cc +++ b/tensorflow/core/ops/embedding_fused_ops.cc @@ -87,6 +87,10 @@ REGISTER_OP("KPFusedSparseSelect") .Input("input_a: int32") .Input("input_b: int32") .Input("input_c: int32") + .Input("greater: int32") + .Input("equal1: int32") + .Input("equal2: int32") + .Input("equal3: int32") .Output("output_x: float") .Output("output_y: float") .Output("output_w: float") diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py index c5041495c..6bfb042da 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py @@ -10,35 +10,36 @@ from utils.utils import benchmark_op np.random.seed(140) -def ori_sparse_select_graph(input_a, input_b, input_c): +def ori_sparse_select_graph(input_a, input_b, input_c, greater, equal1, equal2, equal3): a = tf.reshape(input_a, [-1, 1]) b = tf.reshape(input_b, [-1, 1]) c = tf.reshape(input_c, [-1, 1]) + output_x = a - greater_a = tf.greater(a, 0) + greater_a = tf.greater(a, greater) shape_reshape_a1 = tf.shape(a) shape_reshape_a2 = tf.shape(a) fill_a1 = tf.fill(shape_reshape_a1, tf.constant(1, dtype=tf.float32)) realdiv = tf.realdiv(fill_a1, tf.constant(1, dtype=tf.float32)) - output_x = tf.fill(shape_reshape_a2, tf.constant(0, dtype=tf.float32)) # 全部填充为0,第一个输出 cast_a = tf.cast(greater_a, tf.float32) shape_a = tf.shape(cast_a) fill_a = tf.fill(shape_a, tf.constant(1, dtype=tf.float32)) - equal_4563 = tf.equal(b, 4563) - equal_10831 = tf.equal(b, 10831) - equal_3 = tf.equal(c, 3) + equal_4563 = tf.equal(b, equal1) + equal_10831 = tf.equal(b, equal2) + equal_3 = tf.equal(c, equal3) select_1 = tf.where(equal_4563, fill_a, cast_a) select_2 = tf.where(equal_10831, fill_a, select_1) - output_y = tf.subtract(tf.constant(1, dtype=tf.float32), select_2) # 1 - select_2 ? + output_y = select_2 # select_2 mul = tf.multiply(tf.constant(1, dtype=tf.float32), select_2) # Select.2415 --> Mul.2419 - select_3 = tf.where(equal_3, realdiv, fill_a1) # realdiv和fill_a1 一样的? + select_3 = tf.where(equal_3, realdiv, fill_a1) output_z = tf.concat([mul, select_3], axis=-1) return output_x, output_y, output_z -def opt_sparse_select_graph(input_a, input_b, input_c): +def opt_sparse_select_graph(input_a, input_b, input_c, greater, equal1, equal2, equal3): output_x, output_y, output_z = gen_embedding_fused_ops.KPFusedSparseSelect( - input_a=input_a, input_b=input_b, input_c=input_c + input_a=input_a, input_b=input_b, input_c=input_c, greater=greater, + equal1=equal1, equal2=equal2, equal3=equal3 ) return output_x, output_y, output_z @@ -64,18 +65,29 @@ class TestKPFusedSparseSelect(unittest.TestCase): input0 = tf.compat.v1.placeholder(tf.int32, name="input_a") input1 = tf.compat.v1.placeholder(tf.int32, name="input_b") input2 = tf.compat.v1.placeholder(tf.int32, name="input_c") + greater = tf.compat.v1.placeholder(tf.int32, shape=(), name="greater") + equal1 = tf.compat.v1.placeholder(tf.int32, shape=(), name="equal1") + equal2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="equal2") + equal3 = tf.compat.v1.placeholder(tf.int32, shape=(), name="equal3") """Initialize test data""" feed = { input0: np.random.randint(0, 100, size=a_shape).astype(np.int32), input1: np.random.randint(0, 100, size=b_shape).astype(np.int32), input2: np.random.randint(0, 100, size=c_shape).astype(np.int32), + greater: np.array(0, dtype=np.int32), + equal1: np.array(4563, dtype=np.int32), + equal2: np.array(10831, dtype=np.int32), + equal3: np.array(3, dtype=np.int32), } with tf.name_scope("ori"): - out_ori1, out_ori2, out_ori3 = ori_sparse_select_graph(input0, input1, input2) + out_ori1, out_ori2, out_ori3 = ori_sparse_select_graph( + input0, input1, input2, greater, equal1, equal2, equal3 + ) with tf.name_scope("opt"): - out_opt1, out_opt2, out_opt3 = opt_sparse_select_graph(input0, input1, input2) - + out_opt1, out_opt2, out_opt3 = opt_sparse_select_graph( + input0, input1, input2, greater, equal1, equal2, equal3 + ) with tf.compat.v1.Session(config=self.config) as sess: out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( [out_ori1, out_ori2, out_ori3], @@ -84,7 +96,7 @@ class TestKPFusedSparseSelect(unittest.TestCase): run_metadata=self.run_metadata_ori ) out_opt_val1, out_opt_val2, out_opt_val3 = sess.run( - [out_opt1,out_opt2, out_opt3], + [out_opt1, out_opt2, out_opt3], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt @@ -122,7 +134,7 @@ class TestKPFusedSparseSelect(unittest.TestCase): self.run_metadata_opt, op_name="KPFusedSparseSelect", start_op="ori/Reshape", - end_op="ori/Sub", + end_op="ori/concat", num_runs=num_runs, tag="-----TF_origin-----" ) -- Gitee