[V1] [P/D] Refactor KV Connector Path (#21980)

Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
David Ben-David 2025-08-03 14:03:40 +03:00 committed by GitHub
parent 24d1dffbeb
commit aefeea0fde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 142 additions and 80 deletions

View File

@ -4,7 +4,7 @@ from concurrent.futures import Future
from typing import Optional from typing import Optional
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
class DummyModelRunnerOutput(ModelRunnerOutput): class DummyModelRunnerOutput(ModelRunnerOutput):
@ -12,8 +12,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self, def __init__(self,
finished_sending: Optional[set[str]] = None, finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None): finished_recving: Optional[set[str]] = None):
self.finished_sending = finished_sending self.kv_connector_output = KVConnectorOutput(
self.finished_recving = finished_recving finished_sending=finished_sending,
finished_recving=finished_recving,
)
def __repr__(self):
return (
f"DummyModelRunnerOutput("
f"finished_sending={self.kv_connector_output.finished_sending},"
f"finished_recving={self.kv_connector_output.finished_recving})")
def test_aggregate_workers_output(): def test_aggregate_workers_output():
@ -27,6 +35,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2]) aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None assert aggregated.finished_sending is None
assert aggregated.finished_recving is None assert aggregated.finished_recving is None
@ -38,6 +47,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2]) aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'} assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None assert aggregated.finished_recving is None
@ -49,6 +59,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2]) aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'} assert aggregated.finished_recving == {'req2'}
@ -70,6 +81,7 @@ def test_async_aggregate_workers_output():
assert result_future.done() assert result_future.done()
aggregated = result_future.result() aggregated = result_future.result()
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None assert aggregated.finished_sending is None
assert aggregated.finished_recving is None assert aggregated.finished_recving is None
@ -87,6 +99,7 @@ def test_async_aggregate_workers_output():
assert result_future.done() assert result_future.done()
aggregated = result_future.result() aggregated = result_future.result()
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'} assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None assert aggregated.finished_recving is None
@ -104,5 +117,6 @@ def test_async_aggregate_workers_output():
assert result_future.done() assert result_future.done()
aggregated = result_future.result() aggregated = result_future.result()
assert aggregated is output1 assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'} assert aggregated.finished_recving == {'req2'}

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output, from .utils import (assert_scheduler_empty, create_model_runner_output,
@ -86,7 +86,8 @@ def test_basic_lifecycle():
# (3b): execute_model() # (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_id] model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
# (3c): update_from_output() # (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
scheduler.schedule() scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_sending = [request_remote.request_id] model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule() _ = scheduler.schedule()
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus from vllm.v1.request import FinishReason, RequestStatus
from .utils import (assert_scheduler_empty, create_model_runner_output, from .utils import (assert_scheduler_empty, create_model_runner_output,
@ -72,7 +72,8 @@ def test_basic_lifecycle():
# (2b): forward(): request finishes recv. # (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id] model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
# (2c): update_from_output(): # (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output, engine_core_outputs = scheduler.update_from_output(scheduler_output,
@ -309,7 +310,8 @@ def test_full_block_prompt():
# # STEP (2): Recv. # # STEP (2): Recv.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.finished_recving = [request_id] model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids) assert (request_id in scheduler.finished_recving_kv_req_ids)

View File

@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
@ -188,8 +188,10 @@ def create_model_runner_output(
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=None, pooler_output=None,
finished_sending=finished_sending, kv_connector_output=KVConnectorOutput(
finished_recving=finished_recving, finished_sending=finished_sending,
finished_recving=finished_recving,
),
) )

View File

@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -129,7 +129,7 @@ class KVOutputAggregator:
def aggregate(self, def aggregate(self,
outputs: list[ModelRunnerOutput], outputs: list[ModelRunnerOutput],
output_rank: int = 0) -> ModelRunnerOutput: output_rank: int = 0) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers # aggregate kv_connector_output from all workers
def update_finished_set(req_ids: Optional[set[str]], def update_finished_set(req_ids: Optional[set[str]],
remaining_count_dict: dict[str, int], remaining_count_dict: dict[str, int],
@ -143,6 +143,7 @@ class KVOutputAggregator:
finished_sending = set[str]() finished_sending = set[str]()
finished_recving = set[str]() finished_recving = set[str]()
for output in outputs: for output in outputs:
output = output.kv_connector_output
update_finished_set(output.finished_sending, update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending) self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving, update_finished_set(output.finished_recving,
@ -151,13 +152,10 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[output_rank] output = outputs[output_rank]
# set the aggregated finished_sending / finished_recving output.kv_connector_output = KVConnectorOutput(
# if output.finished_sending/recving is not empty, but the other ranks finished_sending=finished_sending or None,
# still have unfinished send/recv, we want to set the aggregated finished_recving=finished_recving or None,
# finished_sending/recving to None until all ranks have finished )
# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
return output return output

View File

@ -10,7 +10,7 @@ from collections.abc import Mapping
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import reduce from functools import reduce
from typing import Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import msgspec import msgspec
import torch import torch
@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
if TYPE_CHECKING:
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_TOKEN_ID_ARRAY_TYPE = "l"
VLLM_INVALID_TOKEN_ID = -1 VLLM_INVALID_TOKEN_ID = -1
@ -1159,14 +1163,11 @@ class IntermediateTensors:
states and residuals to be sent to the next stage. This data structure states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request. contains the hidden states and residuals for a request.
Each stage also needs to handle its own finished_sending and Each stage also needs to handle its own kv_connector_output.
finished_recving in case of kv transfer.
""" """
tensors: dict[str, torch.Tensor] tensors: dict[str, torch.Tensor]
# [req_ids] kv_connector_output: Optional["KVConnectorOutput"]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
def __init__(self, tensors): def __init__(self, tensors):
# manually define this function, so that # manually define this function, so that

View File

@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
@ -884,7 +884,9 @@ class Scheduler(SchedulerInterface):
self.waiting.remove_requests(stopped_preempted_reqs) self.waiting.remove_requests(stopped_preempted_reqs)
# KV Connector: update state for finished KV Transfers. # KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output) if model_runner_output.kv_connector_output:
self._update_from_kv_xfer_finished(
model_runner_output.kv_connector_output)
# Create EngineCoreOutputs for all clients that have requests with # Create EngineCoreOutputs for all clients that have requests with
# outputs in this step. # outputs in this step.
@ -1128,7 +1130,7 @@ class Scheduler(SchedulerInterface):
return True return True
def _update_from_kv_xfer_finished(self, def _update_from_kv_xfer_finished(self,
model_runner_output: ModelRunnerOutput): kv_connector_output: KVConnectorOutput):
""" """
KV Connector: update the scheduler state based on the output. KV Connector: update the scheduler state based on the output.
@ -1139,9 +1141,9 @@ class Scheduler(SchedulerInterface):
scheduler the request during the next step. scheduler the request during the next step.
""" """
# KV Connector:: update recv and send status from last step. # KV Connector:: update recv and send status from last step.
for req_id in (model_runner_output.finished_recving or ()): for req_id in (kv_connector_output.finished_recving or ()):
logger.debug("Finished recving KV transfer for request %s", req_id) logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id) self.finished_recving_kv_req_ids.add(req_id)
for req_id in (model_runner_output.finished_sending or ()): for req_id in (kv_connector_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id) logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id]) self._free_blocks(self.requests[req_id])

View File

@ -71,6 +71,13 @@ class SamplerOutput:
logprobs_tensors: Optional[LogprobsTensors] logprobs_tensors: Optional[LogprobsTensors]
@dataclass
class KVConnectorOutput:
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
# ModelRunnerOutput is serialized and sent to the scheduler process. # ModelRunnerOutput is serialized and sent to the scheduler process.
# This is expensive for torch.Tensor so prefer to use list instead. # This is expensive for torch.Tensor so prefer to use list instead.
@dataclass @dataclass
@ -104,9 +111,7 @@ class ModelRunnerOutput:
# [num_reqs, hidden_size] # [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]] pooler_output: list[Optional[torch.Tensor]]
# [req_ids] kv_connector_output: Optional[KVConnectorOutput] = None
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
# req_id -> num_nans_in_logits # req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None num_nans_in_logits: Optional[dict[str, int]] = None
@ -119,6 +124,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
finished_sending=None,
finished_recving=None,
num_nans_in_logits=None) num_nans_in_logits=None)

