From 74e1b857d3777a4ded8e7d6565f10a81f1c2c4af Mon Sep 17 00:00:00 2001 From: rayshine <1324789704@qq.com> Date: Tue, 26 Aug 2025 22:09:03 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E4=BC=98=E5=8C=96=EF=BC=9AActionIdGather=E5=A4=9A=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E3=80=81Gather=20SIMD=E5=B9=B6=E8=A1=8C=E3=80=81Selec?= =?UTF-8?q?t=E5=A4=9A=E7=BA=BF=E7=A8=8B=202=E3=80=81=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E8=84=9A=E6=9C=AC=20&=20gather=20=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../embedding_fused_action_id_gather.cc | 71 ++-- .../core/kernels/embedding_fused_gather.cc | 32 +- .../core/kernels/embedding_fused_padding.cc | 27 +- .../embedding_fused_sparse_dynamic_stitch.cc | 11 +- .../kernels/embedding_fused_sparse_reshape.cc | 2 +- .../embedding_fused_sparse_segment_reduce.cc | 8 +- .../kernels/embedding_fused_sparse_select.cc | 66 ++-- .../core/profiler/lib/profiler_session.cc | 2 +- .../fused_embedding_action_id_gather_test.py | 43 ++- .../fused_embedding_gather_test.py | 206 +++++++---- .../fused_embedding_padding_test.py | 160 ++++++--- ...ed_embedding_sparse_dynamic_stitch_test.py | 182 ++++++---- .../fused_embedding_sparse_reshape_test.py | 212 ++++++----- ...ed_embedding_sparse_segment_reduce_test.py | 329 ++++++++++++------ .../fused_embedding_sparse_select_test.py | 106 ++++-- .../embedding_fused_test/utils/utils.py | 88 ++++- 16 files changed, 983 insertions(+), 562 deletions(-) diff --git a/tensorflow/core/kernels/embedding_fused_action_id_gather.cc b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc index b324f35f0..c7df8520c 100644 --- a/tensorflow/core/kernels/embedding_fused_action_id_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc @@ -13,23 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/op_kernel.h" -namespace tensorflow { +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" +namespace tensorflow { + template -static void GatherV2Impl(OpKernelContext* context, - const float* params_data, - const TensorShape& params_shape, - const Tindices* indices_data, - const TensorShape& indices_shape, - int axis, Tensor* temp) { +static void GatherV2Impl(OpKernelContext* context, const float* params_data, const TensorShape& params_shape, + const Tindices* indices_data, const TensorShape& indices_shape, int axis, Tensor* temp) { TensorShape temp_shape; const int P0 = params_shape.dim_size(0); int P1 = 1; @@ -41,8 +33,7 @@ static void GatherV2Impl(OpKernelContext* context, temp_shape.AddDim(params_shape.dim_size(d)); P1 *= params_shape.dim_size(d); } - OP_REQUIRES_OK(context, - context->allocate_temp(DT_FLOAT, temp_shape, temp)); + OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, temp_shape, temp)); VLOG(1) << "temp shape: " << temp->shape().DebugString(); const int num_indices = indices_shape.num_elements(); @@ -52,17 +43,18 @@ static void GatherV2Impl(OpKernelContext* context, const int slice_size = P1; for (int i = 0; i < num_indices; ++i) { Tindices idx = indices_data[i]; - OP_REQUIRES(context, (idx < 0 || idx >= P0), errors::InvalidArgument("GatherV2 axis=0: index out of range")); - std::memcpy(temp_data + i * slice_size, - params_data + idx * slice_size, - sizeof(float) * slice_size); + OP_REQUIRES(context, (idx >= 0 && idx < P0), errors::InvalidArgument("GatherV2 axis=0: index out of range")); + std::memcpy( + temp_data + i * slice_size, params_data + idx * slice_size, sizeof(float) * slice_size + ); } VLOG(1) << "temp value : " << temp->DebugString(100); } + template class KPFusedEmbeddingActionIdGatherOp : public OpKernel { - public: +public: explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { @@ -81,13 +73,11 @@ class KPFusedEmbeddingActionIdGatherOp : public OpKernel { OP_REQUIRES(context, pack_dim.NumElements() == 1, errors::InvalidArgument("pack_dim NumElements must = 1")); Tensor temp; - GatherV2Impl(context, params.flat().data(), params.shape(), - indices1.flat().data(), indices1.shape(), - 0, &temp); + GatherV2Impl(context, params.flat().data(), params.shape(), indices1.flat().data(), + indices1.shape(), 0, &temp); Tensor temp1; - GatherV2Impl(context, temp.flat().data(), temp.shape(), - indices2.flat().data(), indices2.shape(), - 0, &temp1); + GatherV2Impl(context, temp.flat().data(), temp.shape(), indices2.flat().data(), + indices2.shape(), 0, &temp1); int pack_size = pack_dim.scalar()(); VLOG(1) << "pack_size value: " << pack_size; int a_reshaped_cols = temp1.NumElements() / pack_size; @@ -96,16 +86,25 @@ class KPFusedEmbeddingActionIdGatherOp : public OpKernel { Tensor* output; int output_cols = a_reshaped_cols + 1680; OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({pack_size, output_cols}), &output)); + context->allocate_output(0, TensorShape({pack_size, output_cols}), &output)); VLOG(1) << "output shape: " << output->shape().DebugString(); - auto output_matrix = output->matrix(); - output_matrix.slice( - Eigen::array{0, 0}, - Eigen::array{pack_size, a_reshaped_cols}) = a_reshaped; - - output_matrix.slice( - Eigen::array{0, a_reshaped_cols}, - Eigen::array{pack_size, 1680}).setZero(); + auto a_reshaped_data = a_reshaped.data(); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = a_reshaped_cols + 1680; + auto work = [&](int64 start_row, int64 end_row) { + float* base = output->matrix().data(); + for (int64 row = start_row; row < end_row; ++row) { + float* dst_row = base + row * (a_reshaped_cols + 1680); + std::memcpy( + dst_row, a_reshaped_data + row * a_reshaped_cols, sizeof(float) * a_reshaped_cols + ); + std::memset( + dst_row + a_reshaped_cols, 0, sizeof(float) * 1680 + ); + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, pack_size, + cost_per_unit, work); } }; diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index c909307e2..404e4cc48 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -22,7 +20,7 @@ limitations under the License. using namespace tensorflow; class KPFusedGather : public OpKernel { - public: +public: explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { } void Compute(OpKernelContext* context) override { @@ -66,33 +64,27 @@ class KPFusedGather : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({unique_values.size()}), &out_shape)); + 0, TensorShape({unique_values.size()}), &out_shape)); std::memcpy(out_shape->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); OP_REQUIRES_OK(context, context->allocate_output( - 1, TensorShape({static_cast(indices.size())}), &out_indices)); + 1, TensorShape({static_cast(indices.size())}), &out_indices)); std::memcpy(out_indices->data(), indices.data(), indices.size() * sizeof(int32_t)); - OP_REQUIRES(context, data.dim_size(1) * unique_values.size() % 12 == 0, + OP_REQUIRES(context, data.dim_size(1) * unique_values.size() % 12 == 0, errors::Internal("cannot reshape to [-1, 12]")); - - std::vector gather1_result; - for (auto &indice : unique_values) { - for (int64_t i = 0; i < data.dim_size(1); ++i) { - gather1_result.push_back(data_mat(indice, i)); - } - } OP_REQUIRES_OK(context, context->allocate_output( - 2, TensorShape({unique_values.size(), 12}), &out_data)); + 2, TensorShape({unique_values.size(), 12}), &out_data)); auto output_data = out_data->matrix(); - int cur_row = 0; - for (int indice = 0; indice < unique_values.size(); ++indice) { - for (int i = 0; i < 12; ++i) { - output_data(cur_row, i) = gather1_result[12 * indice + i]; - } - cur_row++; + + int64_t cols = data.dim_size(1); + for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { + int64_t idx = unique_values[cur_row]; + const float* src = data_mat.data() + idx * cols; + float* dst = output_data.data() + cur_row * cols; + std::memcpy(dst, src, cols * sizeof(float)); } } }; diff --git a/tensorflow/core/kernels/embedding_fused_padding.cc b/tensorflow/core/kernels/embedding_fused_padding.cc index e36fbf7fa..98351004d 100644 --- a/tensorflow/core/kernels/embedding_fused_padding.cc +++ b/tensorflow/core/kernels/embedding_fused_padding.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/op_kernel.h" + namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; class KPFusedEmbeddingPaddingOp : public OpKernel { - public: +public: explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) { fast_ = (type_string() == "KPFusedEmbeddingPaddingFast"); } @@ -67,32 +67,29 @@ class KPFusedEmbeddingPaddingOp : public OpKernel { int output_rows = padding_rows + input.dim_size(0); int output_cols = input.dim_size(1); OP_REQUIRES( - context, - output_rows * output_cols % reshape_cols == 0, - errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]") + context, output_rows * output_cols % reshape_cols == 0, + errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]") ); int reshape_rows = output_rows * output_cols / reshape_cols; if (fast_) { - OP_REQUIRES_OK(context, - context->allocate_output(1, TensorShape({}), - &output1)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1)); output1->scalar()() = reshape_rows; return; } OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, TensorShape({padding_rows + input_rows_value, output_cols}), - &padding)); + &padding)); auto input_matrix = input.matrix(); auto padding_matrix = padding.matrix(); padding_matrix.slice( - Eigen::array{0, 0}, - Eigen::array{input_rows_value, output_cols}) = input_matrix; + Eigen::array{0, 0}, + Eigen::array{input_rows_value, output_cols}) = input_matrix; padding_matrix.slice( - Eigen::array{input_rows_value, 0}, - Eigen::array{padding_rows, output_cols}).setZero(); + Eigen::array{input_rows_value, 0}, + Eigen::array{padding_rows, output_cols}).setZero(); TensorShape reshaped_shape({reshape_rows, reshape_cols}); OP_REQUIRES_OK(context, @@ -100,7 +97,7 @@ class KPFusedEmbeddingPaddingOp : public OpKernel { output1->flat() = padding.flat(); } - private: +private: bool fast_; }; diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc index 9937a07e0..9e6674546 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc @@ -13,12 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -26,12 +22,11 @@ limitations under the License. using namespace tensorflow; class KPFusedSparseDynamicStitchOp : public OpKernel { - public: +public: explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { - float* output; const Tensor& x = context->input(0); auto x_flat = x.flat(); int64_t num_elems = x_flat.size(); @@ -54,7 +49,7 @@ class KPFusedSparseDynamicStitchOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({num_elems, output_stride}), &output_tensor)); - output = (float*)output_tensor->tensor_data().data(); + float* output = (float*)output_tensor->tensor_data().data(); const size_t copy_size = output_stride * sizeof(float); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); @@ -78,4 +73,4 @@ class KPFusedSparseDynamicStitchOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU), - KPFusedSparseDynamicStitchOp); + KPFusedSparseDynamicStitchOp); \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc index 43428b88c..66dfa4838 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc @@ -192,4 +192,4 @@ class KPFusedSparseReshapeOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("KPFusedSparseReshape").Device(DEVICE_CPU), - KPFusedSparseReshapeOp); + KPFusedSparseReshapeOp); \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc index 19cc7394c..d303c60b2 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc @@ -15,8 +15,6 @@ limitations under the License. #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -25,7 +23,7 @@ using namespace tensorflow; template class KPFusedSparseSegmentReduceOp : public OpKernel { - public: +public: explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context) : OpKernel(context) { int combiner_mode; @@ -138,7 +136,7 @@ class KPFusedSparseSegmentReduceOp : public OpKernel { } } - private: +private: bool is_mean_; }; @@ -149,4 +147,4 @@ class KPFusedSparseSegmentReduceOp : public OpKernel { KPFusedSparseSegmentReduceOp); REGISTER_KERNEL(int64) REGISTER_KERNEL(int32) -#undef REGISTER_KERNEL +#undef REGISTER_KERNEL \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc index 086092d51..953d65241 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_select.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc @@ -16,22 +16,20 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/logging.h" using namespace tensorflow; class KPFusedSparseSelect : public OpKernel { - public: +public: explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) { - } void Compute(OpKernelContext* context) override { - const Tensor& input_a = context->input(0); const Tensor& input_b = context->input(1); const Tensor& input_c = context->input(2); @@ -42,9 +40,9 @@ class KPFusedSparseSelect : public OpKernel { VLOG(1) << "input_a shape: " << input_a.shape().DebugString(); VLOG(1) << "input_b shape: " << input_b.shape().DebugString(); VLOG(1) << "input_c shape: " << input_c.shape().DebugString(); - OP_REQUIRES(context,input_a.NumElements() == input_b.NumElements(), + OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(), errors::InvalidArgument("Input num elements must match")); - OP_REQUIRES(context,input_a.NumElements() == input_c.NumElements(), + OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(), errors::InvalidArgument("Input num elements must match")); auto N = input_a.NumElements(); @@ -52,60 +50,54 @@ class KPFusedSparseSelect : public OpKernel { Eigen::TensorMap> b_reshaped_tensor(b_flat.data(), N, 1); Eigen::TensorMap> c_reshaped_tensor(c_flat.data(), N, 1); - auto a_greater = (a_reshaped_tensor > 0); - auto a_greater_casted = a_greater.cast(); - - auto b_equal_node0 = (b_reshaped_tensor == 4563); - auto b_equal_node1 = (b_reshaped_tensor == 10831); - - Eigen::Tensor tensor_ones(N, 1); - tensor_ones.setConstant(1.0f); - - Eigen::Tensor tensor_zeros(N, 1); - tensor_zeros.setConstant(0.0f); - - auto select_2412 = b_equal_node0.select(tensor_ones, a_greater_casted); - auto select_2415 = b_equal_node1.select(tensor_ones, select_2412); - - auto sub_out = 1.0f - select_2415; - auto concat_out = select_2415.concatenate(tensor_ones, 1); - Tensor* output_x = nullptr; Tensor* output_y = nullptr; Tensor* output_w = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0,TensorShape({N, 1}), &output_x)); - OP_REQUIRES_OK(context, - context->allocate_output(1,TensorShape({N, 1}), &output_y)); - OP_REQUIRES_OK(context, - context->allocate_output(2,TensorShape({N, 2}), &output_w)); + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y)); + OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w)); + + auto out_x = output_x->matrix(); // [N,1] + auto out_y = output_y->matrix(); // [N,1] + auto out_w = output_w->matrix(); // [N,2] + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = 10; + + auto work = [&](int64 start, int64 end) { + for (int64 i = start; i < end; i++) { + float a_greater = (a_reshaped_tensor(i) > 0) ? 1.0f : 0.0f; + float select_2412 = (b_reshaped_tensor(i) == 4563) ? 1.0f : a_greater; + float select_2415 = (b_reshaped_tensor(i) == 10831) ? 1.0f : select_2412; + float sub_out = 1.0f - select_2415; + out_x(i, 0) = 0.0f; + out_y(i, 0) = sub_out; + out_w(i, 0) = select_2415; + out_w(i, 1) = 1.0f; + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, N, cost_per_unit, work); Eigen::TensorMap> map_output_x( output_x->flat().data(), output_x->dim_size(0), output_x->dim_size(1) ); - map_output_x = tensor_zeros; Eigen::TensorMap> map_output_y( output_y->flat().data(), output_y->dim_size(0), output_y->dim_size(1) ); - map_output_y = sub_out; Eigen::TensorMap> map_output_w( output_w->flat().data(), output_w->dim_size(0), output_w->dim_size(1) ); - map_output_w = concat_out; - } - }; REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU), - KPFusedSparseSelect); + KPFusedSparseSelect); \ No newline at end of file diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 982a0f933..e2d4111c7 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -156,7 +156,7 @@ ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options) return; } - LOG(INFO) << "Profiler session started."; + VLOG(1) << "Profiler session started."; #if !defined(IS_MOBILE_PLATFORM) CreateProfilers(options, &profilers_); diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py index d20628b0d..1ae2c1f54 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py @@ -1,9 +1,13 @@ -import tensorflow as tf -import numpy as np +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. import unittest +import numpy as np +import tensorflow as tf + from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op + +np.random.seed(140) def ori_fused_embedding_action_id_gather_graph(input0, input1, input2, input3): @@ -33,7 +37,7 @@ class TestFusedEmbeddingActionIdGather(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -68,23 +72,34 @@ class TestFusedEmbeddingActionIdGather(unittest.TestCase): # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - out_ori_val = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - out_opt_val = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + out_ori_val = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + out_opt_val = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( out_ori_val, out_opt_val, err_msg="result mismatch" ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingActionIdGather", + start_op="ori/stack_1", + end_op="ori/concat", + num_runs=10000, + tag="----------TF_origin-----------" + ) - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") - - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed_dict=feed), - wrapper_sess(sess, [out_opt], feed_dict=feed), - "KPFusedEmbeddingActionIdGather") - if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py index 1c73adc18..5154ffa82 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py @@ -1,87 +1,147 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + + +def ori_fused_embedding_gather_graph(data, slice_input, begin): + slice_out = tf.strided_slice( + slice_input, + begin=begin, + end=[tf.shape(slice_input)[0], begin[1] + 2], + strides=[1, 1], + begin_mask=1, + end_mask=1, + shrink_axis_mask=2 + ) + + slice_out, slice_out_indices = tf.unique(slice_out) + output_shape = slice_out + slice_out = tf.reshape(slice_out, [-1]) + slice_out, slice_out_indices2 = tf.unique(slice_out) + + gather1_result = tf.gather(data, slice_out) + gather1_result = tf.reshape(gather1_result, [-1, 12]) + + gather2_result = tf.gather(gather1_result, slice_out_indices2) + return output_shape, slice_out_indices, gather2_result + + +def opt_fused_embedding_gather_graph(data, slice_input, begin): + custom_out1, custom_out2, custom_out3 = gen_embedding_fused_ops.KPFusedGather( + data=data, + slice_input=slice_input, + begin=begin + ) + return custom_out1, custom_out2, custom_out3 + class TestFusedGather(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_data = np.linspace(0, 11, num=240, endpoint=False, dtype=np.float32).reshape(20, 12) - cls.base_slice_input = np.array([[0, 0], [0, 1], [1, 2]], dtype=np.int64) - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 1] - # Create tf session - cls.sess = tf.compat.v1.Session() + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() @classmethod def tearDownClass(cls): - cls.sess.close() - - def test_custom(self): - # execute custom op - custom_out1, custom_out2, custom_out3= self.custom_op.KPFusedGather( - data=self.base_data, - slice_input=self.base_slice_input, - begin=self.base_begin, - ) - - # tf native implementation - tf_out1, tf_out2, tf_out3 = self._tf_reference_impl( - self.base_data, - self.base_slice_input, - self.base_begin, - ) - - custom_out_val1, custom_out_val2, custom_out_val3 = self.sess.run([custom_out1, custom_out2, custom_out3]) - tf_out_val1, tf_out_val2, tf_out_val3 = self.sess.run([tf_out1, tf_out2, tf_out3]) - - np.testing.assert_array_equal( - custom_out_val1, - tf_out_val1, - err_msg="Segment count mismatch" - ) - - np.testing.assert_array_equal( - custom_out_val2, - tf_out_val2, - err_msg="Segment count mismatch" - ) - - np.testing.assert_allclose( - custom_out_val3, - tf_out_val3, - rtol=1e-6, - err_msg="Output values mismatch" - ) - - def _tf_reference_impl(self, data, slice_input, begin): - slice_out = tf.strided_slice( - slice_input, - begin = begin, - end = [tf.shape(slice_input)[0], begin[1] + 2], - strides = [1, 1], - begin_mask = 1, - end_mask = 1, - shrink_axis_mask = 2 - ) - - slice_out, slice_out_indices = tf.unique(slice_out) - output_shape = slice_out - slice_out = tf.reshape(slice_out, [-1]) - slice_out, _ = tf.unique(slice_out) - - gather1_result = tf.gather(data, slice_out) - gather1_result = tf.reshape(gather1_result, [-1, 12]) - - gather2_result = tf.gather(gather1_result, slice_out) - return output_shape, slice_out_indices, gather2_result + return + + def _run_kp_gather_test(self, data_shape, slice_shape, base_data, base_slice_input, base_begin, num_runs): + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=data_shape, name="data") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=slice_shape, name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + + feed = { + data: base_data, + slice_input: base_slice_input, + begin: base_begin + } + + # original graph + with tf.name_scope("ori"): + out_ori1, out_ori2, out_ori3 = ori_fused_embedding_gather_graph(data, slice_input, begin) + + # optimized graph + with tf.name_scope("opt"): + out_opt1, out_opt2, out_opt3 = opt_fused_embedding_gather_graph(data, slice_input, begin) + + with tf.compat.v1.Session(config=self.config) as sess: + # run ori + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + # run opt + out_opt_val1, out_opt_val2, out_opt_val3 = sess.run( + [out_opt1, out_opt2, out_opt3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + + # 功能测试 + np.testing.assert_array_equal( + out_ori_val1, + out_opt_val1, + err_msg="Segment count mismatch" + ) + + np.testing.assert_array_equal( + out_ori_val2, + out_opt_val2, + err_msg="Segment count mismatch" + ) + np.testing.assert_allclose( + out_opt_val3, + out_ori_val3, + rtol=1e-6, + err_msg="Output values mismatch" + ) + + # benchmark + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2, out_opt3], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedGather", + start_op="ori/strided_slice_1", + end_op="ori/GatherV2_1", + num_runs=num_runs, + tag="--TF_origin--" + ) + + + def test_kp_embedding_gather(self): + base_data = np.linspace(0, 11, num=240, endpoint=False, dtype=np.float32).reshape(20, 12) + base_slice_input = np.array([[0, 0], [0, 1], [1, 2]], dtype=np.int64) + base_begin = [0, 1] + self._run_kp_gather_test((20, 12), (3, 2), base_data, base_slice_input, base_begin, num_runs=100) + + + def test_kp_gather_262145(self): + base_data = np.linspace(0, 11111, num=262145*12, dtype=np.float32).reshape(262145, 12) + # base_slice_input = np.array([[0, 1]], dtype=np.int64) + base_slice_input = np.random.randint(0, 262146, size=(46, 2), dtype=np.int64) + base_begin = [0, 0] + self._run_kp_gather_test((262145, 12), (46, 2), base_data, base_slice_input, base_begin, num_runs=100) + if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py index b9950e51c..dc9bd968e 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py @@ -1,31 +1,35 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op np.random.seed(140) -def opt_fused_embedding_padding_fast_graph(input0, input1, input2, input3): - # execute custom op - _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding_fast(input0, input1, input2, input3) - return custom_out - -def opt_fused_embedding_padding_graph(input0, input1, input2, input3): +def opt_padding_fast_graph(input0, input1, input2, input3): + # execute custom op + _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding_fast(input0, input1, input2, input3) + return custom_out + + +def opt_padding_graph(input0, input1, input2, input3): # execute custom op _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding(input0, input1, input2, input3) return custom_out -def ori_fused_embedding_padding_fast_graph(input0, input1, input2, input3): + +def ori_padding_fast_graph(input0, input1, input2, input3): cast = tf.cast(input0, tf.int32) begin = tf.constant([0], dtype=tf.int32) end = tf.constant([1], dtype=tf.int32) strides = tf.constant([1], dtype=tf.int32) hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1) sub_out = hash_rows - input2 - const = tf.constant(10, dtype=tf.int32) + const = tf.constant(input1.shape[1], dtype=tf.int32) pack = tf.stack([sub_out, const], axis=0) fill = tf.fill(pack, tf.constant(0, dtype=tf.float32)) concat = tf.concat([input1, fill], 0) @@ -34,14 +38,15 @@ def ori_fused_embedding_padding_fast_graph(input0, input1, input2, input3): output = tf.strided_slice(shape_tensor, begin=begin, end=end, strides=strides, shrink_axis_mask=1) return output -def ori_fused_embedding_padding_graph(input0, input1, input2, input3): + +def ori_padding_graph(input0, input1, input2, input3): cast = tf.cast(input0, tf.int32) begin = tf.constant([0], dtype=tf.int32) end = tf.constant([1], dtype=tf.int32) strides = tf.constant([1], dtype=tf.int32) hash_rows = tf.strided_slice(cast, begin=begin, end=end, strides=strides, shrink_axis_mask=1) sub_out = hash_rows - input2 - const = tf.constant(10, dtype=tf.int32) + const = tf.constant(input1.shape[1], dtype=tf.int32) pack = tf.stack([sub_out, const], axis=0) fill = tf.fill(pack, tf.constant(0, dtype=tf.float32)) concat = tf.concat([input1, fill], 0) @@ -55,7 +60,7 @@ class TestFusedEmbeddingPadding(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -65,87 +70,126 @@ class TestFusedEmbeddingPadding(unittest.TestCase): def tearDownClass(cls): return - def test_func_kp_fused_embedding_padding(self): - # Create Graph + def _run_kp_padding_test(self, input1_shape, input3_shape, num_runs=500): with tf.Graph().as_default(): - input0 = tf.compat.v1.placeholder(tf.int64, shape=[2], name="input0") - input1 = tf.compat.v1.placeholder(tf.float32, shape=[None, 10], name="input1") - input2 = tf.compat.v1.placeholder(tf.int32, shape=[], name="input2") - input3 = tf.compat.v1.placeholder(tf.int32, shape=[2], name="input3") + input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0") + input1 = tf.compat.v1.placeholder(tf.float32, shape=input1_shape, name="input1") + input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2") + input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3") """Initialize test data""" feed = { - input0: np.array([6, 10]).astype(np.int64), - input1: np.random.rand(4, 10).astype(np.float), - input2: 4, - input3: np.array([-1, 20]).astype(np.int32), + input0: np.array([6, input1_shape[1]]).astype(np.int64), + input1: np.random.rand(*input1_shape).astype(np.float), + input2: input1_shape[0], + input3: np.array(input3_shape).astype(np.int32), } with tf.name_scope("ori"): - out_ori = ori_fused_embedding_padding_graph(input0, input1, input2, input3) + out_ori = ori_padding_graph(input0, input1, input2, input3) with tf.name_scope("opt"): - out_opt = opt_fused_embedding_padding_graph(input0, input1, input2, input3) + out_opt = opt_padding_graph(input0, input1, input2, input3) # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - ori_result = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - opt_result = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + ori_result = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + opt_result = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( ori_result, opt_result, err_msg="result mismatch" ) - - from tensorflow.python.client import timeline - tl_ori = timeline.Timeline(self.run_metadata_ori.step_stats) - tl_opt = timeline.Timeline(self.run_metadata_opt.step_stats) - ctf_ori = tl_ori.generate_chrome_trace_format() - ctf_opt = tl_opt.generate_chrome_trace_format() - - with open("timeline_ori.json", "w") as f: - f.write(ctf_ori) - with open("timeline_opt.json", "w") as f: - f.write(ctf_opt) - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed), wrapper_sess(sess, [out_opt], feed_dict=feed), "KPFusedEmbeddingPadding") + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingPadding", + start_op="ori/Cast", + end_op="ori/Reshape", + num_runs=num_runs, + tag="-------TF_origin-------" + ) - def test_func_kp_fused_embedding_padding_fast(self): - # Create Graph + + def _run_kp_padding_fast_test(self, input1_shape, input3_shape, num_runs=500): with tf.Graph().as_default(): - input0 = tf.compat.v1.placeholder(tf.int64, shape=[2], name="input0") - input1 = tf.compat.v1.placeholder(tf.float32, shape=[None, 10], name="input1") - input2 = tf.compat.v1.placeholder(tf.int32, shape=[], name="input2") - input3 = tf.compat.v1.placeholder(tf.int32, shape=[2], name="input3") + input0 = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="input0") + input1 = tf.compat.v1.placeholder(tf.float32, shape=input1_shape, name="input1") + input2 = tf.compat.v1.placeholder(tf.int32, shape=(), name="input2") + input3 = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="input3") """Initialize test data""" feed = { - input0: np.array([6, 10]).astype(np.int64), - input1: np.random.rand(4, 10).astype(np.float), - input2: 4, - input3: np.array([-1, 20]).astype(np.int32), + input0: np.array([6, input1_shape[1]]).astype(np.int64), + input1: np.random.rand(*input1_shape).astype(np.float), + input2: input1_shape[0], + input3: np.array(input3_shape).astype(np.int32), } with tf.name_scope("ori"): - out_ori = ori_fused_embedding_padding_fast_graph(input0, input1, input2, input3) + out_ori = ori_padding_fast_graph(input0, input1, input2, input3) with tf.name_scope("opt"): - out_opt = opt_fused_embedding_padding_fast_graph(input0, input1, input2, input3) + out_opt = opt_padding_fast_graph(input0, input1, input2, input3) # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - ori_result = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - opt_result = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + ori_result = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + opt_result = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( ori_result, opt_result, err_msg="result mismatch" ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingPaddingFast", + start_op="ori/Cast", + end_op="ori/StridedSlice_1", + num_runs=num_runs, + tag="---------TF_origin---------" + ) + + + def test_kp_padding_shape10(self): + input1_shape = (4, 10) + input3_shape = (-1, 20) + self._run_kp_padding_test(input1_shape, input3_shape, num_runs=100) + + def test_kp_padding_shape12(self): + input1_shape = (1, 12) + input3_shape = (-1, 36) + self._run_kp_padding_test(input1_shape, input3_shape, num_runs=100) + + def test_kp_padding_fast_shape10(self): + input1_shape = (4, 10) + input3_shape = (-1, 20) + self._run_kp_padding_fast_test(input1_shape, input3_shape, num_runs=100) - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") + def test_kp_padding_fast_shape12(self): + input1_shape = (1, 12) + input3_shape = (-1, 36) + self._run_kp_padding_fast_test(input1_shape, input3_shape, num_runs=100) - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed), wrapper_sess(sess, [out_opt], feed_dict=feed), "KPFusedEmbeddingPaddingFast") if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py index 4de55241c..ab471db89 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py @@ -1,31 +1,47 @@ -import os +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + +np.random.seed(140) + -class TestSparseSegmentMeanSlice(unittest.TestCase): +def ori_fused_sparse_dynamic_stitch_graph(x, emb_tables): + x_1 = tf.reshape(x, shape=[-1]) # 将输入 x 展平成一维向量 x_1 + group_ids = tf.math.floormod(x_1, 12) + group_ids = tf.cast(group_ids, dtype=np.int32) + chunk_indices = tf.math.floordiv(x_1, 12) + original_indices = tf.range(0, tf.size(x_1), 1) + a = tf.dynamic_partition(original_indices, group_ids, num_partitions=12) + b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=12) + c = [tf.gather(emb_tables[i], b[i]) for i in range(12)] + d = tf.dynamic_stitch(a, c) + return d + + +def opt_fused_sparse_dynamic_stitch_graph(x, emb_tables): + output = gen_embedding_fused_ops.KPFusedSparseDynamicStitch( + x = x, + variables = emb_tables + ) + return output + + +class TestSparseDynamicStitch(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops + """Initialize config""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 - cls.variables = [] - max_val = float('inf') - for i in range(12): - N_i = np.random.randint(1000000, 44739244) - max_val = min(N_i, max_val) - var = tf.Variable( - tf.random.normal([N_i, 10], dtype=tf.float32), # shape: (N_i, 10) - name=f"embedding_table_{i}" - ) - cls.variables.append(var) - print(f"Created variable {i}: shape={var.shape}") - - x_np = np.random.randint(0, 12*max_val, size=(10000, 12)) - cls.x = tf.constant(x_np, dtype=tf.int64) + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() # Create tf session cls.sess = tf.compat.v1.Session() @@ -36,11 +52,29 @@ class TestSparseSegmentMeanSlice(unittest.TestCase): cls.sess.close() def test_base(self): - x_first = self.sess.run(self.x) - var_first = self.sess.run(self.variables[0]) + variables = [] + max_val = float('inf') + for i in range(12): + N_i = np.random.randint(100000, 4473924) + max_val = min(N_i, max_val) + var = tf.Variable( + tf.random.normal([N_i, 10], dtype=tf.float32), # shape: (N_i, 10) + name=f"embedding_{i}" + ) + variables.append(var) + # print(f"Created variable {i}: shape={var.shape}") + + x_np = np.random.randint(0, 12*max_val, size=(10000, 12)) + x = tf.constant(x_np, dtype=tf.int64) - x_second = self.sess.run(self.x) - var_second = self.sess.run(self.variables[0]) + self.sess.run(tf.compat.v1.variables_initializer(variables)) + + x_first = self.sess.run(x) + var_first = self.sess.run(variables[0]) + + x_second = self.sess.run(x) + var_second = self.sess.run(variables[0]) + np.testing.assert_allclose( x_first, x_second, @@ -55,42 +89,72 @@ class TestSparseSegmentMeanSlice(unittest.TestCase): err_msg="Input values mismatch" ) - # execute custom op - custom_out = self.custom_op.KPFusedSparseDynamicStitch(x=self.x, variables=self.variables) - - # tf native implementation - tf_out = self._tf_reference_impl(x=self.x, variables=self.variables) - - custom_out_val = self.sess.run([custom_out]) - tf_out_val = self.sess.run([tf_out]) - print("custom_shape: ") - print(custom_out_val[0].shape) - print("tf_out shape: ") - print(tf_out_val[0].shape) - # Numerical comparison - np.testing.assert_allclose( - custom_out_val[0], - tf_out_val[0], - rtol=1e-6, - err_msg="Output values mismatch" - ) - - def _tf_reference_impl(self, x, variables): - x_1 = tf.reshape(x, shape=[-1]) - group_ids = tf.math.floormod(x_1, 12) - group_ids = tf.cast(group_ids, dtype=np.int32) - chunk_indices = tf.math.floordiv(x_1, 12) - - original_indices = tf.range(0,tf.size(x_1),1) - - a = tf.dynamic_partition(original_indices, group_ids, num_partitions=12) - b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=12) - - c = [tf.gather(variables[i], b[i]) for i in range(12)] - - d = tf.dynamic_stitch(a, c) + def test_kp_sparse_dynamic_stitch(self): + # Create Graph + with tf.Graph().as_default(): + num_tables = 12 + emb_dim = 10 + max_val = float('inf') + # 每张表的 placeholder,行数随机生成 + tables = [] + table_sizes = [] + for i in range(num_tables): + N_i = np.random.randint(1000000, 44739244) + table_sizes.append(N_i) + max_val = min(N_i, max_val) + table_ph = tf.compat.v1.placeholder( + tf.float32, shape=(N_i, emb_dim), name=f"embedding_table_{i}" + ) + tables.append(table_ph) + # 生成全局索引 placeholder + x_shape = (1000, num_tables) + input_x = tf.compat.v1.placeholder(tf.int64, shape=x_shape, name="input_x") + # 初始化 feed 数据 + feed = {} + rng = np.random.default_rng(12345) + # 为每张表生成随机 embedding 数据 + for i in range(num_tables): + feed[tables[i]] = rng.standard_normal(size=(table_sizes[i], emb_dim)).astype(np.float32) + # 生成索引数据(保持原逻辑:范围是 0 ~ num_tables * max_val - 1) + feed[input_x] = rng.integers( + low=0, high=num_tables * max_val, size=x_shape, dtype=np.int64 + ) + with tf.name_scope("ori"): + out_ori = ori_fused_sparse_dynamic_stitch_graph(input_x, tables) + with tf.name_scope("opt"): + out_opt = opt_fused_sparse_dynamic_stitch_graph(input_x, tables) + + # Create tf session + with tf.compat.v1.Session(config=self.config) as sess: + # functest + out_ori_val = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + out_opt_val = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) + + np.testing.assert_array_equal( + out_ori_val, + out_opt_val, + err_msg="result mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseDynamicStitch", + start_op="ori/Reshape", + end_op="ori/DynamicStitch", + num_runs=100, + tag="--------TF_origin---------" + ) - return d if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py index 37d275315..0930d9331 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py @@ -1,101 +1,151 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + + +def ori_sparse_reshape_graph(slice_input, begin, newshape): + slice67_out = tf.strided_slice( + slice_input, + begin=begin, + end=[0, 2], + strides=[1, 1], + begin_mask=1, + end_mask=1, + shrink_axis_mask=2 + ) + + slice67_out = tf.reshape(slice67_out, [-1, 1]) + shape_out = tf.shape(slice67_out) + slice57_out = tf.strided_slice( + shape_out, + begin=[0], + end=[1], + strides=[1], + shrink_axis_mask=1 + ) + + const2 = tf.constant(2) + input_shape = tf.stack([slice57_out, const2]) + input_shape = tf.cast(input_shape, tf.int64) + + range_out = tf.range(0, slice57_out, 1) + range_out = tf.reshape(range_out, [-1, 1]) + range_out_64 = tf.cast(range_out, dtype=tf.int64) + concat_out = tf.concat([range_out_64, slice67_out], axis=-1) + + values = np.arange(slice_input.shape[0], dtype=np.float32) + + sparse_tensor = tf.SparseTensor( + indices=concat_out, + values=values, + dense_shape=input_shape + ) + sparse_tensor_out = tf.sparse.reshape(sparse_tensor, newshape) + return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out + + +def opt_sparse_reshape_graph(slice_input, begin, newshape): + custom_out1, custom_out2 = gen_embedding_fused_ops.KPFusedSparseReshape( + slice_input=slice_input, + begin=begin, + new_shape=newshape + ) + return custom_out1, custom_out2 + class TestFusedSparseReshape(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_slice_input = np.array([[0, 0], [0, 1], [1, 2], [3, 4]], dtype=np.int64) - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 1] - cls.base_newshape = [2, 4] - # Create tf session - cls.sess = tf.compat.v1.Session() + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() @classmethod def tearDownClass(cls): - cls.sess.close() - - def test_custom(self): - # execute custom op - custom_out1, custom_out2, = self.custom_op.KPFusedSparseReshape( - slice_input=self.base_slice_input, - begin=self.base_begin, - new_shape=self.base_newshape - ) + # cls.sess.close() + return - # tf native implementation - tf_out1, tf_out2, tf_out3 = self._tf_reference_impl( - self.base_slice_input, - self.base_begin, - self.base_newshape - ) + def _run_kp_reshape_test(self, slice_shape, base_slice_input, base_begin, base_newshape, num_runs): + with tf.Graph().as_default(): + slice_input = tf.compat.v1.placeholder(tf.int64, shape=slice_shape, name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + newshape = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="newshape") - custom_out_val1, custom_out_val2 = self.sess.run([custom_out1, custom_out2]) - tf_out_val1, tf_out_val2, tf_out_val3 = self.sess.run([tf_out1, tf_out2, tf_out3]) - - print("custom_out_val1: ", custom_out_val1) - print("custom_out_val2: ", custom_out_val2) - print("tf_out_val1: ", tf_out_val1) - print("tf_out_val2: ", tf_out_val2) - - np.testing.assert_array_equal( - custom_out_val1, - tf_out_val1, - err_msg="Segment count mismatch" - ) + feed = { + slice_input: base_slice_input, + begin: base_begin, + newshape: base_newshape + } - np.testing.assert_array_equal( - custom_out_val2, - tf_out_val2, - err_msg="Segment count mismatch" - ) + with tf.name_scope("ori"): + out_ori1, out_ori2, out_ori3 = ori_sparse_reshape_graph(slice_input, begin, newshape) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_sparse_reshape_graph(slice_input, begin, newshape) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1,out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + + # 功能测试 + np.testing.assert_array_equal( + out_opt_val1, + out_ori_val1, + err_msg="Segment count mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseReshape", + start_op="ori/StridedSlice", + end_op="ori/SparseReshape", + num_runs=num_runs, + tag="------TF_origin-----" + ) - def _tf_reference_impl(self, slice_input, begin, new_shape): - slice67_out = tf.strided_slice( - slice_input, - begin=begin, - end=[0, 2], - strides=[1, 1], - begin_mask=1, - end_mask=1, - shrink_axis_mask=2 - ) - slice67_out = tf.reshape(slice67_out, [-1, 1]) - shape_out = tf.shape(slice67_out) - slice57_out = tf.strided_slice( - shape_out, - begin=[0], - end=[1], - strides=[1], - shrink_axis_mask=1 - ) - - const2 = tf.constant(2) - input_shape = tf.stack([slice57_out, const2]) - input_shape = tf.cast(input_shape, tf.int64) - - range_out = tf.range(0, slice57_out, 1) - range_out = tf.reshape(range_out, [-1, 1]) - range_out_64 = tf.cast(range_out, dtype=tf.int64) - concat_out = tf.concat([range_out_64, slice67_out], axis=-1) + def test_kp_sparse_reshape(self): + base_slice_input = np.array([[0, 0], [0, 1], [1, 2], [3, 4]], dtype=np.int64) + base_begin = [0, 1] + base_newshape = [2, 4] + self._run_kp_reshape_test((4, 2), base_slice_input, base_begin, base_newshape, num_runs=100) - sparse_tensor = tf.SparseTensor( - indices=concat_out, - values=[1,2,3,4], - dense_shape=input_shape - ) - sparse_tensor_out = tf.sparse.reshape(sparse_tensor, new_shape) - return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out + def test_kp_reshape_2(self): + base_slice_input = np.array([[0, 1]], dtype=np.int64) + base_begin = [0, 1] + base_newshape = [-1, 1] + self._run_kp_reshape_test((1, 2), base_slice_input, base_begin, base_newshape, num_runs=100) + if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py index 69c7a114f..5536eb1cd 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py @@ -1,133 +1,238 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + + +def ori_fused_embedding_sparse_segment_reduce_graph(data, indices, slice_input, begin, end, strides, is_mean): + slice_out = tf.strided_slice( + slice_input, + begin=begin, + end=end, + strides=strides, + begin_mask=1, + end_mask=1, + shrink_axis_mask=2 + ) + + segment_ids = tf.cast(slice_out, dtype=tf.int32) + if is_mean: + output = tf.sparse.segment_mean( + data=data, + indices=indices, + segment_ids=segment_ids + ) + else: + output = tf.sparse.segment_sum( + data=data, + indices=indices, + segment_ids=segment_ids + ) + + output_shape = tf.shape(output) + slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1]) + + return output, slice_out + + +def opt_fused_embedding_sparse_segment_reduce_graph(data, indices, slice_input, begin, end, strides, is_mean): + if is_mean: + custom_out, custom_slice_out = gen_embedding_fused_ops.KPFusedSparseSegmentReduce( + data=data, + indices=indices, + slice_input=slice_input, + begin=begin, + end=end, + strides=strides + ) + else: + custom_out, custom_slice_out = gen_embedding_fused_ops.KPFusedSparseSegmentReduce( + data=data, + indices=indices, + slice_input=slice_input, + begin=begin, + end = end, + strides=strides, + combiner=0 + ) + return custom_out, custom_slice_out + class TestSparseSegmentMeanSlice(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_data = np.array([[1.0, 2.0, 3.0], [3.0, 4.0,5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], dtype=np.float32) # shape {4, 3} - cls.base_indices = np.array([0, 1, 2], dtype=np.int64) # shape {3} - cls.base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) # shape {3, 2} - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 2] - # Create tf session - cls.sess = tf.compat.v1.Session() + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() @classmethod def tearDownClass(cls): - cls.sess.close() + return def test_mean(self): - # execute custom op - custom_out, custom_slice_out = self.custom_op.KPFusedSparseSegmentReduce( - data=self.base_data, - indices=self.base_indices, - slice_input=self.base_slice_input, - begin=self.base_begin, - end = self.base_end, - strides = self.base_strides - ) + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=(4,3), name="data") + indices = tf.compat.v1.placeholder(tf.int32, shape=(3,), name="indices") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(3,2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + end = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="end") + strides = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="strides") + + base_data = np.array( + [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], + dtype=np.float32 + ) # shape {4, 3} + base_indices = np.array([0, 1, 2], dtype=np.int64) # shape {3} + base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) # shape {3, 2} + base_begin = [0, 1] + base_end = [0, 2] + base_strides = [1, 2] + + feed = { + data: base_data, + indices: base_indices, + slice_input: base_slice_input, + begin: base_begin, + end: base_end, + strides: base_strides + } + + with tf.name_scope("ori"): + out_ori1, out_ori2 = ori_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, True + ) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, True + ) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2 = sess.run( + [out_ori1, out_ori2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1, out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) - # tf native implementation - tf_out, tf_slice_out = self._tf_reference_impl( - self.base_data, - self.base_indices, - self.base_slice_input, - self.base_begin, - self.base_end, - self.base_strides, - True - ) + np.testing.assert_allclose( + out_opt_val1, + out_ori_val1, + rtol=1e-6, + err_msg="Output values mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSegmentReduce", + start_op="ori/StridedSlice", + end_op="ori/StridedSlice_1", + num_runs=500, + tag="--------TF_origin---------" + ) - custom_out_val, custom_slice_out_val = self.sess.run([custom_out, custom_slice_out]) - tf_out_val, tf_slice_out_val = self.sess.run([tf_out, tf_slice_out]) - - # Numerical comparison - np.testing.assert_allclose( - custom_out_val, - tf_out_val, - rtol=1e-6, - err_msg="Output values mismatch" - ) - np.testing.assert_array_equal( - custom_slice_out_val, - tf_slice_out_val, - err_msg="Segment count mismatch" - ) def test_sum(self): - custom_out, custom_slice_out = self.custom_op.KPFusedSparseSegmentReduce( - data=self.base_data, - indices=self.base_indices, - slice_input=self.base_slice_input, - begin=self.base_begin, - end = self.base_end, - strides = self.base_strides, - combiner=0 - ) - - tf_out, tf_slice_out = self._tf_reference_impl( - self.base_data, - self.base_indices, - self.base_slice_input, - self.base_begin, - self.base_end, - self.base_strides, - False - ) - - custom_out_val, custom_slice_out_val = self.sess.run([custom_out, custom_slice_out]) - tf_out_val, tf_slice_out_val = self.sess.run([tf_out, tf_slice_out]) - - np.testing.assert_allclose( - custom_out_val, - tf_out_val, - rtol=1e-6, - err_msg="Output values mismatch" - ) - np.testing.assert_array_equal( - custom_slice_out_val, - tf_slice_out_val, - err_msg="Segment count mismatch" - ) + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=(4,3), name="data") + indices = tf.compat.v1.placeholder(tf.int32, shape=(3,), name="indices") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(3,2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + end = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="end") + strides = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="strides") + + base_data = np.array( + [[1.0, 2.0, 3.0], [3.0, 4.0,5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], + dtype=np.float32 + ) # shape {4, 3} + base_indices = np.array([0, 1, 2], dtype=np.int64) + base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) + base_begin = [0, 1] + base_end = [0, 2] + base_strides = [1, 2] + + feed = { + data: base_data, + indices: base_indices, + slice_input: base_slice_input, + begin: base_begin, + end: base_end, + strides: base_strides + } + with tf.name_scope("ori"): + out_ori1, out_ori2 = ori_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, False + ) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_fused_embedding_sparse_segment_reduce_graph( + data,indices, slice_input, begin, end, strides, False + ) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2 = sess.run( + [out_ori1, out_ori2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1, out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + np.testing.assert_allclose( + out_opt_val1, + out_ori_val1, + rtol=1e-6, + err_msg="Output values mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSegmentReduce", + start_op="ori/StridedSlice", + end_op="ori/StridedSlice_1", + num_runs=1000, + tag="---------TF_origin--------" + ) - def _tf_reference_impl(self, data, indices, slice_input, begin, end, strides, is_mean): - slice_out = tf.strided_slice( - slice_input, - begin= begin, - end= end, - strides= strides, - begin_mask=1, - end_mask=1, - shrink_axis_mask=2 - ) - - segment_ids = tf.cast(slice_out, dtype=tf.int32) - if is_mean: - output = tf.sparse.segment_mean( - data = data, - indices = indices, - segment_ids= segment_ids - ) - else: - output = tf.sparse.segment_sum( - data = data, - indices = indices, - segment_ids= segment_ids - ) - - output_shape = tf.shape(output) - slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1]) - - return output, slice_out if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py index d37128cb2..c5041495c 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py @@ -1,12 +1,16 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op +np.random.seed(140) -def ori_fused_embedding_sparse_select_graph(input_a, input_b, input_c): + +def ori_sparse_select_graph(input_a, input_b, input_c): a = tf.reshape(input_a, [-1, 1]) b = tf.reshape(input_b, [-1, 1]) c = tf.reshape(input_c, [-1, 1]) @@ -16,7 +20,7 @@ def ori_fused_embedding_sparse_select_graph(input_a, input_b, input_c): shape_reshape_a2 = tf.shape(a) fill_a1 = tf.fill(shape_reshape_a1, tf.constant(1, dtype=tf.float32)) realdiv = tf.realdiv(fill_a1, tf.constant(1, dtype=tf.float32)) - output_x = tf.fill(shape_reshape_a2, tf.constant(0, dtype=tf.float32)) + output_x = tf.fill(shape_reshape_a2, tf.constant(0, dtype=tf.float32)) # 全部填充为0,第一个输出 cast_a = tf.cast(greater_a, tf.float32) shape_a = tf.shape(cast_a) fill_a = tf.fill(shape_a, tf.constant(1, dtype=tf.float32)) @@ -25,14 +29,14 @@ def ori_fused_embedding_sparse_select_graph(input_a, input_b, input_c): equal_3 = tf.equal(c, 3) select_1 = tf.where(equal_4563, fill_a, cast_a) select_2 = tf.where(equal_10831, fill_a, select_1) - output_y = tf.subtract(tf.constant(1, dtype=tf.float32), select_2) - mul = tf.multiply(tf.constant(1, dtype=tf.float32), select_2) - select_3 = tf.where(equal_3, realdiv, fill_a1) + output_y = tf.subtract(tf.constant(1, dtype=tf.float32), select_2) # 1 - select_2 ? + mul = tf.multiply(tf.constant(1, dtype=tf.float32), select_2) # Select.2415 --> Mul.2419 + select_3 = tf.where(equal_3, realdiv, fill_a1) # realdiv和fill_a1 一样的? output_z = tf.concat([mul, select_3], axis=-1) return output_x, output_y, output_z -def opt_fused_embedding_sparse_select_graph(input_a, input_b, input_c): +def opt_sparse_select_graph(input_a, input_b, input_c): output_x, output_y, output_z = gen_embedding_fused_ops.KPFusedSparseSelect( input_a=input_a, input_b=input_b, input_c=input_c ) @@ -45,7 +49,7 @@ class TestKPFusedSparseSelect(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -55,57 +59,85 @@ class TestKPFusedSparseSelect(unittest.TestCase): def tearDownClass(cls): return - def test_fused_embedding_sparse_select(self): - # Create Graph + def _run_kp_select_test(self, a_shape, b_shape, c_shape, num_runs): with tf.Graph().as_default(): - input0 = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="input_a") - input1 = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="input_b") - input2 = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name="input_c") + input0 = tf.compat.v1.placeholder(tf.int32, name="input_a") + input1 = tf.compat.v1.placeholder(tf.int32, name="input_b") + input2 = tf.compat.v1.placeholder(tf.int32, name="input_c") """Initialize test data""" feed = { - input0: np.random.randint(0, 100, size=(100, 10)).astype(np.int32), - input1: np.random.randint(0, 100, size=(10, 100)).astype(np.int32), - input2: np.random.randint(0, 100, size=(20, 50)).astype(np.int32), + input0: np.random.randint(0, 100, size=a_shape).astype(np.int32), + input1: np.random.randint(0, 100, size=b_shape).astype(np.int32), + input2: np.random.randint(0, 100, size=c_shape).astype(np.int32), } + with tf.name_scope("ori"): - out0_ori, out1_ori, out2_ori = ori_fused_embedding_sparse_select_graph(input0, input1, input2) + out_ori1, out_ori2, out_ori3 = ori_sparse_select_graph(input0, input1, input2) with tf.name_scope("opt"): - out0_opt, out1_opt, out2_opt = opt_fused_embedding_sparse_select_graph(input0, input1, input2) - - # Create tf session + out_opt1, out_opt2, out_opt3 = opt_sparse_select_graph(input0, input1, input2) + with tf.compat.v1.Session(config=self.config) as sess: - # functest - out0_ori_val, out1_ori_val, out2_ori_val = sess.run([out0_ori, out1_ori, out2_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - out0_opt_val, out1_opt_val, out2_opt_val = sess.run([out0_opt, out1_opt, out2_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) - + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2, out_opt_val3 = sess.run( + [out_opt1,out_opt2, out_opt3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + + # 功能测试 np.testing.assert_allclose( - out0_ori_val, - out0_opt_val, + out_ori_val1, + out_opt_val1, rtol=1e-5, err_msg="Output values mismatch" ) np.testing.assert_allclose( - out1_ori_val, - out1_opt_val, + out_ori_val2, + out_opt_val2, rtol=1e-5, err_msg="Output values mismatch" ) np.testing.assert_allclose( - out2_ori_val, - out2_opt_val, + out_ori_val3, + out_opt_val3, rtol=1e-5, err_msg="Output values mismatch" ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2, out_opt3], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSelect", + start_op="ori/Reshape", + end_op="ori/Sub", + num_runs=num_runs, + tag="-----TF_origin-----" + ) - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") - # perftest - perf_run(wrapper_sess(sess, [out0_ori, out1_ori, out2_ori], feed_dict=feed), - wrapper_sess(sess, [out0_opt, out1_opt, out2_opt], feed_dict=feed), - "KPFusedEmbeddingSparseSelect") + def test_fused_embedding_sparse_select(self): + shapes = [ + [(i,), (i,), (i,)] for i in range(1, 101) + ] # 新添加的测试案例,shape组中abc的shape都一样,而且大小在1~100之间 + shapes.append([(100, 10), (10, 100), (20, 50)]) + shapes.extend([[(i, i,), (i, i,), (i, i,)] for i in range(1, 101)]) + shapes.extend([[(i, i, i,), (i, i, i,), (i, i, i,)] for i in range(1, 101)]) + for shape in shapes: + self._run_kp_select_test(*shape, num_runs=10) + print(f"tested shape_a {shape[0]}") if __name__ == "__main__": diff --git a/tensorflow/python/grappler/embedding_fused_test/utils/utils.py b/tensorflow/python/grappler/embedding_fused_test/utils/utils.py index 06f02d6bf..f982f5dd3 100644 --- a/tensorflow/python/grappler/embedding_fused_test/utils/utils.py +++ b/tensorflow/python/grappler/embedding_fused_test/utils/utils.py @@ -1,10 +1,88 @@ import timeit +import json +import os from tensorflow.python.client import timeline -def perf_run(ori_func, opt_func, name, warmup=5, iters=50): - +def extract_op_dur(timeline_file, op_name): + """从 timeline JSON 文件中提取指定算子(fusedOp)的耗时(μs)""" + with open(f"timeline/{timeline_file}.json", "r") as f: + trace_events = json.load(f)["traceEvents"] # timeline.json的格式 + durations = [e["dur"] for e in trace_events if e.get("name") == op_name and "dur" in e] + return durations[0] + + +def extract_op_total_time(timeline_file, start_op, end_op): + """计算从 start_op 到 end_op 的总耗时(包含调度空隙)""" + with open(f"timeline/{timeline_file}.json", "r") as f: + trace_events = json.load(f)["traceEvents"] + start_event = next(e for e in trace_events if e.get("args", {}).get("name") == start_op) # 找到 timeline 里第一个 name 等于 start_op 的事件 + end_event = next(e for e in trace_events if e.get("args", {}).get("name") == end_op) # 找不到会报错 + start_time = start_event["ts"] + end_time = end_event["ts"] + end_event["dur"] # ts 是开始时间,dur是算子的持续时间 + return end_time - start_time + + +def benchmark_op( + sess, + feed, + out_ori, + out_opt, + run_options, + run_metadata_ori, + run_metadata_opt, + op_name, + start_op, + end_op, + num_runs=500, + tag="--------TF_origin---------" +): + print("-" * 60) + print("-" * 60) + print("new test") + + total_times_ori = 0.0 + total_times_opt = 0.0 + + for i in range(num_runs): + # 执行原始算子 + sess.run( + out_ori, + feed_dict=feed, + options=run_options, + run_metadata=run_metadata_ori + ) + # 执行优化后的算子 + sess.run( + out_opt, + feed_dict=feed, + options=run_options, + run_metadata=run_metadata_opt + ) + + # 生成 timeline 文件 + filename_ori = f"{op_name}_ori" + filename_opt = f"{op_name}_opt" + generate_timeline(run_metadata_ori.step_stats, filename_ori) + generate_timeline(run_metadata_opt.step_stats, filename_opt) + + # 统计时延 + total_times_ori += extract_op_total_time(filename_ori, start_op, end_op) + total_times_opt += extract_op_dur(filename_opt, op_name) + + # 计算平均值和加速比 + avg_ori = total_times_ori / num_runs + avg_opt = total_times_opt / num_runs + speedup = (avg_ori - avg_opt) / avg_ori * 100 if avg_ori > 0 else 0 + + # 打印结果 + print(f"{tag}: {avg_ori:.4f} us per run") + print(f"{op_name}: {avg_opt:.4f} us per run") + print(f"improve: {speedup:.2f}%") + + +def perf_run(ori_func, opt_func, name, warmup=5, iters=5): print(f"\nWarmup ori: {warmup} iters") for _ in range(warmup): ori_func() @@ -12,7 +90,7 @@ def perf_run(ori_func, opt_func, name, warmup=5, iters=50): print(f"Running performance test: ori {iters} iters") total_time = timeit.timeit(ori_func, number=iters) ori_avg_time = total_time / iters * 1000 - print(f"{name}: {ori_avg_time:.2f} ms per run") + print(f"{name}: {ori_avg_time:.6f} ms per run") print(f"\nWarmup opt: {warmup} iters") for _ in range(warmup): @@ -21,7 +99,7 @@ def perf_run(ori_func, opt_func, name, warmup=5, iters=50): print(f"Running performance test: opt {iters} iters") total_time = timeit.timeit(opt_func, number=iters) opt_avg_time = total_time / iters * 1000 - print(f"{name}: {opt_avg_time:.2f} ms per run") + print(f"{name}: {opt_avg_time:.6f} ms per run") improvement = (ori_avg_time - opt_avg_time) / ori_avg_time * 100 print(f"improve: {improvement:.2f}%") @@ -36,4 +114,4 @@ def generate_timeline(step_stats, filename): def wrapper_sess(sess, fetches, feed_dict=None, options=None, run_metadata=None): - return lambda: sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) \ No newline at end of file + return lambda: sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) \ No newline at end of file -- Gitee