From b7fd03a3705b4ac3598391e2c934fd9ebc37745e Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:03:26 +0800 Subject: [PATCH 01/24] add distributed check --- .../run_ut/run_distributed_check.py | 180 ++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py new file mode 100644 index 00000000000..1643cbefc12 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -0,0 +1,180 @@ +import os +import sys +import time +import argparse + +import tqdm +import torch +import torch_npu +import torch.distributed as dist +import torch.multiprocessing as mp + + +from msprobe.core.common.file_utils import FileChecker, write_csv, create_directory +from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.core.compare.utils import check_and_return_dir_contents +from msprobe.pytorch.hook_module.wrap_distributed import distributed_func +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info +from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments + + +current_time = time.strftime("%Y%m%d%H%M%S") +RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" + + +def cleanup(): + dist.destroy_process_group() + + +def distributed_setup(rank, world_size, master_ip, master_port): + init_method = 'tcp://' + master_ip + ':' + master_port + dist.init_process_group(backend='hccl', init_method=init_method, + world_size=world_size, rank=rank) + + +def parse_distributed_api(forward_content): + distributed_api = [] + return distributed_api + +def _run_distributed_parser(parser): + parser.add_argument("-api_info", "--api_info_dir", dest="api_info_dir", default="", type=str, + help=" The api param tool result dir: generate from api param tool. ", + required=True) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The ut task result out path.", + required=False) + + +def _run_distributed(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + _run_distributed_parser(parser) + args = parser.parse_args(sys.argv[1:]) + run_distributed_command(args) + + +def run_distributed_command(args): + input_checker = FileChecker(args.api_info_dir, FileCheckConst.DIR, ability=FileCheckConst.READ_ABLE) + api_info_dir = input_checker.common_check() + ranks = sorted(check_and_return_dir_contents(api_info_dir, Const.RANK)) + file_paths = [os.path.join(api_info_dir, rank, 'dump.json') for rank in ranks] + forward_contents = [] + real_data_paths = [] + for file_path in file_paths: + forward_content, _, real_data_path = parse_json_info_forward_backward(file_path) + distributed_api = parse_distributed_api(forward_content) + forward_contents.append(distributed_api) + real_data_paths.append(real_data_path) + + out_path = args.out_path if args.out_path else Const.DEFAULT_OUT_PATH + create_directory(out_path) + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + result_file_path = os.path.join(out_path, RESULT_FILE_NAME) + run_distributed_check(forward_contents, real_data_paths, result_file_path) + + +def run_distributed_check(forward_contents, real_data_paths, result_file_path): + for rank, forward_content in enumerate(forward_contents): + logger.info("Start to check distributed api in rank {}.".format(rank)) + + for api_full_name, api_info_dict in forward_content.items(): + if api_info_dict.get('used'): + continue + + group_ranks, group_id = get_group_info(api_full_name, api_info_dict) + if not group_ranks or not group_id: + logger.warning("The api {} doesn't support distributed check.".format(api_full_name)) + continue + all_args, all_kwargs = get_distributed_args_kwargs(forward_contents, api_full_name, + real_data_paths, group_ranks) + try: + distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_file_path) + except Exception as e: + logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) + logger.error(e) + + +def distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_file_path): + _, api_name = extract_basic_api_segments(api_full_name) + nprocs = len(group_ranks) + distributed_config = {} + distributed_config['api_full_name'] = api_full_name + distributed_config['group_ranks'] = group_ranks + distributed_config['all_args'] = all_args + distributed_config['all_kwargs'] = all_kwargs + distributed_config['result_file_path'] = result_file_path + + distributed_config['master_ip'] = '127.0.0.1' + distributed_config['master_port'] = '2688' + distributed_config['world_size'] = nprocs + + mp.spawn(run_hccl, + args=(distributed_config,), + nprocs=nprocs) + +def run_hccl(rank, distributed_config): + local_rank = distributed_config['group_ranks'][rank] + torch_npu.npu.set_device(local_rank) + world_size = distributed_config['world_size'] + master_ip = distributed_config['master_ip'] + master_port = distributed_config['master_port'] + distributed_setup(rank, world_size, master_ip, master_port) + api_full_name = distributed_config['api_full_name'] + api_name = distributed_config['api_name'] + rank_args = distributed_config['all_args'][rank] + rank_kwargs = distributed_config['all_kwargs'][rank] + result_file_path = distributed_config['result_file_path'] + device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name, local_rank) + logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, rank)) + distributed_func.get(api_name)(*device_args, **device_kwargs) + dist.barrier() + status = 'pass' + message = '' + result_rows = [] + df_row = list([api_full_name, local_rank, status, message]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + cleanup() + + +def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks): + all_args, all_kwargs = [], [] + _, api_name = extract_basic_api_segments(api_full_name) + for group_rank in group_ranks: + target_api_info = forward_contents[group_rank].get(api_full_name) + if not target_api_info: + logger.warning("The api {} doesn't exist in rank {}.".format(api_full_name, group_rank)) + continue + if target_api_info.get('used'): + continue + target_api_info['used'] = True + args, kwargs = get_api_info(target_api_info, api_name, real_data_paths[group_rank]) + all_args.append(args) + all_kwargs.append(kwargs) + return all_args, all_kwargs + + +def get_group_info(api_full_name, api_info_dict): + group = api_info_dict.get('input_kwargs', {}).get('group') + if not group: + logger.warning("The api {} doesn't have group info.".format(api_full_name)) + return None, None + group_ranks = group.get('group_ranks') + if not group_ranks: + logger.warning("The group of api {} doesn't have group_ranks info.".format(api_full_name)) + return None, None + group_id = group.get('group_id') + if not group_id: + logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) + return None, None + return group_ranks, group_id + + +if __name__ == '__main__': + logger.info("Start to run distributed ut task.") + _run_distributed() + logger.info("End to run distributed ut task.") \ No newline at end of file -- Gitee From 59d3318a801561f97cc0d1acbc51188d7fb9f8b2 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:06:39 +0800 Subject: [PATCH 02/24] fix bug --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 1643cbefc12..46042985001 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -69,7 +69,7 @@ def run_distributed_command(args): forward_contents.append(distributed_api) real_data_paths.append(real_data_path) - out_path = args.out_path if args.out_path else Const.DEFAULT_OUT_PATH + out_path = args.out_path if args.out_path else Const.DEFAULT_PATH create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() @@ -103,6 +103,7 @@ def distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_f nprocs = len(group_ranks) distributed_config = {} distributed_config['api_full_name'] = api_full_name + distributed_config['api_name'] = api_name distributed_config['group_ranks'] = group_ranks distributed_config['all_args'] = all_args distributed_config['all_kwargs'] = all_kwargs -- Gitee From 89fb601c9b13bc609e3cb555fedf1fb0ae996995 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:26:50 +0800 Subject: [PATCH 03/24] add parse_distributed_api --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 46042985001..336bdbaafcb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -36,9 +36,14 @@ def distributed_setup(rank, world_size, master_ip, master_port): def parse_distributed_api(forward_content): - distributed_api = [] + distributed_api = {} + for api_full_name, api_info_dict in forward_content.items(): + split_name = api_full_name.split(Const.SEP)[0] + if split_name == Const.DISTRIBUTED: + distributed_api.update({api_full_name: api_info_dict}) return distributed_api - + + def _run_distributed_parser(parser): parser.add_argument("-api_info", "--api_info_dir", dest="api_info_dir", default="", type=str, help=" The api param tool result dir: generate from api param tool. ", -- Gitee From 2f4892931cb2ce85a1eb98e85fc07ac01f872c9a Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:29:20 +0800 Subject: [PATCH 04/24] fix bug --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 336bdbaafcb..021641bd90e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -158,7 +158,7 @@ def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths if target_api_info.get('used'): continue target_api_info['used'] = True - args, kwargs = get_api_info(target_api_info, api_name, real_data_paths[group_rank]) + args, kwargs, _ = get_api_info(target_api_info, api_name, real_data_paths[group_rank]) all_args.append(args) all_kwargs.append(kwargs) return all_args, all_kwargs -- Gitee From c71d62ececf8dba1df73fbe3c12642065af3400d Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:31:10 +0800 Subject: [PATCH 05/24] fix bug --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 021641bd90e..40092dee99f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -134,7 +134,7 @@ def run_hccl(rank, distributed_config): rank_args = distributed_config['all_args'][rank] rank_kwargs = distributed_config['all_kwargs'][rank] result_file_path = distributed_config['result_file_path'] - device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name, local_rank) + device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, rank)) distributed_func.get(api_name)(*device_args, **device_kwargs) dist.barrier() -- Gitee From 3ac33649fa343a0a429df19e706795f92c6cbde2 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 10:36:06 +0800 Subject: [PATCH 06/24] fix bug --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 59415a387e6..f3d54f62448 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -27,6 +27,11 @@ from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops + + +distribute_api_set = get_distributed_ops() +distribute_api_list = list(distribute_api_set) TORCH_TYPE = ["torch.device", "torch.dtype"] @@ -68,7 +73,7 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data = gen_random_tensor(info, convert_type) if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) - if info.get('requires_grad') and need_grad: + if info.get('requires_grad') and need_grad and api_name not in distribute_api_list: data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) -- Gitee From bf91171f37f14e0d9d38c6e5fd4c1a35dee9e337 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 22 Jan 2025 11:30:21 +0800 Subject: [PATCH 07/24] bugfix --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 40092dee99f..f288c6eb89e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -83,10 +83,14 @@ def run_distributed_command(args): def run_distributed_check(forward_contents, real_data_paths, result_file_path): + white_list = [] for rank, forward_content in enumerate(forward_contents): logger.info("Start to check distributed api in rank {}.".format(rank)) for api_full_name, api_info_dict in forward_content.items(): + _, api_name = extract_basic_api_segments(api_full_name) + if api_name not in white_list: + continue if api_info_dict.get('used'): continue @@ -98,6 +102,7 @@ def run_distributed_check(forward_contents, real_data_paths, result_file_path): real_data_paths, group_ranks) try: distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_file_path) + import traceback; traceback.print_exc() except Exception as e: logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) logger.error(e) -- Gitee From 54b7e35f046dcfee7a69604c994ec7aa099e1eed Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 23 Jan 2025 09:33:31 +0800 Subject: [PATCH 08/24] fix bug --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index f288c6eb89e..b9666c84535 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -140,7 +140,7 @@ def run_hccl(rank, distributed_config): rank_kwargs = distributed_config['all_kwargs'][rank] result_file_path = distributed_config['result_file_path'] device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) - logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, rank)) + logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) distributed_func.get(api_name)(*device_args, **device_kwargs) dist.barrier() status = 'pass' -- Gitee From 70f84e2a4bc1ed27516a5e82d4431f83f7b0a2ff Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 11 Feb 2025 16:33:26 +0800 Subject: [PATCH 09/24] add register --- .../run_ut/data_generate.py | 2 + .../run_ut/distributed_bench_function.py | 27 ++++++++++ .../run_ut/distributed_compare_function.py | 28 ++++++++++ .../run_ut/distributed_function_registry.py | 53 +++++++++++++++++++ .../run_ut/run_distributed_check.py | 4 +- 5 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index f3d54f62448..45fa8873b5c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -316,6 +316,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None): kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path) elif value is None: kwargs_params[key] = None + elif key == 'group' and value.get('type') == 'torch.ProcessGroup': + kwargs_params[key] = value elif key == 'atten_mask' and api_name == 'npu_fusion_attention': sparse_mode = kwargs_params.get('sparse_mode', {}) if isinstance(sparse_mode, dict): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py new file mode 100644 index 00000000000..0c43033125b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def mock_broadcast(input_args, input_kwargs): + src = input_args[0][1] + group = input_kwargs[0].get('group', None) + group_ranks = group.get('group_ranks') + real_src = src - min(group_ranks) + + return input_args[real_src][0] \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py new file mode 100644 index 00000000000..802fba45c86 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from msprobe.core.common.const import Const, CompareConst + + +def compare_broadcast(device_out, bench_out): + compare_result = torch.equal(device_out[0].cpu(), bench_out) + if not compare_result: + return CompareConst.ERROR + return CompareConst.PASS + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py new file mode 100644 index 00000000000..8b6581f0553 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import \ + mock_broadcast +from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ + compare_broadcast +from msprobe.core.common.const import CompareConst + + +class DistributedFunctionRegistry: + def __init__(self): + self.compare_functions = {} + self.bench_functions = {} + + def register_compare_function(self, api_name: str, function: Callable): + self.compare_functions[api_name] = function + + def register_bench_function(self, api_name: str, function: Callable): + self.bench_functions[api_name] = function + + def get_compare_function(self, api_name: str) -> Callable: + if api_name in self.compare_functions: + return self.compare_functions[api_name] + else: + return compare_broadcast + + def get_bench_function(self, api_name: str) -> Callable: + if api_name in self.bench_functions: + return self.bench_functions[api_name] + else: + return mock_broadcast + + +distributed_func_registry = DistributedFunctionRegistry() +distributed_func_registry.register_bench_function('broadcast', mock_broadcast) +distributed_func_registry.register_compare_function('broadcast', compare_broadcast) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index b9666c84535..3ed953dc22b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -142,7 +142,7 @@ def run_hccl(rank, distributed_config): device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) distributed_func.get(api_name)(*device_args, **device_kwargs) - dist.barrier() + # dist.barrier() status = 'pass' message = '' result_rows = [] @@ -188,4 +188,4 @@ def get_group_info(api_full_name, api_info_dict): if __name__ == '__main__': logger.info("Start to run distributed ut task.") _run_distributed() - logger.info("End to run distributed ut task.") \ No newline at end of file + logger.info("End to run distributed ut task.") -- Gitee From 0f5378365314d7119b983c33481b109c3048850d Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 11 Feb 2025 17:05:52 +0800 Subject: [PATCH 10/24] add csv header --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 3ed953dc22b..1aceb34f192 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -23,6 +23,7 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_ current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" +RESULT_CSV_HEADER = [['API_NAME', 'RANK', 'COMPARE_RESULT', 'MESSAGE']] def cleanup(): @@ -33,7 +34,7 @@ def distributed_setup(rank, world_size, master_ip, master_port): init_method = 'tcp://' + master_ip + ':' + master_port dist.init_process_group(backend='hccl', init_method=init_method, world_size=world_size, rank=rank) - + def parse_distributed_api(forward_content): distributed_api = {} @@ -79,6 +80,7 @@ def run_distributed_command(args): out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() result_file_path = os.path.join(out_path, RESULT_FILE_NAME) + write_csv(RESULT_CSV_HEADER, result_file_path) run_distributed_check(forward_contents, real_data_paths, result_file_path) @@ -150,7 +152,7 @@ def run_hccl(rank, distributed_config): result_rows.append(df_row) write_csv(result_rows, result_file_path) cleanup() - + def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks): all_args, all_kwargs = [], [] -- Gitee From fda77972b580a25e7c2f32ab9cbbee11a919b23d Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 14 Feb 2025 09:49:03 +0800 Subject: [PATCH 11/24] add bench and compare func --- .../run_ut/distributed_compare_function.py | 2 +- .../api_accuracy_checker/run_ut/run_distributed_check.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py index 802fba45c86..f7f0fa581b1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py @@ -17,7 +17,7 @@ import torch -from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.const import CompareConst def compare_broadcast(device_out, bench_out): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 1aceb34f192..f349e245112 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -18,6 +18,7 @@ from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments @@ -120,7 +121,8 @@ def distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_f distributed_config['all_args'] = all_args distributed_config['all_kwargs'] = all_kwargs distributed_config['result_file_path'] = result_file_path - + benchmark_function = distributed_func_registry.compare_functions(api_name) + distributed_config['benchmark_result'] = benchmark_function(all_args, all_kwargs) distributed_config['master_ip'] = '127.0.0.1' distributed_config['master_port'] = '2688' distributed_config['world_size'] = nprocs @@ -141,11 +143,13 @@ def run_hccl(rank, distributed_config): rank_args = distributed_config['all_args'][rank] rank_kwargs = distributed_config['all_kwargs'][rank] result_file_path = distributed_config['result_file_path'] + benchmark_result = distributed_config['benchmark_result'] device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) distributed_func.get(api_name)(*device_args, **device_kwargs) # dist.barrier() - status = 'pass' + compare_function = distributed_func_registry.get_compare_function(api_name) + status = compare_function(device_args, benchmark_result) message = '' result_rows = [] df_row = list([api_full_name, local_rank, status, message]) -- Gitee From 87bd6515e95e280f9d40b47587aa6cd0ea107b64 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 20 Feb 2025 17:32:34 +0800 Subject: [PATCH 12/24] update const and config --- debug/accuracy_tools/msprobe/config.json | 4 +- .../msprobe/core/common/const.py | 16 ++++ .../api_accuracy_checker/common/config.py | 12 ++- .../pytorch/api_accuracy_checker/config.yaml | 2 + .../run_ut/run_distributed_check.py | 82 ++++++++++++------- .../msprobe/pytorch/pt_config.py | 14 ++++ 6 files changed, 99 insertions(+), 31 deletions(-) diff --git a/debug/accuracy_tools/msprobe/config.json b/debug/accuracy_tools/msprobe/config.json index 3b6c930fdd7..98c4c8699fd 100644 --- a/debug/accuracy_tools/msprobe/config.json +++ b/debug/accuracy_tools/msprobe/config.json @@ -24,7 +24,9 @@ "run_ut": { "white_list": [], "black_list": [], - "error_data_path": "./" + "error_data_path": "./", + "master_ip": "127.0.0.1", + "master_port": "8888" }, "grad_probe": { "grad_level": "L1", diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 5a165443be4..d1d66815ab7 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -27,6 +27,7 @@ class Const: ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." + DOUBLE_SLASH = "//" REGEX_PREFIX_MAX_LENGTH = 20 REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' @@ -600,3 +601,18 @@ class MonitorConst: API = "api" OPS_START_INDEX = 3 HEADER_NAME_INDEX = 1 + + +class DistributedCheckConst: + API_FULL_NAME = "api_full_name" + API_NAME = "api_name" + GROUP_RANKS = "group_ranks" + ALL_ARGS = "all_args" + ALL_KWARGS = "all_kwargs" + RESULT_FILE_PATH = "result_file_path" + BENCHMARK_RESULT = "benchmark_result" + MASTER_IP = "master_ip" + MASTER_PORT = "master_port" + WORLD_SIZE = "world_size" + HCCL = "hccl" + TCP = "tcp" \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index f2b2d6a3046..588a1eb349a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -52,7 +52,9 @@ class Config: 'host': str, 'port': int, 'rank_list': list, - 'tls_path': str + 'tls_path': str, + 'master_ip': str, + 'master_port': str } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -72,6 +74,10 @@ class Config: RunUTConfig.check_nfs_path_config(value) if key == 'tls_path': RunUTConfig.check_tls_path_config(value) + if key == 'master_ip': + RunUTConfig.check_master_ip_config(value) + if key == 'master_port': + RunUTConfig.check_master_port_config(value) return value @@ -91,6 +97,8 @@ class CheckerConfig: self.port = msCheckerConfig.port self.rank_list = msCheckerConfig.rank_list self.tls_path = msCheckerConfig.tls_path + self.master_ip = msCheckerConfig.master_ip + self.master_port = msCheckerConfig.master_port if task_config: self.load_config(task_config) @@ -105,6 +113,8 @@ class CheckerConfig: self.port = task_config.port self.rank_list = task_config.rank_list self.tls_path = task_config.tls_path + self.master_ip = task_config.master_ip + self.master_port = task_config.master_port def get_online_config(self): return OnlineConfig( diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2ec9251009e..30cea3b8e01 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -8,3 +8,5 @@ host: "" port: -1 rank_list: [0] tls_path: "./" +master_ip: '127.0.0.1' +master_port: '2688' diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index f349e245112..77af666a1e6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -2,6 +2,7 @@ import os import sys import time import argparse +from collections import namedtuple import tqdm import torch @@ -11,20 +12,24 @@ import torch.multiprocessing as mp from msprobe.core.common.file_utils import FileChecker, write_csv, create_directory -from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.core.common.const import Const, FileCheckConst, DistributedCheckConst, MonitorConst from msprobe.core.compare.utils import check_and_return_dir_contents from msprobe.pytorch.hook_module.wrap_distributed import distributed_func +from msprobe.pytorch.pt_config import parse_json_config from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments +from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" RESULT_CSV_HEADER = [['API_NAME', 'RANK', 'COMPARE_RESULT', 'MESSAGE']] +Distributed_Check_Params = namedtuple("Distributed_Check_Params", ["api_full_name", "all_args", "all_kwargs", + "group_ranks", "result_file_path", "checker_config"]) def cleanup(): @@ -32,8 +37,8 @@ def cleanup(): def distributed_setup(rank, world_size, master_ip, master_port): - init_method = 'tcp://' + master_ip + ':' + master_port - dist.init_process_group(backend='hccl', init_method=init_method, + init_method = DistributedCheckConst.TCP + MonitorConst.VPP_SEP + Const.DOUBLE_SLASH + master_ip + MonitorConst.VPP_SEP + master_port + dist.init_process_group(backend=DistributedCheckConst.HCCL, init_method=init_method, world_size=world_size, rank=rank) @@ -53,6 +58,8 @@ def _run_distributed_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) + parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str, + help=" The path of config.json", required=False) def _run_distributed(parser=None): @@ -82,10 +89,18 @@ def run_distributed_command(args): out_path = out_path_checker.common_check() result_file_path = os.path.join(out_path, RESULT_FILE_NAME) write_csv(RESULT_CSV_HEADER, result_file_path) - run_distributed_check(forward_contents, real_data_paths, result_file_path) - - -def run_distributed_check(forward_contents, real_data_paths, result_file_path): + if args.config_path: + config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE, + FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) + checked_config_path = config_path_checker.common_check() + _, task_config = parse_json_config(checked_config_path, Const.RUN_UT) + checker_config = CheckerConfig(task_config) + else: + checker_config = CheckerConfig() + run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config) + + +def run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config): white_list = [] for rank, forward_content in enumerate(forward_contents): logger.info("Start to check distributed api in rank {}.".format(rank)) @@ -104,46 +119,55 @@ def run_distributed_check(forward_contents, real_data_paths, result_file_path): all_args, all_kwargs = get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks) try: - distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_file_path) + distributed_check_params = Distributed_Check_Params(api_full_name, all_args, all_kwargs, group_ranks, + result_file_path, checker_config) + distributed_check(distributed_check_params) import traceback; traceback.print_exc() except Exception as e: logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) logger.error(e) -def distributed_check(api_full_name, all_args, all_kwargs, group_ranks, result_file_path): +def distributed_check(distributed_check_params): + api_full_name = distributed_check_params.api_full_name + all_args = distributed_check_params.all_args + all_kwargs = distributed_check_params.all_kwargs + group_ranks = distributed_check_params.group_ranks + result_file_path = distributed_check_params.result_file_path + checker_config = distributed_check_params.checker_config + _, api_name = extract_basic_api_segments(api_full_name) nprocs = len(group_ranks) distributed_config = {} - distributed_config['api_full_name'] = api_full_name - distributed_config['api_name'] = api_name - distributed_config['group_ranks'] = group_ranks - distributed_config['all_args'] = all_args - distributed_config['all_kwargs'] = all_kwargs - distributed_config['result_file_path'] = result_file_path + distributed_config[DistributedCheckConst.API_FULL_NAME] = api_full_name + distributed_config[DistributedCheckConst.API_NAME] = api_name + distributed_config[DistributedCheckConst.GROUP_RANKS] = group_ranks + distributed_config[DistributedCheckConst.ALL_ARGS] = all_args + distributed_config[DistributedCheckConst.ALL_KWARGS] = all_kwargs + distributed_config[DistributedCheckConst.RESULT_FILE_PATH] = result_file_path benchmark_function = distributed_func_registry.compare_functions(api_name) - distributed_config['benchmark_result'] = benchmark_function(all_args, all_kwargs) - distributed_config['master_ip'] = '127.0.0.1' - distributed_config['master_port'] = '2688' - distributed_config['world_size'] = nprocs + distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(all_args, all_kwargs) + distributed_config[DistributedCheckConst.MASTER_IP] = checker_config.master_ip + distributed_config[DistributedCheckConst.MASTER_PORT] = checker_config.master_port + distributed_config[DistributedCheckConst.WORLD_SIZE] = nprocs mp.spawn(run_hccl, args=(distributed_config,), nprocs=nprocs) def run_hccl(rank, distributed_config): - local_rank = distributed_config['group_ranks'][rank] + local_rank = distributed_config[DistributedCheckConst.GROUP_RANKS][rank] torch_npu.npu.set_device(local_rank) - world_size = distributed_config['world_size'] - master_ip = distributed_config['master_ip'] - master_port = distributed_config['master_port'] + world_size = distributed_config[DistributedCheckConst.WORLD_SIZE] + master_ip = distributed_config[DistributedCheckConst.MASTER_IP] + master_port = distributed_config[DistributedCheckConst.MASTER_PORT] distributed_setup(rank, world_size, master_ip, master_port) - api_full_name = distributed_config['api_full_name'] - api_name = distributed_config['api_name'] - rank_args = distributed_config['all_args'][rank] - rank_kwargs = distributed_config['all_kwargs'][rank] - result_file_path = distributed_config['result_file_path'] - benchmark_result = distributed_config['benchmark_result'] + api_full_name = distributed_config[DistributedCheckConst.API_FULL_NAME] + api_name = distributed_config[DistributedCheckConst.API_NAME] + rank_args = distributed_config[DistributedCheckConst.ALL_ARGS][rank] + rank_kwargs = distributed_config[DistributedCheckConst.ALL_KWARGS][rank] + result_file_path = distributed_config[DistributedCheckConst.RESULT_FILE_PATH] + benchmark_result = distributed_config[DistributedCheckConst.BENCHMARK_RESULT] device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) distributed_func.get(api_name)(*device_args, **device_kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 01cff973dfb..f718efcaa7f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -252,6 +252,8 @@ class RunUTConfig(BaseConfig): self.port = json_config.get("port", -1) self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) self.tls_path = json_config.get("tls_path", "./") + self.master_ip = json_config.get("master_ip", "127.0.0.1") + self.master_port = json_config.get("master_port", "8888") self.check_run_ut_config() @classmethod @@ -278,6 +280,16 @@ class RunUTConfig(BaseConfig): def check_tls_path_config(cls, tls_path): if tls_path and not os.path.exists(tls_path): raise Exception("tls_path: %s does not exist" % tls_path) + + @classmethod + def check_master_ip_config(cls, master_ip): + if not re.match(Const.ipv4_pattern, master_ip): + raise Exception("master_ip: %s is invalid" % master_ip) + + @classmethod + def check_master_port_config(cls, master_port): + if not is_int(master_port) or not (0 < int(master_port) <= 65535): + raise Exception("master_port: %s is invalid, port range 0-65535" % master_port) def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) @@ -285,6 +297,8 @@ class RunUTConfig(BaseConfig): RunUTConfig.check_error_data_path_config(self.error_data_path) RunUTConfig.check_nfs_path_config(self.nfs_path) RunUTConfig.check_tls_path_config(self.tls_path) + RunUTConfig.check_master_ip_config(self.master_ip) + RunUTConfig.check_master_port_config(self.master_port) class GradToolConfig(BaseConfig): -- Gitee From 68863f6b75db50ad42b7f353de7dce25bd673b1a Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 21 Feb 2025 17:44:20 +0800 Subject: [PATCH 13/24] add const --- .../msprobe/core/common/const.py | 9 ++++- .../run_ut/distributed_bench_function.py | 11 ++++-- .../run_ut/distributed_function_registry.py | 7 ++-- .../run_ut/run_distributed_check.py | 34 +++++----------- .../run_ut/run_ut_utils.py | 39 ++++++++++++++++++- 5 files changed, 66 insertions(+), 34 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index d1d66815ab7..ab7636068c4 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -606,6 +606,7 @@ class MonitorConst: class DistributedCheckConst: API_FULL_NAME = "api_full_name" API_NAME = "api_name" + GROUP = "group" GROUP_RANKS = "group_ranks" ALL_ARGS = "all_args" ALL_KWARGS = "all_kwargs" @@ -615,4 +616,10 @@ class DistributedCheckConst: MASTER_PORT = "master_port" WORLD_SIZE = "world_size" HCCL = "hccl" - TCP = "tcp" \ No newline at end of file + TCP = "tcp" + BROADCAST = "broadcast" + BROADCAST_SRC_INDEX = 1 + FIRST_TENSOR_INDEX = 0 + GROUP_POSITION = { + "broadcast": 2 + } diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py index 0c43033125b..3d153e2f8fd 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -15,13 +15,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +from msprobe.core.common.const import DistributedCheckConst def mock_broadcast(input_args, input_kwargs): + src = input_args[0][1] - group = input_kwargs[0].get('group', None) - group_ranks = group.get('group_ranks') + group = input_kwargs[0].get(DistributedCheckConst.GROUP) + group_ranks = group.get(DistributedCheckConst.GROUP_RANKS) real_src = src - min(group_ranks) - return input_args[real_src][0] \ No newline at end of file + return input_args[real_src][0] + + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py index 8b6581f0553..8deb3309fde 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -21,7 +21,7 @@ from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distribute mock_broadcast from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ compare_broadcast -from msprobe.core.common.const import CompareConst +from msprobe.core.common.const import DistributedCheckConst class DistributedFunctionRegistry: @@ -49,5 +49,6 @@ class DistributedFunctionRegistry: distributed_func_registry = DistributedFunctionRegistry() -distributed_func_registry.register_bench_function('broadcast', mock_broadcast) -distributed_func_registry.register_compare_function('broadcast', compare_broadcast) +distributed_func_registry.register_bench_function(DistributedCheckConst.BROADCAST, mock_broadcast) +distributed_func_registry.register_compare_function(DistributedCheckConst.BROADCAST, compare_broadcast) + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 77af666a1e6..02cb513732f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -18,7 +18,7 @@ from msprobe.pytorch.hook_module.wrap_distributed import distributed_func from msprobe.pytorch.pt_config import parse_json_config from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params, get_group_info from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments @@ -68,8 +68,8 @@ def _run_distributed(parser=None): _run_distributed_parser(parser) args = parser.parse_args(sys.argv[1:]) run_distributed_command(args) - - + + def run_distributed_command(args): input_checker = FileChecker(args.api_info_dir, FileCheckConst.DIR, ability=FileCheckConst.READ_ABLE) api_info_dir = input_checker.common_check() @@ -112,7 +112,7 @@ def run_distributed_check(forward_contents, real_data_paths, result_file_path, c if api_info_dict.get('used'): continue - group_ranks, group_id = get_group_info(api_full_name, api_info_dict) + group_ranks, group_id = get_group_info(api_full_name, api_name, api_info_dict) if not group_ranks or not group_id: logger.warning("The api {} doesn't support distributed check.".format(api_full_name)) continue @@ -122,8 +122,8 @@ def run_distributed_check(forward_contents, real_data_paths, result_file_path, c distributed_check_params = Distributed_Check_Params(api_full_name, all_args, all_kwargs, group_ranks, result_file_path, checker_config) distributed_check(distributed_check_params) - import traceback; traceback.print_exc() except Exception as e: + import traceback; traceback.print_exc() logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) logger.error(e) @@ -145,7 +145,7 @@ def distributed_check(distributed_check_params): distributed_config[DistributedCheckConst.ALL_ARGS] = all_args distributed_config[DistributedCheckConst.ALL_KWARGS] = all_kwargs distributed_config[DistributedCheckConst.RESULT_FILE_PATH] = result_file_path - benchmark_function = distributed_func_registry.compare_functions(api_name) + benchmark_function = distributed_func_registry.get_bench_function(api_name) distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(all_args, all_kwargs) distributed_config[DistributedCheckConst.MASTER_IP] = checker_config.master_ip distributed_config[DistributedCheckConst.MASTER_PORT] = checker_config.master_port @@ -168,10 +168,10 @@ def run_hccl(rank, distributed_config): rank_kwargs = distributed_config[DistributedCheckConst.ALL_KWARGS][rank] result_file_path = distributed_config[DistributedCheckConst.RESULT_FILE_PATH] benchmark_result = distributed_config[DistributedCheckConst.BENCHMARK_RESULT] - device_args, device_kwargs = generate_device_params(rank_args, rank_kwargs, False, api_name) + device_args, _ = generate_device_params(rank_args, rank_kwargs, False, api_name) logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) - distributed_func.get(api_name)(*device_args, **device_kwargs) - # dist.barrier() + distributed_func.get(api_name)(*device_args) + compare_function = distributed_func_registry.get_compare_function(api_name) status = compare_function(device_args, benchmark_result) message = '' @@ -199,22 +199,6 @@ def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths return all_args, all_kwargs -def get_group_info(api_full_name, api_info_dict): - group = api_info_dict.get('input_kwargs', {}).get('group') - if not group: - logger.warning("The api {} doesn't have group info.".format(api_full_name)) - return None, None - group_ranks = group.get('group_ranks') - if not group_ranks: - logger.warning("The group of api {} doesn't have group_ranks info.".format(api_full_name)) - return None, None - group_id = group.get('group_id') - if not group_id: - logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) - return None, None - return group_ranks, group_id - - if __name__ == '__main__': logger.info("Start to run distributed ut task.") _run_distributed() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dc0174212e3..2099796479f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -29,7 +29,7 @@ else: current_device = "npu" from torch_npu.npu.amp import autocast -from msprobe.core.common.const import FileCheckConst, Const, CompareConst +from msprobe.core.common.const import FileCheckConst, Const, CompareConst, DistributedCheckConst from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException @@ -252,3 +252,40 @@ def is_unsupported_api(api_name, is_overflow_check=False): if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag + + +def get_group_position(api_name): + return DistributedCheckConst.GROUP_POSITION.get(api_name, None) + + +def get_group(api_name, input_args, input_kwargs): + group = None + group = input_kwargs.get(DistributedCheckConst.GROUP) + if group: + return group + group_position = get_group_position(api_name) + if not group_position or len(input_args) <= group_position: + return None + group = input_args[group_position] + if not isinstance(group, dict): + return None + return group + + +def get_group_info(api_full_name, api_name, api_info_dict): + input_args = api_info_dict.get('input_args', {}) + input_kwargs = api_info_dict.get('input_kwargs', {}) + group = get_group(api_name, input_args, input_kwargs) + + if not group: + logger.warning("The api {} doesn't have group info.".format(api_full_name)) + return None, None + group_ranks = group.get('group_ranks') + if not group_ranks: + logger.warning("The group of api {} doesn't have group_ranks info.".format(api_full_name)) + return None, None + group_id = group.get('group_id') + if not group_id: + logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) + return None, None + return group_ranks, group_id -- Gitee From bb428509c9f72dde72aae54bc2cb353b9223b177 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 10:51:45 +0800 Subject: [PATCH 14/24] add const --- .../msprobe/core/common/const.py | 1 + .../run_ut/data_generate.py | 22 +++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index ab7636068c4..b7484485ec7 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -608,6 +608,7 @@ class DistributedCheckConst: API_NAME = "api_name" GROUP = "group" GROUP_RANKS = "group_ranks" + TORCH_PROCESS_GROUP = "torch.ProcessGroup" ALL_ARGS = "all_args" ALL_KWARGS = "all_kwargs" RESULT_FILE_PATH = "result_file_path" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 45fa8873b5c..5075fee6056 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -26,7 +26,7 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt -from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.core.common.const import Const, FileCheckConst, CompareConst, DistributedCheckConst from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops @@ -261,17 +261,19 @@ def gen_bool_tensor(low, high, shape): return data -def gen_args(args_info, api_name, func_options): +def gen_args(api_info, api_name, func_options): """ Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: - api_info: API basic information. List + api_info: API basic information. DICT api_name: API name need_grad: set Tensor grad for backward convert_type: convert ori_type to dist_type flag. real_data_path: the root directory for storing real data. """ + check_object_type(api_info, dict) + args_info = api_info.get("input_args") check_object_type(args_info, list) args_result = [] @@ -289,7 +291,12 @@ def gen_args(args_info, api_name, func_options): func_options['depth'] = depth + 1 data = gen_args(arg, api_name, func_options) elif isinstance(arg, dict): - data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) + if arg.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: + data = None + kwargs_params = api_info.get(Const.INPUT_KWARGS) + kwargs_params[DistributedCheckConst.GROUP] = arg + else: + data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) elif arg is None: data = None else: @@ -316,7 +323,7 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None): kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path) elif value is None: kwargs_params[key] = None - elif key == 'group' and value.get('type') == 'torch.ProcessGroup': + elif key == DistributedCheckConst.GROUP and value.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: kwargs_params[key] = value elif key == 'atten_mask' and api_name == 'npu_fusion_attention': sparse_mode = kwargs_params.get('sparse_mode', {}) @@ -421,7 +428,7 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d if convert_type and convert_type not in Const.CONVERT: error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) + func_options = { 'need_grad': need_grad, 'convert_type': convert_type, @@ -429,9 +436,10 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d 'depth': 0 } if api_info.get("input_args"): - args_params = gen_args(api_info.get("input_args"), api_name, func_options) + args_params = gen_args(api_info, api_name, func_options) else: logger.warning(f'Warning: No args in {api_info} ') args_params = [] + kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) output_dtype = get_output_dtype(api_info) return args_params, kwargs_params, output_dtype -- Gitee From ab10ea06961e9c55f0cc18922954a2761bfc7798 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 10:59:37 +0800 Subject: [PATCH 15/24] add colon const --- debug/accuracy_tools/msprobe/core/common/const.py | 1 + .../api_accuracy_checker/run_ut/run_distributed_check.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index b7484485ec7..30b8173c7f2 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -27,6 +27,7 @@ class Const: ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." + COLON = ":" DOUBLE_SLASH = "//" REGEX_PREFIX_MAX_LENGTH = 20 REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 02cb513732f..02f216f3482 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -37,7 +37,7 @@ def cleanup(): def distributed_setup(rank, world_size, master_ip, master_port): - init_method = DistributedCheckConst.TCP + MonitorConst.VPP_SEP + Const.DOUBLE_SLASH + master_ip + MonitorConst.VPP_SEP + master_port + init_method = DistributedCheckConst.TCP + Const.COLON + Const.DOUBLE_SLASH + master_ip + Const.COLON + master_port dist.init_process_group(backend=DistributedCheckConst.HCCL, init_method=init_method, world_size=world_size, rank=rank) -- Gitee From 147f9ba457a20329b954f683a639d4d0a591911c Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 11:08:14 +0800 Subject: [PATCH 16/24] fix check bug --- debug/accuracy_tools/msprobe/pytorch/pt_config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index ce147f2bae3..d6dc6e9fcd8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -70,8 +70,11 @@ class TensorConfig(BaseConfig): if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host): raise Exception(f"host: {self.host} is invalid.") - if not isinstance(self.port, int) or not (0 < self.port <= 65535): - raise Exception(f"port: {self.port} is invalid, port range 0-65535.") + if not isinstance(self.port, str) or not self.port.isdigit(): + raise Exception(f"port: {self.port} is invalid. Port must be a numeric string.") + port_number = int(self.port) + if not (0 < port_number <= 65535): + raise Exception(f"port: {self.port} is invalid. Port range must be between 1 and 65535.") class StatisticsConfig(BaseConfig): -- Gitee From 2189726c8368336ac21c180bf1a28c85eb6dad33 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 11:28:13 +0800 Subject: [PATCH 17/24] fix check bug --- debug/accuracy_tools/msprobe/pytorch/pt_config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index d6dc6e9fcd8..eb0df5eec52 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -291,8 +291,11 @@ class RunUTConfig(BaseConfig): @classmethod def check_master_port_config(cls, master_port): - if not is_int(master_port) or not (0 < int(master_port) <= 65535): - raise Exception("master_port: %s is invalid, port range 0-65535" % master_port) + if not isinstance(master_port, str) or not master_port.isdigit(): + raise Exception(f"port: {master_port} is invalid. Port must be a numeric string.") + port_number = int(master_port) + if not (0 < port_number <= 65535): + raise Exception(f"port: {master_port} is invalid. Port range must be between 1 and 65535.") def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) -- Gitee From 5cc905e4476778855d4ac0a7916b4018489ea439 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 14:39:55 +0800 Subject: [PATCH 18/24] add group to kwargs --- .../run_ut/data_generate.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 53c9ca68d05..682c8c510b5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -261,19 +261,20 @@ def gen_bool_tensor(low, high, shape): return data -def gen_args(api_info, api_name, func_options): +def gen_args(args_info, api_name, func_options): """ Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: - api_info: API basic information. DICT + args_info: API basic information. DICT api_name: API name - need_grad: set Tensor grad for backward - convert_type: convert ori_type to dist_type flag. - real_data_path: the root directory for storing real data. + func_options: the options for generating args. Dict + need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. + real_data_path: the root directory for storing real data. + depth: the depth of recursion. + kwargs_params: the input kwargs parameters. """ - check_object_type(api_info, dict) - args_info = api_info.get("input_args") check_object_type(args_info, list) args_result = [] @@ -281,6 +282,7 @@ def gen_args(api_info, api_name, func_options): convert_type = func_options.get('convert_type', None) real_data_path = func_options.get('real_data_path', None) depth = func_options.get('depth', 0) + kwargs_params = func_options.get('input_kwargs', {}) if depth > Const.MAX_DEPTH: logger.error("The depth of args is too large, please check the input args.") @@ -289,11 +291,11 @@ def gen_args(api_info, api_name, func_options): for arg in args_info: if isinstance(arg, (list, tuple)): func_options['depth'] = depth + 1 + #zhelilyouwenti data = gen_args(arg, api_name, func_options) elif isinstance(arg, dict): if arg.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: data = None - kwargs_params = api_info.get(Const.INPUT_KWARGS) kwargs_params[DistributedCheckConst.GROUP] = arg else: data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) @@ -434,10 +436,11 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d 'need_grad': need_grad, 'convert_type': convert_type, 'real_data_path': real_data_path, - 'depth': 0 + 'depth': 0, + 'input_kwargs': api_info.get("input_kwargs", {}) } if api_info.get("input_args"): - args_params = gen_args(api_info, api_name, func_options) + args_params = gen_args(api_info.get("input_args"), api_name, func_options) else: logger.warning(f'Warning: No args in {api_info} ') args_params = [] -- Gitee From 728b24f3084c52d27087ff9b923a138d82948ccb Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 16:40:05 +0800 Subject: [PATCH 19/24] add check --- .../msprobe/core/common/const.py | 12 +++++ .../run_ut/distributed_bench_function.py | 17 ++++-- .../run_ut/distributed_compare_function.py | 2 + .../run_ut/distributed_function_registry.py | 4 +- .../run_ut/run_distributed_check.py | 48 +++++++++++++---- .../run_ut/run_ut_utils.py | 52 +++++++++++++++++-- 6 files changed, 115 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 6429d812560..6c09e5262c5 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -683,6 +683,9 @@ class DistributedCheckConst: API_NAME = "api_name" GROUP = "group" GROUP_RANKS = "group_ranks" + GROUP_INDEX = "group_index" + SRC = "src" + SRC_INDEX = "src_index" TORCH_PROCESS_GROUP = "torch.ProcessGroup" ALL_ARGS = "all_args" ALL_KWARGS = "all_kwargs" @@ -699,3 +702,12 @@ class DistributedCheckConst: GROUP_POSITION = { "broadcast": 2 } + SRC_POSITION = { + "broadcast": 1 + } + API_ARGS_INDEX = { + "broadcast": { + "group_index": 2, + "src_index": 1 + } +} diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py index 3d153e2f8fd..f5996a4f663 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -16,13 +16,22 @@ # limitations under the License. from msprobe.core.common.const import DistributedCheckConst +from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_group, get_src -def mock_broadcast(input_args, input_kwargs): +def mock_broadcast(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + if len(input_args) < 1 or len(input_kwargs) < 1: + raise ValueError("input_args and input_kwargs should have at least 1 element") - src = input_args[0][1] - group = input_kwargs[0].get(DistributedCheckConst.GROUP) - group_ranks = group.get(DistributedCheckConst.GROUP_RANKS) + src = get_src(api_name, input_args[0], input_kwargs[0]) + + group = get_group(api_name, input_args[0], input_kwargs[0]) + group_ranks = group.get(DistributedCheckConst.GROUP_RANKS, []) + if not group_ranks: + raise ValueError("group_ranks should not be empty") real_src = src - min(group_ranks) return input_args[real_src][0] diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py index f7f0fa581b1..8f6d2f6fd47 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py @@ -21,6 +21,8 @@ from msprobe.core.common.const import CompareConst def compare_broadcast(device_out, bench_out): + if len(device_out) < 1: + raise ValueError("device_out should not be empty") compare_result = torch.equal(device_out[0].cpu(), bench_out) if not compare_result: return CompareConst.ERROR diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py index 8deb3309fde..d6529a2b83a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -17,9 +17,9 @@ from typing import Callable -from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import \ +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import \ mock_broadcast -from debug.accuracy_tools.msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ compare_broadcast from msprobe.core.common.const import DistributedCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 02f216f3482..ffdccd0e190 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import sys import time @@ -12,13 +29,14 @@ import torch.multiprocessing as mp from msprobe.core.common.file_utils import FileChecker, write_csv, create_directory -from msprobe.core.common.const import Const, FileCheckConst, DistributedCheckConst, MonitorConst +from msprobe.core.common.const import Const, FileCheckConst, DistributedCheckConst, CompareConst from msprobe.core.compare.utils import check_and_return_dir_contents from msprobe.pytorch.hook_module.wrap_distributed import distributed_func from msprobe.pytorch.pt_config import parse_json_config from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params, get_group_info +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params, get_group_info, \ + is_port_in_use from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments @@ -28,7 +46,7 @@ from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" RESULT_CSV_HEADER = [['API_NAME', 'RANK', 'COMPARE_RESULT', 'MESSAGE']] -Distributed_Check_Params = namedtuple("Distributed_Check_Params", ["api_full_name", "all_args", "all_kwargs", +DistributedCheckParams = namedtuple("DistributedCheckParams", ["api_full_name", "all_args", "all_kwargs", "group_ranks", "result_file_path", "checker_config"]) @@ -119,13 +137,17 @@ def run_distributed_check(forward_contents, real_data_paths, result_file_path, c all_args, all_kwargs = get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks) try: - distributed_check_params = Distributed_Check_Params(api_full_name, all_args, all_kwargs, group_ranks, + distributed_check_params = DistributedCheckParams(api_full_name, all_args, all_kwargs, group_ranks, result_file_path, checker_config) distributed_check(distributed_check_params) except Exception as e: - import traceback; traceback.print_exc() + import traceback + traceback.print_exc() logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) - logger.error(e) + result_rows = [] + df_row = list([api_full_name, rank, CompareConst.ERROR, str(e)]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) def distributed_check(distributed_check_params): @@ -135,7 +157,7 @@ def distributed_check(distributed_check_params): group_ranks = distributed_check_params.group_ranks result_file_path = distributed_check_params.result_file_path checker_config = distributed_check_params.checker_config - + _, api_name = extract_basic_api_segments(api_full_name) nprocs = len(group_ranks) distributed_config = {} @@ -146,15 +168,23 @@ def distributed_check(distributed_check_params): distributed_config[DistributedCheckConst.ALL_KWARGS] = all_kwargs distributed_config[DistributedCheckConst.RESULT_FILE_PATH] = result_file_path benchmark_function = distributed_func_registry.get_bench_function(api_name) - distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(all_args, all_kwargs) + distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(api_name, all_args, all_kwargs) distributed_config[DistributedCheckConst.MASTER_IP] = checker_config.master_ip distributed_config[DistributedCheckConst.MASTER_PORT] = checker_config.master_port distributed_config[DistributedCheckConst.WORLD_SIZE] = nprocs + if is_port_in_use(checker_config.master_port, checker_config.master_ip): + raise ValueError( + f"Warning: Port {checker_config.master_port} on host " + f"{checker_config.master_ip} is already in use." + ) + logger.info(f"Port {checker_config.master_port} on host {checker_config.master_ip} is available.") + mp.spawn(run_hccl, args=(distributed_config,), nprocs=nprocs) - + + def run_hccl(rank, distributed_config): local_rank = distributed_config[DistributedCheckConst.GROUP_RANKS][rank] torch_npu.npu.set_device(local_rank) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index 2099796479f..a8570f98d19 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -16,6 +16,7 @@ # limitations under the License. import os +import socket from collections import namedtuple import re import torch @@ -254,8 +255,17 @@ def is_unsupported_api(api_name, is_overflow_check=False): return flag -def get_group_position(api_name): - return DistributedCheckConst.GROUP_POSITION.get(api_name, None) +def get_args_index(api_name, args_name): + """ + 根据 API 名字和参数名获取参数索引。获取 group_index 或者 src_index。 + :param api_name: API 名字,如 "broadcast" 或 "all_reduce" + :param args_name: 参数名,如 "group" 或 "src" + :return: 参数索引 或 None(如果 API 名字或参数名不存在) + """ + api_info = DistributedCheckConst.API_ARGS_INDEX.get(api_name) + if api_info: + return api_info.get(args_name) + return None def get_group(api_name, input_args, input_kwargs): @@ -263,10 +273,10 @@ def get_group(api_name, input_args, input_kwargs): group = input_kwargs.get(DistributedCheckConst.GROUP) if group: return group - group_position = get_group_position(api_name) - if not group_position or len(input_args) <= group_position: + group_index = get_args_index(api_name, DistributedCheckConst.GROUP_INDEX) + if not group_index or len(input_args) <= group_index: return None - group = input_args[group_position] + group = input_args[group_index] if not isinstance(group, dict): return None return group @@ -289,3 +299,35 @@ def get_group_info(api_full_name, api_name, api_info_dict): logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) return None, None return group_ranks, group_id + + +def get_src(api_name, input_args, input_kwargs): + src = None + src = input_kwargs.get(DistributedCheckConst.SRC) + if isinstance(src, int): + return src + src_index = get_args_index(api_name, DistributedCheckConst.SRC_INDEX) + if not src_index or len(input_args) <= src_index: + return None + src = input_args[src_index] + if not isinstance(src, int): + return None + return src + + +def is_port_in_use(port, host): + """ + 检测指定端口是否被占用。 + :param port: 要检测的端口号 + :param host: 主机地址 + :return: 如果端口被占用返回 True,否则返回 False + """ + if not isinstance(port, str) or not port.isdigit(): + raise Exception(f"port: {port} is invalid. Port must be a numeric string.") + port = int(port) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((host, port)) + return False # 端口未被占用 + except socket.error: + return True # 端口已被占用 -- Gitee From 5b89a7f8ca8984b58bd7b6c128b10dc368b42f07 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 17:04:14 +0800 Subject: [PATCH 20/24] fix bug --- .../pytorch/api_accuracy_checker/run_ut/data_generate.py | 1 - .../api_accuracy_checker/run_ut/run_distributed_check.py | 4 +--- debug/accuracy_tools/msprobe/pytorch/pt_config.py | 6 ------ .../pytorch_ut/api_accuracy_checker/common/test_config.py | 2 ++ 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 682c8c510b5..53130b6d9c3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -291,7 +291,6 @@ def gen_args(args_info, api_name, func_options): for arg in args_info: if isinstance(arg, (list, tuple)): func_options['depth'] = depth + 1 - #zhelilyouwenti data = gen_args(arg, api_name, func_options) elif isinstance(arg, dict): if arg.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index ffdccd0e190..06cc0c81610 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -119,14 +119,12 @@ def run_distributed_command(args): def run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config): - white_list = [] for rank, forward_content in enumerate(forward_contents): logger.info("Start to check distributed api in rank {}.".format(rank)) for api_full_name, api_info_dict in forward_content.items(): _, api_name = extract_basic_api_segments(api_full_name) - if api_name not in white_list: - continue + if api_info_dict.get('used'): continue diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index eb0df5eec52..879c0810492 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -70,12 +70,6 @@ class TensorConfig(BaseConfig): if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host): raise Exception(f"host: {self.host} is invalid.") - if not isinstance(self.port, str) or not self.port.isdigit(): - raise Exception(f"port: {self.port} is invalid. Port must be a numeric string.") - port_number = int(self.port) - if not (0 < port_number <= 65535): - raise Exception(f"port: {self.port} is invalid. Port range must be between 1 and 65535.") - class StatisticsConfig(BaseConfig): def __init__(self, json_config): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py index df03485dc6c..30fa11d94de 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py @@ -16,6 +16,8 @@ class TestUtConfig(): self.port = 8080 self.rank_list = [0, 1, 2] self.tls_path = '/path/to/tls' + self.master_ip = '127.0.0.1' + self.master_port = 8888 class TestConfig(unittest.TestCase): -- Gitee From fe32dcab5fb642d57a9ce3d40f83d777e3b62bc3 Mon Sep 17 00:00:00 2001 From: gitee Date: Sat, 22 Feb 2025 17:20:08 +0800 Subject: [PATCH 21/24] fix bug --- .../run_ut/distributed_function_registry.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py index d6529a2b83a..9502ab3530d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -36,16 +36,14 @@ class DistributedFunctionRegistry: self.bench_functions[api_name] = function def get_compare_function(self, api_name: str) -> Callable: - if api_name in self.compare_functions: - return self.compare_functions[api_name] - else: - return compare_broadcast + if not self.compare_functions.get(api_name): + raise Exception("No compare function registered for api: {}".format(api_name)) + return self.compare_functions.get(api_name) def get_bench_function(self, api_name: str) -> Callable: - if api_name in self.bench_functions: - return self.bench_functions[api_name] - else: - return mock_broadcast + if not self.bench_functions.get(api_name): + raise Exception("No benchmark function registered for api: {}".format(api_name)) + return self.bench_functions.get(api_name) distributed_func_registry = DistributedFunctionRegistry() -- Gitee From 7214146015157a58b7fd16f5800d149afaec20a7 Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 24 Feb 2025 16:55:08 +0800 Subject: [PATCH 22/24] fix real_data_path --- .../api_accuracy_checker/run_ut/run_distributed_check.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py index 06cc0c81610..e1213ca9144 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -97,6 +97,9 @@ def run_distributed_command(args): real_data_paths = [] for file_path in file_paths: forward_content, _, real_data_path = parse_json_info_forward_backward(file_path) + if real_data_path: + dump_path = os.path.dirname(file_path) + real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA) distributed_api = parse_distributed_api(forward_content) forward_contents.append(distributed_api) real_data_paths.append(real_data_path) -- Gitee From 2b880ea95d3d84ee0c8d9b7a38e0d49d15d4b4c8 Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 24 Feb 2025 17:15:46 +0800 Subject: [PATCH 23/24] delete useless const --- debug/accuracy_tools/msprobe/core/common/const.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 6c09e5262c5..3a548a213bf 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -699,15 +699,9 @@ class DistributedCheckConst: BROADCAST = "broadcast" BROADCAST_SRC_INDEX = 1 FIRST_TENSOR_INDEX = 0 - GROUP_POSITION = { - "broadcast": 2 - } - SRC_POSITION = { - "broadcast": 1 - } API_ARGS_INDEX = { "broadcast": { "group_index": 2, "src_index": 1 } -} + } -- Gitee From f2c1f9330a2681c77e9ddb5a21b609f0eb08ec63 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 25 Feb 2025 11:33:11 +0800 Subject: [PATCH 24/24] add check --- .../api_accuracy_checker/run_ut/distributed_bench_function.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py index f5996a4f663..0c597625a9e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -33,6 +33,8 @@ def mock_broadcast(api_name, input_args, input_kwargs): if not group_ranks: raise ValueError("group_ranks should not be empty") real_src = src - min(group_ranks) + if len(input_args) <= real_src: + raise ValueError("input_args should have at least {} element".format(real_src + 1)) return input_args[real_src][0] -- Gitee