From 3fbce4cb0d7f2a85b0d9606ef1269a060188915d Mon Sep 17 00:00:00 2001 From: WeiCheng Tan Date: Fri, 25 Jul 2025 10:36:24 +0800 Subject: [PATCH] fix bug in weight loader time statistic --- .../model_executor/models/mf_models/deepseek_mtp.py | 2 +- .../model_executor/models/mf_models/deepseek_v3.py | 2 +- .../models/mf_models/deepseekv3_weight_processor.py | 4 ++-- vllm_mindspore/model_executor/models/mf_models/qwen2.py | 2 +- .../models/mf_models/qwen2_weight_processor.py | 4 ++-- .../model_executor/models/mf_models/weight_processor.py | 6 +++++- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py index 3054d2cc..0a02ebf3 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py @@ -108,7 +108,7 @@ class DeepseekV3MTPForCausalLM(MfModelBase): def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False) + weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint, is_mtp_model=True) self.network.set_dynamic_inputs() dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 3978ecb8..5d74b0ab 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -180,7 +180,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_config, model, self.network, infer_data, do_predict=True ) else: - weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant) + weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 4242ff3b..cd291aaa 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -58,8 +58,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) + def __init__(self, config, network, is_quant, weights_iter): + super().__init__(config, network, is_quant, weights_iter) self.num_layers = self.config.model.model_config.num_layers self.expert_num = self.config.moe_config.expert_num self.moe_split_tp = self.moe_tp_size > 1 diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index a24a0f1b..7ae515bf 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -75,7 +75,7 @@ class Qwen2ForCausalLM(MfModelBase): return network, network.lm_head def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) + weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False, weights) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py index 07b55c6c..3887241b 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2_weight_processor.py @@ -39,8 +39,8 @@ class Qwen2WeightProcessor(BaseWeightProcessor): """ - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) + def __init__(self, config, network, is_quant, weights_iter): + super().__init__(config, network, is_quant, weights_iter) def infer_convert_outer_weight(self, src_hf_dir, hf_weight_map): """convert weight not in model""" diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 85d3d170..aeeb4526 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -43,7 +43,7 @@ class BaseWeightProcessor: """ - def __init__(self, config, network, is_quant): + def __init__(self, config, network, is_quant, weights_iter): self.config = config self.network = network self.is_quant = is_quant @@ -71,6 +71,10 @@ class BaseWeightProcessor: self.parameter_dict = {} self.file_handles = {} + # Trigger the time initial of vllm native loading weights, make sure correct of Loading weights time costs. + for _, _ in weights_iter: + break + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np") -- Gitee