diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py
index 560d939b345e169a84dd6a06f58749115e93333b..039253180f2afbe43c5a895191c10b39ed5b420b 100644
--- a/debug/accuracy_tools/msprobe/core/common/const.py
+++ b/debug/accuracy_tools/msprobe/core/common/const.py
@@ -773,6 +773,11 @@ class MonitorConst:
DEFAULT_STEP_INTERVAL = 1
OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"]
+ OP_MONVIS_SUPPORTED = [
+ "norm", "min", "max", "zeros", "nans", "mean",
+ "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian",
+ "proxy", "token_similarity"
+ ]
MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
DATABASE = "database"
diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..28b5fcb2b86c7e5322ef5fbc191f420b78993374
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sqlite3
+from typing import List, Tuple, Dict, Any
+from functools import wraps
+
+from msprobe.pytorch.common.log import logger
+from msprobe.core.common.file_utils import check_path_before_create, change_mode
+from msprobe.core.common.const import FileCheckConst
+
+
+def _db_operation(func):
+ """数据库操作装饰器,自动管理连接"""
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ conn, curs = None, None
+ try:
+ conn, curs = self._get_connection()
+ result = func(self, conn, curs, *args, **kwargs)
+ return result # 显式返回正常结果
+
+ except sqlite3.Error as err:
+ logger.error(f"Database operation failed: {err}")
+ if conn:
+ conn.rollback()
+ return None # 显式返回错误情况下的None
+
+ finally:
+ self._release_connection(conn, curs)
+ return wrapper
+
+
+class DBManager:
+ """
+ 数据库管理类,封装常用数据库操作
+ """
+
+ DEFAULT_FETCH_SIZE = 10000
+ DEFAULT_INSERT_SIZE = 10000
+ MAX_ROW_COUNT = 100000000
+
+ def __init__(self, db_path: str):
+ """
+ 初始化DBManager
+ :param db_path: 数据库文件路径
+ :param table_config: 表配置对象
+ """
+ self.db_path = db_path
+
+ @staticmethod
+ def _get_where_sql(where_list):
+ if not where_list:
+ return "", tuple()
+
+ where_clauses = []
+ where_values = []
+ if where_list:
+ for col, val in where_list.items():
+ where_clauses.append(f"{col} = ?")
+ where_values.append(val)
+ if where_clauses:
+ where_sql = " WHERE " + " AND ".join(where_clauses)
+ return where_sql, tuple(where_values)
+
+ @_db_operation
+ def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
+ table_name: str, data: List[Tuple], key_list: List[str] = None) -> int:
+ """
+ 批量插入数据
+ :param table_name: 表名
+ :param data: 要插入的数据列表
+ :param batch_size: 每批插入的大小
+ :return: 插入的行数
+ """
+ if not data:
+ return 0
+ columns = len(data[0])
+ if key_list and columns != len(key_list):
+ raise ValueError(
+ f"When inserting into table {table_name}, the length of key list ({key_list})"
+ f"does not match the data({columns}).")
+
+ batch_size = self.DEFAULT_INSERT_SIZE
+ placeholders = ", ".join(["?"] * columns)
+ if key_list:
+ keys = ", ".join(key_list)
+ sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})"
+ else:
+ sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})"
+
+ inserted_rows = 0
+ for i in range(0, len(data), batch_size):
+ batch = data[i:i + batch_size]
+ curs.executemany(sql, batch)
+ inserted_rows += curs.rowcount
+
+ conn.commit()
+ return inserted_rows
+
+ @_db_operation
+ def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
+ table_name: str,
+ columns: List[str] = None,
+ where: dict = None) -> List[Dict]:
+ """
+ 查询数据
+ :param table_name: 表名
+ :param columns: 要查询的列
+ :param where: WHERE条件
+ :return: 查询结果列表(字典形式)
+ """
+
+ cols = ", ".join(columns) if columns else "*"
+ sql = f"SELECT {cols} FROM {table_name}"
+
+ where_sql, where_parems = self._get_where_sql(where)
+ curs.execute(sql + where_sql, where_parems)
+
+ return [dict(row) for row in curs.fetchall()]
+
+ @_db_operation
+ def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
+ table_name: str, updates: Dict[str, Any],
+ where: dict = None) -> int:
+ """
+ 更新数据
+ :param table_name: 表名
+ :param updates: 要更新的字段和值
+ :param where: WHERE条件
+ :param where_params: WHERE条件参数
+ :return: 影响的行数
+ """
+ set_clause = ", ".join([f"{k} = ?" for k in updates.keys()])
+ sql = f"UPDATE {table_name} SET {set_clause}"
+
+ params = tuple(updates.values())
+
+ where_sql, where_parems = self._get_where_sql(where)
+
+ curs.execute(sql + where_sql, params + where_parems)
+ conn.commit()
+ return curs.rowcount
+
+ @_db_operation
+ def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
+ sql: str, params: Tuple = None) -> List[Dict]:
+ """
+ 执行自定义SQL查询
+ :param sql: SQL语句
+ :param params: 参数
+ :return: 查询结果
+ """
+ curs.execute(sql, params or ())
+ if sql.strip().upper().startswith("SELECT"):
+ return [dict(row) for row in curs.fetchall()]
+ conn.commit()
+ return []
+
+ def table_exists(self, table_name: str) -> bool:
+ """
+ :param table_name: 表名
+ :return: 查询结果
+ """
+ result = self.select_data(
+ table_name="sqlite_master",
+ columns=["name"],
+ where={"type": "table", "name": table_name}
+ )
+ return len(result) > 0
+
+ @_db_operation
+ def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor,
+ sql_commands: List[str]) -> List[List[Dict]]:
+ """
+ 批量执行多个SQL语句
+ :param sql_commands: [sql1, sql2, ...]
+ :return: 每个SELECT语句的结果列表
+ """
+ results = []
+ for sql in sql_commands:
+ curs.execute(sql)
+ if sql.strip().upper().startswith("SELECT"):
+ results.append([dict(row) for row in curs.fetchall()])
+ conn.commit()
+ return results
+
+ def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
+ """获取数据库连接和游标"""
+ check_path_before_create(self.db_path)
+ try:
+ conn = sqlite3.connect(self.db_path)
+ conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果
+ curs = conn.cursor()
+ return conn, curs
+ except sqlite3.Error as err:
+ logger.error(f"Database connection failed: {err}")
+ raise
+
+ def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None:
+ """释放数据库连接"""
+ try:
+ if curs is not None:
+ curs.close()
+ if conn is not None:
+ conn.close()
+ except sqlite3.Error as err:
+ logger.error(f"Failed to release database connection: {err}")
+ change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY)
diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a1dd7d327269615c088cca35f5ef3c705be3d79
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py
@@ -0,0 +1,360 @@
+# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import os
+import re
+from collections import OrderedDict, defaultdict
+from concurrent.futures import ProcessPoolExecutor, as_completed
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import pytz
+from msprobe.core.common.const import MonitorConst
+from msprobe.core.common.file_utils import (create_directory, read_csv,
+ recursive_chmod, remove_path)
+from msprobe.core.common.log import logger
+from msprobe.core.common.utils import is_int
+from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict
+from msprobe.core.monitor.utils import get_target_output_dir
+from tqdm import tqdm
+
+# Constants
+all_data_type_list = [
+ "actv", "actv_grad", "exp_avg", "exp_avg_sq",
+ "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"
+]
+DEFAULT_INT_VALUE = 0
+MAX_PROCESS_NUM = 128
+CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv"
+BATCH_SIZE = 10000
+
+
+@dataclass
+class CSV2DBConfig:
+ """Configuration for CSV to database conversion"""
+ monitor_path: str
+ time_start: Optional[str] = None
+ time_end: Optional[str] = None
+ process_num: int = 1
+ data_type_list: Optional[List[str]] = None
+ output_dirpath: Optional[str] = None
+ step_partition: int = 500
+
+
+def validate_process_num(process_num: int) -> None:
+ """Validate process number parameter"""
+ if not is_int(process_num) or process_num <= 0:
+ raise ValueError("process_num must be a positive integer")
+ if process_num > MAX_PROCESS_NUM:
+ raise ValueError(f"Maximum supported process_num is {MAX_PROCESS_NUM}")
+
+
+def validate_step_partition(step_partition: int) -> None:
+ """Validate step partition parameter"""
+ if not is_int(step_partition) or step_partition <= 0:
+ raise ValueError("step_partition must be a positive integer")
+
+
+def validate_data_type_list(data_type_list: Optional[List[str]]) -> None:
+ """Validate data type list parameter"""
+ if data_type_list is None or not data_type_list:
+ logger.info(f"Using default data types: {all_data_type_list}")
+ return
+
+ if not isinstance(data_type_list, list):
+ raise ValueError("data_type_list must be a list")
+
+ invalid_types = [t for t in data_type_list if t not in all_data_type_list]
+ if invalid_types:
+ raise ValueError(f"Unsupported data types: {invalid_types}")
+
+
+def get_info_from_filename(file_name, metric_list=None):
+ metric_name = "_".join(file_name.split('_')[:-1])
+ if metric_list and metric_name not in metric_list:
+ return "", 0, 0
+ match = re.match(f"{metric_name}{CSV_FILE_PATTERN}", file_name)
+ if not match:
+ return "", 0, 0
+ step_start, step_end = match.groups()
+ return metric_name, step_start, step_end
+
+
+def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict:
+ """Pre-scan files for a single rank to collect metadata"""
+ metrics = set()
+ min_step = None
+ max_step = 0
+ metric_stats = defaultdict(set)
+ targets = OrderedDict()
+
+ for file_path in files:
+ file_name = os.path.basename(file_path)
+ metric_name, step_start, step_end = get_info_from_filename(file_name)
+ if not metric_name:
+ continue
+ step_start, step_end = int(step_start), int(step_end)
+
+ metrics.add(metric_name)
+ min_step = min(
+ step_start if min_step is None else min_step, step_start)
+ max_step = max(max_step, step_end)
+
+ data = read_csv(file_path)
+ stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED]
+ metric_stats[metric_name].update(stats)
+
+ for row_id, row in data.iterrows():
+ try:
+ name = row[MonitorConst.HEADER_NAME]
+ vpp_stage = int(row['vpp_stage'])
+ micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE))
+ except (ValueError, KeyError) as e:
+ logger.warning(
+ f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}")
+ continue
+ target = (name, vpp_stage, micro_step)
+ if target not in targets:
+ targets[target] = None
+
+ return {
+ 'max_rank': int(rank),
+ 'metrics': metrics,
+ 'min_step': min_step,
+ 'max_step': max_step,
+ 'metric_stats': metric_stats,
+ 'targets': list(targets.keys())
+ }
+
+
+def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1):
+ """Pre-scan all targets, metrics, and statistics"""
+ logger.info("Scanning dimensions...")
+ rank_files = defaultdict(list)
+
+ # Collect files for each rank
+ for rank, dir_path in data_dirs.items():
+ files = os.listdir(dir_path)
+ for file in files:
+ metric_name, _, _ = get_info_from_filename(
+ file, metric_list=data_type_list)
+ if not metric_name:
+ continue
+ rank_files[rank].append(os.path.join(dir_path, file))
+
+ # Parallel pre-scan
+ with ProcessPoolExecutor(max_workers=workers) as executor:
+ futures = {
+ executor.submit(_pre_scan_single_rank, rank, files): rank
+ for rank, files in rank_files.items()
+ }
+
+ results = []
+ with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar:
+ for future in as_completed(futures):
+ rank = futures[future]
+ try:
+ result = future.result()
+ results.append(result)
+ except Exception as e:
+ logger.error(
+ f"Error pre-scanning rank {rank}: {str(e)}")
+ pbar.update(1)
+
+ # Aggregate results
+ targets = OrderedDict()
+ metrics = set()
+ min_step = None
+ max_step = 0
+ max_rank = 0
+ metric_stats = defaultdict(set)
+
+ for rank_result in results:
+ max_rank = max(max_rank, rank_result['max_rank'])
+ metrics.update(rank_result['metrics'])
+ min_step = min(
+ min_step if min_step is not None else rank_result['min_step'],
+ rank_result['min_step']
+ )
+ max_step = max(max_step, rank_result['max_step'])
+
+ for metric, stats in rank_result['metric_stats'].items():
+ metric_stats[metric].update(stats)
+
+ targets = update_ordered_dict(targets, rank_result['targets'])
+
+ monitor_db.insert_dimensions(
+ targets, metrics, metric_stats, min_step=min_step, max_step=max_step)
+ monitor_db.update_global_stats(
+ max_rank=max_rank, min_step=min_step, max_step=max_step)
+ return rank_files
+
+
+def process_single_rank(
+ task: Tuple[int, List[str]],
+ metric_id_dict: Dict[str, Tuple[int, List[str]]],
+ target_dict: Dict[Tuple[str, int, int], int],
+ step_partition_size: int,
+ db_path: str
+) -> int:
+ """Process data import for a single rank"""
+ rank, files = task
+ db = MonitorDB(db_path, step_partition_size=step_partition_size)
+ total_inserted = 0
+ table_batches = defaultdict(list)
+
+ for file in files:
+ filename = os.path.basename(file)
+ metric_name, _, _ = get_info_from_filename(filename)
+ if not metric_name:
+ continue
+ metric_info = metric_id_dict.get(metric_name)
+ if not metric_info:
+ continue
+
+ metric_id, stats = metric_info
+
+ for row_id, row in read_csv(file).iterrows():
+ try:
+ # Parse row data
+ name = row.get(MonitorConst.HEADER_NAME)
+ vpp_stage = int(row['vpp_stage'])
+ micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE))
+ target_id = target_dict.get((name, vpp_stage, micro_step))
+ if not target_id:
+ continue
+
+ step = int(row['step'])
+ table_name, _, _ = db.get_metric_table_name(metric_id, step)
+ # Prepare row data
+ row_data = [rank, step, target_id]
+ row_data.extend(
+ float(row[stat]) if stat in row else None
+ for stat in stats
+ )
+ except (ValueError, KeyError) as e:
+ logger.error(
+ f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}")
+ continue
+
+ table_batches[table_name].append(tuple(row_data))
+ # Batch insert when threshold reached
+ if len(table_batches[table_name]) >= BATCH_SIZE:
+ inserted = db.insert_rows(
+ table_name, table_batches[table_name])
+ if inserted is not None:
+ total_inserted += inserted
+ table_batches[table_name] = []
+
+ # Insert remaining data
+ for table_name, batch in table_batches.items():
+ if batch:
+ inserted = db.insert_rows(table_name, batch)
+ if inserted is not None:
+ total_inserted += inserted
+
+ logger.info(f"Rank {rank} inserted {total_inserted} rows")
+ return total_inserted
+
+
+def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool:
+ """Main method to import data into database"""
+ # 1. Pre-scan to get rank tasks
+ monitor_db.init_schema()
+ rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers)
+ if not rank_tasks:
+ logger.error("No valid data files found during pre-scan")
+ return False
+
+ # 2. Get metric and target mappings
+ try:
+ metric_id_dict = monitor_db.get_metric_mapping()
+ target_dict = monitor_db.get_target_mapping()
+ except Exception as e:
+ logger.error(f"Failed to get database mappings: {str(e)}")
+ return False
+
+ # 3. Process data for each rank in parallel
+ total_files = sum(len(files) for files in rank_tasks.values())
+ logger.info(f"Starting data import for {len(rank_tasks)} ranks,"
+ "{total_files} files..."
+ )
+
+ with ProcessPoolExecutor(max_workers=workers) as executor:
+ futures = {
+ executor.submit(
+ process_single_rank,
+ (rank, files),
+ metric_id_dict,
+ target_dict,
+ monitor_db.step_partition_size,
+ monitor_db.db_path): rank
+ for rank, files in rank_tasks.items()
+ }
+
+ with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar:
+ for future in pbar:
+ rank = futures[future]
+ try:
+ inserted = future.result()
+ pbar.set_postfix_str(
+ f"Rank {rank}: inserted {inserted} rows")
+ except Exception as e:
+ logger.error(
+ f"Failed to process Rank {rank}: {str(e)}")
+ return False
+ return True
+
+
+def csv2db(config: CSV2DBConfig) -> None:
+ """Main function to convert CSV files to database"""
+ validate_process_num(config.process_num)
+ validate_step_partition(config.step_partition)
+ validate_data_type_list(config.data_type_list)
+
+ target_output_dirs = get_target_output_dir(
+ config.monitor_path, config.time_start, config.time_end)
+
+ if config.output_dirpath is None:
+ local_tz = pytz.timezone("Asia/Shanghai")
+ cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S")
+ config.output_dirpath = os.path.join(
+ config.monitor_path, f"{cur_time}-csv2db")
+
+ create_directory(config.output_dirpath)
+ db_path = os.path.join(config.output_dirpath, "monitor_metrics.db")
+
+ if os.path.exists(db_path):
+ remove_path(db_path)
+ logger.warning(f"Existing path {db_path} will be recovered")
+
+ db = MonitorDB(db_path, step_partition_size=config.step_partition)
+
+ result = import_data(
+ db,
+ target_output_dirs,
+ config.data_type_list if config.data_type_list else all_data_type_list,
+ workers=config.process_num
+ )
+ recursive_chmod(config.output_dirpath)
+ if result:
+ logger.info(
+ f"Data import completed. Output saved to: {config.output_dirpath}")
+ else:
+ logger.warning(
+ f"Data import may be incomplete. Output directory: {config.output_dirpath} "
+ f"(Some records might have failed)"
+ )
diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b135694c420679c1cede2846e0d241b31171a6d1
--- /dev/null
+++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2025, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+from collections.abc import Iterable
+from typing import Dict, List, Optional, Set, Tuple
+
+from msprobe.core.common.const import MonitorConst
+from msprobe.core.common.db_manager import DBManager
+
+
+def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict:
+ """Update ordered dictionary with new items"""
+ for item in new_list:
+ if item not in main_dict:
+ main_dict[item] = None
+ return main_dict
+
+
+def get_ordered_stats(stats: Iterable) -> List[str]:
+ """Get statistics in predefined order"""
+ if not isinstance(stats, Iterable):
+ return []
+ return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats]
+
+
+class MonitorSql:
+ """数据库表参数类"""
+
+ @staticmethod
+ def create_monitoring_targets_table():
+ """监控目标表"""
+ return """
+ CREATE TABLE IF NOT EXISTS monitoring_targets (
+ target_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ target_name TEXT NOT NULL,
+ vpp_stage INTEGER NOT NULL,
+ micro_step INTEGER NOT NULL DEFAULT 0,
+ UNIQUE(target_name, vpp_stage, micro_step)
+ )"""
+
+ @staticmethod
+ def create_monitoring_metrics_table():
+ """监控指标表"""
+ return """
+ CREATE TABLE IF NOT EXISTS monitoring_metrics (
+ metric_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ metric_name TEXT UNIQUE NOT NULL
+ )"""
+
+ @staticmethod
+ def get_metric_mapping_sql():
+ return """
+ SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats
+ FROM monitoring_metrics m
+ LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id
+ GROUP BY m.metric_id
+ """
+
+ @staticmethod
+ def create_metric_stats_table():
+ """指标统计表"""
+ return """
+ CREATE TABLE IF NOT EXISTS metric_stats (
+ metric_id INTEGER NOT NULL,
+ stat_name TEXT NOT NULL,
+ PRIMARY KEY (metric_id, stat_name),
+ FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id)
+ ) WITHOUT ROWID"""
+
+ @staticmethod
+ def create_global_stat_table():
+ return """
+ CREATE TABLE IF NOT EXISTS global_stats (
+ stat_name TEXT PRIMARY KEY,
+ stat_value INTEGER NOT NULL
+ ) WITHOUT ROWID"""
+
+ @classmethod
+ def get_table_definition(cls, table_name=""):
+ """
+ 获取表定义SQL
+ :param table_name: 表名
+ :return: 建表SQL语句
+ :raises ValueError: 当表名不存在时
+ """
+ table_creators = {
+ "monitoring_targets": cls.create_monitoring_targets_table,
+ "monitoring_metrics": cls.create_monitoring_metrics_table,
+ "metric_stats": cls.create_metric_stats_table,
+ "global_stats": cls.create_global_stat_table,
+ }
+ if not table_name:
+ return [table_creators.get(table, lambda x:"")() for table in table_creators]
+ if table_name not in table_creators:
+ raise ValueError(f"Unsupported table name: {table_name}")
+ return table_creators[table_name]()
+
+ @classmethod
+ def get_metric_table_definition(cls, table_name, stats, patition=None):
+ stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats]
+ if patition and len(patition) == 2:
+ partition_start_step, partition_end_step = patition
+ step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step}
+ AND {partition_end_step}),"""
+ else:
+ step_column = "step INTEGER NOT NULL"
+ create_sql = f"""
+ CREATE TABLE {table_name} (
+ rank INTEGER NOT NULL,
+ {step_column}
+ target_id INTEGER NOT NULL,
+ {', '.join(stat_columns)},
+ PRIMARY KEY (rank, step, target_id),
+ FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id)
+ ) WITHOUT ROWID
+ """
+ return create_sql
+
+
+class MonitorDB:
+ """Main class for monitoring database operations"""
+
+ def __init__(self, db_path: str, step_partition_size: int = 500):
+ self.db_path = db_path
+ self.db_manager = DBManager(db_path)
+ self.step_partition_size = step_partition_size
+
+ def get_metric_table_name(self, metric_id: int, step: int) -> str:
+ """Generate metric table name"""
+ step_start = (
+ step // self.step_partition_size) * self.step_partition_size
+ step_end = step_start + self.step_partition_size - 1
+ return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end
+
+ def init_schema(self) -> None:
+ """Initialize database schema"""
+ self.db_manager.execute_multi_sql(MonitorSql.get_table_definition())
+
+ # Insert initial global stats
+ global_stats = [
+ ('max_rank', 0),
+ ('min_step', 0),
+ ('max_step', 0),
+ ('step_partition_size', self.step_partition_size)
+ ]
+ self.db_manager.insert_data("global_stats", global_stats)
+
+ def insert_dimensions(
+ self,
+ targets: OrderedDict,
+ metrics: Set[str],
+ metric_stats: Dict[str, Set[str]],
+ min_step: Optional[int] = None,
+ max_step: int = None,
+ ) -> None:
+ """Insert dimension data into database"""
+ # Insert targets
+ self.db_manager.insert_data(
+ "monitoring_targets",
+ [(name, vpp_stage, micro_step)
+ for (name, vpp_stage, micro_step) in targets],
+ key_list=["target_name", "vpp_stage", "micro_step"]
+ )
+
+ # Insert metrics
+ self.db_manager.insert_data(
+ "monitoring_metrics",
+ [(metric,) for metric in metrics],
+ key_list=["metric_name"]
+ )
+
+ # Insert metric-stat relationships
+ for metric, stats in metric_stats.items():
+ metric_id = self._get_metric_id(metric)
+ ordered_stats = get_ordered_stats(stats)
+
+ self.db_manager.insert_data(
+ "metric_stats",
+ [(metric_id, stat) for stat in ordered_stats],
+ key_list=["metric_id", "stat_name"]
+ )
+
+ # Create metric tables for each partition
+ if min_step is not None and max_step is not None:
+ first_partition = min_step // self.step_partition_size
+ last_partition = max_step // self.step_partition_size
+
+ for partition in range(first_partition, last_partition + 1):
+ step_start = partition * self.step_partition_size
+ self.create_metric_table(
+ metric_id, step_start, ordered_stats)
+
+ def insert_rows(self, table_name, rows):
+ if not self.db_manager.table_exists(table_name):
+ raise RuntimeError(f"{table_name} not existed in {self.db_path}")
+ inserted = self.db_manager.insert_data(table_name, rows)
+ inserted = 0 if inserted is None else inserted
+ return inserted
+
+ def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str:
+ """Create metric table for a specific partition"""
+ table_name, partition_start_step, partition_end_step = self.get_metric_table_name(
+ metric_id,
+ step
+ )
+ if self.db_manager.table_exists(table_name):
+ return table_name
+
+ create_sql = MonitorSql.get_metric_table_definition(
+ table_name, stats, patition=(
+ partition_start_step, partition_end_step)
+ )
+ self.db_manager.execute_sql(create_sql)
+ return table_name
+
+ def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None:
+ """Update global statistics"""
+ updates = [
+ ("max_rank", max_rank),
+ ("min_step", min_step),
+ ("max_step", max_step)
+ ]
+ for stat_name, value in updates:
+ if not value:
+ continue
+ self.db_manager.update_data(
+ table_name="global_stats",
+ updates={"stat_value": value},
+ where={"stat_name": stat_name}
+ )
+
+ def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]:
+ """Get metric name to ID mapping with statistics"""
+ results = self.db_manager.execute_sql(
+ MonitorSql.get_metric_mapping_sql()
+ )
+
+ return {
+ row["metric_name"]: (
+ row["metric_id"],
+ get_ordered_stats(row["stats"].split(",")
+ ) if row["stats"] else []
+ ) for row in results
+ }
+
+ def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]:
+ """Get target mapping dictionary"""
+ results = self.db_manager.select_data(
+ table_name="monitoring_targets",
+ columns=["target_id", "target_name", "vpp_stage", "micro_step"]
+ )
+ if not results:
+ return {}
+ return {
+ (row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"]
+ for row in results
+ }
+
+ def _get_metric_id(self, metric_name: str) -> Optional[int]:
+ """Get metric ID by name"""
+ result = self.db_manager.select_data(
+ table_name="monitoring_metrics",
+ columns=["metric_id"],
+ where={"metric_name": metric_name}
+ )
+ return result[0]["metric_id"] if result else None
diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md
index 2374ef7680e59d5e85fe276dc8597ffb2f4bdbfd..e7f28fead66248f8dde6fcebe7bcc762590124a2 100644
--- a/debug/accuracy_tools/msprobe/docs/19.monitor.md
+++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md
@@ -467,6 +467,24 @@ csv2tensorboard_by_step(
)
```
+将csv数据转换为sqlite db数据。
+
+```python
+from msprobe.core.monitor.csv2db import CSV2DBConfig, csv2db
+# output_dirpath可指定输出目录,默认保存到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳
+# step_partition可以控制数据库中按step分区的间隔,默认每500步一个表
+config = CSV2DBConfig(
+ monitor_path="~/monitor_output",# 与转换为tensorboard用法一致
+ time_start="Dec03_21-34-40",# 与转换为tensorboard用法一致
+ time_end="Dec03_21-34-42",# 与转换为tensorboard用法一致
+ process_num=8,# 与转换为tensorboard用法一致
+ data_type_list=["grad_unreduced"],# 与转换为tensorboard用法一致
+ step_partition=500,
+ output_dirpath="~/monitor_output"
+)
+csv2db(config)
+```
+
### 动态启停
动态启停模式:支持用户在训练过程中随时启动/更新监控。
@@ -553,6 +571,24 @@ csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_
| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 |
| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 |
| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳。 | 否 |
+
+- CSV转sqlite数据库接口
+```python
+csv2db(config: CSV2DBConfig) -> None
+```
+配置参数 (CSV2DBConfig)
+
+| 参数 | 说明 | 是否必选 |
+| -------------- | ------------------------------------------------------------ | -------- |
+| monitor_path | 待转换的csv存盘目录。 | 是 |
+| time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 |
+| time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 |
+| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 |
+| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"]。
不指定就转换全部数据。 | 否 |
+| step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 |
+| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳。 | 否 |
+
+
- 在模型任意位置获取当前参数**梯度**统计量
```python
TrainerMon.generate_wgrad_metrics() -> tuple[dict, dict]