[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2024-12-27 09:32:38 +09:00 committed by GitHub
parent 0c0c2015c5
commit 371d04d39b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 358 additions and 193 deletions

View File

@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
no_top_p=True, no_top_p=True,
no_top_k=True, no_top_k=True,
generators={}, generators={},
max_num_logprobs=VOCAB_SIZE, max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device), vocab_size, device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
sampling_metadata.min_tokens = min_tokens sampling_metadata.min_tokens = min_tokens
sampling_metadata.stop_token_ids = stop_token_ids sampling_metadata.stop_token_ids = stop_token_ids
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
for vocab in range(VOCAB_SIZE): for token_id in range(VOCAB_SIZE):
# Verify that the logprobs for stop token ids is set if token_id in stop_token_ids[batch_idx]:
# to -inf. assert logits[batch_idx][token_id] == -float("inf")
logprob_index = torch.where(
sampler_output.logprob_token_ids[batch_idx] ==
vocab)[0].item()
if vocab in stop_token_ids[batch_idx]:
assert sampler_output.logprobs[batch_idx][
logprob_index] == -float("inf")
else: else:
assert sampler_output.logprobs[batch_idx][ assert logits[batch_idx][token_id] != -float("inf")
logprob_index] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
batch_size, presence_penalty, torch.device(device)) batch_size, presence_penalty, torch.device(device))
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# The logprobs in the SamplerOutput are arranged in descending order. # Since all tokens initially have the same logits, the non-penalized
# Since all tokens initially have the same logprobs, the non-penalized # token ID will be the one with the highest logit value, while the
# tokens will appear at the beginning, while the penalized tokens # penalized token ID will be the one with the lowest logit value.
# will appear at the end of the list. non_penalized_token_id = logits[batch_idx].argmax().item()
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][ penalized_token_id = logits[batch_idx].argmin().item()
VOCAB_SIZE - 1]
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
assert non_penalized_log_prod > penalized_log_prod
if presence_penalty > 0: if presence_penalty > 0:
# If `presence_penalty` is set to a value greater than 0, it # If `presence_penalty` is set to a value greater than 0, it
# indicates a preference for new tokens over those already # indicates a preference for new tokens over those already
@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
sampling_metadata.output_token_ids = output_token_ids sampling_metadata.output_token_ids = output_token_ids
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] non_penalized_token_id = logits[batch_idx].argmax().item()
non_penalized_token_id = logprobs_token_ids[0] penalized_token_id = logits[batch_idx].argmin().item()
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
distinct_sorted_token_ids_in_output = \ distinct_sorted_token_ids_in_output = \
sorted_token_ids_in_output[batch_idx] sorted_token_ids_in_output[batch_idx]
most_frequent_token_id = distinct_sorted_token_ids_in_output[ most_frequent_token_id = distinct_sorted_token_ids_in_output[
@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
batch_size, repetition_penalty, torch.device(device)) batch_size, repetition_penalty, torch.device(device))
sampling_metadata.no_penalties = False sampling_metadata.no_penalties = False
sampler = Sampler() sampler = Sampler()
sampler_output = sampler(fake_logits, sampling_metadata) logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx] non_penalized_token_id = logits[batch_idx].argmax().item()
non_penalized_token_id = logprobs_token_ids[0] penalized_token_id = logits[batch_idx].argmin().item()
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
prompt_tokens = sampling_metadata.prompt_token_ids[ prompt_tokens = sampling_metadata.prompt_token_ids[
batch_idx][:].tolist() batch_idx][:].tolist()
output_tokens = sampling_metadata.output_token_ids[batch_idx] output_tokens = sampling_metadata.output_token_ids[batch_idx]

View File

@ -30,7 +30,7 @@ if TYPE_CHECKING:
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
@ -277,7 +277,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": "VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
# If set, vllm will force flashinfer to use tensor cores; # If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture. # otherwise will use heuristic based on model architecture.

View File

View File

@ -0,0 +1,57 @@
from typing import List, Set, Tuple
import torch
from vllm.model_executor.layers.utils import (
apply_penalties as _apply_penalties)
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if (len(output_token_ids[index]) < min_token):
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]]) -> torch.Tensor:
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return _apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)

View File

@ -0,0 +1,201 @@
from typing import Dict
import torch
import torch.nn as nn
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
try:
import flashinfer.sampling
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
class TopKTopPSampler(nn.Module):
def __init__(self):
super().__init__()
if current_platform.is_cuda:
if is_flashinfer_available:
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
# default it is unused). For backward compatibility, we set
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
# interpret it differently in V0 and V1 samplers: In V0,
# None means False, while in V1, None means True. This is
# why we use the condition
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
logger.info("Using FlashInfer for top-p & top-k sampling.")
self.forward = self.forward_cuda
else:
logger.warning(
"FlashInfer is available, but it is not enabled. "
"Falling back to the PyTorch-native implementation of "
"top-p & top-k sampling. For the best performance, "
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
self.forward = self.forward_native
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FalshInfer.")
self.forward = self.forward_native
else:
self.forward = self.forward_native
def forward_native(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation of top-k and top-p sampling."""
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def forward_cuda(
self,
logits: torch.Tensor,
generators: Dict[int, torch.Generator],
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if no_top_k and no_top_p:
# We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does.
return random_sample(probs, generators)
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
This function sorts the logits tensor, which can be slow for large batches.
"""
if no_top_k and no_top_p:
return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def random_sample(
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Randomly sample from the probabilities.
We use this function instead of torch.multinomial because torch.multinomial
causes CPU-GPU synchronization.
"""
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def flashinfer_sample(
probs: torch.Tensor,
no_top_k: bool,
k: torch.Tensor,
no_top_p: bool,
p: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (no_top_k and no_top_p)
max_top_k_round = 32
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if len(generators) != batch_size:
uniform_samples.uniform_()
if generators:
for i, generator in generators.items():
uniform_samples[:, i].uniform_(generator=generator)
if no_top_k:
# Top-p only.
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
probs, uniform_samples, p, deterministic=True)
elif no_top_p:
# Top-k only.
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
probs, uniform_samples, k, deterministic=True)
else:
# Both top-k and top-p.
next_token_ids, success = (
flashinfer.sampling.top_k_top_p_sampling_from_probs(
probs, uniform_samples, k, p, deterministic=True))
# NOTE: CPU-GPU synchronization happens here.
if not success.all():
if not no_top_k:
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
if not no_top_p:
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0], deterministic=True)
return next_token_ids.view(-1)

View File

@ -1,53 +1,55 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Set, Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.outputs import SamplerOutput from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_min_token_penalties,
apply_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward( def forward(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids, needs_logprobs = sampling_metadata.max_num_logprobs > 0
sampling_metadata.stop_token_ids, if needs_logprobs:
sampling_metadata.min_tokens) # NOTE(woosuk): Use the original logits (before any penalties or
if not sampling_metadata.no_penalties: # temperature scaling) for the top-k logprobs.
assert sampling_metadata.prompt_token_ids is not None # This is different from the V0 sampler, which uses the logits that
_apply_penalties(logits, sampling_metadata.prompt_token_ids, # is used for sampling (after penalties and temperature scaling).
sampling_metadata.presence_penalties, # NOTE: We compute logprobs first because the below ops may
sampling_metadata.frequency_penalties, # modify the logits tensor in-place (and we don't want to clone
sampling_metadata.repetition_penalties, # the logits tensor for memory efficiency).
sampling_metadata.output_token_ids) topk_logprobs, topk_indices = self.get_topk_logprobs(
logits = self.apply_temperature(logits, sampling_metadata.temperature) logits, sampling_metadata)
logits = self.apply_top_k_top_p(logits, sampling_metadata)
probs = self.get_probs(logits)
sampled = self.sample(probs, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
if sampling_metadata.max_num_logprobs > 0:
logprobs = self.get_logprobs(logits)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
else: else:
topk_logprobs = None topk_logprobs = None
topk_indices = None topk_indices = None
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# NOTE: CPU-GPU synchronization happens here. # NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput( sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(), sampled_token_ids=sampled.tolist(),
@ -63,71 +65,37 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
temp: torch.Tensor, temp: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Use float32 to apply temperature scaling.
logits = logits.to(torch.float32)
# Avoid division by zero. # Avoid division by zero.
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp) temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
# Use in-place division to avoid creating a new tensor. # Use in-place division to avoid creating a new tensor.
logits.div_(temp.unsqueeze(dim=1)) logits.div_(temp.unsqueeze(dim=1))
return logits return logits
def apply_top_k_top_p( def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
self, return logits.argmax(dim=-1).view(-1)
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return _apply_top_k_top_p(
logits,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.softmax(logits, dim=-1, dtype=torch.float32)
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
return probs.argmax(dim=-1).view(-1)
def random_sample(
self,
probs: torch.Tensor,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample( def sample(
self, self,
probs: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert not (sampling_metadata.all_greedy assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random) and sampling_metadata.all_random)
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return self.greedy_sample(probs) return self.greedy_sample(logits)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators)
greedy_sampled = self.greedy_sample(probs) random_sampled = self.topk_topp_sampler(
random_sampled = self.random_sample(probs, logits,
sampling_metadata.generators) sampling_metadata.generators,
sampling_metadata.no_top_k,
sampling_metadata.top_k,
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)
if sampling_metadata.all_random:
return random_sampled
greedy_sampled = self.greedy_sample(logits)
sampled = torch.where( sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS, sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, greedy_sampled,
@ -135,86 +103,34 @@ class Sampler(nn.Module):
) )
return sampled return sampled
def get_topk_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
# TODO(woosuk): Optimize this with a custom kernel. def apply_penalties(
def _apply_top_k_top_p( self,
logits: torch.Tensor, logits: torch.Tensor,
no_top_k: bool, sampling_metadata: SamplingMetadata,
k: torch.Tensor, ) -> torch.Tensor:
no_top_p: bool, apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
p: torch.Tensor, sampling_metadata.stop_token_ids,
) -> torch.Tensor: sampling_metadata.min_tokens)
if no_top_k and no_top_p: if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_penalties(logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits return logits
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
if not no_top_k:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
if not no_top_p:
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
def _apply_min_token_penalties(logits: torch.Tensor,
output_token_ids: List[List[int]],
stop_token_ids: List[Set[int]],
min_tokens: List[int]):
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
for index, min_token in enumerate(min_tokens):
if (len(output_token_ids[index]) < min_token):
for stop_token_id in stop_token_ids[index]:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: List[List[int]]):
"""
Applies presence, frequency and repetition penalties to the logits.
"""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
logits.device)
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
presence_penalties, frequency_penalties,
repetition_penalties)
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
device: torch.device) -> torch.Tensor:
"""
Convert the different list data structures to tensors.
"""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)