mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:45:02 +08:00
[P/D]kv_output_aggregator support P TP > D TP (#23917)
Signed-off-by: LCAIZJ <leichao139636@163.com> Co-authored-by: leichao.lc <leichao.lc@antgroup.com>
This commit is contained in:
parent
a0d8b9738d
commit
8de261b04a
@ -355,3 +355,14 @@ class KVConnectorBase_V1(ABC):
|
|||||||
raise TypeError("get_required_kvcache_layout should not be called "
|
raise TypeError("get_required_kvcache_layout should not be called "
|
||||||
"on the abstract base class")
|
"on the abstract base class")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_finished_count(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Get the count of requests expected to complete send/receive operations
|
||||||
|
via this connector.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: expected sending or receiving completion count.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return None
|
||||||
@ -13,6 +13,7 @@ from typing_extensions import TypeVar
|
|||||||
|
|
||||||
import vllm.platforms
|
import vllm.platforms
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -54,6 +55,7 @@ class ExecutorBase(ABC):
|
|||||||
self._init_executor()
|
self._init_executor()
|
||||||
self.is_sleeping = False
|
self.is_sleeping = False
|
||||||
self.sleeping_tags: set[str] = set()
|
self.sleeping_tags: set[str] = set()
|
||||||
|
self.kv_output_aggregator = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
@ -252,6 +254,11 @@ class ExecutorBase(ABC):
|
|||||||
exception."""
|
exception."""
|
||||||
self.check_health()
|
self.check_health()
|
||||||
|
|
||||||
|
def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None:
|
||||||
|
"""Init KVOutputAggregator"""
|
||||||
|
self.kv_output_aggregator = KVOutputAggregator(
|
||||||
|
finished_count or self.parallel_config.world_size)
|
||||||
|
|
||||||
|
|
||||||
class DistributedExecutorBase(ExecutorBase):
|
class DistributedExecutorBase(ExecutorBase):
|
||||||
"""Abstract superclass of distributed executor implementations."""
|
"""Abstract superclass of distributed executor implementations."""
|
||||||
|
|||||||
@ -128,6 +128,9 @@ class EngineCore:
|
|||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
)
|
)
|
||||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||||
|
if self.scheduler.connector is not None: # type: ignore
|
||||||
|
self.model_executor.init_kv_output_aggregator(
|
||||||
|
self.scheduler.connector.get_finished_count()) # type: ignore
|
||||||
|
|
||||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||||
|
|||||||
@ -26,7 +26,6 @@ from vllm.distributed import (destroy_distributed_environment,
|
|||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||||
MessageQueue)
|
MessageQueue)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
|
||||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||||
get_pp_group, get_tp_group)
|
get_pp_group, get_tp_group)
|
||||||
from vllm.executor.multiproc_worker_utils import (
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
@ -135,8 +134,6 @@ class MultiprocExecutor(Executor):
|
|||||||
|
|
||||||
self.output_rank = self._get_output_rank()
|
self.output_rank = self._get_output_rank()
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
self.kv_output_aggregator = KVOutputAggregator(
|
|
||||||
self.parallel_config.world_size)
|
|
||||||
|
|
||||||
def start_worker_monitor(self):
|
def start_worker_monitor(self):
|
||||||
workers = self.workers
|
workers = self.workers
|
||||||
|
|||||||
@ -51,8 +51,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
|||||||
|
|
||||||
# KV connector setup
|
# KV connector setup
|
||||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||||
self.kv_output_aggregator = KVOutputAggregator(
|
|
||||||
self.parallel_config.world_size)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user