From c437aad8f4a615476f9d10bb998bf2b285742921 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 25 Jul 2024 17:32:26 +0800
Subject: [PATCH 01/67] =?UTF-8?q?=E6=8F=90=E7=82=BC=E5=87=BD=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../mindspore/dump/hook_cell/api_registry.py | 93 +++++++++++++++++++
.../msprobe/mindspore/service.py | 2 +
2 files changed, 95 insertions(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 5508416fde0..2f032d93d7c 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -13,12 +13,18 @@
# limitations under the License.
# ============================================================================
+import os
+import functools
import mindspore as ms
+from mindspore.common.tensor import Tensor
+from msprobe.core.common.utils import Const
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
from msprobe.core.common.utils import Const
+PRIMITIVE_PREFIX = "Primitive"
class ApiRegistry:
def __init__(self):
@@ -35,6 +41,7 @@ class ApiRegistry:
self.norm_inner_ops_hook_attr = {}
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
+ self.primitive_counters = {}
@staticmethod
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
@@ -100,5 +107,91 @@ class ApiRegistry:
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
+ def wrap_primitive(self, origin_func, primitive_name, service_instance):
+ primitive_instance = self
+ def func(self, *args, **kwargs):
+ if primitive_name not in primitive_instance.primitive_counters:
+ primitive_instance.primitive_counters[primitive_name] = 0
+ else:
+ primitive_instance.primitive_counters[primitive_name] += 1
+
+ current_count = primitive_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
+
+ def input_backward_hook(grad):
+ print(f"Grad input length: {len(grad)}")
+ print("Captured input grad:", grad)
+ captured_grads_input.append(grad)
+ backward_primitive_name = updated_primitive_name + Const.BACKWARD
+ new_module_input_output = ModuleBackwardInputsOutputs(
+ grad_input=tuple(captured_grads_input),
+ grad_output=tuple(captured_grads_output) if captured_grads_output else None
+ )
+ service_instance.data_collector.backward_data_collect(
+ backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ )
+#1未考虑多输出场景
+# 如果时多grad呢
+# 3 输出的序号问题
+ def output_backward_hook(grad):
+ captured_grads_output.append(grad)
+ backward_primitive_name = primitive_name + Const.BACKWARD
+ new_module_input_output = ModuleBackwardInputsOutputs(
+ grad_input=None,
+ grad_output=tuple(captured_grads_output)
+ )
+ service_instance.data_collector.backward_data_collect(
+ backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ )
+
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
+
+ print(f"Entering {updated_primitive_name} hook, number of args: {len(args)}, name: {self.name}")
+ hooked_inputs = []
+
+ # for idx, arg in enumerate(args):
+ # if isinstance(arg, Tensor):
+ # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ # hooked_inputs.append(arg_hooked)
+ # else:
+ # hooked_inputs.append(arg)
+
+ out = origin_func(*arg, **kwargs)
+ forward_primitive_name = updated_primitive_name + Const.FORWARD
+
+ if service_instance.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
+
+ if isinstance(out, Tensor):
+ out = ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ out = tuple(hooked_outputs)
+
+ return out
+
+ return func
+
+ def register_hooks(self, service_instance):
+ primitive_set = set()
+ for name, cell in self.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ print("primitive name is", pname)
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
+ primitive.__class__ = NewPrimitive
api_register = ApiRegistry()
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index e8aa34dc4fe..8d802e14d06 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -136,3 +136,5 @@ class Service:
if self.config.level == "L1":
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
+ if self.model:
+ api_register.register_hooks(self)
--
Gitee
From e39e7175accd5ff882432572c889f1336a84cb8e Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 25 Jul 2024 20:28:17 +0800
Subject: [PATCH 02/67] =?UTF-8?q?=E8=B7=91=E9=80=9Aprimitive?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../mindspore/debugger/precision_debugger.py | 4 +--
.../mindspore/dump/hook_cell/api_registry.py | 35 ++++---------------
2 files changed, 8 insertions(+), 31 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
index 30f7162ff5c..28161c66855 100644
--- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
@@ -27,12 +27,12 @@ class PrecisionDebugger:
self.service = Service(self.config)
@classmethod
- def start(cls):
+ def start(cls, target=None):
instance = cls._instance
if not instance:
raise Exception("No instance of PrecisionDebugger found.")
if ms.get_context("mode") == ms.PYNATIVE_MODE and instance.config.level_ori == "L1":
- instance.service.start()
+ instance.service.start(target)
else:
handler = TaskHandlerFactory.create(instance.config)
handler.handle()
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 2f032d93d7c..03fd47e8fb1 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -16,6 +16,7 @@
import os
import functools
import mindspore as ms
+from mindspore import ops
from mindspore.common.tensor import Tensor
from msprobe.core.common.utils import Const
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
@@ -117,27 +118,11 @@ class ApiRegistry:
current_count = primitive_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
captured_grads_output = []
- def input_backward_hook(grad):
- print(f"Grad input length: {len(grad)}")
- print("Captured input grad:", grad)
- captured_grads_input.append(grad)
- backward_primitive_name = updated_primitive_name + Const.BACKWARD
- new_module_input_output = ModuleBackwardInputsOutputs(
- grad_input=tuple(captured_grads_input),
- grad_output=tuple(captured_grads_output) if captured_grads_output else None
- )
- service_instance.data_collector.backward_data_collect(
- backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
- )
-#1未考虑多输出场景
-# 如果时多grad呢
-# 3 输出的序号问题
def output_backward_hook(grad):
captured_grads_output.append(grad)
- backward_primitive_name = primitive_name + Const.BACKWARD
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputsOutputs(
grad_input=None,
grad_output=tuple(captured_grads_output)
@@ -149,21 +134,13 @@ class ApiRegistry:
if not service_instance.switch:
return origin_func(*args, **kwargs)
- print(f"Entering {updated_primitive_name} hook, number of args: {len(args)}, name: {self.name}")
hooked_inputs = []
- # for idx, arg in enumerate(args):
- # if isinstance(arg, Tensor):
- # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- # hooked_inputs.append(arg_hooked)
- # else:
- # hooked_inputs.append(arg)
-
- out = origin_func(*arg, **kwargs)
- forward_primitive_name = updated_primitive_name + Const.FORWARD
+ out = origin_func(*args, **kwargs)
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
if service_instance.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
@@ -185,7 +162,7 @@ class ApiRegistry:
def register_hooks(self, service_instance):
primitive_set = set()
- for name, cell in self.model.cells_and_names():
+ for name, cell in service_instance.model.cells_and_names():
for pname, primitive in cell._primitives.items():
primitive_set.add((pname, primitive))
--
Gitee
From 6aaed48a6b8c22220b2a2e1e1878b8e2346aa185 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 11:11:06 +0800
Subject: [PATCH 03/67] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=BE=93=E5=85=A5?=
=?UTF-8?q?=E6=A2=AF=E5=BA=A6?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../core/data_dump/data_processor/base.py | 24 +++++++++++++++++++
.../mindspore/dump/hook_cell/api_registry.py | 14 +++++++++++
2 files changed, 38 insertions(+)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index 5d901291973..4bcf6418197 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -237,6 +237,30 @@ class BaseDataProcessor:
return api_info_struct
+ def analyze_backward_input(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ """
+ Analyze and save backward input gradients.
+ """
+ api_info_struct = {}
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
+ api_info_struct[name] = {}
+ self.api_data_category = Const.INPUT
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
+ api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
+ return api_info_struct
+
+ def analyze_backward_output(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ """
+ Analyze and save backward output gradients.
+ """
+ api_info_struct = {}
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
+ api_info_struct[name] = {}
+ self.api_data_category = Const.OUTPUT
+ input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
+ api_info_struct[name][Const.GRAD_INPUT] = input_info_list
+ return api_info_struct
+
def get_save_file_path(self, suffix):
file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 03fd47e8fb1..77e740011f1 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -118,8 +118,22 @@ class ApiRegistry:
current_count = primitive_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
captured_grads_output = []
+ def input_backward_hook(grad):
+ print(f"Grad input length: {len(grad)}")
+ print("Captured input grad:", grad)
+ captured_grads_input.append(grad)
+ backward_primitive_name = updated_primitive_name + Const.BACKWARD
+ new_module_input_output = ModuleBackwardInputsOutputs(
+ grad_input=tuple(captured_grads_input),
+ grad_output=tuple(captured_grads_output) if captured_grads_output else None
+ )
+ service_instance.data_collector.backward_data_collect(
+ backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ )
+
def output_backward_hook(grad):
captured_grads_output.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
--
Gitee
From da8338f1f906f677fb6d79d0167300fdaef43593 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 11:31:03 +0800
Subject: [PATCH 04/67] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=BE=93=E5=85=A5hook?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 77e740011f1..b28a07d94c8 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -149,8 +149,14 @@ class ApiRegistry:
return origin_func(*args, **kwargs)
hooked_inputs = []
-
- out = origin_func(*args, **kwargs)
+ for idx, arg in enumerate(args):
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
+ else:
+ hooked_inputs.append(arg)
+
+ out = origin_func(*hooked_inputs, **kwargs)
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
if service_instance.data_collector:
--
Gitee
From 085085d7ae18c8a0ce6060558754868cab0a404c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 14:40:15 +0800
Subject: [PATCH 05/67] =?UTF-8?q?=E6=8B=86=E5=88=86=E5=87=BD=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../msprobe/core/data_dump/data_collector.py | 16 ++++++++++++++++
.../mindspore/dump/hook_cell/api_registry.py | 4 ++--
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
index 800a2b81c2f..a537fa3d06a 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
@@ -100,6 +100,22 @@ class DataCollector:
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
self.handle_data(name, data_info)
+ def backward_input_data_collect(self, name, module, pid, module_input_output):
+ self.update_construct(name)
+ if not self.check_scope_and_pid(self.scope, name, pid):
+ return
+
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
+ self.handle_data(name, data_info)
+
+ def backward_output_data_collect(self, name, module, pid, module_input_output):
+ self.update_construct(name)
+ if not self.check_scope_and_pid(self.scope, name, pid):
+ return
+
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
+ self.handle_data(name, data_info)
+
def update_construct(self, name):
if self.config.level not in DataCollector.level_without_construct:
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index b28a07d94c8..4a790a5cbb8 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -130,7 +130,7 @@ class ApiRegistry:
grad_input=tuple(captured_grads_input),
grad_output=tuple(captured_grads_output) if captured_grads_output else None
)
- service_instance.data_collector.backward_data_collect(
+ service_instance.data_collector.backward_input_data_collect(
backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
)
@@ -141,7 +141,7 @@ class ApiRegistry:
grad_input=None,
grad_output=tuple(captured_grads_output)
)
- service_instance.data_collector.backward_data_collect(
+ service_instance.data_collector.backward_output_data_collect(
backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
)
--
Gitee
From 88ce243a0e483a165e42e010856a88afbbb15220 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 14:43:28 +0800
Subject: [PATCH 06/67] =?UTF-8?q?=E6=94=B9=E5=90=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 4a790a5cbb8..e50502cc991 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -125,7 +125,7 @@ class ApiRegistry:
print(f"Grad input length: {len(grad)}")
print("Captured input grad:", grad)
captured_grads_input.append(grad)
- backward_primitive_name = updated_primitive_name + Const.BACKWARD
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputsOutputs(
grad_input=tuple(captured_grads_input),
grad_output=tuple(captured_grads_output) if captured_grads_output else None
--
Gitee
From b4ea092df299862c74880a870d432e028b56b4b0 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 14:58:46 +0800
Subject: [PATCH 07/67] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E5=8F=AA=E8=BE=93?=
=?UTF-8?q?=E5=85=A5/=E8=BE=93=E5=87=BA?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../core/data_dump/data_processor/base.py | 15 +++++++++++++++
.../mindspore/dump/hook_cell/api_registry.py | 16 ++++++----------
2 files changed, 21 insertions(+), 10 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index 4bcf6418197..75238663f04 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -39,7 +39,22 @@ class ModuleBackwardInputsOutputs:
@property
def grad_output_tuple(self):
return convert_tuple(self.grad_output)
+
+@dataclass
+class ModuleBackwardInputs:
+ grad_input: Optional[Tuple]
+
+ @property
+ def grad_input_tuple(self):
+ return convert_tuple(self.grad_input)
+
+@dataclass
+class ModuleBackwardOutputs:
+ grad_output: Optional[Tuple]
+ @property
+ def grad_output_tuple(self):
+ return convert_tuple(self.grad_output)
class TensorStatInfo:
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index e50502cc991..a9f01ef5fac 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -122,27 +122,23 @@ class ApiRegistry:
captured_grads_output = []
def input_backward_hook(grad):
- print(f"Grad input length: {len(grad)}")
- print("Captured input grad:", grad)
+ # print(f"Grad input length: {len(grad)}")
+ # print("Captured input grad:", grad)
captured_grads_input.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputsOutputs(
- grad_input=tuple(captured_grads_input),
- grad_output=tuple(captured_grads_output) if captured_grads_output else None
- )
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ backward_primitive_name, self, os.getpid(), new_module_input_output
)
def output_backward_hook(grad):
captured_grads_output.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputsOutputs(
- grad_input=None,
+ new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ backward_primitive_name, self, os.getpid(), new_module_input_output
)
if not service_instance.switch:
--
Gitee
From 477d07441c832f0cd376ba1e65ca54c83c86c5a7 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 15:00:26 +0800
Subject: [PATCH 08/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index a9f01ef5fac..16528f76d70 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -19,7 +19,7 @@ import mindspore as ms
from mindspore import ops
from mindspore.common.tensor import Tensor
from msprobe.core.common.utils import Const
-from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleForwardInputs, ModuleForwardOutputs
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
--
Gitee
From 149f349cc60ad3caf096bbd1a2c297f458e8771a Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 15:01:49 +0800
Subject: [PATCH 09/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 16528f76d70..a3f9ec35d76 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -19,7 +19,7 @@ import mindspore as ms
from mindspore import ops
from mindspore.common.tensor import Tensor
from msprobe.core.common.utils import Const
-from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleForwardInputs, ModuleForwardOutputs
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs, ModuleBackwardOutputs
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
--
Gitee
From 4dead5fbefa7f68b436b0df1f00e4fd778ed971c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 15:07:22 +0800
Subject: [PATCH 10/67] Update base.py
---
.../msprobe/core/data_dump/data_processor/base.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index 75238663f04..d96d107ac5d 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -260,8 +260,8 @@ class BaseDataProcessor:
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
api_info_struct[name] = {}
self.api_data_category = Const.INPUT
- output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
- api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
+ output_info_list = self.analyze_element(module_input_output.grad_input_tuple)
+ api_info_struct[name][Const.GRAD_INPUT] = output_info_list
return api_info_struct
def analyze_backward_output(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
@@ -272,8 +272,8 @@ class BaseDataProcessor:
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = {}
self.api_data_category = Const.OUTPUT
- input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
- api_info_struct[name][Const.GRAD_INPUT] = input_info_list
+ output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
+ api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
return api_info_struct
def get_save_file_path(self, suffix):
--
Gitee
From d8fce96489f22adab19c9c7c93a5c15f3a05a40b Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 16:15:37 +0800
Subject: [PATCH 11/67] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=AD=A3=E5=90=91?=
=?UTF-8?q?=E5=8F=8D=E5=90=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../msprobe/core/data_dump/data_processor/base.py | 6 +++---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 4 ++--
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index d96d107ac5d..e725d362e8c 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -261,7 +261,7 @@ class BaseDataProcessor:
api_info_struct[name] = {}
self.api_data_category = Const.INPUT
output_info_list = self.analyze_element(module_input_output.grad_input_tuple)
- api_info_struct[name][Const.GRAD_INPUT] = output_info_list
+ api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
return api_info_struct
def analyze_backward_output(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
@@ -272,8 +272,8 @@ class BaseDataProcessor:
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = {}
self.api_data_category = Const.OUTPUT
- output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
- api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
+ input_info_list = self.analyze_element(module_input_output.grad_output_tuple)
+ api_info_struct[name][Const.GRAD_INPUT] = input_info_list
return api_info_struct
def get_save_file_path(self, suffix):
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index a3f9ec35d76..97ab5e3285c 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -127,7 +127,7 @@ class ApiRegistry:
captured_grads_input.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
+ service_instance.data_collector.backward_output_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
@@ -137,7 +137,7 @@ class ApiRegistry:
new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
- service_instance.data_collector.backward_output_data_collect(
+ service_instance.data_collector.backward_input_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
--
Gitee
From c7cb33af18ed7c88a18d7dfc8905c0a4c9185c38 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 16:23:26 +0800
Subject: [PATCH 12/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 97ab5e3285c..a3f9ec35d76 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -127,7 +127,7 @@ class ApiRegistry:
captured_grads_input.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_output_data_collect(
+ service_instance.data_collector.backward_input_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
@@ -137,7 +137,7 @@ class ApiRegistry:
new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
- service_instance.data_collector.backward_input_data_collect(
+ service_instance.data_collector.backward_output_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
--
Gitee
From f1a616bbb1c5f9ab892adfce80cb5961d37975c3 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 16:33:20 +0800
Subject: [PATCH 13/67] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=B8=B8=E9=87=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
debug/accuracy_tools/msprobe/core/common/const.py | 1 +
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 4 +---
2 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
index df82455a676..81e21e3d2ac 100644
--- a/debug/accuracy_tools/msprobe/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -16,6 +16,7 @@ class Const:
OFF = 'OFF'
BACKWARD = 'backward'
FORWARD = 'forward'
+ PRIMITIVE_PREFIX = 'Primitive'
DEFAULT_LIST = []
DEFAULT_PATH = './'
WHITE_LIST = 'white_list'
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index a3f9ec35d76..01bce6b5259 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -117,13 +117,11 @@ class ApiRegistry:
primitive_instance.primitive_counters[primitive_name] += 1
current_count = primitive_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input = []
captured_grads_output = []
def input_backward_hook(grad):
- # print(f"Grad input length: {len(grad)}")
- # print("Captured input grad:", grad)
captured_grads_input.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
--
Gitee
From 1bbc79438375ee7a53942a983a1434c9b6834962 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 16:37:34 +0800
Subject: [PATCH 14/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 01bce6b5259..3ee7d9e0e46 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -109,14 +109,13 @@ class ApiRegistry:
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
def wrap_primitive(self, origin_func, primitive_name, service_instance):
- primitive_instance = self
def func(self, *args, **kwargs):
- if primitive_name not in primitive_instance.primitive_counters:
- primitive_instance.primitive_counters[primitive_name] = 0
+ if primitive_name not in self.primitive_counters:
+ self.primitive_counters[primitive_name] = 0
else:
- primitive_instance.primitive_counters[primitive_name] += 1
+ self.primitive_counters[primitive_name] += 1
- current_count = primitive_instance.primitive_counters[primitive_name]
+ current_count = self.primitive_counters[primitive_name]
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input = []
captured_grads_output = []
--
Gitee
From 7df3f2edf624c06cb9f9ae1e8000347467aecb60 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 16:39:41 +0800
Subject: [PATCH 15/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 3ee7d9e0e46..01bce6b5259 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -109,13 +109,14 @@ class ApiRegistry:
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
def wrap_primitive(self, origin_func, primitive_name, service_instance):
+ primitive_instance = self
def func(self, *args, **kwargs):
- if primitive_name not in self.primitive_counters:
- self.primitive_counters[primitive_name] = 0
+ if primitive_name not in primitive_instance.primitive_counters:
+ primitive_instance.primitive_counters[primitive_name] = 0
else:
- self.primitive_counters[primitive_name] += 1
+ primitive_instance.primitive_counters[primitive_name] += 1
- current_count = self.primitive_counters[primitive_name]
+ current_count = primitive_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input = []
captured_grads_output = []
--
Gitee
From 95f07448ddd07ccbb18f724b104bc78b0e996e96 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:20:06 +0800
Subject: [PATCH 16/67] =?UTF-8?q?=E8=BF=81=E7=A7=BB=E5=87=BD=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../mindspore/dump/hook_cell/api_registry.py | 2 -
.../msprobe/mindspore/service.py | 78 ++++++++++++++++++-
2 files changed, 77 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 01bce6b5259..63a02cefef2 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -25,7 +25,6 @@ from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops,
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
from msprobe.core.common.utils import Const
-PRIMITIVE_PREFIX = "Primitive"
class ApiRegistry:
def __init__(self):
@@ -181,7 +180,6 @@ class ApiRegistry:
primitive_set.add((pname, primitive))
for pname, primitive in primitive_set:
- print("primitive name is", pname)
NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
primitive.__class__ = NewPrimitive
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 8d802e14d06..e7039a3bdb3 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -39,6 +39,7 @@ class Service:
self.first_start = True
self.current_rank = None
self.dump_iter_dir = None
+ self.primitive_counters = {}
def build_hook(self, module_type, name):
def forward_hook(api_or_module_name, module, input, output):
@@ -137,4 +138,79 @@ class Service:
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
if self.model:
- api_register.register_hooks(self)
+ register_hooks(self)
+
+ def register_hooks(self):
+ primitive_set = set()
+ for name, cell in self.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname)})
+ primitive.__class__ = NewPrimitive
+
+ def wrap_primitive(self, origin_func, primitive_name):
+ def func(self, *args, **kwargs):
+ if primitive_name not in self.primitive_counters:
+ self.primitive_counters[primitive_name] = 0
+ else:
+ self.primitive_counters[primitive_name] += 1
+
+ current_count = self.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
+
+ def input_backward_hook(grad):
+ captured_grads_input.append(grad)
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ self.data_collector.backward_input_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
+
+ def output_backward_hook(grad):
+ captured_grads_output.append(grad)
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads_output)
+ )
+ self.data_collector.backward_output_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
+
+ if not self.switch:
+ return origin_func(*args, **kwargs)
+
+ hooked_inputs = []
+ for idx, arg in enumerate(args):
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
+ else:
+ hooked_inputs.append(arg)
+
+ out = origin_func(*hooked_inputs, **kwargs)
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+
+ if self.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
+ self.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if self.data_collector.if_return_forward_new_output():
+ out = self.data_collector.get_forward_new_output()
+
+ if isinstance(out, Tensor):
+ out = ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ out = tuple(hooked_outputs)
+
+ return out
+
+ return func
--
Gitee
From fac1c0cc2fe68b7d40c67cf4a2305feffb5a3243 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:21:39 +0800
Subject: [PATCH 17/67] Update api_registry.py
---
.../mindspore/dump/hook_cell/api_registry.py | 150 +++++++++---------
1 file changed, 75 insertions(+), 75 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 63a02cefef2..4c2f81cb905 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -107,80 +107,80 @@ class ApiRegistry:
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
- def wrap_primitive(self, origin_func, primitive_name, service_instance):
- primitive_instance = self
- def func(self, *args, **kwargs):
- if primitive_name not in primitive_instance.primitive_counters:
- primitive_instance.primitive_counters[primitive_name] = 0
- else:
- primitive_instance.primitive_counters[primitive_name] += 1
-
- current_count = primitive_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
- captured_grads_output = []
-
- def input_backward_hook(grad):
- captured_grads_input.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
-
- def output_backward_hook(grad):
- captured_grads_output.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
-
- if not service_instance.switch:
- return origin_func(*args, **kwargs)
-
- hooked_inputs = []
- for idx, arg in enumerate(args):
- if isinstance(arg, Tensor):
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- hooked_inputs.append(arg_hooked)
- else:
- hooked_inputs.append(arg)
-
- out = origin_func(*hooked_inputs, **kwargs)
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
-
- if service_instance.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
- service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
- if service_instance.data_collector.if_return_forward_new_output():
- out = service_instance.data_collector.get_forward_new_output()
-
- if isinstance(out, Tensor):
- out = ops.HookBackward(output_backward_hook)(out)
- elif isinstance(out, tuple):
- hooked_outputs = []
- for tensor in out:
- if isinstance(tensor, Tensor):
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
- else:
- hooked_outputs.append(tensor)
- out = tuple(hooked_outputs)
-
- return out
-
- return func
-
- def register_hooks(self, service_instance):
- primitive_set = set()
- for name, cell in service_instance.model.cells_and_names():
- for pname, primitive in cell._primitives.items():
- primitive_set.add((pname, primitive))
-
- for pname, primitive in primitive_set:
- NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
- primitive.__class__ = NewPrimitive
+ # def wrap_primitive(self, origin_func, primitive_name, service_instance):
+ # primitive_instance = self
+ # def func(self, *args, **kwargs):
+ # if primitive_name not in primitive_instance.primitive_counters:
+ # primitive_instance.primitive_counters[primitive_name] = 0
+ # else:
+ # primitive_instance.primitive_counters[primitive_name] += 1
+
+ # current_count = primitive_instance.primitive_counters[primitive_name]
+ # updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ # captured_grads_input = []
+ # captured_grads_output = []
+
+ # def input_backward_hook(grad):
+ # captured_grads_input.append(grad)
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ # service_instance.data_collector.backward_input_data_collect(
+ # backward_primitive_name, self, os.getpid(), new_module_input_output
+ # )
+
+ # def output_backward_hook(grad):
+ # captured_grads_output.append(grad)
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # new_module_input_output = ModuleBackwardOutputs(
+ # grad_output=tuple(captured_grads_output)
+ # )
+ # service_instance.data_collector.backward_output_data_collect(
+ # backward_primitive_name, self, os.getpid(), new_module_input_output
+ # )
+
+ # if not service_instance.switch:
+ # return origin_func(*args, **kwargs)
+
+ # hooked_inputs = []
+ # for idx, arg in enumerate(args):
+ # if isinstance(arg, Tensor):
+ # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ # hooked_inputs.append(arg_hooked)
+ # else:
+ # hooked_inputs.append(arg)
+
+ # out = origin_func(*hooked_inputs, **kwargs)
+ # forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+
+ # if service_instance.data_collector:
+ # module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
+ # service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ # if service_instance.data_collector.if_return_forward_new_output():
+ # out = service_instance.data_collector.get_forward_new_output()
+
+ # if isinstance(out, Tensor):
+ # out = ops.HookBackward(output_backward_hook)(out)
+ # elif isinstance(out, tuple):
+ # hooked_outputs = []
+ # for tensor in out:
+ # if isinstance(tensor, Tensor):
+ # hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ # else:
+ # hooked_outputs.append(tensor)
+ # out = tuple(hooked_outputs)
+
+ # return out
+
+ # return func
+
+ # def register_hooks(self, service_instance):
+ # primitive_set = set()
+ # for name, cell in service_instance.model.cells_and_names():
+ # for pname, primitive in cell._primitives.items():
+ # primitive_set.add((pname, primitive))
+
+ # for pname, primitive in primitive_set:
+ # NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
+ # primitive.__class__ = NewPrimitive
api_register = ApiRegistry()
--
Gitee
From 36916dd494f20c97f2ec5f9926df833d7830ffb6 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:24:00 +0800
Subject: [PATCH 18/67] Update service.py
---
.../msprobe/mindspore/service.py | 149 +++++++++---------
1 file changed, 76 insertions(+), 73 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index e7039a3bdb3..b3f0f3885df 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -74,6 +74,81 @@ class Service:
return backward_hook(*args, **kwargs)
return wrap_forward_hook, wrap_backward_hook
+
+ def wrap_primitive(self, origin_func, primitive_name):
+ def func(self, *args, **kwargs):
+ if primitive_name not in self.primitive_counters:
+ self.primitive_counters[primitive_name] = 0
+ else:
+ self.primitive_counters[primitive_name] += 1
+
+ current_count = self.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
+
+ def input_backward_hook(grad):
+ captured_grads_input.append(grad)
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ self.data_collector.backward_input_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
+
+ def output_backward_hook(grad):
+ captured_grads_output.append(grad)
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads_output)
+ )
+ self.data_collector.backward_output_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
+
+ if not self.switch:
+ return origin_func(*args, **kwargs)
+
+ hooked_inputs = []
+ for idx, arg in enumerate(args):
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
+ else:
+ hooked_inputs.append(arg)
+
+ out = origin_func(*hooked_inputs, **kwargs)
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+
+ if self.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
+ self.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if self.data_collector.if_return_forward_new_output():
+ out = self.data_collector.get_forward_new_output()
+
+ if isinstance(out, Tensor):
+ out = ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ out = tuple(hooked_outputs)
+
+ return out
+
+ return func
+
+ def register_hooks(self):
+ primitive_set = set()
+ for name, cell in self.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname)})
+ primitive.__class__ = NewPrimitive
def step(self):
self.current_iter += 1
@@ -140,77 +215,5 @@ class Service:
if self.model:
register_hooks(self)
- def register_hooks(self):
- primitive_set = set()
- for name, cell in self.model.cells_and_names():
- for pname, primitive in cell._primitives.items():
- primitive_set.add((pname, primitive))
-
- for pname, primitive in primitive_set:
- NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname)})
- primitive.__class__ = NewPrimitive
-
- def wrap_primitive(self, origin_func, primitive_name):
- def func(self, *args, **kwargs):
- if primitive_name not in self.primitive_counters:
- self.primitive_counters[primitive_name] = 0
- else:
- self.primitive_counters[primitive_name] += 1
-
- current_count = self.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
- captured_grads_output = []
-
- def input_backward_hook(grad):
- captured_grads_input.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- self.data_collector.backward_input_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
-
- def output_backward_hook(grad):
- captured_grads_output.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- self.data_collector.backward_output_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
-
- if not self.switch:
- return origin_func(*args, **kwargs)
-
- hooked_inputs = []
- for idx, arg in enumerate(args):
- if isinstance(arg, Tensor):
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- hooked_inputs.append(arg_hooked)
- else:
- hooked_inputs.append(arg)
-
- out = origin_func(*hooked_inputs, **kwargs)
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
-
- if self.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
- self.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
- if self.data_collector.if_return_forward_new_output():
- out = self.data_collector.get_forward_new_output()
-
- if isinstance(out, Tensor):
- out = ops.HookBackward(output_backward_hook)(out)
- elif isinstance(out, tuple):
- hooked_outputs = []
- for tensor in out:
- if isinstance(tensor, Tensor):
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
- else:
- hooked_outputs.append(tensor)
- out = tuple(hooked_outputs)
-
- return out
+
- return func
--
Gitee
From da2f5ee3c3017c44b39e3133bc9d978c8d2a89f8 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:31:13 +0800
Subject: [PATCH 19/67] Update service.py
---
.../msprobe/mindspore/service.py | 24 ++++++++++---------
1 file changed, 13 insertions(+), 11 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index b3f0f3885df..24ca7d4609b 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -16,6 +16,7 @@
import os
from pathlib import Path
import functools
+from mindspore.common.tensor import Tensor
from msprobe.core.data_dump.data_collector import build_data_collector
from msprobe.core.data_dump.scope import BaseScope
@@ -25,7 +26,7 @@ from msprobe.mindspore.common.log import logger
from msprobe.core.common.utils import Const
from msprobe.core.common.exceptions import DistributedNotInitializedError
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
-from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs, ModuleBackwardOutputs
class Service:
@@ -74,15 +75,16 @@ class Service:
return backward_hook(*args, **kwargs)
return wrap_forward_hook, wrap_backward_hook
-
+
def wrap_primitive(self, origin_func, primitive_name):
+ service_instance = self
def func(self, *args, **kwargs):
- if primitive_name not in self.primitive_counters:
- self.primitive_counters[primitive_name] = 0
+ if primitive_name not in service_instance.primitive_counters:
+ service_instance.primitive_counters[primitive_name] = 0
else:
- self.primitive_counters[primitive_name] += 1
+ service_instance.primitive_counters[primitive_name] += 1
- current_count = self.primitive_counters[primitive_name]
+ current_count = service_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input = []
captured_grads_output = []
@@ -119,11 +121,11 @@ class Service:
out = origin_func(*hooked_inputs, **kwargs)
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
- if self.data_collector:
+ if service_instance.data_collector:
module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
- self.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
- if self.data_collector.if_return_forward_new_output():
- out = self.data_collector.get_forward_new_output()
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
@@ -213,7 +215,7 @@ class Service:
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
if self.model:
- register_hooks(self)
+ self.register_hooks()
--
Gitee
From f8735e2f3abb6321279cd019fe0d5d93a83cd7c0 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:32:14 +0800
Subject: [PATCH 20/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 24ca7d4609b..9435c48068d 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -17,6 +17,7 @@ import os
from pathlib import Path
import functools
from mindspore.common.tensor import Tensor
+from mindspore import ops
from msprobe.core.data_dump.data_collector import build_data_collector
from msprobe.core.data_dump.scope import BaseScope
@@ -107,7 +108,7 @@ class Service:
backward_primitive_name, self, os.getpid(), new_module_input_output
)
- if not self.switch:
+ if not service_instance.switch:
return origin_func(*args, **kwargs)
hooked_inputs = []
--
Gitee
From 71eea504ba61a5bb02e19bbdc97e5451f13e7c1e Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 26 Jul 2024 17:35:44 +0800
Subject: [PATCH 21/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 9435c48068d..fddd27718f4 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -94,7 +94,7 @@ class Service:
captured_grads_input.append(grad)
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- self.data_collector.backward_input_data_collect(
+ service_instance.data_collector.backward_input_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
@@ -104,7 +104,7 @@ class Service:
new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
- self.data_collector.backward_output_data_collect(
+ service_instance.data_collector.backward_output_data_collect(
backward_primitive_name, self, os.getpid(), new_module_input_output
)
--
Gitee
From 41d45586ddd38c9ab246124e7c08eb3e7ce37103 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 27 Jul 2024 10:32:00 +0800
Subject: [PATCH 22/67] Update api_registry.py
---
.../mindspore/dump/hook_cell/api_registry.py | 76 -------------------
1 file changed, 76 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 4c2f81cb905..3e9425c6fd8 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -107,80 +107,4 @@ class ApiRegistry:
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
- # def wrap_primitive(self, origin_func, primitive_name, service_instance):
- # primitive_instance = self
- # def func(self, *args, **kwargs):
- # if primitive_name not in primitive_instance.primitive_counters:
- # primitive_instance.primitive_counters[primitive_name] = 0
- # else:
- # primitive_instance.primitive_counters[primitive_name] += 1
-
- # current_count = primitive_instance.primitive_counters[primitive_name]
- # updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- # captured_grads_input = []
- # captured_grads_output = []
-
- # def input_backward_hook(grad):
- # captured_grads_input.append(grad)
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- # service_instance.data_collector.backward_input_data_collect(
- # backward_primitive_name, self, os.getpid(), new_module_input_output
- # )
-
- # def output_backward_hook(grad):
- # captured_grads_output.append(grad)
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # new_module_input_output = ModuleBackwardOutputs(
- # grad_output=tuple(captured_grads_output)
- # )
- # service_instance.data_collector.backward_output_data_collect(
- # backward_primitive_name, self, os.getpid(), new_module_input_output
- # )
-
- # if not service_instance.switch:
- # return origin_func(*args, **kwargs)
-
- # hooked_inputs = []
- # for idx, arg in enumerate(args):
- # if isinstance(arg, Tensor):
- # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- # hooked_inputs.append(arg_hooked)
- # else:
- # hooked_inputs.append(arg)
-
- # out = origin_func(*hooked_inputs, **kwargs)
- # forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
-
- # if service_instance.data_collector:
- # module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
- # service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
- # if service_instance.data_collector.if_return_forward_new_output():
- # out = service_instance.data_collector.get_forward_new_output()
-
- # if isinstance(out, Tensor):
- # out = ops.HookBackward(output_backward_hook)(out)
- # elif isinstance(out, tuple):
- # hooked_outputs = []
- # for tensor in out:
- # if isinstance(tensor, Tensor):
- # hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
- # else:
- # hooked_outputs.append(tensor)
- # out = tuple(hooked_outputs)
-
- # return out
-
- # return func
-
- # def register_hooks(self, service_instance):
- # primitive_set = set()
- # for name, cell in service_instance.model.cells_and_names():
- # for pname, primitive in cell._primitives.items():
- # primitive_set.add((pname, primitive))
-
- # for pname, primitive in primitive_set:
- # NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
- # primitive.__class__ = NewPrimitive
-
api_register = ApiRegistry()
--
Gitee
From 5d69b20f2a7db6576522d38462f6d57b81933854 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 27 Jul 2024 15:51:41 +0800
Subject: [PATCH 23/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index fddd27718f4..e9ff2ab7520 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -40,8 +40,8 @@ class Service:
self.current_iter = 0
self.first_start = True
self.current_rank = None
- self.dump_iter_dir = None
self.primitive_counters = {}
+ self.dump_iter_dir = None
def build_hook(self, module_type, name):
def forward_hook(api_or_module_name, module, input, output):
--
Gitee
From 95339b740327abfdb91d7df7928b473cde1b1e2c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 27 Jul 2024 15:54:59 +0800
Subject: [PATCH 24/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index 3e9425c6fd8..b30505f2d4f 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -13,13 +13,8 @@
# limitations under the License.
# ============================================================================
-import os
-import functools
import mindspore as ms
from mindspore import ops
-from mindspore.common.tensor import Tensor
-from msprobe.core.common.utils import Const
-from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs, ModuleBackwardOutputs
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
@@ -41,7 +36,6 @@ class ApiRegistry:
self.norm_inner_ops_hook_attr = {}
self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
- self.primitive_counters = {}
@staticmethod
def store_ori_attr(ori_api_group, api_list, api_ori_attr):
--
Gitee
From f06b73be7c40b5a7828e23c47459caa9ece5c30f Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 27 Jul 2024 15:55:36 +0800
Subject: [PATCH 25/67] Update api_registry.py
---
.../msprobe/mindspore/dump/hook_cell/api_registry.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
index b30505f2d4f..5508416fde0 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py
@@ -14,7 +14,6 @@
# ============================================================================
import mindspore as ms
-from mindspore import ops
from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
@@ -101,4 +100,5 @@ class ApiRegistry:
if attr_name.startswith(Const.ATTR_NAME_PREFIX):
self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
+
api_register = ApiRegistry()
--
Gitee
From b8f4d421fe9010a01163b8cfac20935569ff423f Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 09:59:49 +0800
Subject: [PATCH 26/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index e9ff2ab7520..ba4eca957e2 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -80,6 +80,7 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
def func(self, *args, **kwargs):
+ service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
if primitive_name not in service_instance.primitive_counters:
service_instance.primitive_counters[primitive_name] = 0
else:
--
Gitee
From 4ce94697456e2ffe08a5adf594c36c33e7c84fa3 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 11:30:18 +0800
Subject: [PATCH 27/67] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E6=97=A8=E5=9C=A8?=
=?UTF-8?q?=E6=9C=80=E5=90=8E=E4=B8=80=E5=9D=97=E6=94=B6=E9=9B=86=E6=95=B0?=
=?UTF-8?q?=E6=8D=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../msprobe/mindspore/service.py | 35 ++++++++++---------
1 file changed, 19 insertions(+), 16 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index ba4eca957e2..f7dfa044e8e 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -80,24 +80,15 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
def func(self, *args, **kwargs):
- service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
- if primitive_name not in service_instance.primitive_counters:
- service_instance.primitive_counters[primitive_name] = 0
- else:
- service_instance.primitive_counters[primitive_name] += 1
-
- current_count = service_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
- captured_grads_output = []
-
def input_backward_hook(grad):
captured_grads_input.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
+ if len(captured_grads_input) == num_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
+# 等所有加入后在收集
def output_backward_hook(grad):
captured_grads_output.append(grad)
@@ -109,10 +100,22 @@ class Service:
backward_primitive_name, self, os.getpid(), new_module_input_output
)
+ service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
+ if primitive_name not in service_instance.primitive_counters:
+ service_instance.primitive_counters[primitive_name] = 0
+ else:
+ service_instance.primitive_counters[primitive_name] += 1
+
+ current_count = service_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
+
if not service_instance.switch:
return origin_func(*args, **kwargs)
hooked_inputs = []
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
for idx, arg in enumerate(args):
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
--
Gitee
From 9348c9a4df412c22050c2697279f5cd030474031 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 11:31:14 +0800
Subject: [PATCH 28/67] =?UTF-8?q?=E6=89=93=E5=8D=B0tensor=E5=80=BC?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
debug/accuracy_tools/msprobe/mindspore/service.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index f7dfa044e8e..9ec251353df 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -116,6 +116,7 @@ class Service:
hooked_inputs = []
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
+ print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
for idx, arg in enumerate(args):
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
--
Gitee
From 61659c62dfc2a3e7f8ecf2b9bfe2e78f34a82b23 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 14:53:34 +0800
Subject: [PATCH 29/67] Update base.py
---
.../msprobe/core/data_dump/data_processor/base.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index e725d362e8c..a18efccbc6a 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -240,13 +240,16 @@ class BaseDataProcessor:
api_info_struct = {}
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = {}
- self.api_data_category = Const.OUTPUT
+ # self.api_data_category = Const.OUTPUT
+ self.api_data_category = Const.INPUT
+
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
api_info_struct[name][Const.GRAD_INPUT] = input_info_list
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
api_info_struct[name] = api_info_struct.get(name, {})
- self.api_data_category = Const.INPUT
+ self.api_data_category = Const.OUTPUT
+ # self.api_data_category = Const.INPUT
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
--
Gitee
From f0818349a58c98888ca393b27d662488aef5ecd9 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 15:03:17 +0800
Subject: [PATCH 30/67] Update base.py
---
.../msprobe/core/data_dump/data_processor/base.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index a18efccbc6a..c55c5079e25 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -262,7 +262,8 @@ class BaseDataProcessor:
api_info_struct = {}
if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
api_info_struct[name] = {}
- self.api_data_category = Const.INPUT
+ self.api_data_category = Const.OUTPUT
+ # self.api_data_category = Const.INPUT
output_info_list = self.analyze_element(module_input_output.grad_input_tuple)
api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
return api_info_struct
@@ -274,7 +275,8 @@ class BaseDataProcessor:
api_info_struct = {}
if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = {}
- self.api_data_category = Const.OUTPUT
+ self.api_data_category = Const.INPUT
+ # self.api_data_category = Const.OUTPUT
input_info_list = self.analyze_element(module_input_output.grad_output_tuple)
api_info_struct[name][Const.GRAD_INPUT] = input_info_list
return api_info_struct
--
Gitee
From 7c2eec3adecd2a55791c0fd6a0adfba86fbf1f2c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 15:41:36 +0800
Subject: [PATCH 31/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 9ec251353df..2d4419b370c 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -117,7 +117,9 @@ class Service:
hooked_inputs = []
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
+ print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
for idx, arg in enumerate(args):
+ print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
hooked_inputs.append(arg_hooked)
@@ -133,6 +135,7 @@ class Service:
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
+ # num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
elif isinstance(out, tuple):
--
Gitee
From b0f4bfca7869145e067b3465e52ab351c8c8ee07 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 15:54:18 +0800
Subject: [PATCH 32/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 2d4419b370c..06535109463 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -136,6 +136,7 @@ class Service:
out = service_instance.data_collector.get_forward_new_output()
# num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+ print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
elif isinstance(out, tuple):
--
Gitee
From dc1f5a437661e8deb02d628f6b0a5b5cce4ff2ce Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 29 Jul 2024 16:01:48 +0800
Subject: [PATCH 33/67] Update service.py
---
.../msprobe/mindspore/service.py | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 06535109463..54d34861dc6 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -92,13 +92,14 @@ class Service:
def output_backward_hook(grad):
captured_grads_output.append(grad)
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
- )
+ if len(captured_grads_output) == num_output_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads_output)
+ )
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, self, os.getpid(), new_module_input_output
+ )
service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
if primitive_name not in service_instance.primitive_counters:
@@ -134,8 +135,8 @@ class Service:
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
-
- # num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From 574e684341d5e16a2d3c8983faf507f7ee218233 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 30 Jul 2024 20:16:31 +0800
Subject: [PATCH 34/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 54d34861dc6..6c681a12422 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -223,8 +223,8 @@ class Service:
def register_hook_new(self):
logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
- api_register.api_set_hook_func()
+ # api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
+ # api_register.api_set_hook_func()
if self.model:
self.register_hooks()
--
Gitee
From 5b91d4839615c8572b6a478f42f31dd946cdce0c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 11:28:42 +0800
Subject: [PATCH 35/67] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../core/data_dump/data_processor/mindspore_processor.py | 5 +++--
debug/accuracy_tools/msprobe/mindspore/service.py | 2 +-
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
index 7533e2ee0de..db02f26f607 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
@@ -57,8 +57,9 @@ class MindsporeDataProcessor(BaseDataProcessor):
if data.numel() == 0:
return tensor_stat
elif data.dtype == ms.bool_:
- tensor_stat.max = self.mint_ops_func["max"](data).item()
- tensor_stat.min = self.mint_ops_func["min"](data).item()
+ data_np = data.asnumpy()
+ tensor_stat.max = bool(np.max(data_np))
+ tensor_stat.min = bool(np.min(data_np))
elif not data.shape:
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 6c681a12422..46a054275b1 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -131,7 +131,7 @@ class Service:
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
if service_instance.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
--
Gitee
From 74297ab27573157dc4050f162c30e8f846c987e2 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 17:07:56 +0800
Subject: [PATCH 36/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 15 +++++++++++----
1 file changed, 11 insertions(+), 4 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 46a054275b1..42ab8b31aa8 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -77,19 +77,21 @@ class Service:
return wrap_forward_hook, wrap_backward_hook
+
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
- def func(self, *args, **kwargs):
+ def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
def input_backward_hook(grad):
captured_grads_input.append(grad)
if len(captured_grads_input) == num_tensors:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
)
-# 等所有加入后在收集
+ return input_backward_hook
+ def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
def output_backward_hook(grad):
captured_grads_output.append(grad)
if len(captured_grads_output) == num_output_tensors:
@@ -98,9 +100,11 @@ class Service:
grad_output=tuple(captured_grads_output)
)
service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, self, os.getpid(), new_module_input_output
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
)
+ return output_backward_hook
+ def func(self, *args, **kwargs):
service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
if primitive_name not in service_instance.primitive_counters:
service_instance.primitive_counters[primitive_name] = 0
@@ -119,6 +123,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
+ input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -138,6 +143,8 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
+ output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
+
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
elif isinstance(out, tuple):
--
Gitee
From 19330452a66810b9a7eacf23121ce1b406b47828 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 17:16:59 +0800
Subject: [PATCH 37/67] Update service.py
---
.../msprobe/mindspore/service.py | 64 ++++++++++++-------
1 file changed, 41 insertions(+), 23 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 42ab8b31aa8..fea947dfd90 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -81,28 +81,46 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
- def input_backward_hook(grad):
- captured_grads_input.append(grad)
- if len(captured_grads_input) == num_tensors:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- return input_backward_hook
-
- def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
- def output_backward_hook(grad):
- captured_grads_output.append(grad)
- if len(captured_grads_output) == num_output_tensors:
+ # def input_backward_hook(grad):
+ # captured_grads_input.append(grad)
+ # if len(captured_grads_input) == num_tensors:
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ # service_instance.data_collector.backward_input_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # return input_backward_hook
+
+ # def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
+ # def output_backward_hook(grad):
+ # captured_grads_output.append(grad)
+ # if len(captured_grads_output) == num_output_tensors:
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # new_module_input_output = ModuleBackwardOutputs(
+ # grad_output=tuple(captured_grads_output)
+ # )
+ # service_instance.data_collector.backward_output_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # return output_backward_hook
+ def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
+ def backward_hook(grad):
+ captured_grads.append(grad)
+ if len(captured_grads) == num_grads:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- return output_backward_hook
+ if is_input:
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ else:
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads)
+ )
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ return backward_hook
def func(self, *args, **kwargs):
service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
@@ -123,7 +141,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
+ input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -143,7 +161,7 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
+ output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From 9418a12bde6e93059a89adc2f2bde84df536731a Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 17:20:49 +0800
Subject: [PATCH 38/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index fea947dfd90..326bc75d1b9 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -80,7 +80,7 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
- def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
+ # def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
# def input_backward_hook(grad):
# captured_grads_input.append(grad)
# if len(captured_grads_input) == num_tensors:
@@ -141,7 +141,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -161,7 +161,7 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From d44838b7fd782fd3441ccd8bdf0d47da94ae265d Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 17:28:58 +0800
Subject: [PATCH 39/67] Update service.py
---
.../msprobe/mindspore/service.py | 38 +++++++++----------
1 file changed, 19 insertions(+), 19 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 326bc75d1b9..42b8df62695 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -77,7 +77,24 @@ class Service:
return wrap_forward_hook, wrap_backward_hook
-
+ def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
+ def backward_hook(grad):
+ captured_grads.append(grad)
+ if len(captured_grads) == num_grads:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ if is_input:
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ else:
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads)
+ )
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ return backward_hook
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
# def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
@@ -103,24 +120,7 @@ class Service:
# backward_primitive_name, service_instance, os.getpid(), new_module_input_output
# )
# return output_backward_hook
- def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
- def backward_hook(grad):
- captured_grads.append(grad)
- if len(captured_grads) == num_grads:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- if is_input:
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- else:
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- return backward_hook
+
def func(self, *args, **kwargs):
service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
--
Gitee
From ae8a1f63bd9ef480a413d8443f3b9c33ce618967 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 17:33:52 +0800
Subject: [PATCH 40/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 42b8df62695..108c4df94fc 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -141,7 +141,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
+ input_backward_hook = self.create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -161,7 +161,7 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
+ output_backward_hook = self.create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From 9a97bcf905b4686e746aff63dfff3cf1099f6645 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 20:29:57 +0800
Subject: [PATCH 41/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 108c4df94fc..7868d0897cb 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -141,7 +141,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = self.create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
+ input_backward_hook = service_instance.create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -161,7 +161,7 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = self.create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
+ output_backward_hook = service_instance.create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From 2713c9d15fd71ccb4c51b9487a8f013ad69438d4 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Thu, 1 Aug 2024 20:37:07 +0800
Subject: [PATCH 42/67] Update service.py
---
.../msprobe/mindspore/service.py | 74 +++++++++----------
1 file changed, 37 insertions(+), 37 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 7868d0897cb..fed54b08d33 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -77,49 +77,49 @@ class Service:
return wrap_forward_hook, wrap_backward_hook
- def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
- def backward_hook(grad):
- captured_grads.append(grad)
- if len(captured_grads) == num_grads:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- if is_input:
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ # def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
+ # def backward_hook(grad):
+ # captured_grads.append(grad)
+ # if len(captured_grads) == num_grads:
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # if is_input:
+ # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ # service_instance.data_collector.backward_input_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # else:
+ # new_module_input_output = ModuleBackwardOutputs(
+ # grad_output=tuple(captured_grads)
+ # )
+ # service_instance.data_collector.backward_output_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # return backward_hook
+ def wrap_primitive(self, origin_func, primitive_name):
+ service_instance = self
+ def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
+ def input_backward_hook(grad):
+ captured_grads_input.append(grad)
+ if len(captured_grads_input) == num_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
service_instance.data_collector.backward_input_data_collect(
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
)
- else:
+ return input_backward_hook
+
+ def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
+ def output_backward_hook(grad):
+ captured_grads_output.append(grad)
+ if len(captured_grads_output) == num_output_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads)
+ grad_output=tuple(captured_grads_output)
)
service_instance.data_collector.backward_output_data_collect(
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
)
- return backward_hook
- def wrap_primitive(self, origin_func, primitive_name):
- service_instance = self
- # def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
- # def input_backward_hook(grad):
- # captured_grads_input.append(grad)
- # if len(captured_grads_input) == num_tensors:
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- # service_instance.data_collector.backward_input_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # return input_backward_hook
-
- # def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
- # def output_backward_hook(grad):
- # captured_grads_output.append(grad)
- # if len(captured_grads_output) == num_output_tensors:
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # new_module_input_output = ModuleBackwardOutputs(
- # grad_output=tuple(captured_grads_output)
- # )
- # service_instance.data_collector.backward_output_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # return output_backward_hook
+ return output_backward_hook
def func(self, *args, **kwargs):
@@ -141,7 +141,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = service_instance.create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, is_input=True)
+ input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
for idx, arg in enumerate(args):
print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
@@ -161,7 +161,7 @@ class Service:
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = service_instance.create_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name, is_input=False)
+ output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
--
Gitee
From 65b228b154ca1fa14f679951a97a92fc62b81d2c Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 2 Aug 2024 11:16:58 +0800
Subject: [PATCH 43/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index fed54b08d33..1c4e98450c2 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -248,8 +248,8 @@ class Service:
def register_hook_new(self):
logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
- # api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
- # api_register.api_set_hook_func()
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
+ api_register.api_set_hook_func()
if self.model:
self.register_hooks()
--
Gitee
From 980a7537b5e76418f77de886e1ca0443d6ce98c3 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Fri, 2 Aug 2024 15:52:20 +0800
Subject: [PATCH 44/67] Update service.py
---
.../msprobe/mindspore/service.py | 150 ++++++++++--------
1 file changed, 87 insertions(+), 63 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 1c4e98450c2..1d2439b14c2 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -99,83 +99,107 @@ class Service:
service_instance = self
def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
def input_backward_hook(grad):
- captured_grads_input.append(grad)
- if len(captured_grads_input) == num_tensors:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
+ try:
+ captured_grads_input.append(grad)
+ if len(captured_grads_input) == num_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ captured_grads_input.clear()
+
+ except Exception as e:
+ print(f"Error occurred in input_backward_hook: {e}")
+ print(f"Captured grads input: {captured_grads_input}")
+ print(f"Num tensors: {num_tensors}")
+ print(f"Updated primitive name: {updated_primitive_name}")
+ raise # 重新引发异常
return input_backward_hook
def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
def output_backward_hook(grad):
- captured_grads_output.append(grad)
- if len(captured_grads_output) == num_output_tensors:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
+ try:
+ captured_grads_output.append(grad)
+ if len(captured_grads_output) == num_output_tensors:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardOutputs(
+ grad_output=tuple(captured_grads_output)
+ )
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ # 释放 captured_grads_output 列表
+ captured_grads_output.clear()
+ except Exception as e:
+ print(f"Error occurred in output_backward_hook: {e}")
+ print(f"Captured grads output: {captured_grads_output}")
+ print(f"Num output tensors: {num_output_tensors}")
+ print(f"Updated primitive name: {updated_primitive_name}")
+ raise # 重新引发异常
return output_backward_hook
def func(self, *args, **kwargs):
- service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
- if primitive_name not in service_instance.primitive_counters:
- service_instance.primitive_counters[primitive_name] = 0
- else:
- service_instance.primitive_counters[primitive_name] += 1
-
- current_count = service_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
- captured_grads_output = []
-
- if not service_instance.switch:
- return origin_func(*args, **kwargs)
-
- hooked_inputs = []
- num_tensors = sum(isinstance(arg, Tensor) for arg in args)
- print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
- print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
- input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
- for idx, arg in enumerate(args):
- print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
- if isinstance(arg, Tensor):
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- hooked_inputs.append(arg_hooked)
+ try:
+ service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
+ if primitive_name not in service_instance.primitive_counters:
+ service_instance.primitive_counters[primitive_name] = 0
else:
- hooked_inputs.append(arg)
+ service_instance.primitive_counters[primitive_name] += 1
+
+ current_count = service_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
- out = origin_func(*hooked_inputs, **kwargs)
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
- if service_instance.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
- service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
- if service_instance.data_collector.if_return_forward_new_output():
- out = service_instance.data_collector.get_forward_new_output()
-
- num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
- print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
-
- if isinstance(out, Tensor):
- out = ops.HookBackward(output_backward_hook)(out)
- elif isinstance(out, tuple):
- hooked_outputs = []
- for tensor in out:
- if isinstance(tensor, Tensor):
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ hooked_inputs = []
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
+ print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
+ print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
+ input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
+ for idx, arg in enumerate(args):
+ print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
else:
- hooked_outputs.append(tensor)
- out = tuple(hooked_outputs)
+ hooked_inputs.append(arg)
+
+ out = origin_func(*hooked_inputs, **kwargs)
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
- return out
+ if service_instance.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
+
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+ print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
+ output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
+
+ if isinstance(out, Tensor):
+ out = ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ out = tuple(hooked_outputs)
+ return out
+ except Exception as e:
+ print(f"Error occurred in wrap_primitive: {e}")
+ print(f"Arguments(args): {args}")
+ print(f"Arguments(kwargs): {kwargs}")
+ print(f"Current primitive name: {primitive_name}")
+ raise Exception("This is a primitive op dump error")
return func
def register_hooks(self):
--
Gitee
From b3c234dedec9c4d63660461aabd1aaf96f99420f Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 3 Aug 2024 15:33:06 +0800
Subject: [PATCH 45/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 1d2439b14c2..2b27e84b7b9 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -114,7 +114,7 @@ class Service:
print(f"Captured grads input: {captured_grads_input}")
print(f"Num tensors: {num_tensors}")
print(f"Updated primitive name: {updated_primitive_name}")
- raise # 重新引发异常
+ raise # 重新引发异常
return input_backward_hook
def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
--
Gitee
From 02e759b5866555068f308c2de6d2445317a56adb Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 3 Aug 2024 15:33:55 +0800
Subject: [PATCH 46/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 2b27e84b7b9..66e5d71e664 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -114,7 +114,7 @@ class Service:
print(f"Captured grads input: {captured_grads_input}")
print(f"Num tensors: {num_tensors}")
print(f"Updated primitive name: {updated_primitive_name}")
- raise # 重新引发异常
+ raise Exception("This is a primitive op input_backward dump error") # 重新引发异常
return input_backward_hook
def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
@@ -136,7 +136,7 @@ class Service:
print(f"Captured grads output: {captured_grads_output}")
print(f"Num output tensors: {num_output_tensors}")
print(f"Updated primitive name: {updated_primitive_name}")
- raise # 重新引发异常
+ raise Exception("This is a primitive op output_backward dump error")# 重新引发异常
return output_backward_hook
--
Gitee
From 712981c40d3153a5ef7cf27f1ba4d013edea1eb7 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 3 Aug 2024 16:28:42 +0800
Subject: [PATCH 47/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 66e5d71e664..b8561a2d67a 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -272,8 +272,8 @@ class Service:
def register_hook_new(self):
logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
- api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
- api_register.api_set_hook_func()
+ # api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
+ # api_register.api_set_hook_func()
if self.model:
self.register_hooks()
--
Gitee
From de6b85dd48facbf8b44d5a9c79dff50f624a5e5e Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 3 Aug 2024 17:47:58 +0800
Subject: [PATCH 48/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index b8561a2d67a..7c26f54ec34 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -142,7 +142,7 @@ class Service:
def func(self, *args, **kwargs):
try:
- service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
+ # service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
if primitive_name not in service_instance.primitive_counters:
service_instance.primitive_counters[primitive_name] = 0
else:
@@ -171,7 +171,7 @@ class Service:
out = origin_func(*hooked_inputs, **kwargs)
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
-
+ service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
if service_instance.data_collector:
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
--
Gitee
From 9a781c25756b57a9ab325d86803b5b727253219a Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Sat, 3 Aug 2024 17:51:57 +0800
Subject: [PATCH 49/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 7c26f54ec34..ffdbc1e56ad 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -103,6 +103,7 @@ class Service:
captured_grads_input.append(grad)
if len(captured_grads_input) == num_tensors:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
service_instance.data_collector.backward_input_data_collect(
backward_primitive_name, service_instance, os.getpid(), new_module_input_output
@@ -123,6 +124,7 @@ class Service:
captured_grads_output.append(grad)
if len(captured_grads_output) == num_output_tensors:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(updated_primitive_name)
new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
@@ -171,7 +173,7 @@ class Service:
out = origin_func(*hooked_inputs, **kwargs)
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
- service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
+ service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
if service_instance.data_collector:
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
--
Gitee
From faf5b9f00c0f4661dc9695216ea890a8bd1f3826 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 11:22:21 +0800
Subject: [PATCH 50/67] Update service.py
---
.../msprobe/mindspore/service.py | 78 +++++++++++--------
1 file changed, 46 insertions(+), 32 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index ffdbc1e56ad..12c5ea98c0e 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -97,6 +97,7 @@ class Service:
# return backward_hook
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
+
def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
def input_backward_hook(grad):
try:
@@ -110,12 +111,9 @@ class Service:
)
captured_grads_input.clear()
- except Exception as e:
- print(f"Error occurred in input_backward_hook: {e}")
- print(f"Captured grads input: {captured_grads_input}")
- print(f"Num tensors: {num_tensors}")
- print(f"Updated primitive name: {updated_primitive_name}")
- raise Exception("This is a primitive op input_backward dump error") # 重新引发异常
+ except Exception as exception:
+ raise Exception(f"This is a primitive op input_backward dump error: {exception}"
+ f", updated_primitive_name: {updated_primitive_name}")
return input_backward_hook
def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
@@ -124,7 +122,7 @@ class Service:
captured_grads_output.append(grad)
if len(captured_grads_output) == num_output_tensors:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- service_instance.data_collector.visit_and_clear_overflow_status(updated_primitive_name)
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
new_module_input_output = ModuleBackwardOutputs(
grad_output=tuple(captured_grads_output)
)
@@ -133,38 +131,51 @@ class Service:
)
# 释放 captured_grads_output 列表
captured_grads_output.clear()
- except Exception as e:
- print(f"Error occurred in output_backward_hook: {e}")
- print(f"Captured grads output: {captured_grads_output}")
- print(f"Num output tensors: {num_output_tensors}")
- print(f"Updated primitive name: {updated_primitive_name}")
- raise Exception("This is a primitive op output_backward dump error")# 重新引发异常
+ except Exception as exception:
+ raise Exception(f"This is a primitive op output_backward dump error: {exception}"
+ f", updated_primitive_name: {updated_primitive_name}")
return output_backward_hook
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
+ hooked_inputs = []
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
+ input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
+ for idx, arg in enumerate(args):
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
+ else:
+ hooked_inputs.append(arg)
+ return hooked_inputs
+
+
+ def hook_primitive_outputs(out, create_output_backward_hook, output_backward_hook):
+ if isinstance(out, Tensor):
+ return ops.HookBackward(create_output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(create_output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ return tuple(hooked_outputs)
+ return out
def func(self, *args, **kwargs):
try:
- # service_instance.data_collector.visit_and_clear_overflow_status(primitive_name)
- if primitive_name not in service_instance.primitive_counters:
- service_instance.primitive_counters[primitive_name] = 0
- else:
- service_instance.primitive_counters[primitive_name] += 1
-
+ service_instance._update_primitive_counters(primitive_name)
current_count = service_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- captured_grads_input = []
- captured_grads_output = []
+ captured_grads_input, captured_grads_output = [], []
if not service_instance.switch:
return origin_func(*args, **kwargs)
hooked_inputs = []
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
- print(f"Number of tensor arguments: {num_tensors}") # 打印 num_tensors 的值
- print(f"Arguments(args): type={type(args)}") # 打印每个 arg 的类型
input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
for idx, arg in enumerate(args):
- print(f"Argument {idx}: type={type(arg)}") # 打印每个 arg 的类型
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
hooked_inputs.append(arg_hooked)
@@ -179,11 +190,11 @@ class Service:
service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
-
+
num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
-
+
if isinstance(out, Tensor):
out = ops.HookBackward(output_backward_hook)(out)
elif isinstance(out, tuple):
@@ -196,14 +207,17 @@ class Service:
out = tuple(hooked_outputs)
return out
- except Exception as e:
- print(f"Error occurred in wrap_primitive: {e}")
- print(f"Arguments(args): {args}")
- print(f"Arguments(kwargs): {kwargs}")
- print(f"Current primitive name: {primitive_name}")
- raise Exception("This is a primitive op dump error")
+ except Exception as exception:
+ raise Exception(f"This is a primitive op dump error: {exception}"
+ f", primitive_name: {primitive_name}")
return func
+ def _update_primitive_counters(self, primitive_name):
+ if primitive_name not in self.primitive_counters:
+ self.primitive_counters[primitive_name] = 0
+ else:
+ self.primitive_counters[primitive_name] += 1
+
def register_hooks(self):
primitive_set = set()
for name, cell in self.model.cells_and_names():
--
Gitee
From bbac45931076a68b1e87308862e359550140ddf4 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 11:40:43 +0800
Subject: [PATCH 51/67] Update service.py
---
.../msprobe/mindspore/service.py | 169 ++++++++++--------
1 file changed, 92 insertions(+), 77 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 12c5ea98c0e..4cc60821606 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -77,69 +77,78 @@ class Service:
return wrap_forward_hook, wrap_backward_hook
- # def create_backward_hook(captured_grads, num_grads, updated_primitive_name, is_input):
- # def backward_hook(grad):
- # captured_grads.append(grad)
- # if len(captured_grads) == num_grads:
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # if is_input:
- # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
- # service_instance.data_collector.backward_input_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # else:
- # new_module_input_output = ModuleBackwardOutputs(
- # grad_output=tuple(captured_grads)
- # )
- # service_instance.data_collector.backward_output_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # return backward_hook
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
- def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
- def input_backward_hook(grad):
+ # def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
+ # def input_backward_hook(grad):
+ # try:
+ # captured_grads_input.append(grad)
+ # if len(captured_grads_input) == num_tensors:
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
+ # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
+ # service_instance.data_collector.backward_input_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # captured_grads_input.clear()
+ #
+ # except Exception as exception:
+ # raise Exception(f"This is a primitive op input_backward dump error: {exception}"
+ # f", updated_primitive_name: {updated_primitive_name}")
+ # return input_backward_hook
+ #
+ # def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
+ # def output_backward_hook(grad):
+ # try:
+ # captured_grads_output.append(grad)
+ # if len(captured_grads_output) == num_output_tensors:
+ # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ # service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
+ # new_module_input_output = ModuleBackwardOutputs(
+ # grad_output=tuple(captured_grads_output)
+ # )
+ # service_instance.data_collector.backward_output_data_collect(
+ # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ # )
+ # # 释放 captured_grads_output 列表
+ # captured_grads_output.clear()
+ # except Exception as exception:
+ # raise Exception(f"This is a primitive op output_backward dump error: {exception}"
+ # f", updated_primitive_name: {updated_primitive_name}")
+ # return output_backward_hook
+
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
+ def backward_hook(grad):
try:
- captured_grads_input.append(grad)
- if len(captured_grads_input) == num_tensors:
+ captured_grads.append(grad)
+ if len(captured_grads) == num_tensors:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- captured_grads_input.clear()
+ if hook_type == 'input':
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ elif hook_type == 'output':
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+
+ captured_grads.clear()
except Exception as exception:
- raise Exception(f"This is a primitive op input_backward dump error: {exception}"
+ raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception}"
f", updated_primitive_name: {updated_primitive_name}")
- return input_backward_hook
- def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
- def output_backward_hook(grad):
- try:
- captured_grads_output.append(grad)
- if len(captured_grads_output) == num_output_tensors:
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
- new_module_input_output = ModuleBackwardOutputs(
- grad_output=tuple(captured_grads_output)
- )
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- # 释放 captured_grads_output 列表
- captured_grads_output.clear()
- except Exception as exception:
- raise Exception(f"This is a primitive op output_backward dump error: {exception}"
- f", updated_primitive_name: {updated_primitive_name}")
- return output_backward_hook
+ return backward_hook
def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
hooked_inputs = []
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
- input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
+ 'input')
for idx, arg in enumerate(args):
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
@@ -148,15 +157,19 @@ class Service:
hooked_inputs.append(arg)
return hooked_inputs
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
+ num_output_tensors = sum(
+ isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
+ updated_primitive_name, 'output')
- def hook_primitive_outputs(out, create_output_backward_hook, output_backward_hook):
if isinstance(out, Tensor):
- return ops.HookBackward(create_output_backward_hook)(out)
+ return ops.HookBackward(output_backward_hook)(out)
elif isinstance(out, tuple):
hooked_outputs = []
for tensor in out:
if isinstance(tensor, Tensor):
- hooked_outputs.append(ops.HookBackward(create_output_backward_hook)(tensor))
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
else:
hooked_outputs.append(tensor)
return tuple(hooked_outputs)
@@ -172,39 +185,41 @@ class Service:
if not service_instance.switch:
return origin_func(*args, **kwargs)
- hooked_inputs = []
- num_tensors = sum(isinstance(arg, Tensor) for arg in args)
- input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
- for idx, arg in enumerate(args):
- if isinstance(arg, Tensor):
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- hooked_inputs.append(arg_hooked)
- else:
- hooked_inputs.append(arg)
+ # hooked_inputs = []
+ # num_tensors = sum(isinstance(arg, Tensor) for arg in args)
+ # input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
+ # for idx, arg in enumerate(args):
+ # if isinstance(arg, Tensor):
+ # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ # hooked_inputs.append(arg_hooked)
+ # else:
+ # hooked_inputs.append(arg)
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
out = origin_func(*hooked_inputs, **kwargs)
forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
if service_instance.data_collector:
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
- service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, self,
+ os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
-
- num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
- print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
-
- if isinstance(out, Tensor):
- out = ops.HookBackward(output_backward_hook)(out)
- elif isinstance(out, tuple):
- hooked_outputs = []
- for tensor in out:
- if isinstance(tensor, Tensor):
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
- else:
- hooked_outputs.append(tensor)
- out = tuple(hooked_outputs)
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
+ # num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+ # print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
+ # output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
+ #
+ # if isinstance(out, Tensor):
+ # out = ops.HookBackward(output_backward_hook)(out)
+ # elif isinstance(out, tuple):
+ # hooked_outputs = []
+ # for tensor in out:
+ # if isinstance(tensor, Tensor):
+ # hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ # else:
+ # hooked_outputs.append(tensor)
+ # out = tuple(hooked_outputs)
return out
except Exception as exception:
--
Gitee
From 51a52c0117bb1b244c6c5c010638a3da2f33b883 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 11:47:52 +0800
Subject: [PATCH 52/67] Update service.py
---
.../msprobe/mindspore/service.py | 61 -------------------
1 file changed, 61 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 4cc60821606..200f7759bf8 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -80,44 +80,6 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
- # def create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name):
- # def input_backward_hook(grad):
- # try:
- # captured_grads_input.append(grad)
- # if len(captured_grads_input) == num_tensors:
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
- # new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads_input))
- # service_instance.data_collector.backward_input_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # captured_grads_input.clear()
- #
- # except Exception as exception:
- # raise Exception(f"This is a primitive op input_backward dump error: {exception}"
- # f", updated_primitive_name: {updated_primitive_name}")
- # return input_backward_hook
- #
- # def create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name):
- # def output_backward_hook(grad):
- # try:
- # captured_grads_output.append(grad)
- # if len(captured_grads_output) == num_output_tensors:
- # backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
- # service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
- # new_module_input_output = ModuleBackwardOutputs(
- # grad_output=tuple(captured_grads_output)
- # )
- # service_instance.data_collector.backward_output_data_collect(
- # backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- # )
- # # 释放 captured_grads_output 列表
- # captured_grads_output.clear()
- # except Exception as exception:
- # raise Exception(f"This is a primitive op output_backward dump error: {exception}"
- # f", updated_primitive_name: {updated_primitive_name}")
- # return output_backward_hook
-
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
def backward_hook(grad):
try:
@@ -185,15 +147,6 @@ class Service:
if not service_instance.switch:
return origin_func(*args, **kwargs)
- # hooked_inputs = []
- # num_tensors = sum(isinstance(arg, Tensor) for arg in args)
- # input_backward_hook = create_input_backward_hook(captured_grads_input, num_tensors, updated_primitive_name)
- # for idx, arg in enumerate(args):
- # if isinstance(arg, Tensor):
- # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
- # hooked_inputs.append(arg_hooked)
- # else:
- # hooked_inputs.append(arg)
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
out = origin_func(*hooked_inputs, **kwargs)
@@ -206,20 +159,6 @@ class Service:
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
- # num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
- # print(f"Arguments(out): type={type(out)}") # 打印每个 arg 的类型
- # output_backward_hook = create_output_backward_hook(captured_grads_output, num_output_tensors, updated_primitive_name)
- #
- # if isinstance(out, Tensor):
- # out = ops.HookBackward(output_backward_hook)(out)
- # elif isinstance(out, tuple):
- # hooked_outputs = []
- # for tensor in out:
- # if isinstance(tensor, Tensor):
- # hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
- # else:
- # hooked_outputs.append(tensor)
- # out = tuple(hooked_outputs)
return out
except Exception as exception:
--
Gitee
From b62bcb43e3f7ea0a726187c947437413d5a2f2a4 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 17:25:24 +0800
Subject: [PATCH 53/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 617f30ca1fa..a6c570f55f1 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -142,16 +142,16 @@ class Service:
return tuple(hooked_outputs)
return out
- def func(self, *args, **kwargs):
+ def func(instance_self, *args, **kwargs):
try:
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
+
service_instance._update_primitive_counters(primitive_name)
current_count = service_instance.primitive_counters[primitive_name]
updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input, captured_grads_output = [], []
- if not service_instance.switch:
- return origin_func(*args, **kwargs)
-
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
out = origin_func(*hooked_inputs, **kwargs)
@@ -159,7 +159,7 @@ class Service:
service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
if service_instance.data_collector:
module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
- service_instance.data_collector.forward_data_collect(forward_primitive_name, self,
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
os.getpid(), module_input_output)
if service_instance.data_collector.if_return_forward_new_output():
out = service_instance.data_collector.get_forward_new_output()
@@ -191,6 +191,8 @@ class Service:
self.current_iter += 1
self.data_collector.update_iter(self.current_iter)
HOOKCell.cell_count = defaultdict(int)
+ self.primitive_counters.clear()
+
def start(self, model=None):
self.model = model
--
Gitee
From d3fdd9c3d7cc466dac0d3c42976a50bd2c7348f0 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 20:27:10 +0800
Subject: [PATCH 54/67] Update service.py
---
.../msprobe/mindspore/service.py | 50 +++++++++++--------
1 file changed, 29 insertions(+), 21 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index a6c570f55f1..0829465bcc9 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -17,10 +17,10 @@ import os
import copy
from pathlib import Path
import functools
-from mindspore.common.tensor import Tensor
-from mindspore import ops
from collections import defaultdict
+from mindspore.common.tensor import Tensor
+from mindspore import ops
from msprobe.core.data_dump.data_collector import build_data_collector
from msprobe.core.data_dump.scope import BaseScope
from msprobe.mindspore.common.utils import get_rank_if_initialized
@@ -29,7 +29,8 @@ from msprobe.mindspore.common.log import logger
from msprobe.core.common.utils import Const
from msprobe.core.common.exceptions import DistributedNotInitializedError
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
-from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs, ModuleBackwardOutputs
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,\
+ ModuleBackwardInputs, ModuleBackwardOutputs
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
@@ -106,8 +107,13 @@ class Service:
captured_grads.clear()
except Exception as exception:
- raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception}"
- f", updated_primitive_name: {updated_primitive_name}")
+ raise Exception(
+ "This is a primitive op {hook_type}_backward dump error: {exception},"
+ " updated_primitive_name: {updated_primitive_name}".format(
+ hook_type=hook_type, exception=exception, updated_primitive_name=updated_primitive_name
+ )
+ )
+#改为.format()
return backward_hook
@@ -116,7 +122,7 @@ class Service:
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
'input')
- for idx, arg in enumerate(args):
+ for _, arg in enumerate(args):
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
hooked_inputs.append(arg_hooked)
@@ -144,14 +150,14 @@ class Service:
def func(instance_self, *args, **kwargs):
try:
+ service_instance.update_primitive_counters(primitive_name)
+ current_count = service_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+
if not service_instance.switch:
return origin_func(*args, **kwargs)
- service_instance._update_primitive_counters(primitive_name)
- current_count = service_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
captured_grads_input, captured_grads_output = [], []
-
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
out = origin_func(*hooked_inputs, **kwargs)
@@ -167,25 +173,27 @@ class Service:
return out
except Exception as exception:
- raise Exception(f"This is a primitive op dump error: {exception}"
- f", primitive_name: {primitive_name}")
+ raise Exception("This is a primitive op dump error: {},"
+ " primitive_name: {}".format(exception, primitive_name))
+
return func
- def _update_primitive_counters(self, primitive_name):
+ def update_primitive_counters(self, primitive_name):
if primitive_name not in self.primitive_counters:
self.primitive_counters[primitive_name] = 0
else:
self.primitive_counters[primitive_name] += 1
def register_hooks(self):
- primitive_set = set()
- for name, cell in self.model.cells_and_names():
- for pname, primitive in cell._primitives.items():
- primitive_set.add((pname, primitive))
-
- for pname, primitive in primitive_set:
- NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname)})
- primitive.__class__ = NewPrimitive
+ primitive_set = set()
+ for _, cell in self.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,),
+ {'__call__': self.wrap_primitive(primitive.__call__, pname)})
+ primitive.__class__ = NewPrimitive
def step(self):
self.current_iter += 1
--
Gitee
From 2e4e86c099cd49b30d4f8bd4b93ee2c86fb67d10 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 20:33:18 +0800
Subject: [PATCH 55/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 0829465bcc9..8f3e4c3472e 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -112,8 +112,8 @@ class Service:
" updated_primitive_name: {updated_primitive_name}".format(
hook_type=hook_type, exception=exception, updated_primitive_name=updated_primitive_name
)
- )
-#改为.format()
+ ) from exception
+
return backward_hook
@@ -174,7 +174,7 @@ class Service:
return out
except Exception as exception:
raise Exception("This is a primitive op dump error: {},"
- " primitive_name: {}".format(exception, primitive_name))
+ " primitive_name: {}".format(exception, primitive_name)) from exception
return func
--
Gitee
From 07bb2eb7e97b96f9c6c095b3bf1e44f3ee3b11fc Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Mon, 5 Aug 2024 20:35:57 +0800
Subject: [PATCH 56/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 8f3e4c3472e..bc81eb7bc56 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -266,8 +266,8 @@ class Service:
def register_hook_new(self):
logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
- # api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
- # api_register.api_set_hook_func()
+ api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
+ api_register.api_set_hook_func()
if self.model:
self.register_hooks()
--
Gitee
From 66504478621bb1cf3396fb640983f60c83a91c3d Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 11:16:01 +0800
Subject: [PATCH 57/67] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?=
=?UTF-8?q?=E6=84=8F=E8=A7=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../data_processor/mindspore_processor.py | 4 +-
.../msprobe/mindspore/service.py | 51 ++++++++++---------
2 files changed, 28 insertions(+), 27 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
index e7718504791..b28817e4aa7 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
@@ -75,8 +75,8 @@ class MindsporeDataProcessor(BaseDataProcessor):
return tensor_stat
elif data.dtype == ms.bool_:
data_np = data.asnumpy()
- tensor_stat.max = bool(np.max(data_np))
- tensor_stat.min = bool(np.min(data_np))
+ tensor_stat.max = np.max(data_np)
+ tensor_stat.min = np.min(data_np)
elif not data.shape:
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index bc81eb7bc56..6ff31664156 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -88,24 +88,25 @@ class Service:
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
def backward_hook(grad):
+ captured_grads.append(grad)
try:
- captured_grads.append(grad)
- if len(captured_grads) == num_tensors:
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
-
- if hook_type == 'input':
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
- service_instance.data_collector.backward_input_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
- elif hook_type == 'output':
- new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
- service_instance.data_collector.backward_output_data_collect(
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
- )
-
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
captured_grads.clear()
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ captured_grads.clear()
+
except Exception as exception:
raise Exception(
"This is a primitive op {hook_type}_backward dump error: {exception},"
@@ -114,14 +115,13 @@ class Service:
)
) from exception
-
return backward_hook
def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
hooked_inputs = []
num_tensors = sum(isinstance(arg, Tensor) for arg in args)
input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
- 'input')
+ Const.INPUT)
for _, arg in enumerate(args):
if isinstance(arg, Tensor):
arg_hooked = ops.HookBackward(input_backward_hook)(arg)
@@ -131,10 +131,12 @@ class Service:
return hooked_inputs
def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
- num_output_tensors = sum(
- isinstance(tensor, Tensor) for tensor in out if isinstance(out, tuple)) if isinstance(out, tuple) else 1
+ if isinstance(out, tuple):
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
+ else:
+ num_output_tensors = 1
output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
- updated_primitive_name, 'output')
+ updated_primitive_name, Const.OUTPUT)
if isinstance(out, Tensor):
return ops.HookBackward(output_backward_hook)(out)
@@ -148,7 +150,7 @@ class Service:
return tuple(hooked_outputs)
return out
- def func(instance_self, *args, **kwargs):
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
try:
service_instance.update_primitive_counters(primitive_name)
current_count = service_instance.primitive_counters[primitive_name]
@@ -176,7 +178,7 @@ class Service:
raise Exception("This is a primitive op dump error: {},"
" primitive_name: {}".format(exception, primitive_name)) from exception
- return func
+ return wrapped_primitive_call
def update_primitive_counters(self, primitive_name):
if primitive_name not in self.primitive_counters:
@@ -201,7 +203,6 @@ class Service:
HOOKCell.cell_count = defaultdict(int)
self.primitive_counters.clear()
-
def start(self, model=None):
self.model = model
self.start_call = True
@@ -269,7 +270,7 @@ class Service:
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
if self.model:
+ print(f"Type of self.model: {type(self.model)}") # 使用 print
+ # 或者使用 logger
+ logger.info(f"Type of self.model: {type(self.model)}")
self.register_hooks()
-
-
-
--
Gitee
From 57b22c4e8a42630c224aeecfef3fc4e2e45075dd Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 11:24:16 +0800
Subject: [PATCH 58/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 6ff31664156..147dfde8b79 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -265,7 +265,7 @@ class Service:
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
def register_hook_new(self):
- logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
+ logger.info("The {} hook function1111 is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
--
Gitee
From e9eac8eab43cec60586f062b98669f6a7884a0a3 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 12:43:16 +0800
Subject: [PATCH 59/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 147dfde8b79..9b9da9ec3fe 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -21,6 +21,7 @@ from collections import defaultdict
from mindspore.common.tensor import Tensor
from mindspore import ops
+from mindspore import nn
from msprobe.core.data_dump.data_collector import build_data_collector
from msprobe.core.data_dump.scope import BaseScope
from msprobe.mindspore.common.utils import get_rank_if_initialized
@@ -31,6 +32,7 @@ from msprobe.core.common.exceptions import DistributedNotInitializedError
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,\
ModuleBackwardInputs, ModuleBackwardOutputs
+from msprobe.core.common.exceptions import MsprobeException
from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
@@ -203,8 +205,16 @@ class Service:
HOOKCell.cell_count = defaultdict(int)
self.primitive_counters.clear()
+ def check_model_valid(model):
+ if not model or isinstance(model, nn.Cell):
+ return model
+ raise MsprobeException(
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
+ )
+
+
def start(self, model=None):
- self.model = model
+ self.model = self.check_model_valid(model)
self.start_call = True
logger.info("msprobe: debugger.start() is set successfully")
if self.config.step and self.current_iter > max(self.config.step):
--
Gitee
From deb692050d5252047d5a826dbc4bb2f8b1511bb2 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 12:47:34 +0800
Subject: [PATCH 60/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 9b9da9ec3fe..9f001fd8799 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -205,6 +205,7 @@ class Service:
HOOKCell.cell_count = defaultdict(int)
self.primitive_counters.clear()
+ @staticmethod
def check_model_valid(model):
if not model or isinstance(model, nn.Cell):
return model
--
Gitee
From 9ea3eb706d0fb414ed6149c7016ed5dc5afe767b Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 12:57:13 +0800
Subject: [PATCH 61/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 9f001fd8799..912cb9608ef 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -87,7 +87,7 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
-
+
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
def backward_hook(grad):
captured_grads.append(grad)
@@ -206,7 +206,7 @@ class Service:
self.primitive_counters.clear()
@staticmethod
- def check_model_valid(model):
+ def check_model_valid(self, model):
if not model or isinstance(model, nn.Cell):
return model
raise MsprobeException(
--
Gitee
From 89d57719f029081eaa936abeebfd1f0092706e78 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 13:01:59 +0800
Subject: [PATCH 62/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 912cb9608ef..9a240c379db 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -87,7 +87,7 @@ class Service:
def wrap_primitive(self, origin_func, primitive_name):
service_instance = self
-
+
def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
def backward_hook(grad):
captured_grads.append(grad)
@@ -206,7 +206,7 @@ class Service:
self.primitive_counters.clear()
@staticmethod
- def check_model_valid(self, model):
+ def check_model_valid(model):
if not model or isinstance(model, nn.Cell):
return model
raise MsprobeException(
@@ -215,7 +215,7 @@ class Service:
def start(self, model=None):
- self.model = self.check_model_valid(model)
+ self.model = Service.check_model_valid(model)
self.start_call = True
logger.info("msprobe: debugger.start() is set successfully")
if self.config.step and self.current_iter > max(self.config.step):
--
Gitee
From 1525a443446b9de08412448887c9691948f20811 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 14:15:10 +0800
Subject: [PATCH 63/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 9a240c379db..8b71e988c8d 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -205,7 +205,7 @@ class Service:
HOOKCell.cell_count = defaultdict(int)
self.primitive_counters.clear()
- @staticmethod
+# @staticmethod
def check_model_valid(model):
if not model or isinstance(model, nn.Cell):
return model
@@ -213,7 +213,6 @@ class Service:
MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
)
-
def start(self, model=None):
self.model = Service.check_model_valid(model)
self.start_call = True
--
Gitee
From f7ef581ea6ed7bba6dc7b01c4000004efa93e5a3 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 14:32:52 +0800
Subject: [PATCH 64/67] Update service.py
---
.../msprobe/mindspore/service.py | 52 +++++++++++++------
1 file changed, 35 insertions(+), 17 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index 8b71e988c8d..af0bf466152 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -153,33 +153,51 @@ class Service:
return out
def wrapped_primitive_call(instance_self, *args, **kwargs):
- try:
- service_instance.update_primitive_counters(primitive_name)
- current_count = service_instance.primitive_counters[primitive_name]
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
- if not service_instance.switch:
- return origin_func(*args, **kwargs)
+ service_instance.update_primitive_counters(primitive_name)
+ current_count = service_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
+
+ captured_grads_input, captured_grads_output = [], []
- captured_grads_input, captured_grads_output = [], []
+ try:
hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during input hooking: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+ try:
out = origin_func(*hooked_inputs, **kwargs)
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
- service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
- if service_instance.data_collector:
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during function call: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
+ if service_instance.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ try:
service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
os.getpid(), module_input_output)
- if service_instance.data_collector.if_return_forward_new_output():
- out = service_instance.data_collector.get_forward_new_output()
- out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during forward data collection: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
- return out
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
+
+ try:
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
except Exception as exception:
- raise Exception("This is a primitive op dump error: {},"
+ raise Exception("This is a primitive op dump error during output hooking: {},"
" primitive_name: {}".format(exception, primitive_name)) from exception
+ return out
+
+
return wrapped_primitive_call
def update_primitive_counters(self, primitive_name):
@@ -205,7 +223,7 @@ class Service:
HOOKCell.cell_count = defaultdict(int)
self.primitive_counters.clear()
-# @staticmethod
+ @staticmethod
def check_model_valid(model):
if not model or isinstance(model, nn.Cell):
return model
--
Gitee
From f32090296817bace40c8e971cae1c785866b6d9b Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 14:39:23 +0800
Subject: [PATCH 65/67] Update service.py
---
debug/accuracy_tools/msprobe/mindspore/service.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index af0bf466152..b795ec10342 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -293,12 +293,9 @@ class Service:
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
def register_hook_new(self):
- logger.info("The {} hook function1111 is successfully mounted to the model.".format(self.config.task))
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
if self.model:
- print(f"Type of self.model: {type(self.model)}") # 使用 print
- # 或者使用 logger
- logger.info(f"Type of self.model: {type(self.model)}")
self.register_hooks()
--
Gitee
From 554232ab08dc914c322e11528019b664068f55c6 Mon Sep 17 00:00:00 2001
From: yangxinxian <947098055@qq.com>
Date: Tue, 6 Aug 2024 15:04:40 +0800
Subject: [PATCH 66/67] primitive op dump
---
.gitignore | 9 +-
.../bench_functions/npu_fusion_attention.py | 56 +-
.../api_accuracy_checker/common/utils.py | 6 +-
.../run_ut/data_generate.py | 8 +-
.../run_ut/multi_run_ut.py | 11 +-
.../tensor_transport_layer/attl.py | 9 +-
debug/accuracy_tools/grad_tool/README.md | 34 +-
.../grad_tool/common/base_comparator.py | 20 +-
.../grad_tool/common/constant.py | 4 +-
.../accuracy_tools/grad_tool/common/utils.py | 11 +-
.../grad_tool/grad_ms/global_context.py | 30 +-
.../grad_tool/grad_ms/grad_analyzer.py | 26 +-
.../grad_tool/grad_ms/grad_comparator.py | 13 +-
.../accuracy_tools/grad_tool/grad_ms/utils.py | 19 +-
.../grad_tool/grad_pt/grad_comparator.py | 15 +-
.../grad_tool/grad_pt/grad_monitor.py | 7 +-
debug/accuracy_tools/msprobe/README.md | 45 +-
debug/accuracy_tools/msprobe/config/README.md | 60 +-
.../msprobe/core/common/const.py | 22 +-
.../msprobe/core/common/utils.py | 11 +-
.../msprobe/core/common_config.py | 41 +-
.../msprobe/core/data_dump/data_collector.py | 45 +-
.../core/data_dump/data_processor/base.py | 105 ++-
.../core/data_dump/data_processor/factory.py | 12 +-
.../data_processor/mindspore_processor.py | 90 +-
.../data_processor/pytorch_processor.py | 75 +-
.../msprobe/core/data_dump/json_writer.py | 22 +-
.../msprobe/mindspore/common/utils.py | 13 +
.../mindspore/debugger/debugger_config.py | 15 +-
.../mindspore/debugger/precision_debugger.py | 42 +-
.../dump/hook_cell/api_registry copy.py | 198 +++++
.../mindspore/dump/hook_cell/hook_cell.py | 18 +-
.../msprobe/mindspore/ms_config.py | 30 +-
.../msprobe/mindspore/service.py | 183 +++-
.../run_ut/data_generate.py | 12 +-
.../run_ut/multi_run_ut.py | 19 +-
.../api_accuracy_checker/run_ut/run_ut.py | 18 +-
.../pytorch/bench_functions/__init__.py | 15 +
.../pytorch/bench_functions/apply_adam_w.py | 28 +
.../bench_functions/confusion_transpose.py | 19 +
.../pytorch/bench_functions/fast_gelu.py | 55 ++
.../bench_functions/layer_norm_eval.py | 6 +
.../msprobe/pytorch/bench_functions/linear.py | 12 +
.../bench_functions/matmul_backward.py | 48 +
.../bench_functions/npu_fusion_attention.py | 421 +++++++++
.../pytorch/bench_functions/rms_norm.py | 15 +
.../pytorch/bench_functions/rotary_mul.py | 52 ++
.../bench_functions/scaled_mask_softmax.py | 26 +
.../msprobe/pytorch/bench_functions/swiglu.py | 55 ++
.../msprobe/pytorch/common/parse_json.py | 4 +-
.../msprobe/pytorch/common/utils.py | 35 +
.../msprobe/pytorch/compare/acc_compare.py | 29 +-
.../pytorch/compare/distributed_compare.py | 16 +-
.../pytorch/debugger/debugger_config.py | 2 +-
.../pytorch/debugger/precision_debugger.py | 1 +
.../pytorch/doc/api_accuracy_checker.md | 68 +-
.../msprobe/pytorch/doc/dump.md | 87 +-
.../pytorch/free_benchmark/common/constant.py | 3 +
.../pytorch/free_benchmark/common/utils.py | 4 +
.../result_handlers/base_handler.py | 72 +-
.../msprobe/pytorch/function_factory.py | 75 ++
.../pytorch/hook_module/hook_module.py | 6 +
.../pytorch/hook_module/support_wrap_ops.yaml | 3 +-
.../msprobe/pytorch/hook_module/wrap_aten.py | 21 +-
.../pytorch/hook_module/wrap_npu_custom.py | 20 +-
.../msprobe/pytorch/module_processer.py | 29 +-
.../msprobe/pytorch/pt_config.py | 4 +-
.../accuracy_tools/msprobe/pytorch/service.py | 15 +-
.../test/core_ut/test_common_config.py | 2 +-
.../test/mindspore_ut/test_ms_config.py | 4 +-
.../run_ut/test_multi_run_ut.py | 4 +-
.../msprobe/test/pytorch_ut/test_pt_config.py | 2 +-
debug/accuracy_tools/setup.py | 2 +-
.../.github/workflows/libkineto_ci.yml | 56 --
.../workflows/tb_plugin_build_pip_package.yml | 19 -
.../.github/workflows/tb_plugin_ci.yml | 57 --
plugins/tensorboard-plugins/.gitignore | 3 -
plugins/tensorboard-plugins/.gitmodules | 6 -
.../tensorboard-plugins/CODE_OF_CONDUCT.md | 77 --
plugins/tensorboard-plugins/CONTRIBUTING.md | 34 -
plugins/tensorboard-plugins/LICENSE | 33 -
plugins/tensorboard-plugins/README.md | 38 -
.../libkineto/CMakeLists.txt | 198 -----
.../tensorboard-plugins/libkineto/README.md | 65 --
.../libkineto/include/AbstractConfig.h | 113 ---
.../include/ActivityProfilerInterface.h | 91 --
.../include/ActivityTraceInterface.h | 21 -
.../libkineto/include/ActivityType.h | 34 -
.../libkineto/include/ClientInterface.h | 16 -
.../libkineto/include/Config.h | 433 ---------
.../libkineto/include/GenericTraceActivity.h | 125 ---
.../libkineto/include/IActivityProfiler.h | 104 ---
.../libkineto/include/ILoggerObserver.h | 50 --
.../libkineto/include/ITraceActivity.h | 53 --
.../libkineto/include/ThreadUtil.h | 22 -
.../libkineto/include/TraceSpan.h | 36 -
.../libkineto/include/libkineto.h | 138 ---
.../libkineto/include/time_since_epoch.h | 16 -
.../libkineto/libkineto_defs.bzl | 77 --
.../sample_programs/kineto_playground.cpp | 38 -
.../sample_programs/kineto_playground.cu | 60 --
.../sample_programs/kineto_playground.cuh | 18 -
.../libkineto/src/AbstractConfig.cpp | 188 ----
.../libkineto/src/ActivityBuffers.h | 29 -
.../libkineto/src/ActivityLoggerFactory.h | 60 --
.../src/ActivityProfilerController.cpp | 246 -----
.../src/ActivityProfilerController.h | 84 --
.../libkineto/src/ActivityProfilerProxy.cpp | 119 ---
.../libkineto/src/ActivityProfilerProxy.h | 73 --
.../libkineto/src/ActivityTrace.h | 45 -
.../libkineto/src/ActivityType.cpp | 58 --
.../libkineto/src/Config.cpp | 473 ----------
.../libkineto/src/ConfigLoader.cpp | 300 -------
.../libkineto/src/ConfigLoader.h | 147 ---
.../libkineto/src/CudaDeviceProperties.cpp | 130 ---
.../libkineto/src/CudaDeviceProperties.h | 31 -
.../libkineto/src/CuptiActivity.h | 114 ---
.../libkineto/src/CuptiActivity.tpp | 111 ---
.../libkineto/src/CuptiActivityApi.cpp | 343 -------
.../libkineto/src/CuptiActivityApi.h | 100 ---
.../libkineto/src/CuptiActivityBuffer.h | 51 --
.../libkineto/src/CuptiActivityPlatform.cpp | 31 -
.../libkineto/src/CuptiActivityPlatform.h | 12 -
.../libkineto/src/CuptiActivityProfiler.cpp | 841 ------------------
.../libkineto/src/CuptiActivityProfiler.h | 364 --------
.../libkineto/src/CuptiCallbackApi.cpp | 260 ------
.../libkineto/src/CuptiCallbackApi.h | 130 ---
.../libkineto/src/CuptiCallbackApiMock.h | 32 -
.../libkineto/src/CuptiEventApi.cpp | 112 ---
.../libkineto/src/CuptiEventApi.h | 49 -
.../libkineto/src/CuptiMetricApi.cpp | 107 ---
.../libkineto/src/CuptiMetricApi.h | 38 -
.../libkineto/src/CuptiNvPerfMetric.cpp | 504 -----------
.../libkineto/src/CuptiNvPerfMetric.h | 71 --
.../libkineto/src/CuptiRangeProfilerApi.cpp | 751 ----------------
.../libkineto/src/CuptiRangeProfilerApi.h | 220 -----
.../src/CuptiRangeProfilerConfig.cpp | 68 --
.../libkineto/src/CuptiRangeProfilerConfig.h | 86 --
.../libkineto/src/DaemonConfigLoader.h | 27 -
.../libkineto/src/Demangle.cpp | 49 -
.../libkineto/src/Demangle.h | 12 -
.../libkineto/src/EventProfiler.cpp | 635 -------------
.../libkineto/src/EventProfiler.h | 341 -------
.../libkineto/src/EventProfilerController.cpp | 423 ---------
.../libkineto/src/EventProfilerController.h | 63 --
.../libkineto/src/GenericTraceActivity.cpp | 10 -
.../libkineto/src/ILoggerObserver.cpp | 54 --
.../libkineto/src/Logger.cpp | 136 ---
.../libkineto/src/Logger.h | 244 -----
.../libkineto/src/LoggerCollector.h | 70 --
.../libkineto/src/RoctracerActivityApi.cpp | 569 ------------
.../libkineto/src/RoctracerActivityApi.h | 171 ----
.../libkineto/src/RoctracerActivityBuffer.h | 30 -
.../libkineto/src/SampleListener.h | 146 ---
.../libkineto/src/ScopeExit.h | 29 -
.../libkineto/src/ThreadUtil.cpp | 203 -----
.../libkineto/src/WeakSymbols.cpp | 12 -
.../libkineto/src/cupti_call.h | 33 -
.../libkineto/src/cupti_strings.cpp | 502 -----------
.../libkineto/src/cupti_strings.h | 14 -
.../libkineto/src/init.cpp | 139 ---
.../libkineto/src/libkineto_api.cpp | 41 -
.../libkineto/src/output_base.h | 104 ---
.../libkineto/src/output_csv.cpp | 88 --
.../libkineto/src/output_csv.h | 39 -
.../libkineto/src/output_json.cpp | 583 ------------
.../libkineto/src/output_json.h | 91 --
.../libkineto/src/output_membuf.h | 130 ---
.../libkineto/test/CMakeLists.txt | 3 -
.../libkineto/test/ConfigTest.cpp | 315 -------
.../test/CuptiActivityProfilerTest.cpp | 629 -------------
.../libkineto/test/CuptiCallbackApiTest.cpp | 239 -----
.../libkineto/test/CuptiProfilerApiTest.cu | 353 --------
.../test/CuptiRangeProfilerApiTest.cpp | 113 ---
.../test/CuptiRangeProfilerConfigTest.cpp | 67 --
.../test/CuptiRangeProfilerTestUtil.h | 96 --
.../libkineto/test/CuptiStringsTest.cpp | 29 -
.../libkineto/test/EventProfilerTest.cpp | 578 ------------
.../libkineto/test/LoggerObserverTest.cpp | 96 --
.../test/MockActivitySubProfiler.cpp | 49 -
.../libkineto/test/MockActivitySubProfiler.h | 72 --
.../libkineto/test/PidInfoTest.cpp | 27 -
profiler/README.md | 1 +
profiler/advisor/README.md | 5 +-
profiler/advisor/common/profiling/ge_info.py | 3 +-
profiler/advisor/common/profiling/msprof.py | 3 +-
.../advisor/common/profiling/op_summary.py | 4 +-
profiler/advisor/common/profiling/tasktime.py | 4 +-
.../config/profiling_data_version_config.yaml | 17 +-
.../dataset/profiling/profiling_dataset.py | 11 +-
.../dataset/profiling/profiling_parser.py | 27 +-
profiler/advisor/img/overall.png | Bin 64492 -> 49616 bytes
profiler/advisor/img/overall_0.png | Bin 0 -> 56377 bytes
profiler/advisor/utils/utils.py | 12 +-
profiler/affinity_cpu_bind/README.md | 40 -
profiler/affinity_cpu_bind/bind_core.py | 213 -----
profiler/cli/cluster_cli.py | 4 +-
profiler/cli/compare_cli.py | 2 +
profiler/cluster_analyse/README.md | 39 +-
.../analysis/analysis_facade.py | 12 -
.../cluster_analyse/analysis/base_analysis.py | 153 ----
.../analysis/cann_api_sum/__init__.py | 14 -
.../analysis/cann_api_sum/cann_api_sum.py | 108 ---
.../analysis/cann_api_sum/stats.ipynb | 86 --
.../analysis/cluster_display.py | 239 -----
.../analysis/compute_op_sum/__init__.py | 14 -
.../analysis/compute_op_sum/compute_op_sum.py | 103 ---
.../analysis/compute_op_sum/stats.ipynb | 164 ----
.../analysis/hccl_sum/__init__.py | 14 -
.../analysis/hccl_sum/hccl_sum.py | 133 ---
.../analysis/hccl_sum/stats.ipynb | 162 ----
.../analysis/mstx_sum/__init__.py | 14 -
.../analysis/mstx_sum/mstx_sum.py | 204 -----
.../analysis/mstx_sum/stats.ipynb | 180 ----
profiler/cluster_analyse/cluster_analysis.py | 75 +-
.../cluster_statistics_export/__init__.py | 14 -
.../cann_api_sum_export.py | 65 --
.../compute_op_sum_export.py | 49 -
.../hccl_sum_export.py | 39 -
.../mstx_mark_export.py | 57 --
.../mstx_step_export.py | 35 -
.../cluster_statistics_export/stats_export.py | 40 -
.../common_func/analysis_loader.py | 38 -
.../cluster_analyse/common_func/constant.py | 10 -
.../cluster_analyse/common_func/context.py | 85 --
.../cluster_analyse/common_func/db_manager.py | 7 -
.../common_func/sql_extention_func.py | 73 --
profiler/cluster_analyse/common_func/utils.py | 73 --
profiler/compare_tools/README.md | 82 +-
.../comparator/api_compare_comparator.py | 32 +
.../comparator/kernel_compare_comparator.py | 35 +
.../compare_bean/api_compare_bean.py | 47 +
.../compare_bean/kernel_compare_bean.py | 75 ++
.../origin_data_bean/kernel_details_bean.py | 6 +
.../data_prepare/operator_data_prepare.py | 17 +
.../generator/detail_performance_generator.py | 22 +-
.../profiling_parser/base_profiling_parser.py | 19 +-
.../profiling_parser/gpu_profiling_parser.py | 5 +
.../profiling_parser/npu_profiling_parser.py | 24 +
.../compare_backend/utils/args_manager.py | 13 +-
.../compare_backend/utils/compare_args.py | 4 +
.../compare_backend/utils/constant.py | 7 +-
.../compare_backend/utils/excel_config.py | 48 +-
.../compare_backend/utils/torch_op_node.py | 8 +
.../compare_backend/utils/tree_builder.py | 3 +-
.../view/work_sheet_creator.py | 12 +-
profiler/compare_tools/img/OverallMetrics.png | Bin 0 -> 66941 bytes
profiler/compare_tools/performance_compare.py | 2 +
profiler/merge_profiling_timeline/README.md | 115 ---
profiler/merge_profiling_timeline/main.py | 233 -----
...\345\257\274\346\210\252\345\233\2761.png" | Bin 53047 -> 0 bytes
...\345\257\274\346\210\252\345\233\2762.png" | Bin 64432 -> 0 bytes
profiler/module_visualization/__init__.py | 0
.../module_visualization/graph/__init__.py | 0
.../module_visualization/graph/prof_node.py | 90 --
.../graph_build/__init__.py | 0
.../graph_build/fwd_module_node.py | 29 -
.../graph_build/prof_graph_builder.py | 115 ---
.../module_visualization/prof_graph_export.py | 39 -
.../prof_parse/__init__.py | 0
.../prof_parse/prof_data_pre_process.py | 102 ---
.../test_base_profiling_parser.py | 5 +
262 files changed, 2720 insertions(+), 19791 deletions(-)
create mode 100644 debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry copy.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py
create mode 100644 debug/accuracy_tools/msprobe/pytorch/function_factory.py
delete mode 100644 plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml
delete mode 100644 plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml
delete mode 100644 plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml
delete mode 100644 plugins/tensorboard-plugins/.gitignore
delete mode 100644 plugins/tensorboard-plugins/.gitmodules
delete mode 100644 plugins/tensorboard-plugins/CODE_OF_CONDUCT.md
delete mode 100644 plugins/tensorboard-plugins/CONTRIBUTING.md
delete mode 100644 plugins/tensorboard-plugins/LICENSE
delete mode 100644 plugins/tensorboard-plugins/README.md
delete mode 100644 plugins/tensorboard-plugins/libkineto/CMakeLists.txt
delete mode 100644 plugins/tensorboard-plugins/libkineto/README.md
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ActivityProfilerInterface.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ActivityTraceInterface.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ActivityType.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ClientInterface.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/Config.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/GenericTraceActivity.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/IActivityProfiler.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ILoggerObserver.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ITraceActivity.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/ThreadUtil.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/TraceSpan.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/libkineto.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/include/time_since_epoch.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/libkineto_defs.bzl
delete mode 100644 plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cu
delete mode 100644 plugins/tensorboard-plugins/libkineto/sample_programs/kineto_playground.cuh
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/AbstractConfig.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityBuffers.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityLoggerFactory.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityProfilerController.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityProfilerProxy.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityTrace.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ActivityType.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/Config.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ConfigLoader.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ConfigLoader.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CudaDeviceProperties.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivity.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivity.tpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityBuffer.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityPlatform.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiActivityProfiler.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiCallbackApiMock.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiEventApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiMetricApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiNvPerfMetric.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/CuptiRangeProfilerConfig.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/DaemonConfigLoader.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/Demangle.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/Demangle.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/EventProfiler.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/EventProfiler.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/EventProfilerController.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/EventProfilerController.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/GenericTraceActivity.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ILoggerObserver.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/Logger.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/Logger.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/LoggerCollector.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/RoctracerActivityApi.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/RoctracerActivityBuffer.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/SampleListener.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ScopeExit.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/ThreadUtil.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/WeakSymbols.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/cupti_call.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/cupti_strings.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/cupti_strings.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/init.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/libkineto_api.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_base.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_csv.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_csv.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_json.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_json.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/src/output_membuf.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CMakeLists.txt
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/ConfigTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiActivityProfilerTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiCallbackApiTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiProfilerApiTest.cu
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerApiTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerConfigTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiRangeProfilerTestUtil.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/CuptiStringsTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/EventProfilerTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/LoggerObserverTest.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.cpp
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/MockActivitySubProfiler.h
delete mode 100644 plugins/tensorboard-plugins/libkineto/test/PidInfoTest.cpp
create mode 100644 profiler/advisor/img/overall_0.png
delete mode 100644 profiler/affinity_cpu_bind/README.md
delete mode 100644 profiler/affinity_cpu_bind/bind_core.py
delete mode 100644 profiler/cluster_analyse/analysis/cann_api_sum/__init__.py
delete mode 100644 profiler/cluster_analyse/analysis/cann_api_sum/cann_api_sum.py
delete mode 100644 profiler/cluster_analyse/analysis/cann_api_sum/stats.ipynb
delete mode 100644 profiler/cluster_analyse/analysis/cluster_display.py
delete mode 100644 profiler/cluster_analyse/analysis/compute_op_sum/__init__.py
delete mode 100644 profiler/cluster_analyse/analysis/compute_op_sum/compute_op_sum.py
delete mode 100644 profiler/cluster_analyse/analysis/compute_op_sum/stats.ipynb
delete mode 100644 profiler/cluster_analyse/analysis/hccl_sum/__init__.py
delete mode 100644 profiler/cluster_analyse/analysis/hccl_sum/hccl_sum.py
delete mode 100644 profiler/cluster_analyse/analysis/hccl_sum/stats.ipynb
delete mode 100644 profiler/cluster_analyse/analysis/mstx_sum/__init__.py
delete mode 100644 profiler/cluster_analyse/analysis/mstx_sum/mstx_sum.py
delete mode 100644 profiler/cluster_analyse/analysis/mstx_sum/stats.ipynb
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/__init__.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/cann_api_sum_export.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/compute_op_sum_export.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/hccl_sum_export.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/mstx_mark_export.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/mstx_step_export.py
delete mode 100644 profiler/cluster_analyse/cluster_statistics_export/stats_export.py
delete mode 100644 profiler/cluster_analyse/common_func/analysis_loader.py
delete mode 100644 profiler/cluster_analyse/common_func/context.py
delete mode 100644 profiler/cluster_analyse/common_func/sql_extention_func.py
delete mode 100644 profiler/cluster_analyse/common_func/utils.py
create mode 100644 profiler/compare_tools/compare_backend/comparator/api_compare_comparator.py
create mode 100644 profiler/compare_tools/compare_backend/comparator/kernel_compare_comparator.py
create mode 100644 profiler/compare_tools/compare_backend/compare_bean/api_compare_bean.py
create mode 100644 profiler/compare_tools/compare_backend/compare_bean/kernel_compare_bean.py
create mode 100644 profiler/compare_tools/img/OverallMetrics.png
delete mode 100644 profiler/merge_profiling_timeline/README.md
delete mode 100644 profiler/merge_profiling_timeline/main.py
delete mode 100644 "profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2761.png"
delete mode 100644 "profiler/merge_profiling_timeline/perfetto\344\275\277\347\224\250\346\214\207\345\257\274\346\210\252\345\233\2762.png"
delete mode 100644 profiler/module_visualization/__init__.py
delete mode 100644 profiler/module_visualization/graph/__init__.py
delete mode 100644 profiler/module_visualization/graph/prof_node.py
delete mode 100644 profiler/module_visualization/graph_build/__init__.py
delete mode 100644 profiler/module_visualization/graph_build/fwd_module_node.py
delete mode 100644 profiler/module_visualization/graph_build/prof_graph_builder.py
delete mode 100644 profiler/module_visualization/prof_graph_export.py
delete mode 100644 profiler/module_visualization/prof_parse/__init__.py
delete mode 100644 profiler/module_visualization/prof_parse/prof_data_pre_process.py
diff --git a/.gitignore b/.gitignore
index c70c40e0f52..01a2222429c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -142,4 +142,11 @@ cython_debug/
att_advisor*.html
*.xlsx
operator_tuning_file*.cfg
-.ipynb_checkpoints/
\ No newline at end of file
+.ipynb_checkpoints/
+.idea/vcs.xml
+.idea/inspectionProfiles/profiles_settings.xml
+.idea/misc.xml
+.idea/modules.xml
+.idea/mstt_primitive.iml
+.idea/.gitignore
+.gitignore
diff --git a/debug/accuracy_tools/api_accuracy_checker/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/api_accuracy_checker/bench_functions/npu_fusion_attention.py
index 4c230c17c04..d5a91ce3b5f 100644
--- a/debug/accuracy_tools/api_accuracy_checker/bench_functions/npu_fusion_attention.py
+++ b/debug/accuracy_tools/api_accuracy_checker/bench_functions/npu_fusion_attention.py
@@ -8,7 +8,6 @@ from api_accuracy_checker.common.utils import logger
gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
softmax_build_mode = "QKV" # "MAX_SUM"
-
"""
# 前向函数声明对比
标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
@@ -45,6 +44,9 @@ def softmax_grad(dp, softmax_res):
def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
+ if num_kv_heads == 0 or num_kv_heads < num_heads:
+ raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
+
factor = num_heads // num_kv_heads
kv_shape = kv_tensor.shape
B = kv_shape[0]
@@ -102,28 +104,34 @@ def parse_bsnd_args(query, key, head_num, input_layout):
if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
- if input_layout == "BSH":
- B, S1, H1 = query.shape
- _, S2, H2 = key.shape
- D = H1 // N1
- N2 = H2 // D
- elif input_layout == "SBH":
- S1, B, H1 = query.shape
- S2, _, H2 = key.shape
- D = H1 // N1
- N2 = H2 // D
- elif input_layout == "BSND":
- B, S1, N1, D = query.shape
- _, S2, N2, _ = key.shape
- H1 = N1 * D
- H2 = N2 * D
- elif input_layout == "BNSD":
- B, N1, S1, D = query.shape
- _, N2, S2, _ = key.shape
- H1 = N1 * D
- H2 = N2 * D
- elif input_layout == "TND":
+ if input_layout == "TND":
raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ try:
+ if input_layout == "BSH":
+ B, S1, H1 = query.shape
+ _, S2, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "SBH":
+ S1, B, H1 = query.shape
+ S2, _, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "BSND":
+ B, S1, N1, D = query.shape
+ _, S2, N2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ elif input_layout == "BNSD":
+ B, N1, S1, D = query.shape
+ _, N2, S2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ except Exception as e:
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
+
+ if D == 0:
+ raise ValueError(f"Value D must be non-zero.")
DTYPE = query.dtype
return B, S1, S2, N1, N2, D, H1, H2, DTYPE
@@ -251,6 +259,8 @@ def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softma
"""
print(f"Using softmax_max and softmax_sum to rebuild original softmax")
qk = calculate_qk(q, k, atten_mask, pse, scale)
+ if softmax_max.shape[-1] == 0:
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
softmax_sum.repeat(1, 1, 1, repeat_dim))
@@ -394,6 +404,8 @@ def npu_fusion_attention_grad(*args, **kwargs):
# N不等长适配by cdy
if not (N1 == N2):
+ if N2 == 0:
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
G = int(N1 / N2)
dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py
index 76d117afb49..83b73e90f97 100644
--- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py
+++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py
@@ -634,7 +634,11 @@ def initialize_save_path(save_path, dir_name):
def write_pt(file_path, tensor):
if os.path.exists(file_path):
raise ValueError(f"File {file_path} already exists")
- torch.save(tensor, file_path)
+ try:
+ torch.save(tensor, file_path)
+ except Exception as e:
+ error_message = "An unexpected error occurred: %s when saving tensor to %s" % (str(e), file_path)
+ print_error_log(error_message)
full_path = os.path.realpath(file_path)
file_check_util.change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
return full_path
diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py
index 67dc5ad2532..57811648391 100644
--- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py
+++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py
@@ -20,9 +20,10 @@ import math
import torch
import numpy
-from api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, \
- print_error_log, get_full_data_path, CompareException
+from api_accuracy_checker.common.utils import Const, check_object_type, print_warn_log, print_error_log, \
+ get_full_data_path, CompareException
from api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
+from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker
TORCH_TYPE = ["torch.device", "torch.dtype"]
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
@@ -83,7 +84,8 @@ def gen_real_tensor(data_path, convert_type):
convert_type: convert ori_type to dist_type flag.
"""
data_path = os.path.realpath(data_path)
- check_file_or_directory_path(data_path)
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
+ data_path = data_path_checker.common_check()
if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
error_info = f"The file: {data_path} is not a pt or numpy file."
raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py
index df6c99a567c..0ab8073937f 100644
--- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py
+++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py
@@ -88,14 +88,9 @@ def run_parallel_ut(config):
def update_progress_bar(progress_bar, result_csv_path):
while any(process.poll() is None for process in processes):
- try:
- with open(result_csv_path, 'r') as result_file:
- completed_items = len(result_file.readlines()) - 1
- progress_bar.update(completed_items - progress_bar.n)
- except FileNotFoundError:
- print_warn_log(f"Result CSV file not found: {result_csv_path}.")
- except Exception as e:
- print_error_log(f"An unexpected error occurred while reading result CSV: {e}")
+ with FileOpen(result_csv_path, 'r') as result_file:
+ completed_items = len(result_file.readlines()) - 1
+ progress_bar.update(completed_items - progress_bar.n)
time.sleep(1)
for fwd, bwd in zip(config.forward_files, config.backward_files):
diff --git a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py
index 0b91d2bbc82..5fb63779fbb 100644
--- a/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py
+++ b/debug/accuracy_tools/api_accuracy_checker/tensor_transport_layer/attl.py
@@ -13,6 +13,7 @@ import torch
from api_accuracy_checker.tensor_transport_layer.client import TCPClient
from api_accuracy_checker.tensor_transport_layer.server import TCPServer
from api_accuracy_checker.common.utils import logger
+from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker
from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import remove_path
@@ -138,8 +139,10 @@ class ATTL:
file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt")
else:
file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}")
-
- torch.save(buffer, file_path)
+ try:
+ torch.save(buffer, file_path)
+ except Exception as e:
+ self.logger.error("there is something error. please check it. %s", e)
def download(self):
for file_type in ("start*", "*.pt", "end*"):
@@ -150,6 +153,8 @@ class ATTL:
if cur_file is None:
return None
else:
+ cur_file_checker = FileChecker(cur_file, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
+ cur_file = cur_file_checker.common_check()
buffer = torch.load(cur_file)
remove_path(cur_file)
return buffer
diff --git a/debug/accuracy_tools/grad_tool/README.md b/debug/accuracy_tools/grad_tool/README.md
index a7929ca8187..1d35f03e479 100644
--- a/debug/accuracy_tools/grad_tool/README.md
+++ b/debug/accuracy_tools/grad_tool/README.md
@@ -28,7 +28,7 @@
### 梯度数据导出
-1. 创建配置文件config.yaml,PyTorch框架样例代码如下:
+1. 创建配置文件config.yaml,样例如下:
```python
level: L1
@@ -38,40 +38,30 @@
bounds:
output_path: your_output_dir
```
- > 在MindSpore框架下,当前不支持rank和step配置,默认所有rank和所有step都进行采集,
- > MindSpore中step指的是优化器被调用的次数(并非模型跑的step,某些step,例如loss为nan时,不会调用优化器)
+ > step指的是优化器被调用的次数(并非模型跑的step,某些step,例如loss为nan时,不会调用优化器)
**参数说明**
- | 参数 | 说明 | 是否必选 |
- |--------------------------------|----------------------------------------------------|----------|
- | level | Level级别,PyTorch可取值:L0、L1、L2,MindSpore可取值:L0, L1, L2, L3。决定导出数据的详细程度,级别越大导出数据越详细。数据类型:str。 | PyTorch是(MindSpore否,默认为L0) |
- | param_list | 填写需要监控的权重名称。不指定或列表为空就表示监控所有权重。数据类型:List[str]。 | 否 |
- | rank | 在多卡场景下,填写需要导出梯度数据的卡的Rank ID,不指定或列表为空就表示导出所有Rank的数据。单卡场景无需关注该参数。数据类型:List[int]。(MindSpore当前不支持指定rank) | 否 |
- | step | 指定需要导出数据的step。对于PyTorch不指定或列表为空就表示导出所有step的数据,对于MindSpore不指定表示导出所有step,指定时要求传入range列表,例如[1, 2],否则无效。数据类型:List[int]。(MindSpore当前不支持指定step) | 否 |
- | bounds | 用来划分区间以统计值分布。需要保证由数据小到大排列。不传则使用默认值[-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10](mindspore为[-0.1, 0., 1.0]),数据类型:List。 | 否 |
- | output_path | 输出目录。如果不存在就会创建一个新目录。数据类型:str。 | PyTorch是(MindSpore否,默认为./grad_stat |
+ | 参数 | 说明 | 输入类型 | 是否必选 |
+ |--------------------------------|-----------------------------------|-----------------|----------|
+ | level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2|str | 是 |
+ | param_list | 权重名称列表,表示需要监控的权重。不指定或列表为空就表示监控所有权重。 | List[str] | 否 |
+ | rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。不指定或列表为空就表示导出所有rank的数据。单卡场景无需关注该参数。 (MindSpore静态图模式下,当前暂不支持指定rank功能) | List[int] | 否 |
+ | step | step列表,表示需要导出数据的step列表。不指定或列表为空就表示导出所有step的数据。(MindSpore静态图模式下,当前暂不支持指定step功能) | List[int] | 否 |
+ | bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列。不指定则使用默认值[-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10] | List[float] | 否 |
+ | output_path | 输出目录。如果不存在就会创建一个新目录。 | str | 是 |
**不同级别的level的导出数据**
-- PyTorch/MindSpore动态图不同level数据
| 级别 | 特征数据表头 | 是否有方向数据 |
| ---- | ------------------------------------------------------------ | -------------- |
| L0 | ("param_name", "MD5", "max", "min", "norm", "shape") | 否 |
| L1 | ("param_name", "max", "min", "norm", "shape") | 是 |
| L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 |
-
-- MindSpore静态图不同level数据
-
- | 级别 | 特征数据表头 | 是否有方向数据 |
- | ---- | ------------------------------------------------------------ | -------------- |
- | L0 | ("param_name", "max", "min", "norm", "shape") | 否 |
- | L1 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 否 |
- | L2 | ("param_name", "max", "min", "norm", "shape") | 是 |
- | L3 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 |
intervals就是根据值分布bounds划分出的区间。
+ MindSpore静态图模式下,L0级别中暂不支持"MD5"
**方向数据解释**
@@ -98,7 +88,7 @@ gm = GradientMonitor("config_path", framework="MindSpore")
gm.monitor(optimizer)
```
-3. 结束监控(MindSpore需要)
+3. 结束监控(MindSpore静态图模式下需要)
在训练结束之后,调用stop接口
diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py
index d3254ae71f9..03f74a21e47 100644
--- a/debug/accuracy_tools/grad_tool/common/base_comparator.py
+++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py
@@ -7,7 +7,10 @@ import pandas as pd
import matplotlib.pyplot as plt
from grad_tool.common.constant import GradConst
-from grad_tool.common.utils import write_csv, check_file_or_directory_path, print_info_log, create_directory
+from grad_tool.common.utils import write_csv, check_file_or_directory_path, print_info_log, create_directory, print_error_log
+
+from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util
+from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, check_path_pattern_valid, check_path_length
class BaseComparator(ABC):
@@ -85,8 +88,19 @@ class BaseComparator(ABC):
picture_dir = os.path.join(output_dir, "similarities_picture")
if not os.path.isdir(picture_dir):
create_directory(picture_dir)
- plt.savefig(os.path.join(picture_dir, f"{key}_similarities.png"))
- plt.close()
+ file_path= os.path.join(picture_dir, f"{key}_similarities.png")
+ if os.path.exists(file_path):
+ raise ValueError(f"File {file_path} already exists")
+ check_path_length(file_path)
+ check_path_pattern_valid(file_path)
+ try:
+ plt.savefig(file_path)
+ plt.close()
+ except Exception as e:
+ error_message = "An unexpected error occurred: %s when savfig to %s" % (str(e), file_path)
+ print_error_log(error_message)
+ full_path = os.path.realpath(file_path)
+ file_check_util.change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY)
head_tuple = tuple(['step'] + [str(step) for step in steps])
write_csv(os.path.join(output_dir, "similarities.csv"), [[key] + value], head_tuple)
diff --git a/debug/accuracy_tools/grad_tool/common/constant.py b/debug/accuracy_tools/grad_tool/common/constant.py
index d569d47c16d..38d33e98864 100644
--- a/debug/accuracy_tools/grad_tool/common/constant.py
+++ b/debug/accuracy_tools/grad_tool/common/constant.py
@@ -23,8 +23,7 @@ class GradConst:
LEVEL0 = "L0"
LEVEL1 = "L1"
LEVEL2 = "L2"
- LEVEL3 = "L3"
- SUPPORTED_LEVEL = {"L0", "L1", "L2", "L3"}
+ SUPPORTED_LEVEL = {"L0", "L1", "L2"}
# numpy coding
STEP_IDX = 0
@@ -40,6 +39,7 @@ class GradConst:
DIRECTORY_LENGTH = 4096
FILE_NAME_LENGTH = 255
FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
+ PARAM_VALID_PATTERN = r"^[a-zA-Z0-9.]+$"
DIR = "dir"
FILE = "file"
diff --git a/debug/accuracy_tools/grad_tool/common/utils.py b/debug/accuracy_tools/grad_tool/common/utils.py
index cdce3fda7e3..fceda8ce0f2 100644
--- a/debug/accuracy_tools/grad_tool/common/utils.py
+++ b/debug/accuracy_tools/grad_tool/common/utils.py
@@ -7,6 +7,7 @@ import yaml
import pandas as pd
from grad_tool.common.constant import GradConst
+from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen
def _print_log(level, msg, end='\n'):
@@ -114,7 +115,7 @@ class ListCache(list):
def get_config(filepath):
- with open(filepath, 'r') as file:
+ with FileOpen(filepath, 'r') as file:
config = yaml.safe_load(file)
return config
@@ -220,3 +221,11 @@ def change_mode(path, mode):
except PermissionError as ex:
print_error_log(f'Failed to change {path} authority. {str(ex)}')
raise ex
+
+def check_param(param_name):
+ if not re.match(GradConst.PARAM_VALID_PATTERN, param_name):
+ raise RuntimeError("The parameter name contains special characters.")
+
+def check_str(string, variable_name):
+ if not isinstance(string, str):
+ raise ValueError(f'The variable: "{variable_name}" is not a string.')
\ No newline at end of file
diff --git a/debug/accuracy_tools/grad_tool/grad_ms/global_context.py b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py
index d44bea52c78..424f16aedd3 100644
--- a/debug/accuracy_tools/grad_tool/grad_ms/global_context.py
+++ b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py
@@ -4,7 +4,7 @@ from typing import Dict, List, Union
from grad_tool.common.utils import print_warn_log
from grad_tool.common.constant import GradConst
-from grad_tool.common.utils import path_valid_check, create_directory
+from grad_tool.common.utils import path_valid_check, create_directory, check_str
class GlobalContext:
@@ -12,13 +12,13 @@ class GlobalContext:
_instance = None
_instance_lock = threading.Lock()
_setting = {
- GradConst.LEVEL: GradConst.LEVEL0,
+ GradConst.LEVEL: None,
GradConst.PARAM_LIST: None,
GradConst.STEP: None,
GradConst.RANK: None,
GradConst.CURRENT_STEP: 0,
- GradConst.BOUNDS: [-1., 0., 1.],
- GradConst.OUTPUT_PATH: "./grad_stat"
+ GradConst.BOUNDS: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10],
+ GradConst.OUTPUT_PATH: None
}
def __new__(cls, *args, **kwargs):
@@ -29,23 +29,25 @@ class GlobalContext:
return cls._instance
def init_context(self, config_dict: Dict):
- if config_dict.get(GradConst.LEVEL, None) in GradConst.SUPPORTED_LEVEL:
+ level = config_dict.get(GradConst.LEVEL)
+ check_str(level, variable_name = "level in yaml")
+ if level in GradConst.SUPPORTED_LEVEL:
self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL)
else:
- print_warn_log("Invalid level set in config yaml file, use L0 instead.")
+ raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2")
+
self._set_input_list(config_dict, GradConst.PARAM_LIST, str)
self._set_input_list(config_dict, GradConst.BOUNDS, float)
self._set_input_list(config_dict, GradConst.STEP, int)
self._set_input_list(config_dict, GradConst.RANK, int)
+
output_path = config_dict.get(GradConst.OUTPUT_PATH)
- if output_path:
- try:
- path_valid_check(output_path)
- except RuntimeError as err:
- print_warn_log(f"Invalid output_path, use default output_path. The error message is {err}.")
- output_path = None
- if output_path:
- self._setting[GradConst.OUTPUT_PATH] = output_path
+ check_str(output_path, variable_name = "output_path in yaml")
+ try:
+ path_valid_check(output_path)
+ except RuntimeError as err:
+ raise ValueError(f"Invalid output_path: {output_path}. The error message is {err}.") from err
+ self._setting[GradConst.OUTPUT_PATH] = output_path
if not os.path.isdir(self._setting.get(GradConst.OUTPUT_PATH)):
create_directory(self._setting.get(GradConst.OUTPUT_PATH))
else:
diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py
index 75280b31944..c843df3884e 100644
--- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py
+++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py
@@ -16,6 +16,7 @@ from grad_tool.common.utils import ListCache, print_warn_log
from grad_tool.common.utils import create_directory, check_file_or_directory_path, write_csv
from grad_tool.grad_ms.global_context import grad_context
from grad_tool.grad_ms.global_context import GlobalContext
+from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker
def get_rank_id():
@@ -31,11 +32,10 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
'''
Dump gradient statistic data.
level0: [step, max, min, norm, shape_dim, shape]
- level1: [step, max, min, norm, shape_dim, shape, dist_dim, dist]
- level2: [step, max, min, norm, shape_dim, shape] + grad_bool_data
- level3: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
+ level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data
+ level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data
'''
- dump_path = dump_dir + g_name
+ dump_path = os.path.join(dump_dir, g_name)
dump_dir_path = dump_path + "_dir"
save_op = ms.ops.TensorDump()
@@ -51,7 +51,7 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0)
level_stat = level0_stat
- if level == "L1" or level == "L3":
+ if level == GradConst.LEVEL2:
zero_grad = (grad == 0).sum()
dist_dim = ms.Tensor([len(bounds) + 2]).float()
bucket_result = ms.ops.bucketize(grad.float(), bounds)
@@ -60,11 +60,11 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor,
dist_stat.append(zero_grad)
dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty
dist_stat = ms.ops.stack(dist_stat, axis=0).float()
- level1_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0)
- level_stat = level1_stat
+ level2_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0)
+ level_stat = level2_stat
save_op(dump_path, level_stat)
- if level == "L2" or level == "L3":
+ if level == GradConst.LEVEL1 or level == GradConst.LEVEL2:
grad_direction = grad > 0
save_op(dump_dir_path, grad_direction)
@@ -155,7 +155,7 @@ class CSVGenerator(Process):
level = grad_context.get_context(GradConst.LEVEL)
try:
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
- if level in [GradConst.LEVEL1, GradConst.LEVEL3]:
+ if level == GradConst.LEVEL2:
dist_dim = int(stat_data[shape_dim + GradConst.SHAPE_DIM_IDX + 1])
length = shape_dim + dist_dim + 7
else:
@@ -170,6 +170,8 @@ class CSVGenerator(Process):
stat_data = None
max_try = 10
while max_try:
+ file_path_checker = FileChecker(file_path, FileCheckConst.DIR,FileCheckConst.READ_ABLE)
+ file_path = file_path_checker.common_check()
try:
stat_data = np.load(file_path)
return stat_data
@@ -178,7 +180,7 @@ class CSVGenerator(Process):
max_try -= 1
time.sleep(0.1)
return stat_data
-
+
def gen_csv_line(self, file_path: str, stat_data) -> None:
shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX])
file_name = os.path.basename(file_path)
@@ -187,7 +189,7 @@ class CSVGenerator(Process):
if not param_name:
raise RuntimeError("Invalid gradient statistic file name.")
csv_line = [param_name]
- if self.level == GradConst.LEVEL1 or self.level == GradConst.LEVEL3:
+ if self.level == GradConst.LEVEL2:
csv_line.extend(self.get_dist_data(shape_dim, stat_data))
csv_line.extend(self.get_extrem_data(shape_dim, stat_data))
self.cache_list.append(csv_line)
@@ -208,7 +210,7 @@ class CSVGenerator(Process):
def create_csv_file(self):
headers = ["Param_name"]
- if self.level == GradConst.LEVEL1 or self.level == GradConst.LEVEL3:
+ if self.level == GradConst.LEVEL2:
headers.extend(self.get_dist_header())
headers.extend(self.get_extrem_headers())
output_path = f"{self.save_dir}/grad_summary_{self.current_step}.csv"
diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py
index 2bfeda4387e..3b930d4e283 100644
--- a/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py
+++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py
@@ -9,10 +9,19 @@ class MsGradComparator(BaseComparator):
@classmethod
def _load_grad_files(cls, grad_file1: str, grad_file2: str):
+ if not os.path.exists(grad_file1):
+ raise ValueError(f"file {grad_file1} not exists, please check the file path.")
+ if not os.path.exists(grad_file2):
+ raise ValueError(f"file {grad_file2} not exists, please check the file path.")
+
grad1_suffix = grad_file1.split(".")[-1]
grad2_suffix = grad_file2.split(".")[-1]
- grad1 = torch.load(grad_file1).numpy() if grad1_suffix == "pt" else np.load(grad_file1)
- grad2 = torch.load(grad_file2).numpy() if grad2_suffix == "pt" else np.load(grad_file2)
+
+ try:
+ grad1 = torch.load(grad_file1).numpy() if grad1_suffix == "pt" else np.load(grad_file1)
+ grad2 = torch.load(grad_file2).numpy() if grad2_suffix == "pt" else np.load(grad_file2)
+ except Exception as e:
+ raise RuntimeError(f"An unexpected error occurred: {e} when loading grad_file.") from e
if grad1.shape != grad2.shape:
raise RuntimeError(f"numpy shape is not equal: {grad_file1}, {grad_file2}")
diff --git a/debug/accuracy_tools/grad_tool/grad_ms/utils.py b/debug/accuracy_tools/grad_tool/grad_ms/utils.py
index 23703f28208..c8ee1fd1d45 100644
--- a/debug/accuracy_tools/grad_tool/grad_ms/utils.py
+++ b/debug/accuracy_tools/grad_tool/grad_ms/utils.py
@@ -3,7 +3,8 @@ import os
import numpy as np
import mindspore
from grad_tool.common.constant import GradConst
-from grad_tool.common.utils import print_warn_log, create_directory, change_mode, check_file_or_directory_path
+from grad_tool.common.utils import (print_warn_log, create_directory, change_mode, check_file_or_directory_path,
+ path_valid_check, check_param)
level_adp = {
"L0": {
@@ -20,23 +21,27 @@ level_adp = {
},
}
+
def save_grad_direction(param_name, grad, save_path):
if not os.path.exists(save_path):
create_directory(save_path)
+ check_file_or_directory_path(save_path, file_type=GradConst.DIR)
+ check_param(param_name)
save_filepath = os.path.join(save_path, f"{param_name}.npy")
- check_file_or_directory_path(save_filepath)
+ path_valid_check(save_filepath)
if grad.dtype == mindspore.bfloat16:
grad = grad.to(mindspore.float32)
grad_direction_tensor = grad > 0
grad_direction_ndarray = grad_direction_tensor.numpy()
- np.save(save_filepath, grad_direction_ndarray)
+ try:
+ np.save(save_filepath, grad_direction_ndarray)
+ except Exception as e:
+ raise RuntimeError(f"An unexpected error occurred: {e} when saving numpy to {save_filepath}") from e
change_mode(save_filepath, 0o640)
+
def get_adapted_level(level: str):
- if level == GradConst.LEVEL3:
- print_warn_log(f"In mindpsore pynative mode, only 'L0', 'L1' and 'L2' are supported, use L0 instead")
- level = GradConst.LEVEL0
level_adapted = level_adp.get(level)
- return level_adapted
\ No newline at end of file
+ return level_adapted
diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py
index d1229b93de7..38f0e32153e 100644
--- a/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py
+++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py
@@ -1,3 +1,5 @@
+import os
+
import torch
from grad_tool.common.base_comparator import BaseComparator
@@ -7,8 +9,17 @@ class PtGradComparator(BaseComparator):
@classmethod
def _load_grad_files(cls, grad_file1: str, grad_file2: str):
- tensor1 = torch.load(grad_file1, map_location=torch.device("cpu"))
- tensor2 = torch.load(grad_file2, map_location=torch.device("cpu"))
+ if not os.path.exists(grad_file1):
+ raise ValueError(f"file {grad_file1} not exists, please check the file path.")
+ if not os.path.exists(grad_file2):
+ raise ValueError(f"file {grad_file2} not exists, please check the file path.")
+
+ try:
+ tensor1 = torch.load(grad_file1, map_location=torch.device("cpu"))
+ tensor2 = torch.load(grad_file2, map_location=torch.device("cpu"))
+ except Exception as e:
+ raise RuntimeError(f"An unexpected error occurred: {e} when loading tensor.") from e
+
if tensor1.shape != tensor2.shape:
raise RuntimeError(f"tensor shape is not equal: {grad_file1}, {grad_file2}")
if tensor1.dtype != torch.bool:
diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
index f3079e622c2..2e1abde0d1a 100644
--- a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
+++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
@@ -61,7 +61,10 @@ class PtGradientMonitor(BaseMonitor):
param_grad = grad.clone().detach()
is_positive = param_grad > 0
save_filepath = os.path.join(save_path, f"{param_name}.pt")
- torch.save(is_positive, save_filepath)
+ try:
+ torch.save(is_positive, save_filepath)
+ except Exception as e:
+ raise RuntimeError(f"An unexpected error occurred: {e} when saving tensor to {save_filepath}") from e
change_mode(save_filepath, 0o640)
def monitor(self, model):
@@ -96,7 +99,7 @@ class PtGradientMonitor(BaseMonitor):
output_lines.append(grad_info)
if self._level_adp["have_grad_direction"]:
PtGradientMonitor.save_grad_direction(param_name, grad,
- f'{self._output_path}/rank{self._rank}/step{self._step}')
+ f'{self._output_path}/rank{self._rank}/step{self._step}')
output_path = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}",
f"grad_summary_{self._step}.csv")
write_csv(output_path, output_lines,
diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md
index 1e8c1a1f08d..42743c50781 100644
--- a/debug/accuracy_tools/msprobe/README.md
+++ b/debug/accuracy_tools/msprobe/README.md
@@ -10,10 +10,15 @@ MindStudio精度调试工具(MindStudio Probe),简称msprobe,是MindStud
```shell
pip install mindstudio-probe
```
- 说明
- 1. 使用`pip install mindstudio-probe==版本号`可安装指定版本的包
- 2. pip命令会自动安装包及其依赖
- 3. 安装成功后,日志会显示`Successfully installed mindstudio-probe-版本号`
+使用`pip install mindstudio-probe==版本号`可安装指定版本的包。
+
+pip命令会自动安装最新的包及其配套依赖。
+
+提示如下信息则表示安装成功。
+
+```bash
+Successfully installed mindstudio_probe-{version}
+```
### 下载whl包安装
1. 使用pip命令安装numpy、openpyxl、pandas、PyYAML、rich、torch、tqdm依赖。
@@ -26,6 +31,7 @@ MindStudio精度调试工具(MindStudio Probe),简称msprobe,是MindStud
| 版本 | 发布日期 | 支持PyTorch版本 | 下载链接 | 校验码 |
| ----- | ---------- | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
+ | 1.0.1 | 2024-07-25 | 2.0/2.1/2.2 | [mindstudio_probe-1.0.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/1.0/mindstudio_probe-1.0.1-py3-none-any.whl) | b699e224e4d4e3bcf9412c54fa858a1ee370f0d7a2bc69cb3f1273ac14a6dc82 |
| 1.0 | 2024-07-09 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-1.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/1.0/ascend_training_accuracy_tools-1.0-py3-none-any.whl) | 5016dfe886c5d340ec6f60a959673355855f313c91f100680da814efb49f8e81 |
| 0.0.3 | 2024-06-11 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.3-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.3-py3-none-any.whl) | f46d9714704859e2d67861a65bbb3c76b0a250cf6e238b978b5b959ab1fe125a |
| 0.0.2 | 2024-05-23 | 2.0/2.1/2.2 | [ascend_training_accuracy_tools-0.0.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/att/0.0/ascend_training_accuracy_tools-0.0.2-py3-none-any.whl) | 2e35809bde559e9c4d2f16a02ccde779ed9e436bb65fded0b7ebaf6ac2c88d93 |
@@ -92,6 +98,37 @@ MindStudio精度调试工具(MindStudio Probe),简称msprobe,是MindStud
Finished processing dependencies for mindstudio-probe=={version}
```
+### 查看msprobe工具信息
+
+执行如下命令查看msprobe工具信息。
+
+```bash
+pip show mindstudio-probe
+```
+
+输出结果如下示例:
+
+```bash
+Name: mindstudio-probe
+Version: 1.0
+Summary: This is a pytorch precision comparison tools
+Home-page:
+Author:
+Author-email:
+License:
+Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+Required-by:
+```
+
+关键字段含义:
+
+- Name:工具名称。
+- Version:工具版本号。
+- Summary:工具概述。
+- Location:工具安装路径。
+- Requires:工具依赖。
+
## 工具使用
安装msprobe工具后,可以按照如下思路选择合适的子工具进行精度调试:
diff --git a/debug/accuracy_tools/msprobe/config/README.md b/debug/accuracy_tools/msprobe/config/README.md
index 7b91bd26f16..7d11a365253 100644
--- a/debug/accuracy_tools/msprobe/config/README.md
+++ b/debug/accuracy_tools/msprobe/config/README.md
@@ -2,13 +2,38 @@
当前配置文件主要为PrecisionDebugger接口执行dump或无标杆比对操作时调用的配置,当PrecisionDebugger接口未指定该配置文件时,使用该文件的默认配置。配置文件详见[config.json](./config.json)。
+当在环境上安装msprobe工具后,config.json文件位置可通过如下方式查找:
+
+查找msprobe工具安装路径。
+
+```
+pip show mindstudio-probe
+```
+
+输出结果如下示例:
+
+```
+Name: mindstudio-probe
+Version: 1.0
+Summary: This is a pytorch precision comparison tools
+Home-page:
+Author:
+Author-email:
+License:
+Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+Required-by:
+```
+
+Location字段为msprobe工具的安装路径,那么config.json文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/config
+
## 参数说明
### **通用配置参数**
| 参数名 | 说明 | 是否必选 |
| ----------------- | ------------------------------------------------------------ | -------- |
-| task | dump的任务类型,str类型。可取值"free_benchmark"(无标杆比对,仅PyTorch场景支持)、"statistics"(仅dump API统计信息,默认值)、"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)、"overflow_check"(溢出检测)。配置示例:"task": "tensor"。根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”。 | 否 |
+| task | dump的任务类型,str类型。可取值:
"free_benchmark"(无标杆比对,仅PyTorch场景支持)。
"statistics"(仅dump API统计信息,默认值)。
"tensor"(dump API统计信息和完全复刻整网的API运行情况的真实数据)。
"overflow_check"(溢出检测,仅PyTorch和MindSpore静态图场景支持)。
"run_ut"(精度预检配置,仅PyTorch场景支持)。
配置示例:"task": "tensor"。
根据task参数取值的不同,可以配置不同场景参数,详见:“**task配置为free_benchmark**”,“**task配置为statistics**”,“**task配置为tensor**”,“**task配置为overflow_check**”,“**task配置为run_ut**”。 | 否 |
| dump_path | 设置dump数据目录路径,str类型。配置示例:"dump_path": "./dump_path"。MindSpore场景仅支持绝对路径。 | 是 |
| rank | 指定对某张卡上的数据进行dump,list[int]类型,默认未配置(表示dump所有卡的数据),应配置为大于等于0的整数,且须配置实际可用的Rank ID。配置示例:"rank": [1]。
对于PyTorch场景,Rank ID从0开始计数,最大取值为所有节点可用卡总数-1,若所配置的值大于实际训练所运行的卡的Rank ID,则dump数据为空,比如当前环境Rank ID为0到7,实际训练运行0到3卡,此时若配置Rank ID为4或不存在的10等其他值,此时dump数据为空。
对于MindSpore场景,所有节点的Rank ID均从0开始计数,最大取值为每个节点可用卡总数-1,config.json配置一次rank参数对所有节点同时生效。 | 否 |
| step | 指定dump某个step的数据,list[int]类型。默认未配置,表示dump所有step数据。dump特定step时,须指定为训练脚本中存在的step。step为list格式,可配置逐个step,例如:"step": [0,1,2]。 | 否 |
@@ -85,6 +110,18 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通
| overflow_nums | 控制溢出次数,int类型,仅PyTorch场景支持,表示第N次溢出时,停止训练,过程中检测到溢出API对应kernel数据均dump。配置示例:"overflow_nums": 3。默认为1,即检测到1次溢出,训练停止,配置为-1时,表示持续检测溢出直到训练结束。 | 否 |
| check_mode | MindSpore场景kernel级别的溢出检测,str类型,可取值"aicore"(开启AI Core的溢出检测)、"atomic"(开启Atomic的溢出检测)、"all"(开启AI Core和Atomic的溢出检测,默认值)。配置示例"check_mode": "aicore"。 | 否 |
+### task配置为run_ut
+
+仅PyTorch场景支持。
+
+| 参数名称 | 说明 | 是否必选 |
+| --------------- | ------------------------------------------------------------ | -------- |
+| white_list | API dump白名单,仅对指定的API进行dump。配置示例:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
+| black_list | API dump黑名单,被指定的API不进行dump。配置示例:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
+| error_data_path | 配置保存精度未达标的API输入输出数据路径,默认为当前路径。配置示例"error_data_path": "./"。 | 否 |
+
+说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
+
## 配置示例
以下示例包含当前支持的所有场景可配置的完整参数。
@@ -180,6 +217,27 @@ task配置为free_benchmark时,开启**无标杆比对**,在NPU环境下通
}
```
+### PyTorch场景task配置为run_ut
+
+```json
+{
+ "task": "run_ut",
+ "dump_path": "/home/data_dump",
+ "rank": [],
+ "step": [],
+ "level": "L1",
+ "seed": 1234,
+ "is_deterministic": false,
+ "enable_dataloader": false,
+
+ "run_ut": {
+ "white_list": [],
+ "black_list": [],
+ "error_data_path": "./"
+ }
+}
+```
+
### MindSpore场景task配置为statistics
```json
diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
index df82455a676..eff7b8be8ad 100644
--- a/debug/accuracy_tools/msprobe/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -16,6 +16,7 @@ class Const:
OFF = 'OFF'
BACKWARD = 'backward'
FORWARD = 'forward'
+ PRIMITIVE_PREFIX = 'Primitive'
DEFAULT_LIST = []
DEFAULT_PATH = './'
WHITE_LIST = 'white_list'
@@ -45,6 +46,7 @@ class Const:
PT_SUFFIX = ".pt"
ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
+ ONE_MB = 1048576 # 1 * 1024 * 1024
FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
DISTRIBUTED_PREFIX_LENGTH = 60
# env dump path
@@ -80,12 +82,12 @@ class Const:
INT_TYPE = [np.int32, np.int64]
NPU = 'NPU'
DISTRIBUTED = 'Distributed'
-
+
INPLACE_LIST = [
"broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
- "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single"
+ "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all"
]
-
+
CONVERT = {
"int32_to_int64": ["torch.int32", "torch.int64"],
}
@@ -252,3 +254,17 @@ class OverflowConst:
OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE"
OVERFLOW_ORIGINAL_MODE = 0
OVERFLOW_DEBUG_MODE = 1
+
+
+class MsConst:
+ CELL = "cell"
+ API = "api"
+ KERNEL = "kernel"
+ TOOL_LEVEL_DICT = {
+ "L0": CELL,
+ "L1": API,
+ "L2": KERNEL
+ }
+ PYNATIVE_MODE = "pynative"
+ GRAPH_GE_MODE = "graph_ge"
+ GRAPH_KBYK_MODE = "graph_kbyk"
diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py
index 32aba8d8af4..6e901deb9eb 100644
--- a/debug/accuracy_tools/msprobe/core/common/utils.py
+++ b/debug/accuracy_tools/msprobe/core/common/utils.py
@@ -148,7 +148,7 @@ def check_summary_only_valid(summary_only):
return summary_only
-def check_compare_param(input_parma, output_path, stack_mode=False, summary_compare=False, md5_compare=False):
+def check_compare_param(input_parma, output_path, summary_compare=False, md5_compare=False):
if not (isinstance(input_parma, dict) and isinstance(output_path, str)):
logger.error("Invalid input parameters")
raise CompareException(CompareException.INVALID_PARAM_ERROR)
@@ -318,15 +318,6 @@ def execute_command(cmd):
raise CompareException(CompareException.INVALID_DATA_ERROR)
-def save_numpy_data(file_path, data):
- """
- save_numpy_data
- """
- if not os.path.exists(os.path.dirname(file_path)):
- os.makedirs(os.path.dirname(file_path))
- np.save(file_path, data)
-
-
def parse_value_by_comma(value):
"""
parse value by comma, like '1,2,4,8'
diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py
index ed38eba008b..d6c15e101e7 100644
--- a/debug/accuracy_tools/msprobe/core/common_config.py
+++ b/debug/accuracy_tools/msprobe/core/common_config.py
@@ -18,24 +18,27 @@ class CommonConfig:
def _check_config(self):
if self.task and self.task not in Const.TASK_LIST:
- logger.error_log_with_exp(
- "task is invalid, it should be one of {}".format(Const.TASK_LIST), MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("task is invalid, it should be one of {}".format(Const.TASK_LIST),
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.rank is not None and not isinstance(self.rank, list):
- logger.error_log_with_exp("rank is invalid, it should be a list", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("rank is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.step is not None and not isinstance(self.step, list):
- logger.error_log_with_exp("step is invalid, it should be a list", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("step is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.level and self.level not in Const.LEVEL_LIST:
- logger.error_log_with_exp(
- "level is invalid, it should be one of {}".format(Const.LEVEL_LIST), MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("level is invalid, it should be one of {}".format(Const.LEVEL_LIST),
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.seed is not None and not isinstance(self.seed, int):
- logger.error_log_with_exp("seed is invalid, it should be an integer", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("seed is invalid, it should be an integer",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if not isinstance(self.is_deterministic, bool):
- logger.error_log_with_exp(
- "is_deterministic is invalid, it should be a boolean", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("is_deterministic is invalid, it should be a boolean",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if not isinstance(self.enable_dataloader, bool):
- logger.error_log_with_exp(
- "enable_dataloader is invalid, it should be a boolean", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
-
+ logger.error_log_with_exp("enable_dataloader is invalid, it should be a boolean",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+
class BaseConfig:
def __init__(self, json_config):
@@ -44,15 +47,17 @@ class BaseConfig:
self.data_mode = json_config.get('data_mode')
self.backward_input = json_config.get("backward_input")
self.file_format = json_config.get("file_format")
- self.summary_mode = json_config.get("summary_mode")
- self.overflow_num = json_config.get("overflow_num")
+ self.summary_mode = json_config.get("summary_mode")
+ self.overflow_nums = json_config.get("overflow_nums")
self.check_mode = json_config.get("check_mode")
def check_config(self):
if self.scope is not None and not isinstance(self.scope, list):
- logger.error_log_with_exp("scope is invalid, it should be a list", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("scope is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.list is not None and not isinstance(self.list, list):
- logger.error_log_with_exp("list is invalid, it should be a list", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
+ logger.error_log_with_exp("list is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
if self.data_mode is not None and not isinstance(self.data_mode, list):
- logger.error_log_with_exp("data_mode is invalid, it should be a list", MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
-
+ logger.error_log_with_exp("data_mode is invalid, it should be a list",
+ MsprobeException(MsprobeException.INVALID_PARAM_ERROR))
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
index 800a2b81c2f..de2b93c206d 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py
@@ -1,7 +1,6 @@
-
import os
-from msprobe.core.data_dump.scope import build_scope, ListScope
+from msprobe.core.data_dump.scope import build_scope, ListScope
from msprobe.core.data_dump.json_writer import DataWriter
from msprobe.core.common.log import logger
from msprobe.core.common.const import Const
@@ -21,7 +20,8 @@ class DataCollector:
self.config = config
self.data_writer = DataWriter()
self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer)
- self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None
+ self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) \
+ if self.config.framework == Const.PT_FRAMEWORK else None
self.module_count = {}
if self.config.task == Const.FREE_BENCHMARK:
self.scope = build_scope(ListScope, self.config.scope, self.config.list)
@@ -35,7 +35,7 @@ class DataCollector:
@property
def dump_file_path(self):
return self.data_writer.dump_file_path
-
+
@staticmethod
def check_scope_and_pid(scope, name, pid):
return (not scope or scope.check(name)) and pid == os.getpid()
@@ -43,10 +43,10 @@ class DataCollector:
@staticmethod
def is_inplace(module):
return getattr(module, "op_is_inplace", False)
-
+
def if_return_forward_new_output(self):
return self.data_processor.if_return_forward_new_output()
-
+
def get_forward_new_output(self):
return self.data_processor.get_forward_new_output()
@@ -88,8 +88,11 @@ class DataCollector:
else:
data_info = self.data_processor.analyze_forward_inplace(name, module_input_output)
if self.config.level == "L2":
- return
+ return
self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name))
+ if self.data_processor.stop_run():
+ self.handle_data(name, data_info, use_buffer=False)
+ raise Exception("[msprobe] exit")
self.handle_data(name, data_info)
def backward_data_collect(self, name, module, pid, module_input_output):
@@ -98,6 +101,25 @@ class DataCollector:
return
data_info = self.data_processor.analyze_backward(name, module, module_input_output)
+ if self.data_processor.stop_run():
+ self.handle_data(name, data_info, use_buffer=False)
+ raise Exception("[msprobe] exit")
+ self.handle_data(name, data_info)
+
+ def backward_input_data_collect(self, name, module, pid, module_input_output):
+ self.update_construct(name)
+ if not self.check_scope_and_pid(self.scope, name, pid):
+ return
+
+ data_info = self.data_processor.analyze_backward_input(name, module, module_input_output)
+ self.handle_data(name, data_info)
+
+ def backward_output_data_collect(self, name, module, pid, module_input_output):
+ self.update_construct(name)
+ if not self.check_scope_and_pid(self.scope, name, pid):
+ return
+
+ data_info = self.data_processor.analyze_backward_output(name, module, module_input_output)
self.handle_data(name, data_info)
def update_construct(self, name):
@@ -105,12 +127,15 @@ class DataCollector:
self.data_writer.update_construct({name: self.module_processor.api_parent_node})
self.data_writer.update_construct(self.module_processor.module_node)
- def handle_data(self, name, data_info):
+ def handle_data(self, name, data_info, use_buffer=True):
msg = f"msProbe is collecting data on {name}. "
if data_info:
msg = self.update_data(data_info, msg)
logger.info(msg)
- self.data_writer.flush_data_when_buffer_is_full()
+ if use_buffer:
+ self.data_writer.flush_data_when_buffer_is_full()
+ else:
+ self.write_json()
def module_count_func(self, name, name_template):
module_name = name.split(Const.SEP)[-3]
@@ -135,6 +160,6 @@ class DataCollector:
def update_dump_paths(self, *args):
self.data_writer.update_dump_paths(*args)
self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level)
-
+
def update_iter(self, current_iter):
self.data_processor.update_iter(current_iter)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
index 5d901291973..13134d61980 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py
@@ -35,11 +35,26 @@ class ModuleBackwardInputsOutputs:
@property
def grad_input_tuple(self):
return convert_tuple(self.grad_input)
-
+
@property
def grad_output_tuple(self):
- return convert_tuple(self.grad_output)
+ return convert_tuple(self.grad_output)
+
+@dataclass
+class ModuleBackwardInputs:
+ grad_input: Optional[Tuple]
+
+ @property
+ def grad_input_tuple(self):
+ return convert_tuple(self.grad_input)
+
+@dataclass
+class ModuleBackwardOutputs:
+ grad_output: Optional[Tuple]
+ @property
+ def grad_output_tuple(self):
+ return convert_tuple(self.grad_output)
class TensorStatInfo:
def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None):
@@ -53,7 +68,7 @@ class BaseDataProcessor:
_recursive_key_stack = []
special_type = (np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_,
bool, int, float, str, slice)
-
+
def __init__(self, config, data_writer):
self.data_writer = data_writer
self.config = config
@@ -65,11 +80,11 @@ class BaseDataProcessor:
self.current_iter = 0
self._return_forward_new_output = False
self._forward_new_output = None
-
+
@property
def data_path(self):
return self.data_writer.dump_tensor_data_dir
-
+
@staticmethod
def analyze_api_call_stack(name):
stack_str = []
@@ -87,7 +102,7 @@ class BaseDataProcessor:
stack_str.append(stack_line)
stack_info_struct = {name: stack_str}
return stack_info_struct
-
+
@staticmethod
def _convert_numpy_to_builtin(arg):
type_mapping = {
@@ -103,26 +118,15 @@ class BaseDataProcessor:
if isinstance(arg, numpy_type):
return builtin_type(arg), type(arg).__name__
return arg, ''
-
+
@staticmethod
def _analyze_numpy(value, numpy_type):
return {"type": numpy_type, "value": value}
-
- @staticmethod
- def _analyze_builtin(arg):
- single_arg = {}
- if isinstance(arg, slice):
- single_arg.update({"type": "slice"})
- single_arg.update({"value": [arg.start, arg.stop, arg.step]})
- else:
- single_arg.update({"type": type(arg).__name__})
- single_arg.update({"value": arg})
- return single_arg
-
+
@classmethod
def get_special_types(cls):
return cls.special_type
-
+
@classmethod
def recursive_apply_transform(cls, args, transform):
if isinstance(args, cls.get_special_types()):
@@ -177,13 +181,14 @@ class BaseDataProcessor:
return (Const.ALL in self.config.data_mode or
forward_backward in self.config.data_mode or
input_output in self.config.data_mode)
-
- def analyze_pre_forward(self, name, module,module_input_output: ModuleForwardInputsOutputs):
+
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
pass
-
+
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
api_info_struct = {}
- if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input
+ # check whether data_mode contains forward or input
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
api_info_struct[name] = {}
self.api_data_category = Const.INPUT
args_info_list = self.analyze_element(module_input_output.args_tuple)
@@ -192,13 +197,14 @@ class BaseDataProcessor:
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
- if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output
+ # check whether data_mode contains forward or output
+ if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT):
api_info_struct[name] = api_info_struct.get(name, {})
self.api_data_category = Const.OUTPUT
output_info_list = self.analyze_element(module_input_output.output_tuple)
api_info_struct[name][Const.OUTPUT] = output_info_list
return api_info_struct
-
+
def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
api_info_struct = {}
if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT):
@@ -210,7 +216,7 @@ class BaseDataProcessor:
kwargs_info_list = self.analyze_element(module_input_output.kwargs)
api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list
return api_info_struct
-
+
def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs):
concat_args = module_input_output.concat_args_and_kwargs()
api_info_struct = {}
@@ -220,26 +226,55 @@ class BaseDataProcessor:
output_info_list = self.analyze_element(concat_args)
api_info_struct[name][Const.OUTPUT] = output_info_list
return api_info_struct
-
+
def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
api_info_struct = {}
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
api_info_struct[name] = {}
- self.api_data_category = Const.OUTPUT
+ self.api_data_category = Const.INPUT
input_info_list = self.analyze_element(module_input_output.grad_input_tuple)
- api_info_struct[name][Const.GRAD_INPUT] = input_info_list
+ api_info_struct[name][Const.INPUT] = input_info_list
- if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
api_info_struct[name] = api_info_struct.get(name, {})
- self.api_data_category = Const.INPUT
+ self.api_data_category = Const.OUTPUT
output_info_list = self.analyze_element(module_input_output.grad_output_tuple)
+ api_info_struct[name][Const.OUTPUT] = output_info_list
+
+ return api_info_struct
+
+ def analyze_backward_input(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ """
+ Analyze and save backward input gradients.
+ """
+ api_info_struct = {}
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT):
+ api_info_struct[name] = {}
+ self.api_data_category = Const.OUTPUT
+ # self.api_data_category = Const.INPUT
+ output_info_list = self.analyze_element(module_input_output.grad_input_tuple)
api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list
+ return api_info_struct
+ def analyze_backward_output(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ """
+ Analyze and save backward output gradients.
+ """
+ api_info_struct = {}
+ if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT):
+ api_info_struct[name] = {}
+ self.api_data_category = Const.INPUT
+ # self.api_data_category = Const.OUTPUT
+ input_info_list = self.analyze_element(module_input_output.grad_output_tuple)
+ api_info_struct[name][Const.GRAD_INPUT] = input_info_list
return api_info_struct
def get_save_file_path(self, suffix):
- file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
+ file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX
dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP +
suffix + file_format)
file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name)
- return dump_data_name, file_path
\ No newline at end of file
+ return dump_data_name, file_path
+
+ def stop_run(self):
+ return False
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
index 86ef2115fb2..ad74acdeeba 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/factory.py
@@ -4,7 +4,7 @@ from msprobe.core.common.const import Const
class DataProcessorFactory:
_data_processor = {}
_module_processor = {}
-
+
@classmethod
def register_processor(cls, framework, task, processor_class):
key = (framework, task)
@@ -13,7 +13,7 @@ class DataProcessorFactory:
@classmethod
def register_module_processor(cls, framework, processor_class):
cls._module_processor[framework] = processor_class
-
+
@classmethod
def get_module_processor(cls, framework):
processor_class = cls._module_processor.get(framework)
@@ -39,7 +39,7 @@ class DataProcessorFactory:
TensorDataProcessor as PytorchTensorDataProcessor,
OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor,
FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor,
- KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
+ KernelDumpDataProcessor as PytorchKernelDumpDataProcessor
)
from ....pytorch.module_processer import ModuleProcesser
cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor)
@@ -47,11 +47,13 @@ class DataProcessorFactory:
cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor)
cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor)
cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor)
- cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
+ cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser)
elif framework == Const.MS_FRAMEWORK:
from .mindspore_processor import (
StatisticsDataProcessor as MindsporeStatisticsDataProcessor,
- TensorDataProcessor as MindsporeTensorDataProcessor
+ TensorDataProcessor as MindsporeTensorDataProcessor,
+ OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor
)
cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)
cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)
+ cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
index 7533e2ee0de..b28817e4aa7 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py
@@ -19,7 +19,8 @@ from mindspore import ops
import numpy as np
from msprobe.core.common.const import Const
-from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, TensorStatInfo
+from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo,
+ ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs)
from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst
from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions
from msprobe.mindspore.common.utils import convert_bf16_to_fp32
@@ -30,7 +31,7 @@ from msprobe.mindspore.dump.hook_cell.api_registry import api_register
class MindsporeDataProcessor(BaseDataProcessor):
mindspore_special_type = tuple([ms.Tensor])
ops_func, mint_ops_func, _ = load_ops_functions()
-
+
def __init__(self, config, data_writer):
super().__init__(config, data_writer)
self.mindspore_object_key = {
@@ -47,18 +48,35 @@ class MindsporeDataProcessor(BaseDataProcessor):
@staticmethod
def analyze_dtype_in_kwargs(element):
return {"type": "mindspore.dtype", "value": str(element)}
-
+
+ @staticmethod
+ def _analyze_builtin(arg):
+ single_arg = {}
+ if isinstance(arg, slice):
+ single_arg.update({"type": "slice"})
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
+ values = [
+ value if not isinstance(value, ms.Tensor) else value.item()
+ for value in [arg.start, arg.stop, arg.step]
+ ]
+ single_arg.update({"value": values})
+ else:
+ single_arg.update({"type": type(arg).__name__})
+ single_arg.update({"value": arg})
+ return single_arg
+
@classmethod
def get_special_types(cls):
return super().get_special_types() + cls.mindspore_special_type
-
+
def get_stat_info(self, data):
tensor_stat = TensorStatInfo()
if data.numel() == 0:
return tensor_stat
elif data.dtype == ms.bool_:
- tensor_stat.max = self.mint_ops_func["max"](data).item()
- tensor_stat.min = self.mint_ops_func["min"](data).item()
+ data_np = data.asnumpy()
+ tensor_stat.max = np.max(data_np)
+ tensor_stat.min = np.min(data_np)
elif not data.shape:
tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item()
elif data.dtype == ms.complex64 or data.dtype == ms.complex128:
@@ -90,7 +108,7 @@ class MindsporeDataProcessor(BaseDataProcessor):
if isinstance(element, (bool, int, float, str, slice)):
return self._analyze_builtin(element)
- return None
+ return {}
def analyze_element(self, element):
return self.recursive_apply_transform(element, self.analyze_single_element)
@@ -129,3 +147,61 @@ class TensorDataProcessor(MindsporeDataProcessor):
else:
logger.warning(f'The file path {file_path} length exceeds limit.')
return single_arg
+
+
+class OverflowCheckDataProcessor(MindsporeDataProcessor):
+ __slots__ = ["cached_tensors_and_file_paths"]
+
+ def __init__(self, config, data_writer):
+ super().__init__(config, data_writer)
+ self.cached_tensors_and_file_paths = {}
+ self.real_overflow_dump_times = 0
+ self.overflow_nums = config.overflow_nums
+
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
+ self.has_overflow = False
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
+ self.maybe_save_overflow_data()
+ return api_info_struct if self.has_overflow else None
+
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
+ self.has_overflow = False
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
+ self.maybe_save_overflow_data()
+ return api_info_struct if self.has_overflow else None
+
+ def maybe_save_overflow_data(self):
+ if self.has_overflow:
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
+ tensor = convert_bf16_to_fp32(tensor)
+ np.save(file_path, tensor.asnumpy())
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
+ self.real_overflow_dump_times += 1
+ self.cached_tensors_and_file_paths = {}
+
+ def stop_run(self):
+ if self.overflow_nums == -1:
+ return False
+ if self.real_overflow_dump_times >= self.overflow_nums:
+ logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}")
+ return True
+ return False
+
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
+ if tensor_json['Max'] is None:
+ return
+ if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
+ self.has_overflow = True
+ if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
+ self.has_overflow = True
+
+ def _analyze_tensor(self, tensor, suffix):
+ dump_data_name, file_path = self.get_save_file_path(suffix)
+ if not path_len_exceeds_limit(file_path):
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
+ else:
+ logger.warning(f'The file path {file_path} length exceeds limit.')
+ single_arg = super()._analyze_tensor(tensor, suffix)
+ self._analyze_maybe_overflow_tensor(single_arg)
+ single_arg.update({"data_name": dump_data_name})
+ return single_arg
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
index f307909a416..007fec80964 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py
@@ -15,8 +15,9 @@ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
try:
import torch_npu
+ is_gpu = False
except ImportError:
- pass
+ is_gpu = True
class PytorchDataProcessor(BaseDataProcessor):
@@ -77,6 +78,38 @@ class PytorchDataProcessor(BaseDataProcessor):
tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
return tensor_stat
+ @staticmethod
+ def handle_tensor_extremum_nan_inf(tensor, operator):
+ data_clone = tensor.detach()
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
+ return float('nan')
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
+ finite_values = data_clone[finite_mask]
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
+ torch._C._VariableFunctionsClass.min(finite_values).item()
+ else:
+ data_no_nan = data_clone[~data_nan]
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
+
+ @staticmethod
+ def _analyze_builtin(arg):
+ single_arg = {}
+ if isinstance(arg, slice):
+ single_arg.update({"type": "slice"})
+ # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型
+ values = [
+ value if not isinstance(value, torch.Tensor) else value.item()
+ for value in [arg.start, arg.stop, arg.step]
+ ]
+ single_arg.update({"value": values})
+ else:
+ single_arg.update({"type": type(arg).__name__})
+ single_arg.update({"value": arg})
+ return single_arg
+
@staticmethod
def _analyze_torch_size(arg):
return {"type": "torch.Size", "value": list(arg)}
@@ -97,7 +130,7 @@ class PytorchDataProcessor(BaseDataProcessor):
return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
if isinstance(element, (bool, int, float, str, slice)):
return self._analyze_builtin(element)
- return None
+ return {}
def analyze_element(self, element):
return self.recursive_apply_transform(element, self.analyze_single_element)
@@ -113,9 +146,17 @@ class PytorchDataProcessor(BaseDataProcessor):
tensor_json.update({"Mean": tensor_stat.mean})
tensor_json.update({"Norm": tensor_stat.norm})
tensor_json.update({"requires_grad": tensor.requires_grad})
- if self.config.summary_mode == "md5":
+
+ if tensor_stat.max is not None:
+ if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max):
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max")
+ if tensor_stat.min is not None:
+ if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min):
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min")
+
+ if self.config.summary_mode == Const.MD5:
tensor_md5 = self.get_md5_for_tensor(tensor)
- tensor_json.update({"md5": tensor_md5})
+ tensor_json.update({Const.MD5: tensor_md5})
return tensor_json
@@ -143,7 +184,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
super().__init__(config, data_writer)
self.cached_tensors_and_file_paths = {}
self.real_overflow_dump_times = 0
- self.overflow_nums = config.overflow_num
+ self.overflow_nums = config.overflow_nums
self.bits_for_overflow = 8
@staticmethod
@@ -151,21 +192,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
return overflow_mode == Const.ENV_ENABLE
- @staticmethod
- def handle_tensor_extremum_nan_inf(data_clone, operator):
- data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
- if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
- return float('nan')
- finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
- if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
- finite_values = data_clone[finite_mask]
- return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
- torch._C._VariableFunctionsClass.min(finite_values).item()
- else:
- data_no_nan = data_clone[~data_nan]
- return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
- torch._C._VariableFunctionsClass.min(data_no_nan).item()
-
def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
self.has_overflow = False
api_info_struct = super().analyze_forward(name, module, module_input_output)
@@ -211,16 +237,13 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
else:
torch_npu._C._clear_overflow_npu()
- def _analyze_maybe_overflow_tensor(self, tensor_json, tensor):
- data_clone = tensor.detach()
- if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
+ def _analyze_maybe_overflow_tensor(self, tensor_json):
+ if is_gpu or (hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan()):
if tensor_json['Max'] is None:
return
if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
- tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
self.has_overflow = True
if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
- tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
self.has_overflow = True
else:
self.has_overflow = self.check_overflow_npu()
@@ -234,7 +257,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor):
else:
logger.warning(f'The file path {file_path} length exceeds limit.')
single_arg = super()._analyze_tensor(tensor, suffix)
- self._analyze_maybe_overflow_tensor(single_arg, tensor)
+ self._analyze_maybe_overflow_tensor(single_arg)
single_arg.update({"data_name": dump_data_name})
return single_arg
diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
index c4b7fc11ec4..112e45171ef 100644
--- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
+++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py
@@ -4,7 +4,7 @@ import fcntl
import json
from pathlib import Path
-from msprobe.core.common.file_check import change_mode
+from msprobe.core.common.file_check import change_mode, FileOpen
from msprobe.core.common.log import logger
from msprobe.core.common.const import Const, FileCheckConst
@@ -30,20 +30,20 @@ class DataWriter:
return
is_exists = os.path.exists(file_path)
append = "a+" if is_exists else "w+"
- with os.fdopen(
- os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
- ) as csv_file:
+ with FileOpen(file_path, append) as csv_file:
spawn_writer = csv.writer(csv_file)
if not is_exists:
spawn_writer.writerow(result_header)
spawn_writer.writerows([result,])
+ is_new_file = not is_exists
+ if is_new_file:
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
def initialize_json_file(self, **kwargs):
kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
- with os.fdopen(
- os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
- ) as f:
+ with FileOpen(self.dump_file_path, 'w') as f:
json.dump(kwargs, f)
+ change_mode(self.dump_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
if os.path.exists(self.stack_file_path):
os.remove(self.stack_file_path)
@@ -83,7 +83,7 @@ class DataWriter:
def write_data_json(self, file_path):
logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
if Path(file_path).exists() and os.path.getsize(file_path) > 0:
- with open(file_path, "r+") as f:
+ with FileOpen(file_path, "r+") as f:
fcntl.flock(f, fcntl.LOCK_EX)
data_to_write = json.load(f)
fcntl.flock(f, fcntl.LOCK_UN)
@@ -91,7 +91,7 @@ class DataWriter:
self.init_json['data_path'] = self.dump_tensor_data_dir
data_to_write = self.init_json
data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(data_to_write, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
@@ -99,13 +99,13 @@ class DataWriter:
self.cache_data[Const.DATA].clear()
def write_stack_info_json(self, file_path):
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self.cache_stack, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
def write_construct_info_json(self, file_path):
- with open(file_path, 'w+') as f:
+ with FileOpen(file_path, 'w+') as f:
fcntl.flock(f, fcntl.LOCK_EX)
json.dump(self.cache_construct, f, indent=1)
fcntl.flock(f, fcntl.LOCK_UN)
diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py
index d02f3819537..6abf0a1ee88 100644
--- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py
+++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py
@@ -29,3 +29,16 @@ def convert_bf16_to_fp32(tensor):
tensor = tensor.to(ms.float32)
return tensor
+
+class MsprobeStep(ms.train.Callback):
+
+ def __init__(self, debugger):
+ super(MsprobeStep, self).__init__()
+ self.debugger = debugger
+
+ def on_train_step_begin(self, run_context):
+ self.debugger.start()
+
+ def on_train_step_end(self, run_context):
+ self.debugger.stop()
+ self.debugger.step()
diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
index 04d66d6a26e..23cb7294b8d 100644
--- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
+++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py
@@ -1,14 +1,10 @@
import os
+
from msprobe.core.common.utils import Const
+from msprobe.core.common.const import MsConst
class DebuggerConfig:
- convert_map = {
- "L0": "cell",
- "L1": "api",
- "L2": 'kernel'
- }
-
def __init__(self, common_config, task_config):
self.dump_path = common_config.dump_path
self.task = common_config.task
@@ -16,12 +12,13 @@ class DebuggerConfig:
self.step = [] if not common_config.step else common_config.step
if not common_config.level:
common_config.level = "L1"
- self.level = DebuggerConfig.convert_map[common_config.level]
+ self.level = MsConst.TOOL_LEVEL_DICT.get(common_config.level, MsConst.API)
self.level_ori = common_config.level
self.list = [] if not task_config.list else task_config.list
- self.scope =[] if not task_config.scope else task_config.scope
- self.data_mode = [] if not task_config.data_mode else task_config.data_mode
+ self.scope = [] if not task_config.scope else task_config.scope
+ self.data_mode = [] if not task_config.data_mode else task_config.data_mode
self.file_format = task_config.file_format
+ self.overflow_nums = 1 if not task_config.overflow_nums else task_config.overflow_nums
self.check_mode = task_config.check_mode
self.framework = Const.MS_FRAMEWORK
self.summary_mode = task_config.summary_mode
diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
index 30f7162ff5c..40b44c57ec9 100644
--- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py
@@ -1,9 +1,12 @@
import os
+
import mindspore as ms
+
from msprobe.mindspore.service import Service
from msprobe.mindspore.ms_config import parse_json_config
from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
from msprobe.mindspore.task_handler_factory import TaskHandlerFactory
+from msprobe.core.common.const import MsConst
class PrecisionDebugger:
@@ -14,6 +17,8 @@ class PrecisionDebugger:
cls._instance = super().__new__(cls)
cls._instance.initialized = False
cls._instance.config = None
+ cls.service = None
+ cls.first_start = False
return cls._instance
def __init__(self, config_path=None):
@@ -24,28 +29,47 @@ class PrecisionDebugger:
common_config, task_config = parse_json_config(config_path)
self.config = DebuggerConfig(common_config, task_config)
self.initialized = True
- self.service = Service(self.config)
+
+ @staticmethod
+ def _get_execution_mode():
+ if ms.get_context("mode") == ms.GRAPH_MODE:
+ if ms.context.get_jit_config().get("jit_level") == "O2" or ms.get_context("jit_level") == "O2":
+ return MsConst.GRAPH_GE_MODE
+ else:
+ return MsConst.GRAPH_KBYK_MODE
+ else:
+ return MsConst.PYNATIVE_MODE
@classmethod
- def start(cls):
+ def start(cls, target=None):
instance = cls._instance
if not instance:
raise Exception("No instance of PrecisionDebugger found.")
- if ms.get_context("mode") == ms.PYNATIVE_MODE and instance.config.level_ori == "L1":
- instance.service.start()
+
+ instance.config.execution_mode = instance._get_execution_mode()
+ if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API:
+ if not instance.service:
+ instance.service = Service(instance.config)
+ instance.service.start(target)
else:
- handler = TaskHandlerFactory.create(instance.config)
- handler.handle()
+ if not instance.first_start:
+ handler = TaskHandlerFactory.create(instance.config)
+ handler.handle()
+
+ instance.first_start = True
@classmethod
def stop(cls):
instance = cls._instance
if not instance:
raise Exception("PrecisionDebugger instance is not created.")
- instance.service.stop()
+ if instance.service:
+ instance.service.stop()
@classmethod
def step(cls):
- if not cls._instance:
+ instance = cls._instance
+ if not instance:
raise Exception("PrecisionDebugger instance is not created.")
- cls._instance.service.step()
\ No newline at end of file
+ if instance.service:
+ instance.service.step()
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry copy.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry copy.py
new file mode 100644
index 00000000000..ad73bcd9119
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry copy.py
@@ -0,0 +1,198 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+import functools
+import mindspore as ms
+from mindspore import ops
+from mindspore.common.tensor import Tensor
+from msprobe.core.common.utils import Const
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
+from msprobe.mindspore.dump.hook_cell.wrap_functional import get_functional_ops, setup_hooks, \
+ HOOKFunctionalOP, HOOKMintOP, HOOKMintNNFunctionalOP
+from msprobe.mindspore.dump.hook_cell.wrap_tensor import get_tensor_ops, wrap_tensor_ops_and_bind, HOOKTensor
+from msprobe.core.common.utils import Const
+
+PRIMITIVE_PREFIX = "Primitive"
+
+class ApiRegistry:
+ def __init__(self):
+ self.tensor_ori_attr = {}
+ self.functional_ori_attr = {}
+ self.mint_ops_ori_attr = {}
+ self.mint_func_ops_ori_attr = {}
+ self.norm_inner_ops_ori_attr = {}
+
+ self.tensor_hook_attr = {}
+ self.functional_hook_attr = {}
+ self.mint_ops_hook_attr = {}
+ self.mint_func_ops_hook_attr = {}
+ self.norm_inner_ops_hook_attr = {}
+
+ self.norm_inner_ops = ["norm", "square", "sqrt", "is_complex"]
+ self.primitive_counters = {}
+
+ @staticmethod
+ def store_ori_attr(ori_api_group, api_list, api_ori_attr):
+ for api in api_list:
+ if Const.SEP in api:
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
+ sub_module = getattr(ori_api_group, sub_module_name)
+ api_ori_attr[api] = getattr(sub_module, sub_op)
+ else:
+ api_ori_attr[api] = getattr(ori_api_group, api)
+
+ @staticmethod
+ def set_api_attr(api_group, attr_dict):
+ for api, api_attr in attr_dict.items():
+ if Const.SEP in api:
+ sub_module_name, sub_op = api.rsplit(Const.SEP, 1)
+ sub_module = getattr(api_group, sub_module_name, None)
+ if sub_module is not None:
+ setattr(sub_module, sub_op, api_attr)
+ else:
+ setattr(api_group, api, api_attr)
+
+ def norm_inner_op_set_hook_func(self):
+ self.set_api_attr(ms.ops, self.norm_inner_ops_hook_attr)
+
+ def norm_inner_op_set_ori_func(self):
+ self.set_api_attr(ms.ops, self.norm_inner_ops_ori_attr)
+
+ def api_set_hook_func(self):
+ self.set_api_attr(ms.Tensor, self.tensor_hook_attr)
+ self.set_api_attr(ms.ops, self.functional_hook_attr)
+ self.set_api_attr(ms.mint, self.mint_ops_hook_attr)
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_hook_attr)
+
+ def api_set_ori_func(self):
+ self.set_api_attr(ms.Tensor, self.tensor_ori_attr)
+ self.set_api_attr(ms.ops, self.functional_ori_attr)
+ self.set_api_attr(ms.mint, self.mint_ops_ori_attr)
+ self.set_api_attr(ms.mint.nn.functional, self.mint_func_ops_ori_attr)
+
+ def initialize_hook(self, hook):
+ self.store_ori_attr(ms.Tensor, get_tensor_ops(), self.tensor_ori_attr)
+ wrap_tensor_ops_and_bind(hook)
+ for attr_name in dir(HOOKTensor):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.tensor_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKTensor, attr_name)
+
+ functional_ops, mint_ops, mint_func_ops = get_functional_ops()
+ self.store_ori_attr(ms.ops, self.norm_inner_ops, self.norm_inner_ops_ori_attr)
+ self.store_ori_attr(ms.ops, functional_ops, self.functional_ori_attr)
+ self.store_ori_attr(ms.mint, mint_ops, self.mint_ops_ori_attr)
+ self.store_ori_attr(ms.mint.nn.functional, mint_func_ops, self.mint_func_ops_ori_attr)
+ setup_hooks(hook)
+ for attr_name in dir(HOOKFunctionalOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.functional_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
+ if attr_name[Const.ATTR_NAME_PREFIX_LEN:] in self.norm_inner_ops:
+ self.norm_inner_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKFunctionalOP, attr_name)
+ for attr_name in dir(HOOKMintOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name)
+ for attr_name in dir(HOOKMintNNFunctionalOP):
+ if attr_name.startswith(Const.ATTR_NAME_PREFIX):
+ self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name)
+
+ def wrap_primitive(self, origin_func, primitive_name, service_instance):
+ primitive_instance = self
+ def func(self, *args, **kwargs):
+ if primitive_name not in primitive_instance.primitive_counters:
+ primitive_instance.primitive_counters[primitive_name] = 0
+ else:
+ primitive_instance.primitive_counters[primitive_name] += 1
+
+ current_count = primitive_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+ captured_grads_input = []
+ captured_grads_output = []
+
+ def input_backward_hook(grad):
+ print(f"Grad input length: {len(grad)}")
+ print("Captured input grad:", grad)
+ captured_grads_input.append(grad)
+ backward_primitive_name = updated_primitive_name + Const.BACKWARD
+ new_module_input_output = ModuleBackwardInputsOutputs(
+ grad_input=tuple(captured_grads_input),
+ grad_output=tuple(captured_grads_output) if captured_grads_output else None
+ )
+ service_instance.data_collector.backward_data_collect(
+ backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ )
+#1未考虑多输出场景
+# 如果时多grad呢
+# 3 输出的序号问题
+ def output_backward_hook(grad):
+ captured_grads_output.append(grad)
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ new_module_input_output = ModuleBackwardInputsOutputs(
+ grad_input=None,
+ grad_output=tuple(captured_grads_output)
+ )
+ service_instance.data_collector.backward_data_collect(
+ backward_primitive_name + Const.BACKWARD, self, os.getpid(), new_module_input_output
+ )
+
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
+
+ print(f"Entering {updated_primitive_name} hook, number of args: {len(args)}, name: {self.name}")
+ hooked_inputs = []
+
+ # for idx, arg in enumerate(args):
+ # if isinstance(arg, Tensor):
+ # arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ # hooked_inputs.append(arg_hooked)
+ # else:
+ # hooked_inputs.append(arg)
+
+ out = origin_func(*args, **kwargs)
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+
+ if service_instance.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=out)
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, self, os.getpid(), module_input_output)
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
+
+ if isinstance(out, Tensor):
+ out = ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ out = tuple(hooked_outputs)
+
+ return out
+
+ return func
+
+ def register_hooks(self, service_instance):
+ primitive_set = set()
+ for name, cell in service_instance.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ print("primitive name is", pname)
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,), {'__call__': self.wrap_primitive(primitive.__call__, pname, service_instance)})
+ primitive.__class__ = NewPrimitive
+
+api_register = ApiRegistry()
diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py
index bcb80dd2266..57ed44111ca 100644
--- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py
+++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py
@@ -18,26 +18,23 @@ from mindspore import nn
from msprobe.core.common.const import Const
-cell_count = defaultdict(int)
-g_stop_hook = False
-
-
class HOOKCell(nn.Cell):
+ cell_count = defaultdict(int)
+ g_stop_hook = False
def __init__(self, build_hook) -> None:
super(HOOKCell, self).__init__()
self.changed_status = False
self.input_kwargs = {}
self.prefix = ""
- global g_stop_hook
- if not g_stop_hook:
- g_stop_hook = True
+ if not HOOKCell.g_stop_hook:
+ HOOKCell.g_stop_hook = True
self.changed_status = True
if hasattr(self, "prefix_op_name_"):
self.prefix = self.prefix_op_name_
- cell_count[self.prefix] += 1
- self.prefix = self.prefix + str(cell_count[self.prefix] - 1) + Const.SEP
+ HOOKCell.cell_count[self.prefix] += 1
+ self.prefix = self.prefix + str(HOOKCell.cell_count[self.prefix] - 1) + Const.SEP
forward_hook, backward_hook = build_hook(self.prefix)
self.register_forward_hook(forward_hook)
self.register_backward_hook(backward_hook)
@@ -52,6 +49,5 @@ class HOOKCell(nn.Cell):
finally:
if self.changed_status:
self.changed_status = False
- global g_stop_hook
- g_stop_hook = False
+ HOOKCell.g_stop_hook = False
return out
diff --git a/debug/accuracy_tools/msprobe/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py
index 49ce4cf2c09..c0ef6bb6c00 100644
--- a/debug/accuracy_tools/msprobe/mindspore/ms_config.py
+++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py
@@ -36,37 +36,39 @@ class StatisticsConfig(BaseConfig):
raise Exception("summary_mode is invalid")
-class OverflowCheck(BaseConfig):
+class OverflowCheckConfig(BaseConfig):
def __init__(self, json_config):
super().__init__(json_config)
- self.file_format = None
- self.check_mode = json_config.get("check_mode")
+ self.data_mode = ["all"]
self._check_config()
def _check_config(self):
- if self.data_mode is not None and len(self.data_mode) > 0:
- if len(self.data_mode) > 1 or self.data_mode[0] not in ["all", "input", "output"]:
- raise Exception("data_mode must be all, input or output")
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
+ raise Exception("overflow_nums is invalid, it should be an integer")
+ if self.overflow_nums is not None and self.overflow_nums != -1 and self.overflow_nums <= 0:
+ raise Exception("overflow_nums should be -1 or positive integer")
if self.check_mode and self.check_mode not in ["all", "aicore", "atomic"]:
raise Exception("check_mode is invalid")
+TaskDict = {
+ Const.TENSOR: TensorConfig,
+ Const.STATISTICS: StatisticsConfig,
+ Const.OVERFLOW_CHECK: OverflowCheckConfig,
+}
+
+
def parse_common_config(json_config):
return CommonConfig(json_config)
def parse_task_config(task, json_config):
- task_map = json_config[task]
+ task_map = json_config.get(task)
if not task_map:
task_map = dict()
- if task == Const.TENSOR:
- return TensorConfig(task_map)
- elif task == Const.STATISTICS:
- return StatisticsConfig(task_map)
- elif task == Const.OVERFLOW_CHECK:
- return OverflowCheck(task_map)
- else:
+ if task not in TaskDict:
raise Exception("task is invalid.")
+ return TaskDict.get(task)(task_map)
def parse_json_config(json_file_path):
diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py
index e8aa34dc4fe..b795ec10342 100644
--- a/debug/accuracy_tools/msprobe/mindspore/service.py
+++ b/debug/accuracy_tools/msprobe/mindspore/service.py
@@ -14,9 +14,14 @@
# ============================================================================
import os
+import copy
from pathlib import Path
import functools
+from collections import defaultdict
+from mindspore.common.tensor import Tensor
+from mindspore import ops
+from mindspore import nn
from msprobe.core.data_dump.data_collector import build_data_collector
from msprobe.core.data_dump.scope import BaseScope
from msprobe.mindspore.common.utils import get_rank_if_initialized
@@ -25,20 +30,25 @@ from msprobe.mindspore.common.log import logger
from msprobe.core.common.utils import Const
from msprobe.core.common.exceptions import DistributedNotInitializedError
from msprobe.mindspore.dump.hook_cell.api_registry import api_register
-from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
+from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,\
+ ModuleBackwardInputs, ModuleBackwardOutputs
+from msprobe.core.common.exceptions import MsprobeException
+from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
class Service:
def __init__(self, config):
self.model = None
- self.config = config
+ self.config = copy.deepcopy(config)
self.config.level = self.config.level_ori
- self.data_collector = build_data_collector(config)
+ self.data_collector = build_data_collector(self.config)
self.switch = False
self.current_iter = 0
self.first_start = True
self.current_rank = None
+ self.primitive_counters = {}
self.dump_iter_dir = None
+ self.start_call = False
def build_hook(self, module_type, name):
def forward_hook(api_or_module_name, module, input, output):
@@ -50,6 +60,7 @@ class Service:
self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output)
if self.data_collector.if_return_forward_new_output():
return self.data_collector.get_forward_new_output()
+ del module.input_kwargs
return output
def backward_hook(api_or_module_name, module, grad_input, grad_output):
@@ -68,18 +79,162 @@ class Service:
def wrap_forward_hook(*args, **kwargs):
return forward_hook(*args, **kwargs)
-
+
def wrap_backward_hook(*args, **kwargs):
return backward_hook(*args, **kwargs)
-
+
return wrap_forward_hook, wrap_backward_hook
+ def wrap_primitive(self, origin_func, primitive_name):
+ service_instance = self
+
+ def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
+ def backward_hook(grad):
+ captured_grads.append(grad)
+ try:
+ if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
+ new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
+ service_instance.data_collector.backward_input_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ captured_grads.clear()
+ elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
+ backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name)
+ new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
+ service_instance.data_collector.backward_output_data_collect(
+ backward_primitive_name, service_instance, os.getpid(), new_module_input_output
+ )
+ captured_grads.clear()
+
+ except Exception as exception:
+ raise Exception(
+ "This is a primitive op {hook_type}_backward dump error: {exception},"
+ " updated_primitive_name: {updated_primitive_name}".format(
+ hook_type=hook_type, exception=exception, updated_primitive_name=updated_primitive_name
+ )
+ ) from exception
+
+ return backward_hook
+
+ def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
+ hooked_inputs = []
+ num_tensors = sum(isinstance(arg, Tensor) for arg in args)
+ input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
+ Const.INPUT)
+ for _, arg in enumerate(args):
+ if isinstance(arg, Tensor):
+ arg_hooked = ops.HookBackward(input_backward_hook)(arg)
+ hooked_inputs.append(arg_hooked)
+ else:
+ hooked_inputs.append(arg)
+ return hooked_inputs
+
+ def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
+ if isinstance(out, tuple):
+ num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
+ else:
+ num_output_tensors = 1
+ output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
+ updated_primitive_name, Const.OUTPUT)
+
+ if isinstance(out, Tensor):
+ return ops.HookBackward(output_backward_hook)(out)
+ elif isinstance(out, tuple):
+ hooked_outputs = []
+ for tensor in out:
+ if isinstance(tensor, Tensor):
+ hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
+ else:
+ hooked_outputs.append(tensor)
+ return tuple(hooked_outputs)
+ return out
+
+ def wrapped_primitive_call(instance_self, *args, **kwargs):
+
+ service_instance.update_primitive_counters(primitive_name)
+ current_count = service_instance.primitive_counters[primitive_name]
+ updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
+
+ if not service_instance.switch:
+ return origin_func(*args, **kwargs)
+
+ captured_grads_input, captured_grads_output = [], []
+
+ try:
+ hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during input hooking: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+
+ try:
+ out = origin_func(*hooked_inputs, **kwargs)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during function call: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+
+ forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
+ service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name)
+ if service_instance.data_collector:
+ module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
+ try:
+ service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
+ os.getpid(), module_input_output)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during forward data collection: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+
+ if service_instance.data_collector.if_return_forward_new_output():
+ out = service_instance.data_collector.get_forward_new_output()
+
+ try:
+ out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
+ except Exception as exception:
+ raise Exception("This is a primitive op dump error during output hooking: {},"
+ " primitive_name: {}".format(exception, primitive_name)) from exception
+
+ return out
+
+
+ return wrapped_primitive_call
+
+ def update_primitive_counters(self, primitive_name):
+ if primitive_name not in self.primitive_counters:
+ self.primitive_counters[primitive_name] = 0
+ else:
+ self.primitive_counters[primitive_name] += 1
+
+ def register_hooks(self):
+ primitive_set = set()
+ for _, cell in self.model.cells_and_names():
+ for pname, primitive in cell._primitives.items():
+ primitive_set.add((pname, primitive))
+
+ for pname, primitive in primitive_set:
+ NewPrimitive = type('NewPrimitive', (primitive.__class__,),
+ {'__call__': self.wrap_primitive(primitive.__call__, pname)})
+ primitive.__class__ = NewPrimitive
+
def step(self):
self.current_iter += 1
self.data_collector.update_iter(self.current_iter)
+ HOOKCell.cell_count = defaultdict(int)
+ self.primitive_counters.clear()
+
+ @staticmethod
+ def check_model_valid(model):
+ if not model or isinstance(model, nn.Cell):
+ return model
+ raise MsprobeException(
+ MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。"
+ )
def start(self, model=None):
- self.model = model
+ self.model = Service.check_model_valid(model)
+ self.start_call = True
+ logger.info("msprobe: debugger.start() is set successfully")
if self.config.step and self.current_iter > max(self.config.step):
self.stop()
raise Exception("msprobe: exit after iteration {}".format(max(self.config.step)))
@@ -96,16 +251,22 @@ class Service:
self.register_hook_new()
self.first_start = False
self.switch = True
- logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ")
+ logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
self.create_dirs()
- logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.")
+ logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
def stop(self):
+ logger.info("msprobe: debugger.stop() is set successfully. "
+ "Please set debugger.start() to turn on the dump switch again. ")
+ if not self.start_call:
+ logger.error("msprobe: debugger.start() is not set in the current scope.")
+ raise Exception("debugger.start() is not set in the current scope.")
if self.config.step and self.current_iter not in self.config.step:
return
if self.config.rank and self.current_rank not in self.config.rank:
return
self.switch = False
+ self.start_call = False
self.data_collector.write_json()
def create_dirs(self):
@@ -130,9 +291,11 @@ class Service:
construct_file_path = os.path.join(dump_dir, "construct.json")
self.data_collector.update_dump_paths(
dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None)
-
+
def register_hook_new(self):
- logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task))
+ logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
if self.config.level == "L1":
api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
api_register.api_set_hook_func()
+ if self.model:
+ self.register_hooks()
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
index f495cd673d7..b2eec691af0 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py
@@ -21,10 +21,11 @@ import torch
import numpy
from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api
-from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, \
- get_full_data_path, CompareException
+from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \
+ CompareException
+from msprobe.core.common.file_check import FileChecker
from msprobe.pytorch.common.log import logger
-from msprobe.core.common.const import Const
+from msprobe.core.common.const import Const, FileCheckConst
TORCH_TYPE = ["torch.device", "torch.dtype"]
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
@@ -87,12 +88,13 @@ def gen_real_tensor(data_path, convert_type):
convert_type: convert ori_type to dist_type flag.
"""
data_path = os.path.realpath(data_path)
- check_file_or_directory_path(data_path)
+ data_path_checker = FileChecker(data_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE)
+ data_path = data_path_checker.common_check()
if not data_path.endswith('.pt') and not data_path.endswith('.npy'):
error_info = f"The file: {data_path} is not a pt or numpy file."
raise CompareException(CompareException.INVALID_FILE_ERROR, error_info)
if data_path.endswith('.pt'):
- data = torch.load(data_path).cpu()
+ data = torch.load(data_path, map_location=torch.device('cpu'))
else:
data_np = numpy.load(data_path)
data = torch.from_numpy(data_np)
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
index 9c96a52d8bd..9acb5ee6498 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
@@ -68,7 +68,7 @@ signal.signal(signal.SIGTERM, signal_handler)
ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits',
'save_error_data_flag', 'jit_compile_flag', 'device_id',
- 'result_csv_path', 'total_items', 'real_data_path'])
+ 'result_csv_path', 'total_items', 'config_path'])
def run_parallel_ut(config):
@@ -90,7 +90,7 @@ def run_parallel_ut(config):
*(['-j'] if config.jit_compile_flag else []),
*(['-save_error_data'] if config.save_error_data_flag else []),
'-csv_path', config.result_csv_path,
- *(['-real_data_path', config.real_data_path] if config.real_data_path else [])
+ *(['-config', config.config_path] if config.config_path else [])
]
return cmd
@@ -110,14 +110,9 @@ def run_parallel_ut(config):
def update_progress_bar(progress_bar, result_csv_path):
while any(process.poll() is None for process in processes):
- try:
- with open(result_csv_path, 'r') as result_file:
- completed_items = len(result_file.readlines()) - 1
- progress_bar.update(completed_items - progress_bar.n)
- except FileNotFoundError:
- logger.warning(f"Result CSV file not found: {result_csv_path}.")
- except Exception as e:
- logger.error(f"An unexpected error occurred while reading result CSV: {e}")
+ with FileOpen(result_csv_path, 'r') as result_file:
+ completed_items = len(result_file.readlines()) - 1
+ progress_bar.update(completed_items - progress_bar.n)
time.sleep(1)
for api_info in config.api_files:
@@ -175,7 +170,7 @@ def prepare_config(args):
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
out_path = out_path_checker.common_check()
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
-
+ config_path = os.path.realpath(args.config_path) if args.config_path else None
result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
if not args.result_csv_path:
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
@@ -187,7 +182,7 @@ def prepare_config(args):
logger.info(f"UT task details will be saved in {details_csv_path}")
return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data,
args.jit_compile, args.device_id, result_csv_path,
- total_items, args.real_data_path)
+ total_items, config_path)
def main():
diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
index 30994f70944..559dfdc0f14 100644
--- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
+++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py
@@ -27,6 +27,8 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC
from msprobe.pytorch.hook_module.wrap_tensor import TensorOPTemplate
from msprobe.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
from msprobe.pytorch.hook_module.wrap_torch import TorchOPTemplate
+from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate
+from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate
from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig
from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward
from msprobe.core.common.file_check import FileOpen, FileChecker, \
@@ -78,6 +80,12 @@ def exec_api(api_type, api_name, args, kwargs):
if api_type == "Torch":
torch_api = TorchOPTemplate(api_name, str, False)
out = torch_api.forward(*args, **kwargs)
+ if api_type == "Aten":
+ torch_api = AtenOPTemplate(api_name, None, False)
+ out = torch_api.forward(*args, **kwargs)
+ if api_type == "NPU":
+ torch_api = NpuOPTemplate(api_name, None, False)
+ out = torch_api.forward(*args, **kwargs)
return out
@@ -274,7 +282,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
if need_backward:
if need_to_backward(grad_index, out):
- backward_args = backward_content[api_full_name].get("grad_output")
+ backward_args = backward_content[api_full_name].get("input")
grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0]
bench_grad, _ = generate_cpu_params(grad, {}, False, api_name)
bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out)
@@ -379,10 +387,6 @@ def _run_ut_parser(parser):
help=" The path of accuracy_checking_result_{timestamp}.csv, "
"when run ut is interrupted, enter the file path to continue run ut.",
required=False)
- parser.add_argument("-real_data_path", dest="real_data_path", nargs="?", const="", default="", type=str,
- help=" In real data mode, the root directory for storing real data "
- "must be configured.",
- required=False)
parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true",
help=" Whether to filter the api in the api_info_file.", required=False)
parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str,
@@ -400,9 +404,9 @@ def preprocess_forward_content(forward_content):
if key not in arg_cache:
filtered_new_args = [
{k: v for k, v in arg.items() if k not in ['Max', 'Min']}
- for arg in value['args'] if isinstance(arg, dict)
+ for arg in value['input_args'] if isinstance(arg, dict)
]
- arg_cache[key] = (filtered_new_args, value['kwargs'])
+ arg_cache[key] = (filtered_new_args, value['input_kwargs'])
filtered_new_args, new_kwargs = arg_cache[key]
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py
new file mode 100644
index 00000000000..eb06867371c
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/__init__.py
@@ -0,0 +1,15 @@
+import os
+from pkgutil import iter_modules
+from importlib import import_module
+
+"""
+gpu and cpu not implement benchmark function, supplementary benchmarking function implementation
+"""
+
+package_path = os.path.dirname(os.path.realpath(__file__))
+for _, module_name, _ in iter_modules([package_path]):
+ module = import_module(f"{__name__}.{module_name}")
+ for attr_name in dir(module):
+ attr = getattr(module, attr_name)
+ if callable(attr) and "npu_custom" not in attr_name:
+ globals()[attr_name] = attr
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py
new file mode 100644
index 00000000000..caf21a604c6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/apply_adam_w.py
@@ -0,0 +1,28 @@
+import torch
+
+
+def npu_apply_adam_w(beta1_power, beta2_power, lr, weight_decay,
+ beta1, beta2, eps, grad, max_grad_norm, amsgrad, maximize, out):
+ var, m, v = out
+ if amsgrad:
+ max_grad_norm = (torch.rand(var.shape) * 10.0 - 5.0).to(var.dtype)
+ beta1_power_out = beta1_power * beta1
+ beta2_power_out = beta2_power * beta2
+ var_t = var * (1 + (-lr * weight_decay))
+ gt = -grad if maximize else grad
+ m_out = m * beta1 - (beta1 + (-1)) * gt
+ v_out = v * beta2 - (beta2 + (-1)) * gt * gt
+
+ if amsgrad:
+ max_grad_norm_out = torch.max(max_grad_norm, v_out)
+ if (1 - beta2_power_out) == 0:
+ beta2_power_out -= eps
+ denom = torch.sqrt(torch.div(max_grad_norm_out, (1 - beta2_power_out))) + eps
+ else:
+ vraintain = torch.div(v_out, (1 - beta2_power_out))
+ denom = torch.sqrt(vraintain) + eps
+
+ if (1 - beta1_power_out) == 0:
+ beta1_power_out -= eps
+ var_out = var_t + torch.div(-lr * m_out, (1 - beta1_power_out)).div(denom)
+ return var_out.cpu(), m_out.cpu(), v_out.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py
new file mode 100644
index 00000000000..627bf11b64f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/confusion_transpose.py
@@ -0,0 +1,19 @@
+def npu_confusion_transpose(data, perm, shape, transpose_first):
+ if transpose_first:
+ output = data.permute(*perm).contiguous().view(shape)
+ else:
+ output = data.view(shape).permute(*perm)
+ return output.cpu()
+
+
+def npu_confusion_transpose_backward(grad, perm, shape, transpose_first):
+ shape_cal = shape if transpose_first else [shape[perm_dim] for perm_dim in perm]
+ perm_cal = [0] * len(perm)
+ for i, perm_dim in enumerate(perm):
+ perm_cal[perm_dim] = i
+
+ if transpose_first:
+ result = grad.permute(*perm_cal).reshape(shape_cal)
+ else:
+ result = grad.reshape(shape_cal).permute(*perm_cal)
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py
new file mode 100644
index 00000000000..a1a9ca08085
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/fast_gelu.py
@@ -0,0 +1,55 @@
+import torch
+
+
+def fast_gelu(input0):
+ attr = 1.702
+ const_0 = 0 - attr
+ const_1 = 1
+ const_2 = attr / 2
+
+ abs_x = torch.abs(input0)
+ mul_abs_x = abs_x * const_0
+ exp_abs_x = torch.exp(mul_abs_x)
+ div_down = exp_abs_x + const_1
+
+ pn_x = input0 - abs_x
+ mul_pn_x = pn_x * const_2
+ exp_pn_x = torch.exp(mul_pn_x)
+ div_up = input0 * exp_pn_x
+ div_down_rec = torch.reciprocal(div_down)
+ result = div_up * div_down_rec
+
+ return result.cpu()
+
+
+def npu_fast_gelu_backward(grad, input_x):
+ const_2 = 1.702
+ const_3 = 1.0
+ const_1 = 0.0 - const_2
+
+ # e^(-1.702x)
+ abs_x = torch.abs(input_x)
+ mul_abs_x = abs_x * const_1
+ exp_x = torch.exp(mul_abs_x)
+
+ # 1.702xe^(-1.702x)
+ add_2 = input_x * exp_x
+ add_2 = add_2 * const_2
+
+ # e^(1.702(x-|x|))
+ pn_x = input_x - abs_x
+ mul_pn_x = pn_x * const_2
+ exp_pn_x = torch.exp(mul_pn_x)
+
+ # e^(-1.702x) + 1.702xe^(-1.702x) + e^(1.702(x-|x|))
+ div_up = exp_x + add_2
+ div_up = div_up + exp_pn_x
+
+ # (e^(-1.702x)+1)^2
+ div_down_i = exp_x + const_3
+ div_down = div_down_i * div_down_i
+ div_down_rec = torch.reciprocal(div_down)
+ result_temp = div_up * div_down_rec
+ result = grad * result_temp
+
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py
new file mode 100644
index 00000000000..f6949c079e2
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/layer_norm_eval.py
@@ -0,0 +1,6 @@
+import torch
+
+
+def npu_layer_norm_eval(data, normalized_shape):
+ result = torch.nn.functional.layer_norm(data, normalized_shape)
+ return result.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py
new file mode 100644
index 00000000000..95db875edf6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/linear.py
@@ -0,0 +1,12 @@
+import torch
+
+
+def npu_linear(x, weight, bias):
+ output = torch.nn.functional.linear(x, weight, bias)
+ return output.cpu()
+
+
+def npu_linear_backward(grad, input_data, weight):
+ input_grad = torch.matmul(grad, weight)
+ weight_grad = torch.matmul(grad.t(), input_data)
+ return input_grad.cpu(), weight_grad.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py
new file mode 100644
index 00000000000..ed1c746ec16
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/matmul_backward.py
@@ -0,0 +1,48 @@
+import torch
+
+
+def matmul_backward(grad, self, other, mask):
+ grad_self, grad_other = None, None
+ dim_self = self.dim()
+ dim_other = other.dim()
+
+ size_grad = list(grad.size())
+ size_self = list(self.size())
+ size_other = list(other.size())
+ if dim_self == 1 and dim_other == 1:
+ grad_self = other.mul(grad) if mask[0] else grad_self
+ grad_other = self.mul(grad) if mask[1] else grad_other
+ elif dim_self == 2 and dim_other == 1:
+ grad_self = grad.unsqueeze(1).mm(other.unsqueeze(0)) if mask[0] else grad_self
+ grad_other = self.transpose(-1, -2).mm(grad.unsqueeze(1)).squeeze_(1) if mask[1] else grad_other
+ elif dim_self == 1 and dim_other == 2:
+ grad_self = grad.unsqueeze(0).mm(other.transpose(-1, -2)).squeeze_(0) if mask[0] else grad_self
+ grad_other = self.unsqueeze(1).mm(grad.unsqueeze(0)) if mask[1] else grad_other
+ elif dim_self >= 3 and (dim_other == 1 or dim_other == 2):
+ view_size = 1 if dim_other == 1 else size_grad[-1]
+ unfolded_grad = (grad.unsqueeze(-1) if dim_other == 1 else grad).contiguous().view(-1, view_size)
+ if mask[0]:
+ grad_self = unfolded_grad.mm(other.unsqueeze(0) if dim_other == 1 else other.transpose(-1, -2)) \
+ .view(size_self)
+ if mask[1]:
+ unfolded_self = self.contiguous().view([-1, size_self[-1]])
+ grad_other = unfolded_self.transpose(-1, -2).mm(unfolded_grad).view(size_other)
+ elif (dim_self == 1 or dim_self == 2) and dim_other >= 3:
+ view_size = 1 if dim_self == 1 else size_grad[-2]
+ unfolded_grad_T = grad.view([-1, view_size]) \
+ if dim_self == 1 else grad.transpose(-1, -2).contiguous().view([-1, view_size])
+ if mask[0]:
+ # create a 2D-matrix from other
+ unfolded_other_T = \
+ other.transpose(-1, -2).contiguous().view([-1, size_other[-2]]).transpose(-1, -2)
+ grad_self = unfolded_other_T.mm(unfolded_grad_T).transpose(-1, -2).view(size_self)
+ if mask[1]:
+ size_other_T = size_other[:-2]
+ size_other_T.extend(size_other[::-1][:2])
+ grad_other = \
+ unfolded_grad_T.mm(self.unsqueeze(0) if dim_self == 1 else self).view(size_other_T).transpose(-1, -2)
+ else:
+ grad_self = torch.matmul(grad, other.transpose(-1, -2)) if mask[0] else grad_self
+ grad_other = torch.matmul(self.transpose(-1, -2), grad) if mask[1] else grad_other
+
+ return grad_self.cpu(), grad_other.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py
new file mode 100644
index 00000000000..63f1fa2a3b6
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/npu_fusion_attention.py
@@ -0,0 +1,421 @@
+import torch
+import numpy as np
+from einops import rearrange
+
+from msprobe.pytorch.common.utils import logger
+
+gtype = torch.float64 # arm host必须选择float64,x86环境选择float32即可,64也行。arm计算很慢,s=8k的场景建议使用x86
+softmax_build_mode = "QKV" # "MAX_SUM"
+
+"""
+# 前向函数声明对比
+标杆实现:fusion_attention_forward: q, k, v, drop_mask, atten_mask, pse, scale, keep_prob
+融合算子:npu_fusion_attention_forward: query, key, value, head_num, input_layout, *, pse=None, padding_mask=None,
+ atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647,
+ next_tockens=2147483647, inner_precise=0, prefix=None, sparse_mode=0,
+ gen_mask_parallel=True, sync=False
+
+# 反向函数声明对比
+标杆实现:fusion_attention_backward: dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
+融合算子:npu_fusion_attention_backward: query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None,
+ atten_mask=None, softmax_max=None, softmax_sum=None, softmax_in=None,
+ attention_in=None, scale_value=1.0, keep_prob=1.0, pre_tockens=2147483647,
+ next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
+ numels=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False
+"""
+
+
+def softmax_forward(x):
+ x_max = torch.max(x, dim=-1, keepdims=True)[0]
+ x_sub = x.sub(x_max)
+ y = torch.exp(x_sub)
+ x_sum = y.sum(dim=-1, keepdims=True)
+ res = y.div(x_sum)
+ return res, x_max, x_sum
+
+
+def softmax_grad(dp, softmax_res):
+ muls = dp * softmax_res
+ muls_r = muls.sum(dim=-1, keepdims=True)
+ sub_r = dp - muls_r
+ res = sub_r * softmax_res
+ return res
+
+
+def broadcast_kv(num_heads, num_kv_heads, kv_tensor, dtype):
+ if num_kv_heads == 0 or num_kv_heads < num_heads:
+ raise ValueError(f"num_kv_heads must be non-zero and less than num_heads.")
+
+ factor = num_heads // num_kv_heads
+ kv_shape = kv_tensor.shape
+ B = kv_shape[0]
+ S = kv_shape[2]
+ D = kv_shape[3]
+ kv_res = torch.zeros([B, num_heads, S, D]).to(dtype)
+ for i in range(num_heads):
+ j = i // factor
+ kv_res[:, i:i + 1, :, :] = kv_tensor[:, j:j + 1, :, :]
+ return kv_res
+
+
+def calculate_qk(q, k, atten_mask, pse, scale):
+ if pse is None or len(pse.shape) == 0:
+ qk = torch.matmul(q, k.permute(0, 1, 3, 2)).mul(scale)
+ else:
+ qk = (torch.matmul(q, k.permute(0, 1, 3, 2)) + pse).mul(scale)
+ if atten_mask is None or len(atten_mask.shape) == 0:
+ return qk
+ else:
+ qk = qk + atten_mask.bool() * (-40000.0) # -10000
+ return qk
+
+
+def fusion_attention_forward(q, k, v, drop_mask, atten_mask, pse, scale, keep_prob):
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ softmax_res, softmax_max, softmax_sum = softmax_forward(qk)
+ if drop_mask is None or len(drop_mask.shape) == 0:
+ drop_res = softmax_res
+ else:
+ drop_res = softmax_res * drop_mask * (1.0 / keep_prob)
+ y = torch.matmul(drop_res, v)
+ return y, softmax_max, softmax_sum
+
+
+def fusion_attention_backward(dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob):
+ dp = torch.matmul(dx, v.permute(0, 1, 3, 2))
+ if drop_mask is None or len(drop_mask.shape) == 0:
+ drop_res = softmax_res.permute(0, 1, 3, 2)
+ dp_drop = dp
+ else:
+ drop_res = softmax_res.mul(drop_mask).mul(1.0 / keep_prob).permute(0, 1, 3, 2)
+ dp_drop = dp * drop_mask * (1.0 / keep_prob)
+ dv = torch.matmul(drop_res, dx)
+ softmax_grad_res = (softmax_grad(dp_drop, softmax_res) * scale)
+ dq = torch.matmul(softmax_grad_res, k)
+ dk = torch.matmul(softmax_grad_res.permute(0, 1, 3, 2), q)
+ return dq, dk, dv
+
+
+def parse_bsnd_args(query, key, head_num, input_layout):
+ supported_input_layout = ["BSH", "SBH", "BSND", "BNSD", "TND"]
+ B, S1, S2, N1, N2, D, H1, H2 = None, None, None, head_num, None, None, None, None
+
+ if not isinstance(input_layout, str) or input_layout not in supported_input_layout:
+ raise ValueError(f"Invalid input_layout arg which must be one of {supported_input_layout}.")
+
+ if input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ try:
+ if input_layout == "BSH":
+ B, S1, H1 = query.shape
+ _, S2, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "SBH":
+ S1, B, H1 = query.shape
+ S2, _, H2 = key.shape
+ D = H1 // N1
+ N2 = H2 // D
+ elif input_layout == "BSND":
+ B, S1, N1, D = query.shape
+ _, S2, N2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ elif input_layout == "BNSD":
+ B, N1, S1, D = query.shape
+ _, N2, S2, _ = key.shape
+ H1 = N1 * D
+ H2 = N2 * D
+ except Exception as e:
+ raise ValueError(f"query.shape: {query.shape}, key.shape: {key.shape}, parse_bsnd_args error: {e}") from e
+
+ if D == 0:
+ raise ValueError(f"Value D must be non-zero.")
+ DTYPE = query.dtype
+ return B, S1, S2, N1, N2, D, H1, H2, DTYPE
+
+
+def convert_from_bnsd(_input, input_layout):
+ if input_layout == "BSH":
+ # (B,N,S,D)=>(B,S,N*D)
+ out = rearrange(_input, 'b n s d -> b s (n d)').contiguous()
+ elif input_layout == "SBH":
+ # (B,N,S,D)=>(S,B,N*D)
+ out = rearrange(_input, 'b n s d -> s b (n d)').contiguous()
+ elif input_layout == "BSND":
+ # (B,N,S,D)=>(B,S,N,D)
+ out = rearrange(_input, 'b n s d -> b s n d').contiguous()
+ elif input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ else:
+ out = _input
+ return out
+
+
+def convert_to_bnsd(_input, n, input_layout):
+ # 默认"BNSD"无需处理
+ if input_layout == "BSH":
+ # (B,S,N*D)=>(B,N,S,D)
+ out = rearrange(_input, 'b s (n d) -> b n s d', n=n)
+ elif input_layout == "SBH":
+ # (S,B,N*D)=>(B,N,S,D)
+ out = rearrange(_input, 's b (n d) -> b n s d', n=n)
+ elif input_layout == "BSND":
+ # (B,S,N,D)=>(B,N,S,D)
+ out = rearrange(_input, 'b s n d -> b n s d', n=n)
+ elif input_layout == "TND":
+ raise ValueError(f"input_layout {input_layout} does not supported for now.")
+ else:
+ out = _input
+ if out.dim() != 4:
+ raise ValueError(f"convert qkv format failed with input_layout {input_layout}.")
+ return out.to(gtype)
+
+
+def generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tocken, next_tocken, dtype):
+ """
+ # 当sparse_mode=2、3、4时小算子到融合算子会走这个优化,反过来看就要拆解回原来的基本实现
+ ===> atten_mask = torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(dtype)
+ """
+ shape = [S1, S2]
+
+ if atten_mask is not None:
+ # 当FA的输入已经包含atten_mask时,可以认为已经是转换之后的mask矩阵了,有三种特殊场景,即稀疏矩阵场景,需要进行逆向还原
+ if sparse_mode == 2 or sparse_mode == 3 or sparse_mode == 4:
+ logger.info(f"S1: {S1}, S2:{S2}, atten_mask.shape:{atten_mask.shape}, atten_mask.dtype:{atten_mask.dtype}")
+
+ if atten_mask.dim() == 2 and atten_mask.shape[0] == 2048 and atten_mask.shape[1] == 2048:
+ if atten_mask.equal(torch.from_numpy(np.triu(np.ones([2048, 2048]), k=1)).to(atten_mask.dtype)):
+ if sparse_mode == 2:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
+ elif sparse_mode == 3:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
+ elif sparse_mode == 4:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ logger.debug(f"反向转换atten_mask {atten_mask.shape}")
+ return atten_mask.to(dtype)
+
+ return atten_mask.to(dtype)
+
+ if atten_mask is not None:
+ if atten_mask.dim() == 2:
+ if atten_mask.shape[0] != S1 or atten_mask.shape[1] != S2:
+ raise ValueError(f"Invalid atten_mask shape `SS` {atten_mask.shape}")
+ shape = [S1, S2]
+ elif atten_mask.dim() == 4:
+ if atten_mask.shape[1] == 1:
+ shape = [B, 1, S1, S2] if B != 1 else [1, 1, S1, S2]
+ else:
+ shape = [B, N1, S1, S2] if B != 1 else [1, N1, S1, S2]
+
+ if sparse_mode == 0:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ elif sparse_mode == 1: # no sparse
+ atten_mask = torch.from_numpy(np.zeros(shape))
+ elif sparse_mode == 2:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=1))
+ elif sparse_mode == 3:
+ atten_mask = torch.from_numpy(np.triu(np.ones(shape), k=S2 - S1 + 1))
+ elif sparse_mode == 4:
+ atten_mask_u = torch.from_numpy(np.triu(np.ones(shape), k=next_tocken + 1))
+ atten_mask_l = torch.from_numpy(np.tril(np.ones(shape), k=-pre_tocken - 1))
+ atten_mask = atten_mask_u + atten_mask_l
+ # 注:不会出现sparse_mode=5的情况,该情况要求必须要传入atten_mask,且atten_mask矩阵数据格式须为BNSS或B1SS,
+ # 因此可以认为FA的输入已经是正确的atten_mask了
+ return atten_mask.to(dtype)
+
+
+def generate_kv(key, value, N1, N2):
+ # N不等长适配by cdy
+ if not (N1 == N2):
+ k_new = broadcast_kv(N1, N2, key, key.dtype)
+ v_new = broadcast_kv(N1, N2, value, value.dtype)
+ else:
+ k_new = key
+ v_new = value
+ return k_new, v_new
+
+
+def rebuid_softmax_by_qkv(q, k, atten_mask, pse, scale):
+ """
+ attention = softmax(QK^T/sqrt(d))V
+ softmax(x_i) = e^(x_i - x_max) / sum(e^(x_i - x_max))
+ """
+ logger.info("Using QKV to rebuild original softmax")
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ softmax_res, x_max, x_sum = softmax_forward(qk)
+ return softmax_res
+
+
+def rebuild_softmax_by_max_sum(q, k, atten_mask, pse, scale, softmax_max, softmax_sum):
+ """
+ attention = softmax(QK^T/sqrt(d))V
+ softmax(x_i) = e^(x_i - x_max_i) / x_sum_i)
+ """
+ logger.info("Using softmax_max and softmax_sum to rebuild original softmax")
+ qk = calculate_qk(q, k, atten_mask, pse, scale)
+ if softmax_max.shape[-1] == 0:
+ raise ValueError(f"softmax_max.shape[-1] must be non-zero, softmax_max.shape: {softmax_max.shape}")
+ repeat_dim = qk.shape[-1] // softmax_max.shape[-1]
+ softmax_res = torch.exp(qk.sub(softmax_max.repeat(1, 1, 1, repeat_dim))).div(
+ softmax_sum.repeat(1, 1, 1, repeat_dim))
+ return softmax_res
+
+
+def npu_fusion_attention_forward_patch(*args, **kwargs):
+ # query, key, value, head_num, input_layout
+ if len(args) != 5:
+ raise ValueError(f"Unsupported npu_fusion_attention args {args}.")
+
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[3], args[4])
+ if N1 == N2 and S1 == S2:
+ logger.debug(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ else:
+ logger.debug(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ if not (N1 % N2 == 0 and N1 >= N2):
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
+
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
+
+ new_kwargs = {"keep_prob": 1,
+ "scale": kwargs.get("scale", 1 / (D ** 0.5)),
+ "sparse_mode": kwargs.get("sparse_mode", 0),
+ "prefix": kwargs.get("prefix"),
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
+ "pse": kwargs.get("pse"),
+ "padding_mask": kwargs.get("padding_mask"),
+ "atten_mask": kwargs.get("atten_mask")}
+
+ return args, dims_kwargs, new_kwargs
+
+
+def npu_fusion_attention_backward_patch(*args, **kwargs):
+ if len(args) != 6:
+ raise ValueError(f"Unsupported npu_fusion_attention_grad args {args}.")
+
+ B, S1, S2, N1, N2, D, H1, H2, DTYPE = parse_bsnd_args(args[0], args[1], args[4], args[5])
+ if N1 == N2 and S1 == S2:
+ logger.info(f"running case : BNSD = {B}_{N1}_{S1}_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ else:
+ logger.info(f"running case: BNSD = {B}_{N1}({N2})_{S1}({S2})_{D}, sparse = {kwargs.get('sparse_mode', 0)}")
+ if not (N1 % N2 == 0 and N1 >= N2):
+ raise ValueError(f"N1与N2不匹配,请检查: N1 = {N1}, N2 = {N2}.")
+
+ dims_kwargs = {"B": B, "S1": S1, "S2": S2, "N1": N1, "N2": N2,
+ "D": D, "H1": H1, "H2": H2, "DTYPE": DTYPE}
+
+ new_kwargs = {"keep_prob": 1,
+ "scale_value": kwargs.get("scale_value", 1 / (D ** 0.5)),
+ "sparse_mode": kwargs.get("sparse_mode", 0),
+ "prefix": kwargs.get("prefix"),
+ "pre_tockens": kwargs.get("pre_tockens", 2147483647),
+ "next_tockens": kwargs.get("next_tockens", 2147483647),
+ "pse": kwargs.get("pse"),
+ "padding_mask": kwargs.get("padding_mask"),
+ "softmax_max": kwargs.get("softmax_max"),
+ "softmax_sum": kwargs.get("softmax_sum"),
+ "softmax_in": kwargs.get("softmax_in"),
+ "attention_in": kwargs.get("attention_in"),
+ "seed": kwargs.get("seed", 0),
+ "offset": kwargs.get("offset", 0),
+ "numels": kwargs.get("numels", 0),
+ "atten_mask": kwargs.get("atten_mask")}
+
+ return args, dims_kwargs, new_kwargs
+
+
+def npu_fusion_attention(*args, **kwargs):
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_forward_patch(*args, **kwargs)
+ query, key, value, input_layout = new_args[0], new_args[1], new_args[2], new_args[4]
+ N1 = dims_kwargs.get("N1")
+ N2 = dims_kwargs.get("N2")
+ S1 = dims_kwargs.get("S1")
+ S2 = dims_kwargs.get("S2")
+ B = dims_kwargs.get("B")
+ DTYPE = dims_kwargs.get("DTYPE")
+ atten_mask = new_kwargs.get("atten_mask")
+ keep_prob = new_kwargs.get("keep_prob")
+ sparse_mode = new_kwargs.get("sparse_mode")
+ pre_tockens = new_kwargs.get("pre_tockens")
+ next_tockens = new_kwargs.get("next_tockens")
+ pse = new_kwargs.get("pse")
+ scale = new_kwargs.get("scale")
+
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
+ query = convert_to_bnsd(query, N1, input_layout)
+ key = convert_to_bnsd(key, N2, input_layout)
+ value = convert_to_bnsd(value, N2, input_layout)
+ k_new, v_new = generate_kv(key, value, N1, N2)
+ out_golden, softmax_max, softmax_sum = fusion_attention_forward(q=query, k=k_new, v=v_new,
+ drop_mask=None, atten_mask=atten_mask,
+ pse=pse, scale=scale,
+ keep_prob=keep_prob)
+ if out_golden.dim() == 5:
+ out_golden = out_golden.reshape(out_golden.size(0), out_golden.size(1) * out_golden.size(2), out_golden.size(3),
+ out_golden.size(4))
+ out_golden = convert_from_bnsd(out_golden, input_layout)
+
+ return out_golden.cpu(), softmax_max.repeat(1, 1, 1, 8).cpu(), softmax_sum.repeat(1, 1, 1, 8).cpu()
+
+
+def npu_fusion_attention_grad(*args, **kwargs):
+ # dx, q, k, v, softmax_res, drop_mask, pse, scale, keep_prob
+ new_args, dims_kwargs, new_kwargs = npu_fusion_attention_backward_patch(*args, **kwargs)
+ query, key, value, dx, input_layout = new_args[0], new_args[1], new_args[2], new_args[3], new_args[5]
+ N1 = dims_kwargs.get("N1")
+ N2 = dims_kwargs.get("N2")
+ S1 = dims_kwargs.get("S1")
+ S2 = dims_kwargs.get("S2")
+ B = dims_kwargs.get("B")
+ D = dims_kwargs.get("D")
+ DTYPE = dims_kwargs.get("DTYPE")
+ atten_mask = new_kwargs.get("atten_mask")
+ keep_prob = new_kwargs.get("keep_prob")
+ sparse_mode = new_kwargs.get("sparse_mode")
+ pre_tockens = new_kwargs.get("pre_tockens")
+ next_tockens = new_kwargs.get("next_tockens")
+ pse = new_kwargs.get("pse")
+ softmax_max = new_kwargs.get("softmax_max")
+ softmax_sum = new_kwargs.get("softmax_sum")
+ scale_value = new_kwargs.get("scale_value")
+
+ atten_mask = generate_atten_mask(sparse_mode, atten_mask, B, N1, S1, S2, pre_tockens, next_tockens, DTYPE)
+ query = convert_to_bnsd(query, N1, input_layout)
+ dx = convert_to_bnsd(dx, N1, input_layout)
+ key = convert_to_bnsd(key, N2, input_layout)
+ value = convert_to_bnsd(value, N2, input_layout)
+ k_new, v_new = generate_kv(key, value, N1, N2)
+
+ if softmax_build_mode == "QKV":
+ softmax_res = rebuid_softmax_by_qkv(query, k_new, atten_mask, pse, scale_value)
+ else:
+ softmax_res = rebuild_softmax_by_max_sum(query, k_new, atten_mask, pse, scale_value, softmax_max, softmax_sum)
+
+ dq, dk, dv = fusion_attention_backward(dx, query, k_new, v_new, softmax_res, None, pse, scale_value, keep_prob)
+
+ # N不等长适配by cdy
+ if not (N1 == N2):
+ if N2 == 0:
+ raise ValueError("dims_kwargs.N2 must be non-zero.")
+ G = int(N1 / N2)
+ dk = torch.sum(dk.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
+ dv = torch.sum(dv.reshape(B, N2, G, S2, D), dim=2, keepdim=True).reshape(B, N2, S2, D)
+
+ if dq.dim() == 5:
+ dq = dq.reshape(dq.size(0), dq.size(1) * dq.size(2), dq.size(3), dq.size(4))
+ if dk.dim() == 5:
+ dk = dk.reshape(dk.size(0), dk.size(1) * dk.size(2), dk.size(3), dk.size(4))
+ if dv.dim() == 5:
+ dv = dv.reshape(dv.size(0), dv.size(1) * dv.size(2), dv.size(3), dv.size(4))
+
+ dq = convert_from_bnsd(dq, input_layout)
+ dk = convert_from_bnsd(dk, input_layout)
+ dv = convert_from_bnsd(dv, input_layout)
+
+ return dq.cpu(), dk.cpu(), dv.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py
new file mode 100644
index 00000000000..e647312fdb2
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rms_norm.py
@@ -0,0 +1,15 @@
+import torch
+
+
+def npu_rms_norm(x, gamma, epsilon=1e-5):
+ rstd = torch.rsqrt(torch.mean(torch.pow(x, 2), axis=-1, keepdim=True) + epsilon)
+ res = x * rstd * gamma
+ return res.cpu(), rstd.float().cpu()
+
+
+def npu_rms_norm_backward(grad, x, gamma, rstd):
+ mean_gy = (grad * x * gamma * rstd).mean(dim=-1, keepdim=True)
+ grad_x = (grad * gamma - x * rstd * mean_gy) * rstd
+ grad_gamma = x * grad * rstd
+ return grad_x.cpu(), grad_gamma.cpu()
+
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py
new file mode 100644
index 00000000000..0e0fda5f73f
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/rotary_mul.py
@@ -0,0 +1,52 @@
+import torch
+
+
+def npu_rotary_mul(x, r1, r2):
+ x1, x2 = torch.chunk(x, 2, -1)
+ x_new = torch.cat((-x2, x1), dim=-1)
+ output = r1 * x + r2 * x_new
+ return output.cpu()
+
+
+def npu_rotary_mul_backward(dy_tensor, x, r1, r2):
+ x.requires_grad = True
+ r1.requires_grad = True
+ r2.requires_grad = True
+ # golden
+ x1, x2 = torch.chunk(x, 2, -1)
+ x_new = torch.cat((-x2, x1), dim=-1)
+ golden_tensor = r1 * x + r2 * x_new
+ golden_tensor.backward(dy_tensor)
+ r1_shape = r1.shape
+ r1_grad = torch.zeros(r1_shape).type(torch.float32)
+ r2_grad = torch.zeros(r1_shape).type(torch.float32)
+ x1, x2 = torch.chunk(x.float(), 2, -1)
+ x_new2 = torch.cat((-x2, x1), dim=-1)
+ x_shape = x.shape
+ h = x.float()
+ grad = dy_tensor.float()
+ condition_1 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
+ ((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
+ (r1_shape[1] == x_shape[1]) and (r1_shape[3] == x_shape[3]))
+ condition_2 = (((r1_shape[0] == 1 and x_shape[0] != 1) or (r1_shape[0] == 1 and x_shape[0] == 1)) and
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
+ (r1_shape[2] == x_shape[2]) and (r1_shape[3] == x_shape[3]))
+ condition_3 = (((r1_shape[2] == 1 and x_shape[2] != 1) or (r1_shape[2] == 1 and x_shape[2] == 1)) and
+ ((r1_shape[1] == 1 and x_shape[1] != 1) or (r1_shape[1] == 1 and x_shape[1] == 1)) and
+ (r1_shape[0] == x_shape[0]) and (r1_shape[3] == x_shape[3]))
+ if condition_1:
+ for i in range(x_shape[0]):
+ for j in range(x_shape[2]):
+ r2_grad[0, :, 0, :] += (x_new2[i, :, j, :] * grad[i, :, j, :])
+ r1_grad[0, :, 0, :] += (h[i, :, j, :] * grad[i, :, j, :])
+ elif condition_2:
+ for i in range(x_shape[0]):
+ for j in range(x_shape[1]):
+ r2_grad[0, 0, :, :] += (x_new2[i, j, :, :] * grad[i, j, :, :])
+ r1_grad[0, 0, :, :] += (h[i, j, :, :] * grad[i, j, :, :])
+ elif condition_3:
+ for i in range(x_shape[1]):
+ for j in range(x_shape[2]):
+ r2_grad[:, 0, 0, :] += (x_new2[:, i, j, :] * grad[:, i, j, :])
+ r1_grad[:, 0, 0, :] += (h[:, i, j, :] * grad[:, i, j, :])
+ return x.grad.cpu(), r1_grad.cpu(), r2_grad.cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py
new file mode 100644
index 00000000000..8717aebaf90
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/scaled_mask_softmax.py
@@ -0,0 +1,26 @@
+import torch
+
+
+def npu_scaled_masked_softmax(x, mask, scale, fixed_triu_mask):
+ if fixed_triu_mask:
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
+ dtype = x.dtype
+ x = (x * scale).masked_fill(mask, value=-10000)
+ x = x - torch.max(x, dim=-1, keepdims=True)[0]
+ x = torch.exp(x.float())
+ y = torch.div(x, torch.sum(x, dim=-1, keepdims=True))
+ return y.to(dtype).cpu()
+
+
+def npu_scaled_masked_softmax_backward(y_grad, y, mask, scale, fixed_triu_mask):
+ if fixed_triu_mask:
+ mask = (torch.triu(torch.ones(mask.shape), k=1)).bool().to(mask.device)
+ dtype = y_grad.dtype
+ y_grad = y_grad.float()
+ y = y.float()
+ x_grad = y_grad * y
+ x_grad = y_grad - torch.sum(x_grad, dim=-1, keepdims=True)
+ x_grad = x_grad * y
+ x_grad = x_grad * scale
+ x_grad = x_grad.masked_fill(mask, value=0)
+ return x_grad.to(dtype).cpu()
diff --git a/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py
new file mode 100644
index 00000000000..e03c975a50a
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/bench_functions/swiglu.py
@@ -0,0 +1,55 @@
+import torch
+
+
+def npu_swiglu(x, dim=-1):
+ tensor_dtype = x.dtype
+
+ inTensors = torch.chunk(x, 2, dim=dim)
+ if tensor_dtype == torch.float32:
+ tensor_scalar = torch.sigmoid(torch.mul(inTensors[0], 1.0))
+ output_data = torch.mul(torch.mul(tensor_scalar, inTensors[0]), inTensors[1])
+ else:
+ tensor_self_float = inTensors[0].type(torch.float)
+ tensor_other_float = inTensors[1].type(torch.float)
+ tensor_out_float = torch.nn.functional.silu(tensor_self_float).type(tensor_dtype).type(
+ torch.float32) * tensor_other_float
+ output_data = tensor_out_float.type(tensor_dtype)
+ return output_data.cpu()
+
+
+def npu_swiglu_backward(grad, x, dim=-1):
+ tensor_dtype = grad.dtype
+ in_tensors = torch.chunk(x, 2, dim=dim)
+ tensor_grad_out = grad
+
+ if tensor_dtype == torch.float16:
+ tensor_out1 = torch.mul(
+ torch.mul(in_tensors[1].type(torch.float32), swish_grad(1, in_tensors[0].type(torch.float32))),
+ tensor_grad_out.type(torch.float32)).type(torch.float16)
+ tensor_out2 = torch.mul(tensor_grad_out.type(torch.float32),
+ swish(1, in_tensors[0].type(torch.float32))).type(torch.float16)
+ output = torch.cat((tensor_out1, tensor_out2), dim)
+ elif tensor_dtype == torch.bfloat16:
+ tensor_self_float = in_tensors[0].type(torch.float)
+ tensor_other_float = in_tensors[1].type(torch.float)
+ tensor_gradout_float = tensor_grad_out.type(torch.float)
+
+ tensor_out1 = torch.mul(tensor_gradout_float, swish_grad(1.0, tensor_self_float)).type(torch.bfloat16).type(
+ torch.float32) * tensor_other_float
+ tensor_out2 = swish(1.0, tensor_self_float).type(torch.bfloat16).type(torch.float32) * tensor_gradout_float
+ tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=dim)
+ output = tensor_out_float.type(torch.bfloat16)
+ else:
+ tensor_out1 = torch.mul(torch.mul(in_tensors[1], swish_grad(1.0, in_tensors[0])), tensor_grad_out)
+ tensor_out2 = torch.mul(tensor_grad_out, swish(1.0, in_tensors[0]))
+ output = torch.cat((tensor_out1, tensor_out2), dim)
+ return output.cpu()
+
+
+def swish_grad(beta, x):
+ return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta
+
+
+def swish(beta, x):
+ return x * torch.sigmoid(beta * x)
+
diff --git a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
index 22f79879867..ccad903724c 100644
--- a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
+++ b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py
@@ -1,5 +1,7 @@
import json
+
from msprobe.core.common.exceptions import ParseJsonException
+from msprobe.core.common.file_check import FileOpen
def parse_json_info_forward_backward(json_path):
@@ -11,7 +13,7 @@ def parse_json_info_forward_backward(json_path):
api_name = '.'.join(name_struct[:-1])
return api_name
- with open(json_path, 'r') as f:
+ with FileOpen(json_path, 'r') as f:
dump_json = json.load(f)
real_data_path = dump_json.get("dump_data_dir")
diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py
index acc1de10514..181491488f9 100644
--- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py
@@ -14,10 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
+import logging
import os
import random
import stat
import torch
+import torch.distributed as dist
import numpy as np
from functools import wraps
from msprobe.core.common.exceptions import DistributedNotInitializedError
@@ -221,3 +223,36 @@ class Const:
CONVERT_API = {
"int32_to_int64": ["cross_entropy"]
}
+
+
+def get_tensor_rank(in_feat, out_feat):
+ if dist.is_initialized():
+ return dist.get_rank()
+
+ def get_tensor_rank_single(x):
+ if isinstance(x, (list, tuple)):
+ if len(x) > 0:
+ return get_tensor_rank_single(x[0])
+ elif isinstance(x, torch.Tensor):
+ device = x.device
+ if device.type != 'cpu':
+ return device.index
+ return None
+
+ in_rank = get_tensor_rank_single(in_feat)
+ out_rank = get_tensor_rank_single(out_feat)
+ tensor_rank = in_rank if in_rank else out_rank
+ return tensor_rank
+
+
+def _create_logger(level=logging.INFO):
+ logger_ = logging.getLogger()
+ logger_.setLevel(level)
+ ch = logging.StreamHandler()
+ ch.setLevel(level)
+ logger_.addHandler(ch)
+ return logger_
+
+
+log_level = logging.DEBUG if os.environ.get("API_ACCURACY_CHECK_LOG_LEVEL") == "1" else logging.INFO
+logger = _create_logger(log_level)
diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
index e214910566e..2a68c756ed3 100644
--- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py
@@ -492,7 +492,7 @@ def compare_by_op(op_name, op_name_mapping_dict, input_parma):
error_file = error.filename
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
error_flag = True
- except FileCheckerException:
+ except FileCheckException:
error_file = data_name
n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE
error_flag = True
@@ -645,7 +645,11 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
elif (i - 2) in highlight_dict['yellow_rows']:
ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW,
end_color=CompareConst.YELLOW, fill_type="solid")
- wb.save(file_path)
+ try:
+ wb.save(file_path)
+ except Exception as e:
+ logger.error('Save result file failed')
+ raise CompareException(CompareException.WRITE_FILE_ERROR) from e
change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
@@ -655,8 +659,8 @@ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True,
summary_compare, md5_compare = task_dumppath_get(input_parma)
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
create_directory(output_path)
- check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare)
- except CompareException as error:
+ check_compare_param(input_parma, output_path, summary_compare, md5_compare)
+ except (CompareException, FileCheckException) as error:
logger.error('Compare failed. Please check the arguments and do it again!')
sys.exit(error.code)
compare_core(input_parma, output_path, stack_mode=stack_mode,
@@ -764,9 +768,14 @@ def op_item_parse(item, op_name, index, item_list=None, top_bool=True):
else:
full_op_name = op_name
else:
- full_op_name = op_name + '.' + str(index)
+ full_op_name = op_name + Const.SEP + str(index)
if isinstance(item, dict):
- if 'dtype' in item:
+ if 'type' not in item:
+ for kwarg in item:
+ kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None)
+ item_list += kwarg_parsed_list
+ kwarg_parsed_list.clear()
+ elif 'dtype' in item:
parsed_item = item
parsed_item['full_op_name'] = full_op_name
item_list.append(parsed_item)
@@ -869,13 +878,13 @@ def read_op(op_data, op_name):
op_parsed_list += output_parsed_list
output_parsed_list.clear()
if 'backward' in op_name:
- if 'grad_input' in op_data:
- input_item = op_data['grad_input']
+ if 'input' in op_data:
+ input_item = op_data['input']
input_parsed_list = op_item_parse(input_item, op_name + '_input', None)
op_parsed_list = input_parsed_list.copy()
input_parsed_list.clear()
- if 'grad_output' in op_data:
- output_item = op_data['grad_output']
+ if 'output' in op_data:
+ output_item = op_data['output']
output_parsed_list = op_item_parse(output_item, op_name + '_output', None)
op_parsed_list += output_parsed_list
output_parsed_list.clear()
diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
index 0298eca9e7e..caac1395807 100644
--- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
+++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py
@@ -21,6 +21,7 @@ from msprobe.core.common.utils import CompareException, check_compare_param, \
check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid
from msprobe.pytorch.compare.acc_compare import compare_core
from msprobe.core.common.file_check import create_directory
+from msprobe.core.common.exceptions import FileCheckException
from msprobe.pytorch.common.log import logger
@@ -86,12 +87,11 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
'or use compare() api and manually match the ranks.')
raise CompareException(CompareException.INVALID_PATH_ERROR)
for nr, br in zip(npu_ranks, bench_ranks):
- n_dir = os.path.join(npu_dump_dir, nr)
- b_dir = os.path.join(bench_dump_dir, br)
- s_dir = b_dir
- npu_json_path = extract_json(n_dir, stack_json=False)
- bench_json_path = extract_json(b_dir, stack_json=False)
- stack_json_path = extract_json(s_dir, stack_json=True)
+ npu_data_dir = os.path.join(npu_dump_dir, nr)
+ bench_data_dir = os.path.join(bench_dump_dir, br)
+ npu_json_path = extract_json(npu_data_dir, stack_json=False)
+ bench_json_path = extract_json(bench_data_dir, stack_json=False)
+ stack_json_path = extract_json(npu_data_dir, stack_json=True)
dump_result_param = {
'npu_json_path': npu_json_path,
@@ -103,8 +103,8 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
summary_compare, md5_compare = task_dumppath_get(dump_result_param)
check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
create_directory(output_path)
- check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare)
- except CompareException as error:
+ check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
+ except (CompareException, FileCheckException) as error:
logger.error('Compare failed. Please check the arguments and do it again!')
sys.exit(error.code)
compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
index cfc588e1e97..f1289e9b013 100644
--- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py
@@ -21,7 +21,7 @@ class DebuggerConfig:
self.acl_config = common_config.acl_config if common_config.acl_config else ""
self.is_forward_acl_dump = True
self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS
- self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1
+ self.overflow_nums = task_config.overflow_nums if task_config.overflow_nums else 1
self.framework = Const.PT_FRAMEWORK
if self.task == Const.FREE_BENCHMARK:
diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
index 137c51895d0..6119bbd1d4f 100644
--- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
+++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py
@@ -27,6 +27,7 @@ class PrecisionDebugger:
step=None,
):
if not hasattr(self, "initialized"):
+ self.api_origin = False
self.initialized = True
self.model = self.check_model_valid(model)
common_config, task_config = parse_json_config(config_path, task)
diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
index b3ed4a9e24c..41b97098ae9 100644
--- a/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md
@@ -21,7 +21,7 @@
精度预检操作流程如下:
1. 在NPU和GPU环境下分别安装msprobe工具。详见《[MindStudio精度调试工具](../../README.md)》的“工具安装”章节。
-2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger采集待预检数据。详见《[精度数据采集](./dump.md)》。
+2. 在NPU训练脚本内添加msprobe工具dump接口PrecisionDebugger,采集待预检数据。详见《[精度数据采集](./dump.md)》,注意需要配置level="L1"。
3. 将NPU环境下dump的预检数据拷贝至GPU环境。
4. 在NPU和GPU环境下分别执行run_ut,生成结果用于最终api_precision_compare操作的输入。详见“**run_ut预检操作**”。
5. 将NPU和GPU执行run_ut生成的`accuracy_checking_details_{timestamp}.csv`结果文件拷贝至同一环境下。
@@ -51,10 +51,12 @@ run_ut预检操作包括如下场景:
| -api_info或--api_info_file | 指定API信息文件dump.json。 | 是 |
| -save_error_data | 保存精度未达标的API输入输出数据。 | 否 |
| -o或--out_path | 指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 |
+ | | | |
| -j或--jit_compile | 开启jit编译。 | 否 |
| -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 |
| -csv_path或--result_csv_path | 指定本次运行中断时生成的`accuracy_checking_result_{timestamp}.csv`文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的`accuracy_checking_result_{timestamp}.csv`文件。详见“**断点续检**”。 | run_ut操作中断后继续执行场景下必选 |
| -f或--filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的API。适用于模型较大且重复API较多的场景。 | 否 |
+ | -config或--config_path | 指定预检操作过程中的额外配置(包括黑名单、白名单等)的[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,默认未配置。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md#pytorch场景task配置为run_ut)》。 | 否 |
run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。详细介绍请参见“**预检结果**”。
@@ -64,7 +66,7 @@ run_ut预检操作包括如下场景:
msprobe -f pytorch run_ut -api_info ./dump.json -save_error_data
```
- 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过修改mstt/debug/accuracy_tools/api_accuracy_checker目录下,config.yaml文件的error_data_path参数来配置保存路径,详见“config.yaml文件说明”。
+ 数据默认会存盘到'./ut_error_data{timestamp}'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过error_data_path参数来配置保存路径,error_data_path参数在[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件或config.yaml文件配置,config.json文件需要在run_ut操作时通过-config参数指定,config.yaml文件详见“**config.yaml文件说明**”。
#### 使用multi_run_ut.py执行多线程预检
@@ -99,23 +101,65 @@ msprobe -f pytorch multi_run_ut -api_info ./dump.json -n 32 -d 0 1 2 3
msprobe -f pytorch run_ut -api_info ./dump.json -csv_path /home/xxx/ut/accuracy_checking_result_{timestamp}.csv
```
-#### API预检白名单
+#### API预检黑名单和白名单
-run_ut过程支持API预检白名单,操作方式如下:
+run_ut过程支持API预检黑名单和白名单,通过如下文件配置black_list(黑名单)或white_list(白名单)参数来指定不需要或需要预检的API名称:
-修改mstt/debug/accuracy_tools/api_accuracy_checker目录下config.yaml文件的white_list参数,配置需要预检的API名称,详见“config.yaml文件说明”。
+- 配置[config.json](https://gitee.com/ascend/mstt/tree/master/debug/accuracy_tools/msprobe/config)文件,config.json文件需要在run_ut操作时通过-config参数指定。
+- 配置config.yaml文件,详见“**config.yaml文件说明**”。
+
+config.json文件的优先级高于config.yaml文件,即执行config.json文件时,config.yaml文件的配置不生效。
### config.yaml文件说明
-config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单等功能。
+config.yaml文件可以通过配置参数来控制dump和run_ut操作的白名单、黑名单等功能。操作步骤如下:
+
+1. 查找msprobe工具安装路径。
+
+ ```bash
+ pip show mindstudio-probe
+ ```
+
+ 输出结果如下示例:
+
+ ```bash
+ Name: mindstudio-probe
+ Version: 1.0
+ Summary: This is a pytorch precision comparison tools
+ Home-page:
+ Author:
+ Author-email:
+ License:
+ Location: /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages
+ Requires: numpy, openpyxl, pandas, pyyaml, rich, tqdm, wheel
+ Required-by:
+ ```
+
+ Location字段为msprobe工具的安装路径,那么config.yaml文件位置为/home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
+
+2. 进入config.yaml文件
+
+ ```bash
+ vi /home/xx/anaconda3/envs/pt21py38/lib/python3.8/site-packages/msprobe/pytorch/api_accuracy_checker/config.yaml
+ ```
+
+3. 修改config.yaml文件参数。
+
+ ```yaml
+ white_list: []
+ black_list: []
+ error_data_path: './'
+ precision: 14
+ ```
-文件路径为:mstt/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml
+ | 参数名称 | 说明 | 是否必选 |
+ | --------------- | ------------------------------------------------------------ | -------- |
+ | white_list | API dump白名单,仅对指定的API进行dump。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
+ | black_list | API dump黑名单,被指定的API不进行dump。参数示例:black_list=["conv1d", "conv2d"]。默认未配置黑名单,即dump全量API数据。 | 否 |
+ | error_data_path | 配置保存精度未达标的API输入输出数据路径。参数示例"error_data_path": "./"。默认为当前路径。 | 否 |
+ | precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
-| 参数名称 | 说明 | 是否必选 |
-| --------------- | ------------------------------------------------------------ | -------- |
-| white_list | API dump白名单,指定dump具体API数据,也可以直接配置预检的API白名单,详细请参见“**API预检白名单**”。参数示例:white_list=["conv1d", "conv2d"]。默认未配置白名单,即dump全量API数据。 | 否 |
-| error_data_path | 配置保存精度未达标的API输入输出数据路径。 | 否 |
-| precision | 浮点数表示位数,默认取小数点后14位。 | 否 |
+ 说明:white_list和black_list同时配置时,二者配置的API名单若无交集,则白名单生效,若API名单存在交集,则白名单排除的部分以及交集的API不进行dump。
## 预检结果
diff --git a/debug/accuracy_tools/msprobe/pytorch/doc/dump.md b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md
index 7d0763b6848..7e393cd1026 100644
--- a/debug/accuracy_tools/msprobe/pytorch/doc/dump.md
+++ b/debug/accuracy_tools/msprobe/pytorch/doc/dump.md
@@ -12,7 +12,7 @@ msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方
通过加载dump配置文件的方式来确定dump操作的详细配置。
-可以在from msprobe.pytorch import PrecisionDebugger和模型初始化之间的任意位置添加该接口。
+PrecisionDebugger接口可以在from msprobe.pytorch import PrecisionDebugger之后的位置添加。详细使用可参考“**示例代码**”或“**model配置代码示例**”。
**原型**
@@ -20,7 +20,7 @@ msprobe工具主要通过在训练脚本内添加dump接口并启动训练的方
PrecisionDebugger(config_path=None, task=None, dump_path=None, level=None, model=None, step=None)
```
-说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。
+说明:上述参数除config_path和model外,其他参数均在[config.json](../../config)文件中可配,此处的参数优先级高于[config.json](../../config)文件中的配置,而config.json文件可以配置更多参数,若需要进行更多场景的精度数据dump,建议配置[config.json](../../config)文件。config.json文件的配置可参考《[配置文件说明](https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/config/README.md)》。
**参数说明**
@@ -77,9 +77,9 @@ if __name__ == "__main__"
**功能说明**
-启动函数。
+dump启动函数。
-在模型初始化之后的任意位置添加。
+在模型初始化之后的位置添加。需要与stop函数一起添加在for循环内。
**原型**
@@ -93,9 +93,9 @@ debugger.start()
**功能说明**
-停止函数。
+dump停止函数。
-在**start**函数之后的任意位置添加。
+在**start**函数之后的任意位置添加。若需要dump反向数据,则需要添加在反向计算代码(如loss.backward)之后。
**原型**
@@ -105,13 +105,33 @@ debugger.stop()
该函数为类函数,可以使用debugger.stop()也可以使用PrecisionDebugger.stop()。
+### forward_backward_dump_end函数
+
+**功能说明**
+
+dump停止函数。用于dump指定代码的前反向数据。
+
+在**start**函数之后,反向计算代码(如loss.backward)之前的任意位置添加,可以dump **start**函数和该函数之间的前反向数据,可以通过调整**start**函数与该函数的位置,来指定需要dump的代码块。
+
+要求**stop**函数添加在反向计算代码(如loss.backward)之后,此时该函数与**stop**函数之间的代码不会被dump。
+
+使用示例参见“**示例代码 > 扩展示例**”。
+
+**原型**
+
+```Python
+forward_backward_dump_end()
+```
+
+该函数为类函数,可以使用debugger.forward_backward_dump_end()也可以使用PrecisionDebugger.forward_backward_dump_end()。
+
### step函数
**功能说明**
结束标识。
-在最后一个**stop**函数后或一个step结束的位置添加。
+在最后一个**stop**函数后或一个step结束的位置添加。需要与start函数一起添加在for循环内。
**原型**
@@ -123,24 +143,57 @@ debugger.step()
## 示例代码
+### 基础操作
+
+如下示例可dump完整代码的前反向数据。
+
```Python
from msprobe.pytorch import PrecisionDebugger
+
+# 请勿将PrecisionDebugger的初始化流程插入到循环代码中
debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
-# 请勿将以上初始化流程插入到循环代码中
-# 模型初始化
-# 下面代码也可以用PrecisionDebugger.start()和PrecisionDebugger.stop()
-debugger.start()
+# 模型、损失函数的定义及初始化等操作
+# ...
-# 需要dump的代码片段1
+# 数据集迭代的位置一般为模型训练开始的位置
+for data, label in data_loader:
+ debugger.start() # 开启数据dump
-debugger.stop()
-debugger.start()
+ # 如下是模型每个step执行的逻辑
+ output = model(data)
+ #...
+ loss.backward()
+
+ debugger.stop() # 关闭数据dump
+ debugger.step() # 结束一个step的dump
+```
-# 需要dump的代码片段2
+### 扩展示例
-debugger.stop()
-debugger.step()
+如下示例dump指定代码块前反向数据。
+
+```Python
+from msprobe.pytorch import PrecisionDebugger
+
+# 请勿将PrecisionDebugger的初始化流程插入到循环代码中
+debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
+
+# 模型、损失函数的定义及初始化等操作
+# ...
+
+# 数据集迭代的位置一般为模型训练开始的位置
+for data, label in data_loader:
+ debugger.start() # 开启数据dump
+
+ # 如下是模型每个step执行的逻辑
+ output = model(data)
+ debugger.forward_backward_dump_end() # 插入该函数到start函数之后,只dump start函数到该函数之间代码的前反向数据,本函数到stop函数之间的数据则不dump
+ #...
+ loss.backward()
+
+ debugger.stop() # 关闭数据dump
+ debugger.step() # 结束一个step的dump
```
## dump结果文件介绍
diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
index e737e7b2179..c5e93be138d 100644
--- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/constant.py
@@ -52,6 +52,7 @@ class ThresholdConfig:
DTYPE_PER_THD = {
torch.float16: 1.002,
+ torch.bfloat16: 1.004,
torch.float32: 1.0002,
}
BENCHMARK_THD_DICT = {
@@ -60,6 +61,8 @@ class ThresholdConfig:
torch.bfloat16: BenchmarkThd(2**-8, 1.0, 2**-8, 1e-4),
}
+ TENSOR_SPLIT_MAX_CHUNK = 128
+
class PreheatConfig:
IF_PREHEAT = "if_preheat"
diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
index ddcbd9d0f5c..631beeb85cb 100644
--- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/common/utils.py
@@ -96,3 +96,7 @@ class TorchC:
add = torch._C._VariableFunctionsClass.add
bitwise_xor = torch._C._VariableFunctionsClass.bitwise_xor
clone = torch._C._VariableFunctionsClass.clone
+ clamp = torch._C._VariableFunctionsClass.clamp
+ tensor_split = torch._C._VariableFunctionsClass.tensor_split
+ stack = torch._C._VariableFunctionsClass.stack
+ reshape = torch._C._VariableFunctionsClass.reshape
diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
index 1728b096f5b..e36f5867355 100644
--- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
+++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/base_handler.py
@@ -1,6 +1,7 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
+import numpy as np
import torch
from msprobe.core.common.const import Const
@@ -34,15 +35,36 @@ class FuzzHandler(ABC):
origin_ouput = origin_ouput.values
perturbed_output = perturbed_output.values
if hasattr(perturbed_output, "dtype"):
- abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype)
+ abs_tol = ThresholdConfig.ABS_TOL_VALUE_DICT.get(perturbed_output.dtype, FuzzThreshold.F32_THD)
else:
- abs_tol = FuzzThreshold.F32_THD.value
+ abs_tol = FuzzThreshold.F32_THD
return (
origin_ouput.to(perturbed_output.dtype).to(perturbed_output.device),
perturbed_output,
abs_tol,
)
+ @staticmethod
+ def tensor_split_for_error_calculate(origin_output, perturbed_output):
+ """
+ 对将投入误差值计算的扰动前后输出张量进行分块
+ :param origin_output: 原始输出
+ :param perturbed_output: 扰动后输出
+ :return origin_output_chunks: 切块后原始输出列表
+ :return perturbed_output_chunks: 切块后扰动后输出列表
+ """
+ single_output_mem = origin_output.element_size() * origin_output.nelement() / Const.ONE_MB
+ if single_output_mem == 0 or origin_output.ndim == 0:
+ return [origin_output], [perturbed_output]
+ # 张量大小和批数之间的关系:chunks_exp=math.log(M,2)-4, chunks=2**chunks_exp (M为对比张量数据大小[Mb])
+ chunks_exp = int(math.log(single_output_mem, 2)) - 4
+ chunks = 2 ** chunks_exp
+ chunks = max(chunks, 1)
+ chunks = min(chunks, ThresholdConfig.TENSOR_SPLIT_MAX_CHUNK)
+ origin_output_chunks = TorchC.tensor_split(TorchC.reshape(origin_output, (-1,)), chunks)
+ perturbed_output_chunks = TorchC.tensor_split(TorchC.reshape(perturbed_output, (-1,)), chunks)
+ return origin_output_chunks, perturbed_output_chunks
+
@staticmethod
def convert_overflow_ratio_to_consistent(ratio):
if math.isnan(ratio) or math.isinf(ratio):
@@ -61,36 +83,28 @@ class FuzzHandler(ABC):
self, origin_output, perturbed_output, norm_type, abs_tol
):
if norm_type == NormType.ENDLESS_NORM:
- return self.get_endless_norm(origin_output, perturbed_output, abs_tol)
+ return self.calculate_error(origin_output, perturbed_output, abs_tol)
return ThresholdConfig.COMP_CONSISTENT
- def get_endless_norm(self, origin_output, perturbed_output, abs_tol):
- ratio_tensor1 = TorchC.where(
- TorchC.gt(TorchC.abs(perturbed_output), abs_tol),
- TorchC.div(
- TorchC.abs(origin_output),
- TorchC.add(TorchC.abs(perturbed_output), abs_tol),
- ),
- 1,
- )
- ratio_tensor2 = TorchC.where(
- TorchC.gt(TorchC.abs(origin_output), abs_tol),
- TorchC.div(
- TorchC.abs(perturbed_output),
- TorchC.add(TorchC.abs(origin_output), abs_tol),
- ),
- 1,
- )
+ def calculate_error(self, origin_output, perturbed_output, abs_tol):
+ origin_output_chunks, perturbed_output_chunks = self.tensor_split_for_error_calculate(origin_output, perturbed_output)
+ norm1 = -np.inf
+ norm2 = -np.inf
+ norm3 = np.inf
+ for i, chunk_origin in enumerate(origin_output_chunks):
+ if chunk_origin.nelement() == 0:
+ break
+ chunk_perturbed = perturbed_output_chunks[i]
+ ratio_tensor1 = TorchC.where(TorchC.abs(chunk_perturbed) > abs_tol,
+ TorchC.div(TorchC.clamp(chunk_origin, min=abs_tol), TorchC.clamp(chunk_perturbed, min=abs_tol)), 1)
+ ratio_tensor2 = TorchC.where(TorchC.abs(chunk_origin) > abs_tol,
+ TorchC.div(TorchC.clamp(chunk_perturbed, min=abs_tol), TorchC.clamp(chunk_origin, min=abs_tol)), 1)
+ norm_values = TorchC.stack([TorchC.max(ratio_tensor1), TorchC.max(ratio_tensor2)])
+ max_ratio1, max_ratio2 = norm_values.tolist()
+ norm1 = max(norm1, self.convert_overflow_ratio_to_consistent(max_ratio1))
+ norm2 = max(norm2, self.convert_overflow_ratio_to_consistent(max_ratio2))
+ norm3 = min(norm3, self.convert_overflow_ratio_to_consistent(max_ratio1))
- norm1 = self.convert_overflow_ratio_to_consistent(
- TorchC.max(ratio_tensor1).item()
- )
- norm2 = self.convert_overflow_ratio_to_consistent(
- TorchC.max(ratio_tensor2).item()
- )
- norm3 = self.convert_overflow_ratio_to_consistent(
- TorchC.min(ratio_tensor1).item()
- )
if norm3 < 0:
ratio = ThresholdConfig.SYMBOL_FLIPPING
else:
diff --git a/debug/accuracy_tools/msprobe/pytorch/function_factory.py b/debug/accuracy_tools/msprobe/pytorch/function_factory.py
new file mode 100644
index 00000000000..c2fd8bfd0cb
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/pytorch/function_factory.py
@@ -0,0 +1,75 @@
+from msprobe.pytorch.common.utils import logger
+from msprobe.pytorch.bench_functions.apply_adam_w import npu_apply_adam_w
+from msprobe.pytorch.bench_functions.confusion_transpose import npu_confusion_transpose, \
+ npu_confusion_transpose_backward
+from msprobe.pytorch.bench_functions.fast_gelu import fast_gelu, npu_fast_gelu_backward
+from msprobe.pytorch.bench_functions.layer_norm_eval import npu_layer_norm_eval
+from msprobe.pytorch.bench_functions.linear import npu_linear, npu_linear_backward
+from msprobe.pytorch.bench_functions.matmul_backward import matmul_backward
+from msprobe.pytorch.bench_functions.npu_fusion_attention import npu_fusion_attention, npu_fusion_attention_grad
+from msprobe.pytorch.bench_functions.rms_norm import npu_rms_norm, npu_rms_norm_backward
+from msprobe.pytorch.bench_functions.rotary_mul import npu_rotary_mul, npu_rotary_mul_backward
+from msprobe.pytorch.bench_functions.scaled_mask_softmax import npu_scaled_masked_softmax, \
+ npu_scaled_masked_softmax_backward
+from msprobe.pytorch.bench_functions.swiglu import npu_swiglu, npu_swiglu_backward, swish_grad, swish
+
+
+class Register(dict):
+ def __init__(self, *args, **kwargs):
+ super(Register, self).__init__(*args, **kwargs)
+ self._dict = {}
+
+ def __call__(self, target_func_list):
+ for target in target_func_list:
+ self.register(target)
+ return
+
+ def __setitem__(self, key, value):
+ self._dict[key] = value
+
+ def __getitem__(self, key):
+ return self._dict[key]
+
+ def __contains__(self, key):
+ return key in self._dict
+
+ def __str__(self):
+ return str(self._dict)
+
+ def keys(self):
+ return self._dict.keys()
+
+ def values(self):
+ return self._dict.values()
+
+ def items(self):
+ return self._dict.items()
+
+ def register(self, target):
+
+ def add_register_item(key, value):
+ if key in self._dict:
+ logger.warning(f"{value.__name__} has been registered before, so we will overriden it.")
+ self[key] = value
+ return value
+
+ if callable(target):
+ return add_register_item(target.__name__, target)
+ else:
+ raise Exception(f"The func {target} is not callable.")
+
+
+# register for npu custom bench functions
+npu_custom_functions = Register()
+npu_custom_functions([
+ npu_apply_adam_w, npu_confusion_transpose, fast_gelu, npu_layer_norm_eval, npu_linear, npu_fusion_attention,
+ npu_rms_norm, npu_rotary_mul, npu_scaled_masked_softmax, npu_swiglu
+])
+
+# register for npu custom backward bench functions
+npu_custom_grad_functions = Register()
+npu_custom_grad_functions([
+ npu_confusion_transpose_backward, npu_fast_gelu_backward, npu_linear_backward, matmul_backward,
+ npu_fusion_attention_grad, npu_rms_norm_backward, npu_rotary_mul_backward, npu_scaled_masked_softmax_backward,
+ npu_swiglu_backward
+])
diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
index 6693a09d028..ff6427e51e5 100644
--- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py
@@ -17,9 +17,11 @@
import functools
import threading
+
import torch
import torch.nn as nn
import torch.utils.hooks as full_hooks
+
from msprobe.core.common.const import Const
@@ -61,6 +63,10 @@ class HOOKModule(nn.Module):
HOOKModule.inner_stop_hook[self.current_thread] = False
return result
+ @classmethod
+ def reset_module_stats(cls):
+ cls.module_count = {}
+
def _call_func(self, *input, **kwargs):
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0:
diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
index d64c577ff38..f68708e945e 100644
--- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml
@@ -1873,4 +1873,5 @@ distributed:
- reduce_scatter
- _reduce_scatter_base
- _all_gather_base
- - all_to_all_single
\ No newline at end of file
+ - all_to_all_single
+ - all_to_all
\ No newline at end of file
diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
index 4617e4854fc..a02abbe5f4b 100644
--- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py
@@ -24,12 +24,14 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule
from msprobe.pytorch.common.utils import torch_device_guard
from msprobe.core.common.const import Const
from msprobe.core.common.file_check import FileOpen
-
+from msprobe.pytorch.function_factory import npu_custom_grad_functions
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
with FileOpen(yaml_path, 'r') as f:
- WrapAtenOps = yaml.safe_load(f).get('aten')
+ Ops = yaml.safe_load(f)
+ WrapAtenOps = Ops.get('aten')
+ WhiteAtenOps = Ops.get('white_aten_ops', [])
aten_func = {}
@@ -48,7 +50,7 @@ class HOOKAtenOP(object):
class AtenOPTemplate(HOOKModule):
- def __init__(self, op, hook):
+ def __init__(self, op, hook, need_hook=True):
if isinstance(op, torch._ops.OpOverloadPacket):
op_name_ = op._qualified_op_name.split("::")[-1]
else:
@@ -58,10 +60,21 @@ class AtenOPTemplate(HOOKModule):
op_name_ = op_name_ + '.' + overload_name
self.op = op
self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP
- super().__init__(hook)
+ self.need_hook = need_hook
+ if self.need_hook:
+ super().__init__(hook)
@torch_device_guard
def forward(self, *args, **kwargs):
+ if isinstance(self.op, str):
+ if self.op in npu_custom_grad_functions:
+ return npu_custom_grad_functions[self.op](*args, **kwargs)
+ if self.op in WhiteAtenOps:
+ return eval(f"torch.ops.aten.{self.op}")(*args, **kwargs)
+ if self.op not in aten_func:
+ raise Exception(f"Skip op[{self.op}] accuracy check, because the op is not "
+ f"in dir(torch.ops.aten) and support yaml.")
+ return aten_func[self.op](*args, **kwargs)
return self.op(*args, **kwargs)
diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
index 992713bce57..8a67ed94290 100644
--- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
+++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py
@@ -17,19 +17,26 @@
import os
import torch
-import torch_npu
import yaml
from msprobe.pytorch.hook_module.hook_module import HOOKModule
from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version
from msprobe.core.common.const import Const
from msprobe.core.common.file_check import FileOpen
+from msprobe.pytorch.function_factory import npu_custom_functions
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
with FileOpen(yaml_path, 'r') as f:
WrapNpuOps = yaml.safe_load(f).get('torch_npu')
+try:
+ import torch_npu
+except ImportError:
+ is_gpu = True
+else:
+ is_gpu = False
+
def get_npu_ops():
global WrapNpuOps
@@ -46,13 +53,19 @@ class HOOKNpuOP(object):
class NpuOPTemplate(HOOKModule):
- def __init__(self, op_name, hook):
+ def __init__(self, op_name, hook, need_hook=True):
self.op_name_ = op_name
self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP
- super().__init__(hook)
+ self.need_hook = need_hook
+ if need_hook:
+ super().__init__(hook)
@torch_device_guard
def forward(self, *args, **kwargs):
+ if not self.need_hook:
+ if self.op_name_ not in npu_custom_functions:
+ raise Exception(f'There is not bench function {self.op_name_}')
+ return npu_custom_functions[self.op_name_](*args, **kwargs)
if torch_without_guard_version:
return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs)
else:
@@ -60,7 +73,6 @@ class NpuOPTemplate(HOOKModule):
def wrap_npu_op(op_name, hook):
-
def npu_op_template(*args, **kwargs):
return NpuOPTemplate(op_name, hook)(*args, **kwargs)
diff --git a/debug/accuracy_tools/msprobe/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py
index 422d36d6ac7..f9368a08745 100644
--- a/debug/accuracy_tools/msprobe/pytorch/module_processer.py
+++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py
@@ -1,15 +1,17 @@
from functools import wraps
+
import torch
from torch.utils.hooks import BackwardHook
+
from msprobe.core.common.const import Const
from msprobe.core.data_dump.scope import ModuleRangeScope
class ModuleProcesser:
+ module_count = {}
module_stack = []
api_parent_node = ""
module_node = {}
- current_module_name = ""
def __init__(self, scope):
if isinstance(scope, ModuleRangeScope):
@@ -19,7 +21,6 @@ class ModuleProcesser:
BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook)
BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook)
BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook)
- self.module_count = {}
@staticmethod
def filter_tensor_and_tuple(func):
@@ -55,11 +56,26 @@ class ModuleProcesser:
else:
return result
+ @staticmethod
+ def module_count_func(module_name):
+ if module_name not in ModuleProcesser.module_count:
+ ModuleProcesser.module_count[module_name] = 0
+ else:
+ ModuleProcesser.module_count[module_name] += 1
+ return ModuleProcesser.module_count[module_name]
+
+ @classmethod
+ def reset_module_stats(cls):
+ cls.module_count = {}
+ cls.module_stack = []
+ cls.api_parent_node = ""
+ cls.module_node = {}
+
def node_hook(self, name_prefix, start_or_stop, **kwargs):
def pre_hook(module, input, output=None):
try:
- index = self.module_count_func(name_prefix)
+ index = ModuleProcesser.module_count_func(name_prefix)
except IndexError as e:
index = None
pass
@@ -89,10 +105,3 @@ class ModuleProcesser:
return pre_hook
else:
return end_hook
-
- def module_count_func(self, module_name):
- if module_name not in self.module_count:
- self.module_count[module_name] = 0
- else:
- self.module_count[module_name] += 1
- return self.module_count[module_name]
diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
index a3d765f3a4d..ceec92a633a 100644
--- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py
+++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py
@@ -32,12 +32,12 @@ class StatisticsConfig(BaseConfig):
class OverflowCheckConfig(BaseConfig):
def __init__(self, json_config):
super().__init__(json_config)
- self.overflow_num = json_config.get("overflow_nums")
+ self.overflow_nums = json_config.get("overflow_nums")
self.check_mode = json_config.get("check_mode")
self.check_overflow_config()
def check_overflow_config(self):
- if self.overflow_num is not None and not isinstance(self.overflow_num, int):
+ if self.overflow_nums is not None and not isinstance(self.overflow_nums, int):
raise Exception("overflow_num is invalid")
if self.check_mode is not None and self.check_mode not in ["all", "aicore", "atomic"]:
raise Exception("check_mode is invalid")
diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py
index daeda889879..6b8d67abc9f 100644
--- a/debug/accuracy_tools/msprobe/pytorch/service.py
+++ b/debug/accuracy_tools/msprobe/pytorch/service.py
@@ -2,17 +2,18 @@ import functools
import os
from pathlib import Path
-from msprobe.pytorch.common.log import logger
-from msprobe.core.common.file_check import FileChecker, check_path_before_create
from msprobe.core.common.const import Const, FileCheckConst
from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
+from msprobe.core.common.file_check import FileChecker, check_path_before_create
from msprobe.core.data_dump.data_collector import build_data_collector
-from msprobe.core.data_dump.scope import BaseScope
from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
+from msprobe.core.data_dump.scope import BaseScope
+from msprobe.pytorch.common.log import logger
from msprobe.pytorch.common.utils import get_rank_if_initialized
-from msprobe.pytorch.module_processer import ModuleProcesser
from msprobe.pytorch.hook_module import remove_dropout
from msprobe.pytorch.hook_module.api_registry import api_register
+from msprobe.pytorch.hook_module.hook_module import HOOKModule
+from msprobe.pytorch.module_processer import ModuleProcesser
class Service:
@@ -67,7 +68,8 @@ class Service:
if not self.switch:
return
if self.data_collector:
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
+ # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output)
pid = os.getpid()
@@ -82,6 +84,9 @@ class Service:
self.current_iter += 1
self.data_collector.update_iter(self.current_iter)
+ ModuleProcesser.reset_module_stats()
+ HOOKModule.reset_module_stats()
+
def start(self, model, api_origin=False):
self.model = model
if self.config.step and self.current_iter > max(self.config.step):
diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
index 06c7378ed36..8b2138a485b 100644
--- a/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
+++ b/debug/accuracy_tools/msprobe/test/core_ut/test_common_config.py
@@ -121,7 +121,7 @@ class TestCommonConfig(TestCase):
self.assertIsNone(base_config.backward_input)
self.assertIsNone(base_config.file_format)
self.assertIsNone(base_config.summary_mode)
- self.assertIsNone(base_config.overflow_num)
+ self.assertIsNone(base_config.overflow_nums)
self.assertIsNone(base_config.check_mode)
json_config.update({"scope": "Tensor_Add"})
diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
index 673386afb5d..30212d95e62 100644
--- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
+++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py
@@ -19,7 +19,7 @@ from unittest.mock import patch, mock_open
from msprobe.core.common.const import Const
from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config,
- TensorConfig, StatisticsConfig, OverflowCheck)
+ TensorConfig, StatisticsConfig, OverflowCheckConfig)
class TestMsConfig(TestCase):
@@ -62,7 +62,7 @@ class TestMsConfig(TestCase):
self.assertTrue(isinstance(task_config, StatisticsConfig))
task_config = parse_task_config("overflow_check", mock_json_config)
- self.assertTrue(isinstance(task_config, OverflowCheck))
+ self.assertTrue(isinstance(task_config, OverflowCheckConfig))
with self.assertRaises(Exception) as context:
parse_task_config("free_benchmark", mock_json_config)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
index 771e0423804..27126cdddda 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py
@@ -48,7 +48,7 @@ class TestMultiRunUT(unittest.TestCase):
device_id=[0, 1],
result_csv_path='result.csv',
total_items=2,
- real_data_path=None
+ config_path=None
)
mock_file.side_effect = [
@@ -81,7 +81,7 @@ class TestMultiRunUT(unittest.TestCase):
args.jit_compile = False
args.device_id = [0, 1]
args.result_csv_path = None
- args.real_data_path = None
+ args.config_path = None
config = prepare_config(args)
diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
index c344f0b66b0..470390d77b2 100644
--- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
+++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_config.py
@@ -45,7 +45,7 @@ class TestPtConfig(TestCase):
}
}
result = parse_task_config(Const.OVERFLOW_CHECK, overflow_check_config)
- self.assertEqual(result.overflow_num, 1)
+ self.assertEqual(result.overflow_nums, 1)
self.assertEqual(result.check_mode, "all")
free_benchmark_config = {
diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py
index 4e0eaa1f375..afbf8feb3a0 100644
--- a/debug/accuracy_tools/setup.py
+++ b/debug/accuracy_tools/setup.py
@@ -14,7 +14,7 @@
import setuptools
-__version__ = '1.0.0'
+__version__ = '1.0.1'
INSTALL_REQUIRED = [
"wheel",
diff --git a/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml b/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml
deleted file mode 100644
index 3133d6400fb..00000000000
--- a/plugins/tensorboard-plugins/.github/workflows/libkineto_ci.yml
+++ /dev/null
@@ -1,56 +0,0 @@
-name: LIBKINETOCI
-
-on:
- push:
- branches:
- - main
- pull_request:
- branches:
- - main
-
-jobs:
- build:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- os: [ubuntu-latest]
-
- steps:
- - uses: actions/checkout@v2
- - name: Checkout submodules
- shell: bash
- run: |
- auth_header="$(git config --local --get http.https://github.com/.extraheader)"
- git submodule sync --recursive
- git -c "http.extraheader=$auth_header" -c protocol.version=2 submodule update --init --force --recursive --depth=1
-
- - name: Get env vars
- run: |
- echo GITHUB_WORKFLOW = $GITHUB_WORKFLOW
- echo HOME = $HOME
- echo GITHUB_ACTION = $GITHUB_ACTION
- echo GITHUB_ACTIONS = $GITHUB_ACTIONS
- echo GITHUB_REPOSITORY = $GITHUB_REPOSITORY
- echo GITHUB_EVENT_NAME = $GITHUB_EVENT_NAME
- echo GITHUB_EVENT_PATH = $GITHUB_EVENT_PATH
- echo GITHUB_WORKSPACE = $GITHUB_WORKSPACE
- echo GITHUB_SHA = $GITHUB_SHA
- echo GITHUB_REF = $GITHUB_REF
- c++ --verbose
-
- # TODO: Figure out how to install cupti headers T84637671
- - name: Build static lib
- run: |
- set -e
- mkdir build_static
- cd build_static
- cmake -DKINETO_LIBRARY_TYPE=static ../libkineto/
- make -j
-
- - name: Build shared lib
- run: |
- set -e
- mkdir build_shared
- cd build_shared
- cmake -DKINETO_LIBRARY_TYPE=shared ../libkineto/
- make -j
diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml
deleted file mode 100644
index 9bdafcc4426..00000000000
--- a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_build_pip_package.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-name: Build torch-tb-profiler Pip Package
-
-on:
- # TODO: Add an on_release trigger to build on tags
- workflow_dispatch:
-
-jobs:
- build-package:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2
- - name: build pip package
- run: |
- set -e
- cd tb_plugin
- python setup.py sdist bdist_wheel
- cd dist/
- pip install *.whl
- python -c "import torch_tb_profiler;print(torch_tb_profiler.__version__)"
diff --git a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml b/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml
deleted file mode 100644
index 1b59a7bf90a..00000000000
--- a/plugins/tensorboard-plugins/.github/workflows/tb_plugin_ci.yml
+++ /dev/null
@@ -1,57 +0,0 @@
-name: TB_Plugin_CI
-
-on:
- push:
- branches:
- - main
- - release/**
- - plugin/**
-
- pull_request:
- branches:
- - main
- - release/**
- - plugin/**
-
-jobs:
- generate-matrix:
- runs-on: ubuntu-latest
- outputs:
- matrix: ${{ steps.set-matrix.outputs.matrix }}
- steps:
- - id: set-matrix
- run: |
- echo $GITHUB_BASE_REF
- if [ $GITHUB_BASE_REF == "plugin/vnext" ]
- then
- echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\"]}"
- else
- echo "::set-output name=matrix::{\"python-version\":[3.7, 3.8, 3.9], \"cuda-version\":[\"cpu\"], \"pytorch-version\":[\"nightly\", \"1.11rc\", \"stable\"]}"
- fi
-
- build:
- needs: generate-matrix
- runs-on: ubuntu-latest
- strategy:
- matrix: ${{fromJSON(needs.generate-matrix.outputs.matrix)}}
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
- architecture: 'x64'
- - name: Test
- env:
- CUDA_VERSION: ${{ matrix.cuda-version }}
- PYTORCH_VERSION: ${{ matrix.pytorch-version }}
- TORCH_PROFILER_LOG_LEVEL: DEBUG
- GRPC_VERBOSITY: DEBUG
- GRPC_ENABLE_FORK_SUPPORT: 'False'
- run: |
- set -e
- cd tb_plugin
- sh ./ci_scripts/install_env.sh
- pip install .[gs]
- cd test
- pytest
diff --git a/plugins/tensorboard-plugins/.gitignore b/plugins/tensorboard-plugins/.gitignore
deleted file mode 100644
index ce186381c0b..00000000000
--- a/plugins/tensorboard-plugins/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-# ignore common items
-.idea
-.vscode
diff --git a/plugins/tensorboard-plugins/.gitmodules b/plugins/tensorboard-plugins/.gitmodules
deleted file mode 100644
index 4660ee8bc9e..00000000000
--- a/plugins/tensorboard-plugins/.gitmodules
+++ /dev/null
@@ -1,6 +0,0 @@
-[submodule "libkineto/third_party/googletest"]
- path = libkineto/third_party/googletest
- url = https://github.com/google/googletest.git
-[submodule "libkineto/third_party/fmt"]
- path = libkineto/third_party/fmt
- url = https://github.com/fmtlib/fmt.git
diff --git a/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md b/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md
deleted file mode 100644
index a0cbeaab765..00000000000
--- a/plugins/tensorboard-plugins/CODE_OF_CONDUCT.md
+++ /dev/null
@@ -1,77 +0,0 @@
-# Code of Conduct
-
-## Our Pledge
-
-In the interest of fostering an open and welcoming environment, we as
-contributors and maintainers pledge to make participation in our project and
-our community a harassment-free experience for everyone, regardless of age, body
-size, disability, ethnicity, sex characteristics, gender identity and expression,
-level of experience, education, socio-economic status, nationality, personal
-appearance, race, religion, or sexual identity and orientation.
-
-## Our Standards
-
-Examples of behavior that contributes to creating a positive environment
-include:
-
-* Using welcoming and inclusive language
-* Being respectful of differing viewpoints and experiences
-* Gracefully accepting constructive criticism
-* Focusing on what is best for the community
-* Showing empathy towards other community members
-
-Examples of unacceptable behavior by participants include:
-
-* The use of sexualized language or imagery and unwelcome sexual attention or
- advances
-* Trolling, insulting/derogatory comments, and personal or political attacks
-* Public or private harassment
-* Publishing others' private information, such as a physical or electronic
- address, without explicit permission
-* Other conduct which could reasonably be considered inappropriate in a
- professional setting
-
-## Our Responsibilities
-
-Project maintainers are responsible for clarifying the standards of acceptable
-behavior and are expected to take appropriate and fair corrective action in
-response to any instances of unacceptable behavior.
-
-Project maintainers have the right and responsibility to remove, edit, or
-reject comments, commits, code, wiki edits, issues, and other contributions
-that are not aligned to this Code of Conduct, or to ban temporarily or
-permanently any contributor for other behaviors that they deem inappropriate,
-threatening, offensive, or harmful.
-
-## Scope
-
-This Code of Conduct applies within all project spaces, and it also applies when
-an individual is representing the project or its community in public spaces.
-Examples of representing a project or community include using an official
-project e-mail address, posting via an official social media account, or acting
-as an appointed representative at an online or offline event. Representation of
-a project may be further defined and clarified by project maintainers.
-
-## Enforcement
-
-Instances of abusive, harassing, or otherwise unacceptable behavior may be
-reported by contacting the project team at . All
-complaints will be reviewed and investigated and will result in a response that
-is deemed necessary and appropriate to the circumstances. The project team is
-obligated to maintain confidentiality with regard to the reporter of an incident.
-Further details of specific enforcement policies may be posted separately.
-
-Project maintainers who do not follow or enforce the Code of Conduct in good
-faith may face temporary or permanent repercussions as determined by other
-members of the project's leadership.
-
-## Attribution
-
-This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
-available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
-
-[homepage]: https://www.contributor-covenant.org
-
-For answers to common questions about this code of conduct, see
-https://www.contributor-covenant.org/faq
-
diff --git a/plugins/tensorboard-plugins/CONTRIBUTING.md b/plugins/tensorboard-plugins/CONTRIBUTING.md
deleted file mode 100644
index a2e931bb6f0..00000000000
--- a/plugins/tensorboard-plugins/CONTRIBUTING.md
+++ /dev/null
@@ -1,34 +0,0 @@
-# Contributing to Kineto
-We want to make contributing to this project as easy and transparent as
-possible.
-
-## Code of Conduct
-The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md).
-
-## Pull Requests
-We actively welcome your pull requests.
-
-1. Fork the repo and create your branch from `main`.
-2. If you've added code that should be tested, add tests.
-3. If you've changed APIs, update the documentation.
-4. Ensure the test suite passes.
-5. Make sure your code lints.
-6. If you haven't already, complete the Contributor License Agreement ("CLA").
-
-## Contributor License Agreement ("CLA")
-In order to accept your pull request, we need you to submit a CLA. You only need
-to do this once to work on any of Facebook's open source projects.
-
-Complete your CLA here:
-
-## Issues
-We use GitHub issues to track public bugs. Please ensure your description is
-clear and has sufficient instructions to be able to reproduce the issue.
-
-Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
-disclosure of security bugs. In those cases, please go through the process
-outlined on that page and do not file a public issue.
-
-## License
-By contributing to Kineto, you agree that your contributions will be licensed
-under the LICENSE file in the root directory of this source tree.
diff --git a/plugins/tensorboard-plugins/LICENSE b/plugins/tensorboard-plugins/LICENSE
deleted file mode 100644
index edb179715b5..00000000000
--- a/plugins/tensorboard-plugins/LICENSE
+++ /dev/null
@@ -1,33 +0,0 @@
-BSD License
-
-For Kineto software
-
-Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
-
-All contributions by Microsoft:
-Copyright (c) Microsoft Corporation. (The Azure AI Platform team)
-
-Redistribution and use in source and binary forms, with or without modification,
-are permitted provided that the following conditions are met:
-
- * Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
- * Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
- * Neither the name Facebook nor the names of its contributors may be used to
- endorse or promote products derived from this software without specific
- prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
-ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
-ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/plugins/tensorboard-plugins/README.md b/plugins/tensorboard-plugins/README.md
deleted file mode 100644
index 3a18f4c6239..00000000000
--- a/plugins/tensorboard-plugins/README.md
+++ /dev/null
@@ -1,38 +0,0 @@
-# Kineto
-
-Kineto is part of the PyTorch Profiler.
-
-The Kineto project was started to help enable
-- **performance observability and diagnostics** across common ML bottleneck components
-- **actionable recommendations** for common issues
-- integration of external system-level profiling tools
-- integration with popular visualization platforms and analysis pipelines
-
-A central component is libkineto, a profiling library with special focus on low-overhead GPU timeline tracing.
-
-The PyTorch Profiler TensorBoard plugin provides powerful and intuitive visualizations of profiling results, as well as actionable recommendations, and is the best way to experience the new PyTorch Profiler.
-
-## Libkineto
-Libkineto is an in-process profiling library integrated with the PyTorch Profiler. Please refer to the [README](libkineto/README.md) file in the `libkineto` folder as well as documentation on the [new PyTorch Profiler API](https://pytorch.org/docs/master/profiler.html).
-
-## PyTorch TensorBoard Profiler NPU Plugin
-The goal of the PyTorch TensorBoard Profiler is to provide a seamless and intuitive end-to-end profiling experience, including straightforward collection from PyTorch and insightful visualizations and recommendations in the TensorBoard UI.
-Please refer to the [README](tb_plugin/README.md) file in the `tb_plugin` folder.
-
-## Future Development Direction:
-Some areas we're currently working on:
-- Support for tracing distributed workloads
-- Trace processing, analysis and recommendation engine
-- System-level activities, multiple tracing sources
-- Profiling and monitoring daemon for larger scale deployments
-
-## Releases and Contributing
-We will follow the PyTorch release schedule which roughly happens on a 3 month basis.
-
-We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
-
-If you plan to contribute new features, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the infrastructure in a different direction than you might be aware of. We expect the architecture to keep evolving.
-
-## License
-Kineto has a BSD-style license, as found in the [LICENSE](LICENSE) file.
-
diff --git a/plugins/tensorboard-plugins/libkineto/CMakeLists.txt b/plugins/tensorboard-plugins/libkineto/CMakeLists.txt
deleted file mode 100644
index 63966de803a..00000000000
--- a/plugins/tensorboard-plugins/libkineto/CMakeLists.txt
+++ /dev/null
@@ -1,198 +0,0 @@
-cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
-
-list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
-
-#install libraries into correct locations on all platforms
-include(GNUInstallDirs)
-
-# function to extract filelists from libkineto_defs.bzl file
-find_package(PythonInterp)
-function(get_filelist name outputvar)
- execute_process(
- COMMAND "${PYTHON_EXECUTABLE}" -c
- "exec(open('libkineto_defs.bzl').read());print(';'.join(${name}))"
- WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
- OUTPUT_VARIABLE _tempvar)
- string(REPLACE "\n" "" _tempvar "${_tempvar}")
- set(${outputvar} ${_tempvar} PARENT_SCOPE)
-endfunction()
-
-project(kineto VERSION 0.1 LANGUAGES CXX C)
-
-set(KINETO_LIBRARY_TYPE "default" CACHE STRING
- "Type of library (default, static or shared) to build")
-set_property(CACHE KINETO_LIBRARY_TYPE PROPERTY STRINGS default shared)
-option(KINETO_BUILD_TESTS "Build kineto unit tests" ON)
-
-set(LIBKINETO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src")
-set(LIBKINETO_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/include")
-set(LIBKINETO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
-set(LIBKINETO_THIRDPARTY_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party")
-set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
-
-#We should default to a Release build
-if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "")
- set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
-endif()
-
-if (NOT CUDA_SOURCE_DIR)
- set(CUDA_SOURCE_DIR "$ENV{CUDA_SOURCE_DIR}")
- message(INFO " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}")
-endif()
-
-if (NOT ROCM_SOURCE_DIR)
- set(ROCM_SOURCE_DIR "$ENV{ROCM_SOURCE_DIR}")
- message(INFO " ROCM_SOURCE_DIR = ${ROCM_SOURCE_DIR}")
-endif()
-
-# Set LIBKINETO_NOCUPTI to explicitly disable CUPTI
-# Otherwise, CUPTI is disabled if not found
-IF (NOT CUDA_SOURCE_DIR OR NOT CUPTI_INCLUDE_DIR OR NOT CUDA_cupti_LIBRARY)
- set(LIBKINETO_NOCUPTI ON CACHE BOOL "" FORCE)
-endif()
-
-IF (NOT ROCM_SOURCE_DIR AND NOT ROCTRACER_INCLUDE_DIR)
- set(LIBKINETO_NOROCTRACER ON CACHE BOOL "" FORCE)
-endif()
-
-# Define file lists
-if (LIBKINETO_NOCUPTI AND LIBKINETO_NOROCTRACER)
- get_filelist("get_libkineto_cpu_only_srcs(with_api=False)" LIBKINETO_SRCS)
- message(INFO " CUPTI unavailable or disabled - not building GPU profilers")
-elseif(NOT LIBKINETO_NOROCTRACER)
- get_filelist("get_libkineto_roctracer_srcs()" LIBKINETO_SRCS)
- message(INFO " Building with roctracer")
-else()
- get_filelist("get_libkineto_cupti_srcs(with_api=False)" LIBKINETO_SRCS)
-endif()
-get_filelist("get_libkineto_public_headers()" LIBKINETO_PUBLIC_HEADERS)
-get_filelist("get_libkineto_api_srcs()" LIBKINETO_API_SRCS)
-
-add_library(kineto_base OBJECT ${LIBKINETO_SRCS})
-add_library(kineto_api OBJECT ${LIBKINETO_API_SRCS})
-
-# Make libraries depend on libkineto_defs.bzl
-add_custom_target(libkineto_defs.bzl DEPENDS libkineto_defs.bzl)
-add_dependencies(kineto_base libkineto_defs.bzl)
-
-set_target_properties(kineto_base kineto_api PROPERTIES
- CXX_STANDARD 14
- CXX_STANDARD_REQUIRED YES
- CXX_EXTENSIONS NO
- CXX_VISIBILITY_PRESET hidden)
-
-set(KINETO_COMPILE_OPTIONS "-DKINETO_NAMESPACE=libkineto")
-list(APPEND KINETO_COMPILE_OPTIONS "-DFMT_HEADER_ONLY")
-if(NOT MSVC)
- list(APPEND KINETO_COMPILE_OPTIONS "-std=c++14")
-else()
- list(APPEND KINETO_COMPILE_OPTIONS "/std:c++14")
- list(APPEND KINETO_COMPILE_OPTIONS "-DWIN32_LEAN_AND_MEAN")
- list(APPEND KINETO_COMPILE_OPTIONS "-DNOGDI")
-endif()
-if (NOT LIBKINETO_NOCUPTI)
- list(APPEND KINETO_COMPILE_OPTIONS "-DHAS_CUPTI")
-endif()
-if (NOT LIBKINETO_NOROCTRACER)
- target_compile_options(kineto_base PRIVATE "-DHAS_ROCTRACER")
- target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_HCC__")
- target_compile_options(kineto_base PRIVATE "-D__HIP_PLATFORM_AMD__")
-endif()
-
-target_compile_options(kineto_base PRIVATE "${KINETO_COMPILE_OPTIONS}")
-target_compile_options(kineto_api PRIVATE "${KINETO_COMPILE_OPTIONS}")
-
-if(NOT TARGET fmt)
- if(NOT FMT_SOURCE_DIR)
- set(FMT_SOURCE_DIR "${LIBKINETO_THIRDPARTY_DIR}/fmt"
- CACHE STRING "fmt source directory from submodules")
- endif()
-
- # Build FMT.
- # FMT and some other libraries use BUILD_SHARED_LIBS to control
- # the library type.
- # Save and restore the value after configuring FMT
- set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
- set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
- set(FMT_LIBRARY_TYPE static CACHE STRING "Set lib type to static")
- add_subdirectory("${FMT_SOURCE_DIR}" "${LIBKINETO_BINARY_DIR}/fmt")
- set_property(TARGET fmt PROPERTY POSITION_INDEPENDENT_CODE ON)
- set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
-endif()
-
-set(FMT_INCLUDE_DIR "${FMT_SOURCE_DIR}/include")
-message(STATUS "Kineto: FMT_SOURCE_DIR = ${FMT_SOURCE_DIR}")
-message(STATUS "Kineto: FMT_INCLUDE_DIR = ${FMT_INCLUDE_DIR}")
-if (NOT CUPTI_INCLUDE_DIR)
- set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include")
-endif()
-if (NOT CUDA_INCLUDE_DIRS)
- set(CUDA_INCLUDE_DIRS "${CUDA_SOURCE_DIR}/include")
-endif()
-if (NOT ROCTRACER_INCLUDE_DIR)
- set(ROCTRACER_INCLUDE_DIR "${ROCM_SOURCE_DIR}/roctracer/include")
-endif()
-if (NOT ROCM_INCLUDE_DIRS)
- set(ROCM_INCLUDE_DIRS "${ROCM_SOURCE_DIR}/include")
-endif()
-
-message(INFO " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
-message(INFO " ROCTRACER_INCLUDE_DIR = ${ROCTRACER_INCLUDE_DIR}")
-
-target_include_directories(kineto_base PUBLIC
- $
- $
- $
- $
- $
- $
- $)
-
-target_include_directories(kineto_api PUBLIC
- $
- $)
-
-if(KINETO_LIBRARY_TYPE STREQUAL "default")
- add_library(kineto
- $
- $)
-elseif(KINETO_LIBRARY_TYPE STREQUAL "static")
- add_library(kineto STATIC
- $
- $)
-elseif(KINETO_LIBRARY_TYPE STREQUAL "shared")
- add_library(kineto SHARED
- $)
- set_property(TARGET kineto_base PROPERTY POSITION_INDEPENDENT_CODE ON)
- set_target_properties(kineto PROPERTIES
- CXX_VISIBILITY_PRESET hidden)
-else()
- message(FATAL_ERROR "Unsupported library type ${KINETO_LIBRARY_TYPE}")
-endif()
-
-if(NOT LIBKINETO_NOROCTRACER)
- find_library(ROCTRACER_LIBRARY NAMES libroctracer64.so HINTS /opt/rocm/roctracer/lib)
- target_link_libraries(kineto "${ROCTRACER_LIBRARY}")
- find_library(KINETO_HIP_LIBRARY NAMES libamdhip64.so HINTS /opt/rocm/lib)
- target_link_libraries(kineto "${KINETO_HIP_LIBRARY}")
-endif()
-
-if(NOT LIBKINETO_NOCUPTI)
- target_link_libraries(kineto "${CUDA_cupti_LIBRARY}")
-endif()
-target_link_libraries(kineto $)
-add_dependencies(kineto fmt::fmt-header-only)
-
-install(TARGETS kineto EXPORT kinetoLibraryConfig
- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
-
-install(FILES ${LIBKINETO_PUBLIC_HEADERS}
- DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/kineto")
-
-install(EXPORT kinetoLibraryConfig DESTINATION share/cmake/kineto
- FILE kinetoLibraryConfig.cmake)
-
-if(KINETO_BUILD_TESTS)
- add_subdirectory(test)
-endif()
diff --git a/plugins/tensorboard-plugins/libkineto/README.md b/plugins/tensorboard-plugins/libkineto/README.md
deleted file mode 100644
index 37127ca5aa8..00000000000
--- a/plugins/tensorboard-plugins/libkineto/README.md
+++ /dev/null
@@ -1,65 +0,0 @@
-# Libkineto
-
-Libkineto is an in-process profiling library, part of the Kineto performance
-tools project.
-
-The library provides a way to collect GPU traces and metrics from the host
-process, either via the library public API or by sending a signal, if enabled.
-
-Currently only NVIDIA GPUs are supported.
-
-## Build Notes
-Libkineto uses the standard CMAKE-based build flow.
-
-### Dependencies
-Libkineto requires gcc 5+ and:
-
-- NVIDIA CUPTI: used to collect traces and metrics from NVIDIA GPUs.
-- fmt: used for its convenient and lightweight string formatting functionality.
-- googletest: required to build and run Kineto's tests.
- - **googletest is not required** if you don't want to run Kineto tests.
-By default, building of tests is **on**. Turn it off by setting `KINETO_BUILD_TESTS` to **off**.
-
-You can download [NVIDIA CUPTI][1], [fmt][2], [googletest][3] and set
-`CUDA_SOURCE_DIR`, `FMT_SOURCE_DIR`, `GOOGLETEST_SOURCE_DIR` respectively for
-cmake to find these libraries. If the fmt and googletest variables are not set, cmake will
-build the git submodules found in the `third_party` directory.
-If `CUDA_SOURCE_DIR` is not set, libkineto will fail to build.
-
-### Building Libkineto
-
-```
-# Check out repo and sub modules
-git clone --recursive https://github.com/pytorch/kineto.git
-# Build libkineto with cmake
-cd kineto/libkineto
-mkdir build && cd build
-cmake ..
-make
-```
-
-To run the tests after building libkineto (if tests are built), use the following
-command:
-```
-make test
-```
-
-### Installing Libkineto
-```
-make install
-```
-
-## How Libkineto works
-We will provide a high-level overview, design philosophy and brief descriptions of various
-parts of Libkineto in upcoming blogs.
-
-## Full documentation
-We strive to keep our source files readable. The best and up-to-date
-documentation is available in the source files.
-
-## License
-Libkineto is BSD licensed, as detailed in the [LICENSE](../LICENSE) file.
-
-[1]:https://developer.nvidia.com/CUPTI-CTK10_2
-[2]:https://github.com/fmt
-[3]:https://github.com/google/googletest
diff --git a/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h b/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h
deleted file mode 100644
index 1cadf4906c1..00000000000
--- a/plugins/tensorboard-plugins/libkineto/include/AbstractConfig.h
+++ /dev/null
@@ -1,113 +0,0 @@
-// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
-
-#pragma once
-
-#include
-#include