diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 06beef1435da076bb07c7b537f0a64f3c1e93ec3..db58b6bd80199073884fcfe39a91164cbe12b48c 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -52,7 +52,7 @@ from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.utils import calc_block_num import mindspore as ms -from mindspore import Tensor, JitConfig, Model +from mindspore import Tensor, JitConfig, Model, mutable from vllm_mindspore.model_executor.models.mf_models.deepseekv3_infer_parallelism import DeepseekInferParallelism from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask @@ -120,6 +120,7 @@ class DeepseekV3ForCausalLM(MsModelBase): if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config self.mf_model_config.return_hidden_states = True + setattr(self.mf_model_config, 'npu_mem_size', -1) self.is_quant = bool(hasattr(self.mf_model_config, "quantization_config") and self.mf_model_config.quantization_config) @@ -158,8 +159,6 @@ class DeepseekV3ForCausalLM(MsModelBase): self.network._jit_config_dict = JitConfig( jit_level="O0", infer_boost="on" ).jit_config_dict - self.mf_kvcaches_init = False - self.sampler = get_sampler() self.set_modules({"model": self.network}) @@ -174,19 +173,13 @@ class DeepseekV3ForCausalLM(MsModelBase): self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) self.set_flags = False - def update_mf_kvcaches(self): - if self.mf_kvcaches_init: - return - + def get_key_cache(self): + key_cache = [] forward_context = get_forward_context() for i in range(self.mf_model_config.num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - mf_k_cache, _ = self.network.kvcache(i) - - mf_k_cache.set_device_address( - k_cache._data_ptr(), k_cache.shape, k_cache.dtype - ) - self.mf_kvcaches_init = True + key_cache.append(k_cache) + return mutable(key_cache) def forward( self, @@ -197,7 +190,7 @@ class DeepseekV3ForCausalLM(MsModelBase): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: - self.update_mf_kvcaches() + key_cache = self.get_key_cache() query_lens = attn_metadata.query_lens kv_cache_lens = attn_metadata.seq_lens_tensor.asnumpy() - query_lens @@ -223,6 +216,7 @@ class DeepseekV3ForCausalLM(MsModelBase): model_inputs["position_ids"] = position_ids model_inputs["q_seq_lens"] = q_seq_lens model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache if is_prefill: self.network.phase = "prefill"