[BugFix] Improve internal DP load balancing (#21617)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-02 03:45:27 +01:00 committed by GitHub
parent 9f9c38c392
commit 8d524ce79f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 122 additions and 59 deletions

View File

@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None async_llm: Optional[AsyncLLM] = None
client_count = client_config.pop(
"client_count") if client_config else 1
client_index = client_config.pop( client_index = client_config.pop(
"client_index") if client_config else 0 "client_index") if client_config else 0
try: try:
@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
enable_log_requests=engine_args.enable_log_requests, enable_log_requests=engine_args.enable_log_requests,
disable_log_stats=engine_args.disable_log_stats, disable_log_stats=engine_args.disable_log_stats,
client_addresses=client_config, client_addresses=client_config,
client_count=client_count,
client_index=client_index) client_index=client_index)
# Don't keep the dummy data in memory # Don't keep the dummy data in memory

View File

@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
start_engine_loop: bool = True, start_engine_loop: bool = True,
stat_loggers: Optional[list[StatLoggerFactory]] = None, stat_loggers: Optional[list[StatLoggerFactory]] = None,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0, client_index: int = 0,
) -> None: ) -> None:
""" """
@ -120,6 +121,7 @@ class AsyncLLM(EngineClient):
executor_class=executor_class, executor_class=executor_class,
log_stats=self.log_stats, log_stats=self.log_stats,
client_addresses=client_addresses, client_addresses=client_addresses,
client_count=client_count,
client_index=client_index, client_index=client_index,
) )
@ -156,6 +158,7 @@ class AsyncLLM(EngineClient):
enable_log_requests: bool = False, enable_log_requests: bool = False,
disable_log_stats: bool = False, disable_log_stats: bool = False,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0, client_index: int = 0,
disable_log_requests: bool = True, # Deprecated, will be removed disable_log_requests: bool = True, # Deprecated, will be removed
) -> "AsyncLLM": ) -> "AsyncLLM":
@ -176,6 +179,7 @@ class AsyncLLM(EngineClient):
log_stats=not disable_log_stats, log_stats=not disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
client_addresses=client_addresses, client_addresses=client_addresses,
client_count=client_count,
client_index=client_index, client_index=client_index,
) )

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import multiprocessing import multiprocessing
import time import time
import weakref import weakref
@ -65,18 +66,14 @@ class DPCoordinator:
# Assume coordinator is colocated with front-end procs when not in # Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode. # either external or hybrid DP LB mode.
local_only = not (external_lb or hybrid_lb)
front_publish_address = get_engine_client_zmq_addr( front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb and not hybrid_lb, host=host) local_only=local_only, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context() context = get_mp_context()
self.proc: multiprocessing.Process = context.Process( self.proc: multiprocessing.Process = context.Process(
target=DPCoordinatorProc.run_coordinator, target=DPCoordinatorProc.run_coordinator,
@ -86,7 +83,6 @@ class DPCoordinator:
"front_publish_address": front_publish_address, "front_publish_address": front_publish_address,
"back_output_address": back_output_address, "back_output_address": back_output_address,
"back_publish_address": back_publish_address, "back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
}, },
daemon=True) daemon=True)
self.proc.start() self.proc.start()
@ -125,10 +121,6 @@ class DPCoordinatorProc:
self.stats_update_interval_ms = min_stats_update_interval_ms self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@staticmethod @staticmethod
def run_coordinator( def run_coordinator(
engine_count: int, engine_count: int,
@ -155,6 +147,16 @@ class DPCoordinatorProc:
decoder = MsgpackDecoder(EngineCoreOutputs) decoder = MsgpackDecoder(EngineCoreOutputs)
# For tracking request wave progression.
current_wave = 0
engines_running = False
# For tracking request counts for internal load-balancing.
stats_changed = False
last_stats_step = -1
last_stats_wave = -1
last_step_counts: Optional[list[list[int]]] = None
with make_zmq_socket( with make_zmq_socket(
path=front_publish_address, # IPC path=front_publish_address, # IPC
ctx=self.ctx, ctx=self.ctx,
@ -191,21 +193,33 @@ class DPCoordinatorProc:
while True: while True:
elapsed = int(time.time() * 1000) - last_publish_time elapsed = int(time.time() * 1000) - last_publish_time
# Send at stats_update_interval_ms interval if the stats have # Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds. # changed, or otherwise every 5 seconds.
wait_for = (self.stats_update_interval_ms wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000) if stats_changed else 5000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
# Wait at least 50ms to ensure we've received all stats for
# the current step.
min_timeout = 50 if last_step_counts is None else 0
events = poller.poll(timeout=max(min_timeout, wait_for -
elapsed))
if not events: if not events:
# Poller timeout - publish current stats to front-ends. # Poller timeout - publish current stats to front-ends.
engine_req_counts_list = self._get_engine_counts() if last_step_counts is not None:
to_publish = (engine_req_counts_list, self.current_wave, engine_req_counts_list = last_step_counts
self.engines_running) last_step_counts = None
else:
engine_req_counts_list = self._get_engine_counts()
stats_changed = False
to_publish = (engine_req_counts_list, current_wave,
engines_running)
publish_front.send(msgspec.msgpack.encode(to_publish)) publish_front.send(msgspec.msgpack.encode(to_publish))
last_publish_time = int(time.time() * 1000) last_publish_time = int(time.time() * 1000)
self.stats_changed = False
continue continue
events = dict(events) events = dict(events)
wave_state_changed = False
if publish_front in events: if publish_front in events:
buffer = publish_front.recv() buffer = publish_front.recv()
@ -232,7 +246,7 @@ class DPCoordinatorProc:
# current_wave # current_wave
# we note that 0 is the wave number for the new # we note that 0 is the wave number for the new
# engine # engine
self.engines_running = False engines_running = False
logger.info( logger.info(
"DPCoordinator scaled up from %s to %s " "DPCoordinator scaled up from %s to %s "
"engines", current_count, new_engine_count) "engines", current_count, new_engine_count)
@ -248,15 +262,15 @@ class DPCoordinatorProc:
# engines are paused, so that we can wake the other # engines are paused, so that we can wake the other
# engines. # engines.
engine_to_exclude, wave = decoded engine_to_exclude, wave = decoded
if not self.engines_running: if not engines_running:
if wave < self.current_wave: if wave < current_wave:
# If the wave number is stale, ensure the message # If the wave number is stale, ensure the message
# is handled by all the engines. # is handled by all the engines.
engine_to_exclude = None engine_to_exclude = None
self.engines_running = True engines_running = True
self.stats_changed = True wave_state_changed = True
self._send_start_wave(publish_back, self.current_wave, self._send_start_wave(publish_back, current_wave,
engine_to_exclude) engine_to_exclude)
if output_back in events: if output_back in events:
@ -274,36 +288,56 @@ class DPCoordinatorProc:
# 1. Updated request load stats - update our local # 1. Updated request load stats - update our local
# state with these. # state with these.
stats = self.engines[eng_index].request_counts stats = self.engines[eng_index].request_counts
stats_step = scheduler_stats.step_counter
stats_wave = scheduler_stats.current_wave
if (stats_wave > last_stats_wave
or stats_wave == last_stats_wave
and stats_step > last_stats_step):
if stats_changed:
last_step_counts = self._get_engine_counts(
do_copy=True)
last_stats_step = stats_step
last_stats_wave = stats_wave
elif stats_wave != last_stats_wave or (
stats_step != last_stats_step):
logger.warning(
"Received stats for out-of-order "
"step (%d, %d) from engine %d (expected "
"> (%d, %d))", stats_wave, stats_step,
eng_index, last_stats_wave, last_stats_step)
stats[0] = scheduler_stats.num_waiting_reqs stats[0] = scheduler_stats.num_waiting_reqs
stats[1] = scheduler_stats.num_running_reqs stats[1] = scheduler_stats.num_running_reqs
self.stats_changed = True stats_changed = True
if (wave := outputs.wave_complete) is not None: if (wave := outputs.wave_complete) is not None:
# 2. Notification from rank 0 engine that we've # 2. Notification from rank 0 engine that we've
# moved into the global paused state # moved into the global paused state
# (engines_running==False). # (engines_running==False).
if self.current_wave <= wave: if current_wave <= wave:
new_wave = wave + 1 new_wave = wave + 1
logger.debug("Moving DP wave from %d to %d.", logger.debug("Moving DP wave from %d to %d.",
self.current_wave, new_wave) current_wave, new_wave)
self.current_wave = new_wave current_wave = new_wave
self.engines_running = False engines_running = False
self.stats_changed = True wave_state_changed = True
elif (wave := outputs.start_wave) is not None and ( elif (wave := outputs.start_wave) is not None and (
wave > self.current_wave or wave > current_wave or
(wave == self.current_wave (wave == current_wave and not engines_running)):
and not self.engines_running)):
# 3. The engine received request for a non-current wave # 3. The engine received request for a non-current wave
# so we must ensure that other engines progress to the # so we must ensure that other engines progress to the
# next wave (race condition handling). # next wave (race condition handling).
logger.debug( logger.debug(
"Starting wave %d after notification of " "Starting wave %d after notification of "
"stale wave request from engine.", wave) "stale wave request from engine.", wave)
self.current_wave = wave current_wave = wave
self.engines_running = True engines_running = True
self.stats_changed = True wave_state_changed = True
self._send_start_wave(publish_back, wave, eng_index) self._send_start_wave(publish_back, wave, eng_index)
if wave_state_changed:
message = (None, current_wave, engines_running)
publish_front.send(msgspec.msgpack.encode(message))
@staticmethod @staticmethod
def _send_start_wave(socket: zmq.Socket, wave: int, def _send_start_wave(socket: zmq.Socket, wave: int,
exclude_engine_index: Optional[int]): exclude_engine_index: Optional[int]):
@ -316,6 +350,8 @@ class DPCoordinatorProc:
socket.send_multipart( socket.send_multipart(
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
def _get_engine_counts(self) -> list[list[int]]: def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
"""Return list of [waiting, running] count lists for each engine.""" """Return list of [waiting, running] count lists for each engine."""
if do_copy:
return [copy.copy(e.request_counts) for e in self.engines]
return [e.request_counts for e in self.engines] return [e.request_counts for e in self.engines]

View File

@ -928,7 +928,7 @@ class DPEngineCoreProc(EngineCoreProc):
): ):
# Counts forward-passes of the model so that we can synchronize # Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps. # finished with DP peers every N steps.
self.counter = 0 self.step_counter = 0
self.current_wave = 0 self.current_wave = 0
self.last_counts = (0, 0) self.last_counts = (0, 0)
@ -999,7 +999,9 @@ class DPEngineCoreProc(EngineCoreProc):
counts = self.scheduler.get_request_counts() counts = self.scheduler.get_request_counts()
if counts != self.last_counts: if counts != self.last_counts:
self.last_counts = counts self.last_counts = counts
stats = SchedulerStats(*counts) stats = SchedulerStats(*counts,
step_counter=self.step_counter,
current_wave=self.current_wave)
self.output_queue.put_nowait( self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats))) (-1, EngineCoreOutputs(scheduler_stats=stats)))
@ -1041,15 +1043,16 @@ class DPEngineCoreProc(EngineCoreProc):
self.output_queue.put_nowait( self.output_queue.put_nowait(
(client_index, (client_index,
EngineCoreOutputs(wave_complete=self.current_wave))) EngineCoreOutputs(wave_complete=self.current_wave)))
# Increment wave count and reset step counter.
self.current_wave += 1 self.current_wave += 1
self.step_counter = 0
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 32 steps. # Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1 self.step_counter += 1
if self.counter != 32: if self.step_counter % 32 != 0:
return True return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group, return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished) local_unfinished)

