From f178e56c68d97e3a29a8a885a09dd61f8d534732 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 25 Jun 2024 16:58:23 -0700 Subject: [PATCH] [Hardware][TPU] Raise errors for unsupported sampling params (#5850) --- vllm/worker/tpu_model_runner.py | 71 ++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 23 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2d8fffe5ac16..2c70c1f917a0 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -20,6 +20,8 @@ from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) _PAD_SLOT_ID = 0 # FIXME(woosuk) +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False class TPUModelRunner: @@ -339,9 +341,34 @@ class TPUModelRunner: assert seq_group_metadata.sampling_params is not None sampling_params = seq_group_metadata.sampling_params + # NOTE(woosuk): Here we mimic argmax sampling by applying a very + # low temperature. This is not accurate. t.append(sampling_params.temperature if sampling_params.temperature >= 1e-5 else 1e-5) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") p.append(sampling_params.top_p) + if sampling_params.top_k != -1: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.best_of > 1: + raise NotImplementedError( + "best_of > 1 is not currently supported by the TPU " + "backend.") + if sampling_params.use_beam_search: + raise NotImplementedError( + "Beam search is not supported by the TPU backend.") + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + num_paddings = padded_batch_size - len(seq_group_metadata_list) t += [1.0] * num_paddings p += [1.0] * num_paddings @@ -350,35 +377,32 @@ class TPUModelRunner: p = torch.tensor(p, dtype=torch.float32, device=self.device) return t, p - def prepare_inputs( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ): - assert seq_group_metadata_list is not None - assert len(seq_group_metadata_list) > 0 - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - if seq_group_metadata_list[0].is_prompt: - inputs = self._prepare_prompt(seq_group_metadata_list) - else: - inputs = self._prepare_decode(seq_group_metadata_list) - padded_batch_size = inputs[0].shape[0] - sample_inputs = self._prepare_sample(seq_group_metadata_list, - padded_batch_size) - return inputs + sample_inputs - def _execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> List[CompletionSequenceGroupOutput]: - inputs = self.prepare_inputs(seq_group_metadata_list) + # Prepare inputs. + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + padded_batch_size = inputs[0].shape[0] + t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size) + + # Execute the model. next_token_ids = self.model(inputs[0], inputs[1], kv_caches, - *inputs[2:]) - if not self.is_driver_worker: - return [] + *inputs[2:], t, p) + # Retrieve the outputs to CPU. next_token_ids = next_token_ids.cpu().tolist() + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support the advanced sampling parameters such as logprobs. i = 0 sampler_outputs = [] for seq_group_metadata in seq_group_metadata_list: @@ -400,6 +424,7 @@ class TPUModelRunner: kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> SamplerOutput: assert seq_group_metadata_list is not None + assert len(seq_group_metadata_list) > 0 if seq_group_metadata_list[0].is_prompt: # NOTE(woosuk): To reduce the compilation time, we only compile the # prefill inputs with batch size 1. Because the scheduler is not @@ -492,8 +517,8 @@ class ModelWrapper(nn.Module): logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = logits / t.unsqueeze(dim=1) - # FIXME(woosuk): Disabled top-p sampling since it's too slow. - # logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1, dtype=torch.float32) # FIXME(woosuk): best_of > 1 is not supported. next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)