diff --git a/vllm/envs.py b/vllm/envs.py index 5334667376b24..ea2c8ff0f98c4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -104,6 +104,7 @@ if TYPE_CHECKING: VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 + VLLM_TPU_DISABLE_SAMPLER_DEBUG: bool = False def get_default_cache_root(): @@ -673,6 +674,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TPU_BUCKET_PADDING_GAP": lambda: int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]) if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0, + + # Disable sampler path for debugging performance. + "VLLM_TPU_DISABLE_SAMPLER_DEBUG": + lambda: os.environ.get("VLLM_TPU_DISABLE_SAMPLER_DEBUG", "0") == "1", } # end-env-vars-definition diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 97dfd23163dff..ed0ebb75d1be3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -612,10 +612,17 @@ class TPUModelRunner: kv_caches=self.kv_caches, inputs_embeds=inputs_embeds, ) - selected_token_ids = self.model.sample_from_hidden( - hidden_states, tpu_sampling_metadata) - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] + + if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG: + selected_token_ids = self.model.compute_logits(hidden_states, + logits_indices, None) + selected_token_ids = selected_token_ids.cpu()[:num_reqs] + else: + selected_token_ids = self.model.sample_from_hidden( + hidden_states, tpu_sampling_metadata) + + # Remove padding on cpu and keep dynamic op outside of xla graph. + selected_token_ids = selected_token_ids.cpu()[:num_reqs] # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes