From 23027e2daff4e0ef81baf4403b3d9eb452491b38 Mon Sep 17 00:00:00 2001 From: CYJiang <86391540+googs1025@users.noreply.github.com> Date: Thu, 5 Jun 2025 06:37:25 +0800 Subject: [PATCH] [Misc] refactor: simplify EngineCoreClient.make_async_mp_client in AsyncLLM (#18817) Signed-off-by: googs1025 --- vllm/v1/engine/async_llm.py | 12 ++---------- vllm/v1/engine/core_client.py | 25 +++++++++++++++++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 61ea3c4c3dab4..089f15aee5b04 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -28,8 +28,7 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, - RayDPClient) +from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) @@ -121,15 +120,8 @@ class AsyncLLM(EngineClient): log_stats=self.log_stats) # EngineCore (starts the engine in background process). - core_client_class: type[AsyncMPClient] - if vllm_config.parallel_config.data_parallel_size == 1: - core_client_class = AsyncMPClient - elif vllm_config.parallel_config.data_parallel_backend == "ray": - core_client_class = RayDPClient - else: - core_client_class = DPAsyncMPClient - self.engine_core = core_client_class( + self.engine_core = EngineCoreClient.make_async_mp_client( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0cd58d01df7f7..d1b0b300dccb5 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -68,18 +68,31 @@ class EngineCoreClient(ABC): "is not currently supported.") if multiprocess_mode and asyncio_mode: - if vllm_config.parallel_config.data_parallel_size > 1: - if vllm_config.parallel_config.data_parallel_backend == "ray": - return RayDPClient(vllm_config, executor_class, log_stats) - return DPAsyncMPClient(vllm_config, executor_class, log_stats) - - return AsyncMPClient(vllm_config, executor_class, log_stats) + return EngineCoreClient.make_async_mp_client( + vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats) + @staticmethod + def make_async_mp_client( + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ) -> "MPClient": + if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_backend == "ray": + return RayDPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + return DPAsyncMPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + return AsyncMPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + @abstractmethod def shutdown(self): ...