1 Star 0 Fork 51

gice/tensorflow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
CVE-2021-29530.patch 5.52 KB
一键复制 编辑 原始数据 按行查看 历史
From e6a7c7cc18c3aaad1ae0872cb0a959f5c923d2bd Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Tue, 20 Apr 2021 14:45:33 -0700
Subject: [PATCH] Remove `OP_REQUIRES` call from helper function.
Since `OP_REQUIRES` macro expands to a `return;` (among other), calling it in a helper function only ends the helper function's execution earlier, but the kernel will still run from start to end. Thus, all the expected validations are actually broken/useless as the code ploughs through the next crash anyway.
PiperOrigin-RevId: 369524386
Change-Id: I54f6cf9328445675ccc392e661b04336b229c9da
---
.../core/kernels/sparse/sparse_cholesky_op.cc | 67 ++++++++++---------
1 file changed, 34 insertions(+), 33 deletions(-)
diff --git a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
index 9a939276f0b6c..47ab252317de5 100644
--- a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
+++ b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "tensorflow/core/framework/op_requires.h"
+
#define EIGEN_USE_THREADS
#include "third_party/eigen3/Eigen/Core"
@@ -82,8 +84,8 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
int64 num_rows;
int batch_size;
- ValidateInputs(ctx, *input_matrix, input_permutation_indices, &batch_size,
- &num_rows);
+ OP_REQUIRES_OK(ctx, ValidateInputs(*input_matrix, input_permutation_indices,
+ &batch_size, &num_rows));
// Allocate batch pointers.
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
@@ -226,49 +228,48 @@ class CSRSparseCholeskyCPUOp : public OpKernel {
}
private:
- void ValidateInputs(OpKernelContext* ctx,
- const CSRSparseMatrix& sparse_matrix,
- const Tensor& permutation_indices, int* batch_size,
- int64* num_rows) {
- OP_REQUIRES(ctx, sparse_matrix.dtype() == DataTypeToEnum<T>::value,
- errors::InvalidArgument(
- "Asked for a CSRSparseMatrix of type ",
- DataTypeString(DataTypeToEnum<T>::value),
- " but saw dtype: ", DataTypeString(sparse_matrix.dtype())));
+ Status ValidateInputs(const CSRSparseMatrix& sparse_matrix,
+ const Tensor& permutation_indices, int* batch_size,
+ int64* num_rows) {
+ if (sparse_matrix.dtype() != DataTypeToEnum<T>::value)
+ return errors::InvalidArgument(
+ "Asked for a CSRSparseMatrix of type ",
+ DataTypeString(DataTypeToEnum<T>::value),
+ " but saw dtype: ", DataTypeString(sparse_matrix.dtype()));
const Tensor& dense_shape = sparse_matrix.dense_shape();
const int rank = dense_shape.dim_size(0);
- OP_REQUIRES(ctx, rank == 2 || rank == 3,
- errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
- "but dense_shape has size ", rank));
+ if (rank < 2 || rank > 3)
+ return errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
+ "but dense_shape has size ", rank);
const int row_dim = (rank == 2) ? 0 : 1;
auto dense_shape_vec = dense_shape.vec<int64>();
*num_rows = dense_shape_vec(row_dim);
const int64 num_cols = dense_shape_vec(row_dim + 1);
- OP_REQUIRES(ctx, *num_rows == num_cols,
- errors::InvalidArgument("sparse matrix must be square; got: ",
- *num_rows, " != ", num_cols));
+ if (*num_rows != num_cols)
+ return errors::InvalidArgument(
+ "sparse matrix must be square; got: ", *num_rows, " != ", num_cols);
const TensorShape& perm_shape = permutation_indices.shape();
- OP_REQUIRES(
- ctx, perm_shape.dims() + 1 == rank,
- errors::InvalidArgument(
- "sparse matrix must have the same rank as permutation; got: ", rank,
- " != ", perm_shape.dims(), " + 1."));
- OP_REQUIRES(
- ctx, perm_shape.dim_size(rank - 2) == *num_rows,
- errors::InvalidArgument(
- "permutation must have the same number of elements in each batch "
- "as the number of rows in sparse matrix; got: ",
- perm_shape.dim_size(rank - 2), " != ", *num_rows));
+ if (perm_shape.dims() + 1 != rank)
+ return errors::InvalidArgument(
+ "sparse matrix must have the same rank as permutation; got: ", rank,
+ " != ", perm_shape.dims(), " + 1.");
+ if (perm_shape.dim_size(rank - 2) != *num_rows)
+ return errors::InvalidArgument(
+ "permutation must have the same number of elements in each batch "
+ "as the number of rows in sparse matrix; got: ",
+ perm_shape.dim_size(rank - 2), " != ", *num_rows);
*batch_size = sparse_matrix.batch_size();
if (*batch_size > 1) {
- OP_REQUIRES(
- ctx, perm_shape.dim_size(0) == *batch_size,
- errors::InvalidArgument("permutation must have the same batch size "
- "as sparse matrix; got: ",
- perm_shape.dim_size(0), " != ", *batch_size));
+ if (perm_shape.dim_size(0) != *batch_size)
+ return errors::InvalidArgument(
+ "permutation must have the same batch size "
+ "as sparse matrix; got: ",
+ perm_shape.dim_size(0), " != ", *batch_size);
}
+
+ return Status::OK();
}
};
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gice/tensorflow.git
git@gitee.com:gice/tensorflow.git
gice
tensorflow
tensorflow
master

搜索帮助