diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py index 064103aead22e0f417794e047bbb83557e1b7020..cf21fe86bb541e64101dbdd360739a136f898d71 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py @@ -60,12 +60,23 @@ class ApiRegistry: @staticmethod def store_ori_attr(ori_api_group, api_list, api_ori_attr): for api in api_list: - api_ori_attr[api] = getattr(ori_api_group, api) + if '.' in api: + sub_module_name, sub_op = api.rsplit('.', 1) + sub_module = getattr(ori_api_group, sub_module_name) + api_ori_attr[api] = getattr(sub_module, sub_op) + else: + api_ori_attr[api] = getattr(ori_api_group, api) @staticmethod def set_api_attr(api_group, attr_dict): for api, api_attr in attr_dict.items(): - setattr(api_group, api, api_attr) + if '.' in api: + sub_module_name, sub_op = api.rsplit('.', 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) def api_modularity(self): self.set_api_attr(torch.Tensor, self.tensor_hook_attr) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml index de9fac5dbbbbd1a008f2c61ace8ea7fcabd7efe7..92096fc4bb336928b2ddf9c3e8eba33dca71a12c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml @@ -560,12 +560,45 @@ tensor: - xlogy_ torch: + - linalg.norm + - linalg.vector_norm + - linalg.matrix_norm + - linalg.diagonal + - linalg.det + - linalg.slogdet + - linalg.cond + - linalg.matrix_rank + - linalg.qr + - linalg.lu + - linalg.lu_factor + - linalg.svd + - linalg.svdvals + - linalg.solve + - linalg.lstsq + - linalg.inv + - linalg.pinv + - linalg.matrix_exp + - linalg.matrix_power + - linalg.cross + - linalg.matmul + - linalg.vecdot + - linalg.multi_dot + - linalg.householder_product + - linalg.tensorsolve + - linalg.vander + - linalg.cholesky_ex + - linalg.inv_ex + - linalg.solve_ex + - linalg.lu_factor_ex + - linalg.ldl_factor + - linalg.ldl_factor_ex - _adaptive_avg_pool2d - _add_relu - _add_relu_ - _aminmax - _batch_norm_impl_index - _convolution + - _foreach_norm - _softmax_backward_data - abs - abs_ diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py index 5dcc41b1c8c23c90a6bbad1fe764fef389595661..e3a4af7a850397989fad1e810181eaa7d6fccfb1 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py @@ -32,12 +32,29 @@ with FileOpen(yaml_path, 'r') as f: def get_torch_ops(): global WrapTorchOps - _torch_ops = dir(torch) - return set(WrapTorchOps) & set(_torch_ops) + _torch_ops = [] + for operation in WrapTorchOps: + if '.' in operation: + operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) + operation_sub_module = getattr(torch, operation_sub_module_name) + if operation_sub_op in dir(operation_sub_module): + _torch_ops.append(operation) + else: + if hasattr(torch, operation): + _torch_ops.append(operation) + return set(_torch_ops) + + +TorchOps = {} +for op in get_torch_ops(): + if '.' in op: + sub_module_name, sub_op = op.rsplit('.', 1) + sub_module = getattr(torch, sub_module_name) + TorchOps[op] = getattr(sub_module, sub_op) + else: + TorchOps[op] = getattr(torch, op) -TorchOps = {op: getattr(torch, op) for op in get_torch_ops()} - class HOOKTorchOP(object): pass