diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index 240e42c4e6674aea6ee10987e57b1f500a934bb6..4248a8900fdd7c2c4e1baa0426af76b6ef26896a 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -9,6 +9,10 @@ tuning.strategy=[["transfer", 1],["expert", 2],["iterative", 10]] [spark] # Spark History Server的URL 仅用于Rest模式 spark.history.rest.url=http://localhost:18080 +# Spark History Server的URL 的用户名,仅需要时填写 +spark.history.username= +# Spark History Server的URL 的密码,仅需要时填写 +spark.history.password= # Spark从History Sever抓取Trace的超时时间 spark.fetch.trace.timeout=30 # Spark从History Sever抓取Trace的间隔用时 diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 98a20ff69fe76dcc52028f8be1edb2043a14ea9a..ba212b7c7ad6f4c3f369d0d64f8cdd304db262b6 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -36,6 +36,13 @@ def check_oa_conf(): if OA_CONF.spark_exec_timeout_ratio <= 0: raise ValueError('The spark exec timeout ratio must > 0, please check common configuration.') + if not ( + (OA_CONF.spark_history_username and OA_CONF.spark_history_password) or + (not OA_CONF.spark_history_username and not OA_CONF.spark_history_password) + ): + raise ValueError('The spark history username and password should be provided, or leave both blank, please check' + ' common configuration.') + class OmniAdvisorConf: """ @@ -96,6 +103,8 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') config_fail_threshold = _common_config.getint('common', 'config.fail.threshold') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') + spark_history_username = _common_config.get('spark', 'spark.history.username', fallback='') + spark_history_password = _common_config.get('spark', 'spark.history.password', fallback='') spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index e48e21dff9fcb5043727341d7e88cc588a565d62..bc6980a4eb9ef998a4eddc9338fd9049e940704c 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -140,11 +140,7 @@ def hijack_recommend(argv: list): # 若执行失败 则判断是否需要拉起安全机制 else: if exec_config != user_config: - global_logger.warning("Spark execute failed, ready to activate security protection mechanism.") - safe_exam_record, safe_output = spark_run(load=load, config=user_config, wait_for_trace=False) - global_logger.info("Spark execute in security protection mechanism, going to print Spark output.") - # 打印安全机制下任务的输出 - print(safe_output, end="", flush=True) + raise RuntimeError("Spark execute failed, ready to activate security protection mechanism.") else: global_logger.warning("Spark execute failed in user config, going to print Spark output.") print(output, end="", flush=True) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py index de1a7dd73ed7056686407834ef59c12cce9b5d86..0e1aec2b7bea77602e47060f5610b522eed5dc9d 100755 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py @@ -3,13 +3,15 @@ import requests class SparkFetcher: - def __init__(self, history_server_url,): + def __init__(self, history_server_url, username, password): """ 初始化SparkFetcher类 :param history_server_url: Spark History Server的URL地址 """ self.history_server_url = history_server_url.rstrip('/') # 确保URL末尾没有斜杠 + self.username = username + self.password = password def _make_request(self, endpoint): """ @@ -21,7 +23,11 @@ class SparkFetcher: """ endpoint = endpoint.strip("/") url = f"{self.history_server_url}/{endpoint}" - response = requests.get(url) + if self.username: + # 若用户配置了用户名和密码,那么应当走验证通道 + response = requests.get(url, auth=(self.username, self.password), verify=False, allow_redirects=False) + else: + response = requests.get(url) # 若status.code != 200~209则抛出异常 response.raise_for_status() # 将获取到的数据转为json返回 diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 403e120ebb5ce21662662b0cc78790e70ab0e811..f8f3b59e3f9c226baea4f52d85bc19402525bf61 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -105,7 +105,9 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s """ trace_dict = {} history_server_url = OA_CONF.spark_history_rest_url - spark_fetcher = SparkFetcher(history_server_url) + history_server_username = OA_CONF.spark_history_username + history_server_password = OA_CONF.spark_history_password + spark_fetcher = SparkFetcher(history_server_url, history_server_username, history_server_password) start_time = time.time() while time.time() - start_time < OA_CONF.spark_fetch_trace_timeout: try: diff --git a/omniadvisor/tests/omniadvisor/common/__init__.py b/omniadvisor/tests/omniadvisor/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/omniadvisor/tests/omniadvisor/common/test_constant.py b/omniadvisor/tests/omniadvisor/common/test_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..77c823244d0e3de5d65f88812489d87d8946f1bf --- /dev/null +++ b/omniadvisor/tests/omniadvisor/common/test_constant.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from common.constant import check_oa_conf + +OA_CONF_PATH = 'common.constant.OA_CONF' + + +class TestConstant: + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_normal(self): + # 不应抛异常 + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + # invalid + tuning_retest_times=0, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_tuning_retest_times(self): + with pytest.raises(ValueError, match='tuning retest times'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + # invalid + config_fail_threshold=0, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_config_fail_threshold(self): + with pytest.raises(ValueError, match='config fail threshold'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + # invalid + spark_fetch_trace_timeout=0, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_fetch_trace_timeout(self): + with pytest.raises(ValueError, match='spark fetch trace timeout'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + # invalid + spark_fetch_trace_interval=0, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_fetch_trace_interval(self): + with pytest.raises(ValueError, match='spark fetch trace interval'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + # invalid + spark_exec_timeout_ratio=0, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_exec_timeout_ratio(self): + with pytest.raises(ValueError, match='spark exec timeout ratio'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + # mismatch + spark_history_password='' + )) + def test_check_oa_conf_when_username_password_mismatch_case1(self): + with pytest.raises(ValueError, match='username and password'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + # mismatch + spark_history_username='', + spark_history_password='pass' + )) + def test_check_oa_conf_when_username_password_mismatch_case2(self): + with pytest.raises(ValueError, match='username and password'): + check_oa_conf() diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index 3c9bfd314fb85108a20490b1b9ba8196e4a04653..69fe108c6faa420f044cfe4fc5818f644bd67b26 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -15,7 +15,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_successful_execution(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config): + def test_successful_execution(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "test_job"} user_config = {"spark.executor.memory": "4g"} @@ -38,7 +39,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_fallback_to_safe_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): + def test_fallback_to_safe_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} user_config = {"spark.executor.memory": "4g"} @@ -56,9 +58,10 @@ class TestHijackRecommend(unittest.TestCase): mock_process = MagicMock() mock_multiprocess.return_value = mock_process - hijack_recommend(argv) - assert mock_spark_run.call_count == 2 - mock_process.start.assert_called_once() + with pytest.raises(RuntimeError, + match="Spark execute failed, ready to activate security protection mechanism."): + hijack_recommend(argv) + assert mock_spark_run.call_count == 1 # 场景 3: 执行失败 + 不进入安全机制 @patch("omniadvisor.interface.hijack_recommend._process_load_config") @@ -66,7 +69,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_failed_user_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config): + def test_failed_user_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config): argv = ["--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} user_config = {"spark.executor.memory": "4g"} @@ -99,7 +103,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_process_test_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): + def test_process_test_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job"} user_config = {"spark.executor.memory": "4g"} @@ -110,12 +115,11 @@ class TestHijackRecommend(unittest.TestCase): mock_parse_cmd.return_value = (exec_attr, user_config) mock_query_load.return_value = load mock_get_config.return_value = exec_config - mock_spark_run.return_value = (MagicMock(status="success"), "job output") + mock_spark_run.return_value = (MagicMock(status=OA_CONF.ExecStatus.success), "job output") mock_process = MagicMock() mock_multiprocess.return_value = mock_process hijack_recommend(argv) - mock_process.start.assert_called_once() # 场景 6: 命令解析异常 @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") @@ -158,7 +162,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend.LoadRepository.update_test_config") @patch("omniadvisor.interface.hijack_recommend.get_tuning_result_history") @patch("omniadvisor.interface.hijack_recommend.get_tuning_result") - def test_process_load_config(self, mock_get_tuning_result, mock_get_tuning_result_history, mock_update_test_config, mock_float_format): + def test_process_load_config(self, mock_get_tuning_result, mock_get_tuning_result_history, mock_update_test_config, + mock_float_format): best_config = {'param1': 'value1'} test_config = {'param2': 'value2'} diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py index 8c05130d042d218c74a4c2586a6b3ff0369d21ee..1b3a9f2154a732d97398acdbded66a03a51ae061 100755 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py @@ -1,13 +1,14 @@ import json +from unittest import mock + import pytest import requests -from unittest import mock -from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher -from common.constant import OA_CONF +from common.constant import OA_CONF +from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher history_server_url = OA_CONF.spark_history_rest_url -spark_fetcher = SparkFetcher(history_server_url) +spark_fetcher = SparkFetcher(history_server_url, None, None) # 示例数据 MOCK_APPS_RESPONSE = [ @@ -117,12 +118,12 @@ class TestSparkFetcher: mock_response.text = json.dumps(MOCK_APPS_RESPONSE) # 配置raise_for_status_code方法 使其根据状态码抛出HTTPError - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("HTTP error occurred", response=mock_response) + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("HTTP error occurred", + response=mock_response) # 配置mock对象的返回值 mock_get.return_value = mock_response - # 断言HTTPError被抛出 with pytest.raises(requests.exceptions.HTTPError): spark_fetcher.get_spark_apps() @@ -140,5 +141,25 @@ class TestSparkFetcher: with pytest.raises(requests.exceptions.HTTPError): spark_fetcher.get_spark_executor_by_app("app-1") + @mock.patch('requests.get') + def test_make_request_when_without_username(self, mock_get): + # 不带用户验证情况 + mock_response = mock.Mock() + mock_response.text = "{}" + mock_get.return_value = mock_response + fetcher = SparkFetcher('server_url', '', '') + endpoint = 'endpoint' + fetcher._make_request(endpoint) + mock_get.assert_called_once_with('server_url/endpoint') - + @mock.patch('requests.get') + def test_make_request_when_with_username(self, mock_get): + # 带用户验证情况 + mock_response = mock.Mock() + mock_response.text = "{}" + mock_get.return_value = mock_response + fetcher = SparkFetcher('server_url', 'user', 'pass') + endpoint = 'endpoint' + fetcher._make_request(endpoint) + mock_get.assert_called_once_with('server_url/endpoint', auth=('user', 'pass'), verify=False, + allow_redirects=False)