From a8d53d3c3320123082707d50313cea4bac4f5cce Mon Sep 17 00:00:00 2001 From: twc Date: Thu, 24 Apr 2025 11:49:22 +0800 Subject: [PATCH] model optimizer, support TH compute --- .../models/mf_models/mf_model_base.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 893d91a51..bbef1315a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -45,17 +45,6 @@ from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerT logger = init_logger(__name__) -def _pad_to_max(x, max_len): - return x + [-1] * (max_len - len(x)) - - -def _batch_seq(input_tokens, prefill): - if prefill: - return ms.ops.expand_dims(input_tokens, 0).to(ms.int32) - - return ms.mint.reshape(input_tokens, (-1, 1)).to(ms.int32) - - class MfModelBase(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MfModelBase, self).__init__( @@ -113,8 +102,8 @@ class MfModelBase(MsModelBase): attention_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) model_inputs = {} - model_inputs["input_ids"] = _batch_seq(input_ids, is_prefill) - model_inputs["batch_valid_length"] = ms.Tensor.from_numpy(np.expand_dims(seq_lens_np, 0)) + model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) model_inputs["block_tables"] = attn_metadata.block_tables model_inputs["slot_mapping"] = attn_metadata.slot_mapping model_inputs["position_ids"] = position_ids @@ -167,7 +156,6 @@ class MfModelBase(MsModelBase): else: hidden_states = hidden_states.index_select(0, selected_token_indices) logits = self.lm_head(hidden_states) - logits = logits.reshape(-1, logits.shape[-1]) return logits -- Gitee