From 57c64850318a6f33876456bfce642e0fe0b9b23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Sat, 10 May 2025 17:34:58 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=90=8E=E7=9A=84spark?= =?UTF-8?q?-submit=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/config/spark-submit | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 omniadvisor/config/spark-submit diff --git a/omniadvisor/config/spark-submit b/omniadvisor/config/spark-submit new file mode 100644 index 000000000..f33e0c3c3 --- /dev/null +++ b/omniadvisor/config/spark-submit @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +if [ -z "${SPARK_HOME}" ]; then + source "$(dirname "$0")"/find-spark-home +fi + +# disable randomized hash for string in Python 3.3+ +export PYTHONHASHSEED=0 +if [[ -v enable_omniadvisor && "$enable_omniadvisor" = "true" ]]; then + echo "enable omniadvisor" + if [ -z "${HIJACK_PATH}" ]; then + # 根据实际配置修改 + hijack_path="project_dir/src/hijack.py" + else + hijack_path=$HIJACK_PATH + fi + echo "hijack_path="$hijack_path + spark_cmd=""${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"" + echo $spark_cmd + python3 $hijack_path "$spark_cmd" +else + echo "${SPARK_HOME}/bin/spark-class org.apache.spark.deploy.SparkSubmit $@" + exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" +fi \ No newline at end of file -- Gitee From 62354a6d3cb37895981568fb13de0d4ae4a23d7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Sat, 10 May 2025 18:02:20 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=90=8E=E7=9A=84spark?= =?UTF-8?q?-submit=E8=84=9A=E6=9C=AC=20v2.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/{config => script}/spark-submit | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) rename omniadvisor/{config => script}/spark-submit (85%) diff --git a/omniadvisor/config/spark-submit b/omniadvisor/script/spark-submit similarity index 85% rename from omniadvisor/config/spark-submit rename to omniadvisor/script/spark-submit index f33e0c3c3..8e458a092 100644 --- a/omniadvisor/config/spark-submit +++ b/omniadvisor/script/spark-submit @@ -24,16 +24,9 @@ fi # disable randomized hash for string in Python 3.3+ export PYTHONHASHSEED=0 if [[ -v enable_omniadvisor && "$enable_omniadvisor" = "true" ]]; then - echo "enable omniadvisor" - if [ -z "${HIJACK_PATH}" ]; then - # 根据实际配置修改 - hijack_path="project_dir/src/hijack.py" - else - hijack_path=$HIJACK_PATH - fi - echo "hijack_path="$hijack_path + # 根据特性实际的部署路径进行修改 + hijack_path="project_dir/src/hijack.py" spark_cmd=""${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"" - echo $spark_cmd python3 $hijack_path "$spark_cmd" else echo "${SPARK_HOME}/bin/spark-class org.apache.spark.deploy.SparkSubmit $@" -- Gitee From 8266916a8ff9f5586e65fd7a4e46b0e81e7e1ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Mon, 12 May 2025 20:39:28 +0800 Subject: [PATCH 03/10] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9spark=5Fparameter=5Fp?= =?UTF-8?q?arser=E9=80=82=E9=85=8D=20--conf=20spark.executor.extraJavaOpti?= =?UTF-8?q?ons=3D"-XX:+UseG1GC=20-XX:ActiveProcessCount=3D8"=E8=BF=99?= =?UTF-8?q?=E7=B1=BBvalue=E4=B8=AD=E5=90=AB=E6=9C=89"=3D"=E7=9A=84?= =?UTF-8?q?=E5=8F=82=E6=95=B0=202.spark-submit=E8=84=9A=E6=9C=AC=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=203.tuning=E7=9B=B8=E5=85=B3=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/script/spark-submit | 26 ++++++++++++++++--- .../omniadvisor/interface/config_tuning.py | 1 + .../spark_service/spark_parameter_parser.py | 4 +-- .../interface/test_config_tuning.py | 5 ++-- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/omniadvisor/script/spark-submit b/omniadvisor/script/spark-submit index 8e458a092..89b711a12 100644 --- a/omniadvisor/script/spark-submit +++ b/omniadvisor/script/spark-submit @@ -25,10 +25,30 @@ fi export PYTHONHASHSEED=0 if [[ -v enable_omniadvisor && "$enable_omniadvisor" = "true" ]]; then # 根据特性实际的部署路径进行修改 - hijack_path="project_dir/src/hijack.py" - spark_cmd=""${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@"" + echo "enable_omniadvisor" + #hijack_path="/home/y30056354/union_test/merge/src/hijack.py" + hijack_path="/home/x30041342/omniadvisor/src/hijack.py" + spark_cmd=""${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit" + for arg in "$@" + do + if [[ "$arg" == -* ]]; then + spark_cmd="$spark_cmd $arg" + continue + fi + # 检查是否同时包含 = 和空格 + if [[ "$arg" =~ = && "$arg" =~ [[:space:]] ]]; then + # 提取第一个等号前的内容和之后的内容 + before=$(echo "$arg" | sed -E 's/^([^=]*=).*/\1/') + after=$(echo "$arg" | sed -E 's/^[^=]*=(.*)/\1/') + # 将后面的内容加上双引号并拼接 + new_arg="${before}\"${after}\"" + spark_cmd="$spark_cmd $new_arg" + else + spark_cmd="$spark_cmd $arg" + fi + done python3 $hijack_path "$spark_cmd" else - echo "${SPARK_HOME}/bin/spark-class org.apache.spark.deploy.SparkSubmit $@" + echo "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" fi \ No newline at end of file diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 69ec19d64..c0a5523d4 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -3,6 +3,7 @@ import argparse from algo.expert.tuning import ExpertTuning from algo.iterative.tuning import SmacAppendTuning from omniadvisor.repository.load_repository import LoadRepository +from omniadvisor.repository.tuning_record_repository import TuningRecordRepository from omniadvisor.service.retest_service import retest from omniadvisor.service.tuning_result.tuning_result_history import get_tuning_result_history, \ get_next_tuning_method diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py index 07ba72cfa..e472b296a 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_parameter_parser.py @@ -106,9 +106,7 @@ class SparkParameterParser(ParserInterface): # 提取conf配置 elif key == "conf": for confitem in value: - parts = confitem.split("=") - if len(parts) != 2: - raise ValueError(f"conf{parts}的格式不符合key=value的形式") + parts = confitem.split("=", 1) confkey, confvalue = parts # 处理重复键问题,例如合并值为列表 if confkey in conf_params: diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index bee149100..df62bc15b 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -1,10 +1,11 @@ import logging +import sys from unittest.mock import patch, MagicMock import pytest from common.constant import OA_CONF -from omniadvisor.interface.config_tuning import unified_tuning +from omniadvisor.interface.config_tuning import unified_tuning, main class TestTuning: @@ -30,7 +31,7 @@ class TestTuning: mock_exam_record.status = OA_CONF.ExecStatus.success spark_output = '' mock_spark_run.return_value = mock_exam_record, spark_output - unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) + unified_tuning(load=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_update_best.assert_called_once() mock_smac_tuning.assert_called_once() assert mock_spark_run.call_count == OA_CONF.tuning_retest_times -- Gitee From bbe8aafea294be92a47c7d5a6013ed757cd6a89f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 13 May 2025 20:46:27 +0800 Subject: [PATCH 04/10] =?UTF-8?q?1.load.tuning=5Fneeded=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE=E4=B8=BATrue=202.=5Fupdate=5Ftrace=5Ffrom=5F?= =?UTF-8?q?history=5Fserver=E9=80=BB=E8=BE=91=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/omniadvisor/service/spark_service/spark_run.py | 8 ++++---- omniadvisor/src/server/app/models.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 054036f8c..427eb02eb 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -97,14 +97,14 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s trace_sql = spark_fetcher.get_spark_sql_by_app(application_id) trace_stages = spark_fetcher.get_spark_stages_by_app(application_id) trace_executor = spark_fetcher.get_spark_executor_by_app(application_id) - break except HTTPError as httpe: time.sleep(1) global_logger.warning(f"HistoryServer访问错误:{httpe}") continue - trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) - trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) - trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) + trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) + trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) + trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) + break ExamRecordRepository.update_exam_result(exam_record, trace=trace_dict) diff --git a/omniadvisor/src/server/app/models.py b/omniadvisor/src/server/app/models.py index c66b33450..17243fce9 100644 --- a/omniadvisor/src/server/app/models.py +++ b/omniadvisor/src/server/app/models.py @@ -15,7 +15,7 @@ class DatabaseLoad(models.Model): best_config = models.JSONField(null=True) test_config = models.JSONField(null=True) create_time = models.DateTimeField(auto_now_add=True) - tuning_needed = models.BooleanField(default=False) + tuning_needed = models.BooleanField(default=True) class Meta: constraints = [ -- Gitee From 10ee3276656fb302b4a0b81b3bc7a4d18f3ba9bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Wed, 14 May 2025 11:02:18 +0800 Subject: [PATCH 05/10] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9save=5Ftrace=5Fdata?= =?UTF-8?q?=E7=9A=84=E4=BF=9D=E5=AD=98=E6=96=87=E6=9C=AC=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/omniadvisor/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index 9ea1230d8..c23aaccfd 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -41,7 +41,7 @@ def save_trace_data(data: List[Dict[str, str]], data_dir): try: with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=4) - print(f"数据已成功保存到 {file_path}") + global_logger.info(f"数据已成功保存到 {file_path}") except IOError as e: raise IOError(f"出现IO错误: {e}") except Exception as e: -- Gitee From 4c6dfaf6f1d6dde0e991d6118fcdba00900ca574 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Wed, 14 May 2025 11:13:45 +0800 Subject: [PATCH 06/10] =?UTF-8?q?spark-submit=E5=88=A0=E9=99=A4=E5=86=97?= =?UTF-8?q?=E4=BD=99=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/script/spark-submit | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/omniadvisor/script/spark-submit b/omniadvisor/script/spark-submit index 89b711a12..88763a8cf 100644 --- a/omniadvisor/script/spark-submit +++ b/omniadvisor/script/spark-submit @@ -26,8 +26,7 @@ export PYTHONHASHSEED=0 if [[ -v enable_omniadvisor && "$enable_omniadvisor" = "true" ]]; then # 根据特性实际的部署路径进行修改 echo "enable_omniadvisor" - #hijack_path="/home/y30056354/union_test/merge/src/hijack.py" - hijack_path="/home/x30041342/omniadvisor/src/hijack.py" + hijack_path="" spark_cmd=""${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit" for arg in "$@" do @@ -49,6 +48,5 @@ if [[ -v enable_omniadvisor && "$enable_omniadvisor" = "true" ]]; then done python3 $hijack_path "$spark_cmd" else - echo "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" exec "${SPARK_HOME}"/bin/spark-class org.apache.spark.deploy.SparkSubmit "$@" fi \ No newline at end of file -- Gitee From c1602c0806ee3624442cfc4ca44f2e3c7e32f1d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Wed, 14 May 2025 16:23:22 +0800 Subject: [PATCH 07/10] =?UTF-8?q?test=5Fconfig=5Ftuning=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tests/omniadvisor/interface/test_config_tuning.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index df62bc15b..b1d76ae70 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -1,18 +1,17 @@ import logging -import sys from unittest.mock import patch, MagicMock import pytest from common.constant import OA_CONF -from omniadvisor.interface.config_tuning import unified_tuning, main +from omniadvisor.interface.config_tuning import unified_tuning class TestTuning: def setup_class(self): # 创建表 - self.load_id = '2' + self.load = None self.retest_times = 3 self.tuning_method = OA_CONF.TuningMethod.iterative @@ -31,7 +30,7 @@ class TestTuning: mock_exam_record.status = OA_CONF.ExecStatus.success spark_output = '' mock_spark_run.return_value = mock_exam_record, spark_output - unified_tuning(load=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) + unified_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_update_best.assert_called_once() mock_smac_tuning.assert_called_once() assert mock_spark_run.call_count == OA_CONF.tuning_retest_times -- Gitee From f71512699fc6fd97899752d4b1f236a514895678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Thu, 15 May 2025 15:06:09 +0800 Subject: [PATCH 08/10] =?UTF-8?q?1.spark=5Frun=E6=96=B0=E5=A2=9Etime=5Fout?= =?UTF-8?q?=E5=8F=82=E6=95=B0=202.=E7=9B=B8=E5=85=B3=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/config/common_config.cfg | 3 ++- omniadvisor/src/common/constant.py | 1 + .../service/spark_service/spark_run.py | 11 +++++++-- .../interface/test_config_tuning.py | 24 ++++++++++++------- .../interface/test_hijack_recommend.py | 6 +++-- .../service/spark_service/test_spark_run.py | 24 ++++++++++++------- .../tuning_result/test_tuning_result.py | 2 +- 7 files changed, 48 insertions(+), 23 deletions(-) diff --git a/omniadvisor/config/common_config.cfg b/omniadvisor/config/common_config.cfg index 73ccf53ef..8c8c0aa9e 100755 --- a/omniadvisor/config/common_config.cfg +++ b/omniadvisor/config/common_config.cfg @@ -4,4 +4,5 @@ tuning.retest.times=3 [spark] # Spark History Server的URL 仅用于Rest模式 -spark.history.rest.url=http://localhost:18080 \ No newline at end of file +spark.history.rest.url=http://localhost:18080 +spark.sql.timeout.ratio=1.5 \ No newline at end of file diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 70d3bbda8..1e0366dce 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -77,5 +77,6 @@ class OmniAdvisorConf: # 配置罗列 tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') + timeout_ratio = _common_config.getfloat('spark', 'spark.sql.timeout.ratio') OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 427eb02eb..8c9f37a25 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -9,6 +9,7 @@ from omniadvisor.repository.model.exam_record import ExamRecord from omniadvisor.service.spark_service.spark_command_reconstruct import spark_command_reconstruct from omniadvisor.service.spark_service.spark_executor import SparkExecutor from omniadvisor.service.spark_service.spark_fetcher import SparkFetcher +from omniadvisor.service.tuning_result.tuning_result import get_tuning_result from omniadvisor.utils.logger import global_logger from omniadvisor.utils.utils import save_trace_data @@ -17,8 +18,14 @@ def spark_run(load, conf): # 从解析后的参数列表中提取负载与任务的相关信息 submit_cmd = spark_command_reconstruct(load, conf) + # 判断当前的conf是否和load.default_config相同 不相同则在submit_cmd前增加超时时间 + if conf != load.default_config: + # 获取当前default_config的平均测试性能 + baseline_results = get_tuning_result(load, load.default_config) + timeout_sec = OA_CONF.timeout_ratio * baseline_results.runtime + submit_cmd = f"timeout {timeout_sec} " + submit_cmd + # 根据执行命令创建测试记录 - # TODO Task对象的starttime 和 endtime基本上没有用到过? exam_record = ExamRecordRepository.create(load, conf) # 执行当前的submit_cmd @@ -48,7 +55,7 @@ def spark_run(load, conf): status=OA_CONF.ExecStatus.fail, runtime=OA_CONF.exec_fail_return_runtime, trace=OA_CONF.exec_fail_return_trace - ) + ), "spark-sql failed" except RuntimeError: exam_record.delete() raise diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index b1d76ae70..16c350e25 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -1,10 +1,11 @@ import logging +import sys from unittest.mock import patch, MagicMock import pytest from common.constant import OA_CONF -from omniadvisor.interface.config_tuning import unified_tuning +from omniadvisor.interface.config_tuning import unified_tuning, main class TestTuning: @@ -52,7 +53,7 @@ class TestTuning: mock_exam_record.status = OA_CONF.ExecStatus.fail spark_output = '' mock_spark_run.return_value = mock_exam_record, spark_output - unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) + unified_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_smac_tuning.assert_called_once() mock_spark_run.assert_called_once() @@ -75,22 +76,29 @@ class TestTuning: pytest.raises(RuntimeError): mock_exam_record = MagicMock() mock_exam_record.status = OA_CONF.ExecStatus.success - unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) + unified_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) mock_update_best.assert_called_once() mock_smac_tuning.assert_called_once() mock_spark_run.assert_called_once() assert '非spark运行异常' in caplog.text - def test_unified_tuning_when_load_id_not_exist(self): + def test_main_when_load_id_not_exist(self): """ 当 load id 不存在时 :return: """ + sys.argv = ['config_tuning.py', + '--load-id', self.load, + '--retest-way', OA_CONF.RetestWay.backend, + '--tuning-method', self.tuning_method] + with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') as mock_query_by_id, \ - patch('omniadvisor.utils.logger.global_logger.info') as mock_info: + patch('omniadvisor.utils.logger.global_logger.info') as mock_info, \ + patch('omniadvisor.interface.config_tuning.get_tuning_result_history'), \ + patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'): mock_query_by_id.return_value = None - unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - mock_info.assert_called_with('Cannot find load id: %s in database.', '2') + main() + mock_info.assert_called_with('Cannot find load id: %s in database, the tuning process exits.', 'None') def test_unified_tuning_when_retest_failed(self): """ @@ -102,7 +110,7 @@ class TestTuning: patch('omniadvisor.interface.config_tuning.get_tuning_result_history') as mocked_query_perf, \ patch('omniadvisor.repository.load_repository.LoadRepository.update_test_config') as mock_update_test, \ patch('algo.iterative.tuning.SmacAppendTuning.tune') as mock_smac_tuning: - unified_tuning(load_id=self.load_id, retest_way=OA_CONF.RetestWay.hijacking, + unified_tuning(load=self.load, retest_way=OA_CONF.RetestWay.hijacking, tuning_method=self.tuning_method) mocked_query_perf.assert_called_once() mock_update_test.assert_called_once() diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index be8778a12..7dfe1e947 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -48,7 +48,8 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend.SparkParameterParser") @patch("omniadvisor.interface.hijack_recommend.LoadRepository") @patch("omniadvisor.interface.hijack_recommend.spark_run") - def test_hijack_recommend_failure_with_fallback(self, mock_spark_run, mock_load_repo, mock_parser): + @patch("omniadvisor.interface.hijack_recommend._refresh_best_config") + def test_hijack_recommend_failure_with_fallback(self, mock_refreshed_load, mock_spark_run, mock_load_repo, mock_parser): """ 测试 hijack_recommend 在任务失败且需要回退到用户默认配置时的行为。 """ @@ -81,7 +82,8 @@ class TestHijackRecommend(unittest.TestCase): hijack_recommend(spark_sql_cmd) # 验证 LoadRepository.query 被正确调用 - mock_load_repo.query_by_name_and_default_config.assert_called_once_with(name="example_name", default_config={"timeout": 30}) + mock_load_repo.query_by_name_and_default_config.assert_called_once_with(name="example_name", + default_config={"timeout": 30}) # 验证 spark_run 被调用了两次(第一次使用待测试配置,第二次回退到默认配置) assert mock_spark_run.call_count == 2 diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py index c1513bc33..b8bf0b49d 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py @@ -20,8 +20,9 @@ class TestSparkRun: @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.parser_spark_output') @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') + @patch('omniadvisor.service.spark_service.spark_run.multiprocessing.Process') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_parser_spark_output, + def test_spark_run_success(self, mock_get, mock_process, mock_update_exam_result, mock_save_trace_data, mock_parser_spark_output, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 @@ -31,6 +32,8 @@ class TestSparkRun: # 配置mock对象的返回值 mock_get.return_value = mock_response + mock_instance = MagicMock() + mock_process.return_value = mock_instance # 配置mock对象的返回值 mock_spark_command_reconstruct.return_value = "spark-submit --master local[*]" @@ -47,24 +50,26 @@ class TestSparkRun: "attr1": "value1", "attr2": "value2" } + load_mock.default_config = {"another_key": "another_value"} - result_exam_record = spark_run(load_mock, conf) + result_exam_record, spark_out_put = spark_run(load_mock, conf) assert result_exam_record == task_update_mock - mock_update_exam_result.assert_called_once_with(exam_record_mock, OA_CONF.ExecStatus.success, 10, { - 'sql': f'{OA_CONF.data_dir}/testfile', - 'stages': f'{OA_CONF.data_dir}/testfile', - 'executor': f'{OA_CONF.data_dir}/testfile' - }, "success output\napplication_id:app_12345\ntime_taken: 10") mock_update_exam_result.assert_called() + # 断言Process被调用 + mock_process.assert_called_once() + mock_instance.start.assert_called_once() + + + @patch('omniadvisor.service.spark_service.spark_run.spark_command_reconstruct') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.create') @patch('omniadvisor.service.spark_service.spark_run.SparkExecutor.submit_spark_task') @patch('omniadvisor.service.spark_service.spark_run.save_trace_data') @patch('omniadvisor.service.spark_service.spark_run.ExamRecordRepository.update_exam_result') @patch('requests.get') - def test_spark_run_success(self, mock_get, mock_update_exam_result, mock_save_trace_data, + def test_spark_run_failed(self, mock_get, mock_update_exam_result, mock_save_trace_data, mock_submit_spark_task, mock_create_exam_record, mock_spark_command_reconstruct): # 配置mock对象返回值 mock_response = Mock() @@ -89,7 +94,8 @@ class TestSparkRun: "attr1": "value1", "attr2": "value2" } - result_exam_record = spark_run(load_mock, conf) + load_mock.default_config = {"another_key": "another_value"} + result_exam_record, spark_out_put = spark_run(load_mock, conf) assert result_exam_record == task_update_mock mock_update_exam_result.assert_called_once_with( diff --git a/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py b/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py index 9c3867567..8cdcafaae 100644 --- a/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py +++ b/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py @@ -77,7 +77,7 @@ class TestTuningResult: assert tuning_result.method == 'method1' assert tuning_result.method_extend == "method_extend1" assert tuning_result.rounds == 1 - assert tuning_result.status == OA_CONF.ExecStatus.success + assert tuning_result.status == OA_CONF.ExecStatus.running assert tuning_result.runtime == 12.5 assert tuning_result.trace == "trace_info1" -- Gitee From 0aa1f81317e43c465c5bec4bc8f5ae9d5e1b27b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Thu, 15 May 2025 15:33:39 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E6=96=B0=E5=A2=9Espark=5Fexecutor?= =?UTF-8?q?=E5=89=8D=E5=90=8E=E5=AF=B9=E4=BA=8E=E7=9B=B8=E5=BA=94exam=5Fre?= =?UTF-8?q?cord=E7=9A=84start=5Ftime=E5=92=8Cend=5Ftime=E7=9A=84=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/omniadvisor/service/spark_service/spark_run.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 8c9f37a25..c43cfdfc1 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -31,7 +31,9 @@ def spark_run(load, conf): # 执行当前的submit_cmd spark_executor = SparkExecutor() try: + ExamRecordRepository.update_start_time(exam_record) exitcode, spark_output = spark_executor.submit_spark_task(submit_cmd) + ExamRecordRepository.update_end_time(exam_record) except TimeoutError: # 任务提交超时等 exam_record.delete() -- Gitee From 5808fe32ea509ff0913f86ae1afe358e8e8f86ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Thu, 15 May 2025 15:52:02 +0800 Subject: [PATCH 10/10] =?UTF-8?q?spark=5Frun=E6=89=A7=E8=A1=8C=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E6=97=B6=E8=BF=94=E5=9B=9E=E7=A9=BA=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E4=B8=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/omniadvisor/service/spark_service/spark_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index c43cfdfc1..a0a1d4698 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -57,7 +57,7 @@ def spark_run(load, conf): status=OA_CONF.ExecStatus.fail, runtime=OA_CONF.exec_fail_return_runtime, trace=OA_CONF.exec_fail_return_trace - ), "spark-sql failed" + ), "" except RuntimeError: exam_record.delete() raise -- Gitee