diff --git a/tensorflow/core/kernels/embedding_fused_gather.cc b/tensorflow/core/kernels/embedding_fused_gather.cc index 404e4cc485d986424a0d826962e8d20725534c29..8a3a585a466c9dfee3e3155299d84d2507190035 100644 --- a/tensorflow/core/kernels/embedding_fused_gather.cc +++ b/tensorflow/core/kernels/embedding_fused_gather.cc @@ -20,7 +20,7 @@ limitations under the License. using namespace tensorflow; class KPFusedGather : public OpKernel { -public: + public: explicit KPFusedGather(OpKernelConstruction* context) : OpKernel(context) { } void Compute(OpKernelContext* context) override { @@ -30,6 +30,7 @@ public: OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); OP_REQUIRES(context, data.dims() == 2, errors::Internal("indentity dims must == 2")); + OP_REQUIRES(context, data.dim_size(1) == 12, errors::Internal("indentity dim size must == [n, 12]")); VLOG(1) << "Input indentity shape: " << data.shape().DebugString(); VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); @@ -37,6 +38,7 @@ public: VLOG(1) << "Input begin value: " << begin.SummarizeValue(10); int32 col = begin.flat().data()[1]; + OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)")); auto data_mat = data.matrix(); auto slice_input_mat = slice_input.matrix(); @@ -82,6 +84,7 @@ public: int64_t cols = data.dim_size(1); for (int64_t cur_row = 0; cur_row < unique_values.size(); ++cur_row) { int64_t idx = unique_values[cur_row]; + OP_REQUIRES(context, idx < data.dim_size(0), errors::Internal("idx must < data_row")); const float* src = data_mat.data() + idx * cols; float* dst = output_data.data() + cur_row * cols; std::memcpy(dst, src, cols * sizeof(float)); diff --git a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc index 66dfa4838da1c91d34d759f95b4af0b7ef8834ac..5e84b5bd23dc251731a84f21cac0b97ae2fbec09 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_reshape.cc @@ -153,14 +153,16 @@ class KPFusedSparseReshapeOp : public OpKernel { const Tensor& slice_input = context->input(0); const Tensor& begin = context->input(1); const Tensor& new_shape = context->input(2); + const Tensor& pack_const = context->input(3); OP_REQUIRES(context, slice_input.dims() == 2, errors::Internal("slice_input dims must == 2")); - + OP_REQUIRES(context, new_shape.dim_size(0) == 2, errors::Internal("new_shape dim size must == 2")); VLOG(1) << "Input slice_input shape: " << slice_input.shape().DebugString(); VLOG(1) << "Input begin value: " << begin.DebugString(); VLOG(1) << "Input new_shape value: " << new_shape.DebugString(); int32 col = begin.flat().data()[1]; + OP_REQUIRES(context, col < slice_input.dim_size(1), errors::Internal("begin[1] must < slice_input.dim_size(1)")); int64_t stridedslice57_out = slice_input.dim_size(0); auto slice_input_mat = slice_input.matrix(); @@ -174,8 +176,8 @@ class KPFusedSparseReshapeOp : public OpKernel { Tensor shape_in(DT_INT64, TensorShape({2})); auto tensor_flat = shape_in.flat(); tensor_flat(0) = stridedslice57_out; - tensor_flat(1) = 2; - + tensor_flat(1) = pack_const.flat().data()[0]; + Tensor indices_in(DT_INT64, TensorShape({stridedslice57_out, 2})); auto indices_in_mat = indices_in.matrix(); for (int i = 0; i < stridedslice57_out; ++i) { diff --git a/tensorflow/core/ops/embedding_fused_ops.cc b/tensorflow/core/ops/embedding_fused_ops.cc index 27a667e5921d6403327076edd5b92384d1f78add..da4550c7d5e89c013fa5d87826f719dd0b584122 100644 --- a/tensorflow/core/ops/embedding_fused_ops.cc +++ b/tensorflow/core/ops/embedding_fused_ops.cc @@ -96,6 +96,7 @@ REGISTER_OP("KPFusedSparseReshape") .Input("slice_input: int64") .Input("begin: int32") .Input("new_shape: int64") + .Input("pack_const: int64") .Output("out_indices: int64") .Output("out_shape: int64") .SetShapeFn(shape_inference::UnknownShape); diff --git a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py index 0930d933125d224fc126aba578145df32560cb10..5fca54b8ebf5262c1d91649a6dd4120955ff6060 100644 --- a/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py +++ b/tensorflow/python/grappler/embedding_fused_test/fused_embedding_sparse_reshape_test.py @@ -8,7 +8,7 @@ from tensorflow.python.ops import gen_embedding_fused_ops from utils.utils import benchmark_op -def ori_sparse_reshape_graph(slice_input, begin, newshape): +def ori_sparse_reshape_graph(slice_input, begin, newshape, pack_const): slice67_out = tf.strided_slice( slice_input, begin=begin, @@ -29,7 +29,7 @@ def ori_sparse_reshape_graph(slice_input, begin, newshape): shrink_axis_mask=1 ) - const2 = tf.constant(2) + const2 = pack_const input_shape = tf.stack([slice57_out, const2]) input_shape = tf.cast(input_shape, tf.int64) @@ -49,11 +49,12 @@ def ori_sparse_reshape_graph(slice_input, begin, newshape): return sparse_tensor_out.indices, sparse_tensor_out.dense_shape, concat_out -def opt_sparse_reshape_graph(slice_input, begin, newshape): +def opt_sparse_reshape_graph(slice_input, begin, newshape, pack_const): custom_out1, custom_out2 = gen_embedding_fused_ops.KPFusedSparseReshape( slice_input=slice_input, begin=begin, - new_shape=newshape + new_shape=newshape, + pack_const=pack_const, ) return custom_out1, custom_out2 @@ -75,7 +76,7 @@ class TestFusedSparseReshape(unittest.TestCase): # cls.sess.close() return - def _run_kp_reshape_test(self, slice_shape, base_slice_input, base_begin, base_newshape, num_runs): + def _run_kp_reshape_test(self, slice_shape, base_slice_input, base_begin, base_newshape, pack_const, num_runs): with tf.Graph().as_default(): slice_input = tf.compat.v1.placeholder(tf.int64, shape=slice_shape, name="slice_input") begin = tf.compat.v1.placeholder(tf.int32, shape=(2,), name="begin") @@ -88,9 +89,9 @@ class TestFusedSparseReshape(unittest.TestCase): } with tf.name_scope("ori"): - out_ori1, out_ori2, out_ori3 = ori_sparse_reshape_graph(slice_input, begin, newshape) + out_ori1, out_ori2, out_ori3 = ori_sparse_reshape_graph(slice_input, begin, newshape, pack_const) with tf.name_scope("opt"): - out_opt1, out_opt2 = opt_sparse_reshape_graph(slice_input, begin, newshape) + out_opt1, out_opt2 = opt_sparse_reshape_graph(slice_input, begin, newshape, pack_const) with tf.compat.v1.Session(config=self.config) as sess: out_ori_val1, out_ori_val2, out_ori_val3 = sess.run( @@ -138,13 +139,15 @@ class TestFusedSparseReshape(unittest.TestCase): base_slice_input = np.array([[0, 0], [0, 1], [1, 2], [3, 4]], dtype=np.int64) base_begin = [0, 1] base_newshape = [2, 4] - self._run_kp_reshape_test((4, 2), base_slice_input, base_begin, base_newshape, num_runs=100) + pack_const = 2 + self._run_kp_reshape_test((4, 2), base_slice_input, base_begin, base_newshape, pack_const, num_runs=100) def test_kp_reshape_2(self): base_slice_input = np.array([[0, 1]], dtype=np.int64) base_begin = [0, 1] base_newshape = [-1, 1] - self._run_kp_reshape_test((1, 2), base_slice_input, base_begin, base_newshape, num_runs=100) + pack_const = 1 + self._run_kp_reshape_test((1, 2), base_slice_input, base_begin, base_newshape, pack_const, num_runs=100) if __name__ == "__main__":