From 02b2b2e9c22be9c79f56ec80fb8b472349621535 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E8=8E=89?= <2020857336@qq.com> Date: Thu, 23 Nov 2023 06:08:31 +0000 Subject: [PATCH] =?UTF-8?q?v1=E8=BD=ACv2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 石莉 <2020857336@qq.com> --- .../models/research/object_detection/model_main.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/TensorFlow/built-in/cv/detection/SSD-Resnet50V1-FPN_ID1463_for_TensorFlow/models/research/object_detection/model_main.py b/TensorFlow/built-in/cv/detection/SSD-Resnet50V1-FPN_ID1463_for_TensorFlow/models/research/object_detection/model_main.py index b2706af1f..6fcc6fd8d 100644 --- a/TensorFlow/built-in/cv/detection/SSD-Resnet50V1-FPN_ID1463_for_TensorFlow/models/research/object_detection/model_main.py +++ b/TensorFlow/built-in/cv/detection/SSD-Resnet50V1-FPN_ID1463_for_TensorFlow/models/research/object_detection/model_main.py @@ -35,6 +35,8 @@ from npu_bridge.npu_init import * from tensorflow.core.protobuf import config_pb2 from absl import flags import tensorflow as tf +tf.enable_control_flow_v2() +tf.enable_resource_variables() #import horovod.tensorflow as hvd import dllogger import time @@ -81,7 +83,10 @@ class DLLoggerHook(tf.estimator.SessionRunHook): def before_run(self, run_context): self.t0 = time.time() - return tf.estimator.SessionRunArgs(fetches=['global_step:0', 'learning_rate:0']) + global_step = tf.get_collection('global_step:0') + learning_rate = tf.get_collection('learning_rate:0') + #return tf.estimator.SessionRunArgs(fetches=['global_step:0', 'learning_rate:0']) + return tf.estimator.SessionRunArgs(fetches=[global_step,learning_rate]) def after_run(self, run_context, run_values): throughput = (self.global_batch_size / (time.time() - self.t0)) @@ -95,7 +100,10 @@ class DLLoggerHook(tf.estimator.SessionRunHook): ###############################NPU_modify add##################################### class _LogSessionRunHook(tf.train.SessionRunHook): def before_run(self, run_context): - return tf.estimator.SessionRunArgs(fetches=['overflow_status_reduce_all:0', 'loss_scale:0']) + overflow_status_reduce_all = tf.get_collection('overflow_status_reduce_all:0') + loss_scale = tf.get_collection('loss_scale:0') + #return tf.estimator.SessionRunArgs(fetches=['overflow_status_reduce_all:0', 'loss_scale:0']) + return tf.estimator.SessionRunArgs(fetches=[overflow_status_reduce_all,loss_scale]) def after_run(self, run_context, run_values): if not run_values.results[0]: -- Gitee