diff --git a/ais_bench/benchmark/calculators/base_perf_metric_calculator.py b/ais_bench/benchmark/calculators/base_perf_metric_calculator.py index a50721bc937624d74f6538cfa0cdb3445f1e3e22..0d61881b3698d772b038904bf4677e8609ef40b8 100644 --- a/ais_bench/benchmark/calculators/base_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/base_perf_metric_calculator.py @@ -24,4 +24,16 @@ class BasePerfMetricCalculator(ABC): @abstractmethod def calculate(self): - pass \ No newline at end of file + pass + + def extract_success_item(self, requests: dict): + is_success = requests.get("is_success", []) + for key, value in requests.items(): + if key == "is_success": + continue + if isinstance(value, list): + value = [v for i, v in enumerate(value) if is_success[i]] + else: + value = value if is_success else [] + requests[key] = value + return requests \ No newline at end of file diff --git a/ais_bench/benchmark/calculators/default_perf_metric_calculator.py b/ais_bench/benchmark/calculators/default_perf_metric_calculator.py index fd17c8db96684626a78b559ce97ef096145b14ef..21cf08ffe8527b646576dc08a2b3fb083498df61 100644 --- a/ais_bench/benchmark/calculators/default_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/default_perf_metric_calculator.py @@ -20,6 +20,7 @@ class DefaultPerfMetricCalculator(BasePerfMetricCalculator): if sum(perf_details["requests"]["is_success"]) == 0: self.logger.error("All requests failed, can't calculate performance results. Please check the ERROR log from every responses!") raise ValueError("All requests failed!") + self.extract_success_item(perf_details.get("requests", {})) self.stage_dict = { "total": self._get_requests_id(perf_details) } diff --git a/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py b/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py index 4dbe7f019bc17f3210dfb4e32bafec6611616efc..9e820a27ebc946a4b7df1cb8bd1857a1ce0ec12f 100644 --- a/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py +++ b/ais_bench/benchmark/calculators/stable_perf_metric_calculator.py @@ -24,6 +24,7 @@ class StablePerfMetricCalculator(BasePerfMetricCalculator): if sum(perf_details["requests"]["is_success"]) == 0: self.logger.error("All requests failed, can't calculate performance results. Please check the ERROR log from every responses!") raise ValueError("All requests failed!") + self.extract_success_item(perf_details.get("requests", {})) self.stage_dict = { "stable": self._get_requests_id(perf_details) } diff --git a/ais_bench/benchmark/clients/__init__.py b/ais_bench/benchmark/clients/__init__.py index 1a8a134491b951f81b478cf85c4ae7f813406f22..de29edbc390c467e804a2197e61d1fba55780092 100644 --- a/ais_bench/benchmark/clients/__init__.py +++ b/ais_bench/benchmark/clients/__init__.py @@ -11,4 +11,5 @@ from ais_bench.benchmark.clients.openai_chat_text_client import OpenAIChatTextCl from ais_bench.benchmark.clients.vllm_text_client import VLLMTextClient from ais_bench.benchmark.clients.openai_chat_stream_sglang_client import OpenAIChatStreamSglangClient from ais_bench.benchmark.clients.openai_prompt_chat_text_client import OpenAIPromptChatTextClient -from ais_bench.benchmark.clients.openai_function_chat_text_client import OpenAIFunctionChatTextClient \ No newline at end of file +from ais_bench.benchmark.clients.openai_function_chat_text_client import OpenAIFunctionChatTextClient +from ais_bench.benchmark.clients.sglang_stream_client import SGLangStreamClient \ No newline at end of file diff --git a/ais_bench/benchmark/clients/sglang_stream_client.py b/ais_bench/benchmark/clients/sglang_stream_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ba01e78f46ccce3adbb39cbe1c1a7879ecbb8025 --- /dev/null +++ b/ais_bench/benchmark/clients/sglang_stream_client.py @@ -0,0 +1,36 @@ +from abc import ABC + +from ais_bench.benchmark.clients.base_client import BaseStreamClient +from ais_bench.benchmark.utils import MiddleData + + +class SGLangStreamClient(BaseStreamClient, ABC): + def construct_request_body( + self, + inputs: str, + parameters: dict = None, + ) -> dict: + return dict(text=inputs, stream=True, sampling_params=parameters) + + def process_stream_line(self, json_content: dict) -> dict: + response = {} + generated_text = json_content.get("text", None) + if generated_text: + response.update({"generated_text": generated_text}) + if json_content.get("meta_info"): + response.update({"completion_tokens": json_content["meta_info"].get("completion_tokens", 0)}) + return response + + def update_middle_data(self, res: dict, inputs: MiddleData): + generated_text = res.get("generated_text", "") + if generated_text: + inputs.output = generated_text + inputs.num_generated_chars = len(generated_text) + prefill_time = res.get("prefill_time") + if prefill_time: + inputs.prefill_latency = prefill_time + decode_time = res.get("decode_time") + if decode_time: + inputs.decode_cost.append(decode_time) + if res.get("completion_tokens"): + inputs.num_generated_tokens = res.get("completion_tokens") diff --git a/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py b/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py new file mode 100644 index 0000000000000000000000000000000000000000..29892c206ba194e8f168166e25fc5a93f547da21 --- /dev/null +++ b/ais_bench/benchmark/configs/models/sglang_api/sglang_stream_api_general.py @@ -0,0 +1,20 @@ +from ais_bench.benchmark.models import SGLangCustomAPIStream + +models = [ + dict( + attr="service", + type=SGLangCustomAPIStream, + abbr="sglang-stream-api-general", + path="", + request_rate=0, + retry=2, + host_ip="localhost", + host_port=8080, + max_out_len=512, + batch_size=1, + trust_remote_code=False, + generation_kwargs=dict( + temperature=0.5, + ), + ) +] diff --git a/ais_bench/benchmark/models/__init__.py b/ais_bench/benchmark/models/__init__.py index b95b5e17ce65631a5d621e6ec264826e5da95b0c..d82b89d9d729d1a29869f88bea9d1211c1f0dc0c 100644 --- a/ais_bench/benchmark/models/__init__.py +++ b/ais_bench/benchmark/models/__init__.py @@ -8,4 +8,5 @@ from ais_bench.benchmark.models.huggingface_above_v4_33 import HuggingFaceBaseMo from ais_bench.benchmark.models.tgi_api import TGICustomAPI, TGICustomAPIStream from ais_bench.benchmark.models.triton_api import TritonCustomAPI, TritonCustomAPIStream from ais_bench.benchmark.models.vllm_custom_api_chat_multiturn import VllmMultiturnAPIChatStream -from ais_bench.benchmark.models.vllm_function_call_api_chat import VLLMFunctionCallAPIChat \ No newline at end of file +from ais_bench.benchmark.models.vllm_function_call_api_chat import VLLMFunctionCallAPIChat +from ais_bench.benchmark.models.sglang_api import SGLangCustomAPIStream \ No newline at end of file diff --git a/ais_bench/benchmark/models/sglang_api.py b/ais_bench/benchmark/models/sglang_api.py new file mode 100644 index 0000000000000000000000000000000000000000..61be3484c6aaf9dccaf7cb0c864c434f629c5f23 --- /dev/null +++ b/ais_bench/benchmark/models/sglang_api.py @@ -0,0 +1,133 @@ +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from tqdm import tqdm + +from ais_bench.benchmark.registry import MODELS +from ais_bench.benchmark.utils.prompt import PromptList + +from ais_bench.benchmark.models.base_api import handle_synthetic_input +from ais_bench.benchmark.models.performance_api import PerformanceAPIModel +from ais_bench.benchmark.clients import SGLangStreamClient +from ais_bench.benchmark.utils.build import build_client_from_cfg + +PromptType = Union[PromptList, str, dict] + +@MODELS.register_module() +class SGLangCustomAPIStream(PerformanceAPIModel): + """Model wrapper around SGLang's /generate endpoint. + + Args: + max_seq_len (int): The maximum allowed sequence length of a model. + Note that the length of prompt + generated tokens shall not exceed + this value. Defaults to 2048. + request_rate (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + retry (int): Number of retires if the API call fails. Defaults to 2. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + host_ip (str): The host ip of custom service, default "localhost". + host_port (int): The host port of custom service, default "8080". + enable_ssl (bool, optional): . + """ + + is_api: bool = True + + def __init__(self, + path: str = "", + max_seq_len: int = 4096, + request_rate: int = 1, + rpm_verbose: bool = False, + retry: int = 2, + meta_template: Optional[Dict] = None, + verbose: bool = False, + host_ip: str = "localhost", + host_port: int = 8080, + enable_ssl: bool = False, + custom_client = dict(type=SGLangStreamClient), + generation_kwargs: Optional[Dict] = None, + trust_remote_code: bool = False): + super().__init__(path=path, + max_seq_len=max_seq_len, + meta_template=meta_template, + request_rate=request_rate, + rpm_verbose=rpm_verbose, + retry=retry, + verbose=verbose, + generation_kwargs=generation_kwargs, + trust_remote_code=trust_remote_code) + self.host_ip = host_ip + self.host_port = host_port + self.enable_ssl = enable_ssl + self.base_url = self._get_base_url() + self.generation_kwargs = generation_kwargs + self.endpoint_url = os.path.join(self.base_url, f"generate") + self.init_client(custom_client) + + def init_client(self, custom_client): + if not isinstance(custom_client, dict): + self.logger.warning(f"Value of custom_client: {custom_client} is not a dict! Use Default") + custom_client = dict(type=SGLangStreamClient) + custom_client['url'] = self.endpoint_url + custom_client['retry'] = self.retry + self.client = build_client_from_cfg(custom_client) + + def generate(self, + inputs: List[PromptType], + max_out_len: int = 512, + **kwargs) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[PromptType]): A list of strings or PromptDicts. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + batch_size = kwargs.get("batch_size", len(inputs)) + with ThreadPoolExecutor(max_workers=batch_size) as executor: + results = list( + tqdm(executor.map(self._generate, inputs, + [max_out_len] * len(inputs)), + total=len(inputs), + desc='Inferencing')) + return results + + @handle_synthetic_input + def _generate(self, input: PromptType, max_out_len: int) -> str: + """Generate result given a input. + + Args: + input (PromptType): A string or PromptDict. + The PromptDict should be organized in AISBench' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + if isinstance(input, dict): + data_id = input.get('data_id') + input = input.get('prompt') + else: + data_id = -1 + if max_out_len <= 0: + return '' + cache_data = self.prepare_input_data(input, data_id) + generation_kwargs = self.generation_kwargs.copy() + generation_kwargs.update({"max_new_tokens": max_out_len}) + + response = self.client.request(cache_data, generation_kwargs) + self.set_result(cache_data) + + return ''.join(response) + + def _get_base_url(self): + if self.enable_ssl: + return f"https://{self.host_ip}:{self.host_port}/" + return f"http://{self.host_ip}:{self.host_port}/" \ No newline at end of file