mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
[Hardware][TPU] Raise errors for unsupported sampling params (#5850)
This commit is contained in:
parent
dd793d1de5
commit
f178e56c68
@ -20,6 +20,8 @@ from vllm.utils import make_tensor_with_pad
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
_PAD_SLOT_ID = 0 # FIXME(woosuk)
|
||||||
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||||
|
_ENABLE_TOP_P = False
|
||||||
|
|
||||||
|
|
||||||
class TPUModelRunner:
|
class TPUModelRunner:
|
||||||
@ -339,9 +341,34 @@ class TPUModelRunner:
|
|||||||
assert seq_group_metadata.sampling_params is not None
|
assert seq_group_metadata.sampling_params is not None
|
||||||
sampling_params = seq_group_metadata.sampling_params
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
|
||||||
|
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
|
||||||
|
# low temperature. This is not accurate.
|
||||||
t.append(sampling_params.temperature
|
t.append(sampling_params.temperature
|
||||||
if sampling_params.temperature >= 1e-5 else 1e-5)
|
if sampling_params.temperature >= 1e-5 else 1e-5)
|
||||||
|
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Top-p sampling is currently disabled for the TPU backend "
|
||||||
|
"due to performance issues.")
|
||||||
p.append(sampling_params.top_p)
|
p.append(sampling_params.top_p)
|
||||||
|
if sampling_params.top_k != -1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Top-k sampling is currently disabled for the TPU backend "
|
||||||
|
"due to performance issues.")
|
||||||
|
if sampling_params.best_of > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"best_of > 1 is not currently supported by the TPU "
|
||||||
|
"backend.")
|
||||||
|
if sampling_params.use_beam_search:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Beam search is not supported by the TPU backend.")
|
||||||
|
if sampling_params.logprobs is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"logprobs is not currently supported by the TPU backend.")
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"prompt_logprobs is not currently supported by the TPU "
|
||||||
|
"backend.")
|
||||||
|
|
||||||
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
num_paddings = padded_batch_size - len(seq_group_metadata_list)
|
||||||
t += [1.0] * num_paddings
|
t += [1.0] * num_paddings
|
||||||
p += [1.0] * num_paddings
|
p += [1.0] * num_paddings
|
||||||
@ -350,35 +377,32 @@ class TPUModelRunner:
|
|||||||
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
p = torch.tensor(p, dtype=torch.float32, device=self.device)
|
||||||
return t, p
|
return t, p
|
||||||
|
|
||||||
def prepare_inputs(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
||||||
):
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
assert len(seq_group_metadata_list) > 0
|
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
|
||||||
# all decodes.
|
|
||||||
if seq_group_metadata_list[0].is_prompt:
|
|
||||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
|
||||||
else:
|
|
||||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
|
||||||
padded_batch_size = inputs[0].shape[0]
|
|
||||||
sample_inputs = self._prepare_sample(seq_group_metadata_list,
|
|
||||||
padded_batch_size)
|
|
||||||
return inputs + sample_inputs
|
|
||||||
|
|
||||||
def _execute_model(
|
def _execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> List[CompletionSequenceGroupOutput]:
|
) -> List[CompletionSequenceGroupOutput]:
|
||||||
inputs = self.prepare_inputs(seq_group_metadata_list)
|
# Prepare inputs.
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
if is_prompt:
|
||||||
|
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
padded_batch_size = inputs[0].shape[0]
|
||||||
|
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)
|
||||||
|
|
||||||
|
# Execute the model.
|
||||||
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
|
||||||
*inputs[2:])
|
*inputs[2:], t, p)
|
||||||
if not self.is_driver_worker:
|
# Retrieve the outputs to CPU.
|
||||||
return []
|
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||||
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||||
|
# does not support the advanced sampling parameters such as logprobs.
|
||||||
i = 0
|
i = 0
|
||||||
sampler_outputs = []
|
sampler_outputs = []
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
@ -400,6 +424,7 @@ class TPUModelRunner:
|
|||||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
if seq_group_metadata_list[0].is_prompt:
|
if seq_group_metadata_list[0].is_prompt:
|
||||||
# NOTE(woosuk): To reduce the compilation time, we only compile the
|
# NOTE(woosuk): To reduce the compilation time, we only compile the
|
||||||
# prefill inputs with batch size 1. Because the scheduler is not
|
# prefill inputs with batch size 1. Because the scheduler is not
|
||||||
@ -492,8 +517,8 @@ class ModelWrapper(nn.Module):
|
|||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
logits = logits / t.unsqueeze(dim=1)
|
logits = logits / t.unsqueeze(dim=1)
|
||||||
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
|
if _ENABLE_TOP_P:
|
||||||
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
# FIXME(woosuk): best_of > 1 is not supported.
|
# FIXME(woosuk): best_of > 1 is not supported.
|
||||||
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user