[P/D] Dynamic kv_output_aggregator collect size (#26734)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-10-22 18:07:58 +02:00 committed by GitHub
parent 58fab50d82
commit 4dfdb821c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 90 additions and 19 deletions

View File

@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats()
@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator = KVOutputAggregator(world_size=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
from dataclasses import dataclass

View File

@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
finished_sending: set[str] | None = None,
finished_recving: set[str] | None = None,
invalid_block_ids: set[int] | None = None,
expected_finished_count: int = 0,
):
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
expected_finished_count=expected_finished_count,
)
def __repr__(self):
@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def test_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2)
aggregator = KVOutputAggregator(expected_finished_count=2)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
@ -85,7 +87,7 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2)
aggregator = KVOutputAggregator(expected_finished_count=2)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}
def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4)
assert aggregator._expected_finished_count == 4
# Some request with default expected finished requests
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1])
# still expecting to collect from 4 workers
assert aggregator._send_remaining_count["req1"] == 3
assert not aggregated.kv_connector_output.finished_sending
assert not aggregated.kv_connector_output.finished_recving
# Workers discover and find that in this setup they only need to
# collect from 2
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, expected_finished_count=2
)
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, expected_finished_count=2
)
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
# Req2 only needs 2 acks
aggregated = aggregator.aggregate([output1, output2, output3])
assert aggregated.kv_connector_output.expected_finished_count == 2
assert not aggregated.kv_connector_output.finished_sending
# Req2 is finished
assert "req2" not in aggregator._recv_remaining_count
assert aggregated.kv_connector_output.finished_recving == {"req2"}
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
# NOTE: This is to showcase dynamic update. Workers are responsible for
# ensuring "req1" termination in this case
assert aggregator._send_remaining_count["req1"] == 2

View File

@ -4,10 +4,9 @@
KV cache helper for store.
"""
from collections import defaultdict
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import Literal, cast
from typing import TYPE_CHECKING, Literal, cast
import torch
@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__)
@ -124,11 +126,16 @@ class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
def __init__(self, expected_finished_count: int):
# Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
self._recv_remaining_count = dict[str, int]()
self._send_remaining_count = dict[str, int]()
self._expected_finished_count = expected_finished_count
@classmethod
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
return cls(connector.get_finished_count() or world_size)
def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
@ -141,7 +148,10 @@ class KVOutputAggregator:
finished_set: set[str],
) -> None:
for req_id in req_ids or ():
remaining_count_dict[req_id] -= 1
remaining_count = remaining_count_dict.get(
req_id, self._expected_finished_count
)
remaining_count_dict[req_id] = remaining_count - 1
if remaining_count_dict[req_id] == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
@ -154,6 +164,19 @@ class KVOutputAggregator:
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
# Allow the worker to dynamically update the expected number of
# finished sending/recving for new requests.
if (
kv_output.expected_finished_count > 0
and kv_output.expected_finished_count != self._expected_finished_count
):
logger.debug(
"Expected finished requests updated from %d to %d",
self._expected_finished_count,
kv_output.expected_finished_count,
)
self._expected_finished_count = kv_output.expected_finished_count
update_finished_set(
kv_output.finished_sending, self._send_remaining_count, finished_sending
)
@ -186,6 +209,7 @@ class KVOutputAggregator:
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)
return output

View File

@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC):
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
via this connector.
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
int: expected sending or receiving completion count.

View File

@ -160,9 +160,7 @@ class EngineCore:
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
self.model_executor.init_kv_output_aggregator(
self.scheduler.connector.get_finished_count() # type: ignore
)
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config(

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from concurrent.futures import Future
from functools import cached_property
from typing import Literal, TypeVar, overload
from typing import TYPE_CHECKING, Literal, TypeVar, overload
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
logger = init_logger(__name__)
_R = TypeVar("_R")
@ -233,10 +236,10 @@ class Executor(ABC):
"""Shutdown the executor."""
self.collective_rpc("shutdown")
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size
self.kv_output_aggregator = KVOutputAggregator.from_connector(
connector, self.parallel_config.world_size
)
@cached_property # Avoid unnecessary RPC calls

View File

@ -86,8 +86,14 @@ class KVConnectorOutput:
finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them.
# Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set)
# Configuration describing how many finished sending/receiving
# notifications should be expected for each request. This allows
# handshake-based connectors like Nixl to update the KVOutputAggregator.
# It captures a static setup info and should almost always remain constant
# for a given connector after discovery. Default value entails no change.
expected_finished_count: int = 0
def is_empty(self):
return (