diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b8ec5461f7719..9bf4702320788 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args( from vllm.v1.engine.async_llm import AsyncLLM async_llm: Optional[AsyncLLM] = None + client_count = client_config.pop( + "client_count") if client_config else 1 client_index = client_config.pop( "client_index") if client_config else 0 try: @@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args( enable_log_requests=engine_args.enable_log_requests, disable_log_stats=engine_args.disable_log_stats, client_addresses=client_config, + client_count=client_count, client_index=client_index) # Don't keep the dummy data in memory diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 308ca32105ba9..45f450291ab63 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -57,6 +57,7 @@ class AsyncLLM(EngineClient): start_engine_loop: bool = True, stat_loggers: Optional[list[StatLoggerFactory]] = None, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, ) -> None: """ @@ -120,6 +121,7 @@ class AsyncLLM(EngineClient): executor_class=executor_class, log_stats=self.log_stats, client_addresses=client_addresses, + client_count=client_count, client_index=client_index, ) @@ -156,6 +158,7 @@ class AsyncLLM(EngineClient): enable_log_requests: bool = False, disable_log_stats: bool = False, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, disable_log_requests: bool = True, # Deprecated, will be removed ) -> "AsyncLLM": @@ -176,6 +179,7 @@ class AsyncLLM(EngineClient): log_stats=not disable_log_stats, usage_context=usage_context, client_addresses=client_addresses, + client_count=client_count, client_index=client_index, ) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 8d8d1689e61e3..596edfdbe24f8 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import multiprocessing import time import weakref @@ -65,18 +66,14 @@ class DPCoordinator: # Assume coordinator is colocated with front-end procs when not in # either external or hybrid DP LB mode. + local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=not external_lb and not hybrid_lb, host=host) + local_only=local_only, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) - # When in external LB mode, load stats aren't published, only changes - # to request wave / running state, so we don't need to rate-limit the - # updates to the front-end proc(s). - min_stats_update_interval_ms = 0 if external_lb else 100 - context = get_mp_context() self.proc: multiprocessing.Process = context.Process( target=DPCoordinatorProc.run_coordinator, @@ -86,7 +83,6 @@ class DPCoordinator: "front_publish_address": front_publish_address, "back_output_address": back_output_address, "back_publish_address": back_publish_address, - "min_stats_update_interval_ms": min_stats_update_interval_ms, }, daemon=True) self.proc.start() @@ -125,10 +121,6 @@ class DPCoordinatorProc: self.stats_update_interval_ms = min_stats_update_interval_ms - self.current_wave = 0 - self.engines_running = False - self.stats_changed = False - @staticmethod def run_coordinator( engine_count: int, @@ -155,6 +147,16 @@ class DPCoordinatorProc: decoder = MsgpackDecoder(EngineCoreOutputs) + # For tracking request wave progression. + current_wave = 0 + engines_running = False + + # For tracking request counts for internal load-balancing. + stats_changed = False + last_stats_step = -1 + last_stats_wave = -1 + last_step_counts: Optional[list[list[int]]] = None + with make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, @@ -191,21 +193,33 @@ class DPCoordinatorProc: while True: elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have - # changed, or otherwise every 4 seconds. + # changed, or otherwise every 5 seconds. wait_for = (self.stats_update_interval_ms - if self.stats_changed else 4000) - events = poller.poll(timeout=max(0, wait_for - elapsed)) + if stats_changed else 5000) + + # Wait at least 50ms to ensure we've received all stats for + # the current step. + min_timeout = 50 if last_step_counts is None else 0 + + events = poller.poll(timeout=max(min_timeout, wait_for - + elapsed)) if not events: # Poller timeout - publish current stats to front-ends. - engine_req_counts_list = self._get_engine_counts() - to_publish = (engine_req_counts_list, self.current_wave, - self.engines_running) + if last_step_counts is not None: + engine_req_counts_list = last_step_counts + last_step_counts = None + else: + engine_req_counts_list = self._get_engine_counts() + stats_changed = False + + to_publish = (engine_req_counts_list, current_wave, + engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) - self.stats_changed = False continue events = dict(events) + wave_state_changed = False if publish_front in events: buffer = publish_front.recv() @@ -232,7 +246,7 @@ class DPCoordinatorProc: # current_wave # we note that 0 is the wave number for the new # engine - self.engines_running = False + engines_running = False logger.info( "DPCoordinator scaled up from %s to %s " "engines", current_count, new_engine_count) @@ -248,15 +262,15 @@ class DPCoordinatorProc: # engines are paused, so that we can wake the other # engines. engine_to_exclude, wave = decoded - if not self.engines_running: - if wave < self.current_wave: + if not engines_running: + if wave < current_wave: # If the wave number is stale, ensure the message # is handled by all the engines. engine_to_exclude = None - self.engines_running = True - self.stats_changed = True - self._send_start_wave(publish_back, self.current_wave, + engines_running = True + wave_state_changed = True + self._send_start_wave(publish_back, current_wave, engine_to_exclude) if output_back in events: @@ -274,36 +288,56 @@ class DPCoordinatorProc: # 1. Updated request load stats - update our local # state with these. stats = self.engines[eng_index].request_counts + stats_step = scheduler_stats.step_counter + stats_wave = scheduler_stats.current_wave + if (stats_wave > last_stats_wave + or stats_wave == last_stats_wave + and stats_step > last_stats_step): + if stats_changed: + last_step_counts = self._get_engine_counts( + do_copy=True) + last_stats_step = stats_step + last_stats_wave = stats_wave + elif stats_wave != last_stats_wave or ( + stats_step != last_stats_step): + logger.warning( + "Received stats for out-of-order " + "step (%d, %d) from engine %d (expected " + "> (%d, %d))", stats_wave, stats_step, + eng_index, last_stats_wave, last_stats_step) stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs - self.stats_changed = True + stats_changed = True if (wave := outputs.wave_complete) is not None: # 2. Notification from rank 0 engine that we've # moved into the global paused state # (engines_running==False). - if self.current_wave <= wave: + if current_wave <= wave: new_wave = wave + 1 logger.debug("Moving DP wave from %d to %d.", - self.current_wave, new_wave) - self.current_wave = new_wave - self.engines_running = False - self.stats_changed = True + current_wave, new_wave) + current_wave = new_wave + engines_running = False + wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > self.current_wave or - (wave == self.current_wave - and not self.engines_running)): + wave > current_wave or + (wave == current_wave and not engines_running)): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " "stale wave request from engine.", wave) - self.current_wave = wave - self.engines_running = True - self.stats_changed = True + current_wave = wave + engines_running = True + wave_state_changed = True self._send_start_wave(publish_back, wave, eng_index) + if wave_state_changed: + message = (None, current_wave, engines_running) + publish_front.send(msgspec.msgpack.encode(message)) + @staticmethod def _send_start_wave(socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int]): @@ -316,6 +350,8 @@ class DPCoordinatorProc: socket.send_multipart( (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) - def _get_engine_counts(self) -> list[list[int]]: + def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" + if do_copy: + return [copy.copy(e.request_counts) for e in self.engines] return [e.request_counts for e in self.engines] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0a889b2a0a184..79c47e1028882 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -928,7 +928,7 @@ class DPEngineCoreProc(EngineCoreProc): ): # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. - self.counter = 0 + self.step_counter = 0 self.current_wave = 0 self.last_counts = (0, 0) @@ -999,7 +999,9 @@ class DPEngineCoreProc(EngineCoreProc): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts) + stats = SchedulerStats(*counts, + step_counter=self.step_counter, + current_wave=self.current_wave) self.output_queue.put_nowait( (-1, EngineCoreOutputs(scheduler_stats=stats))) @@ -1041,15 +1043,16 @@ class DPEngineCoreProc(EngineCoreProc): self.output_queue.put_nowait( (client_index, EngineCoreOutputs(wave_complete=self.current_wave))) + # Increment wave count and reset step counter. self.current_wave += 1 + self.step_counter = 0 def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. - self.counter += 1 - if self.counter != 32: + self.step_counter += 1 + if self.step_counter % 32 != 0: return True - self.counter = 0 return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 26985df6f62df..4d30bb6b74466 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -86,11 +86,12 @@ class EngineCoreClient(ABC): executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0, ) -> "MPClient": parallel_config = vllm_config.parallel_config client_args = (vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_external_lb: # External load balancer - client per DP rank. @@ -727,6 +728,7 @@ class AsyncMPClient(MPClient): executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): super().__init__( asyncio_mode=True, @@ -929,11 +931,12 @@ class DPAsyncMPClient(AsyncMPClient): executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): self.current_wave = 0 super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) # List of [waiting, running] pair per engine. # Used only by DPLBAsyncMPClient subclass. @@ -1029,7 +1032,11 @@ class DPAsyncMPClient(AsyncMPClient): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - self.lb_engines = counts[count_slice] + if counts is not None: + sliced_counts = counts[count_slice] + self.lb_engines = sliced_counts + logger.debug("Received counts: %s (%s)", sliced_counts, + count_slice) resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) @@ -1065,40 +1072,45 @@ class DPLBAsyncMPClient(DPAsyncMPClient): executor_class: type[Executor], log_stats: bool, client_addresses: Optional[dict[str, str]] = None, + client_count: int = 1, client_index: int = 0): + self.client_count = client_count + # To route aborts to the correct engine. self.reqs_in_flight: dict[str, EngineIdentity] = {} super().__init__(vllm_config, executor_class, log_stats, - client_addresses, client_index) + client_addresses, client_count, client_index) assert len(self.core_engines) > 1 + self.eng_start_index = (len(self.core_engines) * + self.client_index) // client_count + def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. + current_counts = self.lb_engines if (eng_index := request.data_parallel_rank) is None: - if not self.lb_engines: + if not current_counts: return self.core_engine # TODO use P2C alg for larger DP sizes - num_engines = len(self.lb_engines) - min_counts = [sys.maxsize, sys.maxsize] + num_engines = len(current_counts) + min_score = sys.maxsize eng_index = 0 for i in range(num_engines): # Start from client_index to help with balancing when engines # are empty. - idx = (self.client_index + i) % num_engines - counts = self.lb_engines[idx] - if counts < min_counts: - min_counts = counts + idx = (self.eng_start_index + i) % num_engines + waiting, running = current_counts[idx] + score = waiting * 4 + running + if score < min_score: + min_score = score eng_index = idx - # Adjust local counts for better balancing between stats updates - # from the coordinator (which happen every 100ms). - if min_counts[0]: - min_counts[0] += 1 - else: - min_counts[1] += 1 + # Increment local waiting count for better balancing between stats + # updates from the coordinator (which happen every 100ms). + current_counts[eng_index][0] += self.client_count chosen_engine = self.core_engines[eng_index] # Record which engine is chosen for this request, to handle aborts. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 1eb10ccb6c493..9a80460261e02 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -33,6 +33,10 @@ class SchedulerStats: num_running_reqs: int = 0 num_waiting_reqs: int = 0 + # These are used for internal DP load-balancing. + step_counter: int = 0 + current_wave: int = 0 + kv_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field( diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index c74d8c543f76c..d0175695c1d0f 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -154,6 +154,7 @@ class APIServerProcessManager: client_config = { "input_address": in_addr, "output_address": out_addr, + "client_count": num_servers, "client_index": i } if stats_update_address is not None: