diff --git a/tests/v1/sample/test_topk_topp_sampler.py b/tests/v1/sample/test_topk_topp_sampler.py new file mode 100644 index 0000000000000..8a5076412cfae --- /dev/null +++ b/tests/v1/sample/test_topk_topp_sampler.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch import Generator + +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +DEVICE = "cuda" + +BATCH_SIZE = 1024 +VOCAB_SIZE = 128 * 1024 + + +def test_topk_impl_equivalance(): + + with torch.device(DEVICE): + generator = Generator(device=DEVICE).manual_seed(33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator) + + # Random top-k values between 1 and 9. + k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_( + torch.randint(0, + 2, (BATCH_SIZE, ), + generator=generator, + dtype=bool), VOCAB_SIZE) + + # Top-k only implementation + result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) + + # Top-p + top-k + no_op_top_p = torch.tensor([1.0]) + result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p) + + assert torch.allclose(result1, result2) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 1dea711874bfd..5dfcae08b170c 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -19,6 +19,12 @@ except ImportError: class TopKTopPSampler(nn.Module): + """ + Module that performs optional top-k and top-p filtering followed by + weighted random sampling of logits. + + Implementations may update the logits tensor in-place. + """ def __init__(self): super().__init__() @@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module): k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: - """PyTorch-native implementation of top-k and top-p sampling.""" + """ + PyTorch-native implementation of top-k and top-p sampling. + + The logits tensor may be updated in-place. + """ logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) @@ -136,10 +146,18 @@ def apply_top_k_top_p( ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. - This function sorts the logits tensor, which can be slow for large batches. + If a top-p is used, this function will sort the logits tensor, + which can be slow for large batches. + + The logits tensor may be updated in-place. """ - if k is None and p is None: - return logits + if p is None: + if k is None: + return logits + + # Avoid sorting vocab for top-k only case. + return apply_top_k_only(logits, k) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) if k is not None: @@ -153,7 +171,7 @@ def apply_top_k_top_p( if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one top_p_mask[:, -1] = False @@ -164,6 +182,31 @@ def apply_top_k_top_p( return logits +def apply_top_k_only( + logits: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + """ + Apply top-k mask to the logits. + + This implementation doesn't involve sorting the entire vocab. + + The logits tensor may be updated in-place. + """ + no_top_k_mask = k == logits.shape[1] + # Set non-top-k rows to 1 so that we can gather. + k = k.masked_fill(no_top_k_mask, 1) + max_top_k = k.max() + # topk.values tensor has shape [batch_size, max_top_k]. + # Convert top k to 0-based index in range [0, max_top_k). + k_index = k.sub_(1).unsqueeze(1) + top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index) + # Handle non-topk rows. + top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) + logits.masked_fill_(logits < top_k_mask, -float("inf")) + return logits + + def random_sample( probs: torch.Tensor, generators: dict[int, torch.Generator], diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 397a049dc2543..004f98496b0d7 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -87,6 +87,12 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + """Sample logits based on sampling metadata. + + The various logits processing functions called in this method + may update the logits tensor in-place. + """ + assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: