diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1a31bdbfccb34..043b75cc5d385 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop): assert final_output is not None assert len(final_output.outputs[0].token_ids) == 10 assert final_output.finished + + +@pytest.mark.asyncio(scope="module") +async def test_invalid_argument(async_engine): + scheduler_config = await async_engine.get_scheduler_config() + + if scheduler_config.num_scheduler_steps != 1: + pytest.skip("no need to test this one with multistep") + + sampling_params = SamplingParams( + temperature=0, + min_tokens=10, + max_tokens=10, + ) + + # Targeting specific DP rank only supported in v1 multi-instance DP + with pytest.raises(ValueError): + async for _ in async_engine.generate("test", + sampling_params, + request_id=uid(), + data_parallel_rank=0): + pass diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index 957d50d0d9d85..a65fc35e0ffb2 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch): assert len(engine.stat_loggers) == 1 assert len(engine.stat_loggers[0]) == 1 engine.stat_loggers[0][0].log.assert_called_once() + + +@pytest.mark.asyncio(scope="module") +async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) + after.callback(engine.shutdown) + + sampling_params = SamplingParams(max_tokens=100, + output_kind=RequestOutputKind.DELTA, + temperature=1.0, + seed=33) + + # Test with valid DP rank. + async for _ in engine.generate(request_id="request-34", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=0): + pass + + # Test with out-of-range DP rank. + with pytest.raises(ValueError): + async for _ in engine.generate(request_id="request-35", + prompt=TEXT_PROMPT, + sampling_params=sampling_params, + data_parallel_rank=1): + pass diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index 53242180b21ef..075ceb257ab70 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -29,12 +29,14 @@ if not current_platform.supports_v1(engine_args.create_model_config()): allow_module_level=True) -async def generate(engine: AsyncLLM, - request_id: str, - prompt: PromptType, - output_kind: RequestOutputKind, - max_tokens: int, - prompt_logprobs: Optional[int] = None) -> tuple[int, str]: +async def generate( + engine: AsyncLLM, + request_id: str, + prompt: PromptType, + output_kind: RequestOutputKind, + max_tokens: int, + prompt_logprobs: Optional[int] = None, + data_parallel_rank: Optional[int] = None) -> tuple[int, str]: # Ensure generate doesn't complete too fast for cancellation test. await asyncio.sleep(0.2) @@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM, prompt_logprobs=prompt_logprobs) async for out in engine.generate(request_id=request_id, prompt=prompt, - sampling_params=sampling_params): + sampling_params=sampling_params, + data_parallel_rank=data_parallel_rank): num_tokens = len(out.outputs[0].token_ids) if output_kind == RequestOutputKind.DELTA: @@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind, for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, request_id, prompt, output_kind, - NUM_EXPECTED_TOKENS))) + generate(engine, + request_id, + prompt, + output_kind, + NUM_EXPECTED_TOKENS, + data_parallel_rank=0))) # Confirm that we got all the EXPECTED tokens from the requests. done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 59971f5d65afa..72020a8ccf96b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -494,6 +494,10 @@ class _AsyncLLMEngine(LLMEngine): if arrival_time is None: arrival_time = time.time() + if data_parallel_rank is not None: + raise ValueError("Targeting data_parallel_rank only supported " + "in v1 client.") + if (isinstance(prompt, dict) and prompt.get("prompt_embeds", None) is not None and not prompt.get("prompt_token_ids", None)): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index d1b0b300dccb5..7eff377b74b56 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1000,9 +1000,6 @@ class DPAsyncMPClient(AsyncMPClient): ) -> CoreEngine: if dp_rank is not None: # engines are already in rank order - if dp_rank < 0 or dp_rank >= len(self.core_engines): - raise ValueError(f"Requested DP rank {dp_rank} is out of " - f"range [0, {len(self.core_engines)})") return self.core_engines[dp_rank] if not self.lb_engines: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 546fc98d681c6..e28879d40460e 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -226,6 +226,12 @@ class Processor: if prompt_adapter_request is not None: raise ValueError("V1 does not support prompt_adapter_request.") + data_parallel_size = self.vllm_config.parallel_config.data_parallel_size + if data_parallel_rank is not None and not (0 <= data_parallel_rank < + data_parallel_size): + raise ValueError(f"data_parallel_rank {data_parallel_rank} " + f"is out of range [0, {data_parallel_size}).") + if arrival_time is None: arrival_time = time.time()