diff --git a/convert_tf2npu/ast_impl.py b/convert_tf2npu/ast_impl.py index e6c91b4323d41cdb01da5b3b10b3682fd5dce7d8..548bb50b4c00a7c6d11b8b2beace42c732f5bd9b 100644 --- a/convert_tf2npu/ast_impl.py +++ b/convert_tf2npu/ast_impl.py @@ -156,6 +156,12 @@ def convert_loss_scale_api(node): def ast_call(node): convert_loss_scale_api(node) + if isinstance(node.func, ast.Attribute) and node.func.attr == "gradients" and \ + isinstance(node.func.value, ast.Name) and node.func.value.id == "tf": + new_node = ast.Call(func=ast.Name(id="npu_allreduce", ctx=ast.Load()), args=[node], keywords=[]) + ast.copy_location(new_node, node) + util_global.set_value('need_conver', True) + return new_node if _call_name_match(node.func, "set_experimental_options"): log_msg(getattr(node, 'lineno', 'None'), 'change set_experimental_options(*) to set_experimental_options(experimental_options)') node.args = [ast.Name(id='experimental_options', ctx=ast.Load())] @@ -406,15 +412,24 @@ def ast_call(node): "nadam": "tf.keras.optimizers.Nadam()", "rmsprop": "tf.keras.optimizers.RMSprop()", "sgd": "tf.keras.optimizers.SGD()"} + opt_keyword = None for keyword in node.keywords: if keyword.arg == "optimizer": + opt_keyword = keyword log_success_report(getattr(node, 'lineno', 'None'), 'KerasDistributeOptimizer') if isinstance(keyword.value, ast.Str): keras_opt = opt_map[keyword.value.s] npu_keras_opt = "npu_keras_optimizer(" + keras_opt + ")" keyword.value = pasta.parse(npu_keras_opt) util_global.set_value('need_conver', True) - return node + break + if opt_keyword is None: + opt_keyword = node.args[0] + if isinstance(opt_keyword, ast.Str): + keras_opt = opt_map[opt_keyword.s] + npu_keras_opt = "npu_keras_optimizer(" + keras_opt + ")" + node.args[0] = pasta.parse(npu_keras_opt) + return node if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Attribute): if (node.func.attr.find("Optimizer") != -1) and (node.func.attr != 'ScipyOptimizerInterface'): log_msg(getattr(node, "lineno", "None"), "add NPUDistributedOptimizer()") @@ -422,14 +437,21 @@ def ast_call(node): ast.copy_location(new_node, node) util_global.set_value('need_conver', True) return new_node + opt_list = ["Adadelta", "Adagrad", "Adam", "Adamax", "Ftrl", "Nadam", "RMSprop", "SGD"] if isinstance(node.func, ast.Attribute): - opt_list = ["Adadelta", "Adagrad", "Adam", "Adamax", "Ftrl", "Nadam", "RMSprop", "SGD"] if node.func.attr in opt_list: log_success_report(getattr(node, "lineno", "None"), "KerasDistributeOptimizer") new_node = ast.Call(func=ast.Name(id="npu_keras_optimizer", ctx=ast.Load()), args=[node], keywords=[]) ast.copy_location(new_node, node) util_global.set_value('need_conver', True) return new_node + if isinstance(node.func, ast.Name): + if node.func.id in opt_list: + log_success_report(getattr(node, "lineno", "None"), "KerasDistributeOptimizer") + new_node = ast.Call(func=ast.Name(id="npu_keras_optimizer", ctx=ast.Load()), args=[node], keywords=[]) + ast.copy_location(new_node, node) + util_global.set_value('need_conver', True) + return new_node if (isinstance(node.func, ast.Attribute) and (node.func.attr == 'MonitoredTrainingSession')) or \ (isinstance(node.func, ast.Name) and (node.func.id == 'MonitoredTrainingSession')): log_success_report(getattr(node, "lineno", "None"), 'MonitoredTrainingSession') diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py index b4a7aef8a80da0853e7c37acb13dc43947d5e92e..ba86ba67fcc2cf8b267d16b6f906215768a7b448 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py @@ -125,6 +125,17 @@ def reduce(tensor, var, root_rank, average=True, fusion=0, fusion_id=-1): new_tensor = tf.div(summed_tensor, rank_size) if average else summed_tensor return new_tensor +def npu_allreduce(grads): + rank_size = os.getenv('RANK_SIZE') + if rank_size == None or int(rank_size) <= 1: + return grads + averaged_gradients = [] + with tf.name_scope("NpuDistribute_Allreduce"): + for grad in grads: + avg_grad = allreduce(grad, True) if grad is not None else None + averaged_gradients.append(avg_grad) + return averaged_gradients + class NPUOptimizer(optimizer.Optimizer): """An optimizer that wraps another tf.Optimizer that can using an allreduce to average gradient values before applying gradients to model weights when diff --git a/tf_adapter/python/npu_bridge/npu_init.py b/tf_adapter/python/npu_bridge/npu_init.py index b4b2527eb63c8016bf48a46e42e480986cc79531..2c9323f20060ae07c84da802012b0b82d193b0a4 100644 --- a/tf_adapter/python/npu_bridge/npu_init.py +++ b/tf_adapter/python/npu_bridge/npu_init.py @@ -26,6 +26,7 @@ from npu_bridge.estimator.npu.npu_hook import NPUBroadcastGlobalVariablesHook from npu_bridge.estimator.npu.npu_optimizer import NPUDistributedOptimizer from npu_bridge.estimator.npu.npu_optimizer import NPUOptimizer from npu_bridge.estimator.npu.npu_optimizer import KerasDistributeOptimizer +from npu_bridge.estimator.npu.npu_optimizer import npu_allreduce from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer from npu_bridge.estimator.npu.npu_loss_scale_manager import FixedLossScaleManager from npu_bridge.estimator.npu.npu_loss_scale_manager import ExponentialUpdateLossScaleManager