diff --git a/README.md b/README.md index 55ab96d200f02b564ecac6dc8b62d4360a02c1c5..c3912062de7097af62955d1302af9ed3a2660ed6 100644 --- a/README.md +++ b/README.md @@ -26,10 +26,26 @@ { "training root path": "str", # 训练代码根目录 "exec cmd filepath": "str", # 启动命令所在文件 - "config filepath": "list", # 可能有多个配置文件 + "config_home_dirs": dict, # 多个配置文件路径 "output zip path": "./config_check_pack.zip" # 打包结果路径 } ``` + +说明: +> config_home_dirs【可选,缺省默认为training_root_path】:参数指定了多个配置文件的root地址,代码会自动扫描该文件夹下配置文件(当前支持ini、xml、yaml、yml、json等)并完成打包动作 + +示例: + +```json +{ + "training root path": str, + "exec cmd filepath": str, + "config_home_dirs": { + "dir1": "path_to_config_dir1", + "dir2": "path_to_config_dir2" + } +} +``` 2、数据采集,将环境上需要校验的内容打包。 **静态数据采集** @@ -47,12 +63,22 @@ python -m config_checking -p config_path ``` from config_checking.config_checker import ConfigChecker + ConfigChecker(config_path, model) ``` 会得到一个zip包,里面除了静态信息,还包括训练使用的数据集、权重等动态信息,并会进行随机性设置的告警。 采集到第一个迭代使用的训练数据后会直接退出。 +**随机语句设置patch** +在训练代码初始阶段插入如下代码: +```python +import config_checking.utils.random_patch + +random_patch.apply_patches() # 针对所有random,打印调用日志 +``` +上述语句会在日志中打印代码执行random操作的堆栈信息,供用户查看 + 3、npu和gpu上分别执行上述操作,获得两个zip包 4、将两个zip包传到同一个环境下,使用如下命令进行比对: diff --git a/config_checking/__main__.py b/config_checking/__main__.py index e14f045eefcdceb73ff75dfd5c7f4f62222034ed..9d6f4ae582c2048f75bfb7ce729237abd453b466 100644 --- a/config_checking/__main__.py +++ b/config_checking/__main__.py @@ -19,12 +19,15 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='Configuration Checker!') parser.add_argument('-p', '--pack', help='Pack a directory into a zip file') parser.add_argument('-c', '--compare', nargs=2, help='Compare two zip files') - parser.add_argument('-o', '--output', help='output path') + parser.add_argument('-o', '--output', help='output path, default is current directory') args = parser.parse_args() if args.pack: pack(args.pack) elif args.compare: output_dirpath = args.output if args.output else "./config_check_result" compare(args.compare[0], args.compare[1], output_dirpath) + if not args.output: + args.output = os.getcwd() + compare(args.compare[0], args.compare[1], args.output) else: - parser.print_help() \ No newline at end of file + parser.print_help() diff --git a/config_checking/checkers/__init__.py b/config_checking/checkers/__init__.py index 712bab3ce682f33d260696b8c8e0f5da7d69e309..b3db1816526c5187f3cec5c171e88594efd19cdb 100644 --- a/config_checking/checkers/__init__.py +++ b/config_checking/checkers/__init__.py @@ -3,4 +3,10 @@ import config_checking.checkers.env_args_checker import config_checking.checkers.exec_cmd_checker import config_checking.checkers.pip_checker import config_checking.checkers.dataset_checker -import config_checking.checkers.weights_checker \ No newline at end of file +import config_checking.checkers.weights_checker +import config_checking.checkers.config_file_checker + + +from config_checking.checkers.base_checker import BaseChecker + +__all__ = ['BaseChecker',] \ No newline at end of file diff --git a/config_checking/checkers/base_checker.py b/config_checking/checkers/base_checker.py index 03ded1e25f90f6d839ae53843a5a71c6e9e48ea1..b2a04bb3a9a9699dee3c6f643d0c00dd2bf247e0 100644 --- a/config_checking/checkers/base_checker.py +++ b/config_checking/checkers/base_checker.py @@ -7,7 +7,7 @@ class PackInput: def __init__(self, config_dict=None, model=None): self.training_root_path = config_dict.get("training root path", None) self.exec_cmd_filepath = config_dict.get("exec cmd filepath", None) - self.config_filepath = config_dict.get("config filepath", None) + self.config_home_dirs = config_dict.get("config_home_dirs", None) self.output_zip_path = config_dict.get("output zip path", "./config_check_pack.zip") self.model = model @@ -16,10 +16,12 @@ class BaseChecker(ABC): input_needed = None target_name_in_zip = None + @staticmethod @abstractmethod def pack(pack_input): pass + @staticmethod @abstractmethod def compare(bench_dir, cmp_dir, output_path): pass diff --git a/config_checking/checkers/config_file_checker.py b/config_checking/checkers/config_file_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..d463000e449866e35eb39111ee10106036f209d0 --- /dev/null +++ b/config_checking/checkers/config_file_checker.py @@ -0,0 +1,101 @@ +import os +import shutil +import yaml + +from config_checking.config_checker import register_checker_item +from config_checking.utils.config_compare import CONFIG_EXTENSIONS, ConfigComparator +from config_checking.utils.packing import add_dir_to_zip + +from config_checking.checkers.base_checker import BaseChecker + + +def _copy_and_rename_files_with_given_extension(src_folder, dest_folder, file_extension=None, prefix=''): + """ + TODO 重构,构建统一方法,传入函数f对文件进行过滤,保存到dest文件夹 + Args: + src_folder: + dest_folder: + file_extension: + prefix: + Returns: + """ + if not file_extension: + return + for root, _, files in os.walk(src_folder): + for file in files: + if file.endswith(file_extension): + relative_path = os.path.relpath(root, src_folder) + # 生成新的文件名: prefix_相对路径_文件名 + new_file_name = prefix + relative_path.replace(os.sep, '_') + '_' + file + src_file = os.path.join(root, file) + dest_file = os.path.join(dest_folder, new_file_name) + # 确保目标目录存在 + os.makedirs(os.path.dirname(dest_file), exist_ok=True) + shutil.copy2(src_file, dest_file) + # TODO change to logger + print(f"Copied and renamed: {src_file} -> {dest_file}") + + +@register_checker_item("config_file") +class ConfigFileChecker(BaseChecker): + input_needed = "config_home_dirs" + target_name_in_zip = "config_files" + + EXTENSIONS = CONFIG_EXTENSIONS + + comparator = ConfigComparator() + + @staticmethod + def _load_config_files_dict(root_dir): + conf_dir = os.path.join(root_dir, ConfigFileChecker.target_name_in_zip) + files = dict() + for entry in os.listdir(conf_dir): + full_path = os.path.join(conf_dir, entry) + if os.path.isdir(full_path): + continue + files[entry] = full_path + return files + + @staticmethod + def _compare_cfgs(bench_conf_dict, cmp_conf_dict): + rs = dict() + for key in bench_conf_dict.keys(): + conf1 = bench_conf_dict[key] + conf2 = cmp_conf_dict[key] + diff = ConfigFileChecker.comparator.compare(conf1, conf2) + rs[key] = diff + return rs + + @staticmethod + def _write_files(diff_results, output_dir, file_name='config_file_checker.diff.yaml'): + out_file_path = os.path.join(output_dir, file_name) + with open(out_file_path, 'w', encoding='utf-8') as file: + yaml.dump(diff_results, file, default_flow_style=False, allow_unicode=True) + + @staticmethod + def pack(configs): + output_zip_path = configs.output_zip_path + training_config_dirs = configs.config_home_dirs + if not training_config_dirs: + print("[warning] invalid config_home_dirs") + return + + # should be deleted by end of this func + dest_dir = os.path.join(os.path.dirname(output_zip_path), ConfigFileChecker.target_name_in_zip) + + for _name, _conf_dir in training_config_dirs.items(): + _copy_and_rename_files_with_given_extension(_conf_dir, dest_dir, + file_extension=ConfigFileChecker.EXTENSIONS, prefix=_name) + # add to zip file + add_dir_to_zip(output_zip_path, dest_dir) + if os.path.exists(output_zip_path): + shutil.rmtree(dest_dir) + + @staticmethod + def compare(bench_dir, cmp_dir, output_path): + bench_files_dict = ConfigFileChecker._load_config_files_dict(bench_dir) + cmp_files_dict = ConfigFileChecker._load_config_files_dict(cmp_dir) + diff_results = ConfigFileChecker._compare_cfgs(bench_files_dict, cmp_files_dict) + + ConfigFileChecker._write_files(diff_results, output_path) + diff --git a/config_checking/utils/__init__.py b/config_checking/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7e6b9eca515e85e39b54d6bca31187a9c18c92 --- /dev/null +++ b/config_checking/utils/__init__.py @@ -0,0 +1,4 @@ +from .random_patch import apply_patches + +__all__ = ['apply_patches'] + diff --git a/config_checking/utils/config_compare.py b/config_checking/utils/config_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..52bec365087ca6ebadefc190a462fdf13485e13c --- /dev/null +++ b/config_checking/utils/config_compare.py @@ -0,0 +1,104 @@ +import json +import os +import yaml +import xml.etree.ElementTree as ET +import configparser +from abc import ABC, abstractmethod + +from .compare import compare_json + + +CONFIG_EXTENSIONS = ('json', 'yaml', 'xml', 'ini', 'yml') + + +class Parser(ABC): + @abstractmethod + def parse(self, file_path: str) -> dict: + pass + + +class JsonParser(Parser): + def parse(self, file_path: str) -> dict: + with open(file_path, 'r') as f: + return json.load(f) + + +class IniParser(Parser): + def parse(self, file_path: str) -> dict: + config = configparser.ConfigParser() + config.read(file_path) + return {section: dict(config.items(section)) for section in config.sections()} + + +class YamlParser(Parser): + def parse(self, file_path: str) -> dict: + with open(file_path, 'r') as f: + return yaml.safe_load(f) + + +class XmlParser(Parser): + def parse(self, file_path: str) -> dict: + tree = ET.parse(file_path) + root = tree.getroot() + return self._element_to_dict(root) + + def _element_to_dict(self, element): + return { + element.tag: {child.tag: self._element_to_dict(child) for child in element} + if list(element) + else element.text + } + + +class ParserFactory: + __ParserDict = { + 'json': JsonParser, + 'ini': IniParser, + 'yaml': YamlParser, + 'yml': YamlParser, + 'xml': XmlParser + } + + def get_parser(self, file_type: str) -> Parser: + parser = self.__ParserDict[file_type] + if not parser: + raise ValueError(f'Invalid parser type: {file_type}') + return parser + + +class ConfigComparator: + def __init__(self): + self.parser_factory = ParserFactory() + + def __compare_files(self, file_path1: str, file_type1: str, file_path2: str, file_type2: str) -> dict: + """ + Compare two config files. + Args: + Returns: + dict: a set of AUD changes of comparison of obj1 and obj2 + """ + parser1 = self.parser_factory.get_parser(file_type1) + parser2 = self.parser_factory.get_parser(file_type2) + + config1 = parser1.parse(file_path1) + config2 = parser2.parse(file_path2) + + return compare_json(config1, config2) + + def compare(self, file_a, file_b): + tya = self.__file_extension(file_a) + tyb = self.__file_extension(file_b) + if tya != tyb: + raise ValueError(f'File extensions do not match: {tya} != {tyb}') + return self.__compare_files(file_a, tya, file_b, tyb) + + @staticmethod + def __file_extension(filepath): + _, file_extension = os.path.splitext(filepath) + return file_extension + + +if __name__ == "__main__": + comparator = ConfigComparator() + differences = comparator.compare('config1.ini', 'config2.ini') + print(json.dumps(differences, indent=4)) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..5b825f5ee9d8aef0d681abe42d734e4043c0d519 --- /dev/null +++ b/config_checking/utils/random_patch.py @@ -0,0 +1,39 @@ +import random +import numpy as np +import torch +import traceback +from functools import wraps + + +def __log_stack(func): + @wraps(func) + def wrapper(*args, **kwargs): + stack = traceback.format_stack() + # TODO 替换为logger + print(f"Function {func.__name__} called. Call stack:") + for line in stack[:-1]: + # TODO 替换为logger + print(line.strip()) + return func(*args, **kwargs) + + return wrapper + + +def apply_patches(): + # Patch random module + random.random = __log_stack(random.random) + random.randint = __log_stack(random.randint) + random.uniform = __log_stack(random.uniform) + random.choice = __log_stack(random.choice) + + # Patch numpy.random module + np.random.rand = __log_stack(np.random.rand) + np.random.randint = __log_stack(np.random.randint) + np.random.choice = __log_stack(np.random.choice) + np.random.normal = __log_stack(np.random.normal) + + # Patch torch random functions + torch.rand = __log_stack(torch.rand) + torch.randint = __log_stack(torch.randint) + torch.randn = __log_stack(torch.randn) + torch.manual_seed = __log_stack(torch.manual_seed)