diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8ca3a8fb5968dab3de119e87c47f7d3436655a36 Binary files /dev/null and b/.DS_Store differ diff --git a/profiler/.DS_Store b/profiler/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..207a233c75202f3540e1055e330f46def4ce2ae6 Binary files /dev/null and b/profiler/.DS_Store differ diff --git a/profiler/msprof_analyze/.DS_Store b/profiler/msprof_analyze/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..56ccf3bf95d72fe3a2f9b2a624016d4f45491454 Binary files /dev/null and b/profiler/msprof_analyze/.DS_Store differ diff --git a/profiler/msprof_analyze/precheck/.DS_Store b/profiler/msprof_analyze/precheck/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..317babd10ce50122fb52f0035f1dc564a11cd03a Binary files /dev/null and b/profiler/msprof_analyze/precheck/.DS_Store differ diff --git a/profiler/msprof_analyze/precheck/env_check/analyze.py b/profiler/msprof_analyze/precheck/env_check/analyze.py new file mode 100644 index 0000000000000000000000000000000000000000..ebddca097e25d2ba48169b4c5d6f3b8998b2be01 --- /dev/null +++ b/profiler/msprof_analyze/precheck/env_check/analyze.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +import torch_npu + + +class Timer: + def __init__(self): + self.start_time = None + self.end_time = None + self.delta = 0 + + def start(self): + self.start_time = torch_npu.npu.Event(enable_timing=True) + self.end_time = torch_npu.npu.Event(enable_timing=True) + self.start_time.record() + + def stop(self): + self.end_time.record() + torch_npu.npu.synchronize() + self.delta = self.start_time.elapsed_time(self.end_time) / 1000 \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/env_check/communication_check.py b/profiler/msprof_analyze/precheck/env_check/communication_check.py index 807d4008115422ff312d2877495273bc25312eea..fd76bf87d4a4b70f5cee3d2dce93f613d9226b7b 100644 --- a/profiler/msprof_analyze/precheck/env_check/communication_check.py +++ b/profiler/msprof_analyze/precheck/env_check/communication_check.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# Copyright (c) 2025, Huawei Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,14 +12,208 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from msprof_analyze.precheck.env_check.environment_check import HardwareCheck +import os +from typing import List, Optional +import torch +from torch import distributed as dist +import torch_npu +from distributed_cluster_base import DistributedClusterBase +import argparse +import json +from analyze import Timer +import logging -class CommunicationCheck(HardwareCheck): + +class CommunicationCheck: CHECK_TYPE = "communication" - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, args): + self.world_size = args.world_size + self.output = args.output if args is not None else "./env_check.json" + self.tmp_path = '/tmp/p2p_test' + self.tp = args.tensor_model_parallel_size + self.pp = args.pipeline_model_parallel_size + self.cp = args.context_model_parallel_size + self.ep = args.expert_model_parallel_size + self._validate_parallel_strategy() + self._initialize() + self.rank = dist.get_rank() + self.no_share_storage = False or args.no_shared_storage + + + def _validate_parallel_strategy(self): + if self.tp * self.pp * self.cp * self.ep == 0: + raise ValueError(f"The value of parallel strategy is not correct, tp:{self.tp}, pp:{self.pp}, cp:{self.cp}, ep:{self.ep}") + + def _initialize(self): + self.dcb = DistributedClusterBase() + if not dist.is_initialized(): + self.dcb.initialize_cluster_distributed() + self.dcb.initialize_communication_group(self.tp, self.pp, self.cp, self.ep) + + # 获取dp大小 + def get_dp_size(self): + # 返回world_size除以tp、pp、cp、ep的乘积 + return self.world_size // (self.tp * self.pp * self.cp * self.ep) + + def comm_group_destroy(self): + """destroy the communication groups""" + self.dcb.destroy_comm() + + def get_rank_to_next(self, ranks): + cur_index = ranks.index(dist.get_rank()) + return ranks[(cur_index + 1) % len(ranks)] + + def get_rank_to_prev(self, ranks): + cur_index = ranks.index(dist.get_rank()) + return ranks[(cur_index - 1) % len(ranks)] + + def get_comm_group(self, group_ranks): + if group_ranks is None: + raise ValueError("The group is None, please check again") + for sub in group_ranks: + if dist.get_rank() in sub: + return sub + raise ValueError("Failed to find the current group") + + def is_pipeline_first_rank(self, group): + return dist.get_rank() == group[0] + + def is_pipeline_last_rank(self, group): + return dist.get_rank() == group[-1] + + def comm_group_create(self): + group_names = [] + if self.get_dp_size() > 1: + group_names.append("dp") + if self.tp > 1: + group_names.append("tp") + if self.pp > 1: + group_names.append("pp") + if self.ep > 1: + group_names.append("ep") + if self.cp > 1: + group_names.append("cp") + return group_names + + def get_specified_group(self, group_name): + if group_name == "dp": + return self.dcb.get_data_parallel_group_ranks(), self.dcb.get_data_parallel_group() + elif group_name == "tp": + return self.dcb.get_tensor_parallel_group_ranks(), self.dcb.get_tensor_parallel_group() + elif group_name == "pp": + return self.dcb.get_pipeline_parallel_group_ranks(), self.dcb.get_pipeline_parallel_group() + elif group_name == "ep": + return self.dcb.get_expert_parallel_group_ranks(), self.dcb.get_expert_parallel_group() + elif group_name == "cp": + return self.dcb.get_context_parallel_group_ranks(), self.dcb.get_context_parallel_group() + else: + raise ValueError(f"Invalid group name: {group_name}") + + def _perform_p2p_communication_async(self, comm_ranks, tensor_send, tensor_recv, proc_group): + """ + Perform point-to-point communication asynchronously. + """ + send_ops = [] + recv_ops = [] + if not self.is_pipeline_last_rank(comm_ranks): + send_ops.append(dist.isend(tensor=tensor_send, dst=self.get_rank_to_next(ranks=comm_ranks), group=proc_group)) + + if not self.is_pipeline_first_rank(comm_ranks): + tensor_recv_next = torch.empty(tensor_send.size(), dtype=torch.float32).npu() + recv_ops.append(dist.irecv(tensor=tensor_recv_next, src=self.get_rank_to_prev(ranks=comm_ranks), group=proc_group)) + + # 等待接收操作完成 + for op in recv_ops: + op.wait() + + if not self.is_pipeline_first_rank(comm_ranks): + send_ops.append(dist.isend(tensor=tensor_recv_next, dst=self.get_rank_to_prev(ranks=comm_ranks), group=proc_group)) + + if not self.is_pipeline_last_rank(comm_ranks): + recv_ops.append(dist.irecv(tensor=tensor_recv, src=self.get_rank_to_next(ranks=comm_ranks), group=proc_group)) + + # 等待所有发送和接收操作完成 + for op in send_ops + recv_ops: + op.wait() + + def p2p_comm(self, iteration: int = 5, batch_size=(32, 4096, 8192)): + """Perform point-to-point communication between ranks in the group""" + group_names = self.comm_group_create() + local_cost = {"rank": dist.get_rank()} + if not group_names: + raise ValueError("The tensor-parallel, pipeline-parallel, data-parallel, \ + expert-parallel, context-parallel is set less than 2, no peer communication can be created") + for group_name in group_names: + ranks, proc_group = self.get_specified_group(group_name=group_name) + local_cost[group_name] = [] + if len(ranks) <= 1: + raise ValueError(f"Failed to start communication group {group_name}, since the group is {ranks}.") + for i in range(iteration): + if dist.get_rank() == 0: + logging.info(f">>Start communication: {group_name}, iteration: {i}") + tensor_send = torch.randn(batch_size, dtype=torch.float32).npu() + tensor_recv = torch.empty(batch_size, dtype=torch.float32).npu() + dist.barrier() + timer = Timer() + timer.start() + self._perform_p2p_communication_async(ranks, tensor_send, tensor_recv, proc_group) + timer.stop() + local_cost[group_name].append(timer.delta) + local_cost[f"{group_name}avg_cost"] = sum(local_cost[group_name]) / iteration + self.dump_file(local_cost, self.output) + self.comm_group_destroy() + if self.no_share_storage: + self.collect_data() + + def get_file_path(self, file_path, file_name): + """get the file path of the given file name""" + if not os.path.exists(file_path): + logging.info(f"path is not exist, creating output dir: {file_path}") + os.makedirs(file_path) + return os.path.join(file_path, file_name) + + def dump_file(self, data, save_dir): + """save data to file as worker_x_rank_y.json""" + cur_rank = dist.get_rank() + worker_id = os.getenv("GROUP_RANK") + if worker_id is None: + raise ValueError("GROUP_ID environment variable is not set.") + if data is None: + raise ValueError(f"data is not created by rank {cur_rank}, please check this") + dump_path = self.get_file_path(self.tmp_path, f"worker_{worker_id}_rank_{cur_rank}.json") + try: + with open(dump_path, 'w') as f: + json.dump(data, f, indent=4) + except Exception as e: + print(f"An error occurred while rank {cur_rank} saving data to {dump_path}: {e}") + + def collect_data(self): + """ + 将全部节点搜集的日志汇聚到master节点,共享存储不需要该操作 + """ + worker_id = os.getenv("GROUP_RANK") + if worker_id is None: + raise ValueError("GROUP_ID environment variable is not set.") + send_file_dir = self.get_file_path(self.tmp_path, f'worker_{worker_id}_rank_{self.rank}.json') + receive_file_dir = os.path.join(self.output, 'collected_data') + if self.rank == 0: + print(f"master node {self.rank} is collecting data from other nodes") + self.dcb.collect_global_info(send_file_dir, receive_file_dir, time_out=1800, log_file="./comm_test.log") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--world_size", type=int, default=16, required=True) + parser.add_argument("--output", type=str, default="./env_check.json") + parser.add_argument("--tensor_model_parallel_size", type=int, default=4, required=True) + parser.add_argument("--pipeline_model_parallel_size", type=int, default=4, required=True) + parser.add_argument("--context_model_parallel_size", type=int, default=1, required=True) + parser.add_argument("--expert_model_parallel_size", type=int, default=1, required=True) + args = parser.parse_args() - def check(self): - pass + comm_check = CommunicationCheck(args=args) + comm_check.p2p_comm() + comm_check.collect_data() + comm_check.comm_group_destroy() \ No newline at end of file diff --git a/profiler/msprof_analyze/precheck/env_check/io_check.py b/profiler/msprof_analyze/precheck/env_check/io_check.py index 5cfd5c425f0d18d7021c8ef8dca7447c9df6dfc6..55c1c51bdc0e202f577a51c141b3eaab28c17cdc 100644 --- a/profiler/msprof_analyze/precheck/env_check/io_check.py +++ b/profiler/msprof_analyze/precheck/env_check/io_check.py @@ -14,12 +14,175 @@ # limitations under the License. from msprof_analyze.precheck.env_check.environment_check import HardwareCheck +import os +import time +import logging +import numpy as np +import shutil +import json + class IOCheck(HardwareCheck): - CHECK_TYPE = "io" + def __init__(self, args=None): + self.base_dir = './data_file' + self.global_worker_id = os.getenv("GROUP_RANK") + self.local_rank = os.getenv("LOCAL_RANK") + self._ensure_base_dir_exists() + self.output=args.output if args is not None else './results.json' + self.results = {'local_rank': self.local_rank, + 'global_rank':self.global_worker_id, + 'read_cost': { + 'ckpt_read': 0, + 'log_read': 0, + }, + 'write_cost': { + 'ckpt_write': 0, + 'log_write': 0, + }} + + def _ensure_base_dir_exists(self): + """确保数据目录存在""" + if not os.path.exists(self.base_dir): + os.makedirs(self.base_dir) + + def _generate_log_message(self, index): + """生成日志消息""" + return (f"this is the create txt file for IO check in cluster, you can ignore this information " + f"and check the result in another file [{time.strftime('%Y-%m-%d %H:%M:%S')}] " + f"这是第 {index + 1} 条日志信息。\n") + + def generate_random_weights(self, shape=(4, 2048, 4096)): + """生成随机权重""" + return np.random.randn(*shape).astype(np.float32) + + def _get_file_path(self, index, file_type): + """生成文件路径""" + if file_type == "txt": + return os.path.join(self.base_dir, f'worker_{self.global_worker_id}_data_{index}.txt') + elif file_type == "npy": + return os.path.join(self.base_dir, f'worker_{self.global_worker_id}_data_{index}.npy') + else: + raise ValueError(f"file type {file_type} is not included in the list [txt, npy]") + + def _is_local_rank_zero(self): + """检查本地排名是否为 0""" + return self.local_rank is not None and int(self.local_rank) == 0 + + def _calculate_speed(self, start_time, end_time, data_size): + """计算读写速度""" + elapsed_time = end_time - start_time + return data_size / elapsed_time / (1024 ** 2) + + def generate_file_like_log(self): + """生成数据文件并计算写入速度""" + index = self.local_rank + file_path = self._get_file_path(index, file_type='txt') + start_time = time.time() + total_data_size = 0 + + with open(file_path, 'w', encoding='utf-8') as file: + for i in range(10000): + log_message = self._generate_log_message(i) * 100 + file.write(log_message) + file.flush() + total_data_size += len(log_message) + end_time = time.time() + write_speed = self._calculate_speed(start_time, end_time, total_data_size) + self.results['write_cost']['log_write'] = write_speed + return file_path + + def read_file_like_log(self, file_path=None): + """读取单个文件并计算读取速度""" + if file_path is None: + file_path = self._get_file_path(self.local_rank, file_type='txt') + try: + start_time = time.time() + data = "" + with open(file_path, 'r', encoding='utf-8') as file: + for line in file: + data += line + end_time = time.time() + data_size = len(data.encode('utf-8')) + read_speed = self._calculate_speed(start_time, end_time, data_size) + self.results['read_cost']['log_read'] = read_speed + return data + except FileNotFoundError: + logging.error(f"File {file_path} not found.") + except UnicodeDecodeError: + logging.error(f"Error decoding {file_path} as UTF-8.") + except Exception as e: + logging.error(f"Unexpected error reading {file_path}: {e}") + return None + + def generate_file_like_ckpt(self): + """生成大数据文件并计算写入速度""" + index = self.local_rank + file_path = self._get_file_path(index, file_type='npy') + weight = self.generate_random_weights() + total_data_size = weight.size * weight.itemsize + start_time = time.time() + np.save(file_path, weight) + end_time = time.time() + write_speed = self._calculate_speed(start_time, end_time, total_data_size) + self.results['write_cost']['ckpt_write'] = write_speed + return file_path + + def read_file_like_ckpt(self, file_path=None): + """读取大的单个文件并计算读取速度""" + if file_path is None: + file_path = self._get_file_path(self.local_rank, file_type='npy') + try: + start_time = time.time() + data = np.load(file_path) + end_time = time.time() + data_size = data.nbytes + read_speed = self._calculate_speed(start_time, end_time, data_size) + self.results['read_cost']['ckpt_read'] = read_speed + return data + except FileNotFoundError: + logging.error(f"File {file_path} not found.") + except UnicodeDecodeError: + logging.error(f"Error decoding {file_path} as UTF-8.") + except Exception as e: + logging.error(f"Unexpected error reading {file_path}: {e}") + return None + + def clean_cache_file(self, file_type=None): + """执行完删除缓存文件,避免磁盘空间不足""" + if file_type == 'txt': + file_path = self._get_file_path(self.local_rank, file_type='txt') + elif file_type == 'npy': + file_path = self._get_file_path(self.local_rank, file_type='npy') + else: + raise ValueError(f"no such file type {file_type} could be clean as cache file loaded temperaly") + try: + os.remove(file_path) + except FileNotFoundError: + logging.error(f"File {file_path} not found, cannot remove.") + except PermissionError: + logging.error(f"Permission denied when trying to remove {file_path}.") + except Exception as e: + logging.error(f"Unexpected error removing {file_path}: {e}") + + def dump_results(self): + """生成临时结果用于分析""" + try: + with open(self.output, 'w') as f: + json.dump(self.results, f, indent=4) + logging.info(f"Results have been saved to {self.output}") + except Exception as e: + logging.error(f"Error saving results to {self.output}: {e}") + + + - def __init__(self, **kwargs): - super().__init__(**kwargs) - def check(self): - pass +if __name__ == "__main__": + io = IOCheck() + io.generate_file_like_log() + io.read_file_like_log() + io.generate_file_like_ckpt() + io.read_file_like_ckpt() + io.clean_cache_file('txt') + io.clean_cache_file('npy') + print(io.results) \ No newline at end of file