1 Star 0 Fork 51

gice/tensorflow

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
CVE-2021-29583.patch 2.65 KB
一键复制 编辑 原始数据 按行查看 历史
starlet_dx 提交于 2021-08-31 15:06 +08:00 . fix the cves to tensorflow
From 6972f9dfe325636b3db4e0bc517ee22a159365c0 Mon Sep 17 00:00:00 2001
From: Mihai Maruseac <mihaimaruseac@google.com>
Date: Thu, 6 May 2021 17:45:51 -0700
Subject: [PATCH] Add missing valuidation to FusedBatchNorm.
---
.../core/kernels/fused_batch_norm_op.cc | 29 ++++++++++++++++++-
1 file changed, 28 insertions(+), 1 deletion(-)
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 59470c8a..bd5dab36 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -1267,6 +1267,33 @@ class FusedBatchNormOpBase : public OpKernel {
context, estimated_variance.dims() == 1,
errors::InvalidArgument("estimated_variance must be 1-dimensional",
estimated_variance.shape().DebugString()));
+
+ const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
+ OP_REQUIRES(
+ context, scale.NumElements() == num_channels,
+ errors::InvalidArgument("scale must have the same number of elements "
+ "as the channels of x, got ",
+ scale.NumElements(), " and ", num_channels));
+ OP_REQUIRES(
+ context, offset.NumElements() == num_channels,
+ errors::InvalidArgument("offset must have the same number of elements "
+ "as the channels of x, got ",
+ offset.NumElements(), " and ", num_channels));
+ if (estimated_mean.NumElements() != 0) {
+ OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
+ errors::InvalidArgument(
+ "mean must be empty or have the same number of "
+ "elements as the channels of x, got ",
+ estimated_mean.NumElements(), " and ",num_channels));
+ }
+ if (estimated_variance.NumElements() != 0) {
+ OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
+ errors::InvalidArgument(
+ "variance must be empty or have the same number of "
+ "elements as the channels of x, got ",
+ estimated_variance.NumElements(), " and ", num_channels));
+ }
+
if (has_side_input_) {
OP_REQUIRES(context, side_input->shape() == x.shape(),
errors::InvalidArgument(
@@ -1279,7 +1306,7 @@ class FusedBatchNormOpBase : public OpKernel {
// NOTE(ezhulenev): This requirement is coming from implementation
// details of cudnnBatchNormalizationForwardTrainingEx.
OP_REQUIRES(
- context, !is_training_ || x.dim_size(3) % 4 == 0,
+ context, !is_training_ || num_channels % 4 == 0,
errors::InvalidArgument("FusedBatchNorm with activation requires "
"channel dimension to be a multiple of 4."));
}
--
2.23.0
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gice/tensorflow.git
git@gitee.com:gice/tensorflow.git
gice
tensorflow
tensorflow
master

搜索帮助