From fe3edb4cf0027c99ff37f891349ed8e6d464b02e Mon Sep 17 00:00:00 2001 From: Maximilien de Bayser Date: Tue, 14 Oct 2025 01:25:43 -0300 Subject: [PATCH] Add support for the /rerank endpoint in vllm bench serve (#26602) Signed-off-by: Max de Bayser --- docs/contributing/benchmarks.md | 46 +++++++ vllm/benchmarks/datasets.py | 129 +++++++++++++++++++ vllm/benchmarks/lib/endpoint_request_func.py | 49 ++++++- 3 files changed, 218 insertions(+), 6 deletions(-) diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 6b1eabf3d67fa..0f2c4a5d7f069 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -35,6 +35,7 @@ th { | Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` | | Random | ✅ | ✅ | `synthetic` | | RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` | +| RandomForReranking | ✅ | ✅ | `synthetic` | | Prefix Repetition | ✅ | ✅ | `synthetic` | | HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` | | HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` | @@ -878,6 +879,51 @@ vllm bench serve \ +#### Reranker Benchmark + +Benchmark the performance of rerank requests in vLLM. + +
+Show more + +Unlike generative models which use Completions API or Chat Completions API, +you should set `--backend vllm-rerank` and `--endpoint /v1/rerank` to use the Reranker API. + +For reranking, the only supported dataset is `--dataset-name random-rerank` + +Start the server: + +```bash +vllm serve BAAI/bge-reranker-v2-m3 +``` + +Run the benchmark: + +```bash +vllm bench serve \ + --model BAAI/bge-reranker-v2-m3 \ + --backend vllm-rerank \ + --endpoint /v1/rerank \ + --dataset-name random-rerank \ + --tokenizer BAAI/bge-reranker-v2-m3 \ + --random-input-len 512 \ + --num-prompts 10 \ + --random-batch-size 5 +``` + +For reranker models, this will create `num_prompts / random_batch_size` requests with +`random_batch_size` "documents" where each one has close to `random_input_len` tokens. +In the example above, this results in 2 rerank requests with 5 "documents" each where +each document has close to 512 tokens. + +Please note that the `/v1/rerank` is also supported by embedding models. So if you're running +with an embedding model, also set `--no_reranker`. Because in this case the query is +treated as a individual prompt by the server, here we send `random_batch_size - 1` documents +to account for the extra prompt which is the query. The token accounting to report the +throughput numbers correctly is also adjusted. + +
+ [](){ #performance-benchmarks } ## Performance Benchmarks diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 331d31c1d0e63..d610389ddb6b0 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -572,6 +572,7 @@ class RandomDataset(BenchmarkDataset): # Ensure the lower bound for output length is at least 1 to # prevent sampling 0 tokens. output_low = max(output_low, 1) + output_high = max(output_high, 1) if input_low > input_high: raise ValueError( @@ -638,6 +639,112 @@ class RandomDataset(BenchmarkDataset): return prompt, total_input_len, token_mismatch +# ----------------------------------------------------------------------------- +# Random Dataset Implementation (Synthetic Data) +# ----------------------------------------------------------------------------- + + +class RandomDatasetForReranking(RandomDataset): + """ + Random dataset specialized for the needs of scoring: + - Batches of inputs + - Inputs composed of pairs + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + batchsize: int = 1, + is_reranker: bool = True, + **kwargs, + ) -> list[SampleRequest]: + n_sep_tokens = int(is_reranker) + + query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len + + query_lens, _, query_offsets = self.get_sampling_params( + 1, range_ratio, query_len_param, 0, tokenizer + ) + + query_len = int(query_lens[0]) + + if not is_reranker: + assert num_requests > 1 and batchsize > 1 + num_requests -= 1 + batchsize -= 1 + doc_len_param = input_len + else: + doc_len_param = input_len - query_len - n_sep_tokens + + doc_lens, _, doc_offsets = self.get_sampling_params( + num_requests, range_ratio, doc_len_param, 0, tokenizer + ) + vocab_size = tokenizer.vocab_size + + query_prompt, query_input_len, token_mismatch_total = ( + self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=[], + prefix_len=0, + vocab_size=vocab_size, + input_len=query_len, + offset=int(query_offsets[0]), + index=0, + ) + ) + + requests = [] + for i in range(num_requests): + prompt, total_input_len, token_mismatch = self.generate_token_sequence( # noqa: E501 + tokenizer=tokenizer, + prefix_token_ids=[], + prefix_len=0, + vocab_size=vocab_size, + input_len=int(doc_lens[i]), + offset=int(doc_offsets[i]), + index=i + 1, + ) + token_mismatch_total += token_mismatch + requests.append((prompt, total_input_len)) + + batch_requests = [] + # Create batched requests + for i in range(0, num_requests, batchsize): + batch = requests[i : i + batchsize] + query_contrib = ( + (query_input_len + n_sep_tokens) * len(batch) + if is_reranker + else query_input_len + ) + batch_requests.append( + SampleRequest( + prompt=[query_prompt] + [req[0] for req in batch], + prompt_len=query_contrib + sum(req[1] for req in batch), + expected_output_len=0, + request_id=request_id_prefix + str(i // batchsize), + ) + ) + + if token_mismatch_total != 0: + logger.warning( + "Across all generated prompts, there were %d %s tokens " + "than expected after decoding and re-encoding. This is " + "expected due to the imperfect nature of the sampling " + "procedure.", + abs(token_mismatch_total), + "more" if token_mismatch_total > 0 else "fewer", + ) + + return batch_requests + + # ----------------------------------------------------------------------------- # MultiModalDataset Implementation # ----------------------------------------------------------------------------- @@ -1149,6 +1256,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "sonnet", "random", "random-mm", + "random-rerank", "hf", "custom", "prefix_repetition", @@ -1292,6 +1400,14 @@ def add_dataset_parser(parser: FlexibleArgumentParser): default=1, help=("Batch size for random sampling. Only used for embeddings benchmark."), ) + random_group.add_argument( + "--no-reranker", + action="store_true", + help=( + "Whether the model supports reranking natively." + " Only used for reranker benchmark." + ), + ) # random multimodal dataset options random_mm_group = parser.add_argument_group( @@ -1678,6 +1794,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample, ), + "random-rerank": lambda: RandomDatasetForReranking( + random_seed=args.seed, + dataset_path=args.dataset_path, + disable_shuffle=args.disable_shuffle, + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + input_len=args.random_input_len, + range_ratio=args.random_range_ratio, + request_id_prefix=args.request_id_prefix, + batchsize=args.random_batch_size, + is_reranker=not args.no_reranker, + ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( random_seed=args.seed, dataset_path=args.dataset_path, diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py index 2e5c100a3031d..4f427a31b9ee1 100644 --- a/vllm/benchmarks/lib/endpoint_request_func.py +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -64,7 +64,7 @@ class StreamedResponseHandler: class RequestFuncInput: """The input for the request function.""" - prompt: str + prompt: str | list[str] api_url: str prompt_len: int output_len: int @@ -484,7 +484,7 @@ async def async_request_openai_audio( return output -async def _run_openai_embeddings( +async def _run_pooling_request( session: aiohttp.ClientSession, api_url: str, payload: dict[str, Any], @@ -497,7 +497,7 @@ async def _run_openai_embeddings( try: async with session.post(url=api_url, headers=headers, json=payload) as response: if response.status == 200: - output.latency = time.perf_counter() - st + output.ttft = output.latency = time.perf_counter() - st data = await response.json() output.success = True output.generated_text = "" @@ -536,7 +536,43 @@ async def async_request_openai_embeddings( } _update_headers_common(headers, request_func_input) - return await _run_openai_embeddings( + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_vllm_rerank( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "vLLM score API", "rerank") + + assert ( + isinstance(request_func_input.prompt, list) + and len(request_func_input.prompt) > 1 + ) + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "query": request_func_input.prompt[0], + "documents": request_func_input.prompt[1:], + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + } + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( session, api_url, payload=payload, @@ -572,7 +608,7 @@ async def async_request_openai_embeddings_chat( } _update_headers_common(headers, request_func_input) - return await _run_openai_embeddings( + return await _run_pooling_request( session, api_url, payload=payload, @@ -685,7 +721,7 @@ async def async_request_infinity_embeddings( } _update_headers_common(headers, request_func_input) - return await _run_openai_embeddings( + return await _run_pooling_request( session, api_url, payload=payload, @@ -722,6 +758,7 @@ ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { "infinity-embeddings": async_request_infinity_embeddings, "infinity-embeddings-clip": async_request_infinity_embeddings_clip, # (Infinity embedding server does not support vlm2vec) + "vllm-rerank": async_request_vllm_rerank, } OPENAI_COMPATIBLE_BACKENDS = [