diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py index ed74ba9e38d54f7e507951ea585106f833e83d6b..5460bbbefc6292f347618ba835e991a1a7508afd 100644 --- a/vllm_mindspore/engine/arg_utils.py +++ b/vllm_mindspore/engine/arg_utils.py @@ -51,11 +51,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=True) return False - if self.additional_config != EngineArgs.additional_config: - _raise_or_fallback(feature_name="--additional-config", - recommend_to_remove=False) - return False - # Xgrammar and Guidance are supported. SUPPORTED_GUIDED_DECODING = [ "xgrammar", "xgrammar:disable-any-whitespace", "guidance", diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index d09d8d265c2e9daa8114b6e6afa0da1e85cc99f4..819be42faa58d498f6eaead1c10c2ddfe4d74f2d 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -144,6 +144,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_config.load_checkpoint = self.get_model_path() self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config) + self.mf_model_config.enable_micro_batch = self.enable_micro_batch if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config # dispatch/combine in moe need max_num_seqs as global_max_bs diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 8a5a07778cd2095b24289db1ac309da8c5ad28e4..70577b893c6042cfe65e42240c1f7cadf45c4d6e 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -28,13 +28,15 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_dp_group from vllm.logger import init_logger from vllm.forward_context import get_forward_context import vllm.envs as envs import mindspore as ms -from mindspore import Tensor +from mindspore import Tensor, mint from mindspore.common.api import _pynative_executor +from mindspore.communication import get_rank from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_mf_context @@ -54,6 +56,8 @@ class MfModelBase(MsModelBase): ) self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) + self.rank_id = get_rank() + self.dp_size = get_dp_group() build_mf_context(self.mf_config) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( @@ -196,11 +200,21 @@ class MfModelBase(MsModelBase): attn_metadata = self._dummy_attention_metadata(input_ids, positions) model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, attn_metadata) model_inputs = self.update_model_inputs(model_inputs, **kwargs) + + # enable_mb_split is True in lager EP enable micro-batch and per-dp-bs > 1 + enable_mb_split = self.is_enable_micro_batch_split(is_prefill, model_inputs["q_seq_lens"]) if is_prefill: - self.network.phase = "prefill" - if not self.set_flags or is_pynative(): - self.network.add_flags_custom(is_first_iteration=True) + if self.enable_micro_batch: + self.network.phase = "prefill" if not enable_mb_split else "prefill_micro_batch" + if not self.set_flags or is_pynative() or enable_mb_split: + self.network.add_flags_custom(is_first_iteration=is_first_iteration) + self.network.add_flags_enable_micro_batch(enable_micro_batch=enable_mb_split) + else: + self.network.phase = "prefill" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom(is_first_iteration=True) + hidden_states = self.network(**model_inputs) self.network.phase = "increment" if not self.set_flags or is_pynative(): @@ -241,3 +255,12 @@ class MfModelBase(MsModelBase): def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: raise NotImplementedError("load_weight not implemented.") + + def is_enable_micro_batch_split(self, is_prefill, q_seq_lens): + """Judge enable micro batch """ + if self.enable_micro_batch: + is_prefill_cur_dp = mint.ones((1), dtype=ms.int8) if is_prefill else mint.zeros((1), dtype=ms.int8) + is_prefill_all_dp = get_dp_group().all_gather(is_prefill_cur_dp) + return is_prefill_all_dp.sum() == self.dp_size and q_seq_lens.shape[0] > 1 + else: + return False diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 0d933a2db438919cf68388833e1f95c572436c81..c6d6a83ffb1e68ca5be314cc9f684d9186d74dda 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -112,6 +112,9 @@ class MsModelBase(): self.parallel_config = vllm_config.parallel_config self.load_config = vllm_config.load_config self.scheduler_config = vllm_config.scheduler_config + self.enable_micro_batch = \ + vllm_config.additional_config.get('enable_micro_batch', 0) == 1 \ + if vllm_config.additional_config is not None else False self.modules_dict = None