mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Frontend] Rerank API (Jina- and Cohere-compatible API) (#12376)
Signed-off-by: Kyle Mistele <kyle@mistele.com>
This commit is contained in:
parent
72bac73067
commit
0034b09ceb
@ -50,6 +50,11 @@ In addition, we have the following custom APIs:
|
||||
- Applicable to all [pooling models](../models/pooling_models.md).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
|
||||
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
|
||||
- Jina and Cohere's APIs are very similar; Jina's includes extra information in the rerank endpoint's response.
|
||||
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
|
||||
(chat-template)=
|
||||
|
||||
@ -473,3 +478,90 @@ The following extra parameters are supported:
|
||||
:start-after: begin-score-extra-params
|
||||
:end-before: end-score-extra-params
|
||||
```
|
||||
|
||||
(rerank-api)=
|
||||
|
||||
### Re-rank API
|
||||
|
||||
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
|
||||
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
|
||||
a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
|
||||
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
|
||||
endpoints are compatible with both [Jina AI's re-rank API interface](https://jina.ai/reranker/) and
|
||||
[Cohere's re-rank API interface](https://docs.cohere.com/v2/reference/rerank) to ensure compatibility with
|
||||
popular open-source tools.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py>
|
||||
|
||||
#### Example Request
|
||||
|
||||
Note that the `top_n` request parameter is optional and will default to the length of the `documents` field.
|
||||
Result documents will be sorted by relevance, and the `index` property can be used to determine original order.
|
||||
|
||||
Request:
|
||||
|
||||
```bash
|
||||
curl -X 'POST' \
|
||||
'http://127.0.0.1:8000/v1/rerank' \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Horses and cows are both animals"
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```bash
|
||||
{
|
||||
"id": "rerank-fae51b2b664d4ed38f5969b612edff77",
|
||||
"model": "BAAI/bge-reranker-base",
|
||||
"usage": {
|
||||
"total_tokens": 56
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"index": 1,
|
||||
"document": {
|
||||
"text": "The capital of France is Paris."
|
||||
},
|
||||
"relevance_score": 0.99853515625
|
||||
},
|
||||
{
|
||||
"index": 0,
|
||||
"document": {
|
||||
"text": "The capital of Brazil is Brasilia."
|
||||
},
|
||||
"relevance_score": 0.0005860328674316406
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters](#pooling-params) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-rerank-pooling-params
|
||||
:end-before: end-rerank-pooling-params
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
:start-after: begin-rerank-extra-params
|
||||
:end-before: end-rerank-extra-params
|
||||
```
|
||||
|
||||
32
examples/online_serving/cohere_rerank_client.py
Normal file
32
examples/online_serving/cohere_rerank_client.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
import cohere
|
||||
|
||||
# cohere v1 client
|
||||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
||||
rerank_v1_result = co.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
|
||||
print(rerank_v1_result)
|
||||
|
||||
# or the v2
|
||||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
||||
|
||||
v2_rerank_result = co2.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
|
||||
print(v2_rerank_result)
|
||||
33
examples/online_serving/jinaai_rerank_client.py
Normal file
33
examples/online_serving/jinaai_rerank_client.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
Jina and Cohere https://jina.ai/reranker
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model":
|
||||
"BAAI/bge-reranker-base",
|
||||
"query":
|
||||
"What is the capital of France?",
|
||||
"documents": [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.", "Horses and cows are both animals"
|
||||
]
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
87
tests/entrypoints/openai/test_rerank.py
Normal file
87
tests/entrypoints/openai/test_rerank.py
Normal file
@ -0,0 +1,87 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from vllm.entrypoints.openai.protocol import RerankResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
})
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.results[0].relevance_score >= 0.9
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.", "Cross-encoder models are neat"
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": 2
|
||||
})
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.results[0].relevance_score >= 0.9
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
query = "What is the capital of France?" * 100
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents
|
||||
})
|
||||
assert rerank_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
rerank_response.text
|
||||
@ -10,12 +10,7 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
# Will be used on tests to compare prompt input length
|
||||
"--max-model-len",
|
||||
"100"
|
||||
]
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@ -56,6 +56,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
|
||||
return request.app.state.jinaai_serving_reranking
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API")
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client"
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
@ -512,7 +552,10 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (ScoreRequest, create_score),
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
@ -759,6 +802,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.jinaai_serving_reranking = JinaAIServingRerank(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
@ -1018,6 +1018,52 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
query: str
|
||||
documents: List[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-rerank-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-rerank-pooling-params
|
||||
|
||||
# doc: begin-rerank-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: List[RerankResult]
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
|
||||
@ -26,7 +26,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, ScoreRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -204,9 +205,9 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
|
||||
206
vllm/entrypoints/openai/serving_rerank.py
Normal file
206
vllm/entrypoints/openai/serving_rerank.py
Normal file
@ -0,0 +1,206 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JinaAIServingRerank(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
query = request.query
|
||||
documents = request.documents
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
for doc in documents:
|
||||
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=query,
|
||||
text_pair=doc,
|
||||
**tokenization_kwargs)
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_rerank_response(
|
||||
final_res_batch_checked, request_id, model_name, documents,
|
||||
top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: List[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: List[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
Loading…
x
Reference in New Issue
Block a user