diff --git a/omniadvisor/README.md b/omniadvisor/README.md index 0e305a65ee20fa4846b418572ad8173eb28c317b..6600b178517049309e59812c53a74b94f707e30b 100755 --- a/omniadvisor/README.md +++ b/omniadvisor/README.md @@ -1,20 +1,34 @@ -1 负载劫持 +1 历史数据库初始化 +初始化数据库,数据库用于存储调优过程与负载执行过程。若之前执行过初始化过程,则此处可以跳过。 + +- ① 进入`omniadvisor/src/`目录下。 +- ② 顺序执行以下命令:`python init.py makemigrations`和`python init.py migrate`,初始化数据库 +- ③ 执行命令`python init.py createsuperuser`创建超级用户,并按指示设置用户名和密码。 +- ④ 修改`omniadvisor/src/server/engine/settings.py`中的 `ALLOWED_HOSTS`,在列表中添加当前节点IP地址 + +2 启动历史数据库服务 +完成数据库初始化后,即可启动历史数据库,并通过Web端查看数据库内数据 + +- ① 执行命令`python init.py runserver 0.0.0.0:8000`启动数据库。 +- ② 登录Web管理页面:`localhost:8000/admin/`,`localhost `替换实际的IP地址,并通过初始化过程中的用户名和密码登录。 + +3 负载劫持 首先要对用户负载进行劫持,获取用户负载及相关信息,以便后续能够调优。使能负载劫持,步骤如下: -- ① 修改 omniadvisor/script/spark-submit 脚本中 hijack_path 值,需要填实际的 hijack.py 文件的路径 + +- ① 修改 spark-submit 脚本(位于`omniadvisor/script/`目录下),将 `hijack_path` 值修改为 `hijack.py` 文件(位于`omniadvisor/src/`目录下)绝对路径 - ② 替换正在使用的 spark-submit 脚本(可通过 `which spark-submit` 查看)为 ① 中的 spark-submit 文件 - ③ 启用环境变量:`export enable_omniadvisor=true` 说明:完成以上步骤之后,用户下发相同的负载,会劫持用户的命令,并替换用户配置为系统推荐的最优的配置(初始状态下保持用户的默认配置)。推荐配置若执行失败,则退化为用户的默认配置 -2 查询负载等信息 +4 查询负载信息 执行负载并劫持之后,在后台管理页面查看到相关信息 -- ① 若之前从未创建过用户,使用命令:`python init.py createsuperuser`,按指示设置用户名和密码。 -- ② `python init.py runserver`(可选 绑定 0.0.0.0:8000) -- ③ 修改settings.py中的 ALLOWED_HOSTS(可选) -- ④ 登录后台管理页面:`localhost:8000/admin/` (替换实际的ip,若访问不通,须执行步骤 ②③) -说明:进入页面后,可以看到APP下有若干数据表,其中 loads 就是负载表 -3 启用调优 +- ① 登录后台管理页面:`localhost:8000/admin/` +- ② 切到Loads页面,即可查询看到负载信息 + +5 启用调优 通过在根目录下,执行 `python tuning.py --help` 查看各参数说明 + - ① 各参数的详细介绍 - --load-id,即步骤2中loads表查询到的负载的id - --retest-way,复测方式(复测是为了保证结果的可靠),可选前台或者后台复测。若是前台复测,当用户提交任务时,跟随任务下发;若后台复测,则是由 @@ -22,11 +36,9 @@ - --tuning-method,调优方法,不同的调优方法原理不同,iterative,即迭代调优,本质上使用贝叶斯优化方法;expert,即专家调优,通过诊断资源瓶颈, 给出调优建议;transfer,迁移调优,迁移相似负载的调优经验到陌生负载;native 算子加速,将Spark原生算子替换为C++ Native算子,同时使能CPU向量化执行,实现性能加速 - ② 命令示例 - - `python tuning.py -l 1 -r backend -t expert`。说明:使用专家调优,对id为1的负载,进行调优,期间在后台复测,复测次数为 - common_config.cfg 中 tuning.retest.times 的值 + - `python tuning.py -l 1 -r backend -t expert`。说明:使用专家调优,对id为1的负载,进行调优,期间在后台复测,复测次数为 `common_config.cfg` 中 `tuning.retest.times` 的值 - `python tuning.py -l 1 -r backend`。说明:使用系统中默认的调优策略以及历史调优记录,共同决定当前的调优方式 -说明:调优结束后,会更新负载的最优配置信息,以便步骤1能够使用最优配置 - + 说明:调优结束后,会更新负载的最优配置信息,以便步骤1能够使用最优配置 diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 7842283bfed52a6948ae9fac5a3dfcc3a51b9a8f..b506eb25827eaf03e2be7b867da1337e71ca4cf2 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -70,8 +70,8 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') config_fail_threshold = _common_config.getint('common', 'config.fail.threshold') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') - spark_fetch_trace_timeout = _common_config.get('spark', 'spark.fetch.trace.timeout') - spark_fetch_trace_interval = _common_config.get('spark', 'spark.fetch.trace.interval') + spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') + spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') tuning_strategies = json.loads(_common_config.get('common', 'tuning.strategy')) diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 4f1e2cd7cc397503bedab79ade5b764899331919..6509be5c28c88df149c01e7e096a70fb79b8fa27 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -6,6 +6,7 @@ from algo.native.tuning import NativeTuning from algo.transfer.tuning import TransferTuning from common.constant import OA_CONF +from omniadvisor.repository.model.load import Load from omniadvisor.repository.load_repository import LoadRepository from omniadvisor.repository.tuning_record_repository import TuningRecordRepository from omniadvisor.service.retest_service import retest @@ -60,12 +61,14 @@ def unified_tuning(load, retest_way: str, tuning_method: str): raise NoOptimalConfigError('The recommending config is empty, please try other tuning methods.') # 用户的default_config上叠加next_config叠加 + global_logger.info("Load config tuning success, get new config to retest.") next_config = {**load.default_config, **next_config} TuningRecordRepository.create(load=load, config=next_config, method=tuning_method, method_extend=method_extend) # 复测 if retest_way == OA_CONF.RetestWay.backend: + global_logger.info("The way of retest is backend, going to retest the config from tuning……") try: retest(load, next_config) except Exception: @@ -80,33 +83,47 @@ def unified_tuning(load, retest_way: str, tuning_method: str): LoadRepository.update_best_config(load, tuning_result_history.best_config) else: # 更新待测试配置即可 + global_logger.info(f"The way of retest is hijacking, update the config from tuning to test config.") LoadRepository.update_test_config(load, next_config) -def _get_next_config(load, tuning_method): +def _get_next_config(load: Load, tuning_method: str): """ - 获取下一个配置 + 调优推荐下一个配置 + :param load: 负载 :param tuning_method: 调优方法 :return: """ + # 获取指定负载的调优历史记录 + tuning_result_history = get_tuning_result_history(load) + + # AI迭代调优 if tuning_method == OA_CONF.TuningMethod.iterative: - tuning_result_history = get_tuning_result_history(load) - next_config, method_extend = SmacAppendTuning.tune(tuning_result_history.tuning_history) + global_logger.info("Use AI iterative optimization method to tuning.") + tuning = SmacAppendTuning(tuning_history=tuning_result_history.tuning_history) + # 专家规则调优 elif tuning_method == OA_CONF.TuningMethod.expert: - tuning_result_history = get_tuning_result_history(load) - next_config, method_extend = ExpertTuning.tune(tuning_result_history.tuning_history) + global_logger.info("Use expert rule optimization method to tuning.") + tuning = ExpertTuning(tuning_history=tuning_result_history.tuning_history) + # Native特性使能 elif tuning_method == OA_CONF.TuningMethod.native: - tuning_result_history = get_tuning_result_history(load) - next_config, method_extend = NativeTuning.tune(tuning_result_history.tuning_history) + global_logger.info("Use native operator acceleration feature to tuning.") + tuning = NativeTuning(tuning_history=tuning_result_history.tuning_history) + # 迁移泛化调优 elif tuning_method == OA_CONF.TuningMethod.transfer: - tuning_result_history = get_tuning_result_history(load) + global_logger.info("Use migration generalization optimization method to tuning.") other_history = get_other_tuning_result_history(load) - next_config, method_extend = TransferTuning.tune( - tuning_result_history.tuning_history, [other.tuning_history for other in other_history] + tuning = TransferTuning( + tuning_history=tuning_result_history.tuning_history, + other_histories=[other.tuning_history for other in other_history] ) else: raise ValueError(f'Not supported tuning method: {tuning_method}') + + # 调优 + next_config, method_extend = tuning.tune() + global_logger.debug(f"The method_extend is: {method_extend}, config from tuning is: {next_config}") return method_extend, next_config diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 1dd539ee0fa62b9b1696f8b16a265af3923627ae..59d71373be9655c8c6ab524ea1b331b3eb5883bc 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -23,10 +23,11 @@ def _query_or_create_load(name: str, exec_attr: dict, default_config: dict): loads = LoadRepository.query_by_exec_attr_and_default_config(exec_attr=exec_attr, default_config=default_config) # 如果查询不到负载信息,则创建新的负载 if not loads: - global_logger.info("Not exit available load, create new one.") + global_logger.info("Load not found in database, create new one and execute.") load = LoadRepository.create(name=name, exec_attr=exec_attr, default_config=default_config) TuningRecordRepository.create(load=load, config=default_config, method=OA_CONF.TuningMethod.user) else: + global_logger.info("Load found in database, ready to get config to execute.") load = loads.pop() return load @@ -41,17 +42,21 @@ def _get_exec_config_from_load(load: Load): """ # 如果负载不需要推荐优化配置,直接使用默认配置 if not load.tuning_needed: + global_logger.info("Load doesn't need tuning, running with default config.") return load.default_config # 如果负载需要推荐优化配置 根据优先级选择配置 if load.test_config: # 第一使用待测试配置 + global_logger.info("Test config found, running with test config.") return load.test_config elif load.best_config: # 第二使用最优配置 + global_logger.info("Best config found, running with best config.") return load.best_config else: # 最后使用租户默认配置 + global_logger.info("Only default config found, running with default config.") return load.default_config @@ -63,19 +68,23 @@ def _process_load_config(load: Load): :param load: 本次测试用负载 :return: """ + # 获取测试配置的调优结果 tuning_result = get_tuning_result(load, load.test_config) + + # 测试配置执行成功 if tuning_result.status == OA_CONF.ExecStatus.fail: LoadRepository.update_test_config(load, {}) + # 测试配置执行失败 elif tuning_result.status == OA_CONF.ExecStatus.success: # 获取当前best_config的平均测试性能 - best_config_results = get_tuning_result(load, load.best_config) + best_tuning_result = get_tuning_result(load, load.best_config) # 若调优性能优于最佳性能 刷新当前的best_config - if tuning_result.runtime < best_config_results.runtime: + if tuning_result.runtime < best_tuning_result.runtime: boost_percentage = round( - (best_config_results.runtime - tuning_result.runtime) / best_config_results.runtime, 2 + (best_tuning_result.runtime - tuning_result.runtime) / best_tuning_result.runtime, 2 ) - global_logger.info(f"tuning_runtime={tuning_result.runtime} is " - f"quicker than best_runtime={best_config_results.runtime}, " + global_logger.info(f"Tuning_runtime = {tuning_result.runtime} is " + f"quicker than best_runtime = {best_tuning_result.runtime}, " f"boost percentage = {boost_percentage}" f"found a better spark config") LoadRepository.update_best_config(load, load.test_config) @@ -102,14 +111,18 @@ def hijack_recommend(argv: list): exec_config = _get_exec_config_from_load(load=load) # 根据配置和负载执行Spark任务 + global_logger.info("Going to execute Spark load ……") exam_record, output = spark_run(load, exec_config) # 执行结果分析 若执行失败则调度用户默认配置重新拉起任务 if exam_record.status != OA_CONF.ExecStatus.success and exec_config != user_config: + global_logger.warning("Spark execute failed, ready to activate security protection mechanism.") safe_exam_record, safe_output = spark_run(load, user_config) + global_logger.info("Spark execute success in security protection mechanism, going to print Spark output.") # 打印安全机制下任务的输出 print(safe_output) else: # 打印结果输出 + global_logger.info("Spark execute success, going to print Spark output.") print(output) if exec_config == load.test_config: diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py index 9f550fb01edfcea1917f8a9d23442c63a373bb87..b798bc9f54d9069cab174153140d625f50bf2d51 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py @@ -120,8 +120,10 @@ class SparkCMDParser: # 其余参数保存至exec_attr else: exec_attr[key] = value - global_logger.warning(f"The remainder unknown params {unknown} will be add to exec_attr[remainder]") - exec_attr[_CMD_UNKNOWN_KEY] = unknown + + if unknown: + global_logger.warning(f"The remainder unknown params {unknown} will be add to exec_attr[unknown]") + exec_attr[_CMD_UNKNOWN_KEY] = unknown return exec_attr, conf_params @@ -159,5 +161,4 @@ class SparkCMDParser: cmd_fields += ['--conf', f'{key}={cls._normalize_value(value)}'] cmd = ' '.join(cmd_fields) - global_logger.info(f"The complete spark-sql command is as follows {cmd}") return cmd diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py index 8dd3426ad5eef2a4e9db6f2bce950efb688ef734..af1ee9c2db9c33cddc26a0196e3538ca81b26282 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py @@ -1,14 +1,17 @@ import re + +from common.constant import OA_CONF from omniadvisor.utils.utils import run_cmd class SparkExecutor: @classmethod - def submit_spark_task(cls, execute_cmd): + def submit_spark_task(cls, execute_cmd: str): """ 在shell终端提交spark命令 - :param submit_config_str: spark的提交命令 + + :param execute_cmd: Spark的提交命令 :return: """ exitcode, spark_output = run_cmd(execute_cmd) @@ -16,15 +19,16 @@ class SparkExecutor: return exitcode, spark_output @classmethod - def parser_spark_output(cls, spark_output) -> dict: - return cls._reg_match_application_id_and_time_taken(spark_output) + def parser_spark_output(cls, spark_output: str): + """ + 解析Spark输出,获得spark提交命令、Application ID和执行用时 - @classmethod - def _reg_match_application_id_and_time_taken(cls, spark_output): + :param spark_output: Spark执行输出 + :return: + """ spark_output = spark_output.split("\n") spark_submit_cmd = "" application_id = "" - total_time_taken = -1.0 time_taken_list = [] for item in spark_output: if "deploy.SparkSubmit" in item: @@ -41,6 +45,10 @@ class SparkExecutor: match = re.search(pattern, item) if match: time_taken_list.append(float(match.group(1).strip())) + if time_taken_list: total_time_taken = sum(time_taken_list) - return spark_submit_cmd, application_id, float(total_time_taken) + else: + total_time_taken = OA_CONF.exec_fail_return_runtime + + return spark_submit_cmd, application_id, total_time_taken diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 66a120dacf626bb97d72a6525dd98e469e7571fb..10ab73d39c52e1318132f45e91faa27db8e28553 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -25,13 +25,13 @@ def spark_run(load: Load, conf: dict): """ # 从解析后的参数列表中提取负载与任务的相关信息 submit_cmd = SparkCMDParser.reconstruct_cmd(exec_attr=load.exec_attr, conf_params=conf) - # 判断当前的conf是否和load.default_config相同 不相同则在submit_cmd前增加超时时间 if conf != load.default_config: # 获取当前default_config的平均测试性能 baseline_results = get_tuning_result(load, load.default_config) timeout_sec = OA_CONF.spark_exec_timeout_ratio * baseline_results.runtime submit_cmd = f"timeout {timeout_sec} " + submit_cmd + global_logger.debug(f"The submit cmd about to execute is: {submit_cmd}") # 根据执行命令创建测试记录 exam_record = ExamRecordRepository.create(load, conf) @@ -41,6 +41,7 @@ def spark_run(load: Load, conf: dict): # 不存在application_id的情况下不提取time_taken 直接返回 if exitcode != 0: + global_logger.info(f"Spark Load execute failed, update the exam result.") try: return ExamRecordRepository.update_exam_result( exam_record=exam_record, @@ -63,6 +64,7 @@ def spark_run(load: Load, conf: dict): raise try: + global_logger.info(f"Spark Load execute success, runtime: {runtime}, update the exam result.") exam_record = ExamRecordRepository.update_exam_result( exam_record=exam_record, status=status, @@ -74,6 +76,7 @@ def spark_run(load: Load, conf: dict): raise # 根据ApplicantID在子进程中获取Trace + global_logger.info(f"Going to fetch Spark execute trace in backend.") p = multiprocessing.Process(target=_update_trace_from_history_server, args=(exam_record, application_id)) p.daemon = False p.start() @@ -127,11 +130,12 @@ def _update_trace_from_history_server(exam_record: ExamRecord, application_id: s trace_executor = spark_fetcher.get_spark_executor_by_app(application_id) except HTTPError as httpe: time.sleep(OA_CONF.spark_fetch_trace_interval) - global_logger.warning(f"Cannot access history server: {httpe}") + global_logger.debug(f"Cannot access history server: {httpe}") continue trace_dict['sql'] = save_trace_data(data=trace_sql, data_dir=OA_CONF.data_dir) trace_dict['stages'] = save_trace_data(data=trace_stages, data_dir=OA_CONF.data_dir) trace_dict['executor'] = save_trace_data(data=trace_executor, data_dir=OA_CONF.data_dir) + global_logger.info(f"Fetch Spark execute trace success.") break else: raise RuntimeError(f'Failed to get App {application_id} trace from {history_server_url}') diff --git a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py index 13cb3380b0ab9a998718c40fec8bd583f69db528..233103825d05423a9c5b6e9cc2142170f6e46aee 100644 --- a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py +++ b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py @@ -1,8 +1,7 @@ from typing import List -from algo.common.model import Tuning, Trace +from algo.common.model import TuningData, Trace from common.constant import OA_CONF - from omniadvisor.repository.model.load import Load from omniadvisor.repository.model.tuning_record import TuningRecord from omniadvisor.repository.model.exam_record import ExamRecord @@ -146,13 +145,13 @@ class TuningResult: def exam_records(self): return self._exam_records - def to_tuning(self) -> Tuning: + def to_tuning(self) -> TuningData: if self.trace is not None: trace = Trace(stages_with_summaries=self.trace.get('stages'), sql=self.trace.get('sql')) else: trace = Trace(stages_with_summaries=None, sql=None) - return Tuning( + return TuningData( round=self.rounds, config=self.config, method=self.method, diff --git a/omniadvisor/src/omniadvisor/utils/logger.py b/omniadvisor/src/omniadvisor/utils/logger.py index 2a0984b3a07b84c5a7cc05d16af87dad5e6c117e..bf3d2275b12cf216912479126daf413d1ba0590b 100755 --- a/omniadvisor/src/omniadvisor/utils/logger.py +++ b/omniadvisor/src/omniadvisor/utils/logger.py @@ -60,6 +60,8 @@ dictConfig(LOGGING_CONFIG) # 获取logger并使用 global_logger = logging.getLogger('omniadvisor') +# 禁止logger作用域上升到root +global_logger.propagate = False # 屏蔽部分第三方库中的Logger modules_setting_error = ['smac'] diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index 4c919f612e0a2e65e642bd765a842efc1986423e..8e333afa411735e9637fb3ef8580c032035d72f9 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -2,13 +2,8 @@ import os import json import uuid import subprocess -import shlex from typing import Tuple, List, Dict -from ConfigSpace import ConfigurationSpace, UniformIntegerHyperparameter, Constant, CategoricalHyperparameter, \ - UniformFloatHyperparameter - -from common.constant import OA_CONF from omniadvisor.utils.logger import global_logger @@ -24,12 +19,14 @@ def run_cmd(submit_cmd) -> Tuple[int, str]: :rtype: Tuple[int, str] """ global_logger.debug(f"Executor system command: {submit_cmd}") - cmd_list = shlex.split(submit_cmd) - result = subprocess.run(cmd_list, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=False) - exitcode = result.returncode - data = result.stdout - global_logger.debug(data) - return exitcode, data + kwargs = { + 'stdout': subprocess.PIPE, + 'stderr': subprocess.STDOUT, + 'shell': True, + 'text': True + } + result = subprocess.run(submit_cmd, **kwargs) + return result.returncode, result.stdout def save_trace_data(data: List[Dict[str, str]], data_dir): @@ -45,7 +42,7 @@ def save_trace_data(data: List[Dict[str, str]], data_dir): try: with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=4) - global_logger.info(f"数据已成功保存到 {file_path}") + global_logger.debug(f"数据已成功保存到 {file_path}") except IOError as e: raise IOError(f"出现IO错误: {e}") from e except Exception as e: diff --git a/omniadvisor/src/server/app/models.py b/omniadvisor/src/server/app/models.py index 40851b68a83c6d24431c584fd5f351a5b2b66d5d..bc5279bef8da279adfe85d77713c5b755a3e0212 100644 --- a/omniadvisor/src/server/app/models.py +++ b/omniadvisor/src/server/app/models.py @@ -19,7 +19,7 @@ class DatabaseLoad(models.Model): class Meta: constraints = [ - models.UniqueConstraint(fields=['name', 'default_config'], name='load_unique') + models.UniqueConstraint(fields=['exec_attr', 'default_config'], name='load_unique') ] db_table = 'omniadvisor_load' # 自定义表名 diff --git a/omniadvisor/tests/conftest.py b/omniadvisor/tests/conftest.py index 815fff6e575dd39b5f0728ff2af0e71db50940bc..8110177177d50f18dfd316ccc64bf655701a19eb 100644 --- a/omniadvisor/tests/conftest.py +++ b/omniadvisor/tests/conftest.py @@ -34,7 +34,7 @@ to_registers = { 'algo.common.model.Trace', 'algo.expert.tuning.ExpertTuning', 'algo.iterative.tuning.SmacAppendTuning', - 'algo.common.model.Tuning', + 'algo.common.model.TuningData', 'algo.native.tuning.NativeTuning', 'algo.transfer.tuning.TransferTuning', } diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index 30935facabd50441023b6d1f5ee9be4aa316b020..e7919e68a9a3e6eb208c91b49869a318e1403623 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -40,12 +40,11 @@ class TestTuning: mock_smac_tuning.assert_called_once() assert mock_spark_run.call_count == OA_CONF.tuning_retest_times - def test_unified_tuning_when_spark_execute_failed(self, caplog): + def test_unified_tuning_when_spark_execute_failed(self): """ 后台复测 spark命令执行异常 :return: """ - caplog.set_level(logging.WARNING) with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id'), \ patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'), \ patch('omniadvisor.service.retest_service.spark_run') as mock_spark_run, \ @@ -72,9 +71,7 @@ class TestTuning: mock_update_best.assert_not_called() mock_remove_tuning_result.assert_not_called() - assert 'Retest failed in round 1. Exception source: Spark exception' in caplog.text - - def test_unified_tuning_when_other_exception(self, caplog): + def test_unified_tuning_when_other_exception(self): """ 后台复测 发生其他异常 :return: @@ -96,7 +93,6 @@ class TestTuning: mock_smac_tuning.assert_called_once() mock_spark_run.assert_called_once() mock_remove_tuning_result.assert_called_once() - assert 'Retest failed in round 1. Exception source: Non-Spark exception' in caplog.text def test_main_when_load_id_not_exist(self): """ diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py index e69de2fa9ed8f6dd08db4afe7065eb180a3abce0..8eb2cbaa2455d4b847f6d6f45610ff8452f05c8f 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py @@ -1,10 +1,14 @@ from unittest.mock import patch + +from common.constant import OA_CONF from omniadvisor.service.spark_service.spark_executor import SparkExecutor class TestSparkExecutor: def setup_class(self): - """在每个测试用例之前运行,初始化 SparkExecutor 实例""" + """ + 在每个测试用例之前运行,初始化 SparkExecutor 实例 + """ self.spark_executor = SparkExecutor() @patch('omniadvisor.service.spark_service.spark_executor.run_cmd') @@ -48,26 +52,6 @@ class TestSparkExecutor: assert application_id == expected_result["application_id"] assert time_taken == expected_result["time_taken"] - def test_reg_match_application_id_and_time_taken(self): - """ - 测试 _reg_match_application_id_and_time_taken 方法 - """ - # 模拟 spark_output 数据 - spark_output ="Some log line\n" \ - "INFO deploy.SparkSubmit: Starting Spark job\n" \ - "Another log line\n" \ - "Application Id: application_1234567890\n" \ - "Time taken: 12.34 seconds" - - # 调用方法 - spark_submit_cmd, application_id, time_taken = self.spark_executor._reg_match_application_id_and_time_taken( - spark_output) - - # 验证结果 - assert spark_submit_cmd == "INFO deploy.SparkSubmit: Starting Spark job" - assert application_id == "application_1234567890" - assert time_taken == 12.34 - def test_parser_spark_output_with_empty_output(self): """ 测试 parser_spark_output 方法处理空输出的情况 @@ -81,4 +65,4 @@ class TestSparkExecutor: # 验证结果 assert spark_submit_cmd == "" assert application_id == "" - assert time_taken == -1.0 + assert time_taken == OA_CONF.exec_fail_return_runtime diff --git a/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py b/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py index 8cdcafaae824b8a18824b2d39ec7189f1bf7d25c..0d6759328764d824e04388754928374c4a038767 100644 --- a/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py +++ b/omniadvisor/tests/omniadvisor/service/tuning_result/test_tuning_result.py @@ -66,6 +66,8 @@ class TestTuningResult: task_mock3.status = OA_CONF.ExecStatus.fail task_mock3.runtime = 10 task_mock3.trace = "trace_info3" + # 强制设定复测次数为3,用于测试 + OA_CONF.tuning_retest_times = 3 tasks = [task_mock1, task_mock2] tuning_result = TuningResult( diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index a1db1a28e012feca89799082a7fb88f4bb76dae3..ae50ad5edcebb0219cf8ad81584cc00603e1e0e8 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -15,8 +15,7 @@ class TestRunCmd: # 验证结果 assert exitcode[0] == 0 assert output == "Success output" - mock_run.assert_called_once_with(['echo', 'hello'], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=False) + mock_run.assert_called_once() @patch('subprocess.run') def test_run_cmd_failure(self, mock_run):