mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 18:35:40 +08:00
[BugFix] Improve internal DP load balancing
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
8ed01e32f7
commit
8177e2f02f
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
@ -66,18 +67,14 @@ class DPCoordinator:
|
|||||||
|
|
||||||
# Assume coordinator is colocated with front-end procs when not in
|
# Assume coordinator is colocated with front-end procs when not in
|
||||||
# either external or hybrid DP LB mode.
|
# either external or hybrid DP LB mode.
|
||||||
|
local_only = not (external_lb or hybrid_lb)
|
||||||
front_publish_address = get_engine_client_zmq_addr(
|
front_publish_address = get_engine_client_zmq_addr(
|
||||||
local_only=not external_lb and not hybrid_lb, host=host)
|
local_only=local_only, host=host)
|
||||||
|
|
||||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||||
|
|
||||||
# When in external LB mode, load stats aren't published, only changes
|
|
||||||
# to request wave / running state, so we don't need to rate-limit the
|
|
||||||
# updates to the front-end proc(s).
|
|
||||||
min_stats_update_interval_ms = 0 if external_lb else 100
|
|
||||||
|
|
||||||
context = get_mp_context()
|
context = get_mp_context()
|
||||||
self.proc: multiprocessing.Process = context.Process(
|
self.proc: multiprocessing.Process = context.Process(
|
||||||
target=DPCoordinatorProc.run_coordinator,
|
target=DPCoordinatorProc.run_coordinator,
|
||||||
@ -87,7 +84,6 @@ class DPCoordinator:
|
|||||||
"front_publish_address": front_publish_address,
|
"front_publish_address": front_publish_address,
|
||||||
"back_output_address": back_output_address,
|
"back_output_address": back_output_address,
|
||||||
"back_publish_address": back_publish_address,
|
"back_publish_address": back_publish_address,
|
||||||
"min_stats_update_interval_ms": min_stats_update_interval_ms,
|
|
||||||
},
|
},
|
||||||
daemon=True)
|
daemon=True)
|
||||||
self.proc.start()
|
self.proc.start()
|
||||||
@ -126,10 +122,6 @@ class DPCoordinatorProc:
|
|||||||
|
|
||||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||||
|
|
||||||
self.current_wave = 0
|
|
||||||
self.engines_running = False
|
|
||||||
self.stats_changed = False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_coordinator(
|
def run_coordinator(
|
||||||
engine_count: int,
|
engine_count: int,
|
||||||
@ -156,6 +148,13 @@ class DPCoordinatorProc:
|
|||||||
|
|
||||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||||
|
|
||||||
|
current_wave = 0
|
||||||
|
engines_running = False
|
||||||
|
|
||||||
|
stats_changed = False
|
||||||
|
last_stats_step = -1
|
||||||
|
last_step_counts: Optional[list[list[int]]] = None
|
||||||
|
|
||||||
with make_zmq_socket(
|
with make_zmq_socket(
|
||||||
path=front_publish_address, # IPC
|
path=front_publish_address, # IPC
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
@ -180,21 +179,33 @@ class DPCoordinatorProc:
|
|||||||
while True:
|
while True:
|
||||||
elapsed = int(time.time() * 1000) - last_publish_time
|
elapsed = int(time.time() * 1000) - last_publish_time
|
||||||
# Send at stats_update_interval_ms interval if the stats have
|
# Send at stats_update_interval_ms interval if the stats have
|
||||||
# changed, or otherwise every 4 seconds.
|
# changed, or otherwise every 5 seconds.
|
||||||
wait_for = (self.stats_update_interval_ms
|
wait_for = (self.stats_update_interval_ms
|
||||||
if self.stats_changed else 4000)
|
if stats_changed else 5000)
|
||||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
|
||||||
|
# Wait at least 50ms to ensure we've received all stats for
|
||||||
|
# the current step.
|
||||||
|
min_timeout = 50 if last_step_counts is None else 0
|
||||||
|
|
||||||
|
events = poller.poll(timeout=max(min_timeout, wait_for -
|
||||||
|
elapsed))
|
||||||
if not events:
|
if not events:
|
||||||
# Poller timeout - publish current stats to front-ends.
|
# Poller timeout - publish current stats to front-ends.
|
||||||
engine_req_counts_list = self._get_engine_counts()
|
if last_step_counts is not None:
|
||||||
to_publish = (engine_req_counts_list, self.current_wave,
|
engine_req_counts_list = last_step_counts
|
||||||
self.engines_running)
|
last_step_counts = None
|
||||||
|
else:
|
||||||
|
engine_req_counts_list = self._get_engine_counts()
|
||||||
|
stats_changed = False
|
||||||
|
|
||||||
|
to_publish = (engine_req_counts_list, current_wave,
|
||||||
|
engines_running)
|
||||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||||
last_publish_time = int(time.time() * 1000)
|
last_publish_time = int(time.time() * 1000)
|
||||||
self.stats_changed = False
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
events = dict(events)
|
events = dict(events)
|
||||||
|
wave_state_changed = False
|
||||||
|
|
||||||
if publish_front in events:
|
if publish_front in events:
|
||||||
buffer = publish_front.recv()
|
buffer = publish_front.recv()
|
||||||
@ -221,7 +232,7 @@ class DPCoordinatorProc:
|
|||||||
# current_wave
|
# current_wave
|
||||||
# we note that 0 is the wave number for the new
|
# we note that 0 is the wave number for the new
|
||||||
# engine
|
# engine
|
||||||
self.engines_running = False
|
engines_running = False
|
||||||
logger.info(
|
logger.info(
|
||||||
"DPCoordinator scaled up from %s to %s "
|
"DPCoordinator scaled up from %s to %s "
|
||||||
"engines", current_count, new_engine_count)
|
"engines", current_count, new_engine_count)
|
||||||
@ -237,15 +248,15 @@ class DPCoordinatorProc:
|
|||||||
# engines are paused, so that we can wake the other
|
# engines are paused, so that we can wake the other
|
||||||
# engines.
|
# engines.
|
||||||
engine_to_exclude, wave = decoded
|
engine_to_exclude, wave = decoded
|
||||||
if not self.engines_running:
|
if not engines_running:
|
||||||
if wave < self.current_wave:
|
if wave < current_wave:
|
||||||
# If the wave number is stale, ensure the message
|
# If the wave number is stale, ensure the message
|
||||||
# is handled by all the engines.
|
# is handled by all the engines.
|
||||||
engine_to_exclude = None
|
engine_to_exclude = None
|
||||||
|
|
||||||
self.engines_running = True
|
engines_running = True
|
||||||
self.stats_changed = True
|
wave_state_changed = True
|
||||||
self._send_start_wave(publish_back, self.current_wave,
|
self._send_start_wave(publish_back, current_wave,
|
||||||
engine_to_exclude)
|
engine_to_exclude)
|
||||||
|
|
||||||
if output_back in events:
|
if output_back in events:
|
||||||
@ -263,36 +274,47 @@ class DPCoordinatorProc:
|
|||||||
# 1. Updated request load stats - update our local
|
# 1. Updated request load stats - update our local
|
||||||
# state with these.
|
# state with these.
|
||||||
stats = self.engines[eng_index].request_counts
|
stats = self.engines[eng_index].request_counts
|
||||||
|
stats_step = scheduler_stats.step_counter
|
||||||
|
if stats_changed and stats_step != last_stats_step:
|
||||||
|
last_step_counts = self._get_engine_counts(
|
||||||
|
do_copy=True)
|
||||||
|
elif stats_step < last_stats_step:
|
||||||
|
logger.warning("Received stats for out-of-order "
|
||||||
|
"step from engine {eng_index}")
|
||||||
stats[0] = scheduler_stats.num_waiting_reqs
|
stats[0] = scheduler_stats.num_waiting_reqs
|
||||||
stats[1] = scheduler_stats.num_running_reqs
|
stats[1] = scheduler_stats.num_running_reqs
|
||||||
self.stats_changed = True
|
last_stats_step = stats_step
|
||||||
|
stats_changed = True
|
||||||
|
|
||||||
if (wave := outputs.wave_complete) is not None:
|
if (wave := outputs.wave_complete) is not None:
|
||||||
# 2. Notification from rank 0 engine that we've
|
# 2. Notification from rank 0 engine that we've
|
||||||
# moved into the global paused state
|
# moved into the global paused state
|
||||||
# (engines_running==False).
|
# (engines_running==False).
|
||||||
if self.current_wave <= wave:
|
if current_wave <= wave:
|
||||||
new_wave = wave + 1
|
new_wave = wave + 1
|
||||||
logger.debug("Moving DP wave from %d to %d.",
|
logger.debug("Moving DP wave from %d to %d.",
|
||||||
self.current_wave, new_wave)
|
current_wave, new_wave)
|
||||||
self.current_wave = new_wave
|
current_wave = new_wave
|
||||||
self.engines_running = False
|
engines_running = False
|
||||||
self.stats_changed = True
|
wave_state_changed = True
|
||||||
elif (wave := outputs.start_wave) is not None and (
|
elif (wave := outputs.start_wave) is not None and (
|
||||||
wave > self.current_wave or
|
wave > current_wave or
|
||||||
(wave == self.current_wave
|
(wave == current_wave and not engines_running)):
|
||||||
and not self.engines_running)):
|
|
||||||
# 3. The engine received request for a non-current wave
|
# 3. The engine received request for a non-current wave
|
||||||
# so we must ensure that other engines progress to the
|
# so we must ensure that other engines progress to the
|
||||||
# next wave (race condition handling).
|
# next wave (race condition handling).
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Starting wave %d after notification of "
|
"Starting wave %d after notification of "
|
||||||
"stale wave request from engine.", wave)
|
"stale wave request from engine.", wave)
|
||||||
self.current_wave = wave
|
current_wave = wave
|
||||||
self.engines_running = True
|
engines_running = True
|
||||||
self.stats_changed = True
|
wave_state_changed = True
|
||||||
self._send_start_wave(publish_back, wave, eng_index)
|
self._send_start_wave(publish_back, wave, eng_index)
|
||||||
|
|
||||||
|
if wave_state_changed:
|
||||||
|
message = (None, current_wave, engines_running)
|
||||||
|
publish_front.send(msgspec.msgpack.encode(message))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||||
exclude_engine_index: Optional[int]):
|
exclude_engine_index: Optional[int]):
|
||||||
@ -305,6 +327,8 @@ class DPCoordinatorProc:
|
|||||||
socket.send_multipart(
|
socket.send_multipart(
|
||||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||||
|
|
||||||
def _get_engine_counts(self) -> list[list[int]]:
|
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
||||||
"""Return list of [waiting, running] count lists for each engine."""
|
"""Return list of [waiting, running] count lists for each engine."""
|
||||||
|
if do_copy:
|
||||||
|
return [copy.copy(e.request_counts) for e in self.engines]
|
||||||
return [e.request_counts for e in self.engines]
|
return [e.request_counts for e in self.engines]
|
||||||
|
|||||||
@ -874,7 +874,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
|
|
||||||
# Counts forward-passes of the model so that we can synchronize
|
# Counts forward-passes of the model so that we can synchronize
|
||||||
# finished with DP peers every N steps.
|
# finished with DP peers every N steps.
|
||||||
self.counter = 0
|
self.step_counter = 0
|
||||||
self.current_wave = 0
|
self.current_wave = 0
|
||||||
self.last_counts = (0, 0)
|
self.last_counts = (0, 0)
|
||||||
|
|
||||||
@ -954,7 +954,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
counts = self.scheduler.get_request_counts()
|
counts = self.scheduler.get_request_counts()
|
||||||
if counts != self.last_counts:
|
if counts != self.last_counts:
|
||||||
self.last_counts = counts
|
self.last_counts = counts
|
||||||
stats = SchedulerStats(*counts)
|
stats = SchedulerStats(*counts, step_counter=self.step_counter)
|
||||||
self.output_queue.put_nowait(
|
self.output_queue.put_nowait(
|
||||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||||
|
|
||||||
@ -1001,10 +1001,10 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||||
|
|
||||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||||
self.counter += 1
|
self.step_counter += 1
|
||||||
if self.counter != 32:
|
if self.step_counter != 32:
|
||||||
return True
|
return True
|
||||||
self.counter = 0
|
self.step_counter = 0
|
||||||
|
|
||||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||||
local_unfinished)
|
local_unfinished)
|
||||||
|
|||||||
@ -970,7 +970,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||||
self.current_wave = wave
|
self.current_wave = wave
|
||||||
self.engines_running = running
|
self.engines_running = running
|
||||||
self.lb_engines = counts[count_slice]
|
if counts is not None:
|
||||||
|
sliced_counts = counts[count_slice]
|
||||||
|
self.lb_engines = sliced_counts
|
||||||
|
#TODO TBD whether to keep this debug log
|
||||||
|
logger.debug("Received counts: %s (%s)",
|
||||||
|
sliced_counts, count_slice)
|
||||||
|
|
||||||
resources.stats_update_task = asyncio.create_task(
|
resources.stats_update_task = asyncio.create_task(
|
||||||
run_engine_stats_update_task())
|
run_engine_stats_update_task())
|
||||||
@ -1019,27 +1024,26 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
|||||||
def get_core_engine_for_request(
|
def get_core_engine_for_request(
|
||||||
self, request: EngineCoreRequest) -> EngineIdentity:
|
self, request: EngineCoreRequest) -> EngineIdentity:
|
||||||
# Engines are in rank order.
|
# Engines are in rank order.
|
||||||
|
current_counts = self.lb_engines
|
||||||
if (eng_index := request.data_parallel_rank) is None:
|
if (eng_index := request.data_parallel_rank) is None:
|
||||||
if not self.lb_engines:
|
if not current_counts:
|
||||||
return self.core_engine
|
return self.core_engine
|
||||||
# TODO use P2C alg for larger DP sizes
|
# TODO use P2C alg for larger DP sizes
|
||||||
num_engines = len(self.lb_engines)
|
num_engines = len(current_counts)
|
||||||
min_counts = [sys.maxsize, sys.maxsize]
|
min_score = sys.maxsize
|
||||||
eng_index = 0
|
eng_index = 0
|
||||||
for i in range(num_engines):
|
for i in range(num_engines):
|
||||||
# Start from client_index to help with balancing when engines
|
# Start from client_index to help with balancing when engines
|
||||||
# are empty.
|
# are empty.
|
||||||
idx = (self.client_index + i) % num_engines
|
idx = (self.client_index + i) % num_engines
|
||||||
counts = self.lb_engines[idx]
|
waiting, running = current_counts[idx]
|
||||||
if counts < min_counts:
|
score = waiting * 4 + running
|
||||||
min_counts = counts
|
if score < min_score:
|
||||||
|
min_score = score
|
||||||
eng_index = idx
|
eng_index = idx
|
||||||
# Adjust local counts for better balancing between stats updates
|
# Increment local waiting count for better balancing between stats
|
||||||
# from the coordinator (which happen every 100ms).
|
# updates from the coordinator (which happen every 100ms).
|
||||||
if min_counts[0]:
|
current_counts[eng_index][0] += 1
|
||||||
min_counts[0] += 1
|
|
||||||
else:
|
|
||||||
min_counts[1] += 1
|
|
||||||
|
|
||||||
chosen_engine = self.core_engines[eng_index]
|
chosen_engine = self.core_engines[eng_index]
|
||||||
# Record which engine is chosen for this request, to handle aborts.
|
# Record which engine is chosen for this request, to handle aborts.
|
||||||
|
|||||||
@ -33,6 +33,8 @@ class SchedulerStats:
|
|||||||
num_running_reqs: int = 0
|
num_running_reqs: int = 0
|
||||||
num_waiting_reqs: int = 0
|
num_waiting_reqs: int = 0
|
||||||
|
|
||||||
|
step_counter: int = 0
|
||||||
|
|
||||||
kv_cache_usage: float = 0.0
|
kv_cache_usage: float = 0.0
|
||||||
|
|
||||||
prefix_cache_stats: PrefixCacheStats = field(
|
prefix_cache_stats: PrefixCacheStats = field(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user