diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 89a17fb8beca175548d9e173cbfc3e8b509f8ed7..01acf8c8e71db66a8d0a558c18d0cb5479e8f83e 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -52,6 +52,7 @@ from vllm_mindspore.utils import ( get_dtype_size, ascend_device_count_stateless, ascend_is_initialized, + ms_memory_profiling, ) vllm.utils.direct_register_custom_op = direct_register_custom_op @@ -60,6 +61,7 @@ vllm.utils.async_tensor_h2d = async_tensor_h2d vllm.utils.get_dtype_size = get_dtype_size vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless vllm.utils.cuda_is_initialized = ascend_is_initialized +vllm.utils.memory_profiling = ms_memory_profiling vllm.config.cuda_device_count_stateless = ascend_device_count_stateless import vllm.executor diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index b67c3cc870ff584ca502fccfd7e1cf8383a8bcc5..c70eee8353cd9c23e82f9a940d42c287045d931b 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: else: Library = None -from vllm.utils import T, TORCH_DTYPE_TO_NUMPY_DTYPE, make_ndarray_with_pad +from vllm.utils import T, TORCH_DTYPE_TO_NUMPY_DTYPE, make_ndarray_with_pad, MemorySnapshot, MemoryProfilingResult import mindspore as ms from mindspore.common.initializer import Zero @@ -273,3 +273,86 @@ def convert_np_to_ms_dtype(value): def update_modules(name, module): logger.info(f"replace module {name} by {module}") sys.modules.update({name: module}) + + +@contextlib.contextmanager +def ms_memory_profiling( + baseline_snapshot: MemorySnapshot, + weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). + """ # noqa + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.before_create = baseline_snapshot + # the part of memory used for holding the model weights + result.weights_memory = weights_memory + + result.before_profile.measure() + + yield result + + # measure memory before empty cache to get maximum reserved memory + result.after_profile.measure() + + gc.collect() + torch.cuda.empty_cache() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + + # Since memory fragmentation is much more than expected, the memory + # requested by Mindspore is much larger than the actual requirement, + # therefore use reserved memory to describe increase of torch memory, + # and this patch should be removed after optimization of Mindspore. + result.torch_peak_increase = diff_profile.torch_memory + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa