From 25f0c0248565ab908eec140c1b547d0800c64199 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 23 Oct 2024 11:07:34 +0800 Subject: [PATCH] fix json safe problem --- .../msprobe/core/common/file_utils.py | 4 ++-- .../pytorch/online_dispatch/compare.py | 20 +------------------ .../pytorch/online_dispatch/dump_compare.py | 8 +++----- .../test_compare_online_dispatch.py | 18 +---------------- 4 files changed, 7 insertions(+), 43 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index c7c020130..96fe8ca49 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -365,11 +365,11 @@ def load_json(json_path): return data -def save_json(json_path, data, indent=None): +def save_json(json_path, data, indent=None, mode="w"): check_path_before_create(json_path) json_path = os.path.realpath(json_path) try: - with FileOpen(json_path, 'w') as f: + with FileOpen(json_path, mode) as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(data, f, indent=indent) fcntl.flock(f, fcntl.LOCK_UN) diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index 07989907a..008bb2b32 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -20,7 +20,7 @@ import sys from collections import namedtuple from msprobe.core.common.const import CompareConst, FileCheckConst -from msprobe.core.common.file_utils import FileOpen, change_mode, read_csv +from msprobe.core.common.file_utils import FileOpen, change_mode, read_csv, get_json_contents from msprobe.core.common.utils import CompareException, check_op_str_pattern_valid from msprobe.pytorch.common.log import logger from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap @@ -35,24 +35,6 @@ ResultInfo = namedtuple('ResultInfo', ['api_name', 'is_fwd_success', 'is_bwd_suc 'fwd_compare_alg_results', 'bwd_compare_alg_results']) -def get_file_content_bytes(file): - with FileOpen(file, 'rb') as file_handle: - return file_handle.read() - - -def get_json_contents(file_path): - ops = get_file_content_bytes(file_path) - try: - json_obj = json.loads(ops) - except ValueError as error: - logger.error('Failed to load "%s". %s' % (file_path, str(error))) - raise CompareException(CompareException.INVALID_FILE_ERROR) from error - if not isinstance(json_obj, dict): - logger.error('Json file %s, content is not a dictionary!' % file_path) - raise CompareException(CompareException.INVALID_FILE_ERROR) - return json_obj - - def write_csv(data, filepath): with FileOpen(filepath, 'a', encoding='utf-8-sig') as f: writer = csv.writer(f) diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index edb9c40d3..b185bc111 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -19,7 +19,7 @@ import os from datetime import datetime, timezone import torch -from msprobe.core.common.file_utils import FileOpen, save_npy +from msprobe.core.common.file_utils import FileOpen, save_npy, save_json from msprobe.pytorch.common.log import logger @@ -107,10 +107,8 @@ def dump_data(data, prefix, dump_path): def save_temp_summary(api_index, single_api_summary, path, lock): summary_path = os.path.join(path, f'summary.json') lock.acquire() - with FileOpen(summary_path, "a") as f: - json.dump([api_index, single_api_summary], f) - f.write('\n') - lock.release() + data = [api_index, single_api_summary] + save_json(summary_path, data, mode='a') def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_compare_online_dispatch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_compare_online_dispatch.py index e0c3c3368..47db945c2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_compare_online_dispatch.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_compare_online_dispatch.py @@ -22,7 +22,7 @@ from unittest.mock import Mock, patch import pandas as pd from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.utils import CompareException -from msprobe.pytorch.online_dispatch.compare import get_json_contents, Saver, Comparator +from msprobe.pytorch.online_dispatch.compare import Saver, Comparator from rich.table import Table from io import StringIO from rich.console import Console @@ -41,22 +41,6 @@ class TestCompare(unittest.TestCase): if os.path.exists(self.list_json_path): os.remove(self.list_json_path) - def test_get_json_contents_when_get_json(self): - data = {"one": 1} - with FileOpen(self.dict_json_path, 'w') as f: - json.dump(data, f) - self.assertEqual(get_json_contents(self.dict_json_path), data) - - @patch('msprobe.core.common.log.BaseLogger.error') - def test_get_json_contents_when_get_list(self, mock_error): - data = [1, 2] - with FileOpen(self.list_json_path, 'w') as f: - json.dump(data, f) - with self.assertRaises(CompareException) as context: - get_json_contents(self.list_json_path) - self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) - mock_error.assert_called_once_with('Json file %s, content is not a dictionary!' % self.list_json_path) - class TestSaver(unittest.TestCase): def setUp(self): -- Gitee