From 1306464afbc617b909b9484b7f57b53cbc0515cf Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 13 Sep 2023 10:59:42 +0800 Subject: [PATCH] api_info --- .../accuracy_tools/api_accuracy_checker/dump/api_info.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index a119966d9f6..2a86699d832 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -5,15 +5,15 @@ from api_accuracy_checker.common.base_api import BaseAPIInfo class APIInfo(BaseAPIInfo): - def __init__(self, api_name, is_forward, is_save_data=msCheckerConfig.real_data, - save_path=msCheckerConfig.dump_path, forward_path='forward_real_data', + def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path='forward_real_data', backward_path='backward_real_data'): super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): - super().__init__(name, is_forward=True) + super().__init__(name, is_forward=True, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path) self.analyze_api_input(args, kwargs) self.analyze_api_call_stack() @@ -35,7 +35,8 @@ class ForwardAPIInfo(APIInfo): class BackwardAPIInfo(APIInfo): def __init__(self, name, grads): - super().__init__(name, is_forward=False) + super().__init__(name, is_forward=False, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path) self.analyze_api_input(grads) def analyze_api_input(self, grads): -- Gitee