mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:55:55 +08:00
[TPU] Fix greedy decoding (#6933)
This commit is contained in:
parent
af647fb8b3
commit
6e063ea35b
@ -28,7 +28,9 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored.
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||||
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||||
|
_PAD_SLOT_ID = 1_000_000_000
|
||||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||||
_ENABLE_TOP_P = False
|
_ENABLE_TOP_P = False
|
||||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||||
@ -414,10 +416,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
best_of = []
|
best_of = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
t.append(sampling_params.temperature)
|
||||||
# low temperature. This is not accurate.
|
|
||||||
t.append(sampling_params.temperature
|
|
||||||
if sampling_params.temperature >= 1e-5 else 1e-5)
|
|
||||||
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Top-p sampling is currently disabled for the TPU backend "
|
"Top-p sampling is currently disabled for the TPU backend "
|
||||||
@ -678,13 +677,23 @@ class ModelWrapper(nn.Module):
|
|||||||
hidden_states = hidden_states.flatten(0, 1)
|
hidden_states = hidden_states.flatten(0, 1)
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
logits = logits / t.unsqueeze(dim=1)
|
# Argmax sampling.
|
||||||
|
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
|
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
|
||||||
|
|
||||||
|
# Zero temperature means greedy decoding. Avoid division by zero.
|
||||||
|
nonzero_t = torch.where(t != 0, t, 1.0)
|
||||||
|
logits = logits / nonzero_t.unsqueeze(dim=1)
|
||||||
if _ENABLE_TOP_P:
|
if _ENABLE_TOP_P:
|
||||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||||
|
|
||||||
|
# Random sampling.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
next_token_ids = torch.multinomial(probs,
|
sampled_token_ids = torch.multinomial(probs,
|
||||||
num_samples,
|
num_samples,
|
||||||
replacement=True)
|
replacement=True)
|
||||||
|
next_token_ids = torch.where(t != 0, sampled_token_ids,
|
||||||
|
argmax_token_ids)
|
||||||
return next_token_ids
|
return next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user