diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index a119966d9f66ead41e1f06ad7f521b6948bafd26..2a86699d83259ed75f5f05b4b25f0ffa74fbab61 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -5,15 +5,15 @@ from api_accuracy_checker.common.base_api import BaseAPIInfo class APIInfo(BaseAPIInfo): - def __init__(self, api_name, is_forward, is_save_data=msCheckerConfig.real_data, - save_path=msCheckerConfig.dump_path, forward_path='forward_real_data', + def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path='forward_real_data', backward_path='backward_real_data'): super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): - super().__init__(name, is_forward=True) + super().__init__(name, is_forward=True, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path) self.analyze_api_input(args, kwargs) self.analyze_api_call_stack() @@ -35,7 +35,8 @@ class ForwardAPIInfo(APIInfo): class BackwardAPIInfo(APIInfo): def __init__(self, name, grads): - super().__init__(name, is_forward=False) + super().__init__(name, is_forward=False, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path) self.analyze_api_input(grads) def analyze_api_input(self, grads):