mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 00:08:04 +08:00
Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
parent
8f4ffbd373
commit
c8dcc15921
58
examples/online_serving/multi_instance_data_parallel.py
Normal file
58
examples/online_serving/multi_instance_data_parallel.py
Normal file
@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
"""
|
||||
To run this example, run the following commands simultaneously with
|
||||
different CUDA_VISIBLE_DEVICES:
|
||||
python examples/online_serving/multi_instance_data_parallel.py
|
||||
|
||||
vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
|
||||
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
|
||||
--data-parallel-size-local 1 --enforce-eager --headless
|
||||
|
||||
Once both instances have completed the handshake, this example will
|
||||
send a request to the instance with DP rank 1.
|
||||
"""
|
||||
|
||||
|
||||
async def main():
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
data_parallel_size=2,
|
||||
dtype="auto",
|
||||
max_model_len=2048,
|
||||
data_parallel_address="127.0.0.1",
|
||||
data_parallel_rpc_port=62300,
|
||||
data_parallel_size_local=1,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id="abcdef",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
cache_salt=None)
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None)
|
||||
|
||||
if fast is None:
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||
|
||||
@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ def make_request(
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
eos_token_id=eos_token_id,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||
RequestOutput, PoolingRequestOutput], None]]:
|
||||
...
|
||||
@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
return stream.generator()
|
||||
@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
|
||||
for generation, if any.
|
||||
priority: The priority of the request.
|
||||
Only applicable with priority scheduling.
|
||||
|
||||
data_parallel_rank: The (global) data parallel rank that must
|
||||
handle this request. Only applicable if DP is enabled.
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
for the request.
|
||||
@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@ -55,6 +55,7 @@ class EngineCoreRequest(
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
data_parallel_rank: Optional[int]
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
|
||||
@ -229,6 +229,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
@ -245,7 +246,7 @@ class AsyncLLM(EngineClient):
|
||||
prompt_str, request = self.processor.process_inputs(
|
||||
request_id, prompt, params, arrival_time, lora_request,
|
||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||
priority)
|
||||
priority, data_parallel_rank)
|
||||
|
||||
if params.n == 1:
|
||||
await self._add_request(request, prompt_str, None, 0, queue)
|
||||
@ -291,6 +292,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
@ -321,6 +323,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
|
||||
@ -982,7 +982,16 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
|
||||
def get_core_engine_for_request(self) -> CoreEngine:
|
||||
def get_core_engine_for_request(self,
|
||||
dp_rank: Optional[int] = None
|
||||
) -> CoreEngine:
|
||||
if dp_rank is not None:
|
||||
# engines are already in rank order
|
||||
if dp_rank < 0 or dp_rank >= len(self.core_engines):
|
||||
raise ValueError(f"Requested DP rank {dp_rank} is out of "
|
||||
f"range [0, {len(self.core_engines)})")
|
||||
return self.core_engines[dp_rank]
|
||||
|
||||
if not self.lb_engines:
|
||||
return self.core_engines[0]
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
@ -1018,7 +1027,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
request.current_wave = self.current_wave
|
||||
request.client_index = self.client_index
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request()
|
||||
chosen_engine = self.get_core_engine_for_request(
|
||||
request.data_parallel_rank)
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
|
||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||
|
||||
@ -212,6 +212,7 @@ class Processor:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
@ -328,6 +329,7 @@ class Processor:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user