mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 15:25:28 +08:00
[BugFix] Fix KVConnectorOutput TPU breakage (#22598)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
b799f4b9ea
commit
5898b135ab
@ -179,6 +179,13 @@ def create_model_runner_output(
|
|||||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||||
|
|
||||||
|
kv_connector_output = None if (
|
||||||
|
finished_sending is None
|
||||||
|
and finished_recving is None) else KVConnectorOutput(
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
)
|
||||||
|
|
||||||
# Make output data structure.
|
# Make output data structure.
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
@ -188,10 +195,7 @@ def create_model_runner_output(
|
|||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=None,
|
pooler_output=None,
|
||||||
kv_connector_output=KVConnectorOutput(
|
kv_connector_output=kv_connector_output,
|
||||||
finished_sending=finished_sending,
|
|
||||||
finished_recving=finished_recving,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1151,8 +1151,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
scheduler the request during the next step.
|
scheduler the request during the next step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert self.connector is not None
|
if self.connector is not None:
|
||||||
self.connector.update_connector_output(kv_connector_output)
|
self.connector.update_connector_output(kv_connector_output)
|
||||||
|
|
||||||
# KV Connector:: update recv and send status from last step.
|
# KV Connector:: update recv and send status from last step.
|
||||||
for req_id in (kv_connector_output.finished_recving or ()):
|
for req_id in (kv_connector_output.finished_recving or ()):
|
||||||
|
|||||||
@ -1138,6 +1138,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
i, target_slice] = valid_sampled_token_ids[i]
|
i, target_slice] = valid_sampled_token_ids[i]
|
||||||
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
||||||
|
|
||||||
|
kv_connector_output = None if (
|
||||||
|
finished_sending is None
|
||||||
|
and finished_recving is None) else KVConnectorOutput(
|
||||||
|
finished_sending=finished_sending,
|
||||||
|
finished_recving=finished_recving,
|
||||||
|
)
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
@ -1146,10 +1153,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
kv_connector_output=KVConnectorOutput(
|
kv_connector_output=kv_connector_output,
|
||||||
finished_sending=finished_sending,
|
)
|
||||||
finished_recving=finished_recving,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Check there are no new graphs compiled - all the graphs should be
|
# Check there are no new graphs compiled - all the graphs should be
|
||||||
# captured and compiled during warm up.
|
# captured and compiled during warm up.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user