diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5a1d6ab72b64b6da0f987a898ed0f4b242ea4605 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Datasource local storage ignored files +/../../../../../:\00gitee\00tfadapter\tensorflow\.idea/dataSources/ +/dataSources.local.xml +# Editor-based HTTP Client requests +/httpRequests/ diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..79b3c94830bab93d40d0770f2765540fe24ed423 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..2bda9d8e1a0910c478191c7ee0ddf347d8d14854 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/tensorflow.iml b/.idea/tensorflow.iml new file mode 100644 index 0000000000000000000000000000000000000000..f08604bb65b25149b195f9e9f282f9683a428592 --- /dev/null +++ b/.idea/tensorflow.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/cmake-build-debug/CMakeFiles/clion-log.txt b/cmake-build-debug/CMakeFiles/clion-log.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7ba5e8f9efad023df0acf3c59e29fd413f530aa --- /dev/null +++ b/cmake-build-debug/CMakeFiles/clion-log.txt @@ -0,0 +1 @@ +Toolchains are not configured Configure diff --git a/conver_tf2npu/README.md b/conver_tf2npu/README.md new file mode 100644 index 0000000000000000000000000000000000000000..afc720d8a8bcd936da0eacecfa923ed974ec11ec --- /dev/null +++ b/conver_tf2npu/README.md @@ -0,0 +1,25 @@ +# Tensorflow Adapter For Ascend + +[View English](README.en.md) + +Tensorflow Adapter For Ascend(简称TF_Adapter)致力于将昇腾AI处理器卓越的运算能力,便捷地提供给使用Tensorflow框架的开发者。 +开发者只需安装TF_Adapter插件,并在现有Tensorflow脚本中添加少量配置,即可实现在昇腾AI处理器上加速自己的训练任务。 + +![tfadapter](https://images.gitee.com/uploads/images/2020/1027/094640_8f305b88_8175427.jpeg "framework.jpg") + +您可以通过阅读 [TF_Adapter接口文档](https://support.huaweicloud.com/mprtg-A800_9000_9010/atlasprtg_13_0013.html) 获取更多使用细节。 + +脚本使用指导: +待补充 + +## 贡献 + +欢迎参与贡献。 + +## Release Notes + +Release Notes请参考[RELEASE](RELEASE.md). + +## License + +[Apache License 2.0](LICENSE) diff --git a/conver_tf2npu/ast_impl.py b/conver_tf2npu/ast_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c48a2fec8afaadf10557aa65e10d5bce4c48f5 --- /dev/null +++ b/conver_tf2npu/ast_impl.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import ast +import util_global +from util import log_success_report +from util import log_migration_report + +def attribute(node): + log_success_report(getattr(node, "lineno", "None"), node.attr) + if node.attr == 'dropout': + node.value = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load()) + else: + node = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load()) + util_global.set_value('need_conver', True) + return node + +def import_from(node): + if node.module != None: + values = node.module.split(".") + if "keras" in values: + log_migration_report(getattr(node, "lineno", "None"), "keras") + util_global.set_value('need_conver', True) + +def ast_import(node): + for value in node.names: + if isinstance(value, ast.alias): + values = value.name.split(".") + if "keras" in values: + log_migration_report(getattr(node, "lineno", "None"), "keras") + util_global.set_value('need_conver', True) + +def ast_function_def(node): + log_success_report(getattr(node, "lineno", "None"), node.name) + node.body = [ast.Return(value=ast.Call( + func=ast.Attribute(value=ast.Name(id=util_global.get_value(node.name)[0], + ctx=ast.Load()), attr='gelu', + ctx=ast.Load()), + args=[ast.Name(id='x', ctx=ast.Load())], + keywords=[]))] + + util_global.set_value('need_conver', True) + return node + +def ast_call(node): + if isinstance(node.func, ast.Attribute): + if len(node.args) > 0: + if isinstance(node.args[0], ast.Call): + if isinstance(node.args[0].func, ast.Attribute): + if node.args[0].func.attr == 'BroadcastGlobalVariablesHook': + log_success_report(getattr(node, "lineno", "None"), 'BroadcastGlobalVariablesHook') + node.func = ast.Name(id=util_global.get_value('BroadcastGlobalVariablesHook')[0], ctx=ast.Load()) + node.args = [] + util_global.set_value('need_conver', True) + if isinstance(node.func, ast.Attribute) and node.func.attr == 'shard': + log_success_report(getattr(node, "lineno", "None"), 'shard') + node.args = [ast.Call(func=ast.Name(id='get_rank_size', ctx=ast.Load()), args=[], keywords=[]), + ast.Call(func=ast.Name(id='get_rank_id', ctx=ast.Load()), args=[], keywords=[])] + util_global.set_value('need_conver', True) + if isinstance(node.func, ast.Attribute) and (node.func.attr == 'batch' or node.func.attr == 'map_and_batch'): + exist = False + for keyword in node.keywords: + if keyword.arg == 'drop_remainder': + exist = True + if ((isinstance(keyword.value, ast.NameConstant) and keyword.value.value != True) or + (not isinstance(keyword.value, ast.NameConstant))): + log_success_report(getattr(node, "lineno", "None"), node.func.attr) + keyword.value = ast.NameConstant(value=True) + util_global.set_value('need_conver', True) + if not exist: + log_success_report(getattr(node, "lineno", "None"), node.func.attr) + keyword = ast.keyword(arg='drop_remainder', value=ast.NameConstant(value=True)) + node.keywords.insert(0, keyword) + util_global.set_value('need_conver', True) + if (isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name) and + node.func.value.id == 'tf' and node.func.attr == 'device'): + log_success_report(getattr(node, "lineno", "None"), node.func.attr) + node.args = [ast.Str(s='/cpu:0')] + util_global.set_value('need_conver', True) + return node + +def insert_npu_import(r_node): + npu_alias = ast.alias(name='npu_ops', asname=None) + npu_import = ast.ImportFrom(module='npu_bridge.estimator', names=[npu_alias], level=0) + r_node.body.insert(0, npu_import) + npu_alias = ast.alias(name='npu_unary_ops', asname=None) + npu_import = ast.ImportFrom(module='npu_bridge.estimator.npu_unary_ops', names=[npu_alias], level=0) + r_node.body.insert(0, npu_import) + +# Format printing for locate +def node_tree(node:str): + str2list = list(node.replace(' ', '')) + count = 0 + for i, e in enumerate(str2list): + if e == '(': + count += 1 + str2list[i] = '(\n{}'.format('| ' * count) + elif e == ')': + count -= 1 + str2list[i] = '\n{})'.format('| ' * count) + elif e == ',': + str2list[i] = ',\n{}'.format('| ' * count) + elif e == '[': + count += 1 + str2list[i] = '[\n{}'.format('| ' * count) + elif e == ']': + count -= 1 + str2list[i] = '\n{}]'.format('| ' * count) + return ''.join(str2list) \ No newline at end of file diff --git a/conver_tf2npu/conver.py b/conver_tf2npu/conver.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4acd977e6e5d2aa376ffee65683c7793124220 --- /dev/null +++ b/conver_tf2npu/conver.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import util_global +from conver_by_ast import conver_ast +from file_op import mkdir +from file_op import mkdir_and_copyfile +from file_op import write_report_terminator +from file_op import abs_join + +def conver(): + print("Begin conver, input file: " + util_global.get_value('input')) + out_path = util_global.get_value('output') + dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1] + conver_path = os.walk(util_global.get_value('input')) + for path,dir_list,file_list in conver_path: + for file_name in file_list: + out_path_dst = abs_join(dst_path, path.split(dst_path)[1]) + if file_name.endswith(".py"): + util_global.set_value('path', os.path.join(path, file_name)) + mkdir(os.path.join(out_path, out_path_dst)) + conver_ast(path, out_path_dst, file_name) + if util_global.get_value('need_conver', False): + content = "Finish conver file: " + os.path.join(path, file_name) + print(content) + write_report_terminator(content) + else: + mkdir_and_copyfile(path, abs_join(out_path, out_path_dst), file_name) + else: + mkdir_and_copyfile(path, abs_join(out_path, out_path_dst), file_name) + + print("Finish conver, output file: " + out_path + "; report file: " + util_global.get_value('report')) diff --git a/conver_tf2npu/conver_by_ast.py b/conver_tf2npu/conver_by_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..16a1aaaf060356b0b18c1dc3f70711c045c0e628 --- /dev/null +++ b/conver_tf2npu/conver_by_ast.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import sys +import ast +import astunparse +import util_global +from file_op import write_output_after_conver +from file_op import write_report_after_conver +from util import log_success_report +from ast_impl import attribute +from ast_impl import node_tree +from ast_impl import insert_npu_import +from ast_impl import import_from +from ast_impl import ast_import +from ast_impl import ast_function_def +from ast_impl import ast_call + +class ConverByAst(ast.NodeTransformer): + def generic_visit(self, node): + ast.NodeTransformer.generic_visit(self, node) + return node + def visit_Attribute(self, node): + if node.attr in util_global.get_value('nn') and isinstance(node.value, ast.Attribute): + if node.value.attr == 'nn': + return attribute(node) + if node.attr in util_global.get_value('estimator') and isinstance(node.value, ast.Attribute): + if node.value.attr == 'estimator': + return attribute(node) + if node.attr in util_global.get_value('hvd'): + if isinstance(node.value, ast.Name): + if 'hvd' in str(node.value.id): + return attribute(node) + if isinstance(node.value, ast.Attribute): + if 'hvd' in str(node.value.attr): + return attribute(node) + return node + + def visit_FunctionDef(self, node): + if node.name == 'gelu': + return ast_function_def(node) + self.generic_visit(node) + return node + + def visit_Call(self, node): + node = ast_call(node) + self.generic_visit(node) + return node + + def visit_ImportFrom(self, node): + import_from(node) + self.generic_visit(node) + return node + + def visit_Import(self, node): + ast_import(node) + self.generic_visit(node) + return node + +def conver_ast(path, out_path_dst, file_name): + util_global.set_value('need_conver', False) + file = open(os.path.join(path, file_name), "r") + source = file.read() + r_node = ast.parse(source) + + sys.setrecursionlimit(10000) + visitor = ConverByAst() + visitor.visit(r_node) + ast.fix_missing_locations(r_node) + + if util_global.get_value('need_conver', False): + insert_npu_import(r_node) + dst_content = astunparse.unparse(r_node) + write_output_after_conver(os.path.join(util_global.get_value('output'), out_path_dst, file_name), dst_content) + + if file_name.endswith("a.py"): + write_report_after_conver("only_for_test", file_name, node_tree(ast.dump(r_node))) \ No newline at end of file diff --git a/conver_tf2npu/file_op.py b/conver_tf2npu/file_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6ede41965fd33eb1ab5885fb9332b2c53ca13f31 --- /dev/null +++ b/conver_tf2npu/file_op.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import shutil +import util_global + +def before_clear(): + exit_folder = os.path.exists(util_global.get_value('output')) + if exit_folder: + shutil.rmtree(util_global.get_value('output')) + exit_folder = os.path.exists(util_global.get_value('report')) + if exit_folder: + shutil.rmtree(util_global.get_value('report')) + +def mkdir(path): + folder = os.path.exists(path) + if not folder: + os.makedirs(path) + +def mkdir_and_copyfile(srcfile, dstpath, file_name): + mkdir(dstpath) + shutil.copyfile(os.path.join(srcfile, file_name), os.path.join(dstpath, file_name)) + +def write_output_after_conver(out_file, dst_content): + file = open(out_file, 'w') + file.write(dst_content) + file.close() + +def write_report_after_conver(new_file_path, report_file, dst_content): + mkdir(new_file_path) + file = open(os.path.join(new_file_path, report_file), 'w') + file.write(dst_content) + file.close() + +def write_report_terminator(content): + report_path = util_global.get_value('report') + for file in util_global.get_value('report_file'): + if os.path.exists(os.path.join(report_path, file)): + file = open(os.path.join(report_path, file), 'a') + file.write(content) + file.write("\r\n") + file.write("\r\n") + file.close() + +def write_conver_report(content, file): + report_path = util_global.get_value('report') + mkdir(report_path) + file = open(os.path.join(report_path, file), 'a') + file.write(content) + file.write("\r\n") + file.close() + +def abs_join(abs1, abs2): + abs2 = os.fspath(abs2) + abs2 = os.path.splitdrive(abs2)[1] + abs2 = abs2.strip('\\/') or abs2 + return os.path.join(abs1, abs2) \ No newline at end of file diff --git a/conver_tf2npu/main.py b/conver_tf2npu/main.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed7f3d042a75dc477e1a90ce6cc06c7870e9632 --- /dev/null +++ b/conver_tf2npu/main.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import sys +import getopt +import util_global +from file_op import before_clear +from conver import conver + +def para_check_and_set(argv): + input = "input" + output = "output" + report = "report" + + try: + opts, args = getopt.getopt(argv, "hi:o:r:", ["help", "input=", "output=", "report="]) + except getopt.GetoptError: + print('Parameter error, please check.') + print(' main.py -i -o -r ') + print('or: main.py --input= --output= --report=') + print('-i or --input: The source script to be converted, Default value: input/') + print('-o or --output: The destination script after converted, Default value: output/') + print('-r or --report: Conversion report, Default value: report/') + sys.exit(2) + + for opt, arg in opts: + if opt in ("-h", "--help"): + print(' main.py -i -o -r ') + print('or: main.py --input= --output= --report=') + print('-i or --input: The source script to be converted, Default value: input/') + print('-o or --output: The destination script after converted, Default value: output/') + print('-r or --report: Conversion report, Default value: report/') + sys.exit() + elif opt in ("-i", "--input"): + input = arg + elif opt in ("-o", "--output"): + output = arg + elif opt in ("-r", "--report"): + report = arg + util_global.set_value('input', input) + util_global.set_value('output', output) + util_global.set_value('report', report) + +if __name__ == "__main__": + util_global._init() + para_check_and_set(sys.argv[1:]) + before_clear() + conver() \ No newline at end of file diff --git a/conver_tf2npu/mappings/ast.json b/conver_tf2npu/mappings/ast.json new file mode 100644 index 0000000000000000000000000000000000000000..76b87ed138eb107a4afc44ec19a2083a13fc6cd3 --- /dev/null +++ b/conver_tf2npu/mappings/ast.json @@ -0,0 +1,26 @@ +{ +"need_conver": false, +"gelu": ["npu_unary_ops", "tf.gelu", "npu_unary_ops.gelu"], +"dropout": ["npu_ops", "tf.nn.dropout", "npu_ops.dropout"], +"init": ["print", "hvd.init", "None"], +"DistributedOptimizer": ["NPUDistributedOptimizer", "hvd.DistributedOptimizer", "NPUDistributedOptimizer"], +"rank": ["get_rank_id", "hvd.rank", "get_rank_id"], +"local_rank": ["get_local_rank_id", "hvd.local_rank", "get_local_rank_id"], +"size": ["get_rank_size", "hvd.size", "get_rank_size"], +"BroadcastGlobalVariablesHook": ["print", "BroadcastGlobalVariablesHook", "None"], +"shard": ["", "dataset.shard(xxx, xxx)", "dataset.shard(get_rank_size(), get_rank_id())"], +"EstimatorSpec": ["NPUEstimatorSpec", "tf.estimator.EstimatorSpec", "NPUEstimatorSpec"], +"RunConfig": ["NPURunConfig", "tf.estimator.RunConfig", "NPURunConfig"], +"Estimator": ["NPUEstimator", "tf.estimator.Estimator", "NPUEstimator"], + +"batch": ["", "batch(xxx)", "batch(xxx, drop_remainder=True)"], +"map_and_batch": ["", "map_and_batch(xxx)", "map_and_batch(xxx, drop_remainder=True)"], +"device": ["", "tf.device(xxx)", "tf.device(/cpu:0)"], + +"hvd": ["init","rank", "local_rank","size", "DistributedOptimizer"], +"estimator": ["Estimator", "RunConfig", "EstimatorSpec"], +"nn": ["dropout"], +"keras": ["https://support.huaweicloud.com/mprtg-A800_9000_9010/atlasprtg_13_0008.html"], + +"report_file": ["success_report.txt", "failed_report.txt", "need_migration_doc.txt"] +} \ No newline at end of file diff --git a/conver_tf2npu/util.py b/conver_tf2npu/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ae56619e90ed1014dfefadb8aa1383a035be8b41 --- /dev/null +++ b/conver_tf2npu/util.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import util_global +from file_op import write_conver_report + +def log_info(msg): + print(util_global.get_value('path') + ': ' + msg) + +def log_info(lineno, msg, file): + content = (util_global.get_value('path', '') + ':' + str(lineno) + + ' change ' + util_global.get_value(msg)[1] + + ' to ' + util_global.get_value(msg)[2]) + print(content) + write_conver_report(content, file) + +def log_success_report(lineno, msg): + content = (util_global.get_value('path', '') + ':' + str(lineno) + + ' "change ' + util_global.get_value(msg)[1] + + ' to ' + util_global.get_value(msg)[2]) + print(content) + write_conver_report(content, util_global.get_value('report_file')[0]) + +def log_migration_report(lineno, msg): + content = (util_global.get_value('path', '') + ':' + str(lineno) + ' "' + msg + + '" feature needs to be migrated manually, Please refer to the migration guide: ' + + util_global.get_value(msg)[0]) + print(content) + write_conver_report(content, util_global.get_value('report_file')[2]) \ No newline at end of file diff --git a/conver_tf2npu/util_global.py b/conver_tf2npu/util_global.py new file mode 100644 index 0000000000000000000000000000000000000000..8941003eadf7bb7a3272f9aa63e05fcbc597a445 --- /dev/null +++ b/conver_tf2npu/util_global.py @@ -0,0 +1,33 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import json + +def _init(): + global _global_dict + _global_dict = {} + with open('mappings/ast.json', encoding='utf-8') as f: + load_dict = json.load(f) + items = load_dict.items() + for key, value in items: + set_value(key, value) + +def set_value(key, value): + _global_dict[key] = value + +def get_value(key, def_value = None): + try: + return _global_dict[key] + except KeyError: + return def_value diff --git a/tf_adapter/python/npu_bridge/hccl_init.py b/tf_adapter/python/npu_bridge/hccl_init.py deleted file mode 100644 index b60dc5f44d3229deca499b61f1e382012d5ab44e..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/hccl_init.py +++ /dev/null @@ -1,3 +0,0 @@ -from hccl.manage.api import get_local_rank_id -from hccl.manage.api import get_rank_size -from hccl.manage.api import get_rank_id \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/npu_init.py b/tf_adapter/python/npu_bridge/npu_init.py index bf5bec4c5c7d573cf781cbe1954dffb1c1e70da7..91dc7cf8d95db14d89fee7b134354bf0551810eb 100644 --- a/tf_adapter/python/npu_bridge/npu_init.py +++ b/tf_adapter/python/npu_bridge/npu_init.py @@ -4,4 +4,7 @@ from npu_bridge.estimator.npu.npu_optimizer import NPUDistributedOptimizer from npu_bridge.estimator import npu_ops from npu_bridge.estimator.npu_unary_ops import npu_unary_ops from tensorflow.core.protobuf import rewriter_config_pb2 -from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig \ No newline at end of file +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +from hccl.manage.api import get_local_rank_id +from hccl.manage.api import get_rank_size +from hccl.manage.api import get_rank_id \ No newline at end of file diff --git a/tools/.gitkeep b/tools/.gitkeep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000