diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index b194e9b74d874..1f30d8cf1e8cc 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: lora_request, trace_headers, priority, + data_parallel_rank, ): return dict(engine_prompt), {} diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py index ca56cc2ddb6a7..6643aa471321b 100644 --- a/tests/entrypoints/openai/test_completion_error.py +++ b/tests/entrypoints/openai/test_completion_error.py @@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: lora_request, trace_headers, priority, + data_parallel_rank, ): return dict(engine_prompt), {} diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 2befa40d636da..69d7b1ceedf59 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -396,6 +396,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: lora_request, trace_headers, priority, + data_parallel_rank, ): return dict(engine_prompt), {} diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 25af55baa91f4..224e5d741024b 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -11,6 +11,13 @@ from vllm import SamplingParams from vllm.assets.image import ImageAsset from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.inputs import PromptType from vllm.outputs import RequestOutput from vllm.platforms import current_platform @@ -484,6 +491,60 @@ async def test_dp_rank_argument(): pass +@pytest.mark.asyncio(scope="module") +async def test_header_dp_rank_argument(): + with ExitStack() as after: + with set_default_torch_num_threads(1): + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + MODEL_NAME = "test-model" + BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] + + # Create models first + models = OpenAIServingModels( + engine_client=engine, + base_model_paths=BASE_MODEL_PATHS, + ) + + # Create serving chat instance + serving_chat = OpenAIServingChat( + engine_client=engine, + models=models, + response_role="assistant", + chat_template=None, + chat_template_content_format="auto", + request_logger=None, + ) + # Create a chat completion request + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": TEXT_PROMPT}], + max_tokens=100, + temperature=1.0, + seed=33, + ) + # Test 1: Valid DP rank (0) + mock_raw_request = MagicMock() + mock_raw_request.headers = {"X-data-parallel-rank": "0"} + mock_raw_request.state = MagicMock() + + # Should succeed with valid rank + response = await serving_chat.create_chat_completion(req, mock_raw_request) + assert isinstance(response, ChatCompletionResponse), ( + "Expected a ChatCompletionResponse for valid DP rank" + ) + + # Test 2: Out-of-range DP rank (1) + mock_raw_request.headers = {"X-data-parallel-rank": "1"} + + # should return ErrorResponse for out-of-range rank + response2 = await serving_chat.create_chat_completion(req, mock_raw_request) + assert isinstance(response2, ErrorResponse), ( + "Expected an ErrorResponse for out-of-range DP rank" + ) + + @pytest.mark.asyncio async def test_check_health(): """Test that check_health returns normally for healthy engine diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 95df373502bfd..04967cbe268dd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -381,6 +381,7 @@ class OpenAIServingChat(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 1be0afc8c74e5..265ca9915e5db 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing): lora_request=lora_request, trace_headers=trace_headers, priority=request.priority, + data_parallel_rank=data_parallel_rank, ) generator = self.engine_client.generate( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5f7cfaa53ec18..b9771963c6d4c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1231,6 +1231,7 @@ class OpenAIServing: lora_request: LoRARequest | None, trace_headers: Mapping[str, str] | None, priority: int, + data_parallel_rank: int | None = None, ) -> tuple[EngineCoreRequest, dict[str, Any]]: """Use the Processor to process inputs for AsyncLLM.""" tokenization_kwargs: dict[str, Any] = {} @@ -1246,6 +1247,7 @@ class OpenAIServing: tokenization_kwargs=tokenization_kwargs, trace_headers=trace_headers, priority=priority, + data_parallel_rank=data_parallel_rank, ) return engine_request, tokenization_kwargs