From 4c42267293563e36e01bcf5805c7c208c06392bc Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Fri, 28 Mar 2025 02:26:20 +0000 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/v1/worker/tpu_model_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e16ccc8a65efe..c0548f22f8154 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -599,11 +599,6 @@ class TPUModelRunner: input_ids = self.input_ids inputs_embeds = None num_reqs = self.input_batch.num_reqs - # NOTE (NickLucche) here we sync with TPU: sampling params tensors - # are copied to device in chunks of pre-compiled padded shape to - # avoid recompilations. - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, logits_indices) # Temporary debug pathway. if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG: @@ -618,6 +613,11 @@ class TPUModelRunner: hidden_states, logits_indices, None) selected_token_ids = selected_token_ids.cpu()[:num_reqs] else: + # NOTE (NickLucche) here we sync with TPU: sampling params tensors + # are copied to device in chunks of pre-compiled padded shape to + # avoid recompilations. + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, logits_indices) with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=input_ids,