[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (#18608)

This commit is contained in:
Lukas Geiger 2025-05-26 16:49:36 +01:00 committed by GitHub
parent a869baca73
commit e7523c2e03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor],
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
if generators:
logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators)
return flashinfer_sample(logits, k, p, generators)
def forward_tpu(
self,
@ -254,17 +254,17 @@ def random_sample(
def flashinfer_sample(
probs: torch.Tensor,
logits: torch.Tensor,
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
"""Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead.
"""
assert not (k is None and p is None)
if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, k, p, deterministic=True))
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True)
return next_token_ids.view(-1)