diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 2cbd1750f1fb0adcbe22b930bb1f4b4278f6056c..0f76069a7f9e81cd6ddf08a467c5cdf2f7a55f5c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -1 +1,57 @@ -# 基于api——info信息,将其落盘为json文件 \ No newline at end of file +import fcntl +import json +import os +import threading +import numpy as np + +from .api_info import ForwardAPIInfo, BackwardAPIInfo +from .utils import DumpUtil +from ..common.utils import check_file_or_directory_path + +lock = threading.Lock() + +def write_api_info_json(api_info): + dump_path = DumpUtil.dump_path + initialize_output_json() + if isinstance(api_info, ForwardAPIInfo): + file_path = os.path.join(dump_path, 'forward_info.json') + stack_file_path = os.path.join(dump_path, 'stack_info.json') + write_json(file_path, api_info.api_info_struct) + write_json(stack_file_path, api_info.stack_info_struct, indent=4) + + elif isinstance(api_info, BackwardAPIInfo): + file_path = os.path.join(dump_path, 'backward_info.json') + write_json(file_path, api_info.grad_info_struct) + else: + raise ValueError(f"Invalid api_info type {type(api_info)}") + +def write_json(file_path, data, indent=None): + check_file_or_directory_path(file_path,False) + with open(file_path, 'w') as f: + f.write("{\n}") + try: + lock.acquire() + with open(file_path, 'a+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0, os.SEEK_END) + f.seek(f.tell() - 1, os.SEEK_SET) + f.truncate() + if f.tell() > 3: + f.seek(f.tell() - 1, os.SEEK_SET) + f.truncate() + f.write(',\n') + f.write(json.dumps(data, indent=indent)[1:-1] + '\n}') + except Exception as e: + raise ValueError(f"Json save failed:{e}") + finally: + fcntl.flock(f, fcntl.LOCK_UN) + lock.release() + +def initialize_output_json(): + dump_path = DumpUtil.dump_path + check_file_or_directory_path(dump_path,True) + files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] + for file in files: + file_path = os.path.join(dump_path, file) + if os.path.exists(file_path): + raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") \ No newline at end of file