From 3dc359c3453b41a9b23ae9e45b531933c5f9024b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Thu, 17 Apr 2025 21:04:28 +0800 Subject: [PATCH 1/6] =?UTF-8?q?1.=E5=9C=A8hijack=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E4=B8=8B=E7=A6=81=E7=94=A8consoleHandler=202.=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E5=91=BD=E4=BB=A4=E7=9A=84=E6=8F=90=E4=BA=A4=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=20=E4=BD=BF=E7=94=A8{OA=5FCONF.spark=5Fhome}/bin/spar?= =?UTF-8?q?k-class=E6=8F=90=E4=BA=A4=203.=E7=9B=B8=E5=85=B3=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=B1=BB=E5=90=8C=E6=AD=A5=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/config/common_config.cfg | 1 + omniadvisor/src/common/constant.py | 1 + omniadvisor/src/hijack.py | 7 ++++--- .../omniadvisor/interface/config_tuning.py | 4 ++-- .../omniadvisor/interface/hijack_recommend.py | 2 +- .../src/omniadvisor/service/retest_service.py | 10 +++++----- .../spark_command_reconstruct.py | 7 ++++--- .../spark_service/spark_parameter_parser.py | 6 +++--- .../service/spark_service/spark_run.py | 13 ++++++------ omniadvisor/src/omniadvisor/utils/logger.py | 16 ++++++++++++++- omniadvisor/src/omniadvisor/utils/utils.py | 12 +++++------ omniadvisor/src/tuning.py | 6 +++--- .../interface/test_config_tuning.py | 2 +- .../test_spark_command_reconstruct.py | 9 +++++---- .../tests/omniadvisor/utils/test_logger.py | 20 +++++++++---------- 15 files changed, 67 insertions(+), 49 deletions(-) diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index a7208d244..f173aef5a 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -5,3 +5,4 @@ tuning.retest.times=3 [spark] # Spark History Server的URL 仅用于Rest模式 spark.history.rest.url=http://localhost:18081 +spark.home=/usr/local/spark diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index f545a9eca..b7f83af26 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -73,6 +73,7 @@ class OmniAdvisorConf: # 配置罗列 tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') + spark_home = _common_config.get('spark', 'spark.home') OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/hijack.py b/omniadvisor/src/hijack.py index ed3553496..6303feaed 100644 --- a/omniadvisor/src/hijack.py +++ b/omniadvisor/src/hijack.py @@ -1,19 +1,20 @@ from omniadvisor.interface import hijack_recommend -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger, disable_console_handler if __name__ == '__main__': + disable_console_handler(global_logger) try: hijack_recommend.main() # 无需进行逻辑处理的异常,直接抛至该层 # 若需进行逻辑处理(如环境清理等),则需在相应位置处理后重新抛至该层 except Exception as e: # 异常信息统一在此处打印堆栈,方便定位,抛出异常的地方无需打印log - logger.exception(e) + global_logger.exception(e) # 异常退出 # 劫持任务异常退出后,需在Spark脚本内执行原Spark命令语句 exit(1) # 正常退出 - logger.info('Hijack mission complete!') + global_logger.info('Hijack mission complete!') exit(0) diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 6c89b132f..55b431a63 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -1,7 +1,7 @@ import argparse from common.constant import OA_CONF -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from omniadvisor.repository.load_repository import LoadRepository from omniadvisor.service.tuning_result.tuning_result_history import get_tuning_result_history from omniadvisor.service.retest_service import retest @@ -42,7 +42,7 @@ def unified_tuning(load_id: str, retest_way: str, tuning_method: str): loads = LoadRepository.query_by_id(load_id) if not loads: # 若查询结果为空,直接返回即可 - logger.info('Cannot find load id: %s in database.', load_id) + global_logger.info('Cannot find load id: %s in database.', load_id) return load = loads[0] diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 3a8290d64..53417dd6d 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -3,7 +3,7 @@ import sys from omniadvisor.service.spark_service.spark_parameter_parser import SparkParameterParser from omniadvisor.repository.model.load import Load from omniadvisor.repository.load_repository import LoadRepository -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from omniadvisor.service.spark_service.spark_run import spark_run from common.constant import OA_CONF diff --git a/omniadvisor/src/omniadvisor/service/retest_service.py b/omniadvisor/src/omniadvisor/service/retest_service.py index face40306..f982b1649 100644 --- a/omniadvisor/src/omniadvisor/service/retest_service.py +++ b/omniadvisor/src/omniadvisor/service/retest_service.py @@ -1,7 +1,7 @@ from common.constant import OA_CONF from omniadvisor.repository.model.load import Load from omniadvisor.service.spark_service.spark_run import spark_run -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger def retest(load: Load, config: dict): @@ -12,15 +12,15 @@ def retest(load: Load, config: dict): :param config: 配置 :return: """ - logger.debug('开始复测配置...') + global_logger.debug('开始复测配置...') for i in range(1, OA_CONF.tuning_retest_times + 1): try: exam_record, spark_output = spark_run(load, config) except Exception as e: - logger.error('复测第 %d 轮失败,原因:%s', i, e) + global_logger.error('复测第 %d 轮失败,原因:%s', i, e) return if exam_record.status == OA_CONF.ExecStatus.success: - logger.info('复测第 %d 轮成功,性能结果:%.3f', i, exam_record.runtime) + global_logger.info('复测第 %d 轮成功,性能结果:%.3f', i, exam_record.runtime) else: - logger.error('复测第 %d 轮失败,spark执行异常:%s', i, spark_output) + global_logger.error('复测第 %d 轮失败,spark执行异常:%s', i, spark_output) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py index e42ef37d1..541b3c13f 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py @@ -1,4 +1,5 @@ -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger +from common.constant import OA_CONF def spark_command_reconstruct(load, conf): @@ -7,7 +8,7 @@ def spark_command_reconstruct(load, conf): :param conf: {key:value} 保存spark-sql命令中形如 --conf spark.sql.shuffle.partitions(key)=3200(value)的部分 :return: """ - cmd_prefix = "spark-sql" + cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" name = load.name submit_cmd_list = [cmd_prefix, "--name", name] @@ -38,7 +39,7 @@ def spark_command_reconstruct(load, conf): submit_cmd_list.append(f"{_normalize_key(key)}={_normalize_value(value)}") submit_cmd = " ".join(submit_cmd_list) - logger.info(f"拼接后的spark-sql命令如下{submit_cmd}") + global_logger.info(f"拼接后的spark-sql命令如下{submit_cmd}") return submit_cmd diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py index 028fd18e4..28138ce23 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py @@ -1,6 +1,6 @@ import shlex from dataclasses import dataclass -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from common.constant import OA_CONF from omniadvisor.service.parameter_parser import ParserInterface @@ -111,7 +111,7 @@ class SparkParameterParser(ParserInterface): conf_params[confkey].append(confvalue) else: conf_params[confkey] = [conf_params[key], value] - logger.warn(f"{confkey}被重复配置,多次配置值如下{conf_params[confvalue]}") + global_logger.warn(f"{confkey}被重复配置,多次配置值如下{conf_params[confvalue]}") else: conf_params[confkey] = confvalue # 将 --num-executors --executors-cores --driver-cores这类的配置归类到conf_params中便于后续重建命令 @@ -119,6 +119,6 @@ class SparkParameterParser(ParserInterface): conf_params[SPARK_CONF_SUPPLEMENT_MAP[key]] = value else: name = value - logger.warn(f"The remainder unknown params {unknown} will be add to exec_attr[remainder]") + global_logger.warn(f"The remainder unknown params {unknown} will be add to exec_attr[remainder]") exec_attr["remainder"] = unknown return BaseParams(name=name, conf_params=conf_params, exec_attr=exec_attr) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 3a65fc757..b16603449 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -2,7 +2,7 @@ import time import multiprocessing from requests.exceptions import HTTPError -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher from omniadvisor.service.spark_service.spark_executor import SparkExecutor from omniadvisor.repository.model.exam_record import ExamRecord @@ -51,7 +51,7 @@ def spark_run(load, conf): return exam_record, spark_output -# TODO 这个也要端到端验证一下 用了子进程会不会对ExamRecordRepository有同时读写的问题? + def _update_trace_from_history_server(exam_record: ExamRecord, application_id: str): """ 创建一个子进程对history_server进行轮询 超时时间为10s @@ -67,14 +67,13 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s trace_sql = spark_fetcher.get_spark_sql_by_app(application_id) trace_stages = spark_fetcher.get_spark_stages_by_app(application_id) trace_executor = spark_fetcher.get_spark_executor_by_app(application_id) + trace_dict['sql'] = trace_data_saver(data=trace_sql, data_dir=OA_CONF.data_dir) + trace_dict['stages'] = trace_data_saver(data=trace_stages, data_dir=OA_CONF.data_dir) + trace_dict['executor'] = trace_data_saver(data=trace_executor, data_dir=OA_CONF.data_dir) break except HTTPError as httpe: time.sleep(1) - logger.warning(f"HistoryServer访问错误:{httpe}") + global_logger.warning(f"HistoryServer访问错误:{httpe}") continue - trace_dict['sql'] = trace_data_saver(data=trace_sql, data_dir=OA_CONF.data_dir) - trace_dict['stages'] = trace_data_saver(data=trace_stages, data_dir=OA_CONF.data_dir) - trace_dict['executor'] = trace_data_saver(data=trace_executor, data_dir=OA_CONF.data_dir) - ExamRecordRepository.update_exam_result(exam_record, trace=trace_dict) diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 9eddd2073..8febfccdd 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -56,5 +56,19 @@ if not os.path.exists(OA_CONF.log_dir): # 使用dictConfig加载配置 dictConfig(LOGGING_CONFIG) + # 获取logger并使用 -logger = logging.getLogger('project_logger') +global_logger = logging.getLogger('project_logger') + + +def disable_console_handler(logger): + """ + 如果logger中存在consoleHandler 将consoleHandler移除 + + :param logger: 目标 logger 实例 + """ + for handler in logger.handler[:]: + if getattr(handler, "name", None) == 'consoleHandler': + logger.removeHandler(handler) # 禁用 consoleHandler + return + logger.info("consoleHandler not exist, no need to disable") \ No newline at end of file diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index f4a65957c..add4ff219 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -8,7 +8,7 @@ from ConfigSpace import ConfigurationSpace, UniformIntegerHyperparameter, Consta UniformFloatHyperparameter from common.constant import OA_CONF -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger def run_cmd(submit_cmd) -> Tuple[int, str]: @@ -22,9 +22,9 @@ def run_cmd(submit_cmd) -> Tuple[int, str]: 第二个元素是一个保存在列表中的字符串 ["xxxx"] :rtype: Tuple[int, str] """ - logger.debug(f"Executor system command: {submit_cmd}") + global_logger.debug(f"Executor system command: {submit_cmd}") exitcode, data = subprocess.getstatusoutput(submit_cmd) - logger.debug(data) + global_logger.debug(data) return exitcode, data @@ -60,11 +60,11 @@ def read_json_file(file_path: str): # 将 JSON 文件内容解析为 Python 对象 data = json.load(file) except FileNotFoundError: - logger.error(f"Error: The file '{file_path}' does not exist.") + global_logger.error(f"Error: The file '{file_path}' does not exist.") except json.JSONDecodeError: - logger.error(f"Error: The file '{file_path}' is not a valid JSON file.") + global_logger.error(f"Error: The file '{file_path}' is not a valid JSON file.") except Exception as e: - logger.error(f"An unexpected error occurred: {e}") + global_logger.error(f"An unexpected error occurred: {e}") return data diff --git a/omniadvisor/src/tuning.py b/omniadvisor/src/tuning.py index 674e8e565..8f75b8e07 100644 --- a/omniadvisor/src/tuning.py +++ b/omniadvisor/src/tuning.py @@ -1,5 +1,5 @@ from omniadvisor.interface import config_tuning -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger if __name__ == '__main__': @@ -9,10 +9,10 @@ if __name__ == '__main__': # 若需进行逻辑处理(如环境清理等),则需在相应位置处理后重新抛至该层 except Exception as e: # 异常信息统一在此处打印堆栈,方便定位,抛出异常的地方无需打印log - logger.exception(e) + global_logger.exception(e) # 异常退出 exit(1) # 正常退出 - logger.info('Tuning mission complete!') + global_logger.info('Tuning mission complete!') exit(0) diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index aeebb7d90..7439930ec 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -32,7 +32,7 @@ class TestTuning: :return: """ with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') as mock_query_by_id, \ - patch('omniadvisor.utils.logger.logger.info') as mock_info: + patch('omniadvisor.utils.logger.global_logger.info') as mock_info: mock_query_by_id.return_value = None unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_info.assert_called_with('Cannot find load id: %s in database.', '2') diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py index 0bf413bc0..22abe880b 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_command_reconstruct.py @@ -1,4 +1,5 @@ from unittest.mock import MagicMock +from common.constant import OA_CONF from omniadvisor.service.spark_service.spark_command_reconstruct import spark_command_reconstruct @@ -11,8 +12,8 @@ class TestSparkCommandReconstruct: # 定义 conf 参数 conf = {"spark.sql.shuffle.partitions": "3200", "spark.executor.cores": "4"} - - expected_cmd = "spark-sql --name TestApp" \ + cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" + expected_cmd = f"{cmd_prefix} --name TestApp" \ " --driver-memory 8G" \ " --executor-memory 16G" \ " --conf spark.sql.shuffle.partitions=3200" \ @@ -28,8 +29,8 @@ class TestSparkCommandReconstruct: mock_load.exec_attr = {"i": "init.script.path"} conf = {} - - expected_cmd = "spark-sql --name InitScriptApp -i init.script.path" + cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" + expected_cmd = f"{cmd_prefix} --name InitScriptApp -i init.script.path" result = spark_command_reconstruct(mock_load, conf) diff --git a/omniadvisor/tests/omniadvisor/utils/test_logger.py b/omniadvisor/tests/omniadvisor/utils/test_logger.py index dbdd27083..5ed1206a6 100755 --- a/omniadvisor/tests/omniadvisor/utils/test_logger.py +++ b/omniadvisor/tests/omniadvisor/utils/test_logger.py @@ -4,7 +4,7 @@ import os import pytest from pathlib import Path -from omniadvisor.utils.logger import logger +from omniadvisor.utils.logger import global_logger @pytest.fixture(scope="class", autouse=True) @@ -17,27 +17,27 @@ def class_fixture(): # 临时替换FileHandler 用于测试 orihandler = "" orihandler_index = "" - for handler in logger.handlers: + for handler in global_logger.handlers: if isinstance(handler, logging.FileHandler): orihandler = handler new_file_handler = logging.FileHandler(tmplog) new_file_handler.setLevel(logging.INFO) - orihandler_index = logger.handlers.index(handler) - logger.handlers[orihandler_index] = new_file_handler + orihandler_index = global_logger.handlers.index(handler) + global_logger.handlers[orihandler_index] = new_file_handler yield - logger.handlers[orihandler_index] = orihandler + global_logger.handlers[orihandler_index] = orihandler os.remove(tmplog) class TestLogger: def test_log(self): - logger.debug("This is a debug message") - logger.info("This is an info message") - logger.warning("This is a warning message") - logger.error("This is an error message") - logger.critical("This is a critical message") + global_logger.debug("This is a debug message") + global_logger.info("This is an info message") + global_logger.warning("This is a warning message") + global_logger.error("This is an error message") + global_logger.critical("This is a critical message") assert Path(tmplog).stat().st_size > 0, f"文件 {tmplog} 大小为 0 字节,写入失败" # 验证日志文件内容 -- Gitee From ffd1f76949f0693db4e31aa41fd8d857956206ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Fri, 18 Apr 2025 10:25:14 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E6=B3=A8=E9=87=8A=E8=A1=A5=E5=85=85=20?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=90=8D=E7=A7=B0=E4=BF=AE=E6=94=B9trace=5Fd?= =?UTF-8?q?ata=5Fsaver=20--->=20save=5Ftrace=5Fdata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/hijack.py | 1 + .../omniadvisor/service/spark_service/spark_run.py | 8 ++++---- omniadvisor/src/omniadvisor/utils/utils.py | 2 +- .../service/spark_service/test_spark_run.py | 12 ++++++------ omniadvisor/tests/omniadvisor/utils/test_utils.py | 14 +++++++------- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/omniadvisor/src/hijack.py b/omniadvisor/src/hijack.py index 6303feaed..cc979e0cd 100644 --- a/omniadvisor/src/hijack.py +++ b/omniadvisor/src/hijack.py @@ -3,6 +3,7 @@ from omniadvisor.utils.logger import global_logger, disable_console_handler if __name__ == '__main__': + # 前台劫持功能为了保证用户无感 不应在控制台有日志输出 因此禁用控制台输出 disable_console_handler(global_logger) try: hijack_recommend.main() diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index b16603449..fc71f93c8 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -9,7 +9,7 @@ from omniadvisor.repository.model.exam_record import ExamRecord from omniadvisor.repository.exam_record_repository import ExamRecordRepository from omniadvisor.service.spark_service.spark_command_reconstruct import spark_command_reconstruct from common.constant import OA_CONF -from omniadvisor.utils.utils import trace_data_saver +from omniadvisor.utils.utils import save_trace_data def spark_run(load, conf): @@ -67,9 +67,9 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s trace_sql = spark_fetcher.get_spark_sql_by_app(application_id) trace_stages = spark_fetcher.get_spark_stages_by_app(application_id) trace_executor = spark_fetcher.get_spark_executor_by_app(application_id) - trace_dict['sql'] = trace_data_saver(data=trace_sql, data_dir=OA_CONF.data_dir) - trace_dict['stages'] = trace_data_saver(data=trace_stages, data_dir=OA_CONF.data_dir) - trace_dict['executor'] = trace_data_saver(data=trace_executor, data_dir=OA_CONF.data_dir) + trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) + trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) + trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) break except HTTPError as httpe: time.sleep(1) diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index add4ff219..9ea1230d8 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -28,7 +28,7 @@ def run_cmd(submit_cmd) -> Tuple[int, str]: return exitcode, data -def trace_data_saver(data: List[Dict[str, str]], data_dir): +def save_trace_data(data: List[Dict[str, str]], data_dir): """ 用于把各类trace信息以文件的形式保存,并返回文件的绝对路径 :param data:通过RESTAPI 获取的list[dict]形式文件 diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py index 1cd5e26ba..c1513bc33 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py @@ -18,10 +18,10 @@ class TestSparkRun: @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.create') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.submit_spark_task') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.parser_spark_output') - @patch('omniadvisor.service.spark_service.spark_run.trace_data_saver') + @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_trace_data_saver, mock_parser_spark_output, + def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_parser_spark_output, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 @@ -40,7 +40,7 @@ class TestSparkRun: mock_update_exam_result.return_value = task_update_mock mock_submit_spark_task.return_value = (0, "success output\napplication_id:app_12345\ntime_taken: 10") mock_parser_spark_output.return_value = ["spark-submit --master local[*]", "app_12345", 10] - mock_trace_data_saver.return_value = f"{OA_CONF.data_dir}/testfile" + mock_save_trace_data.return_value = f"{OA_CONF.data_dir}/testfile" load_mock = MagicMock() load_mock.name = "test_name" load_mock.exec_attr = { @@ -61,10 +61,10 @@ class TestSparkRun: @patch('omniadvisor.service.spark_service.spark_run.spark_command_reconstruct') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.create') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.submit_spark_task') - @patch('omniadvisor.service.spark_service.spark_run.trace_data_saver') + @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_trace_data_saver, + def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 mock_response = Mock() @@ -81,7 +81,7 @@ class TestSparkRun: task_update_mock = MagicMock() mock_update_exam_result.return_value = task_update_mock mock_submit_spark_task.return_value = (1, "failure output") - mock_trace_data_saver.return_value = f"{OA_CONF.data_dir}/testfile" + mock_save_trace_data.return_value = f"{OA_CONF.data_dir}/testfile" load_mock = MagicMock() load_mock.name = "test_name" diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index 0f0cf913d..9c5bb71b1 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -1,5 +1,5 @@ from unittest.mock import patch,mock_open -from omniadvisor.utils.utils import run_cmd, trace_data_saver # 假设你的函数定义在 'your_module' 中 +from omniadvisor.utils.utils import run_cmd, save_trace_data # 假设你的函数定义在 'your_module' 中 class TestRunCmd: @@ -30,19 +30,19 @@ class TestRunCmd: class TestTraceDataSaver: @patch('os.makedirs') @patch('uuid.uuid4', return_value="test-uuid") - def test_trace_data_saver_success(self, mock_uuid, mock_makedirs): + def test_save_trace_data_success(self, mock_uuid, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() with patch('builtins.open', m, create=True): - file_path = trace_data_saver(data, data_dir) + file_path = save_trace_data(data, data_dir) expected_path = f"{data_dir}/test-uuid" assert file_path == expected_path m.assert_called_once_with(expected_path, 'w', encoding='utf-8') @patch('os.makedirs') - def test_trace_data_saver_ioerror(self, mock_makedirs): + def test_save_trace_data_ioerror(self, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() @@ -52,14 +52,14 @@ class TestTraceDataSaver: patch('uuid.uuid4', return_value="test-uuid"), \ patch('builtins.open', m, create=True): try: - trace_data_saver(data, data_dir) + save_trace_data(data, data_dir) assert False, "Expected an IOError to be raised." except IOError as e: assert str(e) == "出现IO错误: IO Error", f"Unexpected error message: {str(e)}" @patch('os.makedirs') - def test_trace_data_saver_exception(self, mock_makedirs): + def test_save_trace_data_exception(self, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" m = mock_open() @@ -69,7 +69,7 @@ class TestTraceDataSaver: patch('uuid.uuid4', return_value="test-uuid"), \ patch('builtins.open', m, create=True): try: - trace_data_saver(data, data_dir) + save_trace_data(data, data_dir) assert False, "Expected an Exception to be raised." except Exception as e: assert str(e) == "保存过程中出现错误: Unexpected error", f"Unexpected error message: {str(e)}" \ No newline at end of file -- Gitee From 5dfc002d95656688b61567d625c0a4b0c95de418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Fri, 18 Apr 2025 11:41:49 +0800 Subject: [PATCH 3/6] =?UTF-8?q?save=5Fdata=5Ftrace=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../omniadvisor/service/spark_service/spark_run.py | 11 ++++++----- omniadvisor/src/omniadvisor/utils/logger.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index fc71f93c8..719916aa8 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -67,13 +67,14 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s trace_sql = spark_fetcher.get_spark_sql_by_app(application_id) trace_stages = spark_fetcher.get_spark_stages_by_app(application_id) trace_executor = spark_fetcher.get_spark_executor_by_app(application_id) - trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) - trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) - trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) - break except HTTPError as httpe: time.sleep(1) global_logger.warning(f"HistoryServer访问错误:{httpe}") continue - + trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) + trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) + trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) + break ExamRecordRepository.update_exam_result(exam_record, trace=trace_dict) + + diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 8febfccdd..66cb21c6f 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -71,4 +71,4 @@ def disable_console_handler(logger): if getattr(handler, "name", None) == 'consoleHandler': logger.removeHandler(handler) # 禁用 consoleHandler return - logger.info("consoleHandler not exist, no need to disable") \ No newline at end of file + logger.debug("consoleHandler not exist, no need to disable") \ No newline at end of file -- Gitee From 4a4cc004de47b1afae17adb45ab2906eb00c9ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:18:43 +0800 Subject: [PATCH 4/6] testing --- omniadvisor/config/common_config.cfg | 2 +- omniadvisor/src/omniadvisor/interface/hijack_recommend.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index f173aef5a..fa48e559e 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -4,5 +4,5 @@ tuning.retest.times=3 [spark] # Spark History Server的URL 仅用于Rest模式 -spark.history.rest.url=http://localhost:18081 +spark.history.rest.url=http://90.91.16.183:18081 spark.home=/usr/local/spark diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 53417dd6d..12de329f5 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -21,7 +21,7 @@ def _query_or_create_load(name: str, exec_attr: dict, default_config: dict): loads = LoadRepository.query_by_name_and_default_config(name=name, default_config=default_config) # 如果查询不到负载信息,则创建新的负载 if not loads: - logger.info("Not exit available load, create new one.") + global_logger.info("Not exit available load, create new one.") # TODO 如果用户在这里创建了一个格式异常的config, 会引起spark执行失败 这个应该如何处理 load = LoadRepository.create(name=name, exec_attr=exec_attr, default_config=default_config) else: -- Gitee From 61b0a18457bd5d9c62fe9e474067ad0c30549cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Fri, 18 Apr 2025 15:21:21 +0800 Subject: [PATCH 5/6] testing --- omniadvisor/src/omniadvisor/utils/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 66cb21c6f..f1e070a31 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -67,7 +67,7 @@ def disable_console_handler(logger): :param logger: 目标 logger 实例 """ - for handler in logger.handler[:]: + for handler in logger.handlers[:]: if getattr(handler, "name", None) == 'consoleHandler': logger.removeHandler(handler) # 禁用 consoleHandler return -- Gitee From b3bbbcbb64568faef0e5017477c1f7c2e99735f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Thu, 24 Apr 2025 19:57:12 +0800 Subject: [PATCH 6/6] testing --- omniadvisor/config/common_config.cfg | 3 +-- omniadvisor/src/common/constant.py | 2 -- .../spark_service/spark_command_reconstruct.py | 6 +++++- .../service/spark_service/spark_executor.py | 11 ++++++----- .../service/spark_service/spark_parameter_parser.py | 9 +++++++-- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index fa48e559e..73ccf53ef 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -4,5 +4,4 @@ tuning.retest.times=3 [spark] # Spark History Server的URL 仅用于Rest模式 -spark.history.rest.url=http://90.91.16.183:18081 -spark.home=/usr/local/spark +spark.history.rest.url=http://localhost:18080 \ No newline at end of file diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index b7f83af26..450f60e24 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -73,7 +73,5 @@ class OmniAdvisorConf: # 配置罗列 tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') - spark_home = _common_config.get('spark', 'spark.home') - OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py index 541b3c13f..31aa2dfa3 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_command_reconstruct.py @@ -8,11 +8,15 @@ def spark_command_reconstruct(load, conf): :param conf: {key:value} 保存spark-sql命令中形如 --conf spark.sql.shuffle.partitions(key)=3200(value)的部分 :return: """ - cmd_prefix = f"{OA_CONF.spark_home}/bin/spark-class org.apache.spark.deploy.SparkSubmit" + cmd_prefix = load.exec_attr.get('cmd_prefix') + if not cmd_prefix: + raise AttributeError("the spark-sql command prefix is deprecated") name = load.name submit_cmd_list = [cmd_prefix, "--name", name] for key, value in load.exec_attr.items(): + if key == 'cmd_prefix': + continue # -i参数在处理上不支持--的形式 单独处理 if key == "i": submit_cmd_list.append("-i") diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py index 125807e65..1c3426177 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py @@ -20,7 +20,8 @@ class SparkExecutor: spark_output = spark_output.split("\n") spark_submit_cmd = "" application_id = "" - time_taken = -1.0 + total_time_taken = -1.0 + time_taken_list = [] for item in spark_output: if "deploy.SparkSubmit" in item: spark_submit_cmd = item.strip() @@ -35,7 +36,7 @@ class SparkExecutor: pattern = r"Time taken:\s*(.*)seconds" match = re.search(pattern, item) if match: - time_taken = match.group(1).strip() - if application_id and float(time_taken) > 0: - break - return spark_submit_cmd, application_id, float(time_taken) + time_taken_list.append(float(match.group(1).strip())) + if time_taken_list: + total_time_taken = sum(time_taken_list) + return spark_submit_cmd, application_id, float(total_time_taken) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py index 28138ce23..07ba72cfa 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py @@ -84,11 +84,16 @@ class SparkParameterParser(ParserInterface): """ if not self._submit_cmd: raise ValueError("Submit command cannot be null") - param_start_index = self._submit_cmd.strip().index(" ") name = "" conf_params = {} exec_attr = {} - args = shlex.split(self._submit_cmd[param_start_index:]) + # 解析spark-class的绝对路径 + first_space_index = self._submit_cmd.strip().find(" ") + # 从第一个空格之后的子字符串中找到第二个空格 + second_space_index = self._submit_cmd.strip().find(" ", first_space_index + 1) + exec_attr['cmd_prefix'] = self._submit_cmd[0:second_space_index] + # 解析参数 + args = shlex.split(self._submit_cmd[second_space_index:]) params, unknown = self._parser.parse_known_args(args=args) options_dict = vars(params) for key, value in options_dict.items(): -- Gitee