From d4bb36439e5717f64feee1d9b459565812082e39 Mon Sep 17 00:00:00 2001 From: littleli <992995360@qq.com> Date: Thu, 7 Aug 2025 21:32:32 +0800 Subject: [PATCH] modify kpfusedgather, kpsparsereshape, kpfusedsparsedynamicstitch. --- tensorflow/core/kernels/embedding_fused_gather.cc | 5 ++--- .../kernels/embedding_fused_sparse_dynamic_stitch.cc | 12 ++++++++---- .../core/kernels/embedding_fused_sparse_reshape.cc | 4 ++-- tensorflow/core/ops/embedding_fused_ops.cc | 2 +- .../fused_embedding_gather_test.py | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index ca3eff64..4820124c 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -66,9 +66,8 @@ class KPFusedGather : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({1}), &out_shape)); - auto output_shape = out_shape->flat(); - output_shape(0) = static_cast(unique_values.size()); + 0, TensorShape({unique_values.size()}), &out_shape)); + std::memcpy(out_shape->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); OP_REQUIRES_OK(context, context->allocate_output( diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc index e353b565..9937a07e 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc @@ -38,24 +38,28 @@ class KPFusedSparseDynamicStitchOp : public OpKernel { const int num_inputs = context->num_inputs(); const int num_partitions = num_inputs - 1; + int output_stride = 0; std::vector variables(num_partitions); for (int i = 1; i < num_inputs; ++i) { + if (i == 1) { + const Tensor& input_tensor = context->input(i); + if (input_tensor.shape().dims() == 2) { + output_stride = input_tensor.shape().dim_size(1); + } + } variables[i - 1] = context->input(i).flat().data(); } Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_elems, 10}), + context->allocate_output(0, TensorShape({num_elems, output_stride}), &output_tensor)); output = (float*)output_tensor->tensor_data().data(); - const int64_t output_stride = 10; const size_t copy_size = output_stride * sizeof(float); - auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); const int64 cost_per_unit = 1000 * num_elems; auto work = [&](int start, int end) { - const int64_t output_stride = 10; const size_t copy_size = output_stride * sizeof(float); for (int i = start; i < end; ++i) { diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc index 2fe42e94..5cc7e95b 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc @@ -186,8 +186,8 @@ class KPFusedSparseReshapeOp : public OpKernel { Tensor new_shape_in(DT_INT64, TensorShape({2})); auto newshape_tensor_flat = new_shape_in.flat(); - newshape_tensor_flat(0) = static_cast(new_shape.flat()(0)); - newshape_tensor_flat(1) = static_cast(new_shape.flat()(1)); + newshape_tensor_flat(0) = new_shape.flat()(0); + newshape_tensor_flat(1) = new_shape.flat()(1); ReshapeKp(context, indices_in, shape_in, new_shape_in, 0, 1); } }; diff --git a/tensorflow/core/ops/embedding_fused_ops.cc b/tensorflow/core/ops/embedding_fused_ops.cc index 4ba6bec7..27a667e5 100644 --- a/tensorflow/core/ops/embedding_fused_ops.cc +++ b/tensorflow/core/ops/embedding_fused_ops.cc @@ -95,7 +95,7 @@ REGISTER_OP("KPFusedSparseSelect") REGISTER_OP("KPFusedSparseReshape") .Input("slice_input: int64") .Input("begin: int32") - .Input("new_shape: int32") + .Input("new_shape: int64") .Output("out_indices: int64") .Output("out_shape: int64") .SetShapeFn(shape_inference::UnknownShape); diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py index f47b70d2..1c73adc1 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py @@ -73,7 +73,7 @@ class TestFusedGather(unittest.TestCase): ) slice_out, slice_out_indices = tf.unique(slice_out) - output_shape = tf.shape(slice_out) + output_shape = slice_out slice_out = tf.reshape(slice_out, [-1]) slice_out, _ = tf.unique(slice_out) -- Gitee