From aca7ee212fb61ce4f2da96adfb558172ce7a11fe Mon Sep 17 00:00:00 2001 From: l30044004 Date: Mon, 23 Oct 2023 16:17:29 +0800 Subject: [PATCH 1/9] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E5=8F=8D=E5=90=91api=E7=9A=84=E6=8F=90=E7=A4=BA=E4=BF=A1?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 2 +- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 02d40973c1..345da9016b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -34,7 +34,7 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, enable_dataloader=False, target_iter=1): + def update_config(self, dump_path, real_data=False, enable_dataloader=True, target_iter=1): args = { "dump_path": dump_path, "real_data": real_data, diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 4cc27880c7..4ce54b989d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -16,7 +16,6 @@ from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate from ut_api_info import UtAPIInfo from api_accuracy_checker.common.config import msCheckerConfig -NO_GRAD_APIS = ["hardtanh"] def init_environment(): @@ -146,7 +145,7 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di need_backward = api_full_name in backward_content need_backward = need_backward and need_grad if not need_grad: - print_warn_log("%s involves in-place operations, skip backward" % api_full_name) + print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) grad_out, npu_grad_out = None, None @@ -173,8 +172,6 @@ def get_api_info(api_info_dict, api_name): need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - if api_name in NO_GRAD_APIS: - need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) return args, kwargs, need_grad -- Gitee From 1231e25a50aa157ad7f3742d182d93cec2dc9232 Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 23 Oct 2023 12:23:52 +0000 Subject: [PATCH 2/9] update debug/accuracy_tools/api_accuracy_checker/README.md. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index 3ee506f042..c4ea080957 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -67,7 +67,7 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 | ----------------- | ------------------------------------------------------------ | -------- | | dump_path | 设置dump路径,须为已存在目录,默认为当前目录。 | 否 | | real_data | 真实数据模式,可取值True或False,默认为False,配置为True后开启真实数据模式,dump信息增加forward_real_data和backward_real_data目录,目录下保存每个API输入的具体数值。开启真实数据模式目前仅支持单卡,且会存盘较多数据,可能对磁盘空间有较大冲击。 | 否 | - | enable_dataloader | 自动控制开关,可取值True或False,默认为False,配置为True后自动识别dump target_iter参数指定的迭代数据,并在该迭代执行完成后退出训练。 | 否 | + | enable_dataloader | 自动控制开关,可取值True或False,默认为True,配置为True后自动识别dump target_iter参数指定的迭代数据,并在该迭代执行完成后退出训练。 | 否 | | target_iter | 指定dump某个step的数据,默认为1,仅支持dump1个step,须指定为训练脚本中存在的step。 | 否 | 3. 将API信息输入给run_ut模块运行精度检测并比对,运行如下命令: -- Gitee From d56ce35351fbc473967d22a32c86460395b903b0 Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 23 Oct 2023 12:28:23 +0000 Subject: [PATCH 3/9] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py. Signed-off-by: louyujing --- .../api_accuracy_checker/run_ut/run_overflow_check.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 7c0fa0f6a6..79913e9e4e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -10,7 +10,6 @@ from api_accuracy_checker.common.utils import print_info_log, print_warn_log, ge print_error_log -NO_GRAD_APIS = ["hardtanh"] init_environment() @@ -77,10 +76,10 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di api_type = api_full_name.split("_")[0] api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] args, kwargs, need_grad = get_api_info(api_info_dict, api_name) - need_backward = api_full_name.replace("forward", "backward") in backward_content and api_name[-1] != "_" + need_backward = api_full_name.replace("forward", "backward") in backward_content need_backward = need_backward and need_grad if not need_grad: - print_warn_log("%s involves in-place operations, skip backward" % api_full_name) + print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) if kwargs.get("device"): del kwargs["device"] -- Gitee From 6f27052585ec3dd42d886a72f0a6bae855d06d1c Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 04:01:07 +0000 Subject: [PATCH 4/9] update debug/accuracy_tools/api_accuracy_checker/config.yaml. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 76a31db425..22ef99c3a0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,7 +3,6 @@ jit_compile: True real_data: False dump_step: 1000 error_data_path: './' -enable_dataloader: True target_iter: 1 precision: 14 \ No newline at end of file -- Gitee From 2ba70003807045c4312c5beb940dccff5905ddd9 Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 04:02:01 +0000 Subject: [PATCH 5/9] update debug/accuracy_tools/api_accuracy_checker/common/config.py. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 345da9016b..780773381f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -7,6 +7,7 @@ class Config: check_file_or_directory_path(yaml_file, False) with open(yaml_file, 'r') as file: config = yaml.safe_load(file) + config['enable_dataloader'] = True self.config = {key: self.validate(key, value) for key, value in config.items()} def validate(self, key, value): @@ -34,11 +35,10 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, enable_dataloader=True, target_iter=1): + def update_config(self, dump_path, real_data=False, target_iter=1): args = { "dump_path": dump_path, "real_data": real_data, - "enable_dataloader": enable_dataloader, "target_iter": target_iter } for key, value in args.items(): -- Gitee From fe0e4b4dbd703886153e1f32706efdedd9294e1a Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 04:03:09 +0000 Subject: [PATCH 6/9] update debug/accuracy_tools/api_accuracy_checker/README.md. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index c4ea080957..3f786edb69 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -67,7 +67,6 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 | ----------------- | ------------------------------------------------------------ | -------- | | dump_path | 设置dump路径,须为已存在目录,默认为当前目录。 | 否 | | real_data | 真实数据模式,可取值True或False,默认为False,配置为True后开启真实数据模式,dump信息增加forward_real_data和backward_real_data目录,目录下保存每个API输入的具体数值。开启真实数据模式目前仅支持单卡,且会存盘较多数据,可能对磁盘空间有较大冲击。 | 否 | - | enable_dataloader | 自动控制开关,可取值True或False,默认为True,配置为True后自动识别dump target_iter参数指定的迭代数据,并在该迭代执行完成后退出训练。 | 否 | | target_iter | 指定dump某个step的数据,默认为1,仅支持dump1个step,须指定为训练脚本中存在的step。 | 否 | 3. 将API信息输入给run_ut模块运行精度检测并比对,运行如下命令: -- Gitee From a4900af45af025802dd74e39adb94b147be6c8bc Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 04:04:03 +0000 Subject: [PATCH 7/9] update debug/accuracy_tools/api_accuracy_checker/README.md. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index 3f786edb69..2ec0bd9792 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -60,7 +60,7 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 ```Python from api_accuracy_checker.dump import msCheckerConfig - msCheckerConfig.update_config(dump_path="my/dump/path", real_data=True, enable_dataloader=True, target_iter=1) + msCheckerConfig.update_config(dump_path="my/dump/path", real_data=True, target_iter=1) ``` | 参数名称 | 说明 | 是否必选 | -- Gitee From 56f57b0acaef76ef7bcb50513f71111ce18b8b0a Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 09:17:46 +0000 Subject: [PATCH 8/9] update debug/accuracy_tools/api_accuracy_checker/common/config.py. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 39d285df41..c47911e213 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -8,7 +8,6 @@ class Config: check_file_or_directory_path(yaml_file, False) with FileOpen(yaml_file, 'r') as file: config = yaml.safe_load(file) - config['enable_dataloader'] = True self.config = {key: self.validate(key, value) for key, value in config.items()} def validate(self, key, value): @@ -18,7 +17,6 @@ class Config: 'real_data': bool, 'dump_step': int, 'error_data_path': str, - 'enable_dataloader': bool, 'target_iter': int, 'precision': int } -- Gitee From 5ad4ec3944a48bfbf8820bc40d5d07dc735ee7be Mon Sep 17 00:00:00 2001 From: louyujing Date: Tue, 24 Oct 2023 09:20:00 +0000 Subject: [PATCH 9/9] update debug/accuracy_tools/api_accuracy_checker/dump/dump.py. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 830ff5e4fc..2a69e226cd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -45,9 +45,9 @@ class DumpUtil(object): @staticmethod def incr_iter_num_maybe_exit(): - if DumpUtil.call_num == msCheckerConfig.target_iter or not msCheckerConfig.enable_dataloader: + if DumpUtil.call_num == msCheckerConfig.target_iter: set_dump_switch("ON") - elif DumpUtil.call_num > msCheckerConfig.target_iter and msCheckerConfig.enable_dataloader: + elif DumpUtil.call_num > msCheckerConfig.target_iter: raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) else: set_dump_switch("OFF") -- Gitee