diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py index cad73f68e9f1..5d2b27a9eb4d 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -4,7 +4,7 @@ from concurrent.futures import Future from typing import Optional from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput class DummyModelRunnerOutput(ModelRunnerOutput): @@ -12,8 +12,16 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def __init__(self, finished_sending: Optional[set[str]] = None, finished_recving: Optional[set[str]] = None): - self.finished_sending = finished_sending - self.finished_recving = finished_recving + self.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + + def __repr__(self): + return ( + f"DummyModelRunnerOutput(" + f"finished_sending={self.kv_connector_output.finished_sending}," + f"finished_recving={self.kv_connector_output.finished_recving})") def test_aggregate_workers_output(): @@ -27,6 +35,7 @@ def test_aggregate_workers_output(): aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None @@ -38,6 +47,7 @@ def test_aggregate_workers_output(): aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending == {'req1'} assert aggregated.finished_recving is None @@ -49,6 +59,7 @@ def test_aggregate_workers_output(): aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving == {'req2'} @@ -70,6 +81,7 @@ def test_async_aggregate_workers_output(): assert result_future.done() aggregated = result_future.result() assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving is None @@ -87,6 +99,7 @@ def test_async_aggregate_workers_output(): assert result_future.done() aggregated = result_future.result() assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending == {'req1'} assert aggregated.finished_recving is None @@ -104,5 +117,6 @@ def test_async_aggregate_workers_output(): assert result_future.done() aggregated = result_future.result() assert aggregated is output1 + aggregated = aggregated.kv_connector_output assert aggregated.finished_sending is None assert aggregated.finished_recving == {'req2'} diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 12a71d97e8d2..76394a540aac 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus from .utils import (assert_scheduler_empty, create_model_runner_output, @@ -86,7 +86,8 @@ def test_basic_lifecycle(): # (3b): execute_model() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_sending = [request_id] + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request_id]) # (3c): update_from_output() scheduler.update_from_output(scheduler_output, model_runner_output) @@ -176,7 +177,8 @@ def test_prefix_cache_lifecycle(): scheduler_output = scheduler.schedule() scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_sending = [request_remote.request_id] + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_sending=[request_remote.request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f89970bf2c80..3d52ea526d96 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.request import FinishReason, RequestStatus from .utils import (assert_scheduler_empty, create_model_runner_output, @@ -72,7 +72,8 @@ def test_basic_lifecycle(): # (2b): forward(): request finishes recv. model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_recving = [request_id] + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_recving=[request_id]) # (2c): update_from_output(): engine_core_outputs = scheduler.update_from_output(scheduler_output, @@ -309,7 +310,8 @@ def test_full_block_prompt(): # # STEP (2): Recv. scheduler_output = scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) - model_runner_output.finished_recving = [request_id] + model_runner_output.kv_connector_output = KVConnectorOutput( + finished_recving=[request_id]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.waiting) == 1 assert (request_id in scheduler.finished_recving_kv_req_ids) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 480a7074cdf4..291c84d117cb 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager @@ -188,8 +188,10 @@ def create_model_runner_output( logprobs=None, prompt_logprobs_dict={}, pooler_output=None, - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_output=KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ), ) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 559c233947ce..1a11cb6d0189 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -16,7 +16,7 @@ from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.logger import init_logger -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput logger = init_logger(__name__) @@ -129,7 +129,7 @@ class KVOutputAggregator: def aggregate(self, outputs: list[ModelRunnerOutput], output_rank: int = 0) -> ModelRunnerOutput: - # aggregate finished_sending, finished_recving from all workers + # aggregate kv_connector_output from all workers def update_finished_set(req_ids: Optional[set[str]], remaining_count_dict: dict[str, int], @@ -143,6 +143,7 @@ class KVOutputAggregator: finished_sending = set[str]() finished_recving = set[str]() for output in outputs: + output = output.kv_connector_output update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending) update_finished_set(output.finished_recving, @@ -151,13 +152,10 @@ class KVOutputAggregator: # select output of the worker specified by output_rank output = outputs[output_rank] - # set the aggregated finished_sending / finished_recving - # if output.finished_sending/recving is not empty, but the other ranks - # still have unfinished send/recv, we want to set the aggregated - # finished_sending/recving to None until all ranks have finished - # send/recv - output.finished_sending = finished_sending if finished_sending else None - output.finished_recving = finished_recving if finished_recving else None + output.kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending or None, + finished_recving=finished_recving or None, + ) return output diff --git a/vllm/sequence.py b/vllm/sequence.py index fe87b52f9df1..6e65a2bd0318 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,7 +10,7 @@ from collections.abc import Mapping from collections.abc import Sequence as GenericSequence from dataclasses import dataclass, field from functools import reduce -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import msgspec import torch @@ -21,6 +21,10 @@ from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams +if TYPE_CHECKING: + from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorOutput) + VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 @@ -1159,14 +1163,11 @@ class IntermediateTensors: states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. - Each stage also needs to handle its own finished_sending and - finished_recving in case of kv transfer. + Each stage also needs to handle its own kv_connector_output. """ tensors: dict[str, torch.Tensor] - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None + kv_connector_output: Optional["KVConnectorOutput"] def __init__(self, tensors): # manually define this function, so that diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..49a744cfec69 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -884,7 +884,9 @@ class Scheduler(SchedulerInterface): self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. - self._update_from_kv_xfer_finished(model_runner_output) + if model_runner_output.kv_connector_output: + self._update_from_kv_xfer_finished( + model_runner_output.kv_connector_output) # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. @@ -1128,7 +1130,7 @@ class Scheduler(SchedulerInterface): return True def _update_from_kv_xfer_finished(self, - model_runner_output: ModelRunnerOutput): + kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1139,9 +1141,9 @@ class Scheduler(SchedulerInterface): scheduler the request during the next step. """ # KV Connector:: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or ()): + for req_id in (kv_connector_output.finished_recving or ()): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or ()): + for req_id in (kv_connector_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..7d7cd0c94dd0 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -71,6 +71,13 @@ class SamplerOutput: logprobs_tensors: Optional[LogprobsTensors] +@dataclass +class KVConnectorOutput: + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None + + # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass @@ -104,9 +111,7 @@ class ModelRunnerOutput: # [num_reqs, hidden_size] pooler_output: list[Optional[torch.Tensor]] - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None + kv_connector_output: Optional[KVConnectorOutput] = None # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None @@ -119,6 +124,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], - finished_sending=None, - finished_recving=None, num_nans_in_logits=None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42cef6c5733d..041687ae28b2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -69,7 +69,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) + KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager @@ -1423,8 +1423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - finished_sending: Optional[set[str]], - finished_recving: Optional[set[str]], + kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1459,8 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_output=kv_connector_output, ) @torch.inference_mode() @@ -1564,8 +1562,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - ): - self.maybe_setup_kv_connector(scheduler_output) + ), self.maybe_get_kv_connector_output( + scheduler_output) as kv_connector_output: model_output = self.model( input_ids=input_ids, @@ -1578,10 +1576,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ), ) - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: @@ -1597,20 +1591,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): == "external_launcher" and len(get_pp_group().ranks) > 0 if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. - if not broadcast_pp_output: - if finished_sending or finished_recving: - hidden_states.finished_sending = finished_sending - hidden_states.finished_recving = finished_recving - return hidden_states assert isinstance(hidden_states, IntermediateTensors) + if not broadcast_pp_output: + hidden_states.kv_connector_output = kv_connector_output + return hidden_states get_pp_group().send_tensor_dict(hidden_states.tensors, all_gather_group=get_tp_group()) logits = None else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving) + num_scheduled_tokens_np, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1760,8 +1751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_output=kv_connector_output, num_nans_in_logits=num_nans_in_logits, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4bc4ece9a0df..7fca245c1bef 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -16,8 +16,7 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -369,17 +368,20 @@ class Worker(WorkerBase): assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - if not has_kv_transfer_group(): + + kv_connector_output = output.kv_connector_output + if not kv_connector_output: return None # In case of PP with kv transfer, we need to pass through the - # finished_sending and finished_recving buffers. - new_output = EMPTY_MODEL_RUNNER_OUTPUT - if output.finished_sending or output.finished_recving: - new_output = copy.copy(new_output) - new_output.finished_sending = output.finished_sending - new_output.finished_recving = output.finished_recving - output = new_output + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output assert isinstance(output, ModelRunnerOutput) return output diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 5a3186058fcf..343befe17679 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -4,6 +4,8 @@ Define KV connector functionality mixin for model runners. """ import copy +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Generator # noqa: UP035 from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig @@ -12,7 +14,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, + ModelRunnerOutput) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -53,18 +56,60 @@ class KVConnectorModelRunnerMixin: scheduler_output.finished_req_ids) return None, None - def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput", + @staticmethod + def kv_connector_no_forward(scheduler_output: "SchedulerOutput", vllm_config: VllmConfig) -> ModelRunnerOutput: # KV send/recv even if no work to do. - with set_forward_context(None, vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) + with set_forward_context( + None, vllm_config + ), KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output, wait_for_save=False) as kv_connector_output: + pass - if not finished_sending and not finished_recving: + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving + output.kv_connector_output = kv_connector_output return output + + @staticmethod + def maybe_get_kv_connector_output( + scheduler_output: "SchedulerOutput" + ) -> AbstractContextManager[Optional[KVConnectorOutput]]: + return KVConnectorModelRunnerMixin._get_kv_connector_output( + scheduler_output) if has_kv_transfer_group() else nullcontext() + + # This context manager must be used within an active forward context. + # It encapsulates the entire KV conector lifecycle within execute_model + @staticmethod + @contextmanager + def _get_kv_connector_output( + scheduler_output: "SchedulerOutput", + wait_for_save: bool = True + ) -> Generator[KVConnectorOutput, None, None]: + output = KVConnectorOutput() + + # Update KVConnector with the KVConnector metadata forward(). + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + try: + yield output + finally: + if wait_for_save: + kv_connector.wait_for_save() + + output.finished_sending, output.finished_recving = ( + kv_connector.get_finished(scheduler_output.finished_req_ids)) + + kv_connector.clear_connector_metadata() diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 59cbb0150570..67cb2f9dd810 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -51,7 +51,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsLists, from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) + KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch @@ -1175,9 +1175,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, - ) + kv_connector_output=KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + )) # Check there are no new graphs compiled - all the graphs should be # captured and compiled during warm up.