diff --git a/debug/accuracy_tools/msprobe/config.json b/debug/accuracy_tools/msprobe/config.json index 553b7f9ee3b89215647b00fb14b70af44ea5f00c..9bf9579b80770210bdda668b782a41540e7cb763 100644 --- a/debug/accuracy_tools/msprobe/config.json +++ b/debug/accuracy_tools/msprobe/config.json @@ -25,7 +25,9 @@ "run_ut": { "white_list": [], "black_list": [], - "error_data_path": "./" + "error_data_path": "./", + "master_ip": "127.0.0.1", + "master_port": "8888" }, "grad_probe": { "grad_level": "L1", diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index d9623b807121ea129484a535fe8a9e2293e662f3..3a548a213bf1addf0ba0d42c8c477dc93c7df266 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -27,6 +27,8 @@ class Const: ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" SEP = "." + COLON = ":" + DOUBLE_SLASH = "//" REGEX_PREFIX_MAX_LENGTH = 20 REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" REGEX_FORWARD_BACKWARD = r'\.(forward|backward)\.' @@ -674,3 +676,32 @@ class MonitorConst: CSV = "csv" API = "api" HEADER_NAME = 'name' + + +class DistributedCheckConst: + API_FULL_NAME = "api_full_name" + API_NAME = "api_name" + GROUP = "group" + GROUP_RANKS = "group_ranks" + GROUP_INDEX = "group_index" + SRC = "src" + SRC_INDEX = "src_index" + TORCH_PROCESS_GROUP = "torch.ProcessGroup" + ALL_ARGS = "all_args" + ALL_KWARGS = "all_kwargs" + RESULT_FILE_PATH = "result_file_path" + BENCHMARK_RESULT = "benchmark_result" + MASTER_IP = "master_ip" + MASTER_PORT = "master_port" + WORLD_SIZE = "world_size" + HCCL = "hccl" + TCP = "tcp" + BROADCAST = "broadcast" + BROADCAST_SRC_INDEX = 1 + FIRST_TENSOR_INDEX = 0 + API_ARGS_INDEX = { + "broadcast": { + "group_index": 2, + "src_index": 1 + } + } diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index f2b2d6a30463c62846bcc02e147c9c319f55d1b8..588a1eb349a6223f1c86df04fe3ae590a4e2a1ca 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -52,7 +52,9 @@ class Config: 'host': str, 'port': int, 'rank_list': list, - 'tls_path': str + 'tls_path': str, + 'master_ip': str, + 'master_port': str } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -72,6 +74,10 @@ class Config: RunUTConfig.check_nfs_path_config(value) if key == 'tls_path': RunUTConfig.check_tls_path_config(value) + if key == 'master_ip': + RunUTConfig.check_master_ip_config(value) + if key == 'master_port': + RunUTConfig.check_master_port_config(value) return value @@ -91,6 +97,8 @@ class CheckerConfig: self.port = msCheckerConfig.port self.rank_list = msCheckerConfig.rank_list self.tls_path = msCheckerConfig.tls_path + self.master_ip = msCheckerConfig.master_ip + self.master_port = msCheckerConfig.master_port if task_config: self.load_config(task_config) @@ -105,6 +113,8 @@ class CheckerConfig: self.port = task_config.port self.rank_list = task_config.rank_list self.tls_path = task_config.tls_path + self.master_ip = task_config.master_ip + self.master_port = task_config.master_port def get_online_config(self): return OnlineConfig( diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2ec9251009e61ef68dbfed987abe457d47b91e9a..30cea3b8e01f1c1a8a3a3d25620ba4bb2c9e709a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -8,3 +8,5 @@ host: "" port: -1 rank_list: [0] tls_path: "./" +master_ip: '127.0.0.1' +master_port: '2688' diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index 9d89b2de32f70c6fa7abf38add49b58a13531d7a..53130b6d9c3e1ecae4c900166b307e946ceffb86 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -26,7 +26,12 @@ from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, from msprobe.core.common.file_utils import FileChecker, load_npy from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import load_pt -from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.core.common.const import Const, FileCheckConst, CompareConst, DistributedCheckConst +from msprobe.pytorch.hook_module.wrap_distributed import get_distributed_ops + + +distribute_api_set = get_distributed_ops() +distribute_api_list = list(distribute_api_set) TORCH_TYPE = ["torch.device", "torch.dtype"] @@ -68,7 +73,7 @@ def gen_data(info, api_name, need_grad, convert_type, real_data_path=None): data = gen_random_tensor(info, convert_type) if api_name in hf_32_standard_api and data.dtype == torch.float32: data = fp32_to_hf32_to_fp32(data) - if info.get('requires_grad') and need_grad: + if info.get('requires_grad') and need_grad and api_name not in distribute_api_list: data.requires_grad_(True) temp_data = data * 1 data = temp_data.type_as(data) @@ -261,11 +266,14 @@ def gen_args(args_info, api_name, func_options): Function Description: Based on API basic information, generate input parameters: args, for API forward running Parameter: - api_info: API basic information. List + args_info: API basic information. DICT api_name: API name - need_grad: set Tensor grad for backward - convert_type: convert ori_type to dist_type flag. - real_data_path: the root directory for storing real data. + func_options: the options for generating args. Dict + need_grad: set Tensor grad for backward + convert_type: convert ori_type to dist_type flag. + real_data_path: the root directory for storing real data. + depth: the depth of recursion. + kwargs_params: the input kwargs parameters. """ check_object_type(args_info, list) args_result = [] @@ -274,6 +282,7 @@ def gen_args(args_info, api_name, func_options): convert_type = func_options.get('convert_type', None) real_data_path = func_options.get('real_data_path', None) depth = func_options.get('depth', 0) + kwargs_params = func_options.get('input_kwargs', {}) if depth > Const.MAX_DEPTH: logger.error("The depth of args is too large, please check the input args.") @@ -284,7 +293,11 @@ def gen_args(args_info, api_name, func_options): func_options['depth'] = depth + 1 data = gen_args(arg, api_name, func_options) elif isinstance(arg, dict): - data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) + if arg.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: + data = None + kwargs_params[DistributedCheckConst.GROUP] = arg + else: + data = gen_data(arg, api_name, need_grad, convert_type, real_data_path) elif arg is None: data = None else: @@ -311,6 +324,8 @@ def gen_kwargs(api_info, api_name, convert_type=None, real_data_path=None): kwargs_params[key] = gen_list_kwargs(value, api_name, convert_type, real_data_path) elif value is None: kwargs_params[key] = None + elif key == DistributedCheckConst.GROUP and value.get('type') == DistributedCheckConst.TORCH_PROCESS_GROUP: + kwargs_params[key] = value elif key == 'atten_mask' and api_name == 'npu_fusion_attention': sparse_mode = kwargs_params.get('sparse_mode', {}) if isinstance(sparse_mode, dict): @@ -415,17 +430,19 @@ def gen_api_params(api_info, api_name, need_grad=True, convert_type=None, real_d if convert_type and convert_type not in Const.CONVERT: error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) + func_options = { 'need_grad': need_grad, 'convert_type': convert_type, 'real_data_path': real_data_path, - 'depth': 0 + 'depth': 0, + 'input_kwargs': api_info.get("input_kwargs", {}) } if api_info.get("input_args"): args_params = gen_args(api_info.get("input_args"), api_name, func_options) else: logger.warning(f'Warning: No args in {api_info} ') args_params = [] + kwargs_params = gen_kwargs(api_info, api_name, convert_type, real_data_path) output_dtype = get_output_dtype(api_info) return args_params, kwargs_params, output_dtype diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py new file mode 100644 index 0000000000000000000000000000000000000000..0c597625a9e61217b3f4c9976c7748c7a3e8b48f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_bench_function.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.common.const import DistributedCheckConst +from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_group, get_src + + +def mock_broadcast(api_name, input_args, input_kwargs): + check_object_type(input_args, list) + check_object_type(input_kwargs, list) + if len(input_args) < 1 or len(input_kwargs) < 1: + raise ValueError("input_args and input_kwargs should have at least 1 element") + + src = get_src(api_name, input_args[0], input_kwargs[0]) + + group = get_group(api_name, input_args[0], input_kwargs[0]) + group_ranks = group.get(DistributedCheckConst.GROUP_RANKS, []) + if not group_ranks: + raise ValueError("group_ranks should not be empty") + real_src = src - min(group_ranks) + if len(input_args) <= real_src: + raise ValueError("input_args should have at least {} element".format(real_src + 1)) + + return input_args[real_src][0] + + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6d2f6fd4791ef3a121b78895b5452de62cb987 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_compare_function.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from msprobe.core.common.const import CompareConst + + +def compare_broadcast(device_out, bench_out): + if len(device_out) < 1: + raise ValueError("device_out should not be empty") + compare_result = torch.equal(device_out[0].cpu(), bench_out) + if not compare_result: + return CompareConst.ERROR + return CompareConst.PASS + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..9502ab3530d8f62efd90d6c89dd7213cb0e8d42b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/distributed_function_registry.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_bench_function import \ + mock_broadcast +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_compare_function import \ + compare_broadcast +from msprobe.core.common.const import DistributedCheckConst + + +class DistributedFunctionRegistry: + def __init__(self): + self.compare_functions = {} + self.bench_functions = {} + + def register_compare_function(self, api_name: str, function: Callable): + self.compare_functions[api_name] = function + + def register_bench_function(self, api_name: str, function: Callable): + self.bench_functions[api_name] = function + + def get_compare_function(self, api_name: str) -> Callable: + if not self.compare_functions.get(api_name): + raise Exception("No compare function registered for api: {}".format(api_name)) + return self.compare_functions.get(api_name) + + def get_bench_function(self, api_name: str) -> Callable: + if not self.bench_functions.get(api_name): + raise Exception("No benchmark function registered for api: {}".format(api_name)) + return self.bench_functions.get(api_name) + + +distributed_func_registry = DistributedFunctionRegistry() +distributed_func_registry.register_bench_function(DistributedCheckConst.BROADCAST, mock_broadcast) +distributed_func_registry.register_compare_function(DistributedCheckConst.BROADCAST, compare_broadcast) + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py new file mode 100644 index 0000000000000000000000000000000000000000..e1213ca914463672b4801fd864330e5e80c2688b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_distributed_check.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import argparse +from collections import namedtuple + +import tqdm +import torch +import torch_npu +import torch.distributed as dist +import torch.multiprocessing as mp + + +from msprobe.core.common.file_utils import FileChecker, write_csv, create_directory +from msprobe.core.common.const import Const, FileCheckConst, DistributedCheckConst, CompareConst +from msprobe.core.compare.utils import check_and_return_dir_contents +from msprobe.pytorch.hook_module.wrap_distributed import distributed_func +from msprobe.pytorch.pt_config import parse_json_config +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import generate_device_params, get_group_info, \ + is_port_in_use +from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_api_info +from msprobe.pytorch.api_accuracy_checker.run_ut.distributed_function_registry import distributed_func_registry +from msprobe.pytorch.api_accuracy_checker.common.utils import extract_basic_api_segments +from msprobe.pytorch.api_accuracy_checker.common.config import CheckerConfig + + +current_time = time.strftime("%Y%m%d%H%M%S") +RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" +RESULT_CSV_HEADER = [['API_NAME', 'RANK', 'COMPARE_RESULT', 'MESSAGE']] +DistributedCheckParams = namedtuple("DistributedCheckParams", ["api_full_name", "all_args", "all_kwargs", + "group_ranks", "result_file_path", "checker_config"]) + + +def cleanup(): + dist.destroy_process_group() + + +def distributed_setup(rank, world_size, master_ip, master_port): + init_method = DistributedCheckConst.TCP + Const.COLON + Const.DOUBLE_SLASH + master_ip + Const.COLON + master_port + dist.init_process_group(backend=DistributedCheckConst.HCCL, init_method=init_method, + world_size=world_size, rank=rank) + + +def parse_distributed_api(forward_content): + distributed_api = {} + for api_full_name, api_info_dict in forward_content.items(): + split_name = api_full_name.split(Const.SEP)[0] + if split_name == Const.DISTRIBUTED: + distributed_api.update({api_full_name: api_info_dict}) + return distributed_api + + +def _run_distributed_parser(parser): + parser.add_argument("-api_info", "--api_info_dir", dest="api_info_dir", default="", type=str, + help=" The api param tool result dir: generate from api param tool. ", + required=True) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The ut task result out path.", + required=False) + parser.add_argument("-config", "--config_path", dest="config_path", default="", type=str, + help=" The path of config.json", required=False) + + +def _run_distributed(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + _run_distributed_parser(parser) + args = parser.parse_args(sys.argv[1:]) + run_distributed_command(args) + + +def run_distributed_command(args): + input_checker = FileChecker(args.api_info_dir, FileCheckConst.DIR, ability=FileCheckConst.READ_ABLE) + api_info_dir = input_checker.common_check() + ranks = sorted(check_and_return_dir_contents(api_info_dir, Const.RANK)) + file_paths = [os.path.join(api_info_dir, rank, 'dump.json') for rank in ranks] + forward_contents = [] + real_data_paths = [] + for file_path in file_paths: + forward_content, _, real_data_path = parse_json_info_forward_backward(file_path) + if real_data_path: + dump_path = os.path.dirname(file_path) + real_data_path = os.path.join(dump_path, Const.DUMP_TENSOR_DATA) + distributed_api = parse_distributed_api(forward_content) + forward_contents.append(distributed_api) + real_data_paths.append(real_data_path) + + out_path = args.out_path if args.out_path else Const.DEFAULT_PATH + create_directory(out_path) + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + result_file_path = os.path.join(out_path, RESULT_FILE_NAME) + write_csv(RESULT_CSV_HEADER, result_file_path) + if args.config_path: + config_path_checker = FileChecker(args.config_path, FileCheckConst.FILE, + FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX) + checked_config_path = config_path_checker.common_check() + _, task_config = parse_json_config(checked_config_path, Const.RUN_UT) + checker_config = CheckerConfig(task_config) + else: + checker_config = CheckerConfig() + run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config) + + +def run_distributed_check(forward_contents, real_data_paths, result_file_path, checker_config): + for rank, forward_content in enumerate(forward_contents): + logger.info("Start to check distributed api in rank {}.".format(rank)) + + for api_full_name, api_info_dict in forward_content.items(): + _, api_name = extract_basic_api_segments(api_full_name) + + if api_info_dict.get('used'): + continue + + group_ranks, group_id = get_group_info(api_full_name, api_name, api_info_dict) + if not group_ranks or not group_id: + logger.warning("The api {} doesn't support distributed check.".format(api_full_name)) + continue + all_args, all_kwargs = get_distributed_args_kwargs(forward_contents, api_full_name, + real_data_paths, group_ranks) + try: + distributed_check_params = DistributedCheckParams(api_full_name, all_args, all_kwargs, group_ranks, + result_file_path, checker_config) + distributed_check(distributed_check_params) + except Exception as e: + import traceback + traceback.print_exc() + logger.error("The api {} in rank {} distributed check failed.".format(api_full_name, rank)) + result_rows = [] + df_row = list([api_full_name, rank, CompareConst.ERROR, str(e)]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + + +def distributed_check(distributed_check_params): + api_full_name = distributed_check_params.api_full_name + all_args = distributed_check_params.all_args + all_kwargs = distributed_check_params.all_kwargs + group_ranks = distributed_check_params.group_ranks + result_file_path = distributed_check_params.result_file_path + checker_config = distributed_check_params.checker_config + + _, api_name = extract_basic_api_segments(api_full_name) + nprocs = len(group_ranks) + distributed_config = {} + distributed_config[DistributedCheckConst.API_FULL_NAME] = api_full_name + distributed_config[DistributedCheckConst.API_NAME] = api_name + distributed_config[DistributedCheckConst.GROUP_RANKS] = group_ranks + distributed_config[DistributedCheckConst.ALL_ARGS] = all_args + distributed_config[DistributedCheckConst.ALL_KWARGS] = all_kwargs + distributed_config[DistributedCheckConst.RESULT_FILE_PATH] = result_file_path + benchmark_function = distributed_func_registry.get_bench_function(api_name) + distributed_config[DistributedCheckConst.BENCHMARK_RESULT] = benchmark_function(api_name, all_args, all_kwargs) + distributed_config[DistributedCheckConst.MASTER_IP] = checker_config.master_ip + distributed_config[DistributedCheckConst.MASTER_PORT] = checker_config.master_port + distributed_config[DistributedCheckConst.WORLD_SIZE] = nprocs + + if is_port_in_use(checker_config.master_port, checker_config.master_ip): + raise ValueError( + f"Warning: Port {checker_config.master_port} on host " + f"{checker_config.master_ip} is already in use." + ) + logger.info(f"Port {checker_config.master_port} on host {checker_config.master_ip} is available.") + + mp.spawn(run_hccl, + args=(distributed_config,), + nprocs=nprocs) + + +def run_hccl(rank, distributed_config): + local_rank = distributed_config[DistributedCheckConst.GROUP_RANKS][rank] + torch_npu.npu.set_device(local_rank) + world_size = distributed_config[DistributedCheckConst.WORLD_SIZE] + master_ip = distributed_config[DistributedCheckConst.MASTER_IP] + master_port = distributed_config[DistributedCheckConst.MASTER_PORT] + distributed_setup(rank, world_size, master_ip, master_port) + api_full_name = distributed_config[DistributedCheckConst.API_FULL_NAME] + api_name = distributed_config[DistributedCheckConst.API_NAME] + rank_args = distributed_config[DistributedCheckConst.ALL_ARGS][rank] + rank_kwargs = distributed_config[DistributedCheckConst.ALL_KWARGS][rank] + result_file_path = distributed_config[DistributedCheckConst.RESULT_FILE_PATH] + benchmark_result = distributed_config[DistributedCheckConst.BENCHMARK_RESULT] + device_args, _ = generate_device_params(rank_args, rank_kwargs, False, api_name) + logger.info("Start to check distributed api {} in rank {}.".format(api_full_name, local_rank)) + distributed_func.get(api_name)(*device_args) + + compare_function = distributed_func_registry.get_compare_function(api_name) + status = compare_function(device_args, benchmark_result) + message = '' + result_rows = [] + df_row = list([api_full_name, local_rank, status, message]) + result_rows.append(df_row) + write_csv(result_rows, result_file_path) + cleanup() + + +def get_distributed_args_kwargs(forward_contents, api_full_name, real_data_paths, group_ranks): + all_args, all_kwargs = [], [] + _, api_name = extract_basic_api_segments(api_full_name) + for group_rank in group_ranks: + target_api_info = forward_contents[group_rank].get(api_full_name) + if not target_api_info: + logger.warning("The api {} doesn't exist in rank {}.".format(api_full_name, group_rank)) + continue + if target_api_info.get('used'): + continue + target_api_info['used'] = True + args, kwargs, _ = get_api_info(target_api_info, api_name, real_data_paths[group_rank]) + all_args.append(args) + all_kwargs.append(kwargs) + return all_args, all_kwargs + + +if __name__ == '__main__': + logger.info("Start to run distributed ut task.") + _run_distributed() + logger.info("End to run distributed ut task.") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py index dc0174212e3f8f8cf70fa1701aadc664138dbcdf..a8570f98d19669e08209e49a4c2e62ebaf728387 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py @@ -16,6 +16,7 @@ # limitations under the License. import os +import socket from collections import namedtuple import re import torch @@ -29,7 +30,7 @@ else: current_device = "npu" from torch_npu.npu.amp import autocast -from msprobe.core.common.const import FileCheckConst, Const, CompareConst +from msprobe.core.common.const import FileCheckConst, Const, CompareConst, DistributedCheckConst from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException @@ -252,3 +253,81 @@ def is_unsupported_api(api_name, is_overflow_check=False): if flag: logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag + + +def get_args_index(api_name, args_name): + """ + 根据 API 名字和参数名获取参数索引。获取 group_index 或者 src_index。 + :param api_name: API 名字,如 "broadcast" 或 "all_reduce" + :param args_name: 参数名,如 "group" 或 "src" + :return: 参数索引 或 None(如果 API 名字或参数名不存在) + """ + api_info = DistributedCheckConst.API_ARGS_INDEX.get(api_name) + if api_info: + return api_info.get(args_name) + return None + + +def get_group(api_name, input_args, input_kwargs): + group = None + group = input_kwargs.get(DistributedCheckConst.GROUP) + if group: + return group + group_index = get_args_index(api_name, DistributedCheckConst.GROUP_INDEX) + if not group_index or len(input_args) <= group_index: + return None + group = input_args[group_index] + if not isinstance(group, dict): + return None + return group + + +def get_group_info(api_full_name, api_name, api_info_dict): + input_args = api_info_dict.get('input_args', {}) + input_kwargs = api_info_dict.get('input_kwargs', {}) + group = get_group(api_name, input_args, input_kwargs) + + if not group: + logger.warning("The api {} doesn't have group info.".format(api_full_name)) + return None, None + group_ranks = group.get('group_ranks') + if not group_ranks: + logger.warning("The group of api {} doesn't have group_ranks info.".format(api_full_name)) + return None, None + group_id = group.get('group_id') + if not group_id: + logger.warning("The group of api {} doesn't have group_id info.".format(api_full_name)) + return None, None + return group_ranks, group_id + + +def get_src(api_name, input_args, input_kwargs): + src = None + src = input_kwargs.get(DistributedCheckConst.SRC) + if isinstance(src, int): + return src + src_index = get_args_index(api_name, DistributedCheckConst.SRC_INDEX) + if not src_index or len(input_args) <= src_index: + return None + src = input_args[src_index] + if not isinstance(src, int): + return None + return src + + +def is_port_in_use(port, host): + """ + 检测指定端口是否被占用。 + :param port: 要检测的端口号 + :param host: 主机地址 + :return: 如果端口被占用返回 True,否则返回 False + """ + if not isinstance(port, str) or not port.isdigit(): + raise Exception(f"port: {port} is invalid. Port must be a numeric string.") + port = int(port) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind((host, port)) + return False # 端口未被占用 + except socket.error: + return True # 端口已被占用 diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index 8293ac969490b103eef630081b6001234ca8bb07..879c0810492da74cb93e565c5f356b8303d885d5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -70,9 +70,6 @@ class TensorConfig(BaseConfig): if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host): raise Exception(f"host: {self.host} is invalid.") - if not isinstance(self.port, int) or not (0 < self.port <= 65535): - raise Exception(f"port: {self.port} is invalid, port range 0-65535.") - class StatisticsConfig(BaseConfig): def __init__(self, json_config): @@ -252,6 +249,8 @@ class RunUTConfig(BaseConfig): self.port = json_config.get("port", -1) self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) self.tls_path = json_config.get("tls_path", "./") + self.master_ip = json_config.get("master_ip", "127.0.0.1") + self.master_port = json_config.get("master_port", "8888") self.check_run_ut_config() @classmethod @@ -278,6 +277,19 @@ class RunUTConfig(BaseConfig): def check_tls_path_config(cls, tls_path): if tls_path and not os.path.exists(tls_path): raise Exception("tls_path: %s does not exist" % tls_path) + + @classmethod + def check_master_ip_config(cls, master_ip): + if not re.match(Const.ipv4_pattern, master_ip): + raise Exception("master_ip: %s is invalid" % master_ip) + + @classmethod + def check_master_port_config(cls, master_port): + if not isinstance(master_port, str) or not master_port.isdigit(): + raise Exception(f"port: {master_port} is invalid. Port must be a numeric string.") + port_number = int(master_port) + if not (0 < port_number <= 65535): + raise Exception(f"port: {master_port} is invalid. Port range must be between 1 and 65535.") def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) @@ -285,6 +297,8 @@ class RunUTConfig(BaseConfig): RunUTConfig.check_error_data_path_config(self.error_data_path) RunUTConfig.check_nfs_path_config(self.nfs_path) RunUTConfig.check_tls_path_config(self.tls_path) + RunUTConfig.check_master_ip_config(self.master_ip) + RunUTConfig.check_master_port_config(self.master_port) class GradToolConfig(BaseConfig): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py index df03485dc6c77371750fd0b67ca2c37ff7e2ed7b..30fa11d94de0dd4fec483502a51d0474e8b7646a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py @@ -16,6 +16,8 @@ class TestUtConfig(): self.port = 8080 self.rank_list = [0, 1, 2] self.tls_path = '/path/to/tls' + self.master_ip = '127.0.0.1' + self.master_port = 8888 class TestConfig(unittest.TestCase):