Add support for the /rerank endpoint in vllm bench serve (#26602)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Maximilien de Bayser 2025-10-14 01:25:43 -03:00 committed by GitHub
parent 29350922c6
commit fe3edb4cf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 218 additions and 6 deletions

View File

@ -35,6 +35,7 @@ th {
| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` |
| Random | ✅ | ✅ | `synthetic` |
| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` |
| RandomForReranking | ✅ | ✅ | `synthetic` |
| Prefix Repetition | ✅ | ✅ | `synthetic` |
| HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` |
| HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` |
@ -878,6 +879,51 @@ vllm bench serve \
</details>
#### Reranker Benchmark
Benchmark the performance of rerank requests in vLLM.
<details class="admonition abstract" markdown="1">
<summary>Show more</summary>
Unlike generative models which use Completions API or Chat Completions API,
you should set `--backend vllm-rerank` and `--endpoint /v1/rerank` to use the Reranker API.
For reranking, the only supported dataset is `--dataset-name random-rerank`
Start the server:
```bash
vllm serve BAAI/bge-reranker-v2-m3
```
Run the benchmark:
```bash
vllm bench serve \
--model BAAI/bge-reranker-v2-m3 \
--backend vllm-rerank \
--endpoint /v1/rerank \
--dataset-name random-rerank \
--tokenizer BAAI/bge-reranker-v2-m3 \
--random-input-len 512 \
--num-prompts 10 \
--random-batch-size 5
```
For reranker models, this will create `num_prompts / random_batch_size` requests with
`random_batch_size` "documents" where each one has close to `random_input_len` tokens.
In the example above, this results in 2 rerank requests with 5 "documents" each where
each document has close to 512 tokens.
Please note that the `/v1/rerank` is also supported by embedding models. So if you're running
with an embedding model, also set `--no_reranker`. Because in this case the query is
treated as a individual prompt by the server, here we send `random_batch_size - 1` documents
to account for the extra prompt which is the query. The token accounting to report the
throughput numbers correctly is also adjusted.
</details>
[](){ #performance-benchmarks }
## Performance Benchmarks

View File

@ -572,6 +572,7 @@ class RandomDataset(BenchmarkDataset):
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
output_high = max(output_high, 1)
if input_low > input_high:
raise ValueError(
@ -638,6 +639,112 @@ class RandomDataset(BenchmarkDataset):
return prompt, total_input_len, token_mismatch
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDatasetForReranking(RandomDataset):
"""
Random dataset specialized for the needs of scoring:
- Batches of inputs
- Inputs composed of pairs
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
batchsize: int = 1,
is_reranker: bool = True,
**kwargs,
) -> list[SampleRequest]:
n_sep_tokens = int(is_reranker)
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
query_lens, _, query_offsets = self.get_sampling_params(
1, range_ratio, query_len_param, 0, tokenizer
)
query_len = int(query_lens[0])
if not is_reranker:
assert num_requests > 1 and batchsize > 1
num_requests -= 1
batchsize -= 1
doc_len_param = input_len
else:
doc_len_param = input_len - query_len - n_sep_tokens
doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer
)
vocab_size = tokenizer.vocab_size
query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=[],
prefix_len=0,
vocab_size=vocab_size,
input_len=query_len,
offset=int(query_offsets[0]),
index=0,
)
)
requests = []
for i in range(num_requests):
prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501
tokenizer=tokenizer,
prefix_token_ids=[],
prefix_len=0,
vocab_size=vocab_size,
input_len=int(doc_lens[i]),
offset=int(doc_offsets[i]),
index=i + 1,
)
token_mismatch_total += token_mismatch
requests.append((prompt, total_input_len))
batch_requests = []
# Create batched requests
for i in range(0, num_requests, batchsize):
batch = requests[i : i + batchsize]
query_contrib = (
(query_input_len + n_sep_tokens) * len(batch)
if is_reranker
else query_input_len
)
batch_requests.append(
SampleRequest(
prompt=[query_prompt] + [req[0] for req in batch],
prompt_len=query_contrib + sum(req[1] for req in batch),
expected_output_len=0,
request_id=request_id_prefix + str(i // batchsize),
)
)
if token_mismatch_total != 0:
logger.warning(
"Across all generated prompts, there were %d %s tokens "
"than expected after decoding and re-encoding. This is "
"expected due to the imperfect nature of the sampling "
"procedure.",
abs(token_mismatch_total),
"more" if token_mismatch_total > 0 else "fewer",
)
return batch_requests
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
@ -1149,6 +1256,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"sonnet",
"random",
"random-mm",
"random-rerank",
"hf",
"custom",
"prefix_repetition",
@ -1292,6 +1400,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
default=1,
help=("Batch size for random sampling. Only used for embeddings benchmark."),
)
random_group.add_argument(
"--no-reranker",
action="store_true",
help=(
"Whether the model supports reranking natively."
" Only used for reranker benchmark."
),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
@ -1678,6 +1794,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
),
"random-rerank": lambda: RandomDatasetForReranking(
random_seed=args.seed,
dataset_path=args.dataset_path,
disable_shuffle=args.disable_shuffle,
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
input_len=args.random_input_len,
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
batchsize=args.random_batch_size,
is_reranker=not args.no_reranker,
),
"prefix_repetition": lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed,
dataset_path=args.dataset_path,

View File

@ -64,7 +64,7 @@ class StreamedResponseHandler:
class RequestFuncInput:
"""The input for the request function."""
prompt: str
prompt: str | list[str]
api_url: str
prompt_len: int
output_len: int
@ -484,7 +484,7 @@ async def async_request_openai_audio(
return output
async def _run_openai_embeddings(
async def _run_pooling_request(
session: aiohttp.ClientSession,
api_url: str,
payload: dict[str, Any],
@ -497,7 +497,7 @@ async def _run_openai_embeddings(
try:
async with session.post(url=api_url, headers=headers, json=payload) as response:
if response.status == 200:
output.latency = time.perf_counter() - st
output.ttft = output.latency = time.perf_counter() - st
data = await response.json()
output.success = True
output.generated_text = ""
@ -536,7 +536,43 @@ async def async_request_openai_embeddings(
}
_update_headers_common(headers, request_func_input)
return await _run_openai_embeddings(
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_vllm_rerank(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM score API", "rerank")
assert (
isinstance(request_func_input.prompt, list)
and len(request_func_input.prompt) > 1
)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"query": request_func_input.prompt[0],
"documents": request_func_input.prompt[1:],
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
@ -572,7 +608,7 @@ async def async_request_openai_embeddings_chat(
}
_update_headers_common(headers, request_func_input)
return await _run_openai_embeddings(
return await _run_pooling_request(
session,
api_url,
payload=payload,
@ -685,7 +721,7 @@ async def async_request_infinity_embeddings(
}
_update_headers_common(headers, request_func_input)
return await _run_openai_embeddings(
return await _run_pooling_request(
session,
api_url,
payload=payload,
@ -722,6 +758,7 @@ ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"infinity-embeddings": async_request_infinity_embeddings,
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
# (Infinity embedding server does not support vlm2vec)
"vllm-rerank": async_request_vllm_rerank,
}
OPENAI_COMPATIBLE_BACKENDS = [