diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index 8a75091fbde9d10761be2ab7a7f3274a53e56f2a..4bfa3646e61311675b421dfd16f9c792b0663aa7 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -131,8 +131,18 @@ class OmniAdvisorConf: r'\d{4}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])((0\d|1\d|2[0-3]))?' ] + # 所有日期-小时被替换后的变量: + date_replacement_value = '{date}' + # spark-sql命令所提交的类名 - SparkSQLCLIDriver = 'org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver' + spark_sql_cli_driver = 'org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver' + + # 所有涉及到数据库变动的sql关键字 + modify_keywords = [ + 'create', 'alter', 'insert', 'truncate', 'drop', 'repair', 'load', + # 以下这些关键字,在非 hadoop表中可能是合法的 + 'update', 'merge', 'delete', 'rewrite', 'upsert', 'replace' + ] OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/common/exceptions.py b/omniadvisor/src/common/exceptions.py index b23ef4c8241a8f1df19f0ed30255d68dcb5c363a..59310db44dc6af6550a0d994ea8f22fe9d886584 100644 --- a/omniadvisor/src/common/exceptions.py +++ b/omniadvisor/src/common/exceptions.py @@ -5,6 +5,13 @@ class NoOptimalConfigError(Exception): pass +class TuningPreconditionError(Exception): + """ + 未满足调优前置条件 + """ + pass + + class SystemKilledError(Exception): """ 异常,程序被异常中断,例如kill等 @@ -17,3 +24,10 @@ class DuplicateEntryError(Exception): 数据重复异常 """ pass + + +class UnknownEncodingError(Exception): + """ + 未知编码异常 + """ + pass diff --git a/omniadvisor/src/omniadvisor/interface/config_tuning.py b/omniadvisor/src/omniadvisor/interface/config_tuning.py index 0d4803924f518f2327820dbcfd26620c5eaa5a94..d934eba699e6fd47aaa3e8501c577139a806ce1d 100644 --- a/omniadvisor/src/omniadvisor/interface/config_tuning.py +++ b/omniadvisor/src/omniadvisor/interface/config_tuning.py @@ -6,7 +6,7 @@ from algo.iterative.tuning import SmacAppendTuning from algo.native.tuning import NativeTuning from algo.transfer.tuning import TransferTuning from common.constant import OA_CONF -from common.exceptions import NoOptimalConfigError, SystemKilledError +from common.exceptions import NoOptimalConfigError, SystemKilledError, TuningPreconditionError from omniadvisor.repository.model.load import Load from omniadvisor.repository.load_repository import LoadRepository from omniadvisor.repository.tuning_record_repository import TuningRecordRepository @@ -175,12 +175,10 @@ def _query_and_check_load(load_id): loads = LoadRepository.query_by_id(load_id) if not loads: # 若查询结果为空,直接返回即可 - global_logger.info('Cannot find load id: %s in database, the tuning process exits.', load_id) - return None + raise TuningPreconditionError(f'Cannot find load id: {load_id} in database, the tuning process exits.') load = loads.pop() if not load.tuning_needed: - global_logger.info('The load dont need to tune, the tuning process exits.') - return None + raise TuningPreconditionError('The load dont need to tune, the tuning process exits.') return load @@ -192,29 +190,31 @@ def main(): # 检查load 是否存在以及是否需要调优 load = _query_and_check_load(args.load_id) - if not load: - return # 前台复测 # 1. 必须指定调优方法 # 2. 若负载存在test_config,则说明有配置在前台调优流程中,退出 if args.retest_way == OA_CONF.RetestWay.hijacking: if not args.tuning_method: - global_logger.info('When using retest way \'hijacking\', param --tuning-method is needed.') - return + raise TuningPreconditionError('When using retest way \'hijacking\', param --tuning-method is needed.') if load.test_config: - global_logger.info('There is a config in hijacking retest, please try another load.') - return + raise TuningPreconditionError('There is a config in hijacking retest, please try another load.') _single_tuning(load=load, retest_way=args.retest_way, tuning_method=args.tuning_method) # 后台复测 # 1. 若指定调优方法,则执行一轮调优;若不指定,则进入连续调优策略 # 2. 不做调优冲突判断,若推荐相同配置,则退出 + # 3. 当前负载不支持后台复测,则退出 elif args.retest_way == OA_CONF.RetestWay.backend: + if load.backend_retest_forbidden: + raise TuningPreconditionError( + 'Current load contains data manipulation operations, only retest way of hijacking is supported.' + ) + if args.tuning_method: _single_tuning(load=load, retest_way=args.retest_way, tuning_method=args.tuning_method) else: _continuous_tuning_with_strategies(load=load, tuning_strategies=OA_CONF.tuning_strategies) else: - raise RuntimeError('The retest way must in %s, please check cmd.', OA_CONF.RetestWay.all) + raise TuningPreconditionError('The retest way must in %s, please check cmd.', OA_CONF.RetestWay.all) diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index 5da174b1a2ada9c482ddc511ab9385f9d86fa517..4666b6f11880fe74e585a9041e06e01c4a77ce60 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -5,6 +5,7 @@ import re import sys from common.constant import OA_CONF +from common.exceptions import UnknownEncodingError from omniadvisor.repository.load_repository import LoadRepository from omniadvisor.repository.model.load import Load from omniadvisor.repository.tuning_record_repository import TuningRecordRepository @@ -13,11 +14,16 @@ from omniadvisor.service.spark_service.spark_run import spark_run from omniadvisor.service.tuning_result.tuning_result import get_tuning_result from omniadvisor.service.tuning_result.tuning_result_history import get_tuning_result_history from omniadvisor.utils.logger import global_logger -from omniadvisor.utils.utils import float_format +from omniadvisor.utils.utils import float_format, safe_read_file, file_exists -_NAME_STR = 'name' +_PARAM_NAME_IN_EXEC_ATTR = 'name' +_PARAM_F_IN_EXEC_ATTR = 'f' +_PARAM_E_IN_EXEC_ATTR = 'e' +_PARAM_CLASS_IN_EXEC_ATTR = 'class' _USER_CONFIG_STR = 'user_config' _FILE_CONTENT_STR = 'file_content' +_EMPTY_STR = '' +_UNKNOWN_STR = 'unknown' def _get_exec_config_from_load(load: Load): @@ -93,36 +99,32 @@ def _process_load_config(load: Load, config: dict): pass -def _get_load_name_from_exec_attr(exec_attr: dict): - """ - 获取任务名,无则空 - """ - return exec_attr.get('name', '') - - def _calculate_hash_value(exec_attr: dict, user_config: dict): """ 根据 执行参数和用户配置计算负载唯一标识,hash值 + + :param exec_attr: 执行参数 + :param user_config: 用户配置 + :return: """ - class_info = exec_attr.get('class', '') - if 'SparkSQLCLIDriver' in class_info: + if _is_spark_sql_cli(exec_attr): # spark-sql,则提取 -e 和 -f,替换 - if 'e' in exec_attr: - file_content = _remove_time(exec_attr.get('e', '')) + if _PARAM_E_IN_EXEC_ATTR in exec_attr: + file_content = _remove_time(exec_attr.get(_PARAM_E_IN_EXEC_ATTR, _EMPTY_STR)) else: - file_content = exec_attr.get('f', '') + file_content = exec_attr.get(_PARAM_F_IN_EXEC_ATTR, _EMPTY_STR) else: # spark-submit,则提取 file 及之后的内容,包含 file的入参 - unknown_content = exec_attr.get('unknown', []) + unknown_content = exec_attr.get(_UNKNOWN_STR, []) unknown_content = [item.strip() for item in unknown_content] if len(unknown_content) > 0: file_content = unknown_content[0] else: - file_content = '' + file_content = _EMPTY_STR - name = _get_load_name_from_exec_attr(exec_attr) + name = exec_attr.get(_PARAM_NAME_IN_EXEC_ATTR, _EMPTY_STR) data = { - _NAME_STR: _remove_time(name), + _PARAM_NAME_IN_EXEC_ATTR: _remove_time(name), _USER_CONFIG_STR: user_config, _FILE_CONTENT_STR: file_content } @@ -134,17 +136,71 @@ def _calculate_hash_value(exec_attr: dict, user_config: dict): def _remove_time(content: str): """ 移除时间信息,可精确到小时 + + :param content: 待移除的文本内容 + :return: """ for pattern in OA_CONF.date_patterns: - content = re.sub(pattern, '{date}', content) + content = re.sub(pattern, OA_CONF.date_replacement_value, content) return content.strip() +def _is_spark_sql_cli(exec_attr: dict) -> bool: + """ + 判断是否为 spark-sql类型 + + :param exec_attr: 执行参数 + :return: + """ + class_name = exec_attr.get(_PARAM_CLASS_IN_EXEC_ATTR, _EMPTY_STR) + return OA_CONF.spark_sql_cli_driver == class_name + + +def _has_modify_keywords(exec_attr: dict) -> bool: + """ + 检查是否含有修改数据库的操作 + + :param exec_attr: 执行参数 + :return: + """ + if _is_spark_sql_cli(exec_attr): + # spark-sql,则提取 -e 和 -f,替换 + if _PARAM_E_IN_EXEC_ATTR in exec_attr: + file_content = exec_attr.get(_PARAM_E_IN_EXEC_ATTR, _EMPTY_STR) + else: + # 这里也有可能是相对路径 + file_path = exec_attr.get(_PARAM_F_IN_EXEC_ATTR, _EMPTY_STR) + exists = file_exists(file_path) + if not exists: + # 这里初步判断文件用相对路径的比较少,所以打warning + global_logger.warning( + 'File %s is not exists, for safety’s sake, it contains modify operation by default', file_path + ) + return True + try: + file_content = safe_read_file(file_path) + except (UnknownEncodingError, Exception) as e: + # 初步判断,目前已有的编码,能覆盖,所以打warning + global_logger.warning( + 'An unexpected situation: %s occurred when reading file %s , for safety’s sake, it contains ' + 'modify operation by default', e, file_path + ) + return True + + return _any_modify_keywords_in_sql(file_content) + else: + return True + + def _create_or_update_load(exec_attr: dict, default_config: dict): """ 无则创建,有则更新 name 和 exec_attr 字段。如果日期信息等一直保持不变,后期复测时,会导致测试出现异常 + + :param exec_attr: 执行参数 + :param default_config: 默认配置 + :return: """ - name = _get_load_name_from_exec_attr(exec_attr) + name = exec_attr.get(_PARAM_NAME_IN_EXEC_ATTR, _EMPTY_STR) hash_value = _calculate_hash_value(exec_attr, default_config) loads = LoadRepository.query_by_hash_value(hash_value) if loads: @@ -152,12 +208,29 @@ def _create_or_update_load(exec_attr: dict, default_config: dict): return LoadRepository.update_name_and_exec_attr(loads.pop(), name, exec_attr) else: # 负载没有创建过 - load = LoadRepository.create(name, exec_attr, default_config, hash_value) + backend_retest_forbidden = _has_modify_keywords(exec_attr) + load = LoadRepository.create(name, exec_attr, default_config, hash_value, backend_retest_forbidden) global_logger.info("Load not found in database, created new one and execute.") TuningRecordRepository.create(load=load, config=default_config, method=OA_CONF.TuningMethod.user) return load +def _any_modify_keywords_in_sql(sql: str): + """ + 判断在当前sql中是否存在任何的修改动作 + + :param sql: sql内容 + :return: + """ + # 正则转义关键词(虽然目前不需要,但防御性更强) + escaped_keywords = [re.escape(k) for k in OA_CONF.modify_keywords] + + # 后缀是空格、换行、制表符、SQL常见符号 或字符串结束 + pattern = re.compile(r'(?:' + '|'.join(escaped_keywords) + r')(?=\s|[(),;]|$)', re.IGNORECASE) + + return bool(pattern.search(sql)) + + def hijack_recommend(argv: list): """ 任务劫持,使能参数下发执行任务 diff --git a/omniadvisor/src/omniadvisor/repository/load_repository.py b/omniadvisor/src/omniadvisor/repository/load_repository.py index 28448fd5c78fc4fc66ec3002458ebfbf89183dcc..36146907ea0144376f256e5ca1afb6984f43e8f5 100644 --- a/omniadvisor/src/omniadvisor/repository/load_repository.py +++ b/omniadvisor/src/omniadvisor/repository/load_repository.py @@ -33,7 +33,7 @@ class LoadRepository(Repository): } @classmethod - def create(cls, name: str, exec_attr: dict, default_config: dict, hash_value: str): + def create(cls, name: str, exec_attr: dict, default_config: dict, hash_value: str, backend_retest_forbidden: bool): """ 指定名称、执行属性和默认配置,新增负载 @@ -41,13 +41,15 @@ class LoadRepository(Repository): :param exec_attr: 负载执行属性 :param default_config: 负载默认配置 :param hash_value: hash值 + :param backend_retest_forbidden: 是否禁用后台复测 :return: 负载实例 """ model_attr = { Load.FieldName.name: name, Load.FieldName.exec_attr: exec_attr, Load.FieldName.default_config: default_config, - Load.FieldName.hash_value: hash_value + Load.FieldName.hash_value: hash_value, + Load.FieldName.backend_retest_forbidden: backend_retest_forbidden } try: database_load = cls._create(model_attr=model_attr) diff --git a/omniadvisor/src/omniadvisor/repository/model/load.py b/omniadvisor/src/omniadvisor/repository/model/load.py index ca21c0f4cb6a6d96ba1bacfbc4e36a9e6ad0cd4a..18683a88f78bb17da9a5c405aebcccbb9a5ae142 100644 --- a/omniadvisor/src/omniadvisor/repository/model/load.py +++ b/omniadvisor/src/omniadvisor/repository/model/load.py @@ -59,3 +59,4 @@ class Load: create_time = 'create_time' tuning_needed = 'tuning_needed' hash_value = 'hash_value' + backend_retest_forbidden = 'backend_retest_forbidden' diff --git a/omniadvisor/src/omniadvisor/utils/utils.py b/omniadvisor/src/omniadvisor/utils/utils.py index b53749b0c6cb525db4067fc363b225fbee4f47dc..7bc693626217302148d7c591e2c641853486dc75 100644 --- a/omniadvisor/src/omniadvisor/utils/utils.py +++ b/omniadvisor/src/omniadvisor/utils/utils.py @@ -1,10 +1,11 @@ -import os import json -import uuid +import os import subprocess +import uuid from typing import Tuple, List, Dict from common.constant import OmniAdvisorConf +from common.exceptions import UnknownEncodingError from omniadvisor.utils.logger import global_logger @@ -76,3 +77,26 @@ def float_format(to_format: float) -> str: :return: """ return "{:.{}f}".format(to_format, OmniAdvisorConf.decimal_digits) + + +def safe_read_file(filepath: str): + """ + 以多种不同的编码形式,读取文件 + """ + encoding_types = ['utf-8', 'utf-8-sig', 'gb18030', 'latin1'] + for enc in encoding_types: + try: + with open(filepath, 'r', encoding=enc) as f: + return f.read() + except UnicodeDecodeError: + continue + raise UnknownEncodingError( + f'Only encoding in {encoding_types} is supported, encoding of current file {filepath} is unknown' + ) + + +def file_exists(filepath: str): + """ + 检查文件是否存在 + """ + return os.path.isfile(filepath) diff --git a/omniadvisor/src/server/app/admin.py b/omniadvisor/src/server/app/admin.py index 68ce07a33ebf9bce7cd75e99468aee041a1a6033..ccf45875f3e65d52508d8dafaadeb0f0b514bdd8 100644 --- a/omniadvisor/src/server/app/admin.py +++ b/omniadvisor/src/server/app/admin.py @@ -7,7 +7,9 @@ from .models import DatabaseLoad, DatabaseTuningRecord, DatabaseExamRecord @admin.register(DatabaseLoad) class LoadAdmin(admin.ModelAdmin): - list_display = ('id', 'name', 'exec_attr', 'default_config', 'best_config', 'test_config') + list_display = ( + 'id', 'name', 'exec_attr', 'default_config', 'best_config', 'test_config', 'backend_retest_forbidden' + ) list_filter = () diff --git a/omniadvisor/src/server/app/models.py b/omniadvisor/src/server/app/models.py index 49f59be28678891af050d3b897ecbb49257f43be..b6072c10a74dbdeacf7078d3d271227ad267f924 100644 --- a/omniadvisor/src/server/app/models.py +++ b/omniadvisor/src/server/app/models.py @@ -17,6 +17,7 @@ class DatabaseLoad(models.Model): create_time = models.DateTimeField(auto_now_add=True) tuning_needed = models.BooleanField(default=True) hash_value = models.CharField(max_length=64, null=False, unique=True) + backend_retest_forbidden = models.BooleanField(default=True) class Meta: db_table = 'omniadvisor_load' # 自定义表名 diff --git a/omniadvisor/src/tuning.py b/omniadvisor/src/tuning.py index a75f590572f1bf98f3a5c278b07b4b4e4bc2760a..f066bb93e71e3797f2ed41eb668d5859e5ed7fff 100644 --- a/omniadvisor/src/tuning.py +++ b/omniadvisor/src/tuning.py @@ -1,4 +1,5 @@ from common.constant import check_oa_conf +from common.exceptions import TuningPreconditionError from omniadvisor.interface.config_tuning import ( main, NoOptimalConfigError @@ -12,7 +13,7 @@ if __name__ == '__main__': main() # 无需进行逻辑处理的异常,直接抛至该层 # 若需进行逻辑处理(如环境清理等),则需在相应位置处理后重新抛至该层 - except NoOptimalConfigError as e: + except (NoOptimalConfigError, TuningPreconditionError) as e: # 若未找到优化配置,则通过warning打印即可 global_logger.warning(e) # 正常退出 diff --git a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py index 1b8e4c4802b36364de1216bafba1551424d8485b..7a5f972a00d2c8fc2368909ab4e269a6a59ebe43 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py +++ b/omniadvisor/tests/omniadvisor/interface/test_config_tuning.py @@ -2,6 +2,8 @@ import sys from unittest.mock import patch, MagicMock import pytest + +from common.exceptions import TuningPreconditionError from omniadvisor.interface.config_tuning import _single_tuning, main from common.constant import OA_CONF @@ -18,95 +20,140 @@ class TestTuning: self.empty_str = '' self.tune_return_val = ({'key': 'value'}, '') - def test_unified_tuning_when_retest_backend(self): + @patch('omniadvisor.service.retest_service.float_format') + @patch('omniadvisor.interface.config_tuning.float_format') + @patch('algo.iterative.tuning.SmacAppendTuning.tune') + @patch('omniadvisor.interface.config_tuning.get_tuning_result_history') + @patch('omniadvisor.interface.config_tuning.get_tuning_result') + @patch('omniadvisor.service.retest_service.spark_run') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') + @patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') + def test_unified_tuning_when_retest_backend( + self, + mock_query_by_id, + mock_query_tuning_record, + mock_create_tuning_record, + mock_spark_run, + mock_get_tuning_result, + mock_get_history, + mock_smac_tuning, + mock_float_format_config, + mock_float_format_retest + ): """ 后台复测顺利执行 :return: """ - with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id'), \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') as mock_query_tuning_record, \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'), \ - patch('omniadvisor.service.retest_service.spark_run') as mock_spark_run, \ - patch('omniadvisor.interface.config_tuning.get_tuning_result'), \ - patch('omniadvisor.interface.config_tuning.get_tuning_result_history') as mock_get_history, \ - patch('algo.iterative.tuning.SmacAppendTuning.tune') as mock_smac_tuning, \ - patch('omniadvisor.interface.config_tuning.float_format'), \ - patch('omniadvisor.service.retest_service.float_format'): - mock_query_tuning_record.return_value = list() - mock_exam_record = MagicMock() - mock_smac_tuning.return_value = self.tune_return_val - mock_exam_record.status = OA_CONF.ExecStatus.success - spark_output = self.empty_str - mock_spark_run.return_value = mock_exam_record, spark_output - mock_tuning_history = MagicMock() - mock_get_history.return_value = mock_tuning_history - _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - mock_tuning_history.refresh_best_config.assert_called_once() - mock_smac_tuning.assert_called_once() - assert mock_spark_run.call_count == OA_CONF.tuning_retest_times + mock_query_tuning_record.return_value = list() + mock_exam_record = MagicMock() + mock_smac_tuning.return_value = self.tune_return_val + mock_exam_record.status = OA_CONF.ExecStatus.success + spark_output = self.empty_str + mock_spark_run.return_value = mock_exam_record, spark_output + mock_tuning_history = MagicMock() + mock_get_history.return_value = mock_tuning_history + _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) + mock_tuning_history.refresh_best_config.assert_called_once() + mock_smac_tuning.assert_called_once() + assert mock_spark_run.call_count == OA_CONF.tuning_retest_times - def test_unified_tuning_when_spark_execute_failed(self): + @patch('omniadvisor.service.retest_service.float_format') + @patch('omniadvisor.interface.config_tuning.float_format') + @patch('omniadvisor.repository.load_repository.LoadRepository.update_best_config') + @patch('algo.iterative.tuning.SmacAppendTuning.tune') + @patch('omniadvisor.interface.config_tuning.remove_tuning_result') + @patch('omniadvisor.service.retest_service.get_tuning_result') + @patch('omniadvisor.interface.config_tuning.get_tuning_result_history') + @patch('omniadvisor.service.retest_service.spark_run') + @patch('omniadvisor.interface.config_tuning.get_tuning_result') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') + @patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') + def test_unified_tuning_when_spark_execute_failed( + self, + mock_query_by_id, + mock_query_tuning_record, + mock_create_tuning_record, + mock_get_tuning_result_config, + mock_spark_run, + mock_tuning_result_history, + mock_get_tuning_result_service, + mock_remove_tuning_result, + mock_smac_tuning, + mock_update_best, + mock_float_format_config, + mock_float_format_service + ): """ 后台复测 spark命令执行异常 :return: """ - with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id'), \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') as mock_query_tuning_record, \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'), \ - patch('omniadvisor.interface.config_tuning.get_tuning_result'), \ - patch('omniadvisor.service.retest_service.spark_run') as mock_spark_run, \ - patch('omniadvisor.interface.config_tuning.get_tuning_result_history') as mock_tuning_result_history, \ - patch('omniadvisor.service.retest_service.get_tuning_result') as mock_get_tuning_result, \ - patch('omniadvisor.interface.config_tuning.remove_tuning_result') as mock_remove_tuning_result, \ - patch('algo.iterative.tuning.SmacAppendTuning.tune') as mock_smac_tuning, \ - patch('omniadvisor.repository.load_repository.LoadRepository.update_best_config') as mock_update_best, \ - patch('omniadvisor.interface.config_tuning.float_format'), \ - patch('omniadvisor.service.retest_service.float_format'): - mock_query_tuning_record.return_value = list() - mock_exam_record = MagicMock() - mock_exam_record.status = OA_CONF.ExecStatus.fail - spark_output = self.empty_str - mock_spark_run.return_value = mock_exam_record, spark_output - mock_smac_tuning.return_value = self.tune_return_val - mock_tuning_result = MagicMock() - mock_tuning_result.failed_count = OA_CONF.config_fail_threshold - mock_get_tuning_result.return_value = mock_tuning_result - tuning_result_history = MagicMock() - tuning_result_history.best_config = {**self.load.default_config, **self.tune_return_val[0]} - mock_tuning_result_history.return_value = tuning_result_history - _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - mock_smac_tuning.assert_called_once() - mock_spark_run.assert_called_once() - mock_update_best.assert_not_called() - mock_remove_tuning_result.assert_not_called() + mock_query_tuning_record.return_value = list() + mock_exam_record = MagicMock() + mock_exam_record.status = OA_CONF.ExecStatus.fail + spark_output = self.empty_str + mock_spark_run.return_value = mock_exam_record, spark_output + mock_smac_tuning.return_value = self.tune_return_val + mock_tuning_result = MagicMock() + mock_tuning_result.failed_count = OA_CONF.config_fail_threshold + mock_get_tuning_result_service.return_value = mock_tuning_result + tuning_result_history = MagicMock() + tuning_result_history.best_config = {**self.load.default_config, **self.tune_return_val[0]} + mock_tuning_result_history.return_value = tuning_result_history + _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - def test_unified_tuning_when_other_exception(self): + mock_smac_tuning.assert_called_once() + mock_spark_run.assert_called_once() + mock_update_best.assert_not_called() + mock_remove_tuning_result.assert_not_called() + + @patch('omniadvisor.interface.config_tuning.remove_tuning_result') + @patch('algo.iterative.tuning.SmacAppendTuning.tune') + @patch('omniadvisor.repository.load_repository.LoadRepository.update_best_config') + @patch('omniadvisor.interface.config_tuning.get_tuning_result_history') + @patch('omniadvisor.service.retest_service.spark_run', side_effect=RuntimeError) + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') + @patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') + def test_unified_tuning_when_other_exception( + self, + mock_query_by_id, + mock_query_tuning_record, + mock_create_tuning_record, + mock_spark_run, + mock_get_tuning_result_history, + mock_update_best_config, + mock_smac_tuning, + mock_remove_tuning_result + ): """ 后台复测 发生其他异常 :return: """ + mock_query_tuning_record.return_value = list() + mock_exam_record = MagicMock() + mock_exam_record.status = OA_CONF.ExecStatus.success + mock_smac_tuning.return_value = self.tune_return_val + with pytest.raises(RuntimeError): + _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id'), \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') as mock_query_tuning_record, \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'), \ - patch('omniadvisor.service.retest_service.spark_run', side_effect=RuntimeError) as mock_spark_run, \ - patch('omniadvisor.interface.config_tuning.get_tuning_result_history'), \ - patch('omniadvisor.repository.load_repository.LoadRepository.update_best_config'), \ - patch('algo.iterative.tuning.SmacAppendTuning.tune') as mock_smac_tuning, \ - patch('omniadvisor.interface.config_tuning.remove_tuning_result') as mock_remove_tuning_result: - mock_query_tuning_record.return_value = list() - mock_exam_record = MagicMock() - mock_exam_record.status = OA_CONF.ExecStatus.success - mock_smac_tuning.return_value = self.tune_return_val - with pytest.raises(RuntimeError): - _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.backend, tuning_method=self.tuning_method) - - mock_smac_tuning.assert_called_once() - mock_spark_run.assert_called_once() - mock_remove_tuning_result.assert_called_once() + mock_smac_tuning.assert_called_once() + mock_spark_run.assert_called_once() + mock_remove_tuning_result.assert_called_once() - def test_main_when_load_id_not_exist(self): + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create') + @patch('omniadvisor.interface.config_tuning.get_tuning_result_history') + @patch('omniadvisor.utils.logger.global_logger.info') + @patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') + def test_main_when_load_id_not_exist( + self, + mock_query_by_id, + mock_info, + mock_get_tuning_result_history, + mock_create_tuning_record + ): """ 当 load id 不存在时 :return: @@ -116,29 +163,33 @@ class TestTuning: '--retest-way', OA_CONF.RetestWay.backend, '--tuning-method', self.tuning_method] - with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') as mock_query_by_id, \ - patch('omniadvisor.utils.logger.global_logger.info') as mock_info, \ - patch('omniadvisor.interface.config_tuning.get_tuning_result_history'), \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'): - mock_query_by_id.return_value = None + mock_query_by_id.return_value = None + with pytest.raises(TuningPreconditionError, match='Cannot find load id'): main() - mock_info.assert_called_with('Cannot find load id: %s in database, the tuning process exits.', 'None') - def test_unified_tuning_when_retest_failed(self): + @patch('algo.iterative.tuning.SmacAppendTuning.tune') + @patch('omniadvisor.repository.load_repository.LoadRepository.update_test_config') + @patch('omniadvisor.interface.config_tuning.get_tuning_result_history') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create') + @patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') + @patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id') + def test_unified_tuning_when_retest_failed( + self, + mock_query_by_id, + mock_query_tuning_record, + mock_create_tuning_record, + mocked_query_perf, + mock_update_test, + mock_smac_tuning + ): """ 前台复测 :return: """ - with patch('omniadvisor.repository.load_repository.LoadRepository.query_by_id'), \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.query_by_load_and_config') as mock_query_tuning_record, \ - patch('omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create'), \ - patch('omniadvisor.interface.config_tuning.get_tuning_result_history') as mocked_query_perf, \ - patch('omniadvisor.repository.load_repository.LoadRepository.update_test_config') as mock_update_test, \ - patch('algo.iterative.tuning.SmacAppendTuning.tune') as mock_smac_tuning: - mock_query_tuning_record.return_value = list() - mock_smac_tuning.return_value = self.tune_return_val - _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.hijacking, - tuning_method=self.tuning_method) - mocked_query_perf.assert_called_once() - mock_update_test.assert_called_once() - mock_smac_tuning.assert_called_once() + mock_query_tuning_record.return_value = list() + mock_smac_tuning.return_value = self.tune_return_val + _single_tuning(load=self.load, retest_way=OA_CONF.RetestWay.hijacking, + tuning_method=self.tuning_method) + mocked_query_perf.assert_called_once() + mock_update_test.assert_called_once() + mock_smac_tuning.assert_called_once() diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index fa03b8ea506bc4bc223e1dfbdfbd35b700f8c51b..70a5cb7d88370e4d69a68a28fea7ae7eb6fd6e88 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -5,7 +5,7 @@ import pytest from common.constant import OA_CONF from common.exceptions import DuplicateEntryError from omniadvisor.interface.hijack_recommend import hijack_recommend, _process_load_config, _create_or_update_load, \ - _remove_time, _calculate_hash_value + _remove_time, _calculate_hash_value, _any_modify_keywords_in_sql # 测试代码 @@ -72,7 +72,7 @@ class TestHijackRecommend: @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.validate_submit_arguments") - def test_failed_user_config(self, mock_validate_submit_arguments, mock_parse_cmd,mock_create_or_update_load, + def test_failed_user_config(self, mock_validate_submit_arguments, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run, mock_process_config): argv = ["--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} @@ -223,7 +223,7 @@ class TestHijackRecommend: mock_load_create.return_value = "created_object" mock_query_by_hash_value.return_value = [] result = _create_or_update_load(exec_attr, default_config) - mock_load_create.assert_called_once_with(name, exec_attr, default_config, hash_value) + mock_load_create.assert_called_once_with(name, exec_attr, default_config, hash_value, True) mock_tuning_record_create.assert_called_once() assert result == "created_object" @@ -316,3 +316,148 @@ class TestHijackRecommend: input_text = "文件名:report_2025080112_final.pdf" expected = "文件名:report_{date}_final.pdf" assert _remove_time(input_text) == expected.strip() + + def test_create_table(self): + sql = "CREATE TABLE IF NOT EXISTS users (id INT, name STRING)" + assert _any_modify_keywords_in_sql(sql) is True + + def test_create_view(self): + sql = "CREATE OR REPLACE TEMP VIEW v_users AS SELECT * FROM users" + assert _any_modify_keywords_in_sql(sql) is True + + def test_insert_into(self): + sql = "INSERT INTO users SELECT * FROM new_users" + assert _any_modify_keywords_in_sql(sql) is True + + def test_insert_overwrite(self): + sql = "INSERT OVERWRITE TABLE warehouse SELECT * FROM staging" + assert _any_modify_keywords_in_sql(sql) is True + + def test_delete_statement(self): + sql = "DELETE FROM logs WHERE level = 'DEBUG'" + assert _any_modify_keywords_in_sql(sql) is True + + def test_update_statement(self): + sql = "UPDATE customers SET status = 'active' WHERE last_login > NOW()" + assert _any_modify_keywords_in_sql(sql) is True + + def test_merge_into(self): + sql = """ + MERGE INTO customers USING updates + ON customers.id = updates.id + WHEN MATCHED THEN UPDATE SET customers.name = updates.name + WHEN NOT MATCHED THEN INSERT (id, name) VALUES (updates.id, updates.name) + """ + assert _any_modify_keywords_in_sql(sql) is True + + def test_drop_table(self): + sql = "DROP TABLE IF EXISTS temp_table" + assert _any_modify_keywords_in_sql(sql) is True + + def test_truncate_table(self): + sql = "TRUNCATE TABLE big_table" + assert _any_modify_keywords_in_sql(sql) is True + + def test_multiple_statements(self): + sql = """ + CREATE TABLE tmp AS SELECT * FROM raw; + INSERT INTO final SELECT * FROM tmp; + """ + assert _any_modify_keywords_in_sql(sql) is True + + def test_create_database(self): + sql = "CREATE DATABASE IF NOT EXISTS my_db" + assert _any_modify_keywords_in_sql(sql) is True + + def test_comment_with_real_sql(self): + sql = """ + -- Let's drop the table + DROP TABLE users + """ + assert _any_modify_keywords_in_sql(sql) is True # 注释中是“注释”但 SQL 在有效区域 + + def test_select_only(self): + sql = "SELECT * FROM customers WHERE region = 'US'" + assert _any_modify_keywords_in_sql(sql) is False + + def test_with_cte_select(self): + sql = """ + WITH recent_orders AS ( + SELECT * FROM orders WHERE order_date > '2025-01-01' + ) + SELECT * FROM recent_orders + """ + assert _any_modify_keywords_in_sql(sql) is False + + def test_keywords_in_string(self): + """ + 注意,这里我们会选择粗暴的认为这是危险操作 + """ + sql = "SELECT 'drop table users' AS msg" + assert _any_modify_keywords_in_sql(sql) is True + + def test_keywords_as_column_names(self): + """ + 注意,这里我们会选择粗暴的认为这是危险操作 + """ + sql = "SELECT create, insert, delete FROM keyword_table" + assert _any_modify_keywords_in_sql(sql) is True + + def test_keywords_as_aliases(self): + """ + 注意,这里我们会选择粗暴的认为这是危险操作 + """ + sql = "SELECT user_id AS update, name AS delete FROM users" + assert _any_modify_keywords_in_sql(sql) is True + + def test_function_names_overlap_keywords(self): + """ + 注意,这里 'update_' 'merge_'与并非有效修改、合并操作 + """ + sql = "SELECT update_time(), merge_fields(name, age) FROM meta" + assert _any_modify_keywords_in_sql(sql) is False + + def test_sql_with_comments_only_keywords(self): + """ + 注意,这里我们会选择粗暴的认为这是危险操作 + """ + sql = """ + -- DELETE FROM customers + -- DROP TABLE logs + SELECT * FROM users + """ + assert _any_modify_keywords_in_sql(sql) is True + + def test_empty_sql(self): + sql = "" + assert _any_modify_keywords_in_sql(sql) is False + + def test_whitespace_sql(self): + sql = " \n\t " + assert _any_modify_keywords_in_sql(sql) is False + + def test_sql_with_semicolon_and_comment(self): + """ + 注意,这里我们会选择粗暴的认为这是危险操作 + """ + sql = "SELECT * FROM table; -- insert into table" + assert _any_modify_keywords_in_sql(sql) is True + + def test_multiple_statement_sql_file(self): + """ + 测试多语句的情况 + """ + sql = """ + -- 创建表 + CREATE TABLE IF NOT EXISTS tmp AS SELECT * FROM raw_data; + + -- 插入 + INSERT INTO final SELECT * FROM tmp; + + -- 查询结果 + SELECT COUNT(*) FROM final; + + -- 删除临时表 + DROP TABLE tmp; + """ + assert _any_modify_keywords_in_sql(sql) is True diff --git a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py index fd4000d41ef6b7f3eb269b414ee44cd883bddaf8..b442fbc666443ea154659fd953347fc37930a926 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py @@ -56,7 +56,8 @@ class TestExamRecordRepository: name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, default_config={'param1': 'value2'}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) self.config = {'param1': 'value2'} self.tuning_record = TuningRecordRepository.create( diff --git a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py index 745cc1e6828c2dcbd344a450588143392faa9c25..099a44402f3c6194806fe3430bb278434f8e0793 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py @@ -34,7 +34,8 @@ class TestLoadRepository: name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, default_config={'param1': 'value2'}, - hash_value='0607' + hash_value='0607', + backend_retest_forbidden=True ) assert isinstance(load, Load) assert load.database_model.name == 'example_name' @@ -48,7 +49,7 @@ class TestLoadRepository: # 输入无效数据 with pytest.raises(ValueError): - LoadRepository.create(name='', exec_attr={}, default_config={}, hash_value='') + LoadRepository.create(name='', exec_attr={}, default_config={}, hash_value='', backend_retest_forbidden=True) def test_query_by_hash_value(self): # 创建测试数据 @@ -58,7 +59,8 @@ class TestLoadRepository: name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, default_config={'param1': 'value2'}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) # 查询负载 @@ -78,7 +80,8 @@ class TestLoadRepository: name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, default_config={'param1': 'value2'}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) # 更新最优配置 @@ -95,7 +98,8 @@ class TestLoadRepository: name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, default_config={'param1': 'value2'}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) # 更新测试配置 @@ -112,7 +116,8 @@ class TestLoadRepository: name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, default_config={'param1': 'value2'}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) # 更新测试配置 diff --git a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py index 9283d8e3c9a4f066a0ce6a03b3aea11bb115d698..52c25a993090b3d03f2ec89a5f8c7b81659722f5 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py @@ -38,7 +38,8 @@ class TestTaskRepository: name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, default_config={"param1": "value2"}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) # 创建TuningRecord @@ -71,7 +72,8 @@ class TestTaskRepository: name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, default_config={"param1": "value2"}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) tuning_record = TuningRecordRepository.create( load=load, @@ -94,7 +96,8 @@ class TestTaskRepository: name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, default_config={"param1": "value2"}, - hash_value=hash_value + hash_value=hash_value, + backend_retest_forbidden=True ) tuning_record = TuningRecordRepository.create( load=load, diff --git a/omniadvisor/tests/omniadvisor/utils/test_utils.py b/omniadvisor/tests/omniadvisor/utils/test_utils.py index 81f7ab29d64dc9df42ff459e5ba8788b932750ed..cd6293ac89e1d2ace444c91af07d934e25714185 100644 --- a/omniadvisor/tests/omniadvisor/utils/test_utils.py +++ b/omniadvisor/tests/omniadvisor/utils/test_utils.py @@ -31,51 +31,48 @@ class TestRunCmd: class TestTraceDataSaver: + @patch('builtins.open', new_callable=mock_open, create=True) @patch('os.makedirs') @patch('uuid.uuid4', return_value="test-uuid") - def test_save_trace_data_success(self, mock_uuid, mock_makedirs): + def test_save_trace_data_success(self, mock_uuid, mock_makedirs, mock_file): data = [{"key": "value"}] data_dir = "/tmp" - m = mock_open() - with patch('builtins.open', m, create=True): - file_path = save_trace_data(data, data_dir) + + file_path = save_trace_data(data, data_dir) expected_path = f"{data_dir}/test-uuid" assert file_path == expected_path - m.assert_called_once_with(expected_path, 'w', encoding='utf-8') + mock_file.assert_called_once_with(expected_path, 'w', encoding='utf-8') + @patch('builtins.open', new_callable=mock_open, create=True) + @patch('uuid.uuid4', return_value="test-uuid") @patch('os.makedirs') - def test_save_trace_data_ioerror(self, mock_makedirs): + def test_save_trace_data_ioerror(self, mock_makedirs, mock_uuid, mock_open_file): data = [{"key": "value"}] data_dir = "/tmp" - m = mock_open() - m.side_effect = IOError("IO Error") - with patch('os.makedirs'), \ - patch('uuid.uuid4', return_value="test-uuid"), \ - patch('builtins.open', m, create=True): - try: - save_trace_data(data, data_dir) - assert False, "Expected an IOError to be raised." - except IOError as e: - assert str(e) == "出现IO错误: IO Error", f"Unexpected error message: {str(e)}" + # 设置 open 的 side_effect + mock_open_file.side_effect = IOError("IO Error") + try: + save_trace_data(data, data_dir) + assert False, "Expected an IOError to be raised." + except IOError as e: + assert str(e) == "出现IO错误: IO Error", f"Unexpected error message: {str(e)}" @patch('os.makedirs') - def test_save_trace_data_exception(self, mock_makedirs): + @patch('uuid.uuid4', return_value="test-uuid") + @patch('builtins.open', new_callable=mock_open) + def test_save_trace_data_exception(self, mock_open, mock_uuid, mock_makedirs): data = [{"key": "value"}] data_dir = "/tmp" - m = mock_open() - m.side_effect = Exception("Unexpected error") - - with patch('os.makedirs'), \ - patch('uuid.uuid4', return_value="test-uuid"), \ - patch('builtins.open', m, create=True): - try: - save_trace_data(data, data_dir) - assert False, "Expected an Exception to be raised." - except Exception as e: - assert str(e) == "保存过程中出现错误: Unexpected error", f"Unexpected error message: {str(e)}" + mock_open.side_effect = Exception("Unexpected error") # 设置 mock_open 抛出异常 + + try: + save_trace_data(data, data_dir) + assert False, "Expected an Exception to be raised." + except Exception as e: + assert str(e) == "保存过程中出现错误: Unexpected error", f"Unexpected error message: {str(e)}" class TestFloatFormat: