diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index a7208d244313e06ccb37c86307020b3a381fe2ba..73ccf53ef736fabb8c26d653cbf643a09930b9ac 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -4,4 +4,4 @@ tuning.retest.times=3 [spark] # Spark History Server的URL 仅用于Rest模式 -spark.history.rest.url=http://localhost:18081 +spark.history.rest.url=http://localhost:18080 \ No newline at end of file diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index f545a9eca6835568b39d614a1842a47351811f06..450f60e242b41ffa15fea9025e48bf668ab4808d 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -74,5 +74,4 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') - OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/hijack.py b/omniadvisor/src/hijack.py index ed3553496b4c7e22660fe9a9d21ebc761e1a149f..cc979e0cd9ad3392db3c51d9cf6d4bbc0ec03114 100644 --- a/omniadvisor/src/hijack.py +++ b/omniadvisor/src/hijack.py @@ -1,19 +1,21 @@ from omniadvisor.interface import hijack_recommend -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger, disable_console_handler if __name__ == '__main__': + # 前台劫持功能为了保证用户无感 不应在控制台有日志输出 因此禁用控制台输出 + disable_console_handler(global_logger) try: hijack_recommend.main() # 无需进行逻辑处理的异常,直接抛至该层 # 若需进行逻辑处理(如环境清理等),则需在相应位置处理后重新抛至该层 except Exception as e: # 异常信息统一在此处打印堆栈,方便定位,抛出异常的地方无需打印log - logger.exception(e) + global_logger.exception(e) # 异常退出 # 劫持任务异常退出后,需在Spark脚本内执行原Spark命令语句 exit(1) # 正常退出 - logger.info('Hijack mission complete!') + global_logger.info('Hijack mission complete!') exit(0) diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index dfa85f02d9ab364b58098dc733d53ff979b39355..d5f5408ae42c5bbcdcbb17ff86bbbd62839e9ed4 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -1,7 +1,7 @@ import argparse from common.constant import OA_CONF -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from omniadvisor.repository.load_repository import LoadRepository from omniadvisor.service.tuning_result.tuning_result_history import get_tuning_result_history from omniadvisor.service.retest_service import retest @@ -42,7 +42,7 @@ def unified_tuning(load_id: str, retest_way: str, tuning_method: str): loads = LoadRepository.query_by_id(load_id) if not loads: # 若查询结果为空,直接返回即可 - logger.info('Cannot find load id: %s in database.', load_id) + global_logger.info('Cannot find load id: %s in database.', load_id) return load = loads[0] diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 3a8290d64322a10c543a8cb68fb74e8b59186126..12de329f572fdd47304a9127568345d5d6b534ff 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -3,7 +3,7 @@ import sys from omniadvisor.service.spark_service.spark_parameter_parser import SparkParameterParser from omniadvisor.repository.model.load import Load from omniadvisor.repository.load_repository import LoadRepository -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from omniadvisor.service.spark_service.spark_run import spark_run from common.constant import OA_CONF @@ -21,7 +21,7 @@ def _query_or_create_load(name: str, exec_attr: dict, default_config: dict): loads = LoadRepository.query_by_name_and_default_config(name=name, default_config=default_config) # 如果查询不到负载信息,则创建新的负载 if not loads: - logger.info("Not exit available load, create new one.") + global_logger.info("Not exit available load, create new one.") # TODO 如果用户在这里创建了一个格式异常的config, 会引起spark执行失败 这个应该如何处理 load = LoadRepository.create(name=name, exec_attr=exec_attr, default_config=default_config) else: diff --git a/omniadvisor/src/omniadvisor/service/retest_service.py b/omniadvisor/src/omniadvisor/service/retest_service.py index b25edf6c97fae4fdd280cc6d531ddd67a2410188..f982b1649f6eaba6a7b1cc63739343d8bc011165 100644 --- a/omniadvisor/src/omniadvisor/service/retest_service.py +++ b/omniadvisor/src/omniadvisor/service/retest_service.py @@ -1,7 +1,7 @@ from common.constant import OA_CONF from omniadvisor.repository.model.load import Load from omniadvisor.service.spark_service.spark_run import spark_run -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger def retest(load: Load, config: dict): @@ -12,17 +12,15 @@ def retest(load: Load, config: dict): :param config: 配置 :return: """ - logger.debug('Starting retest config...') + global_logger.debug('开始复测配置...') for i in range(1, OA_CONF.tuning_retest_times + 1): try: exam_record, spark_output = spark_run(load, config) except Exception as e: - logger.error('复测第 %d 轮失败,异常来源:非spark运行异常,详情:%s', i, e) - # 目前采取的策略是:抛异常,认为是其他原因导致的失败,可以肯定不是配置的原因 - raise + global_logger.error('复测第 %d 轮失败,原因:%s', i, e) + return if exam_record.status == OA_CONF.ExecStatus.success: - logger.info('复测第 %d 轮成功,性能结果:%.3f', i, exam_record.runtime) + global_logger.info('复测第 %d 轮成功,性能结果:%.3f', i, exam_record.runtime) else: - # 若是出现异常配置,也要退出 - raise RuntimeError('复测第 %d 轮失败,异常来源:spark运行异常:%s', i, spark_output) + global_logger.error('复测第 %d 轮失败,spark执行异常:%s', i, spark_output) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py index e42ef37d1e97c6208140ba2c167884af063b86d1..31aa2dfa3e0312df38ae23a6d60a795bcbc3d20b 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py @@ -1,4 +1,5 @@ -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger +from common.constant import OA_CONF def spark_command_reconstruct(load, conf): @@ -7,11 +8,15 @@ def spark_command_reconstruct(load, conf): :param conf: {key:value} 保存spark-sql命令中形如 --conf spark.sql.shuffle.partitions(key)=3200(value)的部分 :return: """ - cmd_prefix = "spark-sql" + cmd_prefix = load.exec_attr.get('cmd_prefix') + if not cmd_prefix: + raise AttributeError("the spark-sql command prefix is deprecated") name = load.name submit_cmd_list = [cmd_prefix, "--name", name] for key, value in load.exec_attr.items(): + if key == 'cmd_prefix': + continue # -i参数在处理上不支持--的形式 单独处理 if key == "i": submit_cmd_list.append("-i") @@ -38,7 +43,7 @@ def spark_command_reconstruct(load, conf): submit_cmd_list.append(f"{_normalize_key(key)}={_normalize_value(value)}") submit_cmd = " ".join(submit_cmd_list) - logger.info(f"拼接后的spark-sql命令如下{submit_cmd}") + global_logger.info(f"拼接后的spark-sql命令如下{submit_cmd}") return submit_cmd diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py index 125807e653779019a4407e593e0b5abab8d1e764..1c3426177c44e796cf4df8e98993c59fd5cba197 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py @@ -20,7 +20,8 @@ class SparkExecutor: spark_output = spark_output.split("\n") spark_submit_cmd = "" application_id = "" - time_taken = -1.0 + total_time_taken = -1.0 + time_taken_list = [] for item in spark_output: if "deploy.SparkSubmit" in item: spark_submit_cmd = item.strip() @@ -35,7 +36,7 @@ class SparkExecutor: pattern = r"Time taken:\s*(.*)seconds" match = re.search(pattern, item) if match: - time_taken = match.group(1).strip() - if application_id and float(time_taken) > 0: - break - return spark_submit_cmd, application_id, float(time_taken) + time_taken_list.append(float(match.group(1).strip())) + if time_taken_list: + total_time_taken = sum(time_taken_list) + return spark_submit_cmd, application_id, float(total_time_taken) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py index 028fd18e4d9c8e8dedfad09a09ce26d9d4826901..07ba72cfa17740d31e0688b4a25f04cc517612e3 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py @@ -1,6 +1,6 @@ import shlex from dataclasses import dataclass -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from common.constant import OA_CONF from omniadvisor.service.parameter_parser import ParserInterface @@ -84,11 +84,16 @@ class SparkParameterParser(ParserInterface): """ if not self._submit_cmd: raise ValueError("Submit command cannot be null") - param_start_index = self._submit_cmd.strip().index(" ") name = "" conf_params = {} exec_attr = {} - args = shlex.split(self._submit_cmd[param_start_index:]) + # 解析spark-class的绝对路径 + first_space_index = self._submit_cmd.strip().find(" ") + # 从第一个空格之后的子字符串中找到第二个空格 + second_space_index = self._submit_cmd.strip().find(" ", first_space_index + 1) + exec_attr['cmd_prefix'] = self._submit_cmd[0:second_space_index] + # 解析参数 + args = shlex.split(self._submit_cmd[second_space_index:]) params, unknown = self._parser.parse_known_args(args=args) options_dict = vars(params) for key, value in options_dict.items(): @@ -111,7 +116,7 @@ class SparkParameterParser(ParserInterface): conf_params[confkey].append(confvalue) else: conf_params[confkey] = [conf_params[key], value] - logger.warn(f"{confkey}被重复配置,多次配置值如下{conf_params[confvalue]}") + global_logger.warn(f"{confkey}被重复配置,多次配置值如下{conf_params[confvalue]}") else: conf_params[confkey] = confvalue # 将 --num-executors --executors-cores --driver-cores这类的配置归类到conf_params中便于后续重建命令 @@ -119,6 +124,6 @@ class SparkParameterParser(ParserInterface): conf_params[SPARK_CONF_SUPPLEMENT_MAP[key]] = value else: name = value - logger.warn(f"The remainder unknown params {unknown} will be add to exec_attr[remainder]") + global_logger.warn(f"The remainder unknown params {unknown} will be add to exec_attr[remainder]") exec_attr["remainder"] = unknown return BaseParams(name=name, conf_params=conf_params, exec_attr=exec_attr) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 5737f7aabf9cb8f4810c4c8684911bbf20b2e75c..054036f8cc1180242d3a90bdd92d7dbea40f7351 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -9,8 +9,8 @@ from omniadvisor.repository.model.exam_record import ExamRecord from omniadvisor.service.spark_service.spark_command_reconstruct import spark_command_reconstruct from omniadvisor.service.spark_service.spark_executor import SparkExecutor from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher -from omniadvisor.utils.logger import logger -from omniadvisor.utils.utils import trace_data_saver +from omniadvisor.utils.logger import global_logger +from omniadvisor.utils.utils import save_trace_data def spark_run(load, conf): @@ -28,16 +28,16 @@ def spark_run(load, conf): except TimeoutError: # 任务提交超时等 exam_record.delete() - logger.error('Spark command submission timed out.') + global_logger.error('Spark command submission timed out.') raise except OSError: # 权限不足等 exam_record.delete() - logger.error('Spark command submission permission denied.') + global_logger.error('Spark command submission permission denied.') raise except Exception: exam_record.delete() - logger.error('During Spark command submission, known error occurred.') + global_logger.error('During Spark command submission, known error occurred.') raise # 不存在application_id的情况下不提取time_taken 直接返回 @@ -60,7 +60,7 @@ def spark_run(load, conf): spark_submit_cmd, application_id, runtime = spark_executor.parser_spark_output(spark_output) except Exception: exam_record.delete() - logger.error('During parsing spark output, known error occurred.') + global_logger.error('During parsing spark output, known error occurred.') raise try: @@ -100,11 +100,11 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s break except HTTPError as httpe: time.sleep(1) - logger.warning(f"HistoryServer访问错误:{httpe}") + global_logger.warning(f"HistoryServer访问错误:{httpe}") continue - trace_dict['sql'] = trace_data_saver(data=trace_sql, data_dir=OA_CONF.data_dir) - trace_dict['stages'] = trace_data_saver(data=trace_stages, data_dir=OA_CONF.data_dir) - trace_dict['executor'] = trace_data_saver(data=trace_executor, data_dir=OA_CONF.data_dir) + trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) + trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) + trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) ExamRecordRepository.update_exam_result(exam_record, trace=trace_dict) diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 9eddd2073dc77098b3c1e05ff6c1b05e0fa6e9ed..f1e070a3115674f0136ac14a5f38f1843787f840 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -56,5 +56,19 @@ if not os.path.exists(OA_CONF.log_dir): # 使用dictConfig加载配置 dictConfig(LOGGING_CONFIG) + # 获取logger并使用 -logger = logging.getLogger('project_logger') +global_logger = logging.getLogger('project_logger') + + +def disable_console_handler(logger): + """ + 如果logger中存在consoleHandler 将consoleHandler移除 + + :param logger: 目标 logger 实例 + """ + for handler in logger.handlers[:]: + if getattr(handler, "name", None) == 'consoleHandler': + logger.removeHandler(handler) # 禁用 consoleHandler + return + logger.debug("consoleHandler not exist, no need to disable") \ No newline at end of file diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index f4a65957c82f1a4961ca41458a2c107114fd8897..9ea1230d88f2575e2d53b6f0ffdf7157771985d7 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -8,7 +8,7 @@ from ConfigSpace import ConfigurationSpace, UniformIntegerHyperparameter, Consta UniformFloatHyperparameter from common.constant import OA_CONF -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger def run_cmd(submit_cmd) -> Tuple[int, str]: @@ -22,13 +22,13 @@ def run_cmd(submit_cmd) -> Tuple[int, str]: 第二个元素是一个保存在列表中的字符串 ["xxxx"] :rtype: Tuple[int, str] """ - logger.debug(f"Executor system command: {submit_cmd}") + global_logger.debug(f"Executor system command: {submit_cmd}") exitcode, data = subprocess.getstatusoutput(submit_cmd) - logger.debug(data) + global_logger.debug(data) return exitcode, data -def trace_data_saver(data: List[Dict[str, str]], data_dir): +def save_trace_data(data: List[Dict[str, str]], data_dir): """ 用于把各类trace信息以文件的形式保存,并返回文件的绝对路径 :param data:通过RESTAPI 获取的list[dict]形式文件 @@ -60,11 +60,11 @@ def read_json_file(file_path: str): # 将 JSON 文件内容解析为 Python 对象 data = json.load(file) except FileNotFoundError: - logger.error(f"Error: The file '{file_path}' does not exist.") + global_logger.error(f"Error: The file '{file_path}' does not exist.") except json.JSONDecodeError: - logger.error(f"Error: The file '{file_path}' is not a valid JSON file.") + global_logger.error(f"Error: The file '{file_path}' is not a valid JSON file.") except Exception as e: - logger.error(f"An unexpected error occurred: {e}") + global_logger.error(f"An unexpected error occurred: {e}") return data diff --git a/omniadvisor/src/tuning.py b/omniadvisor/src/tuning.py index 674e8e5651cae896372235b21ba8ab44fbc86af1..8f75b8e078735ca888c3b8884efe50e9a2275b3e 100644 --- a/omniadvisor/src/tuning.py +++ b/omniadvisor/src/tuning.py @@ -1,5 +1,5 @@ from omniadvisor.interface import config_tuning -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger if __name__ == '__main__': @@ -9,10 +9,10 @@ if __name__ == '__main__': # 若需进行逻辑处理(如环境清理等),则需在相应位置处理后重新抛至该层 except Exception as e: # 异常信息统一在此处打印堆栈,方便定位,抛出异常的地方无需打印log - logger.exception(e) + global_logger.exception(e) # 异常退出 exit(1) # 正常退出 - logger.info('Tuning mission complete!') + global_logger.info('Tuning mission complete!') exit(0) diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index 98b302bc6cd7e405a943b97fafe2668fb30d412c..43d93dfbed23bac4217c69241009d0589f8a90a2 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -84,7 +84,7 @@ class TestTuning: :return: """ with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') as mock_query_by_id, \ - patch('omniadvisor.utils.logger.logger.info') as mock_info: + patch('omniadvisor.utils.logger.global_logger.info') as mock_info: mock_query_by_id.return_value = None unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_info.assert_called_with('Cannot find load id: %s in database.', '2') diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py index 0bf413bc09fc20cd46b6369c77843fe003b9a224..22abe880bdb63dc939760686e31f8c546de9bd0e 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock +from common.constant import OA_CONF from omniadvisor.service.spark_service.spark_command_reconstruct import spark_command_reconstruct @@ -11,8 +12,8 @@ class TestSparkCommandReconstruct: # 定义 conf 参数 conf = {"spark.sql.shuffle.partitions": "3200", "spark.executor.cores": "4"} - - expected_cmd = "spark-sql --name TestApp" \ + cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" + expected_cmd = f"{cmd_prefix} --name TestApp" \ " --driver-memory 8G" \ " --executor-memory 16G" \ " --conf spark.sql.shuffle.partitions=3200" \ @@ -28,8 +29,8 @@ class TestSparkCommandReconstruct: mock_load.exec_attr = {"i": "init.script.path"} conf = {} - - expected_cmd = "spark-sql --name InitScriptApp -i init.script.path" + cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" + expected_cmd = f"{cmd_prefix} --name InitScriptApp -i init.script.path" result = spark_command_reconstruct(mock_load, conf) diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py index 1cd5e26baf1f7b38e56a1b53d93a5a98bd687bb5..c1513bc3331923178d7bdc7c4c2949ae10384e55 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py @@ -18,10 +18,10 @@ class TestSparkRun: @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.create') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.submit_spark_task') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.parser_spark_output') - @patch('omniadvisor.service.spark_service.spark_run.trace_data_saver') + @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_trace_data_saver, mock_parser_spark_output, + def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_parser_spark_output, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 @@ -40,7 +40,7 @@ class TestSparkRun: mock_update_exam_result.return_value = task_update_mock mock_submit_spark_task.return_value = (0, "success output\napplication_id:app_12345\ntime_taken: 10") mock_parser_spark_output.return_value = ["spark-submit --master local[*]", "app_12345", 10] - mock_trace_data_saver.return_value = f"{OA_CONF.data_dir}/testfile" + mock_save_trace_data.return_value = f"{OA_CONF.data_dir}/testfile" load_mock = MagicMock() load_mock.name = "test_name" load_mock.exec_attr = { @@ -61,10 +61,10 @@ class TestSparkRun: @patch('omniadvisor.service.spark_service.spark_run.spark_command_reconstruct') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.create') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.submit_spark_task') - @patch('omniadvisor.service.spark_service.spark_run.trace_data_saver') + @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_trace_data_saver, + def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 mock_response = Mock() @@ -81,7 +81,7 @@ class TestSparkRun: task_update_mock = MagicMock() mock_update_exam_result.return_value = task_update_mock mock_submit_spark_task.return_value = (1, "failure output") - mock_trace_data_saver.return_value = f"{OA_CONF.data_dir}/testfile" + mock_save_trace_data.return_value = f"{OA_CONF.data_dir}/testfile" load_mock = MagicMock() load_mock.name = "test_name" diff --git a/omniadvisor/tests/omniadvisor/utils/test_logger.py b/omniadvisor/tests/omniadvisor/utils/test_logger.py index dbdd270833c7b873dc649bc29d71d964d5d8b16d..5ed1206a65135edb8be2708516b30a89592f10b6 100755 --- a/omniadvisor/tests/omniadvisor/utils/test_logger.py +++ b/omniadvisor/tests/omniadvisor/utils/test_logger.py @@ -4,7 +4,7 @@ import os import pytest from pathlib import Path -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger @pytest.fixture(scope="class", autouse=True) @@ -17,27 +17,27 @@ def class_fixture(): # 临时替换FileHandler 用于测试 orihandler = "" orihandler_index = "" - for handler in logger.handlers: + for handler in global_logger.handlers: if isinstance(handler, logging.FileHandler): orihandler = handler new_file_handler = logging.FileHandler(tmplog) new_file_handler.setLevel(logging.INFO) - orihandler_index = logger.handlers.index(handler) - logger.handlers[orihandler_index] = new_file_handler + orihandler_index = global_logger.handlers.index(handler) + global_logger.handlers[orihandler_index] = new_file_handler yield - logger.handlers[orihandler_index] = orihandler + global_logger.handlers[orihandler_index] = orihandler os.remove(tmplog) class TestLogger: def test_log(self): - logger.debug("This is a debug message") - logger.info("This is an info message") - logger.warning("This is a warning message") - logger.error("This is an error message") - logger.critical("This is a critical message") + global_logger.debug("This is a debug message") + global_logger.info("This is an info message") + global_logger.warning("This is a warning message") + global_logger.error("This is an error message") + global_logger.critical("This is a critical message") assert Path(tmplog).stat().st_size > 0, f"文件 {tmplog} 大小为 0 字节,写入失败" # 验证日志文件内容 diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index 0f0cf913df7dbe7377423b28a2fd6bbcd4b3a888..9c5bb71b191a6653fb7d384bb92fea2b7f9c77e5 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -1,5 +1,5 @@ from unittest.mock import patch,mock_open -from omniadvisor.utils.utils import run_cmd, trace_data_saver # 假设你的函数定义在 'your_module' 中 +from omniadvisor.utils.utils import run_cmd, save_trace_data # 假设你的函数定义在 'your_module' 中 class TestRunCmd: @@ -30,19 +30,19 @@ class TestRunCmd: class TestTraceDataSaver: @patch('os.makedirs') @patch('uuid.uuid4', return_value="test-uuid") - def test_trace_data_saver_success(self, mock_uuid, mock_makedirs): + def test_save_trace_data_success(self, mock_uuid, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() with patch('builtins.open', m, create=True): - file_path = trace_data_saver(data, data_dir) + file_path = save_trace_data(data, data_dir) expected_path = f"{data_dir}/test-uuid" assert file_path == expected_path m.assert_called_once_with(expected_path, 'w', encoding='utf-8') @patch('os.makedirs') - def test_trace_data_saver_ioerror(self, mock_makedirs): + def test_save_trace_data_ioerror(self, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() @@ -52,14 +52,14 @@ class TestTraceDataSaver: patch('uuid.uuid4', return_value="test-uuid"), \ patch('builtins.open', m, create=True): try: - trace_data_saver(data, data_dir) + save_trace_data(data, data_dir) assert False, "Expected an IOError to be raised." except IOError as e: assert str(e) == "出现IO错误: IO Error", f"Unexpected error message: {str(e)}" @patch('os.makedirs') - def test_trace_data_saver_exception(self, mock_makedirs): + def test_save_trace_data_exception(self, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() @@ -69,7 +69,7 @@ class TestTraceDataSaver: patch('uuid.uuid4', return_value="test-uuid"), \ patch('builtins.open', m, create=True): try: - trace_data_saver(data, data_dir) + save_trace_data(data, data_dir) assert False, "Expected an Exception to be raised." except Exception as e: assert str(e) == "保存过程中出现错误: Unexpected error", f"Unexpected error message: {str(e)}" \ No newline at end of file