diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index a9391225510326763c293fe0265f97db43d46000..c09d1ce14475a681ae250d675d85cbaa66f66778 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -25,68 +25,63 @@ class KPFusedGather : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& data = context->input(0); - const Tensor& slice_input = context->input(1); + const Tensor& keys = context->input(1); const Tensor& begin = context->input(2); - - OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); - OP_REQUIRES(context, data.dims() == 2, errors::Internal("identity dims must == 2")); - OP_REQUIRES(context, data.dim_size(1) == 12, errors::Internal("identity dim size must == [n, 12]")); - - VLOG(1) << "Input identity shape: " << data.shape().DebugString(); - VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); - VLOG(1) << "Input slice_input: " << slice_input.SummarizeValue(1000); - VLOG(1) << "Input begin value: " << begin.SummarizeValue(10); - + VLOG(1) << "Embedding table size: " << data.shape().DebugString(); + VLOG(1) << "Input key shape: " << keys.shape().DebugString(); + VLOG(1) << "Slice begin value: " << begin.DebugString(); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(keys.shape()), + errors::Internal("Input key must be 2D")); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(data.shape()), + errors::Internal("Embedding table shape must be 2D")); + OP_REQUIRES(context, begin.NumElements() == 2, errors::Internal("begin must be same as keys rank")); int32 col = begin.flat().data()[1]; - OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)")); - auto data_mat = data.matrix(); - auto slice_input_mat = slice_input.matrix(); - - VLOG(1) << "Column index from begin: " << col; + OP_REQUIRES(context, col < keys.dim_size(1), errors::Internal("slice cols out of keys range")); + + Tensor* out_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 1, TensorShape({static_cast(keys.dim_size(0))}), &out_indices)); + int32 *out_indices_data = out_indices->flat().data(); + auto keys_mat = keys.matrix(); std::vector unique_values; - std::vector indices(slice_input.dim_size(0)); std::unordered_map value_to_index; int current_index = 0; - for (int64_t i = 0; i < slice_input.dim_size(0); ++i) { - auto it = value_to_index.find(slice_input_mat(i, col)); + for (int64_t i = 0; i < keys.dim_size(0); ++i) { + auto it = value_to_index.find(keys_mat(i, col)); if (it == value_to_index.end()) { - value_to_index[slice_input_mat(i, col)] = current_index; - unique_values.push_back(slice_input_mat(i, col)); - indices[i] = current_index; - current_index++; + value_to_index[keys_mat(i, col)] = current_index; + unique_values.push_back(keys_mat(i, col)); + out_indices_data[i] = current_index; + ++current_index; } else { - indices[i] = it->second; + out_indices_data[i] = it->second; } } - Tensor* out_shape = nullptr; - Tensor* out_indices = nullptr; - Tensor* out_data = nullptr; - + Tensor* out_unique_value = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({unique_values.size()}), &out_shape)); - std::memcpy(out_shape->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); - - OP_REQUIRES_OK(context, - context->allocate_output( - 1, TensorShape({static_cast(indices.size())}), &out_indices)); - std::memcpy(out_indices->data(), indices.data(), indices.size() * sizeof(int32_t)); + 0, TensorShape({static_cast(unique_values.size())}), &out_unique_value)); + std::memcpy(out_unique_value->data(), unique_values.data(), unique_values.size() * sizeof(int64_t)); + Tensor* out_data = nullptr; + int embedding_dims = data.dim_size(1); OP_REQUIRES_OK(context, context->allocate_output( - 2, TensorShape({unique_values.size(), data.dim_size(1)}), &out_data)); - auto output_data = out_data->matrix(); + 2, TensorShape({static_cast(unique_values.size()), embedding_dims}), &out_data)); - int64_t data_row = data.dim_size(0); - int64_t cols = data.dim_size(1); + const float *data_mat = data.flat().data(); for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { int64_t idx = unique_values[cur_row]; - OP_REQUIRES(context, idx < data_row, errors::Internal("idx must < data_row")); - const float* src = data_mat.data() + idx * cols; - float* dst = output_data.data() + cur_row * cols; - std::memcpy(dst, src, cols * sizeof(float)); + OP_REQUIRES(context, idx < data.dim_size(0), errors::Internal("idx out of table range")); + const float* src = data_mat + idx * embedding_dims; + float* dst = out_data->flat().data() + cur_row * embedding_dims; + std::memcpy(dst, src, embedding_dims * sizeof(float)); } } }; diff --git a/tensorflow/core/kernels/embedding_fused_gather_test.cc b/tensorflow/core/kernels/embedding_fused_gather_test.cc index ef93bfb3f95abcd07c96bdeb0e072d5be64376ca..fa187ba387ee646fbee297d76676f5d378a0607f 100644 --- a/tensorflow/core/kernels/embedding_fused_gather_test.cc +++ b/tensorflow/core/kernels/embedding_fused_gather_test.cc @@ -126,7 +126,7 @@ TEST_F(KPFusedGatherTest, Valid_NormalInput) { ); } -// 反例1:data不是2维 +// data不是2维 TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) { std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; Status s = RunOpExpectFailure( @@ -137,24 +137,10 @@ TEST_F(KPFusedGatherTest, Invalid_DataDimsNot2) { data ); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(s.error_message().find("identity dims must == 2") != std::string::npos); + EXPECT_TRUE(s.error_message().find("Embedding table shape must be 2D") != std::string::npos); } -// 反例2:data 第二维不是12 -TEST_F(KPFusedGatherTest, Invalid_DataDimSizeNot12) { - std::vector data(2 * 10, 1.0f); - Status s = RunOpExpectFailure( - TensorShape({2, 10}), // data 第二维不是12 - TensorShape({2, 2}), - {0, 0}, - {0, 1, 2, 3}, - data - ); - EXPECT_FALSE(s.ok()); - EXPECT_TRUE(s.error_message().find("identity dim size must == [n, 12]") != std::string::npos); -} - -// 反例3:slice_input 不是2维 +// key 不是2维 TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) { std::vector data(2 * 12, 1.0f); Status s = RunOpExpectFailure( @@ -165,10 +151,10 @@ TEST_F(KPFusedGatherTest, Invalid_SliceInputDimsNot2) { data ); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(s.error_message().find("slice_input dims must == 2") != std::string::npos); + EXPECT_TRUE(s.error_message().find("Input key must be 2D") != std::string::npos); } -// 反例4: begin[1] 超出列范围 +// begin[1] 超出列范围 TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) { std::vector data(2 * 12, 1.0f); Status s = RunOpExpectFailure( @@ -179,10 +165,10 @@ TEST_F(KPFusedGatherTest, Invalid_BeginColOutOfRange) { data ); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(s.error_message().find("begin[1] must < slice_input.dim_size(1)") != std::string::npos); + EXPECT_TRUE(s.error_message().find("slice cols out of keys range") != std::string::npos); } -// 反例5: gather 索引超出 data 行数 +// gather 索引超出 data 行数 TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) { std::vector data(2 * 12, 1.0f); Status s = RunOpExpectFailure( @@ -194,7 +180,7 @@ TEST_F(KPFusedGatherTest, Invalid_IndexOutOfRangeInData) { data ); EXPECT_FALSE(s.ok()); - EXPECT_TRUE(s.error_message().find("idx must < data_row") != std::string::npos); + EXPECT_TRUE(s.error_message().find("idx out of table range") != std::string::npos); } } \ No newline at end of file diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py index f4c0795652d5e900988362865de91f44b1e4986f..8e8586b83cd20ae9237db30ee41f43ba1208572b 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_gather_test.py @@ -19,16 +19,12 @@ def ori_fused_embedding_gather_graph(data, slice_input, begin): shrink_axis_mask=2 ) - slice_out, slice_out_indices = tf.unique(slice_out) - output_shape = slice_out - slice_out = tf.reshape(slice_out, [-1]) - slice_out, slice_out_indices2 = tf.unique(slice_out) - - gather1_result = tf.gather(data, slice_out) - gather1_result = tf.reshape(gather1_result, [-1, 12]) - - gather2_result = tf.gather(gather1_result, slice_out_indices2) - return output_shape, slice_out_indices, gather2_result + value, indices = tf.unique(slice_out) + value = tf.reshape(value, [-1]) + value_1, indices_1 = tf.unique(value) + gather1 = tf.gather(data, value_1) + gather2 = tf.gather(gather1, indices_1) + return value, indices, gather2 def opt_fused_embedding_gather_graph(data, slice_input, begin):