mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 10:47:03 +08:00
[FEATURE] Enables /score endpoint for embedding models (#12846)
This commit is contained in:
parent
1cdc88614a
commit
1c3c975766
@ -108,8 +108,7 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/clas
|
||||
### `LLM.score`
|
||||
|
||||
The {class}`~vllm.LLM.score` method outputs similarity scores between sentence pairs.
|
||||
It is primarily designed for [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html).
|
||||
These types of models serve as rerankers between candidate query-document pairs in RAG systems.
|
||||
It is designed for embedding models and cross encoder models. Embedding models use cosine similarity, and [cross-encoder models](https://www.sbert.net/examples/applications/cross-encoder/README.html) serve as rerankers between candidate query-document pairs in RAG systems.
|
||||
|
||||
:::{note}
|
||||
vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG.
|
||||
|
||||
@ -51,7 +51,7 @@ In addition, we have the following custom APIs:
|
||||
- [Pooling API](#pooling-api) (`/pooling`)
|
||||
- Applicable to all [pooling models](../models/pooling_models.md).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Only applicable to [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
- Applicable to embedding models and [cross-encoder models](../models/pooling_models.md) (`--task score`).
|
||||
- [Re-rank API](#rerank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
|
||||
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
|
||||
@ -333,10 +333,10 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
|
||||
|
||||
### Score API
|
||||
|
||||
Our Score API applies a cross-encoder model to predict scores for sentence pairs.
|
||||
Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair.
|
||||
Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_cross_encoder_score.py>
|
||||
|
||||
@ -496,11 +496,11 @@ The following extra parameters are supported:
|
||||
|
||||
### Re-rank API
|
||||
|
||||
Our Re-rank API applies a cross-encoder model to predict relevant scores between a single query, and
|
||||
Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and
|
||||
each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on
|
||||
a scale of 0 to 1.
|
||||
|
||||
You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html).
|
||||
|
||||
The rerank endpoints support popular re-rank models such as `BAAI/bge-reranker-base` and other models supporting the
|
||||
`score` task. Additionally, `/rerank`, `/v1/rerank`, and `/v2/rerank`
|
||||
|
||||
@ -8,17 +8,17 @@ from vllm.entrypoints.openai.protocol import RerankResponse
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
@ -42,7 +42,6 @@ def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
@ -68,7 +67,6 @@ def test_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
assert rerank.results[1].relevance_score <= 0.01
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
|
||||
@ -1,123 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import torch.nn.functional as F
|
||||
from torch import tensor
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ScoreResponse
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
||||
MODELS = [
|
||||
{
|
||||
"name": "BAAI/bge-reranker-v2-m3",
|
||||
"is_cross_encoder": True
|
||||
},
|
||||
{
|
||||
"name": "BAAI/bge-base-en-v1.5",
|
||||
"is_cross_encoder": False
|
||||
},
|
||||
]
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = ["--enforce-eager", "--max-model-len", "100"]
|
||||
def run_transformers(hf_model, model, text_pairs):
|
||||
if model["is_cross_encoder"]:
|
||||
return hf_model.predict(text_pairs).tolist()
|
||||
else:
|
||||
hf_embeddings = [
|
||||
hf_model.encode(text_pair) for text_pair in text_pairs
|
||||
]
|
||||
return [
|
||||
F.cosine_similarity(tensor(pair[0]), tensor(pair[1]), dim=0)
|
||||
for pair in hf_embeddings
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
||||
@pytest.fixture(scope="class", params=MODELS)
|
||||
def model(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def server(model: dict[str, Any]):
|
||||
args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
|
||||
|
||||
with RemoteOpenAIServer(model["name"], args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
@pytest.fixture(scope="class")
|
||||
def runner(model: dict[str, Any], hf_runner):
|
||||
kwargs = {
|
||||
"dtype": DTYPE,
|
||||
"is_cross_encoder" if model["is_cross_encoder"]\
|
||||
else "is_sentence_transformer": True
|
||||
}
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score <= 0.01
|
||||
assert score.data[1].score >= 0.9
|
||||
with hf_runner(model["name"], **kwargs) as hf_model:
|
||||
yield hf_model
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = [
|
||||
"What is the capital of the United States?",
|
||||
"What is the capital of France?"
|
||||
]
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
class TestModel:
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
def test_text_1_str_text_2_list(self, server: RemoteOpenAIServer,
|
||||
model: dict[str, Any], runner):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris."
|
||||
]
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
assert score.data[0].score <= 0.01
|
||||
assert score.data[1].score >= 0.9
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
vllm_outputs = [d.score for d in score.data]
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
text_pairs = [[text_1, text_2[0]], [text_1, text_2[1]]]
|
||||
hf_outputs = run_transformers(runner, model, text_pairs)
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
assert score.data[0].score >= 0.9
|
||||
for i in range(len(vllm_outputs)):
|
||||
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
|
||||
|
||||
def test_text_1_list_text_2_list(self, server: RemoteOpenAIServer,
|
||||
model: dict[str, Any], runner):
|
||||
text_1 = [
|
||||
"What is the capital of the United States?",
|
||||
"What is the capital of France?"
|
||||
]
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris."
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
text_1 = "What is the capital of France?" * 20
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
score_response.text
|
||||
vllm_outputs = [d.score for d in score.data]
|
||||
|
||||
# Test truncation
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
"truncate_prompt_tokens": 101
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
text_pairs = [[text_1[0], text_2[0]], [text_1[1], text_2[1]]]
|
||||
hf_outputs = run_transformers(runner, model, text_pairs)
|
||||
|
||||
for i in range(len(vllm_outputs)):
|
||||
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
|
||||
|
||||
def test_text_1_str_text_2_str(self, server: RemoteOpenAIServer,
|
||||
model: dict[str, Any], runner):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 1
|
||||
|
||||
vllm_outputs = [d.score for d in score.data]
|
||||
|
||||
text_pairs = [[text_1, text_2]]
|
||||
hf_outputs = run_transformers(runner, model, text_pairs)
|
||||
|
||||
for i in range(len(vllm_outputs)):
|
||||
assert math.isclose(hf_outputs[i], vllm_outputs[i], rel_tol=0.01)
|
||||
|
||||
def test_score_max_model_len(self, server: RemoteOpenAIServer,
|
||||
model: dict[str, Any]):
|
||||
|
||||
text_1 = "What is the capital of France?" * 20
|
||||
text_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris."
|
||||
]
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
score_response.text
|
||||
|
||||
# Test truncation
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json={
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
"truncate_prompt_tokens": 101
|
||||
})
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
|
||||
Tuple, Type, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
@ -1010,40 +1011,25 @@ class LLM:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[ScoringRequestOutput]:
|
||||
|
||||
encoded_output = self.encode(
|
||||
encoded_output: List[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
encoded_output_1 = encoded_output[0:len(text_1)]
|
||||
encoded_output_2 = encoded_output[len(text_1):]
|
||||
|
||||
encoded_output_1: List[PoolingRequestOutput] = encoded_output[
|
||||
0:len(text_1)]
|
||||
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
|
||||
len(text_1):]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
output_pairs = [(t1, t2)
|
||||
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
|
||||
scores: List[PoolingRequestOutput] = []
|
||||
|
||||
scores = []
|
||||
scorer = torch.nn.CosineSimilarity(0)
|
||||
|
||||
for embed_1, embed_2 in output_pairs:
|
||||
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
|
||||
|
||||
if (pad_token_id := getattr(tokenizer, "pad_token_id",
|
||||
None)) is not None:
|
||||
tokens = embed_1.prompt_token_ids + [
|
||||
pad_token_id
|
||||
] + embed_2.prompt_token_ids
|
||||
else:
|
||||
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
finished=True))
|
||||
scores = _cosine_similarity(tokenizer=tokenizer,
|
||||
embed_1=encoded_output_1,
|
||||
embed_2=encoded_output_2)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores,
|
||||
PoolingRequestOutput)
|
||||
@ -1183,12 +1169,7 @@ class LLM:
|
||||
text_2 = [text_2]
|
||||
input_text_2: List[str] = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(input_text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(input_text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
_validate_score_input_lens(input_text_1, input_text_2)
|
||||
|
||||
if self.llm_engine.model_config.is_cross_encoder:
|
||||
return self._cross_encoding_score(tokenizer, input_text_1,
|
||||
@ -1197,7 +1178,6 @@ class LLM:
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
else:
|
||||
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
input_text_1, # type: ignore[arg-type]
|
||||
|
||||
@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
def score(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
|
||||
return request.app.state.jinaai_serving_reranking
|
||||
def rerank(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
@ -866,13 +865,13 @@ async def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embed" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.jinaai_serving_reranking = JinaAIServingRerank(
|
||||
request_logger=request_logger) if model_config.task in (
|
||||
"score", "embed", "pooling") else None
|
||||
state.jinaai_serving_reranking = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
|
||||
@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
@ -342,7 +342,7 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embed" else None
|
||||
openai_serving_scores = (OpenAIServingScores(
|
||||
openai_serving_scores = (ServingScores(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
@ -364,9 +364,9 @@ async def main(args):
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
handler_fn = (None if openai_serving_chat is None else
|
||||
openai_serving_chat.create_chat_completion)
|
||||
if handler_fn is None:
|
||||
chat_handler_fn = (None if openai_serving_chat is None else
|
||||
openai_serving_chat.create_chat_completion)
|
||||
if chat_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
@ -375,12 +375,13 @@ async def main(args):
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
response_futures.append(
|
||||
run_request(chat_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/embeddings":
|
||||
handler_fn = (None if openai_serving_embedding is None else
|
||||
openai_serving_embedding.create_embedding)
|
||||
if handler_fn is None:
|
||||
embed_handler_fn = (None if openai_serving_embedding is None else
|
||||
openai_serving_embedding.create_embedding)
|
||||
if embed_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
@ -388,12 +389,13 @@ async def main(args):
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
response_futures.append(
|
||||
run_request(embed_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
elif request.url == "/v1/score":
|
||||
handler_fn = (None if openai_serving_scores is None else
|
||||
openai_serving_scores.create_score)
|
||||
if handler_fn is None:
|
||||
score_handler_fn = (None if openai_serving_scores is None else
|
||||
openai_serving_scores.create_score)
|
||||
if score_handler_fn is None:
|
||||
response_futures.append(
|
||||
make_async_error_request_output(
|
||||
request,
|
||||
@ -401,7 +403,8 @@ async def main(args):
|
||||
))
|
||||
continue
|
||||
|
||||
response_futures.append(run_request(handler_fn, request, tracker))
|
||||
response_futures.append(
|
||||
run_request(score_handler_fn, request, tracker))
|
||||
tracker.submitted()
|
||||
else:
|
||||
response_futures.append(
|
||||
|
||||
@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingCompletionRequest, ScoreRequest,
|
||||
TokenizeCompletionRequest]
|
||||
EmbeddingCompletionRequest, RerankRequest,
|
||||
ScoreRequest, TokenizeCompletionRequest]
|
||||
|
||||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
|
||||
@ -1,208 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JinaAIServingRerank(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
query = request.query
|
||||
documents = request.documents
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
for doc in documents:
|
||||
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=query,
|
||||
text_pair=doc,
|
||||
**tokenization_kwargs)
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_rerank_response(
|
||||
final_res_batch_checked, request_id, model_name, documents,
|
||||
top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: List[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: List[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
@ -1,53 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
ScoreResponse, ScoreResponseData,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ScoreResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.score_utils import (_cosine_similarity,
|
||||
_validate_score_input_lens)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
|
||||
str]) -> List:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [t for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [t for t in text_2]
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
|
||||
|
||||
class OpenAIServingScores(OpenAIServing):
|
||||
class ServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -62,6 +45,248 @@ class OpenAIServingScores(OpenAIServing):
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
texts_1: List[str],
|
||||
texts_2: List[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
|
||||
input_texts = texts_1 + texts_2
|
||||
|
||||
engine_prompts: List[TokensPrompt] = []
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
|
||||
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
|
||||
text_token_prompt = \
|
||||
self._validate_input(
|
||||
request,
|
||||
tok_result["input_ids"],
|
||||
input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: List[Optional[PoolingRequestOutput]] =\
|
||||
[None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_texts_1: List[PoolingRequestOutput] = []
|
||||
emb_texts_2: List[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(texts_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_1.append(emb)
|
||||
|
||||
for i in range(len(texts_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_texts_2.append(emb)
|
||||
|
||||
if len(emb_texts_1) == 1:
|
||||
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
|
||||
embed_1=emb_texts_1,
|
||||
embed_2=emb_texts_2)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
texts_1: List[str],
|
||||
texts_2: List[str],
|
||||
request: Union[RerankRequest, ScoreRequest],
|
||||
request_id=str,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[Union[LoRARequest, None]] = None,
|
||||
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
|
||||
None]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
|
||||
request_prompts: List[str] = []
|
||||
engine_prompts: List[TokensPrompt] = []
|
||||
|
||||
if len(texts_1) == 1:
|
||||
texts_1 = texts_1 * len(texts_2)
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)]
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
|
||||
for t1, t2 in input_pairs))
|
||||
|
||||
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
|
||||
|
||||
request_prompt = f"{t1}{tokenizer.sep_token}{t2}"
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[
|
||||
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
texts_1: Union[str, list[str]],
|
||||
texts_2: Union[str, list[str]],
|
||||
request: Union[ScoreRequest, RerankRequest],
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> List[PoolingRequestOutput]:
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
if isinstance(texts_1, str):
|
||||
texts_1 = [texts_1]
|
||||
if isinstance(texts_2, str):
|
||||
texts_2 = [texts_2]
|
||||
|
||||
_validate_score_input_lens(texts_1, texts_2)
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=texts_1,
|
||||
texts_2=texts_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
texts_1=texts_1,
|
||||
texts_2=texts_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
@ -76,123 +301,24 @@ class OpenAIServingScores(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.text_1,
|
||||
request.text_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
|
||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||
for q, t in input_pairs:
|
||||
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_score_response(
|
||||
final_res_batch_checked,
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
request.model,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
@ -200,7 +326,44 @@ class OpenAIServingScores(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch, request_id, request.model, documents, top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
@ -236,3 +399,35 @@ class OpenAIServingScores(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: List[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: List[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
|
||||
49
vllm/entrypoints/score_utils.py
Normal file
49
vllm/entrypoints/score_utils.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Union
|
||||
|
||||
from torch.nn import CosineSimilarity
|
||||
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
|
||||
def _cosine_similarity(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
embed_1: List[PoolingRequestOutput],
|
||||
embed_2: List[PoolingRequestOutput],
|
||||
) -> List[PoolingRequestOutput]:
|
||||
|
||||
scorer = CosineSimilarity(0)
|
||||
scores: Union[List[PoolingRequestOutput]] = []
|
||||
|
||||
for emb_1, emb_2 in zip(embed_1, embed_2):
|
||||
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
|
||||
|
||||
padding = []
|
||||
if (pad_token_id := getattr(tokenizer, "pad_token_id",
|
||||
None)) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
finished=True))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _validate_score_input_lens(
|
||||
texts_1: Union[List[str], List[dict]],
|
||||
texts_2: Union[List[str], List[dict]],
|
||||
):
|
||||
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(texts_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(texts_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
Loading…
x
Reference in New Issue
Block a user