Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-03-28 02:17:42 +00:00
parent c5d963835b
commit 24f68342b4

View File

@ -604,25 +604,30 @@ class TPUModelRunner:
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
# Temporary debug pathway.
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.compute_logits_no_sampler(
hidden_states, logits_indices, None)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
else:
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
@ -936,6 +941,18 @@ class ModelWrapperV1(nn.Module):
logits = self.model.compute_logits(hidden_states, None)
return logits
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits_no_sampler(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)