From 7214edd8c25f73765f0df7e2966a78254639157f Mon Sep 17 00:00:00 2001 From: lianghuikang <505519763@qq.com> Date: Mon, 29 Mar 2021 15:54:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0experimental=20options?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- convert_tf2npu/ast_impl.py | 5 +++++ tf_adapter/python/npu_bridge/npu_init.py | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/convert_tf2npu/ast_impl.py b/convert_tf2npu/ast_impl.py index c586c4585..80701afbf 100644 --- a/convert_tf2npu/ast_impl.py +++ b/convert_tf2npu/ast_impl.py @@ -97,6 +97,11 @@ def ast_if(node): return node def ast_call(node): + if _call_name_match(node.func, "set_experimental_options"): + log_msg(getattr(node, 'lineno', 'None'), 'change set_experimental_options(*) to set_experimental_options(experimental_options)') + node.args = [ast.Name(id='experimental_options', ctx=ast.Load())] + node.keywords = [] + util_global.set_value('need_conver', True) if isinstance(node.func, ast.Name) and node.func.id == 'check_available_gpus': log_msg(getattr(node, 'lineno', 'None'), "change check_available_gpus() to ['/device:CPU:0']") util_global.set_value('need_conver', True) diff --git a/tf_adapter/python/npu_bridge/npu_init.py b/tf_adapter/python/npu_bridge/npu_init.py index cd1d648cf..dbba95df7 100644 --- a/tf_adapter/python/npu_bridge/npu_init.py +++ b/tf_adapter/python/npu_bridge/npu_init.py @@ -47,6 +47,18 @@ from npu_bridge.estimator.npu.npu_plugin import npu_close import atexit atexit.register(npu_close) +experimental_options = { + "disable_model_pruning": True, + "function_optimization": RewriterConfig.OFF, + "constant_folding": RewriterConfig.OFF, + "shape_optimization": RewriterConfig.OFF, + "arithmetic_optimization": RewriterConfig.OFF, + "loop_optimization": RewriterConfig.OFF, + "dependency_optimization": RewriterConfig.OFF, + "layout_optimizer": RewriterConfig.OFF, + "memory_optimization": RewriterConfig.OFF +} + def npu_hooks_append(hooks_list=[]): if (not isinstance(hooks_list, list)): hooks_list = [] -- Gitee