diff --git a/convert_tf2npu/ast_impl.py b/convert_tf2npu/ast_impl.py
index d1eb1b9322d53ddeb9c3eb15704a1f267ed99a08..bd59bd504fa9a82d589957d29acadfa75be26676 100644
--- a/convert_tf2npu/ast_impl.py
+++ b/convert_tf2npu/ast_impl.py
@@ -13,8 +13,9 @@
# limitations under the License.
# ============================================================================
import ast
-import util_global
import copy
+import pasta
+import util_global
from util import log_success_report
from util import log_migration_report
from util import log_msg
@@ -31,7 +32,8 @@ def import_from(node):
if "keras" in values:
util_global.set_value('is_keras_net', True)
if "horovod" in values:
- util_global.set_value('is_hvd_net', True)
+ util_global.set_value('has_hccl_api', True)
+ util_global.set_value('need_conver', True)
def ast_import(node):
@@ -41,7 +43,8 @@ def ast_import(node):
if "keras" in values:
util_global.set_value('is_keras_net', True)
if "horovod" in values:
- util_global.set_value('is_hvd_net', True)
+ util_global.set_value('has_hccl_api', True)
+ util_global.set_value('need_conver', True)
def ast_function_def(node):
log_success_report(getattr(node, "lineno", "None"), node.name)
@@ -60,6 +63,8 @@ def ast_if(node):
if isinstance(node.test, ast.Compare):
if len(node.test.comparators) == 1 and isinstance(node.test.comparators[0], ast.Str):
if node.test.comparators[0].s == "__main__":
+ util_global.set_value("is_main_file", False)
+ util_global.set_value("has_main_func", True)
if util_global.get_value("is_keras_net", False):
log_msg(getattr(node, "lineno", "None"), " add keras session npu config")
close_sess_call = ast.Call(func=ast.Name(id="close_session", ctx=ast.Load()),
@@ -67,11 +72,9 @@ def ast_if(node):
keras_sess_assign = ast.Assign(targets=[ast.Name(id="npu_keras_sess", ctx=ast.Store())],
value=ast.Call(func=ast.Name(id="set_keras_session_npu_config", ctx=ast.Load()),
args=[], keywords=[]))
- try_node = ast.Try(body=[keras_sess_assign, node.body], handlers=[], orelse=[],
- finalbody=[ast.Expr(value=close_sess_call)])
- node.body = [try_node]
+ node.body = [keras_sess_assign] + node.body + [ast.Expr(value=close_sess_call)]
util_global.set_value('need_conver', True)
- if util_global.get_value("is_hvd_net", False):
+ if util_global.get_value("has_hccl_api", False):
log_msg(getattr(node, "lineno", "None"), " add npu resource init api")
close_sess_call = ast.Call(func=ast.Name(id="close_session", ctx=ast.Load()),
args=[ast.Name(id="npu_sess", ctx=ast.Load())], keywords=[])
@@ -82,9 +85,7 @@ def ast_if(node):
shutdown_call = ast.Call(func=ast.Name(id="shutdown_resource", ctx=ast.Load()),
args=[ast.Name(id="npu_sess", ctx=ast.Load()), ast.Name(id="npu_shutdown", ctx=ast.Load())],
keywords=[])
- try_node = ast.Try(body=[init_assign, node.body], handlers=[], orelse=[],
- finalbody=[ast.Expr(value=shutdown_call), ast.Expr(value=close_sess_call)])
- node.body = [try_node]
+ node.body = [init_assign] + node.body + [ast.Expr(value=shutdown_call), ast.Expr(value=close_sess_call)]
util_global.set_value('need_conver', True)
return node
@@ -159,7 +160,6 @@ def ast_call(node):
node.args = []
node.keywords = []
util_global.set_value('need_conver', True)
- util_global.set_value('insert_empty_hook', True)
return node
if isinstance(node.func, ast.Attribute) and node.func.attr == "DistributedOptimizer":
log_success_report(getattr(node, "lineno", "None"), 'DistributedOptimizer')
@@ -168,6 +168,7 @@ def ast_call(node):
log_success_report(getattr(node, "lineno", "None"), 'shard')
node.args = [ast.Call(func=ast.Name(id='get_rank_size', ctx=ast.Load()), args=[], keywords=[]),
ast.Call(func=ast.Name(id='get_rank_id', ctx=ast.Load()), args=[], keywords=[])]
+ util_global.set_value("has_hccl_api", True)
util_global.set_value('need_conver', True)
if isinstance(node.func, ast.Attribute) and node.func.attr == 'dropout':
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr == 'nn':
@@ -188,11 +189,11 @@ def ast_call(node):
if ((isinstance(keyword.value, ast.NameConstant) and keyword.value.value != True) or
(not isinstance(keyword.value, ast.NameConstant))):
log_success_report(getattr(node, "lineno", "None"), node.func.attr)
- keyword.value = ast.NameConstant(value=True)
+ keyword.value = pasta.parse('True')
util_global.set_value('need_conver', True)
if not exist:
log_success_report(getattr(node, "lineno", "None"), node.func.attr)
- keyword = ast.keyword(arg='drop_remainder', value=ast.NameConstant(value=True))
+ keyword = ast.keyword(arg='drop_remainder', value=pasta.parse('True'))
node.keywords.insert(0, keyword)
util_global.set_value('need_conver', True)
if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and
@@ -203,8 +204,10 @@ def ast_call(node):
if isinstance(node.func, ast.Attribute) and (node.func.attr == "get_distribution_strategy" or
node.func.attr == "MirroredStrategy" or node.func.attr == "MultiWorkerMirroredStrategy"):
log_success_report(getattr(node, "lineno", "None"), node.func.attr)
- node.func = ast.Attribute(value=ast.Name(id="npu_strategy", ctx=ast.Load()),
+ new_func = ast.Attribute(value=ast.Name(id="npu_strategy", ctx=ast.Load()),
attr="NPUStrategy", ctx=ast.Load())
+ ast.copy_location(new_func, node.func)
+ node.func = new_func
node.keywords = []
node.args = []
util_global.set_value('need_conver', True)
@@ -278,7 +281,7 @@ def ast_call(node):
if (keyword.arg == 'eval_on_tpu') or (keyword.arg == 'use_tpu') or (keyword.arg == 'export_to_tpu'):
if (not isinstance(keyword.value, ast.NameConstant)) or (isinstance(keyword.value, ast.NameConstant) and (keyword.value.value != False)):
log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(' + keyword.arg + '=*)')
- keyword.value = ast.NameConstant(value=False)
+ keyword.value = pasta.parse('False')
util_global.set_value('need_conver', True)
if add_eval_on_tpu and (keyword.arg == 'eval_on_tpu'):
add_eval_on_tpu = False
@@ -288,15 +291,15 @@ def ast_call(node):
add_export_to_tpu = False
if add_eval_on_tpu:
log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(eval_on_tpu=*)')
- node.keywords.append(ast.keyword(arg='eval_on_tpu', value=ast.NameConstant(value=False)))
+ node.keywords.append(ast.keyword(arg='eval_on_tpu', value=pasta.parse('False')))
util_global.set_value('need_conver', True)
if add_use_tpu:
log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(use_tpu=*)')
- node.keywords.append(ast.keyword(arg='use_tpu', value=ast.NameConstant(value=False)))
+ node.keywords.append(ast.keyword(arg='use_tpu', value=pasta.parse('False')))
util_global.set_value('need_conver', True)
if add_export_to_tpu:
log_success_report(getattr(node, 'lineno', 'None'), 'TPUEstimator(export_to_tpu=*)')
- node.keywords.append(ast.keyword(arg='export_to_tpu', value=ast.NameConstant(value=False)))
+ node.keywords.append(ast.keyword(arg='export_to_tpu', value=pasta.parse('False')))
util_global.set_value('need_conver', True)
if isinstance(node.func, ast.Attribute) and (node.func.attr == 'VirtualDeviceConfiguration'):
log_success_report(getattr(node, 'lineno', 'None'), 'VirtualDeviceConfiguration')
@@ -338,9 +341,9 @@ def ast_call(node):
compile_ops = keyword
break
if compile_ops:
- compile_ops.value = ast.NameConstant(value=False)
+ compile_ops.value = pasta.parse('False')
else:
- node.keywords.append(ast.keyword(arg='compile_ops', value=ast.NameConstant(value=False)))
+ node.keywords.append(ast.keyword(arg='compile_ops', value=pasta.parse('False')))
return node
for estimator in util_global.get_value('Estimators', []):
if (isinstance(node.func, ast.Attribute) and (node.func.attr == estimator)) \
@@ -352,20 +355,14 @@ def ast_call(node):
config = keyword
break
if config:
- config.value = ast.Call(
- func=ast.Name(id='npu_run_config_init', ctx=ast.Load()),
- args=[],
- keywords=[
- ast.keyword(arg='run_config', value=config.value)
- ]
- )
+ new_value = ast.Call(func=ast.Name(id='npu_run_config_init', ctx=ast.Load()),
+ args=[],
+ keywords=[ast.keyword(arg='run_config', value=config.value)])
+ ast.copy_location(new_value, config.value)
+ config.value = new_value
else:
- node.keywords.append(
- ast.keyword(
- arg='config',
- value=ast.Call(func=ast.Name(id='npu_run_config_init', ctx=ast.Load()), args=[], keywords=[])
- )
- )
+ node.keywords.append(ast.keyword(arg='config',
+ value=pasta.parse('npu_run_config_init()')))
util_global.set_value('need_conver', True)
return node
for estimator_func in util_global.get_value('EstimatorFunc', []):
@@ -385,53 +382,51 @@ def ast_call(node):
if not input_fn:
break
if not hooks:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[])))
+ node.keywords.append(ast.keyword(arg='hooks', value=pasta.parse('npu_hooks_append()')))
elif isinstance(hooks, ast.keyword):
- hooks.value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[
- ast.keyword(arg='hooks_list', value=hooks.value)])
+ new_value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[],
+ keywords=[ast.keyword(arg='hooks_list', value=hooks.value)])
+ ast.copy_location(new_value, hooks.value)
+ hooks.value = new_value
else:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
+ node.keywords.append(ast.keyword(arg='hooks',
+ value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()),
+ args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
util_global.set_value('need_conver', True)
return node
if isinstance(node.func, ast.Attribute) and (node.func.attr == 'compile'):
- opt_map = {"adadelta": "tf.keras.optimizers.Adadelta",
- "adagrad": "tf.keras.optimizers.Adagrad",
- "adam": "tf.keras.optimizers.Adam",
- "adamax": "tf.keras.optimizers.Adamax",
- "ftrl": "tf.keras.optimizers.Ftrl",
- "nadam": "tf.keras.optimizers.Nadam",
- "rmsprop": "tf.keras.optimizers.RMSprop",
- "sgd": "tf.keras.optimizers.SGD"}
+ opt_map = {"adadelta": "tf.keras.optimizers.Adadelta()",
+ "adagrad": "tf.keras.optimizers.Adagrad()",
+ "adam": "tf.keras.optimizers.Adam()",
+ "adamax": "tf.keras.optimizers.Adamax()",
+ "ftrl": "tf.keras.optimizers.Ftrl()",
+ "nadam": "tf.keras.optimizers.Nadam()",
+ "rmsprop": "tf.keras.optimizers.RMSprop()",
+ "sgd": "tf.keras.optimizers.SGD()"}
for keyword in node.keywords:
if keyword.arg == "optimizer":
log_success_report(getattr(node, 'lineno', 'None'), 'KerasDistributeOptimizer')
- opt_func_name = ast.Name(id="npu_keras_optimizer", ctx=ast.Load())
if isinstance(keyword.value, ast.Str):
- keras_opt = opt_map[keyword.value.s].split(".")
- tf_opt_func = ast.Attribute(value=ast.Attribute(value=ast.Attribute(value=ast.Name(id=keras_opt[0], ctx=ast.Load()),
- attr=keras_opt[1], ctx=ast.Load()), attr=keras_opt[2], ctx=ast.Load()),
- attr=keras_opt[3], ctx=ast.Load())
- keyword.value = ast.Call(func=opt_func_name, args=[ast.Call(func=tf_opt_func, args=[], keywords=[])], keywords=[])
+ keras_opt = opt_map[keyword.value.s]
+ npu_keras_opt = "npu_keras_optimizer(" + keras_opt + ")"
+ keyword.value = pasta.parse(npu_keras_opt)
util_global.set_value('need_conver', True)
- util_global.set_value('insert_npu_keras_opt_func', True)
return node
if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Attribute):
if (node.func.attr.find("Optimizer") != -1) and (node.func.attr != 'ScipyOptimizerInterface'):
log_msg(getattr(node, "lineno", "None"), "add NPUDistributedOptimizer()")
- node = ast.Call(func=ast.Name(id="npu_tf_optimizer", ctx=ast.Load()), args=[node], keywords=[])
+ new_node = ast.Call(func=ast.Name(id="npu_tf_optimizer", ctx=ast.Load()), args=[node], keywords=[])
+ ast.copy_location(new_node, node)
util_global.set_value('need_conver', True)
- util_global.set_value('insert_npu_tf_opt_func', True)
- return node
+ return new_node
if isinstance(node.func, ast.Attribute):
opt_list = ["Adadelta", "Adagrad", "Adam", "Adamax", "Ftrl", "Nadam", "RMSprop", "SGD"]
if node.func.attr in opt_list:
log_success_report(getattr(node, "lineno", "None"), "KerasDistributeOptimizer")
- node = ast.Call(func=ast.Name(id="npu_keras_optimizer", ctx=ast.Load()), args=[node], keywords=[])
+ new_node = ast.Call(func=ast.Name(id="npu_keras_optimizer", ctx=ast.Load()), args=[node], keywords=[])
+ ast.copy_location(new_node, node)
util_global.set_value('need_conver', True)
- util_global.set_value('insert_npu_keras_opt_func', True)
- return node
+ return new_node
if (isinstance(node.func, ast.Attribute) and (node.func.attr == 'MonitoredTrainingSession')) or \
(isinstance(node.func, ast.Name) and (node.func.id == 'MonitoredTrainingSession')):
log_success_report(getattr(node, "lineno", "None"), 'MonitoredTrainingSession')
@@ -445,14 +440,15 @@ def ast_call(node):
hooks = keyword
break
if not hooks:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[])))
+ node.keywords.append(ast.keyword(arg='hooks', value=pasta.parse('npu_hooks_append()')))
elif isinstance(hooks, ast.keyword):
- hooks.value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[
- ast.keyword(arg='hooks_list', value=hooks.value)])
+ new_value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[],
+ keywords=[ast.keyword(arg='hooks_list', value=hooks.value)])
+ ast.copy_location(new_value, hooks.value)
+ hooks.value = new_value
else:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
+ node.keywords.append(ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()),
+ args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
util_global.set_value('need_conver', True)
return node
specs = {'TrainSpec': 2, 'EvalSpec': 3}
@@ -469,14 +465,15 @@ def ast_call(node):
hooks = keyword
break
if not hooks:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[])))
+ node.keywords.append(ast.keyword(arg='hooks', value=pasta.parse('npu_hooks_append()')))
elif isinstance(hooks, ast.keyword):
- hooks.value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[
+ new_value = ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[
ast.keyword(arg='hooks_list', value=hooks.value)])
+ ast.copy_location(new_value, hooks.value)
+ hooks.value = new_value
else:
- node.keywords.append(
- ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()), args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
+ node.keywords.append(ast.keyword(arg='hooks', value=ast.Call(func=ast.Name(id='npu_hooks_append', ctx=ast.Load()),
+ args=[], keywords=[ast.keyword(arg='hooks_list', value=hooks)])))
util_global.set_value('need_conver', True)
return node
@@ -508,42 +505,7 @@ def insert_npu_import(r_node):
r_node.body.insert(import_index, npu_import)
log_msg(import_index, "from npu_bridge.npu_init import *")
-def insert_npu_tf_opt_func(r_node):
- n = 0
- lenline = len(r_node.body)
-
- while n < lenline and not isinstance(r_node.body[n], ast.ImportFrom) and not isinstance(r_node.body[n], ast.Import):
- n += 1
-
- while n < lenline and (isinstance(r_node.body[n], ast.ImportFrom) or isinstance(r_node.body[n], ast.Import)):
- n += 1
-
- if n < lenline:
- npu_func = ast.Name(id="NPUDistributedOptimizer", ctx=ast.Load())
- assign_target = ast.Name(id="npu_opt", ctx=ast.Store())
- assign_args = ast.Name(id="opt", ctx=ast.Load())
- npu_opt = ast.Assign(targets=[assign_target], value=ast.Call(func=npu_func, args=[assign_args], keywords=[]))
- return_node = ast.Return(value=ast.Name(id='npu_opt', ctx=ast.Load()))
-
- r_node.body.insert(n, ast.FunctionDef(
- name='npu_tf_optimizer',
- args=ast.arguments(
- args=[
- ast.arg(arg='opt', annotation=None)
- ],
- vararg=None,
- kwonlyargs=[],
- kw_defaults=[],
- kwarg=None,
- defaults=[]),
- body=[
- npu_opt,
- return_node
- ],
- decorator_list=[],
- returns=None))
-
-def insert_npu_keras_opt_func(r_node):
+def insert_npu_resource_init(r_node):
n = 0
lenline = len(r_node.body)
@@ -554,31 +516,22 @@ def insert_npu_keras_opt_func(r_node):
n += 1
if n < lenline:
- npu_func = ast.Name(id="KerasDistributeOptimizer", ctx=ast.Load())
- assign_target = ast.Name(id="npu_opt", ctx=ast.Store())
- assign_args = ast.Name(id="opt", ctx=ast.Load())
- npu_opt = ast.Assign(targets=[assign_target], value=ast.Call(func=npu_func, args=[assign_args], keywords=[]))
- return_node = ast.Return(value=ast.Name(id='npu_opt', ctx=ast.Load()))
+ init_assign = ast.Assign(targets=[ast.Tuple(elts=[ast.Name(id="npu_sess", ctx=ast.Store()),
+ ast.Name(id="npu_shutdown", ctx=ast.Store())],
+ ctx=ast.Store())],
+ value=ast.Call(func=ast.Name(id="init_resource", ctx=ast.Load()), args=[], keywords=[]))
+ r_node.body.insert(n, init_assign)
- r_node.body.insert(n, ast.FunctionDef(
- name='npu_keras_optimizer',
- args=ast.arguments(
- args=[
- ast.arg(arg='opt', annotation=None)
- ],
- vararg=None,
- kwonlyargs=[],
- kw_defaults=[],
- kwarg=None,
- defaults=[]),
- body=[
- npu_opt,
- return_node
- ],
- decorator_list=[],
- returns=None))
+def insert_npu_resource_shutdown(r_node):
+ shutdown_call = ast.Expr(value=ast.Call(func=ast.Name(id="shutdown_resource", ctx=ast.Load()),
+ args=[ast.Name(id="npu_sess", ctx=ast.Load()), ast.Name(id="npu_shutdown", ctx=ast.Load())],
+ keywords=[]))
+ close_sess_call = ast.Expr(value=ast.Call(func=ast.Name(id="close_session", ctx=ast.Load()),
+ args=[ast.Name(id="npu_sess", ctx=ast.Load())], keywords=[]))
+ r_node.body.append(shutdown_call)
+ r_node.body.append(close_sess_call)
-def insert_empty_hook(r_node):
+def insert_keras_sess_npu_config(r_node):
n = 0
lenline = len(r_node.body)
@@ -589,26 +542,15 @@ def insert_empty_hook(r_node):
n += 1
if n < lenline:
- hook_attr = ast.Attribute(value=ast.Attribute(value=ast.Name(id="tf", ctx=ast.Load()), attr="train", ctx=ast.Load()),
- attr="SessionRunHook", ctx=ast.Load())
- class_def = ast.ClassDef(name="NpuEmptyHook", bases=[hook_attr], keywords=[],
- body=[ast.Pass()], decorator_list=[])
- r_node.body.insert(n, class_def)
+ keras_sess_assign = ast.Assign(targets=[ast.Name(id="npu_keras_sess", ctx=ast.Store())],
+ value=ast.Call(func=ast.Name(id="set_keras_session_npu_config", ctx=ast.Load()),
+ args=[], keywords=[]))
+ r_node.body.insert(n, keras_sess_assign)
-def ast_assign(node):
- for target in node.targets:
- if (isinstance(target, ast.Name) and target.id == 'global_jit_level') or (isinstance(target, ast.Attribute) and target.attr == 'global_jit_level'):
- log_msg(getattr(node, 'lineno', 'None'), 'set global_jit_level=config_pb2.OptimizerOptions.OFF')
- util_global.set_value('need_conver', True)
- global_jit_level_assign_node = ast.Assign(
- targets=node.targets,
- ctx=ast.Load(),
- value=ast.Attribute(attr='OFF', ctx=ast.Load(),
- value=ast.Attribute(attr='OptimizerOptions', ctx=ast.Load(),
- value=ast.Name(id='config_pb2', ctx=ast.Load()))))
- node = ast.If(test=ast.NameConstant(value=True), body=[global_jit_level_assign_node], orelse=[])
- return node
- return node
+def insert_keras_sess_close(r_node):
+ close_sess_call = ast.Expr(value=ast.Call(func=ast.Name(id="close_session", ctx=ast.Load()),
+ args=[ast.Name(id="npu_keras_sess", ctx=ast.Load())], keywords=[]))
+ r_node.body.append(close_sess_call)
# Format printing for locate
def node_tree(node:str):
diff --git a/convert_tf2npu/conver.py b/convert_tf2npu/conver.py
index 957c9adcbbbf48e01ed519aa0e77961b28b8ce4d..fb00b69ed4a7d5726ba9df44d27330696b54b13e 100644
--- a/convert_tf2npu/conver.py
+++ b/convert_tf2npu/conver.py
@@ -36,7 +36,7 @@ def conver():
for path, dir_list, file_list in conver_path:
for file_name in file_list:
- out_path_dst = abs_join(dst_path_new, path.split(dst_path)[1])
+ out_path_dst = abs_join(dst_path_new, path.split(util_global.get_value('input'))[1])
file_path = os.path.join(path, file_name).replace('\\', '/')
content = "Begin conver file: " + file_path
print(content)
diff --git a/convert_tf2npu/conver_by_ast.py b/convert_tf2npu/conver_by_ast.py
index 0bdf175dc3f8189f6f47b64174c8775f78f70826..3788b6cc3f364cce411673eb6ac9872af3a2963a 100644
--- a/convert_tf2npu/conver_by_ast.py
+++ b/convert_tf2npu/conver_by_ast.py
@@ -15,25 +15,13 @@
import os
import sys
import ast
-import astunparse
+import pasta
import util_global
from file_op import write_output_after_conver
from file_op import write_report_after_conver
from file_op import scan_file
-from util import log_success_report
-from util import log_migration_report
-from ast_impl import attribute
-from ast_impl import node_tree
-from ast_impl import insert_npu_import
-from ast_impl import insert_npu_tf_opt_func
-from ast_impl import insert_npu_keras_opt_func
-from ast_impl import insert_empty_hook
-from ast_impl import import_from
-from ast_impl import ast_import
-from ast_impl import ast_function_def
-from ast_impl import ast_call
-from ast_impl import ast_assign
-from ast_impl import ast_if
+from util import *
+from ast_impl import *
from visit_by_ast import get_tf_api
class ConverByAst(ast.NodeTransformer):
@@ -75,11 +63,6 @@ class ConverByAst(ast.NodeTransformer):
return node
def visit_Assign(self, node):
- for target in node.targets:
- if (isinstance(target, ast.Name) and target.id == 'global_jit_level') or (isinstance(target, ast.Attribute) and target.attr == 'global_jit_level'):
- return ast_assign(node)
-
- ast_assign(node)
self.generic_visit(node)
return node
@@ -90,16 +73,16 @@ class ConverByAst(ast.NodeTransformer):
def conver_ast(path, out_path_dst, file_name):
util_global.set_value('need_conver', False)
- util_global.set_value('insert_estimator_add_hook_func', False)
- util_global.set_value('insert_npu_tf_opt_func', False)
- util_global.set_value('insert_npu_keras_opt_func', False)
- util_global.set_value('insert_empty_hook', False)
util_global.set_value('is_keras_net', False)
- util_global.set_value('is_hvd_net', False)
+ util_global.set_value('has_hccl_api', False)
+ util_global.set_value('is_main_file', False)
+ util_global.set_value('has_main_func', False)
+ if os.path.join(path, file_name) == util_global.get_value('main', ""):
+ util_global.set_value('is_main_file', True)
with open(os.path.join(path, file_name), "r", encoding='utf-8') as file:
source = file.read()
try:
- r_node = ast.parse(source)
+ r_node = pasta.parse(source)
except Exception as e:
print(repr(e))
return
@@ -116,13 +99,17 @@ def conver_ast(path, out_path_dst, file_name):
if util_global.get_value('need_conver', False):
insert_npu_import(r_node)
- if util_global.get_value('insert_npu_tf_opt_func', False):
- insert_npu_tf_opt_func(r_node)
- if util_global.get_value('insert_npu_keras_opt_func', False):
- insert_npu_keras_opt_func(r_node)
- if util_global.get_value('insert_empty_hook', False):
- insert_empty_hook(r_node)
- dst_content = astunparse.unparse(r_node)
+ if not util_global.get_value('has_main_func', False) and (util_global.get_value('has_hccl_api', False)
+ or util_global.get_value('is_keras_net', False)):
+ log_warning('the network of keras and horovod, or using dataset.shard script do not have main func, '
+ 'should set -m or --main parameter')
+ if util_global.get_value('is_main_file', False) and util_global.get_value('has_hccl_api', False):
+ insert_npu_resource_init(r_node)
+ insert_npu_resource_shutdown(r_node)
+ if util_global.get_value('is_main_file', False) and util_global.get_value('is_keras_net', False):
+ insert_keras_sess_npu_config(r_node)
+ insert_keras_sess_close(r_node)
+ dst_content = pasta.dump(r_node)
write_output_after_conver(os.path.join(util_global.get_value('output'), out_path_dst, file_name), dst_content)
if file_name.endswith("a.py"):
diff --git a/convert_tf2npu/file_op.py b/convert_tf2npu/file_op.py
index 41db221134b10ea23b690b9cf9ee138741e09398..fc4264d30a193b6cfb424c163df811e4ac71450b 100644
--- a/convert_tf2npu/file_op.py
+++ b/convert_tf2npu/file_op.py
@@ -17,6 +17,7 @@ import shutil
import util_global
import pandas as pd
from visit_by_ast import get_tf_enume
+from visit_by_ast import get_unsupport_api
def before_clear():
exit_folder = os.path.exists(util_global.get_value('output'))
@@ -131,6 +132,18 @@ def scan_file(path, file_name, api, lineno):
support_type.append(api_support[api_name.index(class_name)])
migrate_advice.append(api_advice[api_name.index(class_name)])
+ # record unsupported api
+ (unsupport, lineno) = get_unsupport_api(os.path.join(path, file_name))
+ for i in range(len(unsupport)):
+ name = unsupport[i]
+ module = name.split('.')[0]
+ script_name.append(file_name)
+ code_api.append(name)
+ code_line.append(lineno[i])
+ code_module.append(module)
+ support_type.append('不支持(无迁移方案,建议用户不使用)')
+ migrate_advice.append('第三方非TF官网API,暂不支持')
+
analyse_result = pd.DataFrame({'脚本文件名': script_name, '代码行': code_line,
'模块名': code_module, 'API名': code_api,
'工具迁移API支持度': support_type, '说明': migrate_advice})
diff --git a/convert_tf2npu/main.py b/convert_tf2npu/main.py
index e890640e916facf3163c78f9651f717f3d603168..0afe005e13c7fb1abdce52125d44b6308c86e586 100644
--- a/convert_tf2npu/main.py
+++ b/convert_tf2npu/main.py
@@ -20,40 +20,45 @@ from file_op import before_clear
from conver import conver
def para_check_and_set(argv):
- input = "input"
- list = "tf1.15_api_support_list.xlsx"
+ input_dir = "npu_input"
+ support_list = "tf1.15_api_support_list.xlsx"
output = "output" + util_global.get_value('timestap')
report = "report" + util_global.get_value('timestap')
report_suffix = report
+ main_file = ""
try:
- opts, args = getopt.getopt(argv, "hi:l:o:r:", ["help", "input=", "list=", "output=", "report="])
+ opts, args = getopt.getopt(argv, "hi:l:o:r:m:", ["help", "input=", "list=", "output=", "report=", "main="])
except getopt.GetoptError:
print('Parameter error, please check.')
- print(' main.py -i -l -o