diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 8d24917a2c6b7b086180217af23565df3a4699a7..0ae0e74fbc2fdeecb389201dce5754bcf554c87a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -421,15 +421,54 @@ def module_count_func(name, name_template): return index +def retrieve_golden(item_to_replace, golden_path): + golden_tensor = torch.tensor(np.load(golden_path)) + if not isinstance(item_to_replace, torch.Tensor): + print(item_to_replace, type(item_to_replace), golden_tensor) + print(golden_path) + return golden_tensor.item() + else: + return golden_tensor.to(item_to_replace.device) + + +def replace_input_with_golden(name_template, in_feat, dump_step, module): + golden_prefix = "/home/w30031161/sdxl/sdxl_code/sdxl_lora_pytorch/diffusers0.25.0/david_dump_encode/ptdbg_dump_v6.0.T4/step0/rank0/api_stack_dump/" + try: + print(name_template) + print(in_feat) + in_feat_type = type(in_feat) + if isinstance(in_feat, tuple): + in_feat = list(in_feat) + if isinstance(in_feat, list): + for idx, feat in enumerate(in_feat): + if isinstance(feat, (tuple, list)): + in_feat[idx] = replace_input_with_golden(name_template + ".{}".format(idx), in_feat[idx], dump_step, module) + continue + golden_path = golden_prefix + name_template + ".{}.npy".format(idx) + in_feat[idx] = retrieve_golden(in_feat[idx], golden_path) + else: + golden_path = golden_prefix + name_template + ".npy" + return retrieve_golden(in_feat, golden_path) + except FileNotFoundError as e: + print("[WARNING] GOLDEN NOT FOUND: ", e) + except ValueError as e: + print("[WARNING] OP NOT EXIST: ", e) + in_feat = tuple(in_feat) if in_feat_type == tuple else in_feat + return in_feat + + def acc_cmp_dump(name, **kwargs): dump_step = kwargs.get('dump_step', 1) pid = kwargs.get('pid') + pre_forward_hook = kwargs.get('pre_forward_hook', 0) name_template = name if not pid: return RuntimeError("Not get the specified process pid.") def acc_cmp_hook(module, in_feat, out_feat=None): nonlocal name, name_template + replaced_input = None + place_holder = Const.DELIMITER + "{}" + Const.DELIMITER if place_holder in name_template: try: @@ -439,11 +478,19 @@ def acc_cmp_dump(name, **kwargs): raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e name = name_template.format(index) if pid == os.getpid(): - dump_acc_cmp(name, in_feat, out_feat, dump_step, module) + name_template = f"{name}" + Const.DELIMITER + "input" + if pre_forward_hook: + replaced_input = replace_input_with_golden(name_template, in_feat, dump_step, module) + print('in pre forward_hook') + if not pre_forward_hook: + print('in forward_hook') + dump_acc_cmp(name, in_feat, out_feat, dump_step, module) if hasattr(module, "input_args"): del module.input_args if hasattr(module, "input_kwargs"): del module.input_kwargs + + return replaced_input return acc_cmp_hook diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py index 1b3e7d37f289423aef98b049ffc66cf2b400351f..e56c42c677d4294425674ed753a1367fea8cb51d 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py @@ -47,6 +47,7 @@ class HOOKModule(nn.Module): else: HOOKModule.module_count[self.prefix] += 1 self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.DELIMITER + self.register_forward_pre_hook(hook(self.prefix + Const.FORWARD, pre_forward_hook=1)) self.register_forward_hook(hook(self.prefix + Const.FORWARD)) self.register_backward_hook(hook(self.prefix + Const.BACKWARD))