mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 02:05:01 +08:00
[P/D] Dynamic kv_output_aggregator collect size (#26734)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
58fab50d82
commit
4dfdb821c8
@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
|
||||
# done in MultiprocExecutor.execute_model
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Create stats for multiple workers with different transfer patterns
|
||||
worker1_stats = NixlKVConnectorStats()
|
||||
@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
expected_finished_count: int = 0,
|
||||
):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
@ -85,7 +87,7 @@ def test_aggregate_workers_output():
|
||||
|
||||
|
||||
def test_async_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
future1: Future[DummyModelRunnerOutput] = Future()
|
||||
future2: Future[DummyModelRunnerOutput] = Future()
|
||||
@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
def test_aggregate_workers_output_with_expected_finished_count():
|
||||
# We create the aggregator expecting to collect from 4 workers
|
||||
aggregator = KVOutputAggregator(expected_finished_count=4)
|
||||
assert aggregator._expected_finished_count == 4
|
||||
# Some request with default expected finished requests
|
||||
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
aggregated = aggregator.aggregate([output1])
|
||||
# still expecting to collect from 4 workers
|
||||
assert aggregator._send_remaining_count["req1"] == 3
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
assert not aggregated.kv_connector_output.finished_recving
|
||||
|
||||
# Workers discover and find that in this setup they only need to
|
||||
# collect from 2
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, expected_finished_count=2
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, expected_finished_count=2
|
||||
)
|
||||
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
|
||||
# Req2 only needs 2 acks
|
||||
aggregated = aggregator.aggregate([output1, output2, output3])
|
||||
assert aggregated.kv_connector_output.expected_finished_count == 2
|
||||
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
|
||||
# Req2 is finished
|
||||
assert "req2" not in aggregator._recv_remaining_count
|
||||
assert aggregated.kv_connector_output.finished_recving == {"req2"}
|
||||
|
||||
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
|
||||
# NOTE: This is to showcase dynamic update. Workers are responsible for
|
||||
# ensuring "req1" termination in this case
|
||||
assert aggregator._send_remaining_count["req1"] == 2
|
||||
@ -4,10 +4,9 @@
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import CancelledError, Future
|
||||
from typing import Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
import torch
|
||||
|
||||
@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -124,11 +126,16 @@ class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, world_size: int):
|
||||
def __init__(self, expected_finished_count: int):
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
self._recv_remaining_count = dict[str, int]()
|
||||
self._send_remaining_count = dict[str, int]()
|
||||
self._expected_finished_count = expected_finished_count
|
||||
|
||||
@classmethod
|
||||
def from_connector(cls, connector: "KVConnectorBase", world_size: int):
|
||||
return cls(connector.get_finished_count() or world_size)
|
||||
|
||||
def aggregate(
|
||||
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
|
||||
@ -141,7 +148,10 @@ class KVOutputAggregator:
|
||||
finished_set: set[str],
|
||||
) -> None:
|
||||
for req_id in req_ids or ():
|
||||
remaining_count_dict[req_id] -= 1
|
||||
remaining_count = remaining_count_dict.get(
|
||||
req_id, self._expected_finished_count
|
||||
)
|
||||
remaining_count_dict[req_id] = remaining_count - 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
@ -154,6 +164,19 @@ class KVOutputAggregator:
|
||||
kv_output = model_runner_output.kv_connector_output
|
||||
if not kv_output:
|
||||
continue
|
||||
# Allow the worker to dynamically update the expected number of
|
||||
# finished sending/recving for new requests.
|
||||
if (
|
||||
kv_output.expected_finished_count > 0
|
||||
and kv_output.expected_finished_count != self._expected_finished_count
|
||||
):
|
||||
logger.debug(
|
||||
"Expected finished requests updated from %d to %d",
|
||||
self._expected_finished_count,
|
||||
kv_output.expected_finished_count,
|
||||
)
|
||||
self._expected_finished_count = kv_output.expected_finished_count
|
||||
|
||||
update_finished_set(
|
||||
kv_output.finished_sending, self._send_remaining_count, finished_sending
|
||||
)
|
||||
@ -186,6 +209,7 @@ class KVOutputAggregator:
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC):
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector.
|
||||
via this connector. This method is used to initialize the
|
||||
KVOutputAggregator, overwriting the default world_size.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
|
||||
@ -160,9 +160,7 @@ class EngineCore:
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
if self.scheduler.connector is not None: # type: ignore
|
||||
self.model_executor.init_kv_output_aggregator(
|
||||
self.scheduler.connector.get_finished_count() # type: ignore
|
||||
)
|
||||
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
|
||||
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||
|
||||
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Literal, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Literal, TypeVar, overload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R")
|
||||
@ -233,10 +236,10 @@ class Executor(ABC):
|
||||
"""Shutdown the executor."""
|
||||
self.collective_rpc("shutdown")
|
||||
|
||||
def init_kv_output_aggregator(self, finished_count: int | None) -> None:
|
||||
def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None:
|
||||
"""Init KVOutputAggregator"""
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
finished_count or self.parallel_config.world_size
|
||||
self.kv_output_aggregator = KVOutputAggregator.from_connector(
|
||||
connector, self.parallel_config.world_size
|
||||
)
|
||||
|
||||
@cached_property # Avoid unnecessary RPC calls
|
||||
|
||||
@ -86,8 +86,14 @@ class KVConnectorOutput:
|
||||
finished_recving: set[str] | None = None
|
||||
kv_connector_stats: KVConnectorStats | None = None
|
||||
# IDs of externally computed KV blocks that failed to load.
|
||||
# Requests referencing these blocks should be rescheduled to recompute them.
|
||||
# Requests referencing these blocks should be rescheduled to recompute them
|
||||
invalid_block_ids: set[int] = field(default_factory=set)
|
||||
# Configuration describing how many finished sending/receiving
|
||||
# notifications should be expected for each request. This allows
|
||||
# handshake-based connectors like Nixl to update the KVOutputAggregator.
|
||||
# It captures a static setup info and should almost always remain constant
|
||||
# for a given connector after discovery. Default value entails no change.
|
||||
expected_finished_count: int = 0
|
||||
|
||||
def is_empty(self):
|
||||
return (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user