Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-03-28 02:26:20 +00:00
parent 24f68342b4
commit 4c42267293

View File

@ -599,11 +599,6 @@ class TPUModelRunner:
input_ids = self.input_ids
inputs_embeds = None
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Temporary debug pathway.
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
@ -618,6 +613,11 @@ class TPUModelRunner:
hidden_states, logits_indices, None)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
else:
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,