mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 12:59:22 +08:00
Add support for the /rerank endpoint in vllm bench serve (#26602)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
29350922c6
commit
fe3edb4cf0
@ -35,6 +35,7 @@ th {
|
||||
| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` |
|
||||
| Random | ✅ | ✅ | `synthetic` |
|
||||
| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` |
|
||||
| RandomForReranking | ✅ | ✅ | `synthetic` |
|
||||
| Prefix Repetition | ✅ | ✅ | `synthetic` |
|
||||
| HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` |
|
||||
| HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` |
|
||||
@ -878,6 +879,51 @@ vllm bench serve \
|
||||
|
||||
</details>
|
||||
|
||||
#### Reranker Benchmark
|
||||
|
||||
Benchmark the performance of rerank requests in vLLM.
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Unlike generative models which use Completions API or Chat Completions API,
|
||||
you should set `--backend vllm-rerank` and `--endpoint /v1/rerank` to use the Reranker API.
|
||||
|
||||
For reranking, the only supported dataset is `--dataset-name random-rerank`
|
||||
|
||||
Start the server:
|
||||
|
||||
```bash
|
||||
vllm serve BAAI/bge-reranker-v2-m3
|
||||
```
|
||||
|
||||
Run the benchmark:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--model BAAI/bge-reranker-v2-m3 \
|
||||
--backend vllm-rerank \
|
||||
--endpoint /v1/rerank \
|
||||
--dataset-name random-rerank \
|
||||
--tokenizer BAAI/bge-reranker-v2-m3 \
|
||||
--random-input-len 512 \
|
||||
--num-prompts 10 \
|
||||
--random-batch-size 5
|
||||
```
|
||||
|
||||
For reranker models, this will create `num_prompts / random_batch_size` requests with
|
||||
`random_batch_size` "documents" where each one has close to `random_input_len` tokens.
|
||||
In the example above, this results in 2 rerank requests with 5 "documents" each where
|
||||
each document has close to 512 tokens.
|
||||
|
||||
Please note that the `/v1/rerank` is also supported by embedding models. So if you're running
|
||||
with an embedding model, also set `--no_reranker`. Because in this case the query is
|
||||
treated as a individual prompt by the server, here we send `random_batch_size - 1` documents
|
||||
to account for the extra prompt which is the query. The token accounting to report the
|
||||
throughput numbers correctly is also adjusted.
|
||||
|
||||
</details>
|
||||
|
||||
[](){ #performance-benchmarks }
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
@ -572,6 +572,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = max(output_high, 1)
|
||||
|
||||
if input_low > input_high:
|
||||
raise ValueError(
|
||||
@ -638,6 +639,112 @@ class RandomDataset(BenchmarkDataset):
|
||||
return prompt, total_input_len, token_mismatch
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDatasetForReranking(RandomDataset):
|
||||
"""
|
||||
Random dataset specialized for the needs of scoring:
|
||||
- Batches of inputs
|
||||
- Inputs composed of pairs
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
|
||||
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
|
||||
batchsize: int = 1,
|
||||
is_reranker: bool = True,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
n_sep_tokens = int(is_reranker)
|
||||
|
||||
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
|
||||
|
||||
query_lens, _, query_offsets = self.get_sampling_params(
|
||||
1, range_ratio, query_len_param, 0, tokenizer
|
||||
)
|
||||
|
||||
query_len = int(query_lens[0])
|
||||
|
||||
if not is_reranker:
|
||||
assert num_requests > 1 and batchsize > 1
|
||||
num_requests -= 1
|
||||
batchsize -= 1
|
||||
doc_len_param = input_len
|
||||
else:
|
||||
doc_len_param = input_len - query_len - n_sep_tokens
|
||||
|
||||
doc_lens, _, doc_offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, doc_len_param, 0, tokenizer
|
||||
)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
query_prompt, query_input_len, token_mismatch_total = (
|
||||
self.generate_token_sequence(
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=[],
|
||||
prefix_len=0,
|
||||
vocab_size=vocab_size,
|
||||
input_len=query_len,
|
||||
offset=int(query_offsets[0]),
|
||||
index=0,
|
||||
)
|
||||
)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
|
||||
tokenizer=tokenizer,
|
||||
prefix_token_ids=[],
|
||||
prefix_len=0,
|
||||
vocab_size=vocab_size,
|
||||
input_len=int(doc_lens[i]),
|
||||
offset=int(doc_offsets[i]),
|
||||
index=i + 1,
|
||||
)
|
||||
token_mismatch_total += token_mismatch
|
||||
requests.append((prompt, total_input_len))
|
||||
|
||||
batch_requests = []
|
||||
# Create batched requests
|
||||
for i in range(0, num_requests, batchsize):
|
||||
batch = requests[i : i + batchsize]
|
||||
query_contrib = (
|
||||
(query_input_len + n_sep_tokens) * len(batch)
|
||||
if is_reranker
|
||||
else query_input_len
|
||||
)
|
||||
batch_requests.append(
|
||||
SampleRequest(
|
||||
prompt=[query_prompt] + [req[0] for req in batch],
|
||||
prompt_len=query_contrib + sum(req[1] for req in batch),
|
||||
expected_output_len=0,
|
||||
request_id=request_id_prefix + str(i // batchsize),
|
||||
)
|
||||
)
|
||||
|
||||
if token_mismatch_total != 0:
|
||||
logger.warning(
|
||||
"Across all generated prompts, there were %d %s tokens "
|
||||
"than expected after decoding and re-encoding. This is "
|
||||
"expected due to the imperfect nature of the sampling "
|
||||
"procedure.",
|
||||
abs(token_mismatch_total),
|
||||
"more" if token_mismatch_total > 0 else "fewer",
|
||||
)
|
||||
|
||||
return batch_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MultiModalDataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -1149,6 +1256,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
"sonnet",
|
||||
"random",
|
||||
"random-mm",
|
||||
"random-rerank",
|
||||
"hf",
|
||||
"custom",
|
||||
"prefix_repetition",
|
||||
@ -1292,6 +1400,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
default=1,
|
||||
help=("Batch size for random sampling. Only used for embeddings benchmark."),
|
||||
)
|
||||
random_group.add_argument(
|
||||
"--no-reranker",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether the model supports reranking natively."
|
||||
" Only used for reranker benchmark."
|
||||
),
|
||||
)
|
||||
|
||||
# random multimodal dataset options
|
||||
random_mm_group = parser.add_argument_group(
|
||||
@ -1678,6 +1794,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
no_oversample=args.no_oversample,
|
||||
),
|
||||
"random-rerank": lambda: RandomDatasetForReranking(
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
disable_shuffle=args.disable_shuffle,
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.random_input_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
batchsize=args.random_batch_size,
|
||||
is_reranker=not args.no_reranker,
|
||||
),
|
||||
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
|
||||
random_seed=args.seed,
|
||||
dataset_path=args.dataset_path,
|
||||
|
||||
@ -64,7 +64,7 @@ class StreamedResponseHandler:
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
|
||||
prompt: str
|
||||
prompt: str | list[str]
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
@ -484,7 +484,7 @@ async def async_request_openai_audio(
|
||||
return output
|
||||
|
||||
|
||||
async def _run_openai_embeddings(
|
||||
async def _run_pooling_request(
|
||||
session: aiohttp.ClientSession,
|
||||
api_url: str,
|
||||
payload: dict[str, Any],
|
||||
@ -497,7 +497,7 @@ async def _run_openai_embeddings(
|
||||
try:
|
||||
async with session.post(url=api_url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.latency = time.perf_counter() - st
|
||||
output.ttft = output.latency = time.perf_counter() - st
|
||||
data = await response.json()
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
@ -536,7 +536,43 @@ async def async_request_openai_embeddings(
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_openai_embeddings(
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_rerank(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM score API", "rerank")
|
||||
|
||||
assert (
|
||||
isinstance(request_func_input.prompt, list)
|
||||
and len(request_func_input.prompt) > 1
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"query": request_func_input.prompt[0],
|
||||
"documents": request_func_input.prompt[1:],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
@ -572,7 +608,7 @@ async def async_request_openai_embeddings_chat(
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_openai_embeddings(
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
@ -685,7 +721,7 @@ async def async_request_infinity_embeddings(
|
||||
}
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_openai_embeddings(
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
@ -722,6 +758,7 @@ ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"infinity-embeddings": async_request_infinity_embeddings,
|
||||
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
|
||||
# (Infinity embedding server does not support vlm2vec)
|
||||
"vllm-rerank": async_request_vllm_rerank,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user