diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index e3a30579a80bd8ce1f2da9ef7b9c28d796a6acd9..89ea1aafdb95494b8e35680017d8cfc8694d728e 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -99,7 +99,7 @@ class Const: "_reduce_scatter_base", "_all_gather_base"] TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] + LEVEL_LIST = ["L0", "L1", "L2", "L0+L1"] STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index e88d506b2c340f9b6141c2e0bb775a693d61a16c..660eb260ecc55abcada2b99e151333aaeebff031 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -188,7 +188,7 @@ class Const: "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"] TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] + LEVEL_LIST = ["L0", "L1", "L2", "L0+L1"] STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 9fc97332f64df89a056cda1b0bfdc730c2cdfa29..9a6770f5157b13dd467cb2d9faa83c03f8c44521 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -82,7 +82,7 @@ class DebuggerConfig: raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.") def check_model(self, model): - if self.level in ["L0", "mix"] and not model: + if self.level in ["L0", "L0+L1"] and not model: raise Exception( f"For level {self.level}, PrecisionDebugger must receive a model argument.", ) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index 9c079aedebeec26120f27318bab1f1cdc8cf99cb..f291e4a49036deea976b3088b61197cf201dbc6f 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -150,7 +150,7 @@ class Service: pass print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) - if self.config.level in ["L0", "mix"]: + if self.config.level in ["L0", "L0+L1"]: assert self.model is not None print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") for name, module in self.model.named_modules(): @@ -172,7 +172,7 @@ class Service: module.register_full_backward_hook( self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) - if self.config.level in ["mix", "L1", "L2"]: + if self.config.level in ["L0+L1", "L1", "L2"]: api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_modularity()