From 7f783b8a4ac858f0188d5ac3755a0129270b61f1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 14 Oct 2025 22:39:55 +0000 Subject: [PATCH] merge --- tests/entrypoints/openai/test_serving_chat.py | 64 ++++++++++++++++++- vllm/entrypoints/openai/api_server.py | 18 ++++++ vllm/entrypoints/openai/serving_chat.py | 4 ++ vllm/entrypoints/openai/serving_completion.py | 5 ++ vllm/entrypoints/openai/serving_engine.py | 15 +++++ 5 files changed, 104 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index d1367b4eeaf62..2763ecb68680d 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -649,5 +649,65 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): req.cache_salt = "test_salt" with suppress(Exception): await serving_chat.create_chat_completion(req) - engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1] - assert engine_prompt.get("cache_salt") == "test_salt" + assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" + + +@pytest.mark.asyncio +async def test_serving_chat_data_parallel_rank_extraction(): + """Test that data_parallel_rank is properly extracted from header and passed to engine.""" + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test when data_parallel_rank is present in header + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + # Mock request with X-data-parallel-rank header + mock_raw_request = MagicMock() + mock_raw_request.headers = {"X-data-parallel-rank": "2"} + mock_raw_request.state = MagicMock() + + with suppress(Exception): + await serving_chat.create_chat_completion(req, mock_raw_request) + + # Verify that data_parallel_rank was passed to engine.generate + assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs + assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] == 2 + + # Test when data_parallel_rank is not present (defaults to None) + req_no_dp = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 2+2?" + }], + ) + + # Mock request with no header + mock_raw_request_no_dp = MagicMock() + mock_raw_request_no_dp.headers = {} + mock_raw_request_no_dp.state = MagicMock() + + with suppress(Exception): + await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp) + + # Verify that data_parallel_rank defaults to None + assert 'data_parallel_rank' in mock_engine.generate.call_args.kwargs + assert mock_engine.generate.call_args.kwargs['data_parallel_rank'] is None diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index fd80ba7a9afca..056badbd46eba 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -386,6 +386,24 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) + +@router.get("/get_server_info") +async def get_server_info(raw_request: Request): + """Returns server information including DP size for router""" + config = raw_request.app.state.vllm_config + + # Extract dp_size from parallel_config + dp_size = 1 # Default value + if hasattr(config, 'parallel_config') and hasattr(config.parallel_config, 'data_parallel_size'): + dp_size = config.parallel_config.data_parallel_size + + server_info = { + "vllm_config": str(config), + "dp_size": dp_size + } + return JSONResponse(content=server_info) + + @router.get("/ping", response_class=Response) @router.post("/ping", response_class=Response) async def ping(raw_request: Request) -> Response: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 26027112eb589..68b4362c68c8b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing): if raw_request: raw_request.state.request_metadata = request_metadata + # Extract data_parallel_rank from header (router can inject it) + data_parallel_rank = self._get_data_parallel_rank(raw_request) + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing): priority=request.priority, prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, + data_parallel_rank=data_parallel_rank, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7cbe9c69435c3..9c98728d5370b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -140,6 +140,10 @@ class OpenAIServingCompletion(OpenAIServing): logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + # Extract data_parallel_rank from header (router can inject it) + data_parallel_rank = self._get_data_parallel_rank(raw_request) + + # Schedule the request and get the result generator. generators: list[AsyncGenerator[RequestOutput, None]] = [] try: @@ -223,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing): priority=request.priority, prompt_text=prompt_text, tokenization_kwargs=tokenization_kwargs, + data_parallel_rank=data_parallel_rank, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3965d2dac0887..792ee9e159a06 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1298,6 +1298,21 @@ class OpenAIServing: return raw_request.headers.get("X-Request-Id", default) + @staticmethod + def _get_data_parallel_rank(raw_request: Optional[Request]) -> Optional[int]: + """Pulls the data parallel rank from a header, if provided""" + if raw_request is None: + return None + + rank_str = raw_request.headers.get("X-data-parallel-rank") + if rank_str is None: + return None + + try: + return int(rank_str) + except ValueError: + return None + @staticmethod def _get_decoded_token( logprob: Logprob,