From c7a085e7f7b52d7349fa6857ef619eac52275a59 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Mon, 31 Jul 2023 20:43:17 +0800 Subject: [PATCH 1/3] forward and backward api info class --- .../api_accuracy_checker/dump/api_info.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 5c6138f6e4..a2c3e0540f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -4,3 +4,34 @@ class APIInfo: def __init__(self, api_name): self.api_name = api_name + +class ForwardAPIInfo(APIInfo): + def __init__(self, name, args, kwargs): + super().__init__(name) + self.analyze_api_input(args, kwargs) + self.analyze_api_call_stack() + + def analyze_api_input(self, args, kwargs): + args_info_list = self.analyze_element(args) + kwargs_info_dict = self.analyze_element(kwargs) + self.api_info_struct = {self.api_name: {"args":args_info_list, "kwargs":kwargs_info_dict}} + + def analyze_api_call_stack(self): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[3:]: + stack_line = " ".join([ + "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), + " ".join(["\n", code[0].strip() if code else code])])]) + stack_str.append(stack_line) + self.stack_info_struct = {self.api_name: stack_str} + + +class BackwardAPIInfo(APIInfo): + def __init__(self, name, grads): + super().__init__(name) + self.analyze_api_input(grads) + + def analyze_api_input(self, grads): + grads_info_list = self.analyze_element(grads) + self.grad_info_struct = {self.api_name:grads_info_list} + -- Gitee From 4d2bae1fb5f24b7ef78b21283992e7de36f327b5 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Mon, 31 Jul 2023 20:52:56 +0800 Subject: [PATCH 2/3] forward and backward api info class --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index a2c3e0540f..119829ac1e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,5 +1,5 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 - +import inspect class APIInfo: def __init__(self, api_name): -- Gitee From b0bb453d97da56c450c6504e41fed5da0b59f715 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Tue, 1 Aug 2023 16:59:56 +0800 Subject: [PATCH 3/3] add blank lines --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 119829ac1e..c8224c515d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,10 +1,12 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 import inspect + class APIInfo: def __init__(self, api_name): self.api_name = api_name + class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): super().__init__(name) -- Gitee