diff --git a/RAGSDK/opensource/tei/Dockerfile b/RAGSDK/opensource/tei/Dockerfile index 1aa5b99165f2daee496ef0aaa750cd9f6c423157..8ff9b65aa350c435fb22c29155d20f8a61ecb61a 100644 --- a/RAGSDK/opensource/tei/Dockerfile +++ b/RAGSDK/opensource/tei/Dockerfile @@ -74,7 +74,7 @@ RUN curl https://sh.rustup.rs -sSf | bash -s -- -y --no-update-default-toolchain #安装 TEI WORKDIR /home/HwHiAiUser -RUN git clone https://github.com/huggingface/text-embeddings-inference.git && cd text-embeddings-inference && git checkout v1.6.1 +RUN git clone https://github.com/huggingface/text-embeddings-inference.git && cd text-embeddings-inference && git checkout v1.7.4 COPY ./package/tei.patch /tmp RUN cd /home/HwHiAiUser/text-embeddings-inference && patch -p1 < /tmp/tei.patch #RUN sed -i 's/channel = .*/channel = "1.83.0"/g' /home/HwHiAiUser/text-embeddings-inference/rust-toolchain.toml diff --git a/RAGSDK/opensource/tei/tei.patch b/RAGSDK/opensource/tei/tei.patch index a730f291f135008911c59bbcfe865f3d5457820d..f3cb7aa8855a9592b576c91deceff4e6744633e6 100644 --- a/RAGSDK/opensource/tei/tei.patch +++ b/RAGSDK/opensource/tei/tei.patch @@ -87,25 +87,8 @@ index 6402d63..8ad0028 100644 gen-server: # Compile protos pip install grpcio-tools==1.62.2 mypy-protobuf==3.6.0 'types-protobuf' --no-cache-dir -diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml -index 0654eb7..46c3ca2 100644 ---- a/backends/python/server/pyproject.toml -+++ b/backends/python/server/pyproject.toml -@@ -29,9 +29,9 @@ grpcio-tools = "^1.51.1" - pytest = "^7.3.0" - - [[tool.poetry.source]] --name = "pytorch-gpu-src" --url = "https://download.pytorch.org/whl/cu118" --priority = "explicit" -+name = "mirrors" -+url = "https://pypi.tuna.tsinghua.edu.cn/simple/" -+priority = "default" - - [tool.pytest.ini_options] - markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt -index 687ec10..79cee7a 100644 +index 687ec10..e46f148 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -6,10 +6,10 @@ deprecated==1.2.15 ; python_version >= "3.9" and python_version < "3.13" @@ -132,12 +115,8 @@ index 687ec10..79cee7a 100644 nvidia-cublas-cu12==12.4.5.8 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" nvidia-cuda-cupti-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.9" and python_version < "3.13" -@@ -48,17 +48,17 @@ protobuf==4.25.5 ; python_version >= "3.9" and python_version < "3.13" - pyyaml==6.0.2 ; python_version >= "3.9" and python_version < "3.13" - regex==2024.11.6 ; python_version >= "3.9" and python_version < "3.13" - requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" --safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" -+safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13" +@@ -51,14 +51,14 @@ requests==2.32.3 ; python_version >= "3.9" and python_version < "3.13" + safetensors==0.4.5 ; python_version >= "3.9" and python_version < "3.13" scikit-learn==1.5.2 ; python_version >= "3.9" and python_version < "3.13" scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.13" -sentence-transformers==3.3.1 ; python_version >= "3.9" and python_version < "3.13" @@ -155,11 +134,19 @@ index 687ec10..79cee7a 100644 typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py -index 9f56065..0e47676 100644 +index 1e919f2..90d9487 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py -@@ -13,22 +13,16 @@ from text_embeddings_server.models.default_model import DefaultModel +@@ -11,31 +11,21 @@ from text_embeddings_server.models.model import Model + from text_embeddings_server.models.masked_model import MaskedLanguageModel + from text_embeddings_server.models.default_model import DefaultModel from text_embeddings_server.models.classification_model import ClassificationModel +-from text_embeddings_server.models.jinaBert_model import FlashJinaBert +-from text_embeddings_server.models.flash_mistral import FlashMistral +-from text_embeddings_server.models.flash_qwen3 import FlashQwen3 ++from text_embeddings_server.models.qwen3_rerank_model import Qwen3RerankModel ++from text_embeddings_server.models.unixcoder_model import UniXcoderModel ++ from text_embeddings_server.utils.device import get_device, use_ipex +from modeling_bert_adapter import enable_bert_speed @@ -169,6 +156,10 @@ index 9f56065..0e47676 100644 __all__ = ["Model"] TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"] +-DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [ +- "true", +- "1", +-] # Disable gradients torch.set_grad_enabled(False) @@ -183,9 +174,9 @@ index 9f56065..0e47676 100644 - __all__.append(FlashBert) - - def get_model(model_path: Path, dtype: Optional[str], pool: str): - if dtype == "float32": -@@ -40,11 +34,20 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): + def wrap_model_if_hpu(model_handle, device): + """Wrap the model in HPU graph if the device is HPU.""" +@@ -70,22 +60,21 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): else: raise RuntimeError(f"Unknown dtype {dtype}") @@ -201,14 +192,47 @@ index 9f56065..0e47676 100644 logger.info(f"backend device: {device}") config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE) + +- if ( +- hasattr(config, "auto_map") +- and isinstance(config.auto_map, dict) +- and "AutoModel" in config.auto_map +- and config.auto_map["AutoModel"] +- == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel" +- ): +- # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository +- return create_model(FlashJinaBert, model_path, device, datatype) +- - if config.model_type == "bert": + if config.model_type == "bert" or config.model_type == "qwen2" or config.model_type == "qwen3" or \ -+ config.model_type == "roberta" or config.model_type == "xlm-roberta" : ++ config.model_type == "roberta" or config.model_type == "xlm-roberta": config: BertConfig if ( use_ipex() +@@ -111,8 +100,12 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str): + + if config.architectures[0].endswith("Classification"): + return create_model(ClassificationModel, model_path, device, datatype) ++ elif os.getenv("IS_RERANK", None): ++ return create_model(Qwen3RerankModel, model_path, device, datatype) + elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade": + return create_model(MaskedLanguageModel, model_path, device, datatype) ++ elif str(model_path).endswith("unixcoder-base"): ++ return create_model(UniXcoderModel, model_path, device, datatype) + else: + return create_model(DefaultModel, model_path, device, datatype, pool) + +diff --git a/backends/python/server/text_embeddings_server/models/classification_model.py b/backends/python/server/text_embeddings_server/models/classification_model.py +index 91f33ed..a1395ef 100644 +--- a/backends/python/server/text_embeddings_server/models/classification_model.py ++++ b/backends/python/server/text_embeddings_server/models/classification_model.py +@@ -70,3 +70,4 @@ class ClassificationModel(Model): + output = self.model(**kwargs, return_dict=True) + all_scores = output.logits.tolist() + return [Score(values=scores) for scores in all_scores] ++ diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py -index f5c569f..20c13b3 100644 +index 66a9b6b..c9385e3 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -1,17 +1,24 @@ @@ -238,7 +262,15 @@ index f5c569f..20c13b3 100644 class DefaultModel(Model): def __init__( -@@ -26,9 +33,11 @@ class DefaultModel(Model): +@@ -19,16 +26,18 @@ class DefaultModel(Model): + model_path: Path, + device: torch.device, + dtype: torch.dtype, +- pool: str = "cls", ++ pool: str, + trust_remote: bool = False, + ): + model = ( AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote) .to(dtype) .to(device) @@ -270,7 +302,7 @@ index f5c569f..20c13b3 100644 super(DefaultModel, self).__init__(model=model, dtype=dtype, device=device) @property -@@ -63,19 +84,75 @@ class DefaultModel(Model): +@@ -63,19 +84,76 @@ class DefaultModel(Model): kwargs["token_type_ids"] = batch.token_type_ids if self.has_position_ids: kwargs["position_ids"] = batch.position_ids @@ -349,6 +381,7 @@ index f5c569f..20c13b3 100644 @tracer.start_as_current_span("predict") def predict(self, batch: PaddedBatch) -> List[Score]: pass ++ diff --git a/backends/python/server/text_embeddings_server/models/pooling.py b/backends/python/server/text_embeddings_server/models/pooling.py index 43f77b1..69c69f9 100644 --- a/backends/python/server/text_embeddings_server/models/pooling.py @@ -362,8 +395,111 @@ index 43f77b1..69c69f9 100644 } return self.pooling.forward(pooling_features)["sentence_embedding"] +diff --git a/backends/python/server/text_embeddings_server/models/qwen3_rerank_model.py b/backends/python/server/text_embeddings_server/models/qwen3_rerank_model.py +new file mode 100644 +index 0000000..fb009f6 +--- /dev/null ++++ b/backends/python/server/text_embeddings_server/models/qwen3_rerank_model.py +@@ -0,0 +1,97 @@ ++import inspect ++import torch ++import os ++ ++from pathlib import Path ++from typing import Type, List ++from transformers import AutoModelForSequenceClassification, Qwen3ForCausalLM, AutoTokenizer, AutoModelForCausalLM ++from opentelemetry import trace ++ ++from text_embeddings_server.models import Model ++from text_embeddings_server.models.types import PaddedBatch, Embedding, Score ++ ++tracer = trace.get_tracer(__name__) ++ ++ ++class Qwen3RerankModel(Model): ++ def __init__( ++ self, ++ model_path: Path, ++ device: torch.device, ++ dtype: torch.dtype, ++ pool: str = "cls", ++ trust_remote: bool = False, ++ ): ++ ++ # Check environment variable to decide reranker mode ++ self.qwen3_mode = os.environ.get("IS_RERANK", "0") == "1" ++ position_offset = 0 ++ # Load tokenizer (for both modes) ++ self.tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left", ++ trust_remote_code=trust_remote) ++ # 为 Qwen3 模型设置 pad_token 以支持 batch 推理 ++ if self.tokenizer.pad_token is None: ++ self.tokenizer.pad_token = self.tokenizer.eos_token ++ # ------------------------------- ++ # Qwen3-Reranker 初始化逻辑 ++ # ------------------------------- ++ self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=trust_remote) ++ self.model = self.model.to(dtype).to(device).eval() ++ self.model.config.pad_token_id = self.tokenizer.pad_token_id ++ # 用于从 logits 提取 "yes" 和 "no" 的得分 ++ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") ++ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") ++ prefix = "<|im_start|>system\nDetermine whether the problem is related to the document. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" ++ suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" ++ self.prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False) ++ self.suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False) ++ ++ if hasattr(self.model.config, "max_seq_length"): ++ self.max_input_length = self.model.config.max_seq_length ++ else: ++ self.max_input_length = ( ++ self.model.config.max_position_embeddings - position_offset ++ ) ++ ++ super(Qwen3RerankModel, self).__init__( ++ model=self.model, dtype=dtype, device=device ++ ) ++ ++ ++ @property ++ def batch_type(self) -> Type[PaddedBatch]: ++ return PaddedBatch ++ ++ @tracer.start_as_current_span("embed") ++ def embed(self, batch: PaddedBatch) -> List[Embedding]: ++ pass ++ ++ @tracer.start_as_current_span("predict") ++ def predict(self, batch: PaddedBatch) -> List[Score]: ++ kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} ++ # Qwen3-Reranker 的打分逻辑(prompt 生成 + yes/no logits) ++ input_ids = [] ++ for ele in batch.input_ids: ++ ids = self.prefix_tokens + ele.tolist() + self.suffix_tokens ++ input_ids.append(ids) ++ ++ # 注意:传入的是 list 而非 tensor ++ tokenized = self.tokenizer.pad( ++ {"input_ids": input_ids}, ++ padding=True, ++ return_tensors="pt", ++ max_length=self.max_input_length ++ ) ++ inputs = {k: v.to(self.model.device) for k, v in tokenized.items()} ++ ++ with torch.no_grad(): ++ outputs = self.model(**inputs) ++ logits = outputs.logits[:, -1, :] # 最后一个 token 的预测分布 ++ ++ # 提取 "yes" 和 "no" 的概率,作为是否相关的判断 ++ true_logits = logits[:, self.token_true_id] ++ false_logits = logits[:, self.token_false_id] ++ logit_diff = true_logits - false_logits ++ return [Score(values=[p.item()]) for p in logit_diff] ++ ++ diff --git a/backends/python/server/text_embeddings_server/models/types.py b/backends/python/server/text_embeddings_server/models/types.py -index 4f2cfa4..fbd8c2d 100644 +index f27572a..71db970 100644 --- a/backends/python/server/text_embeddings_server/models/types.py +++ b/backends/python/server/text_embeddings_server/models/types.py @@ -7,7 +7,7 @@ from dataclasses import dataclass @@ -375,7 +511,7 @@ index 4f2cfa4..fbd8c2d 100644 tracer = trace.get_tracer(__name__) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) -@@ -34,6 +34,7 @@ class PaddedBatch(Batch): +@@ -36,6 +36,7 @@ class PaddedBatch(Batch): token_type_ids: torch.Tensor position_ids: torch.Tensor attention_mask: torch.Tensor @@ -383,7 +519,7 @@ index 4f2cfa4..fbd8c2d 100644 @classmethod @tracer.start_as_current_span("from_pb") -@@ -78,6 +79,7 @@ class PaddedBatch(Batch): +@@ -82,6 +83,7 @@ class PaddedBatch(Batch): token_type_ids=all_tensors[1], position_ids=all_tensors[2], attention_mask=all_tensors[3], @@ -391,6 +527,100 @@ index 4f2cfa4..fbd8c2d 100644 ) def __len__(self): +diff --git a/backends/python/server/text_embeddings_server/models/unixcoder_model.py b/backends/python/server/text_embeddings_server/models/unixcoder_model.py +new file mode 100644 +index 0000000..ea95871 +--- /dev/null ++++ b/backends/python/server/text_embeddings_server/models/unixcoder_model.py +@@ -0,0 +1,88 @@ ++import os ++import inspect ++import torch ++ ++from pathlib import Path ++from typing import Type, List ++from transformers import RobertaModel, RobertaConfig, RobertaTokenizer ++from opentelemetry import trace ++ ++from collections import defaultdict ++import numpy as np ++from loguru import logger ++ ++from text_embeddings_server.models.pooling import DefaultPooling ++ ++from text_embeddings_server.models import Model ++from text_embeddings_server.models.types import PaddedBatch, Embedding, Score, TokenEmbedding ++ ++tracer = trace.get_tracer(__name__) ++ ++is_unixcode = os.getenv("IS_UNIXCODE", None) ++ ++class UniXcoderModel(Model): ++ def __init__( ++ self, ++ model_path: Path, ++ device: torch.device, ++ dtype: torch.dtype, ++ pool: str, ++ trust_remote: bool = False, ++ ): ++ self.config = RobertaConfig.from_pretrained(model_path) ++ self.config.is_decoder = True ++ model = ( ++ RobertaModel.from_pretrained(model_path, config=self.config) ++ .to(dtype) ++ .to(device) ++ .eval() ++ ) ++ ++ self.hidden_size = model.config.hidden_size ++ ++ position_offset = 0 ++ model_type = model.config.model_type ++ if model_type in ["xlm-roberta", "camembert", "roberta"]: ++ position_offset = model.config.pad_token_id + 1 ++ if hasattr(model.config, "max_seq_length"): ++ self.max_input_length = model.config.max_seq_length ++ else: ++ self.max_input_length = ( ++ model.config.max_position_embeddings - position_offset ++ ) ++ ++ self.tokenizer = RobertaTokenizer.from_pretrained(model_path, local_files_only=True) ++ ++ super(UniXcoderModel, self).__init__(model=model, dtype=dtype, device=device) ++ ++ @property ++ def batch_type(self) -> Type[PaddedBatch]: ++ return PaddedBatch ++ ++ @tracer.start_as_current_span("embed") ++ def embed(self, batch: PaddedBatch) -> List[Embedding]: ++ kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask} ++ return self._process_unixcode(batch, kwargs) ++ ++ def _process_unixcode(self, batch: PaddedBatch, kwargs: dict): ++ tokens_ids = [] ++ mode = "" ++ mode_id = self.tokenizer.convert_tokens_to_ids(mode) ++ for tokens_id in batch.input_ids: ++ tokens_id = tokens_id[:self.max_input_length - 4] ++ tokens_id = tokens_id.tolist()[1:-1] ++ tokens_id = [self.tokenizer.cls_token_id, mode_id, self.tokenizer.sep_token_id] + tokens_id + [self.tokenizer.sep_token_id] ++ ++ tokens_ids.append(tokens_id) ++ tokens_ids = torch.tensor(tokens_ids).to(self.device) ++ mask = tokens_ids.ne(self.config.pad_token_id) ++ attention_mask=mask.unsqueeze(1) * mask.unsqueeze(2).to(self.device) ++ token_embeddings = self.model(tokens_ids, attention_mask=attention_mask)[0] ++ sentence_embeddings = (token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1) ++ return [ ++ Embedding( ++ values=sentence_embeddings[i] ++ ) ++ for i in range(len(batch)) ++ ] ++ diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index 646d79b..da8f6a1 100644 --- a/backends/python/server/text_embeddings_server/server.py @@ -459,7 +689,7 @@ index 646d79b..da8f6a1 100644 ) embed_pb2_grpc.add_EmbeddingServiceServicer_to_server( diff --git a/backends/python/server/text_embeddings_server/utils/device.py b/backends/python/server/text_embeddings_server/utils/device.py -index 3f3b04d..2168cf6 100644 +index 3f3b04d..04a25b3 100644 --- a/backends/python/server/text_embeddings_server/utils/device.py +++ b/backends/python/server/text_embeddings_server/utils/device.py @@ -4,6 +4,7 @@ import importlib.metadata @@ -470,13 +700,7 @@ index 3f3b04d..2168cf6 100644 import subprocess ALLOW_REDUCED_PRECISION = os.getenv( -@@ -54,11 +55,19 @@ def use_ipex() -> bool: - value = os.environ.get("USE_IPEX", "True").lower() - return value in ["true", "1"] and _is_ipex_available() - -- -+ - def get_device(): +@@ -59,6 +60,14 @@ def get_device(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") @@ -570,10 +794,10 @@ index 53255b0..4d24016 100644 } diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs -index 3b08e92..960148e 100644 +index 80292be..253f322 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs -@@ -59,7 +59,7 @@ impl Default for DType { +@@ -53,7 +53,7 @@ impl Default for DType { } #[cfg(feature = "python")] { @@ -612,11 +836,310 @@ index 3fd8b77..61b38dc 100644 Some(( metadata, Batch { +diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs +index 7636afa..ab9a222 100644 +--- a/core/src/tokenization.rs ++++ b/core/src/tokenization.rs +@@ -293,7 +293,7 @@ fn tokenize_input( + prompts: Option<&HashMap>, + tokenizer: &mut Tokenizer, + ) -> Result<(Option, RawEncoding), TextEmbeddingsError> { +- let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?; ++ let pre_prompt = prepare_pre_prompt(default_prompt.clone(), prompt_name, prompts)?; + + let input_chars = inputs.count_chars(); + let limit = max_input_length * MAX_CHAR_MULTIPLIER; +@@ -322,20 +322,41 @@ fn tokenize_input( + + (Some(s), encoding) + } ++ + EncodingInput::Dual(s1, s2) => { +- if pre_prompt.is_some() { ++ let is_rerank = std::env::var("IS_RERANK").ok().as_deref() == Some("1"); ++ ++ if is_rerank { ++ let default_prompt = default_prompt.ok_or_else(|| { ++ TextEmbeddingsError::Validation( ++ "In rerank mode, `--default-prompt` must be set.".to_string(), ++ ) ++ })?; ++ ++ let prompt = default_prompt ++ .replace("\\n", "\n") ++ .replace("", &s1) ++ .replace("", &s2); ++ ++ let encoding = tokenizer ++ .with_truncation(truncate_params)? ++ .encode::<&str>(&prompt, add_special_tokens)?; ++ ++ (Some(prompt), encoding) ++ } else if pre_prompt.is_some() { + return Err(TextEmbeddingsError::Validation( + "`prompt_name` cannot be set with dual inputs".to_string(), + )); ++ } else { ++ ( ++ None, ++ tokenizer ++ .with_truncation(truncate_params)? ++ .encode::<(String, String)>((s1, s2), add_special_tokens)?, ++ ) + } +- +- ( +- None, +- tokenizer +- .with_truncation(truncate_params)? +- .encode::<(String, String)>((s1, s2), add_special_tokens)?, +- ) + } ++ + // input is encoded -> convert to tokenizers Encoding + EncodingInput::Ids(ids) => { + if let Some(mut pre_prompt) = pre_prompt { diff --git a/router/src/http/server.rs b/router/src/http/server.rs -index cadb6c1..1f0de1e 100644 +index f805744..1a5244a 100644 --- a/router/src/http/server.rs +++ b/router/src/http/server.rs -@@ -1785,8 +1785,7 @@ pub async fn run( +@@ -7,7 +7,7 @@ use crate::http::types::{ + RerankRequest, RerankResponse, Sequence, SimilarityInput, SimilarityParameters, + SimilarityRequest, SimilarityResponse, SimpleToken, SparseValue, TokenizeInput, + TokenizeRequest, TokenizeResponse, TruncationDirection, VertexPrediction, VertexRequest, +- VertexResponse, ++ VertexResponse, OpenAICompatRerankRequest, OpenAICompatRerankResponse, OpenAIRank, DocumentWrapper + }; + use crate::{ + logging, shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType, +@@ -1296,6 +1296,207 @@ async fn openai_embed( + Ok((headers, Json(response))) + } + ++ ++/// Get Ranks. Returns a 424 status code if the model is not a Sequence Classification model with ++/// a single class. ++#[utoipa::path( ++ post, ++ tag = "Text Embeddings Inference", ++ path = "/v1/rerank", ++ request_body = OpenAICompatRerankRequest, ++ responses( ++ (status = 200, description = "Ranks", body = OpenAICompatRerankResponse), ++ (status = 424, description = "Rerank Error", body = ErrorResponse, ++ example = json ! ({"error": "Inference failed", "error_type": "backend"})), ++ (status = 429, description = "Model is overloaded", body = ErrorResponse, ++ example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})), ++ (status = 422, description = "Tokenization error", body = ErrorResponse, ++ example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})), ++ (status = 400, description = "Batch is empty", body = ErrorResponse, ++ example = json ! ({"error": "Batch is empty", "error_type": "empty"})), ++ (status = 413, description = "Batch size error", body = ErrorResponse, ++ example = json ! ({"error": "Batch size error", "error_type": "validation"})), ++ ) ++ )] ++ #[instrument( ++ skip_all, ++ fields(total_time, tokenization_time, queue_time, inference_time,) ++ )] ++ async fn openai_rerank( ++ infer: Extension, ++ info: Extension, ++ Extension(context): Extension>, ++ Json(req): Json, ++ ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { ++ let span = tracing::Span::current(); ++ if let Some(context) = context { ++ span.set_parent(context); ++ } ++ ++ let start_time = Instant::now(); ++ ++ if req.documents.is_empty() { ++ let message = "`documents` cannot be empty".to_string(); ++ tracing::error!("{message}"); ++ let err = ErrorResponse { ++ error: message, ++ error_type: ErrorType::Empty, ++ }; ++ let counter = metrics::counter!("te_request_failure", "err" => "validation"); ++ counter.increment(1); ++ Err(err)?; ++ } ++ ++ match &info.model_type { ++ ModelType::Reranker(_) => Ok(()), ++ ModelType::Classifier(_) | ModelType::Embedding(_) => { ++ let counter = metrics::counter!("te_request_failure", "err" => "model_type"); ++ counter.increment(1); ++ let message = "model is not a re-ranker model".to_string(); ++ Err(TextEmbeddingsError::Backend(BackendError::Inference( ++ message, ++ ))) ++ } ++ } ++ .map_err(|err| { ++ tracing::error!("{err}"); ++ ErrorResponse::from(err) ++ })?; ++ ++ // Closure for rerank ++ let rerank_inner = move |query: String, document: String, truncate: bool, infer: Infer| async move { ++ let permit = infer.acquire_permit().await; ++ ++ let response = infer ++ .predict( ++ (query, document), ++ truncate, ++ req.truncation_direction.into(), ++ req.raw_scores, ++ permit, ++ ) ++ .await ++ .map_err(ErrorResponse::from)?; ++ ++ let score = response.results[0]; ++ ++ Ok::<(usize, Duration, Duration, Duration, f32), ErrorResponse>(( ++ response.metadata.prompt_tokens, ++ response.metadata.tokenization, ++ response.metadata.queue, ++ response.metadata.inference, ++ score, ++ )) ++ }; ++ ++ let truncate = req.truncate.unwrap_or(info.auto_truncate); ++ ++ let (response, metadata) = { ++ let counter = metrics::counter!("te_request_count", "method" => "batch"); ++ counter.increment(1); ++ ++ let batch_size = req.documents.len(); ++ if batch_size > info.max_client_batch_size { ++ let message = format!( ++ "batch size {batch_size} > maximum allowed batch size {}", ++ info.max_client_batch_size ++ ); ++ tracing::error!("{message}"); ++ let err = ErrorResponse { ++ error: message, ++ error_type: ErrorType::Validation, ++ }; ++ let counter = metrics::counter!("te_request_failure", "err" => "batch_size"); ++ counter.increment(1); ++ Err(err)?; ++ } ++ ++ let mut futures = Vec::with_capacity(batch_size); ++ let query_chars = req.query.chars().count(); ++ let mut compute_chars = query_chars * batch_size; ++ ++ for document in &req.documents { ++ compute_chars += document.chars().count(); ++ let local_infer = infer.clone(); ++ futures.push(rerank_inner( ++ req.query.clone(), ++ document.clone(), ++ truncate, ++ local_infer.0, ++ )) ++ } ++ let res = join_all(futures) ++ .await ++ .into_iter() ++ .collect::, ErrorResponse>>()?; ++ ++ let mut results = Vec::with_capacity(batch_size); ++ let mut total_tokenization_time = 0; ++ let mut total_queue_time = 0; ++ let mut total_inference_time = 0; ++ let mut total_compute_tokens = 0; ++ ++ for (index, r) in res.into_iter().enumerate() { ++ total_compute_tokens += r.0; ++ total_tokenization_time += r.1.as_nanos() as u64; ++ total_queue_time += r.2.as_nanos() as u64; ++ total_inference_time += r.3.as_nanos() as u64; ++ let document = if req.return_documents { ++ Some(DocumentWrapper { ++ text: req.documents[index].clone(), ++ }) ++ } else { ++ None ++ }; ++ ++ let relevance_score = r.4; ++ // Check that s is not NaN or the partial_cmp below will panic ++ if relevance_score.is_nan() { ++ Err(ErrorResponse { ++ error: "score is NaN".to_string(), ++ error_type: ErrorType::Backend, ++ })?; ++ } ++ ++ results.push(OpenAIRank { index, document, relevance_score }) ++ } ++ ++ // Reverse sort ++ results.sort_by(|x, y| x.relevance_score.partial_cmp(&y.relevance_score).unwrap()); ++ results.reverse(); ++ ++ if let Some(top_n) = req.top_n { ++ results.truncate(top_n); ++ } ++ ++ let batch_size = batch_size as u64; ++ ++ let counter = metrics::counter!("te_request_success", "method" => "batch"); ++ counter.increment(1); ++ ++ ( ++ OpenAICompatRerankResponse{results}, ++ ResponseMetadata::new( ++ compute_chars, ++ total_compute_tokens, ++ start_time, ++ Duration::from_nanos(total_tokenization_time / batch_size), ++ Duration::from_nanos(total_queue_time / batch_size), ++ Duration::from_nanos(total_inference_time / batch_size), ++ ), ++ ) ++ }; ++ ++ metadata.record_span(&span); ++ metadata.record_metrics(); ++ ++ let headers = HeaderMap::from(metadata); ++ ++ tracing::info!("Success"); ++ ++ Ok((headers, Json(response))) ++ } ++ + /// Tokenize inputs + #[utoipa::path( + post, +@@ -1641,7 +1842,10 @@ pub async fn run( + EmbedSparseResponse, + RerankRequest, + Rank, ++ OpenAIRank, ++ DocumentWrapper, + RerankResponse, ++ OpenAICompatRerankResponse, + EmbedRequest, + EmbedResponse, + ErrorResponse, +@@ -1733,6 +1937,7 @@ pub async fn run( + .route("/embed_sparse", post(embed_sparse)) + .route("/predict", post(predict)) + .route("/rerank", post(rerank)) ++ .route("/v1/rerank", post(openai_rerank)) + .route("/similarity", post(similarity)) + .route("/tokenize", post(tokenize)) + .route("/decode", post(decode)) +@@ -1825,8 +2030,7 @@ pub async fn run( routes = routes.layer(axum::middleware::from_fn(auth)); } @@ -626,7 +1149,7 @@ index cadb6c1..1f0de1e 100644 .merge(routes) .merge(public_routes) .layer(Extension(infer)) -@@ -1796,6 +1795,14 @@ pub async fn run( +@@ -1839,6 +2043,14 @@ pub async fn run( .layer(DefaultBodyLimit::max(payload_limit)) .layer(cors_layer); @@ -641,39 +1164,156 @@ index cadb6c1..1f0de1e 100644 // Run server let listener = tokio::net::TcpListener::bind(&addr) .await +@@ -1899,3 +2111,4 @@ impl From for ErrorResponse { + } + } + } ++ +diff --git a/router/src/http/types.rs b/router/src/http/types.rs +index 6012288..85aca6b 100644 +--- a/router/src/http/types.rs ++++ b/router/src/http/types.rs +@@ -271,6 +271,26 @@ pub(crate) struct Rank { + #[derive(Serialize, ToSchema)] + pub(crate) struct RerankResponse(pub Vec); + ++#[derive(Serialize, ToSchema)] ++pub(crate) struct DocumentWrapper { ++ #[schema(example = "Deep Learning is ...")] ++ pub text: String, ++} ++ ++#[derive(Serialize, ToSchema)] ++pub(crate) struct OpenAIRank { ++ #[schema(example = "0")] ++ pub index: usize, ++ #[schema(nullable = true, example = r#"{"text": "Deep Learning is ..."}"#, default = "null")] ++ #[serde(skip_serializing_if = "Option::is_none")] ++ pub document: Option, ++ #[schema(example = "1.0")] ++ pub relevance_score: f32, ++} ++ ++#[derive(Serialize, ToSchema)] ++pub(crate) struct OpenAICompatRerankResponse{pub results: Vec} ++ + #[derive(Deserialize, ToSchema, Debug)] + #[serde(untagged)] + pub(crate) enum InputType { +@@ -325,6 +345,29 @@ pub(crate) struct OpenAICompatRequest { + pub encoding_format: EncodingFormat, + } + ++#[derive(Deserialize, ToSchema)] ++pub(crate) struct OpenAICompatRerankRequest { ++ #[schema(example = "What is Deep Learning?")] ++ pub query: String, ++ #[schema(example = json!(["Deep Learning is ..."]))] ++ pub documents: Vec, ++ #[serde(default)] ++ #[schema(default = "false", example = "false", nullable = true)] ++ pub truncate: Option, ++ #[serde(default)] ++ #[schema(default = "right", example = "right")] ++ pub truncation_direction: TruncationDirection, ++ #[serde(default)] ++ #[schema(default = "false", example = "false")] ++ pub raw_scores: bool, ++ #[serde(default)] ++ #[schema(default = "false", example = "false")] ++ pub return_documents: bool, ++ #[serde(default)] ++ #[schema(example = 3, minimum = 1, nullable = true)] ++ pub top_n: Option, ++} ++ + #[derive(Serialize, ToSchema)] + #[serde(untagged)] + pub(crate) enum Embedding { +@@ -590,3 +633,4 @@ pub(crate) enum VertexPrediction { + pub(crate) struct VertexResponse { + pub predictions: Vec, + } ++ diff --git a/router/src/lib.rs b/router/src/lib.rs -index 49e0581..044eb21 100644 +index f1b8ba2..0cdb23c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs -@@ -63,6 +63,7 @@ pub async fn run( - api_key: Option, - otlp_endpoint: Option, - otlp_service_name: String, -+ prometheus_port: u16, - cors_allow_origin: Option>, - ) -> Result<()> { - let model_id_path = Path::new(&model_id); -@@ -250,8 +251,9 @@ pub async fn run( - - if !backend.padded_model { - tracing::info!("Warming up model"); -+ let max_batch_requests = Some(3); - backend -- .warmup(max_input_length, max_batch_tokens, max_batch_requests) -+ .warmup(4, 4, max_batch_requests) - .await - .context("Model backend is not healthy")?; - } -@@ -314,7 +316,7 @@ pub async fn run( +@@ -36,6 +36,7 @@ use tokenizers::processors::sequence::Sequence; + use tokenizers::processors::template::TemplateProcessing; + use tokenizers::{PostProcessorWrapper, Tokenizer}; + use tracing::Span; ++use std::env; + + pub use logging::init_logging; + +@@ -109,24 +110,42 @@ pub async fn run( + let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + + // Info model type ++ let is_rerank = std::env::var("IS_RERANK").unwrap_or_default() == "1"; + let model_type = match &backend_model_type { ++ text_embeddings_backend::ModelType::Classifier if is_rerank => { ++ // 如果环境变量标志为 RERANK,直接走 reranker 分支,不依赖 config.json 字段 ++ let mut id2label = std::collections::HashMap::new(); ++ id2label.insert("0".to_string(), "LABEL_0".to_string()); ++ ++ let mut label2id = std::collections::HashMap::new(); ++ label2id.insert("LABEL_0".to_string(), 0); ++ ++ let classifier_model = ClassifierModel { id2label, label2id }; ++ ModelType::Reranker(classifier_model) ++ } ++ + text_embeddings_backend::ModelType::Classifier => { ++ // 原始 classifier 分支,检查字段是否存在 + let id2label = config + .id2label + .context("`config.json` does not contain `id2label`")?; ++ let label2id = config ++ .label2id ++ .context("`config.json` does not contain `label2id`")?; ++ + let n_classes = id2label.len(); + let classifier_model = ClassifierModel { + id2label, +- label2id: config +- .label2id +- .context("`config.json` does not contain `label2id`")?, ++ label2id, + }; ++ + if n_classes > 1 { + ModelType::Classifier(classifier_model) + } else { + ModelType::Reranker(classifier_model) + } } - }; ++ + text_embeddings_backend::ModelType::Embedding(pool) => { + ModelType::Embedding(EmbeddingModel { + pooling: pool.to_string(), +@@ -249,11 +268,14 @@ pub async fn run( + .await + .context("Model backend is not healthy")?; -- let prom_builder = prometheus::prometheus_builer(info.max_input_length)?; -+ let prom_builder = prometheus::prometheus_builer(addr, prometheus_port, info.max_input_length)?; +- tracing::info!("Warming up model"); +- backend +- .warmup(max_input_length, max_batch_tokens, max_batch_requests) +- .await +- .context("Model backend is not healthy")?; ++ if !backend.padded_model { ++ tracing::info!("Warming up model"); ++ let max_batch_requests = Some(3); ++ backend ++ .warmup(4, 4, max_batch_requests) ++ .await ++ .context("Model backend is not healthy")?; ++ } - #[cfg(all(feature = "grpc", feature = "http"))] - compile_error!("Features `http` and `grpc` cannot be enabled at the same time."); -@@ -363,7 +365,7 @@ fn get_backend_model_type( + let max_batch_requests = backend + .max_batch_size +@@ -362,11 +384,11 @@ fn get_backend_model_type( continue; } @@ -682,8 +1322,13 @@ index 49e0581..044eb21 100644 return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); +- } else if arch.ends_with("Classification") { ++ } else if arch.ends_with("Classification") || env::var("IS_RERANK").is_ok() { + if pooling.is_some() { + tracing::warn!( + "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." diff --git a/router/src/main.rs b/router/src/main.rs -index e4a902d..7b152e4 100644 +index 39b975d..2966cd4 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -48,7 +48,7 @@ struct Args { @@ -695,50 +1340,3 @@ index e4a902d..7b152e4 100644 max_concurrent_requests: usize, /// **IMPORTANT** This is one critical control to allow maximum usage -@@ -164,6 +164,10 @@ struct Args { - #[clap(default_value = "text-embeddings-inference.server", long, env)] - otlp_service_name: String, - -+ /// The Prometheus port to listen on. -+ #[clap(default_value = "9000", long, short, env)] -+ prometheus_port: u16, -+ - /// Unused for gRPC servers - #[clap(long, env)] - cors_allow_origin: Option>, -@@ -227,6 +231,7 @@ async fn main() -> Result<()> { - args.api_key, - args.otlp_endpoint, - args.otlp_service_name, -+ args.prometheus_port, - args.cors_allow_origin, - ) - .await?; -diff --git a/router/src/prometheus.rs b/router/src/prometheus.rs -index bded390..4c5fb38 100644 ---- a/router/src/prometheus.rs -+++ b/router/src/prometheus.rs -@@ -1,6 +1,13 @@ -+use std::net::SocketAddr; - use metrics_exporter_prometheus::{BuildError, Matcher, PrometheusBuilder}; - --pub(crate) fn prometheus_builer(max_input_length: usize) -> Result { -+pub(crate) fn prometheus_builer( -+ addr: SocketAddr, -+ port: u16, -+ max_input_length: usize, -+) -> Result { -+ let mut addr = addr; -+ addr.set_port(port); - // Duration buckets - let duration_matcher = Matcher::Suffix(String::from("duration")); - let n_duration_buckets = 35; -@@ -30,6 +37,7 @@ pub(crate) fn prometheus_builer(max_input_length: usize) -> Result