[V1] [P/D] Refactor KV Connector Path (#21980)

Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
David Ben-David 2025-08-03 14:03:40 +03:00 committed by GitHub
parent 24d1dffbeb
commit aefeea0fde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 142 additions and 80 deletions

View File

@ -4,7 +4,7 @@ from concurrent.futures import Future
from typing import Optional
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
class DummyModelRunnerOutput(ModelRunnerOutput):
@ -12,8 +12,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None):
self.finished_sending = finished_sending
self.finished_recving = finished_recving
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
def __repr__(self):
return (
f"DummyModelRunnerOutput("
f"finished_sending={self.kv_connector_output.finished_sending},"
f"finished_recving={self.kv_connector_output.finished_recving})")
def test_aggregate_workers_output():
@ -27,6 +35,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
@ -38,6 +47,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
@ -49,6 +59,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
@ -70,6 +81,7 @@ def test_async_aggregate_workers_output():
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
@ -87,6 +99,7 @@ def test_async_aggregate_workers_output():
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
@ -104,5 +117,6 @@ def test_async_aggregate_workers_output():
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
@ -86,7 +86,8 @@ def test_basic_lifecycle():
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_id]
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_remote.request_id]
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output,
@ -72,7 +72,8 @@ def test_basic_lifecycle():
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
@ -309,7 +310,8 @@ def test_full_block_prompt():
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id]
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)

View File

@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
@ -188,8 +188,10 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
),
)

View File

@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
logger = init_logger(__name__)
@ -129,7 +129,7 @@ class KVOutputAggregator:
def aggregate(self,
outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
# aggregate kv_connector_output from all workers
def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int],
@ -143,6 +143,7 @@ class KVOutputAggregator:
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
@ -151,13 +152,10 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank
output = outputs[output_rank]
# set the aggregated finished_sending / finished_recving
# if output.finished_sending/recving is not empty, but the other ranks
# still have unfinished send/recv, we want to set the aggregated
# finished_sending/recving to None until all ranks have finished
# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
)
return output

View File

@ -10,7 +10,7 @@ from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field
from functools import reduce
from typing import Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import msgspec
import torch
@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1
@ -1159,14 +1163,11 @@ class IntermediateTensors:
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and
finished_recving in case of kv transfer.
Each stage also needs to handle its own kv_connector_output.
"""
tensors: dict[str, torch.Tensor]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_output: Optional["KVConnectorOutput"]
def __init__(self, tensors):
# manually define this function, so that

View File

@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
@ -884,7 +884,9 @@ class Scheduler(SchedulerInterface):
self.waiting.remove_requests(stopped_preempted_reqs)
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
if model_runner_output.kv_connector_output:
self._update_from_kv_xfer_finished(
model_runner_output.kv_connector_output)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
@ -1128,7 +1130,7 @@ class Scheduler(SchedulerInterface):
return True
def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput):
kv_connector_output: KVConnectorOutput):
"""
KV Connector: update the scheduler state based on the output.
@ -1139,9 +1141,9 @@ class Scheduler(SchedulerInterface):
scheduler the request during the next step.
"""
# KV Connector:: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()):
for req_id in (kv_connector_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()):
for req_id in (kv_connector_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])

View File

@ -71,6 +71,13 @@ class SamplerOutput:
logprobs_tensors: Optional[LogprobsTensors]
@dataclass
class KVConnectorOutput:
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
# ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead.
@dataclass
@ -104,9 +111,7 @@ class ModelRunnerOutput:
# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_output: Optional[KVConnectorOutput] = None
# req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None
@ -119,6 +124,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
finished_sending=None,
finished_recving=None,
num_nans_in_logits=None)

View File

@ -69,7 +69,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
@ -1423,8 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
kv_connector_output: Optional[KVConnectorOutput],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@ -1459,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_output=kv_connector_output,
)
@torch.inference_mode()
@ -1564,8 +1562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs,
):
self.maybe_setup_kv_connector(scheduler_output)
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
model_output = self.model(
input_ids=input_ids,
@ -1578,10 +1576,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
),
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@ -1597,20 +1591,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
num_scheduled_tokens_np, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1760,8 +1751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)

View File

@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -369,17 +368,20 @@ class Worker(WorkerBase):
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
new_output = copy.copy(new_output)
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
output = new_output
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
assert isinstance(output, ModelRunnerOutput)
return output

View File

@ -4,6 +4,8 @@
Define KV connector functionality mixin for model runners.
"""
import copy
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Generator # noqa: UP035
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
@ -12,7 +14,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
ModelRunnerOutput)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -53,18 +56,60 @@ class KVConnectorModelRunnerMixin:
scheduler_output.finished_req_ids)
return None, None
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
@staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
with set_forward_context(
None, vllm_config
), KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if not finished_sending and not finished_recving:
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
output.kv_connector_output = kv_connector_output
return output
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput"
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
return KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output) if has_kv_transfer_group() else nullcontext()
# This context manager must be used within an active forward context.
# It encapsulates the entire KV conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata()

View File

@ -51,7 +51,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin)
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
@ -1175,9 +1175,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
)
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
))
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.