mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 18:09:08 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
204 lines
7.5 KiB
Python
204 lines
7.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm import envs
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
try:
|
|
import flashinfer.sampling
|
|
is_flashinfer_available = True
|
|
except ImportError:
|
|
is_flashinfer_available = False
|
|
|
|
|
|
class TopKTopPSampler(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
if current_platform.is_cuda:
|
|
if is_flashinfer_available:
|
|
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
|
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
|
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
|
# default it is unused). For backward compatibility, we set
|
|
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
|
# interpret it differently in V0 and V1 samplers: In V0,
|
|
# None means False, while in V1, None means True. This is
|
|
# why we use the condition
|
|
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
|
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
|
self.forward = self.forward_cuda
|
|
else:
|
|
logger.warning(
|
|
"FlashInfer is available, but it is not enabled. "
|
|
"Falling back to the PyTorch-native implementation of "
|
|
"top-p & top-k sampling. For the best performance, "
|
|
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
|
self.forward = self.forward_native
|
|
else:
|
|
logger.warning(
|
|
"FlashInfer is not available. Falling back to the PyTorch-"
|
|
"native implementation of top-p & top-k sampling. For the "
|
|
"best performance, please install FlashInfer.")
|
|
self.forward = self.forward_native
|
|
else:
|
|
self.forward = self.forward_native
|
|
|
|
def forward_native(
|
|
self,
|
|
logits: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""PyTorch-native implementation of top-k and top-p sampling."""
|
|
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
return random_sample(probs, generators)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
logits: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""More optimized implementation for top-k and top-p sampling."""
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
|
if no_top_k and no_top_p:
|
|
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
|
# not needed. This is because `random_sample` does not require
|
|
# CPU-GPU synchronization while `flashinfer_sample` does.
|
|
return random_sample(probs, generators)
|
|
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
|
|
|
|
|
def apply_top_k_top_p(
|
|
logits: torch.Tensor,
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Apply top-k and top-p masks to the logits.
|
|
|
|
This function sorts the logits tensor, which can be slow for large batches.
|
|
"""
|
|
if no_top_k and no_top_p:
|
|
return logits
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
|
|
if not no_top_k:
|
|
# Apply top-k.
|
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
# Get all the top_k values.
|
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
top_k_mask = logits_sort < top_k_mask
|
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
|
|
|
if not no_top_p:
|
|
# Apply top-p.
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
|
# at least one
|
|
top_p_mask[:, -1] = False
|
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
|
|
|
# Re-sort the probabilities.
|
|
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
|
return logits
|
|
|
|
|
|
def random_sample(
|
|
probs: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
"""Randomly sample from the probabilities.
|
|
|
|
We use this function instead of torch.multinomial because torch.multinomial
|
|
causes CPU-GPU synchronization.
|
|
"""
|
|
q = torch.empty_like(probs)
|
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
|
# which is the common case, we first assume that every request does
|
|
# not have its own seed. Then, we overwrite the values for the requests
|
|
# that have their own seeds.
|
|
if len(generators) != probs.shape[0]:
|
|
q.exponential_()
|
|
if generators:
|
|
# TODO(woosuk): This can be slow because we handle each request
|
|
# one by one. Optimize this.
|
|
for i, generator in generators.items():
|
|
q[i].exponential_(generator=generator)
|
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
|
|
|
|
|
def flashinfer_sample(
|
|
probs: torch.Tensor,
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
"""Sample from the probabilities using FlashInfer.
|
|
|
|
Statistically, this function is equivalent to the `random_sample` function.
|
|
However, this function is faster because it avoids sorting the logits tensor
|
|
via rejection sampling.
|
|
|
|
NOTE: The outputs of this function do not necessarily match the outputs of
|
|
the `random_sample` function. It only guarantees that the outputs are
|
|
statistically equivalent.
|
|
|
|
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
|
does not. Call this function at the end of the forward pass to minimize
|
|
the synchronization overhead.
|
|
"""
|
|
assert not (no_top_k and no_top_p)
|
|
max_top_k_round = 32
|
|
batch_size = probs.shape[0]
|
|
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
|
device=probs.device)
|
|
if len(generators) != batch_size:
|
|
uniform_samples.uniform_()
|
|
if generators:
|
|
for i, generator in generators.items():
|
|
uniform_samples[:, i].uniform_(generator=generator)
|
|
|
|
if no_top_k:
|
|
# Top-p only.
|
|
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
|
probs, uniform_samples, p, deterministic=True)
|
|
elif no_top_p:
|
|
# Top-k only.
|
|
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
|
probs, uniform_samples, k, deterministic=True)
|
|
else:
|
|
# Both top-k and top-p.
|
|
next_token_ids, success = (
|
|
flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
|
probs, uniform_samples, k, p, deterministic=True))
|
|
|
|
# NOTE: CPU-GPU synchronization happens here.
|
|
if not success.all():
|
|
if not no_top_k:
|
|
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
|
if not no_top_p:
|
|
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
|
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
|
probs, uniform_samples[0], deterministic=True)
|
|
return next_token_ids.view(-1)
|