[BugFix] Fix KVConnectorOutput TPU breakage (#22598)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-10 19:33:48 -07:00 committed by GitHub
parent b799f4b9ea
commit 5898b135ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 10 deletions

View File

@ -179,6 +179,13 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
@ -188,10 +195,7 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
),
kv_connector_output=kv_connector_output,
)

View File

@ -1151,8 +1151,8 @@ class Scheduler(SchedulerInterface):
scheduler the request during the next step.
"""
assert self.connector is not None
self.connector.update_connector_output(kv_connector_output)
if self.connector is not None:
self.connector.update_connector_output(kv_connector_output)
# KV Connector:: update recv and send status from last step.
for req_id in (kv_connector_output.finished_recving or ()):

View File

@ -1138,6 +1138,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
@ -1146,10 +1153,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
))
kv_connector_output=kv_connector_output,
)
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.