diff --git a/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/random_patch.py b/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/random_patch.py index 9c2eb41e7b93ae1e70033ad0daa4757233e69245..9f2402348be982c3cc2f95a57d3f55041d2eb467 100644 --- a/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/random_patch.py +++ b/debug/accuracy_tools/msprobe/pytorch/config_checking/utils/random_patch.py @@ -17,73 +17,138 @@ import logging import random import traceback from functools import wraps +from typing import Callable, Any +import inspect +import os +from collections import defaultdict import numpy as np import torch -from msprobe.pytorch.config_checking.utils.utils import config_checking_print +from msprobe.pytorch.config_checking.utils.utils import config_checking_print, get_rank DEFAULT_RANDOM_LOG_PATH = './random_patch.log' - -def __log_stack(func): - @wraps(func) - def wrapper(*args, **kwargs): - stack = traceback.format_stack() - msg = f"info: random function {func.__name__} called. Call stack:" - for line in stack[:-1]: - msg += '\n' + line.strip() - logging.info(msg) - return func(*args, **kwargs) - - return wrapper - - -def __check_torch_with_device(func): - @wraps(func) - def wrapper(*args, **kwargs): - if 'device' in kwargs: - # 获取调用栈信息以确定文件和行号 - stack = traceback.extract_stack() - caller = stack[-2] - file_name = caller.filename - line_number = caller.lineno - logging.warning(f"Warning: torch function {func.__name__} called with device specified in {file_name} " - f"at line {line_number}.") - return func(*args, **kwargs) - return wrapper - - -def __track_func(func): - return __log_stack(__check_torch_with_device(func)) - - -def apply_patches(): - # init logging - logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) - - # Patch random module - random.random = __track_func(random.random) - random.randint = __track_func(random.randint) - random.uniform = __track_func(random.uniform) - random.choice = __track_func(random.choice) - +class RandomCallTracker: + def __init__(self, log_path: str = DEFAULT_RANDOM_LOG_PATH): + self.log_path = os.path.abspath(log_path) + self.logger = self._setup_logger() + self.operation_counts = defaultdict(int) + self._log_random_seeds() + + def _setup_logger(self) -> logging.Logger: + logger = logging.getLogger(f'RandomTracker_rank{get_rank()}') + logger.setLevel(logging.INFO) + if logger.handlers: + logger.handlers.clear() + + os.makedirs(os.path.dirname(self.log_path) or '.', exist_ok=True) + handler = logging.FileHandler(self.log_path) + formatter = logging.Formatter( + '%(asctime)s - Rank:%(rank)d - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + handler.addFilter(lambda record: setattr(record, 'rank', get_rank()) or True) + logger.addHandler(handler) + return logger + + def _log_random_seeds(self): + seeds_msg = ( + f"Initial Random Seeds for Rank {get_rank()}:\n" + f"Python random: {random.getstate()[1][0]}\n" + f"NumPy random: {np.random.get_state()[1][0]}\n" + f"Torch random: {torch.initial_seed()}" + ) + self.logger.info(seeds_msg) + + def _get_caller_context(self, caller_frame) -> str: + abs_file_path = os.path.abspath(caller_frame.f_code.co_filename) + caller_name = caller_frame.f_code.co_name + if 'self' in caller_frame.f_locals: + caller_class = caller_frame.f_locals['self'].__class__.__name__ + return f"{abs_file_path}:{caller_class}.{caller_name}" + return f"{abs_file_path}:{caller_name}" + + def track_random_call(self, func: Callable, api_name: str) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + frame = inspect.currentframe() + caller_frame = frame.f_back + + abs_file_path = os.path.abspath(caller_frame.f_code.co_filename) + stack_trace = ''.join(traceback.format_stack(caller_frame)) + stack_trace = stack_trace.replace(caller_frame.f_code.co_filename, abs_file_path) + operation = api_name + caller_context = self._get_caller_context(caller_frame) + unique_key = f"{operation}@{caller_context}" + self.operation_counts[unique_key] += 1 + + cumulative_msg = ( + f"Random Operation Cumulative Calls for Rank {get_rank()}:\n" + f"{self._format_operation_counts()}\n" + f"Current Operation: {operation}, File: {caller_context}, line: {caller_frame.f_lineno}\n" + f"Call Stack (with absolute path):\n{stack_trace}" + ) + self.logger.info(cumulative_msg) + + try: + result = func(*args, **kwargs) + self.logger.debug(f"Result: {result}") + return result + except Exception as e: + error_trace = traceback.format_exc() + self.logger.error(f"Error in {operation}: {str(e)}\nFull traceback:\n{error_trace}") + raise + finally: + del frame, caller_frame + + return wrapper + + def _format_operation_counts(self) -> str: + if not self.operation_counts: + return "No random operations recorded yet" + return "\n".join(f"{key}: {count}" for key, count in sorted(self.operation_counts.items())) + +def apply_patches(log_path: str = DEFAULT_RANDOM_LOG_PATH) -> RandomCallTracker: + tracker = RandomCallTracker(log_path) + + # Patch random module + random_patches = { + 'random': ('random.random', random.random), + 'randint': ('random.randint', random.randint), + 'uniform': ('random.uniform', random.uniform), + 'choice': ('random.choice', random.choice) + } + for name, (api_name, func) in random_patches.items(): + setattr(random, name, tracker.track_random_call(func, api_name)) + # Patch numpy.random module - np.random.rand = __track_func(np.random.rand) - np.random.randint = __track_func(np.random.randint) - np.random.choice = __track_func(np.random.choice) - np.random.normal = __track_func(np.random.normal) - + np_random_patches = { + 'rand': ('np.random.rand', np.random.rand), + 'randn': ('np.random.randn', np.random.randn), + 'randint': ('np.random.randint', np.random.randint), + 'choice': ('np.random.choice', np.random.choice), + 'normal': ('np.random.normal', np.random.normal) + } + for name, (api_name, func) in np_random_patches.items(): + setattr(np.random, name, tracker.track_random_call(func, api_name)) + # Patch torch random functions - torch.rand = __track_func(torch.rand) - torch.randint = __track_func(torch.randint) - torch.randn = __track_func(torch.randn) - torch.rand_like = __track_func(torch.rand_like) - torch.randint_like = __track_func(torch.randint_like) - torch.randn_like = __track_func(torch.randn_like) - torch.manual_seed = __track_func(torch.manual_seed) - + torch_patches = { + 'rand': ('torch.rand', torch.rand), + 'randint': ('torch.randint', torch.randint), + 'randn': ('torch.randn', torch.randn), + 'rand_like': ('torch.rand_like', torch.rand_like), + 'randint_like': ('torch.randint_like', torch.randint_like), + 'randn_like': ('torch.randn_like', torch.randn_like), + 'manual_seed': ('torch.manual_seed', torch.manual_seed) + } + for name, (api_name, func) in torch_patches.items(): + setattr(torch, name, tracker.track_random_call(func, api_name)) + # Patch torch.Tensor random function - torch.Tensor.exponential_ = __track_func(torch.Tensor.exponential_) + torch.Tensor.exponential_ = tracker.track_random_call(torch.Tensor.exponential_, 'torch.Tensor.exponential_') - config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") + config_checking_print(f"random patches saved to file: {os.path.abspath(tracker.log_path)}") + return tracker \ No newline at end of file