From d29785f4c36884d28be41f9d33f6f9e543f2a114 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Thu, 31 Jul 2025 09:54:42 +0800 Subject: [PATCH 01/23] monvis --- .../msprobe/core/common/const.py | 5 + .../msprobe/core/common/db_manager.py | 214 +++++++++++ .../msprobe/core/monitor/csv2db.py | 342 ++++++++++++++++++ .../msprobe/core/monitor/db_utils.py | 264 ++++++++++++++ 4 files changed, 825 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/core/common/db_manager.py create mode 100644 debug/accuracy_tools/msprobe/core/monitor/csv2db.py create mode 100644 debug/accuracy_tools/msprobe/core/monitor/db_utils.py diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 560d939b34..039253180f 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -773,6 +773,11 @@ class MonitorConst: DEFAULT_STEP_INTERVAL = 1 OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"] + OP_MONVIS_SUPPORTED = [ + "norm", "min", "max", "zeros", "nans", "mean", + "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian", + "proxy", "token_similarity" + ] MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR" DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output" DATABASE = "database" diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py new file mode 100644 index 0000000000..b23fe01437 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -0,0 +1,214 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sqlite3 +from typing import List, Tuple, Dict, Any +from functools import wraps + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.file_utils import check_path_before_create, change_mode +from msprobe.core.common.const import FileCheckConst + +class DBManager: + """ + 数据库管理类,封装常用数据库操作 + """ + + DEFAULT_FETCH_SIZE = 10000 + DEFAULT_INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + def __init__(self, db_path: str): + """ + 初始化DBManager + :param db_path: 数据库文件路径 + :param table_config: 表配置对象 + """ + self.db_path = db_path + + def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: + """获取数据库连接和游标""" + check_path_before_create(self.db_path) + try: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 + curs = conn.cursor() + return conn, curs + except sqlite3.Error as err: + logger.error(f"Database connection failed: {err}") + raise + + def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: + """释放数据库连接""" + try: + if curs is not None: + curs.close() + if conn is not None: + conn.close() + except sqlite3.Error as err: + logger.error(f"Failed to release database connection: {err}") + change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def _db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + return func(self, conn, curs, *args, **kwargs) + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + finally: + self._release_connection(conn, curs) + return wrapper + + @staticmethod + def _get_where_sql(where_list): + if not where_list: + return "", tuple() + + where_clauses = [] + where_values = [] + if where_list: + for col, val in where_list.items(): + where_clauses.append(f"{col} = ?") + where_values.append(val) + if where_clauses: + where_sql = " WHERE " + " AND ".join(where_clauses) + return where_sql, tuple(where_values) + + @_db_operation + def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: + """ + 批量插入数据 + :param table_name: 表名 + :param data: 要插入的数据列表 + :param batch_size: 每批插入的大小 + :return: 插入的行数 + """ + if not data: + return 0 + columns = len(data[0]) + if key_list and columns != len(key_list): + raise ValueError( + f"When inserting into table {table_name}, the length of key list ({key_name})" + f"does not match the data({columns}).") + + batch_size = self.DEFAULT_INSERT_SIZE + placeholders = ", ".join(["?"] * columns) + if key_list: + keys = ", ".join(key_list) + sql = f"INSERT OR IGNORE INTO {table_name} ({keys}) VALUES ({placeholders})" + else: + sql = f"INSERT OR IGNORE INTO {table_name} VALUES ({placeholders})" + + inserted_rows = 0 + for i in range(0, len(data), batch_size): + batch = data[i:i + batch_size] + curs.executemany(sql, batch) + inserted_rows += curs.rowcount + + conn.commit() + return inserted_rows + + @_db_operation + def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, + columns: List[str] = None, + where: dict = None) -> List[Dict]: + """ + 查询数据 + :param table_name: 表名 + :param columns: 要查询的列 + :param where: WHERE条件 + :return: 查询结果列表(字典形式) + """ + + cols = ", ".join(columns) if columns else "*" + sql = f"SELECT {cols} FROM {table_name}" + + where_sql, where_parems = self._get_where_sql(where) + curs.execute(sql+where_sql, where_parems) + + return [dict(row) for row in curs.fetchall()] + + @_db_operation + def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + table_name: str, updates: Dict[str, Any], + where: dict = None) -> int: + """ + 更新数据 + :param table_name: 表名 + :param updates: 要更新的字段和值 + :param where: WHERE条件 + :param where_params: WHERE条件参数 + :return: 影响的行数 + """ + set_clause = ", ".join([f"{k} = ?" for k in updates.keys()]) + sql = f"UPDATE {table_name} SET {set_clause}" + + params = tuple(updates.values()) + + where_sql, where_parems = self._get_where_sql(where) + + curs.execute(sql+where_sql, params + where_parems) + conn.commit() + return curs.rowcount + + @_db_operation + def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql: str, params: Tuple = None) -> List[Dict]: + """ + 执行自定义SQL查询 + :param sql: SQL语句 + :param params: 参数 + :return: 查询结果 + """ + curs.execute(sql, params or ()) + if sql.strip().upper().startswith("SELECT"): + return [dict(row) for row in curs.fetchall()] + conn.commit() + return [] + + def table_exists(self, table_name: str) -> bool: + """ + :param table_name: 表名 + :return: 查询结果 + """ + result = self.select_data( + table_name="sqlite_master", + columns=["name"], + where={"type": "table", "name": table_name} + ) + return len(result) > 0 + + @_db_operation + def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, + sql_commands: List[str]) -> List[List[Dict]]: + """ + 批量执行多个SQL语句 + :param sql_commands: [sql1, sql2, ...] + :return: 每个SELECT语句的结果列表 + """ + results = [] + for sql in sql_commands: + curs.execute(sql) + if sql.strip().upper().startswith("SELECT"): + results.append([dict(row) for row in curs.fetchall()]) + conn.commit() + return results diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py new file mode 100644 index 0000000000..d3ac0a464f --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -0,0 +1,342 @@ +# Copyright (c) 2025-2026, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import os +import re +from collections import OrderedDict, defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import pytz +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.file_utils import (create_directory, read_csv, + recursive_chmod, remove_path) +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.monitor.db_utils import MonitorDB, update_ordered_dict +from msprobe.core.monitor.utils import get_target_output_dir +from tqdm import tqdm + +# Constants +all_data_type_list = [ + "actv", "actv_grad", "exp_avg", "exp_avg_sq", + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", + "linear_hook", "norm_hook", "proxy_model", "token_hook", "attention_hook" +] +DEFAULT_INT_VALUE = 0 +MAX_PROCESS_NUM = 128 +CSV_FILE_PATTERN = r"(\w+)_(\d+)-(\d+)\.csv" +BATCH_SIZE = 10000 + + +@dataclass +class CSV2DBConfig: + """Configuration for CSV to database conversion""" + monitor_path: str + time_start: Optional[str] = None + time_end: Optional[str] = None + process_num: int = 1 + data_type_list: Optional[List[str]] = None + output_dirpath: Optional[str] = None + step_partition: int = 500 + + +def validate_process_num(process_num: int) -> None: + """Validate process number parameter""" + if not is_int(process_num) or process_num <= 0: + raise ValueError("process_num must be a positive integer") + if process_num > MAX_PROCESS_NUM: + raise ValueError(f"Maximum supported process_num is {MAX_PROCESS_NUM}") + + +def validate_step_partition(step_partition: int) -> None: + """Validate step partition parameter""" + if not is_int(step_partition) or step_partition <= 0: + raise ValueError("step_partition must be a positive integer") + + +def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: + """Validate data type list parameter""" + if data_type_list is None or not data_type_list: + logger.info(f"Using default data types: {all_data_type_list}") + return + + if not isinstance(data_type_list, list): + raise ValueError("data_type_list must be a list") + + invalid_types = [t for t in data_type_list if t not in all_data_type_list] + if invalid_types: + raise ValueError(f"Unsupported data types: {invalid_types}") + + +def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: + """Pre-scan files for a single rank to collect metadata""" + metrics = set() + min_step = None + max_step = 0 + metric_stats = defaultdict(set) + targets = OrderedDict() + + for file_path in files: + file_name = os.path.basename(file_path) + match = re.match(CSV_FILE_PATTERN, file_name) + if not match: + continue + + metric_name, step_start, step_end = match.groups() + step_start, step_end = int(step_start), int(step_end) + + metrics.add(metric_name) + min_step = min(min_step or step_start, step_start) + max_step = max(max_step, step_end) + + data = read_csv(file_path) + stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] + metric_stats[metric_name].update(stats) + + for _, row in data.iterrows(): + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target = (name, vpp_stage, micro_step) + if target not in targets: + targets[target] = None + + return { + 'max_rank': int(rank), + 'metrics': metrics, + 'min_step': min_step, + 'max_step': max_step, + 'metric_stats': metric_stats, + 'targets': list(targets.keys()) + } + + +def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1) -> Dict[int, List[str]]: + """Pre-scan all targets, metrics, and statistics""" + logger.info("Scanning dimensions...") + rank_files = defaultdict(list) + + # Collect files for each rank + for rank, dir_path in data_dirs.items(): + files = os.listdir(dir_path) + for file in files: + match = re.match(CSV_FILE_PATTERN, file) + if not match: + continue + metric_name, _, _ = match.groups() + if metric_name not in data_type_list: + continue + rank_files[rank].append(os.path.join(dir_path, file)) + + # Parallel pre-scan + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit(_pre_scan_single_rank, rank, files): rank + for rank, files in rank_files.items() + } + + results = [] + with tqdm(total=len(futures), desc="Pre-scanning ranks") as pbar: + for future in as_completed(futures): + rank = futures[future] + try: + result = future.result() + results.append(result) + except Exception as e: + logger.error( + f"Error pre-scanning rank {rank}: {str(e)}") + pbar.update(1) + + # Aggregate results + targets = OrderedDict() + metrics = set() + min_step = None + max_step = 0 + max_rank = 0 + metric_stats = defaultdict(set) + + for rank_result in results: + max_rank = max(max_rank, rank_result['max_rank']) + metrics.update(rank_result['metrics']) + min_step = min( + min_step or rank_result['min_step'], rank_result['min_step']) + max_step = max(max_step, rank_result['max_step']) + + for metric, stats in rank_result['metric_stats'].items(): + metric_stats[metric].update(stats) + + targets = update_ordered_dict(targets, rank_result['targets']) + + monitor_db.insert_dimensions( + targets, metrics, metric_stats, min_step=min_step, max_step=max_step) + monitor_db.update_global_stats( + max_rank=max_rank, min_step=min_step, max_step=max_step) + return rank_files + + +def process_single_rank( + task: Tuple[int, List[str]], + metric_id_dict: Dict[str, Tuple[int, List[str]]], + target_dict: Dict[Tuple[str, int, int], int], + step_partition_size: int, + db_path: str +) -> int: + """Process data import for a single rank""" + rank, files = task + db = MonitorDB(db_path, step_partition_size=step_partition_size) + total_inserted = 0 + table_batches = defaultdict(list) + + for file in files: + filename = os.path.basename(file) + match = re.match(CSV_FILE_PATTERN, filename) + if not match: + continue + + metric_name, _, _ = match.groups() + metric_info = metric_id_dict.get(metric_name) + if not metric_info: + continue + + metric_id, stats = metric_info + + for row_id, row in read_csv(file).iterrows(): + try: + # Parse row data + name = row.get(MonitorConst.HEADER_NAME) + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + target_id = target_dict.get((name, vpp_stage, micro_step)) + if not target_id: + continue + + step = int(row['step']) + table_name = db.get_metric_table_name(metric_id, step) + # Prepare row data + row_data = [rank, step, target_id] + row_data.extend( + float(row[stat]) if stat in row else None + for stat in stats + ) + table_batches[table_name].append(tuple(row_data)) + + # Batch insert when threshold reached + if len(table_batches[table_name]) >= BATCH_SIZE: + inserted = db.insert_rows( + table_name, table_batches[table_name]) + if inserted is not None: + total_inserted += inserted + table_batches[table_name] = [] + + except (ValueError, KeyError) as e: + logger.error( + f"CSV float conversion failed | file={file}:{row_id+2} | error={str(e)}") + continue + + # Insert remaining data + for table_name, batch in table_batches.items(): + if batch: + inserted = db.insert_rows(table_name, batch) + if inserted is not None: + total_inserted += inserted + + logger.info(f"Rank {rank} inserted {total_inserted} rows") + return total_inserted + + +def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 4) -> bool: + """Main method to import data into database""" + # 1. Pre-scan to get rank tasks + monitor_db.init_schema() + rank_tasks = _pre_scan(monitor_db, data_dirs, data_type_list, workers) + if not rank_tasks: + logger.error("No valid data files found during pre-scan") + return False + + # 2. Get metric and target mappings + try: + metric_id_dict = monitor_db.get_metric_mapping() + target_dict = monitor_db.get_target_mapping() + except Exception as e: + logger.error(f"Failed to get database mappings: {str(e)}") + return False + + # 3. Process data for each rank in parallel + total_files = sum(len(files) for files in rank_tasks.values()) + logger.info(f"Starting data import for {len(rank_tasks)} ranks," + "{total_files} files..." + ) + + with ProcessPoolExecutor(max_workers=workers) as executor: + futures = { + executor.submit( + process_single_rank, + (rank, files), + metric_id_dict, + target_dict, + monitor_db.step_partition_size, + monitor_db.db_path + ): rank for rank, files in rank_tasks.items() + } + + with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: + for future in pbar: + rank = futures[future] + try: + inserted = future.result() + pbar.set_postfix_str( + f"Rank {rank}: inserted {inserted} rows") + except Exception as e: + logger.error( + f"Failed to process Rank {rank}: {str(e)}") + return True + + +def csv2db(config: CSV2DBConfig) -> None: + """Main function to convert CSV files to database""" + validate_process_num(config.process_num) + validate_step_partition(config.step_partition) + validate_data_type_list(config.data_type_list) + + target_output_dirs = get_target_output_dir( + config.monitor_path, config.time_start, config.time_end) + + if config.output_dirpath is None: + local_tz = pytz.timezone("Asia/Shanghai") + cur_time = datetime.datetime.now(local_tz).strftime("%b%d_%H-%M-%S") + config.output_dirpath = os.path.join( + config.monitor_path, f"{cur_time}-csv2db") + + create_directory(config.output_dirpath) + db_path = os.path.join(config.output_dirpath, "monitor_metrics.db") + + if os.path.exists(db_path): + remove_path(db_path) + logger.warning(f"Existing path {db_path} will be recovered") + + db = MonitorDB(db_path, step_partition_size=config.step_partition) + + import_data( + db, + target_output_dirs, + config.data_type_list if config.data_type_list else all_data_type_list, + workers=config.process_num + ) + + recursive_chmod(config.output_dirpath) + logger.info(f"Output has been saved to: {config.output_dirpath}") diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py new file mode 100644 index 0000000000..1096cc209e --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -0,0 +1,264 @@ +from collections import OrderedDict +from collections.abc import Iterable +from typing import Dict, List, Optional, Set, Tuple + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.db_manager import DBManager + + +def update_ordered_dict(main_dict: OrderedDict, new_list: List) -> OrderedDict: + """Update ordered dictionary with new items""" + for item in new_list: + if item not in main_dict: + main_dict[item] = None + return main_dict + + +def get_ordered_stats(stats: Iterable) -> List[str]: + """Get statistics in predefined order""" + if not isinstance(stats, Iterable): + return [] + return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats] + + +class MonitorSql: + """数据库表参数类""" + + @staticmethod + def _create_monitoring_targets_table(): + """监控目标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_targets ( + target_id INTEGER PRIMARY KEY AUTOINCREMENT, + target_name TEXT NOT NULL, + vpp_stage INTEGER NOT NULL, + micro_step INTEGER NOT NULL DEFAULT 0, + UNIQUE(target_name, vpp_stage, micro_step) + )""" + + @staticmethod + def _create_monitoring_metrics_table(): + """监控指标表""" + return """ + CREATE TABLE IF NOT EXISTS monitoring_metrics ( + metric_id INTEGER PRIMARY KEY AUTOINCREMENT, + metric_name TEXT UNIQUE NOT NULL + )""" + + @staticmethod + def _create_metric_stats_table(): + """指标统计表""" + return """ + CREATE TABLE IF NOT EXISTS metric_stats ( + metric_id INTEGER NOT NULL, + stat_name TEXT NOT NULL, + PRIMARY KEY (metric_id, stat_name), + FOREIGN KEY (metric_id) REFERENCES monitoring_metrics(metric_id) + ) WITHOUT ROWID""" + + @staticmethod + def _create_global_stat_table(): + return """ + CREATE TABLE IF NOT EXISTS global_stats ( + stat_name TEXT PRIMARY KEY, + stat_value INTEGER NOT NULL + ) WITHOUT ROWID""" + + @classmethod + def get_table_definition(cls, table_name=""): + """ + 获取表定义SQL + :param table_name: 表名 + :return: 建表SQL语句 + :raises ValueError: 当表名不存在时 + """ + table_creators = { + "monitoring_targets": cls._create_monitoring_targets_table, + "monitoring_metrics": cls._create_monitoring_metrics_table, + "metric_stats": cls._create_metric_stats_table, + "global_stats": cls._create_global_stat_table, + } + if not table_name: + return [table_creators[table]() for table in table_creators] + if table_name not in table_creators: + raise ValueError(f"Unsupported table name: {table_name}") + return table_creators[table_name]() + + @classmethod + def get_metric_table_definition(cls, table_name, stats, patition=[]): + stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] + if len(patition) == 2: + partition_start_step, partition_end_step = patition + step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} + AND {partition_end_step}),""" + else: + step_column = "step INTEGER NOT NULL" + create_sql = f""" + CREATE TABLE {table_name} ( + rank INTEGER NOT NULL, + {step_column} + target_id INTEGER NOT NULL, + {', '.join(stat_columns)}, + PRIMARY KEY (rank, step, target_id), + FOREIGN KEY (target_id) REFERENCES monitoring_targets(target_id) + ) WITHOUT ROWID + """ + return create_sql + + @staticmethod + def get_metric_mapping_sql(): + return """ + SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats + FROM monitoring_metrics m + LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id + GROUP BY m.metric_id + """ + + +class MonitorDB: + """Main class for monitoring database operations""" + + def __init__(self, db_path: str, step_partition_size: int = 500): + self.db_path = db_path + self.db_manager = DBManager(db_path) + self.step_partition_size = step_partition_size + + def get_metric_table_name(self, metric_id: int, step: int) -> str: + """Generate metric table name""" + step_start = ( + step // self.step_partition_size) * self.step_partition_size + step_end = step_start + self.step_partition_size - 1 + return f"metric_{metric_id}_step_{step_start}_{step_end}", step_start, step_end + + def init_schema(self) -> None: + """Initialize database schema""" + self.db_manager.execute_multi_sql(MonitorSql.get_table_definition()) + + # Insert initial global stats + global_stats = [ + ('max_rank', 0), + ('min_step', 0), + ('max_step', 0), + ('step_partition_size', self.step_partition_size) + ] + self.db_manager.insert_data("global_stats", global_stats) + + def insert_dimensions( + self, + targets: OrderedDict, + metrics: Set[str], + metric_stats: Dict[str, Set[str]], + min_step: Optional[int] = None, + max_step: int = None, + ) -> None: + """Insert dimension data into database""" + # Insert targets + self.db_manager.insert_data( + "monitoring_targets", + [(name, vpp_stage, micro_step) + for (name, vpp_stage, micro_step) in targets], + key_list=["target_name", "vpp_stage", "micro_step"] + ) + + # Insert metrics + self.db_manager.insert_data( + "monitoring_metrics", + [(metric,) for metric in metrics], + key_list=["metric_name"] + ) + + # Insert metric-stat relationships + for metric, stats in metric_stats.items(): + metric_id = self._get_metric_id(metric) + ordered_stats = get_ordered_stats(stats) + + self.db_manager.insert_data( + "metric_stats", + [(metric_id, stat) for stat in ordered_stats], + key_list=["metric_id", "stat_name"] + ) + + # Create metric tables for each partition + if min_step and max_step: + first_partition = min_step // self.step_partition_size + last_partition = max_step // self.step_partition_size + + for partition in range(first_partition, last_partition + 1): + step_start = partition * self.step_partition_size + self.create_metric_table( + metric_id, step_start, ordered_stats) + + def insert_rows(self, table_name, rows): + if not self.db_manager.table_exists(table_name): + raise RuntimeError(f"{table_name} not existed in {self.db_path}") + inserted = self.db_manager.insert_data(table_name, rows) + inserted = 0 if inserted is None else inserted + return inserted + + def create_metric_table(self, metric_id: int, step: int, stats: List[str]) -> str: + """Create metric table for a specific partition""" + table_name, partition_start_step, partition_end_step = self.get_metric_table_name( + metric_id, + step + ) + if self.db_manager.table_exists(table_name): + return table_name + + create_sql = MonitorSql.get_metric_table_definition( + table_name, stats, patition=( + partition_start_step, partition_end_step) + ) + self.db_manager.execute_sql(create_sql) + return table_name + + def update_global_stats(self, max_rank: int = None, min_step: Optional[int] = None, max_step: int = None) -> None: + """Update global statistics""" + updates = [ + ("max_rank", max_rank), + ("min_step", min_step), + ("max_step", max_step) + ] + for stat_name, value in updates: + if not value: + continue + self.db_manager.update_data( + table_name="global_stats", + updates={"stat_value": value}, + where={"stat_name": stat_name} + ) + + def get_metric_mapping(self) -> Dict[str, Tuple[int, List[str]]]: + """Get metric name to ID mapping with statistics""" + results = self.db_manager.execute_sql( + MonitorSql.get_metric_mapping_sql() + ) + + return { + row["metric_name"]: ( + row["metric_id"], + get_ordered_stats(row["stats"].split(",") + ) if row["stats"] else [] + ) for row in results + } + + def get_target_mapping(self) -> Dict[Tuple[str, int, int], int]: + """Get target mapping dictionary""" + results = self.db_manager.select_data( + table_name="monitoring_targets", + columns=["target_id", "target_name", "vpp_stage", "micro_step"] + ) + if not results: + return {} + return { + (row["target_name"], row["vpp_stage"], row["micro_step"]): row["target_id"] + for row in results + } + + def _get_metric_id(self, metric_name: str) -> Optional[int]: + """Get metric ID by name""" + result = self.db_manager.select_data( + table_name="monitoring_metrics", + columns=["metric_id"], + where={"metric_name": metric_name} + ) + return result[0]["metric_id"] if result else None -- Gitee From 98e9f6fa00c86ce3e79a9c8fea73b218c5d7a01e Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 09:22:46 +0800 Subject: [PATCH 02/23] bugfix --- debug/accuracy_tools/msprobe/core/common/db_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index b23fe01437..36b4efeefe 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -106,7 +106,7 @@ class DBManager: columns = len(data[0]) if key_list and columns != len(key_list): raise ValueError( - f"When inserting into table {table_name}, the length of key list ({key_name})" + f"When inserting into table {table_name}, the length of key list ({key_list})" f"does not match the data({columns}).") batch_size = self.DEFAULT_INSERT_SIZE -- Gitee From dcedf183f95a6cd7612d30e9ac953a26eb032da1 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 10:41:57 +0800 Subject: [PATCH 03/23] bugfix --- debug/accuracy_tools/msprobe/core/monitor/db_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index 1096cc209e..8f6170e255 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -179,7 +179,7 @@ class MonitorDB: ) # Create metric tables for each partition - if min_step and max_step: + if min_step is not None and max_step is not None: first_partition = min_step // self.step_partition_size last_partition = max_step // self.step_partition_size -- Gitee From 175031eeee9a5df27a2e268a5fce69dc39aca510 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 11:34:32 +0800 Subject: [PATCH 04/23] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index d3ac0a464f..d303f8e6a0 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -226,7 +226,7 @@ def process_single_rank( continue step = int(row['step']) - table_name = db.get_metric_table_name(metric_id, step) + table_name, _, _ = db.get_metric_table_name(metric_id, step) # Prepare row data row_data = [rank, step, target_id] row_data.extend( -- Gitee From 8cc457ba0d1d710112608946df2297adca085e68 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 11:46:50 +0800 Subject: [PATCH 05/23] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index d303f8e6a0..3b22482d6a 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -101,7 +101,8 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: step_start, step_end = int(step_start), int(step_end) metrics.add(metric_name) - min_step = min(min_step or step_start, step_start) + min_step = min( + step_start if min_step in None else min_step, step_start) max_step = max(max_step, step_end) data = read_csv(file_path) @@ -174,7 +175,9 @@ def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: max_rank = max(max_rank, rank_result['max_rank']) metrics.update(rank_result['metrics']) min_step = min( - min_step or rank_result['min_step'], rank_result['min_step']) + min_step if min_step is not None else rank_result['min_step'], + rank_result['min_step'] + ) max_step = max(max_step, rank_result['max_step']) for metric, stats in rank_result['metric_stats'].items(): -- Gitee From 24f768599a477304d0a6cd929b0e9d5e5a1187e4 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:29:44 +0800 Subject: [PATCH 06/23] md --- .../accuracy_tools/msprobe/docs/19.monitor.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 2374ef7680..8f2ca4f80f 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -467,6 +467,24 @@ csv2tensorboard_by_step( ) ``` +将csv数据转换为sqlite数据。 + +```python +from msprobe.pytorch.monitor.csv2db import CSV2DBConfig, csv2db +# output_dirpath可指定输出目录,默认保存到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳 +# step_partition可以控制数据库中按step分区的间隔,默认每500步一个表 +config = CSV2DBConfig( + monitor_path="~/monitor_output",# 与转换为tensorboard用法一致 + time_start="Dec03_21-34-40",# 与转换为tensorboard用法一致 + time_end="Dec03_21-34-42",# 与转换为tensorboard用法一致 + process_num=8,# 与转换为tensorboard用法一致 + data_type_list=["grad_unreduced"],# 与转换为tensorboard用法一致 + step_partition=500, + output_dirpath="~/monitor_output" +) +csv2db(config) +``` + ### 动态启停 动态启停模式:支持用户在训练过程中随时启动/更新监控。 -- Gitee From 7b6a8d90d7da54878ae0de32922159a1a15bb85b Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:32:09 +0800 Subject: [PATCH 07/23] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 3b22482d6a..ef9e7439c2 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -102,7 +102,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: metrics.add(metric_name) min_step = min( - step_start if min_step in None else min_step, step_start) + step_start if min_step is None else min_step, step_start) max_step = max(max_step, step_end) data = read_csv(file_path) -- Gitee From 78debecc60df8290e986634ec071501b5a2b4708 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:39:53 +0800 Subject: [PATCH 08/23] cleancode dbmanager --- .../msprobe/core/common/db_manager.py | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 36b4efeefe..e74f7db2c0 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,6 +20,7 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst + class DBManager: """ 数据库管理类,封装常用数据库操作 @@ -37,29 +38,7 @@ class DBManager: """ self.db_path = db_path - def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: - """获取数据库连接和游标""" - check_path_before_create(self.db_path) - try: - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 - curs = conn.cursor() - return conn, curs - except sqlite3.Error as err: - logger.error(f"Database connection failed: {err}") - raise - - def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: - """释放数据库连接""" - try: - if curs is not None: - curs.close() - if conn is not None: - conn.close() - except sqlite3.Error as err: - logger.error(f"Failed to release database connection: {err}") - change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) - + @staticmethod def _db_operation(func): """数据库操作装饰器,自动管理连接""" @wraps(func) @@ -74,6 +53,7 @@ class DBManager: conn.rollback() finally: self._release_connection(conn, curs) + return return wrapper @staticmethod @@ -143,7 +123,7 @@ class DBManager: sql = f"SELECT {cols} FROM {table_name}" where_sql, where_parems = self._get_where_sql(where) - curs.execute(sql+where_sql, where_parems) + curs.execute(sql + where_sql, where_parems) return [dict(row) for row in curs.fetchall()] @@ -166,7 +146,7 @@ class DBManager: where_sql, where_parems = self._get_where_sql(where) - curs.execute(sql+where_sql, params + where_parems) + curs.execute(sql + where_sql, params + where_parems) conn.commit() return curs.rowcount @@ -212,3 +192,26 @@ class DBManager: results.append([dict(row) for row in curs.fetchall()]) conn.commit() return results + + def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: + """获取数据库连接和游标""" + check_path_before_create(self.db_path) + try: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 + curs = conn.cursor() + return conn, curs + except sqlite3.Error as err: + logger.error(f"Database connection failed: {err}") + raise + + def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: + """释放数据库连接""" + try: + if curs is not None: + curs.close() + if conn is not None: + conn.close() + except sqlite3.Error as err: + logger.error(f"Failed to release database connection: {err}") + change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) -- Gitee From dab6f46bf042939029876381c13a45efabab12cd Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:45:12 +0800 Subject: [PATCH 09/23] cleancode db utils --- .../msprobe/core/monitor/db_utils.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index 8f6170e255..c6476491e2 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import OrderedDict from collections.abc import Iterable from typing import Dict, List, Optional, Set, Tuple @@ -44,6 +58,15 @@ class MonitorSql: metric_id INTEGER PRIMARY KEY AUTOINCREMENT, metric_name TEXT UNIQUE NOT NULL )""" + + @staticmethod + def get_metric_mapping_sql(): + return """ + SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats + FROM monitoring_metrics m + LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id + GROUP BY m.metric_id + """ @staticmethod def _create_metric_stats_table(): @@ -79,15 +102,15 @@ class MonitorSql: "global_stats": cls._create_global_stat_table, } if not table_name: - return [table_creators[table]() for table in table_creators] + return [table_creators.get(table, lambda x:"")() for table in table_creators] if table_name not in table_creators: raise ValueError(f"Unsupported table name: {table_name}") return table_creators[table_name]() @classmethod - def get_metric_table_definition(cls, table_name, stats, patition=[]): + def get_metric_table_definition(cls, table_name, stats, patition=None): stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] - if len(patition) == 2: + if patition and len(patition) == 2: partition_start_step, partition_end_step = patition step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} AND {partition_end_step}),""" @@ -105,15 +128,6 @@ class MonitorSql: """ return create_sql - @staticmethod - def get_metric_mapping_sql(): - return """ - SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats - FROM monitoring_metrics m - LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id - GROUP BY m.metric_id - """ - class MonitorDB: """Main class for monitoring database operations""" -- Gitee From fec87a3da682873cc3d0639528fc338b82b52035 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:50:23 +0800 Subject: [PATCH 10/23] cleancode csv2db --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index ef9e7439c2..915da904b6 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -127,7 +127,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: } -def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1) -> Dict[int, List[str]]: +def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1): """Pre-scan all targets, metrics, and statistics""" logger.info("Scanning dimensions...") rank_files = defaultdict(list) @@ -293,8 +293,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list metric_id_dict, target_dict, monitor_db.step_partition_size, - monitor_db.db_path - ): rank for rank, files in rank_tasks.items() + monitor_db.db_path): rank + for rank, files in rank_tasks.items() } with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: @@ -307,7 +307,6 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - return True def csv2db(config: CSV2DBConfig) -> None: -- Gitee From cbaf2e85ba1fce5870121e92defb4af1b368fc7f Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:15:44 +0800 Subject: [PATCH 11/23] markdown --- .../accuracy_tools/msprobe/docs/19.monitor.md | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 8f2ca4f80f..40ab643103 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -467,7 +467,7 @@ csv2tensorboard_by_step( ) ``` -将csv数据转换为sqlite数据。 +将csv数据转换为sqlite db数据。 ```python from msprobe.pytorch.monitor.csv2db import CSV2DBConfig, csv2db @@ -571,6 +571,27 @@ csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_ | process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | | data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | | output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | + +- CSV转数据库接口说明 +```python +csv2db(config: CSV2DBConfig) -> None +``` +配置参数 (CSV2DBConfig) + +| 参数 | 说明 | 是否必选 | +| -------------- | ------------------------------------------------------------ | -------- | +| monitor_path | 待转换的csv存盘目录。 | 是 | +| time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | +| time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | +| process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | +| step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 | +| output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | + +## 使用示例 + +### 基本用法 + - 在模型任意位置获取当前参数**梯度**统计量 ```python TrainerMon.generate_wgrad_metrics() -> tuple[dict, dict] -- Gitee From 60cf57fb496c09b975267b9bbe8eecc5bc02727f Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:25:44 +0800 Subject: [PATCH 12/23] cleancode dbmanager --- .../msprobe/core/common/db_manager.py | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index e74f7db2c0..970b9e29c4 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -38,24 +38,6 @@ class DBManager: """ self.db_path = db_path - @staticmethod - def _db_operation(func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - return func(self, conn, curs, *args, **kwargs) - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - finally: - self._release_connection(conn, curs) - return - return wrapper - @staticmethod def _get_where_sql(where_list): if not where_list: @@ -71,7 +53,28 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - @_db_operation + def db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + + return wrapper + + @db_operation def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: """ @@ -106,7 +109,7 @@ class DBManager: conn.commit() return inserted_rows - @_db_operation + @db_operation def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, columns: List[str] = None, @@ -127,7 +130,7 @@ class DBManager: return [dict(row) for row in curs.fetchall()] - @_db_operation + @db_operation def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, updates: Dict[str, Any], where: dict = None) -> int: @@ -150,7 +153,7 @@ class DBManager: conn.commit() return curs.rowcount - @_db_operation + @db_operation def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql: str, params: Tuple = None) -> List[Dict]: """ @@ -177,7 +180,7 @@ class DBManager: ) return len(result) > 0 - @_db_operation + @db_operation def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql_commands: List[str]) -> List[List[Dict]]: """ @@ -215,3 +218,5 @@ class DBManager: except sqlite3.Error as err: logger.error(f"Failed to release database connection: {err}") change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) + + \ No newline at end of file -- Gitee From 238ec4dd67e23e8d2f73e39a3612635918aef295 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:27:54 +0800 Subject: [PATCH 13/23] cleancode db utils --- .../msprobe/core/common/db_manager.py | 2 -- .../msprobe/core/monitor/db_utils.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 970b9e29c4..bf3732aa98 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -218,5 +218,3 @@ class DBManager: except sqlite3.Error as err: logger.error(f"Failed to release database connection: {err}") change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) - - \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index c6476491e2..b135694c42 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -39,7 +39,7 @@ class MonitorSql: """数据库表参数类""" @staticmethod - def _create_monitoring_targets_table(): + def create_monitoring_targets_table(): """监控目标表""" return """ CREATE TABLE IF NOT EXISTS monitoring_targets ( @@ -51,7 +51,7 @@ class MonitorSql: )""" @staticmethod - def _create_monitoring_metrics_table(): + def create_monitoring_metrics_table(): """监控指标表""" return """ CREATE TABLE IF NOT EXISTS monitoring_metrics ( @@ -69,7 +69,7 @@ class MonitorSql: """ @staticmethod - def _create_metric_stats_table(): + def create_metric_stats_table(): """指标统计表""" return """ CREATE TABLE IF NOT EXISTS metric_stats ( @@ -80,7 +80,7 @@ class MonitorSql: ) WITHOUT ROWID""" @staticmethod - def _create_global_stat_table(): + def create_global_stat_table(): return """ CREATE TABLE IF NOT EXISTS global_stats ( stat_name TEXT PRIMARY KEY, @@ -96,10 +96,10 @@ class MonitorSql: :raises ValueError: 当表名不存在时 """ table_creators = { - "monitoring_targets": cls._create_monitoring_targets_table, - "monitoring_metrics": cls._create_monitoring_metrics_table, - "metric_stats": cls._create_metric_stats_table, - "global_stats": cls._create_global_stat_table, + "monitoring_targets": cls.create_monitoring_targets_table, + "monitoring_metrics": cls.create_monitoring_metrics_table, + "metric_stats": cls.create_metric_stats_table, + "global_stats": cls.create_global_stat_table, } if not table_name: return [table_creators.get(table, lambda x:"")() for table in table_creators] -- Gitee From 4ba46e52df03429992436922597d34ad40f09bc1 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:37:28 +0800 Subject: [PATCH 14/23] cleancode csv2db --- .../accuracy_tools/msprobe/core/monitor/csv2db.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 915da904b6..c7ef89d624 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -307,7 +307,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - + return False + return True def csv2db(config: CSV2DBConfig) -> None: """Main function to convert CSV files to database""" @@ -333,12 +334,18 @@ def csv2db(config: CSV2DBConfig) -> None: db = MonitorDB(db_path, step_partition_size=config.step_partition) - import_data( + result = import_data( db, target_output_dirs, config.data_type_list if config.data_type_list else all_data_type_list, workers=config.process_num ) - recursive_chmod(config.output_dirpath) - logger.info(f"Output has been saved to: {config.output_dirpath}") + if result: + logger.info("Data import completed. Output saved to: %s", config.output_dirpath) + else: + logger.warning( + "Data import may be incomplete. Output directory: %s " + "(Some records might have failed)", + config.output_dirpath + ) -- Gitee From 5e916a06e65cc89d0cdcdd44c00db344dfcc6741 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:41:22 +0800 Subject: [PATCH 15/23] markdown --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 40ab643103..a5abfd4c51 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -470,7 +470,7 @@ csv2tensorboard_by_step( 将csv数据转换为sqlite db数据。 ```python -from msprobe.pytorch.monitor.csv2db import CSV2DBConfig, csv2db +from msprobe.core.monitor.csv2db import CSV2DBConfig, csv2db # output_dirpath可指定输出目录,默认保存到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳 # step_partition可以控制数据库中按step分区的间隔,默认每500步一个表 config = CSV2DBConfig( -- Gitee From 3c8abc428d18c7a603e0864b3617f279a504ff47 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:46:14 +0800 Subject: [PATCH 16/23] markdown --- debug/accuracy_tools/msprobe/core/common/db_manager.py | 2 +- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index bf3732aa98..9e12f89a4f 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -53,7 +53,7 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - def db_operation(func): + def db_operation(self, func): """数据库操作装饰器,自动管理连接""" @wraps(func) def wrapper(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index c7ef89d624..45bfcd81c7 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -310,6 +310,7 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list return False return True + def csv2db(config: CSV2DBConfig) -> None: """Main function to convert CSV files to database""" validate_process_num(config.process_num) -- Gitee From 672ce8affc47287e05f4a5151b35f0beb930e1e0 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:52:16 +0800 Subject: [PATCH 17/23] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 45bfcd81c7..05d21d604e 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -343,10 +343,9 @@ def csv2db(config: CSV2DBConfig) -> None: ) recursive_chmod(config.output_dirpath) if result: - logger.info("Data import completed. Output saved to: %s", config.output_dirpath) + logger.info(f"Data import completed. Output saved to: {config.output_dirpath}") else: logger.warning( - "Data import may be incomplete. Output directory: %s " - "(Some records might have failed)", - config.output_dirpath + f"Data import may be incomplete. Output directory: {config.output_dirpath} " + f"(Some records might have failed)" ) -- Gitee From b7a2c4d8f2c4c5dac23b0d412d46ea0a1a606c35 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:59:02 +0800 Subject: [PATCH 18/23] mdfix --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index a5abfd4c51..9f2f38e846 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -572,7 +572,7 @@ csv2tensorboard_by_step(monitor_path, time_start, time_end, process_num=1, data_ | data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | | output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2tensorboard_by_step"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | -- CSV转数据库接口说明 +- CSV转sqlite数据库接口 ```python csv2db(config: CSV2DBConfig) -> None ``` @@ -588,9 +588,6 @@ csv2db(config: CSV2DBConfig) -> None | step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 | | output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | -## 使用示例 - -### 基本用法 - 在模型任意位置获取当前参数**梯度**统计量 ```python -- Gitee From aa4c54d1522fae00c725b7b7ac778887e2f4208f Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 16:01:24 +0800 Subject: [PATCH 19/23] bugfix --- .../msprobe/core/common/db_manager.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 9e12f89a4f..7ca4a8e2f0 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,6 +20,25 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst +def _db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper class DBManager: """ @@ -53,28 +72,7 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - def db_operation(self, func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - result = func(self, conn, curs, *args, **kwargs) - return result # 显式返回正常结果 - - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - return None # 显式返回错误情况下的None - - finally: - self._release_connection(conn, curs) - - return wrapper - - @db_operation + @_db_operation def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: """ @@ -109,7 +107,7 @@ class DBManager: conn.commit() return inserted_rows - @db_operation + @_db_operation def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, columns: List[str] = None, @@ -130,7 +128,7 @@ class DBManager: return [dict(row) for row in curs.fetchall()] - @db_operation + @_db_operation def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, updates: Dict[str, Any], where: dict = None) -> int: @@ -153,7 +151,7 @@ class DBManager: conn.commit() return curs.rowcount - @db_operation + @_db_operation def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql: str, params: Tuple = None) -> List[Dict]: """ @@ -180,7 +178,7 @@ class DBManager: ) return len(result) > 0 - @db_operation + @_db_operation def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql_commands: List[str]) -> List[List[Dict]]: """ -- Gitee From b3230566309c8ce859134ce3adb2c16620d1c668 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 16:21:34 +0800 Subject: [PATCH 20/23] bugfix --- .../msprobe/core/common/db_manager.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 7ca4a8e2f0..28b5fcb2b8 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,25 +20,27 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst + def _db_operation(func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - result = func(self, conn, curs, *args, **kwargs) - return result # 显式返回正常结果 - - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - return None # 显式返回错误情况下的None - - finally: - self._release_connection(conn, curs) - return wrapper + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper + class DBManager: """ -- Gitee From 1d9a17e28b2ad876660d3c2e74114b4a72a8c828 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 4 Aug 2025 10:01:06 +0800 Subject: [PATCH 21/23] bugfix --- .../msprobe/core/monitor/csv2db.py | 77 +++++++++++-------- .../accuracy_tools/msprobe/docs/19.monitor.md | 2 +- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 05d21d604e..3a1dd7d327 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -34,12 +34,11 @@ from tqdm import tqdm # Constants all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", - "grad_unreduced", "grad_reduced", "param_origin", "param_updated", - "linear_hook", "norm_hook", "proxy_model", "token_hook", "attention_hook" + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other" ] DEFAULT_INT_VALUE = 0 MAX_PROCESS_NUM = 128 -CSV_FILE_PATTERN = r"(\w+)_(\d+)-(\d+)\.csv" +CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" BATCH_SIZE = 10000 @@ -83,6 +82,17 @@ def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: raise ValueError(f"Unsupported data types: {invalid_types}") +def get_info_from_filename(file_name, metric_list=None): + metric_name = "_".join(file_name.split('_')[:-1]) + if metric_list and metric_name not in metric_list: + return "", 0, 0 + match = re.match(f"{metric_name}{CSV_FILE_PATTERN}", file_name) + if not match: + return "", 0, 0 + step_start, step_end = match.groups() + return metric_name, step_start, step_end + + def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: """Pre-scan files for a single rank to collect metadata""" metrics = set() @@ -93,11 +103,9 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: for file_path in files: file_name = os.path.basename(file_path) - match = re.match(CSV_FILE_PATTERN, file_name) - if not match: + metric_name, step_start, step_end = get_info_from_filename(file_name) + if not metric_name: continue - - metric_name, step_start, step_end = match.groups() step_start, step_end = int(step_start), int(step_end) metrics.add(metric_name) @@ -109,10 +117,15 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] metric_stats[metric_name].update(stats) - for _, row in data.iterrows(): - name = row[MonitorConst.HEADER_NAME] - vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + for row_id, row in data.iterrows(): + try: + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + except (ValueError, KeyError) as e: + logger.warning( + f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}") + continue target = (name, vpp_stage, micro_step) if target not in targets: targets[target] = None @@ -136,11 +149,9 @@ def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: for rank, dir_path in data_dirs.items(): files = os.listdir(dir_path) for file in files: - match = re.match(CSV_FILE_PATTERN, file) - if not match: - continue - metric_name, _, _ = match.groups() - if metric_name not in data_type_list: + metric_name, _, _ = get_info_from_filename( + file, metric_list=data_type_list) + if not metric_name: continue rank_files[rank].append(os.path.join(dir_path, file)) @@ -207,11 +218,9 @@ def process_single_rank( for file in files: filename = os.path.basename(file) - match = re.match(CSV_FILE_PATTERN, filename) - if not match: + metric_name, _, _ = get_info_from_filename(filename) + if not metric_name: continue - - metric_name, _, _ = match.groups() metric_info = metric_id_dict.get(metric_name) if not metric_info: continue @@ -236,21 +245,20 @@ def process_single_rank( float(row[stat]) if stat in row else None for stat in stats ) - table_batches[table_name].append(tuple(row_data)) - - # Batch insert when threshold reached - if len(table_batches[table_name]) >= BATCH_SIZE: - inserted = db.insert_rows( - table_name, table_batches[table_name]) - if inserted is not None: - total_inserted += inserted - table_batches[table_name] = [] - except (ValueError, KeyError) as e: logger.error( - f"CSV float conversion failed | file={file}:{row_id+2} | error={str(e)}") + f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}") continue + table_batches[table_name].append(tuple(row_data)) + # Batch insert when threshold reached + if len(table_batches[table_name]) >= BATCH_SIZE: + inserted = db.insert_rows( + table_name, table_batches[table_name]) + if inserted is not None: + total_inserted += inserted + table_batches[table_name] = [] + # Insert remaining data for table_name, batch in table_batches.items(): if batch: @@ -293,8 +301,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list metric_id_dict, target_dict, monitor_db.step_partition_size, - monitor_db.db_path): rank - for rank, files in rank_tasks.items() + monitor_db.db_path): rank + for rank, files in rank_tasks.items() } with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: @@ -343,7 +351,8 @@ def csv2db(config: CSV2DBConfig) -> None: ) recursive_chmod(config.output_dirpath) if result: - logger.info(f"Data import completed. Output saved to: {config.output_dirpath}") + logger.info( + f"Data import completed. Output saved to: {config.output_dirpath}") else: logger.warning( f"Data import may be incomplete. Output directory: {config.output_dirpath} " diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 9f2f38e846..e7f28fead6 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -584,7 +584,7 @@ csv2db(config: CSV2DBConfig) -> None | time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | | time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | | process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | -| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated"]。
不指定就转换全部数据。 | 否 | +| data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"]。
不指定就转换全部数据。 | 否 | | step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 | | output_dirpath | 指定转换后的输出路径,默认输出到"{curtime}_csv2db"文件夹,其中curtime为自动获取的当前时间戳。 | 否 | -- Gitee From ec69c159f250cf9e4138d3e62298de415456215d Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Tue, 5 Aug 2025 21:33:20 +0800 Subject: [PATCH 22/23] bugfix --- .../msprobe/core/common/const.py | 9 +++++ .../msprobe/core/monitor/csv2db.py | 35 ++++++++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 039253180f..719fdda298 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -840,3 +840,12 @@ class MonitorConst: TRAIN_STAGE[key] = BACKWARD_STAGE for key in OPTIMIZER_KEY: TRAIN_STAGE[key] = OPTIMIZER_STAGE + + # csv2db + DEFAULT_INT_VALUE = 0 + MAX_PROCESS_NUM = 128 + CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" + BATCH_SIZE = 10000 + MAX_PARTITION = 10_000_000 + MIN_PARTITION = 10 + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 3a1dd7d327..ef8d4e26c3 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -36,10 +36,7 @@ all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other" ] -DEFAULT_INT_VALUE = 0 -MAX_PROCESS_NUM = 128 -CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" -BATCH_SIZE = 10000 + @dataclass @@ -58,14 +55,18 @@ def validate_process_num(process_num: int) -> None: """Validate process number parameter""" if not is_int(process_num) or process_num <= 0: raise ValueError("process_num must be a positive integer") - if process_num > MAX_PROCESS_NUM: - raise ValueError(f"Maximum supported process_num is {MAX_PROCESS_NUM}") + if process_num > MonitorConst.MAX_PROCESS_NUM: + raise ValueError(f"Maximum supported process_num is {MonitorConst.MAX_PROCESS_NUM}") def validate_step_partition(step_partition: int) -> None: - """Validate step partition parameter""" - if not is_int(step_partition) or step_partition <= 0: - raise ValueError("step_partition must be a positive integer") + if not isinstance(step_partition, int): + raise TypeError("step_partition must be integer") + if not MonitorConst.MIN_PARTITION <= step_partition <= MonitorConst.MAX_PARTITION: + raise ValueError( + f"step_partition must be between {MonitorConst.MIN_PARTITION} ", + f"and {MonitorConst.MAX_PARTITION}, got {step_partition}" + ) def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: @@ -86,7 +87,7 @@ def get_info_from_filename(file_name, metric_list=None): metric_name = "_".join(file_name.split('_')[:-1]) if metric_list and metric_name not in metric_list: return "", 0, 0 - match = re.match(f"{metric_name}{CSV_FILE_PATTERN}", file_name) + match = re.match(f"{metric_name}{MonitorConst.CSV_FILE_PATTERN}", file_name) if not match: return "", 0, 0 step_start, step_end = match.groups() @@ -121,7 +122,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: try: name = row[MonitorConst.HEADER_NAME] vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) except (ValueError, KeyError) as e: logger.warning( f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}") @@ -232,7 +233,7 @@ def process_single_rank( # Parse row data name = row.get(MonitorConst.HEADER_NAME) vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) target_id = target_dict.get((name, vpp_stage, micro_step)) if not target_id: continue @@ -252,7 +253,7 @@ def process_single_rank( table_batches[table_name].append(tuple(row_data)) # Batch insert when threshold reached - if len(table_batches[table_name]) >= BATCH_SIZE: + if len(table_batches[table_name]) >= MonitorConst.BATCH_SIZE: inserted = db.insert_rows( table_name, table_batches[table_name]) if inserted is not None: @@ -290,9 +291,9 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list # 3. Process data for each rank in parallel total_files = sum(len(files) for files in rank_tasks.values()) logger.info(f"Starting data import for {len(rank_tasks)} ranks," - "{total_files} files..." + f"{total_files} files..." ) - + all_succeeded = True with ProcessPoolExecutor(max_workers=workers) as executor: futures = { executor.submit( @@ -315,8 +316,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - return False - return True + all_succeeded = False + return all_succeeded def csv2db(config: CSV2DBConfig) -> None: -- Gitee From a5efea04ccd42d8aa28455a861d19d2bfc17a25a Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Tue, 5 Aug 2025 21:37:27 +0800 Subject: [PATCH 23/23] fix readme --- debug/accuracy_tools/msprobe/docs/19.monitor.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index e7f28fead6..07aba74461 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -581,8 +581,8 @@ csv2db(config: CSV2DBConfig) -> None | 参数 | 说明 | 是否必选 | | -------------- | ------------------------------------------------------------ | -------- | | monitor_path | 待转换的csv存盘目录。 | 是 | -| time_start | 起始时间戳。搭配time_end一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | -| time_end | 结束时间戳。搭配time_start一起使用。指定一个时间范围,会对这个范围内的文件进行转换。左闭右闭的区间。默认为None不限制。 | 否 | +| time_start | 起始时间, 例如"Dec03_21-34-40"。搭配time_end一起使用,从而指定一个时间范围(闭区间),会对这个范围内的文件进行转换。默认为None不限制。 | 否 | +| time_end | 结束时间,例如"Dec03_21-34-41"。搭配time_start一起使用,从而指定一个时间范围(闭区间),会对这个范围内的文件进行转换。默认为None不限制。 | 否 | | process_num | 指定拉起的进程个数,默认为1,更多的进程个数可以加速转换。 | 否 | | data_type_list | 指定需要转换的数据类型, 数据类型应来自输出件文件前缀,所有类型数据:
["actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other"]。
不指定就转换全部数据。 | 否 | | step_partition | 控制数据库中按step分区的间隔,默认每500步一个表。 | 否 | -- Gitee