diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py index 064103aead22e0f417794e047bbb83557e1b7020..44c0aa61a5e19b75f2f295f46bb64d8c13437e96 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/api_registry.py @@ -17,13 +17,14 @@ import torch import torch.distributed as dist -from . import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten +from . import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten, wrap_linalg from .wrap_torch import get_torch_ops from .wrap_functional import get_functional_ops from .wrap_tensor import get_tensor_ops from .wrap_vf import get_vf_ops from .wrap_distributed import get_distributed_ops from .wrap_aten import get_aten_ops +from .wrap_linalg import get_linalg_ops from ..common.utils import torch_without_guard_version, npu_distributed_api torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' @@ -47,6 +48,7 @@ class ApiRegistry: self.vf_ori_attr = {} self.aten_ori_attr = {} self.torch_npu_ori_attr = {} + self.linalg_ori_attr = {} self.tensor_hook_attr = {} self.torch_hook_attr = {} @@ -56,6 +58,7 @@ class ApiRegistry: self.vf_hook_attr = {} self.aten_hook_attr = {} self.torch_npu_hook_attr = {} + self.linalg_hook_attr = {} @staticmethod def store_ori_attr(ori_api_group, api_list, api_ori_attr): @@ -68,6 +71,7 @@ class ApiRegistry: setattr(api_group, api, api_attr) def api_modularity(self): + self.set_api_attr(torch.linalg, self.linalg_hook_attr) self.set_api_attr(torch.Tensor, self.tensor_hook_attr) self.set_api_attr(torch, self.torch_hook_attr) self.set_api_attr(torch.nn.functional, self.functional_hook_attr) @@ -83,6 +87,7 @@ class ApiRegistry: self.set_api_attr(torch_npu, self.torch_npu_hook_attr) def api_originality(self): + self.set_api_attr(torch.linalg, self.linalg_ori_attr) self.set_api_attr(torch.Tensor, self.tensor_ori_attr) self.set_api_attr(torch, self.torch_ori_attr) self.set_api_attr(torch.nn.functional, self.functional_ori_attr) @@ -98,6 +103,12 @@ class ApiRegistry: self.set_api_attr(torch_npu, self.torch_npu_ori_attr) def initialize_hook(self, hook): + self.store_ori_attr(torch.linalg, get_linalg_ops(), self.linalg_ori_attr) + wrap_linalg.wrap_linalg_ops_and_bind(hook) + for attr_name in dir(wrap_linalg.HOOKLinalgOP): + if attr_name.startswith("wrap_"): + self.linalg_hook_attr[attr_name[5:]] = getattr(wrap_linalg.HOOKLinalgOP, attr_name) + self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) wrap_tensor.wrap_tensor_ops_and_bind(hook) for attr_name in dir(wrap_tensor.HOOKTensor): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml index de9fac5dbbbbd1a008f2c61ace8ea7fcabd7efe7..8f0d0901512c7099150ac7dc545422c3f7582ae1 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/support_wrap_ops.yaml @@ -566,6 +566,7 @@ torch: - _aminmax - _batch_norm_impl_index - _convolution + - _foreach_norm - _softmax_backward_data - abs - abs_ @@ -1019,6 +1020,9 @@ torch: _VF: - lstm +torch.linalg: + - vector_norm + torch_npu: - one_ - npu_sort_v2 diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_linalg.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..e43de5ac9773f3e5dae84fa0b3614c59b266cc94 --- /dev/null +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_linalg.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with open(yaml_path, 'r') as f: + WrapLinalgOps = yaml.safe_load(f).get('torch.linalg', []) + + +def get_linalg_ops(): + global WrapLinalgOps + return WrapLinalgOps + + +LinalgOps = {op: getattr(torch.linalg, op) for op in get_linalg_ops()} + + +class HOOKLinalgOP(object): + pass + + +class LinalgOPTemplate(HOOKModule): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Linalg_" + str(op_name) + "_" + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return LinalgOps[self.op_name_](*args, **kwargs) + + +def wrap_linalg_op(op_name, hook): + + def linalg_op_template(*args, **kwargs): + return LinalgOPTemplate(op_name, hook)(*args, **kwargs) + + return linalg_op_template + + +def wrap_linalg_ops_and_bind(hook): + _linalg_ops = get_linalg_ops() + for op_name in _linalg_ops: + setattr(HOOKLinalgOP, "wrap_" + op_name, wrap_linalg_op(op_name, hook)) \ No newline at end of file