From 5e87df8edcbfa9bbc81030385996e129929b4cd9 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Mon, 20 Jan 2025 11:05:54 +0800 Subject: [PATCH 01/20] pt&ms debugger.save + base.py ut --- .../msprobe/core/common/const.py | 3 +- .../msprobe/core/data_dump/data_collector.py | 13 ++++ .../core/data_dump/data_processor/base.py | 73 ++++++++++++++++++- .../msprobe/core/data_dump/json_writer.py | 21 +++++- .../mindspore/debugger/precision_debugger.py | 14 ++++ .../msprobe/mindspore/service.py | 61 +++++++++++++++- .../pytorch/debugger/precision_debugger.py | 9 +++ .../accuracy_tools/msprobe/pytorch/service.py | 57 ++++++++++++++- .../data_dump/data_processor/test_base.py | 36 ++++++++- 9 files changed, 277 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 5a165443be..21804eead3 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -109,7 +109,8 @@ class Const: LEVEL_L1 = "L1" LEVEL_L2 = "L2" LEVEL_MIX = "mix" - LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX] + LEVEL_DEBUG = "debug" + LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX, LEVEL_DEBUG] ATTR_NAME_PREFIX = "wrap_" ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX) KERNEL_DUMP = "kernel_dump" diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index d613517cc8..c38d805c16 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -162,6 +162,19 @@ class DataCollector: def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) + def debug_data_collect_forward(self, variable, name_with_count): + self.update_api_or_module_name(name_with_count) + data_info = self.data_processor.analyze_debug_forward(variable) + self.data_writer.update_debug({name_with_count: data_info}) + + def debug_data_collect_backward(self, variable, grad_name_with_count): + # prepare all None nested data structure + all_none_data_info = self.data_processor.analyze_element_to_all_none(variable) + self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) + + # register tensor backward hook + self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug) + def params_data_collect(self, name, param_name, pid, data): grad_name = name + Const.SEP + Const.PARAMS_GRAD # 校验scope和pid,以及当前name是否有过反向计算 diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index dfb4745b87..87792edca1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -17,6 +17,8 @@ import inspect import os from dataclasses import dataclass, is_dataclass from typing import Tuple, Dict, Optional, Any +from functools import partial +import copy import numpy as np @@ -104,6 +106,7 @@ class BaseDataProcessor: self.save_name = None if hasattr(config, "data_mode"): self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode) + self.custom_dump_data_name = None @property def data_path(self): @@ -219,7 +222,7 @@ class BaseDataProcessor: return cls.apply_transform_dict(args_dict, transform, depth) elif isinstance(args, (list, tuple)): result_list = cls.apply_transform_list(args, transform, depth) - return type(args)(result_list) + return result_list elif isinstance(args, dict): return cls.apply_transform_dict(args, transform, depth) elif args is not None: @@ -227,7 +230,7 @@ class BaseDataProcessor: return None else: return None - + @classmethod def apply_transform_dict(cls, args, transform, depth): result_dict = {} @@ -276,6 +279,68 @@ class BaseDataProcessor: def analyze_element(self, element): return self.recursive_apply_transform(element, self.analyze_single_element) + def analyze_element_to_all_none(self, element): + return self.recursive_apply_transform(element, self.analyze_element_to_none) + + def analyze_element_to_none(self, element, suffix_stack): + return None + + def analyze_hook_single_element(self, element, suffix_stack, hook_fn): + if hasattr(element, "register_hook"): + # element might be mindspore.tensor or torch.tensor + if hasattr(element, "requires_grad") and not element.requires_grad: + return + indexes = copy.deepcopy(suffix_stack) + wrap_hook_fn = partial(hook_fn, index=indexes) + def real_hook_fn(grad): + return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) + + @staticmethod + def set_value_into_nested_structure(data_structure, index, value): + ''' + Args: + data_structure: nested data structure + index: List[str] + value: value to be set + ''' + if not index: + raise ValueError(f"index need to be non empty when set value to nested data structure") + current_level = data_structure + for i in index[:-1]: + if isinstance(current_level, list): + current_level = current_level[int(i)] + elif isinstance(current_level, dict): + current_level = current_level[i] + else: + raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") + + if isinstance(current_level, list): + current_level[int(index[-1])] = value + elif isinstance(current_level, dict): + current_level[index[-1]] = value + else: + raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") + + def analyze_debug_forward(self, variable): + self.api_data_category = Const.OUTPUT + data_info = self.analyze_element(variable) + return data_info + + def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure): + def hook_fn(grad, index): + self.api_data_category = Const.OUTPUT + file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX + suffix = Const.SEP.join(index) + self.custom_dump_data_name = (grad_name_with_count + Const.SEP + Const.OUTPUT + Const.SEP + suffix + file_format) + grad_data_info = self.analyze_element(grad) + self.custom_dump_data_name = None + full_index = [grad_name_with_count] + index + self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info) + return grad + wrap_analyze_hook_single_element = partial(self.analyze_hook_single_element, hook_fn=hook_fn) + self.recursive_apply_transform(variable, wrap_analyze_hook_single_element) + def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} # check whether data_mode contains forward or input @@ -373,6 +438,10 @@ class BaseDataProcessor: return api_info_struct def get_save_file_path(self, suffix): + if self.custom_dump_data_name: + dump_data_name = self.custom_dump_data_name + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + return dump_data_name, file_path file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX if self.save_name is not None: dump_data_name = (self.save_name + file_format) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index e992359777..d404c6a628 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -15,6 +15,7 @@ import csv import os +import copy from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.file_utils import change_mode, FileOpen, save_json @@ -29,10 +30,12 @@ class DataWriter: self.construct_file_path = None self.free_benchmark_file_path = None self.dump_tensor_data_dir = None + self.debug_file_path = None self.flush_size = 1000 self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -53,6 +56,7 @@ class DataWriter: self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} def initialize_json_file(self, **kwargs): if not self.cache_data: @@ -63,14 +67,21 @@ class DataWriter: save_json(self.stack_file_path, self.cache_stack, indent=1) if not self.cache_construct: save_json(self.construct_file_path, self.cache_construct, indent=1) + if self.debug_file_path and not self.cache_debug: + debug_dict = copy.deepcopy(kwargs) + debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) + self.cache_debug = debug_dict + save_json(self.debug_file_path, self.cache_debug, indent=1) + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, - free_benchmark_file_path): + free_benchmark_file_path, debug_file_path=None): self.dump_file_path = dump_file_path self.stack_file_path = stack_file_path self.construct_file_path = construct_file_path self.dump_tensor_data_dir = dump_data_dir self.free_benchmark_file_path = free_benchmark_file_path + self.debug_file_path = debug_file_path def flush_data_periodically(self): dump_data = self.cache_data.get(Const.DATA) @@ -98,6 +109,9 @@ class DataWriter: def update_construct(self, new_data): self.cache_construct.update(new_data) + def update_debug(self, new_data): + self.cache_debug.update(new_data) + def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") save_json(file_path, self.cache_data, indent=1) @@ -108,6 +122,9 @@ class DataWriter: def write_construct_info_json(self, file_path): save_json(file_path, self.cache_construct, indent=1) + def write_debug_info_json(self, file_path): + save_json(file_path, self.cache_debug, indent=1) + def write_json(self): if self.cache_data: self.write_data_json(self.dump_file_path) @@ -115,3 +132,5 @@ class DataWriter: self.write_stack_info_json(self.stack_file_path) if self.cache_construct: self.write_construct_info_json(self.construct_file_path) + if self.cache_debug: + self.write_debug_info_json(self.debug_file_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 1e8d969077..41dd8b1bbb 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -153,6 +153,20 @@ class PrecisionDebugger: return instance.gm.monitor(opt) + @classmethod + def save(cls, variable, name, save_backward=True): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task not in [Const.TENSOR, Const.STATISTICS]: + return + + instance.config.execution_mode = cls._get_execution_mode() + if cls._need_service(): + if not instance.service: + instance.service = Service(instance.config) + instance.service.save(variable, name, save_backward) + @classmethod def _need_service(cls): instance = cls._instance diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 0bc8d4db88..a766f59e5c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -62,6 +62,8 @@ class Service: self.check_level_valid() self.should_stop_service = False self.params_grad_info = {} + if self.config.level == Const.LEVEL_DEBUG: + self.init_for_debug_level() @staticmethod def check_model_valid(model): @@ -139,7 +141,7 @@ class Service: # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 if data_info.get(grad_name): # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, + self.data_collector.handle_data(grad_name, data_info, flush=self.data_collector.data_processor.is_terminated) # 记录当前模块的参数梯度信息已占位 self.params_grad_info[grad_name] = True @@ -241,12 +243,16 @@ class Service: primitive.__class__ = new_primitive def step(self): + if self.config.level == Const.LEVEL_DEBUG: + return self.data_collector.write_json() self.current_iter += 1 self.data_collector.update_iter(self.current_iter) self.reset_status() def start(self, model=None): + if self.config.level == Const.LEVEL_DEBUG: + return self.start_call = True if self.should_stop_service: return @@ -291,6 +297,8 @@ class Service: JitDump.jit_dump_switch = True def stop(self): + if self.config.level == Const.LEVEL_DEBUG: + return if self.should_stop_service: return logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " @@ -308,6 +316,57 @@ class Service: self.data_collector.write_json() JitDump.jit_dump_switch = False + def save(self, variable, name, save_backward): + ''' + Args: + variable: Union[List[variable], dict{str: variable}, mindspore.tensor, str, float, int] + name: str + save_backward: boolean + Return: + void + ''' + if self.config.level != Const.LEVEL_DEBUG: + return + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + def init_for_debug_level(self): + try: + self.current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + self.current_rank = None + + # dir: dump_path -- rank{} -- debug.json + create_directory(self.config.dump_path) + self.dump_iter_dir = self.config.dump_path + cur_rank = self.current_rank if self.current_rank is not None else '' + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + if self.config.task in self.data_collector.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths( + dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None, debug_file_path) + + self.debug_variable_counter = defaultdict(int) + def need_end_service(self): if self.config.step and self.current_iter > max(self.config.step): return True diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 3588b02b42..a31444383c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -123,6 +123,15 @@ class PrecisionDebugger: else: instance.service.start(instance.model) + @classmethod + def save(cls, variable, name, save_backward=True): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task not in [Const.TENSOR, Const.STATISTICS]: + return + instance.service.save(variable, name, save_backward) + @classmethod def forward_backward_dump_end(cls): instance = cls._instance diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 7c73e1e4f2..3c97f0c04d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -16,6 +16,7 @@ import functools import os from collections import namedtuple +from collections import defaultdict import torch from msprobe.core.common.const import Const @@ -55,6 +56,8 @@ class Service: self.should_stop_service = False self.attl = None self.params_grad_info = {} + if self.config.level == Const.LEVEL_DEBUG: + self.init_for_debug_level() def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): @@ -114,7 +117,7 @@ class Service: # 当模块中的参数有requires_grad属性为True时,才会进行梯度计算,此时才需要占位 if data_info.get(grad_name): # 将grad_name的data_info先写入cache_data中, 梯度计算后再更新 - self.data_collector.handle_data(grad_name, data_info, + self.data_collector.handle_data(grad_name, data_info, flush=self.data_collector.data_processor.is_terminated) # 记录当前模块的参数梯度信息已占位 self.params_grad_info[grad_name] = True @@ -214,6 +217,8 @@ class Service: return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) def start(self, model): + if self.config.level == Const.LEVEL_DEBUG: + return if self.need_stop_service(): return @@ -238,6 +243,8 @@ class Service: logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") def stop(self): + if self.config.level == Const.LEVEL_DEBUG: + return if self.should_stop_service: return if self.config.step and self.current_iter not in self.config.step: @@ -253,6 +260,8 @@ class Service: self.data_collector.write_json() def step(self): + if self.config.level == Const.LEVEL_DEBUG: + return if self.should_stop_service: return self.data_collector.write_json() @@ -319,6 +328,52 @@ class Service: self.data_collector.update_dump_paths( dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path) + def save(self, variable, name, save_backward): + if self.config.level != Const.LEVEL_DEBUG: + return + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + + def init_for_debug_level(self): + try: + self.current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + self.current_rank = None + + # dir: dump_path -- rank{} -- debug.json + create_directory(self.config.dump_path) + self.dump_iter_dir = self.config.dump_path + cur_rank = self.current_rank if self.current_rank is not None else '' + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + if self.config.task in self.data_collector.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv") + debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths( + dump_file_path, stack_file_path, construct_file_path, dump_data_dir, + free_benchmark_file_path, debug_file_path) + + self.debug_variable_counter = defaultdict(int) + def register_hook_new(self): logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task)) if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 14e5f71f00..6047e1850c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -65,7 +65,7 @@ class TestBaseDataProcessor(unittest.TestCase): self.data_writer.dump_tensor_data_dir = "./dump_data" self.processor.current_api_or_module_name = "test_api" self.processor.api_data_category = "input" - + @patch('inspect.stack') def test_analyze_api_call_stack(self, mock_stack): mock_stack.return_value = [ @@ -81,8 +81,8 @@ class TestBaseDataProcessor(unittest.TestCase): result = BaseDataProcessor.analyze_api_call_stack('test_stack') expected_output = { 'test_stack': [ - 'File file5.py, line 50, in function5, \n code line 5', - 'File file6.py, line 60, in function6, \n code line 6', + 'File file5.py, line 50, in function5, \n code line 5', + 'File file6.py, line 60, in function6, \n code line 6', 'File file7.py, line 70, in function7, \n code line 7', ] } @@ -128,7 +128,7 @@ class TestBaseDataProcessor(unittest.TestCase): last_hidden_state: int = None hidden_states: Optional[Tuple[int, ...]] = None attentions: Optional[Tuple[int, ...]] = None - + myData = MyDataClass( last_hidden_state=1, hidden_states=(2, 3), @@ -253,3 +253,31 @@ class TestBaseDataProcessor(unittest.TestCase): expected_file_name = "test_api.input.suffix.pt" expected_file_path = os.path.join(self.data_writer.dump_tensor_data_dir, expected_file_name) self.assertEqual(result, (expected_file_name, expected_file_path)) + + def test_get_save_file_path_with_custom_dump_data_name(self): + self.config.framework = "pytorch" + self.processor.custom_dump_data_name = "custom_name" + result = self.processor.get_save_file_path("suffix") + expected_file_name = "custom_name" + expected_file_path = os.path.join(self.data_writer.dump_tensor_data_dir, expected_file_name) + self.assertEqual(result, (expected_file_name, expected_file_path)) + + def test_set_value_into_nested_structure(self): + dst_data_structure = {"key1": [None, None]} + result = self.processor.set_value_into_nested_structure(dst_data_structure, ["key1", "0"], 12) + excepted_result = {"key1": [None, 12]} + self.assertEqual(result, excepted_result) + + def test_analyze_element_to_all_none(self): + element = {"key1": [12, 3, {"key2": 10, "key3":["12"]}]} + result = self.processor.analyze_element_to_all_none(element) + excepted_result = {"key1": [None, None, {"key2": None, "key3":[None]}]} + self.assertEqual(result, excepted_result) + + def test_analyze_hook_single_element(self): + element = MagicMock() + element.hasattr = MagicMock(side_effect=lambda attr: attr == "register_hook") + element.requires_grad = True + hook_fn = MagicMock() + self.processor.analyze_hook_single_element(element, [1, 2], hook_fn) + element.register_hook.assert_called_once() -- Gitee From 5127a3bcc71ed1a3a746c6d85e47e62eac558a57 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Mon, 20 Jan 2025 19:13:31 +0800 Subject: [PATCH 02/20] ut --- .../data_dump/data_processor/test_base.py | 40 ++++++++++++++++++- .../core_ut/data_dump/test_data_collector.py | 17 ++++++++ .../core_ut/data_dump/test_json_writer.py | 6 +++ 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 6047e1850c..710f16a863 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -134,12 +134,12 @@ class TestBaseDataProcessor(unittest.TestCase): hidden_states=(2, 3), attentions=(4, 5) ) - expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': (4, 6), 'attentions': (8,10)} + expected_dataclass_res = {'last_hidden_state': 2, 'hidden_states': [4, 6], 'attentions': [8,10]} self.assertEqual(BaseDataProcessor.recursive_apply_transform(2, transform), 4) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myData, transform), expected_dataclass_res) self.assertEqual(BaseDataProcessor.recursive_apply_transform(myNamedTuple, transform), {'a': 2}) self.assertEqual(BaseDataProcessor.recursive_apply_transform([1, 2], transform), [2, 4]) - self.assertEqual(BaseDataProcessor.recursive_apply_transform((1, 2), transform), (2, 4)) + self.assertEqual(BaseDataProcessor.recursive_apply_transform((1, 2), transform), [2, 4]) self.assertEqual(BaseDataProcessor.recursive_apply_transform({'a': 1}, transform), {'a': 2}) @patch.object(logger, 'warning') @@ -281,3 +281,39 @@ class TestBaseDataProcessor(unittest.TestCase): hook_fn = MagicMock() self.processor.analyze_hook_single_element(element, [1, 2], hook_fn) element.register_hook.assert_called_once() + + def test_analyze_debug_backward(self): + variable = MagicMock() # 模拟输入变量 + grad_name_with_count = "grad_name_1" + nested_data_structure = {"key": "value"} # 模拟嵌套数据结构 + + self.processor.recursive_apply_transform = MagicMock() + self.processor.set_value_into_nested_structure = MagicMock() + self.processor.analyze_element = MagicMock() + self.processor.analyze_hook_single_element = MagicMock() + self.processor.recursive_apply_transform.assert_called_once_with( + variable, self.processor.analyze_hook_single_element + ) + + + self.processor.analyze_debug_backward(variable, grad_name_with_count, nested_data_structure)\ + # 验证 analyze_hook_single_element 的调用 + self.processor.analyze_hook_single_element.assert_called_once() + args, kwargs = self.processor.analyze_hook_single_element.call_args + self.assertIn("hook_fn", kwargs) + self.assertEqual(kwargs["hook_fn"].__name__, "hook_fn") + + grad = MagicMock() + index = ["layer1", "layer2"] + result = kwargs["hook_fn"](grad, index) + + # 验证 hook_fn 内部逻辑 + self.assertEqual(self.processor.api_data_category, "OUTPUT") + self.assertEqual(self.processor.custom_dump_data_name, "grad_name_1__OUTPUT__layer1__layer2.pt") + self.processor.analyze_element.assert_called_once_with(grad) + self.processor.set_value_into_nested_structure.assert_called_once_with( + nested_data_structure, ["grad_name_1", "layer1", "layer2"], "grad_data_info" + ) + self.assertIsNone(self.processor.custom_dump_data_name) + self.assertEqual(result, grad) + diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index 8357c41346..2b099942b2 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -101,3 +101,20 @@ class TestDataCollector(unittest.TestCase): self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") mock_handle_data.assert_called_with("name", {}, flush=False) + + @patch.object(DataWriter, "update_debug") + @patch.object(BaseDataProcessor, "analyze_debug_forward", return_value="data_info") + @patch.object(DataCollector, "update_api_or_module_name") + def test_debug_data_collect_forward(self, mock_update_api_or_module_name, _, mock_update_debug): + self.data_collector.debug_data_collect_forward("variable", "name_with_count") + mock_update_api_or_module_name.assert_called_with("name_with_count") + mock_update_debug.assert_called_with({"name_with_count": "data_info"}) + + @patch.object(DataWriter, "cache_debug") + @patch.object(DataWriter, "update_debug") + @patch.object(BaseDataProcessor, "analyze_debug_backward") + @patch.object(BaseDataProcessor, "analyze_element_to_all_none", return_value = "all_none_data_info") + def test_debug_data_collect_forward(self, _, mock_analyze_debug_backward, mock_update_debug, mock_cache_debug): + self.data_collector.debug_data_collect_forward("variable", "name_with_count") + mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) + mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", mock_cache_debug) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index 042e16c5d3..7c20810c14 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -54,12 +54,17 @@ class TestDataWriter(unittest.TestCase): self.data_writer.dump_file_path = os.path.join(self.cur_path, "dump.json") self.data_writer.stack_file_path = os.path.join(self.cur_path, "stack.json") self.data_writer.construct_file_path = os.path.join(self.cur_path, "construct.json") + self.data_writer.debug_file_path = os.path.join(self.cur_path, "debug.json") self.data_writer.initialize_json_file(task="tensor", level="L1") load_data = load_json(self.data_writer.dump_file_path) expected = {"task": "tensor", "level": "L1", "dump_data_dir": "./dump_tensor_data", "data": {}} self.assertEqual(load_data, expected) + load_data = load_json(self.data_writer.debug_file_path) + self.assertEqual(load_data, expected) + self.assertEqual(self.data_writer.cache_debug, expected) + expected = {} load_data = load_json(self.data_writer.stack_file_path) self.assertEqual(load_data, expected) @@ -70,6 +75,7 @@ class TestDataWriter(unittest.TestCase): remove_path(self.data_writer.dump_file_path) remove_path(self.data_writer.stack_file_path) remove_path(self.data_writer.construct_file_path) + remove_path(self.data_writer.debug_file_path) def test_update_dump_paths(self): self.assertIsNone(self.data_writer.dump_file_path) -- Gitee From d7f7893aa03bbb436fe646b107f8a2d7314f14e9 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Mon, 20 Jan 2025 20:17:52 +0800 Subject: [PATCH 03/20] ut v3 --- .../data_dump/data_processor/test_base.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index 710f16a863..d83cd7833d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -265,7 +265,7 @@ class TestBaseDataProcessor(unittest.TestCase): def test_set_value_into_nested_structure(self): dst_data_structure = {"key1": [None, None]} result = self.processor.set_value_into_nested_structure(dst_data_structure, ["key1", "0"], 12) - excepted_result = {"key1": [None, 12]} + excepted_result = {"key1": [12, None]} self.assertEqual(result, excepted_result) def test_analyze_element_to_all_none(self): @@ -282,34 +282,35 @@ class TestBaseDataProcessor(unittest.TestCase): self.processor.analyze_hook_single_element(element, [1, 2], hook_fn) element.register_hook.assert_called_once() - def test_analyze_debug_backward(self): + @patch("msprobe.core.data_dump.data_processor.base.partial") + def test_analyze_debug_backward(self, mock_partial): variable = MagicMock() # 模拟输入变量 grad_name_with_count = "grad_name_1" nested_data_structure = {"key": "value"} # 模拟嵌套数据结构 self.processor.recursive_apply_transform = MagicMock() self.processor.set_value_into_nested_structure = MagicMock() - self.processor.analyze_element = MagicMock() + self.processor.analyze_element = MagicMock(return_value="grad_data_info") self.processor.analyze_hook_single_element = MagicMock() - self.processor.recursive_apply_transform.assert_called_once_with( - variable, self.processor.analyze_hook_single_element - ) + # call + self.processor.analyze_debug_backward(variable, grad_name_with_count, nested_data_structure) - self.processor.analyze_debug_backward(variable, grad_name_with_count, nested_data_structure)\ - # 验证 analyze_hook_single_element 的调用 - self.processor.analyze_hook_single_element.assert_called_once() - args, kwargs = self.processor.analyze_hook_single_element.call_args + # check partial + args, kwargs = mock_partial.call_args self.assertIn("hook_fn", kwargs) + self.assertEqual(args[0], self.processor.analyze_hook_single_element) self.assertEqual(kwargs["hook_fn"].__name__, "hook_fn") + wrap_func = mock_partial.return_value + self.processor.recursive_apply_transform.assert_called_once_with(variable, wrap_func) + grad = MagicMock() index = ["layer1", "layer2"] result = kwargs["hook_fn"](grad, index) # 验证 hook_fn 内部逻辑 - self.assertEqual(self.processor.api_data_category, "OUTPUT") - self.assertEqual(self.processor.custom_dump_data_name, "grad_name_1__OUTPUT__layer1__layer2.pt") + self.assertEqual(self.processor.api_data_category, "output") self.processor.analyze_element.assert_called_once_with(grad) self.processor.set_value_into_nested_structure.assert_called_once_with( nested_data_structure, ["grad_name_1", "layer1", "layer2"], "grad_data_info" -- Gitee From f2b1537f4d8c0cef35f2a72118ac28a006205d17 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Mon, 20 Jan 2025 20:19:29 +0800 Subject: [PATCH 04/20] ut v3 --- .../test/core_ut/data_dump/data_processor/test_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index d83cd7833d..d36ebda877 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -264,9 +264,9 @@ class TestBaseDataProcessor(unittest.TestCase): def test_set_value_into_nested_structure(self): dst_data_structure = {"key1": [None, None]} - result = self.processor.set_value_into_nested_structure(dst_data_structure, ["key1", "0"], 12) + self.processor.set_value_into_nested_structure(dst_data_structure, ["key1", "0"], 12) excepted_result = {"key1": [12, None]} - self.assertEqual(result, excepted_result) + self.assertEqual(dst_data_structure, excepted_result) def test_analyze_element_to_all_none(self): element = {"key1": [12, 3, {"key2": 10, "key3":["12"]}]} -- Gitee From 396b56542907eb6889a1808c82482d443dd11e76 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Wed, 22 Jan 2025 11:02:42 +0800 Subject: [PATCH 05/20] cleancode --- .../msprobe/core/data_dump/data_collector.py | 9 ++ .../core/data_dump/data_processor/base.py | 82 ++++++++++--------- .../msprobe/core/data_dump/json_writer.py | 15 ++-- .../msprobe/mindspore/service.py | 29 ++++--- .../accuracy_tools/msprobe/pytorch/service.py | 32 ++++---- .../core_ut/data_dump/test_json_writer.py | 14 +++- .../test/mindspore_ut/test_ms_service.py | 16 ++-- 7 files changed, 109 insertions(+), 88 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index a0bd034925..ef676f119b 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -23,6 +23,15 @@ from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory +class DumpPathAggregation: + dump_file_path: str + stack_file_path: str + construct_file_path: str + dump_tensor_data_dir: str + free_benchmark_file_path: str + debug_file_path = None + + def build_data_collector(config): return DataCollector(config) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 87792edca1..bb6fe4ae39 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -200,6 +200,48 @@ class BaseDataProcessor: allowed_data_mode += [Const.INPUT, Const.OUTPUT] return allowed_data_mode + @staticmethod + def set_value_into_nested_structure(data_structure, index, value): + ''' + Args: + data_structure: nested data structure + index: List[str] + value: value to be set + ''' + if not index: + raise ValueError(f"index need to be non empty when set value to nested data structure") + current_level = data_structure + for i in index[:-1]: + if isinstance(current_level, list): + current_level = current_level[int(i)] + elif isinstance(current_level, dict): + current_level = current_level[i] + else: + raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") + + if isinstance(current_level, list): + current_level[int(index[-1])] = value + elif isinstance(current_level, dict): + current_level[index[-1]] = value + else: + raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") + + @staticmethod + def analyze_element_to_none(element, suffix_stack): + return None + + @staticmethod + def analyze_hook_single_element(element, suffix_stack, hook_fn): + if hasattr(element, "register_hook"): + # element might be mindspore.tensor or torch.tensor + if hasattr(element, "requires_grad") and not element.requires_grad: + return + indexes = copy.deepcopy(suffix_stack) + wrap_hook_fn = partial(hook_fn, index=indexes) + def real_hook_fn(grad): + return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) + @classmethod def get_special_types(cls): return cls.special_type @@ -282,46 +324,6 @@ class BaseDataProcessor: def analyze_element_to_all_none(self, element): return self.recursive_apply_transform(element, self.analyze_element_to_none) - def analyze_element_to_none(self, element, suffix_stack): - return None - - def analyze_hook_single_element(self, element, suffix_stack, hook_fn): - if hasattr(element, "register_hook"): - # element might be mindspore.tensor or torch.tensor - if hasattr(element, "requires_grad") and not element.requires_grad: - return - indexes = copy.deepcopy(suffix_stack) - wrap_hook_fn = partial(hook_fn, index=indexes) - def real_hook_fn(grad): - return wrap_hook_fn(grad) - element.register_hook(real_hook_fn) - - @staticmethod - def set_value_into_nested_structure(data_structure, index, value): - ''' - Args: - data_structure: nested data structure - index: List[str] - value: value to be set - ''' - if not index: - raise ValueError(f"index need to be non empty when set value to nested data structure") - current_level = data_structure - for i in index[:-1]: - if isinstance(current_level, list): - current_level = current_level[int(i)] - elif isinstance(current_level, dict): - current_level = current_level[i] - else: - raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") - - if isinstance(current_level, list): - current_level[int(index[-1])] = value - elif isinstance(current_level, dict): - current_level[index[-1]] = value - else: - raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") - def analyze_debug_forward(self, variable): self.api_data_category = Const.OUTPUT data_info = self.analyze_element(variable) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index d404c6a628..26dddb91cd 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -74,14 +74,13 @@ class DataWriter: save_json(self.debug_file_path, self.cache_debug, indent=1) - def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, - free_benchmark_file_path, debug_file_path=None): - self.dump_file_path = dump_file_path - self.stack_file_path = stack_file_path - self.construct_file_path = construct_file_path - self.dump_tensor_data_dir = dump_data_dir - self.free_benchmark_file_path = free_benchmark_file_path - self.debug_file_path = debug_file_path + def update_dump_paths(self, dump_path_aggregation): + self.dump_file_path = dump_path_aggregation.dump_file_path + self.stack_file_path = dump_path_aggregation.stack_file_path + self.construct_file_path = dump_path_aggregation.construct_file_path + self.dump_tensor_data_dir = dump_path_aggregation.dump_data_dir + self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path + self.debug_file_path = dump_path_aggregation.debug_file_path def flush_data_periodically(self): dump_data = self.cache_data.get(Const.DATA) diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 30f896d547..a252d6564c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -32,7 +32,7 @@ else: from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException from msprobe.core.common.file_utils import create_directory from msprobe.core.common.utils import Const, print_tools_ends_info -from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.data_dump.data_collector import build_data_collector, DumpPathAggregation from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor @@ -354,12 +354,14 @@ class Service: else: dump_data_dir = None - dump_file_path = os.path.join(dump_dir, "dump.json") - stack_file_path = os.path.join(dump_dir, "stack.json") - construct_file_path = os.path.join(dump_dir, "construct.json") - debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths( - dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None, debug_file_path) + dump_path_aggregation = DumpPathAggregation + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.free_benchmark_file_path = None + dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths(dump_path_aggregation) self.debug_variable_counter = defaultdict(int) @@ -397,12 +399,13 @@ class Service: else: dump_data_dir = None - dump_file_path = os.path.join(dump_dir, "dump.json") - stack_file_path = os.path.join(dump_dir, "stack.json") - construct_file_path = os.path.join(dump_dir, "construct.json") - self.data_collector.update_dump_paths( - dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None - ) + dump_path_aggregation = DumpPathAggregation + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.free_benchmark_file_path = None + self.data_collector.update_dump_paths(dump_path_aggregation) self.data_collector.initialize_json_file( framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK ) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 163b21901e..368fd69d59 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -23,7 +23,7 @@ from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import create_directory from msprobe.core.common.utils import print_tools_ends_info -from msprobe.core.data_dump.data_collector import build_data_collector +from msprobe.core.data_dump.data_collector import build_data_collector, DumpPathAggregation from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData @@ -318,13 +318,13 @@ class Service: else: dump_data_dir = None - dump_file_path = os.path.join(dump_dir, "dump.json") - stack_file_path = os.path.join(dump_dir, "stack.json") - construct_file_path = os.path.join(dump_dir, "construct.json") - free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv") - self.data_collector.update_dump_paths( - dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path - ) + dump_path_aggregation = DumpPathAggregation + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.free_benchmark_file_path = None + self.data_collector.update_dump_paths(dump_path_aggregation) self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) @@ -363,14 +363,14 @@ class Service: else: dump_data_dir = None - dump_file_path = os.path.join(dump_dir, "dump.json") - stack_file_path = os.path.join(dump_dir, "stack.json") - construct_file_path = os.path.join(dump_dir, "construct.json") - free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv") - debug_file_path = os.path.join(dump_dir, "debug.json") - self.data_collector.update_dump_paths( - dump_file_path, stack_file_path, construct_file_path, dump_data_dir, - free_benchmark_file_path, debug_file_path) + dump_path_aggregation = DumpPathAggregation + dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") + dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") + dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation.dump_tensor_data_dir = dump_data_dir + dump_path_aggregation.free_benchmark_file_path = None + dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths(dump_path_aggregation) self.debug_variable_counter = defaultdict(int) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index 7c20810c14..1c53530503 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -1,7 +1,7 @@ import csv import os import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from msprobe.core.common.file_utils import FileOpen, remove_path, load_json from msprobe.core.data_dump.json_writer import DataWriter @@ -80,13 +80,21 @@ class TestDataWriter(unittest.TestCase): def test_update_dump_paths(self): self.assertIsNone(self.data_writer.dump_file_path) test_path = os.path.join(self.cur_path, "test1.json") - - self.data_writer.update_dump_paths(test_path, test_path, test_path, test_path, test_path) + mock_dump_path_aggregation = MagicMock() + mock_dump_path_aggregation.dump_file_path = test_path + mock_dump_path_aggregation.stack_file_path = test_path + mock_dump_path_aggregation.construct_file_path = test_path + mock_dump_path_aggregation.dump_tensor_data_dir = test_path + mock_dump_path_aggregation.free_benchmark_file_path = test_path + mock_dump_path_aggregation.debug_file_path = test_path + + self.data_writer.update_dump_paths(mock_dump_path_aggregation) self.assertTrue(self.data_writer.dump_file_path == test_path) self.assertTrue(self.data_writer.stack_file_path == test_path) self.assertTrue(self.data_writer.construct_file_path == test_path) self.assertTrue(self.data_writer.dump_tensor_data_dir == test_path) self.assertTrue(self.data_writer.free_benchmark_file_path == test_path) + self.assertTrue(self.data_writer.debug_file_path == test_path) @patch.object(DataWriter, "write_json") def test_flush_data_periodically(self, mock_write_json): diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 7e0ea2f05d..28b8813377 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -29,7 +29,7 @@ from msprobe.mindspore.common.utils import register_backward_hook_functions from msprobe.mindspore.dump.hook_cell.api_registry import ApiRegistry, api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump -from msprobe.mindspore.service import Service +from msprobe.mindspore.service import Service, DumpPathAggregation class TestService(unittest.TestCase): @@ -92,13 +92,13 @@ class TestService(unittest.TestCase): mock_create_directory.assert_has_calls( [unittest.mock.call(path) for path in expected_calls], any_order=True) - self.service.data_collector.update_dump_paths.assert_called_once_with( - "/tmp/dump/step1/rank0/dump.json", - "/tmp/dump/step1/rank0/stack.json", - "/tmp/dump/step1/rank0/construct.json", - "/tmp/dump/step1/rank0/dump_tensor_data", - None - ) + dump_path_aggregation = DumpPathAggregation + dump_path_aggregation.dump_file_path = "/tmp/dump/step1/rank0/dump.json" + dump_path_aggregation.stack_file_path = "/tmp/dump/step1/rank0/stack.json" + dump_path_aggregation.construct_file_path = "/tmp/dump/step1/rank0/construct.json" + dump_path_aggregation.dump_tensor_data_dir = "/tmp/dump/step1/rank0/dump_tensor_data" + dump_path_aggregation.free_benchmark_file_path = None + self.service.data_collector.update_dump_paths.assert_called_once_with(dump_path_aggregation) self.service.data_collector.initialize_json_file.assert_called_once_with( framework=Const.MS_FRAMEWORK ) -- Gitee From 90fa65d03684296027d067e1f80faa26179e8113 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Wed, 22 Jan 2025 11:18:41 +0800 Subject: [PATCH 06/20] bugfix --- .../core/data_dump/data_processor/base.py | 84 +++++++++---------- .../accuracy_tools/msprobe/pytorch/service.py | 2 + 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index bb6fe4ae39..615ae72266 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -145,6 +145,48 @@ class BaseDataProcessor: else: return data + @staticmethod + def set_value_into_nested_structure(data_structure, index, value): + ''' + Args: + data_structure: nested data structure + index: List[str] + value: value to be set + ''' + if not index: + raise ValueError(f"index need to be non empty when set value to nested data structure") + current_level = data_structure + for i in index[:-1]: + if isinstance(current_level, list): + current_level = current_level[int(i)] + elif isinstance(current_level, dict): + current_level = current_level[i] + else: + raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") + + if isinstance(current_level, list): + current_level[int(index[-1])] = value + elif isinstance(current_level, dict): + current_level[index[-1]] = value + else: + raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") + + @staticmethod + def analyze_element_to_none(element, suffix_stack): + return None + + @staticmethod + def analyze_hook_single_element(element, suffix_stack, hook_fn): + if hasattr(element, "register_hook"): + # element might be mindspore.tensor or torch.tensor + if hasattr(element, "requires_grad") and not element.requires_grad: + return + indexes = copy.deepcopy(suffix_stack) + wrap_hook_fn = partial(hook_fn, index=indexes) + def real_hook_fn(grad): + return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) + @staticmethod def _convert_numpy_to_builtin(arg): type_mapping = { @@ -200,48 +242,6 @@ class BaseDataProcessor: allowed_data_mode += [Const.INPUT, Const.OUTPUT] return allowed_data_mode - @staticmethod - def set_value_into_nested_structure(data_structure, index, value): - ''' - Args: - data_structure: nested data structure - index: List[str] - value: value to be set - ''' - if not index: - raise ValueError(f"index need to be non empty when set value to nested data structure") - current_level = data_structure - for i in index[:-1]: - if isinstance(current_level, list): - current_level = current_level[int(i)] - elif isinstance(current_level, dict): - current_level = current_level[i] - else: - raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") - - if isinstance(current_level, list): - current_level[int(index[-1])] = value - elif isinstance(current_level, dict): - current_level[index[-1]] = value - else: - raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") - - @staticmethod - def analyze_element_to_none(element, suffix_stack): - return None - - @staticmethod - def analyze_hook_single_element(element, suffix_stack, hook_fn): - if hasattr(element, "register_hook"): - # element might be mindspore.tensor or torch.tensor - if hasattr(element, "requires_grad") and not element.requires_grad: - return - indexes = copy.deepcopy(suffix_stack) - wrap_hook_fn = partial(hook_fn, index=indexes) - def real_hook_fn(grad): - return wrap_hook_fn(grad) - element.register_hook(real_hook_fn) - @classmethod def get_special_types(cls): return cls.special_type diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 368fd69d59..b9e3fb0cd3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -57,6 +57,8 @@ class Service: self.should_stop_service = False self.attl = None self.params_grad_info = {} + if self.config.level == Const.LEVEL_DEBUG: + self.init_for_debug_level() def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): -- Gitee From 9e113784813a3b2ea19ace4a613f6015d7f31a95 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Wed, 22 Jan 2025 15:49:40 +0800 Subject: [PATCH 07/20] fix ut --- debug/accuracy_tools/msprobe/core/data_dump/json_writer.py | 2 +- .../msprobe/test/core_ut/data_dump/test_data_collector.py | 7 +++---- .../accuracy_tools/msprobe/test/pytorch_ut/test_service.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 26dddb91cd..ad567b48f7 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -78,7 +78,7 @@ class DataWriter: self.dump_file_path = dump_path_aggregation.dump_file_path self.stack_file_path = dump_path_aggregation.stack_file_path self.construct_file_path = dump_path_aggregation.construct_file_path - self.dump_tensor_data_dir = dump_path_aggregation.dump_data_dir + self.dump_tensor_data_dir = dump_path_aggregation.dump_tensor_data_dir self.free_benchmark_file_path = dump_path_aggregation.free_benchmark_file_path self.debug_file_path = dump_path_aggregation.debug_file_path diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index 2b099942b2..d42335811c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -110,11 +110,10 @@ class TestDataCollector(unittest.TestCase): mock_update_api_or_module_name.assert_called_with("name_with_count") mock_update_debug.assert_called_with({"name_with_count": "data_info"}) - @patch.object(DataWriter, "cache_debug") @patch.object(DataWriter, "update_debug") @patch.object(BaseDataProcessor, "analyze_debug_backward") @patch.object(BaseDataProcessor, "analyze_element_to_all_none", return_value = "all_none_data_info") - def test_debug_data_collect_forward(self, _, mock_analyze_debug_backward, mock_update_debug, mock_cache_debug): - self.data_collector.debug_data_collect_forward("variable", "name_with_count") + def test_debug_data_collect_backward(self, _, mock_analyze_debug_backward, mock_update_debug): + self.data_collector.debug_data_collect_backward("variable", "name_with_count") mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) - mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", mock_cache_debug) + mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py index 51b343fc2c..4b8e5204d5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_service.py @@ -90,7 +90,7 @@ class TestService(unittest.TestCase): patch("msprobe.pytorch.service.api_register.initialize_hook") as mock_init_hook, \ patch("msprobe.pytorch.service.api_register.api_modularity") as mock_api_modularity: self.service.register_api_hook() - self.assertEqual(mock_logger.call_count, 1) + self.assertEqual(mock_logger.call_count, 2) mock_init_hook.assert_called_once() mock_api_modularity.assert_called_once() -- Gitee From 7c8aca2e0ce252e8b3bcb7bbbdebc5be374a1ce7 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Wed, 22 Jan 2025 17:44:09 +0800 Subject: [PATCH 08/20] bugfix --- .../msprobe/core/data_dump/data_collector.py | 2 +- .../core/data_dump/data_processor/base.py | 21 ++++++++----------- .../data_processor/mindspore_processor.py | 6 +++++- .../data_processor/pytorch_processor.py | 6 +++++- .../msprobe/core/data_dump/json_writer.py | 2 +- .../msprobe/mindspore/service.py | 4 +++- 6 files changed, 24 insertions(+), 17 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index ef676f119b..99fe0a72b2 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -188,7 +188,7 @@ class DataCollector: self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) # register tensor backward hook - self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug) + self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug['data']) def params_data_collect(self, name, param_name, pid, data): grad_name = name + Const.SEP + Const.PARAMS_GRAD diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 615ae72266..ffe4f77c00 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -175,18 +175,6 @@ class BaseDataProcessor: def analyze_element_to_none(element, suffix_stack): return None - @staticmethod - def analyze_hook_single_element(element, suffix_stack, hook_fn): - if hasattr(element, "register_hook"): - # element might be mindspore.tensor or torch.tensor - if hasattr(element, "requires_grad") and not element.requires_grad: - return - indexes = copy.deepcopy(suffix_stack) - wrap_hook_fn = partial(hook_fn, index=indexes) - def real_hook_fn(grad): - return wrap_hook_fn(grad) - element.register_hook(real_hook_fn) - @staticmethod def _convert_numpy_to_builtin(arg): type_mapping = { @@ -291,6 +279,15 @@ class BaseDataProcessor: cls._recursive_key_stack.pop() return result_list + @classmethod + def analyze_hook_single_element(cls, element, suffix_stack, hook_fn): + if cls.is_hookable_element(element): + indexes = copy.deepcopy(suffix_stack) + wrap_hook_fn = partial(hook_fn, index=indexes) + def real_hook_fn(grad): + return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) + def if_return_forward_new_output(self): return self._return_forward_new_output diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 3f3e547e32..aaf8312108 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -49,6 +49,10 @@ class MindsporeDataProcessor(BaseDataProcessor): def analyze_dtype_in_kwargs(element): return {"type": "mindspore.dtype", "value": str(element)} + @staticmethod + def is_hookable_element(element): + return hasattr(element, "register_hook") and callable(element.register_hook) + @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type @@ -176,7 +180,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data() return api_info_struct if self.has_overflow else None - + def analyze_params(self, name, param_name, grad): self.has_overflow = False api_info_struct = super().analyze_params(name, param_name, grad) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 045eb2b74a..39b5464a08 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -81,6 +81,10 @@ class PytorchDataProcessor(BaseDataProcessor): def analyze_dtype_in_kwargs(element): return {"type": "torch.dtype", "value": str(element)} + @staticmethod + def is_hookable_element(element): + return hasattr(element, "register_hook") and callable(element.register_hook) and element.requires_grad + @staticmethod def get_stat_info(data): tensor_stat = TensorStatInfo() @@ -275,7 +279,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.handle_overflow() return api_info_struct if self.has_overflow else None - + def analyze_params(self, name, param_name, grad): self.has_overflow = False self._is_support_inf_nan() diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index ad567b48f7..93fc46399a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -109,7 +109,7 @@ class DataWriter: self.cache_construct.update(new_data) def update_debug(self, new_data): - self.cache_debug.update(new_data) + self.cache_debug['data'].update(new_data) def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index a252d6564c..672ce8d536 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -362,7 +362,9 @@ class Service: dump_path_aggregation.free_benchmark_file_path = None dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") self.data_collector.update_dump_paths(dump_path_aggregation) - + self.data_collector.initialize_json_file( + framework=Const.MT_FRAMEWORK if is_mindtorch() else Const.MS_FRAMEWORK + ) self.debug_variable_counter = defaultdict(int) def need_end_service(self): -- Gitee From faeaf042bc84a7825e8e6aea52d0f080272a89af Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 10:25:50 +0800 Subject: [PATCH 09/20] add ut --- .../test/mindspore_ut/test_debug_save.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py new file mode 100644 index 0000000000..f2d5c23398 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import TestCase +from unittest.mock import patch +import mindspore + +from msprobe.mindspore import PrecisionDebugger + + + +class TestDebuggerSave(TestCase): + def setUp(self): + statistics_task_json = { + "task": "statistics", + "dump_path": "./dump_path", + "rank": [], + "step": [], + "level": "debug", + "enable_dataloader": False, + "statistics": { + "scope": [], + "list":[], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + with patch("msprobe.core.common.utils.load_json", return_value=statistics_task_json): + debugger = PrecisionDebugger() + def test_simple_case(): + def forward_func(x, y): + PrecisionDebugger.save(x, "x_tensor") + PrecisionDebugger.save(y, "y_tensor") + return x * y + x = mindspore.Tensor([1.]) + y = mindspore.Tensor([2.]) + result_json = {} + with patch("msprobe.core.common.utils.save_json") as mock_save_json: + forward_func(x, y) + mock_save_json.assert_called_once_with("./dump_path", result_json, 1) + + -- Gitee From b7c0e9eca47223d5ffad7126abfd7a02b193abff Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 11:08:29 +0800 Subject: [PATCH 10/20] fix added ut --- .../test/mindspore_ut/test_debug_save.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py index f2d5c23398..274c9aa298 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py @@ -19,7 +19,6 @@ import mindspore from msprobe.mindspore import PrecisionDebugger - class TestDebuggerSave(TestCase): def setUp(self): statistics_task_json = { @@ -36,18 +35,33 @@ class TestDebuggerSave(TestCase): "summary_mode": "statistics" } } - with patch("msprobe.core.common.utils.load_json", return_value=statistics_task_json): - debugger = PrecisionDebugger() - def test_simple_case(): + with patch("msprobe.mindspore.ms_config.load_json", return_value=statistics_task_json): + self.debugger = PrecisionDebugger() + def test_only_forward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") - PrecisionDebugger.save(y, "y_tensor") return x * y x = mindspore.Tensor([1.]) y = mindspore.Tensor([2.]) - result_json = {} - with patch("msprobe.core.common.utils.save_json") as mock_save_json: - forward_func(x, y) - mock_save_json.assert_called_once_with("./dump_path", result_json, 1) + result_json = { + "task": "statistics", + "level": "debug", + "framework": "mindspore", + "dump_data_dir": None, + "data": { + "x_tensor.0": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": (1,), + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.0 + }, + "x_tensor_grad.0": None + } + } + forward_func(x, y) + self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) -- Gitee From 493c871aab023364246e61c21d1e8cc07cdcbdd3 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 11:20:52 +0800 Subject: [PATCH 11/20] ut done --- .../msprobe/test/mindspore_ut/test_debug_save.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py index 274c9aa298..ec7988580f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py @@ -37,7 +37,8 @@ class TestDebuggerSave(TestCase): } with patch("msprobe.mindspore.ms_config.load_json", return_value=statistics_task_json): self.debugger = PrecisionDebugger() - def test_only_forward(self): + + def test_forward_and_backward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") return x * y @@ -58,10 +59,19 @@ class TestDebuggerSave(TestCase): "Mean": 1.0, "Norm": 1.0 }, - "x_tensor_grad.0": None + "x_tensor_grad.0": { + "type": "mindspore.Tensor", + "dtype": "Float32", + "shape": (1,), + "Max": 2.0, + "Min": 2.0, + "Mean": 2.0, + "Norm": 2.0 + } } } - forward_func(x, y) + grad_fn = mindspore.value_and_grad(forward_func, (0, 1)) + grad_fn(x, y) self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) -- Gitee From 5f7687291fe95517c0ecd7962757e863a44e3c17 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 15:14:56 +0800 Subject: [PATCH 12/20] add pt debug save ut --- .../data_dump/data_processor/test_base.py | 6 +- .../core_ut/data_dump/test_data_collector.py | 4 +- .../test/pytorch_ut/test_debug_save.py | 77 +++++++++++++++++++ 3 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py index d36ebda877..d994cfce01 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_base.py @@ -9,6 +9,7 @@ import numpy as np from msprobe.core.common.log import logger from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs, \ TensorStatInfo, BaseDataProcessor +from msprobe.core.data_dump.data_processor.mindspore_processor import MindsporeDataProcessor class TestModuleForwardInputsOutputs(unittest.TestCase): @@ -274,12 +275,13 @@ class TestBaseDataProcessor(unittest.TestCase): excepted_result = {"key1": [None, None, {"key2": None, "key3":[None]}]} self.assertEqual(result, excepted_result) - def test_analyze_hook_single_element(self): + @patch.object(MindsporeDataProcessor, "is_hookable_element", return_value=True) + def test_analyze_hook_single_element(self, _): element = MagicMock() element.hasattr = MagicMock(side_effect=lambda attr: attr == "register_hook") element.requires_grad = True hook_fn = MagicMock() - self.processor.analyze_hook_single_element(element, [1, 2], hook_fn) + MindsporeDataProcessor.analyze_hook_single_element(element, [1, 2], hook_fn) element.register_hook.assert_called_once() @patch("msprobe.core.data_dump.data_processor.base.partial") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index d42335811c..a27daa3c8b 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -114,6 +114,8 @@ class TestDataCollector(unittest.TestCase): @patch.object(BaseDataProcessor, "analyze_debug_backward") @patch.object(BaseDataProcessor, "analyze_element_to_all_none", return_value = "all_none_data_info") def test_debug_data_collect_backward(self, _, mock_analyze_debug_backward, mock_update_debug): + self.data_collector.data_writer.cache_debug = {"data": None} self.data_collector.debug_data_collect_backward("variable", "name_with_count") mock_update_debug.assert_called_with({"name_with_count": "all_none_data_info"}) - mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug) + mock_analyze_debug_backward.assert_called_with("variable", "name_with_count", self.data_collector.data_writer.cache_debug['data']) + self.data_collector.data_writer.cache_debug = None diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py new file mode 100644 index 0000000000..b37e49adf6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import TestCase +from unittest.mock import patch +import torch + +from msprobe.pytorch import PrecisionDebugger + + +class TestDebuggerSave(TestCase): + def setUp(self): + statistics_task_json = { + "task": "statistics", + "dump_path": "./dump_path", + "rank": [], + "step": [], + "level": "debug", + "enable_dataloader": False, + "statistics": { + "scope": [], + "list":[], + "data_mode": ["all"], + "summary_mode": "statistics" + } + } + with patch("msprobe.mindspore.pt_config.load_json", return_value=statistics_task_json): + self.debugger = PrecisionDebugger() + + def test_forward_and_backward(self): + def forward_func(x, y): + PrecisionDebugger.save(x, "x_tensor") + return x * y + x = torch.tensor([1.]) + y = torch.tensor([2.]) + result_json = { + "task": "statistics", + "level": "debug", + "framework": "pytorch", + "dump_data_dir": None, + "data": { + "x_tensor.0": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": (1,), + "Max": 1.0, + "Min": 1.0, + "Mean": 1.0, + "Norm": 1.0 + }, + "x_tensor_grad.0": { + "type": "torch.Tensor", + "dtype": "torch.float32", + "shape": (1,), + "Max": 2.0, + "Min": 2.0, + "Mean": 2.0, + "Norm": 2.0 + } + } + } + loss = forward_func(x, y) + loss.backward() + self.assertEqual(self.debugger.service.data_collector.data_writer.cache_debug, result_json) + + -- Gitee From 42caf16419b202ed231c4e6d7da2e9756b4a2623 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 15:24:42 +0800 Subject: [PATCH 13/20] fix pt ut --- debug/accuracy_tools/msprobe/pytorch/service.py | 1 + .../msprobe/test/pytorch_ut/test_debug_save.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index b9e3fb0cd3..5f85c20206 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -373,6 +373,7 @@ class Service: dump_path_aggregation.free_benchmark_file_path = None dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") self.data_collector.update_dump_paths(dump_path_aggregation) + self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) self.debug_variable_counter = defaultdict(int) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py index b37e49adf6..ad2b3d36f3 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py @@ -35,7 +35,7 @@ class TestDebuggerSave(TestCase): "summary_mode": "statistics" } } - with patch("msprobe.mindspore.pt_config.load_json", return_value=statistics_task_json): + with patch("msprobe.pytorch.pt_config.load_json", return_value=statistics_task_json): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): @@ -44,6 +44,8 @@ class TestDebuggerSave(TestCase): return x * y x = torch.tensor([1.]) y = torch.tensor([2.]) + x.requires_grad = True + y.requires_grad = True result_json = { "task": "statistics", "level": "debug", @@ -53,20 +55,22 @@ class TestDebuggerSave(TestCase): "x_tensor.0": { "type": "torch.Tensor", "dtype": "torch.float32", - "shape": (1,), + "shape": torch.Size([1]), "Max": 1.0, "Min": 1.0, "Mean": 1.0, - "Norm": 1.0 + "Norm": 1.0, + "requires_grad": True }, "x_tensor_grad.0": { "type": "torch.Tensor", "dtype": "torch.float32", - "shape": (1,), + "shape": torch.Size([1]), "Max": 2.0, "Min": 2.0, "Mean": 2.0, - "Norm": 2.0 + "Norm": 2.0, + "requires_grad": False } } } -- Gitee From c50a3cc012579428ec4cfc3e7c807ae87ed1e7b0 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 15:39:02 +0800 Subject: [PATCH 14/20] rename ut --- .../mindspore_ut/{test_debug_save.py => test_ms_debug_save.py} | 2 +- .../pytorch_ut/{test_debug_save.py => test_pt_debug_save.py} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename debug/accuracy_tools/msprobe/test/mindspore_ut/{test_debug_save.py => test_ms_debug_save.py} (98%) rename debug/accuracy_tools/msprobe/test/pytorch_ut/{test_debug_save.py => test_pt_debug_save.py} (98%) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py similarity index 98% rename from debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index ec7988580f..9e157b75f1 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -19,7 +19,7 @@ import mindspore from msprobe.mindspore import PrecisionDebugger -class TestDebuggerSave(TestCase): +class TestMindsporeDebuggerSave(TestCase): def setUp(self): statistics_task_json = { "task": "statistics", diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py similarity index 98% rename from debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py rename to debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py index ad2b3d36f3..cc4c57eaf2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py @@ -19,7 +19,7 @@ import torch from msprobe.pytorch import PrecisionDebugger -class TestDebuggerSave(TestCase): +class TestPytorchDebuggerSave(TestCase): def setUp(self): statistics_task_json = { "task": "statistics", -- Gitee From ea521397c09bd856997ec6098a09aaf52f547256 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 15:53:55 +0800 Subject: [PATCH 15/20] fix ms ut --- .../msprobe/test/mindspore_ut/test_ms_debug_save.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index 9e157b75f1..72c38e9a18 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -20,7 +20,7 @@ from msprobe.mindspore import PrecisionDebugger class TestMindsporeDebuggerSave(TestCase): - def setUp(self): + def test_forward_and_backward(self): statistics_task_json = { "task": "statistics", "dump_path": "./dump_path", @@ -37,8 +37,6 @@ class TestMindsporeDebuggerSave(TestCase): } with patch("msprobe.mindspore.ms_config.load_json", return_value=statistics_task_json): self.debugger = PrecisionDebugger() - - def test_forward_and_backward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") return x * y -- Gitee From 7115d0a64320083ef7c0919cfffaf375fa6369e1 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 16:43:47 +0800 Subject: [PATCH 16/20] PrecisionDebugger._instance = None at ut setup --- .../msprobe/test/mindspore_ut/test_ms_debug_save.py | 12 +++++++++--- .../msprobe/test/pytorch_ut/test_pt_debug_save.py | 6 +++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index 72c38e9a18..d35e8c076b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -17,10 +17,12 @@ from unittest.mock import patch import mindspore from msprobe.mindspore import PrecisionDebugger - +from msprobe.core.common_config import CommonConfig, BaseConfig class TestMindsporeDebuggerSave(TestCase): - def test_forward_and_backward(self): + def setUp(self): + PrecisionDebugger._instance = None + mindspore.set_context(mode=mindspore.PYNATIVE_MODE) statistics_task_json = { "task": "statistics", "dump_path": "./dump_path", @@ -35,8 +37,12 @@ class TestMindsporeDebuggerSave(TestCase): "summary_mode": "statistics" } } - with patch("msprobe.mindspore.ms_config.load_json", return_value=statistics_task_json): + common_config = CommonConfig(statistics_task_json) + task_config = BaseConfig(statistics_task_json) + with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)): self.debugger = PrecisionDebugger() + + def test_forward_and_backward(self): def forward_func(x, y): PrecisionDebugger.save(x, "x_tensor") return x * y diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py index cc4c57eaf2..2de8620823 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_debug_save.py @@ -17,10 +17,12 @@ from unittest.mock import patch import torch from msprobe.pytorch import PrecisionDebugger +from msprobe.core.common_config import CommonConfig, BaseConfig class TestPytorchDebuggerSave(TestCase): def setUp(self): + PrecisionDebugger._instance = None statistics_task_json = { "task": "statistics", "dump_path": "./dump_path", @@ -35,7 +37,9 @@ class TestPytorchDebuggerSave(TestCase): "summary_mode": "statistics" } } - with patch("msprobe.pytorch.pt_config.load_json", return_value=statistics_task_json): + common_config = CommonConfig(statistics_task_json) + task_config = BaseConfig(statistics_task_json) + with patch("msprobe.pytorch.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): -- Gitee From 3d76940d2e3e721ff7b964c489652f0e819ec398 Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Thu, 23 Jan 2025 17:34:49 +0800 Subject: [PATCH 17/20] add patch --- .../msprobe/test/mindspore_ut/test_ms_debug_save.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py index d35e8c076b..9040895be5 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_debug_save.py @@ -39,7 +39,8 @@ class TestMindsporeDebuggerSave(TestCase): } common_config = CommonConfig(statistics_task_json) task_config = BaseConfig(statistics_task_json) - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)): + with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", return_value=(common_config, task_config)), \ + patch("msprobe.mindspore.debugger.precision_debugger.set_register_backward_hook_functions"): self.debugger = PrecisionDebugger() def test_forward_and_backward(self): -- Gitee From 5b4c005a09c06b5809b038fd8fc465e613c63bda Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Fri, 24 Jan 2025 14:29:10 +0800 Subject: [PATCH 18/20] api hook fix --- debug/accuracy_tools/msprobe/mindspore/service.py | 6 +++++- debug/accuracy_tools/msprobe/pytorch/service.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 1cef0888de..42494f3a40 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -68,7 +68,8 @@ class Service: self.should_stop_service = False self.params_grad_info = {} # 提前注册,确保注册尽可能多的API hook - self.register_api_hook() + if self.need_register_hook(): + self.register_api_hook() if self.config.level == Const.LEVEL_DEBUG: self.init_for_debug_level() @@ -91,6 +92,9 @@ class Service: module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) return module_input_output + def need_register_hook(self): + return self.config.level != Const.LEVEL_DEBUG + def build_hook(self, target_type, name): def pre_hook(api_or_cell_name, cell, input_data): if not self.should_execute_hook(target_type, cell, True): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 5f85c20206..76d336842a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -57,9 +57,15 @@ class Service: self.should_stop_service = False self.attl = None self.params_grad_info = {} + # 提前注册,确保注册尽可能多的API hook + if self.need_register_hook(): + self.register_api_hook() if self.config.level == Const.LEVEL_DEBUG: self.init_for_debug_level() + def need_register_hook(self): + return self.config.level != Const.LEVEL_DEBUG + def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): if not self.should_execute_hook(module_type, module, True): -- Gitee From 85f61081f66b76c054e1f1d66b9bc599d381774b Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Mon, 27 Jan 2025 11:22:36 +0800 Subject: [PATCH 19/20] code review --- .../msprobe/core/common/utils.py | 10 +++ .../msprobe/core/data_dump/data_collector.py | 12 +--- .../core/data_dump/data_processor/base.py | 72 ++++++++----------- .../data_processor/mindspore_processor.py | 2 +- .../data_processor/pytorch_processor.py | 2 +- .../msprobe/core/data_dump/json_writer.py | 12 ++-- .../msprobe/mindspore/common/utils.py | 19 +++++ .../mindspore/debugger/precision_debugger.py | 8 ++- .../msprobe/mindspore/service.py | 19 ++--- .../msprobe/pytorch/common/utils.py | 18 +++++ .../pytorch/debugger/precision_debugger.py | 7 +- .../accuracy_tools/msprobe/pytorch/service.py | 25 ++----- 12 files changed, 111 insertions(+), 95 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 6d5f9e0fd7..b819142b64 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -495,3 +495,13 @@ def check_str_param(param): if not re.match(Const.REGEX_PREFIX_PATTERN, param): logger.error('The parameter {} contains special characters.'.format(param)) raise MsprobeBaseException(MsprobeBaseException.INVALID_CHAR_ERROR) + + +class DumpPathAggregation: + dump_file_path = None + stack_file_path = None + construct_file_path = None + dump_tensor_data_dir = None + free_benchmark_file_path = None + debug_file_path = None + diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 99fe0a72b2..f8b5d9541a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -23,15 +23,6 @@ from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory -class DumpPathAggregation: - dump_file_path: str - stack_file_path: str - construct_file_path: str - dump_tensor_data_dir: str - free_benchmark_file_path: str - debug_file_path = None - - def build_data_collector(config): return DataCollector(config) @@ -178,8 +169,7 @@ class DataCollector: self.data_processor.update_iter(current_iter) def debug_data_collect_forward(self, variable, name_with_count): - self.update_api_or_module_name(name_with_count) - data_info = self.data_processor.analyze_debug_forward(variable) + data_info = self.data_processor.analyze_debug_forward(variable, name_with_count) self.data_writer.update_debug({name_with_count: data_info}) def debug_data_collect_backward(self, variable, grad_name_with_count): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index ffe4f77c00..6818d47d4b 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -106,7 +106,6 @@ class BaseDataProcessor: self.save_name = None if hasattr(config, "data_mode"): self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode) - self.custom_dump_data_name = None @property def data_path(self): @@ -146,34 +145,29 @@ class BaseDataProcessor: return data @staticmethod - def set_value_into_nested_structure(data_structure, index, value): + def set_value_into_nested_structure(data_structure, indexes, value): ''' Args: data_structure: nested data structure - index: List[str] + indexes: List value: value to be set ''' - if not index: - raise ValueError(f"index need to be non empty when set value to nested data structure") + if not indexes: + raise ValueError(f"set_value_into_nested_structure failed: \ + indexes need to be non empty when set value to nested data structure") current_level = data_structure - for i in index[:-1]: - if isinstance(current_level, list): - current_level = current_level[int(i)] - elif isinstance(current_level, dict): - current_level = current_level[i] + for i, index in enumerate(indexes): + valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index + valid_for_dict = isinstance(current_level, dict) and index in current_level + is_last = i == len(index)-1 + if valid_for_dict or valid_for_list: + if is_last: + current_level = value + else: + current_level = current_level[index] else: - raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") - - if isinstance(current_level, list): - current_level[int(index[-1])] = value - elif isinstance(current_level, dict): - current_level[index[-1]] = value - else: - raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") - - @staticmethod - def analyze_element_to_none(element, suffix_stack): - return None + raise ValueError(f"set_value_into_nested_structure failed: \ + invalid data_structure type(dict or list) or invalid index") @staticmethod def _convert_numpy_to_builtin(arg): @@ -265,7 +259,7 @@ class BaseDataProcessor: def apply_transform_dict(cls, args, transform, depth): result_dict = {} for k, arg in args.items(): - cls._recursive_key_stack.append(str(k)) + cls._recursive_key_stack.append(k) result_dict[k] = cls.recursive_apply_transform(arg, transform, depth=depth + 1) cls._recursive_key_stack.pop() return result_dict @@ -274,13 +268,13 @@ class BaseDataProcessor: def apply_transform_list(cls, args, transform, depth): result_list = [] for i, arg in enumerate(args): - cls._recursive_key_stack.append(str(i)) + cls._recursive_key_stack.append(i) result_list.append(cls.recursive_apply_transform(arg, transform, depth=depth + 1)) cls._recursive_key_stack.pop() return result_list @classmethod - def analyze_hook_single_element(cls, element, suffix_stack, hook_fn): + def register_hook_single_element(cls, element, suffix_stack, hook_fn): if cls.is_hookable_element(element): indexes = copy.deepcopy(suffix_stack) wrap_hook_fn = partial(hook_fn, index=indexes) @@ -319,26 +313,26 @@ class BaseDataProcessor: return self.recursive_apply_transform(element, self.analyze_single_element) def analyze_element_to_all_none(self, element): - return self.recursive_apply_transform(element, self.analyze_element_to_none) + return self.recursive_apply_transform(element, lambda element, stack: None) - def analyze_debug_forward(self, variable): - self.api_data_category = Const.OUTPUT + def analyze_debug_forward(self, variable, name_with_count): + self.current_api_or_module_name = name_with_count + self.api_data_category = Const.TENSOR + # these two attributes are used to construct tensor file name {name_with_count}.tensor.{indexes}.npy/pt data_info = self.analyze_element(variable) return data_info def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure): - def hook_fn(grad, index): - self.api_data_category = Const.OUTPUT - file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX - suffix = Const.SEP.join(index) - self.custom_dump_data_name = (grad_name_with_count + Const.SEP + Const.OUTPUT + Const.SEP + suffix + file_format) + def hook_fn(grad, indexes): + suffix = Const.SEP.join([str(index) for index in indexes]) + self.save_name = grad_name_with_count + Const.SEP + Const.TENSOR + Const.SEP + suffix grad_data_info = self.analyze_element(grad) - self.custom_dump_data_name = None - full_index = [grad_name_with_count] + index + self.save_name = None + full_index = [grad_name_with_count] + indexes self.set_value_into_nested_structure(nested_data_structure, full_index, grad_data_info) return grad - wrap_analyze_hook_single_element = partial(self.analyze_hook_single_element, hook_fn=hook_fn) - self.recursive_apply_transform(variable, wrap_analyze_hook_single_element) + wrap_register_hook_single_element = partial(self.register_hook_single_element, hook_fn=hook_fn) + self.recursive_apply_transform(variable, wrap_register_hook_single_element) def analyze_forward_input(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} @@ -437,10 +431,6 @@ class BaseDataProcessor: return api_info_struct def get_save_file_path(self, suffix): - if self.custom_dump_data_name: - dump_data_name = self.custom_dump_data_name - file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) - return dump_data_name, file_path file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX if self.save_name is not None: dump_data_name = (self.save_name + file_format) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 1388c2cc43..91826e0bf3 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -106,7 +106,7 @@ class MindsporeDataProcessor(BaseDataProcessor): if isinstance(element, Number): return self.analyze_dtype_in_kwargs(element) if isinstance(element, ms.Tensor): - return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) return {} diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 39b5464a08..5e72b7dea1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -180,7 +180,7 @@ class PytorchDataProcessor(BaseDataProcessor): if converted_numpy is not element: return self._analyze_numpy(converted_numpy, numpy_type) if isinstance(element, torch.Tensor): - return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + return self._analyze_tensor(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) return {} diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 93fc46399a..4d8dbaface 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -59,6 +59,13 @@ class DataWriter: self.cache_debug = {} def initialize_json_file(self, **kwargs): + if self.debug_file_path and not self.cache_debug: + # debug level case only create debug.json + debug_dict = copy.deepcopy(kwargs) + debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) + self.cache_debug = debug_dict + save_json(self.debug_file_path, self.cache_debug, indent=1) + return if not self.cache_data: kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) self.cache_data = kwargs @@ -67,11 +74,6 @@ class DataWriter: save_json(self.stack_file_path, self.cache_stack, indent=1) if not self.cache_construct: save_json(self.construct_file_path, self.cache_construct, indent=1) - if self.debug_file_path and not self.cache_debug: - debug_dict = copy.deepcopy(kwargs) - debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) - self.cache_debug = debug_dict - save_json(self.debug_file_path, self.cache_debug, indent=1) def update_dump_paths(self, dump_path_aggregation): diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 6d837c8511..ebf3ac85fe 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -179,3 +179,22 @@ def set_register_backward_hook_functions(): else: register_backward_hook_functions["pre"] = ms.nn.Cell.register_backward_pre_hook register_backward_hook_functions["full"] = ms.nn.Cell.register_backward_hook + + +def check_save_param(variable, name, save_backward): + # try catch this api to skip invalid call + if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)): + logger.warning("PrecisionDebugger.save variable type not valid, \ + should be one of list, dict, ms.Tensor, int, float or string. \ + skip current save process") + raise ValueError + if not isinstance(name, str): + logger.warning("PrecisionDebugger.save name not valid, \ + should be string. \ + skip current save process") + raise ValueError + if not isinstance(save_backward, bool): + logger.warning("PrecisionDebugger.save_backward name not valid, \ + should be bool. \ + skip current save process") + raise ValueError \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 853caa2987..795485d204 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -22,7 +22,7 @@ from mindspore._c_expression import MSContext from msprobe.core.common.const import Const, MsgConst from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.const import Const as MsConst -from msprobe.mindspore.common.utils import set_register_backward_hook_functions +from msprobe.mindspore.common.utils import set_register_backward_hook_functions, check_save_param from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.api_registry import api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -179,7 +179,11 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task not in [Const.TENSOR, Const.STATISTICS]: + if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG: + return + try: + check_save_param(variable, name, save_backward) + except ValueError: return instance.config.execution_mode = cls._get_execution_mode() diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 42494f3a40..107c545541 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -31,8 +31,8 @@ else: from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import Const, print_tools_ends_info -from msprobe.core.data_dump.data_collector import build_data_collector, DumpPathAggregation +from msprobe.core.common.utils import Const, print_tools_ends_info, DumpPathAggregation +from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor @@ -68,8 +68,7 @@ class Service: self.should_stop_service = False self.params_grad_info = {} # 提前注册,确保注册尽可能多的API hook - if self.need_register_hook(): - self.register_api_hook() + self.register_api_hook() if self.config.level == Const.LEVEL_DEBUG: self.init_for_debug_level() @@ -92,9 +91,6 @@ class Service: module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) return module_input_output - def need_register_hook(self): - return self.config.level != Const.LEVEL_DEBUG - def build_hook(self, target_type, name): def pre_hook(api_or_cell_name, cell, input_data): if not self.should_execute_hook(target_type, cell, True): @@ -355,7 +351,6 @@ class Service: self.current_rank = None # dir: dump_path -- rank{} -- debug.json - create_directory(self.config.dump_path) self.dump_iter_dir = self.config.dump_path cur_rank = self.current_rank if self.current_rank is not None else '' dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") @@ -366,12 +361,8 @@ class Service: else: dump_data_dir = None - dump_path_aggregation = DumpPathAggregation - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation = DumpPathAggregation() dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.free_benchmark_file_path = None dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") self.data_collector.update_dump_paths(dump_path_aggregation) self.data_collector.initialize_json_file( @@ -419,7 +410,7 @@ class Service: else: dump_data_dir = None - dump_path_aggregation = DumpPathAggregation + dump_path_aggregation = DumpPathAggregation() dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 3eed56300f..4f6a60b0ab 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -402,3 +402,21 @@ def load_api_data(api_data_bytes): except Exception as e: raise RuntimeError(f"load api_data from bytes failed") from e return buffer + +def check_save_param(variable, name, save_backward): + # try catch this api to skip invalid call + if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): + logger.warning("PrecisionDebugger.save variable type not valid, \ + should be one of list, dict, torch.Tensor, int, float or string. \ + skip current save process") + raise ValueError + if not isinstance(name, str): + logger.warning("PrecisionDebugger.save name not valid, \ + should be string. \ + skip current save process") + raise ValueError + if not isinstance(save_backward, bool): + logger.warning("PrecisionDebugger.save save_backward not valid, \ + should be bool. \ + skip current save process") + raise ValueError \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index a31444383c..13e5999660 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -21,6 +21,7 @@ from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.utils import get_real_step_or_rank from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.utils import check_save_param from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.pytorch.pt_config import parse_json_config @@ -128,7 +129,11 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task not in [Const.TENSOR, Const.STATISTICS]: + if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG: + return + try: + check_save_param(variable, name, save_backward) + except ValueError: return instance.service.save(variable, name, save_backward) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 76d336842a..b1153e1d24 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -22,8 +22,8 @@ import torch from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import create_directory -from msprobe.core.common.utils import print_tools_ends_info -from msprobe.core.data_dump.data_collector import build_data_collector, DumpPathAggregation +from msprobe.core.common.utils import print_tools_ends_info, DumpPathAggregation +from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData @@ -58,14 +58,10 @@ class Service: self.attl = None self.params_grad_info = {} # 提前注册,确保注册尽可能多的API hook - if self.need_register_hook(): - self.register_api_hook() + self.register_api_hook() if self.config.level == Const.LEVEL_DEBUG: self.init_for_debug_level() - def need_register_hook(self): - return self.config.level != Const.LEVEL_DEBUG - def build_hook(self, module_type, name): def pre_hook(api_or_module_name, module, args, kwargs): if not self.should_execute_hook(module_type, module, True): @@ -326,12 +322,12 @@ class Service: else: dump_data_dir = None - dump_path_aggregation = DumpPathAggregation + dump_path_aggregation = DumpPathAggregation() dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.free_benchmark_file_path = None + dump_path_aggregation.free_benchmark_file_path = os.path.join(dump_dir, "free_benchmark.csv") self.data_collector.update_dump_paths(dump_path_aggregation) self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) @@ -360,7 +356,6 @@ class Service: self.current_rank = None # dir: dump_path -- rank{} -- debug.json - create_directory(self.config.dump_path) self.dump_iter_dir = self.config.dump_path cur_rank = self.current_rank if self.current_rank is not None else '' dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") @@ -371,12 +366,8 @@ class Service: else: dump_data_dir = None - dump_path_aggregation = DumpPathAggregation - dump_path_aggregation.dump_file_path = os.path.join(dump_dir, "dump.json") - dump_path_aggregation.stack_file_path = os.path.join(dump_dir, "stack.json") - dump_path_aggregation.construct_file_path = os.path.join(dump_dir, "construct.json") + dump_path_aggregation = DumpPathAggregation() dump_path_aggregation.dump_tensor_data_dir = dump_data_dir - dump_path_aggregation.free_benchmark_file_path = None dump_path_aggregation.debug_file_path = os.path.join(dump_dir, "debug.json") self.data_collector.update_dump_paths(dump_path_aggregation) self.data_collector.initialize_json_file(framework=Const.PT_FRAMEWORK) @@ -385,10 +376,6 @@ class Service: def register_api_hook(self): - logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task)) - if self.config.level in [Const.LEVEL_L0, Const.LEVEL_MIX]: - self.module_processor.hook_modules(self.model, self.build_hook) - if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1, Const.LEVEL_L2]: logger.info_on_rank_0(f"The api {self.config.task} hook function is successfully mounted to the model.") api_register.initialize_hook( -- Gitee From 7bd9949d91ae30f7a5a6e80a35b047043193b68b Mon Sep 17 00:00:00 2001 From: qianzhengxin Date: Wed, 5 Feb 2025 15:20:38 +0800 Subject: [PATCH 20/20] fix --- .../core/data_dump/data_processor/base.py | 6 +++--- .../msprobe/mindspore/common/utils.py | 18 +++++++++--------- .../mindspore/debugger/precision_debugger.py | 2 +- .../msprobe/pytorch/common/utils.py | 18 +++++++++--------- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 6818d47d4b..4cc939abd4 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -159,10 +159,10 @@ class BaseDataProcessor: for i, index in enumerate(indexes): valid_for_list = isinstance(current_level, list) and isinstance(index, int) and len(current_level) > index valid_for_dict = isinstance(current_level, dict) and index in current_level - is_last = i == len(index)-1 + is_last = i == len(indexes)-1 if valid_for_dict or valid_for_list: if is_last: - current_level = value + current_level[index] = value else: current_level = current_level[index] else: @@ -277,7 +277,7 @@ class BaseDataProcessor: def register_hook_single_element(cls, element, suffix_stack, hook_fn): if cls.is_hookable_element(element): indexes = copy.deepcopy(suffix_stack) - wrap_hook_fn = partial(hook_fn, index=indexes) + wrap_hook_fn = partial(hook_fn, indexes=indexes) def real_hook_fn(grad): return wrap_hook_fn(grad) element.register_hook(real_hook_fn) diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index ebf3ac85fe..5c713e02c0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -184,17 +184,17 @@ def set_register_backward_hook_functions(): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call if not isinstance(variable, (list, dict, ms.Tensor, int, float, str)): - logger.warning("PrecisionDebugger.save variable type not valid, \ - should be one of list, dict, ms.Tensor, int, float or string. \ - skip current save process") + logger.warning("PrecisionDebugger.save variable type not valid, " + "should be one of list, dict, ms.Tensor, int, float or string. " + "Skip current save process.") raise ValueError if not isinstance(name, str): - logger.warning("PrecisionDebugger.save name not valid, \ - should be string. \ - skip current save process") + logger.warning("PrecisionDebugger.save name not valid, " + "should be string. " + "skip current save process.") raise ValueError if not isinstance(save_backward, bool): - logger.warning("PrecisionDebugger.save_backward name not valid, \ - should be bool. \ - skip current save process") + logger.warning("PrecisionDebugger.save_backward name not valid, " + "should be bool. " + "Skip current save process.") raise ValueError \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 795485d204..33bed1f06f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -179,7 +179,7 @@ class PrecisionDebugger: instance = cls._instance if not instance: raise Exception(MsgConst.NOT_CREATED_INSTANCE) - if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level != Const.LEVEL_DEBUG: + if instance.task not in [Const.TENSOR, Const.STATISTICS] or instance.config.level_ori != Const.LEVEL_DEBUG: return try: check_save_param(variable, name, save_backward) diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 4f6a60b0ab..7aef707a0d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -406,17 +406,17 @@ def load_api_data(api_data_bytes): def check_save_param(variable, name, save_backward): # try catch this api to skip invalid call if not isinstance(variable, (list, dict, torch.Tensor, int, float, str)): - logger.warning("PrecisionDebugger.save variable type not valid, \ - should be one of list, dict, torch.Tensor, int, float or string. \ - skip current save process") + logger.warning("PrecisionDebugger.save variable type not valid, " + "should be one of list, dict, torch.Tensor, int, float or string. " + "Skip current save process.") raise ValueError if not isinstance(name, str): - logger.warning("PrecisionDebugger.save name not valid, \ - should be string. \ - skip current save process") + logger.warning("PrecisionDebugger.save name not valid, " + "should be string. " + "skip current save process.") raise ValueError if not isinstance(save_backward, bool): - logger.warning("PrecisionDebugger.save save_backward not valid, \ - should be bool. \ - skip current save process") + logger.warning("PrecisionDebugger.save_backward name not valid, " + "should be bool. " + "Skip current save process.") raise ValueError \ No newline at end of file -- Gitee