diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 327278fd4738c6d767af7be643874ae5e80b6396..a6e3aba0fdc2078859eecd8d2970a01b4021a0d3 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -24,6 +24,8 @@ class Const: Class for const """ TOOL_NAME = "msprobe" + MD5_INDEX = "md5_index" + MD5 = "md5" ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index b188fe5cbe63917f58cdec903bb67ecccf6a0bf3..c43dc9deee7b7d531e725a9b92ae8985cbaf4cad 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -94,6 +94,8 @@ class BaseDataProcessor: def __init__(self, config, data_writer): self.data_writer = data_writer self.config = config + if self.data_writer is not None: + self.data_writer.config = config self.api_info_struct = {} self.stack_info_struct = {} self.current_api_or_module_name = None diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 1e8cb322f9f4dc518a10d690168e0b80b84fa18e..bc72d32b6ae06b04b7102476771537cdf5ab51c0 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -13,7 +13,9 @@ # limitations under the License. # ============================================================================ +import os import zlib +from concurrent.futures import ThreadPoolExecutor import mindspore as ms from mindspore import mint, ops, hal @@ -53,6 +55,11 @@ class MindsporeDataProcessor(BaseDataProcessor): } self._async_dump_cache = {} self.api_register = get_api_register() + self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2) + + @staticmethod + def compute_crc32_bytes(tensor_bytes): + return f"{zlib.crc32(tensor_bytes):08x}" @staticmethod def get_md5_for_tensor(x): @@ -188,8 +195,18 @@ class MindsporeDataProcessor(BaseDataProcessor): tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({Const.MD5: tensor_md5}) + tensor = convert_bf16_to_fp32(tensor) + # 拷贝并搬到 CPU + tensor_bytes = tensor.asnumpy().tobytes() + + future = self._crc_executor.submit( + MindsporeDataProcessor.compute_crc32_bytes, + tensor_bytes + ) + + crc_placeholder = self.data_writer.append_crc32_to_buffer(future) + tensor_json[Const.MD5_INDEX] = crc_placeholder + return tensor_json def _analyze_and_save_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 8ab864b232565a37500e601253c98ad1d90730c6..c305ee766a749c0a97ebb1dbbd69071ae4c7ea30 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import zlib from dataclasses import asdict from typing import List +from concurrent.futures import ThreadPoolExecutor import numpy as np import torch @@ -65,6 +67,12 @@ class PytorchDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2) + + + @staticmethod + def compute_crc32_bytes(tensor_bytes): + return f"{zlib.crc32(tensor_bytes):08x}" @staticmethod def get_md5_for_tensor(x): @@ -248,8 +256,18 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({"requires_grad": tensor.requires_grad}) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: - tensor_md5 = self.get_md5_for_tensor(tensor) - tensor_json.update({Const.MD5: tensor_md5}) + # 拷贝并搬到 CPU + if tensor.dtype == torch.bfloat16: + tensor = tensor.float() + tensor_bytes = tensor.cpu().detach().numpy().tobytes() + + future = self._crc_executor.submit( + PytorchDataProcessor.compute_crc32_bytes, + tensor_bytes + ) + + crc_placeholder = self.data_writer.append_crc32_to_buffer(future) + tensor_json[Const.MD5_INDEX] = crc_placeholder return tensor_json def _analyze_and_save_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 0119a692a8565f1b4bf01dfc3e9b8ee01e21173d..31f5fc5cc124b79bbf243a95cbc9a9851087e054 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -18,6 +18,7 @@ import os import copy import threading +import concurrent from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.file_utils import change_mode, FileOpen, save_json, load_json from msprobe.core.common.log import logger @@ -42,6 +43,9 @@ class DataWriter: self.cache_construct = {} self.cache_debug = {} self.stat_stack_list = [] + self._error_log_initialized = False + self._cache_logged_error_types = set() + self.crc32_stack_list = [] @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -104,6 +108,24 @@ class DataWriter: self.cache_construct = {} self.cache_debug = {} + def append_crc32_to_buffer(self, future: concurrent.futures.Future) -> int: + """ + 把一个计算 CRC32 的 Future 放入队列,返回占位符索引 + """ + idx = len(self.crc32_stack_list) + self.crc32_stack_list.append(future) + return idx + + def flush_crc32_stack(self): + """ + 等待所有 CRC32 计算完成,返回结果列表 + """ + if not self.crc32_stack_list: + return [] + results = [f.result() for f in self.crc32_stack_list] + self.crc32_stack_list = [] + return results + def initialize_json_file(self, **kwargs): if kwargs["level"] == Const.LEVEL_DEBUG and not self.cache_debug: # debug level case only create debug.json @@ -137,7 +159,15 @@ class DataWriter: length = len(dump_data) - threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size + # 1) 先取到 config(如果没有,就拿 None) + cfg = getattr(self, "config", None) + # 2) 再取 summary_mode(如果 cfg 是 None 或者没 summary_mode,就拿 None) + summary_mode = getattr(cfg, "summary_mode", None) + + if summary_mode == Const.MD5: + threshold = self.flush_size + else: + threshold = self.flush_size if length < self.larger_flush_size else self.larger_flush_size if length % threshold == 0: self.write_json() @@ -241,6 +271,14 @@ class DataWriter: self._replace_stat_placeholders(self.cache_data, stat_result) if self.cache_debug: self._replace_stat_placeholders(self.cache_debug, stat_result) + + # 2) 再 flush CRC32 + crc32_result = self.flush_crc32_stack() + if crc32_result: + self._replace_crc32_placeholders(self.cache_data, crc32_result) + if self.cache_debug: + self._replace_crc32_placeholders(self.cache_debug, crc32_result) + if self.cache_data: self.write_data_json(self.dump_file_path) if self.cache_stack: @@ -250,3 +288,21 @@ class DataWriter: if self.cache_debug: self.write_debug_info_json(self.debug_file_path) + def _replace_crc32_placeholders(self, data, crc32_results): + """ + 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 + """ + if isinstance(data, dict): + for k, v in list(data.items()): + if k == Const.MD5_INDEX and isinstance(v, int): + idx = v + # 防越界 + crc = crc32_results[idx] if idx < len(crc32_results) else None + # 删除占位符,改成真实字段 + del data[k] + data[Const.MD5] = crc + else: + self._replace_crc32_placeholders(v, crc32_results) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_crc32_placeholders(item, crc32_results) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py index 25141a9e774d6a6ca05be9b668de96a7d57cb373..99dc1d5eee2d2ca5f9a3176f1a674d7ee8a6d260 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py @@ -138,12 +138,12 @@ class TestMindsporeDataProcessor(unittest.TestCase): expected_result = { 'type': 'mindspore.Tensor', 'dtype': 'Int32', - 'shape': (3,), - 'md5': 'test_md5', + 'shape': (3,) } result = self.processor._analyze_tensor(tensor, suffix) # 删除不必要的字段 result.pop('tensor_stat_index', None) + result.pop('md5_index', None) self.assertEqual(result, expected_result) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 33071c0e01c800858a892061927d470c06bcff73..7bb081e52b79e2b67563bcfaf39bc4d77b867574 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -316,10 +316,10 @@ class TestPytorchDataProcessor(unittest.TestCase): 'type': 'torch.Tensor', 'dtype': str(tensor.dtype), 'shape': tensor.shape, - 'requires_grad': tensor.requires_grad, - 'md5': 'mocked_md5' + 'requires_grad': tensor.requires_grad } result.pop('tensor_stat_index', None) + result.pop('md5_index', None) self.assertDictEqual(expected, result) def test_analyze_tensor_with_empty_tensor(self):