vllm/vllm/v1/sample/ops/topk_topp_sampler.py
Zhang Xiangze 7bdb42b2f2
[CPU]Avoid repeated random sample compile (#28260)
Signed-off-by: Zhang Xiangze <Xiangze.Zhang@arm.com>
2025-11-07 11:03:57 +00:00

291 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from packaging import version
from vllm import envs
from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
class TopKTopPSampler(nn.Module):
"""
Module that performs optional top-k and top-p filtering followed by
weighted random sampling of logits.
Implementations may update the logits tensor in-place.
"""
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
super().__init__()
self.logprobs_mode = logprobs_mode
# flashinfer optimization does not apply if intermediate
# logprobs/logits after top_k/top_p need to be returned
if (
logprobs_mode not in ("processed_logits", "processed_logprobs")
and current_platform.is_cuda()
):
if envs.VLLM_USE_FLASHINFER_SAMPLER:
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
logger.info_once(
"Using FlashInfer for top-p & top-k sampling.",
scope="global",
)
self.forward = self.forward_cuda
else:
logger.debug_once(
"FlashInfer top-p/top-k sampling is available but disabled "
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
"after verifying accuracy for your workloads."
)
self.forward = self.forward_native
elif current_platform.is_cpu():
arch = current_platform.get_cpu_architecture()
# Fall back to native implementation for POWERPC and RISCV.
# On PowerPC argmax produces incorrect output with torch.compile.
# PR: https://github.com/vllm-project/vllm/pull/26987
if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
self.forward = self.forward_native
else:
self.forward = self.forward_cpu
else:
self.forward = self.forward_native
self.apply_top_k_top_p = apply_top_k_top_p
def forward_native(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators), logits_to_return
def forward_cuda(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""More optimized implementation for top-k and top-p sampling."""
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
if (k is None and p is None) or generators:
if generators:
logger.debug_once(
"FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation."
)
return self.forward_native(logits, generators, k, p)
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
"FlashInfer does not support returning logits/logprobs"
)
# flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def forward_cpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
def apply_top_k_top_p(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
which can be slow for large batches.
The logits tensor may be updated in-place.
"""
if p is None:
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if p is not None:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
The logits tensor may be updated in-place.
"""
no_top_k_mask = k == logits.shape[1]
# Set non-top-k rows to 1 so that we can gather.
k = k.masked_fill(no_top_k_mask, 1)
max_top_k = k.max()
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
def random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
import flashinfer
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
raise ImportError(
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
)
assert not (k is None and p is None)
if k is None:
# Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True
)
elif p is None:
# Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True
)
else:
# Both top-k and top-p.
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
logits, k, p, deterministic=True
)
return next_token_ids.view(-1)