diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index eb243d6fa130061091ab60bb434539a32f136fba..efb6f8e251de8455cdbe7ab5c227eca38c61712d 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,11 +1,11 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202504/20250403/br_infer_deepseek_os_20250403204446_a10f9cf58ea06de7cf6acbec0bde94442992955b_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202504/20250417/br_infer_deepseek_os_20250417004508_38b6db6c3039b59153d52d5e353cd01fe774dc93_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202503/20250322/master_20250322160019_1aa0a919d27c806700b2399bf965c5f6663c10fd_newest/' msadapter: - 'https://repo.mindspore.cn/mindspore/msadapter/version/202504/20250403/master_20250403171706_61451a9e1a5909cfa7877f72b1286bc0a843a067_newest/' + 'https://repo.mindspore.cn/mindspore/msadapter/version/202504/20250410/master_20250410120007_83e7214eb2b9598179135a4e98dce3b69ba27da2_newest/' vllm: 'https://repo.mindspore.cn/mirrors/vllm/version/202503/20250321/v0.7.3_20250321112504_ed6e9075d31e32c8548b480a47d1ffb77da1f54c_newest/' diff --git a/setup.py b/setup.py index 647dfa531f6bca4654ffc320e15fc837a7873e0a..8e2154b3671e0e6956e5cb90feb3e6d19f3216a1 100644 --- a/setup.py +++ b/setup.py @@ -26,9 +26,9 @@ from typing import List from pathlib import Path from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext -from setuptools.command.install import install from setuptools import Extension import subprocess +import warnings def load_module_from_path(module_name, path): @@ -88,51 +88,63 @@ def get_requirements() -> List[str]: return requirements +def write_commit_id(): + ret_code = os.system("git rev-parse --abbrev-ref HEAD > ./vllm_mindspore/.commit_id " + "&& git log --abbrev-commit -1 >> ./vllm_mindspore/.commit_id") + if ret_code != 0: + sys.stdout.write("Warning: Can not get commit id information. Please make sure git is available.") + os.system("echo 'git is not available while building.' > ./vllm_mindspore/.commit_id") + + version = (Path("vllm_mindspore") / "version.txt").read_text() def _get_ascend_home_path(): return os.environ.get("ASCEND_HOME_PATH", "/usr/local/Ascend/ascend-toolkit/latest") +def _get_ascend_env_path(check_exists=True): + env_script_path = os.path.join(_get_ascend_home_path(), "bin", "setenv.bash") + if check_exists and not os.path.exists(env_script_path): + warnings.warn(f"The file '{env_script_path}' is not found, " + "please make sure env variable 'ASCEND_HOME_PATH' is set correctly.") + return None + return env_script_path + class CustomBuildExt(build_ext): ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) - ASCENDC_OPS_DIR = os.path.join(ROOT_DIR, "vllm_mindspore", "ops", "ascendc") def build_extension(self, ext): - if ext.name == "ascendc_kernels_npu": - self.build_ascendc_kernels() - elif ext.name == "npu_ops": + if ext.name == "vllm_mindspore.npu_ops": self.build_npu_ops(ext) else: raise ValueError(f"Unknown extension name: {ext.name}") - def build_ascendc_kernels(self): - kernel_so_name = "libascendc_kernels_npu.so" - print(f"Building {kernel_so_name}...") - tmp_build_dir = os.path.join(self.ASCENDC_OPS_DIR, "build") - if os.path.exists(tmp_build_dir): - print(f"Removing existing build directory: {tmp_build_dir}") - shutil.rmtree(tmp_build_dir) - os.makedirs(tmp_build_dir, exist_ok=True) + def build_npu_ops(self, ext): + # "vllm_mindspore.npu_ops" --> "npu_ops" + ext_name = ext.name.split('.')[-1] + so_name = ext_name + ".so" + print(f"Building {so_name} ...") + OPS_DIR = os.path.join(ROOT_DIR, "vllm_mindspore", "ops") + BUILD_OPS_DIR = os.path.join(ROOT_DIR, "build", "ops") + os.makedirs(BUILD_OPS_DIR, exist_ok=True) ascend_home_path = _get_ascend_home_path() - env_script_path = os.path.join(ascend_home_path, "bin", "setenv.bash") - if not os.path.exists(env_script_path): - raise RuntimeError(f"The file '{env_script_path}' is not found, " - "please make sure env variable 'ASCEND_HOME_PATH' is set correctly.") + env_script_path = _get_ascend_env_path(False) + build_extension_dir = os.path.join(BUILD_OPS_DIR, "kernel_meta", ext_name) # Combine all cmake commands into one string cmake_cmd = ( f"source {env_script_path} && " - f"cmake -S {self.ASCENDC_OPS_DIR} -B {tmp_build_dir} " - f"-DRUN_MODE=npu -DCMAKE_BUILD_TYPE=Debug " - f"-DCMAKE_INSTALL_PREFIX={os.path.join(tmp_build_dir, 'install')} " - f"-DASCEND_CANN_PACKAGE_PATH={ascend_home_path} && " - f"cmake --build {tmp_build_dir} -j --verbose && " - f"cmake --install {tmp_build_dir}" + f"cmake -S {OPS_DIR} -B {BUILD_OPS_DIR}" + f" -DCMAKE_BUILD_TYPE=Release" + f" -DCMAKE_INSTALL_PREFIX={os.path.join(BUILD_OPS_DIR, 'install')}" + f" -DBUILD_EXTENSION_DIR={build_extension_dir}" + f" -DMS_EXTENSION_NAME={ext_name}" + f" -DASCEND_CANN_PACKAGE_PATH={ascend_home_path} && " + f"cmake --build {BUILD_OPS_DIR} -j --verbose" ) try: # Run the combined cmake command - print("Running combined CMake commands:") + print(f"Running combined CMake commands:\n{cmake_cmd}") result = subprocess.run(cmake_cmd, cwd=self.ROOT_DIR, text=True, shell=True, capture_output=True) if result.returncode != 0: print("CMake commands failed:") @@ -140,54 +152,25 @@ class CustomBuildExt(build_ext): print(result.stderr) # Print error output raise RuntimeError(f"Combined CMake commands failed with exit code {result.returncode}") except subprocess.CalledProcessError as e: - raise RuntimeError(f"Failed to build {kernel_so_name}: {e}") + raise RuntimeError(f"Failed to build {so_name}: {e}") - # Move the generated .so file to the target directory - src_so_path = os.path.join(tmp_build_dir, "lib", kernel_so_name) - lib_dir = os.path.join(self.ROOT_DIR, self.build_lib, "vllm_mindspore", "lib") - dst_so_path = os.path.join(lib_dir, kernel_so_name) - os.makedirs(lib_dir, exist_ok=True) + # Copy the generated .so file to the target directory + src_so_path = os.path.join(build_extension_dir, so_name) + dst_so_path = self.get_ext_fullpath(ext.name) + os.makedirs(os.path.dirname(dst_so_path), exist_ok=True) if os.path.exists(dst_so_path): os.remove(dst_so_path) - shutil.move(src_so_path, dst_so_path) - print(f"Moved {kernel_so_name} to {lib_dir}.") - # Remove the build directory after building kernels.so - shutil.rmtree(tmp_build_dir) + shutil.copy(src_so_path, dst_so_path) + print(f"Copied {so_name} to {dst_so_path}") - def build_npu_ops(self, ext): - print("Building npu_ops.so ...") - try: - import mindspore as ms - except ImportError: - print("Mindspore is not found, skip building npu_ops.so") - return - try: - src = [os.path.join(self.ASCENDC_OPS_DIR, s) for s in ext.sources] - build_lib_dir = os.path.join(self.ROOT_DIR, self.build_lib, "vllm_mindspore") - ms.ops.CustomOpBuilder( - "npu_ops", - src, - backend="Ascend", - cflags=f"-I{self.ASCENDC_OPS_DIR}", - ldflags=f"-L{os.path.join(build_lib_dir, 'lib')} -lascendc_kernels_npu -Wl,-rpath,'$$ORIGIN/lib'" - ).load() - except ImportError: - pass - # Move the generated .so file to the target directory - kernel_meta_dir = os.path.join(self.ROOT_DIR, "kernel_meta") - src_so_path = os.path.join(kernel_meta_dir, "npu_ops", "npu_ops.so") - dst_so_path = os.path.join(build_lib_dir, "npu_ops.so") - os.makedirs(build_lib_dir, exist_ok=True) - if os.path.exists(dst_so_path): - os.remove(dst_so_path) - shutil.move(src_so_path, build_lib_dir) - print(f"Moved npu_ops.so to {build_lib_dir}.") - shutil.rmtree(kernel_meta_dir) + +write_commit_id() package_data = { "": [ "*.so", "lib/*.so", + ".commit_id" ] } @@ -197,11 +180,8 @@ def _get_ext_modules(): # As a temporary solution, this is controlled via an environment variable. # Once the CI environment adds support for custom operator compilation, # this should be updated to enable compilation by default. - if os.getenv("vLLM_USE_NPU_ADV_STEP_FLASH_OP", "off") == "on": - ext_modules.append(Extension("ascendc_kernels_npu", sources=[])) - ext_modules.append(Extension("npu_ops", sources=[ - "adv_step_flash_adapter.cpp" - ])) + if os.getenv("vLLM_USE_NPU_ADV_STEP_FLASH_OP", "off") == "on" and _get_ascend_env_path() is not None: + ext_modules.append(Extension("vllm_mindspore.npu_ops", sources=[])) # sources are specified in CMakeLists.txt return ext_modules setup( diff --git a/tests/mindformers b/tests/mindformers index ed67bae4e88fa4d01c91cfbe4dfd822165c75d2f..544c4009573051e0e254efab71d212bfc77fc7b2 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit ed67bae4e88fa4d01c91cfbe4dfd822165c75d2f +Subproject commit 544c4009573051e0e254efab71d212bfc77fc7b2 diff --git a/tests/st/python/config/predict_deepseek_r1_671b.yaml b/tests/st/python/config/predict_deepseek_r1_671b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..112375eff5105d85e32c325b0b3f3d987591a1bd --- /dev/null +++ b/tests/st/python/config/predict_deepseek_r1_671b.yaml @@ -0,0 +1,121 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +run_mode: 'predict' +use_parallel: True + +load_checkpoint: "/path/to/deepseekr1/model_ckpt" +load_ckpt_format: "safetensors" +auto_trans_ckpt: True # If true, auto transform load_checkpoint to load in distributed model + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'DeepSeekR1' + +# default parallel of device num = 32 for Atlas 800T A2 +parallel_config: + model_parallel: 4 + pipeline_stage: 1 + expert_parallel: 1 + vocab_emb_dp: False + +# mindspore context init config +context: + mode: 0 # 0--Graph Mode; 1--Pynative Mode + max_device_memory: "61GB" + device_id: 0 + affinity_cpu_list: None + +kernel_launch_group: + thread_num: 4 + kernel_group_num: 16 + +# parallel context config +parallel: + parallel_mode: "STAND_ALONE" # use 'STAND_ALONE' mode for inference with parallelism in frontend + full_batch: False + strategy_ckpt_save_file: "./ckpt_strategy.ckpt" + +# model config +model: + model_config: + type: DeepseekV3Config + auto_register: deepseek3_config.DeepseekV3Config + batch_size: 1 # add for incre predict + seq_length: 4096 + hidden_size: 7168 + num_layers: 4 + num_heads: 128 + max_position_embeddings: 163840 + intermediate_size: 18432 + kv_lora_rank: 512 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 128 + qk_nope_head_dim: 128 + vocab_size: 129280 + multiple_of: 256 + rms_norm_eps: 1.0e-6 + bos_token_id: 0 + eos_token_id: 1 + pad_token_id: 1 + ignore_token_id: -100 + compute_dtype: "bfloat16" + layernorm_compute_type: "bfloat16" + softmax_compute_type: "bfloat16" + rotary_dtype: "bfloat16" + router_dense_type: "bfloat16" + param_init_type: "bfloat16" + scaling_factor: + beta_fast: 32.0 + beta_slow: 1.0 + factor: 40.0 + mscale: 1.0 + mscale_all_dim: 1.0 + original_max_position_embeddings: 4096 + use_past: True + extend_method: "YARN" + use_flash_attention: True + block_size: 16 + num_blocks: 512 + offset: 0 + checkpoint_name_or_path: "" + repetition_penalty: 1 + max_decode_length: 1024 + top_k: 1 + top_p: 1 + theta: 10000.0 + do_sample: False + is_dynamic: True + qkv_concat: False + ffn_concat: True + auto_map: + AutoConfig: deepseek3_config.DeepseekV3Config + AutoModel: deepseek3.DeepseekV3ForCausalLM + arch: + type: DeepseekV3ForCausalLM + auto_register: deepseek3.DeepseekV3ForCausalLM + +moe_config: + expert_num: 256 + num_experts_chosen: 8 + routing_policy: "TopkRouterV2" + shared_expert_num: 1 + routed_scaling_factor: 2.5 + first_k_dense_replace: 3 + moe_intermediate_size: 2048 + topk_group: 4 + n_group: 8 + +processor: + return_tensors: ms + tokenizer: + unk_token: '' + bos_token: '<|begin▁of▁sentence|>' + eos_token: '<|end▁of▁sentence|>' + pad_token: '<|end▁of▁sentence|>' + type: LlamaTokenizerFast + vocab_file: '/path/to/deepseekr1/tokenizer.json' + tokenizer_file: '/path/to/deepseekr1/tokenizer.json' + chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{{'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}" + type: LlamaProcessor diff --git a/tests/st/python/test_shm_broadcast.py b/tests/st/python/test_shm_broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc328810fdf318feadb30cffa735e8be105892f --- /dev/null +++ b/tests/st/python/test_shm_broadcast.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test cpu communicator and share memory""" +import pytest +import multiprocessing +import random +import time +from typing import List + +import numpy as np +import torch.distributed as dist + +import vllm_mindspore + +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.distributed.utils import StatelessProcessGroup +from vllm.utils import get_ip, get_open_port, update_environment_variables, get_distributed_init_method + + +def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: + np.random.seed(seed) + sizes = np.random.randint(1, 10_000, n) + # on average, each array will have 5k elements + # with int64, each array will have 40kb + return [np.random.randint(1, 100, i) for i in sizes] + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes = [] + + port = get_open_port() + distributed_init_method = get_distributed_init_method("127.0.0.1", port) + + for i in range(number_of_processes): + p = multiprocessing.Process(target=fn, args=(distributed_init_method, i, world_size)) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + # `multiprocessing.Process` cannot accept environment variables directly + # so we need to pass the environment variables as arguments + # and update the environment variables in the function + def wrapped_fn(distributed_init_method, rank, world_size): + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + fn() + + return wrapped_fn + + +@worker_fn_wrapper +def worker_fn(): + + rank = dist.get_rank() + if rank == 0: + port = get_open_port() + ip = get_ip() + dist.broadcast_object_list([ip, port], src=0) + else: + recv = [None, None] + dist.broadcast_object_list(recv, src=0) + ip, port = recv + + stateless_pg = dist.new_group([0,1,2,3], backend="gloo") + + for pg in [dist.group.WORLD, stateless_pg]: + + writer_rank = 2 + broadcaster = MessageQueue.create_from_process_group( + pg, 40 * 1024, 2, writer_rank) + if rank == writer_rank: + seed = random.randint(0, 1000) + dist.broadcast_object_list([seed], writer_rank) + else: + recv = [None] + dist.broadcast_object_list(recv, writer_rank) + seed = recv[0] # type: ignore + + if pg == dist.group.WORLD: + dist.barrier() + else: + dist.barrier(group=pg) + + # in case we find a race condition + # print the seed so that we can reproduce the error + print(f"Rank {rank} got seed {seed}") + # test broadcasting with about 400MB of data + N = 10_000 + if rank == writer_rank: + arrs = get_arrays(N, seed) + for x in arrs: + broadcaster.broadcast_object(x) + time.sleep(random.random() / 1000) + else: + arrs = get_arrays(N, seed) + for x in arrs: + y = broadcaster.broadcast_object(None) + assert np.array_equal(x, y) + time.sleep(random.random() / 1000) + + if pg == dist.group.WORLD: + dist.barrier() + print("torch distributed passed the test!") + else: + dist.barrier(group=pg) + print("StatelessProcessGroup passed the test!") + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_single +def test_shm_broadcast(): + distributed_run(worker_fn, 4) diff --git a/tests/st/python/test_vllm_deepseek_bf16_part.py b/tests/st/python/test_vllm_deepseek_bf16_part.py new file mode 100644 index 0000000000000000000000000000000000000000..c19dd14a66e82fa30ea302723c12497d0b191652 --- /dev/null +++ b/tests/st/python/test_vllm_deepseek_bf16_part.py @@ -0,0 +1,76 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf deepseek r1.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + +class TestDeepSeek: + """ + Test Deepseek. + """ + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_deepseek_r1_bf16(self): + """ + test case deepseek r1 bf16 + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16", + trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=['ugs611ాలు sic辨hara的开璞 SquaresInsp'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_vllm_deepseek_part.py b/tests/st/python/test_vllm_deepseek_part.py index a0caa3161c6b7be11faa977bd5737479cc32d030..8dfa95635dbffb0e584397cfbea33491ec2862c3 100644 --- a/tests/st/python/test_vllm_deepseek_part.py +++ b/tests/st/python/test_vllm_deepseek_part.py @@ -1,13 +1,15 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. -# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://wwww.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by application law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -23,8 +25,6 @@ env_vars = { "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b_w8a8.yaml", "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), "vLLM_MODEL_BACKEND": "MindFormers", - "vLLM_MODEL_MEMORY_USE_GB": "40", - "ASCEND_TOTAL_MEMORY_GB": "60", "MS_ENABLE_LCCL": "off", "HCCL_OP_EXPANSION_MODE": "AIV", "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7", @@ -66,13 +66,54 @@ class TestDeepSeek: # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) - except_list=['ugs611ాలు哒ాలు mahassisemaSTE的道德'] + except_list=['ugs611ాలు哒ాలు mahassisemaSTE的道德', 'ugs611ాలు哒ాలు mah战区rollerOVERlaid'] # Print the outputs. for i, output in enumerate(outputs): prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - assert generated_text == except_list[i] + assert generated_text in except_list + + # unset env + env_manager.unset_all() + + +class TestDeepSeekMTP: + """ + Test DeepseekMTP. + 大模型用量化(4层),mtp模型用浮点(1层,layer 61)。 + mtp的权重加载默认从配置的num_hidden_layer开始,为了支持减层推理场景mtp权重加载,ci服务器上修改了浮点的权重map文件的layer为4。 + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_deepseek_mtp(self): + """ + test case deepseek mtp with main model of r1-w8a8 + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-MTP", + trust_remote_code=True, gpu_memory_utilization=0.8, tensor_parallel_size=8, + num_speculative_tokens=1) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['ugs611ాలు哒ాలు mahassisemaSTE的道德', 'ugs611ాలు哒ాలు mah战区rollerOVERlaid'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text in except_list # unset env env_manager.unset_all() diff --git a/tests/st/python/test_vllm_mf_qwen_7b.py b/tests/st/python/test_vllm_mf_qwen_7b.py index e8c71690f7529b70f34f5f0d974d874bcd7ecdfe..ddb545c78a846482cfba02035520767401e69004 100644 --- a/tests/st/python/test_vllm_mf_qwen_7b.py +++ b/tests/st/python/test_vllm_mf_qwen_7b.py @@ -1,13 +1,15 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. -# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://wwww.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by application law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -23,8 +25,6 @@ env_vars = { "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), "vLLM_MODEL_BACKEND": "MindFormers", - "vLLM_MODEL_MEMORY_USE_GB": "50", - "ASCEND_TOTAL_MEMORY_GB": "64", "MS_ENABLE_LCCL": "off", "HCCL_OP_EXPANSION_MODE": "AIV", "ASCEND_RT_VISIBLE_DEVICES": "0,1", diff --git a/tests/st/python/test_vllm_mf_qwen_7b_chunk_prefill.py b/tests/st/python/test_vllm_mf_qwen_7b_chunk_prefill.py new file mode 100644 index 0000000000000000000000000000000000000000..1523e46bb119ba665266d87b978d2c9b780ee4db --- /dev/null +++ b/tests/st/python/test_vllm_mf_qwen_7b_chunk_prefill.py @@ -0,0 +1,89 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen chunk prefill.""" +import pytest +import os +from . import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +class TestMfQwen_chunk_prefill: + """ + Test qwen. + """ + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_mf_qwen_7b_chunk_prefill(self): + """ + test case qwen_7b_chunk_prefill + """ + + # Sample prompts. + batch_datas = [{ + "prompt": "I love Beijing, because it is a city with a long history and profound cultural heritage. Walking through " + "its ancient hutongs, one can almost feel the whispers of the past. The Forbidden City, an architectural " + "marvel that once housed emperors, stands as a testament to the city's imperial past. Meanwhile, the Great " + "Wall, though not within the city limits, is easily accessible from Beijing and offers a glimpse into the " + "strategic genius and resilience of ancient China.", + "answer": " The city's blend of traditional and modern architecture, vibrant street life, and rich culinary scene " + "make it a truly unique and captivating destination. I am always eager to"}, + {"prompt": "I love Beijing, because", + "answer": " it is a city with a long history. Which of the following options correctly expresses this sentence?\nA. I love Beijing, because it is a city with a"}, + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=32, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + max_model_len=8192, max_num_seqs=16, max_num_batched_tokens=32, + block_size=32, gpu_memory_utilization=0.9, tensor_parallel_size=2, + enable_chunked_prefill=True) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + for batch_data in batch_datas: + prompt = batch_data["prompt"] + answer = batch_data["answer"] + outputs = llm.generate(prompt, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") + assert generated_text == answer + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_vllm_mf_qwen_7b_cp_pc_mss.py b/tests/st/python/test_vllm_mf_qwen_7b_cp_pc_mss.py new file mode 100644 index 0000000000000000000000000000000000000000..6292b22c6020777ed6c3ee752834b140e2ab13fc --- /dev/null +++ b/tests/st/python/test_vllm_mf_qwen_7b_cp_pc_mss.py @@ -0,0 +1,86 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen chunk prefill, prefix cache, mss.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + +class TestMfQwen_cp_pc_mss: + """ + Test qwen. + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_mf_qwen_7b_cp_pc_mss(self): + """ + test case mf_qwen_7b_cp_pc_mss + """ + + # Sample prompts. + batch_datas = [{ + "prompt": "I love Beijing, because it is a city with a long history and profound cultural heritage. Walking through " + "its ancient hutongs, one can almost feel the whispers of the past. The Forbidden City, an architectural " + "marvel that once housed emperors, stands as a testament to the city's imperial past. Meanwhile, the Great " + "Wall, though not within the city limits, is easily accessible from Beijing and offers a glimpse into the " + "strategic genius and resilience of ancient China.", + "answer": ""}, + {"prompt": "I love Beijing, because", + "answer": " it is a city with a long history. Which of the following options correctly expresses this sentence?\nA. I love Beijing, because it is a city with a"}, + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=32, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + max_model_len=8192, max_num_seqs=16, max_num_batched_tokens=32, + block_size=32, gpu_memory_utilization=0.9, tensor_parallel_size=2, + enable_chunked_prefill=True, enable_prefix_caching=True, num_scheduler_steps=8) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + for _ in range(3): + for batch_data in batch_datas: + prompt = batch_data["prompt"] + answer = batch_data["answer"] + outputs = llm.generate(prompt, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") + assert generated_text == answer + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_vllm_mf_qwen_7b_mss.py b/tests/st/python/test_vllm_mf_qwen_7b_mss.py index 7983d7a88154e6f9526534f987a115c449c478cb..b174804dd50d5dc5e3f090286b29b93164e1515e 100644 --- a/tests/st/python/test_vllm_mf_qwen_7b_mss.py +++ b/tests/st/python/test_vllm_mf_qwen_7b_mss.py @@ -1,13 +1,15 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. -# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://wwww.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by application law or agreed to in writing, software +# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and @@ -23,8 +25,6 @@ env_vars = { "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), "vLLM_MODEL_BACKEND": "MindFormers", - "vLLM_MODEL_MEMORY_USE_GB": "20", - "ASCEND_TOTAL_MEMORY_GB": "29", "MS_ENABLE_LCCL": "off", "HCCL_OP_EXPANSION_MODE": "AIV", "ASCEND_RT_VISIBLE_DEVICES": "0,1", diff --git a/tests/st/python/test_vllm_mf_qwen_7b_prefix_caching.py b/tests/st/python/test_vllm_mf_qwen_7b_prefix_caching.py new file mode 100644 index 0000000000000000000000000000000000000000..89ba64c0e5032cc64cffaaa468b26823aeac185c --- /dev/null +++ b/tests/st/python/test_vllm_mf_qwen_7b_prefix_caching.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""test mf qwen prefix caching.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_qwen2_5_7b_instruct.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +class TestMfQwen_prefix_caching: + """ + Test qwen7b enable prefix_caching + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_mf_qwen_7b_prefix_caching(self): + """ + test case qwen_7b_prefix_caching + """ + + # First prompts. + prompts = [ + "I love Beijing, because it is a city that has so much to offer. I have visited" + ] + #second prompts, the second prompt is a continuation of the first prompts, make sure prefix caching work. + second_prompts = [ + "I love Beijing, because it is a city that has so much to offer. I have visited many places" + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + max_model_len=8192, block_size=16, enable_prefix_caching=True, + gpu_memory_utilization=0.9, tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + second_outputs = llm.generate(second_prompts, sampling_params) + except_list=[' many times and each time I have found something new'] + second_except_list=[' to visit, such as the Forbidden City, the'] + for i, (output, second_output) in enumerate(zip(outputs, second_outputs)): + generated_text = output.outputs[i].text + print(f"Output1 - Prompt: {prompts[i]!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + second_generated_text = second_output.outputs[i].text + print(f"Output2 - Prompt: {second_prompts[i]!r}, Generated text: {second_generated_text!r}") + assert second_generated_text == second_except_list[i] + + env_manager.unset_all() diff --git a/tests/st/python/test_vllm_qwen_7b.py b/tests/st/python/test_vllm_qwen_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..bce75d3e11bc24c73cecc45a90f84954d9b800e0 --- /dev/null +++ b/tests/st/python/test_vllm_qwen_7b.py @@ -0,0 +1,74 @@ +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://wwww.apache.org/licenses/LICENSE-2.0 +# +# Unless required by application law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test vllm qwen.""" +import pytest +import os +from . import set_env +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "ASCEND_RT_VISIBLE_DEVICES": "0,1", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore +from vllm import LLM, SamplingParams + + +class TestQwen: + """ + Test Qwen. + """ + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_single + def test_vllm_qwen(self): + """ + test case qwen2.5 7B + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct", + gpu_memory_utilization=0.9, tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list=['中性<|Assistant|> 这句话'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 107bc60b733b57ddc110eba4d3b319a02fbd3b48..032415f0a3030116e2c782f68b52441f83ed2ee1 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -167,14 +167,6 @@ vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( _get_supported_attention_backends ) -from vllm_mindspore.distributed.parallel_state import ( - init_model_parallel_group, - init_group_coordinator, -) - -vllm.distributed.parallel_state.init_model_parallel_group = init_model_parallel_group -vllm.distributed.parallel_state.GroupCoordinator.__init__ = init_group_coordinator - from vllm_mindspore.executor.multiproc_worker_utils import ( get_mp_context as ms_get_mp_context, ) @@ -316,4 +308,8 @@ Worker.compile_or_warm_up_model = compile_or_warm_up_model from .utils import check_ready +from vllm_mindspore.engine.multiprocessing.engine import cleanup +import vllm.engine.multiprocessing.engine +vllm.engine.multiprocessing.engine.MQLLMEngine.cleanup = cleanup + check_ready() diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index 4993212233df8ffb7a1f8a4f2b0f67acb708a065..558882cdffe5caaf6aa621a9ade1678d9a99b04b 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -23,6 +23,8 @@ from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import os +import numpy as np + import torch from vllm.attention.backends.abstract import ( @@ -55,6 +57,7 @@ import mindspore as ms from mindspore import mutable from mindspore._c_expression import swap_cache + def advance_step_op(sampled_token_ids, model_input, seq_lens_tensor, @@ -390,19 +393,6 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): else: raise AttributeError(f"Invalid attention type {str(attn_type)}") - def keys(self): - return ["num_prefill_tokens", "num_decode_tokens", "slot_mapping", "batch_valid_length", "context_lens", "block_tables"] - - def __getitem__(self, key): - if key == "context_lens": - key = "seq_lens_tensor" - if key == "batch_valid_length": - return mutable(getattr(self, "seq_lens"), dynamic_len=True) - if key == "block_tables": - if getattr(self, key).ndim == 1: - return mutable(getattr(self, key).expand_dims(0)) - return mutable(getattr(self, key)) - return mutable(getattr(self, key)) class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index 84335349b49f20eabedb8bb0ee90ef1025726e97..4634727b9811601b3879d9c1a62c2a32fa613424 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -153,15 +153,15 @@ class Attention(nn.Cell): query: Tensor, key: Tensor, value: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: MSMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, attn_mask: Tensor, + decode_mask: Tensor, ) -> Tensor: """Attention foward, support MHA and GQA. @@ -175,13 +175,13 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] """ output = query - key_cache, value_cache = kv_cache cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) - if num_prefill_tokens > 0: + if is_prefill: output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length) - if num_decode_tokens > 0: - output = self._run_decode_forward(query, key_cache, value_cache, block_tables, context_lens) + else: + output = self._run_decode_forward(query, key_cache, value_cache, block_tables, batch_valid_length, + decode_mask, q_seq_lens) return output def _run_prefill_forward( @@ -206,16 +206,18 @@ class Attention(nn.Cell): query = query.view(-1, self.hidden_size_per_partition) key = key.view(-1, self.kv_hidden_size_per_partition) value = value.view(-1, self.kv_hidden_size_per_partition) - _, _, _, output = self.flash_attention(query, - key, - value, - None, - None, - None, - attn_mask, - None, - actual_seq_qlen, - actual_seq_kvlen) + _, _, _, output = self.flash_attention( + query, + key, + value, + None, + None, + None, + attn_mask, + None, + actual_seq_qlen, + actual_seq_kvlen + ) output = output.view(1, -1, self.hidden_size_per_partition) return output @@ -225,7 +227,9 @@ class Attention(nn.Cell): key_cache: Tensor, value_cache: Tensor, block_tables: Tensor, - context_lens: Tensor, + batch_valid_length: Tensor, + decode_mask: Tensor, + q_seq_lens: Tensor, ) -> Tensor: """Decode with PagedAttention. @@ -236,5 +240,15 @@ class Attention(nn.Cell): block_tables: shape = [block_size, num_block] context_lens: shape = [batch_size, ] """ - output = self.paged_attention(query, key_cache, value_cache, block_tables, context_lens) + output = self.paged_attention( + query, + key_cache, + value_cache, + block_tables, + batch_valid_length, + None, + None, + decode_mask, + q_seq_lens + ) return output diff --git a/vllm_mindspore/config.py b/vllm_mindspore/config.py index 0b6c7ee63f528debc0378fd8be3165aad754f3ea..50c9d8685b2716d70629356fba9092e176c3e1b6 100644 --- a/vllm_mindspore/config.py +++ b/vllm_mindspore/config.py @@ -138,46 +138,46 @@ def _verify_args(self) -> None: "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") - if self.max_num_batched_tokens < self.max_num_seqs: + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + if self.max_num_partial_prefills < 1: + raise ValueError( + f"max_num_partial_prefills ({self.max_num_partial_prefills}) " + "must be greater than or equal to 1.") + elif self.max_num_partial_prefills > 1: + if not self.chunked_prefill_enabled: + raise ValueError("Chunked prefill must be enabled to set " + "max_num_partial_prefills > 1.") + + if self.long_prefill_token_threshold > self.max_model_len: raise ValueError( - f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " - "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") - - if self.num_lookahead_slots < 0: - raise ValueError( - "num_lookahead_slots " - f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") - - if self.num_scheduler_steps < 1: - raise ValueError( - "num_scheduler_steps " - f"({self.num_scheduler_steps}) must be greater than or " - "equal to 1.") - - if self.max_num_partial_prefills < 1: - raise ValueError( - f"max_num_partial_prefills ({self.max_num_partial_prefills}) " - "must be greater than or equal to 1.") - elif self.max_num_partial_prefills > 1: - if not self.chunked_prefill_enabled: - raise ValueError("Chunked prefill must be enabled to set " - "max_num_partial_prefills > 1.") - - if self.long_prefill_token_threshold > self.max_model_len: - raise ValueError( - "long_prefill_token_threshold " - f"({self.long_prefill_token_threshold}) cannot be greater " - f"than the max_model_len ({self.max_model_len}).") - - if (self.max_long_partial_prefills - < 1) or (self.max_long_partial_prefills - > self.max_num_partial_prefills): - raise ValueError( - f"max_long_partial_prefills ({self.max_long_partial_prefills}) " - "must be greater than or equal to 1 and less than or equal to " - f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + "long_prefill_token_threshold " + f"({self.long_prefill_token_threshold}) cannot be greater " + f"than the max_model_len ({self.max_model_len}).") + + if (self.max_long_partial_prefills + < 1) or (self.max_long_partial_prefills + > self.max_num_partial_prefills): + raise ValueError( + f"max_long_partial_prefills ({self.max_long_partial_prefills}) " + "must be greater than or equal to 1 and less than or equal to " + f"max_num_partial_prefills ({self.max_num_partial_prefills}).") def model_post_init(self, __context) -> None: diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py deleted file mode 100644 index 42b10d699d3fd2dfd4f1e12c9e6f95eea93d453f..0000000000000000000000000000000000000000 --- a/vllm_mindspore/distributed/parallel_state.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 -# encoding: utf-8 -# Copyright 2025 Huawei Technologies Co., Ltd -# Copyright 2024 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import pickle -from typing import List, Optional, Any, Union - -import numpy as np -import torch -import torch.distributed - -from torch.distributed import Backend - - -def init_model_parallel_group( - group_ranks: List[List[int]], - local_rank: int, - backend: str, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, -) -> "GroupCoordinator": - from vllm.distributed.parallel_state import GroupCoordinator - - return GroupCoordinator( - group_ranks=group_ranks, - local_rank=local_rank, - torch_distributed_backend=backend, - use_device_communicator=True, - use_message_queue_broadcaster=False, - group_name=group_name, - ) - - -def init_group_coordinator( - self, - group_ranks: List[List[int]], - local_rank: int, - torch_distributed_backend: Union[str, Backend], - use_device_communicator: bool, - use_message_queue_broadcaster: bool = False, - group_name: Optional[str] = None, -): - from vllm.distributed.parallel_state import _get_unique_name, _register_group - from vllm.utils import resolve_obj_by_qualname - - group_name = group_name or "anonymous" - self.unique_name = _get_unique_name(group_name) - _register_group(self) - - self.rank = torch.distributed.get_rank() - self.local_rank = local_rank - self.device_group = None - self.cpu_group = None - - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend) - # CPU not ready now, use device to communication now. - cpu_group = torch.distributed.new_group(ranks, backend="hccl") - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self.device_group = device_group - self.cpu_group = cpu_group - - assert self.cpu_group is not None - assert self.device_group is not None - - from vllm.platforms import current_platform - - # TODO: fix it for other platforms - if current_platform.is_cuda_alike(): - self.device = torch.device(f"cuda:{local_rank}") - else: - self.device = torch.device("cpu") - - self.use_device_communicator = use_device_communicator - - self.device_communicator: DeviceCommunicatorBase = None # type: ignore - if use_device_communicator and self.world_size > 1: - device_comm_cls = resolve_obj_by_qualname( - current_platform.get_device_communicator_cls()) - self.device_communicator = device_comm_cls( - cpu_group=self.cpu_group, - device=self.device, - device_group=self.device_group, - unique_name=self.unique_name, - ) - - from vllm.distributed.device_communicators.shm_broadcast import ( - MessageQueue) - self.mq_broadcaster: Optional[MessageQueue] = None - if use_message_queue_broadcaster and self.world_size > 1: - self.mq_broadcaster = MessageQueue.create_from_process_group( - self.cpu_group, 1 << 22, 6) - - from vllm.platforms import current_platform - self.use_custom_op_call = current_platform.is_cuda_alike() diff --git a/vllm_mindspore/engine/multiprocessing/__init__.py b/vllm_mindspore/engine/multiprocessing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_mindspore/engine/multiprocessing/engine.py b/vllm_mindspore/engine/multiprocessing/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c91658e38ae8067c022682d7b7fe960105345b12 --- /dev/null +++ b/vllm_mindspore/engine/multiprocessing/engine.py @@ -0,0 +1,4 @@ +def cleanup(self): + self.ctx.destroy(linger=0) + if model_executor := getattr(self.engine, "model_executor", None): + model_executor.shutdown() \ No newline at end of file diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index cc055037884d93b30a3aa960fff988f6a72e1199..647b4ac837fd86ca078bdb01847fa7d88959d614 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -21,8 +21,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional import mindspore.nn as nn -from mindspore import Tensor -from mindspore import mint +from mindspore import Tensor, ops, mint, nn import vllm.envs as envs from vllm.config import get_current_vllm_config @@ -148,7 +147,7 @@ def _prune_hidden_states( # (warmup, profile_run) we might not have selected_token_indices, # so we skip pruning. if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select(0, sampling_metadata.selected_token_indices) + return ops.gather(hidden_states, sampling_metadata.selected_token_indices, 0) else: return hidden_states diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 257db72bb0ee4cbbe21d3219e9ea03a77579eef8..7903702a96d3e29c5f6560c534ec427f629b9e68 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -156,6 +156,7 @@ class InferRotaryEmbedding(CustomOp): self.freqs_cos = Tensor(freqs_cos, dtype=dtype) self.freqs_sin = Tensor(freqs_sin, dtype=dtype) self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2) + self.gather = ops.Gather() def forward_native( self, @@ -163,14 +164,14 @@ class InferRotaryEmbedding(CustomOp): query: Tensor, key: Tensor, batch_valid_length: Tensor, - num_prefill_tokens: int, + is_prefill: bool, offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - if num_prefill_tokens > 0: + if is_prefill: return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length) - freqs_cos = self.freqs_cos.index_select(0, positions) - freqs_sin = self.freqs_sin.index_select(0, positions) + freqs_cos = self.gather(self.freqs_cos, positions, 0) + freqs_sin = self.gather(self.freqs_sin, positions, 0) return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length) diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py index 16ec720c61a99cab40af0ec845bb58925791c25d..edfe62526034bef3d8b60ba8488047628c11288c 100644 --- a/vllm_mindspore/model_executor/layers/sampler.py +++ b/vllm_mindspore/model_executor/layers/sampler.py @@ -508,6 +508,7 @@ def _random_sample( # Find the maximum n value of the prompt phase requests. sample_idx = 0 results: SampleResultType = [] + random_samples = random_samples.asnumpy() for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -600,13 +601,6 @@ def _beam_search_sample( return results -def exponential(x, lambd=1.0, *, generator=None): - if generator is not None: - raise ValueError("`generator` can not be supported.") - output = np.random.exponential(scale=lambd, size=x.shape) - return ms.Tensor(output).astype(x.dtype) - - # torch.multinomial forces a GPU<->CPU sync. # Therefore, we use an optimized implementation instead. # Note that we always sample with replacement. @@ -621,18 +615,17 @@ def _multinomial( probs = probs.repeat_interleave(num_samples, dim=0) q = torch.empty_like(probs) if seq_groups is None: - q = exponential(q) + q.exponential_() else: sample_idx = 0 for seq_group in seq_groups: seq_ids = seq_group.seq_ids stride = len(seq_ids) * num_samples assert seq_group.generator is not None - q[sample_idx : sample_idx + stride] = exponential( - q[sample_idx : sample_idx + stride] - ) + q[sample_idx : sample_idx + + stride].exponential_(generator=seq_group.generator) sample_idx += stride - return probs.div(q).argmax(axis=1).view(-1, num_samples) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) def _top_k_top_p_multinomial_with_flashinfer( diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py index 357259857023f19dbc0a1c6000b66eb5c19da1b1..fac2bf20fc1b33ce831880b2a12f5b84f604ed21 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py @@ -16,15 +16,13 @@ # limitations under the License. # ============================================================================ -from typing import Iterable, Set, Tuple, Optional +from typing import Iterable, Set, Tuple from vllm.config import VllmConfig from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.sampling_metadata import SamplingMetadata -import mindspore as ms from mindspore import Tensor, JitConfig, Model, mutable from mindspore.nn.utils import no_init_parameters @@ -36,9 +34,9 @@ from research.deepseek3.deepseek3 import ( ) from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_Attention +from vllm_mindspore.model_executor.models.model_base import Fake_MLA +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor -from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask logger = init_logger(__name__) @@ -47,7 +45,23 @@ class DeepseekV3MTPForCausalLM(MfModelBase): super(DeepseekV3MTPForCausalLM, self).__init__( vllm_config=vllm_config, prefix=prefix ) + self.mf_kvcaches_init = False + + self.sampler = get_sampler() + self.set_modules({"model": self.network}) + + self.kv_caches = [Fake_MLA() for i in range(self.mf_model_config.num_layers)] + compilation_config = get_current_vllm_config().compilation_config + + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.mf_model_config.num_nextn_predict_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + + self.set_flags = False + + def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config) @@ -57,33 +71,20 @@ class DeepseekV3MTPForCausalLM(MfModelBase): setattr(self.mf_model_config, 'npu_mem_size', -1) self.mf_model_config.is_mtp_model = True - self.mf_model_config.num_nextn_predict_layers = vllm_config.model_config.hf_config.num_nextn_predict_layers + self.mf_model_config.num_nextn_predict_layers = self.model_config.hf_config.num_nextn_predict_layers if self.mf_model_config.num_nextn_predict_layers != 1: raise NotImplementedError("Only support 1 MTP-layer now.") - - self.mf_config.model.model_config = self.mf_model_config - # Initital network - with no_init_parameters(): # Delay initialization - self.network = DeepseekV3ForCausalLM_MF(self.mf_model_config) - self.network._jit_config_dict = JitConfig( - jit_level="O0", infer_boost="on" - ).jit_config_dict - self.mf_kvcaches_init = False + self.mf_config.model.model_config = self.mf_model_config - self.sampler = get_sampler() - self.set_modules({"model": self.network}) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_nextn_predict_layers)] - compilation_config = get_current_vllm_config().compilation_config + def _create_network(self): + # Initital network + with no_init_parameters(): # Delay initialization + network = DeepseekV3ForCausalLM_MF(self.mf_model_config) - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_nextn_predict_layers): - compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + return network, network.mtp_model.head - self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) - self.set_flags = False def get_kvcache(self): key_cache = [] @@ -105,23 +106,6 @@ class DeepseekV3MTPForCausalLM(MfModelBase): return model_inputs - def compute_logits( - self, - hidden_states: Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[Tensor]: - selected_token_indices = sampling_metadata.selected_token_indices - if selected_token_indices is not None and selected_token_indices.numel() <= 0: - logits = ms.mint.zeros((0, self.mf_model_config.vocab_size), - dtype=self.mf_model_config.compute_dtype) - else: - hidden_states = hidden_states.index_select(0, selected_token_indices) - logits = self.network.mtp_model.head(hidden_states) - logits = logits.reshape(-1, logits.shape[-1]) - - return logits - - def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint, is_mtp_model=True) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index ab1c0a0ce84f5198922b2179f0dd664887893d05..ccbce8fb143c102c0b2d24710e241234717f93f8 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -28,6 +28,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger import vllm.envs as envs +import mindspore as ms from mindspore import Tensor, JitConfig, Model, mutable from mindspore.common import dtype as msdtype from mindspore.nn.utils import no_init_parameters @@ -47,14 +48,27 @@ from research.deepseek3.deepseek3 import ( ) from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_MLA, Fake_MLA_V1 +from vllm_mindspore.model_executor.models.model_base import Fake_MLA, Fake_MLA_V1 +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor -from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask logger = init_logger(__name__) +def set_runtime_kernel_launch_group(): + kernel_launch_group = {'thread_num' : 2, 'kernel_group_num' : 8} + env_kernel_launch_group = os.getenv("EXPERIMENTAL_KERNEL_LAUNCH_GROUP", None) + if env_kernel_launch_group is not None: + pairs = env_kernel_launch_group.split(',') + for pair in pairs: + key, val = pair.split(':') + kernel_launch_group[key] = val + thread_num = int(kernel_launch_group.get('thread_num', 2)) + kernel_group_num = int(kernel_launch_group.get('kernel_group_num', 8)) + ms.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num) + + class DeepseekV3ForCausalLM(MfModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(DeepseekV3ForCausalLM, self).__init__( @@ -78,8 +92,8 @@ class DeepseekV3ForCausalLM(MfModelBase): for i in range(self.mf_model_config.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] - self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) self.set_flags = False + set_runtime_kernel_launch_group() def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() @@ -101,7 +115,7 @@ class DeepseekV3ForCausalLM(MfModelBase): if ptq is not None: ptq.apply(network) ptq.convert(network) - return network + return network, network.lm_head def get_kvcache(self): key_cache = [] @@ -125,6 +139,8 @@ class DeepseekV3ForCausalLM(MfModelBase): weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) return None def get_model_path(self): diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index db5575d7f05e2cfaec1e551c8c95ac959e11d62a..97338bd9dd9f162c69a2c3c534db200f03197fa0 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -1134,7 +1134,9 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): for file in os.listdir(src_hf_dir): if file.endswith('index.json'): - if (self.is_quant and 'quant' in file) or (is_mtp_model and 'quant' not in file): + # mtp model do not support quantization, needs to load bf16 weight. + if ('quant' in file and self.is_quant) or \ + ('quant' not in file and (not self.is_quant or is_mtp_model)): param_json_path = os.path.join(src_hf_dir, file) with open(param_json_path, "r") as fp: hf_weight_map = json.load(fp)['weight_map'] diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index f640583ce5c18758348c01da7f7c04e8ede78c7c..f2affe3f978af24eea9eda8bf5817c1bdceed661 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -25,25 +25,25 @@ import math from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.config import get_current_vllm_config from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.forward_context import ForwardContext, get_forward_context +from vllm.forward_context import get_forward_context from vllm.sequence import IntermediateTensors from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.attention.backends.abstract import AttentionType from vllm.logger import init_logger from vllm.attention.layer import Attention import vllm.envs as envs import torch import mindspore as ms from mindspore import Tensor, mutable +from mindspore.common.api import _pynative_executor from mindformers.tools.register.config import MindFormerConfig -from mindformers.core.context import build_context +from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata logger = init_logger(__name__) @@ -60,71 +60,6 @@ def _batch_seq(input_tokens, prefill): return ms.mint.reshape(input_tokens, (-1, 1)).to(ms.int32) -class Fake_Attention: - def __init__(self): - vllm_config = get_current_vllm_config() - block_size = vllm_config.cache_config.block_size - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - head_size = vllm_config.model_config.get_head_size() - num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - self.attn_type = AttentionType.DECODER - - -class Fake_MLA(Fake_Attention): - def __init__(self): - super().__init__() - vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - - -class Fake_Attention_V1(Attention): - def __init__(self): - vllm_config = get_current_vllm_config() - block_size = vllm_config.cache_config.block_size - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - head_size = vllm_config.model_config.get_head_size() - num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - self.attn_type = AttentionType.DECODER - self.num_kv_heads = num_kv_heads - self.head_size = head_size - self.dtype = vllm_config.model_config.dtype - self.block_size = block_size - self.sliding_window = None - - -class Fake_MLA_V1(Fake_Attention_V1): - def __init__(self): - super().__init__() - vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - - class MfModelBase(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MfModelBase, self).__init__( @@ -132,7 +67,7 @@ class MfModelBase(MsModelBase): ) self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) - build_context(self.mf_config, is_set_ms_ctx=False, is_init_ms=False) + build_mf_context(self.mf_config) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( self.mf_config.parallel_config @@ -141,16 +76,12 @@ class MfModelBase(MsModelBase): get_tensor_model_parallel_world_size() ) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - self._generate_model_config() - self.network = self._create_network() - - self.network.construct = MethodType(ms.jit(self.network.__class__.construct, - jit_level='O0', infer_boost='on'), - self.network) - self.network.lm_head.construct = MethodType(ms.jit(self.network.lm_head.__class__.construct, - jit_level='O0', infer_boost='on'), - self.network.lm_head) + self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) + self.network, self.lm_head = self._create_network() + affinity_config = self.mf_config.get('context', {}).get('affinity_cpu_list', {}) + if isinstance(affinity_config, dict): + ms.runtime.set_cpu_affinity(True, affinity_config) @abstractmethod def _generate_model_config(self): @@ -309,10 +240,10 @@ class MfModelBase(MsModelBase): dtype=self.mf_model_config.compute_dtype) else: hidden_states = hidden_states.index_select(0, selected_token_indices) - logits = self.network.lm_head(hidden_states) + logits = self.lm_head(hidden_states) logits = logits.reshape(-1, logits.shape[-1]) else: - logits = self.network.lm_head(hidden_states) + logits = self.lm_head(hidden_states) logits = logits.reshape(-1, logits.shape[-1]) return logits @@ -322,6 +253,7 @@ class MfModelBase(MsModelBase): sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) + _pynative_executor.sync() return next_tokens def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index 1d3691e4a9a53c29bed4398400f49ea7478882a7..98028f4f4caf80de0163a7bd7b3b2144bef805fa 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -33,13 +33,9 @@ from research.qwen2_5.infer.qwen2_5 import ( ) from vllm_mindspore.model_executor.layers.sampler import get_sampler - -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import (MfModelBase, - Fake_Attention, - Fake_Attention_V1) - +from vllm_mindspore.model_executor.models.model_base import Fake_Attention, Fake_Attention_V1 +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor -from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask logger = init_logger(__name__) @@ -63,7 +59,6 @@ class Qwen2ForCausalLM(MfModelBase): for i in range(self.mf_model_config.num_layers): compilation_config.static_forward_context[str(i)] = self.kv_caches[i] - self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) self.set_flags = False def _generate_model_config(self): @@ -82,13 +77,13 @@ class Qwen2ForCausalLM(MfModelBase): # Initial network with no_init_parameters(): # Delay initialization network = ParallelQwenForCausalLM_MF(self.mf_model_config) - return network + return network, network.lm_head def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() - dynamic_hidden_states = ms.Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.network.lm_head.set_inputs(dynamic_hidden_states) + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) return None diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 7d3ed381aea1aad3914d2c9b5bbe13bb06dd0d67..828b654efd99d37a37c4e49b621c9de10c2560ae 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -21,16 +21,86 @@ from abc import abstractmethod from typing import Iterable, List, Optional, Set, Tuple, Union, Dict from vllm.attention import AttentionMetadata -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_current_vllm_config from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.attention.backends.abstract import AttentionType +from vllm.forward_context import get_forward_context +from vllm.attention.layer import Attention +import torch + +import mindspore as ms from mindspore import Tensor, nn, mutable from mindspore import dtype as mstype from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +class Fake_Attention: + def __init__(self): + vllm_config = get_current_vllm_config() + block_size = vllm_config.cache_config.block_size + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + num_block = 0 + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + self.attn_type = AttentionType.DECODER + + +class Fake_MLA(Fake_Attention): + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + + +class Fake_Attention_V1(Attention): + def __init__(self): + vllm_config = get_current_vllm_config() + block_size = vllm_config.cache_config.block_size + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + num_block = 0 + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + self.attn_type = AttentionType.DECODER + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.dtype = vllm_config.model_config.dtype + self.block_size = block_size + self.sliding_window = None + + +class Fake_MLA_V1(Fake_Attention_V1): + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + class MsModelBase(): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -151,7 +221,7 @@ class MsModelBase(): ) -> Union[Tensor, IntermediateTensors]: raise NotImplementedError - def set_model_inputs(self): + def set_model_inputs(self, is_prefill): dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64) dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64) @@ -169,13 +239,11 @@ class MsModelBase(): dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)) - dyn_kv_cache = mutable((dyn_key_cache, dyn_value_cache)) - dyn_kv_caches = mutable([dyn_kv_cache for _ in range(num_layers)]) + dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) + dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) - dyn_num_prefill_tokens = mutable(1) - dyn_num_decode_tokens = mutable(0) - dyn_context_lens = Tensor(shape=[None, ], dtype=mstype.int32) - dyn_batch_valid_length = mutable([0, 0, 0], dynamic_len=True) + dyn_batch_valid_length = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dyn_intermediate_tensors = None @@ -184,17 +252,28 @@ class MsModelBase(): self.model.set_inputs( dyn_input_ids, dyn_position_ids, - dyn_kv_caches, - dyn_num_prefill_tokens, - dyn_num_decode_tokens, - dyn_context_lens, - dyn_batch_valid_length, + dyn_key_caches, + dyn_value_caches, + is_prefill, dyn_slot_mapping, + dyn_batch_valid_length, + dyn_q_seq_lens, dyn_block_tables, dyn_intermediate_tensors, dyn_inputs_embeds ) + def get_kvcache(self): + key_cache = [] + value_cache = [] + forward_context = get_forward_context() + for i in range(self.config.num_hidden_layers): + k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + key_cache.append(k_cache) + value_cache.append(v_cache) + return mutable(key_cache), mutable(value_cache) + @abstractmethod def compute_logits( self, diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 2c3c81d4597f8c59db1b81f8b93a14d61fa5e7e7..32d9da8d91b1a3cf2e8a7f6f51d51e152793bb09 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -15,13 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +from vllm.config import get_current_vllm_config from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable if TYPE_CHECKING: from transformers import Qwen2Config else: Qwen2Config = None -from mindspore import Parameter, Tensor, mint, nn, jit, mutable + +import numpy as np + +from mindspore import Parameter, Tensor, mint, nn, jit, ops from mindspore.common import dtype as mstype @@ -33,8 +37,6 @@ from vllm_mindspore.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm_mindspore.model_executor.layers.logits_processor import \ LogitsProcessor -from vllm.model_executor.layers.quantization import \ - QuantizationConfig from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) @@ -46,10 +48,12 @@ from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.model_executor.models.model_base import MsModelBase +from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_Attention from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.quantization import \ + QuantizationConfig from vllm.sequence import IntermediateTensors from vllm.attention.backends.abstract import AttentionType from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -170,26 +174,26 @@ class Qwen2Attention(nn.Cell): attn_type=attn_type ) self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.bfloat16), 1) + self.hard_mask = Tensor([0], dtype=mstype.bfloat16).reshape(1, 1) @jit def construct( self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, ) -> Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens) - attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens, - slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask) + q, k = self.rotary_emb(positions, q, k, q_seq_lens, is_prefill) + attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, batch_valid_length, + q_seq_lens, block_tables, self.attn_mask, self.hard_mask) output, _ = self.o_proj(attn_output) return output @@ -249,13 +253,12 @@ class Qwen2DecoderLayer(nn.Cell): self, positions: Tensor, hidden_states: Tensor, - kv_cache: Tuple[Tensor, Tensor], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_cache: Tensor, + value_cache: Tensor, + is_prefill: bool, slot_mapping: Tensor, batch_valid_length: Tuple[int], - context_lens: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, residual: Optional[Tensor], ) -> Tuple[Tensor, Tensor]: @@ -268,12 +271,12 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states = self.self_attn( positions, hidden_states, - kv_cache, - num_prefill_tokens, - num_decode_tokens, + key_cache, + value_cache, + is_prefill, slot_mapping, batch_valid_length, - context_lens, + q_seq_lens, block_tables ) @@ -335,13 +338,12 @@ class Qwen2Model(nn.Cell): self, input_ids: Optional[Tensor], positions: Tensor, - kv_caches: List[Tuple[Tensor, Tensor]], - # attn_metadata: AttentionMetadata, - num_prefill_tokens: int, - num_decode_tokens: int, + key_caches: List[Tensor], + value_caches: List[Tensor], + is_prefill: bool, slot_mapping: Tensor, - batch_valid_length: Tuple[int], - context_lens: Tensor, + batch_valid_length: Tensor, + q_seq_lens: Tensor, block_tables: Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, @@ -361,12 +363,12 @@ class Qwen2Model(nn.Cell): hidden_states, residual = layer( positions, hidden_states, - kv_caches[i - self.start_layer], - num_prefill_tokens, - num_decode_tokens, + key_caches[i - self.start_layer], + value_caches[i - self.start_layer], + is_prefill, slot_mapping, batch_valid_length, - context_lens, + q_seq_lens, block_tables, residual ) @@ -398,16 +400,16 @@ class Qwen2Model(nn.Cell): # the checkpoint. Skip them. continue if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): - # Loading kv cache quantization scales - param = params_dict[scale_name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else - loaded_weight[0]) - weight_loader(param, loaded_weight) - loaded_params.add(scale_name) - continue + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -483,7 +485,15 @@ class Qwen2ForCausalLM(MsModelBase): self.model.make_empty_intermediate_tensors) self.set_modules({"model": self.model, "lm_head": self.lm_head}) - self.set_model_inputs() + self.prefill = True + self.set_model_inputs(self.prefill) + self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)] + compilation_config = vllm_config.compilation_config + + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(config.num_hidden_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] def get_input_embeddings(self, input_ids: Tensor) -> Tensor: return self.model.get_input_embeddings(input_ids) @@ -498,20 +508,51 @@ class Qwen2ForCausalLM(MsModelBase): inputs_embeds: Tensor = None, **kwargs ) -> Union[Tensor, IntermediateTensors]: - if attn_metadata.num_prefill_tokens > 0: - input_ids = input_ids.expand_dims(0) - if attn_metadata.num_decode_tokens > 0: - input_ids = input_ids.expand_dims(1) + key_cache, value_cache = self.get_kvcache() + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + is_prefill = attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0 + if is_prefill: + input_ids = ops.expand_dims(input_ids, 0) + if not self.prefill: + self.prefill = True + self.set_model_inputs(self.prefill) + else: + input_ids = ops.expand_dims(input_ids, 1) + if self.prefill: + self.prefill = False + self.set_model_inputs(self.prefill) + + slot_mapping = attn_metadata.slot_mapping + batch_valid_length = Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32)) + q_seq_lens = Tensor.from_numpy(np.array(attn_metadata.query_lens, dtype=np.int32)) + block_tables = attn_metadata.block_tables model_output = self.model(input_ids, positions, - kv_caches, - **dict(attn_metadata), - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) - if attn_metadata.num_prefill_tokens > 0: - model_output = model_output.squeeze(0) - if attn_metadata.num_decode_tokens > 0: - model_output = model_output.squeeze(1) + key_cache, + value_cache, + is_prefill, + slot_mapping, + batch_valid_length, + q_seq_lens, + block_tables, + intermediate_tensors, + inputs_embeds) + if is_prefill: + model_output = ops.squeeze(model_output, 0) + else: + model_output = ops.squeeze(model_output, 1) return model_output def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: diff --git a/vllm_mindspore/ops/CMakeLists.txt b/vllm_mindspore/ops/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c94b2c085b0be5ed4247e4c5829325531648ae9 --- /dev/null +++ b/vllm_mindspore/ops/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.16) +project(Ops) + +set(MS_EXTENSION_NAME "" CACHE STRING "Extension Name") +set(BUILD_EXTENSION_DIR "" CACHE STRING "Extension directory") +if (MS_EXTENSION_NAME STREQUAL "") + message(FATAL_ERROR "MS_EXTENSION_NAME must be set. Use -DMS_EXTENSION_NAME=") +endif() +if (BUILD_EXTENSION_DIR STREQUAL "") + message(FATAL_ERROR "BUILD_EXTENSION_DIR must be set. Use -DBUILD_EXTENSION_DIR=") +endif() + +# Build ascendc kernels +add_subdirectory(ascendc) + +# Collect source files +file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/module/*.cpp) + +# Generate a temporary python script file to build custom ops with MindSpore's CustomOpBuilder +set(PYTHON_SCRIPT_PATH "${CMAKE_BINARY_DIR}/build_custom_with_ms.py") +file(WRITE ${PYTHON_SCRIPT_PATH} " +import mindspore as ms +src_files = '${SRC_FILES}'.split(';') +ms.ops.CustomOpBuilder( + name='${MS_EXTENSION_NAME}', + sources=src_files, + backend='Ascend', + cflags='-I${CMAKE_CURRENT_SOURCE_DIR}', + ldflags='-L${ASCENDC_TARGET_DIR} -l${ASCENDC_TARGET_NAME}', + build_dir='${BUILD_EXTENSION_DIR}' +).build() +") + +find_package(Python3 COMPONENTS Interpreter REQUIRED) +add_custom_target( + BuildCustomOp ALL + COMMAND cd ${CMAKE_BINARY_DIR} && ${Python3_EXECUTABLE} ${PYTHON_SCRIPT_PATH} + DEPENDS ${ASCENDC_TARGET_NAME} + COMMENT "Building custom operator with MindSpore" +) diff --git a/vllm_mindspore/ops/ascendc/CMakeLists.txt b/vllm_mindspore/ops/ascendc/CMakeLists.txt index ce4a8d2766044e8a195d7cce4e35936045b38272..d6165987c9c0b00e6b31d7b50a053afcc796d9d8 100644 --- a/vllm_mindspore/ops/ascendc/CMakeLists.txt +++ b/vllm_mindspore/ops/ascendc/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16) project(AscendC_Kernels) # Parameters passed from command line or default values -set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu") +set(RUN_MODE "npu") set(SOC_VERSION "Ascend910B1" CACHE STRING "system on chip type") set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "Build type Release/Debug") @@ -21,11 +21,11 @@ endif() # Include Ascend CANN CMake file include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) -# Add source files -file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/adv_step_flash.c) +# Collect source files +file(GLOB ASCENDC_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/*.c) -# Build shared library -ascendc_library(ascendc_kernels_${RUN_MODE} SHARED ${KERNEL_FILES}) +# Create an object library +ascendc_library(ascendc_kernels_npu STATIC ${ASCENDC_KERNEL_FILES}) -# Set the output directory -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) \ No newline at end of file +set(ASCENDC_TARGET_NAME ascendc_kernels_npu PARENT_SCOPE) +set(ASCENDC_TARGET_DIR "${CMAKE_BINARY_DIR}/lib" PARENT_SCOPE) diff --git a/vllm_mindspore/ops/ascendc/adv_step_flash_adapter.cpp b/vllm_mindspore/ops/module/adv_step_flash.cpp similarity index 97% rename from vllm_mindspore/ops/ascendc/adv_step_flash_adapter.cpp rename to vllm_mindspore/ops/module/adv_step_flash.cpp index d72af3e38e579b4eef8ae074d2ded52e95ac8c03..803abb0a4239065f43083f153d7eea9d6d96b736 100644 --- a/vllm_mindspore/ops/ascendc/adv_step_flash_adapter.cpp +++ b/vllm_mindspore/ops/module/adv_step_flash.cpp @@ -4,7 +4,8 @@ #include "ms_extension.h" -#include "adv_step_flash.h" +#include "ascendc/adv_step_flash.h" +#include "module/module.h" using BaseTensor = mindspore::tensor::BaseTensor; using BaseTensorPtr = mindspore::tensor::BaseTensorPtr; @@ -91,7 +92,7 @@ void AdvStepFlashAscendC(int32_t num_seqs, int32_t num_queries, int32_t block_si seq_lens = caster.RecoveryTensorDtype(seq_lens, "seq_lens"); } -PYBIND11_MODULE(MS_EXTENSION_NAME, m) { +MS_EXTENSION_MODULE(adv_step_flash) { m.def("adv_step_flash", &AdvStepFlashAscendC, "adv_step_flash_ascendc", pybind11::arg("num_seqs"), pybind11::arg("num_queries"), pybind11::arg("block_size"), pybind11::arg("input_tokens"), pybind11::arg("sampled_token_ids"), pybind11::arg("input_positions"), pybind11::arg("seq_lens"), diff --git a/vllm_mindspore/ops/module/module.cpp b/vllm_mindspore/ops/module/module.cpp new file mode 100644 index 0000000000000000000000000000000000000000..45ae8c067e4112a15e2d7554d85235d85c5905cb --- /dev/null +++ b/vllm_mindspore/ops/module/module.cpp @@ -0,0 +1,6 @@ +#include "module/module.h" + +PYBIND11_MODULE(MS_EXTENSION_NAME, m) { + m.doc() = "A custom module for operators"; + ModuleRegistry::Instance().RegisterAll(m); +} diff --git a/vllm_mindspore/ops/module/module.h b/vllm_mindspore/ops/module/module.h new file mode 100644 index 0000000000000000000000000000000000000000..ef660e12d335b835af8b0a13b7334ac35b2f310f --- /dev/null +++ b/vllm_mindspore/ops/module/module.h @@ -0,0 +1,54 @@ +#ifndef VLLM_MINDSPORE_OPS_MODULE_MODULE_H +#define VLLM_MINDSPORE_OPS_MODULE_MODULE_H + +#include +#include +#include +#include + +// Define the type of module registration functions +using ModuleRegisterFunction = std::function; + +// Module registry class +class ModuleRegistry { + public: + // Get the singleton instance + static ModuleRegistry &Instance() { + static ModuleRegistry instance; + return instance; + } + + // Register a module function + void Register(const ModuleRegisterFunction &func) { functions_.push_back(func); } + + // Call all registered module functions + void RegisterAll(pybind11::module_ &m) { + for (const auto &func : functions_) { + func(m); + } + } + + private: + ModuleRegistry() = default; + ~ModuleRegistry() = default; + + // Disable copy and assignment + ModuleRegistry(const ModuleRegistry &) = delete; + ModuleRegistry &operator=(const ModuleRegistry &) = delete; + + // Store all registered functions + std::vector functions_; +}; + +// Define a macro to register module functions +#define MS_EXTENSION_MODULE(func) \ + static void func##_register(pybind11::module_ &); \ + namespace { \ + struct func##_registrar { \ + func##_registrar() { ModuleRegistry::Instance().Register(func##_register); } \ + }; \ + static func##_registrar registrar_instance; \ + } \ + static void func##_register(pybind11::module_ &m) + +#endif // VLLM_MINDSPORE_OPS_MODULE_MODULE_H diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 689fe9f154c821c11e2809da83beee9581a34512..d32b525ecd92b9e7782861150803cb7b52d7c667 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -223,6 +223,9 @@ def check_ready(): # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): + set_context(mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") + if is_mindformers_model_backend(): logger.info("Run with Mindformers backend!") necessary_envs = ("MINDFORMERS_MODEL_CONFIG", )