代码拉取完成,页面将自动刷新
同步操作将从 openEuler-RISC-V/tensorflow 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
From d6ed5bcfe1dcab9e85a4d39931bd18d99018e75b Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Fri, 23 Apr 2021 11:40:06 -0700
Subject: [PATCH] Add missing validation in
`QuantizedBatchNormWithGlobalNormalization`
PiperOrigin-RevId: 370123451
Change-Id: Id234d6dab1ec21230bb8e503dba30f899af87f33
---
.../core/kernels/quantized_batch_norm_op.cc | 77 ++++++++++++++++---
1 file changed, 67 insertions(+), 10 deletions(-)
diff --git a/tensorflow/core/kernels/quantized_batch_norm_op.cc b/tensorflow/core/kernels/quantized_batch_norm_op.cc
index b03da7ad17fab..6dfe07f97a400 100644
--- a/tensorflow/core/kernels/quantized_batch_norm_op.cc
+++ b/tensorflow/core/kernels/quantized_batch_norm_op.cc
@@ -173,20 +173,50 @@ class QuantizedBatchNormOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
- const float input_min = context->input(1).flat<float>()(0);
- const float input_max = context->input(2).flat<float>()(0);
+ const auto& input_min_tensor = context->input(1);
+ OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("input_min must have 1 element"));
+ const float input_min = input_min_tensor.flat<float>()(0);
+ const auto& input_max_tensor = context->input(2);
+ OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("input_max must have 1 element"));
+ const float input_max = input_max_tensor.flat<float>()(0);
const Tensor& mean = context->input(3);
- const float mean_min = context->input(4).flat<float>()(0);
- const float mean_max = context->input(5).flat<float>()(0);
+ const auto& mean_min_tensor = context->input(4);
+ OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("mean_min must have 1 element"));
+ const float mean_min = mean_min_tensor.flat<float>()(0);
+ const auto& mean_max_tensor = context->input(5);
+ OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("mean_max must have 1 element"));
+ const float mean_max = mean_max_tensor.flat<float>()(0);
const Tensor& var = context->input(6);
- const float var_min = context->input(7).flat<float>()(0);
- const float var_max = context->input(8).flat<float>()(0);
+ const auto& var_min_tensor = context->input(7);
+ OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("var_min must have 1 element"));
+ const float var_min = var_min_tensor.flat<float>()(0);
+ const auto& var_max_tensor = context->input(8);
+ OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("var_max must have 1 element"));
+ const float var_max = var_max_tensor.flat<float>()(0);
const Tensor& beta = context->input(9);
- const float beta_min = context->input(10).flat<float>()(0);
- const float beta_max = context->input(11).flat<float>()(0);
+ const auto& beta_min_tensor = context->input(10);
+ OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("beta_min must have 1 element"));
+ const float beta_min = beta_min_tensor.flat<float>()(0);
+ const auto& beta_max_tensor = context->input(11);
+ OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("beta_max must have 1 element"));
+ const float beta_max = beta_max_tensor.flat<float>()(0);
const Tensor& gamma = context->input(12);
- const float gamma_min = context->input(13).flat<float>()(0);
- const float gamma_max = context->input(14).flat<float>()(0);
+ const auto& gamma_min_tensor = context->input(13);
+ OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
+ errors::InvalidArgument("gamma_min must have 1 element"));
+ const float gamma_min = gamma_min_tensor.flat<float>()(0);
+ const auto& gamma_max_tensor = context->input(14);
+ OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
+ errors::InvalidArgument("gamma_max must have 1 element"));
+ const float gamma_max = gamma_max_tensor.flat<float>()(0);
OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
@@ -203,6 +233,33 @@ class QuantizedBatchNormOp : public OpKernel {
OP_REQUIRES(context, gamma.dims() == 1,
errors::InvalidArgument("gamma must be 1-dimensional",
gamma.shape().DebugString()));
+ OP_REQUIRES(context, mean.NumElements() > 1,
+ errors::InvalidArgument("Must have at least a mean value",
+ gamma.shape().DebugString()));
+ OP_REQUIRES(context, mean.NumElements() > 1,
+ errors::InvalidArgument("Must have at least a mean value"));
+ const auto last_dim = input.shape().dims() - 1;
+ OP_REQUIRES(context,
+ mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
+ errors::InvalidArgument("Must provide as many means as the "
+ "last dimension of the input tensor: ",
+ mean.shape().DebugString(), " vs. ",
+ input.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == var.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and variance tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and beta tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
+ OP_REQUIRES(
+ context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
+ errors::InvalidArgument(
+ "Mean and gamma tensors must have the same shape: ",
+ mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。