From 5898b135abc7b7c0ef7107d21a07d54a84314b7c Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 10 Aug 2025 19:33:48 -0700 Subject: [PATCH] [BugFix] Fix KVConnectorOutput TPU breakage (#22598) Signed-off-by: Nick Hill --- tests/v1/kv_connector/unit/utils.py | 12 ++++++++---- vllm/v1/core/sched/scheduler.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 13 +++++++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 291c84d117cb..c22d5b861e3f 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -179,6 +179,13 @@ def create_model_runner_output( sampled_token = EOS_TOKEN_ID if use_eos else 0 sampled_token_ids = [[sampled_token] for _ in req_ids] + kv_connector_output = None if ( + finished_sending is None + and finished_recving is None) else KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + # Make output data structure. return ModelRunnerOutput( req_ids=req_ids, @@ -188,10 +195,7 @@ def create_model_runner_output( logprobs=None, prompt_logprobs_dict={}, pooler_output=None, - kv_connector_output=KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - ), + kv_connector_output=kv_connector_output, ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 85fc1a4a016a..dcb9f4dd36f5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1151,8 +1151,8 @@ class Scheduler(SchedulerInterface): scheduler the request during the next step. """ - assert self.connector is not None - self.connector.update_connector_output(kv_connector_output) + if self.connector is not None: + self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. for req_id in (kv_connector_output.finished_recving or ()): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 915869726fbf..ae0219458ecf 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1138,6 +1138,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): i, target_slice] = valid_sampled_token_ids[i] req_state.output_token_ids.extend(valid_sampled_token_ids[i]) + kv_connector_output = None if ( + finished_sending is None + and finished_recving is None) else KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving, + ) + model_runner_output = ModelRunnerOutput( req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, @@ -1146,10 +1153,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - kv_connector_output=KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - )) + kv_connector_output=kv_connector_output, + ) # Check there are no new graphs compiled - all the graphs should be # captured and compiled during warm up.