diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_aten.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_aten.py index 3b746e9c5f3bb5fd726ad3c41ae49a675b33751a..57003eb5252c327c71b6c4bf252c8432667571fa 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_aten.py @@ -47,24 +47,53 @@ class HOOKAtenOP(object): class AtenOPTemplate(HOOKModule): - def __init__(self, op_name, hook): - self.op_name_ = op_name - self.prefix_op_name_ = "Aten_" + str(op_name) + "_" + def __init__(self, op, hook): + if isinstance(op, torch._ops.OpOverloadPacket): + op_name_ = op._qualified_op_name.split("::")[-1] + else: + op_name_ = op.name().split("::")[-1] + overload_name = op._overloadname + if not '.' + overload_name in op_name_: + op_name_ = op_name_ + '.' + overload_name + self.op = op + self.prefix_op_name_ = "Aten_" + str(op_name_) + "_" super().__init__(hook) @torch_device_guard def forward(self, *args, **kwargs): - return aten_func.get(self.op_name_)(*args, **kwargs) + return self.op(*args, **kwargs) -def wrap_aten_op(op_name, hook): - def aten_op_template(*args, **kwargs): - return AtenOPTemplate(op_name, hook)(*args, **kwargs) +class AtenOPPacketTemplate(): + def __init__(self, opPacket, hook): + self.opPacket = opPacket + self.hook = hook - return aten_op_template + def __getattr__(self, key): + try: + attr = getattr(self.opPacket, key) + except AttributeError as e: + raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e + if isinstance(attr, torch._ops.OpOverload): + return AtenOPTemplate(attr, self.hook) + else: + return attr + + def overloads(self): + return self.opPacket.overloads() + + @torch_device_guard + def __call__(self, *args, **kwargs): + return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs) + + +def wrap_aten_op(op, hook): + return AtenOPPacketTemplate(op, hook) def wrap_aten_ops_and_bind(hook): _aten_ops = get_aten_ops() for op_name in _aten_ops: - setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(op_name, hook)) + if not isinstance(aten_func.get(op_name), torch._ops.OpOverloadPacket): + continue + setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(aten_func.get(op_name), hook)) diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_aten.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3379f3bd9366acec456ae71d8d3a8251526659 --- /dev/null +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_wrap_aten.py @@ -0,0 +1,44 @@ +import unittest +import torch +from ptdbg_ascend.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTemplate + + +def noop_hook_wrapper(name): + def noop_hook(module, in_feat, out_feat): + pass + return noop_hook + +if torch.__version__.split("+")[0] > '2.0': + class TestWrapAten(unittest.TestCase): + def setUp(self): + self.aten_op = AtenOPPacketTemplate(torch.ops.aten.convolution, noop_hook_wrapper) + + def test_atenop_attribute(self): + self.assertIsInstance(self.aten_op.default, AtenOPTemplate) + self.assertIsInstance(self.aten_op.out, AtenOPTemplate) + self.assertEqual(self.aten_op.default, torch.ops.aten.convolution.default) + self.assertEqual(self.aten_op.out, torch.ops.aten.convolution.out) + + def test_atenop_forward(self): + image = torch.randn(4, 3, 24, 24) + kernel = torch.randn(10, 3, 3, 3) + functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1], + padding=[1, 1], dilation=[1, 1], groups=1, bias=None) + aten_out = self.aten_op(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1) + self.assertTrue(torch.all(functional_out == aten_out)) + + def test_atenop_overload_forward(self): + image = torch.randn(4, 3, 24, 24) + kernel = torch.randn(10, 3, 3, 3) + functional_out = torch.nn.functional.conv2d(image, kernel, stride=[1, 1], + padding=[1, 1], dilation=[1, 1], groups=1, bias=None) + aten_out = self.aten_op.default(image, kernel, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1) + self.assertTrue(torch.all(functional_out == aten_out)) + + def test_atenop_nonattr(self): + self.assertRaises(AttributeError, getattr, self.aten_op, "foo") + + def test_atenop_overloads(self): + self.assertEqual(self.aten_op.overloads(), self.aten_op.opPacket.overloads()) + +