diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..82398376b698ece31d5272982ba56a3e654eebc6 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +import shutil +import sys +import unittest +from unittest import TestCase, mock +from unittest.mock import patch + +import pandas as pd + +from msprof_analyze.test.st.utils import execute_cmd +from msprof_analyze.cluster_analyse.cluster_analysis import cluster_analysis_main +from msprof_analyze.cluster_analyse.cluster_analysis import Interface +from msprof_analyze.cluster_analyse.cluster_data_preprocess.msprof_data_preprocessor import MsprofDataPreprocessor + +NAMESPACE = "msprof_analyze.cluster_analyse" + +class TestClusterAnalyseClusterAnalysis(TestCase): + """ + test cluster analysis + solutions: cluster_analysis.py is the entrance of cluster_analysis, + its main function is parse the argv and run encountered analysis task. + However, run whole task in UTest is not reasonable, so the main solutions is checking return of failure. + """ + + CLUSTER_ANALYSIS_PATH = os.path.join(os.path.dirname(__file__), "cluster_analysis_test") + PROF_PATH = os.path.join(CLUSTER_ANALYSIS_PATH, "PROF_114514") + OUTPUT_PATH = os.path.join(CLUSTER_ANALYSIS_PATH, "output") + + def setUp(self): + # backup argv + self._orig_argv = sys.argv + if not os.path.exists(self.CLUSTER_ANALYSIS_PATH): + os.mkdir(self.CLUSTER_ANALYSIS_PATH) + if not os.path.exists(self.PROF_PATH): + os.mkdir(self.PROF_PATH) + + def tearDown(self): + # restore argv,avoiding argv pollution + sys.argv = self._orig_argv + if os.path.exists(self.CLUSTER_ANALYSIS_PATH): + shutil.rmtree(self.CLUSTER_ANALYSIS_PATH) + + @mock.patch(NAMESPACE + ".cluster_analysis.Interface") # 打桩真正的 Interface.run + def test_cluster_analysis_main_should_run_success_and_handle_correct_parameter(self, mock_if): + # 构造“命令行” + sys.argv = [ + "cluster_analysis.py", + "-d", "./tmp/prof", + "-o", "./tmp/out", + "-m", "all", + "--data_simplification", + "--force", + ] + + # execute cluster entrance + cluster_analysis_main() + + # assert Interface be called once + self.assertEqual(mock_if.call_count, 1) + kwargs = mock_if.call_args[0][0] # first arg is parameter dict + self.assertEqual(kwargs["profiling_path"], "./tmp/prof") + self.assertEqual(kwargs["mode"], "all") + self.assertEqual(kwargs["output_path"], "./tmp/out") + self.assertTrue(kwargs["data_simplification"]) + self.assertTrue(kwargs["force"]) + + # restore origin argv, avoiding argv pollution + sys.argv = self._orig_argv + + # @mock.patch("cluster_analysis.Interface") + # def test_unknown_args_non_comm(self, mock_if): + # sys.argv = [ + # "cluster_analysis.py", + # "-d", "/tmp/prof", + # "-m", "all", + # "extra", "unknown" + # ] + # cluster_analysis_main() + # kwargs = mock_if.call_args[0][0] + # self.assertEqual(kwargs["mode"], "all") + # self.assertIn("extra_args", kwargs) + # self.assertEqual(kwargs["extra_args"], ["extra", "unknown"]) + # + # @mock.patch("cluster_analysis.Interface") + # def test_unknown_args_comm_mode_logs_warning(self, mock_if): + # with self.assertLogs("cluster_analysis", level="WARNING") as cm: + # sys.argv = [ + # "cluster_analysis.py", + # "-d", "/tmp/prof", + # "-m", "comm", + # "extra" + # ] + # cluster_analysis_main() + # # 确认日志中有提示 + # self.assertTrue(any("Invalid parameters" in msg for msg in cm.output)) \ No newline at end of file