diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py index b355fac9bc319039c2523d967b70de443f27f6b9..e6776a539bd242f41109e08eb5a2de1eb5a74020 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_info.py @@ -66,8 +66,8 @@ class ApiInfo: err_msg = "ApiInfo.get_kwargs failed: compute_element_dict key is not a string" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)) - if not isinstance(compute_element_info, (list, dict)): - err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list or dict" + if not (isinstance(compute_element_info, (list, dict)) or compute_element_info is None): + err_msg = "ApiInfo.get_kwargs failed: compute_element_dict value is not a list, dict or null" logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ParseJsonFailed)) kwargs_compute_element_dict = {key_str: ComputeElement(compute_element_info=compute_element_info) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py index d4f1c11d7544785c353d2dd22792e45c257cdfe6..b4b84c39286efb6e08f3c5c927c5c1d3c6b07dda 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/api_accuracy_checker/test_api_info.py @@ -21,6 +21,20 @@ class TestApiInfo(unittest.TestCase): """ global_context.init(False, os.path.join(directory, "files")) + def test_get_kwargs_with_null(self): + # first load forward backward api_info + only_kwargs_api_info_dict = { + "input_kwargs": { + "approximate": None, + } + } + api_info = ApiInfo("only_input_kwargs_api") + api_info.load_forward_info(only_kwargs_api_info_dict) + + self.assertTrue(api_info.check_forward_info()) + kwargs_compute_element_dict = api_info.get_kwargs() + self.assertEqual(kwargs_compute_element_dict.get("approximate").get_parameter(), None) + def test_get_compute_element_list(self): # first load forward backward api_info forward_api_info_dict = {