[Core] Cleanup TPU model runner for MM (#23894)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-30 15:14:58 +08:00 committed by GitHub
parent 9748c5198b
commit f1bddbd852
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -808,31 +808,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return per_layer_attn_metadata, logits_indices, padded_num_reqs,\ return per_layer_attn_metadata, logits_indices, padded_num_reqs,\
num_reqs, end_index num_reqs, end_index
def _scatter_placeholders(
self,
embeds: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return embeds
placeholders = embeds.new_full(
(is_embed.shape[0], embeds.shape[-1]),
fill_value=torch.nan,
)
placeholders[is_embed] = embeds
return placeholders
def _gather_placeholders(
self,
placeholders: torch.Tensor,
is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
if is_embed is None:
return placeholders
return placeholders[is_embed]
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: if not scheduled_encoder_inputs:
@ -892,12 +867,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE (NickLucche) here we diverge from logic in other runners, as we # NOTE (NickLucche) here we diverge from logic in other runners, as we
# assume to only have whole mm items to process. Hence we avoid the # assume to only have whole mm items to process. Hence we avoid the
# intrinsic dynamism that `scatter_mm_placeholders` introduces. # intrinsic dynamism that `scatter_mm_placeholders` introduces.
for (mm_hash, pos_info), output in zip( for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
mm_hashes_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
assert pos_info.is_embed is None, "Expected all positions to be"\ assert pos_info.is_embed is None, "Expected all positions to be"\
" contiguous and embeddings." " contiguous and embeddings."
self.encoder_cache[mm_hash] = output self.encoder_cache[mm_hash] = output