From 7dab332e62a2b741531f698380cad9890ec59310 Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Wed, 4 Sep 2024 14:33:31 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E8=B0=83=E7=94=A8=EF=BC=8C=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E5=88=B0=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_checking/utils/random_patch.py | 63 +++++++++++++++++++-------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index e48d203..d604fd5 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -1,40 +1,65 @@ +import logging import random -import numpy as np -import torch import traceback from functools import wraps -from config_checking.utils.utils import config_checking_print + +import numpy as np +import torch + +# TODO 日志打印默认记录在random_patch.log中: +# 1、日志写入文件并发处理 +# 2、支持日志文件路径设置 +# 3、多个装饰器可改为责任链模式 +logging.basicConfig(filename='./random_patch.log', level=logging.INFO) def __log_stack(func): @wraps(func) def wrapper(*args, **kwargs): stack = traceback.format_stack() - # TODO 替换为logger - config_checking_print(f"Function {func.__name__} called. Call stack:") + msg = f"info: random function {func.__name__} called. Call stack:" for line in stack[:-1]: - # TODO 替换为logger - config_checking_print(line.strip()) + msg += '\n' + line.strip() + logging.info(msg) return func(*args, **kwargs) return wrapper +def __check_torch_with_device(func): + @wraps(func) + def wrapper(*args, **kwargs): + if 'device' in kwargs: + # 获取调用栈信息以确定文件和行号 + stack = traceback.extract_stack() + caller = stack[-2] + file_name = caller.filename + line_number = caller.lineno + logging.warning(f"Warning: torch function {func.__name__} called with device specified in {file_name} " + f"at line {line_number}.") + return func(*args, **kwargs) + return wrapper + + +def __track_func(func): + return __log_stack(__check_torch_with_device(func)) + + def apply_patches(): # Patch random module - random.random = __log_stack(random.random) - random.randint = __log_stack(random.randint) - random.uniform = __log_stack(random.uniform) - random.choice = __log_stack(random.choice) + random.random = __track_func(random.random) + random.randint = __track_func(random.randint) + random.uniform = __track_func(random.uniform) + random.choice = __track_func(random.choice) # Patch numpy.random module - np.random.rand = __log_stack(np.random.rand) - np.random.randint = __log_stack(np.random.randint) - np.random.choice = __log_stack(np.random.choice) - np.random.normal = __log_stack(np.random.normal) + np.random.rand = __track_func(np.random.rand) + np.random.randint = __track_func(np.random.randint) + np.random.choice = __track_func(np.random.choice) + np.random.normal = __track_func(np.random.normal) # Patch torch random functions - torch.rand = __log_stack(torch.rand) - torch.randint = __log_stack(torch.randint) - torch.randn = __log_stack(torch.randn) - torch.manual_seed = __log_stack(torch.manual_seed) + torch.rand = __track_func(torch.rand) + torch.randint = __track_func(torch.randint) + torch.randn = __track_func(torch.randn) + torch.manual_seed = __track_func(torch.manual_seed) -- Gitee From 28f134716d7860be45020142842c71cf4a1e172d Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Thu, 5 Sep 2024 10:28:55 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91?= =?UTF-8?q?=E9=9A=8F=E6=9C=BA=E8=B0=83=E7=94=A8=EF=BC=8C=E8=BE=93=E5=87=BA?= =?UTF-8?q?=E5=88=B0=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 +++- config_checking/utils/random_patch.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4063b4a..365fa8d 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,9 @@ from config_checking.utils.random_patch import apply_patches apply_patches() # 针对所有random,打印调用日志 ``` -上述语句会在日志中打印代码执行random操作的堆栈信息,供用户查看 +上述语句会汇总代码中调用的所有随机生成语句,并输出到执行路径下的文件`random_patch.log`中,同时记录如下信息: +a、输出调用栈信息:随机语句调用,如`random`,`np.random`以及`torch.rand`模块 +b、输出调用行号:torch指定device侧生成随机数语句 ## 通过标准 diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index d604fd5..4690f81 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -5,12 +5,16 @@ from functools import wraps import numpy as np import torch +from config_checking.utils.utils import config_checking_print + + +DEFAULT_RANDOM_LOG_PATH = './random_patch.log' # TODO 日志打印默认记录在random_patch.log中: # 1、日志写入文件并发处理 # 2、支持日志文件路径设置 # 3、多个装饰器可改为责任链模式 -logging.basicConfig(filename='./random_patch.log', level=logging.INFO) +logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) def __log_stack(func): @@ -63,3 +67,5 @@ def apply_patches(): torch.randint = __track_func(torch.randint) torch.randn = __track_func(torch.randn) torch.manual_seed = __track_func(torch.manual_seed) + + config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") -- Gitee