1 Star 0 Fork 51

gice/tensorflow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
CVE-2021-29547.patch 6.17 KB
一键复制 编辑 原始数据 按行查看 历史
From d6ed5bcfe1dcab9e85a4d39931bd18d99018e75b Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Fri, 23 Apr 2021 11:40:06 -0700
Subject: [PATCH] Add missing validation in
`QuantizedBatchNormWithGlobalNormalization`
PiperOrigin-RevId: 370123451
Change-Id: Id234d6dab1ec21230bb8e503dba30f899af87f33
---
.../core/kernels/quantized_batch_norm_op.cc | 77 ++++++++++++++++---
1 file changed, 67 insertions(+), 10 deletions(-)
diff --git a/tensorflow/core/kernels/quantized_batch_norm_op.cc b/tensorflow/core/kernels/quantized_batch_norm_op.cc
index b03da7ad17fab..6dfe07f97a400 100644
--- a/tensorflow/core/kernels/quantized_batch_norm_op.cc
+++ b/tensorflow/core/kernels/quantized_batch_norm_op.cc
@@ -173,20 +173,50 @@ class QuantizedBatchNormOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- const float input_min = context->input(1).flat<float>()(0);
- const float input_max = context->input(2).flat<float>()(0);
+ const auto& input_min_tensor = context->input(1);
+ OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("input_min must have 1 element"));
+ const float input_min = input_min_tensor.flat<float>()(0);
+ const auto& input_max_tensor = context->input(2);
+ OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("input_max must have 1 element"));
+ const float input_max = input_max_tensor.flat<float>()(0);
const Tensor& mean = context->input(3);
- const float mean_min = context->input(4).flat<float>()(0);
- const float mean_max = context->input(5).flat<float>()(0);
+ const auto& mean_min_tensor = context->input(4);
+ OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("mean_min must have 1 element"));
+ const float mean_min = mean_min_tensor.flat<float>()(0);
+ const auto& mean_max_tensor = context->input(5);
+ OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("mean_max must have 1 element"));
+ const float mean_max = mean_max_tensor.flat<float>()(0);
const Tensor& var = context->input(6);
- const float var_min = context->input(7).flat<float>()(0);
- const float var_max = context->input(8).flat<float>()(0);
+ const auto& var_min_tensor = context->input(7);
+ OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("var_min must have 1 element"));
+ const float var_min = var_min_tensor.flat<float>()(0);
+ const auto& var_max_tensor = context->input(8);
+ OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("var_max must have 1 element"));
+ const float var_max = var_max_tensor.flat<float>()(0);
const Tensor& beta = context->input(9);
- const float beta_min = context->input(10).flat<float>()(0);
- const float beta_max = context->input(11).flat<float>()(0);
+ const auto& beta_min_tensor = context->input(10);
+ OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("beta_min must have 1 element"));
+ const float beta_min = beta_min_tensor.flat<float>()(0);
+ const auto& beta_max_tensor = context->input(11);
+ OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("beta_max must have 1 element"));
+ const float beta_max = beta_max_tensor.flat<float>()(0);
const Tensor& gamma = context->input(12);
- const float gamma_min = context->input(13).flat<float>()(0);
- const float gamma_max = context->input(14).flat<float>()(0);
+ const auto& gamma_min_tensor = context->input(13);
+ OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("gamma_min must have 1 element"));
+ const float gamma_min = gamma_min_tensor.flat<float>()(0);
+ const auto& gamma_max_tensor = context->input(14);
+ OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("gamma_max must have 1 element"));
+ const float gamma_max = gamma_max_tensor.flat<float>()(0);
OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
@@ -203,6 +233,33 @@ class QuantizedBatchNormOp : public OpKernel {
OP_REQUIRES(context, gamma.dims() == 1,
errors::InvalidArgument("gamma must be 1-dimensional",
gamma.shape().DebugString()));
+ OP_REQUIRES(context, mean.NumElements() > 1,
+ errors::InvalidArgument("Must have at least a mean value",
+ gamma.shape().DebugString()));
+ OP_REQUIRES(context, mean.NumElements() > 1,
+ errors::InvalidArgument("Must have at least a mean value"));
+ const auto last_dim = input.shape().dims() - 1;
+ OP_REQUIRES(context,
+ mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
+ errors::InvalidArgument("Must provide as many means as the "
+ "last dimension of the input tensor: ",
+ mean.shape().DebugString(), " vs. ",
+ input.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == var.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and variance tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and beta tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and gamma tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gice/tensorflow.git
git@gitee.com:gice/tensorflow.git
gice
tensorflow
tensorflow
master

搜索帮助