diff --git a/debug/accuracy_tools/msprobe/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py index f9368a087458ef62b875ccbad7c9d3c101aea7ec..9cee721db9b4ea6fce0969a22cac409bf71859e7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py @@ -26,9 +26,15 @@ class ModuleProcesser: def filter_tensor_and_tuple(func): @wraps(func) def wrap_by_filter_tensor_and_tuple(*args, **kwargs): - # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入 + # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是解析非tensor数据的属性,对tensor属性挂hook # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] if not isinstance(args[1], (torch.Tensor, tuple)): + for item_str in dir(args[1]): + item = getattr(args[1], item_str) + if isinstance(item, (torch.Tensor, tuple)): + args_new = (args[0], item) + result = func(*args_new, **kwargs) + setattr(args[1], item_str, result) return args[1] return func(*args, **kwargs)