diff --git a/README.en.md b/README.en.md index 8fa0b8f3aa4c78d61e5ac25dbc5d7ed15dc0c14b..54b7c291d6e2f83866c0070b7e0551addf96be96 100644 --- a/README.en.md +++ b/README.en.md @@ -29,6 +29,7 @@ Curious about how it works under the hood? Check out our [System Walkthrough](ht ## News +- **[08/25]** We're excited to share that we have added support of fast checkpoint loading for Ascend NPUs (CANN) when using with vLLM, PyTorch and HuggingFace Accelerate. Please refer to the documentation in `/docs/stable/store/ascend_npu_deployment` and `/docs/stable/store/晟腾NPU一站式部署文档` - **[03/25]** We're excited to share that we'll be giving a ServerlessLLM tutorial at the SESAME workshop, co-located with ASPLOS/EuroSys 2025 in Rotterdam, Netherlands, on March 31. [Slides](https://docs.google.com/presentation/d/1ioGCVpsg0x3oCxX19EiE820aMiY22X5MG6jgImZ1W18/edit?usp=sharing) | [More info](https://sesame25.github.io/) - **[11/24]** We have added experimental support of fast checkpoint loading for AMD GPUs (ROCm) when using with vLLM, PyTorch and HuggingFace Accelerate. Please refer to the [documentation](https://serverlessllm.github.io/docs/stable/store/rocm_quickstart) for more details. - **[10/24]** ServerlessLLM was invited to present at a global AI tech vision forum in Singapore. diff --git a/README.md b/README.md index 8fa0b8f3aa4c78d61e5ac25dbc5d7ed15dc0c14b..54b7c291d6e2f83866c0070b7e0551addf96be96 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Curious about how it works under the hood? Check out our [System Walkthrough](ht ## News +- **[08/25]** We're excited to share that we have added support of fast checkpoint loading for Ascend NPUs (CANN) when using with vLLM, PyTorch and HuggingFace Accelerate. Please refer to the documentation in `/docs/stable/store/ascend_npu_deployment` and `/docs/stable/store/晟腾NPU一站式部署文档` - **[03/25]** We're excited to share that we'll be giving a ServerlessLLM tutorial at the SESAME workshop, co-located with ASPLOS/EuroSys 2025 in Rotterdam, Netherlands, on March 31. [Slides](https://docs.google.com/presentation/d/1ioGCVpsg0x3oCxX19EiE820aMiY22X5MG6jgImZ1W18/edit?usp=sharing) | [More info](https://sesame25.github.io/) - **[11/24]** We have added experimental support of fast checkpoint loading for AMD GPUs (ROCm) when using with vLLM, PyTorch and HuggingFace Accelerate. Please refer to the [documentation](https://serverlessllm.github.io/docs/stable/store/rocm_quickstart) for more details. - **[10/24]** ServerlessLLM was invited to present at a global AI tech vision forum in Singapore. diff --git a/docs/stable/store/ascend_npu_deployment.md b/docs/stable/store/ascend_npu_deployment.md new file mode 100644 index 0000000000000000000000000000000000000000..214822b3f151b936fde8cdfdb161cdb6f8f0991b --- /dev/null +++ b/docs/stable/store/ascend_npu_deployment.md @@ -0,0 +1,346 @@ +# Ascend NPU Deployment + +ServerlessLLM Store (`sllm-store`) is a Python library that supports fast model checkpoint loading from multi-tier storage (i.e., DRAM, SSD, HDD) into GPUs/NPUs. + +ServerlessLLM Store provides a model manager and two key functions: + +- `save_model`: Convert a HuggingFace model into a loading-optimized format and save it to a local path. +- `load_model`: Load a model into given GPUs. + +This document provides instructions for deploying the `sllm_store` loading acceleration library from ServerlessLLM on an Ascend NPU environment. It outlines how to use the library with CANN to speed up model loading. + +## Environment Setup + +- **Check NPU status:** `npu-smi info` +- **Monitor NPU status in real-time:** `watch -d -n 1 npu-smi info` +- **Check CANN version:** `ascend-dmi -c` + +--- + +### Set CANN Environment + +The version used for the CANN test is **`8.0.0`**. Using other versions may cause bugs. + +```shell +# Set up CANN environment +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# Set up CANN environment +export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/8.0.0 +export LD_LIBRARY_PATH=$ASCEND_TOOLKIT_HOME/lib64:$ASCEND_TOOLKIT_HOME/runtime/lib64:$LD_LIBRARY_PATH +export ASCEND_OPP_PATH=$ASCEND_TOOLKIT_HOME/opp +export ASCEND_AICPU_PATH=$ASCEND_TOOLKIT_HOME/aicpu +export PYTHONPATH=$ASCEND_TOOLKIT_HOME/python/site-packages:$PYTHONPATH + +# Optional: Set log level for CANN (for debugging) +export ASCEND_SLOG_PRINT_TO_STDOUT=1 # Enable logging to standard output +export ASCEND_GLOBAL_LOG_LEVEL=1 # Log level, typically 1 for INFO, 3 for ERROR +``` + +### Download `torch` and `torch_npu` + +First, install `torch` and `torch_npu` to ensure the compilation process can find the necessary `torch_npu` functions. + +```shell +pip install torch==2.4.0 +pip install torch_npu==2.4.0.post2 +``` + +### Set Up Conda Environment + +If you don't have Conda, download it first. + +```shell +conda create -n sllm-worker python=3.10 -y +conda activate sllm-worker +conda install -c conda-forge gcc=13 gxx cmake -y +conda install -c conda-forge ninja + +# Set USE_CANN environment variable to enable using CANN +export USE_CANN=1 +``` + +--- + +### Download from Source + +1. Clone the repository and navigate to the `store` directory. + + + +```shell +git clone https://gitee.com/openeuler/ServerlessLLM.git +``` + +2. Download the `store` library from the source. + + + +```shell +rm -rf build +pip install . +``` + +--- + +## Usage Examples + +1. Convert a model to the ServerlessLLM format and save it locally. + + + +```python +from sllm_store.transformers import save_model + +# Load a model from HuggingFace model hub. +import torch +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained('facebook/opt-1.3b', torch_dtype=torch.float16) + +# Replace './models' with your local path. +save_model(model, './models/facebook/opt-1.3b') +``` + +2. In a separate terminal process, start the checkpoint store server. + + + +```shell +# 'mem_pool_size' is the maximum size of the memory pool in GB. It should be larger than the model size. +export ASCEND_RT_VISIBLE_DEVICES="0" # Set the NPU device to use, e.g., "0,1,..." +sllm-store start --storage-path ./models --mem-pool-size 4GB +``` + +3. Load the model and perform inference. + + + +```python +import time +import torch +import torch_npu +from sllm_store.transformers import load_model + +# Warm up the GPU +num_npus = torch_npu.npu.device_count() +for i in range(num_npus): + torch.ones(1).to(f"npu:{i}") + torch_npu.npu.synchronize() + +start = time.time() +model = load_model("facebook/opt-1.3b", device_map="auto", torch_dtype=torch.float16, storage_path="./models/", fully_parallel=True) +# Please note the loading time depends on the model size and the hardware bandwidth. +print(f"Model loading time: {time.time() - start:.2f}s") + +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained('facebook/opt-1.3b') +inputs = tokenizer('Hello, my dog is cute', return_tensors='pt').to("npu") +outputs = model.generate(**inputs) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## Built-in Test Cases + +Navigate to `cd sllm_store/tests/python/`. There are four files available for testing `sllm_store` with CANN and NPU: + +``` +| +|-- test_cann_basic.py +|-- test_cann_load_model.py +|-- test_cann_load_vllm_model.py +|-- test_cann_save_model.py +``` + +--- + +## Applying vLLM Patches on NPU + +To use vLLM, your versions must align with the following: + +```shell +vllm-ascend==0.7.3.post1 +torch-npu==2.5.1 +torch==2.5.1 +``` + +Tested vLLM version: [https://vllm-ascend.readthedocs.io/en/v0.7.3-dev/installation.html](https://vllm-ascend.readthedocs.io/en/v0.7.3-dev/installation.html) + +A known issue when installing vLLM on NPU is the incompatibility of `torch-npu` and `torch` versions with ServerlessLLM on the NPU. This requires manually fixing the version issues. + +### Install vLLM Patches + +1. Download vLLM ascend from the source code; using `pip install` may cause issues. +2. Check patch status (optional): + + + +```shell +./sllm_store/vllm_patch/check_patch.sh +``` + +3. Apply patches: + + + +```shell +./sllm_store/vllm_patch/patch.sh +patch -p1 < sllm_load_npu.patch +patch -p1 < vllm_ascend.patch +``` + +Remove patches (if needed): + +```shell +./sllm_store/vllm_patch/remove_patch.sh +``` + +--- + +### Note + +> The patch files are located in the `sllm_store/vllm_patch/sllm_load.patch` directory of the ServerlessLLM repository. + +Download a model from HuggingFace and save it in the ServerlessLLM format: + +```shell +python3 sllm_store/example/cann_save_vllm_model.py --model-name facebook/opt-1.3b --storage-path $PWD/models --tensor-parallel-size 1 +``` + +You can also transfer a model from a local path instead of downloading it from the network by passing the `--local-model-path` parameter. + +After downloading the model, start the checkpoint store server and load the model within vLLM using the `sllm` load format. + +Start the checkpoint store server in a separate process: + +```shell +# 'mem_pool_size' is the maximum size of the memory pool in GB. It should be larger than the model size. +sllm-store start --storage-path $PWD/models --mem-pool-size 4GB +``` + +Load the model in vLLM: + +```python +from vllm import LLM, SamplingParams + +import os + +storage_path = os.getenv("STORAGE_PATH", "./models") +model_name = "facebook/opt-1.3b" +model_path = os.path.join(storage_path, model_name) + +llm = LLM( + model=model_path, + load_format="serverless_llm", + dtype="float16" +) + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + +### NNAL Issues + +**If you encounter NNAL-related issues when running the vLLM example, install the NNAL package (if missing).** + +NNAL provides the ATB library, including `libatb.so`. Download and install it. The version must match your CANN version. Use `$(uname -i)` to get the architecture, such as `aarch64` or `x86_64`. + +```shell +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.1.RC1/Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run + +chmod +x ./Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run + +./Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run --install +``` + +Load the NNAL/ATB environment script to add the library's path (e.g., `libatb.so`). This must be done in every new terminal session before running any scripts. + +```shell +source /usr/local/Ascend/nnal/atb/set_env.sh +``` + +## Troubleshooting + +Here are some issues you might encounter during compilation and usage. + +### Torch Version Issues + +The default `requirement.txt` specifies: + +``` +torch==2.5.1 +torch-npu==2.5.1 +``` + +If you encounter a compilation error related to `torch` or `torch-npu`, switch to: + +```shell +pip install torch==2.4.0 +pip install torch-npu==2.4.0.post2. +``` + +If you then encounter issues with `torchvision` or CANN when running `sllm_store`, switch back to: + +```shell +pip install torch==2.5.1 +pip install torch-npu==2.5.1 +``` + +--- + +### Inference Issues -\> `load_model()` + +If you encounter `torch_npu`-related issues: + +1. Downgrade `numpy` to `<= 2.0` -\> `pip install numpy==1.26.4` +2. Install the necessary packages via `pip`. + +Start the backend server: + +```shell +sllm-store start --storage-path /root/PROJECT/ServerlessLLM-NPU/sllm_store/tests/python/models --mem-pool-size 4GB +``` + +When running `test_cann_load_model`, if you encounter an issue where the gRPC server cannot connect, set the following environment variables on the server. Proxies can block the connection. + +```shell +export NO_PROXY="localhost,127.0.0.1" +export no_proxy="localhost,127.0.0.1" # It is often best to set both for compatibility +``` + +If you encounter an issue like this: + +```shell +(sllm-worker) [root@devserver-bms-2956fa59 sllm_store]# sllm-store start +Traceback (most recent call last): + File "/root/miniconda3/envs/sllm-worker/bin/sllm-store", line 5, in + from sllm_store.cli import main + File "/root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/cli.py", line 24, in + from sllm_store.server import serve + File "/root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/server.py", line 13, in + ctypes.CDLL(os.path.join(sllm_store.__path__[0], "libglog.so")) + File "/root/miniconda3/envs/sllm-worker/lib/python3.10/ctypes/__init__.py", line 374, in __init__ + self._handle = _dlopen(self._name, mode) +OSError: /root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/libglog.so: cannot open shared object file: No such file or directory +[ERROR] 2025-06-09-23:13:26 (PID:2197687, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception +``` + +Solution: + +```shell +ln -s /root/PROJECT/ServerlessLLM-NPU/sllm_store/build/lib.linux-aarch64-cpython-310/sllm_store/libglog.so \ + /root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/libglog.so +``` \ No newline at end of file diff --git "a/docs/stable/store/\346\231\237\350\205\276NPU\344\270\200\347\253\231\345\274\217\351\203\250\347\275\262\346\226\207\346\241\243.md" "b/docs/stable/store/\346\231\237\350\205\276NPU\344\270\200\347\253\231\345\274\217\351\203\250\347\275\262\346\226\207\346\241\243.md" new file mode 100644 index 0000000000000000000000000000000000000000..1720ca1f69c48a441f7259ddef9ec0dc2f966779 --- /dev/null +++ "b/docs/stable/store/\346\231\237\350\205\276NPU\344\270\200\347\253\231\345\274\217\351\203\250\347\275\262\346\226\207\346\241\243.md" @@ -0,0 +1,351 @@ +# 晟腾 NPU 一站式部署文档 + +ServerlessLLM Store (`sllm-store`) 是一个 Python 库,支持将模型检查点从多层存储(如 DRAM、SSD、HDD)快速加载到 GPU/NPU 中。 + +ServerlessLLM Store 提供了一个模型管理器和两个关键功能: + +- **`save_model`**:将 HuggingFace 模型转换为加载优化的格式并保存到本地路径。 +- **`load_model`**:将模型加载到指定的 GPU 中。 + +--- + +本文档旨在部署 ServerlessLLM 中的 sllm\_store 加载加速库到昇腾 NPU 环境和 CANN 结合进行加载提速。 + + +## 环境设置 + +- 查看 npu 状态:`npu-smi info` +- 实时监测 npu 状态: `watch -d -n 1 npu-smi info` +- 查看 CANN 版本:`ascend-dmi -c` + +--- + +### 设置 CANN 环境 + +CANN 测试中使用版本为 `8.0.0`,使用其他版本可能导致BUG。 + +```shell +# Set up CANN environment +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# Set up CANN environment +export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/8.0.0 +export LD_LIBRARY_PATH=$ASCEND_TOOLKIT_HOME/lib64:$ASCEND_TOOLKIT_HOME/runtime/lib64:$LD_LIBRARY_PATH +export ASCEND_OPP_PATH=$ASCEND_TOOLKIT_HOME/opp +export ASCEND_AICPU_PATH=$ASCEND_TOOLKIT_HOME/aicpu +export PYTHONPATH=$ASCEND_TOOLKIT_HOME/python/site-packages:$PYTHONPATH + +# Optional: Set log level for CANN (for debugging) +export ASCEND_SLOG_PRINT_TO_STDOUT=1 #日志打屏, 可选 +export ASCEND_GLOBAL_LOG_LEVEL=1 #日志级别常用 1 INFO级别; 3 ERROR级别 +``` + +### 下载 torch 和 torch\_npu + +先下载 `torch` 和 `torch_npu` 以保证编译时候可以找到 `torch_npu` 的函数 + +```shell +pip install torch==2.4.0 +pip install torch_npu==2.4.0.post2 +``` + +### 设置 conda 环境 + +如没有 conda 内置,先下载 conda + +```shell +conda create -n sllm-worker python=3.10 -y   +conda activate sllm-worker +conda install -c conda-forge gcc=13 gxx cmake -y +conda install -c conda-forge ninja + +# Set USE_CANN environment variable to enable using CANN +export USE_CANN=1 +``` + +--- + +### 从源码下载 + +1. 克隆仓库,进入 `store` 路径 + + + +```shell +git clone https://gitee.com/openeuler/ServerlessLLM.git +``` + +2. 从源码下载 `store` 库 + + + +```shell +rm -rf build   +pip install . +``` + +--- + +## 使用样例 + +1. 转换模型到 ServerlessLLM format 并保存到本地 + + + +```python +from sllm_store.transformers import save_model   +   +# Load a model from HuggingFace model hub.   +import torch   +from transformers import AutoModelForCausalLM   +model = AutoModelForCausalLM.from_pretrained('facebook/opt-1.3b', torch_dtype=torch.float16)   + +# Replace './models' with your local path.   +save_model(model, './models/facebook/opt-1.3b') +``` + +2. 在一个单独终端 process 里启动 checkpoint store 服务器 + + + +```shell +# 'mem_pool_size' is the maximum size of the memory pool in GB. It should be larger than the model size.   +export ASCEND_RT_VISIBLE_DEVICES="0" # 设置需要使用的 NPU Device "0,1,..." +sllm-store start --storage-path ./models --mem-pool-size 4GB +``` + +3. 加载模型并推理 + + + +```python +import time +import torch +import torch_npu +from sllm_store.transformers import load_model + +# warm up the GPU +num_npus = torch_npu.npu.device_count() +for i in range(num_npus): +    torch.ones(1).to(f"npu:{i}") +    torch_npu.npu.synchronize() + +start = time.time() +model = load_model("facebook/opt-1.3b", device_map="auto", torch_dtype=torch.float16, storage_path="./models/", fully_parallel=True) +# Please note the loading time depends on the model size and the hardware bandwidth. +print(f"Model loading time: {time.time() - start:.2f}s") + +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained('facebook/opt-1.3b') +inputs = tokenizer('Hello, my dog is cute', return_tensors='pt').to("npu") +outputs = model.generate(**inputs) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +## 内置测试用例 + +进入 `cd sllm_store/tests/python/`,有四个文件可以用来测试 CANN 和 NPU 的 sllm\_store + +``` +| +|-- test_cann_basic.py +|-- test_cann_load_model.py +|-- test_cann_load_vllm_model.py +|-- test_cann_save_model.py +``` + +----- + +## NPU 上应用 vLLM 补丁 + +为了使用 vLLM,版本必须与以下对齐: + +```shell +vllm-ascend==0.7.3.post1 +torch-npu==2.5.1 +torch-npu==2.5.1 +``` + +已测试的 vLLM 版本: + +[https://vllm-ascend.readthedocs.io/en/v0.7.3-dev/installation.html](https://vllm-ascend.readthedocs.io/en/v0.7.3-dev/installation.html) + +在 NPU 上安装 vLLM 时,一个存在的问题是 torch-npu 和 torch 版本与 NPU 上的 ServerlessLLM 不兼容,因此需要手动修复版本问题。 + +### 安装 vLLM 补丁 + +1. 使用源码下载 vLLM ascend,不能使用 pip install,要不是会出问题 +2. 检查补丁状态(可选): + + + +``` +./sllm_store/vllm_patch/check_patch.sh +``` + +3. 应用补丁: + + + +``` +./sllm_store/vllm_patch/patch.sh +patch -p1 < sllm_load_npu.patch +patch -p1 < vllm_ascend.patch +``` + +移除补丁(如果需要): + +``` +./sllm_store/vllm_patch/remove_patch.sh +``` + +--- + +### 注意 + +> 补丁文件位于 ServerlessLLM 仓库的 `sllm_store/vllm_patch/sllm_load.patch`。 + +从 HuggingFace 下载模型并以 ServerlessLLM 格式保存: + +```shell +python3 sllm_store/example/cann_save_vllm_model.py --model-name facebook/opt-1.3b --storage-path $PWD/models --tensor-parallel-size 1 +``` + +也可以通过传递 `--local-model-path` 参数,从本地路径传输模型,而不是从网络下载。 + +下载模型后,启动 checkpoint store 服务器,并通过 `sllm` 加载格式在 vLLM 中加载模型。 + +在单独的进程中启动 checkpoint store 服务器: + +```shell +# 'mem_pool_size' 是内存池的最大大小,单位为 GB。它应大于模型大小。 +sllm-store start --storage-path $PWD/models --mem-pool-size 4GB +``` + +在 vLLM 中加载模型: + +```python +from vllm import LLM, SamplingParams + +import os + +storage_path = os.getenv("STORAGE_PATH", "./models") +model_name = "facebook/opt-1.3b" +model_path = os.path.join(storage_path, model_name) + +llm = LLM( +    model=model_path, +    load_format="serverless_llm", +    dtype="float16" +) + +prompts = [ +    "Hello, my name is", +    "The president of the United States is", +    "The capital of France is", +    "The future of AI is", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +outputs = llm.generate(prompts, sampling_params) + +# 打印输出。 +for output in outputs: +    prompt = output.prompt +    generated_text = output.outputs[0].text +    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` + +### NNAL 问题 + +**当运行 vLLM 示例时,如果遇到与 NNAL 相关的问题:安装 NNAL 包(如果缺失)** + +NNAL 提供 ATB 库,包括 libatb.so。下载并安装它(版本要与您的 CANN 匹配;使用 $(uname -i) 来获取架构,例如 aarch64 或 x86\_64): + +```shell +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.1.RC1/Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run + +chmod +x ./Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run + +./Ascend-cann-nnal_8.1.RC1_linux-"$(uname -i)".run --install +``` + +加载 NNAL/ATB 环境脚本以添加库(例如 libatb.so)的路径(在每个新的终端会话中,必须在运行脚本之前完成此操作): + +```shell +source /usr/local/Ascend/nnal/atb/set_env.sh +``` + +## 故障排除 + +以下为编译和使用过程可能遇到的问题。 + +### Torch 版本问题 + +默认的 requirement.txt 设置为: + +``` +torch==2.5.1 +torch-npu==2.5.1 +``` + +如果遇到由于 torch 或 torch-npu 导致的编译错误,请切换到: + +```shell +pip install torch==2.4.0 +pip install torch-npu==2.4.0.post2.  +``` + +之后,如果在运行 sllm\_store, 时遇到与 torchvision 或 CANN 相关的问题,请切换回:  + +```shell +pip install torch==2.5.1 +pip install torch-npu==2.5.1 +``` + +--- + +### 推理问题 -\> load\_model() + +如果遇到与 torch\_npu 相关的问题: + +1. 将 numpy 降级到 \<= 2.0 -\> `pip install numpy==1.26.4` +2. pip 安装必要的包 + +启动后端服务器: + +```shell +sllm-store start --storage-path /root/PROJECT/ServerlessLLM-NPU/sllm_store/tests/python/models --mem-pool-size 4GB +``` + +当运行 `test_cann_load_model` 时,如果遇到 gPRC 服务器无法连接的问题。在服务器上设置以下环境变量,因为代理会阻止连接。 + +```shell +export NO_PROXY="localhost,127.0.0.1" +export no_proxy="localhost,127.0.0.1" # 通常最好同时设置这两个以保证兼容性 +``` + +如果遇到一下问题: + +```shell +(sllm-worker) [root@devserver-bms-2956fa59 sllm_store]# sllm-store start +Traceback (most recent call last): +  File "/root/miniconda3/envs/sllm-worker/bin/sllm-store", line 5, in +    from sllm_store.cli import main +  File "/root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/cli.py", line 24, in +    from sllm_store.server import serve +  File "/root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/server.py", line 13, in +    ctypes.CDLL(os.path.join(sllm_store.__path__[0], "libglog.so")) +  File "/root/miniconda3/envs/sllm-worker/lib/python3.10/ctypes/__init__.py", line 374, in __init__ +    self._handle = _dlopen(self._name, mode) +OSError: /root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/libglog.so: cannot open shared object file: No such file or directory +[ERROR] 2025-06-09-23:13:26 (PID:2197687, Device:-1, RankID:-1) ERR99999 UNKNOWN applicaiton exception +``` + +解决方法 + +```shell +ln -s /root/PROJECT/ServerlessLLM-NPU/sllm_store/build/lib.linux-aarch64-cpython-310/sllm_store/libglog.so \ +      /root/PROJECT/ServerlessLLM-NPU/sllm_store/sllm_store/libglog.so +``` \ No newline at end of file diff --git a/sllm_store/requirements-build.txt b/sllm_store/requirements-build.txt index bae9656f10d47f0f26d5d4ac97ee6329a25e595d..24c98362f5a72026095e0d3f425394a13430ae3f 100644 --- a/sllm_store/requirements-build.txt +++ b/sllm_store/requirements-build.txt @@ -1,6 +1,6 @@ cmake>=3.20,<4.0.0 ninja -numpy +numpy<2.0.0 peft==0.15.2 setuptools torch==2.4.0