mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:25:20 +08:00
[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (#18608)
This commit is contained in:
parent
a869baca73
commit
e7523c2e03
@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
|
|||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""More optimized implementation for top-k and top-p sampling."""
|
"""More optimized implementation for top-k and top-p sampling."""
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
||||||
if k is None and p is None:
|
if k is None and p is None:
|
||||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||||
# not needed. This is because `random_sample` does not require
|
# not needed. This is because `random_sample` does not require
|
||||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
if generators:
|
if generators:
|
||||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||||
"per-request generators. Falling back to "
|
"per-request generators. Falling back to "
|
||||||
"PyTorch-native implementation.")
|
"PyTorch-native implementation.")
|
||||||
return self.forward_native(logits, generators, k, p)
|
return self.forward_native(logits, generators, k, p)
|
||||||
return flashinfer_sample(probs, k, p, generators)
|
return flashinfer_sample(logits, k, p, generators)
|
||||||
|
|
||||||
def forward_tpu(
|
def forward_tpu(
|
||||||
self,
|
self,
|
||||||
@ -254,12 +254,12 @@ def random_sample(
|
|||||||
|
|
||||||
|
|
||||||
def flashinfer_sample(
|
def flashinfer_sample(
|
||||||
probs: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
generators: dict[int, torch.Generator],
|
generators: dict[int, torch.Generator],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Sample from the probabilities using FlashInfer.
|
"""Sample from the logits using FlashInfer.
|
||||||
|
|
||||||
Statistically, this function is equivalent to the `random_sample` function.
|
Statistically, this function is equivalent to the `random_sample` function.
|
||||||
However, this function is faster because it avoids sorting the logits tensor
|
However, this function is faster because it avoids sorting the logits tensor
|
||||||
@ -274,18 +274,19 @@ def flashinfer_sample(
|
|||||||
the synchronization overhead.
|
the synchronization overhead.
|
||||||
"""
|
"""
|
||||||
assert not (k is None and p is None)
|
assert not (k is None and p is None)
|
||||||
|
|
||||||
if k is None:
|
if k is None:
|
||||||
# Top-p only.
|
# Top-p only.
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||||
probs, p, deterministic=True)
|
probs, p, deterministic=True)
|
||||||
elif p is None:
|
elif p is None:
|
||||||
# Top-k only.
|
# Top-k only.
|
||||||
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||||
probs, k, deterministic=True)
|
probs, k, deterministic=True)
|
||||||
else:
|
else:
|
||||||
# Both top-k and top-p.
|
# Both top-k and top-p.
|
||||||
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||||
probs, k, p, deterministic=True))
|
logits, k, p, deterministic=True)
|
||||||
|
|
||||||
return next_token_ids.view(-1)
|
return next_token_ids.view(-1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user