From dacfef799a751a17eff11accbcb853b8ba071547 Mon Sep 17 00:00:00 2001 From: yang guodong Date: Mon, 14 Jul 2025 14:57:13 +0800 Subject: [PATCH] fix load safetensors --- vllm_mindspore.patch | 55 +++++++++++++++++++ .../models/mf_models/mf_model_base.py | 4 +- .../models/mf_models/weight_processor.py | 18 +++--- 3 files changed, 66 insertions(+), 11 deletions(-) create mode 100644 vllm_mindspore.patch diff --git a/vllm_mindspore.patch b/vllm_mindspore.patch new file mode 100644 index 00000000..76d586f7 --- /dev/null +++ b/vllm_mindspore.patch @@ -0,0 +1,55 @@ +diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +index 16f4a52..1d9792d 100644 +--- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py ++++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +@@ -155,7 +155,7 @@ class MfModelBase(MsModelBase): + attention_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) + + model_inputs = {} +- model_inputs["input_ids"] = input_ids.astype(ms.int32) ++ model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 + model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables * 1 + model_inputs["slot_mapping"] = attn_metadata.slot_mapping +@@ -174,7 +174,7 @@ class MfModelBase(MsModelBase): + attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + + model_inputs = {} +- model_inputs["input_ids"] = input_ids.astype(ms.int32) ++ model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 + model_inputs["batch_valid_length"] = ms.from_numpy(attn_metadata.seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables * 1 + model_inputs["slot_mapping"] = attn_metadata.slot_mapping +diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +index 89d786e..7fe04b9 100644 +--- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py ++++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +@@ -131,19 +131,19 @@ class BaseWeightProcessor: + np_data = sf_file.get_slice(hf_param_name) + shape = np_data.get_shape() + if split_axis == 0: +- split_size = shape[0] // self.global_group_size +- start = self.global_rank_id * split_size +- stop = (self.global_rank_id + 1) * split_size ++ split_size = shape[0] // self.tp_group_size ++ start = self.tp_rank_id * split_size ++ stop = (self.tp_rank_id + 1) * split_size + split_data = np_data[start:stop] + elif split_axis == 1: +- split_size = shape[1] // self.global_group_size +- start = self.global_rank_id * split_size +- stop = (self.global_rank_id + 1) * split_size ++ split_size = shape[1] // self.tp_group_size ++ start = self.tp_rank_id * split_size ++ stop = (self.tp_rank_id + 1) * split_size + split_data = np_data[:, start:stop] + elif split_axis == 2: +- split_size = shape[2] // self.global_group_size +- start = self.global_rank_id * split_size +- stop = (self.global_rank_id + 1) * split_size ++ split_size = shape[2] // self.tp_group_size ++ start = self.tp_rank_id * split_size ++ stop = (self.tp_rank_id + 1) * split_size + split_data = np_data[:, :, start:stop] + else: + raise ValueError("split_axis:{} is not supported.".format(split_axis)) diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 16f4a523..1d9792df 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -155,7 +155,7 @@ class MfModelBase(MsModelBase): attention_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens) model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) model_inputs["block_tables"] = attn_metadata.block_tables * 1 model_inputs["slot_mapping"] = attn_metadata.slot_mapping @@ -174,7 +174,7 @@ class MfModelBase(MsModelBase): attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["input_ids"] = input_ids.astype(ms.int32) * 1 model_inputs["batch_valid_length"] = ms.from_numpy(attn_metadata.seq_lens_np) model_inputs["block_tables"] = attn_metadata.block_tables * 1 model_inputs["slot_mapping"] = attn_metadata.slot_mapping diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 89d786eb..7fe04b97 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -131,19 +131,19 @@ class BaseWeightProcessor: np_data = sf_file.get_slice(hf_param_name) shape = np_data.get_shape() if split_axis == 0: - split_size = shape[0] // self.global_group_size - start = self.global_rank_id * split_size - stop = (self.global_rank_id + 1) * split_size + split_size = shape[0] // self.tp_group_size + start = self.tp_rank_id * split_size + stop = (self.tp_rank_id + 1) * split_size split_data = np_data[start:stop] elif split_axis == 1: - split_size = shape[1] // self.global_group_size - start = self.global_rank_id * split_size - stop = (self.global_rank_id + 1) * split_size + split_size = shape[1] // self.tp_group_size + start = self.tp_rank_id * split_size + stop = (self.tp_rank_id + 1) * split_size split_data = np_data[:, start:stop] elif split_axis == 2: - split_size = shape[2] // self.global_group_size - start = self.global_rank_id * split_size - stop = (self.global_rank_id + 1) * split_size + split_size = shape[2] // self.tp_group_size + start = self.tp_rank_id * split_size + stop = (self.tp_rank_id + 1) * split_size split_data = np_data[:, :, start:stop] else: raise ValueError("split_axis:{} is not supported.".format(split_axis)) -- Gitee