diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index ea06819ca35e9fd8fd078a67ec87da9310e4b99b..4569688a1000c3997a2fac622eb4d76b96864668 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -228,7 +228,7 @@ class PytorchDataProcessor(BaseDataProcessor): if isinstance(element, dist.ProcessGroup): return self._analyze_process_group(element) if isinstance(element, dist.P2POp): - return self._analyze_p2pop(element) + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, dist.ReduceOp): return self._analyze_reduce_op(element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) @@ -247,10 +247,10 @@ class PytorchDataProcessor(BaseDataProcessor): module_input_output.update_output_with_args_and_kwargs() return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg): + def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: - tensor_info = self._analyze_tensor(arg.tensor, []) + tensor_info = self._analyze_tensor(arg.tensor, suffix) p2pop_info.update({"tensor": tensor_info}) p2pop_info.update({"op": arg.op.__name__}) p2pop_info.update({"peer": arg.peer})