mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 00:07:12 +08:00
updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
parent
24f68342b4
commit
4c42267293
@ -599,11 +599,6 @@ class TPUModelRunner:
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
|
||||
# Temporary debug pathway.
|
||||
if envs.VLLM_TPU_DISABLE_SAMPLER_DEBUG:
|
||||
@ -618,6 +613,11 @@ class TPUModelRunner:
|
||||
hidden_states, logits_indices, None)
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
else:
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user