diff --git a/tensorflow/core/kernels/embedding_fused_action_id_gather.cc b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc index b324f35f03e60726811e83e70587ebaaf460601b..af60b4ab2d5aef43ff5fe92d7c2b6a9b894a59bd 100644 --- a/tensorflow/core/kernels/embedding_fused_action_id_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_action_id_gather.cc @@ -13,23 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/op_kernel.h" -namespace tensorflow { +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/util/work_sharder.h" +namespace tensorflow { + template -static void GatherV2Impl(OpKernelContext* context, - const float* params_data, - const TensorShape& params_shape, - const Tindices* indices_data, - const TensorShape& indices_shape, - int axis, Tensor* temp) { +static void GatherV2Impl(OpKernelContext* context, const float* params_data, const TensorShape& params_shape, + const Tindices* indices_data, const TensorShape& indices_shape, int axis, Tensor* temp) { TensorShape temp_shape; const int P0 = params_shape.dim_size(0); int P1 = 1; @@ -41,28 +33,30 @@ static void GatherV2Impl(OpKernelContext* context, temp_shape.AddDim(params_shape.dim_size(d)); P1 *= params_shape.dim_size(d); } - OP_REQUIRES_OK(context, - context->allocate_temp(DT_FLOAT, temp_shape, temp)); - VLOG(1) << "temp shape: " << temp->shape().DebugString(); + OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, temp_shape, temp)); const int num_indices = indices_shape.num_elements(); float* temp_data = temp->flat().data(); - VLOG(1) << "num_indices : " << num_indices; - OP_REQUIRES(context, axis == 0, errors::InvalidArgument("axis only support 0")); - const int slice_size = P1; - for (int i = 0; i < num_indices; ++i) { - Tindices idx = indices_data[i]; - OP_REQUIRES(context, (idx < 0 || idx >= P0), errors::InvalidArgument("GatherV2 axis=0: index out of range")); - std::memcpy(temp_data + i * slice_size, - params_data + idx * slice_size, - sizeof(float) * slice_size); + if (axis == 0) { + const int slice_size = P1; + for (int i = 0; i < num_indices; ++i) { + Tindices idx = indices_data[i]; + if (idx < 0 || idx >= P0) { + LOG(FATAL) << "GatherV2 axis=0: index out of range: " << idx; + } + std::memcpy( + temp_data + i * slice_size, params_data + idx * slice_size, sizeof(float) * slice_size + ); + } + } else { + LOG(FATAL) << "Only axis=0 is supported"; } - VLOG(1) << "temp value : " << temp->DebugString(100); } + template class KPFusedEmbeddingActionIdGatherOp : public OpKernel { - public: +public: explicit KPFusedEmbeddingActionIdGatherOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { @@ -72,40 +66,41 @@ class KPFusedEmbeddingActionIdGatherOp : public OpKernel { const Tensor& indices2 = context->input(2); const Tensor& pack_dim = context->input(3); - VLOG(1) << "indices1 shape: " << indices1.shape().DebugString(); - VLOG(1) << "params shape: " << params.shape().DebugString(); - VLOG(1) << "indices2 shape: " << indices2.shape().DebugString(); OP_REQUIRES(context, indices1.dims() <= 2, errors::InvalidArgument("indices1 dims must <= 2")); OP_REQUIRES(context, indices2.dims() <= 2, errors::InvalidArgument("indices2 dims must <= 2")); OP_REQUIRES(context, params.dims() == 2, errors::InvalidArgument("params dims must = 2")); OP_REQUIRES(context, pack_dim.NumElements() == 1, errors::InvalidArgument("pack_dim NumElements must = 1")); Tensor temp; - GatherV2Impl(context, params.flat().data(), params.shape(), - indices1.flat().data(), indices1.shape(), - 0, &temp); + GatherV2Impl(context, params.flat().data(), params.shape(), indices1.flat().data(), + indices1.shape(), 0, &temp); Tensor temp1; - GatherV2Impl(context, temp.flat().data(), temp.shape(), - indices2.flat().data(), indices2.shape(), - 0, &temp1); + GatherV2Impl(context, temp.flat().data(), temp.shape(), indices2.flat().data(), + indices2.shape(), 0, &temp1); int pack_size = pack_dim.scalar()(); - VLOG(1) << "pack_size value: " << pack_size; int a_reshaped_cols = temp1.NumElements() / pack_size; auto a_reshaped = temp1.shaped({pack_size, a_reshaped_cols}); - VLOG(1) << "a_reshaped_cols : " << a_reshaped_cols; Tensor* output; int output_cols = a_reshaped_cols + 1680; OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({pack_size, output_cols}), &output)); - VLOG(1) << "output shape: " << output->shape().DebugString(); - auto output_matrix = output->matrix(); - output_matrix.slice( - Eigen::array{0, 0}, - Eigen::array{pack_size, a_reshaped_cols}) = a_reshaped; - - output_matrix.slice( - Eigen::array{0, a_reshaped_cols}, - Eigen::array{pack_size, 1680}).setZero(); + context->allocate_output(0, TensorShape({pack_size, output_cols}), &output)); + + auto a_reshaped_data = a_reshaped.data(); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = a_reshaped_cols + 1680; + float* base = output->matrix().data(); + Shard(worker_threads->num_threads, worker_threads->workers, pack_size, cost_per_unit, + [&](int64 start_row, int64 end_row) { + for (int64 row = start_row; row < end_row; ++row) { + float* dst_row = base + row * (a_reshaped_cols + 1680); + std::memcpy( + dst_row, a_reshaped_data + row * a_reshaped_cols, sizeof(float) * a_reshaped_cols + ); + std::memset( + dst_row + a_reshaped_cols, 0, sizeof(float) * 1680 + ); + } + }); } }; @@ -114,11 +109,11 @@ class KPFusedEmbeddingActionIdGatherOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("Tindices1") \ .TypeConstraint("Tindices2"), \ - KPFusedEmbeddingActionIdGatherOp); + KPFusedEmbeddingActionIdGatherOp) REGISTER_CPU_KERNEL(int64, int32) REGISTER_CPU_KERNEL(int32, int32) REGISTER_CPU_KERNEL(int64, int64) REGISTER_CPU_KERNEL(int32, int64) -} \ No newline at end of file +} diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index 51ec57762e215bdd76f4eb843926acc666fd5c3d..6927d6b80d5aafcba0ac61f081a603a905c3f336 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" +#include + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -22,7 +22,7 @@ limitations under the License. using namespace tensorflow; class KPFusedGather : public OpKernel { - public: +public: explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { } void Compute(OpKernelContext* context) override { @@ -33,16 +33,10 @@ class KPFusedGather : public OpKernel { OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); OP_REQUIRES(context, data.dims() == 2, errors::Internal("indentity dims must == 2")); - VLOG(1) << "Input indentity shape: " << data.shape().DebugString(); - VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); - VLOG(1) << "Input begin value: " << begin.SummarizeValue(10); - int32 col = begin.flat().data()[1]; auto data_mat = data.matrix(); auto slice_input_mat = slice_input.matrix(); - VLOG(1) << "Column index from begin: " << col; - std::vector unique_values; std::vector indices(slice_input.dim_size(0)); std::unordered_map value_to_index; @@ -60,41 +54,41 @@ class KPFusedGather : public OpKernel { } Tensor* out_shape = nullptr; - Tensor* out_indices = nullptr; - Tensor* out_data = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({unique_values.size()}), &out_shape)); - std::memcpy(out_shape->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); - + 0, TensorShape({1}), &out_shape)); + out_shape->flat()(0) = static_cast(unique_values.size()); + + Tensor* out_indices = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 1, TensorShape({static_cast(indices.size())}), &out_indices)); + 1, TensorShape({static_cast(indices.size())}), &out_indices)); std::memcpy(out_indices->data(), indices.data(), indices.size() * sizeof(int32_t)); - OP_REQUIRES(context, data.dim_size(1) * unique_values.size() % 12 == 0, + + OP_REQUIRES(context, data.dim_size(1) * unique_values.size() % 12 == 0, errors::Internal("cannot reshape to [-1, 12]")); - - std::vector gather1_result; - for (auto &indice : unique_values) { - for (int64_t i = 0; i < data.dim_size(1); ++i) { - gather1_result.push_back(data_mat(indice, i)); - } - } + Tensor* out_data = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 2, TensorShape({unique_values.size(), 12}), &out_data)); + 2, TensorShape({unique_values.size(), 12}), &out_data)); auto output_data = out_data->matrix(); - int cur_row = 0; - for (auto &indice : unique_values) { + + for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { + int64_t idx = unique_values[cur_row]; for (int i = 0; i < 12; ++i) { - output_data(cur_row, i) = gather1_result[12 * indice + i]; + const float* src = &data_mat(idx, 0); + float* dst = &output_data(cur_row, 0); + float32x4_t v0 = vld1q_f32(src); + float32x4_t v1 = vld1q_f32(src + 4); + float32x4_t v2 = vld1q_f32(src + 8); + vst1q_f32(dst, v0); + vst1q_f32(dst + 4, v1); + vst1q_f32(dst + 8, v2); } - cur_row++; } } }; REGISTER_KERNEL_BUILDER(Name("KPFusedGather").Device(DEVICE_CPU), - KPFusedGather); \ No newline at end of file + KPFusedGather); diff --git a/tensorflow/core/kernels/embedding_fused_padding.cc b/tensorflow/core/kernels/embedding_fused_padding.cc index e36fbf7fa7afd8210a43b0f1d2a772ccc77f3273..98351004dcc25fd349c505c1c984bb54a5bd792a 100644 --- a/tensorflow/core/kernels/embedding_fused_padding.cc +++ b/tensorflow/core/kernels/embedding_fused_padding.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/op_kernel.h" + namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; class KPFusedEmbeddingPaddingOp : public OpKernel { - public: +public: explicit KPFusedEmbeddingPaddingOp(OpKernelConstruction* context) : OpKernel(context) { fast_ = (type_string() == "KPFusedEmbeddingPaddingFast"); } @@ -67,32 +67,29 @@ class KPFusedEmbeddingPaddingOp : public OpKernel { int output_rows = padding_rows + input.dim_size(0); int output_cols = input.dim_size(1); OP_REQUIRES( - context, - output_rows * output_cols % reshape_cols == 0, - errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]") + context, output_rows * output_cols % reshape_cols == 0, + errors::InvalidArgument("padding cannot reshape to [-1, ", reshape_cols, "]") ); int reshape_rows = output_rows * output_cols / reshape_cols; if (fast_) { - OP_REQUIRES_OK(context, - context->allocate_output(1, TensorShape({}), - &output1)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({}), &output1)); output1->scalar()() = reshape_rows; return; } OP_REQUIRES_OK(context, context->allocate_temp(DT_FLOAT, TensorShape({padding_rows + input_rows_value, output_cols}), - &padding)); + &padding)); auto input_matrix = input.matrix(); auto padding_matrix = padding.matrix(); padding_matrix.slice( - Eigen::array{0, 0}, - Eigen::array{input_rows_value, output_cols}) = input_matrix; + Eigen::array{0, 0}, + Eigen::array{input_rows_value, output_cols}) = input_matrix; padding_matrix.slice( - Eigen::array{input_rows_value, 0}, - Eigen::array{padding_rows, output_cols}).setZero(); + Eigen::array{input_rows_value, 0}, + Eigen::array{padding_rows, output_cols}).setZero(); TensorShape reshaped_shape({reshape_rows, reshape_cols}); OP_REQUIRES_OK(context, @@ -100,7 +97,7 @@ class KPFusedEmbeddingPaddingOp : public OpKernel { output1->flat() = padding.flat(); } - private: +private: bool fast_; }; diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc index 9937a07e095059998e34819c7c80ce968298e70d..e1cdbc5cd205689b26e860fd2914fff04987dc5f 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc @@ -13,12 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -26,7 +22,7 @@ limitations under the License. using namespace tensorflow; class KPFusedSparseDynamicStitchOp : public OpKernel { - public: +public: explicit KPFusedSparseDynamicStitchOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -78,4 +74,4 @@ class KPFusedSparseDynamicStitchOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("KPFusedSparseDynamicStitch").Device(DEVICE_CPU), - KPFusedSparseDynamicStitchOp); + KPFusedSparseDynamicStitchOp); \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc index 43428b88c48a89409ae4e088ec661332488483d8..fb0fa57837000a33c539a4b5aa13b469dd151e96 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc @@ -13,140 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" -#include "tensorflow/core/kernels/reshape_util.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/tensor_util.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" using namespace tensorflow; -static void ReshapeKp(OpKernelContext *context, const Tensor &input_indices_in, - const Tensor &input_shape_in, const Tensor &target_shape_in, - int output_indices_idx, int output_shape_idx) { - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices_in.shape()), - errors::InvalidArgument( - "Input indices should be a matrix but received shape ", - input_indices_in.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape_in.shape()), - errors::InvalidArgument( - "Input shape should be a vector but received shape ", - input_shape_in.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsVector(target_shape_in.shape()), - errors::InvalidArgument( - "Target shape should be a vector but received shape ", - target_shape_in.shape().DebugString())); - - const int64 input_rank = input_shape_in.NumElements(); - const int64 output_rank = target_shape_in.NumElements(); - const TensorShape input_shape(input_shape_in.vec()); - const int64 dense_size = input_shape.num_elements(); - const int64 nnz = input_indices_in.shape().dim_size(0); - - TensorShape output_shape; - int64 product = 1; - int unknown_index = -1; - auto target_shape = target_shape_in.vec(); - for (int d = 0; d < output_rank; ++d) { - const int64 size = target_shape(d); - if (size == -1) { - OP_REQUIRES( - context, unknown_index == -1, - errors::InvalidArgument("only one output dimension may be -1, " - "not both ", - unknown_index, " and ", d)); - unknown_index = d; - output_shape.AddDim(1); - } else { - OP_REQUIRES(context, size >= 0, - errors::InvalidArgument("size ", d, - " must be non-negative, not ", size)); - product *= size; - output_shape.AddDim(size); - } - } - if (unknown_index != -1) { - OP_REQUIRES( - context, product > 0, - errors::InvalidArgument("reshape cannot infer the missing " - "input size for an empty tensor unless all " - "specified input sizes are non-zero")); - const int64 missing = dense_size / product; - OP_REQUIRES( - context, product * missing == dense_size, - errors::InvalidArgument( - "Input to reshape is a SparseTensor with ", dense_size, - " dense values, but the requested shape requires a multiple of ", - product, ". input_shape=", input_shape.DebugString(), - " output_shape=", output_shape.DebugString())); - output_shape.set_dim(unknown_index, missing); - } - - OP_REQUIRES( - context, output_shape.num_elements() == dense_size, - errors::InvalidArgument("Input to reshape is a tensor with ", dense_size, - " dense values, but the requested shape has ", - output_shape.num_elements(), - ". input_shape=", input_shape.DebugString(), - " output_shape=", output_shape.DebugString())); - - if (input_shape == output_shape) { - context->set_output(output_indices_idx, input_indices_in); - context->set_output(output_shape_idx, input_shape_in); - return; - } - - gtl::InlinedVector input_strides(input_rank); - if (input_rank > 0) { - input_strides[input_rank - 1] = 1; - for (int d = input_rank - 2; d >= 0; --d) { - input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1); - } - } - - gtl::InlinedVector output_strides(output_rank); - if (output_rank > 0) { - output_strides[output_rank - 1] = 1; - for (int d = output_rank - 2; d >= 0; --d) { - output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1); - } - } - - Tensor *result_indices = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(output_indices_idx, - TensorShape({nnz, output_rank}), - &result_indices)); - auto input_ind = input_indices_in.matrix(); - auto output_ind = result_indices->matrix(); - for (int i = 0; i < nnz; ++i) { - int64 id = 0; - for (int j = 0; j < input_rank; ++j) { - id += input_ind(i, j) * input_strides[j]; - } - for (int j = 0; j < output_rank; ++j) { - output_ind(i, j) = id / output_strides[j]; - id %= output_strides[j]; - } - } - - Tensor *result_shape = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx, - TensorShape({output_rank}), - &result_shape)); - auto output_shape_vec = result_shape->vec(); - for (int j = 0; j < output_shape.dims(); ++j) { - output_shape_vec(j) = output_shape.dim_size(j); - } -} class KPFusedSparseReshapeOp : public OpKernel { - public: +public: explicit KPFusedSparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) { } void Compute(OpKernelContext* context) override { @@ -155,39 +30,54 @@ class KPFusedSparseReshapeOp : public OpKernel { const Tensor& new_shape = context->input(2); OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); - - VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); - VLOG(1) << "Input begin value: " << begin.DebugString(); - VLOG(1) << "Input new_shape value: " << new_shape.DebugString(); - + int32 col = begin.flat().data()[1]; - int64_t stridedslice57_out = slice_input.dim_size(0); - auto slice_input_mat = slice_input.matrix(); - - VLOG(1) << "stridedslice57_out: " << stridedslice57_out; - VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0); - VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1); - OP_REQUIRES(context, stridedslice57_out == slice_input.dim_size(0), errors::Internal("concat shape mismatch")); - VLOG(1) << "Column index from begin: " << col; - VLOG(1) << "indices size: " << stridedslice57_out; - - Tensor shape_in(DT_INT64, TensorShape({2})); - auto tensor_flat = shape_in.flat(); - tensor_flat(0) = stridedslice57_out; - tensor_flat(1) = 2; - - Tensor indices_in(DT_INT64, TensorShape({stridedslice57_out, 2})); - auto indices_in_mat = indices_in.matrix(); - for (int i = 0; i < stridedslice57_out; ++i) { - indices_in_mat(i, 0) = i; - indices_in_mat(i, 1) = slice_input_mat(i, col); + int64 nnz = slice_input.dim_size(0); + TensorShape output_shape; + int64 product = 2 * nnz; + auto target_shape = new_shape.vec(); + + OP_REQUIRES(context, !(target_shape(0)==-1&&target_shape(1)==-1), errors::InvalidArgument("only one output dimension may be -1.")); + OP_REQUIRES(context, (target_shape(0)>0||target_shape(0)==-1) && (target_shape(1)>0||target_shape(1)==-1), errors::InvalidArgument("must be non-negative.")); + OP_REQUIRES(context, product%target_shape(0)==0 && product%target_shape(1)==0, errors::InvalidArgument("reshape cannot infer the missing.")); + + output_shape.AddDim(target_shape(0) == -1 ? product / target_shape(1) : target_shape(0)); + output_shape.AddDim(target_shape(1) == -1 ? product / target_shape(0) : target_shape(1)); + + if (output_shape.dim_size(0) == nnz && output_shape.dim_size(1) == 2) { + context->set_output(0, slice_input); + Tensor input_shape_in(DT_INT64, TensorShape({2})); + auto tensor_flat = input_shape_in.flat(); + tensor_flat(0) = nnz; + tensor_flat(1) = 2; + context->set_output(1, input_shape_in); + return; } - Tensor new_shape_in(DT_INT64, TensorShape({2})); - auto newshape_tensor_flat = new_shape_in.flat(); - newshape_tensor_flat(0) = new_shape.flat()(0); - newshape_tensor_flat(1) = new_shape.flat()(1); - ReshapeKp(context, indices_in, shape_in, new_shape_in, 0, 1); + Tensor *result_indices = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({nnz, 2}), &result_indices)); + auto input_ind = slice_input.matrix(); + auto output_ind = result_indices->matrix(); + + const int64 target_shape1 = target_shape(1); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + const int64 cost_per_unit = 50; + + Shard(worker_threads->num_threads, worker_threads->workers, nnz, cost_per_unit, + [&](int64 start, int64 limit) { + for (int64 i = start; i < limit; ++i) { + int64 base_index = 2 * i + input_ind(i, col); + output_ind(i, 1) = base_index % target_shape1; + output_ind(i, 0) = base_index / target_shape1; + } + }); + + Tensor *result_shape = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({2}), &result_shape)); + auto output_shape_vec = result_shape->vec(); + for (int64 j = 0; j < output_shape.dims(); ++j) { + output_shape_vec(j) = output_shape.dim_size(j); + } } }; diff --git a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc index 19cc7394c61efea481f6b1143de915506de36f7a..7472fbb9f11f9d07d529525b76f9702d881007e4 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_segment_reduce.cc @@ -15,8 +15,6 @@ limitations under the License. #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/util/work_sharder.h" @@ -25,7 +23,7 @@ using namespace tensorflow; template class KPFusedSparseSegmentReduceOp : public OpKernel { - public: +public: explicit KPFusedSparseSegmentReduceOp(OpKernelConstruction* context) : OpKernel(context) { int combiner_mode; @@ -138,7 +136,7 @@ class KPFusedSparseSegmentReduceOp : public OpKernel { } } - private: +private: bool is_mean_; }; @@ -146,7 +144,7 @@ class KPFusedSparseSegmentReduceOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSegmentReduce") \ .Device(DEVICE_CPU) \ .TypeConstraint("Tidx"), \ - KPFusedSparseSegmentReduceOp); + KPFusedSparseSegmentReduceOp) REGISTER_KERNEL(int64) REGISTER_KERNEL(int32) -#undef REGISTER_KERNEL +#undef REGISTER_KERNEL \ No newline at end of file diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc index 086092d511d8547030bc93882a71bf5dfc2c34c5..89a42d146001fac99c7b1f0e75fcfd6ee7fe5585 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_select.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc @@ -16,22 +16,19 @@ limitations under the License. #include #include -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/logging.h" using namespace tensorflow; class KPFusedSparseSelect : public OpKernel { - public: +public: explicit KPFusedSparseSelect(OpKernelConstruction* context) : OpKernel(context) { - } void Compute(OpKernelContext* context) override { - const Tensor& input_a = context->input(0); const Tensor& input_b = context->input(1); const Tensor& input_c = context->input(2); @@ -39,12 +36,10 @@ class KPFusedSparseSelect : public OpKernel { auto a_flat = input_a.flat(); auto b_flat = input_b.flat(); auto c_flat = input_c.flat(); - VLOG(1) << "input_a shape: " << input_a.shape().DebugString(); - VLOG(1) << "input_b shape: " << input_b.shape().DebugString(); - VLOG(1) << "input_c shape: " << input_c.shape().DebugString(); - OP_REQUIRES(context,input_a.NumElements() == input_b.NumElements(), + + OP_REQUIRES(context, input_a.NumElements() == input_b.NumElements(), errors::InvalidArgument("Input num elements must match")); - OP_REQUIRES(context,input_a.NumElements() == input_c.NumElements(), + OP_REQUIRES(context, input_a.NumElements() == input_c.NumElements(), errors::InvalidArgument("Input num elements must match")); auto N = input_a.NumElements(); @@ -58,10 +53,10 @@ class KPFusedSparseSelect : public OpKernel { auto b_equal_node0 = (b_reshaped_tensor == 4563); auto b_equal_node1 = (b_reshaped_tensor == 10831); - Eigen::Tensor tensor_ones(N, 1); + Eigen::Tensor tensor_ones(N, 1); tensor_ones.setConstant(1.0f); - Eigen::Tensor tensor_zeros(N, 1); + Eigen::Tensor tensor_zeros(N, 1); tensor_zeros.setConstant(0.0f); auto select_2412 = b_equal_node0.select(tensor_ones, a_greater_casted); @@ -74,13 +69,9 @@ class KPFusedSparseSelect : public OpKernel { Tensor* output_y = nullptr; Tensor* output_w = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0,TensorShape({N, 1}), &output_x)); - OP_REQUIRES_OK(context, - context->allocate_output(1,TensorShape({N, 1}), &output_y)); - OP_REQUIRES_OK(context, - context->allocate_output(2,TensorShape({N, 2}), &output_w)); - + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({N, 1}), &output_x)); + OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape({N, 1}), &output_y)); + OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({N, 2}), &output_w)); Eigen::TensorMap> map_output_x( output_x->flat().data(), @@ -102,9 +93,7 @@ class KPFusedSparseSelect : public OpKernel { output_w->dim_size(1) ); map_output_w = concat_out; - } - }; REGISTER_KERNEL_BUILDER(Name("KPFusedSparseSelect").Device(DEVICE_CPU), diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc index 982a0f933558f45234e1282f2c98d4269234a172..24b820af1f341aff6523aa6c3c8db2a44d50f7fc 100644 --- a/tensorflow/core/profiler/lib/profiler_session.cc +++ b/tensorflow/core/profiler/lib/profiler_session.cc @@ -156,7 +156,7 @@ ProfilerSession::ProfilerSession(const profiler::ProfilerOptions& options) return; } - LOG(INFO) << "Profiler session started."; + // LOG(INFO) << "Profiler session started."; #if !defined(IS_MOBILE_PLATFORM) CreateProfilers(options, &profilers_); diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py index d20628b0d301f9b4d430f3cf97f3ab8af11a5f48..1ae2c1f5494842a5c1b1f8606f5bf3ef0bcf4fca 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_action_id_gather_test.py @@ -1,9 +1,13 @@ -import tensorflow as tf -import numpy as np +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. import unittest +import numpy as np +import tensorflow as tf + from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op + +np.random.seed(140) def ori_fused_embedding_action_id_gather_graph(input0, input1, input2, input3): @@ -33,7 +37,7 @@ class TestFusedEmbeddingActionIdGather(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -68,23 +72,34 @@ class TestFusedEmbeddingActionIdGather(unittest.TestCase): # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - out_ori_val = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - out_opt_val = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + out_ori_val = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + out_opt_val = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( out_ori_val, out_opt_val, err_msg="result mismatch" ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingActionIdGather", + start_op="ori/stack_1", + end_op="ori/concat", + num_runs=10000, + tag="----------TF_origin-----------" + ) - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") - - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed_dict=feed), - wrapper_sess(sess, [out_opt], feed_dict=feed), - "KPFusedEmbeddingActionIdGather") - if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py index 1c73adc187830c9831347e29ddbfa79b88b3449d..4e1755ef93d32e28820e5806d9b57f84ccb01826 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py @@ -1,88 +1,140 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + + +def ori_fused_embedding_gather_graph(data, slice_input, begin): + slice_out = tf.strided_slice( + slice_input, + begin=begin, + end=[tf.shape(slice_input)[0], begin[1] + 2], + strides=[1, 1], + begin_mask=1, + end_mask=1, + shrink_axis_mask=2 + ) + + slice_out, slice_out_indices = tf.unique(slice_out) + output_shape = tf.shape(slice_out) + slice_out = tf.reshape(slice_out, [-1]) + slice_out, _ = tf.unique(slice_out) + + gather1_result = tf.gather(data, slice_out) + gather1_result = tf.reshape(gather1_result, [-1, 12]) + + gather2_result = tf.gather(gather1_result, slice_out) + return output_shape, slice_out_indices, gather2_result + + +def opt_fused_embedding_gather_graph(data, slice_input, begin): + custom_out1, custom_out2, custom_out3 = gen_embedding_fused_ops.KPFusedGather( + data=data, + slice_input=slice_input, + begin=begin + ) + return custom_out1, custom_out2, custom_out3 + class TestFusedGather(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_data = np.linspace(0, 11, num=240, endpoint=False, dtype=np.float32).reshape(20, 12) - cls.base_slice_input = np.array([[0, 0], [0, 1], [1, 2]], dtype=np.int64) - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 1] - # Create tf session - cls.sess = tf.compat.v1.Session() + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() @classmethod def tearDownClass(cls): - cls.sess.close() - - def test_custom(self): - # execute custom op - custom_out1, custom_out2, custom_out3= self.custom_op.KPFusedGather( - data=self.base_data, - slice_input=self.base_slice_input, - begin=self.base_begin, - ) - - # tf native implementation - tf_out1, tf_out2, tf_out3 = self._tf_reference_impl( - self.base_data, - self.base_slice_input, - self.base_begin, - ) - - custom_out_val1, custom_out_val2, custom_out_val3 = self.sess.run([custom_out1, custom_out2, custom_out3]) - tf_out_val1, tf_out_val2, tf_out_val3 = self.sess.run([tf_out1, tf_out2, tf_out3]) - - np.testing.assert_array_equal( - custom_out_val1, - tf_out_val1, - err_msg="Segment count mismatch" - ) - - np.testing.assert_array_equal( - custom_out_val2, - tf_out_val2, - err_msg="Segment count mismatch" - ) - - np.testing.assert_allclose( - custom_out_val3, - tf_out_val3, - rtol=1e-6, - err_msg="Output values mismatch" - ) - - def _tf_reference_impl(self, data, slice_input, begin): - slice_out = tf.strided_slice( - slice_input, - begin = begin, - end = [tf.shape(slice_input)[0], begin[1] + 2], - strides = [1, 1], - begin_mask = 1, - end_mask = 1, - shrink_axis_mask = 2 - ) - - slice_out, slice_out_indices = tf.unique(slice_out) - output_shape = slice_out - slice_out = tf.reshape(slice_out, [-1]) - slice_out, _ = tf.unique(slice_out) - - gather1_result = tf.gather(data, slice_out) - gather1_result = tf.reshape(gather1_result, [-1, 12]) - - gather2_result = tf.gather(gather1_result, slice_out) - return output_shape, slice_out_indices, gather2_result + return + + def test_kp_embedding_gather(self): + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=(20, 12), name="data") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(3, 2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + base_data = np.linspace(0, 11, num=240, endpoint=False, dtype=np.float32).reshape(20, 12) + base_slice_input = np.array([[0, 0], [0, 1], [1, 2]], dtype=np.int64) + base_begin = [0, 1] + feed = { + data: base_data, + slice_input: base_slice_input, + begin: base_begin + } + # original graph + with tf.name_scope("ori"): + out_ori1, out_ori2, out_ori3 = ori_fused_embedding_gather_graph( + data, + slice_input, + begin + ) + + # optimized graph + with tf.name_scope("opt"): + out_opt1, out_opt2, out_opt3 = opt_fused_embedding_gather_graph( + data=data, + slice_input=slice_input, + begin=begin + ) + + with tf.compat.v1.Session(config=self.config) as sess: + # run ori + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + # run opt + out_opt_val1, out_opt_val2, out_opt_val3 = sess.run( + [out_opt1, out_opt2, out_opt3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + # 功能测试 + np.testing.assert_array_equal( + out_ori_val1, + out_opt_val1, + err_msg="Segment count mismatch" + ) + + np.testing.assert_array_equal( + out_ori_val2, + out_opt_val2, + err_msg="Segment count mismatch" + ) + np.testing.assert_allclose( + out_opt_val3, + out_ori_val3, + rtol=1e-6, + err_msg="Output values mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedGather", + start_op="ori/strided_slice_1", + end_op="ori/GatherV2_1", + num_runs=10000, + tag="--TF_origin--" + ) + if __name__ == "__main__": tf.compat.v1.disable_eager_execution() - unittest.main(argv=[''], verbosity=2) \ No newline at end of file + unittest.main(argv=[''], verbosity=2) diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py index b9950e51c61f0b818a038ced840c24e7347e043a..0b9437928c88fad52d7e5d8fed2d1219c6e77d82 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_padding_test.py @@ -1,23 +1,27 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op np.random.seed(140) def opt_fused_embedding_padding_fast_graph(input0, input1, input2, input3): - # execute custom op - _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding_fast(input0, input1, input2, input3) - return custom_out - + # execute custom op + _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding_fast(input0, input1, input2, input3) + return custom_out + + def opt_fused_embedding_padding_graph(input0, input1, input2, input3): # execute custom op _, custom_out = gen_embedding_fused_ops.kp_fused_embedding_padding(input0, input1, input2, input3) return custom_out + def ori_fused_embedding_padding_fast_graph(input0, input1, input2, input3): cast = tf.cast(input0, tf.int32) begin = tf.constant([0], dtype=tf.int32) @@ -34,6 +38,7 @@ def ori_fused_embedding_padding_fast_graph(input0, input1, input2, input3): output = tf.strided_slice(shape_tensor, begin=begin, end=end, strides=strides, shrink_axis_mask=1) return output + def ori_fused_embedding_padding_graph(input0, input1, input2, input3): cast = tf.cast(input0, tf.int32) begin = tf.constant([0], dtype=tf.int32) @@ -55,7 +60,7 @@ class TestFusedEmbeddingPadding(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -87,28 +92,34 @@ class TestFusedEmbeddingPadding(unittest.TestCase): # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - ori_result = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - opt_result = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + ori_result = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + opt_result = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( ori_result, opt_result, err_msg="result mismatch" ) - - from tensorflow.python.client import timeline - tl_ori = timeline.Timeline(self.run_metadata_ori.step_stats) - tl_opt = timeline.Timeline(self.run_metadata_opt.step_stats) - ctf_ori = tl_ori.generate_chrome_trace_format() - ctf_opt = tl_opt.generate_chrome_trace_format() - - with open("timeline_ori.json", "w") as f: - f.write(ctf_ori) - with open("timeline_opt.json", "w") as f: - f.write(ctf_opt) - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed), wrapper_sess(sess, [out_opt], feed_dict=feed), "KPFusedEmbeddingPadding") + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingPadding", + start_op="ori/Cast", + end_op="ori/Reshape", + num_runs=1000, + tag="-------TF_origin-------" + ) + def test_func_kp_fused_embedding_padding_fast(self): # Create Graph @@ -132,20 +143,34 @@ class TestFusedEmbeddingPadding(unittest.TestCase): # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - ori_result = sess.run([out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - opt_result = sess.run([out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + ori_result = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + opt_result = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) np.testing.assert_array_equal( ori_result, opt_result, err_msg="result mismatch" ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedEmbeddingPaddingFast", + start_op="ori/Cast", + end_op="ori/StridedSlice_1", + num_runs=1000, + tag="---------TF_origin---------" + ) - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") - - # perftest - perf_run(wrapper_sess(sess, [out_ori], feed), wrapper_sess(sess, [out_opt], feed_dict=feed), "KPFusedEmbeddingPaddingFast") if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py index 4de55241c17cda33e2d8380b9c7cec9d74fd7083..ab471db894ebe419bd9bfeaaae1116c02092d30c 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_dynamic_stitch_test.py @@ -1,31 +1,47 @@ -import os +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + +np.random.seed(140) + -class TestSparseSegmentMeanSlice(unittest.TestCase): +def ori_fused_sparse_dynamic_stitch_graph(x, emb_tables): + x_1 = tf.reshape(x, shape=[-1]) # 将输入 x 展平成一维向量 x_1 + group_ids = tf.math.floormod(x_1, 12) + group_ids = tf.cast(group_ids, dtype=np.int32) + chunk_indices = tf.math.floordiv(x_1, 12) + original_indices = tf.range(0, tf.size(x_1), 1) + a = tf.dynamic_partition(original_indices, group_ids, num_partitions=12) + b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=12) + c = [tf.gather(emb_tables[i], b[i]) for i in range(12)] + d = tf.dynamic_stitch(a, c) + return d + + +def opt_fused_sparse_dynamic_stitch_graph(x, emb_tables): + output = gen_embedding_fused_ops.KPFusedSparseDynamicStitch( + x = x, + variables = emb_tables + ) + return output + + +class TestSparseDynamicStitch(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops + """Initialize config""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 - cls.variables = [] - max_val = float('inf') - for i in range(12): - N_i = np.random.randint(1000000, 44739244) - max_val = min(N_i, max_val) - var = tf.Variable( - tf.random.normal([N_i, 10], dtype=tf.float32), # shape: (N_i, 10) - name=f"embedding_table_{i}" - ) - cls.variables.append(var) - print(f"Created variable {i}: shape={var.shape}") - - x_np = np.random.randint(0, 12*max_val, size=(10000, 12)) - cls.x = tf.constant(x_np, dtype=tf.int64) + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() # Create tf session cls.sess = tf.compat.v1.Session() @@ -36,11 +52,29 @@ class TestSparseSegmentMeanSlice(unittest.TestCase): cls.sess.close() def test_base(self): - x_first = self.sess.run(self.x) - var_first = self.sess.run(self.variables[0]) + variables = [] + max_val = float('inf') + for i in range(12): + N_i = np.random.randint(100000, 4473924) + max_val = min(N_i, max_val) + var = tf.Variable( + tf.random.normal([N_i, 10], dtype=tf.float32), # shape: (N_i, 10) + name=f"embedding_{i}" + ) + variables.append(var) + # print(f"Created variable {i}: shape={var.shape}") + + x_np = np.random.randint(0, 12*max_val, size=(10000, 12)) + x = tf.constant(x_np, dtype=tf.int64) - x_second = self.sess.run(self.x) - var_second = self.sess.run(self.variables[0]) + self.sess.run(tf.compat.v1.variables_initializer(variables)) + + x_first = self.sess.run(x) + var_first = self.sess.run(variables[0]) + + x_second = self.sess.run(x) + var_second = self.sess.run(variables[0]) + np.testing.assert_allclose( x_first, x_second, @@ -55,42 +89,72 @@ class TestSparseSegmentMeanSlice(unittest.TestCase): err_msg="Input values mismatch" ) - # execute custom op - custom_out = self.custom_op.KPFusedSparseDynamicStitch(x=self.x, variables=self.variables) - - # tf native implementation - tf_out = self._tf_reference_impl(x=self.x, variables=self.variables) - - custom_out_val = self.sess.run([custom_out]) - tf_out_val = self.sess.run([tf_out]) - print("custom_shape: ") - print(custom_out_val[0].shape) - print("tf_out shape: ") - print(tf_out_val[0].shape) - # Numerical comparison - np.testing.assert_allclose( - custom_out_val[0], - tf_out_val[0], - rtol=1e-6, - err_msg="Output values mismatch" - ) - - def _tf_reference_impl(self, x, variables): - x_1 = tf.reshape(x, shape=[-1]) - group_ids = tf.math.floormod(x_1, 12) - group_ids = tf.cast(group_ids, dtype=np.int32) - chunk_indices = tf.math.floordiv(x_1, 12) - - original_indices = tf.range(0,tf.size(x_1),1) - - a = tf.dynamic_partition(original_indices, group_ids, num_partitions=12) - b = tf.dynamic_partition(chunk_indices, group_ids, num_partitions=12) - - c = [tf.gather(variables[i], b[i]) for i in range(12)] - - d = tf.dynamic_stitch(a, c) + def test_kp_sparse_dynamic_stitch(self): + # Create Graph + with tf.Graph().as_default(): + num_tables = 12 + emb_dim = 10 + max_val = float('inf') + # 每张表的 placeholder,行数随机生成 + tables = [] + table_sizes = [] + for i in range(num_tables): + N_i = np.random.randint(1000000, 44739244) + table_sizes.append(N_i) + max_val = min(N_i, max_val) + table_ph = tf.compat.v1.placeholder( + tf.float32, shape=(N_i, emb_dim), name=f"embedding_table_{i}" + ) + tables.append(table_ph) + # 生成全局索引 placeholder + x_shape = (1000, num_tables) + input_x = tf.compat.v1.placeholder(tf.int64, shape=x_shape, name="input_x") + # 初始化 feed 数据 + feed = {} + rng = np.random.default_rng(12345) + # 为每张表生成随机 embedding 数据 + for i in range(num_tables): + feed[tables[i]] = rng.standard_normal(size=(table_sizes[i], emb_dim)).astype(np.float32) + # 生成索引数据(保持原逻辑:范围是 0 ~ num_tables * max_val - 1) + feed[input_x] = rng.integers( + low=0, high=num_tables * max_val, size=x_shape, dtype=np.int64 + ) + with tf.name_scope("ori"): + out_ori = ori_fused_sparse_dynamic_stitch_graph(input_x, tables) + with tf.name_scope("opt"): + out_opt = opt_fused_sparse_dynamic_stitch_graph(input_x, tables) + + # Create tf session + with tf.compat.v1.Session(config=self.config) as sess: + # functest + out_ori_val = sess.run( + [out_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori + ) + out_opt_val = sess.run( + [out_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt + ) + + np.testing.assert_array_equal( + out_ori_val, + out_opt_val, + err_msg="result mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori], + [out_opt], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseDynamicStitch", + start_op="ori/Reshape", + end_op="ori/DynamicStitch", + num_runs=100, + tag="--------TF_origin---------" + ) - return d if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py index 37d275315eafdaba1e745e887ba5be684015adf5..8236393d5901115b0247476627bc32d7b57839a6 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py @@ -1,66 +1,15 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op -class TestFusedSparseReshape(unittest.TestCase): - @classmethod - def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_slice_input = np.array([[0, 0], [0, 1], [1, 2], [3, 4]], dtype=np.int64) - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 1] - cls.base_newshape = [2, 4] - # Create tf session - cls.sess = tf.compat.v1.Session() - - @classmethod - def tearDownClass(cls): - cls.sess.close() - - def test_custom(self): - # execute custom op - custom_out1, custom_out2, = self.custom_op.KPFusedSparseReshape( - slice_input=self.base_slice_input, - begin=self.base_begin, - new_shape=self.base_newshape - ) - - # tf native implementation - tf_out1, tf_out2, tf_out3 = self._tf_reference_impl( - self.base_slice_input, - self.base_begin, - self.base_newshape - ) - - custom_out_val1, custom_out_val2 = self.sess.run([custom_out1, custom_out2]) - tf_out_val1, tf_out_val2, tf_out_val3 = self.sess.run([tf_out1, tf_out2, tf_out3]) - - print("custom_out_val1: ", custom_out_val1) - print("custom_out_val2: ", custom_out_val2) - print("tf_out_val1: ", tf_out_val1) - print("tf_out_val2: ", tf_out_val2) - - np.testing.assert_array_equal( - custom_out_val1, - tf_out_val1, - err_msg="Segment count mismatch" - ) - - np.testing.assert_array_equal( - custom_out_val2, - tf_out_val2, - err_msg="Segment count mismatch" - ) - def _tf_reference_impl(self, slice_input, begin, new_shape): - slice67_out = tf.strided_slice( +def ori_fused_embedding_sparse_reshape_graph(slice_input, begin, newshape): + slice67_out = tf.strided_slice( slice_input, begin=begin, end=[0, 2], @@ -70,32 +19,120 @@ class TestFusedSparseReshape(unittest.TestCase): shrink_axis_mask=2 ) - slice67_out = tf.reshape(slice67_out, [-1, 1]) - shape_out = tf.shape(slice67_out) - slice57_out = tf.strided_slice( - shape_out, - begin=[0], - end=[1], - strides=[1], - shrink_axis_mask=1 - ) - - const2 = tf.constant(2) - input_shape = tf.stack([slice57_out, const2]) - input_shape = tf.cast(input_shape, tf.int64) + slice67_out = tf.reshape(slice67_out, [-1, 1]) + shape_out = tf.shape(slice67_out) + slice57_out = tf.strided_slice( + shape_out, + begin=[0], + end=[1], + strides=[1], + shrink_axis_mask=1 + ) + + const2 = tf.constant(2) + input_shape = tf.stack([slice57_out, const2]) + input_shape = tf.cast(input_shape, tf.int64) + + range_out = tf.range(0, slice57_out, 1) + range_out = tf.reshape(range_out, [-1, 1]) + range_out_64 = tf.cast(range_out, dtype=tf.int64) + concat_out = tf.concat([range_out_64, slice67_out], axis=-1) + + sparse_tensor = tf.SparseTensor( + indices=concat_out, + values=[1,2,3,4], + dense_shape=input_shape + ) + sparse_tensor_out = tf.sparse.reshape(sparse_tensor, newshape) + return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out + + +def opt_fused_sparse_reshape_graph(slice_input, begin, newshape): + custom_out1, custom_out2 = gen_embedding_fused_ops.KPFusedSparseReshape( + slice_input=slice_input, + begin=begin, + new_shape=newshape + ) + return custom_out1, custom_out2 + + +class TestFusedSparseReshape(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() + + @classmethod + def tearDownClass(cls): + # cls.sess.close() + return + + def test_kp_sparse_reshape(self): + with tf.Graph().as_default(): + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(4,2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + newshape = tf.compat.v1.placeholder(tf.int64, shape=(2,), name="newshape") + base_slice_input = np.array([[0, 0], [0, 1], [1, 2], [3, 4]], dtype=np.int64) + base_begin = [0, 1] + base_newshape = [2, 4] + feed = { + slice_input: base_slice_input, + begin: base_begin, + newshape: base_newshape + } + + with tf.name_scope("ori"): + out_ori1, out_ori2, out_ori3 = ori_fused_embedding_sparse_reshape_graph(slice_input, begin, newshape) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_fused_sparse_reshape_graph(slice_input, begin, newshape) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1,out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + + # 功能测试 + np.testing.assert_array_equal( + out_opt_val1, + out_ori_val1, + err_msg="Segment count mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseReshape", + start_op="ori/StridedSlice", + end_op="ori/SparseReshape", + num_runs=10000, + tag="------TF_origin-----" + ) - range_out = tf.range(0, slice57_out, 1) - range_out = tf.reshape(range_out, [-1, 1]) - range_out_64 = tf.cast(range_out, dtype=tf.int64) - concat_out = tf.concat([range_out_64, slice67_out], axis=-1) - - sparse_tensor = tf.SparseTensor( - indices=concat_out, - values=[1,2,3,4], - dense_shape=input_shape - ) - sparse_tensor_out = tf.sparse.reshape(sparse_tensor, new_shape) - return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py index 69c7a114f40062582af03df5c29c644014a147f9..5536eb1cd024eef7556433451f5cf71d883b6c12 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_segment_reduce_test.py @@ -1,133 +1,238 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops +from utils.utils import benchmark_op + + +def ori_fused_embedding_sparse_segment_reduce_graph(data, indices, slice_input, begin, end, strides, is_mean): + slice_out = tf.strided_slice( + slice_input, + begin=begin, + end=end, + strides=strides, + begin_mask=1, + end_mask=1, + shrink_axis_mask=2 + ) + + segment_ids = tf.cast(slice_out, dtype=tf.int32) + if is_mean: + output = tf.sparse.segment_mean( + data=data, + indices=indices, + segment_ids=segment_ids + ) + else: + output = tf.sparse.segment_sum( + data=data, + indices=indices, + segment_ids=segment_ids + ) + + output_shape = tf.shape(output) + slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1]) + + return output, slice_out + + +def opt_fused_embedding_sparse_segment_reduce_graph(data, indices, slice_input, begin, end, strides, is_mean): + if is_mean: + custom_out, custom_slice_out = gen_embedding_fused_ops.KPFusedSparseSegmentReduce( + data=data, + indices=indices, + slice_input=slice_input, + begin=begin, + end=end, + strides=strides + ) + else: + custom_out, custom_slice_out = gen_embedding_fused_ops.KPFusedSparseSegmentReduce( + data=data, + indices=indices, + slice_input=slice_input, + begin=begin, + end = end, + strides=strides, + combiner=0 + ) + return custom_out, custom_slice_out + class TestSparseSegmentMeanSlice(unittest.TestCase): @classmethod def setUpClass(cls): - """Initialize test data and custom op""" - # Load custom op - cls.custom_op = gen_embedding_fused_ops - - # Base test data - cls.base_data = np.array([[1.0, 2.0, 3.0], [3.0, 4.0,5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], dtype=np.float32) # shape {4, 3} - cls.base_indices = np.array([0, 1, 2], dtype=np.int64) # shape {3} - cls.base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) # shape {3, 2} - cls.base_begin = [0, 1] - cls.base_end = [0, 2] - cls.base_strides = [1, 2] - # Create tf session - cls.sess = tf.compat.v1.Session() + """Initialize""" + cls.config = tf.compat.v1.ConfigProto() + cls.config.intra_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 + + cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) + cls.run_metadata_ori = tf.compat.v1.RunMetadata() + cls.run_metadata_opt = tf.compat.v1.RunMetadata() @classmethod def tearDownClass(cls): - cls.sess.close() + return def test_mean(self): - # execute custom op - custom_out, custom_slice_out = self.custom_op.KPFusedSparseSegmentReduce( - data=self.base_data, - indices=self.base_indices, - slice_input=self.base_slice_input, - begin=self.base_begin, - end = self.base_end, - strides = self.base_strides - ) + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=(4,3), name="data") + indices = tf.compat.v1.placeholder(tf.int32, shape=(3,), name="indices") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(3,2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + end = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="end") + strides = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="strides") + + base_data = np.array( + [[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], + dtype=np.float32 + ) # shape {4, 3} + base_indices = np.array([0, 1, 2], dtype=np.int64) # shape {3} + base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) # shape {3, 2} + base_begin = [0, 1] + base_end = [0, 2] + base_strides = [1, 2] + + feed = { + data: base_data, + indices: base_indices, + slice_input: base_slice_input, + begin: base_begin, + end: base_end, + strides: base_strides + } + + with tf.name_scope("ori"): + out_ori1, out_ori2 = ori_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, True + ) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, True + ) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2 = sess.run( + [out_ori1, out_ori2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1, out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) - # tf native implementation - tf_out, tf_slice_out = self._tf_reference_impl( - self.base_data, - self.base_indices, - self.base_slice_input, - self.base_begin, - self.base_end, - self.base_strides, - True - ) + np.testing.assert_allclose( + out_opt_val1, + out_ori_val1, + rtol=1e-6, + err_msg="Output values mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSegmentReduce", + start_op="ori/StridedSlice", + end_op="ori/StridedSlice_1", + num_runs=500, + tag="--------TF_origin---------" + ) - custom_out_val, custom_slice_out_val = self.sess.run([custom_out, custom_slice_out]) - tf_out_val, tf_slice_out_val = self.sess.run([tf_out, tf_slice_out]) - - # Numerical comparison - np.testing.assert_allclose( - custom_out_val, - tf_out_val, - rtol=1e-6, - err_msg="Output values mismatch" - ) - np.testing.assert_array_equal( - custom_slice_out_val, - tf_slice_out_val, - err_msg="Segment count mismatch" - ) def test_sum(self): - custom_out, custom_slice_out = self.custom_op.KPFusedSparseSegmentReduce( - data=self.base_data, - indices=self.base_indices, - slice_input=self.base_slice_input, - begin=self.base_begin, - end = self.base_end, - strides = self.base_strides, - combiner=0 - ) - - tf_out, tf_slice_out = self._tf_reference_impl( - self.base_data, - self.base_indices, - self.base_slice_input, - self.base_begin, - self.base_end, - self.base_strides, - False - ) - - custom_out_val, custom_slice_out_val = self.sess.run([custom_out, custom_slice_out]) - tf_out_val, tf_slice_out_val = self.sess.run([tf_out, tf_slice_out]) - - np.testing.assert_allclose( - custom_out_val, - tf_out_val, - rtol=1e-6, - err_msg="Output values mismatch" - ) - np.testing.assert_array_equal( - custom_slice_out_val, - tf_slice_out_val, - err_msg="Segment count mismatch" - ) + with tf.Graph().as_default(): + data = tf.compat.v1.placeholder(tf.float32, shape=(4,3), name="data") + indices = tf.compat.v1.placeholder(tf.int32, shape=(3,), name="indices") + slice_input = tf.compat.v1.placeholder(tf.int64, shape=(3,2), name="slice_input") + begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") + end = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="end") + strides = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="strides") + + base_data = np.array( + [[1.0, 2.0, 3.0], [3.0, 4.0,5.0], [5.0, 6.0, 7.0], [5.0, 6.0, 7.0]], + dtype=np.float32 + ) # shape {4, 3} + base_indices = np.array([0, 1, 2], dtype=np.int64) + base_slice_input = np.array([[0, 0], [0, 2], [1, 2]], dtype=np.int64) + base_begin = [0, 1] + base_end = [0, 2] + base_strides = [1, 2] + + feed = { + data: base_data, + indices: base_indices, + slice_input: base_slice_input, + begin: base_begin, + end: base_end, + strides: base_strides + } + with tf.name_scope("ori"): + out_ori1, out_ori2 = ori_fused_embedding_sparse_segment_reduce_graph( + data, indices, slice_input, begin, end, strides, False + ) + with tf.name_scope("opt"): + out_opt1, out_opt2 = opt_fused_embedding_sparse_segment_reduce_graph( + data,indices, slice_input, begin, end, strides, False + ) + + with tf.compat.v1.Session(config=self.config) as sess: + out_ori_val1, out_ori_val2 = sess.run( + [out_ori1, out_ori2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2 = sess.run( + [out_opt1, out_opt2], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) + np.testing.assert_allclose( + out_opt_val1, + out_ori_val1, + rtol=1e-6, + err_msg="Output values mismatch" + ) + np.testing.assert_array_equal( + out_opt_val2, + out_ori_val2, + err_msg="Segment count mismatch" + ) + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2], + [out_opt1, out_opt2], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSegmentReduce", + start_op="ori/StridedSlice", + end_op="ori/StridedSlice_1", + num_runs=1000, + tag="---------TF_origin--------" + ) - def _tf_reference_impl(self, data, indices, slice_input, begin, end, strides, is_mean): - slice_out = tf.strided_slice( - slice_input, - begin= begin, - end= end, - strides= strides, - begin_mask=1, - end_mask=1, - shrink_axis_mask=2 - ) - - segment_ids = tf.cast(slice_out, dtype=tf.int32) - if is_mean: - output = tf.sparse.segment_mean( - data = data, - indices = indices, - segment_ids= segment_ids - ) - else: - output = tf.sparse.segment_sum( - data = data, - indices = indices, - segment_ids= segment_ids - ) - - output_shape = tf.shape(output) - slice_out = tf.strided_slice(output_shape, begin=[0], end=[1], strides=[1]) - - return output, slice_out if __name__ == "__main__": tf.compat.v1.disable_eager_execution() diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py index d37128cb28003afaf48ebe7b991e433c25bbf819..54c0926fcbc3415a843e32108b62c044558444a5 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_select_test.py @@ -1,9 +1,13 @@ +# Copyright 2025 The Huawei Technologies Co. Authors. All Rights Reserved. +import unittest + import tensorflow as tf import numpy as np -import unittest from tensorflow.python.ops import gen_embedding_fused_ops -from utils.utils import perf_run, generate_timeline, wrapper_sess +from utils.utils import benchmark_op + +np.random.seed(140) def ori_fused_embedding_sparse_select_graph(input_a, input_b, input_c): @@ -45,7 +49,7 @@ class TestKPFusedSparseSelect(unittest.TestCase): """Initialize config""" cls.config = tf.compat.v1.ConfigProto() cls.config.intra_op_parallelism_threads = 16 - cls.config.inter_op_parallelism_threads = 16 + cls.config.inter_op_parallelism_threads = 1 cls.run_options = tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE) cls.run_metadata_ori = tf.compat.v1.RunMetadata() @@ -68,46 +72,63 @@ class TestKPFusedSparseSelect(unittest.TestCase): input2: np.random.randint(0, 100, size=(20, 50)).astype(np.int32), } with tf.name_scope("ori"): - out0_ori, out1_ori, out2_ori = ori_fused_embedding_sparse_select_graph(input0, input1, input2) + out_ori1, out_ori2, out_ori3 = ori_fused_embedding_sparse_select_graph(input0, input1, input2) with tf.name_scope("opt"): - out0_opt, out1_opt, out2_opt = opt_fused_embedding_sparse_select_graph(input0, input1, input2) + out_opt1, out_opt2, out_opt3 = opt_fused_embedding_sparse_select_graph(input0, input1, input2) # Create tf session with tf.compat.v1.Session(config=self.config) as sess: # functest - out0_ori_val, out1_ori_val, out2_ori_val = sess.run([out0_ori, out1_ori, out2_ori], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_ori) - out0_opt_val, out1_opt_val, out2_opt_val = sess.run([out0_opt, out1_opt, out2_opt], feed_dict=feed, options=self.run_options, run_metadata=self.run_metadata_opt) + out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( + [out_ori1, out_ori2, out_ori3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_ori + ) + out_opt_val1, out_opt_val2, out_opt_val3 = sess.run( + [out_opt1, out_opt2, out_opt3], + feed_dict=feed, + options=self.run_options, + run_metadata=self.run_metadata_opt + ) np.testing.assert_allclose( - out0_ori_val, - out0_opt_val, + out_ori_val1, + out_opt_val1, rtol=1e-5, err_msg="Output values mismatch" ) np.testing.assert_allclose( - out1_ori_val, - out1_opt_val, + out_ori_val2, + out_opt_val2, rtol=1e-5, err_msg="Output values mismatch" ) np.testing.assert_allclose( - out2_ori_val, - out2_opt_val, + out_ori_val3, + out_opt_val3, rtol=1e-5, err_msg="Output values mismatch" ) - - generate_timeline(self.run_metadata_ori.step_stats, f"{self._testMethodName}_ori") - generate_timeline(self.run_metadata_opt.step_stats, f"{self._testMethodName}_opt") - - # perftest - perf_run(wrapper_sess(sess, [out0_ori, out1_ori, out2_ori], feed_dict=feed), - wrapper_sess(sess, [out0_opt, out1_opt, out2_opt], feed_dict=feed), - "KPFusedEmbeddingSparseSelect") + + benchmark_op( + sess, + feed, + [out_ori1, out_ori2, out_ori3], + [out_opt1, out_opt2, out_opt3], + self.run_options, + self.run_metadata_ori, + self.run_metadata_opt, + op_name="KPFusedSparseSelect", + start_op="ori/Reshape", + end_op="ori/Sub", + num_runs=1000, + tag="-----TF_origin-----" + ) if __name__ == "__main__": tf.compat.v1.disable_eager_execution() - unittest.main(argv=[''], verbosity=2) \ No newline at end of file + unittest.main(argv=[''], verbosity=2) diff --git a/tensorflow/python/grappler/embedding_fused_test/utils/utils.py b/tensorflow/python/grappler/embedding_fused_test/utils/utils.py index 06f02d6bf1d1a715b61a739f3331ec5eaa5057cf..f982f5dd381536a5c93de4c1c788e367dc198d10 100644 --- a/tensorflow/python/grappler/embedding_fused_test/utils/utils.py +++ b/tensorflow/python/grappler/embedding_fused_test/utils/utils.py @@ -1,10 +1,88 @@ import timeit +import json +import os from tensorflow.python.client import timeline -def perf_run(ori_func, opt_func, name, warmup=5, iters=50): - +def extract_op_dur(timeline_file, op_name): + """从 timeline JSON 文件中提取指定算子(fusedOp)的耗时(μs)""" + with open(f"timeline/{timeline_file}.json", "r") as f: + trace_events = json.load(f)["traceEvents"] # timeline.json的格式 + durations = [e["dur"] for e in trace_events if e.get("name") == op_name and "dur" in e] + return durations[0] + + +def extract_op_total_time(timeline_file, start_op, end_op): + """计算从 start_op 到 end_op 的总耗时(包含调度空隙)""" + with open(f"timeline/{timeline_file}.json", "r") as f: + trace_events = json.load(f)["traceEvents"] + start_event = next(e for e in trace_events if e.get("args", {}).get("name") == start_op) # 找到 timeline 里第一个 name 等于 start_op 的事件 + end_event = next(e for e in trace_events if e.get("args", {}).get("name") == end_op) # 找不到会报错 + start_time = start_event["ts"] + end_time = end_event["ts"] + end_event["dur"] # ts 是开始时间,dur是算子的持续时间 + return end_time - start_time + + +def benchmark_op( + sess, + feed, + out_ori, + out_opt, + run_options, + run_metadata_ori, + run_metadata_opt, + op_name, + start_op, + end_op, + num_runs=500, + tag="--------TF_origin---------" +): + print("-" * 60) + print("-" * 60) + print("new test") + + total_times_ori = 0.0 + total_times_opt = 0.0 + + for i in range(num_runs): + # 执行原始算子 + sess.run( + out_ori, + feed_dict=feed, + options=run_options, + run_metadata=run_metadata_ori + ) + # 执行优化后的算子 + sess.run( + out_opt, + feed_dict=feed, + options=run_options, + run_metadata=run_metadata_opt + ) + + # 生成 timeline 文件 + filename_ori = f"{op_name}_ori" + filename_opt = f"{op_name}_opt" + generate_timeline(run_metadata_ori.step_stats, filename_ori) + generate_timeline(run_metadata_opt.step_stats, filename_opt) + + # 统计时延 + total_times_ori += extract_op_total_time(filename_ori, start_op, end_op) + total_times_opt += extract_op_dur(filename_opt, op_name) + + # 计算平均值和加速比 + avg_ori = total_times_ori / num_runs + avg_opt = total_times_opt / num_runs + speedup = (avg_ori - avg_opt) / avg_ori * 100 if avg_ori > 0 else 0 + + # 打印结果 + print(f"{tag}: {avg_ori:.4f} us per run") + print(f"{op_name}: {avg_opt:.4f} us per run") + print(f"improve: {speedup:.2f}%") + + +def perf_run(ori_func, opt_func, name, warmup=5, iters=5): print(f"\nWarmup ori: {warmup} iters") for _ in range(warmup): ori_func() @@ -12,7 +90,7 @@ def perf_run(ori_func, opt_func, name, warmup=5, iters=50): print(f"Running performance test: ori {iters} iters") total_time = timeit.timeit(ori_func, number=iters) ori_avg_time = total_time / iters * 1000 - print(f"{name}: {ori_avg_time:.2f} ms per run") + print(f"{name}: {ori_avg_time:.6f} ms per run") print(f"\nWarmup opt: {warmup} iters") for _ in range(warmup): @@ -21,7 +99,7 @@ def perf_run(ori_func, opt_func, name, warmup=5, iters=50): print(f"Running performance test: opt {iters} iters") total_time = timeit.timeit(opt_func, number=iters) opt_avg_time = total_time / iters * 1000 - print(f"{name}: {opt_avg_time:.2f} ms per run") + print(f"{name}: {opt_avg_time:.6f} ms per run") improvement = (ori_avg_time - opt_avg_time) / ori_avg_time * 100 print(f"improve: {improvement:.2f}%") @@ -36,4 +114,4 @@ def generate_timeline(step_stats, filename): def wrapper_sess(sess, fetches, feed_dict=None, options=None, run_metadata=None): - return lambda: sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) \ No newline at end of file + return lambda: sess.run(fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata) \ No newline at end of file