代码拉取完成,页面将自动刷新
同步操作将从 openEuler-RISC-V/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From 1d8218f155c1d22c21afda8bf28e36e4094d9e88 Mon Sep 17 00:00:00 2001
From: Ben Barsdell <bbarsdell@nvidia.com>
Date: Fri, 8 Jan 2021 11:04:37 +1100
Subject: [PATCH 1/2] Refactor ReshapeSparseTensor into a template+class
- This is in preparation for adding a GPU implementation.
- No functional change.
---
.../kernels/deserialize_sparse_string_op.cc | 8 +-
tensorflow/core/kernels/reshape_util.cc | 101 ++++++++++++------
tensorflow/core/kernels/reshape_util.h | 18 ++++
tensorflow/core/kernels/sparse_reshape_op.cc | 12 ++-
4 files changed, 98 insertions(+), 41 deletions(-)
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
index 2e151078..3acd86ef 100644
--- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
@@ -35,6 +35,8 @@ limitations under the License.
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
namespace {
using sparse::SparseTensor;
@@ -204,9 +206,9 @@ class DeserializeSparseOp : public OpKernel {
target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
}
- ReshapeSparseTensor(context, output.indices(), input_shape, target_shape,
- 0 /* output indices index */,
- 2 /* output shape index */);
+ ReshapeSparseTensor<CPUDevice>(context, output.indices(), input_shape,
+ target_shape, 0 /* output indices index */,
+ 2 /* output shape index */);
context->set_output(1, output.values());
}
diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc
index 1fce80f7..d0d54738 100644
--- a/tensorflow/core/kernels/reshape_util.cc
+++ b/tensorflow/core/kernels/reshape_util.cc
@@ -31,6 +31,53 @@ limitations under the License.
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+namespace functor {
+
+template <>
+struct ReshapeSparseTensorFunctor<CPUDevice> {
+ Status operator()(const TensorShape &input_shape,
+ const TensorShape &output_shape,
+ typename TTypes<int64>::ConstMatrix input_indices,
+ typename TTypes<int64>::Matrix output_indices) const {
+ const int64 input_rank = input_shape.dims();
+ const int64 output_rank = output_shape.dims();
+ const int64 nnz = input_indices.dimension(0);
+ gtl::InlinedVector<int64, 8> input_strides(input_rank);
+ if (input_rank > 0) {
+ input_strides[input_rank - 1] = 1;
+ for (int d = input_rank - 2; d >= 0; --d) {
+ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+ }
+ }
+
+ gtl::InlinedVector<int64, 8> output_strides(output_rank);
+ if (output_rank > 0) {
+ output_strides[output_rank - 1] = 1;
+ for (int d = output_rank - 2; d >= 0; --d) {
+ output_strides[d] =
+ output_strides[d + 1] * output_shape.dim_size(d + 1);
+ }
+ }
+
+ for (int i = 0; i < nnz; ++i) {
+ int64 id = 0;
+ for (int j = 0; j < input_rank; ++j) {
+ id += input_indices(i, j) * input_strides[j];
+ }
+ for (int j = 0; j < output_rank; ++j) {
+ output_indices(i, j) = id / output_strides[j];
+ id %= output_strides[j];
+ }
+ }
+ return Status::OK();
+ }
+};
+
+} // namespace functor
+
+template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
@@ -111,40 +158,6 @@ void ReshapeSparseTensor(OpKernelContext *context,
return;
}
- gtl::InlinedVector<int64, 8> input_strides(input_rank);
- if (input_rank > 0) {
- input_strides[input_rank - 1] = 1;
- for (int d = input_rank - 2; d >= 0; --d) {
- input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
- }
- }
-
- gtl::InlinedVector<int64, 8> output_strides(output_rank);
- if (output_rank > 0) {
- output_strides[output_rank - 1] = 1;
- for (int d = output_rank - 2; d >= 0; --d) {
- output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
- }
- }
-
- Tensor *result_indices = nullptr;
- OP_REQUIRES_OK(context,
- context->allocate_output(output_indices_idx,
- TensorShape({nnz, output_rank}),
- &result_indices));
- auto input_ind = input_indices_in.matrix<int64>();
- auto output_ind = result_indices->matrix<int64>();
- for (int i = 0; i < nnz; ++i) {
- int64 id = 0;
- for (int j = 0; j < input_rank; ++j) {
- id += input_ind(i, j) * input_strides[j];
- }
- for (int j = 0; j < output_rank; ++j) {
- output_ind(i, j) = id / output_strides[j];
- id %= output_strides[j];
- }
- }
-
Tensor *result_shape = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(output_shape_idx,
TensorShape({output_rank}),
@@ -153,6 +166,26 @@ void ReshapeSparseTensor(OpKernelContext *context,
for (int j = 0; j < output_shape.dims(); ++j) {
output_shape_vec(j) = output_shape.dim_size(j);
}
+
+ Tensor *result_indices = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(output_indices_idx,
+ TensorShape({nnz, output_rank}),
+ &result_indices));
+ if (nnz > 0) {
+ OP_REQUIRES_OK(context, functor::ReshapeSparseTensorFunctor<Device>()(
+ input_shape, output_shape,
+ input_indices_in.matrix<int64>(),
+ result_indices->matrix<int64>()));
+ }
}
+#define EXPLICITLY_INSTANTIATE_FUNCTION(Device) \
+ template void ReshapeSparseTensor<Device>( \
+ OpKernelContext *context, const Tensor &input_indices_in, \
+ const Tensor &input_shape_in, const Tensor &target_shape_in, \
+ int output_indices_idx, int output_shape_idx)
+EXPLICITLY_INSTANTIATE_FUNCTION(CPUDevice);
+#undef EXPLICITLY_INSTANTIATE_FUNCTION
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h
index 7e1809e8..b3a35651 100644
--- a/tensorflow/core/kernels/reshape_util.h
+++ b/tensorflow/core/kernels/reshape_util.h
@@ -16,18 +16,36 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/status.h"
+
namespace tensorflow {
class OpKernelContext;
class Tensor;
// Reshapes the input indices and input shape to the target shape.
+// Note: This template is explicitly instantiated for CPU device only.
+template <typename Device>
void ReshapeSparseTensor(OpKernelContext *context,
const Tensor &input_indices_in,
const Tensor &input_shape_in,
const Tensor &target_shape_in, int output_indices_idx,
int output_shape_idx);
+namespace functor {
+
+template <typename Device>
+struct ReshapeSparseTensorFunctor {
+ Status operator()(const TensorShape &input_shape,
+ const TensorShape &output_shape,
+ typename TTypes<int64>::ConstMatrix input_indices,
+ typename TTypes<int64>::Matrix output_indices) const;
+};
+
+} // namespace functor
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_
diff --git a/tensorflow/core/kernels/sparse_reshape_op.cc b/tensorflow/core/kernels/sparse_reshape_op.cc
index 3896c959..490d9ffd 100644
--- a/tensorflow/core/kernels/sparse_reshape_op.cc
+++ b/tensorflow/core/kernels/sparse_reshape_op.cc
@@ -30,6 +30,9 @@ limitations under the License.
namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <typename Device>
class SparseReshapeOp : public OpKernel {
public:
explicit SparseReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -46,12 +49,13 @@ class SparseReshapeOp : public OpKernel {
input_indices_in.dim_size(1) == input_shape_in.dim_size(0),
errors::InvalidArgument(
"Input tensor rank must match input shape length."));
- ReshapeSparseTensor(context, context->input(0), context->input(1),
- context->input(2), 0 /* output indices index */,
- 1 /* output shape index */);
+ ReshapeSparseTensor<Device>(
+ context, context->input(0), context->input(1), context->input(2),
+ 0 /* output indices index */, 1 /* output shape index */);
}
};
REGISTER_KERNEL_BUILDER(Name("SparseReshape").Device(DEVICE_CPU),
- SparseReshapeOp)
+ SparseReshapeOp<CPUDevice>)
+
} // namespace tensorflow
--
2.27.0
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。