diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py index ea1007950e2b785942c7a804e1ce2cdd28886beb..60a54d9e1523da8b9fa05bf3aef3a02b667e26c3 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_graph_cell_dump.py @@ -30,6 +30,7 @@ from msprobe.mindspore.dump.cell_dump_process import cell_construct_wrapper from msprobe.mindspore.dump.cell_dump_process import convert_special_values, sort_filenames from msprobe.mindspore.dump.cell_dump_process import check_relation from msprobe.mindspore.dump.cell_dump_process import process_csv, np_ms_dtype_dict +from msprobe.mindspore.dump.cell_dump_process import create_kbyk_json class TestCellWrapperProcess(unittest.TestCase): @@ -347,3 +348,32 @@ class TestProcessCsv(unittest.TestCase): self.assertIs(tensor_json[CoreConst.MIN], False) self.assertIsNone(tensor_json[CoreConst.MEAN]) self.assertEqual(tensor_json[CoreConst.NORM], 1.23) + + +class TestCreateKbykJsonMultiRank(unittest.TestCase): + @patch("msprobe.mindspore.dump.cell_dump_process.create_directory", lambda path: None) + @patch( + "msprobe.mindspore.dump.cell_dump_process.save_json", + lambda path, data, indent=4: open(path, "w").write("test") + ) + def test_create_kbyk_json_multi_rank(self): + + test_cases = [ + (None, "0kernel_kbyk_dump.json"), + ("1", "1kernel_kbyk_dump.json"), + ("3", "3kernel_kbyk_dump.json"), + ] + + for rank_id_env, expected_prefix in test_cases: + with tempfile.TemporaryDirectory() as dump_path: + summary_mode = ["max"] + step = 0 + # Patch environment variable + if rank_id_env is not None: + with patch.dict(os.environ, {"RANK_ID": rank_id_env}): + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + else: + with patch.dict(os.environ, {}, clear=True): + config_json_path = create_kbyk_json(dump_path, summary_mode, step) + self.assertEqual(os.path.basename(config_json_path), expected_prefix) + self.assertTrue(config_json_path.startswith(dump_path))