From f1bddbd852f37f98958d636821c45014c05e07a8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 30 Aug 2025 15:14:58 +0800 Subject: [PATCH] [Core] Cleanup TPU model runner for MM (#23894) Signed-off-by: DarkLight1337 --- vllm/v1/worker/tpu_model_runner.py | 32 +----------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2307006127085..985d5ba58c49c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -808,31 +808,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ num_reqs, end_index - def _scatter_placeholders( - self, - embeds: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return embeds - - placeholders = embeds.new_full( - (is_embed.shape[0], embeds.shape[-1]), - fill_value=torch.nan, - ) - placeholders[is_embed] = embeds - return placeholders - - def _gather_placeholders( - self, - placeholders: torch.Tensor, - is_embed: Optional[torch.Tensor], - ) -> torch.Tensor: - if is_embed is None: - return placeholders - - return placeholders[is_embed] - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -892,12 +867,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (mm_hash, pos_info), output in zip( - mm_hashes_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): assert pos_info.is_embed is None, "Expected all positions to be"\ " contiguous and embeddings." self.encoder_cache[mm_hash] = output