diff --git a/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py b/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py index 5f2f123fb6867b6b8fc48e8049f6687c5c9369d9..2a2040e1712bb330bfeaaa5552fe8f959e6c912c 100644 --- a/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py +++ b/profiler/msprof_analyze/advisor/advisor_backend/compute_advice/npu_slow_advice.py @@ -37,12 +37,11 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): @staticmethod def save_to_excel(data: pd.DataFrame, file_path: str) -> None: PathManager.check_path_writeable(os.path.dirname(file_path)) - writer = pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") - data.index.name = Constant.TITLE.INDEX - data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) - NpuSlowAdvice.color_sheet(data, writer.book, writer.sheets[NpuSlowAdvice.OP_PERF_SHEET]) - writer.sheets[NpuSlowAdvice.OP_PERF_SHEET].freeze_panes = "A2" - writer.close() + with pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") as writer: + data.index.name = Constant.TITLE.INDEX + data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) + NpuSlowAdvice.color_sheet(data, writer.book, writer.sheets[NpuSlowAdvice.OP_PERF_SHEET]) + writer.sheets[NpuSlowAdvice.OP_PERF_SHEET].freeze_panes = "A2" @staticmethod def color_sheet(data: pd.DataFrame, workbook, worksheet): @@ -80,7 +79,6 @@ class NpuSlowAdvice(ComputeAdviceBase, ABC): self.data = pd.read_csv(self.kernel_details_path, dtype={"Start Time(us)": str}) # 去除末尾的\t分隔符 self.data["Start Time(us)"] = self.data["Start Time(us)"].apply(lambda x: x[:-1]) - pool = multiprocessing.Pool(multiprocessing.cpu_count()) - result = pool.map(self.update_op_row, self.data.iterrows()) - pool.close() + with multiprocessing.Pool(multiprocessing.cpu_count()) as pool: + result = pool.map(self.update_op_row, self.data.iterrows()) self.data = pd.DataFrame(result) diff --git a/profiler/msprof_analyze/advisor/common/profiling/ge_info.py b/profiler/msprof_analyze/advisor/common/profiling/ge_info.py index f255684290e1935928ba741dec4cfdc55341cfe5..b9ef012a5ae0d504e5edde24c33e6ae2afdd7bf6 100644 --- a/profiler/msprof_analyze/advisor/common/profiling/ge_info.py +++ b/profiler/msprof_analyze/advisor/common/profiling/ge_info.py @@ -18,8 +18,7 @@ import logging import os from typing import Any, List -from sqlalchemy import text -from sqlalchemy.exc import SQLAlchemyError +from msprof_analyze.prof_common.db_manager import DBManager from msprof_analyze.advisor.dataset.profiling.db_manager import ConnectionManager from msprof_analyze.advisor.dataset.profiling.profiling_parser import ProfilingParser @@ -51,14 +50,11 @@ class GeInfo(ProfilingParser): check_path_valid(db_path) if not ConnectionManager.check_db_exists(db_path, [db_file]): return False - try: - conn = ConnectionManager(db_path, db_file) - except SQLAlchemyError as e: - logger.error("Database error: %s", e) - return False - if conn.check_table_exists(['TaskInfo']): - with conn().connect() as sql_conn: - self.op_state_info_list = sql_conn.execute(text("select op_name, op_state from TaskInfo")).fetchall() + conn, cursor = DBManager.create_connect_db(db_path) + if DBManager.judge_table_exists(cursor, 'TaskInfo'): + sql = "select op_name, op_state from TaskInfo" + self.op_state_info_list = DBManager.fetch_all_data(cursor, sql) + DBManager.destroy_db_connect(conn, cursor) return True def get_static_shape_operators(self) -> List[Any]: diff --git a/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py b/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py index acd44ac137e2e0ebc05ea6a1ffef64f45cc4c628..0f2efcd433cc4352e21d069b2724464816b3cd26 100644 --- a/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/communication/communication_dataset.py @@ -40,9 +40,6 @@ class CommunicationDataset(Dataset): def __init__(self, collection_path, data: dict, **kwargs) -> None: self.collection_path = collection_path - if not collection_path.endswith("ascend_pt") and not collection_path.endswith("ascend_ms"): - return - self.is_pta = collection_path.endswith("ascend_pt") self.communication_file = "" self.hccl_dict = defaultdict(list) self.step = kwargs.get("step") @@ -138,7 +135,8 @@ class CommunicationDataset(Dataset): if not DBManager.check_tables_in_db(self.communication_file, *expected_tables): logger.warning(f"Communication tables: {expected_tables} not found in {self.communication_file}") return False - export = CommunicationInfoExport(self.communication_file, self.is_pta) + is_pta = self.collection_path.endswith("ascend_pt") + export = CommunicationInfoExport(self.communication_file, is_pta) df = export.read_export_db() if TableConstant.STEP not in df.columns: df[TableConstant.STEP] = 'step' diff --git a/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py b/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py index a61bd8d6a18f65214e8a2b8b068f7d745f33c5d3..56c7b9a02a3d6bb0dff04b5396fb56a1cdf48f74 100644 --- a/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py +++ b/profiler/msprof_analyze/advisor/dataset/stack/db_stack_finder.py @@ -144,13 +144,13 @@ class DBStackFinder: if not self._is_db_contains_stack(): self.stack_map[name] = None return False + conn, cursor = None, None try: conn, cursor = DBManager.create_connect_db(self._db_path) if params: df = pd.read_sql(sql, conn, params=params) else: df = pd.read_sql(sql, conn) - DBManager.destroy_db_connect(conn, cursor) if df is None or df.empty: self.stack_map[name] = None return False @@ -160,3 +160,7 @@ class DBStackFinder: logger.error(f"Error loading API stack data: {e}") self.stack_map[name] = None return False + finally: + if conn and cursor: + DBManager.destroy_db_connect(conn, cursor) + diff --git a/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py b/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py index 016d0b8753d379adaffd25778b92ada977764477..512a7ae16354be0dea91f824e14241c4592438d2 100644 --- a/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py +++ b/profiler/msprof_analyze/advisor/dataset/timeline_event_dataset.py @@ -156,7 +156,7 @@ class BaseTimelineEventDataset(Dataset): for event_type in collector.get_event_type(): df = db_helper.query_timeline_event(event_type) collector.add_op_from_db(df) - db_helper.destory_db_connection() + db_helper.destroy_db_connection() return True def parse_data_with_generator(self, func): diff --git a/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py b/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py index 2e820af7d2c1aa0365a44ceb5e1cb42559d27acb..e96c4c60934621a86c90edc0d2084814c69d127e 100644 --- a/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py +++ b/profiler/msprof_analyze/advisor/dataset/timeline_op_collector/timeline_op_sql.py @@ -243,7 +243,7 @@ class TimelineDBHelper: self.init = bool(self.conn and self.curs) return self.init - def destory_db_connection(self): + def destroy_db_connection(self): DBManager.destroy_db_connect(self.conn, self.curs) self.init = False diff --git a/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py b/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py index 6a094fdf3df8828f0a89e3333517459257440d19..8e35bc84e6f853ea4df0a92f5b19f8be3e3400c9 100644 --- a/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py +++ b/profiler/msprof_analyze/compare_tools/compare_backend/view/excel_view.py @@ -34,3 +34,4 @@ class ExcelView(BaseView): WorkSheetCreator(workbook, sheet_name, data, self._args).create_sheet() workbook.close() os.chmod(self._file_path, Constant.FILE_AUTHORITY) + diff --git a/profiler/msprof_analyze/prof_common/database_service.py b/profiler/msprof_analyze/prof_common/database_service.py index 8cd4cdd2a1f414ba6d7945a22dea8fc6c312f85e..45df254c3d8e461067e68292ada6fb3da8d1a89d 100644 --- a/profiler/msprof_analyze/prof_common/database_service.py +++ b/profiler/msprof_analyze/prof_common/database_service.py @@ -98,10 +98,6 @@ class DatabaseService: result_data[table_name] = data except Exception as err: logger.error(err) - return result_data - try: - DBManager.destroy_db_connect(conn, cursor) - except Exception as err: - logger.error(err) - return result_data + break + DBManager.destroy_db_connect(conn, cursor) return result_data diff --git a/profiler/msprof_analyze/prof_common/file_manager.py b/profiler/msprof_analyze/prof_common/file_manager.py index 737d1788417cbc1131b86b82144953b6449ead14..183ef16b4823f00d96d9537b41b494a03e34b732 100644 --- a/profiler/msprof_analyze/prof_common/file_manager.py +++ b/profiler/msprof_analyze/prof_common/file_manager.py @@ -118,7 +118,7 @@ class FileManager: PathManager.check_path_writeable(os.path.dirname(file_path)) try: with os.fdopen( - os.open(file_path, os.O_WRONLY | os.O_CREAT, Constant.FILE_AUTHORITY), + os.open(file_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, Constant.FILE_AUTHORITY), 'w') as file: file.write(content) except Exception as e: diff --git a/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py b/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py index 6cc9cd0374e5f31351ac4f935f237859af30fe17..c94a963823f3fc2542811bebb8216d7fb5ed9cd1 100644 --- a/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py +++ b/profiler/msprof_analyze/test/ut/advisor/dataset/test_timeline_op_sql.py @@ -131,6 +131,6 @@ class TestTimelineDBHelper(unittest.TestCase): self.db_helper.init = True self.db_helper.conn = MagicMock() self.db_helper.curs = MagicMock() - self.db_helper.destory_db_connection() + self.db_helper.destroy_db_connection() self.assertFalse(self.db_helper.init)