Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)

Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
jmswen 2025-06-04 08:26:47 -07:00 committed by GitHub
parent 8f4ffbd373
commit c8dcc15921
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 97 additions and 5 deletions

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Optional
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
"""
To run this example, run the following commands simultaneously with
different CUDA_VISIBLE_DEVICES:
python examples/online_serving/multi_instance_data_parallel.py
vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
--data-parallel-size-local 1 --enforce-eager --headless
Once both instances have completed the handshake, this example will
send a request to the instance with DP rank 1.
"""
async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
)
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
None,
0.0,
None,
cache_salt=None)
cache_salt=None,
data_parallel_rank=None)
if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(

View File

@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -56,6 +56,7 @@ def make_request(
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)

View File

@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
eos_token_id=eos_token_id,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]

View File

@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...
@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...
@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return stream.generator()
@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
data_parallel_rank: The (global) data parallel rank that must
handle this request. Only applicable if DP is enabled.
Yields:
The output `RequestOutput` objects from the LLMEngine
for the request.
@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:

View File

@ -55,6 +55,7 @@ class EngineCoreRequest(
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]
# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.

View File

@ -229,6 +229,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
@ -245,7 +246,7 @@ class AsyncLLM(EngineClient):
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
priority, data_parallel_rank)
if params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
@ -291,6 +292,7 @@ class AsyncLLM(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
@ -321,6 +323,7 @@ class AsyncLLM(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
# The output_handler task pushes items into the queue.

View File

@ -982,7 +982,16 @@ class DPAsyncMPClient(AsyncMPClient):
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())
def get_core_engine_for_request(self) -> CoreEngine:
def get_core_engine_for_request(self,
dp_rank: Optional[int] = None
) -> CoreEngine:
if dp_rank is not None:
# engines are already in rank order
if dp_rank < 0 or dp_rank >= len(self.core_engines):
raise ValueError(f"Requested DP rank {dp_rank} is out of "
f"range [0, {len(self.core_engines)})")
return self.core_engines[dp_rank]
if not self.lb_engines:
return self.core_engines[0]
# TODO use P2C alg for larger DP sizes
@ -1018,7 +1027,8 @@ class DPAsyncMPClient(AsyncMPClient):
request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request()
chosen_engine = self.get_core_engine_for_request(
request.data_parallel_rank)
self.reqs_in_flight[request.request_id] = chosen_engine
to_await = self._send_input(EngineCoreRequestType.ADD, request,

View File

@ -212,6 +212,7 @@ class Processor:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:
# TODO(woosuk): Support pooling models.
@ -328,6 +329,7 @@ class Processor:
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
data_parallel_rank=data_parallel_rank,
)
def _validate_model_inputs(self,