From 2b1df9f980c3ea3477e2cc435da080e482daa1e8 Mon Sep 17 00:00:00 2001 From: yang_feida Date: Sat, 2 Aug 2025 15:48:39 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E8=8E=B7=E5=8F=96history=20server=E6=97=B6?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E7=94=A8=E6=88=B7=E9=AA=8C=E8=AF=81?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/config/common_config.cfg | 4 +++ omniadvisor/src/common/constant.py | 2 ++ .../service/spark_service/spark_fetcher.py | 28 +++++++++++++++++-- .../service/spark_service/spark_run.py | 4 ++- .../spark_service/test_spark_fetcher.py | 10 +++++-- 5 files changed, 42 insertions(+), 6 deletions(-) diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index 240e42c4e..4248a8900 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -9,6 +9,10 @@ tuning.strategy=[["transfer", 1],["expert", 2],["iterative", 10]] [spark] # Spark History Server的URL 仅用于Rest模式 spark.history.rest.url=http://localhost:18080 +# Spark History Server的URL 的用户名,仅需要时填写 +spark.history.username= +# Spark History Server的URL 的密码,仅需要时填写 +spark.history.password= # Spark从History Sever抓取Trace的超时时间 spark.fetch.trace.timeout=30 # Spark从History Sever抓取Trace的间隔用时 diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 98a20ff69..e6e21d502 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -96,6 +96,8 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') config_fail_threshold = _common_config.getint('common', 'config.fail.threshold') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') + spark_history_username = _common_config.get('spark', 'spark.history.username') + spark_history_password = _common_config.get('spark', 'spark.history.password') spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py index de1a7dd73..c3fb5dc3d 100755 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py @@ -3,15 +3,19 @@ import requests class SparkFetcher: - def __init__(self, history_server_url,): + def __init__(self, history_server_url, username, password): """ 初始化SparkFetcher类 :param history_server_url: Spark History Server的URL地址 """ self.history_server_url = history_server_url.rstrip('/') # 确保URL末尾没有斜杠 + self.username = username + self.password = password + # 若用户配置了用户名和密码,那么应当使用 with_auth 方法 + self._make_request = self._make_request_with_auth if self.username else self._make_request_without_auth - def _make_request(self, endpoint): + def _make_request_without_auth(self, endpoint): """ 发送GET请求并处理响应。 @@ -31,6 +35,26 @@ class SparkFetcher: raise ValueError('Something wrong in trace fetched, can not decode into Json data.') from e return json_data + def _make_request_with_auth(self, endpoint): + """ + 发送带 ssl 认证的 http 请求 + :param endpoint: API端点路径 + :return: 解析后的JSON响应数据 + :raises: requests.exceptions.HTTPError 如果HTTP请求返回了错误状态码 + """ + endpoint = endpoint.strip("/") + url = f"{self.history_server_url}/{endpoint}" + # 使用 self.username 和 self.password 进行 HTTP Basic 认证 & 忽略 SSL 证书验证 & 禁止自动跳转 + response = requests.get(url, auth=(self.username, self.password), verify=False, allow_redirects=False) + # 若status.code != 200~209则抛出异常 + response.raise_for_status() + # 将获取到的数据转为json返回 + try: + json_data = json.loads(response.text) + except Exception as e: + raise ValueError('Something wrong in trace fetched, can not decode into Json data.') from e + return json_data + def get_spark_apps(self): """ 获取所有Spark应用的信息。 diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 403e120eb..f8f3b59e3 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -105,7 +105,9 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s """ trace_dict = {} history_server_url = OA_CONF.spark_history_rest_url - spark_fetcher = SparkFetcher(history_server_url) + history_server_username = OA_CONF.spark_history_username + history_server_password = OA_CONF.spark_history_password + spark_fetcher = SparkFetcher(history_server_url, history_server_username, history_server_password) start_time = time.time() while time.time() - start_time < OA_CONF.spark_fetch_trace_timeout: try: diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py index 8c05130d0..feb3b3ae5 100755 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py @@ -5,9 +5,8 @@ from unittest import mock from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher from common.constant import OA_CONF - history_server_url = OA_CONF.spark_history_rest_url -spark_fetcher = SparkFetcher(history_server_url) +spark_fetcher = SparkFetcher(history_server_url, None, None) # 示例数据 MOCK_APPS_RESPONSE = [ @@ -140,5 +139,10 @@ class TestSparkFetcher: with pytest.raises(requests.exceptions.HTTPError): spark_fetcher.get_spark_executor_by_app("app-1") + def test_init_when_with_auth(self): + # 带用户验证情况 + pass - + def test_init_when_without_auth(self): + # 不带用户验证情况 + pass -- Gitee From e71c2bd1d38f1f1206a351fe75d88b6f1ad2e440 Mon Sep 17 00:00:00 2001 From: yang_feida Date: Sat, 2 Aug 2025 16:26:38 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E5=8E=BB=E9=99=A4=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=85=9C=E5=BA=95=E6=9C=BA=E5=88=B6=EF=BC=8C?= =?UTF-8?q?=E6=94=B9=E4=B8=BAspark-submit=E5=85=9C=E5=BA=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/omniadvisor/interface/hijack_recommend.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index e48e21dff..bc6980a4e 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -140,11 +140,7 @@ def hijack_recommend(argv: list): # 若执行失败 则判断是否需要拉起安全机制 else: if exec_config != user_config: - global_logger.warning("Spark execute failed, ready to activate security protection mechanism.") - safe_exam_record, safe_output = spark_run(load=load, config=user_config, wait_for_trace=False) - global_logger.info("Spark execute in security protection mechanism, going to print Spark output.") - # 打印安全机制下任务的输出 - print(safe_output, end="", flush=True) + raise RuntimeError("Spark execute failed, ready to activate security protection mechanism.") else: global_logger.warning("Spark execute failed in user config, going to print Spark output.") print(output, end="", flush=True) -- Gitee From 4d9b08423c2e619d28745911121a334f19b22107 Mon Sep 17 00:00:00 2001 From: yang_feida Date: Mon, 4 Aug 2025 11:09:14 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E8=AF=84=E5=AE=A1=E6=84=8F=E8=A7=81?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/common/constant.py | 11 +- .../service/spark_service/spark_fetcher.py | 30 +---- .../tests/omniadvisor/common/__init__.py | 0 .../tests/omniadvisor/common/test_constant.py | 122 ++++++++++++++++++ .../interface/test_hijack_recommend.py | 25 ++-- .../spark_service/test_spark_fetcher.py | 37 ++++-- 6 files changed, 179 insertions(+), 46 deletions(-) create mode 100644 omniadvisor/tests/omniadvisor/common/__init__.py create mode 100644 omniadvisor/tests/omniadvisor/common/test_constant.py diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index e6e21d502..ba212b7c7 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -36,6 +36,13 @@ def check_oa_conf(): if OA_CONF.spark_exec_timeout_ratio <= 0: raise ValueError('The spark exec timeout ratio must > 0, please check common configuration.') + if not ( + (OA_CONF.spark_history_username and OA_CONF.spark_history_password) or + (not OA_CONF.spark_history_username and not OA_CONF.spark_history_password) + ): + raise ValueError('The spark history username and password should be provided, or leave both blank, please check' + ' common configuration.') + class OmniAdvisorConf: """ @@ -96,8 +103,8 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') config_fail_threshold = _common_config.getint('common', 'config.fail.threshold') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') - spark_history_username = _common_config.get('spark', 'spark.history.username') - spark_history_password = _common_config.get('spark', 'spark.history.password') + spark_history_username = _common_config.get('spark', 'spark.history.username', fallback='') + spark_history_password = _common_config.get('spark', 'spark.history.password', fallback='') spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py index c3fb5dc3d..0e1aec2b7 100755 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py @@ -12,10 +12,8 @@ class SparkFetcher: self.history_server_url = history_server_url.rstrip('/') # 确保URL末尾没有斜杠 self.username = username self.password = password - # 若用户配置了用户名和密码,那么应当使用 with_auth 方法 - self._make_request = self._make_request_with_auth if self.username else self._make_request_without_auth - def _make_request_without_auth(self, endpoint): + def _make_request(self, endpoint): """ 发送GET请求并处理响应。 @@ -25,27 +23,11 @@ class SparkFetcher: """ endpoint = endpoint.strip("/") url = f"{self.history_server_url}/{endpoint}" - response = requests.get(url) - # 若status.code != 200~209则抛出异常 - response.raise_for_status() - # 将获取到的数据转为json返回 - try: - json_data = json.loads(response.text) - except Exception as e: - raise ValueError('Something wrong in trace fetched, can not decode into Json data.') from e - return json_data - - def _make_request_with_auth(self, endpoint): - """ - 发送带 ssl 认证的 http 请求 - :param endpoint: API端点路径 - :return: 解析后的JSON响应数据 - :raises: requests.exceptions.HTTPError 如果HTTP请求返回了错误状态码 - """ - endpoint = endpoint.strip("/") - url = f"{self.history_server_url}/{endpoint}" - # 使用 self.username 和 self.password 进行 HTTP Basic 认证 & 忽略 SSL 证书验证 & 禁止自动跳转 - response = requests.get(url, auth=(self.username, self.password), verify=False, allow_redirects=False) + if self.username: + # 若用户配置了用户名和密码,那么应当走验证通道 + response = requests.get(url, auth=(self.username, self.password), verify=False, allow_redirects=False) + else: + response = requests.get(url) # 若status.code != 200~209则抛出异常 response.raise_for_status() # 将获取到的数据转为json返回 diff --git a/omniadvisor/tests/omniadvisor/common/__init__.py b/omniadvisor/tests/omniadvisor/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/omniadvisor/tests/omniadvisor/common/test_constant.py b/omniadvisor/tests/omniadvisor/common/test_constant.py new file mode 100644 index 000000000..77c823244 --- /dev/null +++ b/omniadvisor/tests/omniadvisor/common/test_constant.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from common.constant import check_oa_conf + +OA_CONF_PATH = 'common.constant.OA_CONF' + + +class TestConstant: + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_normal(self): + # 不应抛异常 + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + # invalid + tuning_retest_times=0, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_tuning_retest_times(self): + with pytest.raises(ValueError, match='tuning retest times'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + # invalid + config_fail_threshold=0, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_config_fail_threshold(self): + with pytest.raises(ValueError, match='config fail threshold'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + # invalid + spark_fetch_trace_timeout=0, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_fetch_trace_timeout(self): + with pytest.raises(ValueError, match='spark fetch trace timeout'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + # invalid + spark_fetch_trace_interval=0, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_fetch_trace_interval(self): + with pytest.raises(ValueError, match='spark fetch trace interval'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + # invalid + spark_exec_timeout_ratio=0, + spark_history_username='user', + spark_history_password='pass' + )) + def test_check_oa_conf_when_invalid_exec_timeout_ratio(self): + with pytest.raises(ValueError, match='spark exec timeout ratio'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + spark_history_username='user', + # mismatch + spark_history_password='' + )) + def test_check_oa_conf_when_username_password_mismatch_case1(self): + with pytest.raises(ValueError, match='username and password'): + check_oa_conf() + + @patch(OA_CONF_PATH, new=SimpleNamespace( + tuning_retest_times=1, + config_fail_threshold=1, + spark_fetch_trace_timeout=10, + spark_fetch_trace_interval=5, + spark_exec_timeout_ratio=0.5, + # mismatch + spark_history_username='', + spark_history_password='pass' + )) + def test_check_oa_conf_when_username_password_mismatch_case2(self): + with pytest.raises(ValueError, match='username and password'): + check_oa_conf() diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index 3c9bfd314..69fe108c6 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -15,7 +15,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_successful_execution(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config): + def test_successful_execution(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "test_job"} user_config = {"spark.executor.memory": "4g"} @@ -38,7 +39,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_fallback_to_safe_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): + def test_fallback_to_safe_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} user_config = {"spark.executor.memory": "4g"} @@ -56,9 +58,10 @@ class TestHijackRecommend(unittest.TestCase): mock_process = MagicMock() mock_multiprocess.return_value = mock_process - hijack_recommend(argv) - assert mock_spark_run.call_count == 2 - mock_process.start.assert_called_once() + with pytest.raises(RuntimeError, + match="Spark execute failed, ready to activate security protection mechanism."): + hijack_recommend(argv) + assert mock_spark_run.call_count == 1 # 场景 3: 执行失败 + 不进入安全机制 @patch("omniadvisor.interface.hijack_recommend._process_load_config") @@ -66,7 +69,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_failed_user_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config): + def test_failed_user_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config): argv = ["--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} user_config = {"spark.executor.memory": "4g"} @@ -99,7 +103,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_process_test_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): + def test_process_test_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job"} user_config = {"spark.executor.memory": "4g"} @@ -110,12 +115,11 @@ class TestHijackRecommend(unittest.TestCase): mock_parse_cmd.return_value = (exec_attr, user_config) mock_query_load.return_value = load mock_get_config.return_value = exec_config - mock_spark_run.return_value = (MagicMock(status="success"), "job output") + mock_spark_run.return_value = (MagicMock(status=OA_CONF.ExecStatus.success), "job output") mock_process = MagicMock() mock_multiprocess.return_value = mock_process hijack_recommend(argv) - mock_process.start.assert_called_once() # 场景 6: 命令解析异常 @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") @@ -158,7 +162,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend.LoadRepository.update_test_config") @patch("omniadvisor.interface.hijack_recommend.get_tuning_result_history") @patch("omniadvisor.interface.hijack_recommend.get_tuning_result") - def test_process_load_config(self, mock_get_tuning_result, mock_get_tuning_result_history, mock_update_test_config, mock_float_format): + def test_process_load_config(self, mock_get_tuning_result, mock_get_tuning_result_history, mock_update_test_config, + mock_float_format): best_config = {'param1': 'value1'} test_config = {'param2': 'value2'} diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py index feb3b3ae5..1b3a9f215 100755 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py @@ -1,9 +1,11 @@ import json +from unittest import mock + import pytest import requests -from unittest import mock -from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher + from common.constant import OA_CONF +from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher history_server_url = OA_CONF.spark_history_rest_url spark_fetcher = SparkFetcher(history_server_url, None, None) @@ -116,12 +118,12 @@ class TestSparkFetcher: mock_response.text = json.dumps(MOCK_APPS_RESPONSE) # 配置raise_for_status_code方法 使其根据状态码抛出HTTPError - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("HTTP error occurred", response=mock_response) + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("HTTP error occurred", + response=mock_response) # 配置mock对象的返回值 mock_get.return_value = mock_response - # 断言HTTPError被抛出 with pytest.raises(requests.exceptions.HTTPError): spark_fetcher.get_spark_apps() @@ -139,10 +141,25 @@ class TestSparkFetcher: with pytest.raises(requests.exceptions.HTTPError): spark_fetcher.get_spark_executor_by_app("app-1") - def test_init_when_with_auth(self): - # 带用户验证情况 - pass - - def test_init_when_without_auth(self): + @mock.patch('requests.get') + def test_make_request_when_without_username(self, mock_get): # 不带用户验证情况 - pass + mock_response = mock.Mock() + mock_response.text = "{}" + mock_get.return_value = mock_response + fetcher = SparkFetcher('server_url', '', '') + endpoint = 'endpoint' + fetcher._make_request(endpoint) + mock_get.assert_called_once_with('server_url/endpoint') + + @mock.patch('requests.get') + def test_make_request_when_with_username(self, mock_get): + # 带用户验证情况 + mock_response = mock.Mock() + mock_response.text = "{}" + mock_get.return_value = mock_response + fetcher = SparkFetcher('server_url', 'user', 'pass') + endpoint = 'endpoint' + fetcher._make_request(endpoint) + mock_get.assert_called_once_with('server_url/endpoint', auth=('user', 'pass'), verify=False, + allow_redirects=False) -- Gitee