diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7e0b927c5b78..70c07eac6304 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -355,3 +355,14 @@ class KVConnectorBase_V1(ABC): raise TypeError("get_required_kvcache_layout should not be called " "on the abstract base class") return None + + def get_finished_count(self) -> Optional[int]: + """ + Get the count of requests expected to complete send/receive operations + via this connector. + + Returns: + int: expected sending or receiving completion count. + """ + + return None \ No newline at end of file diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a3c1d79a58b2..d18bef1256af 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -13,6 +13,7 @@ from typing_extensions import TypeVar import vllm.platforms from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -54,6 +55,7 @@ class ExecutorBase(ABC): self._init_executor() self.is_sleeping = False self.sleeping_tags: set[str] = set() + self.kv_output_aggregator = None @abstractmethod def _init_executor(self) -> None: @@ -252,6 +254,11 @@ class ExecutorBase(ABC): exception.""" self.check_health() + def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None: + """Init KVOutputAggregator""" + self.kv_output_aggregator = KVOutputAggregator( + finished_count or self.parallel_config.world_size) + class DistributedExecutorBase(ExecutorBase): """Abstract superclass of distributed executor implementations.""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 64a67f3b438e..a022e9c0d705 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -128,6 +128,9 @@ class EngineCore: log_stats=self.log_stats, ) self.use_spec_decode = vllm_config.speculative_config is not None + if self.scheduler.connector is not None: # type: ignore + self.model_executor.init_kv_output_aggregator( + self.scheduler.connector.get_finished_count()) # type: ignore self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = engine_receiver_cache_from_config( diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index f566c9aee0c5..3aa373f12b60 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -26,7 +26,6 @@ from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) -from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, get_pp_group, get_tp_group) from vllm.executor.multiproc_worker_utils import ( @@ -135,8 +134,6 @@ class MultiprocExecutor(Executor): self.output_rank = self._get_output_rank() self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 59c9b56625a9..aadb5fd1dddd 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -51,8 +51,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor): # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None - self.kv_output_aggregator = KVOutputAggregator( - self.parallel_config.world_size) @property def max_concurrent_batches(self) -> int: