From a38f72c8b1288ad8ea4b7e17f5d97a6d2b5553d2 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Mon, 7 Jul 2025 15:51:59 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E6=94=AF=E6=8C=81fp8=E7=9A=84=E7=89=B9=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../data_processor/pytorch_processor.py | 26 ++----------------- .../msprobe/pytorch/common/utils.py | 16 ------------ .../pytorch/dump/module_dump/hook_wrapper.py | 4 +-- .../pytorch/hook_module/hook_module.py | 4 +-- .../msprobe/pytorch/monitor/module_hook.py | 6 ++--- .../msprobe/pytorch/monitor/module_metric.py | 3 --- .../data_processor/test_pytorch_processor.py | 10 ------- .../test/pytorch_ut/common/test_pt_utils.py | 26 +------------------ .../pytorch_ut/dump/test_pt_hook_wrapper.py | 11 -------- 9 files changed, 8 insertions(+), 98 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 2cd93b3caee..0117e6cd35b 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -30,7 +30,7 @@ from msprobe.core.common.utils import convert_tuple from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo -from msprobe.pytorch.common.utils import Const as PtConst, save_pt, is_hifloat8_tensor, is_float8_tensor +from msprobe.pytorch.common.utils import save_pt from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow is_gpu = False @@ -210,18 +210,6 @@ class PytorchDataProcessor(BaseDataProcessor): logger.warning(f"Failed to get value of torch.distributed.ReduceOp with error info: {e}.") return {"type": "torch.distributed.ReduceOp", "value": op_type} - @staticmethod - def _cast_to_float_if_fp8(tensor): - dtype = str(tensor.dtype) - if is_float8_tensor(tensor): - dtype = PtConst.HIFLOAT8_TYPE if is_hifloat8_tensor(tensor) else dtype - logger.debug( - f"The {dtype} tensor analyzing/saving is unsupported in dump function." - f"Casting to float for processing." - ) - tensor = tensor.float() - return tensor, dtype - @classmethod def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type @@ -268,11 +256,10 @@ class PytorchDataProcessor(BaseDataProcessor): return p2pop_info def _analyze_tensor(self, tensor, suffix): - tensor, dtype = self._cast_to_float_if_fp8(tensor) tensor_stat = self.get_stat_info(tensor, self.config.async_dump) tensor_json = {} tensor_json.update({'type': 'torch.Tensor'}) - tensor_json.update({'dtype': dtype}) + tensor_json.update({'dtype': str(tensor.dtype)}) tensor_json.update({"shape": tensor.shape}) stat_values = [ @@ -295,7 +282,6 @@ class PytorchDataProcessor(BaseDataProcessor): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) single_arg.update({"data_name": dump_data_name}) - tensor, _ = self._cast_to_float_if_fp8(tensor) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.clone().detach() else: @@ -396,7 +382,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): self._analyze_maybe_overflow_flag() if self.has_overflow: for file_path, tensor in self.cached_tensors_and_file_paths.items(): - tensor, _ = self._cast_to_float_if_fp8(tensor) save_pt(tensor.clone().contiguous().detach(), file_path) self.real_overflow_nums += 1 if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums: @@ -588,11 +573,6 @@ class KernelDumpDataProcessor(PytorchDataProcessor): ) def clone_and_detach_tensor(self, input_params): if isinstance(input_params, torch.Tensor): - if is_float8_tensor(input_params): - raise MsprobeException( - MsprobeException.UNSUPPORTED_TYPE_ERROR, - f"L2 backward dump does not support float8 type." - ) if input_params.requires_grad: return input_params.clone().detach().requires_grad_() return input_params.clone() @@ -607,8 +587,6 @@ class KernelDumpDataProcessor(PytorchDataProcessor): def analyze_single_element(self, element, suffix_stack): if isinstance(element, torch.Tensor): - if is_float8_tensor(element): - return {} if not self.is_found_output_tensor: if element.requires_grad: self.forward_output_tensor = element diff --git a/debug/accuracy_tools/msprobe/pytorch/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/common/utils.py index 8f10660c713..2aeb585fc6c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/utils.py @@ -264,10 +264,6 @@ class Const: NPU = 'NPU' DISTRIBUTED = 'Distributed' - HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor" - FLOAT8_E5M2_TYPE = "torch.float8_e5m2" - FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn" - RAISE_PRECISION = { torch.float16: torch.float32, torch.bfloat16: torch.float32, @@ -483,18 +479,6 @@ def is_torch_nn_module(variable): return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule) -def is_hifloat8_tensor(tensor): - if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor): - return True - return False - - -def is_float8_tensor(tensor): - if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]: - return True - return is_hifloat8_tensor(tensor) - - def register_forward_pre_hook(module, forward_pre_hook): if torch_version_above_or_equal_2: module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py index 0434e3e6268..f41bf674602 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/hook_wrapper.py @@ -21,13 +21,11 @@ from torch.utils.hooks import BackwardHook from msprobe.core.common.const import Const from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import is_float8_tensor def wrap_setup_backward_hook(func): def requires_clone(tensor): - return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \ - tensor.requires_grad and torch.is_grad_enabled() + return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled() @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH) def parse_tensor(item, tensor_list): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index 0a55f6a9deb..f9fc0c10fe6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.utils.hooks as full_hooks from msprobe.core.common.runtime import Runtime -from msprobe.pytorch.common.utils import is_float8_tensor, register_forward_pre_hook, register_forward_hook +from msprobe.pytorch.common.utils import register_forward_pre_hook, register_forward_hook class HOOKModule(nn.Module): @@ -104,7 +104,7 @@ class HOOKModule(nn.Module): else: return result - if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()): + if not (var.requires_grad and torch.is_grad_enabled()): return result grad_fn = var.grad_fn diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index e34fb21d5ee..30c34bfa943 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -32,7 +32,7 @@ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFa from msprobe.core.common.file_utils import write_df_to_csv from msprobe.core.common.utils import analyze_api_call_stack from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor +from msprobe.pytorch.common.utils import is_recomputation from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group @@ -763,7 +763,7 @@ class TrainerMon: def clone_if_tensor(args): if isinstance(args, tuple): return tuple([clone_if_tensor(arg) for arg in args]) - elif isinstance(args, torch.Tensor) and not is_float8_tensor(args): + elif isinstance(args, torch.Tensor): return args.clone() else: return args @@ -1171,8 +1171,6 @@ class TrainerMon: grad = param.main_grad else: grad = param.grad - if is_float8_tensor(grad): - grad = grad.float() context_dict[key] = grad.clone() if param.micro_step == self.micro_batch_number: diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index 48d241c5f61..9dd4d0b71f2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -16,7 +16,6 @@ import re import torch -from msprobe.pytorch.common.utils import is_float8_tensor from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean from msprobe.pytorch.monitor.utils import get_nan_tensor @@ -181,8 +180,6 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None): # Non-tensor in/output filled with nan. out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops}) continue - if is_float8_tensor(tensor): - tensor = tensor.float() for metric_name in ops: fun_metric = config_metric_registry.get(metric_name) out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index ad933870c68..43369317d3d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -369,16 +369,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertEqual(result['shape'], torch.Size([0])) self.assertEqual(result['requires_grad'], False) - def test_cast_to_float_if_fp8(self): - tensor = MagicMock() - tensor.dtype = "torch.float8_e5m2" - _, dtype = self.processor._cast_to_float_if_fp8(tensor) - self.assertEqual(dtype, "torch.float8_e5m2") - - tensor.dtype = "torch.float8_e4m3fn" - _, dtype = self.processor._cast_to_float_if_fp8(tensor) - self.assertEqual(dtype, "torch.float8_e4m3fn") - class TestTensorDataProcessor(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py index 0a25e6edf59..c28557f8fca 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/common/test_pt_utils.py @@ -22,7 +22,6 @@ from unittest.mock import MagicMock, patch import torch import torch.distributed as dist from msprobe.core.common.exceptions import DistributedNotInitializedError -from msprobe.core.common.file_utils import FileCheckConst from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData from msprobe.pytorch.common.utils import ( parameter_adapter, @@ -35,9 +34,7 @@ from msprobe.pytorch.common.utils import ( save_api_data, load_api_data, save_pkl, - load_pkl, - is_float8_tensor, - is_hifloat8_tensor + load_pkl ) @@ -311,24 +308,3 @@ class TestSavePkl(unittest.TestCase): load_pkl(self.filepath) self.assertIn("Unsupported object type: os.system", str(context.exception)) os.remove(self.filepath) - -class TestFloat8Tensor(unittest.TestCase): - def setUp(self): - self.tensor = MagicMock() - - def test_is_float8_tensor(self): - self.tensor.dtype = "torch.float8_e5m2" - res = is_float8_tensor(self.tensor) - self.assertTrue(res) - - self.tensor.dtype = "torch.float8_e4m3fn" - res = is_float8_tensor(self.tensor) - self.assertTrue(res) - - def test_is_not_float8_tensor(self): - self.tensor.dtype = 123 - res = is_float8_tensor(self.tensor) - self.assertFalse(res) - - res = is_hifloat8_tensor(self.tensor) - self.assertFalse(res) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py index 88039390f19..c070846a604 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_pt_hook_wrapper.py @@ -79,14 +79,3 @@ class TestWrapSetupBackwardHook(unittest.TestCase): self.assertTrue(torch.equal(result[1]["dict"], test_tensor1)) self.assertTrue(torch.equal(result[2][0], test_tensor2)) self.assertTrue(torch.equal(result[3][0], test_tensor3)) - - @patch('msprobe.pytorch.common.utils.is_float8_tensor', return_value=True) - def test_float8_tensor_handling(self, _): - test_data = [torch.randn(3, requires_grad=True)] - - mock_self = MagicMock() - self.mock_func.return_value = [] - result = self.decorated_func(mock_self, test_data) - - self.assertIsInstance(result, list) - self.assertListEqual(result, test_data) -- Gitee