diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_functional.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_functional.py index 1ce938129628aefebfa7a991ce166c47159cda0d..0533f55b54c43543301f5653162dc7f522027de9 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_functional.py @@ -67,9 +67,6 @@ yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") with FileOpen(yaml_path, 'r') as f: WrapFunctionalOps = yaml.safe_load(f).get('functional') -for f in dir(torch.nn.functional): - locals().update({f: getattr(torch.nn.functional, f)}) - def get_functional_ops(): global WrapFunctionalOps @@ -77,6 +74,9 @@ def get_functional_ops(): return set(WrapFunctionalOps) & set(_all_functional_ops) +TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} + + class HOOKFunctionalOP(object): pass @@ -89,7 +89,7 @@ class FunctionalOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return eval(self.op_name_)(*args, **kwargs) + return TorchFunctions[str(self.op_name_)](*args, **kwargs) def wrap_functional_op(op_name, hook): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_tensor.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_tensor.py index c5e3321bb8303fb4d92ff4c14b9a0dcb5df640b9..cddc99d91b0a6c9e8cd5d8a41ed11f8cf084b032 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_tensor.py @@ -32,10 +32,13 @@ with FileOpen(yaml_path, 'r') as f: def get_tensor_ops(): global WrapTensorOps - _tensor_ops = dir(torch._C._TensorBase) + _tensor_ops = dir(torch.Tensor) return set(WrapTensorOps) & set(_tensor_ops) +TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} + + class HOOKTensor(object): pass @@ -50,7 +53,7 @@ class TensorOPTemplate(HOOKModule): @torch_device_guard @parameter_adapter def forward(self, *args, **kwargs): - return getattr(torch._C._TensorBase, str(self.op_name_))(*args, **kwargs) + return TensorOps[str(self.op_name_)](*args, **kwargs) def wrap_tensor_op(op_name, hook): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py index d1ad40fe45c0feec8a604f28bb592b0825ea085a..5dcc41b1c8c23c90a6bbad1fe764fef389595661 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_torch.py @@ -32,10 +32,13 @@ with FileOpen(yaml_path, 'r') as f: def get_torch_ops(): global WrapTorchOps - _torch_ops = dir(torch._C._VariableFunctionsClass) + _torch_ops = dir(torch) return set(WrapTorchOps) & set(_torch_ops) +TorchOps = {op: getattr(torch, op) for op in get_torch_ops()} + + class HOOKTorchOP(object): pass @@ -47,52 +50,9 @@ class TorchOPTemplate(HOOKModule): self.prefix_op_name_ = "Torch_" + str(op_name) + "_" super().__init__(hook) - def input_param_need_adapt(self): - special_op_list = ["broadcast_tensors", "block_diag"] - for item in special_op_list: - if item in self.op_name_: - return True - return False - - def einsum_adapt(self, *args): - if len(args) < 2: - raise ValueError('einsum(): must specify the equation string and at least one operand, ' - 'or at least one operand and its subscripts list') - equation = None - operands = None - if isinstance(args[0], torch.Tensor): - def parse_subscript(n: int) -> str: - if n == Ellipsis: - return '...' - if n >= 0 and n < 26: - return chr(ord('A') + n) - if n >= 26 and n < 52: - return chr(ord('a') + n - 26) - raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52]') - equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) - - if len(args) % 2 == 1: - equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) - operands = args[:-1:2] - else: - operands = args[::2] - else: - equation = args[0] - operands = args[1:] - - if len(operands) == 1 and isinstance(operands[0], (list, tuple)): - _operands = operands[0] - return self.einsum_adapt(equation, *_operands) - return equation, operands - @torch_device_guard def forward(self, *args, **kwargs): - if self.input_param_need_adapt(): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(args, **kwargs) - else: - if self.op_name_ == 'einsum': - args = self.einsum_adapt(*args) - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + return TorchOps[str(self.op_name_)](*args, **kwargs) def wrap_torch_op(op_name, hook): diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_torch.py index 0c5a997f307f82857bff62043c6e7a3f8e88397f..ef0350ccc0196deda397d51be4d75afd57fe0076 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_torch.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_torch.py @@ -20,10 +20,6 @@ class TestWrapTorch(unittest.TestCase): self.assertEqual(template.op_name_, self.op_name) self.assertEqual(template.prefix_op_name_, "Torch_" + str(self.op_name) + "_") - def test_input_param_need_adapt(self): - template = TorchOPTemplate(self.op_name, self.hook) - self.assertFalse(template.input_param_need_adapt()) - def test_forward(self): template = TorchOPTemplate(self.op_name, self.hook) result = template.forward(torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]))