diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py index 3054d2ccfba2e12a2bee64fc3df8648580a54ed6..0a02ebf3c483cbbdfa98f17e6935573779885f92 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py @@ -108,7 +108,7 @@ class DeepseekV3MTPForCausalLM(MfModelBase): def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False) + weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint, is_mtp_model=True) self.network.set_dynamic_inputs() dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 3978ecb82c69d82ff452a2596cf45380cacad50c..5d74b0abf611d12eac08ae1e51fa05f51f4598c8 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -180,7 +180,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_config, model, self.network, infer_data, do_predict=True ) else: - weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant) + weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 4242ff3b6301571c8213753fd76502ce0f6b71e8..cd291aaa453337ff2e4929127a7030d099e78794 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -58,8 +58,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) + def __init__(self, config, network, is_quant, weights_iter): + super().__init__(config, network, is_quant, weights_iter) self.num_layers = self.config.model.model_config.num_layers self.expert_num = self.config.moe_config.expert_num self.moe_split_tp = self.moe_tp_size > 1 diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index a24a0f1b98ed603b8f84f424a1b712c0a01f7a83..7ae515bfa8a409ada8d489b46bfb098550e1e103 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -75,7 +75,7 @@ class Qwen2ForCausalLM(MfModelBase): return network, network.lm_head def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) + weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py index 07b55c6cc7d3f984f8998a837398716c8e0f5051..3887241b934704de7c1a4b79dc3a8731aecd9ed7 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py @@ -39,8 +39,8 @@ class Qwen2WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) + def __init__(self, config, network, is_quant, weights_iter): + super().__init__(config, network, is_quant, weights_iter) def infer_convert_outer_weight(self, src_hf_dir, hf_weight_map): """convert weight not in model""" diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 85d3d170d6d74e7b1ce7cd8342297c60fead6a2d..aeeb452698c9365c15cc5d76a077835b4594c358 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -43,7 +43,7 @@ class BaseWeightProcessor: """ - def __init__(self, config, network, is_quant): + def __init__(self, config, network, is_quant, weights_iter): self.config = config self.network = network self.is_quant = is_quant @@ -71,6 +71,10 @@ class BaseWeightProcessor: self.parameter_dict = {} self.file_handles = {} + # Trigger the time initial of vllm native loading weights, make sure correct of Loading weights time costs. + for _, _ in weights_iter: + break + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np")