From a43df2edc48fd335e87a925a075f21b357e4962f Mon Sep 17 00:00:00 2001 From: hhqx Date: Wed, 11 Dec 2024 16:56:00 +0800 Subject: [PATCH 1/2] fix set exc_info=Constant.ENABLE_STACKTRACE_LOGGING fix fix train_with_profiler.py fix cleancode fix cleancode add --static to profiler examples fix reviews fix fix fix Revert "remove precheck_run_llama2.sh" This reverts commit f1a73c09e0975691d1a2522c3ead9ae39ecda9ce. fix cleancode remove code remove code in train_with_profiler.py fix codes remove unused example model remove setup.py remove precheck_run_llama2.sh remove files for pr: 1. remove multiprocess_utils.py 2. remove args check in train_with_profiler.py and unnecessary comment remove old args parser profiling_cmd should not empty remove tests for pr using torchrun to start profiler example fix add validate_distributed_env() fix tiny dataset shape add cal loss to demo trainer add pbar fix fix fix fix launch.py fix fix optimize profiler examples optimize profiler examples fix optimize arg: host_ips and host_config_file optimize arg: host_ips and host_config_file add conda activate to precheck/__main__.py fix cleancode in args_manager.py add arg check for function in ssh_utils.py add command check for profiling_cmd [archive bomb] add check function to archive bomb remove unnecessary import of task_manager.py fix run_remote_command in multiprocess_run add multiprocess support for remote_run and sync add multiprocess_utils.py set MS_PROF_PRECHECK_CMD = "msprof-analyze precheck" fix task_manager.py add DEFAULT_DP_CONFIG add dynamic_prof.py fix constant.py add init to runner fix MS_PROF_PRECHECK_CMD fix precheck_cli.py add files for build wheel add precheck_cli add __name__ to logging.getLogger() set ARG_MAX_LEN = 255 fix run_llama2_precheck.sh fix run_precheck.sh fix args_manager.py fix advisor fix run_precheck.sh fix runner.py zip fix run_precheck.sh fix cat_files.py fix sync.py fix collector fix and add collector test fix group_manager.py fix collector fix collector fix glob.glob usage fix update precheck.tests update ssh and archive test files add other test files update test files add test files merge squash the precheck changes in wzx_poc merge squash the precheck changes in precheck_hqx_pullup1 create logger.py and add ssh config file for remote hosts add args_manager and main module --- profiler/MANIFEST.in | 1 + profiler/cli/entrance.py | 5 +- profiler/cli/precheck_cli.py | 106 +++++++ profiler/precheck/__init__.py | 0 profiler/precheck/__main__.py | 60 ++-- profiler/precheck/analyze/__init__.py | 0 profiler/precheck/analyze/advisor_adaptor.py | 7 +- profiler/precheck/collect/collector.py | 26 +- profiler/precheck/common/constant.py | 10 +- profiler/precheck/common/utils.py | 59 ++-- profiler/precheck/examples/__init__.py | 0 .../precheck/examples/profiler/__init__.py | 0 .../examples/profiler/dynamic_prof.py | 71 +++++ profiler/precheck/examples/profiler/models.py | 67 ++++ .../examples/profiler/train_with_profiler.py | 286 ++++++++++++++++++ profiler/precheck/examples/run_precheck.sh | 23 -- .../examples/scripts/precheck_run_llama2.sh | 128 ++++++++ .../{ => scripts}/run_llama2_precheck.sh | 17 +- .../precheck/examples/scripts/run_precheck.sh | 37 +++ profiler/precheck/manager/__init__.py | 0 profiler/precheck/manager/args_manager.py | 151 ++++----- profiler/precheck/manager/disk_manager.py | 2 +- profiler/precheck/manager/group_manager.py | 21 +- profiler/precheck/manager/task_manager.py | 36 +-- profiler/precheck/runner/__init__.py | 0 profiler/precheck/runner/__main__.py | 62 ++-- profiler/precheck/runner/runners.py | 26 +- profiler/precheck/setup.py | 18 -- profiler/precheck/tools/__init__.py | 0 profiler/precheck/tools/archive_utils.py | 9 +- profiler/precheck/tools/ssh_utils.py | 5 +- 31 files changed, 945 insertions(+), 288 deletions(-) create mode 100644 profiler/cli/precheck_cli.py create mode 100644 profiler/precheck/__init__.py create mode 100644 profiler/precheck/analyze/__init__.py create mode 100644 profiler/precheck/examples/__init__.py create mode 100644 profiler/precheck/examples/profiler/__init__.py create mode 100644 profiler/precheck/examples/profiler/dynamic_prof.py create mode 100644 profiler/precheck/examples/profiler/models.py create mode 100644 profiler/precheck/examples/profiler/train_with_profiler.py delete mode 100644 profiler/precheck/examples/run_precheck.sh create mode 100644 profiler/precheck/examples/scripts/precheck_run_llama2.sh rename profiler/precheck/examples/{ => scripts}/run_llama2_precheck.sh (65%) create mode 100644 profiler/precheck/examples/scripts/run_precheck.sh create mode 100644 profiler/precheck/manager/__init__.py create mode 100644 profiler/precheck/runner/__init__.py delete mode 100644 profiler/precheck/setup.py create mode 100644 profiler/precheck/tools/__init__.py diff --git a/profiler/MANIFEST.in b/profiler/MANIFEST.in index 0550da458f..99ae66e4ec 100644 --- a/profiler/MANIFEST.in +++ b/profiler/MANIFEST.in @@ -3,5 +3,6 @@ recursive-include profiler/cli/ * recursive-include profiler/prof_common/ * recursive-include profiler/compare_tools/ * recursive-include profiler/cluster_analyse/ * +recursive-include profiler/precheck/ * global-exclude */__pycache__/* global-exclude *.pyc diff --git a/profiler/cli/entrance.py b/profiler/cli/entrance.py index a260553031..f1d687ad03 100644 --- a/profiler/cli/entrance.py +++ b/profiler/cli/entrance.py @@ -21,6 +21,7 @@ from profiler.cli.analyze_cli import analyze_cli from profiler.cli.complete_cli import auto_complete_cli from profiler.cli.compare_cli import compare_cli from profiler.cli.cluster_cli import cluster_cli +from profiler.cli.precheck_cli import precheck_cli from profiler.advisor.version import print_version_callback, cli_version logger = logging.getLogger() @@ -31,7 +32,8 @@ COMMAND_PRIORITY = { "advisor": 1, "compare": 2, "cluster": 3, - "auto-completion": 4 + "precheck": 4, + "auto-completion": 5 } @@ -64,4 +66,5 @@ def msprof_analyze_cli(**kwargs): msprof_analyze_cli.add_command(analyze_cli, name="advisor") msprof_analyze_cli.add_command(compare_cli, name="compare") msprof_analyze_cli.add_command(cluster_cli, name="cluster") +msprof_analyze_cli.add_command(precheck_cli, name="precheck") msprof_analyze_cli.add_command(auto_complete_cli, name="auto-completion") diff --git a/profiler/cli/precheck_cli.py b/profiler/cli/precheck_cli.py new file mode 100644 index 0000000000..2928fedcd3 --- /dev/null +++ b/profiler/cli/precheck_cli.py @@ -0,0 +1,106 @@ +import sys +import ipaddress +import logging +from functools import wraps + +import click + +from profiler.precheck.manager.args_manager import PrecheckArgsManager, PrecheckRunnerArgsManager +from profiler.precheck.runner.__main__ import main as runner_main +from profiler.precheck.__main__ import main as precheck_main + +logger = logging.getLogger(__name__) + +CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) + + +@click.group(context_settings=CONTEXT_SETTINGS) +def precheck_cli(): + """Profiler precheck tool""" + pass + + +def common_options(f): + """Common options for both precheck and runner commands""" + + @wraps(f) + def wrapper(*args, **kwargs): + return f(*args, **kwargs) + + wrapper = click.option('--master_addr', required=True, + help='IP address of the master node (first node in the cluster)')(wrapper) + wrapper = click.option('--master_port', type=int, default=29500, + help='Port on the master node for communication. Default is 29500')(wrapper) + wrapper = click.option('--nnodes', type=int, required=True, + help='Total number of nodes in the distributed setup')(wrapper) + wrapper = click.option('--nproc_per_node', type=int, required=True, + help='Number of processes to run per node')(wrapper) + wrapper = click.option('--node_prof_save_dir', default='', + help='Directory for saving node profiling data')(wrapper) + wrapper = click.option('--master_prof_gather_dir', default='', + help='Directory for saving gathered profiling data in master node')(wrapper) + wrapper = click.option('--output_dir', default='./output', + help='Directory to save profiling dump data, logs, and advisor reports')(wrapper) + wrapper = click.option('--task_name', default='', + help='Name of the task or experiment')(wrapper) + wrapper = click.option('--static', is_flag=True, + help='If set, run profiling in static mode')(wrapper) + wrapper = click.option('--profiling_cmd', default="", + help='Command to run the profiler script')(wrapper) + return wrapper + + +def validate_ip_list(ctx, param, value): + if not value: + return [] + try: + ips = [ip.strip() for ip in value.split(',')] + # Validate each IP + for ip in ips: + ipaddress.ip_address(ip) + return ips + except ValueError as e: + raise click.BadParameter(f'Invalid IP address in list: {e}') + + +@precheck_cli.command(context_settings=CONTEXT_SETTINGS, + name="start_all", + short_help='Start precheck on all nodes via ssh') +@common_options +@click.option('--host_ips', + callback=validate_ip_list, + help='Comma-separated list of IP addresses for nodes in distributed training (e.g., "192.168.1.1,192.168.1.2")') +@click.option('--python_path', default=sys.executable, + help='Path to the Python interpreter') +@click.option('--host_config_file', default='', + help='Path to the host configuration file (CSV format with node connection details)') +def precheck_start_all(**kwargs): + """Run precheck command""" + # Add validation + if not kwargs.get('host_ips') and not kwargs.get('host_config_file'): + raise click.UsageError('Either --host_ips or --host_config_file must be specified') + + if kwargs.get('host_ips') and kwargs.get('host_config_file'): + raise click.UsageError('Cannot specify both --host_ips and --host_config_file') + + args = PrecheckArgsManager(type('Args', (), kwargs)) + click.echo(args) + precheck_main(args) + + +@precheck_cli.command(context_settings=CONTEXT_SETTINGS, + name="start_node", + short_help='Start one node precheck, if your nnodes > 1, you need to run this command on each node') +@common_options +@click.option('--node_rank', type=int, required=True, + help='Rank of the current node') +def precheck_start_node(**kwargs): + """Run precheck runner command""" + args = PrecheckRunnerArgsManager(type('Args', (), kwargs)) + click.echo(args) + + runner_main(args) + + +if __name__ == '__main__': + precheck_cli() diff --git a/profiler/precheck/__init__.py b/profiler/precheck/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/__main__.py b/profiler/precheck/__main__.py index 35e0169ded..03c96db145 100644 --- a/profiler/precheck/__main__.py +++ b/profiler/precheck/__main__.py @@ -1,25 +1,21 @@ import os -import re -import sys -import threading -import subprocess -import argparse -from datetime import datetime from copy import deepcopy import logging -from typing import Union +from profiler.precheck.common.constant import Constant from profiler.precheck.common.logger import add_file_handler, create_logger from profiler.precheck.common.utils import cn_now -from profiler.precheck.manager.args_manager import PrecheckArgsManager, get_precheck_args +from profiler.precheck.manager.args_manager import PrecheckArgsManager from profiler.precheck.tools.ssh_utils import run_remote_command -from profiler.prof_common.constant import Constant from profiler.prof_common.path_manager import PathManager def get_command_tpl(): cwd = os.getcwd() - EXECUTOR = f'cd {cwd} && {{python_path}} -m profiler.precheck.runner' + from profiler.precheck.runner.__main__ import get_conda_envs_info + conda_env, conda_activate_script = get_conda_envs_info() + + EXECUTOR = f'source {conda_activate_script} {conda_env} && cd {cwd} && {Constant.MS_PROF_PRECHECK_CMD} start_node' ARGS = ('--nnodes={nnodes}', '--nproc_per_node={nproc_per_node}', '--node_rank={node_rank}', '--master_addr={master_addr}', '--master_port={master_port}', @@ -34,21 +30,7 @@ def get_command_tpl(): return TPL -def main(): - logger = create_logger("profiler.precheck", logging.DEBUG, use_memory_handler=True) - args = get_precheck_args() - - PathManager.create_file_safety(args.task_output_dir) - - timestamp = cn_now().strftime('%Y%m%d_%H%M%S') - log_filename = f'precheck_{timestamp}.log' - log_file_path = os.path.join(args.task_output_dir, log_filename) - PathManager.check_path_writeable(log_file_path) - - logger = add_file_handler(logger, log_file_path) - logger.info("Starting precheck with arguments: %s", args) - logger.info("Precheck log file will be saved at %s", log_file_path) - +def start_precheck(args: PrecheckArgsManager, logger): config = dict( nnodes=args.nnodes, node_rank=-1, @@ -65,8 +47,7 @@ def main(): ) hosts_info = [] - for host in args.host_ips: - node_id = args.host_ips.index(host) + for node_id, host in enumerate(args.host_ips): node_config = deepcopy(config) node_config['node_rank'] = node_id @@ -83,17 +64,32 @@ def main(): "port": 22 } - if args.ssh_config_file: + if args.host_config_file: host_info.update(args.ssh_remote_hosts[host]) hosts_info.append(host_info) - logger.info("Prepared command for host %s: %s, log file will be saved at %s:%s", - host, cmd, host, log_file_path) logger.info("Starting remote command execution on %d hosts", len(hosts_info)) run_remote_command(hosts_info) logger.info("Precheck main processes have been started on all hosts") -if __name__ == "__main__": - main() +def main(args=None): + logger = create_logger("profiler.precheck", logging.DEBUG, use_memory_handler=True) + + PathManager.make_dir_safety(args.task_output_dir) + + timestamp = cn_now().strftime('%Y%m%d_%H%M%S') + log_filename = f'precheck_{timestamp}.log' + log_file_path = os.path.join(args.task_output_dir, log_filename) + PathManager.create_file_safety(log_file_path) + PathManager.check_path_writeable(log_file_path) + + logger = add_file_handler(logger, log_file_path) + logger.info("Starting precheck, Precheck log file will be saved at %s", log_file_path) + logger.info("Precheck arguments: %s", args) + + try: + start_precheck(args, logger) + except Exception as e: + logger.error("Precheck runner failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) diff --git a/profiler/precheck/analyze/__init__.py b/profiler/precheck/analyze/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/analyze/advisor_adaptor.py b/profiler/precheck/analyze/advisor_adaptor.py index e0278e799b..ae635ccb23 100644 --- a/profiler/precheck/analyze/advisor_adaptor.py +++ b/profiler/precheck/analyze/advisor_adaptor.py @@ -3,7 +3,6 @@ import os import logging from pathlib import Path -sys.path.append("../../..") sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "compare_tools")) sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "cluster_analyse")) @@ -11,8 +10,7 @@ from profiler.advisor.analyzer.analyzer_controller import AnalyzerController from profiler.advisor.interface.interface import Interface from profiler.prof_common.path_manager import PathManager -logger = logging.getLogger() -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) class advisor_adaptor: @@ -38,7 +36,8 @@ class advisor_adaptor: PathManager.check_input_directory_path(output_path) PathManager.input_path_common_check(output_path) - return PathManager.check_path_owner_consistent(output_path) + PathManager.check_path_owner_consistent(output_path) + return True def analyze(self, input_profiling_path, output_path): if self._check_profiling_path_valid(input_profiling_path) and self._check_output_path_valid(output_path): diff --git a/profiler/precheck/collect/collector.py b/profiler/precheck/collect/collector.py index a1a2ce2f8d..e7fb8ce9cc 100644 --- a/profiler/precheck/collect/collector.py +++ b/profiler/precheck/collect/collector.py @@ -1,6 +1,6 @@ import sys import os -from typing import Any +from typing import Any, Dict sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) @@ -216,7 +216,7 @@ class Collector: try: # 设置环境变量,这些会在torch.dist中用到 # 因为master node rank为0, 所以global rank直接等于local rank - GroupManager().set_env(master_env) + master_env.set_env() self.init(master_env) start_event = create_npu_event(self.stream) @@ -232,6 +232,9 @@ class Collector: gather_tensor = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.int64, device=self.device) # 分为 (file_size, file_hash) dist.init_process_group(backend='hccl', rank=self.rank, world_size=self.world_size) + if not (dist.is_available() and dist.is_initialized()): + raise RuntimeError("Distributed environment is not available") + file_sizes_hash, wait_time, transfer_time = self.gather_rank_data(group=dist.group.WORLD, gather_tensor=gather_tensor, all_gather=True) @@ -279,16 +282,16 @@ class Collector: self.logger.info("[Rank %d] master rank not in sub group" % self.rank) dist.barrier() except Exception as e: - self.logger.error("%s", e, exc_info=True) + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) raise e finally: dist.destroy_process_group() - def slave_node_run(self, input_file_dir, master_rank_num, slave_env: EnvGroup): + def slave_node_run(self, slave_env: EnvGroup, input_file_dir, master_rank_num): try: self.logger.debug('Enter slave node run wrapper') # 设置环境变量,这些会在torch.dist中用到 - GroupManager().set_env(slave_env) + slave_env.set_env() self.init(slave_env) torch.npu.set_device(self.device) start_event = create_npu_event(self.stream) @@ -322,6 +325,9 @@ class Collector: gather_tensor = torch.tensor(file_hash_chunks, dtype=torch.int64, device=self.device) dist.init_process_group(backend='hccl', rank=self.rank, world_size=self.world_size) + if not (dist.is_available() and dist.is_initialized()): + raise RuntimeError("Distributed environment is not available") + file_sizes_hash, wait_time, transfer_time = self.gather_rank_data(group=dist.group.WORLD, gather_tensor=gather_tensor, all_gather=True) @@ -347,12 +353,12 @@ class Collector: self.logger.warning("[Rank %d] slave rank not in sub group" % (self.rank)) dist.barrier() except Exception as e: - self.logger.error("%s", e, exc_info=True) + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) raise e finally: dist.destroy_process_group() - def run(self, args_dict: dict[str, Any]): + def run(self, args_dict: Dict[str, Any]): input_file_dir = args_dict.get("input_file_dir") output_file_dir = args_dict.get("output_file_dir") nnodes = args_dict.get("nnodes") @@ -399,7 +405,7 @@ class Collector: raise TimeoutError("Timeout reached. Terminating all subprocesses.") except TimeoutError as e: - self.logger.error("%s", e, exc_info=True) + self.logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) for process in processes: if process.is_alive(): process.terminate() @@ -413,7 +419,7 @@ class Collector: rank = node_rank + master_rank_num - 1 slave_env = EnvGroup(rank=rank, local_rank=0, world_size=world_size, master_addr=master_addr, master_port=master_port, group_rank=node_rank, local_world_size=1) - self.slave_node_run(input_file_dir, master_rank_num, slave_env) + self.slave_node_run(slave_env, input_file_dir, master_rank_num) if __name__ == "__main__": @@ -448,5 +454,5 @@ if __name__ == "__main__": try: collector.run(args_dict) except Exception as e: - logger.error("%s", e, exc_info=True) + logger.error("%s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) raise e diff --git a/profiler/precheck/common/constant.py b/profiler/precheck/common/constant.py index 181e6ce80e..74a7b503cf 100644 --- a/profiler/precheck/common/constant.py +++ b/profiler/precheck/common/constant.py @@ -9,7 +9,7 @@ class Constant: UNZIP_DISK_SIZE_RAIO = 1.0 # 需要x倍压缩文件的空间进行解压操作 DEFAULT_TIME_OUT = 1200 - ARG_MAX_LEN = 100 # 参数最大长度 + ARG_MAX_LEN = 255 # 参数最大长度 ARG_MIN_INT_VALUE = - (1 << 31) # 32位整数最小值 ARG_MAX_INT_VALUE = (1 << 31) - 1 # 32位整数最大值 ARG_MIN_PORT_VALUE = 0 @@ -21,15 +21,19 @@ class Constant: COLLECTOR_DEFAULT_TIMEOUT = 1200 # seconds COLLECTOR_SPLIT_FILE_SIZE = None # 文件传输的split块大小,默认split size设为根据显存自动计算 LOCALHOST_ADDRESSES = {'localhost', '127.0.0.1'} - + MAX_ARCHIVE_SIZE = 20 * 1024 * 1024 * 1024 # 20 GB MAX_ARCHIVE_FILE_COUNT = 10000 MAX_ARCHIVE_RATIO = 10 DEFAULT_PROFILING_COMMANDS = { - "[resnet]": "[resnet]", + "[resnet]": "resnet", } + MS_PROF_PRECHECK_CMD = "msprof-analyze precheck" + + ENABLE_STACKTRACE_LOGGING = False + class TimeConstant: """Time related constants""" diff --git a/profiler/precheck/common/utils.py b/profiler/precheck/common/utils.py index 25cf23bb60..e3f7db1518 100644 --- a/profiler/precheck/common/utils.py +++ b/profiler/precheck/common/utils.py @@ -1,20 +1,16 @@ - import os import sys import hashlib import subprocess import logging from datetime import datetime -import torch_npu +import torch_npu from profiler.precheck.common.constant import TimeConstant - -logger = logging.getLogger() - -current_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.append(os.path.join(current_path, '..')) from profiler.prof_common.path_manager import PathManager +logger = logging.getLogger(__name__) + def get_file_md5(filepath, chunk_size=4096, split_hash_size=4): PathManager.check_input_file_path(filepath) @@ -24,7 +20,7 @@ def get_file_md5(filepath, chunk_size=4096, split_hash_size=4): for chunk in iter(lambda: file.read(chunk_size), b""): md5_hash.update(chunk) hash_bytes = int(md5_hash.hexdigest(), 16).to_bytes(16, 'big') - + chunks = [] for i in range(0, 16, split_hash_size): chunks.append(int.from_bytes(hash_bytes[i:i + split_hash_size], 'big')) @@ -45,7 +41,7 @@ def get_quick_hash(file_path, sample_size=65536, hash_spilt_size=4): f.seek(-sample_size, 2) hash_md5.update(f.read(sample_size)) hash_bytes = int(hash_md5.hexdigest(), 16).to_bytes(16, 'big') - + chunks = [] for i in range(0, 16, hash_spilt_size): chunks.append(int.from_bytes(hash_bytes[i:i + hash_spilt_size], 'big')) @@ -60,15 +56,38 @@ def is_equal_file_hash(chunks1, chunks2): def cat_files(output_file, input_files): + """ + Concatenate multiple binary input files into a single output file using cat command. + + Args: + output_file (str): Path to the output file + input_files (list): List of input file paths to concatenate + + Returns: + bool: True if concatenation was successful + + Raises: + subprocess.CalledProcessError: If the cat command fails + """ PathManager.check_input_file_path(output_file) - cmd = ["cat", *list(input_files), ">>", output_file] - with open(output_file, 'w') as outfile: - result = subprocess.run(cmd, stdout=outfile, capture_output=True, text=True) - if result.returncode == 0: - return True - else: - logger.error("Occurred during concatenation. ERROR: {}".format(result.stderr)) - raise subprocess.CalledProcessError(result.returncode, cmd, output=result.stdout, stderr=result.stderr) + cmd = ["cat"] + list(input_files) + + try: + with open(output_file, 'wb') as outfile: + result = subprocess.run(cmd, stdout=outfile, stderr=subprocess.PIPE) + + if result.returncode == 0: + return True + else: + logger.error("Error occurred during concatenation: %s", + result.stderr.decode('utf-8', errors='replace')) + raise subprocess.CalledProcessError(result.returncode, cmd, + output=None, + stderr=result.stderr) + + except OSError as e: + logger.error("OS error occurred during file operation: %s", str(e)) + raise def compress_directory(src_dir, output_file): @@ -158,14 +177,14 @@ def check_file_owner_and_permission(file_path): RuntimeError: If file not found, not owned by current user, or has wrong permissions """ PathManager.check_path_readable(file_path) - + if not os.path.isfile(file_path): raise RuntimeError(f"File not found at {file_path}") - + # Check file owner if os.stat(file_path).st_uid != os.getuid(): raise RuntimeError(f"File {file_path} is not owned by current user") - + # Check file permissions (only owner should have write permission) current_mode = os.stat(file_path).st_mode desired_mode = 0o700 # rwx------ (only owner has read/write/execute) diff --git a/profiler/precheck/examples/__init__.py b/profiler/precheck/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/examples/profiler/__init__.py b/profiler/precheck/examples/profiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/examples/profiler/dynamic_prof.py b/profiler/precheck/examples/profiler/dynamic_prof.py new file mode 100644 index 0000000000..f4b1e9b849 --- /dev/null +++ b/profiler/precheck/examples/profiler/dynamic_prof.py @@ -0,0 +1,71 @@ +import json +import os +import logging +from copy import deepcopy + +logger = logging.getLogger(__name__) + +DEFAULT_DP_CONFIG = { + "activities": ["CPU", "NPU"], + "prof_dir": "./prof_result", + "analyse": False, + "record_shapes": False, + "profile_memory": False, + "with_stack": False, + "with_flops": False, + "with_modules": False, + "active": 1, + "is_rank": False, + "rank_list": [], + "experimental_config": { + "profiler_level": "Level0", + "aic_metrics": "AiCoreNone", + "l2_cache": False, + "op_attr": False, + "gc_detect_threshold": None, + "data_simplification": True, + "record_op_args": False, + "export_type": "text", + "msprof_tx": False + } +} + + +def _get_prof_config_json(prof_dp_path): + prof_config_json = os.path.join(prof_dp_path, "profiler_config.json") + return prof_config_json + + +def _set_default_prof_config(prof_config_json): + with open(prof_config_json, "w") as f: + json.dump(DEFAULT_DP_CONFIG, f, indent=4) + + +def get_dynamic_prof_config_path(): + cwd = os.path.dirname(os.path.realpath(__file__)) + prof_dp_path = os.path.join(cwd, './local_config/config_dynamic') + + prof_config_json = _get_prof_config_json(prof_dp_path) + os.makedirs(os.path.dirname(prof_config_json), exist_ok=True) + + if not os.path.exists(prof_config_json): + _set_default_prof_config(prof_config_json) + logger.info("Created default dynamic profiler config file at {}".format(prof_config_json)) + + return prof_dp_path + + +def start_dynamic_profiler(prof_dp_path, prof_save_dir): + prof_config_json = _get_prof_config_json(prof_dp_path) + if prof_save_dir is not None: + if not os.path.exists(prof_config_json): + data = deepcopy(DEFAULT_DP_CONFIG) + else: + with open(prof_config_json, 'r') as f: + data = json.load(f) + data['prof_dir'] = prof_save_dir + + with open(prof_config_json, 'w') as f: + json.dump(data, f, indent=4) + + logger.info('has started dynamic profiling') diff --git a/profiler/precheck/examples/profiler/models.py b/profiler/precheck/examples/profiler/models.py new file mode 100644 index 0000000000..4a0f8cc0de --- /dev/null +++ b/profiler/precheck/examples/profiler/models.py @@ -0,0 +1,67 @@ +import logging +from typing import Dict, Any, Tuple + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +# ============= Models ============= +class SimpleResNet(nn.Module): + def __init__(self, num_classes: int = 10): + super().__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.fc = nn.Linear(64 * 56 * 56, num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + return x + + +# ============= Datasets ============= +class DummyImageDataset(Dataset): + def __init__(self, input_shape: Tuple[int, ...], num_samples: int = 100): + self.input_shape = input_shape + self.num_samples = num_samples + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.randn(self.input_shape) + y = torch.randint(0, 10, ()) + return x, y + + +# ============= Example Registry ============= +class ExampleRegistry: + @staticmethod + def get_example_config(example_name: str) -> Dict[str, Any]: + configs = { + "resnet": { + "model_class": SimpleResNet, + "model_args": {"num_classes": 10}, + "dataset_class": DummyImageDataset, + "dataset_args": {"input_shape": (3, 224, 224), "num_samples": 800}, + "batch_size": 8, + }, + } + + if example_name not in configs: + available_models = ", ".join(configs.keys()) + raise ValueError( + f"Unknown example name: {example_name}. " + f"Available models are: {available_models}" + ) + + return configs[example_name] diff --git a/profiler/precheck/examples/profiler/train_with_profiler.py b/profiler/precheck/examples/profiler/train_with_profiler.py new file mode 100644 index 0000000000..08bef3541c --- /dev/null +++ b/profiler/precheck/examples/profiler/train_with_profiler.py @@ -0,0 +1,286 @@ +""" +Example Usage: +1. Single node training examples: +torchrun --nproc_per_node=8 \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr="127.0.0.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + +2. Distributed training examples: + + # Multiple nodes (2 nodes, 8 GPUs each) + # On node 0 (master node): + torchrun --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=0 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + + # On node 1: + torchrun --nproc_per_node=8 \ + --nnodes=2 \ + --node_rank=1 \ + --master_addr="192.168.1.1" \ + --master_port=29500 \ + train_with_profiler.py \ + --example_name bert \ + --prof_output_dir ./profiler_output + +Distributed Training Parameters: +--nproc_per_node: Number of processes per node (typically number of GPUs) +--nnodes: Total number of nodes +--node_rank: Rank of current node (0 to nnodes-1) +--master_addr: IP address of master node +--master_port: Port for master node communication + +Available Models: +- resnet: ResNet model implementation + +Environment Variables (automatically set by torchrun): +- RANK: Global rank of the process +- WORLD_SIZE: Total number of processes +- LOCAL_RANK: Local rank within the current node +- MASTER_ADDR: Master node address +- MASTER_PORT: Master node port +""" + +import os +import argparse +import ipaddress +import datetime +import logging +from typing import Optional, List + +import torch +import torch_npu +import torch.nn as nn +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +try: + from torch_npu.profiler import dynamic_profile as dp +except ImportError: + dp = None + +from profiler.precheck.examples.profiler.models import ExampleRegistry +from profiler.precheck.examples.profiler.dynamic_prof import get_dynamic_prof_config_path +from profiler.precheck.common.constant import Constant + +logger = logging.getLogger(__name__) + + +class ProfilerCallback: + """Callback for handling profiling operations""" + + def __init__(self, prof_save_dir, + is_dynamic=False, dynamic_prof_path=None): + self.profiler = None + self.is_dynamic = is_dynamic + if is_dynamic: + self.dynamic_prof_path = dynamic_prof_path if dynamic_prof_path else get_dynamic_prof_config_path() + self.prof_save_dir = prof_save_dir + + def on_train_begin(self): + if self.is_dynamic: + dp.init(self.dynamic_prof_path) + dist.barrier() + if dist.get_rank() == 0: + from profiler.precheck.examples.profiler.dynamic_prof import start_dynamic_profiler + start_dynamic_profiler(self.dynamic_prof_path, + self.prof_save_dir) + self.profiler = dp + else: + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level2, + l2_cache=False, + data_simplification=False + ) + self.profiler = torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.CPU, + torch_npu.profiler.ProfilerActivity.NPU + ], + with_stack=True, + record_shapes=True, + profile_memory=True, + schedule=torch_npu.profiler.schedule( + wait=5, warmup=5, active=20, repeat=1, skip_first=10), + experimental_config=experimental_config, + with_flops=True, + with_modules=True, + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( + self.prof_save_dir) + ) + self.profiler.__enter__() + + def on_step_end(self): + if self.profiler: + self.profiler.step() + + def on_train_end(self): + if not self.is_dynamic and self.profiler: + self.profiler.__exit__(None, None, None) + + +class Trainer: + def __init__( + self, + model: nn.Module, + dataloader: Optional[Dataset] = None, + callbacks: Optional[List[ProfilerCallback]] = None, + criterion: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + self.model = model + self.dataloader = dataloader + self.callbacks = callbacks or [] + + # Setup loss and optimizer with defaults + self.criterion = criterion or nn.CrossEntropyLoss() + self.optimizer = optimizer or torch.optim.Adam(self.model.parameters()) + + # get dist config from env + self.rank = int(os.environ.get("RANK", 0)) + self.world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) + self.device = f"npu:{self.local_rank}" + + # Setup device and distributed training + self.setup_distributed(self.rank, self.world_size, self.local_rank) + + # Move model and criterion to device + self.model = self.model.to(self.device) + self.criterion = self.criterion.to(self.device) + + @staticmethod + def setup_distributed(rank, world_size, local_rank): + if dist.is_initialized(): + return + + torch.npu.set_device(local_rank) + dist.init_process_group( + backend='hccl', + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=1800) + ) + logger.info(f"[Rank {rank}] Initialized distributed training") + + def cleanup(self): + """Explicitly cleanup distributed training resources""" + if dist.is_initialized(): + dist.destroy_process_group() + logger.info(f"[Rank {self.rank}] Destroyed distributed training") + + def train(self, epoch: int = 1): + # Call training start callbacks + for callback in self.callbacks: + callback.on_train_begin() + + # Training loop + for epoch_idx in range(epoch): + if self.rank == 0: + pbar = tqdm( + total=len(self.dataloader), + desc=f'Epoch {epoch_idx + 1}/{epoch}', + unit='batch' + ) + + for step, (inputs, labels) in enumerate(self.dataloader): + # Move data to device + inputs = inputs.to(self.device) + labels = labels.to(self.device) + + # Forward pass + self.optimizer.zero_grad() + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + # Backward pass + loss.backward() + self.optimizer.step() + + if self.rank == 0: + pbar.update(1) + pbar.set_postfix({ + 'step': f'{step + 1}/{len(self.dataloader)}', + 'loss': f'{loss.item():.4f}' + }) + + dist.barrier() + + # Call step end callbacks + for callback in self.callbacks: + callback.on_step_end() + + if self.rank == 0: + pbar.close() + + # Call training end callbacks + for callback in self.callbacks: + callback.on_train_end() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--example_name', default='resnet', + choices=['resnet'], + help='Name of the example to run') + parser.add_argument('--prof_output_dir', required=True) + parser.add_argument('--static', action='store_true', required=False, default=False) + args = parser.parse_args() + + # Get example configuration + example_config = ExampleRegistry.get_example_config(args.example_name) + + # Create model and dataset + model = example_config["model_class"](**example_config["model_args"]) + dataset = example_config["dataset_class"](**example_config["dataset_args"]) + + # Create loss and optimizer (可选,使用默认值也可以) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + # Create profiler callback + profiler_callback = ProfilerCallback( + args.prof_output_dir, + is_dynamic=(not args.static) + ) + + dataloader = DataLoader(dataset, batch_size=example_config["batch_size"]) + + # Initialize trainer + trainer = Trainer( + model=model, + dataloader=dataloader, + callbacks=[profiler_callback], + criterion=criterion, # 可选 + optimizer=optimizer, # 可选 + ) + + try: + trainer.train() + finally: + trainer.cleanup() + + +if __name__ == '__main__': + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + + try: + main() + except Exception as e: + logger.error(f"Unexpected error: {e}", exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise diff --git a/profiler/precheck/examples/run_precheck.sh b/profiler/precheck/examples/run_precheck.sh deleted file mode 100644 index 59ed6ce852..0000000000 --- a/profiler/precheck/examples/run_precheck.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -# Define node IPs -nodes_ip1=( - "${NODE_IP_0:-7.210.189.120}" - "${NODE_IP_1:-7.210.189.129}" -) - -echo "Starting distributed precheck with ${#nodes_ip1[@]} nodes" -echo "Master node: ${nodes_ip1[0]}" -echo "All nodes: ${nodes_ip1[@]}" - -# Run precheck with distributed configuration -python3 -m profiler.precheck \ - --host_ips ${nodes_ip1[@]} \ - --master_addr ${nodes_ip1[0]} \ - --master_port 29500 \ - --nnodes ${#nodes_ip1[@]} \ - --nproc_per_node 8 \ - --output_dir ./output \ - --static - -echo "Precheck completed" diff --git a/profiler/precheck/examples/scripts/precheck_run_llama2.sh b/profiler/precheck/examples/scripts/precheck_run_llama2.sh new file mode 100644 index 0000000000..e3bf0859e7 --- /dev/null +++ b/profiler/precheck/examples/scripts/precheck_run_llama2.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True + +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +MASTER_ADDR=${MASTER_ADDR:-"192.168.0.1"} +MASTER_PORT=${MASTER_PORT:-6000} +NNODES=${NNODES:-2} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR=${CKPT_SAVE_DIR:-"./ckpt/llama-2-7b"} +CKPT_LOAD_DIR=${CKPT_LOAD_DIR:-"./model_weights/llama-2-7b-legacy"} +TOKENIZER_MODEL=${TOKENIZER_MODEL:-"./model_from_hf/llama-2-7b-hf/tokenizer.model"} +DATA_PATH=${DATA_PATH:-"./dataset/enwiki_text_document"} + +TP=${TP:-2} +PP=${PP:-4} + +# Result directory +OUTPUT_DIR=${OUTPUT_DIR:-"./result/precheck/llama2-1129-2130"} + +PROF_NODE_RES_DIR="$OUTPUT_DIR/node_prof_save_dir" +LOG_FILE="$OUTPUT_DIR/precheck.log" + +# Check if profiling output directory exists before running training +# This prevents starting a long training job if the directory is missing +if [ ! -d "$OUTPUT_DIR" ]; then + echo "Error: Result directory $OUTPUT_DIR does not exist." \ + "Please create the directory before running training" \ + "(in ${BASH_SOURCE[0]})" >&2 + exit 1 +fi + +# Get the directory of the current script and cd into it +# SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +# echo "Script directory: $SCRIPT_DIR" +# cd "$SCRIPT_DIR" +# echo "Changed working directory to: $(pwd)" + + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --sequence-parallel \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 11008 \ + --num-attention-heads 32 \ + --tokenizer-type Llama2Tokenizer \ + --tokenizer-model ${TOKENIZER_MODEL} \ + --seq-length 4096 \ + --max-position-embeddings 4096 \ + --micro-batch-size 1 \ + --global-batch-size 256 \ + --make-vocab-size-divisible-by 1 \ + --lr 1.25e-6 \ + --train-iters 5 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rmsnorm \ + --swiglu \ + --use-flash-attn \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.25e-7 \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --initial-loss-scale 65536 \ + --adam-beta2 0.95 \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng \ + --use-distributed-optimizer \ + --use-fused-swiglu \ + --use-fused-rotary-pos-emb \ + --overlap-grad-reduce \ + --bf16" + +DATA_ARGS=" \ + --data-path $DATA_PATH \ + --split 949,50,1" + +PROFILE_ARGS=" \ + --profile \ + --profile-step-start 2 \ + --profile-step-end 4 \ + --profile-ranks -1 \ + --profile-level level0 \ + --profile-with-cpu \ + --profile-save-path $PROF_NODE_RES_DIR" + +OUTPUT_ARGS=" \ + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 0" + +# Add precheck arguments +# PRECHECK_ARGS=" \ +# --do_precheck" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $PROFILE_ARGS \ + --distributed-backend nccl \ + --load $CKPT_LOAD_DIR \ + --save $CKPT_SAVE_DIR \ + | tee $LOG_FILE diff --git a/profiler/precheck/examples/run_llama2_precheck.sh b/profiler/precheck/examples/scripts/run_llama2_precheck.sh similarity index 65% rename from profiler/precheck/examples/run_llama2_precheck.sh rename to profiler/precheck/examples/scripts/run_llama2_precheck.sh index c431e5563e..495dab8ca6 100644 --- a/profiler/precheck/examples/run_llama2_precheck.sh +++ b/profiler/precheck/examples/scripts/run_llama2_precheck.sh @@ -2,17 +2,17 @@ # You should set the IP addresses of the nodes in the NODES_IP variable # Change the IP addresses to the actual IP addresses of your nodes -NODES_IP="7.210.189.120 7.210.189.129" +NODES_IP="${NODES_IP:-192.168.0.1,192.168.0.2}" -# Convert NODES_IP to an array nodes_ip -IFS=' ' read -r -a nodes_ip <<< "$NODES_IP" +# Convert comma-separated NODES_IP to an array nodes_ip +IFS=',' read -r -a nodes_ip <<< "$NODES_IP" echo "Starting distributed precheck with ${#nodes_ip[@]} nodes" echo "Master node: ${nodes_ip[0]}" echo "All nodes: ${nodes_ip[*]}" -output_dir_base="./result/precheck" +output_dir_base="./result/demo_precheck" # Add timestamp to task name timestamp=$(date +"%Y%m%d_%H%M%S") @@ -21,9 +21,12 @@ task_name="llama2-demo_${timestamp}" output_dir="${output_dir_base}/${task_name}" node_prof_save_dir="${output_dir}/node_prof_save_dir" +# Join array elements with commas +host_ips=$(IFS=,; echo "${nodes_ip[*]}") + # Run precheck with distributed configuration -python3 -m profiler.precheck \ - --host_ips "${nodes_ip[*]}" \ +msprof-analyze precheck start_all \ + --host_ips "${host_ips}" \ --master_addr "${nodes_ip[0]}" \ --master_port 29500 \ --nnodes ${#nodes_ip[@]} \ @@ -31,7 +34,7 @@ python3 -m profiler.precheck \ --output_dir ${output_dir_base} \ --task_name ${task_name} \ --node_prof_save_dir ${node_prof_save_dir} \ - --profiling_cmd "OUTPUT_DIR=${output_dir} bash ./examples/legacy/llama2/precheck_run_llama2.sh" \ + --profiling_cmd "OUTPUT_DIR=${output_dir} bash ./examples/scripts/precheck_run_llama2.sh" \ --static echo "Precheck completed" diff --git a/profiler/precheck/examples/scripts/run_precheck.sh b/profiler/precheck/examples/scripts/run_precheck.sh new file mode 100644 index 0000000000..bf5b3b89cf --- /dev/null +++ b/profiler/precheck/examples/scripts/run_precheck.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# You should set the IP addresses of the nodes in the NODES_IP variable +# Change the IP addresses to the actual IP addresses of your nodes +NODES_IP="${NODES_IP:-192.168.0.1,192.168.0.2}" + + +# Convert comma-separated NODES_IP to an array nodes_ip +IFS=',' read -r -a nodes_ip <<< "$NODES_IP" + +timestamp=$(date +"%Y%m%d_%H%M%S") +task_name="task_demo_${timestamp}" + +echo "Starting distributed precheck with ${#nodes_ip[@]} nodes" +echo "Master node: ${nodes_ip[0]}" +echo "All nodes: ${nodes_ip[@]}" + +output_dir=./output_test + +PROFILING_CMD="[resnet]" + +# Join array elements with commas +host_ips=$(IFS=,; echo "${nodes_ip[*]}") + +# Run precheck with distributed configuration +msprof-analyze precheck start_all \ + --host_ips "${host_ips}" \ + --master_addr ${nodes_ip[0]} \ + --master_port 29500 \ + --nnodes ${#nodes_ip[@]} \ + --nproc_per_node 8 \ + --output_dir "${output_dir}" \ + --task_name ${task_name} \ + --profiling_cmd "${PROFILING_CMD}" \ + --static + +echo "Precheck completed" diff --git a/profiler/precheck/manager/__init__.py b/profiler/precheck/manager/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/manager/args_manager.py b/profiler/precheck/manager/args_manager.py index c1094da990..2386b5eb31 100644 --- a/profiler/precheck/manager/args_manager.py +++ b/profiler/precheck/manager/args_manager.py @@ -7,6 +7,7 @@ import shutil import sys import logging from typing import List, Union +from collections import OrderedDict from profiler.precheck.common.constant import Constant from profiler.precheck.common.utils import cn_now @@ -249,9 +250,8 @@ class BaseArgsManager: self._check_profiling_cmd_valid(self.profiling_cmd) def _check_profiling_cmd_valid(self, profiling_cmd: str) -> None: - if not profiling_cmd: - logger.warning('Profiling command is not provided.') - return + if not profiling_cmd.strip(): + logger.error('Profiling command should not be empty.') if profiling_cmd in Constant.DEFAULT_PROFILING_COMMANDS: logger.info('Using default profiling command for {}', profiling_cmd) @@ -276,16 +276,17 @@ class PrecheckArgsManager(BaseArgsManager): self._args = args self._ssh_remote_hosts = {} + self._host_ips = [] self.check_args() @property def host_ips(self): - return self._args.host_ips + return self._host_ips @property - def ssh_config_file(self): - return self._args.ssh_config_file + def host_config_file(self): + return self._args.host_config_file @property def ssh_remote_hosts(self): @@ -305,7 +306,7 @@ class PrecheckArgsManager(BaseArgsManager): @classmethod def _check_host_ips_valid(cls, host_ips: List[str]) -> Union[Exception, None]: if not host_ips: - return ValueError("Host IPs must be provided.") + return None for i, ip in enumerate(host_ips): if not ipaddress.ip_address(ip): @@ -316,30 +317,32 @@ class PrecheckArgsManager(BaseArgsManager): return None - def try_to_parse_ssh_config_file(self, ssh_config_file: str) -> Union[Exception, dict]: - if not ssh_config_file: + def try_to_parse_host_config_file(self, host_config_file: str) -> Union[Exception, OrderedDict]: + if not host_config_file: logger.info("SSH config file is not provided.") logger.info("Use default ssh settings for all nodes: ssh_key_file, user, port = ~/.ssh/id_rsa, $USER, 22") return {} - if not os.path.isfile(ssh_config_file): - return FileNotFoundError(f"SSH config file {ssh_config_file} does not exist.") + if not os.path.isfile(host_config_file): + return FileNotFoundError(f"SSH config file {host_config_file} does not exist.") - PathManager.check_path_readable(ssh_config_file) + PathManager.check_path_readable(host_config_file) ssh_remote_hosts = [] required_fields = ['host_ip', 'ssh_key_file', 'user', 'port'] - with open(ssh_config_file, 'r') as f: + with open(host_config_file, 'r') as f: header = f.readline().strip().split(',') if any(field not in header for field in required_fields): - return ValueError(f"SSH config file {ssh_config_file} is missing required fields: {required_fields}") + return ValueError(f"Host config file {host_config_file} is missing required fields: {required_fields}") for line in f: values = line.strip().split(',') if len(values) != len(required_fields): - return ValueError(f"SSH config file {ssh_config_file} has invalid number of fields in line: {line}") + return ValueError( + f"Host config file {host_config_file} has invalid number of fields in line: {line}") host_ip, ssh_key_file, user, port = values + port = int(port) exception = None try: @@ -357,7 +360,7 @@ class PrecheckArgsManager(BaseArgsManager): if exception: return RuntimeError( - f"SSH config file {ssh_config_file} is not valid, invalid line: {line}, error: {exception}") + f"Host config file {host_config_file} is not valid, invalid line: {line}, error: {exception}") ssh_remote_hosts.append({ 'host': host_ip, @@ -366,8 +369,8 @@ class PrecheckArgsManager(BaseArgsManager): 'port': int(port) }) - self._ssh_remote_hosts = {item['host']: item for item in ssh_remote_hosts} - return self._ssh_remote_hosts + ssh_remote_hosts = OrderedDict({item['host']: item for item in ssh_remote_hosts}) + return ssh_remote_hosts def check_args(self): super().check_args() @@ -376,25 +379,46 @@ class PrecheckArgsManager(BaseArgsManager): if error: self.raise_error('Python path {} is not valid: {}', self.python_path, error) - error = self._check_host_ips_valid(self.host_ips) - if error: - self.raise_error('Host ips {} is not valid: {}', self.host_ips, error) + # Ensure either host_ips or host_config_file is provided + if not self.host_config_file and not self._args.host_ips: + self.raise_error('Either host config file or host ips must be provided') + + # If host_ips is provided, validate it first + if self._args.host_ips: + error = self._check_host_ips_valid(self._args.host_ips) + if error: + self.raise_error('Host ips {} is not valid: {}', self._args.host_ips, error) + + # Set the validated host_ips + self._host_ips = self._args.host_ips + + # If config file is provided, parse and validate it + if self.host_config_file: + res = self.try_to_parse_host_config_file(self.host_config_file) + if isinstance(res, Exception): + self.raise_error('Host config file {} is not valid: {}', self.host_config_file, res) + self._ssh_remote_hosts = res + config_file_ips = list(self._ssh_remote_hosts.keys()) + + # If host_ips is also provided, verify they match + if self._args.host_ips: + if not set(self._args.host_ips) == set(config_file_ips): + self.raise_error('Host ips does not match the IPs in host config file. Given: {}, In file: {}', + self._args.host_ips, config_file_ips) + else: + # If only config file is provided, use IPs from the config file + self._host_ips = config_file_ips + # Validate number of nodes and master node configuration if self.nnodes != len(self.host_ips): self.raise_error( - 'The number of nodes {} is not equal to the number of host ips {}', self.nnodes, len(self.host_ips)) + 'The number of nodes {} is not equal to the number of host ips {}', + self.nnodes, len(self.host_ips)) if self.master_addr != self.host_ips[0]: self.raise_error( - 'The master address {} is not the first host ip {}', self.master_addr, self.host_ips[0]) - - res = self.try_to_parse_ssh_config_file(self.ssh_config_file) - if isinstance(res, Exception): - self.raise_error('SSH config file {} is not valid: {}', self.ssh_config_file, res) - - if self.ssh_config_file and not set(self.host_ips) == set(self.ssh_remote_hosts.keys()): - self.raise_error('Host ips is not equal to the host ips in ssh config file {}', self.args.host_ips, - self.ssh_config_file) + 'The master address {} is not the first host ip {}', + self.master_addr, self.host_ips[0]) class PrecheckRunnerArgsManager(BaseArgsManager): @@ -414,68 +438,3 @@ class PrecheckRunnerArgsManager(BaseArgsManager): error = self._check_int_range(self.node_rank, min_value=0, max_value=self.nnodes - 1) if error: self.raise_error('Node rank {} is not valid: {}', self.node_rank, error) - - -def _add_precheck_base_args(parser: argparse.ArgumentParser): - parser.add_argument('--master_addr', type=str, required=True, - help='IP address of the master node (first node in the cluster).') - - parser.add_argument('--master_port', type=int, default=29500, - help='Port on the master node for communication. Default is 29500.') - - parser.add_argument('--nnodes', type=int, required=True, - help='Total number of nodes in the distributed setup.') - - parser.add_argument('--nproc_per_node', type=int, required=True, - help='Number of processes to run per node ' - '(usually corresponds to the number of NPUs per node).') - - parser.add_argument('--node_prof_save_dir', default='', type=str, - help='Directory for saving node profiling data.') - - parser.add_argument('--master_prof_gather_dir', default='', type=str, - help='Directory for saving gathered profiling data in master node.') - - parser.add_argument('--output_dir', default='./output', type=str, - help='Directory to save profiling dump data, logs, and advisor reports.') - - parser.add_argument('--task_name', default='', type=str, - help='Name of the task or experiment, used to organize output directories.') - - parser.add_argument('--static', action='store_true', - help='If set, run profiling in static mode.') - - parser.add_argument('--profiling_cmd', type=str, default="", - help='Command to run the profiler script') - - -def get_precheck_args(): - parser = argparse.ArgumentParser() - - _add_precheck_base_args(parser) - - parser.add_argument('--host_ips', nargs='+', type=str, default=None, - help='List of IP addresses for nodes in distributed training') - - parser.add_argument('--python_path', default=sys.executable, type=str, - help='Path to the Python interpreter to use for launching the distributed precheck.') - - parser.add_argument('--ssh_configs', default='', type=str, - help='Path to the SSH config file for nodes in distributed training, in csv format with fields: host_ip, ssh_key_file, user, port') - - args = PrecheckArgsManager(parser.parse_args()) - - return args - - -def get_precheck_runner_args(): - parser = argparse.ArgumentParser() - - _add_precheck_base_args(parser) - - parser.add_argument('--node_rank', required=True, type=int, - help='Rank of the current node (unique ID for each node in the cluster).') - - args = PrecheckRunnerArgsManager(parser.parse_args()) - - return args diff --git a/profiler/precheck/manager/disk_manager.py b/profiler/precheck/manager/disk_manager.py index e149fa277f..a497c992cb 100644 --- a/profiler/precheck/manager/disk_manager.py +++ b/profiler/precheck/manager/disk_manager.py @@ -1,7 +1,7 @@ import os import logging -logger = logging.getLogger() +logger = logging.getLogger(__name__) class DiskManager: diff --git a/profiler/precheck/manager/group_manager.py b/profiler/precheck/manager/group_manager.py index 2109a1465f..f78bf6dead 100644 --- a/profiler/precheck/manager/group_manager.py +++ b/profiler/precheck/manager/group_manager.py @@ -5,7 +5,6 @@ import torch.distributed as dist from profiler.advisor.utils.utils import singleton - class EnvGroup: def __init__(self, rank, local_rank, world_size, master_addr, master_port, group_rank, local_world_size): self.rank = rank @@ -39,8 +38,14 @@ class EnvGroup: if not isinstance(self.local_world_size, int): raise ValueError('local_world_size must be an integer') - if not (dist.is_available() and dist.is_initialized()): - raise RuntimeError("Distributed environment is not available") + def set_env(self): + os.environ["RANK"] = str(self.rank) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["GROUP_RANK"] = str(self.group_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(self.local_world_size) class SubGroup: @@ -85,16 +90,6 @@ class GroupManager: self._node_group = None self._sub_group_dict = {} - @staticmethod - def set_env(slave_env: EnvGroup): - os.environ["RANK"] = str(slave_env.rank) - os.environ["LOCAL_RANK"] = str(slave_env.local_rank) - os.environ["WORLD_SIZE"] = str(slave_env.world_size) - os.environ["MASTER_ADDR"] = slave_env.master_addr - os.environ["MASTER_PORT"] = str(slave_env.master_port) - os.environ["GROUP_RANK"] = str(slave_env.group_rank) - os.environ["LOCAL_WORLD_SIZE"] = str(slave_env.local_world_size) - def get_rank(self): return self._rank diff --git a/profiler/precheck/manager/task_manager.py b/profiler/precheck/manager/task_manager.py index fe44c99641..3f0875b46e 100644 --- a/profiler/precheck/manager/task_manager.py +++ b/profiler/precheck/manager/task_manager.py @@ -1,21 +1,17 @@ -import sys import os import logging import argparse -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from analyze.advisor_adaptor import advisor_adaptor -from analyze.module1 import module1 +from profiler.precheck.analyze.advisor_adaptor import advisor_adaptor from profiler.prof_common.path_manager import PathManager -logger = logging.getLogger() +logger = logging.getLogger(__name__) + class TaskManager: - ADVISOR='advisor' - MODULE1='module1' + ADVISOR = 'advisor' supported_analyzer = { - ADVISOR : advisor_adaptor, - MODULE1 : module1, + ADVISOR: advisor_adaptor, } all_analyzer = list(supported_analyzer.keys()) @@ -34,21 +30,20 @@ class TaskManager: def get_result(analyzer_name, input_path, output): if analyzer_name not in TaskManager.all_analyzer: - logger.error("Error analyzer %s, supported analyzer are %s",analyzer_name ,TaskManager.all_analyzer) - raise ValueError("Error analyzer %s, supported analyzer are %s",analyzer_name ,TaskManager.all_analyzer ) - + logger.error("Error analyzer %s, supported analyzer are %s", analyzer_name, TaskManager.all_analyzer) + raise ValueError("Error analyzer %s, supported analyzer are %s", analyzer_name, TaskManager.all_analyzer) + input_profiling_path_real = PathManager.get_realpath(input_path) output_path_real = PathManager.get_realpath(output) try: analyze = TaskManager.get_analyzer(analyzer_name) analyzer_instance = analyze() - result = analyzer_instance.analyze(input_profiling_path=input_profiling_path_real, output_path=output_path_real) - - except Exception as e: - logger.error("%s is skipped when an exception is encountered. The exception is as follows: %s", analyzer_name, e) - - + result = analyzer_instance.analyze(input_profiling_path=input_profiling_path_real, + output_path=output_path_real) + except Exception as e: + logger.error("%s is skipped when an exception is encountered. The exception is as follows: %s", + analyzer_name, e) def get_args(): @@ -80,6 +75,5 @@ if __name__ == "__main__": TaskManager.get_result(analyzer=analyzer, input_profiling_path=input_profiling_path, output_path=output_path) - except RuntimeError as error: - logger.error("[ERROR] {%s}",error) - \ No newline at end of file + except Exception as error: + logger.error("%s", error) diff --git a/profiler/precheck/runner/__init__.py b/profiler/precheck/runner/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/runner/__main__.py b/profiler/precheck/runner/__main__.py index d27e8a8e28..f5dde70a85 100644 --- a/profiler/precheck/runner/__main__.py +++ b/profiler/precheck/runner/__main__.py @@ -3,14 +3,13 @@ import subprocess import sys import os import logging -import datetime +from profiler.precheck.common.constant import Constant from profiler.precheck.common.logger import add_file_handler, create_logger from profiler.precheck.common.utils import check_file_owner_and_permission, cn_now -from profiler.precheck.manager.args_manager import get_precheck_runner_args +from profiler.precheck.manager.args_manager import PrecheckRunnerArgsManager from profiler.precheck.runner.runners import CollectorRunner, AdvisorRunner from profiler.precheck.manager.distribute_manager import DistributeManager -from profiler.precheck.examples.profiler.profiler_demo import ProfilerRunner from profiler.prof_common.path_manager import PathManager @@ -46,7 +45,7 @@ def get_conda_envs_info(python_path=sys.executable): return conda_env, conda_activate_script -def main(args, logger): +def start_precheck_runner(args: PrecheckRunnerArgsManager, logger: logging.Logger): logger.info("Starting precheck runner with arguments: %s", args) dist_config = DistributeManager(args) @@ -66,27 +65,44 @@ def main(args, logger): # start profiling logger.info("Starting profiler runner") - if args.profiling_cmd: - # Build command list - conda_env, conda_activate_script = get_conda_envs_info() + conda_env, conda_activate_script = get_conda_envs_info() + profiler_example_name = Constant.DEFAULT_PROFILING_COMMANDS.get(args.profiling_cmd, None) + if profiler_example_name is None: profiling_cmd = [ "/bin/bash", "-c", f"source {conda_activate_script} {conda_env} && cd {os.getcwd()} && " - f"NODE_RANK={dist_config.node_rank} {args.profiling_cmd}" + f"MASTER_ADDR={dist_config.master_addr} MASTER_PORT={dist_config.master_port} " + f"NNODES={dist_config.nnodes} NODE_RANK={dist_config.node_rank} " + f"NPROC_PER_NODE={dist_config.nproc_per_node} " + f"{args.profiling_cmd}" ] - - logger.info("Using custom profiling command: %s", ' '.join(profiling_cmd)) - try: - logger.info("Executing profiling command...") - subprocess.run(profiling_cmd, check=True, capture_output=False, text=True) - logger.info("Profiling command completed successfully") - except subprocess.CalledProcessError as e: - logger.error("Profiling command failed with error: %s", e, exc_info=False) - raise RuntimeError("Profiling command failed with error: %s" % e) from e else: - # Use default ProfilerRunner - ProfilerRunner(prof_res_dir=prof_node_res_dir, is_dynamic=(not args.static), config=dist_config).run() + profiler_example_base = os.path.join(os.path.dirname(os.path.dirname(__file__)), "examples", "profiler", ) + + profiling_cmd = [ + "/bin/bash", "-c", + f"source {conda_activate_script} {conda_env} && cd {os.getcwd()} && " + f"torchrun " + f"--master_addr={dist_config.master_addr} " + f"--master_port={dist_config.master_port} " + f"--nproc_per_node={dist_config.nproc_per_node} " + f"--nnodes={dist_config.nnodes} " + f"--node_rank={dist_config.node_rank} " + f"{os.path.join(profiler_example_base, 'train_with_profiler.py')} " + f"--example_name {profiler_example_name} " + f"--prof_output_dir {prof_node_res_dir}" + + (" --static" if args.static else "") + ] + + logger.info("Using custom profiling command: %s", ' '.join(profiling_cmd)) + try: + logger.info("Executing profiling command...") + subprocess.run(profiling_cmd, check=True, capture_output=False, text=True) + logger.info("Profiling command completed successfully") + except subprocess.CalledProcessError as e: + logger.error("Profiling command failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + raise # zip and transport to master logger.info("Starting collector runner") @@ -100,9 +116,8 @@ def main(args, logger): logger.info("Completed precheck runner execution") -if __name__ == '__main__': +def main(args=None): logger = create_logger("profiler.precheck", logging.DEBUG, use_memory_handler=True) - args = get_precheck_runner_args() output_dir = os.path.join(args.output_dir, args.task_name) PathManager.make_dir_safety(output_dir) @@ -112,6 +127,7 @@ if __name__ == '__main__': logger = add_file_handler(logger, log_file_path) try: - main(args, logger) + start_precheck_runner(args, logger) except Exception as e: - logger.error("Precheck runner failed with error: %s", e, exc_info=False) + logger.error("Precheck runner failed with error: %s", e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) + diff --git a/profiler/precheck/runner/runners.py b/profiler/precheck/runner/runners.py index f1a1a086e7..347de9eea3 100644 --- a/profiler/precheck/runner/runners.py +++ b/profiler/precheck/runner/runners.py @@ -63,7 +63,7 @@ class CollectorRunner: logger.info('%s init', self.__class__.__name__) @staticmethod - def zip_directory(self, src_dir): + def zip_directory(src_dir): """Zip the specified directory.""" zip_file_path = f"{src_dir}.zip" @@ -85,6 +85,8 @@ class CollectorRunner: src_dir=src_dir, output_path=zip_file_path, whitelist=Constant.PROFILER_FILE_PATTERNS, + use_regex=True, + regex_fullmatch=False, )) logger.info('Successfully created new zip file %s', zip_file_path) @@ -100,17 +102,21 @@ class CollectorRunner: """Transport the zip file to the destination.""" def run_collector(input_file_dir, output_file_dir: str, config: DistributeManager): - nnodes = config.nnodes - node_rank = config.node_rank - master_addr = config.master_addr - master_port = config.master_port - master_rank_num = Constant.COLLECTOR_MASTER_RANK_NUM - split_file_size = Constant.COLLECTOR_SPLIT_FILE_SIZE - timeout = Constant.COLLECTOR_DEFAULT_TIMEOUT + args_dict = { + "input_file_dir": input_file_dir, + "output_file_dir": output_file_dir, + "nnodes": config.nnodes, + "node_rank": config.node_rank, + "master_addr": config.master_addr, + "master_port": config.master_port, + "master_rank_num": Constant.COLLECTOR_MASTER_RANK_NUM, + "split_file_size": Constant.COLLECTOR_SPLIT_FILE_SIZE, + "time_out": Constant.COLLECTOR_DEFAULT_TIMEOUT, + "log_file": None + } from profiler.precheck.collect.collector import Collector - Collector().run(input_file_dir, output_file_dir, nnodes, node_rank, master_addr, master_port, - master_rank_num, split_file_size, timeout) + Collector().run(args_dict) run_collector(zip_file, self.des_dir, self.config) diff --git a/profiler/precheck/setup.py b/profiler/precheck/setup.py deleted file mode 100644 index 8e4012b020..0000000000 --- a/profiler/precheck/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -import os - -from setuptools import setup, find_packages - -cwd = os.path.abspath(os.path.dirname(__file__)) -root_path = os.path.dirname(os.path.dirname(cwd)) - -setup( - name='precheck', - version='0.2.2', - package_dir={"": root_path}, # 如果源代码都在当前目录下,可以保留这个 - packages=find_packages(where=root_path), - entry_points={ - "console_scripts": [ - "precheck-cli=profiler.precheck.__main__:main" # 确保路径正确 - ] - }, -) diff --git a/profiler/precheck/tools/__init__.py b/profiler/precheck/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/precheck/tools/archive_utils.py b/profiler/precheck/tools/archive_utils.py index 18ef7c13af..bb5ec0646e 100644 --- a/profiler/precheck/tools/archive_utils.py +++ b/profiler/precheck/tools/archive_utils.py @@ -108,14 +108,15 @@ def create_archive(archive_args: ArchiveConfig): return any(map(pattern_matcher, whitelist)) # Get all files in source directory recursively - files = glob.glob('**/*', root_dir=archive_args.src_dir, recursive=True) + abs_files = glob.glob(os.path.join(archive_args.src_dir, '**', '*'), recursive=True) + files = [os.path.relpath(file, archive_args.src_dir) for file in abs_files] files_to_add = [ - file for file in files - if should_include_file(file) and os.path.isfile(os.path.join(archive_args.src_dir, file)) + file for file_abs_path, file in zip(abs_files, files) + if should_include_file(file) and os.path.isfile(file_abs_path) ] - logger.info("Has find %d files to add at path: %s", len(files_to_add), archive_args.src_dir) + logger.info("Has found %d files to add at path: %s", len(files_to_add), archive_args.src_dir) # Process files based on archive type (tar or zip) def add_files_to_tar(files_to_add): diff --git a/profiler/precheck/tools/ssh_utils.py b/profiler/precheck/tools/ssh_utils.py index 9a084e5ed9..e996676f15 100644 --- a/profiler/precheck/tools/ssh_utils.py +++ b/profiler/precheck/tools/ssh_utils.py @@ -117,7 +117,7 @@ def execute_ssh_command(config: SSHConfig, command: str) -> dict: 'output': result.stdout } except subprocess.CalledProcessError as e: - logger.error("SSH command failed on %s: %s", config.host, e, exc_info=True) + logger.error("SSH command failed on %s: %s", config.host, e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) return { 'success': False, 'output': e.stderr @@ -183,7 +183,7 @@ def execute_ssh_command_in_tmux(config: SSHConfig, session_name: str, command: s attach_cmd, config.username, config.host) except Exception as e: - logger.error("Failed to connect to %s: %s", config.host, e, exc_info=True) + logger.error("Failed to connect to %s: %s", config.host, e, exc_info=Constant.ENABLE_STACKTRACE_LOGGING) raise RuntimeError(f"Fail to start host {config.host}") from e return dict( @@ -223,6 +223,7 @@ def run_remote_command(hosts_info: List[dict], session_name: str = None, using_t session_name = f"auto_{user}_{cn_now().strftime('%m%d')}" results = [] + for host_info in hosts_info: config = SSHConfig( host=host_info["host"], -- Gitee From 2d46e267a6408b882cfb905c8c222837c399f58b Mon Sep 17 00:00:00 2001 From: hhqx Date: Sat, 21 Dec 2024 18:51:42 +0800 Subject: [PATCH 2/2] fix precheck_cli.py fix precheck_cli.py --- profiler/cli/precheck_cli.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/profiler/cli/precheck_cli.py b/profiler/cli/precheck_cli.py index 2928fedcd3..7b58d5924c 100644 --- a/profiler/cli/precheck_cli.py +++ b/profiler/cli/precheck_cli.py @@ -5,12 +5,7 @@ from functools import wraps import click -from profiler.precheck.manager.args_manager import PrecheckArgsManager, PrecheckRunnerArgsManager -from profiler.precheck.runner.__main__ import main as runner_main -from profiler.precheck.__main__ import main as precheck_main - logger = logging.getLogger(__name__) - CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) @@ -83,6 +78,9 @@ def precheck_start_all(**kwargs): if kwargs.get('host_ips') and kwargs.get('host_config_file'): raise click.UsageError('Cannot specify both --host_ips and --host_config_file') + from profiler.precheck.manager.args_manager import PrecheckArgsManager + from profiler.precheck.__main__ import main as precheck_main + args = PrecheckArgsManager(type('Args', (), kwargs)) click.echo(args) precheck_main(args) @@ -96,6 +94,9 @@ def precheck_start_all(**kwargs): help='Rank of the current node') def precheck_start_node(**kwargs): """Run precheck runner command""" + from profiler.precheck.manager.args_manager import PrecheckRunnerArgsManager + from profiler.precheck.runner.__main__ import main as runner_main + args = PrecheckRunnerArgsManager(type('Args', (), kwargs)) click.echo(args) -- Gitee