diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5a1d6ab72b64b6da0f987a898ed0f4b242ea4605
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Datasource local storage ignored files
+/../../../../../:\00gitee\00tfadapter\tensorflow\.idea/dataSources/
+/dataSources.local.xml
+# Editor-based HTTP Client requests
+/httpRequests/
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..79b3c94830bab93d40d0770f2765540fe24ed423
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..2bda9d8e1a0910c478191c7ee0ddf347d8d14854
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/tensorflow.iml b/.idea/tensorflow.iml
new file mode 100644
index 0000000000000000000000000000000000000000..f08604bb65b25149b195f9e9f282f9683a428592
--- /dev/null
+++ b/.idea/tensorflow.iml
@@ -0,0 +1,2 @@
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/cmake-build-debug/CMakeFiles/clion-log.txt b/cmake-build-debug/CMakeFiles/clion-log.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c7ba5e8f9efad023df0acf3c59e29fd413f530aa
--- /dev/null
+++ b/cmake-build-debug/CMakeFiles/clion-log.txt
@@ -0,0 +1 @@
+Toolchains are not configured Configure
diff --git a/conver_tf2npu/README.md b/conver_tf2npu/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..afc720d8a8bcd936da0eacecfa923ed974ec11ec
--- /dev/null
+++ b/conver_tf2npu/README.md
@@ -0,0 +1,25 @@
+# Tensorflow Adapter For Ascend
+
+[View English](README.en.md)
+
+Tensorflow Adapter For Ascend(简称TF_Adapter)致力于将昇腾AI处理器卓越的运算能力,便捷地提供给使用Tensorflow框架的开发者。
+开发者只需安装TF_Adapter插件,并在现有Tensorflow脚本中添加少量配置,即可实现在昇腾AI处理器上加速自己的训练任务。
+
+
+
+您可以通过阅读 [TF_Adapter接口文档](https://support.huaweicloud.com/mprtg-A800_9000_9010/atlasprtg_13_0013.html) 获取更多使用细节。
+
+脚本使用指导:
+待补充
+
+## 贡献
+
+欢迎参与贡献。
+
+## Release Notes
+
+Release Notes请参考[RELEASE](RELEASE.md).
+
+## License
+
+[Apache License 2.0](LICENSE)
diff --git a/conver_tf2npu/ast_impl.py b/conver_tf2npu/ast_impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0c48a2fec8afaadf10557aa65e10d5bce4c48f5
--- /dev/null
+++ b/conver_tf2npu/ast_impl.py
@@ -0,0 +1,120 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless REQUIRED by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import ast
+import util_global
+from util import log_success_report
+from util import log_migration_report
+
+def attribute(node):
+ log_success_report(getattr(node, "lineno", "None"), node.attr)
+ if node.attr == 'dropout':
+ node.value = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load())
+ else:
+ node = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load())
+ util_global.set_value('need_conver', True)
+ return node
+
+def import_from(node):
+ if node.module != None:
+ values = node.module.split(".")
+ if "keras" in values:
+ log_migration_report(getattr(node, "lineno", "None"), "keras")
+ util_global.set_value('need_conver', True)
+
+def ast_import(node):
+ for value in node.names:
+ if isinstance(value, ast.alias):
+ values = value.name.split(".")
+ if "keras" in values:
+ log_migration_report(getattr(node, "lineno", "None"), "keras")
+ util_global.set_value('need_conver', True)
+
+def ast_function_def(node):
+ log_success_report(getattr(node, "lineno", "None"), node.name)
+ node.body = [ast.Return(value=ast.Call(
+ func=ast.Attribute(value=ast.Name(id=util_global.get_value(node.name)[0],
+ ctx=ast.Load()), attr='gelu',
+ ctx=ast.Load()),
+ args=[ast.Name(id='x', ctx=ast.Load())],
+ keywords=[]))]
+
+ util_global.set_value('need_conver', True)
+ return node
+
+def ast_call(node):
+ if isinstance(node.func, ast.Attribute):
+ if len(node.args) > 0:
+ if isinstance(node.args[0], ast.Call):
+ if isinstance(node.args[0].func, ast.Attribute):
+ if node.args[0].func.attr == 'BroadcastGlobalVariablesHook':
+ log_success_report(getattr(node, "lineno", "None"), 'BroadcastGlobalVariablesHook')
+ node.func = ast.Name(id=util_global.get_value('BroadcastGlobalVariablesHook')[0], ctx=ast.Load())
+ node.args = []
+ util_global.set_value('need_conver', True)
+ if isinstance(node.func, ast.Attribute) and node.func.attr == 'shard':
+ log_success_report(getattr(node, "lineno", "None"), 'shard')
+ node.args = [ast.Call(func=ast.Name(id='get_rank_size', ctx=ast.Load()), args=[], keywords=[]),
+ ast.Call(func=ast.Name(id='get_rank_id', ctx=ast.Load()), args=[], keywords=[])]
+ util_global.set_value('need_conver', True)
+ if isinstance(node.func, ast.Attribute) and (node.func.attr == 'batch' or node.func.attr == 'map_and_batch'):
+ exist = False
+ for keyword in node.keywords:
+ if keyword.arg == 'drop_remainder':
+ exist = True
+ if ((isinstance(keyword.value, ast.NameConstant) and keyword.value.value != True) or
+ (not isinstance(keyword.value, ast.NameConstant))):
+ log_success_report(getattr(node, "lineno", "None"), node.func.attr)
+ keyword.value = ast.NameConstant(value=True)
+ util_global.set_value('need_conver', True)
+ if not exist:
+ log_success_report(getattr(node, "lineno", "None"), node.func.attr)
+ keyword = ast.keyword(arg='drop_remainder', value=ast.NameConstant(value=True))
+ node.keywords.insert(0, keyword)
+ util_global.set_value('need_conver', True)
+ if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and
+ node.func.value.id == 'tf' and node.func.attr == 'device'):
+ log_success_report(getattr(node, "lineno", "None"), node.func.attr)
+ node.args = [ast.Str(s='/cpu:0')]
+ util_global.set_value('need_conver', True)
+ return node
+
+def insert_npu_import(r_node):
+ npu_alias = ast.alias(name='npu_ops', asname=None)
+ npu_import = ast.ImportFrom(module='npu_bridge.estimator', names=[npu_alias], level=0)
+ r_node.body.insert(0, npu_import)
+ npu_alias = ast.alias(name='npu_unary_ops', asname=None)
+ npu_import = ast.ImportFrom(module='npu_bridge.estimator.npu_unary_ops', names=[npu_alias], level=0)
+ r_node.body.insert(0, npu_import)
+
+# Format printing for locate
+def node_tree(node:str):
+ str2list = list(node.replace(' ', ''))
+ count = 0
+ for i, e in enumerate(str2list):
+ if e == '(':
+ count += 1
+ str2list[i] = '(\n{}'.format('| ' * count)
+ elif e == ')':
+ count -= 1
+ str2list[i] = '\n{})'.format('| ' * count)
+ elif e == ',':
+ str2list[i] = ',\n{}'.format('| ' * count)
+ elif e == '[':
+ count += 1
+ str2list[i] = '[\n{}'.format('| ' * count)
+ elif e == ']':
+ count -= 1
+ str2list[i] = '\n{}]'.format('| ' * count)
+ return ''.join(str2list)
\ No newline at end of file
diff --git a/conver_tf2npu/conver.py b/conver_tf2npu/conver.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4acd977e6e5d2aa376ffee65683c7793124220
--- /dev/null
+++ b/conver_tf2npu/conver.py
@@ -0,0 +1,44 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless REQUIRED by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import os
+import util_global
+from conver_by_ast import conver_ast
+from file_op import mkdir
+from file_op import mkdir_and_copyfile
+from file_op import write_report_terminator
+from file_op import abs_join
+
+def conver():
+ print("Begin conver, input file: " + util_global.get_value('input'))
+ out_path = util_global.get_value('output')
+ dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1]
+ conver_path = os.walk(util_global.get_value('input'))
+ for path,dir_list,file_list in conver_path:
+ for file_name in file_list:
+ out_path_dst = abs_join(dst_path, path.split(dst_path)[1])
+ if file_name.endswith(".py"):
+ util_global.set_value('path', os.path.join(path, file_name))
+ mkdir(os.path.join(out_path, out_path_dst))
+ conver_ast(path, out_path_dst, file_name)
+ if util_global.get_value('need_conver', False):
+ content = "Finish conver file: " + os.path.join(path, file_name)
+ print(content)
+ write_report_terminator(content)
+ else:
+ mkdir_and_copyfile(path, abs_join(out_path, out_path_dst), file_name)
+ else:
+ mkdir_and_copyfile(path, abs_join(out_path, out_path_dst), file_name)
+
+ print("Finish conver, output file: " + out_path + "; report file: " + util_global.get_value('report'))
diff --git a/conver_tf2npu/conver_by_ast.py b/conver_tf2npu/conver_by_ast.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a1aaaf060356b0b18c1dc3f70711c045c0e628
--- /dev/null
+++ b/conver_tf2npu/conver_by_ast.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless REQUIRED by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import os
+import sys
+import ast
+import astunparse
+import util_global
+from file_op import write_output_after_conver
+from file_op import write_report_after_conver
+from util import log_success_report
+from ast_impl import attribute
+from ast_impl import node_tree
+from ast_impl import insert_npu_import
+from ast_impl import import_from
+from ast_impl import ast_import
+from ast_impl import ast_function_def
+from ast_impl import ast_call
+
+class ConverByAst(ast.NodeTransformer):
+ def generic_visit(self, node):
+ ast.NodeTransformer.generic_visit(self, node)
+ return node
+ def visit_Attribute(self, node):
+ if node.attr in util_global.get_value('nn') and isinstance(node.value, ast.Attribute):
+ if node.value.attr == 'nn':
+ return attribute(node)
+ if node.attr in util_global.get_value('estimator') and isinstance(node.value, ast.Attribute):
+ if node.value.attr == 'estimator':
+ return attribute(node)
+ if node.attr in util_global.get_value('hvd'):
+ if isinstance(node.value, ast.Name):
+ if 'hvd' in str(node.value.id):
+ return attribute(node)
+ if isinstance(node.value, ast.Attribute):
+ if 'hvd' in str(node.value.attr):
+ return attribute(node)
+ return node
+
+ def visit_FunctionDef(self, node):
+ if node.name == 'gelu':
+ return ast_function_def(node)
+ self.generic_visit(node)
+ return node
+
+ def visit_Call(self, node):
+ node = ast_call(node)
+ self.generic_visit(node)
+ return node
+
+ def visit_ImportFrom(self, node):
+ import_from(node)
+ self.generic_visit(node)
+ return node
+
+ def visit_Import(self, node):
+ ast_import(node)
+ self.generic_visit(node)
+ return node
+
+def conver_ast(path, out_path_dst, file_name):
+ util_global.set_value('need_conver', False)
+ file = open(os.path.join(path, file_name), "r")
+ source = file.read()
+ r_node = ast.parse(source)
+
+ sys.setrecursionlimit(10000)
+ visitor = ConverByAst()
+ visitor.visit(r_node)
+ ast.fix_missing_locations(r_node)
+
+ if util_global.get_value('need_conver', False):
+ insert_npu_import(r_node)
+ dst_content = astunparse.unparse(r_node)
+ write_output_after_conver(os.path.join(util_global.get_value('output'), out_path_dst, file_name), dst_content)
+
+ if file_name.endswith("a.py"):
+ write_report_after_conver("only_for_test", file_name, node_tree(ast.dump(r_node)))
\ No newline at end of file
diff --git a/conver_tf2npu/file_op.py b/conver_tf2npu/file_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ede41965fd33eb1ab5885fb9332b2c53ca13f31
--- /dev/null
+++ b/conver_tf2npu/file_op.py
@@ -0,0 +1,69 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless REQUIRED by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import os
+import shutil
+import util_global
+
+def before_clear():
+ exit_folder = os.path.exists(util_global.get_value('output'))
+ if exit_folder:
+ shutil.rmtree(util_global.get_value('output'))
+ exit_folder = os.path.exists(util_global.get_value('report'))
+ if exit_folder:
+ shutil.rmtree(util_global.get_value('report'))
+
+def mkdir(path):
+ folder = os.path.exists(path)
+ if not folder:
+ os.makedirs(path)
+
+def mkdir_and_copyfile(srcfile, dstpath, file_name):
+ mkdir(dstpath)
+ shutil.copyfile(os.path.join(srcfile, file_name), os.path.join(dstpath, file_name))
+
+def write_output_after_conver(out_file, dst_content):
+ file = open(out_file, 'w')
+ file.write(dst_content)
+ file.close()
+
+def write_report_after_conver(new_file_path, report_file, dst_content):
+ mkdir(new_file_path)
+ file = open(os.path.join(new_file_path, report_file), 'w')
+ file.write(dst_content)
+ file.close()
+
+def write_report_terminator(content):
+ report_path = util_global.get_value('report')
+ for file in util_global.get_value('report_file'):
+ if os.path.exists(os.path.join(report_path, file)):
+ file = open(os.path.join(report_path, file), 'a')
+ file.write(content)
+ file.write("\r\n")
+ file.write("\r\n")
+ file.close()
+
+def write_conver_report(content, file):
+ report_path = util_global.get_value('report')
+ mkdir(report_path)
+ file = open(os.path.join(report_path, file), 'a')
+ file.write(content)
+ file.write("\r\n")
+ file.close()
+
+def abs_join(abs1, abs2):
+ abs2 = os.fspath(abs2)
+ abs2 = os.path.splitdrive(abs2)[1]
+ abs2 = abs2.strip('\\/') or abs2
+ return os.path.join(abs1, abs2)
\ No newline at end of file
diff --git a/conver_tf2npu/main.py b/conver_tf2npu/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ed7f3d042a75dc477e1a90ce6cc06c7870e9632
--- /dev/null
+++ b/conver_tf2npu/main.py
@@ -0,0 +1,59 @@
+# Copyright 2020 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless REQUIRED by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import sys
+import getopt
+import util_global
+from file_op import before_clear
+from conver import conver
+
+def para_check_and_set(argv):
+ input = "input"
+ output = "output"
+ report = "report"
+
+ try:
+ opts, args = getopt.getopt(argv, "hi:o:r:", ["help", "input=", "output=", "report="])
+ except getopt.GetoptError:
+ print('Parameter error, please check.')
+ print(' main.py -i -o