From bde3b0bbeb8c135c0db6ae0a0f9973a8d45469b9 Mon Sep 17 00:00:00 2001 From: zhang_xu_hao1230 Date: Fri, 21 Mar 2025 15:01:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=BD=91=E7=BB=9Cst=E7=94=A8?= =?UTF-8?q?=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/st/python/config/__init__.py | 0 .../config/predict_deepseek_r1_671b_w8a8.yaml | 125 +++++++++++++++++ .../config/predict_qwen2_5_7b_instruct.yaml | 126 ++++++++++++++++++ tests/st/python/set_env.py | 57 ++++++++ tests/st/python/test_vllm_deepseek_part.py | 77 +++++++++++ tests/st/python/test_vllm_mf_qwen_7b.py | 77 +++++++++++ tests/st/python/test_vllm_mf_qwen_7b_mss.py | 77 +++++++++++ .../{test_demo.py => test_vllm_qwen_7b.py} | 44 +++--- 8 files changed, 567 insertions(+), 16 deletions(-) create mode 100644 tests/st/python/config/__init__.py create mode 100644 tests/st/python/config/predict_deepseek_r1_671b_w8a8.yaml create mode 100644 tests/st/python/config/predict_qwen2_5_7b_instruct.yaml create mode 100644 tests/st/python/set_env.py create mode 100644 tests/st/python/test_vllm_deepseek_part.py create mode 100644 tests/st/python/test_vllm_mf_qwen_7b.py create mode 100644 tests/st/python/test_vllm_mf_qwen_7b_mss.py rename tests/st/python/{test_demo.py => test_vllm_qwen_7b.py} (58%) diff --git a/tests/st/python/config/__init__.py b/tests/st/python/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/st/python/config/predict_deepseek_r1_671b_w8a8.yaml b/tests/st/python/config/predict_deepseek_r1_671b_w8a8.yaml new file mode 100644 index 0000000..5a5e9d6 --- /dev/null +++ b/tests/st/python/config/predict_deepseek_r1_671b_w8a8.yaml @@ -0,0 +1,125 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +run_mode: 'predict' +use_parallel: True + +load_checkpoint: "/path/to/deepseekr1/model_w8a8_ckpt" +load_ckpt_format: "safetensors" +auto_trans_ckpt: True # If true, auto transform load_checkpoint to load in distributed model + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'DeepSeekR1-W8A8' + +# default parallel of device num = 16 for Atlas 800T A2 +parallel_config: + model_parallel: 16 + pipeline_stage: 1 + expert_parallel: 1 + vocab_emb_dp: False + +# mindspore context init config +context: + mode: 0 # 0--Graph Mode; 1--Pynative Mode + max_device_memory: "61GB" + device_id: 0 + affinity_cpu_list: None + +kernel_launch_group: + thread_num: 4 + kernel_group_num: 16 + +# parallel context config +parallel: + parallel_mode: "STAND_ALONE" # use 'STAND_ALONE' mode for inference with parallelism in frontend + full_batch: False + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + +# model config +model: + model_config: + type: DeepseekV3Config + auto_register: deepseek3_config.DeepseekV3Config + batch_size: 1 # add for incre predict + seq_length: 4096 + hidden_size: 7168 + num_layers: 4 + num_heads: 128 + max_position_embeddings: 163840 + intermediate_size: 18432 + kv_lora_rank: 512 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 128 + qk_nope_head_dim: 128 + vocab_size: 129280 + multiple_of: 256 + rms_norm_eps: 1.0e-6 + bos_token_id: 0 + eos_token_id: 1 + pad_token_id: 1 + ignore_token_id: -100 + compute_dtype: "bfloat16" + layernorm_compute_type: "bfloat16" + softmax_compute_type: "bfloat16" + rotary_dtype: "bfloat16" + router_dense_type: "bfloat16" + param_init_type: "bfloat16" + scaling_factor: + beta_fast: 32.0 + beta_slow: 1.0 + factor: 40.0 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 4096 + use_past: True + extend_method: "YARN" + use_flash_attention: True + block_size: 16 + num_blocks: 512 + offset: 0 + checkpoint_name_or_path: "" + repetition_penalty: 1 + max_decode_length: 1024 + top_k: 1 + top_p: 1 + theta: 10000.0 + do_sample: False + is_dynamic: True + qkv_concat: False + ffn_concat: True + quantization_config: + quant_method: 'ptq' + weight_dtype: 'int8' + activation_dtype: 'int8' + auto_map: + AutoConfig: deepseek3_config.DeepseekV3Config + AutoModel: deepseek3.DeepseekV3ForCausalLM + arch: + type: DeepseekV3ForCausalLM + auto_register: deepseek3.DeepseekV3ForCausalLM + +moe_config: + expert_num: 256 + num_experts_chosen: 8 + routing_policy: "TopkRouterV2" + shared_expert_num: 1 + routed_scaling_factor: 2.5 + first_k_dense_replace: 3 + moe_intermediate_size: 2048 + topk_group: 4 + n_group: 8 + +processor: + return_tensors: ms + tokenizer: + unk_token: '' + bos_token: '<|begin▁of▁sentence|>' + eos_token: '<|end▁of▁sentence|>' + pad_token: '<|end▁of▁sentence|>' + type: LlamaTokenizerFast + vocab_file: '/path/to/deepseekr1/tokenizer.json' + tokenizer_file: '/path/to/deepseekr1/tokenizer.json' + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}" + type: LlamaProcessor diff --git a/tests/st/python/config/predict_qwen2_5_7b_instruct.yaml b/tests/st/python/config/predict_qwen2_5_7b_instruct.yaml new file mode 100644 index 0000000..821e33f --- /dev/null +++ b/tests/st/python/config/predict_qwen2_5_7b_instruct.yaml @@ -0,0 +1,126 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed model +only_save_strategy: False +resume_training: False +use_parallel: False +run_mode: 'predict' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'qwen2_5_7b' + +# runner config +runner_config: + epochs: 5 + batch_size: 1 + sink_mode: True + sink_size: 2 +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: + type: DynamicLossScaleUpdateCell + loss_scale_value: 65536 + scale_factor: 2 + scale_window: 1000 + use_clip_grad: True + +# default parallel of device num = 8 for Atlas 800T A2 +parallel_config: + data_parallel: 1 + model_parallel: 1 + pipeline_stage: 1 + micro_batch_num: 1 + vocab_emb_dp: False + gradient_aggregation_group: 4 +# when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +model: + model_config: + type: LlamaConfig + batch_size: 1 + seq_length: 32768 + hidden_size: 3584 + num_layers: 28 + num_heads: 28 + n_kv_heads: 4 + vocab_size: 152064 + intermediate_size: 18944 + max_position_embeddings: 32768 + qkv_has_bias: True + rms_norm_eps: 1.0e-6 + theta: 1000000.0 + emb_dropout_prob: 0.0 + eos_token_id: [151645,151643] + pad_token_id: 151643 + bos_token_id: 151643 + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "bfloat16" + param_init_type: "bfloat16" + use_past: True + use_flash_attention: True + block_size: 32 + num_blocks: 1024 + use_past_shard: False + offset: 0 + checkpoint_name_or_path: "" + repetition_penalty: 1.05 + max_decode_length: 512 + top_k: 20 + top_p: 0.8 + temperature: 0.7 + do_sample: True + is_dynamic: True + qkv_concat: True + auto_map: + AutoTokenizer: [qwen2_5_tokenizer.Qwen2Tokenizer, null] + + arch: + type: LlamaForCausalLM + +processor: + return_tensors: ms + tokenizer: + model_max_length: 131072 + bos_token: null + eos_token: "<|im_end|>" + unk_token: null + pad_token: "<|endoftext|>" + vocab_file: "/path/to/vocab.json" + merges_file: "/path/to/merges.txt" + chat_template: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n" + type: Qwen2Tokenizer + type: Qwen2Processor + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + ascend_config: + precision_mode: "must_keep_origin_dtype" + max_call_depth: 10000 + max_device_memory: "59GB" + save_graphs: False + save_graphs_path: "./graph" + device_id: 0 + +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: False + full_batch: True + search_mode: "sharding_propagation" + enable_parallel_optimizer: False + strategy_ckpt_config: + save_file: "./ckpt_strategy.ckpt" + only_trainable_params: False + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 diff --git a/tests/st/python/set_env.py b/tests/st/python/set_env.py new file mode 100644 index 0000000..575a844 --- /dev/null +++ b/tests/st/python/set_env.py @@ -0,0 +1,57 @@ +import os +import sys +from typing import Dict, Optional + +mindformers_path = "/home/jenkins/mindspore/testcases/testcases/tests/mindformers" + +if mindformers_path not in sys.path: + sys.path.insert(0, mindformers_path) + +current_pythonpath = os.environ.get("PYTHONPATH", "") +if current_pythonpath: + os.environ["PYTHONPATH"] = f"{mindformers_path}:{current_pythonpath}" +else: + os.environ["PYTHONPATH"] = mindformers_path + +class EnvVarManager: + def __init__(self): + self._original_env: Dict[str, Optional[str]] = {} + self._managed_vars: Dict[str, str] = {} + + def set_env_var(self, var_name: str, value: str) -> None: + """设置环境变量并记录原始值(如果存在)""" + if var_name not in self._original_env: + # 保存原始值,即使它不存在(保存为None) + self._original_env[var_name] = os.environ.get(var_name) + + os.environ[var_name] = value + self._managed_vars[var_name] = value + + def unset_env_var(self, var_name: str) -> None: + """取消设置之前设置的环境变量,恢复原始值""" + if var_name not in self._original_env: + raise ValueError(f"Variable {var_name} was not set by this manager") + + original_value = self._original_env[var_name] + if original_value is not None: + os.environ[var_name] = original_value + else: + if var_name in os.environ: + del os.environ[var_name] + + del self._original_env[var_name] + del self._managed_vars[var_name] + + def unset_all(self) -> None: + """取消设置所有由该管理器设置的环境变量""" + for var_name in list(self._managed_vars.keys()): + self.unset_env_var(var_name) + + def get_managed_vars(self) -> Dict[str, str]: + """获取当前由该管理器管理的所有环境变量 """ + return self._managed_vars.copy() + + def setup_ai_environment(self, env_vars: Dict[str, str]) -> None: + """设置AI相关的环境变量,使用传入的参数""" + for var_name, value in env_vars.items(): + self.set_env_var(var_name, value) diff --git a/tests/st/python/test_vllm_deepseek_part.py b/tests/st/python/test_vllm_deepseek_part.py new file mode 100644 index 0000000..ce18a1e --- /dev/null +++ b/tests/st/python/test_vllm_deepseek_part.py @@ -0,0 +1,77 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf deepseek r1.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b_w8a8.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "vLLM_MODEL_MEMORY_USE_GB": "40", + "ASCEND_TOTAL_MEMORY_GB": "60", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + +class TestDeepSeek: + """ + Test Deepseek. + """ + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_deepseek_r1(self): + """ + test case deepseek r1 w8a8 + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8", trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=['ugs611ాలు哒ాలు mahassisemaSTE的道德'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() \ No newline at end of file diff --git a/tests/st/python/test_vllm_mf_qwen_7b.py b/tests/st/python/test_vllm_mf_qwen_7b.py new file mode 100644 index 0000000..bdec7cf --- /dev/null +++ b/tests/st/python/test_vllm_mf_qwen_7b.py @@ -0,0 +1,77 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "vLLM_MODEL_MEMORY_USE_GB": "20", + "ASCEND_TOTAL_MEMORY_GB": "29", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +class TestMfQwen: + """ + Test Qwen. + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_mf_qwen(self): + """ + test case qwen2.5 7B + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", gpu_memory_utilization=0.9, tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=['中性<|Assistant|> 这句话'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() \ No newline at end of file diff --git a/tests/st/python/test_vllm_mf_qwen_7b_mss.py b/tests/st/python/test_vllm_mf_qwen_7b_mss.py new file mode 100644 index 0000000..1f7796c --- /dev/null +++ b/tests/st/python/test_vllm_mf_qwen_7b_mss.py @@ -0,0 +1,77 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen mss.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "vLLM_MODEL_MEMORY_USE_GB": "20", + "ASCEND_TOTAL_MEMORY_GB": "29", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + +class TestMfQwen_mss: + """ + Test qwen. + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_mf_qwen_7b_mss(self): + """ + test case qwen_7b_mss + """ + + # Sample prompts. + prompts = [ + "I love Beijing, because", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", max_model_len=8192, max_num_batched_tokens=8192, + block_size=32, gpu_memory_utilization=0.9, num_scheduler_steps=8, tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=[' it is a city with a long history. Which'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() \ No newline at end of file diff --git a/tests/st/python/test_demo.py b/tests/st/python/test_vllm_qwen_7b.py similarity index 58% rename from tests/st/python/test_demo.py rename to tests/st/python/test_vllm_qwen_7b.py index d6e1fd0..5a9813e 100644 --- a/tests/st/python/test_demo.py +++ b/tests/st/python/test_vllm_qwen_7b.py @@ -13,44 +13,56 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""test demo for st.""" +"""test qwen.""" import pytest +from . import set_env +env_manager = set_env.EnvVarManager() +env_vars = { + "ASCEND_TOTAL_MEMORY_GB": "29", + "vLLM_MODEL_MEMORY_USE_GB": "20", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams - -class TestDemo: +class TestQwen: """ - Test Demo for ST. + Test Qwen. """ @pytest.mark.level0 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_single - def test_aaa(self): + def test_qwen(self): """ - test case aaa + test case qwen2.5 7B """ - # pylint: disable=W0611 - import vllm_mindspore - from vllm import LLM, SamplingParams # Sample prompts. prompts = [ - "I am", - "Today is", - "Llama is" + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", ] # Create a sampling params object. - sampling_params = SamplingParams(temperature=0.0, top_p=0.95) + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) # Create an LLM. - llm = LLM(model="/home/workspace/mindspore_dataset/weight/Llama-2-7b-hf") + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", gpu_memory_utilization=0.9, max_model_len=200, tensor_parallel_size=2) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) + except_list=['中性<|Assistant|> 这句话'] # Print the outputs. - for output in outputs: + for i, output in enumerate(outputs): prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - assert len(outputs) == 3 + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() \ No newline at end of file -- Gitee