diff --git a/config_checking/checkers/__init__.py b/config_checking/checkers/__init__.py index ccb5bc2aafd72ea4d68e6a4bdb7bf5e62bc9703d..76787d1926bf05e1a1e6592b038db16b0214c8a6 100644 --- a/config_checking/checkers/__init__.py +++ b/config_checking/checkers/__init__.py @@ -4,6 +4,7 @@ import config_checking.checkers.pip_checker import config_checking.checkers.checkpoint_checker import config_checking.checkers.dataset_checker import config_checking.checkers.weights_checker +import config_checking.checkers.hyperparameter_checker from config_checking.checkers.base_checker import BaseChecker diff --git a/config_checking/checkers/base_checker.py b/config_checking/checkers/base_checker.py index e986b2623d06ab771f8f3f4789cabd2c8570c5b7..1676b0ad36a090a7a445aab511126151d277db13 100644 --- a/config_checking/checkers/base_checker.py +++ b/config_checking/checkers/base_checker.py @@ -9,6 +9,7 @@ class PackInput: self.ckpt_path = config_dict.get("ckpt path", None) self.need_env_args = config_dict.get("env args", None) self.need_pip_data = config_dict.get("pip data", None) + self.model_paths= config_dict.get("model_paths", None) self.output_zip_path = config_dict.get("output zip path", "./config_check_pack.zip") self.model = model diff --git a/config_checking/checkers/hyperparameter_checker.py b/config_checking/checkers/hyperparameter_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8594ce0a9444526286f6969c0b660888b66b99 --- /dev/null +++ b/config_checking/checkers/hyperparameter_checker.py @@ -0,0 +1,178 @@ +import os +import json +import importlib.util +from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item +from config_checking.utils.packing import add_file_to_zip +from config_checking.utils.utils import load_json, compare_dict, write_list_to_file +from config_checking.utils.utils import config_checking_print +from typing import Union, List, Dict, Any +from difflib import SequenceMatcher +import tempfile +import re +import shlex + +@register_checker_item("hyperparameter") +class HyperparameterChecker(BaseChecker): + input_needed = "model_paths" + target_name_in_zip = "hyperparameters" + result_filename = "hyperparameter_diff.txt" + + PARAMETER_NAME_MAPPING = { + "learning_rate": ["lr", "learningrate"], + "batch_size": ["batch", "bs", "batch_size_per_gpu"], + "epochs": ["num_epochs", "max_epochs", "epoch"], + "weight_decay": ["wd", "weightdecay"], + "dropout_rate": ["dropout", "drop_rate"], + } + + @staticmethod + def pack(pack_input): + model_paths = pack_input.model_paths + output_zip_path = pack_input.output_zip_path + + if not isinstance(model_paths, list): + raise TypeError("model_paths should be a list of file paths.") + + for script_path in model_paths: + if os.path.isfile(script_path): + hyperparameters = HyperparameterChecker._extract_hyperparameters_from_script(script_path) + if hyperparameters: + dest_path_in_zip = os.path.join(HyperparameterChecker.target_name_in_zip, os.path.splitext(os.path.basename(script_path))[0] + ".json") + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp_file: + json.dump(hyperparameters, tmp_file, indent=4) + tmp_file_path = tmp_file.name + add_file_to_zip(output_zip_path, tmp_file_path, dest_path_in_zip) + os.remove(tmp_file_path) + config_checking_print(f"Added hyperparameters from script {script_path} to zip as {dest_path_in_zip}") + else: + config_checking_print(f"Warning: Failed to extract hyperparameters from script {script_path}") + else: + config_checking_print(f"Warning: Script path {script_path} is not a file.") + + @staticmethod + def _extract_hyperparameters_from_script(script_path: str) -> Dict[str, Any]: + """ + Extracts arguments from bash script used to run a model training. + """ + hyperparameters = {} + with open(script_path, 'r') as file: + script_content = file.read() + + command_line = re.search(r'torchrun\s+(.*?)\s*\|', script_content, re.DOTALL) + if command_line: + command_line = command_line.group(1) + + blocks = re.findall(r'(\w+_ARGS)="(.*?)"', script_content, re.DOTALL) + block_contents = {} + for block_name, block_content in blocks: + block_content = block_content.replace('\n', ' ') + block_contents[block_name] = block_content + command_line = command_line.replace(f"${block_name}", block_content) + + matches = re.findall(r'--([\w-]+)(?:\s+([^\s\\]+))?', command_line) + for match in matches: + key, value = match + if value and value.startswith('$'): + env_var = re.search(rf'{value[1:]}="?(.*?)"?\s', script_content) + if env_var: + value = env_var.group(1) + hyperparameters[key] = value if value else True + + return hyperparameters + + @staticmethod + def _fuzzy_match_parameter(param_name: str, available_params: Dict[str, Any]) -> Union[str, None]: + """ + Fuzzy matches a parameter name against available parameter names using predefined mappings and string similarity. + """ + if param_name in available_params: + return param_name + + if param_name in HyperparameterChecker.PARAMETER_NAME_MAPPING: + alternatives = HyperparameterChecker.PARAMETER_NAME_MAPPING[param_name] + for alt_name in alternatives: + if alt_name in available_params: + return alt_name + best_match_name = None + best_match_ratio = 0.8 + for available_param_name in available_params: + ratio = SequenceMatcher(None, param_name.lower(), available_param_name.lower()).ratio() + if ratio > best_match_ratio and ratio > best_match_ratio: + best_match_ratio = ratio + best_match_name = available_param_name + + if best_match_name: + config_checking_print(f"Fuzzy matched parameter '{param_name}' to '{best_match_name}' (similarity: {best_match_ratio:.2f})") + return best_match_name + + return None + + def compare(bench_dir, cmp_dir, output_path): + bench_model_dir = os.path.join(bench_dir, HyperparameterChecker.target_name_in_zip) # Still using the same target_name_in_zip, but now for model files + cmp_model_dir = os.path.join(cmp_dir, HyperparameterChecker.target_name_in_zip) + output_filepath = os.path.join(output_path, HyperparameterChecker.result_filename) + + bench_hyperparameters = {} + cmp_hyperparameters = {} + + if os.path.exists(bench_model_dir): + for root, _, files in os.walk(bench_model_dir): + for file in files: + if file.endswith('.json'): + filepath = os.path.join(root, file) + relative_filepath = os.path.relpath(filepath, bench_model_dir) + params = load_json(filepath) + if params: + bench_hyperparameters[relative_filepath] = params + + if os.path.exists(cmp_model_dir): + for root, _, files in os.walk(cmp_model_dir): + for file in files: + if file.endswith('.json'): + filepath = os.path.join(root, file) + relative_filepath = os.path.relpath(filepath, cmp_model_dir) + params = load_json(filepath) + if params: + cmp_hyperparameters[relative_filepath] = params + + all_diffs = [] + all_files = set(bench_hyperparameters.keys()) | set(cmp_hyperparameters.keys()) + + for filename in all_files: + bench_params = bench_hyperparameters.get(filename, None) + cmp_params = cmp_hyperparameters.get(filename, None) + + if bench_params is not None and cmp_params is not None: + file_diffs = [] + bench_param_names = set(bench_params.keys()) + cmp_param_names = set(cmp_params.keys()) + all_param_names = bench_param_names | cmp_param_names + + for bench_param_name in bench_param_names: + matched_cmp_param_name = HyperparameterChecker._fuzzy_match_parameter(bench_param_name, cmp_params) + if matched_cmp_param_name: + cmp_param_value = cmp_params[matched_cmp_param_name] + diff = compare_dict({bench_param_name: bench_params[bench_param_name]}, + {matched_cmp_param_name: cmp_param_value}) + if diff: + file_diffs.extend([f" Parameter '{bench_param_name}' (matched with '{matched_cmp_param_name}'): {d}" for d in diff]) + del cmp_params[matched_cmp_param_name] + + else: + file_diffs.append(f" [Only in benchmark] Parameter: '{bench_param_name}': {bench_params[bench_param_name]}") + + for cmp_param_name, cmp_param_value in cmp_params.items(): + file_diffs.append(f" [Only in compare] Parameter: '{cmp_param_name}': {cmp_param_value}") + + if file_diffs: + all_diffs.append(f"File: {filename}") + all_diffs.extend(file_diffs) + + elif bench_params is not None: + all_diffs.append(f"[Only in benchmark] File: {filename}") + elif cmp_params is not None: + all_diffs.append(f"[Only in compare] File: {filename}") + + write_list_to_file(all_diffs, output_filepath) + config_checking_print(f"Hyperparameter comparison result written to {output_filepath}") \ No newline at end of file diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index a6b3e4d95a277c5e6a23f57032d3ad34d638d054..2b9ce765c59edca71b7b44054bab4fc5083745a2 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -2,78 +2,115 @@ import logging import random import traceback from functools import wraps +from typing import Callable, Any +import inspect +import os import numpy as np import torch from config_checking.utils.utils import config_checking_print - DEFAULT_RANDOM_LOG_PATH = './random_patch.log' -# TODO 日志打印默认记录在random_patch.log中: -# 1、日志写入文件并发处理 -# 2、支持日志文件路径设置 -# 3、多个装饰器可改为责任链模式 - - -def __log_stack(func): - @wraps(func) - def wrapper(*args, **kwargs): - stack = traceback.format_stack() - msg = f"info: random function {func.__name__} called. Call stack:" - for line in stack[:-1]: - msg += '\n' + line.strip() - logging.info(msg) - return func(*args, **kwargs) - - return wrapper - - -def __check_torch_with_device(func): - @wraps(func) - def wrapper(*args, **kwargs): - if 'device' in kwargs: - # 获取调用栈信息以确定文件和行号 - stack = traceback.extract_stack() - caller = stack[-2] - file_name = caller.filename - line_number = caller.lineno - logging.warning(f"Warning: torch function {func.__name__} called with device specified in {file_name} " - f"at line {line_number}.") - return func(*args, **kwargs) - return wrapper - - -def __track_func(func): - return __log_stack(__check_torch_with_device(func)) - - -def apply_patches(): - # init logging - logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) - - # Patch random module - random.random = __track_func(random.random) - random.randint = __track_func(random.randint) - random.uniform = __track_func(random.uniform) - random.choice = __track_func(random.choice) - - # Patch numpy.random module - np.random.rand = __track_func(np.random.rand) - np.random.randint = __track_func(np.random.randint) - np.random.choice = __track_func(np.random.choice) - np.random.normal = __track_func(np.random.normal) - - # Patch torch random functions - torch.rand = __track_func(torch.rand) - torch.randint = __track_func(torch.randint) - torch.randn = __track_func(torch.randn) - torch.rand_like = __track_func(torch.rand_like) - torch.randint_like = __track_func(torch.randint_like) - torch.randn_like = __track_func(torch.randn_like) - torch.manual_seed = __track_func(torch.manual_seed) - - # Patch torch.Tensor random function - torch.Tensor.exponential_ = __track_func(torch.Tensor.exponential_) - - config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") +class RandomCallTracker: + def __init__(self, log_path: str = DEFAULT_RANDOM_LOG_PATH): + self.log_path = os.path.abspath(log_path) + self.logger = self._setup_logger() + + def _setup_logger(self) -> logging.Logger: + logger = logging.getLogger('RandomTracker') + logger.setLevel(logging.INFO) + + if logger.handlers: + logger.handlers.clear() + + os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) + + handler = logging.FileHandler(self.log_path, mode='w') + formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + def track_random_call(self, func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe() + caller_frame = frame.f_back + caller_info = inspect.getframeinfo(caller_frame) + + device_info = "" + if 'device' in kwargs: + device_info = f" [device: {kwargs['device']}]" + elif any(isinstance(arg, torch.Tensor) for arg in args): + tensors = [arg for arg in args if isinstance(arg, torch.Tensor)] + if tensors: + device_info = f" [tensor device: {tensors[0].device}]" + + stack_trace = ''.join(traceback.format_stack(caller_frame)) + + msg = ( + f"Random function '{func.__qualname__}' called " + f"Location: {os.path.abspath(caller_info.filename)}:{caller_info.lineno}\n" + f"Function: {caller_info.function}\n" + f"Code context: {caller_info.code_context[0].strip() if caller_info.code_context else 'N/A'}\n" + f"Arguments: args={args}, kwargs={kwargs}{device_info}\n" + f"Stack trace:\n{stack_trace}" + ) + + self.logger.info(msg) + + try: + result = func(*args, **kwargs) + self.logger.debug(f"Result: {result}") + return result + except Exception as e: + error_trace = traceback.format_exc() + self.logger.error(f"Error in {func.__qualname__}: {str(e)}\nFull traceback:\n{error_trace}") + raise + finally: + del frame, caller_frame + + return wrapper + +def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: + tracker = RandomCallTracker(log_path) + track = tracker.track_random_call + + random_patches = { + 'random': random.random, + 'randint': random.randint, + 'uniform': random.uniform, + 'choice': random.choice + } + for name, func in random_patches.items(): + setattr(random, name, track(func)) + + np_random_patches = { + 'rand': np.random.rand, + 'randint': np.random.randint, + 'choice': np.random.choice, + 'normal': np.random.normal + } + for name, func in np_random_patches.items(): + setattr(np.random, name, track(func)) + + torch_patches = { + 'rand': torch.rand, + 'randint': torch.randint, + 'randn': torch.randn, + 'rand_like': torch.rand_like, + 'randint_like': torch.randint_like, + 'randn_like': torch.randn_like, + 'manual_seed': torch.manual_seed + } + for name, func in torch_patches.items(): + setattr(torch, name, track(func)) + + torch.Tensor.exponential_ = track(torch.Tensor.exponential_) + + config_checking_print(f"Random patches applied, logs saved to: {os.path.abspath(log_path)}") + return tracker \ No newline at end of file