From e182933abcd2d4901d4be566b5903cac54b3f1e5 Mon Sep 17 00:00:00 2001 From: one_east Date: Sat, 26 Jul 2025 10:38:44 +0800 Subject: [PATCH] CPU bind for 910B and 910C --- vllm_mindspore/__init__.py | 29 +++---- vllm_mindspore/worker/worker.py | 140 +++++++++++++++++++++++++++++--- 2 files changed, 143 insertions(+), 26 deletions(-) diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 4e9a1717e..c9cf4c7ae 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -152,16 +152,18 @@ from vllm_mindspore.model_executor.model_loader.weight_utils import ( vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( safetensors_weights_iterator) -from vllm_mindspore.worker.worker import _warm_up_model +from vllm_mindspore.worker.worker import (_warm_up_model, + wrapper_worker_bind_cpu) from vllm_mindspore.worker.profile import ( wrapper_worker_init, wrapper_worker_init_device, ) -from vllm.worker.worker import Worker +from vllm.worker.worker import Worker as V0Worker -Worker._warm_up_model = _warm_up_model -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(Worker.init_device) +V0Worker._warm_up_model = _warm_up_model +V0Worker.__init__ = (wrapper_worker_bind_cpu( + wrapper_worker_init(V0Worker.__init__))) +V0Worker.init_device = wrapper_worker_init_device(V0Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, @@ -320,12 +322,6 @@ vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampl vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor -from vllm.v1.worker.gpu_worker import Worker -from vllm_mindspore.v1.worker.gpu_worker import init_device - -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(init_device) - import vllm.v1.utils from vllm_mindspore.v1.utils import copy_slice @@ -363,10 +359,15 @@ from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer ShmRingBuffer.__init__ = initialize_ShmRingBuffer -from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model -from vllm.v1.worker.gpu_worker import Worker +from vllm_mindspore.v1.worker.gpu_worker import (compile_or_warm_up_model, + init_device) + +from vllm.v1.worker.gpu_worker import Worker as V1Worker -Worker.compile_or_warm_up_model = compile_or_warm_up_model +V1Worker.__init__ = (wrapper_worker_bind_cpu( + wrapper_worker_init(V1Worker.__init__))) +V1Worker.init_device = wrapper_worker_init_device(init_device) +V1Worker.compile_or_warm_up_model = compile_or_warm_up_model from vllm_mindspore.v1.core.sched.scheduler import update_from_output from vllm.v1.core.sched.scheduler import Scheduler diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index c09d997b8..8a7a2789f 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -15,21 +15,11 @@ # limitations under the License. """Adapted functions for mindspore in Worker.""" -import gc -import os import math -from typing import Tuple, Optional +import subprocess +import psutil import torch - -from vllm.config import VllmConfig -from vllm.distributed import ( - ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce, -) - from vllm.logger import init_logger from vllm_mindspore.utils import get_valid_dtype @@ -40,6 +30,132 @@ from vllm.sampling_params import SamplingParams logger = init_logger(__name__) +def execute_command(cmd_list): + try: + with subprocess.Popen(cmd_list, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) as p: + out, _ = p.communicate(timeout=1000) + res = out.decode() + return res + except FileNotFoundError as e: + message = f"Failed to execute command, because {e}." + raise RuntimeError(message) from e + + +def get_numa_map(): + npu_to_core_map = {} + + # Get quantity of CPUs and NUMA nodes. + total_cpu_count = 0 + numa_node_count = 0 + numa_info = execute_command("lscpu").strip().split("\n") + for val in numa_info: + if val.startswith("CPU(s):"): + total_cpu_count = int(val.split(" ")[-1]) + if val.startswith("NUMA"): + numa_node_count = int(val.split(" ")[-1]) + break + + # Get chip count of NPU. + chip_info = execute_command(["npu-smi", "info", "-l"]).strip().split("\n") + chip_count = 0 + npu_count = 0 + for val in chip_info: + if val.strip().startswith("Total"): + npu_count = int(val.split(" ")[-1]) + if val.strip().startswith("Chip"): + chip_count = int(val.split(" ")[-1]) + break + + # Get affinity relationship between CPUs and NPUs. + numa_topo_info = execute_command(["npu-smi", "info", "-t", + "topo"]).strip().split("\n") + numa_to_npu_map = {} + max_affinity_cpu = 0 + if "Affinity" not in numa_topo_info[0]: + # If the device does not provide affinity, + # the CPUs will be evenly distributed. + cpu_num_per_npu = total_cpu_count // (npu_count * chip_count) + for i in range(npu_count * chip_count): + cpu_start = i * cpu_num_per_npu + # 4 CPUs are reserved for CANN + npu_to_core_map[i] = [cpu_start, cpu_start + cpu_num_per_npu - 4] + return npu_to_core_map + else: + npu_num = 0 + for val in numa_topo_info[1:]: + line = val.split(" ") + if line and line[0].startswith("NPU"): + cpu_affinity = line[-1] + max_affinity_cpu = max(max_affinity_cpu, + int(cpu_affinity.split("-")[1])) + if numa_to_npu_map.get(cpu_affinity) is None: + numa_to_npu_map[cpu_affinity] = list() + # If each NPU has multiple chips, + # assign them to the same NUMA node. + for i in range(chip_count): + numa_to_npu_map[cpu_affinity].append(npu_num * chip_count + + i) + npu_num += 1 + + # If the number of NUMA nodes with affinity is less than + # or equal to half of the total, the offset is introduced, + # and no extra CPU is reserved for CANN. + if numa_node_count >= 2 * len(numa_to_npu_map): + offset_mode = True + cpu_reserved_for_cann = 0 + else: + offset_mode = False + cpu_reserved_for_cann = 4 + + for key, val in numa_to_npu_map.items(): + cpu_range = key.split("-") + cpu_start = int(cpu_range[0]) + cpu_end = int(cpu_range[1]) + cpu_count = cpu_end - cpu_start + 1 + if offset_mode: + if max_affinity_cpu == total_cpu_count - 1: + cpu_start = cpu_start - cpu_count + else: + cpu_start = cpu_start + cpu_count + shared_npu_count = len(val) + cpu_num_per_npu = int(cpu_count / shared_npu_count) + for npu in val: + npu_to_core_map[npu] = [ + cpu_start, cpu_start + cpu_num_per_npu - cpu_reserved_for_cann + ] + cpu_start += cpu_num_per_npu + + return npu_to_core_map + + +def bind_cpu(rank): + rank_cpu_maps = get_numa_map() + + local_rank = rank % len(rank_cpu_maps) + cpu_range = rank_cpu_maps[local_rank] + cpu_list = list(range(cpu_range[0], cpu_range[1])) + current_process = psutil.Process() + current_process.cpu_affinity(cpu_list) + logger.info("bind process %d in rank %d to cpu: %s", current_process.pid, + local_rank, cpu_list) + + +def wrapper_worker_bind_cpu(fun): + + def new_fun(*arg, **kwargs): + # Bind CPU with wrapper when workers are initializing. + local_rank = kwargs.get("local_rank") + parallel_config = kwargs.get("vllm_config").parallel_config + local_rank = (parallel_config.data_parallel_rank_local * + parallel_config.world_size + local_rank) + bind_cpu(local_rank) + fun(*arg, **kwargs) + + return new_fun + def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefill, is_mtp_model=False): bs = 1 -- Gitee