View File

@ -69,7 +69,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin) KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
@ -1423,8 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
num_scheduled_tokens: int, num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray, num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]], kv_connector_output: Optional[KVConnectorOutput],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\ assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \ len(self.input_batch.pooling_params), \
@ -1459,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=pooler_output, pooler_output=pooler_output,
finished_sending=finished_sending, kv_connector_output=kv_connector_output,
finished_recving=finished_recving,
) )
@torch.inference_mode() @torch.inference_mode()
@ -1564,8 +1562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs, skip_cuda_graphs=skip_cuda_graphs,
): ), self.maybe_get_kv_connector_output(
self.maybe_setup_kv_connector(scheduler_output) scheduler_output) as kv_connector_output:
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
@ -1578,10 +1576,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
), ),
) )
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output hidden_states, aux_hidden_states = model_output
else: else:
@ -1597,20 +1591,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
== "external_launcher" and len(get_pp_group().ranks) > 0 == "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states. # For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors) assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(hidden_states.tensors, get_pp_group().send_tensor_dict(hidden_states.tensors,
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
logits = None logits = None
else: else:
if self.input_batch.pooling_params: if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens, return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending, num_scheduled_tokens_np, kv_connector_output)
finished_recving)
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
@ -1760,8 +1751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
finished_sending=finished_sending, kv_connector_output=kv_connector_output,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
) )

