From a5873edfea20fab850e688d5e2169fb0cc1b57b6 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Sat, 9 Nov 2024 17:58:30 +0800 Subject: [PATCH] fix input kwargs null & testcase --- .../mindspore/api_accuracy_checker/api_info.py | 4 ++-- .../api_accuracy_checker/test_api_info.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py index b355fac9bc..e6776a539b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py @@ -66,8 +66,8 @@ class ApiInfo: err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)) - if not isinstance(compute_element_info, (list, dict)): - err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict" + if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None): + err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)) kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py index d4f1c11d75..b4b84c3928 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py @@ -21,6 +21,20 @@ class TestApiInfo(unittest.TestCase): """ global_context.init(False, os.path.join(directory, "files")) + def test_get_kwargs_with_null(self): + # first load forward backward api_info + only_kwargs_api_info_dict = { + "input_kwargs": { + "approximate": None, + } + } + api_info = ApiInfo("only_input_kwargs_api") + api_info.load_forward_info(only_kwargs_api_info_dict) + + self.assertTrue(api_info.check_forward_info()) + kwargs_compute_element_dict = api_info.get_kwargs() + self.assertEqual(kwargs_compute_element_dict.get("approximate").get_parameter(), None) + def test_get_compute_element_list(self): # first load forward backward api_info forward_api_info_dict = { -- Gitee