mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-05 10:35:39 +08:00
[V1][Sampler] Faster top-k only implementation (#15478)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
733e7c9e95
commit
35fad35a48
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import torch
|
||||||
|
from torch import Generator
|
||||||
|
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
|
||||||
|
BATCH_SIZE = 1024
|
||||||
|
VOCAB_SIZE = 128 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
def test_topk_impl_equivalance():
|
||||||
|
|
||||||
|
with torch.device(DEVICE):
|
||||||
|
generator = Generator(device=DEVICE).manual_seed(33)
|
||||||
|
|
||||||
|
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||||
|
|
||||||
|
# Random top-k values between 1 and 9.
|
||||||
|
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||||
|
|
||||||
|
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||||
|
k.masked_fill_(
|
||||||
|
torch.randint(0,
|
||||||
|
2, (BATCH_SIZE, ),
|
||||||
|
generator=generator,
|
||||||
|
dtype=bool), VOCAB_SIZE)
|
||||||
|
|
||||||
|
# Top-k only implementation
|
||||||
|
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||||
|
|
||||||
|
# Top-p + top-k
|
||||||
|
no_op_top_p = torch.tensor([1.0])
|
||||||
|
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||||
|
|
||||||
|
assert torch.allclose(result1, result2)
|
||||||
@ -19,6 +19,12 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class TopKTopPSampler(nn.Module):
|
class TopKTopPSampler(nn.Module):
|
||||||
|
"""
|
||||||
|
Module that performs optional top-k and top-p filtering followed by
|
||||||
|
weighted random sampling of logits.
|
||||||
|
|
||||||
|
Implementations may update the logits tensor in-place.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -84,7 +90,11 @@ class TopKTopPSampler(nn.Module):
|
|||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""PyTorch-native implementation of top-k and top-p sampling."""
|
"""
|
||||||
|
PyTorch-native implementation of top-k and top-p sampling.
|
||||||
|
|
||||||
|
The logits tensor may be updated in-place.
|
||||||
|
"""
|
||||||
logits = apply_top_k_top_p(logits, k, p)
|
logits = apply_top_k_top_p(logits, k, p)
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
@ -136,10 +146,18 @@ def apply_top_k_top_p(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Apply top-k and top-p masks to the logits.
|
"""Apply top-k and top-p masks to the logits.
|
||||||
|
|
||||||
This function sorts the logits tensor, which can be slow for large batches.
|
If a top-p is used, this function will sort the logits tensor,
|
||||||
|
which can be slow for large batches.
|
||||||
|
|
||||||
|
The logits tensor may be updated in-place.
|
||||||
"""
|
"""
|
||||||
if k is None and p is None:
|
if p is None:
|
||||||
return logits
|
if k is None:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Avoid sorting vocab for top-k only case.
|
||||||
|
return apply_top_k_only(logits, k)
|
||||||
|
|
||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
@ -153,7 +171,7 @@ def apply_top_k_top_p(
|
|||||||
if p is not None:
|
if p is not None:
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sort = logits_sort.softmax(dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
probs_sum = probs_sort.cumsum(dim=-1)
|
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||||
# at least one
|
# at least one
|
||||||
top_p_mask[:, -1] = False
|
top_p_mask[:, -1] = False
|
||||||
@ -164,6 +182,31 @@ def apply_top_k_top_p(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def apply_top_k_only(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply top-k mask to the logits.
|
||||||
|
|
||||||
|
This implementation doesn't involve sorting the entire vocab.
|
||||||
|
|
||||||
|
The logits tensor may be updated in-place.
|
||||||
|
"""
|
||||||
|
no_top_k_mask = k == logits.shape[1]
|
||||||
|
# Set non-top-k rows to 1 so that we can gather.
|
||||||
|
k = k.masked_fill(no_top_k_mask, 1)
|
||||||
|
max_top_k = k.max()
|
||||||
|
# topk.values tensor has shape [batch_size, max_top_k].
|
||||||
|
# Convert top k to 0-based index in range [0, max_top_k).
|
||||||
|
k_index = k.sub_(1).unsqueeze(1)
|
||||||
|
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index)
|
||||||
|
# Handle non-topk rows.
|
||||||
|
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||||
|
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def random_sample(
|
def random_sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
|
|||||||
@ -87,6 +87,12 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Sample logits based on sampling metadata.
|
||||||
|
|
||||||
|
The various logits processing functions called in this method
|
||||||
|
may update the logits tensor in-place.
|
||||||
|
"""
|
||||||
|
|
||||||
assert not (sampling_metadata.all_greedy
|
assert not (sampling_metadata.all_greedy
|
||||||
and sampling_metadata.all_random)
|
and sampling_metadata.all_random)
|
||||||
if sampling_metadata.all_random:
|
if sampling_metadata.all_random:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user