diff --git a/debug/accuracy_tools/atat/config/config.json b/debug/accuracy_tools/atat/config/config.json index 6ddae793f3f5623dfa03a5f0bc68d0b46bd5ad5f..c249a1e4982207c6a7467119b35f918d021b6ed3 100644 --- a/debug/accuracy_tools/atat/config/config.json +++ b/debug/accuracy_tools/atat/config/config.json @@ -1,32 +1,26 @@ { - "task": "statistics", //(ms,pt), 必要, 默认statistics | 取值["tensor"(真实数据及统计数据), "statistics", "summary_md5", "overflow_check"] - "dump_path": "/home/wangchao/out", //(ms,pt),必要, 输出数据path - "rank": [1], //(ms,pt),默认不配置或为空即为dump全量,注意pytorch针对此功能要做整改 - "step": [0,1], //(ms,pt),默认不配置或为空即为dump全量,注意mindspore需要注意此step的真实行为 - "level": "L1", //(ms,pt),默认L1。 | L0:module, L1:api, L2:kernel - // 上述为通用配置 - // 以下为每种任务的配置,当前从业务上讲,只支持一种任务 - "tensor" : { - "scope": [""], //[]中有两个值,用于锁定区间,L0级别下使用,表示两个模块间的范围(pt) - "list":[""], //list, api_list, kernel_api, kernel_name(ms,pt), 某个api下的kernel dump(pt) - "data_mode": ["all"], //dump数据模式。默认all(pt,ms)。可取值"all"、"forward"、"backward"、"input"和"output",表示仅保存dump的数据中文件名包含"forward"、"backward"、"input"和"output"的前向、反向、输入或输出的.npy文件。 - "backward_input": "/home/wangchao/out/forward_mul_1.npy", //"dump_mode"="acl"时,反向数据 - "file_format": "npy" // 真实tensor数据的保存格式(ms) | [bin, npy] + "task": "overflow_check", + "dump_path": "/dump/path", + "rank": [12], + "step": [], + "level": "L1", + "seed": 1234, + "is_deterministic": false, + "tensor": { + "scope": [], + "list":[], + "data_mode": ["all"], + "backward_input": "", + "file_format": "npy" }, - "statistics" : { - "scope": [""], //[]中有两个值,用于锁定区间,L0级别下使用,表示两个模块间的范围(pt) - "list":[""], //list(pt), api_list(pt), kernel_name(ms) - "data_mode": ["all"] //dump数据模式。默认all(pt,ms)可取值"all"、"forward"、"backward"、"input"和"output",表示仅保存dump的数据中文件名包含"forward"、"backward"、"input"和"output"的前向、反向、输入或输出的.npy文件。 - }, - "summary_md5" : { - "scope": [""], //[]中有两个值,用于锁定区间,L0级别下使用,表示两个模块间的范围(pt) - "list":[""], //list(pt), api_list(pt), kernel_name(ms) - "data_mode": ["all"] //dump数据模式。默认all(pt,ms)可取值"all"、"forward"、"backward"、"input"和"output",表示仅保存dump的数据中文件名包含"forward"、"backward"、"input"和"output"的前向、反向、输入或输出的.npy文件。 + "statistics": { + "scope": [], + "list":[], + "data_mode": ["all"], + "summary_mode": "statistics" }, "overflow_check": { - "overflow_nums": 1, //溢出次数(pt),默认1,即检测到1次溢出,训练停止,配置为-1时,表示持续检测溢出直到训练结束。数据类型:int。 - "check_mode":"all" //["aicore", "atomic", "all"] kernel级别溢出检测的能力(ms) + "overflow_nums": 1, + "check_mode":"all" } - - //后续配置文件维护策略:当前架构归一只关心数据采集,后续如果集成解析、梯度、无标杆等能力,新增task及task的任务配置 } \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/atat/core/common_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f0fe9ec5e009ad268364effeaa32d7f4c5850c --- /dev/null +++ b/debug/accuracy_tools/atat/core/common_config.py @@ -0,0 +1,52 @@ +# from ..common.exceptions import CalibratorException +from .file_check_util import FileChecker, FileCheckConst +from .utils import Const + + +# 公共配置类 +class CommonConfig: + def __init__(self, json_config): + self.task = json_config.get('task') + self.dump_path = json_config.get('dump_path') + self.rank = json_config.get('rank') + self.step = json_config.get('step') + self.level = json_config.get('level') + self.seed = json_config.get('seed') + self.is_deterministic = json_config.get('is_deterministic') + self._check_config() + + def _check_config(self): + if self.task not in Const.TASK_LIST: + raise Exception("task is invalid") + if self.rank is not None and not isinstance(self.rank, list): + raise Exception("rank is invalid") + if self.step is not None and not isinstance(self.step, list): + raise Exception("step is invalid") + if self.level not in ["L0", "L1", "L2"]: + raise Exception("level is invalid") + if not isinstance(self.seed, int): + raise Exception("seed is invalid") + if not isinstance(self.is_deterministic, bool): + raise Exception("is_deterministic is invalid") + + +# 基础配置类 +class BaseConfig: + def __init__(self, json_config): + self.scope = json_config.get('scope') + self.list = json_config.get('list') + self.data_mode = json_config.get('data_mode') + + def check_config(self): + if self.scope is not None and not isinstance(self.scope, list): + raise Exception("scope is invalid") + if self.list is not None and not isinstance(self.list, list): + raise Exception("list is invalid") + if self.data_mode is not None and not isinstance(self.data_mode, list): + raise Exception("data_mode is invalid") + + + + + + diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index ab9c26008e95c3b791f00cac0d4cee1df790a47e..f19ac15f3b161d9a72b757e36f2a1878dc431b5b 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -96,6 +96,7 @@ class Const: INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", "_reduce_scatter_base", "_all_gather_base"] + TASK_LIST = ["tensor", "statistics", "overflow_check"] class CompareConst: """ diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 59d41bea94ca1fd8b3fd1289d28747bb979fe513..ef6432ba2713720e9ed4ef3888a4642a9ee670fe 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -1,37 +1,47 @@ -import os -from ..common.utils import print_warn_log +from ..common import print_warn_log_rank_0, seed_all class DebuggerConfig: - def __init__(self, dump_path, hook_name, rank=None, step=None): - self.dump_path = dump_path - self.hook_name = hook_name - self.rank = rank - self.step = step if step is not None else [] + def __init__(self, common_config, task_config): + self.dump_path = common_config.dump_path + self.task = common_config.task + self.rank = common_config.rank + self.step = common_config.step + self.level = common_config.level + self.seed = common_config.seed + self.is_deterministic = common_config.is_deterministic + self.scope = task_config.scope + self.list = task_config.list + self.data_mode = task_config.data_mode + self.baackward_input = task_config.backward_input + self.summary_mode = task_config.summary_mode + self.overflow_check = task_config.overflow_check + self.repair_type = None + self.repair_scope = None + self.repair_api_str = None + self.on_step_end = None + self.check() if self.step: self.step.sort() + seed_all(self.seed, self.is_deterministic) def check(self): - self._check_hook_name() self._check_rank() self._check_step() return True - def _check_hook_name(self): - if self.hook_name not in ["dump", "overflow_check"]: - raise ValueError(f"hook_name should be in ['dump', 'overflow_check'], got {self.hook_name}") - def _check_rank(self): - if self.rank is not None: - if not isinstance(self.rank, int) or self.rank < 0: - raise ValueError(f"rank {self.rank} must be a positive integer.") + if self.rank: + for rank_id in self.rank: + if not isinstance(rank_id, int) or rank_id < 0: + raise ValueError(f"rank {self.rank} must be a positive integer.") else: - print_warn_log(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.") + print_warn_log_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.") def _check_step(self): if not isinstance(self.step, list): raise ValueError(f"step {self.step} should be list") for s in self.step: if not isinstance(s, int): - raise ValueError(f"step element {s} should be int") + raise ValueError(f"step element {s} should be int") \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py index d9bbf6f4060e29fab0164c66ccbd121f22944c31..1e1cc58cdbcb09bbcca01eb75a481af89a03bd89 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py @@ -1,16 +1,7 @@ -import os -from concurrent.futures import ThreadPoolExecutor -import torch -from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu, print_error_log, \ - CompareException, print_warn_log -from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count -from ..dump.utils import set_dump_path, set_dump_switch_print_info, generate_dump_path_str, \ - set_dump_switch_config, set_backward_input -from ..overflow_check.utils import OverFlowUtil -from ..overflow_check.overflow_check import overflow_check -from ..hook_module.register_hook import register_hook_core, init_overflow_nums -from ..hook_module.hook_module import HOOKModule from .debugger_config import DebuggerConfig +from ..service import Service +from ..common import print_warn_log_rank_0 +from ..pt_config import parse_json_config class PrecisionDebugger: @@ -19,104 +10,31 @@ class PrecisionDebugger: def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(PrecisionDebugger, cls).__new__(cls) - cls._instance.first_start = True - cls._instance.hook_func = None cls._instance.config = None cls._instance.model = None cls._instance.enable_dataloader = False return cls._instance - def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False, model=None): + def __init__(self, config_path=None, *args, **kwargs): if not hasattr(self, 'initialized'): self.initialized = True - if hook_name is None: - err_msg = "You must provide hook_name argument to PrecisionDebugger\ - when config is not provided." - raise Exception(err_msg) - self.config = DebuggerConfig(dump_path, hook_name, rank, step) - self.configure_hook = self.get_configure_hook(self.config.hook_name) - self.configure_hook() - DumpUtil.target_iter = self.config.step - DumpUtil.target_rank = self.config.rank - set_dump_path(self.config.dump_path) - self.hook_func = overflow_check if self.config.hook_name == "overflow_check" else acc_cmp_dump - self.model = model - self.enable_dataloader = enable_dataloader - if not isinstance(enable_dataloader, bool): - print_error_log("Params enable_dataloader only support True or False.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if enable_dataloader: - DumpUtil.iter_num -= 1 - torch.utils.data.dataloader._BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) - - def get_configure_hook(self, hook_name): - hook_dict = {"dump": self.configure_full_dump, "overflow_check": self.configure_overflow_dump} - return hook_dict.get(hook_name, lambda: ValueError("hook name {} is not in ['dump', 'overflow_check']".format(hook_name))) - - def configure_full_dump(self, mode='api_stack', scope=None, api_list=None, filter_switch=Const.OFF, - input_output_mode=[Const.ALL], acl_config=None, backward_input=None, summary_only=False, summary_mode=None): - if mode == "acl" and self.model is not None: - print_error_log("Init dump does not support ACL dump mode.") - raise CompareException(CompareException.INVALID_DUMP_MODE) - scope = scope if scope is not None else [] - api_list = api_list if api_list is not None else [] - backward_input = backward_input if backward_input is not None else [] - - if summary_only: - if summary_mode is not None: - raise ValueError("summary_mode can not be used with summary_only") - print_warn_log("Argument 'summary_only' will be deprecated, it would be better to use 'summary_mode'") - summary_mode = "summary" - elif summary_mode is None: - summary_mode = "all" - - set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, - filter_switch=filter_switch, dump_mode=input_output_mode, summary_only=summary_only, - summary_mode=summary_mode) - if mode == 'acl': - DumpUtil.set_acl_config(acl_config) - if not scope or not isinstance(scope, list) or len(scope) != 1: - raise ValueError("scope must be congfigured as a list with one api name") - if isinstance(scope[0], str) and 'backward' in scope[0] and not backward_input: - raise ValueError("backward_input must be configured when scope contains 'backward'") - elif 'backward' in scope[0]: - set_backward_input(backward_input) - - def configure_overflow_dump(self, mode="api", acl_config=None, overflow_nums=1, filter_switch=Const.OFF, need_replicate=False): - if mode == "acl": - DumpUtil.dump_switch_mode = mode - DumpUtil.set_acl_config(acl_config) - init_overflow_nums(overflow_nums) - check_switch_valid(filter_switch) - OverFlowUtil.overflow_filter_switch = filter_switch - if not isinstance(need_replicate, bool): - print_error_log("Params autojudge only support True or False.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if need_replicate: - DumpUtil.need_replicate = True + common_config, task_config = parse_json_config(config_path) + self.config = DebuggerConfig(common_config, task_config) + self.model = kwargs.get('model') + self.enable_dataloader = kwargs.get('enable_dataloader') + self.service = Service(self.model, self.config) + @classmethod def start(cls): instance = cls._instance if not instance: raise Exception("No instance of PrecisionDebugger found.") if instance.enable_dataloader: - print_warn_log("DataLoader is enabled, start() skipped.") + print_warn_log_rank_0("DataLoader is enabled, start() skipped.") else: - if DumpUtil.iter_num in DumpUtil.target_iter or not DumpUtil.target_iter: - if instance.first_start: - register_hook_core(instance.hook_func, instance.model) - instance.first_start = False - DumpUtil.dump_switch = "ON" - DumpUtil.dump_thread_pool = ThreadPoolExecutor() - OverFlowUtil.overflow_check_switch = "ON" - dump_path_str = generate_dump_path_str() - set_dump_switch_print_info("ON", DumpUtil.dump_switch_mode, dump_path_str) - elif DumpUtil.target_iter and DumpUtil.iter_num > max(DumpUtil.target_iter): - cls.stop() - raise Exception("ptdbg: exit after iteration {}".format(max(DumpUtil.target_iter))) - else: - cls.stop() + instance.service.start() + @classmethod def stop(cls): @@ -124,45 +42,12 @@ class PrecisionDebugger: if not instance: raise Exception("PrecisionDebugger instance is not created.") if instance.enable_dataloader: - print_warn_log("DataLoader is enabled, stop() skipped.") + print_warn_log_rank_0("DataLoader is enabled, stop() skipped.") else: - DumpUtil.dump_switch = "OFF" - OverFlowUtil.overflow_check_switch = "OFF" - dump_path_str = generate_dump_path_str() - set_dump_switch_print_info("OFF", DumpUtil.dump_switch_mode, dump_path_str) - write_to_disk() - if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: - DumpUtil.dump_thread_pool.shutdown(wait=True) - if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE, Const.API_LIST]: - generate_compare_script(DumpUtil.dump_data_dir, get_pkl_file_path(), DumpUtil.dump_switch_mode) + instance.service.stop() @classmethod def step(cls): - instance = cls._instance - if not instance: + if not cls._instance: raise Exception("PrecisionDebugger instance is not created.") - if not instance.enable_dataloader: - DumpUtil.iter_num += 1 - DumpUtil.dump_init_enable = True - HOOKModule.module_count = {} - reset_module_count() - else: - print_warn_log("DataLoader is enabled, step() skipped.") - - @staticmethod - def incr_iter_num_maybe_exit(): - PrecisionDebugger.step() - PrecisionDebugger.start() - - -def iter_tracer(func): - def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger._instance - temp_enable_dataloader = debugger_instance.enable_dataloader - debugger_instance.enable_dataloader = False - debugger_instance.stop() - result = func(*args, **kwargs) - debugger_instance.incr_iter_num_maybe_exit() - debugger_instance.enable_dataloader = temp_enable_dataloader - return result - return func_wrapper + cls._instance.service.step() diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py new file mode 100644 index 0000000000000000000000000000000000000000..75cf230f5dd9fa2f2dee6aba904721979041b611 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -0,0 +1,68 @@ +import os +import json +from ..core.common_config import CommonConfig, BaseConfig +from ..core.file_check_util import FileOpen + + +#特定任务配置类 +class TensorConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.backward_input = json_config.get("backward_input") + self.file_format = json_config.get("file_format") + self.check_config() + self._check_file_format() + + def _check_file_format(self): + if self.file_format not in ["npy", "bin"]: + raise Exception("file_format is invalid") + + +class StatisticsConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.summary_mode = json_config.get("summary_mode") + self.check_config() + self._check_summary_mode() + + def _check_summary_mode(self): + if self.summary_mode not in ["statistics", "md5"]: + raise Exception("summary_mode is invalid") + + +class OverviewConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.overflow_num = json_config.get("overflow_num") + self.check_mode = json_config.get("check_mode") + self.check_overflow_config() + + def check_overflow_config(self): + if not isinstance(self.overflow_num, int): + raise Exception("overflow_num is invalid") + if self.check_mode not in ["all", "aicore", "atomic"]: + raise Exception("check_mode is invalid") + + +def parse_common_config(json_config): + return CommonConfig(json_config) + + +def parse_task_config(task, json_config): + if task == "tensor": + return TensorConfig(json_config["tensor"]) + elif task == "statistics": + return StatisticsConfig(json_config["statistics"]) + elif task == "overview": + return OverviewConfig(json_config["overview"]) + + +def parse_json_config(json_file_path): + if not json_file_path: + config_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + json_file_path = os.path.join(os.path.join(config_dir, "config"), "config.json") + with FileOpen(json_file_path, 'r') as file: + json_config = json.load(file) + common_config = parse_common_config(json_config) + task_config = parse_task_config(common_config.task, json_config) + return common_config, task_config \ No newline at end of file