diff --git a/cli/infer_vllm.py b/cli/infer_vllm.py index 70e7b0f5ecac7f0ca035ad2500a6df6eaed6796c..35227d61c09e29ac48715cc8eefa4bb657c9f189 100644 --- a/cli/infer_vllm.py +++ b/cli/infer_vllm.py @@ -100,6 +100,7 @@ def main(): prompt_type_path=args.prompt_type_path, train_tensor_parallel_size=args.tensor_parallel_size, train_pipeline_parallel_size=1, + train_context_parallel_size=1, infer_tensor_parallel_size=args.tensor_parallel_size, infer_pipeline_parallel_size=1, max_num_seqs=1, @@ -109,8 +110,8 @@ def main(): ) if args.load_format == "megatron": - tp_rank = ps._TP.rank_in_group - weights_path = os.path.join(args.load, f"iter_0000001/mp_rank_{tp_rank:02}/model_optim_rng.pt") + ep_rank = ps._EP.rank_in_group + weights_path = os.path.join(args.load, f"iter_0000001/mp_rank_00_{ep_rank:03}/model_optim_rng.pt") actor_weights = torch.load(weights_path)['model'] actor_weights = replace_state_dict_name( diff --git a/configs/envs/runtime_env.yaml b/configs/envs/runtime_env.yaml index 1525a2f02aa7778be9fbbc37baca1e6468c98cba..b7162a816b4e4747c2a138c00b762648b61e3ff7 100644 --- a/configs/envs/runtime_env.yaml +++ b/configs/envs/runtime_env.yaml @@ -8,6 +8,7 @@ env_vars: HCCL_IF_BASE_PORT: '48000' CUDA_DEVICE_MAX_CONNECTIONS: '1' HYDRA_FULL_ERROR: '1' + VLLM_DP_SIZE: '8' # GLOO_SOCKET_IFNAME: "Your SOCKET IFNAME" # TP_SOCKET_IFNAME: "Your SOCKET IFNAME" # HCCL_SOCKET_IFNAME: "Your SOCKET IFNAME" \ No newline at end of file diff --git a/mindspeed_rl/config_cls/generate_config.py b/mindspeed_rl/config_cls/generate_config.py index 6cf7642a78408a8961587c6d89b9b3b91eee9c59..6a37f69b65f87f6654815a8a9c094bdd1f8e41cb 100644 --- a/mindspeed_rl/config_cls/generate_config.py +++ b/mindspeed_rl/config_cls/generate_config.py @@ -19,6 +19,7 @@ class GenerateConfig(BaseConfig): infer_tensor_parallel_size: Tensor parallel size during inference. Default is 8. infer_pipeline_parallel_size: Pipeline parallel size during inference. Default is 1. infer_expert_parallel_size: Expert parallel size during inference. Default is 1. + infer_context_parallel_size: Context parallel size during inference. Default is 1. max_num_seqs: Maximum number of sequences to process simultaneously. Default is 256. max_model_len: Maximum model length (in tokens). Default is 2048. @@ -53,6 +54,9 @@ class GenerateConfig(BaseConfig): # 推理时的专家并行大小,默认为 1 self.infer_expert_parallel_size = 1 + # 推理时的文本并行大小,默认为1 + self.infer_context_parallel_size = 1 + # 最大可处理的序列数量,默认为 1 self.max_num_seqs = 1 diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index 00e400faedf94b31293296ddeba8e517eae2d763..ceab667de637f0764bfe1e2feb1fefc11040449d 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -23,6 +23,19 @@ def validate_rl_args( if rl_config.reward_resource is not None: raise ValueError( f" Reward model is not supported when use_integrated_worker mode is on.") + + # resharding ep check + # TODO: + infer_actual_ep = generate_config.infer_expert_parallel_size * generate_config.infer_tensor_parallel_size + train_actual_ep = actor_config.expert_model_parallel_size * actor_config.tensor_model_parallel_size + if infer_actual_ep % train_actual_ep == 0: + expert_expand_N = infer_actual_ep // train_actual_ep + if expert_expand_N > 1: + if actor_config.moe_tp_extend_ep is False: + raise RuntimeError('use resharding ep expand, --moe_tp_extend_ep must be true') + else: + raise RuntimeError('infer_expert_parallel_size must be an integer multiple of expert_model_parallel_size') + # 校验序列长度与模型最大长度 if generate_config.max_model_len < actor_config.seq_length: diff --git a/mindspeed_rl/models/base/base_inference_engine.py b/mindspeed_rl/models/base/base_inference_engine.py index 5719cd8099588cd791a381ada13be941bc2ffe2a..14b4c26b5b699e0dd71773cc7adcc7816b583e1c 100644 --- a/mindspeed_rl/models/base/base_inference_engine.py +++ b/mindspeed_rl/models/base/base_inference_engine.py @@ -17,9 +17,11 @@ class BaseInferEngine(ABC): prompt_type: str = None, prompt_type_path: str = None, train_expert_parallel_size: int = 1, + train_context_parallel_size: int = 1, infer_tensor_parallel_size: int = 8, infer_pipeline_parallel_size: int = 1, infer_expert_parallel_size: int = 1, + infer_context_parallel_size: int = 1, max_num_seqs: int = 1, # Default value set to 1 max_model_len: int = 2048, # Default value set to 2048 dtype: str = "bfloat16", # Default value set to "bfloat16" @@ -34,9 +36,11 @@ class BaseInferEngine(ABC): train_tensor_parallel_size (int): Tensor parallel size during training. train_pipeline_parallel_size (int): Pipeline parallel size during training. train_expert_parallel_size (int): Expert parallel size during training. + train_context_parallel_size (int): Context parallel size during training. infer_tensor_parallel_size (int): Tensor parallel size during inference. infer_pipeline_parallel_size (int): Pipeline parallel size during inference. infer_expert_parallel_size (int): Expert parallel size during inference. + infer_context_parallel_size (int): Context parallel size during inference. max_num_seqs (int): Maximum number of sequences to process simultaneously. Default is 1. max_model_len (int): Maximum model length (in tokens). Default is 2048. dtype (str): Data type for model weights. Default is "bfloat16". @@ -49,9 +53,11 @@ class BaseInferEngine(ABC): self.train_tensor_parallel_size = train_tensor_parallel_size self.train_pipeline_parallel_size = train_pipeline_parallel_size self.train_expert_parallel_size = train_expert_parallel_size + self.train_context_parallel_size = train_context_parallel_size self.infer_tensor_parallel_size = infer_tensor_parallel_size self.infer_pipeline_parallel_size = infer_pipeline_parallel_size self.infer_expert_parallel_size = infer_expert_parallel_size + self.infer_context_parallel_size = infer_context_parallel_size self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.dtype = dtype diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index a6264f29f466b73a13a3d78801b282c67ed1556d..e00488a0498386fcd662ad38398cb7f25f7fded5 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -16,10 +16,11 @@ from vllm.model_executor.models import ModelRegistry class InferParallelConfig: - def __init__(self, infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int): + def __init__(self, infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int, infer_context_parallel_size: int): self.infer_tensor_parallel_size = infer_tensor_parallel_size self.infer_pipeline_parallel_size = infer_pipeline_parallel_size self.infer_expert_parallel_size = infer_expert_parallel_size + self.infer_context_parallel_size = infer_context_parallel_size def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module, @@ -88,9 +89,10 @@ def deepseek_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module, if name not in params_dict.keys(): raise ValueError(f"unexpected key {name} in deepseek_megatron_weight_loader") if "mlp.experts.w13_weight" in name: - loaded_weight.copy_(loaded_weight.view(hf_config.n_routed_experts, hf_config.hidden_size, -1).transpose(2, 1).contiguous()) + # TODO: 可能又一定的问题,比较简单粗暴,直接与infer_paralle_config.infer_expert_parallel_size挂钩 + loaded_weight.copy_(loaded_weight.view(hf_config.n_routed_experts // infer_paralle_config.infer_expert_parallel_size, hf_config.hidden_size, -1).transpose(2, 1).contiguous()) if "mlp.experts.w2_weight" in name: - loaded_weight.copy_(loaded_weight.view(hf_config.n_routed_experts, -1, hf_config.hidden_size).transpose(2, 1).contiguous()) + loaded_weight.copy_(loaded_weight.view(hf_config.n_routed_experts // infer_paralle_config.infer_expert_parallel_size, -1, hf_config.hidden_size).transpose(2, 1).contiguous()) load_single_weight(params_dict, name, loaded_weight) return vllm_model @@ -176,6 +178,8 @@ MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY = { "LlamaForCausalLM": llama_megatron_core_weight_loader, "Qwen2ForCausalLM": qwen_megatron_weight_loader, "DeepseekV3ForCausalLM": deepseek_megatron_weight_loader, + "DeepseekV2ForCausalLM": deepseek_megatron_weight_loader, + "CustomDeepseekV2ForCausalLM": deepseek_megatron_weight_loader, } diff --git a/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py b/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py index f4c6bc80d028bceec9856086ce125770db1d0306..e778dc8b358adcb29969f6e29d9c98f79ac76913 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/vllm_parallel_state.py @@ -9,6 +9,8 @@ from typing import Optional import torch import torch.distributed import vllm.distributed.parallel_state as ps +import vllm.envs as envs +from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_pp_group, @@ -17,7 +19,6 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, ) - """ This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. - We assume the Megatron tp+dp+pp world is already established before calling this function. @@ -31,20 +32,31 @@ _DEVICE_MESH = None _TP = None # Pipeline model parallel group that the current rank belongs to. _PP = None +# Expert model parallel group that the current rank belongs to. +_EP = None +# Expert tensor model parallel group that the current rank belongs to. +_ETP = None +# Data model parallel group that the current rank belongs to. +_DP = None # This method is for initializing the ParallelGroup when using HybridEngine def initialize_parallel_state( - distributed_init_method: str = "env://", - backend: str = "hccl", - infer_tensor_model_parallel_size: int = 1, - train_tensor_model_parallel_size: int = 1, - infer_pipeline_model_parallel_size: int = 1, - train_pipeline_model_parallel_size: int = 1 + distributed_init_method: str = "env://", + backend: str = "hccl", + infer_tensor_model_parallel_size: int = 1, + train_tensor_model_parallel_size: int = 1, + infer_pipeline_model_parallel_size: int = 1, + train_pipeline_model_parallel_size: int = 1, + infer_expert_tensor_parallel_size: int = 1, + train_expert_tensor_parallel_size: int = 1, + train_expert_model_parallel_size: int = 1, + infer_expert_model_parallel_size: int = 1, + infer_context_model_parallel_size: int = 1, + train_context_model_parallel_size: int = 1, ): os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. rank = int(os.getenv("RANK", "-1")) local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -53,6 +65,8 @@ def initialize_parallel_state( world_size = int(os.getenv("WORLD_SIZE", "-1")) if world_size == -1: raise ValueError("The world_size is set to -1, not initialized by TORCHRUN") + config = get_current_vllm_config() + config.parallel_config.tensor_parallel_size = infer_tensor_model_parallel_size init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) if torch.distributed.get_world_size() > 1: # NOTE: build a sepearate inference group with infer tp & micro dp @@ -60,54 +74,30 @@ def initialize_parallel_state( infer_tensor_model_parallel_size=infer_tensor_model_parallel_size, train_tensor_model_parallel_size=train_tensor_model_parallel_size, infer_pipeline_model_parallel_size=infer_pipeline_model_parallel_size, - train_pipeline_model_parallel_size=train_pipeline_model_parallel_size + train_pipeline_model_parallel_size=train_pipeline_model_parallel_size, + infer_expert_tensor_parallel_size=infer_expert_tensor_parallel_size, + train_expert_tensor_parallel_size=train_expert_tensor_parallel_size, + train_expert_model_parallel_size=train_expert_model_parallel_size, + infer_expert_model_parallel_size=infer_expert_model_parallel_size, + infer_context_model_parallel_size=infer_context_model_parallel_size, + train_context_model_parallel_size=train_context_model_parallel_size ) else: initialize_model_parallel(infer_tensor_model_parallel_size, infer_pipeline_model_parallel_size, backend) -def ensure_model_parallel_initialized( - tensor_model_parallel_size: int, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, -) -> None: - """Helper to initialize model parallel groups if they are not initialized, - or ensure tensor-parallel and pipeline-parallel sizes are equal to expected - values if the model parallel groups are initialized. - """ - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) - return - - current_tp_size = get_tensor_model_parallel_world_size() - if current_tp_size != tensor_model_parallel_size: - raise ValueError( - "tensor parallel group already initialized, but of unexpected size: " - f"{current_tp_size=} vs. " - f"{tensor_model_parallel_size=}" - ) - pp_world_size = get_pp_group().world_size - if pp_world_size != pipeline_model_parallel_size: - raise ValueError( - "pipeline parallel group already initialized, but of unexpected size: " - f"{pp_world_size=} vs. " - f"{pipeline_model_parallel_size=}" - ) - - -def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" - return ps._TP is not None - # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) - - def initialize_model_parallel_for_vllm( - infer_tensor_model_parallel_size: int, - train_tensor_model_parallel_size: int = 1, - infer_pipeline_model_parallel_size: int = 1, - train_pipeline_model_parallel_size: int = 1 + infer_tensor_model_parallel_size: int, + train_tensor_model_parallel_size: int = 1, + infer_pipeline_model_parallel_size: int = 1, + train_pipeline_model_parallel_size: int = 1, + infer_expert_tensor_parallel_size: int = 1, + train_expert_tensor_parallel_size: int = 1, + train_expert_model_parallel_size: int = 1, + infer_expert_model_parallel_size: int = 1, + infer_context_model_parallel_size: int = 1, + train_context_model_parallel_size: int = 1, + num_process: int = 1, ) -> None: # Get world size and rank. Ensure some consistencies. @@ -142,8 +132,10 @@ def initialize_model_parallel_for_vllm( Returns: list of group_lists [[g0, g1], [g2, g3], [g4, g5], [g6, g7]] ''' - if ((world_size // (train_tensor_model_parallel_size * train_pipeline_model_parallel_size)) * train_tensor_model_parallel_size < infer_tensor_model_parallel_size or - ((world_size // (train_tensor_model_parallel_size * train_pipeline_model_parallel_size)) * train_tensor_model_parallel_size) % infer_tensor_model_parallel_size != 0): + if ((world_size // ( + train_tensor_model_parallel_size * train_pipeline_model_parallel_size)) * train_tensor_model_parallel_size < infer_tensor_model_parallel_size or + ((world_size // ( + train_tensor_model_parallel_size * train_pipeline_model_parallel_size)) * train_tensor_model_parallel_size) % infer_tensor_model_parallel_size != 0): raise ValueError( f"Can't split train tp size {train_tensor_model_parallel_size} to infer tp size {infer_tensor_model_parallel_size} " f"with train dp size {(world_size // (train_tensor_model_parallel_size * train_pipeline_model_parallel_size))}.") @@ -172,16 +164,13 @@ def initialize_model_parallel_for_vllm( [[g0, g2], [g1, g3], [g4, g6], [g5, g7]] ''' if train_tensor_model_parallel_size < infer_tensor_model_parallel_size or train_tensor_model_parallel_size % infer_tensor_model_parallel_size != 0: - raise ValueError(f"Can't gather train tp size {train_tensor_model_parallel_size} to infer tp size {infer_tensor_model_parallel_size}") + raise ValueError( + f"Can't gather train tp size {train_tensor_model_parallel_size} to infer tp size {infer_tensor_model_parallel_size}") num_tensor_model_parallel_groups = world_size // infer_tensor_model_parallel_size - num_tensor_model_parallel_groups_per_train_tp = train_tensor_model_parallel_size // infer_tensor_model_parallel_size group_ranks = [] - for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): - start = train_tensor_model_parallel_size * i - end = train_tensor_model_parallel_size * (i + 1) - for j in range(num_tensor_model_parallel_groups_per_train_tp): - ranks = list(range(start + j, end, num_tensor_model_parallel_groups_per_train_tp)) - group_ranks.append(ranks) + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * infer_tensor_model_parallel_size, (i + 1) * infer_tensor_model_parallel_size)) + group_ranks.append(ranks) return group_ranks @@ -191,9 +180,11 @@ def initialize_model_parallel_for_vllm( else: return get_allgather_tp_group_ranks() + tp_group_ranks = get_tp_group_ranks() + print(f">>>>>>>>>>>>>>>>TP rank: {tp_group_ranks}") _TP = init_model_parallel_group( - group_ranks=get_tp_group_ranks(), + group_ranks=tp_group_ranks, local_rank=get_world_group().local_rank, backend=backend, use_message_queue_broadcaster=True, @@ -208,16 +199,109 @@ def initialize_model_parallel_for_vllm( ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) # pipeline parallel does not need custom allreduce + print(f">>>>>>>>>>>>>>>>PP rank: {group_ranks}") _PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, ) ps._PP = _PP # for verl + data_parallel_size = 1 + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + + if config is not None: + data_parallel_size = config.parallel_config.data_parallel_size + + # all_ranks = torch.arange(world_size).reshape( + # data_parallel_size, infer_pipeline_model_parallel_size, + # infer_tensor_model_parallel_size) + + num_expert_parallel_groups: int = infer_expert_tensor_parallel_size + num_expert_tensor_parallel_groups: int = world_size // infer_expert_tensor_parallel_size + + num_rank_per_process = world_size // num_process + all_ranks = list(range(world_size)) + + # n_process_ranks = [all_ranks[i: (i + 1) * num_rank_per_process] for i in range(num_process)] + + global _EP + assert _EP is None, ("expert parallel group is already initialized") + group_ranks = [] + # TODO: group_ranks + # for process_ranks in range(num_process): + # for i in range(num_expert_parallel_groups): + # ranks = list(range((process_ranks * num_rank_per_process) + i, (process_ranks + 1) * num_rank_per_process, + # num_expert_parallel_groups)) + # group_ranks.append(ranks) + + + tensor_model_parallel_size=infer_tensor_model_parallel_size + context_parallel_size=infer_context_model_parallel_size + expert_model_parallel_size = infer_expert_model_parallel_size + infer_data_parallel_size = world_size // tensor_model_parallel_size // infer_pipeline_model_parallel_size + tensor_and_data_group_size_with_cp: int = tensor_model_parallel_size * infer_data_parallel_size * context_parallel_size + num_tensor_and_data_groups_with_cp: int = world_size // tensor_and_data_group_size_with_cp + num_expert_groups: int = infer_data_parallel_size * context_parallel_size // expert_model_parallel_size + tensor_and_expert_group_size = tensor_model_parallel_size * expert_model_parallel_size + group_ranks = [] + for i in range(num_tensor_and_data_groups_with_cp): + for j in range(num_expert_groups): + start_rank = i * tensor_and_data_group_size_with_cp + j * tensor_and_expert_group_size + end_rank = i * tensor_and_data_group_size_with_cp + (j + 1) * tensor_and_expert_group_size + ranks = range(start_rank, end_rank) + group_ranks.append(list(ranks)) + + print(f">>>>>>>>>>>>>>>>EP rank: {group_ranks}") + + ps._EP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="ep") + + global _ETP + assert _ETP is None, ( + "expert tensor parallel group is already initialized") + + group_ranks = [] + for i in range(num_expert_tensor_parallel_groups): + ranks = list(range(i * infer_expert_tensor_parallel_size, + (i + 1) * infer_expert_tensor_parallel_size)) + group_ranks.append(ranks) + print(f">>>>>>>>>>>>>>>>ETP rank: {group_ranks}") + + ps._ETP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="etp") + + global _DP + assert _DP is None, ("data parallel group is already initialized") + # group_ranks = all_ranks.transpose(0, + # 2).reshape(-1, + # data_parallel_size).unbind(0) + # group_ranks = [x.tolist() for x in group_ranks] + dp_group_ranks = torch.tensor(tp_group_ranks).transpose(0, 1).reshape(-1, data_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in dp_group_ranks] + print(f">>>>>>>>>>>>>>>>DP rank: {group_ranks}") + + ps._DP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="dp") + + os.environ["VLLM_DP_RANK"] = str(ps._DP.rank_in_group) + envs.VLLM_DP_RANK = int(os.environ["VLLM_DP_RANK"]) + os.environ["VLLM_DP_MASTER_IP"] = os.environ.get("MASTER_ADDR") + envs.VLLM_DP_MASTER_IP = os.environ["VLLM_DP_MASTER_IP"] + os.environ["VLLM_DP_MASTER_PORT"] = str( + int(os.environ.get("MASTER_PORT")) + ps.get_tensor_model_parallel_rank() + 1) + envs.VLLM_DP_MASTER_PORT = int(os.environ["VLLM_DP_MASTER_PORT"]) + def initialize_model_parallel( - tensor_model_parallel_size: int = 1, - pipeline_model_parallel_size: int = 1, - backend: Optional[str] = None, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, ) -> None: """ NOTE: This method is a hack from the open-sourced version without @@ -250,8 +334,6 @@ def initialize_model_parallel( world_size: int = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size global _TP if _TP is not None: @@ -284,4 +366,3 @@ def initialize_model_parallel( ) ps._PP = _PP # for verl - diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 16ce4c31b84842730bdad8e945f82d8592705417..df6cf96dcd4283918963cc995e3a3dfe4805c7f0 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -48,9 +48,11 @@ class VLLMInferEngine(BaseInferEngine): train_tensor_parallel_size: int, train_pipeline_parallel_size: int, train_expert_parallel_size: int, + train_context_parallel_size: int, infer_tensor_parallel_size: int, infer_pipeline_parallel_size: int, infer_expert_parallel_size: int, + infer_context_parallel_size: int, megatron_config: MegatronConfig, sampling_config: dict, prompt_type: str = None, @@ -73,9 +75,11 @@ class VLLMInferEngine(BaseInferEngine): train_tensor_parallel_size (int): Tensor parallel size during training. train_pipeline_parallel_size (int): Pipeline parallel size during training. train_expert_parallel_size (int): Expert parallel size during training. + train_context_parallel_size (int): Context parallel size during training. infer_tensor_parallel_size (int): Tensor parallel size during inference. infer_pipeline_parallel_size (int): Pipeline parallel size during inference. infer_expert_parallel_size (int): Expert parallel size during inference. + infer_context_parallel_size (int): Context parallel size furing inference. sampling_config (dict): Configuration for text generation sampling. enable_prefix_caching (bool): Whether to enable prefix caching. num_scheduler_steps (int): Num scheduler steps. Default is 1. @@ -94,9 +98,11 @@ class VLLMInferEngine(BaseInferEngine): train_tensor_parallel_size=train_tensor_parallel_size, train_pipeline_parallel_size=train_pipeline_parallel_size, train_expert_parallel_size=train_expert_parallel_size, + train_context_parallel_size=train_context_parallel_size, infer_tensor_parallel_size=infer_tensor_parallel_size, infer_pipeline_parallel_size=infer_pipeline_parallel_size, infer_expert_parallel_size=infer_expert_parallel_size, + infer_context_parallel_size=infer_context_parallel_size, max_num_seqs=max_num_seqs, max_model_len=max_model_len, dtype=dtype, @@ -145,6 +151,10 @@ class VLLMInferEngine(BaseInferEngine): train_tensor_model_parallel_size=train_tensor_parallel_size, infer_pipeline_model_parallel_size=infer_pipeline_parallel_size, train_pipeline_model_parallel_size=train_pipeline_parallel_size, + train_expert_model_parallel_size=train_expert_parallel_size, + infer_expert_model_parallel_size=infer_expert_parallel_size, + infer_context_model_parallel_size=infer_context_parallel_size, + train_context_model_parallel_size=train_context_parallel_size ) if load_format == "megatron": @@ -221,7 +231,7 @@ class VLLMInferEngine(BaseInferEngine): def sync_model_weights(self, params, load_format='megatron'): infer_parallel_config = InferParallelConfig(self.infer_tensor_parallel_size, self.infer_pipeline_parallel_size, - self.infer_expert_parallel_size) + self.infer_expert_parallel_size*self.infer_tensor_parallel_size, self.infer_context_parallel_size) load_megatron_weights(params, self.model, infer_parallel_config, @@ -239,7 +249,7 @@ class VLLMInferEngine(BaseInferEngine): @torch.no_grad() def generate_sequences(self, idx_list, **kwargs): self.init_cache_engine() - with self.update_sampling_params(**kwargs): + with self.update_sampling_params(ignore_eos=True, **kwargs): response = self.llm.generate( prompts=None, sampling_params=self.sampling_params, diff --git a/mindspeed_rl/utils/tokenizer.py b/mindspeed_rl/utils/tokenizer.py index eb4e9a7dbe702dec16a726368f76299fa5e1abb0..73c75f902ba47e420f4f3f6cd38e252897e66787 100644 --- a/mindspeed_rl/utils/tokenizer.py +++ b/mindspeed_rl/utils/tokenizer.py @@ -259,8 +259,8 @@ class _HuggingFaceTokenizer(BaseTokenizer): def tokenize(self, text): return self.tokenizer(text).input_ids - def detokenize(self, token_ids): - return self.tokenizer.decode(token_ids) + def detokenize(self, token_ids, **kwargs): + return self.tokenizer.decode(token_ids, **kwargs) @property def eod(self): diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 96987d07fb3cfa44f757813843fb0b945aa50553..08f7db4edb7533d6cfecf7032287eeab080710ac 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -319,9 +319,11 @@ class ActorHybridWorkerBase(BaseWorker): train_tensor_parallel_size=self.megatron_config.tensor_model_parallel_size, train_pipeline_parallel_size=self.megatron_config.pipeline_model_parallel_size, train_expert_parallel_size=self.megatron_config.expert_model_parallel_size, + train_context_parallel_size=self.megatron_config.context_parallel_size, infer_tensor_parallel_size=self.generate_config.infer_tensor_parallel_size, infer_pipeline_parallel_size=self.generate_config.infer_pipeline_parallel_size, infer_expert_parallel_size=self.generate_config.infer_expert_parallel_size, + infer_context_parallel_size=self.generate_config.infer_context_parallel_size, megatron_config=self.megatron_config, sampling_config=sampling_config, enable_prefix_caching=self.generate_config.enable_prefix_caching, @@ -332,7 +334,6 @@ class ActorHybridWorkerBase(BaseWorker): gpu_memory_utilization=self.generate_config.gpu_memory_utilization, trust_remote_code=self.generate_config.trust_remote_code ) - return rollout def _build_sharding_manager(self): diff --git a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py index b7e200b64a7d69ef384e34d38b693879a52b3fcd..72dd6ad640cf64478bdf983b08b9abea7fae9d32 100644 --- a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py +++ b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py @@ -21,13 +21,18 @@ Manager used to shard weight and offload/onload optimizer from training stage to from itertools import chain from collections import defaultdict +import os import torch -import torch.distributed +import torch.distributed as dist +import vllm.distributed.parallel_state as ps +from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.workers.resharding.vllm_weight_container import MegatronStyleVllmWeightContainer from mindspeed_rl.workers.resharding.weight_adaptor import get_weight_adaptor - +logger = Loggers( + name="vllm_engine_inference", +) class MegatronOffLoader: def __init__(self, megatron_model=None, optimizer=None, wrap_with_ddp=True): self.optimizer = optimizer @@ -47,17 +52,28 @@ class MegatronOffLoader: self.swap_tensors_to_device(buffer.grad_data, copy_data=False) def offload_optimizer(self): - for param_group in self.optimizer.optimizer.param_groups: - for param in param_group['params']: - param.data = param.data.to("cpu", non_blocking=False) - self.optimizer.optimizer.state = self._move_to_device(self.optimizer.optimizer.state, "cpu") + if hasattr(self.optimizer, "chained_optimizers"): + optimizers = self.optimizer.chained_optimizers + else: + optimizers = [self.optimizer] + for optimizer in optimizers: + for param_group in optimizer.optimizer.param_groups: + for param in param_group['params']: + param.data = param.data.to("cpu", non_blocking=False) + optimizer.optimizer.state = self._move_to_device(optimizer.optimizer.state, + "cpu") def onload_optimizer(self): - for param_group in self.optimizer.optimizer.param_groups: - for param in param_group['params']: - param.data = param.data.to(torch.cuda.current_device(), non_blocking=False) - self.optimizer.optimizer.state = self._move_to_device(self.optimizer.optimizer.state, - torch.cuda.current_device()) + if hasattr(self.optimizer, "chained_optimizers"): + optimizers = self.optimizer.chained_optimizers + else: + optimizers = [self.optimizer] + for optimizer in optimizers: + for param_group in optimizer.optimizer.param_groups: + for param in param_group['params']: + param.data = param.data.to(torch.cuda.current_device(), non_blocking=False) + optimizer.optimizer.state = self._move_to_device(optimizer.optimizer.state, + torch.cuda.current_device()) def _move_to_device(self, data, device): if isinstance(data, defaultdict): @@ -148,6 +164,7 @@ class MegatronShardingManager: weight_adaptor = get_weight_adaptor(self.inference_engine.model.__class__.__name__) self.weight_adaptor = weight_adaptor(model_config) + # vllm_weight_container 是实例化的Container,里面包含了memory buffer self.vllm_weight_container = MegatronStyleVllmWeightContainer( megatron_model=megatron_model, vllm_model=self.inference_engine.model, @@ -165,7 +182,7 @@ class MegatronShardingManager: self.grad_offload = grad_offload self.train_param_offload = train_param_offload self.enable_validate = enable_validate - self.use_distributed_optimizer = self.optimizer.config.use_distributed_optimizer + # self.use_distributed_optimizer = self.optimizer.config.use_distributed_optimizer self.inference_engine.offload_model_weights() self.megatron_offloader = megatron_offloader @@ -177,7 +194,10 @@ class MegatronShardingManager: def onload_infer_params(self): infer_weight_buffers = self.vllm_weight_container.weight_buffers for buffer in infer_weight_buffers: + print(f"==========buffer{buffer}") buffer.rebuild() + # buffer.onload() # 为了打一个expert的补丁 + # buffer.rebuild() def enter_infer_mode(self): """ @@ -204,7 +224,7 @@ class MegatronShardingManager: self.megatron_offloader.offload_param() self.inference_engine.sync_model_weights(infer_params, load_format='megatron') - # torch.cuda.empty_cache() + def exit_infer_mode(self): """ @@ -272,3 +292,100 @@ class MegatronShardingManager: self.megatron_offloader.offload_grad() # torch.cuda.empty_cache() +############################################################### +###以下代码提交需要删去,是为了加载推理权重、验证精度用的。 + +def replace_state_dict_name(state_dict, vllm_dict, arch=None): + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("input_layernorm", "input_layernorm"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + # Deepseek add + ("self_attention.linear_qb", "self_attn.q_b_proj"), + ("self_attention.linear_kvb", "self_attn.kv_b_proj"), + ("mlp.router.weight", "mlp.gate.weight"), + ("mlp.router.expert_bias", "mlp.gate.e_score_correction_bias"), + ("mlp.shared_experts.linear_fc1", "mlp.shared_experts.gate_up_proj"), + ("mlp.shared_experts.linear_fc2", "mlp.shared_experts.down_proj"), + ("mlp.experts.weight1", "mlp.experts.w13_weight"), + ("mlp.experts.weight2", "mlp.experts.w2_weight"), + ("self_attention.q_layernorm", "self_attn.q_a_layernorm"), + ("self_attention.k_layernorm", "self_attn.kv_a_layernorm"), + ] + + + new_state_dict = {} + for name, loaded_weight in state_dict.items(): + if "_extra_state" in name: + continue + if "Deepseek" in arch: + name = _replace_name_m2v_deepseek(name, params_mapping) + else: + name = _replace_name_m2v(name, params_mapping) + + # the router bias in raw weight in fp32 + if "e_score_correction_bias" in name: + loaded_weight = loaded_weight.to(vllm_dict[name].dtype) + + # to adapter 'copy_' in megatron weight loader to save memory + if "mlp.experts" in name: + loaded_weight = loaded_weight.view(vllm_dict[name].shape) + + new_state_dict[name] = loaded_weight + return new_state_dict + + +def _replace_name_m2v(name, name_mapping): + """ + Transfer state dict names from megatron to vllm. + """ + for m_name, v_name in name_mapping: + if m_name not in name: + continue + if "layers" in name: # deal with decoder layers + name = name.replace("decoder", "model") + name_list = name.split(".") + if "layer_norm_weight" in name_list or "layer_norm_bias" in name_list: + param_name_list = name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = name_list[:3] + weight_or_bias = name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = name.replace(m_name, v_name) + return param_name + return name + + +def _replace_name_m2v_deepseek(name, name_mapping): + """ + Transfer state dict names from megatron to vllm. + """ + for m_name, v_name in name_mapping: + if m_name not in name: + continue + if "layers" in name: # deal with decoder layers + name = name.replace("decoder", "model") + param_name = name.replace(m_name, v_name) + return param_name + return name + +def print_memory(content, condition=True): + if condition: + torch.cuda.empty_cache() + print( + f'{content}', + f'torch allocated {torch.cuda.memory_allocated() / 1e9:.4f} GB, ' + f'reserved {torch.cuda.memory_reserved() / 1e9:.4f} GB') \ No newline at end of file diff --git a/mindspeed_rl/workers/resharding/memory_buffer.py b/mindspeed_rl/workers/resharding/memory_buffer.py index 022a94690bee4da6df387bc198be3d36036203f9..294ec7a445d9be64fcf1942df0e263275aebae2e 100644 --- a/mindspeed_rl/workers/resharding/memory_buffer.py +++ b/mindspeed_rl/workers/resharding/memory_buffer.py @@ -71,16 +71,23 @@ class MemoryBuffer: if param_name not in self.tensor_indices: raise KeyError(f"Parameter {param_name} not found in the buffer.") - start_index, shape = self.tensor_indices[param_name] + start_index, shape = self.tensor_indices[param_name] # weight_name -- index shape return self.get(shape, start_index) def calc_padded_numel(shape: torch.Size, dtype: torch.dtype): """for cuda memory alignment, make sure alignment by 128-bits""" - align_numel = 128 // torch.finfo(dtype).bits - numel = shape.numel() - return (numel + align_numel - 1) // align_numel * align_numel + align_numel = 128 // torch.finfo(dtype).bits #计算在 128 位对齐下,一个内存块中可以容纳的张量元素数 + numel = shape.numel() #计算张量的总元素数量 + return (numel + align_numel - 1) // align_numel * align_numel #调整总元素数(向上取整),满足 128 位对齐。 +# 构建EP增大的buffer———构造一个experts_weight_buffer_meta,存放 +def get_weight_buffer_meta_from_buffer(weight_buffer_meta) -> Dict[str, Dict]: + experts_weight_buffer_meta = {} + for name, meta_info in sorted(weight_buffer_meta.items()): + if "mlp.experts" in name: + experts_weight_buffer_meta[name] = meta_info + return experts_weight_buffer_meta def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype, MemoryBuffer]: """Build the memory buffer given weight_buffer_meta @@ -123,8 +130,60 @@ def build_memory_buffer(weight_buffer_meta: Dict[str, Dict]) -> Dict[torch.dtype return memory_buffers -def build_model_weight_buffer(model: nn.Module, names_per_pp: List[str], get_weight_buffer_meta): - memory_buffers = [ModelWeightBuffer(model, weight_names, get_weight_buffer_meta) for weight_names in names_per_pp] +def build_experts_memory_buffer(experts_weight_buffer_meta: Dict[str, Dict], experts_memory_expend_N) -> Dict[torch.dtype, MemoryBuffer]: + """Build the experts memory buffer given experts_weight_buffer_meta + + Args: + weight_buffer_meta: contains mapping from name to a dictionary containing shape and dtype of the tensors + + Returns: a large memory buffer for each dtype that can hold all the tensors + + """ + experts_memory_buffers = {} + total_numel_map = {} # map from dtype to the total numel + + for name, meta_info in sorted(experts_weight_buffer_meta.items()): + shape = meta_info['shape'] + shape = torch.Size([experts_memory_expend_N, shape[0], shape[1], shape[2]]) + dtype = meta_info['dtype'] + + if not isinstance(shape, torch.Size): + raise TypeError("Shape must be an instance of torch.Size") + if not isinstance(dtype, torch.dtype): + raise TypeError("dtype must be an instance of torch.dtype") + if dtype not in total_numel_map: + total_numel_map[dtype] = 0 + + tmp_numel = calc_padded_numel(shape, dtype) + total_numel_map[dtype] += tmp_numel + + # 根据计算得到的dtype以及分配的空间,构造memory_buffers字典 + for dtype, total_numel in total_numel_map.items(): + # Create a buffer for each dtype with the total numel + experts_memory_buffers[dtype] = MemoryBuffer(total_numel, total_numel, dtype) + + # Now, insert each tensor's index and shape for later retrieval by name + current_index_map = {} # This keeps track of the current memory index for each dtype + for name, meta_info in sorted(experts_weight_buffer_meta.items()): + shape = meta_info['shape'] + + shape = torch.Size([experts_memory_expend_N, shape[0], shape[1], shape[2]]) + + dtype = meta_info['dtype'] + buffer = experts_memory_buffers[dtype] + tensor_size = calc_padded_numel(shape, dtype) + + # 第一次取dtype对应的索引,为0; 取过之后更新为start_index + tensor_size;都保存在current_index_map中,用dtype做为索引 + start_index = current_index_map.get(dtype, 0) + current_index_map[dtype] = start_index + tensor_size + + buffer.tensor_indices[name] = (start_index, shape) + + return experts_memory_buffers + + +def build_model_weight_buffer(model: nn.Module, names_per_pp: List[str], get_weight_buffer_meta, experts_memory_expend_N): + memory_buffers = [ModelWeightBuffer(model, weight_names, get_weight_buffer_meta, experts_memory_expend_N) for weight_names in names_per_pp] return memory_buffers @@ -133,13 +192,24 @@ class ModelWeightBuffer: A factory class that processes a model's state_dict and returns memory buffers for the model parameters. It also provides a mapping between model parameter names and their corresponding memory buffer view. """ - def __init__(self, model: nn.Module, weight_names: List, get_weight_buffer_meta): + def __init__(self, model: nn.Module, weight_names: List, get_weight_buffer_meta, experts_memory_expend_N): self.model = model self.get_weight_buffer_meta = get_weight_buffer_meta self.weight_buffer_meta = self.get_weight_buffer_meta(self.model, weight_names) self.weight_names = list(self.weight_buffer_meta.keys()) self.memory_buffers = None - # self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) + + self.experts_memory_expend_N = experts_memory_expend_N + + if(self.experts_memory_expend_N>0): + # 如果EP增大 + self.experts_weight_buffer_meta = get_weight_buffer_meta_from_buffer(self.weight_buffer_meta) + + self.experts_weight_names = list(self.experts_weight_buffer_meta.keys()) + + self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expend_N) + # self.experts_memory_buffers = None + def __getitem__(self, weight_name: str) -> torch.Tensor: return self.get_weight_by_name(weight_name) @@ -152,14 +222,25 @@ class ModelWeightBuffer: dtype = self.weight_buffer_meta[weight_name]['dtype'] self.memory_buffers[dtype].copy_by_name(weight_name, param) - def offload(self): + def copy_by_experts_name(self, weight_name: str, param): + dtype = self.experts_weight_buffer_meta[weight_name]['dtype'] + self.experts_memory_buffers[dtype].copy_by_name(weight_name, param) + + + def offload(self): # 新代码 直接销毁buffer for memory_buffer in self.memory_buffers.values(): memory_buffer.data = memory_buffer.data.to("cpu", non_blocking=False) - def onload(self): + for experts_memory_buffer in self.experts_memory_buffers.values(): + experts_memory_buffer.data = experts_memory_buffer.data.to("cpu", non_blocking=False) + + def onload(self):# 将缓冲区数据移动到当前 NPU 显存上 for memory_buffer in self.memory_buffers.values(): memory_buffer.data = memory_buffer.data.to(torch.cuda.current_device(), non_blocking=False) + for experts_memory_buffer in self.experts_memory_buffers.values(): + experts_memory_buffer.data = experts_memory_buffer.data.to(torch.cuda.current_device(), non_blocking=False) + def destroy(self): for memory_buffer in self.memory_buffers.values(): memory_buffer = None @@ -167,4 +248,6 @@ class ModelWeightBuffer: def rebuild(self): if self.memory_buffers is None: - self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) \ No newline at end of file + self.memory_buffers = build_memory_buffer(self.weight_buffer_meta) + if(self.experts_memory_expend_N>0): + self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expend_N) \ No newline at end of file diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index a783e90663e8d0241f617cb15b292043173c7ff5..14aa49c5636730e88c9269c869469d708952b1d6 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -25,8 +25,9 @@ import torch.distributed as dist import numpy as np from torch.distributed import new_group +import vllm.distributed.parallel_state as ps -from mindspeed_rl.workers.resharding.memory_buffer import build_model_weight_buffer +from mindspeed_rl.workers.resharding.memory_buffer import build_model_weight_buffer, calc_padded_numel import mindspeed_rl.workers.resharding.utils from mindspeed_rl.workers.resharding.utils import get_tensor_parallel_partition_dim, tp_md5_validate, \ update_md5_by_rank, compute_md5, validate_md5, _build_infer_param_dict, get_tp_allgather_group, \ @@ -66,10 +67,13 @@ class MegatronStyleVllmWeightContainer: self.weight_adaptor = weight_adaptor self._num_hidden_layers = self.model_config.num_hidden_layers + # pp configs self._pp_rank = self.parallel_state.get_pipeline_model_parallel_rank() self._pp_group = self.parallel_state.get_pipeline_model_parallel_group() - self._pp_size = self.parallel_state.get_pipeline_model_parallel_world_size() + self._pp_size = self.parallel_state.get_pipeline_model_parallel_world_size() #PP 内部 + self._world_size = dist.get_world_size() + self.pp_group_size = self._world_size // self._pp_size self._num_layer_list = self._build_num_layer_list(num_layer_list) self._vpp_size = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK else 1 self._vpp_rank = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE else 0 @@ -97,7 +101,13 @@ class MegatronStyleVllmWeightContainer: self._infer_ep_size = infer_expert_parallel_size self.moe_tp_extend_ep = moe_tp_extend_ep - self._world_size = dist.get_world_size() + + + + self.infer_expert_tensor_parallel_size = 1 + self.num_process = 1 + self._infer_ep_size = self._infer_ep_size * self._infer_tp_size + self.experts_memory_expend_N = self._infer_ep_size//self._ep_size # validate parallel configs self._validate_parallel_config() @@ -116,10 +126,9 @@ class MegatronStyleVllmWeightContainer: def _validate_parallel_config(self): if self._infer_pp_size != 1: raise ValueError("infer_pp_size != 1 not supported yet") - if self._infer_ep_size != 1: - raise ValueError("infer_ep_size != 1 not supported yet") - if self._ep_size > 1 and self._ep_size != self._infer_tp_size: - raise ValueError("For training EP, supports EP -> TP only currently.") + + if self._infer_ep_size % self._ep_size != 0: + raise ValueError("The training expert size should be divisibled by the inference expert size.") if self._ep_size > 1 and not self.moe_tp_extend_ep: raise ValueError("To enable training EP, you need to enable moe_tp_extend_ep and use GroupedMLP.") if self._pp_size < self._infer_pp_size: @@ -146,9 +155,11 @@ class MegatronStyleVllmWeightContainer: return the whole weight state dict for vllm, but in megatron style and names, needs megatron weight loader to further transfer for vllm """ + # TODO: 这里需要做一些处理,获得所有的推理权重 self._update_weight_buffers_intra_pp() self._update_weight_buffers_inter_pp() + self.send_receive_experts() params = self._get_all_params() params = _build_infer_param_dict(params=params) @@ -157,6 +168,7 @@ class MegatronStyleVllmWeightContainer: def _build_num_layer_list(self, num_layer_list): if num_layer_list: return [int(num_layers) for num_layers in num_layer_list.split(',')] + print(f"!!!!!!!!!======self._num_hidden_layers: {self._num_hidden_layers}; self._pp_size{self._pp_size}") if self._num_hidden_layers % self._pp_size != 0: raise ValueError("num_layers % pp_size == 0, please specify num_layer_list") return [self._num_hidden_layers // self._pp_size for _ in range(self._pp_size)] @@ -178,10 +190,17 @@ class MegatronStyleVllmWeightContainer: Build buffers from vllm state dict. Totally build train pp_size buffers, each buffer corresponds to a pack of megatron weight. Return a list of buffers, and a reference dict megatron_param_name->buffer. """ - vllm_names = list(dict(self.vllm_model.named_parameters()).keys()) + vllm_names = list(dict(self.vllm_model.named_parameters()).keys()) # 获取每个pp内部的weights name + + print(f"!!!!!!======self._infer_ep_size{self._infer_ep_size}, self._ep_size{self._ep_size}") + self.weight_names_per_pp = self.weight_adaptor.get_weight_names_per_pp(self._num_layer_list, vllm_names) + + print(f"======MegatronStyleVllmWeightContainer————weight_names_per_pp:{self.weight_names_per_pp}") # 相比vllm,多了一层[] [vllm_names] # 当前只有pp,没有vpp,只是多一层[],PP为多少,就有多少个元素 + self.weight_buffers = build_model_weight_buffer(self.vllm_model, self.weight_names_per_pp, - self.weight_adaptor.get_weight_buffer_meta) + self.weight_adaptor.get_weight_buffer_meta, + self.experts_memory_expend_N) def trans_ep_params_to_tp(self, megatron_param, name): """ @@ -264,31 +283,43 @@ class MegatronStyleVllmWeightContainer: async_op=False ) total_experts = self.num_local_experts * tp_size - return torch.cat(output_tensor_list, dim=1).reshape(hidden_size, total_experts, -1).permute(1, 0, 2) + res = torch.cat(output_tensor_list, dim=1).reshape(hidden_size, total_experts, -1) + if 'weight2' in name: + return res.permute(1, 2, 0).contiguous() + return res.permute(1, 0, 2).contiguous() + + # total_experts = self.num_local_experts * tp_size + # return torch.cat(output_tensor_list, dim=1).reshape(hidden_size, total_experts, -1).permute(1, 0, 2) def _update_weight_buffers_intra_pp(self): """ Here, we only update the current training pp_rank's buffer. + 更新当前流水线分区(pp_rank)的权重缓冲区(PP内部) """ def _transfer_from_megatron_division(megatron_param, name): """ Deal with the tp_param form train_tp to infer_tp. + 用于处理训练状态(train_tp)中的张量并将其转换为推理状态(infer_tp)。 """ - infer_param = self.allgather_tp_param(megatron_param, name) - infer_param = self.split_tp_params(infer_param, name) - infer_param = self.trans_ep_params_to_tp(infer_param, name) + infer_param = self.allgather_tp_param(megatron_param, name) # TP缩小 + infer_param = self.split_tp_params(infer_param, name) # TP扩大 与上面的只选做一个 + infer_param = self.trans_ep_params_to_tp(infer_param, name) # 如果expert做了TP,需要走这个函数 return infer_param pp_rank = self._pp_rank weight_buffer = self.weight_buffers[pp_rank] + # print(f"!!!!!!0410======weight_buffer.weight_names{weight_buffer.weight_names}") true_megatron_model = self._unwrap_megatron_model(self.megatron_model) normal_layer_func = partial(self.weight_adaptor.global2local_layer, num_layer_list=self._num_layer_list) name_pairs = sorted(list(set([(name, self.weight_adaptor.replace_name_i2t(normal_layer_func(name))) for name in weight_buffer.weight_names]))) + # print(f"!!!!!!0410==========name_pairs{name_pairs}") if self.enable_validate: self.origin_params_for_md5 = hashlib.md5() self.infer_params_for_md5 = [hashlib.md5() for _ in range(get_tp_allgather_world_size())] + + # 检查 linear_fc1 和 linear_fc2 权重形状是否符合特定关系(fc1 包含门控和扩展参数,因此大小是 fc2 的两倍)。不符合条件的模型不被支持。 for hf_name, megatron_name in name_pairs: if megatron_name.endswith("linear_fc1.weight"): fc2_name = megatron_name.replace("linear_fc1", "linear_fc2") @@ -300,12 +331,23 @@ class MegatronStyleVllmWeightContainer: megatron_params_dict = dict(true_megatron_model.named_buffers()) megatron_params_dict.update(true_megatron_model.named_parameters()) + for hf_name, megatron_name in name_pairs: - megatron_param = megatron_params_dict[megatron_name] - param = _transfer_from_megatron_division(megatron_param, megatron_name) - weight_buffer.copy_by_name(hf_name, param) + if((self._infer_ep_size>1 or self._ep_size>1) and "mlp.experts" in megatron_name): + pass + else: + megatron_param = megatron_params_dict[megatron_name] + param = _transfer_from_megatron_division(megatron_param, megatron_name)# 将训练态的参数转换为推理态:先将TP全部聚合 再分割为推理的状态 + weight_buffer.copy_by_name(hf_name, param) # 从训练态,copy 权重到 memory_buffer 将推理态权重拷贝到对应的推理缓冲区(按照推理态名称 hf_name 存储)。 + + + for hf_name, megatron_name in name_pairs: + if((self._infer_ep_size>1 or self._ep_size>1) and "mlp.experts" in megatron_name): + megatron_param = megatron_params_dict[megatron_name] + weight_buffer.copy_by_experts_name(hf_name, megatron_param) # tp md5 validate + # 通过 MD5 哈希校验,验证推理态权重(infer_params_for_md5)与训练态权重(origin_params_for_md5)的一致性,确保全收集和分割操作正确。 if self.enable_validate: tp_md5_validate(self.infer_params_for_md5, self.origin_params_for_md5, f"rank[{self._rank}] tp params allgather") @@ -328,6 +370,83 @@ class MegatronStyleVllmWeightContainer: dist.broadcast(md5_tensor_src, group=self._pp_group, src=global_src, async_op=False) validate_md5(md5_tensor_src, md5_tensor, f"rank[{self._rank}] pp resharding params") + for cur_pp_rank in range(self._pp_size): + # cur_pp_rank 不一定是当前卡 + print(f"======cur_pp_rank: {cur_pp_rank}") + global_src = dist.get_global_rank(group=self._pp_group, group_rank=cur_pp_rank) # 根据_pp_group和cur_pp_rank确定当前卡在所有卡上的rank + + # broadcast专家权重(experts memory buffer中的) + # step1 pp维度进行broadcast, 每张卡上都有完整的experts权重,但是散落在weight_buffers[cur_pp_rank].experts_memory_buffers中 + for dtype, experts_memory_buffer in self.weight_buffers[cur_pp_rank].experts_memory_buffers.items(): # experts_memory_buffer中的专家相关的权重,按照dtype作为索引的buffer + + dist.broadcast(tensor=experts_memory_buffer.data, src=global_src, group=self._pp_group, async_op=False) #每张卡上有完整的experts,但是分布在不同的memory上 + + pp_group_rank = self._rank // self.pp_group_size + + # 获取对应的dtype + for name, tensor_indices_value in sorted(experts_memory_buffer.tensor_indices.items()): + + shape = tensor_indices_value[1] # 是*N的 + + index = pp_group_rank % self.experts_memory_expend_N + + experts_tensor = experts_memory_buffer.get_by_name(name) + + experts_tensor_reshape = experts_tensor.view(shape) + weight_tensor_infer = experts_tensor_reshape[index] + + print(f"===self.experts_memory_expend_N{self.experts_memory_expend_N}; experts_tensor{experts_tensor.shape}; experts_tensor_reshape{experts_tensor_reshape.shape}; weight_tensor_infer{weight_tensor_infer.shape}") + + self.weight_buffers[cur_pp_rank].copy_by_name(name, weight_tensor_infer) + + # 卸载专家的buffer + experts_memory_buffer = None + self.weight_buffers[cur_pp_rank].experts_memory_buffers[dtype] = None + + for memory_buffer in self.weight_buffers[cur_pp_rank].experts_memory_buffers.values(): + memory_buffer = None + self.weight_buffers[cur_pp_rank].experts_memory_buffers = None + + + + def get_expert_router(self, cur_rank, train_tp_ep_size, infer_tp_ep_size, world_size): + for tp_ep_group_id in range(world_size // infer_tp_ep_size): + tp_ep_group = [i for i in range(tp_ep_group_id * infer_tp_ep_size, (tp_ep_group_id + 1) * infer_tp_ep_size)] + + # construct comm group if you have torch + # infer_tp_ep_group = dist.new_group(tp_ep_group) + if cur_rank in tp_ep_group: + global INFER_TP_EP_GROUP + INFER_TP_EP_GROUP = tp_ep_group + stride = infer_tp_ep_size // train_tp_ep_size + dev_array = np.array(INFER_TP_EP_GROUP).reshape(stride, train_tp_ep_size) + src_router = np.squeeze(dev_array.transpose().reshape(1,infer_tp_ep_size)).tolist() + src=src_router[cur_rank % infer_tp_ep_size] + dst=INFER_TP_EP_GROUP[src_router.index(cur_rank)] + print("ranks:",INFER_TP_EP_GROUP) + print("src_rank",src_router) + return src,dst + + def send_receive_experts(self): + cur_rank = dist.get_rank() + src_rank, dst_rank = self.get_expert_router(cur_rank, self._ep_size, self._infer_ep_size, self._world_size) + for cur_pp_rank in range(self._pp_size): + for memory_buffer in self.weight_buffers[cur_pp_rank].memory_buffers.values(): + for name in sorted(memory_buffer.tensor_indices.keys()): + if "mlp.experts" in name: + # 做收发 + tensor_to_send = memory_buffer.get_by_name(name) + tensor_to_replace = torch.empty_like(tensor_to_send) + print(f"Rank {cur_rank}: Sent tensor name '{name}' to Rank {dst_rank}") + print(f"Rank {cur_rank}: Replaced tensor name '{name}' from Rank {src_rank}") + send_op = dist.P2POp(dist.isend, tensor_to_send, dst_rank) + recv_op = dist.P2POp(dist.irecv,tensor_to_replace, src_rank) + reqs = dist.batch_isend_irecv([send_op,recv_op]) + for req in reqs: + req.wait() + memory_buffer.copy_by_name(name,tensor_to_replace) + + def _get_all_params(self): """Get all the parameters of the models in all pp ranks @@ -353,7 +472,7 @@ class MegatronStyleVllmWeightContainer: return if self._tp_size % self._infer_tp_size != 0: raise ValueError("self._tp_size must be divisible by self._infer_tp_size") - tp_allgather_size = self._tp_size // self._infer_tp_size + tp_allgather_size = self._tp_size if mindspeed_rl.workers.resharding.utils._TP_ALLGATHER_GROUP is not None: raise RuntimeError("Group for allgather tensor model parallel weight is already initialized") num_groups = self._world_size // tp_allgather_size @@ -432,7 +551,7 @@ class MegatronStyleVllmWeightContainer: 2. split train_tp params into groups (size: infer_tp_size) 3. return the corresponding param from group based on infer tp rank """ - if self._infer_tp_size <= self._tp_size: + if self._infer_tp_size <= self._tp_size or is_fake_tp_param(name, self.moe_tp_extend_ep): return param tp_group = get_tp_group() @@ -494,6 +613,9 @@ class MegatronStyleVllmWeightContainer: torch.distributed.all_gather(infer_param, param, group=tp_allgather_group) if self.enable_validate: update_md5_by_rank(infer_param, param, self.origin_params_for_md5, self.infer_params_for_md5) - infer_param = self._default_tp_concat_fn(name, param, infer_param) + part_len = len(infer_param) // self._infer_tp_size + start = self._rank % self._infer_tp_size + part_param = infer_param[part_len*start: part_len*(start+1)] + infer_param = self._default_tp_concat_fn(name, param, part_param) return infer_param diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index 19ea46c2857a0bdacb46cffe919cbedee2c516a0..36a5d8c9468663a77cd4e0a0a5622105c5e5a4a8 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -216,17 +216,19 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): if valid_names and name not in valid_names: continue if 'kv_a_proj_with_mqa' in name: - q_param = dict(model.named_parameters()).get(name.replace('kv_a_proj_with_mqa', 'q_a_proj')) + # 将kv_a_proj_with_mqa和q_a_proj的tensor拼接,并用qkv_proj和拼接的结果替换掉原来kv_a_proj_with_mqa的对应部分 + q_param = dict(model.named_parameters()).get(name.replace('kv_a_proj_with_mqa', 'q_a_proj' if self.model_config.q_lora_rank else "q_proj")) qkv_param_shape = torch.cat([q_param, param], dim=0).shape qkv_name = name.replace('kv_a_proj_with_mqa', 'qkv_proj') weight_buffer_meta[qkv_name] = {'shape': qkv_param_shape, 'dtype': param.dtype} - elif 'q_a_proj' in name: + elif 'q_a_proj' in name or 'q_proj' in name: continue else: weight_buffer_meta[name] = {'shape': param.shape, 'dtype': param.dtype} return weight_buffer_meta + class QwenMVWeightAdaptor(MegatronVLLMWeightAdaptor): """ Megatron-vLLM WeightAdaptor for Qwen model architectures. @@ -239,6 +241,7 @@ WEIGHT_ADAPTOR_REGISTRY = { "Qwen2ForCausalLM": QwenMVWeightAdaptor, "DeepseekV3ForCausalLM": DeepSeekMVWeightAdaptor, "DeepseekV2ForCausalLM": DeepSeekMVWeightAdaptor, + "CustomDeepseekV2ForCausalLM": DeepSeekMVWeightAdaptor, }