diff --git a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc index e353b56556b4379cdea285e4945833fbecc32f9d..9937a07e095059998e34819c7c80ce968298e70d 100644 --- a/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc +++ b/tensorflow/core/kernels/embedding_fused_sparse_dynamic_stitch.cc @@ -38,24 +38,28 @@ class KPFusedSparseDynamicStitchOp : public OpKernel { const int num_inputs = context->num_inputs(); const int num_partitions = num_inputs - 1; + int output_stride = 0; std::vector variables(num_partitions); for (int i = 1; i < num_inputs; ++i) { + if (i == 1) { + const Tensor& input_tensor = context->input(i); + if (input_tensor.shape().dims() == 2) { + output_stride = input_tensor.shape().dim_size(1); + } + } variables[i - 1] = context->input(i).flat().data(); } Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_elems, 10}), + context->allocate_output(0, TensorShape({num_elems, output_stride}), &output_tensor)); output = (float*)output_tensor->tensor_data().data(); - const int64_t output_stride = 10; const size_t copy_size = output_stride * sizeof(float); - auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); const int64 cost_per_unit = 1000 * num_elems; auto work = [&](int start, int end) { - const int64_t output_stride = 10; const size_t copy_size = output_stride * sizeof(float); for (int i = start; i < end; ++i) {