From 61c55c85acb4f28adfde914002b4b95c8a73f3f0 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Fri, 4 Aug 2023 04:44:40 +0000 Subject: [PATCH 1/3] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- .../api_accuracy_checker/dump/api_info.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index cff12e6b84a..f5a5b361447 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -12,6 +12,7 @@ class APIInfo: self.rank = torch_npu.npu.current_device() self.api_name = api_name self.save_real_data = DumpUtil.save_real_data + self.special_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} def analyze_element(self, element): if isinstance(element, (list, tuple)): @@ -21,7 +22,11 @@ class APIInfo: elif isinstance(element, dict): out = {} for key, value in element.items(): - out[key] = self.analyze_element(value) + if key in self.special_key.keys(): + fun = self.special_key[key] + out[key] = fun(value) + else: + out[key] = self.analyze_element(value) elif isinstance(element, torch.Tensor): out = self.analyze_tensor(element, self.save_real_data) @@ -77,7 +82,26 @@ class APIInfo: if element is None or isinstance(element, (bool,int,float,str,slice)): return True return False - + + def analyze_device_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.device'}) + if not isinstance(element, str): + + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + single_arg.update({'value' : device_value}) + else: + device_value = element.type + else: + single_arg.update({'value' : element}) + return single_arg + + def analyze_dtype_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.dtype'}) + single_arg.update({'value' : str(element)}) + return single_arg def get_tensor_extremum(self, data, operator): if data.dtype is torch.bool: -- Gitee From f7798b7deff6ebcb467d92dc2a4d7501b67a1b45 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Fri, 4 Aug 2023 04:49:04 +0000 Subject: [PATCH 2/3] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index ced02d593f2..90979ea6525 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -14,10 +14,7 @@ class APIInfo: self.api_name = api_name self.save_real_data = msCheckerConfig.real_data self.special_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} - - - def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] -- Gitee From c22d4d6d4b4ad6b495c7fb3b3771208d9f525654 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Sat, 5 Aug 2023 02:31:44 +0000 Subject: [PATCH 3/3] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 90979ea6525..d33ed21d813 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -13,7 +13,7 @@ class APIInfo: self.rank = os.getpid() self.api_name = api_name self.save_real_data = msCheckerConfig.real_data - self.special_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} + self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} def analyze_element(self, element): if isinstance(element, (list, tuple)): @@ -23,8 +23,8 @@ class APIInfo: elif isinstance(element, dict): out = {} for key, value in element.items(): - if key in self.special_key.keys(): - fun = self.special_key[key] + if key in self.torch_object_key.keys(): + fun = self.torch_object_key[key] out[key] = fun(value) else: out[key] = self.analyze_element(value) -- Gitee