mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (#18608)
This commit is contained in:
parent
a869baca73
commit
e7523c2e03
@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
if generators:
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
return flashinfer_sample(probs, k, p, generators)
|
||||
return flashinfer_sample(logits, k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
@ -254,17 +254,17 @@ def random_sample(
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
probs: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the probabilities using FlashInfer.
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
@ -274,18 +274,19 @@ def flashinfer_sample(
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (k is None and p is None)
|
||||
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
||||
probs, k, p, deterministic=True))
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user