diff --git a/debug/accuracy_tools/msprobe/docs/19.monitor.md b/debug/accuracy_tools/msprobe/docs/19.monitor.md index 075e62a22a35d71b94cc0152582a664605a8b282..313db831e3fa6e5bd2edcee9ac9b72e1da169b20 100644 --- a/debug/accuracy_tools/msprobe/docs/19.monitor.md +++ b/debug/accuracy_tools/msprobe/docs/19.monitor.md @@ -9,7 +9,7 @@ 要求: -- PyTorch场景:torch不低于**2.0** +- PyTorch场景:torch不低于**1.11** - MindSpore场景:mindspore不低于**2.4.10**,仅支持**MindSpore动态图**,暂不支持**msadapter**套件 ## 功能介绍 diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index ce84e6b35b74e55a90915350ff3ef2da3f7ba441..da783bcfedf4f6a0a8eaddaee17f9c23b233859a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -19,6 +19,4 @@ from .compare.pt_compare import compare from .common.utils import seed_all from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if torch_version_above_or_equal_2: - from msprobe.pytorch.monitor.module_hook import TrainerMon +from msprobe.pytorch.monitor.module_hook import TrainerMon diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py index b2fa26a58e702120fcabd5d82f8e1e0ed27f3bc4..eee79d39da12df7f454051386f24ca9b835364fe 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/distributed/wrap_distributed.py @@ -20,6 +20,7 @@ import re import torch import torch.distributed as dist import torch.nn as nn +from torch._C._distributed_c10d import Work from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import load_yaml @@ -42,7 +43,7 @@ distributed_func = {} for f in dir(dist): distributed_func[f] = getattr(dist, f) -ORIGIN_WAIT = getattr(dist.Work, 'wait') +ORIGIN_WAIT = getattr(Work, 'wait') PENDING_ASYNC_CC_BY_HANDLE = {} diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index eea2bdbc2d2b1676e69b4528aeb28f280e5efc10..99148e853bb1003b356b139920ca4a05c5f429ee 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -42,10 +42,6 @@ from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, val get_output_base_dir, get_target_output_dir from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if not torch_version_above_or_equal_2: - raise ValueError("monitor require torch>=2.0") - FORMAT_MAPPING = { MonitorConst.TENSORBOARD: SummaryWriterWithAD,