diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index ca5c067b364e0..05751badc7619 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -6,8 +6,12 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, - apply_top_k_top_p_tpu) +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +# isort: off +from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as + apply_top_k_top_p_tpu) +# isort: on if not current_platform.is_tpu(): pytest.skip("This test needs a TPU.", allow_module_level=True) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 7bd4a5a380ac0..cc5653b10ec1d 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -73,10 +73,8 @@ class TopKTopPSampler(nn.Module): self.forward = self.forward_native else: self.forward = self.forward_native - if current_platform.is_tpu(): - self.apply_top_k_top_p = apply_top_k_top_p_tpu - else: - self.apply_top_k_top_p = apply_top_k_top_p + + self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, @@ -125,53 +123,6 @@ class TopKTopPSampler(nn.Module): return flashinfer_sample(logits.contiguous(), k, p, generators), None -def apply_top_k_top_p_tpu( - logits: torch.Tensor, - k: torch.Tensor, - p: torch.Tensor, -) -> torch.Tensor: - """ - Apply top-k and top-p optimized for TPU. - - This algorithm avoids using torch.scatter which is extremely slow on TPU. - This is achieved by finding a "cut-off" element in the original logit, and - after thresholding the logit using this cut-off, the remaining elements - shall constitute the top-p set. - - Note: in the case of tie (i.e. multipple cut-off elements present in the - logit), all tie elements are included in the top-p set. In other words, - this function does not break ties. Instead, these tie tokens have equal - chance of being chosen during final sampling, so we can consider the tie - being broken then. - """ - probs = logits.softmax(dim=-1) - probs_sort, _ = probs.sort(dim=-1, descending=False) - - if k is not None: - top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) - top_k_count = top_k_count.unsqueeze(dim=1) - top_k_cutoff = probs_sort.gather(-1, top_k_count) - - # Make sure the no top-k rows are no-op. - no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) - top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) - - elements_to_discard = probs < top_k_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - if p is not None: - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) - - return logits - - def apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index e84136e3a6d07..17b83a4ba074c 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampler layer implementing TPU supported operations.""" +from typing import Optional + import torch import torch.nn as nn from vllm.v1.outputs import LogprobsTensors, SamplerOutput -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata _SAMPLING_EPS = 1e-5 @@ -17,7 +18,6 @@ class Sampler(nn.Module): def __init__(self): # TODO(houseroad): Add support for logprobs_mode. super().__init__() - self.topk_topp_sampler = TopKTopPSampler() def forward( self, @@ -65,13 +65,17 @@ class Sampler(nn.Module): logits = self.apply_min_p(logits, sampling_metadata.min_p) # Apply top_k and/or top_p. - random_sampled, _ = self.topk_topp_sampler( + logits = apply_top_k_top_p( logits, - sampling_metadata.generators, sampling_metadata.top_k, sampling_metadata.top_p, ) + # Random sample. + probs = logits.softmax(dim=-1, dtype=torch.float32) + random_sampled = self.random_sample(probs, + sampling_metadata.generators) + sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled) return sampled @@ -144,3 +148,66 @@ class Sampler(nn.Module): # Apply mask using boolean indexing (xla friendly) logits.masked_fill_(~valid_token_mask, -float("inf")) return logits + + def random_sample( + self, + probs: torch.Tensor, + generators: dict[int, torch.Generator], + ) -> torch.Tensor: + q = torch.empty_like(probs) + # NOTE(woosuk): To batch-process the requests without their own seeds, + # which is the common case, we first assume that every request does + # not have its own seed. Then, we overwrite the values for the requests + # that have their own seeds. + q.exponential_() + if generators: + for i, generator in generators.items(): + q[i].exponential_(generator=generator) + return probs.div_(q).argmax(dim=-1).view(-1) + + +def apply_top_k_top_p( + logits: torch.Tensor, + k: Optional[torch.Tensor], + p: Optional[torch.Tensor], +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + + if k is not None: + top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, ) + top_k_count = top_k_count.unsqueeze(dim=1) + top_k_cutoff = probs_sort.gather(-1, top_k_count) + + # Make sure the no top-k rows are no-op. + no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) + top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) + + elements_to_discard = probs < top_k_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + if p is not None: + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits