mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 15:31:49 +08:00
[Bugfix] fix DP-aware routing in OpenAI API requests (#29002)
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
686cbaac64
commit
500f26e6d3
@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
|
||||
@ -396,6 +396,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
|
||||
@ -11,6 +11,13 @@ from vllm import SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
@ -484,6 +491,60 @@ async def test_dp_rank_argument():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_header_dp_rank_argument():
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
MODEL_NAME = "test-model"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
|
||||
# Create models first
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
|
||||
# Create serving chat instance
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine_client=engine,
|
||||
models=models,
|
||||
response_role="assistant",
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
request_logger=None,
|
||||
)
|
||||
# Create a chat completion request
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
messages=[{"role": "user", "content": TEXT_PROMPT}],
|
||||
max_tokens=100,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
)
|
||||
# Test 1: Valid DP rank (0)
|
||||
mock_raw_request = MagicMock()
|
||||
mock_raw_request.headers = {"X-data-parallel-rank": "0"}
|
||||
mock_raw_request.state = MagicMock()
|
||||
|
||||
# Should succeed with valid rank
|
||||
response = await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||
assert isinstance(response, ChatCompletionResponse), (
|
||||
"Expected a ChatCompletionResponse for valid DP rank"
|
||||
)
|
||||
|
||||
# Test 2: Out-of-range DP rank (1)
|
||||
mock_raw_request.headers = {"X-data-parallel-rank": "1"}
|
||||
|
||||
# should return ErrorResponse for out-of-range rank
|
||||
response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
|
||||
assert isinstance(response2, ErrorResponse), (
|
||||
"Expected an ErrorResponse for out-of-range DP rank"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_health():
|
||||
"""Test that check_health returns normally for healthy engine
|
||||
|
||||
@ -381,6 +381,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -230,6 +230,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -1231,6 +1231,7 @@ class OpenAIServing:
|
||||
lora_request: LoRARequest | None,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
priority: int,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""Use the Processor to process inputs for AsyncLLM."""
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
@ -1246,6 +1247,7 @@ class OpenAIServing:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user