diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index c6b492b5a3cc..57c195982ca8 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -61,3 +61,51 @@ def test_sampler_different(model_name: str): # to have deterministic results over many tokens, tests the first ~20 # tokens match. assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20] + + +@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) +# TODO TPU will appear busy if we fan-out test params here +@pytest.mark.parametrize("n_prompts", [1]) +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This test needs a TPU") +def test_logprobs(model_name: str, n_prompts: int): + """ + Request top logprobs with different sampling settings and check + that results contains the requested number, ordered ascendingly. + """ + + def check_num_logprobs(logprobs, expected_num: int): + for step in logprobs: + prev_logp = 1.0 + # order by rank + sorted_step = dict( + sorted(step.items(), key=lambda item: item[1].rank)) + + # Can contain the sampled token + assert len(step) == expected_num or len(step) == expected_num + 1 + # Check results are ordered by prob value + for rankno, (tid, logp) in enumerate(sorted_step.items()): + assert logp.logprob <= prev_logp + prev_logp = logp.logprob + assert logp.rank == rankno + 1 + + llm = LLM(model_name, + enforce_eager=False, + max_num_seqs=1, + max_model_len=128, + max_num_batched_tokens=128) + prompts = [ + "Write a short story about a robot that dreams for the first time." + ] * n_prompts + greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\ + logprobs=4) + regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ + logprobs=4) + topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\ + logprobs=4, top_k=12, top_p=0.5) + + for sp in [greedy_sampling_params, regular_sampling_params, \ + topkp_sampling_params]: + output = llm.generate(prompts, sp) + for o in output: + check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index d4ea8c2dee07..a1c7dcdb111f 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -31,8 +31,10 @@ class TPUSupportedSamplingMetadata: all_greedy: bool = True - # unsupported, you need to return an extra tensor of static size BxV - max_num_logprobs = None + # Whether logprobs are to be gathered in this batch of request. To balance + # out compile time and runtime, a fixed `max_number_logprobs` value is used + # when gathering logprobs, regardless of the values specified in the batch. + logprobs: bool = False # TODO No penalties for now no_penalties: bool = True @@ -84,10 +86,12 @@ class TPUSupportedSamplingMetadata: we want to pre-compile a graph with sampling parameters, even if they are not strictly needed for greedy decoding. """ + needs_logprobs = input_batch.max_num_logprobs>0 if \ + input_batch.max_num_logprobs else False # Early return to avoid unnecessary cpu to tpu copy if (input_batch.all_greedy is True and generate_params_if_all_greedy is False): - return cls(all_greedy=True) + return cls(all_greedy=True, logprobs=needs_logprobs) num_reqs = input_batch.num_reqs @@ -115,4 +119,5 @@ class TPUSupportedSamplingMetadata: top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( xla_device), min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( - xla_device)) + xla_device), + logprobs=needs_logprobs) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 33526c003a24..7c31a2984b30 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -22,27 +22,18 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata, ) -> SamplerOutput: - # NOTE(woosuk): Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. - # This is different from the V0 sampler, which uses the logits that - # is used for sampling (after penalties and temperature scaling). - # Use float32 for the logits. logits = logits.to(torch.float32) # Sample the next token. sampled = self.sample(logits, sampling_metadata) - # Use int32 to reduce the tensor size. - sampled = sampled.to(torch.int32) - - # These are GPU tensors. + # These are TPU tensors. sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.unsqueeze(-1), - logprobs_tensors=None, - ) + logprobs_tensors=None) return sampler_output def apply_temperature( @@ -50,7 +41,6 @@ class Sampler(nn.Module): logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: - # Use in-place division to avoid creating a new tensor. return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d716542f7898..8e162d5170d6 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -791,8 +791,18 @@ class TPUModelRunner: arange) selected_token_ids = self.sample_from_logits(logits, tpu_sampling_metadata) + + # NOTE (NickLucche) Use the original logits (before any penalties or + # temperature scaling) for the top-k logprobs. We can't enforce it due + # to recompilations outside torch.compiled code, so just make sure + # `sample_from_logits` does not modify the logits in-place. + logprobs = self.gather_logprobs(logits, selected_token_ids) \ + if tpu_sampling_metadata.logprobs else None + # Remove padding on cpu and keep dynamic op outside of xla graph. selected_token_ids = selected_token_ids.cpu()[:num_reqs] + logprobs_lists = logprobs.tolists() \ + if tpu_sampling_metadata.logprobs else None # Update the cache state concurrently. Code above will not block until # we use `selected_token_ids`. Add mark_step if post-processing changes @@ -862,7 +872,7 @@ class TPUModelRunner: req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, spec_token_ids=None, - logprobs=None, + logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, ) @@ -1121,6 +1131,22 @@ class TPUModelRunner: logger.info("Compilation finished in %.2f [secs].", end - start) self._update_num_xla_graphs("sample_from_logits") + def _precompile_gather_logprobs(self) -> None: + logger.info("Compiling gather_logprobs with different input shapes.") + start = time.perf_counter() + for num_reqs in self.num_reqs_paddings: + dummy_logits = torch.zeros((num_reqs, self.vocab_size), + device=self.device, + dtype=self._hidden_states_dtype) + dummy_tokens = torch.zeros((num_reqs, 1), + dtype=torch.int64).to(self.device) + self.gather_logprobs(dummy_logits, dummy_tokens) + logger.info(" -- num_seqs: %d", num_reqs) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in %.2f [secs].", end - start) + self._update_num_xla_graphs("gather_logprobs") + def capture_model(self) -> None: """ Precompile all the subgraphs with possible input shapes. @@ -1131,6 +1157,7 @@ class TPUModelRunner: self._precompile_compute_logits() self._precompile_structured_decoding() self._precompile_sample_from_logits() + self._precompile_gather_logprobs() def profile_run( self, @@ -1254,6 +1281,10 @@ class TPUModelRunner: def sample_from_logits( self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor: + """ + Sample with xla-friendly function. This function is to be traced + separately from `forward` for lighter compilation overhead. + """ if sampling_metadata.all_greedy: out_tokens = torch.argmax(logits, dim=-1, keepdim=True) else: @@ -1261,6 +1292,20 @@ class TPUModelRunner: sampling_metadata).sampled_token_ids return out_tokens + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def gather_logprobs(self, logits: torch.Tensor, + sampled_tokens: torch.Tensor) -> LogprobsTensors: + """ + Gather the top_logprobs with corresponding tokens. Use a fixed number + of logprobs as an alternative to having multiple pre-compiled graphs. + Select the number of logprobs actually demanded by each request on CPU. + """ + logprobs = self.sampler.compute_logprobs(logits) + return self.sampler.gather_logprobs( + logprobs, + self.model_config.max_logprobs, + token_ids=sampled_tokens.squeeze(-1)) + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def structured_decode(self, require_struct_decoding: torch.Tensor, grammar_bitmask: torch.Tensor, logits: torch.Tensor,