mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:44:57 +08:00
[V1][DP] More robust DP/EP dummy request coordination (#16277)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
bc7c4d206b
commit
1e013fa388
@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
|
|||||||
# the engines only synchronize stopping every N steps so
|
# the engines only synchronize stopping every N steps so
|
||||||
# allow a small amount of time here.
|
# allow a small amount of time here.
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
if core_client.num_engines_running == 0:
|
if not core_client.engines_running:
|
||||||
break
|
break
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
assert core_client.num_engines_running == 0
|
assert not core_client.engines_running
|
||||||
assert not core_client.reqs_in_flight
|
assert not core_client.reqs_in_flight
|
||||||
|
|||||||
@ -61,6 +61,11 @@ class EngineCoreRequest(
|
|||||||
arrival_time: float
|
arrival_time: float
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
|
|
||||||
|
# Used in DP case to indicate which wave of requests this is expected to
|
||||||
|
# belong to, to cover a race condition where the request is sent before
|
||||||
|
# a wave finished notification is received.
|
||||||
|
current_wave: int = 0
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreEventType(enum.IntEnum):
|
class EngineCoreEventType(enum.IntEnum):
|
||||||
"""The type of engine core request event."""
|
"""The type of engine core request event."""
|
||||||
@ -139,8 +144,12 @@ class EngineCoreOutputs(
|
|||||||
utility_output: Optional[UtilityOutput] = None
|
utility_output: Optional[UtilityOutput] = None
|
||||||
finished_requests: Optional[set[str]] = None
|
finished_requests: Optional[set[str]] = None
|
||||||
|
|
||||||
# In DP case, used to signal that the engine is paused.
|
# In DP case, used to signal that the current wave of requests
|
||||||
engine_paused: bool = False
|
# has finished and the engines are paused.
|
||||||
|
wave_complete: Optional[int] = None
|
||||||
|
# In DP case, used to signal that a request was received for an
|
||||||
|
# "old" wave, so the next wave needs to be started in other engines.
|
||||||
|
start_wave: Optional[int] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.timestamp == 0.0:
|
if self.timestamp == 0.0:
|
||||||
@ -154,7 +163,7 @@ class EngineCoreRequestType(enum.Enum):
|
|||||||
"""
|
"""
|
||||||
ADD = b'\x00'
|
ADD = b'\x00'
|
||||||
ABORT = b'\x01'
|
ABORT = b'\x01'
|
||||||
START_DP = b'\x02'
|
START_DP_WAVE = b'\x02'
|
||||||
UTILITY = b'\x03'
|
UTILITY = b'\x03'
|
||||||
# Sentinel used within EngineCoreProc.
|
# Sentinel used within EngineCoreProc.
|
||||||
EXECUTOR_FAILED = b'\x04'
|
EXECUTOR_FAILED = b'\x04'
|
||||||
|
|||||||
@ -325,7 +325,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
|
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
self.step_with_batch_queue)
|
self.step_with_batch_queue)
|
||||||
self.global_unfinished_reqs = False
|
self.engines_running = False
|
||||||
|
|
||||||
# Background Threads and Queues for IO. These enable us to
|
# Background Threads and Queues for IO. These enable us to
|
||||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||||
@ -410,8 +410,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
"""Exits when an engine step needs to be performed."""
|
"""Exits when an engine step needs to be performed."""
|
||||||
|
|
||||||
waited = False
|
waited = False
|
||||||
while not self.global_unfinished_reqs and not (
|
while not self.engines_running and not (self.scheduler.has_requests()):
|
||||||
self.scheduler.has_requests()):
|
|
||||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||||
logger.debug("EngineCore waiting for work.")
|
logger.debug("EngineCore waiting for work.")
|
||||||
waited = True
|
waited = True
|
||||||
@ -419,10 +418,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
self._handle_client_request(*req)
|
self._handle_client_request(*req)
|
||||||
|
|
||||||
if waited:
|
if waited:
|
||||||
logger.debug(
|
logger.debug("EngineCore loop active.")
|
||||||
"EngineCore loop active - local unfinished: %s, finished: %s.",
|
|
||||||
self.scheduler.has_unfinished_requests(),
|
|
||||||
self.scheduler.has_finished_requests())
|
|
||||||
|
|
||||||
# Handle any more client requests.
|
# Handle any more client requests.
|
||||||
while not self.input_queue.empty():
|
while not self.input_queue.empty():
|
||||||
@ -446,10 +442,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
elif request_type == EngineCoreRequestType.ABORT:
|
elif request_type == EngineCoreRequestType.ABORT:
|
||||||
self.abort_requests(request)
|
self.abort_requests(request)
|
||||||
elif request_type == EngineCoreRequestType.START_DP:
|
|
||||||
if not self.global_unfinished_reqs:
|
|
||||||
logger.debug("EngineCore starting idle loop.")
|
|
||||||
self.global_unfinished_reqs = True
|
|
||||||
elif request_type == EngineCoreRequestType.UTILITY:
|
elif request_type == EngineCoreRequestType.UTILITY:
|
||||||
call_id, method_name, args = request
|
call_id, method_name, args = request
|
||||||
output = UtilityOutput(call_id)
|
output = UtilityOutput(call_id)
|
||||||
@ -548,9 +540,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
socket.send_multipart(buffers, copy=False)
|
socket.send_multipart(buffers, copy=False)
|
||||||
|
|
||||||
|
|
||||||
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DPEngineCoreProc(EngineCoreProc):
|
class DPEngineCoreProc(EngineCoreProc):
|
||||||
"""ZMQ-wrapper for running EngineCore in background process
|
"""ZMQ-wrapper for running EngineCore in background process
|
||||||
in a data parallel context."""
|
in a data parallel context."""
|
||||||
@ -587,7 +576,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||||
tp_size))
|
tp_size))
|
||||||
|
|
||||||
|
self.local_dp_rank = local_dp_rank
|
||||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||||
|
self.current_wave = 0
|
||||||
|
|
||||||
# Initialize the engine after setting up environment.
|
# Initialize the engine after setting up environment.
|
||||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||||
@ -602,6 +593,31 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
if dp_group := getattr(self, "dp_group", None):
|
if dp_group := getattr(self, "dp_group", None):
|
||||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||||
|
|
||||||
|
def add_request(self, request: EngineCoreRequest):
|
||||||
|
if request.current_wave != self.current_wave:
|
||||||
|
if request.current_wave > self.current_wave:
|
||||||
|
self.current_wave = request.current_wave
|
||||||
|
elif not self.engines_running:
|
||||||
|
# Request received for an already-completed wave, notify
|
||||||
|
# front-end that we need to start the next one.
|
||||||
|
self.output_queue.put_nowait(
|
||||||
|
EngineCoreOutputs(start_wave=self.current_wave))
|
||||||
|
|
||||||
|
super().add_request(request)
|
||||||
|
|
||||||
|
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||||
|
request: Any) -> None:
|
||||||
|
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||||
|
new_wave: int = request
|
||||||
|
if new_wave >= self.current_wave:
|
||||||
|
self.current_wave = new_wave
|
||||||
|
if not self.engines_running:
|
||||||
|
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||||
|
new_wave)
|
||||||
|
self.engines_running = True
|
||||||
|
else:
|
||||||
|
super()._handle_client_request(request_type, request)
|
||||||
|
|
||||||
def run_busy_loop(self):
|
def run_busy_loop(self):
|
||||||
"""Core busy loop of the EngineCore for data parallel case."""
|
"""Core busy loop of the EngineCore for data parallel case."""
|
||||||
|
|
||||||
@ -628,7 +644,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
# up-to-date state is returned in the engine outputs.
|
# up-to-date state is returned in the engine outputs.
|
||||||
self._process_engine_step()
|
self._process_engine_step()
|
||||||
|
|
||||||
if not self.global_unfinished_reqs:
|
if not self.engines_running:
|
||||||
# All engines are idle.
|
# All engines are idle.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -637,18 +653,23 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
self.execute_dummy_batch()
|
self.execute_dummy_batch()
|
||||||
|
|
||||||
# 3) All-reduce operation to determine global unfinished reqs.
|
# 3) All-reduce operation to determine global unfinished reqs.
|
||||||
self.global_unfinished_reqs = self._has_global_unfinished_reqs(
|
self.engines_running = self._has_global_unfinished_reqs(
|
||||||
local_unfinished_reqs)
|
local_unfinished_reqs)
|
||||||
|
|
||||||
if not self.global_unfinished_reqs:
|
if not self.engines_running:
|
||||||
|
if self.local_dp_rank == 0:
|
||||||
# Notify client that we are pausing the loop.
|
# Notify client that we are pausing the loop.
|
||||||
self.output_queue.put_nowait(ENGINE_PAUSED_OUTPUTS)
|
logger.debug("Wave %d finished, pausing engine loop.",
|
||||||
|
self.current_wave)
|
||||||
|
self.output_queue.put_nowait(
|
||||||
|
EngineCoreOutputs(wave_complete=self.current_wave))
|
||||||
|
self.current_wave += 1
|
||||||
|
|
||||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||||
|
|
||||||
# Optimization - only perform finish-sync all-reduce every 16 steps.
|
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
if self.counter != 16:
|
if self.counter != 24:
|
||||||
return True
|
return True
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|||||||
@ -792,15 +792,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
||||||
log_stats: bool):
|
log_stats: bool):
|
||||||
|
|
||||||
self.num_engines_running = 0
|
self.current_wave = 0
|
||||||
|
self.engines_running = False
|
||||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||||
|
|
||||||
super().__init__(vllm_config, executor_class, log_stats)
|
super().__init__(vllm_config, executor_class, log_stats)
|
||||||
|
|
||||||
# Control message used for triggering dp idle mode loop.
|
|
||||||
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
|
|
||||||
*self.encoder.encode(None))
|
|
||||||
|
|
||||||
assert len(self.core_engines) > 1
|
assert len(self.core_engines) > 1
|
||||||
|
|
||||||
def _init_core_engines(
|
def _init_core_engines(
|
||||||
@ -829,23 +826,23 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
# NOTE: text prompt is not needed in the core engine as it has been
|
# NOTE: text prompt is not needed in the core engine as it has been
|
||||||
# tokenized.
|
# tokenized.
|
||||||
request.prompt = None
|
request.prompt = None
|
||||||
|
request.current_wave = self.current_wave
|
||||||
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
|
|
||||||
|
|
||||||
chosen_engine = self.get_core_engine_for_request()
|
chosen_engine = self.get_core_engine_for_request()
|
||||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||||
chosen_engine.num_reqs_in_flight += 1
|
chosen_engine.num_reqs_in_flight += 1
|
||||||
if self.num_engines_running >= len(self.core_engines):
|
|
||||||
await self._send_input_message(msg, chosen_engine)
|
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||||
else:
|
chosen_engine)
|
||||||
|
if not self.engines_running:
|
||||||
# Send request to chosen engine and dp start loop
|
# Send request to chosen engine and dp start loop
|
||||||
# control message to all other engines.
|
# control message to all other engines.
|
||||||
self.num_engines_running += len(self.core_engines)
|
self.engines_running = True
|
||||||
await asyncio.gather(*[
|
to_await = asyncio.gather(
|
||||||
self._send_input_message(
|
to_await, # type: ignore[assignment]
|
||||||
msg if engine is chosen_engine else self.start_dp_msg,
|
*self._start_wave_coros(exclude_index=chosen_engine.index))
|
||||||
engine) for engine in self.core_engines
|
|
||||||
])
|
await to_await
|
||||||
|
|
||||||
self._ensure_output_queue_task()
|
self._ensure_output_queue_task()
|
||||||
|
|
||||||
@ -860,21 +857,31 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
if engine := self.reqs_in_flight.pop(req_id, None):
|
if engine := self.reqs_in_flight.pop(req_id, None):
|
||||||
engine.num_reqs_in_flight -= 1
|
engine.num_reqs_in_flight -= 1
|
||||||
|
|
||||||
if outputs.engine_paused:
|
if outputs.wave_complete is not None:
|
||||||
assert self.num_engines_running >= 1
|
# Current wave is complete, move to next wave number
|
||||||
self.num_engines_running -= 1
|
# and mark engines as paused.
|
||||||
if not self.num_engines_running and self.reqs_in_flight:
|
if self.current_wave <= outputs.wave_complete:
|
||||||
# If there are requests in flight here, they must have
|
self.current_wave = outputs.wave_complete + 1
|
||||||
# been sent after the engines paused. We must make
|
self.engines_running = False
|
||||||
# sure to start the other engines:
|
|
||||||
self.num_engines_running = len(self.core_engines)
|
elif outputs.start_wave is not None and (
|
||||||
coros = [
|
outputs.start_wave > self.current_wave or
|
||||||
self._send_input_message(self.start_dp_msg, engine)
|
(outputs.start_wave == self.current_wave
|
||||||
for engine in self.core_engines
|
and not self.engines_running)):
|
||||||
if not engine.num_reqs_in_flight
|
# Engine received request for a non-current wave so we must ensure
|
||||||
|
# that other engines progress to the next wave.
|
||||||
|
self.current_wave = outputs.start_wave
|
||||||
|
self.engines_running = True
|
||||||
|
await asyncio.gather(*self._start_wave_coros(
|
||||||
|
exclude_index=outputs.engine_index))
|
||||||
|
|
||||||
|
def _start_wave_coros(self, exclude_index: int) -> list[Awaitable[None]]:
|
||||||
|
logger.debug("Sending start DP wave %d.", self.current_wave)
|
||||||
|
return [
|
||||||
|
self._send_input(EngineCoreRequestType.START_DP_WAVE,
|
||||||
|
self.current_wave, engine)
|
||||||
|
for engine in self.core_engines if engine.index != exclude_index
|
||||||
]
|
]
|
||||||
if coros:
|
|
||||||
await asyncio.gather(*coros)
|
|
||||||
|
|
||||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||||
if not request_ids:
|
if not request_ids:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user