View File

@ -16,8 +16,7 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -369,17 +368,20 @@ class Worker(WorkerBase):
assert isinstance(output, IntermediateTensors) assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors, get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
if not has_kv_transfer_group():
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None return None
# In case of PP with kv transfer, we need to pass through the # In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers. # kv_connector_output
new_output = EMPTY_MODEL_RUNNER_OUTPUT if (not kv_connector_output.finished_sending
if output.finished_sending or output.finished_recving: and not kv_connector_output.finished_recving):
new_output = copy.copy(new_output) return EMPTY_MODEL_RUNNER_OUTPUT
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output = new_output output.kv_connector_output = kv_connector_output
return output
assert isinstance(output, ModelRunnerOutput) assert isinstance(output, ModelRunnerOutput)
return output return output

View File

@ -4,6 +4,8 @@
Define KV connector functionality mixin for model runners. Define KV connector functionality mixin for model runners.
""" """
import copy import copy
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import Generator # noqa: UP035
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -12,7 +14,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput,
ModelRunnerOutput)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
@ -53,18 +56,60 @@ class KVConnectorModelRunnerMixin:
scheduler_output.finished_req_ids) scheduler_output.finished_req_ids)
return None, None return None, None
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput", @staticmethod
def kv_connector_no_forward(scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput: vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do. # KV send/recv even if no work to do.
with set_forward_context(None, vllm_config): with set_forward_context(
self.maybe_setup_kv_connector(scheduler_output) None, vllm_config
finished_sending, finished_recving = ( ), KVConnectorModelRunnerMixin._get_kv_connector_output(
self.get_finished_kv_transfers(scheduler_output)) scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if not finished_sending and not finished_recving: if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending output.kv_connector_output = kv_connector_output
output.finished_recving = finished_recving
return output return output
@staticmethod
def maybe_get_kv_connector_output(
scheduler_output: "SchedulerOutput"
) -> AbstractContextManager[Optional[KVConnectorOutput]]:
return KVConnectorModelRunnerMixin._get_kv_connector_output(
scheduler_output) if has_kv_transfer_group() else nullcontext()
# This context manager must be used within an active forward context.
# It encapsulates the entire KV conector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput",
wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
try:
yield output
finally:
if wait_for_save:
kv_connector.wait_for_save()
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
kv_connector.clear_connector_metadata()

View File

@ -51,7 +51,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists,
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.worker.kv_connector_model_runner_mixin import ( from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin) KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
@ -1175,9 +1175,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
finished_sending=finished_sending, kv_connector_output=KVConnectorOutput(
finished_recving=finished_recving, finished_sending=finished_sending,
) finished_recving=finished_recving,
))
# Check there are no new graphs compiled - all the graphs should be # Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up. # captured and compiled during warm up.