From e446f73babbdd3495316df2546bde8e36a0ab7ac Mon Sep 17 00:00:00 2001 From: limingxing517 Date: Wed, 1 Feb 2023 03:57:04 +0000 Subject: [PATCH] fp32 Signed-off-by: limingxing517 --- .../WideDeep_ID2712_for_TensorFlow/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/TensorFlow/built-in/recommendation/WideDeep_ID2712_for_TensorFlow/train.py b/TensorFlow/built-in/recommendation/WideDeep_ID2712_for_TensorFlow/train.py index 1f7e5fd95..849af2acb 100644 --- a/TensorFlow/built-in/recommendation/WideDeep_ID2712_for_TensorFlow/train.py +++ b/TensorFlow/built-in/recommendation/WideDeep_ID2712_for_TensorFlow/train.py @@ -280,6 +280,9 @@ def parse_args(): help="size of train data") parser.add_argument("--display_step", default= config.display_step, help="display step") + parser.add_argument('--precision_mode', default='allow_mix_precision', + help='allow_fp32_to_fp16/force_fp16/ ' + 'must_keep_origin_dtype/allow_mix_precision.') args = parser.parse_args() '''args, unknown_args = parser.parse_known_args() if len(unknown_args) > 0: @@ -291,7 +294,7 @@ def parse_args(): if __name__ == '__main__': display_step = config.display_step - os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" + #os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1" tag = algo Base_path = config.BASE_DIR @@ -336,9 +339,11 @@ if __name__ == '__main__': #custom_op.parameter_map["mix_compile_mode"].b = True #开启混合计算,根据实际情况配置 custom_op.parameter_map["use_off_line"].b = True custom_op.parameter_map["min_group_size"].b = 1 - custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") + custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes(args.precision_mode) custom_op.parameter_map["hcom_parallel"].b = True custom_op.parameter_map["iterations_per_loop"].i = config.iterations_per_loop + if args.precision_mode == "allow_mix_precision": + custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("ops_info.json") custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("ops_info.json") custom_op.parameter_map["fusion_switch_file"].s = tf.compat.as_bytes("fusion_switch.cfg") #aic err debug -- Gitee