From 92c408188b85d6dd6ccfe9ec53a7562f5773a08c Mon Sep 17 00:00:00 2001 From: zhttjd Date: Mon, 9 Jun 2025 11:54:07 +0800 Subject: [PATCH] Added patcher UT for distribute.py, functions.py, mmcv.py, mmdet.py, numpy.py, fix distribute.py's import issue --- mx_driving/patcher/distribute.py | 2 +- tests/torch/test_patcher_distribute.py | 79 +++++++++ tests/torch/test_patcher_functions.py | 126 +++++++++++++ tests/torch/test_patcher_mmcv.py | 81 +++++++++ tests/torch/test_patcher_mmdet.py | 237 +++++++++++++++++++++++++ tests/torch/test_patcher_numpy.py | 22 +++ 6 files changed, 546 insertions(+), 1 deletion(-) create mode 100644 tests/torch/test_patcher_distribute.py create mode 100644 tests/torch/test_patcher_functions.py create mode 100644 tests/torch/test_patcher_mmcv.py create mode 100644 tests/torch/test_patcher_mmdet.py create mode 100644 tests/torch/test_patcher_numpy.py diff --git a/mx_driving/patcher/distribute.py b/mx_driving/patcher/distribute.py index bab48147..93c5c67c 100644 --- a/mx_driving/patcher/distribute.py +++ b/mx_driving/patcher/distribute.py @@ -5,7 +5,7 @@ from typing import Dict def ddp(mmcvparallel: ModuleType, options: Dict): if hasattr(mmcvparallel, "distributed"): - import mmcv + import mmcv.device mmcvparallel.distributed.MMDistributedDataParallel = mmcv.device.npu.NPUDistributedDataParallel diff --git a/tests/torch/test_patcher_distribute.py b/tests/torch/test_patcher_distribute.py new file mode 100644 index 00000000..7dba095c --- /dev/null +++ b/tests/torch/test_patcher_distribute.py @@ -0,0 +1,79 @@ +import random +import types +import unittest +from unittest.mock import ANY, patch, MagicMock, PropertyMock + +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from mx_driving.patcher import ddp, ddp_forward + + +def assertIsNotInstance(obj, cls): + assert not isinstance(obj, cls), f"Expected {repr(obj)} to NOT be an instance of {cls.__name__}" + + +class TestDistribute(TestCase): + def setUp(self): + # Create mock objects for testing + self.mock_mmcvparallel = MagicMock() + self.mock_mmcvparallel.distributed = MagicMock() + self.mock_mmcvparallel.distributed.MMDistributedDataParallel = MagicMock() + + def test_ddp_patch(self): + # Apply monkey patch + ddp(self.mock_mmcvparallel, {}) + + assertIsNotInstance(self.mock_mmcvparallel.distributed.MMDistributedDataParallel, MagicMock) + + def test_ddp_forward_patch(self): + # Apply the ddp_forward patch + ddp_forward(self.mock_mmcvparallel, {}) + + # Get the patched _run_ddp_forward method + new_forward = self.mock_mmcvparallel.distributed.MMDistributedDataParallel._run_ddp_forward + + # Verify _run_ddp_forward is correctly replaced + assertIsNotInstance( + new_forward, + MagicMock + ) + + # Create mock instance and inputs + mock_self = MagicMock() + mock_self.device_ids = [0] # Simulate device IDs present + mock_self.module = MagicMock(return_value="module_output") + + # Mock the to_kwargs method + mock_self.to_kwargs = MagicMock(return_value=( + [("processed_input",)], + [{"processed_kwarg": "value"}] + )) + + # Call the patched forward method + result = new_forward(mock_self, "input1", "input2", kwarg1="value1") + + # Check to_kwargs is called correctly + mock_self.to_kwargs.assert_called_once_with( + ("input1", "input2"), + {"kwarg1": "value1"}, + 0 + ) + + # Check module is called correctly + mock_self.module.assert_called_once_with( + "processed_input", + processed_kwarg="value" + ) + + # Verify return value + self.assertEqual(result, "module_output") + + # Test case with no device_ids + mock_self.reset_mock() + mock_self.device_ids = [] + result = new_forward(mock_self, "input3", kwarg2="value2") + mock_self.module.assert_called_once_with("input3", kwarg2="value2") + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/tests/torch/test_patcher_functions.py b/tests/torch/test_patcher_functions.py new file mode 100644 index 00000000..24822b65 --- /dev/null +++ b/tests/torch/test_patcher_functions.py @@ -0,0 +1,126 @@ +import random +import types +import unittest +from unittest.mock import patch, MagicMock, PropertyMock +from typing import List, Union, Dict +from types import ModuleType + +import torch +from torch_npu.testing.testcase import TestCase, run_tests + + +def assertIsNotInstance(obj, cls): + assert not isinstance(obj, cls), f"Expected {repr(obj)} to NOT be an instance of {cls.__name__}" + + +class TestPatcherStream(TestCase): + def setUp(self): + # Create mock mmcvparallel module + self.mock_mmcvparallel = types.ModuleType('mmcvparallel') + self.mock_mmcvparallel._functions = types.ModuleType('_functions') + self.mock_mmcvparallel._functions.Scatter = MagicMock() + + # Add the missing attributes for torch + self.mock_mmcvparallel._functions.torch = types.ModuleType('torch') + self.mock_mmcvparallel._functions.torch.device = torch.device + + # Set up necessary functions and types + self.mock_mmcvparallel._functions.get_input_device = MagicMock() + self.mock_mmcvparallel._functions.scatter = MagicMock() + self.mock_mmcvparallel._functions.synchronize_stream = MagicMock() + self.mock_mmcvparallel._functions._get_stream = MagicMock() + self.mock_mmcvparallel._functions.Tensor = torch.Tensor + + # Set default return values + self.mock_mmcvparallel._functions.get_input_device.return_value = -1 + self.mock_mmcvparallel._functions.scatter.return_value = ["scatter_output"] + + # Dynamically return target # of gpu + def scatter_mock(input_, target_gpus, streams=None): + return [f"output_{i}" for i in range(len(target_gpus))] + + self.mock_mmcvparallel._functions.scatter = MagicMock(side_effect=scatter_mock) + + def test_monkeypatch(self): + """Verify forward method is correctly replaced""" + from mx_driving.patcher import stream + options = {} + + # Apply monkeypatch using stream function + stream(self.mock_mmcvparallel, options) + + # Verify Scatter.forward has been replaced with new_forward + assertIsNotInstance(self.mock_mmcvparallel._functions.Scatter.forward, MagicMock) + + def test_new_forward_input_device_neg_one(self): + """Test stream behavior when input device is -1 and target GPUs are not [-1]""" + from mx_driving.patcher import stream + stream(self.mock_mmcvparallel, {}) + + # Create mock input + test_input = MagicMock(spec=torch.Tensor) + target_gpus = [0, 1] + + # Execute new forward method + result = self.mock_mmcvparallel._functions.Scatter.forward.__func__(target_gpus, test_input) + + # Verify stream handling logic + self.mock_mmcvparallel._functions._get_stream.assert_called() + self.mock_mmcvparallel._functions.scatter.assert_called_once() + self.mock_mmcvparallel._functions.synchronize_stream.assert_called_once() + + # Verify output format + self.assertEqual(len(result), len(target_gpus)) + self.assertIsInstance(result, tuple) + + def test_new_forward_non_neg_input_device(self): + """Test behavior when input device is not -1""" + from mx_driving.patcher import stream + stream(self.mock_mmcvparallel, {}) + + # Set input device to non-negative value + self.mock_mmcvparallel._functions.get_input_device.return_value = 0 + test_input = MagicMock(spec=torch.Tensor) + target_gpus = [0, 1] + + # Execute new forward method + result = self.mock_mmcvparallel._functions.Scatter.forward.__func__(target_gpus, test_input) + + # Verify no stream handling occurs + self.mock_mmcvparallel._functions._get_stream.assert_not_called() + self.mock_mmcvparallel._functions.scatter.assert_called_once() + self.mock_mmcvparallel._functions.synchronize_stream.assert_not_called() + self.assertIsInstance(result, tuple) + + def test_new_forward_list_input(self): + """Test handling of list input""" + from mx_driving.patcher import stream + stream(self.mock_mmcvparallel, {}) + + # Create list input + test_input = [torch.tensor([1]), torch.tensor([2])] + target_gpus = [0, 1] + + # Execute new forward method + result = self.mock_mmcvparallel._functions.Scatter.forward.__func__(target_gpus, test_input) + + # Verify processing logic + self.mock_mmcvparallel._functions.get_input_device.assert_called_once() + self.mock_mmcvparallel._functions.scatter.assert_called_once() + self.assertIsInstance(result, tuple) + + def test_no_scatter_class(self): + """Verify graceful handling when Scatter class is missing""" + mock_mmcvparallel = MagicMock() + mock_mmcvparallel._functions = MagicMock() + delattr(mock_mmcvparallel._functions, "Scatter") + + from mx_driving.patcher import stream + try: + stream(mock_mmcvparallel, {}) + except AttributeError: + self.fail("stream should handle missing Scatter class gracefully") + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/tests/torch/test_patcher_mmcv.py b/tests/torch/test_patcher_mmcv.py new file mode 100644 index 00000000..0b61c607 --- /dev/null +++ b/tests/torch/test_patcher_mmcv.py @@ -0,0 +1,81 @@ +import importlib +import types +import unittest +from unittest.mock import MagicMock, patch, PropertyMock +from torch_npu.testing.testcase import TestCase, run_tests + +import mx_driving.patcher.mmcv as mx_mmcv + + +def assertIsNotInstance(obj, cls): + assert not isinstance(obj, cls), f"Expected {repr(obj)} to NOT be an instance of {cls.__name__}" + + +class TestPatcherMMCV(TestCase): + def setUp(self): + pass + + def test_patch_mmcv_version_found(self): + """Test successful import of mmcv and version patching""" + with patch('importlib.import_module') as mock_import: + # Mock mmcv module + mock_mmcv = MagicMock() + mock_mmcv.__version__ = "1.7.2" + mock_import.return_value = mock_mmcv + + # Call patching function + mx_mmcv.patch_mmcv_version("2.1.0") + + # Assert version restoration + self.assertEqual(mock_mmcv.__version__, "1.7.2", "Version should be restored to original") + + # Assert import attempts + mock_import.assert_any_call("mmdet") + mock_import.assert_any_call("mmdet3d") + + def test_patch_mmcv_version_not_found(self): + """Test handling when mmcv cannot be imported""" + with patch('importlib.import_module') as mock_import: + mock_import.side_effect = ImportError + # Assert no exception raised + mx_mmcv.patch_mmcv_version("666.888.2333") + mock_import.assert_called_once_with("mmcv") + + def test_dc(self): + """Test monkeypatching for deform_conv2d""" + mock_mmcvops = MagicMock() + + # Call dc function + mx_mmcv.dc(mock_mmcvops, {}) + + # Assert function replacements + assertIsNotInstance(mock_mmcvops.deform_conv.DeformConv2dFunction, MagicMock) + assertIsNotInstance(mock_mmcvops.deform_conv.deform_conv2d, MagicMock) + + def test_mdc(self): + """Test monkeypatching for modulated_deform_conv2d""" + mock_mmcvops = MagicMock() + mock_mmcvops.modulated_deform_conv = MagicMock() + + # Call mdc function + mx_mmcv.mdc(mock_mmcvops, {}) + + # Assert function replacements + assertIsNotInstance(mock_mmcvops.modulated_deform_conv.ModulatedDeformConv2dFunction, MagicMock) + assertIsNotInstance(mock_mmcvops.modulated_deform_conv.modulated_deform_conv2d, MagicMock) + + def test_msda(self): + """Test monkeypatching for multi_scale_deformable_attn""" + mock_mmcvops = MagicMock() + mock_mmcvops.multi_scale_deformable_attn = MagicMock() + + # Call msda function + mx_mmcv.msda(mock_mmcvops, {}) + + # Assert function replacements + assertIsNotInstance(mock_mmcvops.multi_scale_deformable_attn.MultiScaleDeformableAttnFunction, MagicMock) + assertIsNotInstance(mock_mmcvops.multi_scale_deformable_attn.multi_scale_deformable_attn, MagicMock) + + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/tests/torch/test_patcher_mmdet.py b/tests/torch/test_patcher_mmdet.py new file mode 100644 index 00000000..7e1eac86 --- /dev/null +++ b/tests/torch/test_patcher_mmdet.py @@ -0,0 +1,237 @@ +import types +import unittest +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn +from torch_npu.testing.testcase import TestCase, run_tests + +from mx_driving.patcher import pseudo_sampler, resnet_add_relu, resnet_maxpool + + +def assertIsNotInstance(obj, cls): + assert not isinstance(obj, cls), f"Expected {repr(obj)} to NOT be an instance of {cls.__name__}" + + +class TestPseudoSamplerPatcher(TestCase): + + def setUp(self): + # Create mock mmdetsamplers module + self.mock_mmdetsamplers = types.ModuleType('mmdetsamplers') + self.mock_mmdetsamplers.pseudo_sampler = types.ModuleType('pseudo_sampler') + self.mock_mmdetsamplers.sampling_result = types.ModuleType('sampling_result') + + # Mock PseudoSampler class and sample method + self.mock_pseudo_sampler_cls = MagicMock() + self.mock_mmdetsamplers.pseudo_sampler.PseudoSampler = self.mock_pseudo_sampler_cls + + # Mock SamplingResult class + self.mock_sampling_result = MagicMock() + self.mock_mmdetsamplers.sampling_result.SamplingResult = self.mock_sampling_result + + def test_patching_with_pseudo_sampler(self): + """Test successful patching when PseudoSampler exists.""" + # Apply patching + pseudo_sampler(self.mock_mmdetsamplers, {}) + + # Verify sample method was replaced + assertIsNotInstance( + self.mock_mmdetsamplers.pseudo_sampler.PseudoSampler.sample, + MagicMock + ) + + def test_patched_behavior(self): + """Test behavior of patched sample method.""" + # Apply patching + pseudo_sampler(self.mock_mmdetsamplers, {}) + + # Create mock inputs + mock_assign_result = MagicMock() + mock_assign_result.gt_inds = torch.tensor([1, 0, 2]).unsqueeze(1) + mock_bboxes = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]) + mock_gt_bboxes = torch.tensor([[0.1, 0.1, 0.2], [0.3, 0.3, 0.4]]) + + # Call patched method + patched_sample = self.mock_mmdetsamplers.pseudo_sampler.PseudoSampler.sample + sampling_result = patched_sample( + None, mock_assign_result, mock_bboxes, mock_gt_bboxes + ) + + # Verify SamplingResult construction + self.mock_mmdetsamplers.sampling_result.SamplingResult.assert_called_once() + pos_inds, neg_inds, _, _, _, _ = self.mock_sampling_result.call_args[0] + self.assertRtolEqual(pos_inds, torch.tensor([True, False, True])) + self.assertRtolEqual(neg_inds, torch.tensor([False, True, False])) + + +class TestResNetAddReLU(TestCase): + def setUp(self): + self.mock_mmdetresnet = types.ModuleType('mmdetresnet') + + class BasicBlock(nn.Module): + def __init__(self, downsample=None): + super().__init__() + self.conv1 = nn.Identity() + self.norm1 = nn.Identity() + self.relu = nn.Identity() + self.conv2 = nn.Identity() + self.norm2 = nn.Identity() + self.downsample = downsample + self.with_cp = False + + def forward(self, x): + return x # just a mock, original implementation will be replaced + + class Bottleneck(nn.Module): + def __init__(self, downsample=None): + super().__init__() + self.conv1 = nn.Identity() + self.norm1 = nn.Identity() + self.relu = nn.Identity() + self.conv2 = nn.Identity() + self.norm2 = nn.Identity() + self.conv3 = nn.Identity() + self.norm3 = nn.Identity() + self.downsample = downsample + self.with_cp = False + self.with_plugins = False + self.after_conv1_plugin_names = [] + self.after_conv2_plugin_names = [] + self.after_conv3_plugin_names = [] + + def forward_plugin_func(x, _): + return x + + self.forward_plugin = forward_plugin_func + + def forward(self, x): + return x # just a mock, original implementation will be replaced + + self.mock_mmdetresnet.BasicBlock = BasicBlock + self.mock_mmdetresnet.Bottleneck = Bottleneck + + def test_basic_block_forward(self): + # Apply patch + resnet_add_relu(self.mock_mmdetresnet, {}) + + block = self.mock_mmdetresnet.BasicBlock() + x = torch.tensor([1.0]).npu() + + # execute forward and verify + result = block(x) + self.assertRtolEqual(result, torch.tensor([2.0]).npu()) + + def test_basic_block_with_downsample(self): + # Apply patch + resnet_add_relu(self.mock_mmdetresnet, {}) + + # with downsample + downsample = nn.Identity() + block = self.mock_mmdetresnet.BasicBlock(downsample=downsample) + x = torch.tensor([1.0]).npu() + + # execute forward and verify + result = block(x) + self.assertRtolEqual(result, torch.tensor([2.0]).npu()) + + def test_basic_block_with_cp(self): + # Apply patch + resnet_add_relu(self.mock_mmdetresnet, {}) + + # with checkpoint + block = self.mock_mmdetresnet.BasicBlock() + block.with_cp = True + x = torch.tensor([1.0], requires_grad=True).npu() + + # execute forward and verify + result = block(x) + self.assertRtolEqual(result, torch.tensor([2.0]).npu()) + + def test_bottleneck_forward(self): + # Apply patch + resnet_add_relu(self.mock_mmdetresnet, {}) + + block = self.mock_mmdetresnet.Bottleneck() + x = torch.tensor([1.0]).npu() + + # execute forward and verify + result = block(x) + self.assertRtolEqual(result, torch.tensor([2.0]).npu()) + + def test_bottleneck_with_plugins(self): + # Apply patch + resnet_add_relu(self.mock_mmdetresnet, {}) + + # with plugins + block = self.mock_mmdetresnet.Bottleneck() + block.with_plugins = True + x = torch.tensor([1.0]).npu() + + # execute and verify + result = block(x) + self.assertRtolEqual(result, torch.tensor([2.0]).npu()) + + +class TestResNetMaxPool(TestCase): + def setUp(self): + self.mock_mmdetresnet = types.ModuleType('mmdetresnet') + + class ResNet(nn.Module): + def __init__(self, deep_stem=False): + super().__init__() + self.deep_stem = deep_stem + self.stem = nn.Identity() if deep_stem else None + self.conv1 = nn.Identity() + self.norm1 = nn.Identity() + self.relu = nn.Identity() + self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) + self.res_layers = ['layer1', 'layer2'] + self.out_indices = [0, 1] + self.layer1 = nn.Identity() + self.layer2 = nn.Identity() + + def forward(self, x): + return x # just a mock, original implementation will be replaced + + self.mock_mmdetresnet.ResNet = ResNet + + def test_forward_with_grad(self): + resnet_maxpool(self.mock_mmdetresnet, {}) + + model = self.mock_mmdetresnet.ResNet() + x = torch.ones(1, 3, 32, 32, requires_grad=True).npu() + + result = model(x) + self.assertEqual(len(result), 2) # verify output layer number + + def test_forward_without_grad(self): + resnet_maxpool(self.mock_mmdetresnet, {}) + + model = self.mock_mmdetresnet.ResNet() + x = torch.ones(1, 3, 32, 32, requires_grad=False).npu() + + result = model(x) + self.assertEqual(len(result), 2) + + def test_deep_stem_path(self): + resnet_maxpool(self.mock_mmdetresnet, {}) + + model = self.mock_mmdetresnet.ResNet(deep_stem=True) + x = torch.ones(1, 3, 32, 32).npu() + + result = model(x) + self.assertEqual(len(result), 2) + + def test_out_indices_handling(self): + resnet_maxpool(self.mock_mmdetresnet, {}) + + model = self.mock_mmdetresnet.ResNet() + model.out_indices = [1] # only output the second layer + + x = torch.ones(1, 3, 32, 32).npu() + + result = model(x) + self.assertEqual(len(result), 1) # verify only output 1 layer + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/tests/torch/test_patcher_numpy.py b/tests/torch/test_patcher_numpy.py new file mode 100644 index 00000000..1534773a --- /dev/null +++ b/tests/torch/test_patcher_numpy.py @@ -0,0 +1,22 @@ +import importlib +import types +import unittest +from unittest.mock import Mock +from torch_npu.testing.testcase import TestCase, run_tests + +from mx_driving.patcher import numpy_type + + +class TestNumpyTypePatch(TestCase): + + def setUp(self): + self.mock_np = Mock(spec=[]) # restrictd attr, hasattr will be false + + def test_numpy_type_patch_replacement(self): + numpy_type(self.mock_np, {}) + self.assertEqual(self.mock_np.bool, bool) + self.assertEqual(self.mock_np.float, float) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file -- Gitee