From d46ddc2c2cd02c230df67cceaa0ff63281deeffd Mon Sep 17 00:00:00 2001 From: rayshine <1324789704@qq.com> Date: Tue, 2 Sep 2025 15:51:35 +0800 Subject: [PATCH] fix:solve reshape and select op issues --- .../kernels/embedding_fused_sparse_reshape.cc | 24 ++++----- .../embedding_fused_sparse_reshape_test.cc | 51 +++++++++++++++++-- .../kernels/embedding_fused_sparse_select.cc | 8 +-- 3 files changed, 63 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc index 5e84b5bd2..e4acad2c6 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc @@ -157,39 +157,37 @@ class KPFusedSparseReshapeOp : public OpKernel { OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2")); + OP_REQUIRES(context, pack_const.dims() == 0, + errors::InvalidArgument("pack_const must be a scalar")); VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); VLOG(1) << "Input begin value: " << begin.DebugString(); VLOG(1) << "Input new_shape value: " << new_shape.DebugString(); + OP_REQUIRES(context, begin.dims() == 1 && begin.dim_size(0) == 2, + errors::InvalidArgument("begin must be 1D with at least 2 elements")); int32 col = begin.flat().data()[1]; OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)")); - int64_t stridedslice57_out = slice_input.dim_size(0); + int64_t num_rows = slice_input.dim_size(0); auto slice_input_mat = slice_input.matrix(); - VLOG(1) << "stridedslice57_out: " << stridedslice57_out; + VLOG(1) << "num_rows: " << num_rows; VLOG(1) << "slice_input.dim_size(0): " << slice_input.dim_size(0); VLOG(1) << "slice_input.dim_size(1): " << slice_input.dim_size(1); - OP_REQUIRES(context, stridedslice57_out == slice_input.dim_size(0), errors::Internal("concat shape mismatch")); VLOG(1) << "Column index from begin: " << col; - VLOG(1) << "indices size: " << stridedslice57_out; Tensor shape_in(DT_INT64, TensorShape({2})); auto tensor_flat = shape_in.flat(); - tensor_flat(0) = stridedslice57_out; - tensor_flat(1) = pack_const.flat().data()[0]; + tensor_flat(0) = num_rows; + tensor_flat(1) = pack_const.scalar()(); - Tensor indices_in(DT_INT64, TensorShape({stridedslice57_out, 2})); + Tensor indices_in(DT_INT64, TensorShape({num_rows, 2})); auto indices_in_mat = indices_in.matrix(); - for (int i = 0; i < stridedslice57_out; ++i) { + for (int i = 0; i < num_rows; ++i) { indices_in_mat(i, 0) = i; indices_in_mat(i, 1) = slice_input_mat(i, col); } - Tensor new_shape_in(DT_INT64, TensorShape({2})); - auto newshape_tensor_flat = new_shape_in.flat(); - newshape_tensor_flat(0) = new_shape.flat()(0); - newshape_tensor_flat(1) = new_shape.flat()(1); - ReshapeKp(context, indices_in, shape_in, new_shape_in, 0, 1); + ReshapeKp(context, indices_in, shape_in, new_shape, 0, 1); } }; diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc index c96b685cb..a42bdea36 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape_test.cc @@ -55,7 +55,7 @@ class KPFusedSparseReshapeTest : public OpsTestBase { AddInputFromArray(slice_shape, slice_data); AddInputFromArray(TensorShape({2}), begin_val); AddInputFromArray(TensorShape({2}), new_shape_val); - AddInputFromArray(TensorShape({1}), pack_const_val); + AddInputFromArray(TensorShape({}), pack_const_val); TF_ASSERT_OK(RunOpKernel()); @@ -85,9 +85,9 @@ class KPFusedSparseReshapeTest : public OpsTestBase { TF_CHECK_OK(InitOp()); AddInputFromArray(slice_shape, slice_data); - AddInputFromArray(TensorShape({2}), begin_val); + AddInputFromArray(TensorShape({static_cast(begin_val.size())}), begin_val); AddInputFromArray(TensorShape({static_cast(new_shape_val.size())}), new_shape_val); - AddInputFromArray(TensorShape({1}), pack_const_val); + AddInputFromArray(TensorShape({}), pack_const_val); return RunOpKernel(); } @@ -233,4 +233,49 @@ TEST_F(KPFusedSparseReshapeTest, Invalid_ProductZeroWithUnknownDim) { << "Actual error: " << s.error_message(); } +// 反例9:begin 是 1D 但长度为 1(不够 2 个元素) +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize1) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0}, // begin = [0],长度为 1 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("begin must be 1D with at least 2 elements") != std::string::npos) + << "Actual error: " << s.error_message(); +} + +// 反例10:begin 是 1D 但长度为 3(超过 2) +TEST_F(KPFusedSparseReshapeTest, Invalid_BeginRank1ButSize3) { + Status s = RunOpExpectFailure( + TensorShape({2, 2}), {0, 1, 1, 0}, + {0, 1, 2}, // begin = [0,1,2],长度为 3 + {2, 2}, + {2}); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("begin must be 1D with at least 2 elements") != std::string::npos) + << "Actual error: " << s.error_message(); +} + +// 反例11:pack_const 是标量(0维) +TEST_F(KPFusedSparseReshapeTest, Invalid_PackConstIsScalarButExpect1D) { + TF_CHECK_OK(NodeDefBuilder("kp_fused_sparse_reshape", "KPFusedSparseReshape") + .Input(FakeInput(DT_INT64)) // slice_input + .Input(FakeInput(DT_INT32)) // begin + .Input(FakeInput(DT_INT64)) // new_shape + .Input(FakeInput(DT_INT64)) // pack_const + .Finalize(node_def())); + TF_CHECK_OK(InitOp()); + + AddInputFromArray(TensorShape({2, 2}), {0, 1, 1, 0}); + AddInputFromArray(TensorShape({2}), {0, 1}); + AddInputFromArray(TensorShape({2}), {2, 2}); + AddInputFromArray(TensorShape({1}), {1}); // pack_const = 标量 1(0维) + + Status s = RunOpKernel(); + EXPECT_FALSE(s.ok()); + EXPECT_TRUE(s.error_message().find("pack_const must be a scalar") != std::string::npos) + << "Actual error: " << s.error_message(); +} + } // namespace diff --git a/tensorflow/core/kernels/embedding_fused_sparse_select.cc b/tensorflow/core/kernels/embedding_fused_sparse_select.cc index df19fa100..306a42074 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_select.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_select.cc @@ -95,11 +95,11 @@ public: for (int64 i = start; i < end; i++) { // Greater(bool)+Cast.2406(float) --> 1.0f / 0.0f float a_greater = (a_reshaped_tensor(i, 0) > greater_val) ? 1.0f : 0.0f; - float select_2412 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f - float select_2415 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : select_2412; // Fill.2409-->1.0f + float res_equal1 = (b_reshaped_tensor(i, 0) == equal1_val) ? 1.0f : a_greater; // Fill.2409-->1.0f + float res_equal2 = (b_reshaped_tensor(i, 0) == equal2_val) ? 1.0f : res_equal1; // Fill.2409-->1.0f out_x(i, 0) = a_reshaped_tensor(i, 0); // Reshape.2401 - out_y(i, 0) = select_2415; - out_w(i, 0) = select_2415; // Mul.2419 硬编码 1.0f * input + out_y(i, 0) = res_equal2; + out_w(i, 0) = res_equal2; // Mul.2419 硬编码 1.0f * input out_w(i, 1) = 1.0f; // select_2427被消除,直接使用Fill.2422-->1.0f } }; -- Gitee