diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 560d939b345e169a84dd6a06f58749115e93333b..039253180f2afbe43c5a895191c10b39ed5b420b 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -773,6 +773,11 @@ class MonitorConst: DEFAULT_STEP_INTERVAL = 1 OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"] + OP_MONVIS_SUPPORTED = [ + "norm", "min", "max", "zeros", "nans", "mean", + "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian", + "proxy", "token_similarity" + ] MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR" DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output" DATABASE = "database" diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..28b5fcb2b86c7e5322ef5fbc191f420b78993374 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sqlite3 +from typing import List, Tuple, Dict, Any +from functools import wraps + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import check_path_before_create, change_mode +from msprobe.core.common.const import FileCheckConst + + +def _db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper + + +class DBManager: + """ + 数据库管理类,封装常用数据库操作 + """ + + DEFAULT_FETCH_SIZE = 10000 + DEFAULT_INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + def __init__(self, db_path: str): + """ + 初始化DBManager + :param db_path: 数据库文件路径 + :param table_config: 表配置对象 + """ + self.db_path = db_path + + @staticmethod + def _get_where_sql(where_list): + if not where_list: + return "", tuple() + + where_clauses = [] + where_values = [] + if where_list: + for col, val in where_list.items(): + where_clauses.append(f"{col} = ?") + where_values.append(val) + if where_clauses: + where_sql = " WHERE " + " AND ".join(where_clauses) + return where_sql, tuple(where_values) + + @_db_operation + def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: + """ + 批量插入数据 + :param table_name: 表名 + :param data: 要插入的数据列表 + :param batch_size: 每批插入的大小 + :return: 插入的行数 + """ + if not data: + return 0 + columns = len(data[0]) + if key_list and columns != len(key_list): + raise ValueError( + f"When inserting into table {table_name}, the length of key list ({key_list})" + f"does not match the data({columns}).") + + batch_size = self.DEFAULT_INSERT_SIZE + placeholders = ", ".join(["?"] * columns) + if key_list: + keys = ", ".join(key_list) + sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})" + else: + sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})" + + inserted_rows = 0 + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + curs.executemany(sql, batch) + inserted_rows += curs.rowcount + + conn.commit() + return inserted_rows + + @_db_operation + def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, + columns: List[str] = None, + where: dict = None) -> List[Dict]: + """ + 查询数据 + :param table_name: 表名 + :param columns: 要查询的列 + :param where: WHERE条件 + :return: 查询结果列表(字典形式) + """ + + cols = ", ".join(columns) if columns else "*" + sql = f"SELECT {cols} FROM {table_name}" + + where_sql, where_parems = self._get_where_sql(where) + curs.execute(sql + where_sql, where_parems) + + return [dict(row) for row in curs.fetchall()] + + @_db_operation + def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, updates: Dict[str, Any], + where: dict = None) -> int: + """ + 更新数据 + :param table_name: 表名 + :param updates: 要更新的字段和值 + :param where: WHERE条件 + :param where_params: WHERE条件参数 + :return: 影响的行数 + """ + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + sql = f"UPDATE {table_name} SET {set_clause}" + + params = tuple(updates.values()) + + where_sql, where_parems = self._get_where_sql(where) + + curs.execute(sql + where_sql, params + where_parems) + conn.commit() + return curs.rowcount + + @_db_operation + def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql: str, params: Tuple = None) -> List[Dict]: + """ + 执行自定义SQL查询 + :param sql: SQL语句 + :param params: 参数 + :return: 查询结果 + """ + curs.execute(sql, params or ()) + if sql.strip().upper().startswith("SELECT"): + return [dict(row) for row in curs.fetchall()] + conn.commit() + return [] + + def table_exists(self, table_name: str) -> bool: + """ + :param table_name: 表名 + :return: 查询结果 + """ + result = self.select_data( + table_name="sqlite_master", + columns=["name"], + where={"type": "table", "name": table_name} + ) + return len(result) > 0 + + @_db_operation + def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql_commands: List[str]) -> List[List[Dict]]: + """ + 批量执行多个SQL语句 + :param sql_commands: [sql1, sql2, ...] + :return: 每个SELECT语句的结果列表 + """ + results = [] + for sql in sql_commands: + curs.execute(sql) + if sql.strip().upper().startswith("SELECT"): + results.append([dict(row) for row in curs.fetchall()]) + conn.commit() + return results + + def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: + """获取数据库连接和游标""" + check_path_before_create(self.db_path) + try: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 + curs = conn.cursor() + return conn, curs + except sqlite3.Error as err: + logger.error(f"Database connection failed: {err}") + raise + + def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: + """释放数据库连接""" + try: + if curs is not None: + curs.close() + if conn is not None: + conn.close() + except sqlite3.Error as err: + logger.error(f"Failed to release database connection: {err}") + change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..05d21d604e8198a56e7f0baeabc305a45a7c05f2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -0,0 +1,351 @@ +# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import re +from collections import OrderedDict, defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import pytz +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import (create_directory, read_csv, + recursive_chmod, remove_path) +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict +from msprobe.core.monitor.utils import get_target_output_dir +from tqdm import tqdm + +# Constants +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", + "linear_hook", "norm_hook", "proxy_model", "token_hook", "attention_hook" +] +DEFAULT_INT_VALUE = 0 +MAX_PROCESS_NUM = 128 +CSV_FILE_PATTERN = r"(\w+)_(\d+)-(\d+)\.csv" +BATCH_SIZE = 10000 + + +@dataclass +class CSV2DBConfig: + """Configuration for CSV to database conversion""" + monitor_path: str + time_start: Optional[str] = None + time_end: Optional[str] = None + process_num: int = 1 + data_type_list: Optional[List[str]] = None + output_dirpath: Optional[str] = None + step_partition: int = 500 + + +def validate_process_num(process_num: int) -> None: + """Validate process number parameter""" + if not is_int(process_num) or process_num <= 0: + raise ValueError("process_num must be a positive integer") + if process_num > MAX_PROCESS_NUM: + raise ValueError(f"Maximum supported process_num is {MAX_PROCESS_NUM}") + + +def validate_step_partition(step_partition: int) -> None: + """Validate step partition parameter""" + if not is_int(step_partition) or step_partition <= 0: + raise ValueError("step_partition must be a positive integer") + + +def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: + """Validate data type list parameter""" + if data_type_list is None or not data_type_list: + logger.info(f"Using default data types: {all_data_type_list}") + return + + if not isinstance(data_type_list, list): + raise ValueError("data_type_list must be a list") + + invalid_types = [t for t in data_type_list if t not in all_data_type_list] + if invalid_types: + raise ValueError(f"Unsupported data types: {invalid_types}") + + +def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: + """Pre-scan files for a single rank to collect metadata""" + metrics = set() + min_step = None + max_step = 0 + metric_stats = defaultdict(set) + targets = OrderedDict() + + for file_path in files: + file_name = os.path.basename(file_path) + match = re.match(CSV_FILE_PATTERN, file_name) + if not match: + continue + + metric_name, step_start, step_end = match.groups() + step_start, step_end = int(step_start), int(step_end) + + metrics.add(metric_name) + min_step = min( + step_start if min_step is None else min_step, step_start) + max_step = max(max_step, step_end) + + data = read_csv(file_path) + stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] + metric_stats[metric_name].update(stats) + + for _, row in data.iterrows(): + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target = (name, vpp_stage, micro_step) + if target not in targets: + targets[target] = None + + return { + 'max_rank': int(rank), + 'metrics': metrics, + 'min_step': min_step, + 'max_step': max_step, + 'metric_stats': metric_stats, + 'targets': list(targets.keys()) + } + + +def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1): + """Pre-scan all targets, metrics, and statistics""" + logger.info("Scanning dimensions...") + rank_files = defaultdict(list) + + # Collect files for each rank + for rank, dir_path in data_dirs.items(): + files = os.listdir(dir_path) + for file in files: + match = re.match(CSV_FILE_PATTERN, file) + if not match: + continue + metric_name, _, _ = match.groups() + if metric_name not in data_type_list: + continue + rank_files[rank].append(os.path.join(dir_path, file)) + + # Parallel pre-scan + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_pre_scan_single_rank, rank, files): rank + for rank, files in rank_files.items() + } + + results = [] + with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar: + for future in as_completed(futures): + rank = futures[future] + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error( + f"Error pre-scanning rank {rank}: {str(e)}") + pbar.update(1) + + # Aggregate results + targets = OrderedDict() + metrics = set() + min_step = None + max_step = 0 + max_rank = 0 + metric_stats = defaultdict(set) + + for rank_result in results: + max_rank = max(max_rank, rank_result['max_rank']) + metrics.update(rank_result['metrics']) + min_step = min( + min_step if min_step is not None else rank_result['min_step'], + rank_result['min_step'] + ) + max_step = max(max_step, rank_result['max_step']) + + for metric, stats in rank_result['metric_stats'].items(): + metric_stats[metric].update(stats) + + targets = update_ordered_dict(targets, rank_result['targets']) + + monitor_db.insert_dimensions( + targets, metrics, metric_stats, min_step=min_step, max_step=max_step) + monitor_db.update_global_stats( + max_rank=max_rank, min_step=min_step, max_step=max_step) + return rank_files + + +def process_single_rank( + task: Tuple[int, List[str]], + metric_id_dict: Dict[str, Tuple[int, List[str]]], + target_dict: Dict[Tuple[str, int, int], int], + step_partition_size: int, + db_path: str +) -> int: + """Process data import for a single rank""" + rank, files = task + db = MonitorDB(db_path, step_partition_size=step_partition_size) + total_inserted = 0 + table_batches = defaultdict(list) + + for file in files: + filename = os.path.basename(file) + match = re.match(CSV_FILE_PATTERN, filename) + if not match: + continue + + metric_name, _, _ = match.groups() + metric_info = metric_id_dict.get(metric_name) + if not metric_info: + continue + + metric_id, stats = metric_info + + for row_id, row in read_csv(file).iterrows(): + try: + # Parse row data + name = row.get(MonitorConst.HEADER_NAME) + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target_id = target_dict.get((name, vpp_stage, micro_step)) + if not target_id: + continue + + step = int(row['step']) + table_name, _, _ = db.get_metric_table_name(metric_id, step) + # Prepare row data + row_data = [rank, step, target_id] + row_data.extend( + float(row[stat]) if stat in row else None + for stat in stats + ) + table_batches[table_name].append(tuple(row_data)) + + # Batch insert when threshold reached + if len(table_batches[table_name]) >= BATCH_SIZE: + inserted = db.insert_rows( + table_name, table_batches[table_name]) + if inserted is not None: + total_inserted += inserted + table_batches[table_name] = [] + + except (ValueError, KeyError) as e: + logger.error( + f"CSV float conversion failed | file={file}:{row_id+2} | error={str(e)}") + continue + + # Insert remaining data + for table_name, batch in table_batches.items(): + if batch: + inserted = db.insert_rows(table_name, batch) + if inserted is not None: + total_inserted += inserted + + logger.info(f"Rank {rank} inserted {total_inserted} rows") + return total_inserted + + +def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool: + """Main method to import data into database""" + # 1. Pre-scan to get rank tasks + monitor_db.init_schema() + rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers) + if not rank_tasks: + logger.error("No valid data files found during pre-scan") + return False + + # 2. Get metric and target mappings + try: + metric_id_dict = monitor_db.get_metric_mapping() + target_dict = monitor_db.get_target_mapping() + except Exception as e: + logger.error(f"Failed to get database mappings: {str(e)}") + return False + + # 3. Process data for each rank in parallel + total_files = sum(len(files) for files in rank_tasks.values()) + logger.info(f"Starting data import for {len(rank_tasks)} ranks," + "{total_files} files..." + ) + + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit( + process_single_rank, + (rank, files), + metric_id_dict, + target_dict, + monitor_db.step_partition_size, + monitor_db.db_path): rank + for rank, files in rank_tasks.items() + } + + with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: + for future in pbar: + rank = futures[future] + try: + inserted = future.result() + pbar.set_postfix_str( + f"Rank {rank}: inserted {inserted} rows") + except Exception as e: + logger.error( + f"Failed to process Rank {rank}: {str(e)}") + return False + return True + + +def csv2db(config: CSV2DBConfig) -> None: + """Main function to convert CSV files to database""" + validate_process_num(config.process_num) + validate_step_partition(config.step_partition) + validate_data_type_list(config.data_type_list) + + target_output_dirs = get_target_output_dir( + config.monitor_path, config.time_start, config.time_end) + + if config.output_dirpath is None: + local_tz = pytz.timezone("Asia/Shanghai") + cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S") + config.output_dirpath = os.path.join( + config.monitor_path, f"{cur_time}-csv2db") + + create_directory(config.output_dirpath) + db_path = os.path.join(config.output_dirpath, "monitor_metrics.db") + + if os.path.exists(db_path): + remove_path(db_path) + logger.warning(f"Existing path {db_path} will be recovered") + + db = MonitorDB(db_path, step_partition_size=config.step_partition) + + result = import_data( + db, + target_output_dirs, + config.data_type_list if config.data_type_list else all_data_type_list, + workers=config.process_num + ) + recursive_chmod(config.output_dirpath) + if result: + logger.info(f"Data import completed. Output saved to: {config.output_dirpath}") + else: + logger.warning( + f"Data import may be incomplete. Output directory: {config.output_dirpath} " + f"(Some records might have failed)" + ) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b135694c420679c1cede2846e0d241b31171a6d1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from collections.abc import Iterable +from typing import Dict, List, Optional, Set, Tuple + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.db_manager import DBManager + + +def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict: + """Update ordered dictionary with new items""" + for item in new_list: + if item not in main_dict: + main_dict[item] = None + return main_dict + + +def get_ordered_stats(stats: Iterable) -> List[str]: + """Get statistics in predefined order""" + if not isinstance(stats, Iterable): + return [] + return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats] + + +class MonitorSql: + """数据库表参数类""" + + @staticmethod + def create_monitoring_targets_table(): + """监控目标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_targets ( + target_id INTEGER PRIMARY KEY AUTOINCREMENT, + target_name TEXT NOT NULL, + vpp_stage INTEGER NOT NULL, + micro_step INTEGER NOT NULL DEFAULT 0, + UNIQUE(target_name, vpp_stage, micro_step) + )""" + + @staticmethod + def create_monitoring_metrics_table(): + """监控指标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_metrics ( + metric_id INTEGER PRIMARY KEY AUTOINCREMENT, + metric_name TEXT UNIQUE NOT NULL + )""" + + @staticmethod + def get_metric_mapping_sql(): + return """ + SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats + FROM monitoring_metrics m + LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id + GROUP BY m.metric_id + """ + + @staticmethod + def create_metric_stats_table(): + """指标统计表""" + return """ + CREATE TABLE IF NOT EXISTS metric_stats ( + metric_id INTEGER NOT NULL, + stat_name TEXT NOT NULL, + PRIMARY KEY (metric_id, stat_name), + FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id) + ) WITHOUT ROWID""" + + @staticmethod + def create_global_stat_table(): + return """ + CREATE TABLE IF NOT EXISTS global_stats ( + stat_name TEXT PRIMARY KEY, + stat_value INTEGER NOT NULL + ) WITHOUT ROWID""" + + @classmethod + def get_table_definition(cls, table_name=""): + """ + 获取表定义SQL + :param table_name: 表名 + :return: 建表SQL语句 + :raises ValueError: 当表名不存在时 + """ + table_creators = { + "monitoring_targets": cls.create_monitoring_targets_table, + "monitoring_metrics": cls.create_monitoring_metrics_table, + "metric_stats": cls.create_metric_stats_table, + "global_stats": cls.create_global_stat_table, + } + if not table_name: + return [table_creators.get(table, lambda x:"")() for table in table_creators] + if table_name not in table_creators: + raise ValueError(f"Unsupported table name: {table_name}") + return table_creators[table_name]() + + @classmethod + def get_metric_table_definition(cls, table_name, stats, patition=None): + stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] + if patition and len(patition) == 2: + partition_start_step, partition_end_step = patition + step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} + AND {partition_end_step}),""" + else: + step_column = "step INTEGER NOT NULL" + create_sql = f""" + CREATE TABLE {table_name} ( + rank INTEGER NOT NULL, + {step_column} + target_id INTEGER NOT NULL, + {', '.join(stat_columns)}, + PRIMARY KEY (rank, step, target_id), + FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id) + ) WITHOUT ROWID + """ + return create_sql + + +class MonitorDB: + """Main class for monitoring database operations""" + + def __init__(self, db_path: str, step_partition_size: int = 500): + self.db_path = db_path + self.db_manager = DBManager(db_path) + self.step_partition_size = step_partition_size + + def get_metric_table_name(self, metric_id: int, step: int) -> str: + """Generate metric table name""" + step_start = ( + step // self.step_partition_size) * self.step_partition_size + step_end = step_start + self.step_partition_size - 1 + return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end + + def init_schema(self) -> None: + """Initialize database schema""" + self.db_manager.execute_multi_sql(MonitorSql.get_table_definition()) + + # Insert initial global stats + global_stats = [ + ('max_rank', 0), + ('min_step', 0), + ('max_step', 0), + ('step_partition_size', self.step_partition_size) + ] + self.db_manager.insert_data("global_stats", global_stats) + + def insert_dimensions( + self, + targets: OrderedDict, + metrics: Set[str], + metric_stats: Dict[str, Set[str]], + min_step: Optional[int] = None, + max_step: int = None, + ) -> None: + """Insert dimension data into database""" + # Insert targets + self.db_manager.insert_data( + "monitoring_targets", + [(name, vpp_stage, micro_step) + for (name, vpp_stage, micro_step) in targets], + key_list=["target_name", "vpp_stage", "micro_step"] + ) + + # Insert metrics + self.db_manager.insert_data( + "monitoring_metrics", + [(metric,) for metric in metrics], + key_list=["metric_name"] + ) + + # Insert metric-stat relationships + for metric, stats in metric_stats.items(): + metric_id = self._get_metric_id(metric) + ordered_stats = get_ordered_stats(stats) + + self.db_manager.insert_data( + "metric_stats", + [(metric_id, stat) for stat in ordered_stats], + key_list=["metric_id", "stat_name"] + ) + + # Create metric tables for each partition + if min_step is not None and max_step is not None: + first_partition = min_step // self.step_partition_size + last_partition = max_step // self.step_partition_size + + for partition in range(first_partition, last_partition + 1): + step_start = partition * self.step_partition_size + self.create_metric_table( + metric_id, step_start, ordered_stats) + + def insert_rows(self, table_name, rows): + if not self.db_manager.table_exists(table_name): + raise RuntimeError(f"{table_name} not existed in {self.db_path}") + inserted = self.db_manager.insert_data(table_name, rows) + inserted = 0 if inserted is None else inserted + return inserted + + def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str: + """Create metric table for a specific partition""" + table_name, partition_start_step, partition_end_step = self.get_metric_table_name( + metric_id, + step + ) + if self.db_manager.table_exists(table_name): + return table_name + + create_sql = MonitorSql.get_metric_table_definition( + table_name, stats, patition=( + partition_start_step, partition_end_step) + ) + self.db_manager.execute_sql(create_sql) + return table_name + + def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None: + """Update global statistics""" + updates = [ + ("max_rank", max_rank), + ("min_step", min_step), + ("max_step", max_step) + ] + for stat_name, value in updates: + if not value: + continue + self.db_manager.update_data( + table_name="global_stats", + updates={"stat_value": value}, + where={"stat_name": stat_name} + ) + + def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]: + """Get metric name to ID mapping with statistics""" + results = self.db_manager.execute_sql( + MonitorSql.get_metric_mapping_sql() + ) + + return { + row["metric_name"]: ( + row["metric_id"], + get_ordered_stats(row["stats"].split(",") + ) if row["stats"] else [] + ) for row in results + } + + def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]: + """Get target mapping dictionary""" + results = self.db_manager.select_data( + table_name="monitoring_targets", + columns=["target_id", "target_name", "vpp_stage", "micro_step"] + ) + if not results: + return {} + return { + (row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"] + for row in results + } + + def _get_metric_id(self, metric_name: str) -> Optional[int]: + """Get metric ID by name""" + result = self.db_manager.select_data( + table_name="monitoring_metrics", + columns=["metric_id"], + where={"metric_name": metric_name} + ) + return result[0]["metric_id"] if result else None diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f3efde95164e40fcb0ed6a1aeba8fd21c7a98b62 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py @@ -0,0 +1,241 @@ +import unittest +import sqlite3 +import os +import tempfile +from typing import Dict, List +from unittest.mock import patch, MagicMock + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.db_manager import DBManager + +class TestDBManager(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.db_manager = DBManager(self.db_path) + + # 创建测试表 + self.test_table = "test_table" + self.create_test_table() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def create_test_table(self): + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.test_table} ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + value INTEGER, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + + def test_get_connection_success(self): + """测试成功获取数据库连接""" + conn, curs = self.db_manager._get_connection() + self.assertIsInstance(conn, sqlite3.Connection) + self.assertIsInstance(curs, sqlite3.Cursor) + self.db_manager._release_connection(conn, curs) + + @patch.object(logger, 'error') + def test_get_connection_success_failed(self, mock_logger): + """测试错误日志记录""" + with patch('sqlite3.connect', side_effect=sqlite3.Error("Test error")): + with self.assertRaises(sqlite3.Error): + self.db_manager._get_connection() + mock_logger.assert_called_with("Database connection failed: Test error") + + + def test_insert_data_basic(self): + """测试基本数据插入""" + test_data = [ + (1, "item1", 100), + (2, "item2", 200) + ] + columns = ["id", "name", "value"] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=columns + ) + self.assertEqual(inserted, 2) + + # 验证数据是否实际插入 + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"] + ) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["name"], "item1") + + def test_insert_data_without_keys(self): + """测试无列名的数据插入""" + test_data = [ + (3, "item3", 300, 333), + (4, "item4", 400, 333) + ] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data + ) + self.assertEqual(inserted, 2) + + def test_insert_data_empty(self): + """测试空数据插入""" + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=[] + ) + self.assertEqual(inserted, 0) + + def test_insert_data_mismatch_keys(self): + """测试列名与数据不匹配的情况""" + test_data = [(5, "item5")] + with self.assertRaises(ValueError): + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] # 多了一个列 + ) + + def test_select_data_basic(self): + """测试基本数据查询""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(10, "test10", 1000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["name", "value"], + where={"id": 10} + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "test10") + self.assertEqual(results[0]["value"], 1000) + + def test_select_data_no_where(self): + """测试无条件查询""" + # 插入多条数据 + test_data = [ + (20, "item20", 2000), + (21, "item21", 2100) + ] + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table + ) + self.assertGreaterEqual(len(results), 2) + + def test_update_data_basic(self): + """测试基本数据更新""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(30, "old_name", 3000)], + key_list=["id", "name", "value"] + ) + + updated = self.db_manager.update_data( + table_name=self.test_table, + updates={"name": "new_name", "value": 3500}, + where={"id": 30} + ) + self.assertEqual(updated, 1) + + # 验证更新结果 + results = self.db_manager.select_data( + table_name=self.test_table, + where={"id": 30} + ) + self.assertEqual(results[0]["name"], "new_name") + self.assertEqual(results[0]["value"], 3500) + + def test_execute_sql_select(self): + """测试执行SELECT SQL语句""" + self.db_manager.insert_data( + table_name=self.test_table, + data=[(50, "sql_item", 5000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.execute_sql( + sql=f"SELECT name, value FROM {self.test_table} WHERE id = ?", + params=(50,) + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "sql_item") + + def test_execute_sql_non_select(self): + """测试执行非SELECT SQL语句""" + # 先插入数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(60, "to_delete", 6000)], + key_list=["id", "name", "value"] + ) + + # 执行DELETE语句 + self.db_manager.execute_sql( + sql=f"DELETE FROM {self.test_table} WHERE id = 60" + ) + + # 验证数据已被删除 + results = self.db_manager.select_data( + table_name=self.test_table, + where={"id": 60} + ) + self.assertEqual(len(results), 0) + + def test_table_exists_true(self): + """测试表存在检查(存在的情况)""" + exists = self.db_manager.table_exists(self.test_table) + self.assertTrue(exists) + + def test_table_exists_false(self): + """测试表存在检查(不存在的情况)""" + exists = self.db_manager.table_exists("non_existent_table") + self.assertFalse(exists) + + def test_execute_multi_sql(self): + """测试批量执行多个SQL语句""" + sql_commands = [ + f"INSERT INTO {self.test_table} (id, name, value) VALUES (70, 'multi1', 7000)", + f"INSERT INTO {self.test_table} (id, name, value) VALUES (71, 'multi2', 7100)", + f"SELECT * FROM {self.test_table} WHERE id IN (70, 71)" + ] + + results = self.db_manager.execute_multi_sql(sql_commands) + + # 应该只有最后一个SELECT语句有结果 + self.assertEqual(len(results), 1) + self.assertEqual(len(results[0]), 2) + + @patch.object(logger, 'error') + def test_db_operation_decorator(self, mock_logger): + """测试数据库操作装饰器""" + # 模拟一个会失败的操作 + with patch.object(self.db_manager, '_get_connection', + side_effect=sqlite3.Error("Test error")): + result = self.db_manager.select_data(table_name=self.test_table) + self.assertIsNone(result) # 装饰器会捕获异常并返回None + mock_logger.assert_called_with("Database operation failed: Test error") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py new file mode 100644 index 0000000000000000000000000000000000000000..cac26b90c7b471e050bba4be6d5fc605cc7c2f19 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -0,0 +1,232 @@ +import unittest +import os +import tempfile +import shutil +from unittest.mock import patch, MagicMock +import pandas as pd + +from msprobe.core.monitor.csv2db import ( + CSV2DBConfig, + validate_process_num, + validate_step_partition, + validate_data_type_list, + _pre_scan_single_rank, + _pre_scan, + process_single_rank, + import_data, + csv2db, + all_data_type_list, + MAX_PROCESS_NUM, +) + + +class TestCSV2DBValidations(unittest.TestCase): + def test_validate_process_num_valid(self): + """测试有效的进程数""" + validate_process_num(1) + validate_process_num(MAX_PROCESS_NUM) + + def test_validate_process_num_invalid(self): + """测试无效的进程数""" + with self.assertRaises(ValueError): + validate_process_num(0) + with self.assertRaises(ValueError): + validate_process_num(-1) + with self.assertRaises(ValueError): + validate_process_num(MAX_PROCESS_NUM + 1) + + def test_validate_step_partition_valid(self): + """测试有效的step分区""" + validate_step_partition(1) + validate_step_partition(500) + + def test_validate_step_partition_invalid(self): + """测试无效的step分区""" + with self.assertRaises(ValueError): + validate_step_partition(0) + with self.assertRaises(ValueError): + validate_step_partition(-1) + + def test_validate_data_type_list_valid(self): + """测试有效的数据类型列表""" + validate_data_type_list(["actv", "grad_reduced"]) + validate_data_type_list(all_data_type_list[:2]) + + def test_validate_data_type_list_invalid(self): + """测试无效的数据类型列表""" + with self.assertRaises(ValueError): + validate_data_type_list(["invalid_type"]) + with self.assertRaises(ValueError): + validate_data_type_list(["actv", "invalid_type"]) + + +class TestPreScanFunctions(unittest.TestCase): + def setUp(self): + # 创建临时目录和测试CSV文件 + self.temp_dir = tempfile.mkdtemp() + self.temp_dir_rank2 = tempfile.mkdtemp() + self.test_csv_path_actv = os.path.join(self.temp_dir, "actv_0-100.csv") + self.test_csv_path_rank2_grad = os.path.join( + self.temp_dir_rank2, "grad_reduced_100-200.csv") + self.test_csv_path_rank_inv = os.path.join( + self.temp_dir_rank2, "invalid_metric_100-200.csv") + + # 创建测试CSV数据 + test_data_actv = { + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + test_data_grad = { + "name": ["layer1_weight", "layer2_weight"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + df = pd.DataFrame(test_data_actv) + df.to_csv(self.test_csv_path_actv, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank2_grad, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank_inv, index=False) + + def tearDown(self): + # 清理临时目录 + shutil.rmtree(self.temp_dir) + + def test_pre_scan_single_rank(self): + """测试单个rank的预扫描""" + rank = 0 + files = [self.test_csv_path_actv] + result = _pre_scan_single_rank(rank, files) + self.assertEqual(result["max_rank"], rank) + self.assertEqual(result["metrics"], {"actv"}) + self.assertEqual(result["min_step"], 0) + self.assertEqual(result["max_step"], 100) + self.assertEqual(result["metric_stats"], {"actv": {"min", "max"}}) + self.assertEqual(len(result["targets"]), 2) + + def test_pre_scan(self): + """测试完整预扫描流程""" + # 模拟MonitorDB + mock_db = MagicMock() + + # 测试数据 + data_dirs = {0: self.temp_dir, 2: self.temp_dir_rank2} + data_type_list = ["actv", "grad_reduced"] + + result = _pre_scan(mock_db, data_dirs, data_type_list) + + self.assertEqual(sorted(list(result.keys())), [0, 2]) + + mock_db.insert_dimensions.assert_called_once() + mock_db.update_global_stats.assert_called_with( + max_rank=2, min_step=0, max_step=200 + ) + + +class TestProcessSingleRank(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.MonitorDB") + @patch("msprobe.core.monitor.csv2db.read_csv") + def test_process_single_rank(self, mock_read_csv, mock_db_class): + """测试处理单个rank的数据""" + # 模拟数据库和映射 + mock_db = MagicMock() + mock_db_class.return_value = mock_db + mock_db.get_metric_table_name.return_value = ( + "metric_1_step_0_99", 0, 99) + mock_db.insert_rows.return_value = 2 + + # 模拟CSV数据 + mock_result = pd.DataFrame({ + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "norm": [0.1, 0.2], + "max": [1.0, 2.0] + }) + mock_read_csv.return_value = mock_result + + # 测试数据 + task = (0, ["actv_10-20.csv"]) + metric_id_dict = {"actv": (1, ["norm", "max"])} + target_dict = {("layer1", 0, 0): 1, ("layer2", 0, 1): 2} + step_partition_size = 100 + db_path = "dummy.db" + + result = process_single_rank( + task, metric_id_dict, target_dict, step_partition_size, db_path) + + self.assertEqual(result, 2) + mock_db.insert_rows.assert_called_with( + "metric_1_step_0_99", [(0, 10, 1, 0.1, 1.0), (0, 20, 2, 0.2, 2.0)] + ) + + +class TestImportData(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db._pre_scan") + def test_import_data_success(self, mock_pre_scan): + """测试数据导入成功场景""" + # 模拟预扫描结果 + mock_pre_scan.return_value = { + 0: ["actv_10-20.csv"], 1: ["actv_10-20.csv"]} + + # 模拟数据库 + mock_db = MagicMock() + mock_db.get_metric_mapping.return_value = {"actv": (1, ["min", "max"])} + mock_db.get_target_mapping.return_value = {("layer1", 0, 0): 1} + + # 测试数据 + data_dirs = {0: "dir0", 1: "dir1"} + data_type_list = ["actv"] + workers = 2 + + import_data(mock_db, data_dirs, data_type_list, workers) + + mock_db.init_schema.assert_called_once() + mock_pre_scan.assert_called_once() + + @patch("msprobe.core.monitor.csv2db._pre_scan") + def test_import_data_no_files(self, mock_pre_scan): + """测试没有找到数据文件的情况""" + mock_pre_scan.return_value = {} + + mock_db = MagicMock() + data_dirs = {0: "dir0"} + data_type_list = ["actv"] + + result = import_data(mock_db, data_dirs, data_type_list) + + self.assertFalse(result) + mock_pre_scan.assert_called_once() + + +class TestCSV2DBMain(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.import_data") + @patch("msprobe.core.monitor.csv2db.get_target_output_dir") + @patch("msprobe.core.monitor.csv2db.create_directory") + def test_csv2db(self, mock_create_dir, mock_get_dirs, mock_import): + """测试主函数csv2db""" + # 模拟配置 + config = CSV2DBConfig( + monitor_path="test_path", + data_type_list=["actv"], + process_num=4, + step_partition=500 + ) + + # 模拟依赖函数 + mock_get_dirs.return_value = {0: "dir0", 1: "dir1"} + mock_import.return_value = True + + csv2db(config) + + mock_get_dirs.assert_called_once() + mock_create_dir.assert_called_once() + mock_import.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d25dc4b7fb5974e4dfcce5455e20dbb8731615e0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py @@ -0,0 +1,255 @@ +import unittest +import os +import re +import tempfile +from collections import OrderedDict +from unittest.mock import patch + +from msprobe.core.common.const import MonitorConst +from msprobe.core.monitor.db_utils import MonitorDB, MonitorSql, update_ordered_dict, get_ordered_stats + +def normalize_spaces(text): + return re.sub(r'\s+', ' ', text) + +class TestDBUtils(unittest.TestCase): + def test_update_ordered_dict(self): + """测试update_ordered_dict函数""" + main_dict = OrderedDict([('a', 1), ('b', 2)]) + new_list = ['b', 'c', 'd'] + + result = update_ordered_dict(main_dict, new_list) + + self.assertEqual(list(result.keys()), ['a', 'b', 'c', 'd']) + self.assertEqual(result['a'], 1) + self.assertIsNone(result['c']) + + def test_get_ordered_stats(self): + """测试get_ordered_stats函数""" + test_stats = ['stat2', 'stat1', 'stat3'] + supported_stats = ['stat1', 'stat2', 'stat3', 'stat4'] + + with patch.object(MonitorConst, 'OP_MONVIS_SUPPORTED', supported_stats): + result = get_ordered_stats(test_stats) + + self.assertEqual(result, ['stat1', 'stat2', 'stat3']) + + def test_get_ordered_stats_with_non_iterable(self): + """测试get_ordered_stats处理非可迭代对象""" + result = get_ordered_stats(123) + self.assertEqual(result, []) + + +class TestMonitorSql(unittest.TestCase): + def test_get_table_definition_all_tables(self): + """测试获取所有表定义""" + result = MonitorSql.get_table_definition() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 4) + self.assertTrue(all("CREATE TABLE" in sql for sql in result)) + + def test_get_table_definition_single_table(self): + """测试获取单个表定义""" + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + result = MonitorSql.get_table_definition(table) + result = normalize_spaces(result) + self.assertIn(f"CREATE TABLE IF NOT EXISTS {table}", result) + + def test_get_table_definition_invalid_table(self): + """测试获取不存在的表定义""" + with self.assertRaises(ValueError): + MonitorSql.get_table_definition("invalid_table") + + def test_get_metric_table_definition_with_partition(self): + """测试带分区的指标表定义""" + stats = ["norm", "max"] + result = MonitorSql.get_metric_table_definition("test_metric", stats, [100, 200]) + result = normalize_spaces(result) + self.assertIn("norm REAL DEFAULT NULL", result) + self.assertIn("max REAL DEFAULT NULL", result) + self.assertIn("step INTEGER NOT NULL CHECK(step BETWEEN 100 AND 200)", result) + + def test_get_metric_mapping_sql(self): + """测试获取指标映射SQL""" + result = MonitorSql.get_metric_mapping_sql() + result = normalize_spaces(result) + self.assertIn("SELECT m.metric_id, m.metric_name", result) + self.assertIn("GROUP_CONCAT(ms.stat_name)", result) + + +class TestMonitorDB(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.monitor_db = MonitorDB(self.db_path, step_partition_size=100) + + # 初始化数据库schema + self.monitor_db.init_schema() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def test_init_schema(self): + """测试初始化数据库schema""" + # 验证表是否创建成功 + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + self.assertTrue(self.monitor_db.db_manager.table_exists(table)) + + # 验证全局统计初始值 + results = self.monitor_db.db_manager.select_data("global_stats") + self.assertEqual(len(results), 4) + self.assertEqual(results[0]['stat_value'], 0) # max_rank + + def test_get_metric_table_name(self): + """测试生成指标表名""" + # 测试分区边界 + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 50), + ("metric_1_step_0_99", 0, 99) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 100), + ("metric_1_step_100_199", 100, 199) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 199), + ("metric_1_step_100_199", 100, 199) + ) + + def test_insert_dimensions(self): + """测试插入维度数据""" + targets = OrderedDict() + targets[("layer1", 0, 0)] = None + targets[("layer2", 0, 1)] = None + + metrics = {"metric1", "metric2"} + metric_stats = { + "metric1": {"norm", "max"}, + "metric2": {"min", "max"} + } + + self.monitor_db.insert_dimensions( + targets=targets, + metrics=metrics, + metric_stats=metric_stats, + min_step=0, + max_step=200 + ) + + # 验证目标插入 + target_results = self.monitor_db.db_manager.select_data("monitoring_targets") + self.assertEqual(len(target_results), 2) + + # 验证指标插入 + metric_results = self.monitor_db.db_manager.select_data("monitoring_metrics") + self.assertEqual(len(metric_results), 2) + + # 验证指标统计关系插入 + stat_results = self.monitor_db.db_manager.select_data("metric_stats") + self.assertEqual(len(stat_results), 4) # 2 metrics * 2 stats each + + # 验证指标表创建 + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_1_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_1_step_100_199")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_2_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_2_step_100_199")) + + def test_create_metric_table(self): + """测试创建指标表""" + table_name = self.monitor_db.create_metric_table( + metric_id=1, + step=50, + stats=["norm", "max"] + ) + + self.assertEqual(table_name, "metric_1_step_0_99") + self.assertTrue(self.monitor_db.db_manager.table_exists(table_name)) + + def test_update_global_stats(self): + """测试更新全局统计""" + self.monitor_db.update_global_stats( + max_rank=8, + min_step=10, + max_step=1000 + ) + + # 验证更新结果 + results = self.monitor_db.db_manager.select_data("global_stats") + stats = {row['stat_name']: row['stat_value'] for row in results} + self.assertEqual(stats['max_rank'], 8) + self.assertEqual(stats['min_step'], 10) + self.assertEqual(stats['max_step'], 1000) + + def test_get_metric_mapping(self): + """测试获取指标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_metrics", + [("metric1",), ("metric2",)], + ["metric_name"] + ) + + # 获取metric_id + metric1_id = self.monitor_db._get_metric_id("metric1") + metric2_id = self.monitor_db._get_metric_id("metric2") + + # 插入统计关系 + self.monitor_db.db_manager.insert_data( + "metric_stats", + [(metric1_id, "norm"), (metric1_id, "max"), (metric2_id, "min")], + ["metric_id", "stat_name"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_metric_mapping() + + self.assertEqual(len(mapping), 2) + self.assertEqual(mapping["metric1"][0], metric1_id) + self.assertEqual(sorted(mapping["metric1"][1]), ["max", "norm"]) + self.assertEqual(mapping["metric2"][1], ["min"]) + + def test_get_target_mapping(self): + """测试获取目标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_targets", + [("target1", 0, 0), ("target2", 0, 1)], + ["target_name", "vpp_stage", "micro_step"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_target_mapping() + + self.assertEqual(len(mapping), 2) + self.assertIn(("target1", 0, 0), mapping) + self.assertIn(("target2", 0, 1), mapping) + + def test_insert_rows(self): + """测试插入行数据""" + # 先创建测试表 + self.monitor_db.db_manager.execute_sql( + "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)" + ) + + # 测试插入 + inserted = self.monitor_db.insert_rows( + "test_table", + [(1, "item1"), (2, "item2")] + ) + + self.assertEqual(inserted, 2) + + # 验证数据 + results = self.monitor_db.db_manager.select_data("test_table") + self.assertEqual(len(results), 2) + + def test_insert_rows_table_not_exists(self): + """测试插入行数据到不存在的表""" + with self.assertRaises(RuntimeError): + self.monitor_db.insert_rows( + "non_existent_table", + [(1, "item1")] + )