From d6d9656507e5d0e5fbec188a3f584fe722ad64f8 Mon Sep 17 00:00:00 2001 From: RanZheng <364167184@qq.com> Date: Tue, 7 Jan 2025 10:23:43 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Bugfix=E3=80=91=E3=80=90monitor?= =?UTF-8?q?=E3=80=91=E4=BF=AE=E5=A4=8D=E6=BF=80=E6=B4=BB=E5=80=BC=E5=A4=9A?= =?UTF-8?q?ops=E9=87=87=E9=9B=86=E4=BF=9D=E5=AD=98csv=E6=97=B6header?= =?UTF-8?q?=E9=A1=BA=E5=BA=8F=E4=B8=8E=E5=AE=9E=E9=99=85=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=AF=B9=E4=B8=8D=E4=B8=8A=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E4=BF=AE=E6=94=B9ut?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/monitor/anomaly_detect.py | 2 +- .../msprobe/test/pytorch_ut/monitor/config/xy_config.json | 2 +- .../msprobe/test/pytorch_ut/monitor/test_module_hook.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py index fbfcac10f5..128e71d253 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py @@ -379,7 +379,7 @@ class CSVWriterWithAD(BaseWriterWithAD): input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT] else: input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT] - ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)] + ops_ = [MonitorConst.DOT.join(i) for i in itertools.product(input_and_output, ops)] csv_header = ["module_name", "step", *ops_] else: csv_header = ["param_name", "step", *ops] diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/xy_config.json b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/xy_config.json index d299be0787..8540929ad2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/xy_config.json +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/config/xy_config.json @@ -4,5 +4,5 @@ "xy_distribution": true, "all_xy": true, "format": "csv", - "ops": ["norm"] + "ops": ["norm", "nans"] } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py index 6bf8582c02..e31e4829c8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_module_hook.py @@ -73,13 +73,13 @@ class TestModuleHook(unittest.TestCase): self.assertTrue(os.path.exists(actv_grad_0_csv)) # validate columns and lines actv_0 = pd.read_csv(actv_0_csv) - expect_columns = ['vpp_stage', 'module_name', 'step', 'input.norm', 'output.norm'] + expect_columns = ['vpp_stage', 'module_name', 'step', 'input.norm', 'input.nans', 'output.norm', 'output.nans'] self.assertListEqual(list(actv_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([2, 5])) + self.assertEqual(actv_0.shape, tuple([2, 7])) actv_grad_0 = pd.read_csv(actv_grad_0_csv) - expect_columns = ['vpp_stage', 'module_name', 'step', 'input_grad.norm', 'output_grad.norm'] + expect_columns = ['vpp_stage', 'module_name', 'step', 'input_grad.norm', 'input_grad.nans', 'output_grad.norm', 'output_grad.nans'] self.assertListEqual(list(actv_grad_0.columns), expect_columns) - self.assertEqual(actv_0.shape, tuple([2, 5])) + self.assertEqual(actv_0.shape, tuple([2, 7])) def test_wg_distribution(self): self.get_dist_mock(False) -- Gitee