diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 9e81ac56914c689940e2acfbb4cb5a5e152a337d..9bef9ad2d84647df6aaa27256c71dccfe7732961 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -15,6 +15,7 @@ import os import zlib +import ctypes from collections.abc import Iterable from dataclasses import asdict from typing import List @@ -122,12 +123,6 @@ class PytorchDataProcessor(BaseDataProcessor): self.tensor_handler = TensorHandler() self._crc_executor = ThreadPoolExecutor(max_workers=os.cpu_count() // 2) - - @staticmethod - def compute_crc32_bytes(tensor_bytes): - return f"{zlib.crc32(tensor_bytes):08x}" - - @staticmethod def get_md5_for_tensor(x): if x.dtype == torch.bfloat16: @@ -136,6 +131,64 @@ class PytorchDataProcessor(BaseDataProcessor): crc32_hash = zlib.crc32(tensor_bytes) return f"{crc32_hash:08x}" + @staticmethod + def tensor_bytes_view_cpu(t: torch.Tensor): + """ + 返回 t 在当前 dtype 下的原始字节视图(优先零拷贝)。 + 需保证:t 已在 CPU 且是 contiguous。 + 可能返回 memoryview 或 bytes(兜底拷贝)或者 转为numpy,均可被 zlib.crc32 接受。 + """ + + nbytes = t.numel() * t.element_size() + byte_offset = t.storage_offset() * t.element_size() + + if nbytes == 0: + return memoryview(b"") + + storage = t.untyped_storage() + + # ctypes 指针构造 memoryview(零拷贝 FFI) + try: + addr = storage.data_ptr() + byte_offset + buf = (ctypes.c_ubyte * nbytes).from_address(addr) + mv3 = memoryview(buf) + + return mv3 + except Exception as e1: + logger.warning(f"path_A_failed: {e1}.") + + try: + data = ctypes.string_at(storage.data_ptr() + byte_offset, nbytes) + + return data # bytes 也可直接用于 zlib.crc32 + except Exception as e2: + logger.warning(f"path_B_failed: {e2}.") + + try: + if t.dtype == torch.bfloat16: + t = t.float() + data = t.numpy() + + return data + except Exception as e3: + logger.warning(f"path_C_failed: {e3}.") + return memoryview(b"") + + @staticmethod + def compute_crc32_from_tensor(t: torch.Tensor) -> str: + """ + 直接对 Tensor 原始字节做 CRC32。 + : + - "raw": 保持 bfloat16 原始 16bit 字节(推荐,避免升精/增容) + """ + + # 取得字节视图(含多级回退),然后做 CRC + mv = PytorchDataProcessor.tensor_bytes_view_cpu(t) + + crc = zlib.crc32(mv) + + return f"{crc:08x}" + @staticmethod def analyze_device_in_kwargs(element): single_arg = {} @@ -299,14 +352,24 @@ class PytorchDataProcessor(BaseDataProcessor): if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = None if not self.tensor_handler.is_empty_data(tensor): - # 拷贝并搬到 CPU - if common_tensor.dtype == torch.bfloat16: - common_tensor = common_tensor.float() - tensor_bytes = common_tensor.cpu().detach().numpy() + t_cpu = common_tensor + + # 根据设备类型做同步,确保数据已准备好 + if t_cpu.device.type == "cuda": + t_cpu = t_cpu.to("cpu", non_blocking=True) + torch.cuda.synchronize() + # 先异步搬运再进行同步可以显著提升性能 + elif t_cpu.device.type == "npu": + t_cpu = t_cpu.to("cpu", non_blocking=True) + torch.npu.synchronize() + + t_cpu = t_cpu.detach() + if not t_cpu.is_contiguous(): + t_cpu = t_cpu.contiguous() future = self._crc_executor.submit( - PytorchDataProcessor.compute_crc32_bytes, - tensor_bytes + PytorchDataProcessor.compute_crc32_from_tensor, + t_cpu ) crc_placeholder = self.data_writer.append_crc32_to_buffer(future)