diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index fc45eea3a73cf..5ebee31ebaedc 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import multiprocessing import time import weakref @@ -66,18 +67,14 @@ class DPCoordinator: # Assume coordinator is colocated with front-end procs when not in # either external or hybrid DP LB mode. + local_only = not (external_lb or hybrid_lb) front_publish_address = get_engine_client_zmq_addr( - local_only=not external_lb and not hybrid_lb, host=host) + local_only=local_only, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_output_address = get_engine_client_zmq_addr(local_only_eng, host) - # When in external LB mode, load stats aren't published, only changes - # to request wave / running state, so we don't need to rate-limit the - # updates to the front-end proc(s). - min_stats_update_interval_ms = 0 if external_lb else 100 - context = get_mp_context() self.proc: multiprocessing.Process = context.Process( target=DPCoordinatorProc.run_coordinator, @@ -87,7 +84,6 @@ class DPCoordinator: "front_publish_address": front_publish_address, "back_output_address": back_output_address, "back_publish_address": back_publish_address, - "min_stats_update_interval_ms": min_stats_update_interval_ms, }, daemon=True) self.proc.start() @@ -126,10 +122,6 @@ class DPCoordinatorProc: self.stats_update_interval_ms = min_stats_update_interval_ms - self.current_wave = 0 - self.engines_running = False - self.stats_changed = False - @staticmethod def run_coordinator( engine_count: int, @@ -156,6 +148,13 @@ class DPCoordinatorProc: decoder = MsgpackDecoder(EngineCoreOutputs) + current_wave = 0 + engines_running = False + + stats_changed = False + last_stats_step = -1 + last_step_counts: Optional[list[list[int]]] = None + with make_zmq_socket( path=front_publish_address, # IPC ctx=self.ctx, @@ -180,21 +179,33 @@ class DPCoordinatorProc: while True: elapsed = int(time.time() * 1000) - last_publish_time # Send at stats_update_interval_ms interval if the stats have - # changed, or otherwise every 4 seconds. + # changed, or otherwise every 5 seconds. wait_for = (self.stats_update_interval_ms - if self.stats_changed else 4000) - events = poller.poll(timeout=max(0, wait_for - elapsed)) + if stats_changed else 5000) + + # Wait at least 50ms to ensure we've received all stats for + # the current step. + min_timeout = 50 if last_step_counts is None else 0 + + events = poller.poll(timeout=max(min_timeout, wait_for - + elapsed)) if not events: # Poller timeout - publish current stats to front-ends. - engine_req_counts_list = self._get_engine_counts() - to_publish = (engine_req_counts_list, self.current_wave, - self.engines_running) + if last_step_counts is not None: + engine_req_counts_list = last_step_counts + last_step_counts = None + else: + engine_req_counts_list = self._get_engine_counts() + stats_changed = False + + to_publish = (engine_req_counts_list, current_wave, + engines_running) publish_front.send(msgspec.msgpack.encode(to_publish)) last_publish_time = int(time.time() * 1000) - self.stats_changed = False continue events = dict(events) + wave_state_changed = False if publish_front in events: buffer = publish_front.recv() @@ -221,7 +232,7 @@ class DPCoordinatorProc: # current_wave # we note that 0 is the wave number for the new # engine - self.engines_running = False + engines_running = False logger.info( "DPCoordinator scaled up from %s to %s " "engines", current_count, new_engine_count) @@ -237,15 +248,15 @@ class DPCoordinatorProc: # engines are paused, so that we can wake the other # engines. engine_to_exclude, wave = decoded - if not self.engines_running: - if wave < self.current_wave: + if not engines_running: + if wave < current_wave: # If the wave number is stale, ensure the message # is handled by all the engines. engine_to_exclude = None - self.engines_running = True - self.stats_changed = True - self._send_start_wave(publish_back, self.current_wave, + engines_running = True + wave_state_changed = True + self._send_start_wave(publish_back, current_wave, engine_to_exclude) if output_back in events: @@ -263,36 +274,47 @@ class DPCoordinatorProc: # 1. Updated request load stats - update our local # state with these. stats = self.engines[eng_index].request_counts + stats_step = scheduler_stats.step_counter + if stats_changed and stats_step != last_stats_step: + last_step_counts = self._get_engine_counts( + do_copy=True) + elif stats_step < last_stats_step: + logger.warning("Received stats for out-of-order " + "step from engine {eng_index}") stats[0] = scheduler_stats.num_waiting_reqs stats[1] = scheduler_stats.num_running_reqs - self.stats_changed = True + last_stats_step = stats_step + stats_changed = True if (wave := outputs.wave_complete) is not None: # 2. Notification from rank 0 engine that we've # moved into the global paused state # (engines_running==False). - if self.current_wave <= wave: + if current_wave <= wave: new_wave = wave + 1 logger.debug("Moving DP wave from %d to %d.", - self.current_wave, new_wave) - self.current_wave = new_wave - self.engines_running = False - self.stats_changed = True + current_wave, new_wave) + current_wave = new_wave + engines_running = False + wave_state_changed = True elif (wave := outputs.start_wave) is not None and ( - wave > self.current_wave or - (wave == self.current_wave - and not self.engines_running)): + wave > current_wave or + (wave == current_wave and not engines_running)): # 3. The engine received request for a non-current wave # so we must ensure that other engines progress to the # next wave (race condition handling). logger.debug( "Starting wave %d after notification of " "stale wave request from engine.", wave) - self.current_wave = wave - self.engines_running = True - self.stats_changed = True + current_wave = wave + engines_running = True + wave_state_changed = True self._send_start_wave(publish_back, wave, eng_index) + if wave_state_changed: + message = (None, current_wave, engines_running) + publish_front.send(msgspec.msgpack.encode(message)) + @staticmethod def _send_start_wave(socket: zmq.Socket, wave: int, exclude_engine_index: Optional[int]): @@ -305,6 +327,8 @@ class DPCoordinatorProc: socket.send_multipart( (EngineCoreRequestType.START_DP_WAVE.value, wave_encoded)) - def _get_engine_counts(self) -> list[list[int]]: + def _get_engine_counts(self, do_copy=False) -> list[list[int]]: """Return list of [waiting, running] count lists for each engine.""" + if do_copy: + return [copy.copy(e.request_counts) for e in self.engines] return [e.request_counts for e in self.engines] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 88c511606d7c5..a1642a9990333 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -874,7 +874,7 @@ class DPEngineCoreProc(EngineCoreProc): # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. - self.counter = 0 + self.step_counter = 0 self.current_wave = 0 self.last_counts = (0, 0) @@ -954,7 +954,7 @@ class DPEngineCoreProc(EngineCoreProc): counts = self.scheduler.get_request_counts() if counts != self.last_counts: self.last_counts = counts - stats = SchedulerStats(*counts) + stats = SchedulerStats(*counts, step_counter=self.step_counter) self.output_queue.put_nowait( (-1, EngineCoreOutputs(scheduler_stats=stats))) @@ -1001,10 +1001,10 @@ class DPEngineCoreProc(EngineCoreProc): def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: # Optimization - only perform finish-sync all-reduce every 32 steps. - self.counter += 1 - if self.counter != 32: + self.step_counter += 1 + if self.step_counter != 32: return True - self.counter = 0 + self.step_counter = 0 return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 69ae3690d00e9..7f712a0acb81e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -970,7 +970,12 @@ class DPAsyncMPClient(AsyncMPClient): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - self.lb_engines = counts[count_slice] + if counts is not None: + sliced_counts = counts[count_slice] + self.lb_engines = sliced_counts + #TODO TBD whether to keep this debug log + logger.debug("Received counts: %s (%s)", + sliced_counts, count_slice) resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) @@ -1019,27 +1024,26 @@ class DPLBAsyncMPClient(DPAsyncMPClient): def get_core_engine_for_request( self, request: EngineCoreRequest) -> EngineIdentity: # Engines are in rank order. + current_counts = self.lb_engines if (eng_index := request.data_parallel_rank) is None: - if not self.lb_engines: + if not current_counts: return self.core_engine # TODO use P2C alg for larger DP sizes - num_engines = len(self.lb_engines) - min_counts = [sys.maxsize, sys.maxsize] + num_engines = len(current_counts) + min_score = sys.maxsize eng_index = 0 for i in range(num_engines): # Start from client_index to help with balancing when engines # are empty. idx = (self.client_index + i) % num_engines - counts = self.lb_engines[idx] - if counts < min_counts: - min_counts = counts + waiting, running = current_counts[idx] + score = waiting * 4 + running + if score < min_score: + min_score = score eng_index = idx - # Adjust local counts for better balancing between stats updates - # from the coordinator (which happen every 100ms). - if min_counts[0]: - min_counts[0] += 1 - else: - min_counts[1] += 1 + # Increment local waiting count for better balancing between stats + # updates from the coordinator (which happen every 100ms). + current_counts[eng_index][0] += 1 chosen_engine = self.core_engines[eng_index] # Record which engine is chosen for this request, to handle aborts. diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 1eb10ccb6c493..5e2f9e739280a 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -33,6 +33,8 @@ class SchedulerStats: num_running_reqs: int = 0 num_waiting_reqs: int = 0 + step_counter: int = 0 + kv_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field(