diff --git a/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py b/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py index 7ccd1e542eee2050542a08df62e1720a9cdf4dcb..d990c8656726d6ab1d545a25b45293a4f64b4220 100644 --- a/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py +++ b/profiler/msprof_analyze/precheck/distributed_cluster/distributed_cluster_base.py @@ -12,8 +12,226 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import socket +from datetime import timedelta +from typing import List, Optional +from msprof_analyze.precheck.collect.collector import Collector +import torch +import torch_npu class DistributedClusterBase: - def __init__(self): - pass + def __init__(self, world_size: int, num_nodes: int): + self.world_size = world_size + self.num_node = num_nodes + self.rank = int(os.getenv('RANK', '0')) + self.device_count = torch.npu.device_count() + self.local_rank = self.rank % self.device_count + self.hostname = socket.gethostname() + + + def initialize_cluster_distributed(self): + # -----> 1. 全局初始化 + if torch.distributed.is_initialized(): + if self.rank == 0: + print("分布式集群已完成初始化") + else: + if self.rank == 0: + print("开始分布式集群初始化...") + torch.distributed.init_process_group( + backend="hccl", + world_size=self.world_size, + rank=self.rank, + ) + + # -----> 2. 计算local_rank + # 获取当前节点全局rank编号 + self.rank = torch.distributed.get_rank() + + # -----> 3. 设置NPU上下文 + torch.npu.set_device(self.local_rank) + + # -----> 4. 获取hostname与global rank_id的匹配关系 + global _HOSTNAME + global _GLOBAL_RANK_ID + _HOSTNAME = self.hostname + _GLOBAL_RANK_ID = self.rank + + + def initialize_communication_group( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + distributed_timeout_minutes: int = 30, + ) -> List[List[List[int]]]: + # -----> 1. 获取所需参数 && 参数校验 + # 检查是否初始化,并获取world_size + assert torch.distributed.is_initialized() + self.world_size: int = torch.distributed.get_world_size() + + # 校验tp, pp, cp给定参数是否合理 + if ( + self.world_size + % (tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size) + != 0 + ): + raise RuntimeError( + f"world_size ({self.world_size}) is not divisible by tensor_model_parallel_size " + f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size}) " + f"x context_parallel_size ({context_parallel_size})" + ) + + # 计算dp + data_parallel_size: int = self.world_size // ( + tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + ) + + # 校验ep设置 + if data_parallel_size % expert_model_parallel_size != 0: + raise RuntimeError( + f"data_parallel_size ({data_parallel_size}) is not divisible by expert_model_parallel_size " + ) + + # 不支持同时开启ep、cp + if expert_model_parallel_size > 1 and context_parallel_size > 1: + raise RuntimeError( + f"combination of expert model prallellism and context parallelism is not supported" + ) + + num_tensor_model_parallel_groups: int = self.world_size // tensor_model_parallel_size + num_pipeline_model_parallel_groups: int = self.world_size // pipeline_model_parallel_size + + self.rank = torch.distributed.get_rank() + timeout = timedelta(minutes=distributed_timeout_minutes) + + if self.rank == 0: + print("data parallel: ", data_parallel_size) + print("tensor parallel: ", tensor_model_parallel_size) + print("pipeline parallel: ", pipeline_model_parallel_size) + print("context parallel: ", context_parallel_size) + print("expert parallel: ", expert_model_parallel_size) + + data_parallel_group = [] + tensor_parallel_group = [] + pipeline_parallel_group = [] + context_parallel_group = [] + expert_parallel_group = [] + + # -----> 2. 创建 data-parallel groups + global _DATA_PARALLEL_GROUP + global _DATA_PARALLEL_GLOBAL_RANKS + for i in range(pipeline_model_parallel_size): + start_rank = i * num_pipeline_model_parallel_groups + end_rank = (i + 1) * num_pipeline_model_parallel_groups + for j in range(context_parallel_size * tensor_model_parallel_size): + ranks = range( + start_rank + j, end_rank, context_parallel_size * tensor_model_parallel_size + ) + group = torch.distributed.new_group(ranks, timeout=timeout, backend="hccl") + data_parallel_group.append(list(ranks)) + if self.rank in ranks: + _DATA_PARALLEL_GROUP = group + _DATA_PARALLEL_GLOBAL_RANKS = ranks + + # -----> 3. 创建 context-parallel groups + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_GLOBAL_RANKS + for i in range(pipeline_model_parallel_size): + for j in range(data_parallel_size): + start_rank = ( + i * num_pipeline_model_parallel_groups + + j * tensor_model_parallel_size * context_parallel_size + ) + end_rank = ( + i * num_pipeline_model_parallel_groups + + (j + 1) * tensor_model_parallel_size * context_parallel_size + ) + for k in range(tensor_model_parallel_size): + ranks = range(start_rank + k, end_rank, tensor_model_parallel_size) + group = torch.distributed.new_group(ranks, timeout=timeout, backend="hccl") + context_parallel_group.append(list(ranks)) + if self.rank in ranks: + _CONTEXT_PARALLEL_GROUP = group + _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks + + # -----> 4. 创建 tensor model-parallel groups + global _TENSOR_MODEL_PARALLEL_GROUP + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks, timeout=timeout, backend="hccl") + tensor_parallel_group.append(list(ranks)) + if self.rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # -----> 5. 创建 pipeline model-parallel groups + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, self.world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks, timeout=timeout, backend="hccl") + pipeline_parallel_group.append(list(ranks)) + if self.rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + + # -----> 6. 创建 expert model-parallel groups + global _EXPERT_MODEL_PARALLEL_GROUP + tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size + num_tensor_and_data_groups: int = self.world_size // tensor_and_data_group_size + tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size + num_expert_groups: int = data_parallel_size // expert_model_parallel_size + for i in range(num_tensor_and_data_groups): + for j in range(num_expert_groups): + start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size + end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size + for k in range(tensor_model_parallel_size * context_parallel_size): + ranks = range( + start_rank + k, end_rank, tensor_model_parallel_size * context_parallel_size + ) + group = torch.distributed.new_group(ranks, timeout=timeout, backend="hccl") + expert_parallel_group.append(list(ranks)) + if self.rank in ranks: + _EXPERT_MODEL_PARALLEL_GROUP = group + + if self.rank == 0: + print("data parallel group: ", data_parallel_group) + print("tensor parallel group: ", tensor_parallel_group) + print("pipeline parallel group: ", pipeline_parallel_group) + print("context parallel group: ", context_parallel_group) + print("expert parallel group: ", expert_parallel_group) + return data_parallel_group, tensor_parallel_group, pipeline_parallel_group, \ + context_parallel_group, expert_parallel_group + + + def collect_global_info( + self, + input_file_dir: str = None, + output_file_dir: str = None, + num_node: int = None, + node_rank: int = None, + master_addr: str = None, + master_port: int = None, + master_rank_num: int = None, + split_file_size: int = None, + time_out: int = None, + log_file: str = None + ): + collector = Collector() + + args_dict = { + "input_file_dir": input_file_dir, + "output_file_dir": output_file_dir, + "num_node": num_node, + "node_rank": node_rank, + "master_addr": master_addr, + "master_port": master_port, + "master_rank_num": master_rank_num, + "split_file_size": split_file_size, + "time_out": time_out, + "log_file": log_file + } + + collector.run(args_dict) \ No newline at end of file