diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ed0ebb75d1be3..e16ccc8a65efe 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -604,25 +604,30 @@ class TPUModelRunner: # avoid recompilations. tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ from_input_batch(self.input_batch, logits_indices) - # Run the decoder - with set_forward_context(attn_metadata, self.vllm_config): - hidden_states = self.model( - input_ids=input_ids, - positions=self.position_ids, - kv_caches=self.kv_caches, - inputs_embeds=inputs_embeds, - ) - + + # Temporary debug pathway. if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG: - selected_token_ids = self.model.compute_logits(hidden_states, - logits_indices, None) + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=self.position_ids, + kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, + ) + selected_token_ids = self.model.compute_logits_no_sampler( + hidden_states, logits_indices, None) selected_token_ids = selected_token_ids.cpu()[:num_reqs] else: - selected_token_ids = self.model.sample_from_hidden( - hidden_states, tpu_sampling_metadata) - - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=input_ids, + positions=self.position_ids, + kv_caches=self.kv_caches, + inputs_embeds=inputs_embeds, + ) + selected_token_ids = self.model.sample_from_hidden( + hidden_states, tpu_sampling_metadata) + selected_token_ids = selected_token_ids.cpu()[:num_reqs] # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes @@ -936,6 +941,18 @@ class ModelWrapperV1(nn.Module): logits = self.model.compute_logits(hidden_states, None) return logits + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def compute_logits_no_sampler( + self, + hidden_states: torch.Tensor, + logits_indices: torch.Tensor, + sampling_metadata, + ) -> Optional[torch.Tensor]: + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, sampling_metadata) + selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + return selected_token_ids + def get_multimodal_embeddings(self, *args, **kwargs): return self.model.get_multimodal_embeddings(*args, **kwargs)