mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 01:57:02 +08:00
updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
c5d963835b
commit
24f68342b4
@ -604,25 +604,30 @@ class TPUModelRunner:
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
|
||||
# Temporary debug pathway.
|
||||
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
|
||||
selected_token_ids = self.model.compute_logits(hidden_states,
|
||||
logits_indices, None)
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.compute_logits_no_sampler(
|
||||
hidden_states, logits_indices, None)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
else:
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
@ -936,6 +941,18 @@ class ModelWrapperV1(nn.Module):
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
return logits
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def compute_logits_no_sampler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits_indices: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
return selected_token_ids
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user