mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 19:25:01 +08:00
[V1][DP] More robust DP/EP dummy request coordination (#16277)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
bc7c4d206b
commit
1e013fa388
@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
|
||||
# the engines only synchronize stopping every N steps so
|
||||
# allow a small amount of time here.
|
||||
for _ in range(10):
|
||||
if core_client.num_engines_running == 0:
|
||||
if not core_client.engines_running:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
assert core_client.num_engines_running == 0
|
||||
assert not core_client.engines_running
|
||||
assert not core_client.reqs_in_flight
|
||||
|
||||
@ -61,6 +61,11 @@ class EngineCoreRequest(
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
@ -139,8 +144,12 @@ class EngineCoreOutputs(
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
finished_requests: Optional[set[str]] = None
|
||||
|
||||
# In DP case, used to signal that the engine is paused.
|
||||
engine_paused: bool = False
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: Optional[int] = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
START_DP = b'\x02'
|
||||
START_DP_WAVE = b'\x02'
|
||||
UTILITY = b'\x03'
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b'\x04'
|
||||
|
||||
@ -325,7 +325,7 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.global_unfinished_reqs = False
|
||||
self.engines_running = False
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
@ -410,8 +410,7 @@ class EngineCoreProc(EngineCore):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.global_unfinished_reqs and not (
|
||||
self.scheduler.has_requests()):
|
||||
while not self.engines_running and not (self.scheduler.has_requests()):
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
@ -419,10 +418,7 @@ class EngineCoreProc(EngineCore):
|
||||
self._handle_client_request(*req)
|
||||
|
||||
if waited:
|
||||
logger.debug(
|
||||
"EngineCore loop active - local unfinished: %s, finished: %s.",
|
||||
self.scheduler.has_unfinished_requests(),
|
||||
self.scheduler.has_finished_requests())
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
@ -446,10 +442,6 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.START_DP:
|
||||
if not self.global_unfinished_reqs:
|
||||
logger.debug("EngineCore starting idle loop.")
|
||||
self.global_unfinished_reqs = True
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
@ -548,9 +540,6 @@ class EngineCoreProc(EngineCore):
|
||||
socket.send_multipart(buffers, copy=False)
|
||||
|
||||
|
||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
"""ZMQ-wrapper for running EngineCore in background process
|
||||
in a data parallel context."""
|
||||
@ -587,7 +576,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||
tp_size))
|
||||
|
||||
self.local_dp_rank = local_dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine after setting up environment.
|
||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||
@ -602,6 +593,31 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(start_wave=self.current_wave))
|
||||
|
||||
super().add_request(request)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||
new_wave: int = request
|
||||
if new_wave >= self.current_wave:
|
||||
self.current_wave = new_wave
|
||||
if not self.engines_running:
|
||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||
new_wave)
|
||||
self.engines_running = True
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
@ -628,7 +644,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# up-to-date state is returned in the engine outputs.
|
||||
self._process_engine_step()
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
if not self.engines_running:
|
||||
# All engines are idle.
|
||||
continue
|
||||
|
||||
@ -637,18 +653,23 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.execute_dummy_batch()
|
||||
|
||||
# 3) All-reduce operation to determine global unfinished reqs.
|
||||
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
|
||||
self.engines_running = self._has_global_unfinished_reqs(
|
||||
local_unfinished_reqs)
|
||||
|
||||
if not self.global_unfinished_reqs:
|
||||
if not self.engines_running:
|
||||
if self.local_dp_rank == 0:
|
||||
# Notify client that we are pausing the loop.
|
||||
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
|
||||
logger.debug("Wave %d finished, pausing engine loop.",
|
||||
self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(wave_complete=self.current_wave))
|
||||
self.current_wave += 1
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 16 steps.
|
||||
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 16:
|
||||
if self.counter != 24:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
|
||||
@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||
log_stats: bool):
|
||||
|
||||
self.num_engines_running = 0
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
# Control message used for triggering dp idle mode loop.
|
||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
||||
*self.encoder.encode(None))
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
def _init_core_engines(
|
||||
@ -829,23 +826,23 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# NOTE: text prompt is not needed in the core engine as it has been
|
||||
# tokenized.
|
||||
request.prompt = None
|
||||
|
||||
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
|
||||
request.current_wave = self.current_wave
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request()
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
chosen_engine.num_reqs_in_flight += 1
|
||||
if self.num_engines_running >= len(self.core_engines):
|
||||
await self._send_input_message(msg, chosen_engine)
|
||||
else:
|
||||
|
||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||
chosen_engine)
|
||||
if not self.engines_running:
|
||||
# Send request to chosen engine and dp start loop
|
||||
# control message to all other engines.
|
||||
self.num_engines_running += len(self.core_engines)
|
||||
await asyncio.gather(*[
|
||||
self._send_input_message(
|
||||
msg if engine is chosen_engine else self.start_dp_msg,
|
||||
engine) for engine in self.core_engines
|
||||
])
|
||||
self.engines_running = True
|
||||
to_await = asyncio.gather(
|
||||
to_await, # type: ignore[assignment]
|
||||
*self._start_wave_coros(exclude_index=chosen_engine.index))
|
||||
|
||||
await to_await
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
@ -860,21 +857,31 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
if engine := self.reqs_in_flight.pop(req_id, None):
|
||||
engine.num_reqs_in_flight -= 1
|
||||
|
||||
if outputs.engine_paused:
|
||||
assert self.num_engines_running >= 1
|
||||
self.num_engines_running -= 1
|
||||
if not self.num_engines_running and self.reqs_in_flight:
|
||||
# If there are requests in flight here, they must have
|
||||
# been sent after the engines paused. We must make
|
||||
# sure to start the other engines:
|
||||
self.num_engines_running = len(self.core_engines)
|
||||
coros = [
|
||||
self._send_input_message(self.start_dp_msg, engine)
|
||||
for engine in self.core_engines
|
||||
if not engine.num_reqs_in_flight
|
||||
if outputs.wave_complete is not None:
|
||||
# Current wave is complete, move to next wave number
|
||||
# and mark engines as paused.
|
||||
if self.current_wave <= outputs.wave_complete:
|
||||
self.current_wave = outputs.wave_complete + 1
|
||||
self.engines_running = False
|
||||
|
||||
elif outputs.start_wave is not None and (
|
||||
outputs.start_wave > self.current_wave or
|
||||
(outputs.start_wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
# Engine received request for a non-current wave so we must ensure
|
||||
# that other engines progress to the next wave.
|
||||
self.current_wave = outputs.start_wave
|
||||
self.engines_running = True
|
||||
await asyncio.gather(*self._start_wave_coros(
|
||||
exclude_index=outputs.engine_index))
|
||||
|
||||
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
|
||||
logger.debug("Sending start DP wave %d.", self.current_wave)
|
||||
return [
|
||||
self._send_input(EngineCoreRequestType.START_DP_WAVE,
|
||||
self.current_wave, engine)
|
||||
for engine in self.core_engines if engine.index != exclude_index
|
||||
]
|
||||
if coros:
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user