diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 84c933ab86c345ccc0b8bbbc561b70eb0146de8d..6aa1166a9520df1bec000a59fe685f2322786479 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -22,6 +22,9 @@ import os import sys import warnings import msadapter # noqa F401 +from vllm_mindspore.ray_patch import patch_ray + +patch_ray() if "vllm" in sys.modules: # Check models variable in sub process, cannot raise here. @@ -526,9 +529,4 @@ from vllm.model_executor.models.registry import _ModelRegistry _ModelRegistry._normalize_archs = _normalize_archs -from vllm_mindspore.utils import view -from mindspore import Tensor - -Tensor.view = view - check_ready() diff --git a/vllm_mindspore/ray_patch.py b/vllm_mindspore/ray_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..088d1c2a6bcb248aa65ff6904cb32bcdbf8e0437 --- /dev/null +++ b/vllm_mindspore/ray_patch.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/ray-project/ray/blob/ray-2.49.0/python/ray/experimental/channel/serialization_context.py +# +# Copyright 2025 Huawei Technologies Co., Ltd. +# Copyright 2025 The Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Patching ray functions to use view(dtype)""" +import numpy as np +import torch + +try: + from ray.experimental.util.types import Device +except ImportError: + Device = None + + +def _view(tensor, target_dtype): + """ function for view(dtype) """ + ori_shape = tensor.shape + target_shape = (-1, ) + if len(ori_shape) > 1: + target_shape = ori_shape[:-1] + target_shape + out = np.frombuffer( + tensor.numpy(), + torch.ops.creation._TypeDict.get(target_dtype, np.float32)) + if not out.flags.aligned: + out = np.require(out, requirements=["ALIGNED"]) + if target_dtype == torch.bfloat16: + return torch.tensor(out.astype( + np.float32)).astype(target_dtype).reshape(target_shape) + return torch.tensor(out).reshape(target_shape) + + +def serialize_to_numpy_or_scalar(self, tensor): + """ + patch for + ray.experimental.channel.serialization_context._SerializationContext + """ + tensor_device_type = tensor.device.type + if tensor_device_type != "cpu": + tensor = tensor.to("cpu") + if tensor.dim() > 0: + return (_view(tensor, + torch.uint8).numpy(), tensor.dtype, tensor_device_type) + else: + return (tensor.item(), tensor.dtype, tensor_device_type) + + +def deserialize_from_numpy_or_scalar(self, np_array, dtype, tensor_device_type, + target_device): + """ + patch for + ray.experimental.channel.serialization_context._SerializationContext + """ + + if target_device == Device.DEFAULT: + target_device_type = tensor_device_type + elif target_device in [Device.GPU, Device.CUDA]: + target_device_type = "cuda" + else: + target_device_type = target_device.value + + if target_device_type != "cpu": + + def convert_numpy_to_tensor(np_array): + if not isinstance(np_array, np.ndarray): + # For scalar tensors, create the 0-dim tensor. + return torch.tensor(np_array, + device=target_device_type, + dtype=dtype) + else: + # For non-scalar tensors, view as the original dtype. + # It does zero-copy convert np_array inside shared memory to + # a tensor. Since we move data to GPU immediately, it is safe. + cpu_tensor = torch.from_numpy(np_array) + cpu_tensor = _view(cpu_tensor, dtype) + return cpu_tensor.to(device=target_device_type) + + gpu_tensor = convert_numpy_to_tensor(np_array) + + return gpu_tensor + + if not isinstance(np_array, np.ndarray): + # For scalar tensors, create the 0-dim tensor. + return torch.tensor(np_array, device=target_device_type, dtype=dtype) + else: + # For non-scalar tensors, view as the original dtype. + return _view(torch.tensor(np_array, device=target_device_type), dtype) + + +def patch_ray(): + """patch for ray serialization context to use view(dtype) """ + try: + from ray._version import version + from ray.experimental.channel.serialization_context import ( + _SerializationContext) + if version >= "2.47.0": + _SerializationContext.deserialize_from_numpy_or_scalar = \ + deserialize_from_numpy_or_scalar + _SerializationContext.serialize_to_numpy_or_scalar = \ + serialize_to_numpy_or_scalar + else: + _SerializationContext.deserialize_from_numpy = \ + deserialize_from_numpy_or_scalar + _SerializationContext.serialize_to_numpy = \ + serialize_to_numpy_or_scalar + except ImportError: + pass diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 8690fa57b4a5451d18f2c4ec59cfc33d40c64e07..3f719bbc06137d5c716c21abf5b7329cab713fdf 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -429,34 +429,3 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa - - -def view(self, *shape_or_dtype): - from mindspore._c_expression import typing - if len(shape_or_dtype) == 1 and isinstance(shape_or_dtype[0], typing.Type): - target_dtype = shape_or_dtype[0] - ori_shape = self.shape - target_shape = (-1, ) - if len(ori_shape) > 1: - target_shape = ori_shape[:-1] + target_shape - out = np.frombuffer( - self.numpy(), - torch.ops.creation._TypeDict.get(target_dtype, np.float32)) - if not out.flags.aligned: - out = np.require(out, requirements=["ALIGNED"]) - if target_dtype == ms.bfloat16: - return ms.Tensor.from_numpy(out.astype( - np.float32)).astype(target_dtype).reshape(target_shape) - return ms.Tensor.from_numpy(out).reshape(target_shape) - result = [] - if type(shape_or_dtype) is tuple: - for items in shape_or_dtype: - if not isinstance(items, int): - for item in items: - if not isinstance(item, int): - result.append(item.item()) - else: - result.append(item) - else: - result.append(items) - return ms.ops.reshape(self, result)