[V1][DP] More robust DP/EP dummy request coordination (#16277)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-22 19:12:15 -07:00 committed by GitHub
parent bc7c4d206b
commit 1e013fa388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 94 additions and 57 deletions

View File

@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so # the engines only synchronize stopping every N steps so
# allow a small amount of time here. # allow a small amount of time here.
for _ in range(10): for _ in range(10):
if core_client.num_engines_running == 0: if not core_client.engines_running:
break break
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
assert core_client.num_engines_running == 0 assert not core_client.engines_running
assert not core_client.reqs_in_flight assert not core_client.reqs_in_flight

View File

@ -61,6 +61,11 @@ class EngineCoreRequest(
arrival_time: float arrival_time: float
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
# Used in DP case to indicate which wave of requests this is expected to
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
class EngineCoreEventType(enum.IntEnum): class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event.""" """The type of engine core request event."""
@ -139,8 +144,12 @@ class EngineCoreOutputs(
utility_output: Optional[UtilityOutput] = None utility_output: Optional[UtilityOutput] = None
finished_requests: Optional[set[str]] = None finished_requests: Optional[set[str]] = None
# In DP case, used to signal that the engine is paused. # In DP case, used to signal that the current wave of requests
engine_paused: bool = False # has finished and the engines are paused.
wave_complete: Optional[int] = None
# In DP case, used to signal that a request was received for an
# "old" wave, so the next wave needs to be started in other engines.
start_wave: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.timestamp == 0.0: if self.timestamp == 0.0:
@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
""" """
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
START_DP = b'\x02' START_DP_WAVE = b'\x02'
UTILITY = b'\x03' UTILITY = b'\x03'
# Sentinel used within EngineCoreProc. # Sentinel used within EngineCoreProc.
EXECUTOR_FAILED = b'\x04' EXECUTOR_FAILED = b'\x04'

View File

@ -325,7 +325,7 @@ class EngineCoreProc(EngineCore):
self.step_fn = (self.step if self.batch_queue is None else self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue) self.step_with_batch_queue)
self.global_unfinished_reqs = False self.engines_running = False
# Background Threads and Queues for IO. These enable us to # Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL, # overlap ZMQ socket IO with GPU since they release the GIL,
@ -410,8 +410,7 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed.""" """Exits when an engine step needs to be performed."""
waited = False waited = False
while not self.global_unfinished_reqs and not ( while not self.engines_running and not (self.scheduler.has_requests()):
self.scheduler.has_requests()):
if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.") logger.debug("EngineCore waiting for work.")
waited = True waited = True
@ -419,10 +418,7 @@ class EngineCoreProc(EngineCore):
self._handle_client_request(*req) self._handle_client_request(*req)
if waited: if waited:
logger.debug( logger.debug("EngineCore loop active.")
"EngineCore loop active - local unfinished: %s, finished: %s.",
self.scheduler.has_unfinished_requests(),
self.scheduler.has_finished_requests())
# Handle any more client requests. # Handle any more client requests.
while not self.input_queue.empty(): while not self.input_queue.empty():
@ -446,10 +442,6 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT: elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request) self.abort_requests(request)
elif request_type == EngineCoreRequestType.START_DP:
if not self.global_unfinished_reqs:
logger.debug("EngineCore starting idle loop.")
self.global_unfinished_reqs = True
elif request_type == EngineCoreRequestType.UTILITY: elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request call_id, method_name, args = request
output = UtilityOutput(call_id) output = UtilityOutput(call_id)
@ -548,9 +540,6 @@ class EngineCoreProc(EngineCore):
socket.send_multipart(buffers, copy=False) socket.send_multipart(buffers, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
class DPEngineCoreProc(EngineCoreProc): class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process """ZMQ-wrapper for running EngineCore in background process
in a data parallel context.""" in a data parallel context."""
@ -587,7 +576,9 @@ class DPEngineCoreProc(EngineCoreProc):
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size)) tp_size))
self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
# Initialize the engine after setting up environment. # Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class, super().__init__(input_path, output_path, vllm_config, executor_class,
@ -602,6 +593,31 @@ class DPEngineCoreProc(EngineCoreProc):
if dp_group := getattr(self, "dp_group", None): if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group) stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
EngineCoreOutputs(start_wave=self.current_wave))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave: int = request
if new_wave >= self.current_wave:
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
new_wave)
self.engines_running = True
else:
super()._handle_client_request(request_type, request)
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case.""" """Core busy loop of the EngineCore for data parallel case."""
@ -628,7 +644,7 @@ class DPEngineCoreProc(EngineCoreProc):
# up-to-date state is returned in the engine outputs. # up-to-date state is returned in the engine outputs.
self._process_engine_step() self._process_engine_step()
if not self.global_unfinished_reqs: if not self.engines_running:
# All engines are idle. # All engines are idle.
continue continue
@ -637,18 +653,23 @@ class DPEngineCoreProc(EngineCoreProc):
self.execute_dummy_batch() self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs. # 3) All-reduce operation to determine global unfinished reqs.
self.global_unfinished_reqs = self._has_global_unfinished_reqs( self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs) local_unfinished_reqs)
if not self.global_unfinished_reqs: if not self.engines_running:
if self.local_dp_rank == 0:
# Notify client that we are pausing the loop. # Notify client that we are pausing the loop.
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
self.output_queue.put_nowait(
EngineCoreOutputs(wave_complete=self.current_wave))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 16 steps. # Optimization - only perform finish-sync all-reduce every 24 steps.
self.counter += 1 self.counter += 1
if self.counter != 16: if self.counter != 24:
return True return True
self.counter = 0 self.counter = 0

View File

@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
log_stats: bool): log_stats: bool):
self.num_engines_running = 0 self.current_wave = 0
self.engines_running = False
self.reqs_in_flight: dict[str, CoreEngine] = {} self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats) super().__init__(vllm_config, executor_class, log_stats)
# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
*self.encoder.encode(None))
assert len(self.core_engines) > 1 assert len(self.core_engines) > 1
def _init_core_engines( def _init_core_engines(
@ -829,23 +826,23 @@ class DPAsyncMPClient(AsyncMPClient):
# NOTE: text prompt is not needed in the core engine as it has been # NOTE: text prompt is not needed in the core engine as it has been
# tokenized. # tokenized.
request.prompt = None request.prompt = None
request.current_wave = self.current_wave
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
chosen_engine = self.get_core_engine_for_request() chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine self.reqs_in_flight[request.request_id] = chosen_engine
chosen_engine.num_reqs_in_flight += 1 chosen_engine.num_reqs_in_flight += 1
if self.num_engines_running >= len(self.core_engines):
await self._send_input_message(msg, chosen_engine) to_await = self._send_input(EngineCoreRequestType.ADD, request,
else: chosen_engine)
if not self.engines_running:
# Send request to chosen engine and dp start loop # Send request to chosen engine and dp start loop
# control message to all other engines. # control message to all other engines.
self.num_engines_running += len(self.core_engines) self.engines_running = True
await asyncio.gather(*[ to_await = asyncio.gather(
self._send_input_message( to_await, # type: ignore[assignment]
msg if engine is chosen_engine else self.start_dp_msg, *self._start_wave_coros(exclude_index=chosen_engine.index))
engine) for engine in self.core_engines
]) await to_await
self._ensure_output_queue_task() self._ensure_output_queue_task()
@ -860,21 +857,31 @@ class DPAsyncMPClient(AsyncMPClient):
if engine := self.reqs_in_flight.pop(req_id, None): if engine := self.reqs_in_flight.pop(req_id, None):
engine.num_reqs_in_flight -= 1 engine.num_reqs_in_flight -= 1
if outputs.engine_paused: if outputs.wave_complete is not None:
assert self.num_engines_running >= 1 # Current wave is complete, move to next wave number
self.num_engines_running -= 1 # and mark engines as paused.
if not self.num_engines_running and self.reqs_in_flight: if self.current_wave <= outputs.wave_complete:
# If there are requests in flight here, they must have self.current_wave = outputs.wave_complete + 1
# been sent after the engines paused. We must make self.engines_running = False
# sure to start the other engines:
self.num_engines_running = len(self.core_engines) elif outputs.start_wave is not None and (
coros = [ outputs.start_wave > self.current_wave or
self._send_input_message(self.start_dp_msg, engine) (outputs.start_wave == self.current_wave
for engine in self.core_engines and not self.engines_running)):
if not engine.num_reqs_in_flight # Engine received request for a non-current wave so we must ensure
# that other engines progress to the next wave.
self.current_wave = outputs.start_wave
self.engines_running = True
await asyncio.gather(*self._start_wave_coros(
exclude_index=outputs.engine_index))
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
logger.debug("Sending start DP wave %d.", self.current_wave)
return [
self._send_input(EngineCoreRequestType.START_DP_WAVE,
self.current_wave, engine)
for engine in self.core_engines if engine.index != exclude_index
] ]
if coros:
await asyncio.gather(*coros)
async def abort_requests_async(self, request_ids: list[str]) -> None: async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids: if not request_ids: