diff --git a/omniadvisor/src/common/constant.py b/omniadvisor/src/common/constant.py index d154ee6ae74f9f1f0bca5816b4f09299e60e370c..7994e26442318049791b055ada7397a0520c92ca 100644 --- a/omniadvisor/src/common/constant.py +++ b/omniadvisor/src/common/constant.py @@ -119,5 +119,17 @@ class OmniAdvisorConf: # 保留小数位数 decimal_digits = 3 + # spark 的 name、sql中日期匹配格式,最高精确度是小时 + date_patterns = [ + # YYYY-MM-DD 或 YYYY-MM-DD HH + r'\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])(?:\s+(0\d|1\d|2[0-3]))?', + # YYYY/MM/DD 或 YYYY/MM/DD HH + r'\d{4}/(0[1-9]|1[0-2])/(0[1-9]|[12]\d|3[01])(?:\s+(0\d|1\d|2[0-3]))?', + # YYYY.MM.DD 或 YYYY.MM.DD HH + r'\d{4}\.(0[1-9]|1[0-2])\.(0[1-9]|[12]\d|3[01])(?:\s+(0\d|1\d|2[0-3]))?', + # 无分隔:YYYYMMDD 或 YYYYMMDDHH + r'\d{4}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])((0\d|1\d|2[0-3]))?' + ] + OA_CONF = OmniAdvisorConf() diff --git a/omniadvisor/src/common/exceptions.py b/omniadvisor/src/common/exceptions.py index 6c92d4dc8c9a8b3c7fd03df5f020e9ca27649203..b23ef4c8241a8f1df19f0ed30255d68dcb5c363a 100644 --- a/omniadvisor/src/common/exceptions.py +++ b/omniadvisor/src/common/exceptions.py @@ -9,4 +9,11 @@ class SystemKilledError(Exception): """ 异常,程序被异常中断,例如kill等 """ - pass \ No newline at end of file + pass + + +class DuplicateEntryError(Exception): + """ + 数据重复异常 + """ + pass diff --git a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py index bc6980a4eb9ef998a4eddc9338fd9049e940704c..e3d1cc8092500b232fc6f492ce327b157f440938 100644 --- a/omniadvisor/src/omniadvisor/interface/hijack_recommend.py +++ b/omniadvisor/src/omniadvisor/interface/hijack_recommend.py @@ -1,40 +1,23 @@ -import sys +import hashlib +import json import multiprocessing +import re +import sys from common.constant import OA_CONF -from omniadvisor.service.spark_service.spark_cmd_parser import SparkCMDParser -from omniadvisor.repository.model.load import Load from omniadvisor.repository.load_repository import LoadRepository +from omniadvisor.repository.model.load import Load from omniadvisor.repository.tuning_record_repository import TuningRecordRepository +from omniadvisor.service.spark_service.spark_cmd_parser import SparkCMDParser +from omniadvisor.service.spark_service.spark_run import spark_run from omniadvisor.service.tuning_result.tuning_result import get_tuning_result from omniadvisor.service.tuning_result.tuning_result_history import get_tuning_result_history from omniadvisor.utils.logger import global_logger -from omniadvisor.service.spark_service.spark_run import spark_run - from omniadvisor.utils.utils import float_format - -def _query_or_create_load(name: str, exec_attr: dict, default_config: dict): - """ - 根据名称和默认配置,查询或创建负载 - - :param name: 负载名称 - :param exec_attr: 执行属性 - :param default_config: 默认参数配置 - :return: 负载实例 - """ - # 从负载数据库中查询是否有相关联的Load - loads = LoadRepository.query_by_exec_attr_and_default_config(exec_attr=exec_attr, default_config=default_config) - # 如果查询不到负载信息,则创建新的负载 - if not loads: - global_logger.info("Load not found in database, create new one and execute.") - load = LoadRepository.create(name=name, exec_attr=exec_attr, default_config=default_config) - TuningRecordRepository.create(load=load, config=default_config, method=OA_CONF.TuningMethod.user) - else: - global_logger.info("Load found in database, ready to get config to execute.") - load = loads.pop() - - return load +_NAME_STR = 'name' +_USER_CONFIG_STR = 'user_config' +_FILE_CONTENT_STR = 'file_content' def _get_exec_config_from_load(load: Load): @@ -110,22 +93,82 @@ def _process_load_config(load: Load, config: dict): pass +def _get_load_name_from_exec_attr(exec_attr: dict): + """ + 获取任务名,无则空 + """ + return exec_attr.get('name', '') + + +def _calculate_hash_value(exec_attr: dict, user_config: dict): + """ + 根据 执行参数和用户配置计算负载唯一标识,hash值 + """ + class_info = exec_attr.get('class', '') + if 'SparkSQLCLIDriver' in class_info: + # spark-sql,则提取 -e 和 -f,替换 + if 'e' in exec_attr: + file_content = _remove_time(exec_attr.get('e', '')) + else: + file_content = exec_attr.get('f', '') + else: + # spark-submit,则提取 file 及之后的内容,包含 file的入参 + unknown_content = exec_attr.get('unknown', []) + unknown_content = [item.strip() for item in unknown_content] + if len(unknown_content) > 0: + file_content = unknown_content[0] + else: + file_content = '' + + name = _get_load_name_from_exec_attr(exec_attr) + data = { + _NAME_STR: _remove_time(name), + _USER_CONFIG_STR: user_config, + _FILE_CONTENT_STR: file_content + } + json_str = json.dumps(data, sort_keys=True) + hash_value = hashlib.sha256(json_str.encode()).hexdigest() + return hash_value + + +def _remove_time(content: str): + """ + 移除时间信息,可精确到小时 + """ + for pattern in OA_CONF.date_patterns: + content = re.sub(pattern, '{date}', content) + return content.strip() + + +def _create_or_update_load(exec_attr: dict, default_config: dict): + """ + 无则创建,有则更新 name 和 exec_attr 字段。如果日期信息等一直保持不变,后期复测时,会导致测试出现异常 + """ + name = _get_load_name_from_exec_attr(exec_attr) + hash_value = _calculate_hash_value(exec_attr, default_config) + loads = LoadRepository.query_by_hash_value(hash_value) + if loads: + global_logger.info("Load found in database, update and ready to get config to execute.") + return LoadRepository.update_name_and_exec_attr(loads.pop(), name, exec_attr) + else: + # 负载没有创建过 + load = LoadRepository.create(name, exec_attr, default_config, hash_value) + global_logger.info("Load not found in database, created new one and execute.") + TuningRecordRepository.create(load=load, config=default_config, method=OA_CONF.TuningMethod.user) + return load + + def hijack_recommend(argv: list): """ 任务劫持,使能参数下发执行任务 - :param argv: Spark执行命令字段 """ # 获取用户传入的Spark命令 并解析命令 global_logger.debug("Hijack input params: %s", argv) exec_attr, user_config = SparkCMDParser.parse_cmd(argv=argv) - # 提取任务名字 - if 'name' not in exec_attr.keys(): - raise ValueError(f'Task name not in Spark submit cmd!') - name = exec_attr['name'] # 查询或创建相应负载 - load = _query_or_create_load(name=name, exec_attr=exec_attr, default_config=user_config) + load = _create_or_update_load(exec_attr=exec_attr, default_config=user_config) # 获取待执行参数配置 exec_config = _get_exec_config_from_load(load=load) diff --git a/omniadvisor/src/omniadvisor/repository/load_repository.py b/omniadvisor/src/omniadvisor/repository/load_repository.py index 940554a0f88448cf360cb42d45d493fdea5b2b75..28448fd5c78fc4fc66ec3002458ebfbf89183dcc 100644 --- a/omniadvisor/src/omniadvisor/repository/load_repository.py +++ b/omniadvisor/src/omniadvisor/repository/load_repository.py @@ -1,3 +1,6 @@ +from django.db import IntegrityError + +from common.exceptions import DuplicateEntryError from server.app.models import DatabaseLoad from omniadvisor.repository.model.load import Load from omniadvisor.repository.repository import Repository @@ -22,7 +25,7 @@ class LoadRepository(Repository): ] # 用于存储字段格式 _fields_format = { - Load.FieldName.name: ([str], [''], []), + Load.FieldName.name: ([str], [], []), Load.FieldName.exec_attr: ([dict], [dict()], []), Load.FieldName.default_config: ([dict], [], []), Load.FieldName.best_config: ([dict], [], []), @@ -30,59 +33,77 @@ class LoadRepository(Repository): } @classmethod - def create(cls, name: str, exec_attr: dict, default_config: dict): + def create(cls, name: str, exec_attr: dict, default_config: dict, hash_value: str): """ 指定名称、执行属性和默认配置,新增负载 :param name: 负载名称 :param exec_attr: 负载执行属性 :param default_config: 负载默认配置 + :param hash_value: hash值 :return: 负载实例 """ - # 判断是否有相同name和default_config的Load存在 - model_attr = { - Load.FieldName.exec_attr: exec_attr, - Load.FieldName.default_config: default_config - } - database_loads = cls._query(model_attr=model_attr) - # 强制要求name和default_config组合是唯一的 - if len(database_loads) != 0: - raise RuntimeError(f'Create {cls._model_class.__name__} fail because combination of exec_attr {exec_attr} ' - f'and default_config {default_config} exist.') - model_attr = { Load.FieldName.name: name, Load.FieldName.exec_attr: exec_attr, - Load.FieldName.default_config: default_config + Load.FieldName.default_config: default_config, + Load.FieldName.hash_value: hash_value } - database_load = cls._create(model_attr=model_attr) + try: + database_load = cls._create(model_attr=model_attr) + except IntegrityError as e: + raise DuplicateEntryError( + f'Create {cls._model_class.__name__} fail because hash_value {hash_value} exist.' + ) from e return Load(database_model=database_load) @classmethod - def query_by_exec_attr_and_default_config(cls, exec_attr: dict, default_config: dict): + def query_by_id(cls, load_id: str) -> list: """ - 指定名称和执行属性,查询负载 - (暂定以这两个信息进行查询,具体视现网情况而定) - - :param exec_attr: 负载的执行参数 - :param default_config: 负载默认配置 - :return: 负载实例列表 + 根据 load_id 查询 load + :param load_id: 负载id + :return: 负载列表 """ model_attr = { - Load.FieldName.exec_attr: exec_attr, - Load.FieldName.default_config: default_config + Load.FieldName.id: load_id, } database_loads = cls._query(model_attr=model_attr) - # 强制要求exec_attr和default_config组合是唯一的 - if len(database_loads) > 1: - raise RuntimeError(f'Find {cls._model_class.__name__} combination of name {exec_attr} and default_config ' - f'{default_config} is not unique.') + return [ + Load(database_model=database_load) for database_load in database_loads + ] + @classmethod + def query_by_hash_value(cls, hash_value: str): + """ + 根据hash值,查找负载实例 + :param hash_value: 负载实例 + :return: 负载实例 + """ + model_attr = { + Load.FieldName.hash_value: hash_value + } + database_loads = cls._query(model_attr=model_attr) return [ Load(database_model=database_load) for database_load in database_loads ] + @classmethod + def update_name_and_exec_attr(cls, load: Load, name: str, exec_attr: dict): + """ + 更新和时间相关的信息 + :param load: 负载 + :param name: 负载名称 + :param exec_attr: 负载执行属性 + :return: 负载实例 + """ + model_attr = { + Load.FieldName.name: name, + Load.FieldName.exec_attr: exec_attr + } + database_load = cls._update(load.database_model, model_attr=model_attr) + return Load(database_model=database_load) + @classmethod def update_best_config(cls, load: Load, best_config: dict): """ @@ -135,23 +156,3 @@ class LoadRepository(Repository): model_attr=model_attr ) return Load(database_model=database_load) - - @classmethod - def query_by_id(cls, load_id: str) -> list: - """ - 根据 load_id 查询 load - :param load_id: 负载id - :return: 负载列表 - """ - model_attr = { - Load.FieldName.id: load_id, - } - database_loads = cls._query(model_attr=model_attr) - - # ID应该是唯一的 - if len(database_loads) > 1: - raise RuntimeError(f'Find {cls._model_class.__name__} id {load_id} is not unique.') - - return [ - Load(database_model=database_load) for database_load in database_loads - ] diff --git a/omniadvisor/src/omniadvisor/repository/model/load.py b/omniadvisor/src/omniadvisor/repository/model/load.py index 01d909ea174b3b549f8cf5d95d4b4a045f866189..ca21c0f4cb6a6d96ba1bacfbc4e36a9e6ad0cd4a 100644 --- a/omniadvisor/src/omniadvisor/repository/model/load.py +++ b/omniadvisor/src/omniadvisor/repository/model/load.py @@ -58,3 +58,4 @@ class Load: test_config = 'test_config' create_time = 'create_time' tuning_needed = 'tuning_needed' + hash_value = 'hash_value' diff --git a/omniadvisor/src/omniadvisor/repository/repository.py b/omniadvisor/src/omniadvisor/repository/repository.py index 1ddaaa2dea897fd69625f29c21c3ecbdb9095387..8b118dfec4a80e0950bf16ba33d0bab04b0bf76d 100644 --- a/omniadvisor/src/omniadvisor/repository/repository.py +++ b/omniadvisor/src/omniadvisor/repository/repository.py @@ -50,7 +50,7 @@ class Repository(ABC): """ for field in cls._frozen_fields: if field in model_attr.keys(): - raise ValueError(f'The {cls._model_class.__name__} filed {field} can not be changed.') + raise ValueError(f'The {cls._model_class.__name__} field {field} can not be changed.') @classmethod def _check_required_field(cls, model_attr: dict): @@ -61,7 +61,7 @@ class Repository(ABC): """ for field in cls._required_fields: if field not in model_attr.keys(): - raise ValueError(f'The {cls._model_class.__name__} filed {field} is required.') + raise ValueError(f'The {cls._model_class.__name__} field {field} is required.') @classmethod def _check_field_format(cls, model_attr: dict): @@ -70,17 +70,17 @@ class Repository(ABC): :param model_attr: Model属性字典 """ - for field, filed_format in cls._fields_format.items(): + for field, field_format in cls._fields_format.items(): if field not in model_attr.keys(): continue - filed_types, filed_forbidden, filed_allow = filed_format - if not isinstance(model_attr[field], tuple(filed_types)): - raise TypeError(f'Type of {cls._model_class.__name__} filed {field} must in {filed_types}.') - if model_attr[field] in filed_forbidden: - raise ValueError(f'Value of {cls._model_class.__name__} filed {field} cannot be in {filed_forbidden}.') - if filed_allow and model_attr[field] not in filed_allow: - raise ValueError(f'Value of {cls._model_class.__name__} filed {field} must be in {filed_allow}.') + field_types, field_forbidden, field_allow = field_format + if not isinstance(model_attr[field], tuple(field_types)): + raise TypeError(f'Type of {cls._model_class.__name__} field {field} must in {field_types}.') + if model_attr[field] in field_forbidden: + raise ValueError(f'Value of {cls._model_class.__name__} field {field} cannot be in {field_forbidden}.') + if field_allow and model_attr[field] not in field_allow: + raise ValueError(f'Value of {cls._model_class.__name__} field {field} must be in {field_allow}.') @classmethod def _create(cls, model_attr: dict): diff --git a/omniadvisor/src/server/app/models.py b/omniadvisor/src/server/app/models.py index cfb1e071c3bcc60b06148f5b93ef902426d7457e..49f59be28678891af050d3b897ecbb49257f43be 100644 --- a/omniadvisor/src/server/app/models.py +++ b/omniadvisor/src/server/app/models.py @@ -9,18 +9,16 @@ class DatabaseLoad(models.Model): 负载模型 """ id = models.AutoField(primary_key=True) - name = models.CharField(max_length=100, null=False) + name = models.CharField(max_length=100, null=True) exec_attr = models.JSONField(null=False) default_config = models.JSONField(null=False) best_config = models.JSONField(null=True) test_config = models.JSONField(null=True) create_time = models.DateTimeField(auto_now_add=True) tuning_needed = models.BooleanField(default=True) + hash_value = models.CharField(max_length=64, null=False, unique=True) class Meta: - constraints = [ - models.UniqueConstraint(fields=['exec_attr', 'default_config'], name='load_unique') - ] db_table = 'omniadvisor_load' # 自定义表名 diff --git a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py index 69fe108c6faa420f044cfe4fc5818f644bd67b26..2bf198fcf50799c143883f867c71779687f68b07 100644 --- a/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py +++ b/omniadvisor/tests/omniadvisor/interface/test_hijack_recommend.py @@ -1,21 +1,23 @@ -import unittest -import pytest from unittest.mock import MagicMock, patch +import pytest + from common.constant import OA_CONF -from omniadvisor.interface.hijack_recommend import hijack_recommend, _process_load_config +from common.exceptions import DuplicateEntryError +from omniadvisor.interface.hijack_recommend import hijack_recommend, _process_load_config, _create_or_update_load, \ + _remove_time, _calculate_hash_value # 测试代码 -class TestHijackRecommend(unittest.TestCase): +class TestHijackRecommend: # 场景 1: 正常流程,Spark 执行成功 @patch("omniadvisor.interface.hijack_recommend._process_load_config") @patch("omniadvisor.interface.hijack_recommend.spark_run") @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_successful_execution(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + def test_successful_execution(self, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run, mock_process_config): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "test_job"} @@ -24,7 +26,7 @@ class TestHijackRecommend(unittest.TestCase): exec_config = user_config mock_parse_cmd.return_value = (exec_attr, user_config) - mock_query_load.return_value = load + mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config mock_spark_run.return_value = (MagicMock(status="success"), "job output") @@ -37,9 +39,9 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._process_load_config") @patch("omniadvisor.interface.hijack_recommend.spark_run") @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_fallback_to_safe_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + def test_fallback_to_safe_config(self, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} @@ -48,7 +50,7 @@ class TestHijackRecommend(unittest.TestCase): exec_config = {"spark.executor.memory": "8g"} mock_parse_cmd.return_value = (exec_attr, user_config) - mock_query_load.return_value = load + mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config mock_spark_run.side_effect = [ @@ -67,9 +69,9 @@ class TestHijackRecommend(unittest.TestCase): @patch("omniadvisor.interface.hijack_recommend._process_load_config") @patch("omniadvisor.interface.hijack_recommend.spark_run") @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_failed_user_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + def test_failed_user_config(self, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run, mock_process_config): argv = ["--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job_name"} @@ -78,7 +80,7 @@ class TestHijackRecommend(unittest.TestCase): exec_config = user_config mock_parse_cmd.return_value = (exec_attr, user_config) - mock_query_load.return_value = load + mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config mock_spark_run.return_value = (MagicMock(status="fail"), "fail output") @@ -87,23 +89,14 @@ class TestHijackRecommend(unittest.TestCase): assert mock_spark_run.call_count == 1 mock_process_config.assert_not_called() - # 场景 4: 缺少任务名称 - @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_missing_task_name(self, mock_parse_cmd): - argv = ["--conf", "spark.executor.memory=4g"] - mock_parse_cmd.return_value = ({}, {"spark.executor.memory": "4g"}) - - with pytest.raises(ValueError, match="Task name not in Spark submit cmd"): - hijack_recommend(argv) - # 场景 5: 执行的是 test_config,触发 _process_load_config @patch('omniadvisor.service.spark_service.spark_run.multiprocessing.Process') @patch("omniadvisor.interface.hijack_recommend._process_load_config") @patch("omniadvisor.interface.hijack_recommend.spark_run") @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_process_test_config(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run, + def test_process_test_config(self, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run, mock_process_config, mock_multiprocess): argv = ["--class", "Job", "--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job"} @@ -113,7 +106,7 @@ class TestHijackRecommend(unittest.TestCase): exec_config = test_config mock_parse_cmd.return_value = (exec_attr, user_config) - mock_query_load.return_value = load + mock_create_or_update_load.return_value = load mock_get_config.return_value = exec_config mock_spark_run.return_value = (MagicMock(status=OA_CONF.ExecStatus.success), "job output") mock_process = MagicMock() @@ -129,11 +122,11 @@ class TestHijackRecommend(unittest.TestCase): hijack_recommend(["--conf", "x"]) # 场景 7: 创建 Load 异常 - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_query_create_load_raises(self, mock_parse_cmd, mock_query_load): + def test_query_create_load_raises(self, mock_parse_cmd, mock_create_or_update_load): mock_parse_cmd.return_value = ({"name": "job"}, {}) - mock_query_load.side_effect = Exception("db error") + mock_create_or_update_load.side_effect = Exception("db error") with pytest.raises(Exception, match="db error"): hijack_recommend(["--conf", "x"]) @@ -141,16 +134,16 @@ class TestHijackRecommend(unittest.TestCase): # 场景 8: spark_run 抛异常 @patch("omniadvisor.interface.hijack_recommend.spark_run") @patch("omniadvisor.interface.hijack_recommend._get_exec_config_from_load") - @patch("omniadvisor.interface.hijack_recommend._query_or_create_load") + @patch("omniadvisor.interface.hijack_recommend._create_or_update_load") @patch("omniadvisor.interface.hijack_recommend.SparkCMDParser.parse_cmd") - def test_spark_run_raises(self, mock_parse_cmd, mock_query_load, mock_get_config, mock_spark_run): + def test_spark_run_raises(self, mock_parse_cmd, mock_create_or_update_load, mock_get_config, mock_spark_run): argv = ["--conf", "spark.executor.memory=4g"] exec_attr = {"name": "job"} config = {"spark.executor.memory": "4g"} load = MagicMock(test_config={}, default_config=config) mock_parse_cmd.return_value = (exec_attr, config) - mock_query_load.return_value = load + mock_create_or_update_load.return_value = load mock_get_config.return_value = config mock_spark_run.side_effect = RuntimeError("spark failed") @@ -212,3 +205,109 @@ class TestHijackRecommend(unittest.TestCase): _process_load_config(load=load, config=test_config) mock_update_test_config.assert_not_called() mock_tuning_result_history.refresh_best_config.assert_not_called() + + @patch("omniadvisor.repository.load_repository.LoadRepository.query_by_hash_value") + @patch("omniadvisor.repository.tuning_record_repository.TuningRecordRepository.create") + @patch("omniadvisor.repository.load_repository.LoadRepository.create") + def test_create_or_update_load_when_create_success(self, mock_load_create, mock_tuning_record_create, + mock_query_by_hash_value): + name = "test_exec" + exec_attr = {"name": name, "cmd": "run"} + default_config = {"param": 123} + hash_value = _calculate_hash_value(exec_attr, default_config) + mock_load_create.return_value = "created_object" + mock_query_by_hash_value.return_value = [] + result = _create_or_update_load(exec_attr, default_config) + mock_load_create.assert_called_once_with(name, exec_attr, default_config, hash_value) + mock_tuning_record_create.assert_called_once() + assert result == "created_object" + + # 测试:创建抛出异常 -> 更新路径 + @patch("omniadvisor.repository.load_repository.LoadRepository.update_name_and_exec_attr") + @patch("omniadvisor.repository.load_repository.LoadRepository.query_by_hash_value") + @patch("omniadvisor.repository.load_repository.LoadRepository.create", side_effect=DuplicateEntryError) + def test_create_or_update_load_when_create_failed(self, mock_create, mock_query_by_hash_value, mock_update): + name = "test_exec" + exec_attr = {"name": name, "cmd": "run"} + default_config = {"param": 123} + hash_value = _calculate_hash_value(exec_attr, default_config) + mock_load = MagicMock() + + mock_query_by_hash_value.return_value = [mock_load] + mock_update.return_value = "updated_object" + + result = _create_or_update_load(exec_attr, default_config) + + mock_query_by_hash_value.assert_called_once_with(hash_value) + mock_update.assert_called_once_with(mock_load, name, exec_attr) + assert result == "updated_object" + + def test_remove_date_yyyy_mm_dd(self): + input_text = "任务截止时间是 2025-08-01" + expected = "任务截止时间是 {date}" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_yyyy_mm_dd_hour(self): + input_text = "请在 2025-08-01 14 前完成" + expected = "请在 {date} 前完成" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_yyyy_slash_mm_slash_dd(self): + input_text = "报告日期为2025/08/01,请查阅" + expected = "报告日期为{date},请查阅" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_yyyy_slash_mm_slash_dd_hour(self): + input_text = "开始于2025/08/01 09,持续3小时" + expected = "开始于{date},持续3小时" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_yyyy_dot_mm_dot_dd(self): + input_text = "会议安排:2025.08.01 开始" + expected = "会议安排:{date} 开始" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_yyyy_dot_mm_dot_dd_hour(self): + input_text = "演讲时间:2025.08.01 17" + expected = "演讲时间:{date}" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_compact_format(self): + input_text = "任务编号:2025080114,已完成" + expected = "任务编号:{date},已完成" + assert _remove_time(input_text) == expected.strip() + + def test_remove_multiple_dates(self): + input_text = "事件发生在2025-08-01 09,与2025/08/01 10无关" + expected = "事件发生在{date},与{date}无关" + assert _remove_time(input_text) == expected.strip() + + def test_remove_when_no_date_present(self): + input_text = "这是一段没有时间信息的文本" + expected = input_text + assert _remove_time(input_text) == expected.strip() + + def test_remove_from_empty_string(self): + input_text = "" + expected = "" + assert _remove_time(input_text) == expected.strip() + + def test_remove_from_spaces_only(self): + input_text = " " + expected = "" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_at_start(self): + input_text = "2025-08-01 是个重要日子" + expected = "{date} 是个重要日子" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_at_end(self): + input_text = "截止日期是 2025/08/01" + expected = "截止日期是 {date}" + assert _remove_time(input_text) == expected.strip() + + def test_remove_date_embedded_in_filename(self): + input_text = "文件名:report_2025080112_final.pdf" + expected = "文件名:report_{date}_final.pdf" + assert _remove_time(input_text) == expected.strip() diff --git a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py index 46b66689dd38d78ac516c5bda3fc5a38b9c1ca5a..fd4000d41ef6b7f3eb269b414ee44cd883bddaf8 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_exam_record_repository.py @@ -1,6 +1,7 @@ import pytest from datetime import datetime +from omniadvisor.interface.hijack_recommend import _calculate_hash_value from .common import create_table, delete_table from common.constant import OA_CONF from server.app.models import ( @@ -50,10 +51,12 @@ class TestExamRecordRepository: DatabaseLoad.objects.all().delete() # 构建测试数据 + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) self.load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, - default_config={'param1': 'value2'} + default_config={'param1': 'value2'}, + hash_value=hash_value ) self.config = {'param1': 'value2'} self.tuning_record = TuningRecordRepository.create( diff --git a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py index afad5b3debb781c379d7b8f99830a1e71586ba77..745cc1e6828c2dcbd344a450588143392faa9c25 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_load_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_load_repository.py @@ -1,5 +1,6 @@ import pytest +from omniadvisor.interface.hijack_recommend import _calculate_hash_value from .common import create_table from .common import delete_table from omniadvisor.repository.load_repository import DatabaseLoad @@ -32,12 +33,14 @@ class TestLoadRepository: load = LoadRepository.create( name='example_name', exec_attr={'name': 'example_name', 'cpu': 4}, - default_config={"param1": "value2"} + default_config={'param1': 'value2'}, + hash_value='0607' ) assert isinstance(load, Load) assert load.database_model.name == 'example_name' assert load.database_model.exec_attr == {'name': 'example_name', 'cpu': 4} - assert load.database_model.default_config == {"param1": "value2"} + assert load.database_model.default_config == {'param1': 'value2'} + assert load.database_model.hash_value == '0607' # 查询数据库中是否有新增数据 database_loads = DatabaseLoad.objects.filter(name='example_name') @@ -45,72 +48,76 @@ class TestLoadRepository: # 输入无效数据 with pytest.raises(ValueError): - LoadRepository.create(name="", exec_attr={}, default_config={}) + LoadRepository.create(name='', exec_attr={}, default_config={}, hash_value='') - def test_query(self): + def test_query_by_hash_value(self): # 创建测试数据 - load = LoadRepository.create( + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) + + LoadRepository.create( name='example_name', - exec_attr={'name': "example_name", 'cpu': 4}, - default_config={"param1": "value2"} + exec_attr={'name': 'example_name', 'cpu': 4}, + default_config={'param1': 'value2'}, + hash_value=hash_value ) # 查询负载 - loads = LoadRepository.query_by_exec_attr_and_default_config( - exec_attr={'name': "example_name", 'cpu': 4}, - default_config={"param1": "value2"} - ) + loads = LoadRepository.query_by_hash_value(hash_value) assert len(loads) == 1 - assert loads[0].exec_attr == {'name': "example_name", 'cpu': 4} + assert loads[0].exec_attr == {'name': 'example_name', 'cpu': 4} # 查询失败 - loads = LoadRepository.query_by_exec_attr_and_default_config( - exec_attr={"name": "wrong_name"}, - default_config={"param1": "value2"} - ) + wrong_hash_value = _calculate_hash_value(exec_attr={'name': 'wrong_name'}, user_config={'param1': 'value2'}) + loads = LoadRepository.query_by_hash_value(wrong_hash_value) assert len(loads) == 0 def test_update_best_config(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', - exec_attr={'name': "example_name", 'cpu': 4}, - default_config={"param1": "value2"} + exec_attr={'name': 'example_name', 'cpu': 4}, + default_config={'param1': 'value2'}, + hash_value=hash_value ) # 更新最优配置 updated_load = LoadRepository.update_best_config( load=load, - best_config={"param2": "value2"} + best_config={'param2': 'value2'} ) - assert updated_load.best_config == {"param2": "value2"} + assert updated_load.best_config == {'param2': 'value2'} def test_update_test_config(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', - exec_attr={'name': "example_name", 'cpu': 4}, - default_config={"param1": "value2"} + exec_attr={'name': 'example_name', 'cpu': 4}, + default_config={'param1': 'value2'}, + hash_value=hash_value ) # 更新测试配置 updated_load = LoadRepository.update_test_config( load=load, - test_config={"param2": "value2"} + test_config={'param2': 'value2'} ) - assert updated_load.test_config == {"param2": "value2"} + assert updated_load.test_config == {'param2': 'value2'} def test_update_tuning_needed(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'name': 'example_name', 'cpu': 4}, user_config={'param1': 'value2'}) load = LoadRepository.create( name='example_name', - exec_attr={'name': "example_name", 'cpu': 4}, - default_config={"param1": "value2"} + exec_attr={'name': 'example_name', 'cpu': 4}, + default_config={'param1': 'value2'}, + hash_value=hash_value ) # 更新测试配置 updated_load = LoadRepository.update_test_config( load=load, - test_config={"is_need_tuning": True} + test_config={'is_need_tuning': True} ) - assert updated_load.test_config == {"is_need_tuning": True} + assert updated_load.test_config == {'is_need_tuning': True} diff --git a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py index 4c927c016dea46a25b20c7b92b7f046d5b637938..9283d8e3c9a4f066a0ce6a03b3aea11bb115d698 100644 --- a/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py +++ b/omniadvisor/tests/omniadvisor/repository/test_tuning_record_repository.py @@ -1,3 +1,4 @@ +from omniadvisor.interface.hijack_recommend import _calculate_hash_value from .common import create_table, delete_table from common.constant import OA_CONF from server.app.models import DatabaseLoad, DatabaseTuningRecord @@ -32,10 +33,12 @@ class TestTaskRepository: def test_create(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, - default_config={"param1": "value2"} + default_config={"param1": "value2"}, + hash_value=hash_value ) # 创建TuningRecord @@ -63,10 +66,12 @@ class TestTaskRepository: def test_query_by_load(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, - default_config={"param1": "value2"} + default_config={"param1": "value2"}, + hash_value=hash_value ) tuning_record = TuningRecordRepository.create( load=load, @@ -84,10 +89,12 @@ class TestTaskRepository: def test_query_by_load_and_config(self): # 创建测试数据 + hash_value = _calculate_hash_value(exec_attr={'cmd': 'spark-sql -f test.sql'}, user_config={"param1": "value2"}) load = LoadRepository.create( name='test', exec_attr={'cmd': 'spark-sql -f test.sql'}, - default_config={"param1": "value2"} + default_config={"param1": "value2"}, + hash_value=hash_value ) tuning_record = TuningRecordRepository.create( load=load,