View File

@ -86,11 +86,12 @@ class EngineCoreClient(ABC):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0, client_index: int = 0,
) -> "MPClient": ) -> "MPClient":
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
client_args = (vllm_config, executor_class, log_stats, client_args = (vllm_config, executor_class, log_stats,
client_addresses, client_index) client_addresses, client_count, client_index)
if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb: if parallel_config.data_parallel_external_lb:
# External load balancer - client per DP rank. # External load balancer - client per DP rank.
@ -727,6 +728,7 @@ class AsyncMPClient(MPClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0): client_index: int = 0):
super().__init__( super().__init__(
asyncio_mode=True, asyncio_mode=True,
@ -929,11 +931,12 @@ class DPAsyncMPClient(AsyncMPClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0): client_index: int = 0):
self.current_wave = 0 self.current_wave = 0
super().__init__(vllm_config, executor_class, log_stats, super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index) client_addresses, client_count, client_index)
# List of [waiting, running] pair per engine. # List of [waiting, running] pair per engine.
# Used only by DPLBAsyncMPClient subclass. # Used only by DPLBAsyncMPClient subclass.
@ -1029,7 +1032,11 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf) counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave self.current_wave = wave
self.engines_running = running self.engines_running = running
self.lb_engines = counts[count_slice] if counts is not None:
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
logger.debug("Received counts: %s (%s)", sliced_counts,
count_slice)
resources.stats_update_task = asyncio.create_task( resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task()) run_engine_stats_update_task())
@ -1065,40 +1072,45 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
client_addresses: Optional[dict[str, str]] = None, client_addresses: Optional[dict[str, str]] = None,
client_count: int = 1,
client_index: int = 0): client_index: int = 0):
self.client_count = client_count
# To route aborts to the correct engine. # To route aborts to the correct engine.
self.reqs_in_flight: dict[str, EngineIdentity] = {} self.reqs_in_flight: dict[str, EngineIdentity] = {}
super().__init__(vllm_config, executor_class, log_stats, super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index) client_addresses, client_count, client_index)
assert len(self.core_engines) > 1 assert len(self.core_engines) > 1
self.eng_start_index = (len(self.core_engines) *
self.client_index) // client_count
def get_core_engine_for_request( def get_core_engine_for_request(
self, request: EngineCoreRequest) -> EngineIdentity: self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order. # Engines are in rank order.
current_counts = self.lb_engines
if (eng_index := request.data_parallel_rank) is None: if (eng_index := request.data_parallel_rank) is None:
if not self.lb_engines: if not current_counts:
return self.core_engine return self.core_engine
# TODO use P2C alg for larger DP sizes # TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines) num_engines = len(current_counts)
min_counts = [sys.maxsize, sys.maxsize] min_score = sys.maxsize
eng_index = 0 eng_index = 0
for i in range(num_engines): for i in range(num_engines):
# Start from client_index to help with balancing when engines # Start from client_index to help with balancing when engines
# are empty. # are empty.
idx = (self.client_index + i) % num_engines idx = (self.eng_start_index + i) % num_engines
counts = self.lb_engines[idx] waiting, running = current_counts[idx]
if counts < min_counts: score = waiting * 4 + running
min_counts = counts if score < min_score:
min_score = score
eng_index = idx eng_index = idx
# Adjust local counts for better balancing between stats updates # Increment local waiting count for better balancing between stats
# from the coordinator (which happen every 100ms). # updates from the coordinator (which happen every 100ms).
if min_counts[0]: current_counts[eng_index][0] += self.client_count
min_counts[0] += 1
else:
min_counts[1] += 1
chosen_engine = self.core_engines[eng_index] chosen_engine = self.core_engines[eng_index]
# Record which engine is chosen for this request, to handle aborts. # Record which engine is chosen for this request, to handle aborts.

View File

@ -33,6 +33,10 @@ class SchedulerStats:
num_running_reqs: int = 0 num_running_reqs: int = 0
num_waiting_reqs: int = 0 num_waiting_reqs: int = 0
# These are used for internal DP load-balancing.
step_counter: int = 0
current_wave: int = 0
kv_cache_usage: float = 0.0 kv_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field( prefix_cache_stats: PrefixCacheStats = field(

View File

@ -154,6 +154,7 @@ class APIServerProcessManager:
client_config = { client_config = {
"input_address": in_addr, "input_address": in_addr,
"output_address": out_addr, "output_address": out_addr,
"client_count": num_servers,
"client_index": i "client_index": i
} }
if stats_update_address is not None: if stats_update_address is not None: