diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index f0e031969e73..ce4c4d198db5 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind): # the engines only synchronize stopping every N steps so # allow a small amount of time here. for _ in range(10): - if core_client.num_engines_running == 0: + if not core_client.engines_running: break await asyncio.sleep(0.5) - assert core_client.num_engines_running == 0 + assert not core_client.engines_running assert not core_client.reqs_in_flight diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index af4122a51077..5f5675b955fa 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -61,6 +61,11 @@ class EngineCoreRequest( arrival_time: float lora_request: Optional[LoRARequest] + # Used in DP case to indicate which wave of requests this is expected to + # belong to, to cover a race condition where the request is sent before + # a wave finished notification is received. + current_wave: int = 0 + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -139,8 +144,12 @@ class EngineCoreOutputs( utility_output: Optional[UtilityOutput] = None finished_requests: Optional[set[str]] = None - # In DP case, used to signal that the engine is paused. - engine_paused: bool = False + # In DP case, used to signal that the current wave of requests + # has finished and the engines are paused. + wave_complete: Optional[int] = None + # In DP case, used to signal that a request was received for an + # "old" wave, so the next wave needs to be started in other engines. + start_wave: Optional[int] = None def __post_init__(self): if self.timestamp == 0.0: @@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' - START_DP = b'\x02' + START_DP_WAVE = b'\x02' UTILITY = b'\x03' # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b'\x04' diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9c4036efd050..2211431fbceb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -325,7 +325,7 @@ class EngineCoreProc(EngineCore): self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) - self.global_unfinished_reqs = False + self.engines_running = False # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, @@ -410,8 +410,7 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.global_unfinished_reqs and not ( - self.scheduler.has_requests()): + while not self.engines_running and not (self.scheduler.has_requests()): if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True @@ -419,10 +418,7 @@ class EngineCoreProc(EngineCore): self._handle_client_request(*req) if waited: - logger.debug( - "EngineCore loop active - local unfinished: %s, finished: %s.", - self.scheduler.has_unfinished_requests(), - self.scheduler.has_finished_requests()) + logger.debug("EngineCore loop active.") # Handle any more client requests. while not self.input_queue.empty(): @@ -446,10 +442,6 @@ class EngineCoreProc(EngineCore): self.add_request(request) elif request_type == EngineCoreRequestType.ABORT: self.abort_requests(request) - elif request_type == EngineCoreRequestType.START_DP: - if not self.global_unfinished_reqs: - logger.debug("EngineCore starting idle loop.") - self.global_unfinished_reqs = True elif request_type == EngineCoreRequestType.UTILITY: call_id, method_name, args = request output = UtilityOutput(call_id) @@ -548,9 +540,6 @@ class EngineCoreProc(EngineCore): socket.send_multipart(buffers, copy=False) -ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True) - - class DPEngineCoreProc(EngineCoreProc): """ZMQ-wrapper for running EngineCore in background process in a data parallel context.""" @@ -587,7 +576,9 @@ class DPEngineCoreProc(EngineCoreProc): for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * tp_size)) + self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.current_wave = 0 # Initialize the engine after setting up environment. super().__init__(input_path, output_path, vllm_config, executor_class, @@ -602,6 +593,31 @@ class DPEngineCoreProc(EngineCoreProc): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) + def add_request(self, request: EngineCoreRequest): + if request.current_wave != self.current_wave: + if request.current_wave > self.current_wave: + self.current_wave = request.current_wave + elif not self.engines_running: + # Request received for an already-completed wave, notify + # front-end that we need to start the next one. + self.output_queue.put_nowait( + EngineCoreOutputs(start_wave=self.current_wave)) + + super().add_request(request) + + def _handle_client_request(self, request_type: EngineCoreRequestType, + request: Any) -> None: + if request_type == EngineCoreRequestType.START_DP_WAVE: + new_wave: int = request + if new_wave >= self.current_wave: + self.current_wave = new_wave + if not self.engines_running: + logger.debug("EngineCore starting idle loop for wave %d.", + new_wave) + self.engines_running = True + else: + super()._handle_client_request(request_type, request) + def run_busy_loop(self): """Core busy loop of the EngineCore for data parallel case.""" @@ -628,7 +644,7 @@ class DPEngineCoreProc(EngineCoreProc): # up-to-date state is returned in the engine outputs. self._process_engine_step() - if not self.global_unfinished_reqs: + if not self.engines_running: # All engines are idle. continue @@ -637,18 +653,23 @@ class DPEngineCoreProc(EngineCoreProc): self.execute_dummy_batch() # 3) All-reduce operation to determine global unfinished reqs. - self.global_unfinished_reqs = self._has_global_unfinished_reqs( + self.engines_running = self._has_global_unfinished_reqs( local_unfinished_reqs) - if not self.global_unfinished_reqs: - # Notify client that we are pausing the loop. - self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS) + if not self.engines_running: + if self.local_dp_rank == 0: + # Notify client that we are pausing the loop. + logger.debug("Wave %d finished, pausing engine loop.", + self.current_wave) + self.output_queue.put_nowait( + EngineCoreOutputs(wave_complete=self.current_wave)) + self.current_wave += 1 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: - # Optimization - only perform finish-sync all-reduce every 16 steps. + # Optimization - only perform finish-sync all-reduce every 24 steps. self.counter += 1 - if self.counter != 16: + if self.counter != 24: return True self.counter = 0 diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f54b3546f06d..0efb5dfb39b7 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient): def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool): - self.num_engines_running = 0 + self.current_wave = 0 + self.engines_running = False self.reqs_in_flight: dict[str, CoreEngine] = {} super().__init__(vllm_config, executor_class, log_stats) - # Control message used for triggering dp idle mode loop. - self.start_dp_msg = (EngineCoreRequestType.START_DP.value, - *self.encoder.encode(None)) - assert len(self.core_engines) > 1 def _init_core_engines( @@ -829,23 +826,23 @@ class DPAsyncMPClient(AsyncMPClient): # NOTE: text prompt is not needed in the core engine as it has been # tokenized. request.prompt = None - - msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request)) + request.current_wave = self.current_wave chosen_engine = self.get_core_engine_for_request() self.reqs_in_flight[request.request_id] = chosen_engine chosen_engine.num_reqs_in_flight += 1 - if self.num_engines_running >= len(self.core_engines): - await self._send_input_message(msg, chosen_engine) - else: + + to_await = self._send_input(EngineCoreRequestType.ADD, request, + chosen_engine) + if not self.engines_running: # Send request to chosen engine and dp start loop # control message to all other engines. - self.num_engines_running += len(self.core_engines) - await asyncio.gather(*[ - self._send_input_message( - msg if engine is chosen_engine else self.start_dp_msg, - engine) for engine in self.core_engines - ]) + self.engines_running = True + to_await = asyncio.gather( + to_await, # type: ignore[assignment] + *self._start_wave_coros(exclude_index=chosen_engine.index)) + + await to_await self._ensure_output_queue_task() @@ -860,21 +857,31 @@ class DPAsyncMPClient(AsyncMPClient): if engine := self.reqs_in_flight.pop(req_id, None): engine.num_reqs_in_flight -= 1 - if outputs.engine_paused: - assert self.num_engines_running >= 1 - self.num_engines_running -= 1 - if not self.num_engines_running and self.reqs_in_flight: - # If there are requests in flight here, they must have - # been sent after the engines paused. We must make - # sure to start the other engines: - self.num_engines_running = len(self.core_engines) - coros = [ - self._send_input_message(self.start_dp_msg, engine) - for engine in self.core_engines - if not engine.num_reqs_in_flight - ] - if coros: - await asyncio.gather(*coros) + if outputs.wave_complete is not None: + # Current wave is complete, move to next wave number + # and mark engines as paused. + if self.current_wave <= outputs.wave_complete: + self.current_wave = outputs.wave_complete + 1 + self.engines_running = False + + elif outputs.start_wave is not None and ( + outputs.start_wave > self.current_wave or + (outputs.start_wave == self.current_wave + and not self.engines_running)): + # Engine received request for a non-current wave so we must ensure + # that other engines progress to the next wave. + self.current_wave = outputs.start_wave + self.engines_running = True + await asyncio.gather(*self._start_wave_coros( + exclude_index=outputs.engine_index)) + + def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]: + logger.debug("Sending start DP wave %d.", self.current_wave) + return [ + self._send_input(EngineCoreRequestType.START_DP_WAVE, + self.current_wave, engine) + for engine in self.core_engines if engine.index != exclude_index + ] async def abort_requests_async(self, request_ids: list[str]) -> None: if not request_ids: