From 2cf8ff64c77f14a1aba2b5147ec7405e7ffc89cd Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Jul 2025 20:17:54 +0000 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/v1/engine/core_client.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 13d0ee19d16bb..aea0e1e6558d5 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -430,17 +430,22 @@ class MPClient(EngineCoreClient): dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank dp_local_size = parallel_config.data_parallel_size_local - external_dp_lb = parallel_config.data_parallel_external_lb - offline_mode = parallel_config.data_parallel_rank_local is not None - engine_ranks = ([dp_rank] if (offline_mode or external_dp_lb) - else range(dp_rank, dp_rank + dp_local_size)) + + # If External DPLB, Client manages local EngineCores. + # If Internal DPLB, Client manages local+remote EngineCores. + num_ranks = (dp_local_size + if parallel_config.data_parallel_external_lb else + dp_size) + self.engine_ranks_managed = ([dp_rank] if offline_mode else range( + dp_rank, dp_rank + num_ranks)) assert parallel_config.data_parallel_size_local <= len( - engine_ranks) + self.engine_ranks_managed) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - index.to_bytes(2, "little") for index in engine_ranks + rank.to_bytes(2, "little") + for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -895,8 +900,6 @@ class DPAsyncMPClient(AsyncMPClient): return assert self.stats_update_address is not None - dp_start_rank = self.vllm_config.parallel_config.data_parallel_rank - dp_end_rank = dp_start_rank + self.vllm_config.parallel_config.data_parallel_size_local async def run_engine_stats_update_task(): with make_zmq_socket(self.ctx, self.stats_update_address, @@ -961,9 +964,9 @@ class DPAsyncMPClient(AsyncMPClient): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - # NOTE: counts includes num running for all global - # EngineCores, so need to slide for the local ones. - self.lb_engines = counts[dp_start_rank:dp_end_rank] + # NOTE: counts include all global Cores. Slice + # to get get the Core's managed by this client. + self.lb_engines = counts[self.engine_ranks_managed] resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task())