[Bugfix] fix DP-aware routing in OpenAI API requests (#29002)

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-19 01:50:42 +08:00 committed by GitHub
parent 686cbaac64
commit 500f26e6d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 68 additions and 0 deletions

View File

@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

View File

@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

View File

@ -396,6 +396,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}

View File

@ -11,6 +11,13 @@ from vllm import SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
@ -484,6 +491,60 @@ async def test_dp_rank_argument():
pass
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
MODEL_NAME = "test-model"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
# Create models first
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
# Create serving chat instance
serving_chat = OpenAIServingChat(
engine_client=engine,
models=models,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
request_logger=None,
)
# Create a chat completion request
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": TEXT_PROMPT}],
max_tokens=100,
temperature=1.0,
seed=33,
)
# Test 1: Valid DP rank (0)
mock_raw_request = MagicMock()
mock_raw_request.headers = {"X-data-parallel-rank": "0"}
mock_raw_request.state = MagicMock()
# Should succeed with valid rank
response = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response, ChatCompletionResponse), (
"Expected a ChatCompletionResponse for valid DP rank"
)
# Test 2: Out-of-range DP rank (1)
mock_raw_request.headers = {"X-data-parallel-rank": "1"}
# should return ErrorResponse for out-of-range rank
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
assert isinstance(response2, ErrorResponse), (
"Expected an ErrorResponse for out-of-range DP rank"
)
@pytest.mark.asyncio
async def test_check_health():
"""Test that check_health returns normally for healthy engine

View File

@ -381,6 +381,7 @@ class OpenAIServingChat(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(

View File

@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(

View File

@ -1231,6 +1231,7 @@ class OpenAIServing:
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
@ -1246,6 +1247,7 @@ class OpenAIServing:
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs