[P/D] Dynamic kv_output_aggregator collect size (#26734)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-22 18:07:58 +02:00 committed by GitHub
parent 58fab50d82
commit 4dfdb821c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 90 additions and 19 deletions

View File

@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model # done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3) aggregator = KVOutputAggregator(expected_finished_count=3)
# Create stats for multiple workers with different transfer patterns # Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats() worker1_stats = NixlKVConnectorStats()
@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
KVOutputAggregator (used by MultiprocExecutor). KVOutputAggregator (used by MultiprocExecutor).
""" """
aggregator = KVOutputAggregator(world_size=3) aggregator = KVOutputAggregator(expected_finished_count=3)
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
finished_sending: set[str] | None = None, finished_sending: set[str] | None = None,
finished_recving: set[str] | None = None, finished_recving: set[str] | None = None,
invalid_block_ids: set[int] | None = None, invalid_block_ids: set[int] | None = None,
expected_finished_count: int = 0,
): ):
self.kv_connector_output = KVConnectorOutput( self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(), invalid_block_ids=invalid_block_ids or set(),
expected_finished_count=expected_finished_count,
) )
def __repr__(self): def __repr__(self):
@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def test_aggregate_workers_output(): def test_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2) aggregator = KVOutputAggregator(expected_finished_count=2)
output1 = DummyModelRunnerOutput() output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput() output2 = DummyModelRunnerOutput()
@ -85,7 +87,7 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output(): def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2) aggregator = KVOutputAggregator(expected_finished_count=2)
future1: Future[DummyModelRunnerOutput] = Future() future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future() future2: Future[DummyModelRunnerOutput] = Future()
@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_sending is None assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"} assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5} assert aggregated.invalid_block_ids == {3, 4, 5}
def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4)
assert aggregator._expected_finished_count == 4
# Some request with default expected finished requests
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1])
# still expecting to collect from 4 workers
assert aggregator._send_remaining_count["req1"] == 3
assert not aggregated.kv_connector_output.finished_sending
assert not aggregated.kv_connector_output.finished_recving
# Workers discover and find that in this setup they only need to
# collect from 2
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, expected_finished_count=2
)
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, expected_finished_count=2
)
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
# Req2 only needs 2 acks
aggregated = aggregator.aggregate([output1, output2, output3])
assert aggregated.kv_connector_output.expected_finished_count == 2
assert not aggregated.kv_connector_output.finished_sending
# Req2 is finished
assert "req2" not in aggregator._recv_remaining_count
assert aggregated.kv_connector_output.finished_recving == {"req2"}
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
# NOTE: This is to showcase dynamic update. Workers are responsible for
# ensuring "req1" termination in this case
assert aggregator._send_remaining_count["req1"] == 2

View File

@ -4,10 +4,9 @@
KV cache helper for store. KV cache helper for store.
""" """
from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import CancelledError, Future from concurrent.futures import CancelledError, Future
from typing import Literal, cast from typing import TYPE_CHECKING, Literal, cast
import torch import torch
@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -124,11 +126,16 @@ class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single """Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler.""" output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int): def __init__(self, expected_finished_count: int):
# Complete transfer tracker. Used to track finished requests # Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers] # [req_id -> n_remaining_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size) self._recv_remaining_count = dict[str, int]()
self._send_remaining_count = defaultdict[str, int](lambda: world_size) self._send_remaining_count = dict[str, int]()
self._expected_finished_count = expected_finished_count
@classmethod
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
return cls(connector.get_finished_count() or world_size)
def aggregate( def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0 self, outputs: list[ModelRunnerOutput], output_rank: int = 0
@ -141,7 +148,10 @@ class KVOutputAggregator:
finished_set: set[str], finished_set: set[str],
) -> None: ) -> None:
for req_id in req_ids or (): for req_id in req_ids or ():
remaining_count_dict[req_id] -= 1 remaining_count = remaining_count_dict.get(
req_id, self._expected_finished_count
)
remaining_count_dict[req_id] = remaining_count - 1
if remaining_count_dict[req_id] == 0: if remaining_count_dict[req_id] == 0:
finished_set.add(req_id) finished_set.add(req_id)
del remaining_count_dict[req_id] del remaining_count_dict[req_id]
@ -154,6 +164,19 @@ class KVOutputAggregator:
kv_output = model_runner_output.kv_connector_output kv_output = model_runner_output.kv_connector_output
if not kv_output: if not kv_output:
continue continue
# Allow the worker to dynamically update the expected number of
# finished sending/recving for new requests.
if (
kv_output.expected_finished_count > 0
and kv_output.expected_finished_count != self._expected_finished_count
):
logger.debug(
"Expected finished requests updated from %d to %d",
self._expected_finished_count,
kv_output.expected_finished_count,
)
self._expected_finished_count = kv_output.expected_finished_count
update_finished_set( update_finished_set(
kv_output.finished_sending, self._send_remaining_count, finished_sending kv_output.finished_sending, self._send_remaining_count, finished_sending
) )
@ -186,6 +209,7 @@ class KVOutputAggregator:
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None, kv_connector_stats=aggregated_kv_connector_stats or None,
invalid_block_ids=invalid_block_ids, invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
) )
return output return output

View File

@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC):
def get_finished_count(self) -> int | None: def get_finished_count(self) -> int | None:
""" """
Get the count of requests expected to complete send/receive operations Get the count of requests expected to complete send/receive operations
via this connector. via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns: Returns:
int: expected sending or receiving completion count. int: expected sending or receiving completion count.

View File

@ -160,9 +160,7 @@ class EngineCore:
) )
self.use_spec_decode = vllm_config.speculative_config is not None self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore if self.scheduler.connector is not None: # type: ignore
self.model_executor.init_kv_output_aggregator( self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
self.scheduler.connector.get_finished_count() # type: ignore
)
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config( self.mm_receiver_cache = engine_receiver_cache_from_config(

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from functools import cached_property from functools import cached_property
from typing import Literal, TypeVar, overload from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__) logger = init_logger(__name__)
_R = TypeVar("_R") _R = TypeVar("_R")
@ -233,10 +236,10 @@ class Executor(ABC):
"""Shutdown the executor.""" """Shutdown the executor."""
self.collective_rpc("shutdown") self.collective_rpc("shutdown")
def init_kv_output_aggregator(self, finished_count: int | None) -> None: def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
"""Init KVOutputAggregator""" """Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator( self.kv_output_aggregator = KVOutputAggregator.from_connector(
finished_count or self.parallel_config.world_size connector, self.parallel_config.world_size
) )
@cached_property # Avoid unnecessary RPC calls @cached_property # Avoid unnecessary RPC calls

View File

@ -86,8 +86,14 @@ class KVConnectorOutput:
finished_recving: set[str] | None = None finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None kv_connector_stats: KVConnectorStats | None = None
# IDs of externally computed KV blocks that failed to load. # IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them. # Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set) invalid_block_ids: set[int] = field(default_factory=set)
# Configuration describing how many finished sending/receiving
# notifications should be expected for each request. This allows
# handshake-based connectors like Nixl to update the KVOutputAggregator.
# It captures a static setup info and should almost always remain constant
# for a given connector after discovery. Default value entails no change.
expected_finished_count: int = 0
def is_empty(self): def is_empty(self):
return ( return (