From c8dcc159214a20650451dcd64b226f56671763f1 Mon Sep 17 00:00:00 2001 From: jmswen Date: Wed, 4 Jun 2025 08:26:47 -0700 Subject: [PATCH] Allow AsyncLLMEngine.generate to target a specific DP rank (#19102) Signed-off-by: Jon Swenson --- .../multi_instance_data_parallel.py | 58 +++++++++++++++++++ tests/tokenization/test_detokenize.py | 3 +- tests/v1/engine/test_engine_core.py | 1 + tests/v1/engine/test_engine_core_client.py | 1 + tests/v1/engine/test_output_processor.py | 5 ++ vllm/engine/async_llm_engine.py | 12 +++- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/async_llm.py | 5 +- vllm/v1/engine/core_client.py | 14 ++++- vllm/v1/engine/processor.py | 2 + 10 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 examples/online_serving/multi_instance_data_parallel.py diff --git a/examples/online_serving/multi_instance_data_parallel.py b/examples/online_serving/multi_instance_data_parallel.py new file mode 100644 index 0000000000000..62b1ec71af14d --- /dev/null +++ b/examples/online_serving/multi_instance_data_parallel.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +from typing import Optional + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +""" +To run this example, run the following commands simultaneously with +different CUDA_VISIBLE_DEVICES: + python examples/online_serving/multi_instance_data_parallel.py + + vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \ + --data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \ + --data-parallel-size-local 1 --enforce-eager --headless + +Once both instances have completed the handshake, this example will +send a request to the instance with DP rank 1. +""" + + +async def main(): + engine_args = AsyncEngineArgs( + model="ibm-research/PowerMoE-3b", + data_parallel_size=2, + dtype="auto", + max_model_len=2048, + data_parallel_address="127.0.0.1", + data_parallel_rpc_port=62300, + data_parallel_size_local=1, + enforce_eager=True, + ) + + engine_client = AsyncLLMEngine.from_engine_args(engine_args) + + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + max_tokens=100, + ) + + prompt = "Who won the 2004 World Series?" + final_output: Optional[RequestOutput] = None + async for output in engine_client.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id="abcdef", + data_parallel_rank=1, + ): + final_output = output + if final_output: + print(final_output.outputs[0].text) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index b289dc972c89b..9f2414eca24f3 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer, None, 0.0, None, - cache_salt=None) + cache_salt=None, + data_parallel_rank=None) if fast is None: detokenizer = IncrementalDetokenizer.from_new_request( diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 3d7632a6037f7..1cbbf30371afd 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest: arrival_time=time.time(), lora_request=None, cache_salt=None, + data_parallel_rank=None, ) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 47181d36f4ccc..c2dc3b4731b5a 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -56,6 +56,7 @@ def make_request( arrival_time=time.time(), lora_request=None, cache_salt=None, + data_parallel_rank=None, ) diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index a83454ee67e73..6b88b0cf17e32 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, eos_token_id=None, lora_request=None, cache_salt=None, + data_parallel_rank=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, eos_token_id=None, lora_request=None, cache_salt=None, + data_parallel_rank=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool, eos_token_id=eos_token_id, lora_request=None, cache_salt=None, + data_parallel_rank=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool, eos_token_id=None, lora_request=None, cache_salt=None, + data_parallel_rank=None, sampling_params=SamplingParams( skip_special_tokens=False, spaces_between_special_tokens=False, @@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors): eos_token_id=None, lora_request=None, cache_salt=None, + data_parallel_rank=None, sampling_params=SamplingParams(), ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6d8d97cf5feba..59971f5d65afa 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> None: ... @@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> None: ... @@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> None: @@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, PoolingRequestOutput], None]]: ... @@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> Coroutine[None, None, AsyncGenerator[Union[ RequestOutput, PoolingRequestOutput], None]]: ... @@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, *, inputs: Optional[PromptType] = None, # DEPRECATED ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: @@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient): trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + data_parallel_rank=data_parallel_rank, ) return stream.generator() @@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient): for generation, if any. priority: The priority of the request. Only applicable with priority scheduling. - + data_parallel_rank: The (global) data parallel rank that must + handle this request. Only applicable if DP is enabled. Yields: The output `RequestOutput` objects from the LLMEngine for the request. @@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient): trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + data_parallel_rank=data_parallel_rank, ): yield LLMEngine.validate_output(output, RequestOutput) except asyncio.CancelledError: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index d1bec25237d62..59463f1ba99f5 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -55,6 +55,7 @@ class EngineCoreRequest( arrival_time: float lora_request: Optional[LoRARequest] cache_salt: Optional[str] + data_parallel_rank: Optional[int] # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0e369632156bd..61ea3c4c3dab4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -229,6 +229,7 @@ class AsyncLLM(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> RequestOutputCollector: """Add new request to the AsyncLLM.""" @@ -245,7 +246,7 @@ class AsyncLLM(EngineClient): prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, tokenization_kwargs, trace_headers, prompt_adapter_request, - priority) + priority, data_parallel_rank) if params.n == 1: await self._add_request(request, prompt_str, None, 0, queue) @@ -291,6 +292,7 @@ class AsyncLLM(EngineClient): trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: """ Main function called by the API server to kick off a request @@ -321,6 +323,7 @@ class AsyncLLM(EngineClient): trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, priority=priority, + data_parallel_rank=data_parallel_rank, ) # The output_handler task pushes items into the queue. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index adb0709c828a7..0cd58d01df7f7 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -982,7 +982,16 @@ class DPAsyncMPClient(AsyncMPClient): resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) - def get_core_engine_for_request(self) -> CoreEngine: + def get_core_engine_for_request(self, + dp_rank: Optional[int] = None + ) -> CoreEngine: + if dp_rank is not None: + # engines are already in rank order + if dp_rank < 0 or dp_rank >= len(self.core_engines): + raise ValueError(f"Requested DP rank {dp_rank} is out of " + f"range [0, {len(self.core_engines)})") + return self.core_engines[dp_rank] + if not self.lb_engines: return self.core_engines[0] # TODO use P2C alg for larger DP sizes @@ -1018,7 +1027,8 @@ class DPAsyncMPClient(AsyncMPClient): request.current_wave = self.current_wave request.client_index = self.client_index - chosen_engine = self.get_core_engine_for_request() + chosen_engine = self.get_core_engine_for_request( + request.data_parallel_rank) self.reqs_in_flight[request.request_id] = chosen_engine to_await = self._send_input(EngineCoreRequestType.ADD, request, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c0d01d9b6f61..546fc98d681c6 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -212,6 +212,7 @@ class Processor: trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + data_parallel_rank: Optional[int] = None, ) -> tuple[Optional[str], EngineCoreRequest]: # TODO(woosuk): Support pooling models. @@ -328,6 +329,7 @@ class Processor: arrival_time=arrival_time, lora_request=lora_request, cache_salt=decoder_inputs.get("cache_salt"), + data_parallel_rank=data_parallel_rank, ) def _validate_model_inputs(self,