From 8ea1d048eeac5283f23adfa3d813f4cd10cdfda0 Mon Sep 17 00:00:00 2001 From: gaohua Date: Wed, 3 Sep 2025 09:57:58 +0800 Subject: [PATCH 1/2] sglang api support --- ais_bench/benchmark/clients/__init__.py | 3 +- .../benchmark/clients/sglang_stream_client.py | 36 +++++ .../sglang_api/sglang_stream_api_general.py | 20 +++ ais_bench/benchmark/models/__init__.py | 3 +- ais_bench/benchmark/models/sglang_api.py | 133 ++++++++++++++++++ 5 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 ais_bench/benchmark/clients/sglang_stream_client.py create mode 100644 ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py create mode 100644 ais_bench/benchmark/models/sglang_api.py diff --git a/ais_bench/benchmark/clients/__init__.py b/ais_bench/benchmark/clients/__init__.py index 1a8a134..de29edb 100644 --- a/ais_bench/benchmark/clients/__init__.py +++ b/ais_bench/benchmark/clients/__init__.py @@ -11,4 +11,5 @@ from ais_bench.benchmark.clients.openai_chat_text_client import OpenAIChatTextCl from ais_bench.benchmark.clients.vllm_text_client import VLLMTextClient from ais_bench.benchmark.clients.openai_chat_stream_sglang_client import OpenAIChatStreamSglangClient from ais_bench.benchmark.clients.openai_prompt_chat_text_client import OpenAIPromptChatTextClient -from ais_bench.benchmark.clients.openai_function_chat_text_client import OpenAIFunctionChatTextClient \ No newline at end of file +from ais_bench.benchmark.clients.openai_function_chat_text_client import OpenAIFunctionChatTextClient +from ais_bench.benchmark.clients.sglang_stream_client import SGLangStreamClient \ No newline at end of file diff --git a/ais_bench/benchmark/clients/sglang_stream_client.py b/ais_bench/benchmark/clients/sglang_stream_client.py new file mode 100644 index 0000000..ba01e78 --- /dev/null +++ b/ais_bench/benchmark/clients/sglang_stream_client.py @@ -0,0 +1,36 @@ +from abc import ABC + +from ais_bench.benchmark.clients.base_client import BaseStreamClient +from ais_bench.benchmark.utils import MiddleData + + +class SGLangStreamClient(BaseStreamClient, ABC): + def construct_request_body( + self, + inputs: str, + parameters: dict = None, + ) -> dict: + return dict(text=inputs, stream=True, sampling_params=parameters) + + def process_stream_line(self, json_content: dict) -> dict: + response = {} + generated_text = json_content.get("text", None) + if generated_text: + response.update({"generated_text": generated_text}) + if json_content.get("meta_info"): + response.update({"completion_tokens": json_content["meta_info"].get("completion_tokens", 0)}) + return response + + def update_middle_data(self, res: dict, inputs: MiddleData): + generated_text = res.get("generated_text", "") + if generated_text: + inputs.output = generated_text + inputs.num_generated_chars = len(generated_text) + prefill_time = res.get("prefill_time") + if prefill_time: + inputs.prefill_latency = prefill_time + decode_time = res.get("decode_time") + if decode_time: + inputs.decode_cost.append(decode_time) + if res.get("completion_tokens"): + inputs.num_generated_tokens = res.get("completion_tokens") diff --git a/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py b/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py new file mode 100644 index 0000000..29892c2 --- /dev/null +++ b/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py @@ -0,0 +1,20 @@ +from ais_bench.benchmark.models import SGLangCustomAPIStream + +models = [ + dict( + attr="service", + type=SGLangCustomAPIStream, + abbr="sglang-stream-api-general", + path="", + request_rate=0, + retry=2, + host_ip="localhost", + host_port=8080, + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.5, + ), + ) +] diff --git a/ais_bench/benchmark/models/__init__.py b/ais_bench/benchmark/models/__init__.py index b95b5e1..d82b89d 100644 --- a/ais_bench/benchmark/models/__init__.py +++ b/ais_bench/benchmark/models/__init__.py @@ -8,4 +8,5 @@ from ais_bench.benchmark.models.huggingface_above_v4_33 import HuggingFaceBaseMo from ais_bench.benchmark.models.tgi_api import TGICustomAPI, TGICustomAPIStream from ais_bench.benchmark.models.triton_api import TritonCustomAPI, TritonCustomAPIStream from ais_bench.benchmark.models.vllm_custom_api_chat_multiturn import VllmMultiturnAPIChatStream -from ais_bench.benchmark.models.vllm_function_call_api_chat import VLLMFunctionCallAPIChat \ No newline at end of file +from ais_bench.benchmark.models.vllm_function_call_api_chat import VLLMFunctionCallAPIChat +from ais_bench.benchmark.models.sglang_api import SGLangCustomAPIStream \ No newline at end of file diff --git a/ais_bench/benchmark/models/sglang_api.py b/ais_bench/benchmark/models/sglang_api.py new file mode 100644 index 0000000..61be348 --- /dev/null +++ b/ais_bench/benchmark/models/sglang_api.py @@ -0,0 +1,133 @@ +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from tqdm import tqdm + +from ais_bench.benchmark.registry import MODELS +from ais_bench.benchmark.utils.prompt import PromptList + +from ais_bench.benchmark.models.base_api import handle_synthetic_input +from ais_bench.benchmark.models.performance_api import PerformanceAPIModel +from ais_bench.benchmark.clients import SGLangStreamClient +from ais_bench.benchmark.utils.build import build_client_from_cfg + +PromptType = Union[PromptList, str, dict] + +@MODELS.register_module() +class SGLangCustomAPIStream(PerformanceAPIModel): + """Model wrapper around SGLang's /generate endpoint. + + Args: + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + request_rate (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + host_ip (str): The host ip of custom service, default "localhost". + host_port (int): The host port of custom service, default "8080". + enable_ssl (bool, optional): . + """ + + is_api: bool = True + + def __init__(self, + path: str = "", + max_seq_len: int = 4096, + request_rate: int = 1, + rpm_verbose: bool = False, + retry: int = 2, + meta_template: Optional[Dict] = None, + verbose: bool = False, + host_ip: str = "localhost", + host_port: int = 8080, + enable_ssl: bool = False, + custom_client = dict(type=SGLangStreamClient), + generation_kwargs: Optional[Dict] = None, + trust_remote_code: bool = False): + super().__init__(path=path, + max_seq_len=max_seq_len, + meta_template=meta_template, + request_rate=request_rate, + rpm_verbose=rpm_verbose, + retry=retry, + verbose=verbose, + generation_kwargs=generation_kwargs, + trust_remote_code=trust_remote_code) + self.host_ip = host_ip + self.host_port = host_port + self.enable_ssl = enable_ssl + self.base_url = self._get_base_url() + self.generation_kwargs = generation_kwargs + self.endpoint_url = os.path.join(self.base_url, f"generate") + self.init_client(custom_client) + + def init_client(self, custom_client): + if not isinstance(custom_client, dict): + self.logger.warning(f"Value of custom_client: {custom_client} is not a dict! Use Default") + custom_client = dict(type=SGLangStreamClient) + custom_client['url'] = self.endpoint_url + custom_client['retry'] = self.retry + self.client = build_client_from_cfg(custom_client) + + def generate(self, + inputs: List[PromptType], + max_out_len: int = 512, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + batch_size = kwargs.get("batch_size", len(inputs)) + with ThreadPoolExecutor(max_workers=batch_size) as executor: + results = list( + tqdm(executor.map(self._generate, inputs, + [max_out_len] * len(inputs)), + total=len(inputs), + desc='Inferencing')) + return results + + @handle_synthetic_input + def _generate(self, input: PromptType, max_out_len: int) -> str: + """Generate result given a input. + + Args: + input (PromptType): A string or PromptDict. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + if isinstance(input, dict): + data_id = input.get('data_id') + input = input.get('prompt') + else: + data_id = -1 + if max_out_len <= 0: + return '' + cache_data = self.prepare_input_data(input, data_id) + generation_kwargs = self.generation_kwargs.copy() + generation_kwargs.update({"max_new_tokens": max_out_len}) + + response = self.client.request(cache_data, generation_kwargs) + self.set_result(cache_data) + + return ''.join(response) + + def _get_base_url(self): + if self.enable_ssl: + return f"https://{self.host_ip}:{self.host_port}/" + return f"http://{self.host_ip}:{self.host_port}/" \ No newline at end of file -- Gitee From f37895db5f05459b6c78547d4ef6888a298e72de Mon Sep 17 00:00:00 2001 From: gaohua Date: Mon, 13 Oct 2025 17:17:18 +0800 Subject: [PATCH 2/2] fix duration bug if not all success --- .../calculators/base_perf_metric_calculator.py | 14 +++++++++++++- .../calculators/default_perf_metric_calculator.py | 1 + .../calculators/stable_perf_metric_calculator.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/ais_bench/benchmark/calculators/base_perf_metric_calculator.py b/ais_bench/benchmark/calculators/base_perf_metric_calculator.py index a50721b..0d61881 100644 --- a/ais_bench/benchmark/calculators/base_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/base_perf_metric_calculator.py @@ -24,4 +24,16 @@ class BasePerfMetricCalculator(ABC): @abstractmethod def calculate(self): - pass \ No newline at end of file + pass + + def extract_success_item(self, requests: dict): + is_success = requests.get("is_success", []) + for key, value in requests.items(): + if key == "is_success": + continue + if isinstance(value, list): + value = [v for i, v in enumerate(value) if is_success[i]] + else: + value = value if is_success else [] + requests[key] = value + return requests \ No newline at end of file diff --git a/ais_bench/benchmark/calculators/default_perf_metric_calculator.py b/ais_bench/benchmark/calculators/default_perf_metric_calculator.py index fd17c8d..21cf08f 100644 --- a/ais_bench/benchmark/calculators/default_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/default_perf_metric_calculator.py @@ -20,6 +20,7 @@ class DefaultPerfMetricCalculator(BasePerfMetricCalculator): if sum(perf_details["requests"]["is_success"]) == 0: self.logger.error("All requests failed, can't calculate performance results. Please check the ERROR log from every responses!") raise ValueError("All requests failed!") + self.extract_success_item(perf_details.get("requests", {})) self.stage_dict = { "total": self._get_requests_id(perf_details) } diff --git a/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py b/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py index 4dbe7f0..9e820a2 100644 --- a/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py @@ -24,6 +24,7 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): if sum(perf_details["requests"]["is_success"]) == 0: self.logger.error("All requests failed, can't calculate performance results. Please check the ERROR log from every responses!") raise ValueError("All requests failed!") + self.extract_success_item(perf_details.get("requests", {})) self.stage_dict = { "stable": self._get_requests_id(perf_details) } -- Gitee