From 4cfd78f045478ca7a078da75d75dd00428ddc153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Sat, 9 Aug 2025 18:09:45 +0800 Subject: [PATCH 1/2] =?UTF-8?q?1.=E4=BF=AE=E6=94=B9tuning=5Fstrategies?= =?UTF-8?q?=E7=9A=84=E7=B1=BB=E5=9E=8B:=E4=BB=8EList[List]=20->=20List[Tup?= =?UTF-8?q?le]=202.=E5=87=BD=E6=95=B0=E5=85=A5=E5=8F=82=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E3=80=81=E8=BF=94=E5=9B=9E=E5=80=BC=E7=B1=BB=E5=9E=8B=E4=BB=A5?= =?UTF-8?q?=E5=8F=8A=E6=B3=A8=E8=A7=A3=E8=A7=84=E8=8C=83=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/common/constant.py | 20 +++--- .../omniadvisor/interface/config_tuning.py | 65 ++++++++++++------- .../omniadvisor/interface/hijack_recommend.py | 58 +++++++++++------ .../repository/exam_record_repository.py | 30 +++++---- .../repository/load_prefetch_repository.py | 5 ++ .../omniadvisor/repository/load_repository.py | 14 ++-- .../repository/model/exam_record.py | 6 +- .../src/omniadvisor/repository/repository.py | 46 +++++++++---- .../repository/tuning_record_repository.py | 19 +++--- .../src/omniadvisor/service/retest_service.py | 17 +++-- .../service/spark_service/spark_cmd_parser.py | 20 +++--- .../service/spark_service/spark_executor.py | 18 ++--- .../service/spark_service/spark_fetcher.py | 54 ++++++++------- .../service/spark_service/spark_run.py | 26 ++++---- .../service/tuning_result/tuning_result.py | 9 ++- .../tuning_result/tuning_result_history.py | 5 +- omniadvisor/src/server/app/admin.py | 2 +- omniadvisor/src/server/app/models.py | 2 +- .../repository/test_exam_record_repository.py | 10 +-- .../repository/test_load_repository.py | 10 +-- .../test_tuning_record_repository.py | 6 +- .../spark_service/test_spark_executor.py | 6 +- .../spark_service/test_spark_fetcher.py | 2 +- .../service/spark_service/test_spark_run.py | 6 +- 24 files changed, 273 insertions(+), 183 deletions(-) diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 8a75091fb..4a06f50cb 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -5,7 +5,7 @@ import configparser from server.engine.settings import BASE_DIR -def load_common_config(config_path: str): +def load_common_config(config_path: str) -> configparser.ConfigParser: """ 使用configparser库加载common_config @@ -21,9 +21,11 @@ def load_common_config(config_path: str): return common_config -def check_oa_conf(): +def check_oa_conf() -> None: """ - 校验OA_CONF中参数是否正确 + 校验OA_CONF中参数是否符合要求,在参数值不符合规范时抛出异常 + + :raises ValueError: OA_CONF中的值不符合要求时抛出ValueError """ if OA_CONF.tuning_retest_times <= 0: raise ValueError('The tuning retest times must > 0, please check common configuration.') @@ -46,7 +48,9 @@ def check_oa_conf(): class OmniAdvisorConf: """ - OmniAdvisor常量 + OmniAdvisor 常量配置类。 + + 存放 OmniAdvisor 系统的全局常量,包括 API 路径、默认参数等。 """ # OmniAdvisor版本信息 product_name = "Kunpeng BoostKit" @@ -85,7 +89,7 @@ class OmniAdvisorConf: empty_config = dict() exec_fail_return_runtime = float('inf') - exec_fail_return_app_id = '' + exec_fail_return_application_id = '' exec_fail_return_trace = dict() # 调优算法 @@ -109,12 +113,12 @@ class OmniAdvisorConf: tuning_retest_times = _common_config.getint('common', 'tuning.retest.times') config_fail_threshold = _common_config.getint('common', 'config.fail.threshold') spark_history_rest_url = _common_config.get('spark', 'spark.history.rest.url') - spark_history_username = _common_config.get('spark', 'spark.history.username', fallback='') - spark_history_password = _common_config.get('spark', 'spark.history.password', fallback='') + spark_history_username = _common_config.get('spark', 'spark.history.username') + spark_history_password = _common_config.get('spark', 'spark.history.password') spark_fetch_trace_timeout = _common_config.getint('spark', 'spark.fetch.trace.timeout') spark_fetch_trace_interval = _common_config.getint('spark', 'spark.fetch.trace.interval') spark_exec_timeout_ratio = _common_config.getfloat('spark', 'spark.exec.timeout.ratio') - tuning_strategies = json.loads(_common_config.get('common', 'tuning.strategy')) + strategies_tuples = [tuple(item) for item in json.loads(_common_config.get('common', 'tuning.strategy'))] # 保留小数位数 decimal_digits = 3 diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 0d4803924..99ad84bbc 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -1,5 +1,6 @@ import argparse import signal +from typing import Optional from algo.expert.tuning import ExpertTuning from algo.iterative.tuning import SmacAppendTuning @@ -23,17 +24,22 @@ from omniadvisor.utils.logger import global_logger from omniadvisor.utils.utils import float_format -def handler(signum, frame): +def handler(signum: int, frame) -> None: # TODO 这个frame为什么是no value 以及signum是什么类型 """ - 用于注册异常退出信号 + 用于注册异常退出信号的处理函数。 + + :param signum: 接收到的信号编号(如 signal.SIGTERM) + :param frame: 当前栈帧,可能为 None + :raises SystemKilledError: 当接收到终止信号时抛出该异常 """ raise SystemKilledError(f'System is terminated, because of catching signal %s', signum) -def _parse_tuning_args(): +def _parse_tuning_args() -> argparse.Namespace: """ - 解析用户直命令行输入的参数 - :return: + 解析调优模块用户直命令行输入的参数 + + :return: 解析后参数的命名空间 """ parser = argparse.ArgumentParser( description="Parse params that will frequently used. For other configs, change the config file.") @@ -51,13 +57,17 @@ def _parse_tuning_args(): return args -def _single_tuning(load, retest_way: str, tuning_method: str): +def _single_tuning(load: Load, retest_way: str, tuning_method: str) -> None: """ - 后台调优 + 对入参中的负载使用指定的算完成一次调优动作 + (一次完整的调优动作包括①调优算法推出配置,②调优记录生成,③配置复测流程执行,④负载对应的test_config与best_config刷新) + 当推出的配置符合要求时,生成对应的TuningRecord,并依据复测方式按照指定流程完成复测 + 根据复测情况对该负载相关联test_config、best_config等字段进行更新 + :param load: 负载 :param retest_way: 复测方法 - :param tuning_method: 算法的类型 - :return: + :param tuning_method: 调优算法的类型 + :raises NoOptimalConfigError: 当没有推出符合要求的优化后参数时抛出异常 """ signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGINT, handler) @@ -65,7 +75,8 @@ def _single_tuning(load, retest_way: str, tuning_method: str): method_extend, next_config = _get_next_config(load, tuning_method) if not next_config: - raise NoOptimalConfigError(f'The recommending config of method {tuning_method} is empty, please try other tuning methods.') + raise NoOptimalConfigError(f'The recommending config of method {tuning_method} is empty,' + f' please try other tuning methods.') # 用户的default_config上叠加next_config叠加 global_logger.info("Load config tuning success, get new config to retest.") @@ -73,7 +84,8 @@ def _single_tuning(load, retest_way: str, tuning_method: str): # 判断是否有重复推荐配置 if TuningRecordRepository.query_by_load_and_config(load=load, config=next_config): - raise NoOptimalConfigError(f'The recommending config of method {tuning_method} already exists, please try other tuning methods.') + raise NoOptimalConfigError(f'The recommending config of method {tuning_method} already exists,' + f' please try other tuning methods.') TuningRecordRepository.create(load=load, config=next_config, method=tuning_method, method_extend=method_extend) @@ -105,13 +117,17 @@ def _single_tuning(load, retest_way: str, tuning_method: str): LoadRepository.update_test_config(load, next_config) -def _get_next_config(load: Load, tuning_method: str): +def _get_next_config(load: Load, tuning_method: str) -> tuple[str, str]: # TODO 检查一下这个返回值类型 """ - 调优推荐下一个配置 + 对指定的负载通过指定的调优算法进行调优,并返回调优后的配置 :param load: 负载 :param tuning_method: 调优方法 - :return: + :raises ValueError: 当传入的调优方法不被支持时抛出异常 + :returns: 一个元组 (method_extend, next_config) + - **method_extend**: 调优方法描述 + - **next_config**: 优化后配置 + :rtype: tuple[str, str] """ # 获取指定负载的调优历史记录 tuning_result_history = get_tuning_result_history(load) @@ -119,20 +135,20 @@ def _get_next_config(load: Load, tuning_method: str): # AI迭代调优 if tuning_method == OA_CONF.TuningMethod.iterative: global_logger.info("Use AI iterative optimization method to tuning.") - tuning = SmacAppendTuning(tuning_history=tuning_result_history.to_tuning_data_list()) + tuner = SmacAppendTuning(tuning_history=tuning_result_history.to_tuning_data_list()) # 专家规则调优 elif tuning_method == OA_CONF.TuningMethod.expert: global_logger.info("Use expert rule optimization method to tuning.") - tuning = ExpertTuning(tuning_history=tuning_result_history.to_tuning_data_list()) + tuner = ExpertTuning(tuning_history=tuning_result_history.to_tuning_data_list()) # Native特性使能 elif tuning_method == OA_CONF.TuningMethod.native: global_logger.info("Use native operator acceleration feature to tuning.") - tuning = NativeTuning(tuning_history=tuning_result_history.to_tuning_data_list()) + tuner = NativeTuning(tuning_history=tuning_result_history.to_tuning_data_list()) # 迁移泛化调优 elif tuning_method == OA_CONF.TuningMethod.transfer: global_logger.info("Use migration generalization optimization method to tuning.") other_history = get_other_tuning_result_history(load) - tuning = TransferTuning( + tuner = TransferTuning( tuning_history=tuning_result_history.to_tuning_data_list(), other_histories=[other.to_tuning_data_list() for other in other_history] ) @@ -140,12 +156,12 @@ def _get_next_config(load: Load, tuning_method: str): raise ValueError(f'Not supported tuning method: {tuning_method}') # 调优 - next_config, method_extend = tuning.tune() + next_config, method_extend = tuner.tune() global_logger.debug(f"The method_extend is: {method_extend}, config from tuning is: {next_config}") return method_extend, next_config -def _continuous_tuning_with_strategies(load, tuning_strategies: list): +def _continuous_tuning_with_strategies(load: Load, tuning_strategies: list[tuple[str, int]]) -> None: """ 根据输入的调优策略列表,进行连续的调优 @@ -165,11 +181,12 @@ def _continuous_tuning_with_strategies(load, tuning_strategies: list): break -def _query_and_check_load(load_id): +def _query_and_check_load(load_id: str) -> Optional[Load]: # TODO load_id用str类型感觉略微奇怪 """ - 前置检查 - :param load_id: - :return: + 根据load_id查询该负载是否存在于数据库中且需要调优 + + :param load_id: 负载 ID + :returns: 如果找到并且需要调优,则返回对应的 Load 对象;否则返回 ``None``。 """ # 查询负载: loads = LoadRepository.query_by_id(load_id) diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 5da174b1a..dcef90f99 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -3,6 +3,7 @@ import json import multiprocessing import re import sys +from typing import Any from common.constant import OA_CONF from omniadvisor.repository.load_repository import LoadRepository @@ -20,9 +21,9 @@ _USER_CONFIG_STR = 'user_config' _FILE_CONTENT_STR = 'file_content' -def _get_exec_config_from_load(load: Load): +def _get_exec_config_from_load(load: Load) -> dict[str, Any]: """ - 根据负载属性,按优先级获取即将执行的参数配置 + 根据负载属性,按优先级规则获取即将执行的参数配置 :param load: 负载实例 :return: 待执行配置 @@ -47,11 +48,12 @@ def _get_exec_config_from_load(load: Load): return load.default_config -def _process_load_config(load: Load, config: dict): +def _process_load_config(load: Load, config: dict) -> None: """ 1.检查当前负载中test_config是否完成复测流程 2.根据复测情况判断是否需要清空load中的test_config 3.根据测试性能判断是否需要刷新load中保存的best_config + :param load: 本次测试用负载 :param config: 本次测试用配置 :return: @@ -93,19 +95,27 @@ def _process_load_config(load: Load, config: dict): pass -def _get_load_name_from_exec_attr(exec_attr: dict): +def _get_load_name_from_exec_attr(exec_attr: dict) -> str: """ - 获取任务名,无则空 + 获取负载的name参数,无则返回空字符串 + + :param exec_attr: Spark任务的执行配置 + :return: 负载的name参数 """ return exec_attr.get('name', '') -def _calculate_hash_value(exec_attr: dict, user_config: dict): +def _calculate_hash_value(exec_attr: dict, config: dict) -> str: """ 根据 执行参数和用户配置计算负载唯一标识,hash值 + + :param exec_attr: 负载的执行参数 + :param config: 负载的配置 + :return: 计算得出的哈希值 """ class_info = exec_attr.get('class', '') - if 'SparkSQLCLIDriver' in class_info: + # 提交类为OA_CONF.SparkSQLCLIDriver时,表示该任务是由spark-sql命令所提交 + if class_info == OA_CONF.SparkSQLCLIDriver: # spark-sql,则提取 -e 和 -f,替换 if 'e' in exec_attr: file_content = _remove_time(exec_attr.get('e', '')) @@ -123,7 +133,7 @@ def _calculate_hash_value(exec_attr: dict, user_config: dict): name = _get_load_name_from_exec_attr(exec_attr) data = { _NAME_STR: _remove_time(name), - _USER_CONFIG_STR: user_config, + _USER_CONFIG_STR: config, _FILE_CONTENT_STR: file_content } json_str = json.dumps(data, sort_keys=True) @@ -131,37 +141,47 @@ def _calculate_hash_value(exec_attr: dict, user_config: dict): return hash_value -def _remove_time(content: str): +def _remove_time(content: str) -> str: """ - 移除时间信息,可精确到小时 + 移除时间信息,可精确到小时 #TODO 移除哪里的时间信息了 content是哪儿的入参 + + :param content: + :return: """ for pattern in OA_CONF.date_patterns: content = re.sub(pattern, '{date}', content) return content.strip() -def _create_or_update_load(exec_attr: dict, default_config: dict): +def _create_or_update_load(exec_attr: dict, config: dict) -> Load: """ - 无则创建,有则更新 name 和 exec_attr 字段。如果日期信息等一直保持不变,后期复测时,会导致测试出现异常 + 根据exec_attr和config查询是否存在对应的负载,无则创建新负载并返回,有则更新 name 和 exec_attr 字段后返回该负载。 + (如果日期信息等一直保持不变,后期复测时,会导致测试出现异常) #TODO 展开讲讲 + + :param exec_attr: 负载执行参数 + :param config: 负载配置 + :return: 更新后或者新创建的负载 """ name = _get_load_name_from_exec_attr(exec_attr) - hash_value = _calculate_hash_value(exec_attr, default_config) + hash_value = _calculate_hash_value(exec_attr, config) loads = LoadRepository.query_by_hash_value(hash_value) if loads: global_logger.info("Load found in database, update and ready to get config to execute.") return LoadRepository.update_name_and_exec_attr(loads.pop(), name, exec_attr) else: # 负载没有创建过 - load = LoadRepository.create(name, exec_attr, default_config, hash_value) + load = LoadRepository.create(name, exec_attr, config, hash_value) global_logger.info("Load not found in database, created new one and execute.") - TuningRecordRepository.create(load=load, config=default_config, method=OA_CONF.TuningMethod.user) + TuningRecordRepository.create(load=load, config=config, method=OA_CONF.TuningMethod.user) return load -def hijack_recommend(argv: list): +def hijack_recommend(argv: list) -> None: """ - 任务劫持,使能参数下发执行任务 - :param argv: Spark执行命令字段 + 对用户的任务进行劫持,使能参数并下发执行任务 + + :param argv: Spark任务的执行命令 + :return: """ # 非SUBMIT动作(指kill任务/查询状态/查询版本)的提交直接回退到原生spark-submit脚本执行 不被特性所劫持 SparkCMDParser.validate_submit_arguments(argv) @@ -171,7 +191,7 @@ def hijack_recommend(argv: list): exec_attr, user_config = SparkCMDParser.parse_cmd(argv=argv) # 查询或创建相应负载 - load = _create_or_update_load(exec_attr=exec_attr, default_config=user_config) + load = _create_or_update_load(exec_attr=exec_attr, config=user_config) # 获取待执行参数配置 exec_config = _get_exec_config_from_load(load=load) diff --git a/omniadvisor/src/omniadvisor/repository/exam_record_repository.py b/omniadvisor/src/omniadvisor/repository/exam_record_repository.py index ebb4c1e47..c8e790fb4 100644 --- a/omniadvisor/src/omniadvisor/repository/exam_record_repository.py +++ b/omniadvisor/src/omniadvisor/repository/exam_record_repository.py @@ -1,5 +1,5 @@ from datetime import datetime - +from typing import List from common.constant import OA_CONF from server.app.models import ( DatabaseTuningRecord, @@ -33,12 +33,13 @@ class ExamRecordRepository(Repository): ExamRecord.FieldName.end_time: ([datetime], [''], []), ExamRecord.FieldName.status: ([str], [], OA_CONF.ExecStatus.all), ExamRecord.FieldName.runtime: ([float], [], []), - ExamRecord.FieldName.app_id: ([str], [], []), + ExamRecord.FieldName.application_id: ([str], [], []), ExamRecord.FieldName.trace: ([dict], [], []), } @classmethod - def create(cls, load: Load, config: dict, status: str, runtime: float, start_time: datetime, end_time: datetime, app_id: str): + def create(cls, load: Load, config: dict, status: str, runtime: float, + start_time: datetime, end_time: datetime, application_id: str) -> ExamRecord: """ 指定exam_record属性,新增测试记录 @@ -48,7 +49,7 @@ class ExamRecordRepository(Repository): :param runtime: 执行用时 :param start_time: 任务开始时间 :param end_time: 任务结束时间 - :param app_id: Spark任务application_id + :param application_id: Spark任务application_id :return: 测试记录实例 """ tuning_records = TuningRecordRepository.query_by_load_and_config(load=load, config=config) @@ -62,15 +63,15 @@ class ExamRecordRepository(Repository): ExamRecord.FieldName.runtime: runtime, ExamRecord.FieldName.start_time: start_time, ExamRecord.FieldName.end_time: end_time, - ExamRecord.FieldName.app_id: app_id + ExamRecord.FieldName.application_id: application_id } database_record = cls._create(model_attr=model_attr) return ExamRecord(database_model=database_record) @classmethod - def query_by_tuning_record(cls, tuning_record: TuningRecord): + def query_by_tuning_record(cls, tuning_record: TuningRecord) -> List[ExamRecord]: """ - 指定调优记录,查询测试记录 + 指定调优记录,查询与该调优记录所关联的全部测试记录,并以列表的形式返回 :param tuning_record: 调优记录实例 :return: 测试记录实例列表 @@ -84,9 +85,9 @@ class ExamRecordRepository(Repository): ] @classmethod - def query_by_load_and_config(cls, load: Load, config: dict): + def query_by_load_and_config(cls, load: Load, config: dict) -> List[ExamRecord]: """ - 指定负载、配置,查询测试记录 + 指定(负载,配置),查询与该(负载,配置)关联的全部测试记录,并以列表形式返回 :param load: 负载 :param config: 参数配置 @@ -100,7 +101,7 @@ class ExamRecordRepository(Repository): return cls.query_by_tuning_record(tuning_record=tuning_record) @classmethod - def update_trace(cls, exam_record: ExamRecord, trace: dict): + def update_trace(cls, exam_record: ExamRecord, trace: dict) -> ExamRecord: """ 更新测试记录中的trace信息 @@ -119,7 +120,7 @@ class ExamRecordRepository(Repository): return ExamRecord(database_model=database_task) @classmethod - def update_runtime(cls, exam_record: ExamRecord, runtime: float): + def update_runtime(cls, exam_record: ExamRecord, runtime: float) -> ExamRecord: """ 更新从jobs中获取的runtime信息 @@ -138,11 +139,12 @@ class ExamRecordRepository(Repository): return ExamRecord(database_model=database_task) @classmethod - def delete(cls, exam_record: ExamRecord): + def delete(cls, exam_record: ExamRecord) -> None: """ - 更新测试记录中的结束时间为此刻 + 从数据库删除指定的测试记录 :param exam_record: 测试记录实例 - :return: 更新完后的测试记录 + :return: """ + # TODO 这个delete动作返回什么东西了?我甚至找不到这个delete函数 return exam_record.database_model.delete() diff --git a/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py b/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py index 1e616274a..4362ec724 100644 --- a/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py +++ b/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py @@ -10,11 +10,16 @@ from omniadvisor.repository.repository import Repository class LoadPrefetchRepository(Repository): + """ + #TODO 这个类居然五月份就有了吗?能不能介绍一下这是个什么类。 + #TODO 为什么这个类只有这一个函数 这是干啥用的 + """ @classmethod def exclude_query_load_with_tuning_records(cls, exclude_load: Load): """ 获取 exclude_load 的 id 以外的 load 及 tuning record的聚合 + :param exclude_load: 排除的load id :return: """ diff --git a/omniadvisor/src/omniadvisor/repository/load_repository.py b/omniadvisor/src/omniadvisor/repository/load_repository.py index 28448fd5c..64e0e1d61 100644 --- a/omniadvisor/src/omniadvisor/repository/load_repository.py +++ b/omniadvisor/src/omniadvisor/repository/load_repository.py @@ -33,7 +33,7 @@ class LoadRepository(Repository): } @classmethod - def create(cls, name: str, exec_attr: dict, default_config: dict, hash_value: str): + def create(cls, name: str, exec_attr: dict, default_config: dict, hash_value: str) -> Load: """ 指定名称、执行属性和默认配置,新增负载 @@ -58,7 +58,7 @@ class LoadRepository(Repository): return Load(database_model=database_load) @classmethod - def query_by_id(cls, load_id: str) -> list: + def query_by_id(cls, load_id: str) -> list[Load]: """ 根据 load_id 查询 load :param load_id: 负载id @@ -74,7 +74,7 @@ class LoadRepository(Repository): ] @classmethod - def query_by_hash_value(cls, hash_value: str): + def query_by_hash_value(cls, hash_value: str) -> list[Load]: """ 根据hash值,查找负载实例 :param hash_value: 负载实例 @@ -89,7 +89,7 @@ class LoadRepository(Repository): ] @classmethod - def update_name_and_exec_attr(cls, load: Load, name: str, exec_attr: dict): + def update_name_and_exec_attr(cls, load: Load, name: str, exec_attr: dict) -> Load: """ 更新和时间相关的信息 :param load: 负载 @@ -105,7 +105,7 @@ class LoadRepository(Repository): return Load(database_model=database_load) @classmethod - def update_best_config(cls, load: Load, best_config: dict): + def update_best_config(cls, load: Load, best_config: dict) -> Load: """ 更新负载的最优配置 @@ -123,7 +123,7 @@ class LoadRepository(Repository): return Load(database_model=database_load) @classmethod - def update_test_config(cls, load: Load, test_config: dict): + def update_test_config(cls, load: Load, test_config: dict) -> Load: """ 更新负载的测试配置 @@ -141,7 +141,7 @@ class LoadRepository(Repository): return Load(database_model=database_load) @classmethod - def update_tuning_needed(cls, load: Load, tuning_needed: bool): + def update_tuning_needed(cls, load: Load, tuning_needed: bool) -> Load: """ 更新当前负载是否需要调优这一信息 :param load: 负载实例 diff --git a/omniadvisor/src/omniadvisor/repository/model/exam_record.py b/omniadvisor/src/omniadvisor/repository/model/exam_record.py index 26c2a747f..d260126c5 100644 --- a/omniadvisor/src/omniadvisor/repository/model/exam_record.py +++ b/omniadvisor/src/omniadvisor/repository/model/exam_record.py @@ -49,8 +49,8 @@ class ExamRecord: return self._database_model.runtime @property - def app_id(self): - return self._database_model.app_id + def application_id(self): + return self._database_model.application_id @property def trace(self): @@ -67,5 +67,5 @@ class ExamRecord: end_time = 'end_time' status = 'status' runtime = 'runtime' - app_id = 'app_id' + application_id = 'application_id' trace = 'trace' diff --git a/omniadvisor/src/omniadvisor/repository/repository.py b/omniadvisor/src/omniadvisor/repository/repository.py index 8b118dfec..aed2f18bf 100644 --- a/omniadvisor/src/omniadvisor/repository/repository.py +++ b/omniadvisor/src/omniadvisor/repository/repository.py @@ -1,7 +1,7 @@ from abc import ABC from abc import abstractmethod - -from django.db.models import Model +from typing import Any +from django.db.models import Model, QuerySet class Repository(ABC): @@ -42,33 +42,51 @@ class Repository(ABC): pass @classmethod - def _check_frozen_field(cls, model_attr: dict): + def _check_frozen_field(cls, model_attr: dict[str, Any]) -> None: """ - 检查是否存在不允许修改的字段 + 检查是否存在不允许修改的字段。 - :param model_attr: Model属性字典 + 遍历类定义的 ``_frozen_fields`` 列表,如果在 ``model_attr`` 中发现这些字段, + 则说明用户试图修改被冻结的字段,方法会抛出 ``ValueError`` 异常。 + + :param model_attr: 模型属性字典,key 为字段名,value 为字段值。 + :raises ValueError: 当检测到存在不允许修改的字段时抛出。 """ for field in cls._frozen_fields: if field in model_attr.keys(): raise ValueError(f'The {cls._model_class.__name__} field {field} can not be changed.') @classmethod - def _check_required_field(cls, model_attr: dict): + def _check_required_field(cls, model_attr: dict[str, Any]) -> None: """ - 检查是否缺少必须填写的字段 + 检查是否缺少必须填写的字段。 - :param model_attr: Model属性字典 + 遍历类定义的 ``_required_fields`` 列表,如果在 ``model_attr`` 中缺少这些字段, + 则说明必填字段未提供,方法会抛出 ``ValueError`` 异常。 + + :param model_attr: 模型属性字典,key 为字段名,value 为字段值。 + :raises ValueError: 当检测到缺少必填字段时抛出。 """ for field in cls._required_fields: if field not in model_attr.keys(): raise ValueError(f'The {cls._model_class.__name__} field {field} is required.') @classmethod - def _check_field_format(cls, model_attr: dict): + def _check_field_format(cls, model_attr: dict[str, Any]) -> None: """ - 检查字段格式是否合理 + 检查字段的类型和值是否符合定义的格式约束。 - :param model_attr: Model属性字典 + 遍历类定义的 ``_fields_format`` 配置,对传入的 ``model_attr`` 中已定义的字段进行以下检查: + + 1. 类型检查:字段值的类型必须在允许类型集合 ``field_types`` 中。 + 2. 禁止值检查:字段值不能在 ``field_forbidden`` 列表中。 + 3. 允许值检查:如果 ``field_allow`` 非空,则字段值必须在该列表中。 + + :param model_attr: 模型属性字典,key 为字段名,value 为字段值。 + :raises TypeError: 当字段类型不在允许类型集合中时抛出。 + :raises ValueError: + - 当字段值在禁止值列表中时抛出。 + - 当字段值不在允许值列表(且允许值列表非空)中时抛出。 """ for field, field_format in cls._fields_format.items(): if field not in model_attr.keys(): @@ -83,7 +101,7 @@ class Repository(ABC): raise ValueError(f'Value of {cls._model_class.__name__} field {field} must be in {field_allow}.') @classmethod - def _create(cls, model_attr: dict): + def _create(cls, model_attr: dict) -> Model: """ 添加Model模型至数据库 @@ -98,7 +116,7 @@ class Repository(ABC): return database_model @classmethod - def _query(cls, model_attr: dict): + def _query(cls, model_attr: dict) -> QuerySet[Model]: """ 根据Model属性字典,搜索符合要求的实例 @@ -111,7 +129,7 @@ class Repository(ABC): return database_models @classmethod - def _update(cls, database_model: Model, model_attr: dict): + def _update(cls, database_model: Model, model_attr: dict) -> Model: """ 更新Model模型至数据库 diff --git a/omniadvisor/src/omniadvisor/repository/tuning_record_repository.py b/omniadvisor/src/omniadvisor/repository/tuning_record_repository.py index d612b9a79..2e4952854 100644 --- a/omniadvisor/src/omniadvisor/repository/tuning_record_repository.py +++ b/omniadvisor/src/omniadvisor/repository/tuning_record_repository.py @@ -1,3 +1,4 @@ +from typing import List from common.constant import OA_CONF from server.app.models import DatabaseLoad, DatabaseTuningRecord @@ -29,7 +30,7 @@ class TuningRecordRepository(Repository): } @classmethod - def create(cls, load: Load, config: dict, method: str, method_extend: str = ''): + def create(cls, load: Load, config: dict, method: str, method_extend: str = '') -> TuningRecord: """ 指定负载、参数配置和调优方法,新增调优记录 @@ -68,12 +69,12 @@ class TuningRecordRepository(Repository): return TuningRecord(database_model=database_record) @classmethod - def query_by_load(cls, load: Load): + def query_by_load(cls, load: Load) -> List[TuningRecord]: """ - 指定负载,查询调优记录 + 指定负载,查询与该负载关联的全部调优记录 :param load: 负载实例 - :return: 调优记录实例 + :return: 调优记录实例列表 """ model_attr = { TuningRecord.FieldName.load: load.database_model, @@ -84,13 +85,13 @@ class TuningRecordRepository(Repository): ] @classmethod - def query_by_load_and_config(cls, load: Load, config: dict): + def query_by_load_and_config(cls, load: Load, config: dict) -> List[TuningRecord]: """ - 指定负载、参数配置,查询调优记录 + 指定负载、参数配置,查询唯一匹配的调优记录 :param load: 负载实例 :param config: 参数配置 - :return: 调优记录实例 + :return: 调优记录实例列表(实际上只有一个实例,返回列表是为了维护接口的数据格式统一) """ model_attr = { TuningRecord.FieldName.load: load.database_model, @@ -108,10 +109,10 @@ class TuningRecordRepository(Repository): ] @classmethod - def delete(cls, tuning_record: TuningRecord): + def delete(cls, tuning_record: TuningRecord) -> None: """ 删除 tuning_record 记录 - :param tuning_record: + :param tuning_record:调优记录实例 :return: """ return tuning_record.database_model.delete() diff --git a/omniadvisor/src/omniadvisor/service/retest_service.py b/omniadvisor/src/omniadvisor/service/retest_service.py index c06bbb642..202ce53bb 100644 --- a/omniadvisor/src/omniadvisor/service/retest_service.py +++ b/omniadvisor/src/omniadvisor/service/retest_service.py @@ -1,3 +1,4 @@ +from typing import Any from common.constant import OA_CONF from omniadvisor.repository.model.load import Load from omniadvisor.service.spark_service.spark_run import spark_run @@ -6,13 +7,19 @@ from omniadvisor.utils.logger import global_logger from omniadvisor.utils.utils import float_format -def retest(load: Load, config: dict): +def retest(load: Load, config: dict[str, Any]) -> None: """ - 根据复测结果,更新最优配置 + 根据复测结果更新最优配置。 - :param load: 负载 - :param config: 配置 - :return: + 该方法会在同一配置下执行多轮 Spark 任务复测(次数由 OA_CONF.tuning_retest_times控制), + 用于验证配置的稳定性与性能表现: + - 如果某轮任务成功执行(状态为 OA_CONF.ExecStatus.success),会记录该轮性能结果。 + - 如果出现非 Spark 异常,会立即抛出并中止复测。 + - 如果出现 Spark 异常且调优结果状态为 OA_CONF.TuningResultStatus.fail,则提前结束复测。 + + :param load: 待复测的负载实例。 + :param config: Spark 执行配置(JSON 风格字典)。 + :raises Exception: 当任务执行过程中出现非 Spark 异常时抛出。 """ global_logger.debug('Starting retest config...') for i in range(1, OA_CONF.tuning_retest_times + 1): diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py index f778d2f2d..0b69b3642 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_cmd_parser.py @@ -4,7 +4,7 @@ # Copyright (c) Huawei Technologies Co, Ltd. 2023-2023. All rights reserved. import argparse -from typing import List +from typing import List, Any from omniadvisor.utils.logger import global_logger @@ -88,13 +88,14 @@ class SparkCMDParser: _parser.add_argument('-d', '--database') @classmethod - def validate_submit_arguments(cls, args: List): + def validate_submit_arguments(cls, args: List) -> None: """ 检查当前提交的spark-submit命令是否属于SUBMIT动作的spark命令 该函数与Spark 3.3.1源码中SparkSubmitArguments.scala文件中的validateSubmitArguments逻辑保持一致 :param args: 从spark-submit文件中劫持的入参列表 - :return: 属于SUBMIT动作返回True 不属于SUBMIT动作返回False + :raises ValueError: 当未提供任何参数时抛出(args 为空)。 + :raises TypeError: 当参数不属于 SUBMIT 动作时抛出。 """ if len(args) == 0: raise ValueError("No arguments were provided. At least one argument is required.") @@ -107,9 +108,10 @@ class SparkCMDParser: raise TypeError("This is not a SUBMIT type spark task.") @staticmethod - def _normalize_value(value): + def _normalize_value(value: Any) -> Any: """ 对spark-sql命令解析出来的value做标准化处理 + (该函数虽然暂时不对value做任何处理 但考虑到整体代码结构 保留在此) :param value: 原始命令解析出的value :return: 标准化的value @@ -117,12 +119,12 @@ class SparkCMDParser: return value @staticmethod - def _append_double_dash_args(cls, key, value): + def _append_double_dash_args(cls, key, value) -> List: """ 用于处理spark命令中以 "--" 作为前缀的参数 - :param cls: - :param key: - :param value: + + :param key: 参数名 + :param value: 参数值 :return: """ if key in _BOOLEAN_TYPE_KEYS: @@ -134,7 +136,7 @@ class SparkCMDParser: return [f'--{key}', cls._normalize_value(value)] @classmethod - def parse_cmd(cls, argv: list): + def parse_cmd(cls, argv: list) -> dict: """ 解析提交后的命令,得到解析的参数后,以字典的形式保存 返回一个包含所有提取出的基础参数的字典,每个键代表一个参数名,详见add_argument中的配置的参数。 diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py index 2bcb584f6..c20ea64a8 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_executor.py @@ -1,4 +1,5 @@ import re +from typing import Tuple from dataclasses import dataclass from datetime import datetime @@ -11,7 +12,7 @@ from omniadvisor.utils.logger import global_logger class SparkExecResult: exitcode: int output: str - app_id: str + application_id: str duration: float start_time: datetime end_time: datetime @@ -23,13 +24,13 @@ class SparkExecutor: """ @classmethod - def submit_spark_task(cls, cmd_fields: list, timeout: int): + def submit_spark_task(cls, cmd_fields: list, timeout: int) -> SparkExecResult: """ - 在shell终端提交spark命令,并在命令成功执行并返回时解析结果中的数据 + 在shell终端提交Spark任务,在Spark任务执行正常退出时解析结果中的数据,并使用SparkExecResult类返回解析后的结果 :param cmd_fields: Spark的提交命令 :param timeout: 超时时间 - :return: + :return: SparkExecResult """ if timeout: # 命令列表需全为str类型 @@ -48,7 +49,7 @@ class SparkExecutor: return SparkExecResult( exitcode=exitcode, output=output, - app_id=application_id, + application_id=application_id, duration=total_time_taken, start_time=start_time, end_time=end_time @@ -57,20 +58,21 @@ class SparkExecutor: return SparkExecResult( exitcode=exitcode, output=output, - app_id='', + application_id='', duration=OA_CONF.exec_fail_return_runtime, start_time=start_time, end_time=end_time ) @classmethod - def _parser_spark_output(cls, spark_output: str): + def _parser_spark_output(cls, spark_output: str) -> Tuple[str, float]: """ 解析Spark输出,获得spark提交命令、Application ID和Time taken(若存在) (注: Time taken后续不再作为性能指标被使用 目前仅作为性能参考获取系其数值) :param spark_output: Spark执行输出 - :return: + :raises RuntimeError: 当无法从输出中匹配到 Application ID 时抛出。 + :return: application_id, total_time_taken """ # 解析Application id: # 该条匹配模式的文本来自Spark3.3.1源码中yarn/Client.scala文件在Line224所打印的日志 diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py index 7a544dceb..a154b94a0 100755 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py @@ -2,12 +2,13 @@ import re import json import requests from dateutil import parser +from typing import Any from omniadvisor.utils.logger import global_logger class SparkFetcher: - def __init__(self, history_server_url, username, password): + def __init__(self, history_server_url: str, username: str, password: str): """ 初始化SparkFetcher类 @@ -17,13 +18,18 @@ class SparkFetcher: self.username = username self.password = password - def _make_request(self, endpoint): + def _make_request(self, endpoint: str) -> dict[str, Any]: """ - 发送GET请求并处理响应。 + 发送 GET 请求到 History Server,并返回解析后的 JSON 数据。 + + 该方法会根据配置决定是否使用用户名和密码进行 HTTP 基本认证。 + 如果请求返回错误状态码(非 2xx),会直接抛出异常。 + 响应数据会解析为 JSON 格式返回,如果解析失败,则抛出 ValueError :param endpoint: API端点路径 :return: 解析后的JSON响应数据 - :raises: requests.exceptions.HTTPError 如果HTTP请求返回了错误状态码 + :raises requests.exceptions.HTTPError:当 HTTP 响应状态码不是 200~209 时抛出。 + :raises ValueError:当响应内容无法解析为 JSON 时抛出。 """ endpoint = endpoint.strip("/") url = f"{self.history_server_url}/{endpoint}" @@ -41,7 +47,7 @@ class SparkFetcher: raise ValueError('Something wrong in trace fetched, can not decode into Json data.') from e return json_data - def get_spark_apps(self): + def get_spark_apps(self) -> dict[str, Any]: """ 获取所有Spark应用的信息。 @@ -49,62 +55,62 @@ class SparkFetcher: """ return self._make_request("api/v1/applications") - def get_application_details(self, app_id): + def get_application_details(self, application_id: str) -> dict[str, Any]: """ 根据应用ID获取指定Spark应用的详细信息。 - :param app_id: 应用ID + :param application_id: 应用ID :return: 返回指定应用的详细信息 """ - return self._make_request(f"api/v1/applications/{app_id}") + return self._make_request(f"api/v1/applications/{application_id}") - def get_spark_sql_by_app(self, app_id): + def get_spark_sql_by_app(self, application_id: str) -> dict[str, Any]: """ 根据应用ID获取指定Spark应用的SQL信息。 - :param app_id: 应用ID + :param application_id: 应用ID :return: 返回指定应用的SQL信息 """ - return self._make_request(f"api/v1/applications/{app_id}/sql") + return self._make_request(f"api/v1/applications/{application_id}/sql") - def get_spark_stages_by_app(self, app_id): + def get_spark_stages_by_app(self, application_id: str) -> dict[str, Any]: """ 根据应用ID获取指定Spark应用的阶段信息。 - :param app_id: 应用ID + :param application_id: 应用ID :return: 返回指定应用的阶段信息 """ - return self._make_request(f"api/v1/applications/{app_id}/stages?withSummaries=true") + return self._make_request(f"api/v1/applications/{application_id}/stages?withSummaries=true") - def get_spark_executor_by_app(self, app_id): + def get_spark_executor_by_app(self, application_id: str) -> dict[str, Any]: """ 根据应用ID获取指定Spark应用的执行器信息。 - :param app_id: 应用ID + :param application_id: 应用ID :return: 返回指定应用的执行器信息 """ - return self._make_request(f"api/v1/applications/{app_id}/executors") + return self._make_request(f"api/v1/applications/{application_id}/executors") - def get_spark_jobs_by_app(self, app_id): + def get_spark_jobs_by_app(self, application_id: str) -> dict[str, Any]: """ 根据应用ID获取指定Spark应用的Jobs信息 - :param app_id: 应用ID + :param application_id: 应用ID :return: 返回指定应用的Jobs信息 """ - return self._make_request(f"api/v1/applications/{app_id}/jobs") + return self._make_request(f"api/v1/applications/{application_id}/jobs") - def get_spark_runtime_by_app(self, app_id) -> float: + def get_spark_runtime_by_app(self, application_id) -> float: """ 通过从HistoryServer上获取的jobs信息,计算任务的执行耗时 runtime的计算的耗时为 最早的Job提交的时间 与 最晚的Job完成时间 之间的差值(注:最晚完成的Job并不一定是最后提交的Job) :return: runtime (seconds) """ - jobs_detail = self.get_spark_jobs_by_app(app_id) + jobs_detail = self.get_spark_jobs_by_app(application_id) if not jobs_detail: - global_logger.error(f"No job info returned for app_id={app_id}") - raise ValueError(f"No job information found for app_id={app_id}") + global_logger.error(f"No job info returned for application_id={application_id}") + raise ValueError(f"No job information found for application_id={application_id}") # 获取全部的时间戳信息 timestamp_list = [] diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index ded0683b9..f258adb83 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -1,7 +1,7 @@ import multiprocessing import time from requests.exceptions import HTTPError - +from typing import Tuple from common.constant import OA_CONF from omniadvisor.repository.model.load import Load from omniadvisor.repository.exam_record_repository import ExamRecordRepository @@ -24,7 +24,7 @@ _RETURN_CODE_MAP = { } -def spark_run(load: Load, config: dict, wait_for_trace: bool = False): +def spark_run(load: Load, config: dict, wait_for_trace: bool = False) -> Tuple[ExamRecord, str]: """ 输入负载与配置,执行Spark任务 并从Spark命令的返回值中获取 生成一条记录本次执行信息的exam_record @@ -33,7 +33,7 @@ def spark_run(load: Load, config: dict, wait_for_trace: bool = False): :param load: 负载 :param config: 参数配置 :param wait_for_trace: 是否阻塞等待获取trace - :return: + :return:测试记录实例, Spark任务执行的返回结果 """ # 从解析后的参数列表中提取负载与任务的相关信息 submit_cmd = SparkCMDParser.reconstruct_cmd(exec_attr=load.exec_attr, conf_params=config) @@ -51,13 +51,13 @@ def spark_run(load: Load, config: dict, wait_for_trace: bool = False): global_logger.info('Spark Load %d execute success', load.id) exam_record_status = OA_CONF.ExecStatus.success exam_record_runtime = exec_result.duration - exam_record_app_id = exec_result.app_id + exam_record_application_id = exec_result.application_id else: global_logger.warning('Spark Load %d execute failed, return code: %d, error code describe: %s', load.id, exec_result.exitcode, _RETURN_CODE_MAP.get(exec_result.exitcode, 'Abnormal exit code')) exam_record_status = OA_CONF.ExecStatus.fail exam_record_runtime = OA_CONF.exec_fail_return_runtime - exam_record_app_id = OA_CONF.exec_fail_return_app_id + exam_record_application_id = OA_CONF.exec_fail_return_application_id # 构建测试记录 exam_record = ExamRecordRepository.create( @@ -67,7 +67,7 @@ def spark_run(load: Load, config: dict, wait_for_trace: bool = False): runtime=exam_record_runtime, start_time=exec_result.start_time, end_time=exec_result.end_time, - app_id=exam_record_app_id + application_id=exam_record_application_id ) if exec_result.exitcode == 0: @@ -75,32 +75,32 @@ def spark_run(load: Load, config: dict, wait_for_trace: bool = False): if wait_for_trace: # 阻塞获取trace global_logger.info('Going to fetch Spark execute trace, the process is blocking.') - _update_trace_and_runtime_from_history_server(exam_record=exam_record, application_id=exec_result.app_id) + _update_trace_and_runtime_from_history_server(exam_record=exam_record, application_id=exec_result.application_id) else: # 不阻塞获取trace,通过子进程进行获取 global_logger.info('Going to fetch Spark execute trace, the process is non-blocking.') - p = multiprocessing.Process(target=_update_trace_and_runtime_from_history_server, args=(exam_record, exec_result.app_id)) + p = multiprocessing.Process(target=_update_trace_and_runtime_from_history_server, args=(exam_record, exec_result.application_id)) p.daemon = False p.start() return exam_record, exec_result.output -def _calc_timeout_from_load(load: Load): +def _calc_timeout_from_load(load: Load) -> int: """ - 针对负载和配置,计算得到超时时间 + 根据负载的基线执行时间计算该负载在复测时的超时时间 :param load: 负载 - :return: + :return: 超时时间 """ # 超时时间为基线执行用时的倍数 baseline_result = get_tuning_result(load=load, config=load.default_config) return int(OA_CONF.spark_exec_timeout_ratio * baseline_result.runtime) -def _update_trace_and_runtime_from_history_server(exam_record: ExamRecord, application_id: str): +def _update_trace_and_runtime_from_history_server(exam_record: ExamRecord, application_id: str) -> None: """ - 根据application_id对history_server进行轮询, 查询该任务的trace信息和runtime信息 + 根据application_id对history_server进行轮询, 查询该任务的trace信息和runtime信息,并刷新ExamRecord中相关字段的信息 :param application_id: Spark任务application_id :return: diff --git a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py index 43e656acf..f46d7e2fd 100644 --- a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py +++ b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py @@ -12,6 +12,7 @@ from omniadvisor.repository.tuning_record_repository import TuningRecordReposito from omniadvisor.repository.exam_record_repository import ExamRecordRepository +# TODO 为什么这两个函数单独放在类外面 有点奇怪啊 def get_tuning_result(load: Load, config: dict): """ 指定负载、参数配置,获取调优结果 @@ -39,9 +40,13 @@ def get_tuning_result(load: Load, config: dict): ) -def remove_tuning_result(load: Load, config: dict): +def remove_tuning_result(load: Load, config: dict) -> None: """ - 删除给定负载和配置的执行记录 + 删除给定负载和配置所对应的全部执行记录 + + :param load: 负载实例 + :param config: 负载配置 + :return: """ tuning_result = get_tuning_result(load, config) # 清测试记录 diff --git a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py index bb51a2ec6..3ea6b631e 100644 --- a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py +++ b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py @@ -10,9 +10,10 @@ from omniadvisor.service.tuning_result.tuning_result import TuningResult from omniadvisor.service.tuning_result.tuning_result import get_tuning_result -def get_other_tuning_result_history(exclude_load: Load): +def get_other_tuning_result_history(exclude_load: Load) -> List: """ 查询 exclude_load 以外的所有 tuning result history + :param exclude_load :return: """ @@ -39,7 +40,7 @@ def get_other_tuning_result_history(exclude_load: Load): return result - +# TODO 这个函数怎么也单独在类外部 def get_tuning_result_history(load: Load): """ 指定负载,获取调优结果所有历史 diff --git a/omniadvisor/src/server/app/admin.py b/omniadvisor/src/server/app/admin.py index 68ce07a33..e27e4ca69 100644 --- a/omniadvisor/src/server/app/admin.py +++ b/omniadvisor/src/server/app/admin.py @@ -19,5 +19,5 @@ class TuningRecordAdmin(admin.ModelAdmin): @admin.register(DatabaseExamRecord) class ExamRecordAdmin(admin.ModelAdmin): - list_display = ('id', 'tuning_record', 'start_time', 'end_time', 'status', 'runtime', 'app_id', 'trace') + list_display = ('id', 'tuning_record', 'start_time', 'end_time', 'status', 'runtime', 'application_id', 'trace') list_filter = () diff --git a/omniadvisor/src/server/app/models.py b/omniadvisor/src/server/app/models.py index 49f59be28..7b12f52d3 100644 --- a/omniadvisor/src/server/app/models.py +++ b/omniadvisor/src/server/app/models.py @@ -75,7 +75,7 @@ class DatabaseExamRecord(models.Model): default=OA_CONF.ExecStatus.running ) runtime = models.FloatField(null=True) - app_id = models.CharField(max_length=50, null=True) + application_id = models.CharField(max_length=50, null=True) trace = models.JSONField(null=True) class Meta: diff --git a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py index fd4000d41..161dbf4e3 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py @@ -51,7 +51,7 @@ class TestExamRecordRepository: DatabaseLoad.objects.all().delete() # 构建测试数据 - hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, config={"param1": "value2"}) self.load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, @@ -68,7 +68,7 @@ class TestExamRecordRepository: self.runtime = 20.4 self.start_time = datetime(2023, 10, 5, 14, 30, 15) self.end_time = datetime(2024, 10, 5, 14, 30, 15) - self.app_id = 'application_123456' + self.application_id = 'application_123456' # 创建task self.exam_record = ExamRecordRepository.create( load=self.load, @@ -77,7 +77,7 @@ class TestExamRecordRepository: runtime=self.runtime, start_time=self.start_time, end_time=self.end_time, - app_id=self.app_id + application_id=self.application_id ) def test_create(self): @@ -93,7 +93,7 @@ class TestExamRecordRepository: assert self.exam_record.runtime == self.runtime assert self.exam_record.start_time == self.start_time assert self.exam_record.end_time == self.end_time - assert self.exam_record.app_id == self.app_id + assert self.exam_record.application_id == self.application_id # 测试没有tuning_record场景 with pytest.raises(RuntimeError): @@ -104,7 +104,7 @@ class TestExamRecordRepository: runtime=self.runtime, start_time=self.start_time, end_time=self.end_time, - app_id=self.app_id + application_id=self.application_id ) def test_query_by_tuning_record(self): diff --git a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py index 745cc1e68..c00891849 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py @@ -52,7 +52,7 @@ class TestLoadRepository: def test_query_by_hash_value(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, config={'param1': 'value2'}) LoadRepository.create( name='example_name', @@ -67,13 +67,13 @@ class TestLoadRepository: assert loads[0].exec_attr == {'name': 'example_name', 'cpu': 4} # 查询失败 - wrong_hash_value = _calculate_hash_value(exec_attr={'name': 'wrong_name'}, user_config={'param1': 'value2'}) + wrong_hash_value = _calculate_hash_value(exec_attr={'name': 'wrong_name'}, config={'param1': 'value2'}) loads = LoadRepository.query_by_hash_value(wrong_hash_value) assert len(loads) == 0 def test_update_best_config(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, @@ -90,7 +90,7 @@ class TestLoadRepository: def test_update_test_config(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, @@ -107,7 +107,7 @@ class TestLoadRepository: def test_update_tuning_needed(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, diff --git a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py index 9283d8e3c..ff2e019ef 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py @@ -33,7 +33,7 @@ class TestTaskRepository: def test_create(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, @@ -66,7 +66,7 @@ class TestTaskRepository: def test_query_by_load(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, @@ -89,7 +89,7 @@ class TestTaskRepository: def test_query_by_load_and_config(self): # 创建测试数据 - hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py index c9ff0f041..b4a1eea84 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_executor.py @@ -44,7 +44,7 @@ class TestSparkExecutor: mock_run_cmd.assert_called_once_with(cmd_fields=['timeout', str(self.timeout)] + self.cmd_fields) assert exec_result.exitcode == self.exitcode assert exec_result.output == self.output - assert exec_result.app_id == self.application_id + assert exec_result.application_id == self.application_id assert exec_result.duration == OA_CONF.exec_fail_return_runtime assert type(exec_result.start_time) is datetime assert type(exec_result.end_time) is datetime @@ -54,9 +54,9 @@ class TestSparkExecutor: 测试 parser_spark_output 方法 """ # 正确解析spark output - app_id, total_time_taken = self.spark_executor._parser_spark_output(self.output) + application_id, total_time_taken = self.spark_executor._parser_spark_output(self.output) # 验证结果 - assert app_id == self.application_id + assert application_id == self.application_id assert total_time_taken == OA_CONF.exec_fail_return_runtime # 解析spark output缺失application id diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py index e897e3ca9..fb1282091 100755 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_fetcher.py @@ -175,7 +175,7 @@ class TestSparkFetcher(unittest.TestCase): # 验证是否抛出异常 with self.assertRaises(ValueError) as context: spark_fetcher.get_spark_runtime_by_app("app-1") - self.assertIn("No job information found for app_id=app-1", str(context.exception)) + self.assertIn("No job information found for application_id=app-1", str(context.exception)) @mock.patch('requests.get') def test_get_spark_runtime_by_app_incomplete_jobs(self, mock_get): diff --git a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py index b983bb87b..21deb8434 100644 --- a/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py +++ b/omniadvisor/tests/omniadvisor/service/spark_service/test_spark_run.py @@ -32,7 +32,7 @@ class TestSparkRun: # 构建exec_result mock对象 self.mock_exec_result = MagicMock self.mock_exec_result.exitcode = 0 - self.mock_exec_result.app_id = 'application_123456' + self.mock_exec_result.application_id = 'application_123456' self.mock_exec_result.duration = 24.2 self.mock_exec_result.start_time = datetime(2023, 10, 5, 14, 30, 15) self.mock_exec_result.end_time = datetime(2023, 10, 5, 14, 30, 15) @@ -87,7 +87,7 @@ class TestSparkRun: runtime=self.mock_exec_result.duration, start_time=self.mock_exec_result.start_time, end_time=self.mock_exec_result.end_time, - app_id=self.mock_exec_result.app_id + application_id=self.mock_exec_result.application_id ) mock_multiprocess.assert_called_once() mock_process.start.assert_called_once() @@ -125,6 +125,6 @@ class TestSparkRun: runtime=OA_CONF.exec_fail_return_runtime, start_time=self.mock_exec_result.start_time, end_time=self.mock_exec_result.end_time, - app_id=OA_CONF.exec_fail_return_app_id + application_id=OA_CONF.exec_fail_return_application_id ) mock_multiprocess.assert_not_called() -- Gitee From 742c4ad00551bf3ab67721d870d9abf270ae37ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E8=89=BA=E4=B8=B9?= <53546877+Craven1701@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:04:21 +0800 Subject: [PATCH 2/2] =?UTF-8?q?clean=20code=E4=BB=A5=E5=8F=8A=E4=B8=80?= =?UTF-8?q?=E4=BA=9BTODO=E4=BA=8B=E9=A1=B9=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- omniadvisor/src/omniadvisor/interface/config_tuning.py | 8 ++++---- omniadvisor/src/omniadvisor/interface/hijack_recommend.py | 2 +- .../src/omniadvisor/repository/exam_record_repository.py | 5 ++--- .../omniadvisor/repository/load_prefetch_repository.py | 3 +-- .../omniadvisor/service/spark_service/spark_fetcher.py | 5 +++-- .../src/omniadvisor/service/spark_service/spark_run.py | 4 +++- .../omniadvisor/service/tuning_result/tuning_result.py | 1 - .../service/tuning_result/tuning_result_history.py | 2 +- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 16ce88dcd..2d17d4c31 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -24,7 +24,7 @@ from omniadvisor.utils.logger import global_logger from omniadvisor.utils.utils import float_format -def handler(signum: int, frame) -> None: # TODO 这个frame为什么是no value 以及signum是什么类型 +def handler(signum: int, frame) -> None: """ 用于注册异常退出信号的处理函数。 @@ -117,7 +117,7 @@ def _single_tuning(load: Load, retest_way: str, tuning_method: str) -> None: LoadRepository.update_test_config(load, next_config) -def _get_next_config(load: Load, tuning_method: str) -> tuple[str, str]: # TODO 检查一下这个返回值类型 +def _get_next_config(load: Load, tuning_method: str) -> tuple[dict, str]: """ 对指定的负载通过指定的调优算法进行调优,并返回调优后的配置 @@ -127,7 +127,7 @@ def _get_next_config(load: Load, tuning_method: str) -> tuple[str, str]: # TODO :returns: 一个元组 (method_extend, next_config) - **method_extend**: 调优方法描述 - **next_config**: 优化后配置 - :rtype: tuple[str, str] + :rtype: tuple[dict, str] """ # 获取指定负载的调优历史记录 tuning_result_history = get_tuning_result_history(load) @@ -181,7 +181,7 @@ def _continuous_tuning_with_strategies(load: Load, tuning_strategies: list[tuple break -def _query_and_check_load(load_id: str) -> Optional[Load]: # TODO load_id用str类型感觉略微奇怪 +def _query_and_check_load(load_id: str) -> Optional[Load]: """ 根据load_id查询该负载是否存在于数据库中且需要调优 diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 7b40cb255..76dfde80f 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -27,7 +27,7 @@ _EMPTY_STR = '' _UNKNOWN_STR = 'unknown' -def _get_exec_config_from_load(load: Load) -> dict[str, Any]: +def _get_exec_config_from_load(load: Load) -> dict[str, str]: """ 根据负载属性,按优先级规则获取即将执行的参数配置 diff --git a/omniadvisor/src/omniadvisor/repository/exam_record_repository.py b/omniadvisor/src/omniadvisor/repository/exam_record_repository.py index c8e790fb4..f03d2878e 100644 --- a/omniadvisor/src/omniadvisor/repository/exam_record_repository.py +++ b/omniadvisor/src/omniadvisor/repository/exam_record_repository.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List +from typing import List, Tuple from common.constant import OA_CONF from server.app.models import ( DatabaseTuningRecord, @@ -139,12 +139,11 @@ class ExamRecordRepository(Repository): return ExamRecord(database_model=database_task) @classmethod - def delete(cls, exam_record: ExamRecord) -> None: + def delete(cls, exam_record: ExamRecord) -> Tuple[int, dict]: """ 从数据库删除指定的测试记录 :param exam_record: 测试记录实例 :return: """ - # TODO 这个delete动作返回什么东西了?我甚至找不到这个delete函数 return exam_record.database_model.delete() diff --git a/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py b/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py index 4362ec724..12b4d472f 100644 --- a/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py +++ b/omniadvisor/src/omniadvisor/repository/load_prefetch_repository.py @@ -11,8 +11,7 @@ from omniadvisor.repository.repository import Repository class LoadPrefetchRepository(Repository): """ - #TODO 这个类居然五月份就有了吗?能不能介绍一下这是个什么类。 - #TODO 为什么这个类只有这一个函数 这是干啥用的 + 提前批量获取反向外键相关对象的数据 避免N+1的查询问题 """ @classmethod diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py index a154b94a0..6868c59b5 100755 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_fetcher.py @@ -1,8 +1,9 @@ -import re import json +import re +from typing import Any + import requests from dateutil import parser -from typing import Any from omniadvisor.utils.logger import global_logger diff --git a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py index 66b4fbbec..53bd47ae8 100644 --- a/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py +++ b/omniadvisor/src/omniadvisor/service/spark_service/spark_run.py @@ -1,7 +1,9 @@ import multiprocessing import time -from requests.exceptions import HTTPError from typing import Tuple + +from requests.exceptions import HTTPError + from common.constant import OA_CONF from omniadvisor.repository.model.load import Load from omniadvisor.repository.exam_record_repository import ExamRecordRepository diff --git a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py index f46d7e2fd..1b3475a54 100644 --- a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py +++ b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result.py @@ -12,7 +12,6 @@ from omniadvisor.repository.tuning_record_repository import TuningRecordReposito from omniadvisor.repository.exam_record_repository import ExamRecordRepository -# TODO 为什么这两个函数单独放在类外面 有点奇怪啊 def get_tuning_result(load: Load, config: dict): """ 指定负载、参数配置,获取调优结果 diff --git a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py index 3ea6b631e..e28be2747 100644 --- a/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py +++ b/omniadvisor/src/omniadvisor/service/tuning_result/tuning_result_history.py @@ -40,7 +40,7 @@ def get_other_tuning_result_history(exclude_load: Load) -> List: return result -# TODO 这个函数怎么也单独在类外部 + def get_tuning_result_history(load: Load): """ 指定负载,获取调优结果所有历史 -- Gitee