mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 13:41:20 +08:00
Allow AsyncLLMEngine.generate to target a specific DP rank (#19102)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
parent
8f4ffbd373
commit
c8dcc15921
58
examples/online_serving/multi_instance_data_parallel.py
Normal file
58
examples/online_serving/multi_instance_data_parallel.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this example, run the following commands simultaneously with
|
||||||
|
different CUDA_VISIBLE_DEVICES:
|
||||||
|
python examples/online_serving/multi_instance_data_parallel.py
|
||||||
|
|
||||||
|
vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
|
||||||
|
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
|
||||||
|
--data-parallel-size-local 1 --enforce-eager --headless
|
||||||
|
|
||||||
|
Once both instances have completed the handshake, this example will
|
||||||
|
send a request to the instance with DP rank 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
engine_args = AsyncEngineArgs(
|
||||||
|
model="ibm-research/PowerMoE-3b",
|
||||||
|
data_parallel_size=2,
|
||||||
|
dtype="auto",
|
||||||
|
max_model_len=2048,
|
||||||
|
data_parallel_address="127.0.0.1",
|
||||||
|
data_parallel_rpc_port=62300,
|
||||||
|
data_parallel_size_local=1,
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "Who won the 2004 World Series?"
|
||||||
|
final_output: Optional[RequestOutput] = None
|
||||||
|
async for output in engine_client.generate(
|
||||||
|
prompt=prompt,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
request_id="abcdef",
|
||||||
|
data_parallel_rank=1,
|
||||||
|
):
|
||||||
|
final_output = output
|
||||||
|
if final_output:
|
||||||
|
print(final_output.outputs[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
|
|||||||
None,
|
None,
|
||||||
0.0,
|
0.0,
|
||||||
None,
|
None,
|
||||||
cache_salt=None)
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None)
|
||||||
|
|
||||||
if fast is None:
|
if fast is None:
|
||||||
detokenizer = IncrementalDetokenizer.from_new_request(
|
detokenizer = IncrementalDetokenizer.from_new_request(
|
||||||
|
|||||||
@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
|
|||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,7 @@ def make_request(
|
|||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
|||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
|||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
|||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
|
|||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
|
|||||||
eos_token_id=None,
|
eos_token_id=None,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
cache_salt=None,
|
cache_salt=None,
|
||||||
|
data_parallel_rank=None,
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||||
]
|
]
|
||||||
|
|||||||
@ -442,6 +442,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -456,6 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -473,6 +475,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -902,6 +905,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||||
RequestOutput, PoolingRequestOutput], None]]:
|
RequestOutput, PoolingRequestOutput], None]]:
|
||||||
...
|
...
|
||||||
@ -917,6 +921,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> Coroutine[None, None, AsyncGenerator[Union[
|
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||||
RequestOutput, PoolingRequestOutput], None]]:
|
RequestOutput, PoolingRequestOutput], None]]:
|
||||||
...
|
...
|
||||||
@ -935,6 +940,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||||
@ -967,6 +973,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
return stream.generator()
|
return stream.generator()
|
||||||
@ -980,6 +987,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
@ -999,7 +1007,8 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
for generation, if any.
|
for generation, if any.
|
||||||
priority: The priority of the request.
|
priority: The priority of the request.
|
||||||
Only applicable with priority scheduling.
|
Only applicable with priority scheduling.
|
||||||
|
data_parallel_rank: The (global) data parallel rank that must
|
||||||
|
handle this request. Only applicable if DP is enabled.
|
||||||
Yields:
|
Yields:
|
||||||
The output `RequestOutput` objects from the LLMEngine
|
The output `RequestOutput` objects from the LLMEngine
|
||||||
for the request.
|
for the request.
|
||||||
@ -1057,6 +1066,7 @@ class AsyncLLMEngine(EngineClient):
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, RequestOutput)
|
yield LLMEngine.validate_output(output, RequestOutput)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|||||||
@ -55,6 +55,7 @@ class EngineCoreRequest(
|
|||||||
arrival_time: float
|
arrival_time: float
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
cache_salt: Optional[str]
|
cache_salt: Optional[str]
|
||||||
|
data_parallel_rank: Optional[int]
|
||||||
|
|
||||||
# Index of the client, used to ensure outputs are sent back to the same
|
# Index of the client, used to ensure outputs are sent back to the same
|
||||||
# client for this request when scaling out the front-end.
|
# client for this request when scaling out the front-end.
|
||||||
|
|||||||
@ -229,6 +229,7 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> RequestOutputCollector:
|
) -> RequestOutputCollector:
|
||||||
"""Add new request to the AsyncLLM."""
|
"""Add new request to the AsyncLLM."""
|
||||||
|
|
||||||
@ -245,7 +246,7 @@ class AsyncLLM(EngineClient):
|
|||||||
prompt_str, request = self.processor.process_inputs(
|
prompt_str, request = self.processor.process_inputs(
|
||||||
request_id, prompt, params, arrival_time, lora_request,
|
request_id, prompt, params, arrival_time, lora_request,
|
||||||
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
tokenization_kwargs, trace_headers, prompt_adapter_request,
|
||||||
priority)
|
priority, data_parallel_rank)
|
||||||
|
|
||||||
if params.n == 1:
|
if params.n == 1:
|
||||||
await self._add_request(request, prompt_str, None, 0, queue)
|
await self._add_request(request, prompt_str, None, 0, queue)
|
||||||
@ -291,6 +292,7 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""
|
"""
|
||||||
Main function called by the API server to kick off a request
|
Main function called by the API server to kick off a request
|
||||||
@ -321,6 +323,7 @@ class AsyncLLM(EngineClient):
|
|||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The output_handler task pushes items into the queue.
|
# The output_handler task pushes items into the queue.
|
||||||
|
|||||||
@ -982,7 +982,16 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
resources.stats_update_task = asyncio.create_task(
|
resources.stats_update_task = asyncio.create_task(
|
||||||
run_engine_stats_update_task())
|
run_engine_stats_update_task())
|
||||||
|
|
||||||
def get_core_engine_for_request(self) -> CoreEngine:
|
def get_core_engine_for_request(self,
|
||||||
|
dp_rank: Optional[int] = None
|
||||||
|
) -> CoreEngine:
|
||||||
|
if dp_rank is not None:
|
||||||
|
# engines are already in rank order
|
||||||
|
if dp_rank < 0 or dp_rank >= len(self.core_engines):
|
||||||
|
raise ValueError(f"Requested DP rank {dp_rank} is out of "
|
||||||
|
f"range [0, {len(self.core_engines)})")
|
||||||
|
return self.core_engines[dp_rank]
|
||||||
|
|
||||||
if not self.lb_engines:
|
if not self.lb_engines:
|
||||||
return self.core_engines[0]
|
return self.core_engines[0]
|
||||||
# TODO use P2C alg for larger DP sizes
|
# TODO use P2C alg for larger DP sizes
|
||||||
@ -1018,7 +1027,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
request.current_wave = self.current_wave
|
request.current_wave = self.current_wave
|
||||||
request.client_index = self.client_index
|
request.client_index = self.client_index
|
||||||
|
|
||||||
chosen_engine = self.get_core_engine_for_request()
|
chosen_engine = self.get_core_engine_for_request(
|
||||||
|
request.data_parallel_rank)
|
||||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||||
|
|
||||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||||
|
|||||||
@ -212,6 +212,7 @@ class Processor:
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> tuple[Optional[str], EngineCoreRequest]:
|
) -> tuple[Optional[str], EngineCoreRequest]:
|
||||||
|
|
||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
@ -328,6 +329,7 @@ class Processor:
|
|||||||
arrival_time=arrival_time,
|
arrival_time=arrival_time,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
cache_salt=decoder_inputs.get("cache_salt"),
|
cache_salt=decoder_inputs.get("cache_salt"),
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_model_inputs(self,
|
def _validate_model_inputs(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user