From f63ef3b6917cb7233912fcebe6626650fbafad51 Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Sat, 14 Sep 2024 15:14:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91=E7=BB=9F?= =?UTF-8?q?=E4=B8=80hash=EF=BC=9Brandom=E6=97=A5=E5=BF=97=E6=95=B4?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config_checking/checkers/checkpoint_checker.py | 14 +++++++------- config_checking/utils/hash.py | 7 ++++++- config_checking/utils/packing.py | 4 +++- config_checking/utils/random_patch.py | 7 ++++++- config_checking/utils/utils.py | 7 +++---- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/config_checking/checkers/checkpoint_checker.py b/config_checking/checkers/checkpoint_checker.py index 30e3bf9..be5cb6f 100644 --- a/config_checking/checkers/checkpoint_checker.py +++ b/config_checking/checkers/checkpoint_checker.py @@ -1,19 +1,19 @@ -import os import json +import os + import torch -import hashlib -import zlib + from config_checking.checkers.base_checker import BaseChecker +from config_checking.config_checker import register_checker_item +from config_checking.utils.hash import bytes_hash from config_checking.utils.packing import create_file_in_zip -from config_checking.utils.utils import load_json, compare_dict, write_list_to_file -from config_checking.config_checker import register_checker_item from config_checking.utils.utils import config_checking_print +from config_checking.utils.utils import load_json, compare_dict, write_list_to_file def tensor_to_hash(tensor): tensor_bytes = tensor.cpu().numpy().tobytes() - hash_object = hashlib.sha256(tensor_bytes) - return hash_object.hexdigest() + return bytes_hash(tensor_bytes) def tensor_in_state_dict_to_hash(state_dict): result = {} diff --git a/config_checking/utils/hash.py b/config_checking/utils/hash.py index 153cd05..d5721aa 100644 --- a/config_checking/utils/hash.py +++ b/config_checking/utils/hash.py @@ -2,7 +2,6 @@ import hashlib import os from concurrent.futures import ThreadPoolExecutor - BLOCK_SIZE = 64 << 20 # 64MB MAX_THREAD_WORKERS = 16 @@ -41,3 +40,9 @@ def calculate_hash(file_path, max_workers=MAX_THREAD_WORKERS): def string_hash(input_str): return hashlib.sha256(input_str.encode('utf-8')).hexdigest() + + +def bytes_hash(obj: bytes): + hex_dig = hashlib.sha256(obj).hexdigest() + short_hash = int(hex_dig, 16) % (2 ** 16) + return short_hash diff --git a/config_checking/utils/packing.py b/config_checking/utils/packing.py index 22690b2..621e713 100644 --- a/config_checking/utils/packing.py +++ b/config_checking/utils/packing.py @@ -3,6 +3,8 @@ import zipfile import hashlib import multiprocessing +from config_checking.utils.hash import string_hash + proc_lock = multiprocessing.Lock() @@ -80,7 +82,7 @@ class DirPacker: hash_file_path = f"{rel_path}.hash" target_file_path = os.path.join(self.result_dirname, hash_file_path) with open(file_path, 'rb') as f: - file_hash = hashlib.sha256(f.read()).hexdigest() + file_hash = string_hash(f.read()) zip_info = zipfile.ZipInfo(target_file_path) self.zip_handler.writestr(zip_info, file_hash) diff --git a/config_checking/utils/random_patch.py b/config_checking/utils/random_patch.py index bb24919..a6b3e4d 100644 --- a/config_checking/utils/random_patch.py +++ b/config_checking/utils/random_patch.py @@ -14,7 +14,6 @@ DEFAULT_RANDOM_LOG_PATH = './random_patch.log' # 1、日志写入文件并发处理 # 2、支持日志文件路径设置 # 3、多个装饰器可改为责任链模式 -logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) def __log_stack(func): @@ -50,6 +49,9 @@ def __track_func(func): def apply_patches(): + # init logging + logging.basicConfig(filename=DEFAULT_RANDOM_LOG_PATH, level=logging.INFO) + # Patch random module random.random = __track_func(random.random) random.randint = __track_func(random.randint) @@ -71,4 +73,7 @@ def apply_patches(): torch.randn_like = __track_func(torch.randn_like) torch.manual_seed = __track_func(torch.manual_seed) + # Patch torch.Tensor random function + torch.Tensor.exponential_ = __track_func(torch.Tensor.exponential_) + config_checking_print(f"random patches saved to file: {DEFAULT_RANDOM_LOG_PATH}") diff --git a/config_checking/utils/utils.py b/config_checking/utils/utils.py index 37705ff..d0213c7 100644 --- a/config_checking/utils/utils.py +++ b/config_checking/utils/utils.py @@ -4,7 +4,8 @@ import os import re import torch import torch.distributed as dist -import hashlib + +from config_checking.utils.hash import bytes_hash def load_txt(file_path): @@ -61,9 +62,7 @@ def config_checking_print(msg): def tensor_to_hash(tensor): """Compute the hash value of a tensor""" tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() - hash_object = hashlib.sha256() - hash_object.update(tensor_bytes) - return hash_object.hexdigest() + return bytes_hash(tensor_bytes) features = { -- Gitee