From 9b95251a2e0ed02df77fdaad78b702af0885e1b2 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 9 Jan 2025 09:06:30 +0000 Subject: [PATCH 1/9] add config_checking/checkers/hyperparameter_checker.py. Signed-off-by: sunyiming --- .../checkers/hyperparameter_checker.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 config_checking/checkers/hyperparameter_checker.py diff --git a/config_checking/checkers/hyperparameter_checker.py b/config_checking/checkers/hyperparameter_checker.py new file mode 100644 index 0000000..1f35256 --- /dev/null +++ b/config_checking/checkers/hyperparameter_checker.py @@ -0,0 +1,92 @@ +import os +import json +from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item +from config_checking.utils.packing import add_file_to_zip +from config_checking.utils.utils import load_json, compare_dict, write_list_to_file +from config_checking.utils.utils import config_checking_print + +@register_checker_item("hyperparameter") +class HyperparameterChecker(BaseChecker): + input_needed = "hyperparameter_paths" # 可以是包含多个超参数文件路径的字典或列表 + + target_name_in_zip = "hyperparameters" # 在zip文件中创建的目录名 + result_filename = "hyperparameter_diff.txt" + + def pack(pack_input): + hyperparameter_paths = pack_input.hyperparameter_paths + output_zip_path = pack_input.output_zip_path + + if isinstance(hyperparameter_paths, dict): + for dirname, pathname in hyperparameter_paths.items(): + if os.path.isfile(pathname): + dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, dirname, os.path.basename(pathname)) + add_file_to_zip(output_zip_path, pathname, dest_path_in_zip) + config_checking_print(f"add hyperparameter {dirname} {pathname} to zip") + else: + config_checking_print(f"Warning: Hyperparameter path {pathname} is not a file.") + elif isinstance(hyperparameter_paths, list): + for pathname in hyperparameter_paths: + if os.path.isfile(pathname): + dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, os.path.basename(pathname)) + add_file_to_zip(output_zip_path, pathname, dest_path_in_zip) + config_checking_print(f"add hyperparameter {pathname} to zip") + else: + config_checking_print(f"Warning: Hyperparameter path {pathname} is not a file.") + else: + raise TypeError("hyperparameter_paths should be a dict or a list of file paths.") + + def compare(bench_dir, cmp_dir, output_path): + bench_hyperparameter_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip) + cmp_hyperparameter_dir = os.path.join(cmp_dir, HyperparameterChecker.target_name_in_zip) + output_filepath = os.path.join(output_path, HyperparameterChecker.result_filename) + + bench_hyperparameters = {} + cmp_hyperparameters = {} + + # 从 bench_dir 读取超参数 + if os.path.exists(bench_hyperparameter_dir): + for root, _, files in os.walk(bench_hyperparameter_dir): + for file in files: + if file.endswith(('.json', '.yaml', '.yml')): # 假设超参数文件是 JSON 或 YAML + filepath = os.path.join(root, file) + try: + with open(filepath, 'r') as f: + if filepath.endswith('.json'): + bench_hyperparameters[os.path.relpath(filepath, bench_hyperparameter_dir)] = json.load(f) + # 可以添加对 YAML 的支持,如果需要 + except Exception as e: + config_checking_print(f"Error loading hyperparameter file {filepath}: {e}") + + # 从 cmp_dir 读取超参数 + if os.path.exists(cmp_hyperparameter_dir): + for root, _, files in os.walk(cmp_hyperparameter_dir): + for file in files: + if file.endswith(('.json', '.yaml', '.yml')): + filepath = os.path.join(root, file) + try: + with open(filepath, 'r') as f: + if filepath.endswith('.json'): + cmp_hyperparameters[os.path.relpath(filepath, cmp_hyperparameter_dir)] = json.load(f) + # 可以添加对 YAML 的支持 + except Exception as e: + config_checking_print(f"Error loading hyperparameter file {filepath}: {e}") + + # 比较超参数 + all_diffs = [] + all_files = set(bench_hyperparameters.keys()) | set(cmp_hyperparameters.keys()) + + for filename in all_files: + bench_data = bench_hyperparameters.get(filename, None) + cmp_data = cmp_hyperparameters.get(filename, None) + + if bench_data is not None and cmp_data is not None: + diff = compare_dict(bench_data, cmp_data, prefix=f"File: {filename} -> ") + all_diffs.extend(diff) + elif bench_data is not None: + all_diffs.append(f"[Only in benchmark] File: {filename}") + elif cmp_data is not None: + all_diffs.append(f"[Only in compare] File: {filename}") + + write_list_to_file(all_diffs, output_filepath) + config_checking_print(f"Hyperparameter comparison result written to {output_filepath}") \ No newline at end of file -- Gitee From b75eebfbffaa4e04eb674c7d25cd13dde92fabc9 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 9 Jan 2025 09:13:33 +0000 Subject: [PATCH 2/9] add config_checking/checkers/random_instruction_checker.py. Signed-off-by: sunyiming --- .../checkers/random_instruction_checker.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 config_checking/checkers/random_instruction_checker.py diff --git a/config_checking/checkers/random_instruction_checker.py b/config_checking/checkers/random_instruction_checker.py new file mode 100644 index 0000000..806df7d --- /dev/null +++ b/config_checking/checkers/random_instruction_checker.py @@ -0,0 +1,122 @@ +import os +import json +import torch +import inspect + +from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item, register_pre_forward_fun_list +from config_checking.utils.packing import create_file_in_zip +from config_checking.utils.utils import write_list_to_file, config_checking_print, get_rank, write_content_to_file +from config_checking.utils.utils import read_rank_result_to_dict, compare_lists_of_dicts + +# 记录随机操作的列表,每个 rank 一个 +random_operations_history = {} + +def get_random_op_info(op_name, *args, **kwargs): + """ + 提取随机操作的相关信息,例如形状、数据类型等。 + """ + info = {"op_name": op_name} + # 尝试提取参数信息,可以根据需要扩展 + if args: + info["args"] = [str(arg) if not isinstance(arg, torch.Tensor) else f"Tensor[shape={list(arg.shape)}, dtype={arg.dtype}]" for arg in args] + if kwargs: + info["kwargs"] = {k: str(v) if not isinstance(v, torch.Tensor) else f"Tensor[shape={list(v.shape)}, dtype={v.dtype}]" for k, v in kwargs.items()} + return info + +def capture_random_state(): + """ + 捕获当前的随机数生成器状态。 + """ + return { + "torch.random.get_rng_state": torch.random.get_rng_state().tolist(), + "torch.cuda.get_rng_state_all": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + } + +# 需要 hook 的随机函数列表,可以根据需要添加 +RANDOM_FUNCTIONS_TO_HOOK = [ + (torch, 'rand'), + (torch, 'randn'), + (torch, 'randint'), + (torch, 'randperm'), + (torch.nn.functional, 'dropout') # 例如 dropout 也有随机性 + # 可以添加更多具有随机性的函数 +] + +original_functions = {} + +def hook_random_functions(): + """Hook 随机函数以记录其调用信息。""" + global original_functions + rank = get_rank() + if rank not in random_operations_history: + random_operations_history[rank] = [] + + for module, func_name in RANDOM_FUNCTIONS_TO_HOOK: + original_func = getattr(module, func_name) + original_functions[(module, func_name)] = original_func + + def hooked_func(*args, **kwargs): + info = get_random_op_info(func_name, *args, **kwargs) + # 获取调用堆栈信息,可以帮助定位是哪里的随机调用 + stack_info = inspect.stack()[1][3] # 获取调用者函数名 + info['called_by'] = stack_info + random_operations_history[rank].append(info) + return original_func(*args, **kwargs) + setattr(module, func_name, hooked_func) + config_checking_print(f"Rank {rank}: Hooked random functions.") + +def unhook_random_functions(): + """取消 hook 随机函数。""" + global original_functions + for (module, func_name), original_func in original_functions.items(): + setattr(module, func_name, original_func) + config_checking_print(f"Rank {get_rank()}: Unhooked random functions.") + +@register_checker_item("random_instruction") +class RandomInstructionChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "random_instructions" + result_filename = "random_instruction_check_result.txt" + + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def pre_forward_hook(module, input): + hook_random_functions() + return input + + def post_forward_hook(module, input, output): + unhook_random_functions() + # 将捕获到的随机操作历史记录保存到 zip 文件 + rank = get_rank() + filepath = os.path.join(RandomInstructionChecker.target_name_in_zip, f"rank{rank}.json") + create_file_in_zip(output_zip_path, filepath, json.dumps(random_operations_history.get(rank, []), indent=4)) + config_checking_print(f"Rank {rank}: Added random instructions info to zip") + + register_pre_forward_fun_list(pre_forward_hook, call_every_forward=True) + register_pre_forward_fun_list(post_forward_hook, call_every_forward=True) + + def compare(bench_dir, cmp_dir, output_path): + bench_random_inst_path = os.path.join(bench_dir, RandomInstructionChecker.target_name_in_zip) + cmp_random_inst_path = os.path.join(cmp_dir, RandomInstructionChecker.target_name_in_zip) + + bench_random_instructions = read_rank_result_to_dict(bench_random_inst_path) + cmp_random_instructions = read_rank_result_to_dict(cmp_random_inst_path) + + comparison_results = {} + for rank in sorted(bench_random_instructions.keys() | cmp_random_instructions.keys()): + bench_ops = bench_random_instructions.get(rank, []) + cmp_ops = cmp_random_instructions.get(rank, []) + deleted, added, changed = compare_lists_of_dicts(bench_ops, cmp_ops) + comparison_results[rank] = { + "deleted": deleted, + "added": added, + "changed": changed + } + + output_filepath = os.path.join(output_path, RandomInstructionChecker.result_filename) + write_content_to_file(json.dumps(comparison_results, indent=4), output_filepath) + config_checking_print(f"Random instruction comparison result written to {output_filepath}") \ No newline at end of file -- Gitee From ec69b5b9061b01d25e0c3b8f175c27cd94624e76 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 9 Jan 2025 09:13:33 +0000 Subject: [PATCH 3/9] add config_checking/checkers/random_instruction_checker.py. Signed-off-by: sunyiming --- config_checking/checkers/base_checker.py | 1 + .../checkers/random_instruction_checker.py | 122 ++++++++++++++++++ 2 files changed, 123 insertions(+) create mode 100644 config_checking/checkers/random_instruction_checker.py diff --git a/config_checking/checkers/base_checker.py b/config_checking/checkers/base_checker.py index e986b26..be292b0 100644 --- a/config_checking/checkers/base_checker.py +++ b/config_checking/checkers/base_checker.py @@ -9,6 +9,7 @@ class PackInput: self.ckpt_path = config_dict.get("ckpt path", None) self.need_env_args = config_dict.get("env args", None) self.need_pip_data = config_dict.get("pip data", None) + self.hyperparameters_path = config_dict.get("hyperparameters path", None) self.output_zip_path = config_dict.get("output zip path", "./config_check_pack.zip") self.model = model diff --git a/config_checking/checkers/random_instruction_checker.py b/config_checking/checkers/random_instruction_checker.py new file mode 100644 index 0000000..806df7d --- /dev/null +++ b/config_checking/checkers/random_instruction_checker.py @@ -0,0 +1,122 @@ +import os +import json +import torch +import inspect + +from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item, register_pre_forward_fun_list +from config_checking.utils.packing import create_file_in_zip +from config_checking.utils.utils import write_list_to_file, config_checking_print, get_rank, write_content_to_file +from config_checking.utils.utils import read_rank_result_to_dict, compare_lists_of_dicts + +# 记录随机操作的列表,每个 rank 一个 +random_operations_history = {} + +def get_random_op_info(op_name, *args, **kwargs): + """ + 提取随机操作的相关信息,例如形状、数据类型等。 + """ + info = {"op_name": op_name} + # 尝试提取参数信息,可以根据需要扩展 + if args: + info["args"] = [str(arg) if not isinstance(arg, torch.Tensor) else f"Tensor[shape={list(arg.shape)}, dtype={arg.dtype}]" for arg in args] + if kwargs: + info["kwargs"] = {k: str(v) if not isinstance(v, torch.Tensor) else f"Tensor[shape={list(v.shape)}, dtype={v.dtype}]" for k, v in kwargs.items()} + return info + +def capture_random_state(): + """ + 捕获当前的随机数生成器状态。 + """ + return { + "torch.random.get_rng_state": torch.random.get_rng_state().tolist(), + "torch.cuda.get_rng_state_all": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + } + +# 需要 hook 的随机函数列表,可以根据需要添加 +RANDOM_FUNCTIONS_TO_HOOK = [ + (torch, 'rand'), + (torch, 'randn'), + (torch, 'randint'), + (torch, 'randperm'), + (torch.nn.functional, 'dropout') # 例如 dropout 也有随机性 + # 可以添加更多具有随机性的函数 +] + +original_functions = {} + +def hook_random_functions(): + """Hook 随机函数以记录其调用信息。""" + global original_functions + rank = get_rank() + if rank not in random_operations_history: + random_operations_history[rank] = [] + + for module, func_name in RANDOM_FUNCTIONS_TO_HOOK: + original_func = getattr(module, func_name) + original_functions[(module, func_name)] = original_func + + def hooked_func(*args, **kwargs): + info = get_random_op_info(func_name, *args, **kwargs) + # 获取调用堆栈信息,可以帮助定位是哪里的随机调用 + stack_info = inspect.stack()[1][3] # 获取调用者函数名 + info['called_by'] = stack_info + random_operations_history[rank].append(info) + return original_func(*args, **kwargs) + setattr(module, func_name, hooked_func) + config_checking_print(f"Rank {rank}: Hooked random functions.") + +def unhook_random_functions(): + """取消 hook 随机函数。""" + global original_functions + for (module, func_name), original_func in original_functions.items(): + setattr(module, func_name, original_func) + config_checking_print(f"Rank {get_rank()}: Unhooked random functions.") + +@register_checker_item("random_instruction") +class RandomInstructionChecker(BaseChecker): + input_needed = "model" + multi_rank = True + + target_name_in_zip = "random_instructions" + result_filename = "random_instruction_check_result.txt" + + def pack(pack_input): + output_zip_path = pack_input.output_zip_path + + def pre_forward_hook(module, input): + hook_random_functions() + return input + + def post_forward_hook(module, input, output): + unhook_random_functions() + # 将捕获到的随机操作历史记录保存到 zip 文件 + rank = get_rank() + filepath = os.path.join(RandomInstructionChecker.target_name_in_zip, f"rank{rank}.json") + create_file_in_zip(output_zip_path, filepath, json.dumps(random_operations_history.get(rank, []), indent=4)) + config_checking_print(f"Rank {rank}: Added random instructions info to zip") + + register_pre_forward_fun_list(pre_forward_hook, call_every_forward=True) + register_pre_forward_fun_list(post_forward_hook, call_every_forward=True) + + def compare(bench_dir, cmp_dir, output_path): + bench_random_inst_path = os.path.join(bench_dir, RandomInstructionChecker.target_name_in_zip) + cmp_random_inst_path = os.path.join(cmp_dir, RandomInstructionChecker.target_name_in_zip) + + bench_random_instructions = read_rank_result_to_dict(bench_random_inst_path) + cmp_random_instructions = read_rank_result_to_dict(cmp_random_inst_path) + + comparison_results = {} + for rank in sorted(bench_random_instructions.keys() | cmp_random_instructions.keys()): + bench_ops = bench_random_instructions.get(rank, []) + cmp_ops = cmp_random_instructions.get(rank, []) + deleted, added, changed = compare_lists_of_dicts(bench_ops, cmp_ops) + comparison_results[rank] = { + "deleted": deleted, + "added": added, + "changed": changed + } + + output_filepath = os.path.join(output_path, RandomInstructionChecker.result_filename) + write_content_to_file(json.dumps(comparison_results, indent=4), output_filepath) + config_checking_print(f"Random instruction comparison result written to {output_filepath}") \ No newline at end of file -- Gitee From f432c49ac5abd4589c389d422e0fcc711645c8f1 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 18 Feb 2025 11:29:52 +0000 Subject: [PATCH 4/9] update config_checking/checkers/hyperparameter_checker.py. Signed-off-by: sunyiming --- .../checkers/hyperparameter_checker.py | 161 +++++++----------- 1 file changed, 59 insertions(+), 102 deletions(-) diff --git a/config_checking/checkers/hyperparameter_checker.py b/config_checking/checkers/hyperparameter_checker.py index 2141119..297d9df 100644 --- a/config_checking/checkers/hyperparameter_checker.py +++ b/config_checking/checkers/hyperparameter_checker.py @@ -9,6 +9,8 @@ from config_checking.utils.utils import config_checking_print from typing import Union, List, Dict, Any from difflib import SequenceMatcher import tempfile +import re +import shlex @register_checker_item("hyperparameter") class HyperparameterChecker(BaseChecker): @@ -24,119 +26,79 @@ class HyperparameterChecker(BaseChecker): "dropout_rate": ["dropout", "drop_rate"], } + @staticmethod def pack(pack_input): model_paths = pack_input.model_paths output_zip_path = pack_input.output_zip_path - if isinstance(model_paths, dict): - for dirname, pathname in model_paths.items(): - if os.path.isfile(pathname): - hyperparameters = HyperparameterChecker._extract_hyperparameters_from_model(pathname) - if hyperparameters: - dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, dirname, os.path.splitext(os.path.basename(pathname))[0] + ".json") # Save as json - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp_file: - json.dump(hyperparameters, tmp_file, indent=4) - tmp_file_path = tmp_file.name - add_file_to_zip(output_zip_path, tmp_file_path, dest_path_in_zip) - os.remove(tmp_file_path) # Clean up temp file - config_checking_print(f"add hyperparameters from model file {dirname} {pathname} to zip as {dest_path_in_zip}") - else: - config_checking_print(f"Warning: Failed to extract hyperparameters from model file {pathname}") - else: - config_checking_print(f"Warning: Model path {pathname} is not a file: {pathname}") - elif isinstance(model_paths, list): - for pathname in model_paths: - if os.path.isfile(pathname): - hyperparameters = HyperparameterChecker._extract_hyperparameters_from_model(pathname) - if hyperparameters: - dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, os.path.splitext(os.path.basename(pathname))[0] + ".json") # Save as json - with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp_file: - json.dump(hyperparameters, tmp_file, indent=4) - tmp_file_path = tmp_file.name - add_file_to_zip(output_zip_path, tmp_file_path, dest_path_in_zip) - os.remove(tmp_file_path) # Clean up temp file - config_checking_print(f"add hyperparameters from model file {pathname} to zip as {dest_path_in_zip}") - else: - config_checking_print(f"Warning: Failed to extract hyperparameters from model file {pathname}") + if not isinstance(model_paths, list): + raise TypeError("model_paths should be a list of file paths.") + + for script_path in model_paths: + if os.path.isfile(script_path): + hyperparameters = HyperparameterChecker._extract_hyperparameters_from_script(script_path) + if hyperparameters: + dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, os.path.splitext(os.path.basename(script_path))[0] + ".json") + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp_file: + json.dump(hyperparameters, tmp_file, indent=4) + tmp_file_path = tmp_file.name + add_file_to_zip(output_zip_path, tmp_file_path, dest_path_in_zip) + os.remove(tmp_file_path) + config_checking_print(f"Added hyperparameters from script {script_path} to zip as {dest_path_in_zip}") else: - config_checking_print(f"Warning: Model path {pathname} is not a file: {pathname}") - else: - raise TypeError("model_paths should be a dict or a list of file paths.") + config_checking_print(f"Warning: Failed to extract hyperparameters from script {script_path}") + else: + config_checking_print(f"Warning: Script path {script_path} is not a file.") @staticmethod - def _extract_hyperparameters_from_model(model_path: str) -> Union[Dict[str, Any], None]: + def _extract_hyperparameters_from_script(script_path: str) -> Dict[str, Any]: """ - Loads a Python module from model_path and tries to extract hyperparameters by calling parse_args(). - Assumes the model file contains a class that has a parse_args() method. + Extracts arguments from bash script used to run a model training. """ - try: - module_name = os.path.basename(model_path).replace('.py', '') - spec = importlib.util.spec_from_file_location(module_name, model_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - model_class = getattr(module, "Model", None) # Try to get a class named 'Model' - if model_class is None: - # Try to find the first class defined in the module as a fallback - classes = [obj for name, obj in module.__dict__.items() if isinstance(obj, type) and obj.__module__ == module.__name__] - if classes: - model_class = classes[0] # Take the first class found - else: - config_checking_print(f"Warning: Could not find a Model class in {model_path}. Trying to call parse_args directly on the module.") - if hasattr(module, "parse_args"): - args = module.parse_args() # Directly call parse_args on the module - if hasattr(args, "__dict__"): - return vars(args) # Convert Namespace to dict if needed - elif isinstance(args, dict): - return args - else: - config_checking_print(f"Warning: parse_args in {model_path} did not return a dict-like object. Returning empty dict.") - return {} - else: - config_checking_print(f"Warning: No Model class or parse_args function found in {model_path}. Cannot extract hyperparameters.") - return None - - - if hasattr(model_class, "parse_args"): - model_instance = model_class() # Instantiate the model class - args = model_instance.parse_args() # Call parse_args on an instance - if hasattr(args, "__dict__"): - return vars(args) # Convert Namespace to dict - elif isinstance(args, dict): - return args - else: - config_checking_print(f"Warning: parse_args in {model_path} did not return a dict-like object. Returning empty dict.") - return {} - else: - config_checking_print(f"Warning: Model class in {model_path} does not have a parse_args() method. Cannot extract hyperparameters.") - return None - - except Exception as e: - config_checking_print(f"Error extracting hyperparameters from {model_path}: {e}") - return None + hyperparameters = {} + with open(script_path, 'r') as file: + script_content = file.read() + + command_line = re.search(r'torchrun\s+(.*?)\s*\|', script_content, re.DOTALL) + if command_line: + command_line = command_line.group(1) + + blocks = re.findall(r'(\w+_ARGS)="(.*?)"', script_content, re.DOTALL) + block_contents = {} + for block_name, block_content in blocks: + block_content = block_content.replace('\n', ' ') + block_contents[block_name] = block_content + command_line = command_line.replace(f"${block_name}", block_content) + + matches = re.findall(r'--([\w-]+)(?:\s+([^\s]+))?', command_line) + for match in matches: + key, value = match + if value and value.startswith('$'): + env_var = re.search(rf'{value[1:]}="?(.*?)"?\s', script_content) + if env_var: + value = env_var.group(1) + hyperparameters[key] = value if value else True + + return hyperparameters @staticmethod def _fuzzy_match_parameter(param_name: str, available_params: Dict[str, Any]) -> Union[str, None]: """ Fuzzy matches a parameter name against available parameter names using predefined mappings and string similarity. """ - # 1. Check for exact match first if param_name in available_params: return param_name - # 2. Check predefined mappings if param_name in HyperparameterChecker.PARAMETER_NAME_MAPPING: alternatives = HyperparameterChecker.PARAMETER_NAME_MAPPING[param_name] for alt_name in alternatives: if alt_name in available_params: return alt_name - - # 3. Fuzzy matching using string similarity (SequenceMatcher) - Optional, but can be helpful best_match_name = None - best_match_ratio = 0.8 # Threshold for considering it a match, adjust as needed + best_match_ratio = 0.8 for available_param_name in available_params: ratio = SequenceMatcher(None, param_name.lower(), available_param_name.lower()).ratio() - if ratio > best_match_ratio and ratio > best_match_ratio: # Higher than threshold and better than current best + if ratio > best_match_ratio and ratio > best_match_ratio: best_match_ratio = ratio best_match_name = available_param_name @@ -144,7 +106,7 @@ class HyperparameterChecker(BaseChecker): config_checking_print(f"Fuzzy matched parameter '{param_name}' to '{best_match_name}' (similarity: {best_match_ratio:.2f})") return best_match_name - return None # No match found + return None def compare(bench_dir, cmp_dir, output_path): bench_model_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip) # Still using the same target_name_in_zip, but now for model files @@ -154,29 +116,26 @@ class HyperparameterChecker(BaseChecker): bench_hyperparameters = {} cmp_hyperparameters = {} - # From bench_dir, extract hyperparameters by loading json files if os.path.exists(bench_model_dir): for root, _, files in os.walk(bench_model_dir): for file in files: - if file.endswith('.json'): # Expecting json files now + if file.endswith('.json'): filepath = os.path.join(root, file) - relative_filepath = os.path.relpath(filepath, bench_model_dir) # Use relative path as key - params = load_json(filepath) # Load hyperparameters from json + relative_filepath = os.path.relpath(filepath, bench_model_dir) + params = load_json(filepath) if params: - bench_hyperparameters[relative_filepath] = params # Store parameters with relative file path as key + bench_hyperparameters[relative_filepath] = params - # From cmp_dir, extract hyperparameters by loading json files if os.path.exists(cmp_model_dir): for root, _, files in os.walk(cmp_model_dir): for file in files: - if file.endswith('.json'): # Expecting json files now + if file.endswith('.json'): filepath = os.path.join(root, file) - relative_filepath = os.path.relpath(filepath, cmp_model_dir) # Use relative path as key - params = load_json(filepath) # Load hyperparameters from json + relative_filepath = os.path.relpath(filepath, cmp_model_dir) + params = load_json(filepath) if params: - cmp_hyperparameters[relative_filepath] = params # Store parameters with relative file path as key + cmp_hyperparameters[relative_filepath] = params - # Compare hyperparameters all_diffs = [] all_files = set(bench_hyperparameters.keys()) | set(cmp_hyperparameters.keys()) @@ -185,7 +144,6 @@ class HyperparameterChecker(BaseChecker): cmp_params = cmp_hyperparameters.get(filename, None) if bench_params is not None and cmp_params is not None: - # Fuzzy parameter comparison within each file file_diffs = [] bench_param_names = set(bench_params.keys()) cmp_param_names = set(cmp_params.keys()) @@ -199,12 +157,11 @@ class HyperparameterChecker(BaseChecker): {matched_cmp_param_name: cmp_param_value}) if diff: file_diffs.extend([f" Parameter '{bench_param_name}' (matched with '{matched_cmp_param_name}'): {d}" for d in diff]) - del cmp_params[matched_cmp_param_name] # Remove matched parameter to find unmatched cmp parameters + del cmp_params[matched_cmp_param_name] else: file_diffs.append(f" [Only in benchmark] Parameter: '{bench_param_name}': {bench_params[bench_param_name]}") - # Remaining parameters in cmp_params are only in compare for cmp_param_name, cmp_param_value in cmp_params.items(): file_diffs.append(f" [Only in compare] Parameter: '{cmp_param_name}': {cmp_param_value}") -- Gitee From 90e99df43fc6f20e83ff535e183eb011dc8c7d4f Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 18 Feb 2025 11:55:56 +0000 Subject: [PATCH 5/9] update config_checking/checkers/hyperparameter_checker.py. Signed-off-by: sunyiming --- config_checking/checkers/hyperparameter_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_checking/checkers/hyperparameter_checker.py b/config_checking/checkers/hyperparameter_checker.py index 297d9df..3a8594c 100644 --- a/config_checking/checkers/hyperparameter_checker.py +++ b/config_checking/checkers/hyperparameter_checker.py @@ -70,7 +70,7 @@ class HyperparameterChecker(BaseChecker): block_contents[block_name] = block_content command_line = command_line.replace(f"${block_name}", block_content) - matches = re.findall(r'--([\w-]+)(?:\s+([^\s]+))?', command_line) + matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line) for match in matches: key, value = match if value and value.startswith('$'): -- Gitee From 4d3838b99fcdbd6da22a4fa73eaf5fa1db67933f Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 18 Feb 2025 11:56:15 +0000 Subject: [PATCH 6/9] update config_checking/checkers/__init__.py. Signed-off-by: sunyiming --- config_checking/checkers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config_checking/checkers/__init__.py b/config_checking/checkers/__init__.py index 45a60e9..76787d1 100644 --- a/config_checking/checkers/__init__.py +++ b/config_checking/checkers/__init__.py @@ -4,7 +4,7 @@ import config_checking.checkers.pip_checker import config_checking.checkers.checkpoint_checker import config_checking.checkers.dataset_checker import config_checking.checkers.weights_checker -import config_checking.checkers.hyperparameters_checker +import config_checking.checkers.hyperparameter_checker from config_checking.checkers.base_checker import BaseChecker -- Gitee From ef99d7c68ec6af9bea27607016f668255c75d5fe Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 21 Feb 2025 03:33:30 +0000 Subject: [PATCH 7/9] update config_checking/utils/random_patch.py. Signed-off-by: sunyiming --- config_checking/utils/random_patch.py | 189 +++++++++++++++++--------- 1 file changed, 121 insertions(+), 68 deletions(-) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index a6b3e4d..70be026 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -2,78 +2,131 @@ import logging import random import traceback from functools import wraps +from typing import Callable, Any +import inspect import numpy as np import torch from config_checking.utils.utils import config_checking_print - DEFAULT_RANDOM_LOG_PATH = './random_patch.log' -# TODO 日志打印默认记录在random_patch.log中: -# 1、日志写入文件并发处理 -# 2、支持日志文件路径设置 -# 3、多个装饰器可改为责任链模式 - - -def __log_stack(func): - @wraps(func) - def wrapper(*args, **kwargs): - stack = traceback.format_stack() - msg = f"info: random function {func.__name__} called. Call stack:" - for line in stack[:-1]: - msg += '\n' + line.strip() - logging.info(msg) - return func(*args, **kwargs) - - return wrapper - - -def __check_torch_with_device(func): - @wraps(func) - def wrapper(*args, **kwargs): - if 'device' in kwargs: - # 获取调用栈信息以确定文件和行号 - stack = traceback.extract_stack() - caller = stack[-2] - file_name = caller.filename - line_number = caller.lineno - logging.warning(f"Warning: torch function {func.__name__} called with device specified in {file_name} " - f"at line {line_number}.") - return func(*args, **kwargs) - return wrapper - - -def __track_func(func): - return __log_stack(__check_torch_with_device(func)) - - -def apply_patches(): - # init logging - logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) - - # Patch random module - random.random = __track_func(random.random) - random.randint = __track_func(random.randint) - random.uniform = __track_func(random.uniform) - random.choice = __track_func(random.choice) - - # Patch numpy.random module - np.random.rand = __track_func(np.random.rand) - np.random.randint = __track_func(np.random.randint) - np.random.choice = __track_func(np.random.choice) - np.random.normal = __track_func(np.random.normal) - - # Patch torch random functions - torch.rand = __track_func(torch.rand) - torch.randint = __track_func(torch.randint) - torch.randn = __track_func(torch.randn) - torch.rand_like = __track_func(torch.rand_like) - torch.randint_like = __track_func(torch.randint_like) - torch.randn_like = __track_func(torch.randn_like) - torch.manual_seed = __track_func(torch.manual_seed) - - # Patch torch.Tensor random function - torch.Tensor.exponential_ = __track_func(torch.Tensor.exponential_) - - config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") +class RandomCallTracker: + """Centralized class to handle random function tracking and logging""" + + def __init__(self, log_path: str = DEFAULT_RANDOM_LOG_PATH): + self.log_path = log_path + self.logger = self._setup_logger() + + def _setup_logger(self) -> logging.Logger: + """Configure logging with a more detailed format""" + logger = logging.getLogger('RandomTracker') + logger.setLevel(logging.INFO) + + # Avoid duplicate handlers if already configured + if not logger.handlers: + # Create directory if it doesn't exist + import os + os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) + + handler = logging.FileHandler(self.log_path) + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + def track_random_call(self, func: Callable) -> Callable: + """Decorator to track random function calls with detailed stack info""" + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe() + caller_frame = frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + + # Enhanced device handling for torch functions + device_info = "" + if 'device' in kwargs: + device_info = f" [device: {kwargs['device']}]" + elif any(isinstance(arg, torch.Tensor) for arg in args): + tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] + if tensors: + device_info = f" [tensor device: {tensors[0].device}]" + + # Format detailed message + msg = ( + f"Random function '{func.__qualname__}' called " + f"from {caller_info.filename}:{caller_info.lineno}" + f"{device_info}\n" + f"Function: {caller_info.function}\n" + f"Code context: {caller_info.code_context[0].strip() if caller_info.code_context else 'N/A'}\n" + f"Arguments: args={args}, kwargs={kwargs}" + ) + + self.logger.info(msg) + + try: + result = func(*args, **kwargs) + self.logger.debug(f"Result: {result}") + return result + except Exception as e: + self.logger.error(f"Error in {func.__qualname__}: {str(e)}\n{traceback.format_exc()}") + raise + finally: + del frame, caller_frame + + return wrapper + +def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: + """ + Apply tracking patches to random number generation functions + + Args: + log_path (str): Path where the log file will be saved + + Returns: + RandomCallTracker: The tracker instance for potential further manipulation + """ + tracker = RandomCallTracker(log_path) + track = tracker.track_random_call + + # Random module patches + random_patches = { + 'random': random.random, + 'randint': random.randint, + 'uniform': random.uniform, + 'choice': random.choice + } + for name, func in random_patches.items(): + setattr(random, name, track(func)) + + # NumPy random patches + np_random_patches = { + 'rand': np.random.rand, + 'randint': np.random.randint, + 'choice': np.random.choice, + 'normal': np.random.normal + } + for name, func in np_random_patches.items(): + setattr(np.random, name, track(func)) + + # Torch random patches + torch_patches = { + 'rand': torch.rand, + 'randint': torch.randint, + 'randn': torch.randn, + 'rand_like': torch.rand_like, + 'randint_like': torch.randint_like, + 'randn_like': torch.randn_like, + 'manual_seed': torch.manual_seed + } + for name, func in torch_patches.items(): + setattr(torch, name, track(func)) + + # Torch Tensor method patch + torch.Tensor.exponential_ = track(torch.Tensor.exponential_) + + config_checking_print(f"Random patches applied, logs saved to: {log_path}") + return tracker \ No newline at end of file -- Gitee From f889e65299a3fa6ca38e67aed64f636d3f2a7449 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 21 Feb 2025 07:08:57 +0000 Subject: [PATCH 8/9] update config_checking/utils/random_patch.py. Signed-off-by: sunyiming --- config_checking/utils/random_patch.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index 70be026..25780b7 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -11,21 +11,15 @@ from config_checking.utils.utils import config_checking_print DEFAULT_RANDOM_LOG_PATH = './random_patch.log' -class RandomCallTracker: - """Centralized class to handle random function tracking and logging""" - +class RandomCallTracker: def __init__(self, log_path: str = DEFAULT_RANDOM_LOG_PATH): self.log_path = log_path self.logger = self._setup_logger() def _setup_logger(self) -> logging.Logger: - """Configure logging with a more detailed format""" logger = logging.getLogger('RandomTracker') logger.setLevel(logging.INFO) - - # Avoid duplicate handlers if already configured if not logger.handlers: - # Create directory if it doesn't exist import os os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) @@ -39,14 +33,11 @@ class RandomCallTracker: return logger def track_random_call(self, func: Callable) -> Callable: - """Decorator to track random function calls with detailed stack info""" @wraps(func) def wrapper(*args, **kwargs): frame = inspect.currentframe() caller_frame = frame.f_back caller_info = inspect.getframeinfo(caller_frame) - - # Enhanced device handling for torch functions device_info = "" if 'device' in kwargs: device_info = f" [device: {kwargs['device']}]" @@ -54,8 +45,6 @@ class RandomCallTracker: tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] if tensors: device_info = f" [tensor device: {tensors[0].device}]" - - # Format detailed message msg = ( f"Random function '{func.__qualname__}' called " f"from {caller_info.filename}:{caller_info.lineno}" @@ -80,19 +69,9 @@ class RandomCallTracker: return wrapper def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: - """ - Apply tracking patches to random number generation functions - - Args: - log_path (str): Path where the log file will be saved - - Returns: - RandomCallTracker: The tracker instance for potential further manipulation - """ tracker = RandomCallTracker(log_path) track = tracker.track_random_call - # Random module patches random_patches = { 'random': random.random, 'randint': random.randint, @@ -102,7 +81,6 @@ def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: for name, func in random_patches.items(): setattr(random, name, track(func)) - # NumPy random patches np_random_patches = { 'rand': np.random.rand, 'randint': np.random.randint, @@ -112,7 +90,6 @@ def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: for name, func in np_random_patches.items(): setattr(np.random, name, track(func)) - # Torch random patches torch_patches = { 'rand': torch.rand, 'randint': torch.randint, @@ -125,7 +102,6 @@ def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: for name, func in torch_patches.items(): setattr(torch, name, track(func)) - # Torch Tensor method patch torch.Tensor.exponential_ = track(torch.Tensor.exponential_) config_checking_print(f"Random patches applied, logs saved to: {log_path}") -- Gitee From 8e3de28d0b84b5d5876f48e104d84f26d861246f Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 21 Feb 2025 07:34:05 +0000 Subject: [PATCH 9/9] update config_checking/utils/random_patch.py. Signed-off-by: sunyiming --- config_checking/utils/random_patch.py | 40 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index 25780b7..2b9ce76 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -4,6 +4,7 @@ import traceback from functools import wraps from typing import Callable, Any import inspect +import os import numpy as np import torch @@ -13,23 +14,25 @@ DEFAULT_RANDOM_LOG_PATH = './random_patch.log' class RandomCallTracker: def __init__(self, log_path: str = DEFAULT_RANDOM_LOG_PATH): - self.log_path = log_path + self.log_path = os.path.abspath(log_path) self.logger = self._setup_logger() def _setup_logger(self) -> logging.Logger: logger = logging.getLogger('RandomTracker') logger.setLevel(logging.INFO) - if not logger.handlers: - import os - os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) + + if logger.handlers: + logger.handlers.clear() - handler = logging.FileHandler(self.log_path) - formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' - ) - handler.setFormatter(formatter) - logger.addHandler(handler) + os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) + + handler = logging.FileHandler(self.log_path, mode='w') + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) return logger def track_random_call(self, func: Callable) -> Callable: @@ -38,6 +41,7 @@ class RandomCallTracker: frame = inspect.currentframe() caller_frame = frame.f_back caller_info = inspect.getframeinfo(caller_frame) + device_info = "" if 'device' in kwargs: device_info = f" [device: {kwargs['device']}]" @@ -45,13 +49,16 @@ class RandomCallTracker: tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] if tensors: device_info = f" [tensor device: {tensors[0].device}]" + + stack_trace = ''.join(traceback.format_stack(caller_frame)) + msg = ( f"Random function '{func.__qualname__}' called " - f"from {caller_info.filename}:{caller_info.lineno}" - f"{device_info}\n" + f"Location: {os.path.abspath(caller_info.filename)}:{caller_info.lineno}\n" f"Function: {caller_info.function}\n" f"Code context: {caller_info.code_context[0].strip() if caller_info.code_context else 'N/A'}\n" - f"Arguments: args={args}, kwargs={kwargs}" + f"Arguments: args={args}, kwargs={kwargs}{device_info}\n" + f"Stack trace:\n{stack_trace}" ) self.logger.info(msg) @@ -61,7 +68,8 @@ class RandomCallTracker: self.logger.debug(f"Result: {result}") return result except Exception as e: - self.logger.error(f"Error in {func.__qualname__}: {str(e)}\n{traceback.format_exc()}") + error_trace = traceback.format_exc() + self.logger.error(f"Error in {func.__qualname__}: {str(e)}\nFull traceback:\n{error_trace}") raise finally: del frame, caller_frame @@ -104,5 +112,5 @@ def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: torch.Tensor.exponential_ = track(torch.Tensor.exponential_) - config_checking_print(f"Random patches applied, logs saved to: {log_path}") + config_checking_print(f"Random patches applied, logs saved to: {os.path.abspath(log_path)}") return tracker \ No newline at end of file -- Gitee