mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:35:00 +08:00
[V1] Use FlashInfer Sampling Kernel for Top-P & Top-K Sampling (#11394)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0c0c2015c5
commit
371d04d39b
@ -68,7 +68,7 @@ def _create_default_sampling_metadata(
|
||||
no_top_p=True,
|
||||
no_top_k=True,
|
||||
generators={},
|
||||
max_num_logprobs=VOCAB_SIZE,
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
@ -169,20 +169,14 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampling_metadata.stop_token_ids = stop_token_ids
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
for vocab in range(VOCAB_SIZE):
|
||||
# Verify that the logprobs for stop token ids is set
|
||||
# to -inf.
|
||||
logprob_index = torch.where(
|
||||
sampler_output.logprob_token_ids[batch_idx] ==
|
||||
vocab)[0].item()
|
||||
if vocab in stop_token_ids[batch_idx]:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] == -float("inf")
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id in stop_token_ids[batch_idx]:
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
assert sampler_output.logprobs[batch_idx][
|
||||
logprob_index] != -float("inf")
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@ -205,18 +199,14 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
||||
batch_size, presence_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
# The logprobs in the SamplerOutput are arranged in descending order.
|
||||
# Since all tokens initially have the same logprobs, the non-penalized
|
||||
# tokens will appear at the beginning, while the penalized tokens
|
||||
# will appear at the end of the list.
|
||||
penalized_token_id = sampler_output.logprob_token_ids[batch_idx][
|
||||
VOCAB_SIZE - 1]
|
||||
penalized_log_prod = sampler_output.logprobs[batch_idx][VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = sampler_output.logprob_token_ids[batch_idx][0]
|
||||
non_penalized_log_prod = sampler_output.logprobs[batch_idx][0]
|
||||
assert non_penalized_log_prod > penalized_log_prod
|
||||
# Since all tokens initially have the same logits, the non-penalized
|
||||
# token ID will be the one with the highest logit value, while the
|
||||
# penalized token ID will be the one with the lowest logit value.
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
if presence_penalty > 0:
|
||||
# If `presence_penalty` is set to a value greater than 0, it
|
||||
# indicates a preference for new tokens over those already
|
||||
@ -256,11 +246,11 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
||||
sampling_metadata.output_token_ids = output_token_ids
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
distinct_sorted_token_ids_in_output = \
|
||||
sorted_token_ids_in_output[batch_idx]
|
||||
most_frequent_token_id = distinct_sorted_token_ids_in_output[
|
||||
@ -305,11 +295,11 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
batch_size, repetition_penalty, torch.device(device))
|
||||
sampling_metadata.no_penalties = False
|
||||
sampler = Sampler()
|
||||
sampler_output = sampler(fake_logits, sampling_metadata)
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logprobs_token_ids = sampler_output.logprob_token_ids[batch_idx]
|
||||
non_penalized_token_id = logprobs_token_ids[0]
|
||||
penalized_token_id = logprobs_token_ids[VOCAB_SIZE - 1]
|
||||
non_penalized_token_id = logits[batch_idx].argmax().item()
|
||||
penalized_token_id = logits[batch_idx].argmin().item()
|
||||
prompt_tokens = sampling_metadata.prompt_token_ids[
|
||||
batch_idx][:].tolist()
|
||||
output_tokens = sampling_metadata.output_token_ids[batch_idx]
|
||||
|
||||
@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
|
||||
VLLM_TRACE_FUNCTION: int = 0
|
||||
VLLM_ATTENTION_BACKEND: Optional[str] = None
|
||||
VLLM_USE_FLASHINFER_SAMPLER: bool = False
|
||||
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
|
||||
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
|
||||
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
|
||||
VLLM_PP_LAYER_PARTITION: Optional[str] = None
|
||||
@ -277,7 +277,8 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
|
||||
# If set, vllm will use flashinfer sampler
|
||||
"VLLM_USE_FLASHINFER_SAMPLER":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
|
||||
lambda: bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
|
||||
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None,
|
||||
|
||||
# If set, vllm will force flashinfer to use tensor cores;
|
||||
# otherwise will use heuristic based on model architecture.
|
||||
|
||||
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
57
vllm/v1/sample/ops/penalties.py
Normal file
57
vllm/v1/sample/ops/penalties.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import (
|
||||
apply_penalties as _apply_penalties)
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_min_token_penalties(logits: torch.Tensor,
|
||||
output_token_ids: List[List[int]],
|
||||
stop_token_ids: List[Set[int]],
|
||||
min_tokens: List[int]) -> None:
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||
for index, min_token in enumerate(min_tokens):
|
||||
if (len(output_token_ids[index]) < min_token):
|
||||
for stop_token_id in stop_token_ids[index]:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: List[List[int]]) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return _apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
201
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
201
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@ -0,0 +1,201 @@
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import flashinfer.sampling
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda:
|
||||
if is_flashinfer_available:
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||
# interpret it differently in V0 and V1 samplers: In V0,
|
||||
# None means False, while in V1, None means True. This is
|
||||
# why we use the condition
|
||||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is available, but it is not enabled. "
|
||||
"Falling back to the PyTorch-native implementation of "
|
||||
"top-p & top-k sampling. For the best performance, "
|
||||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FalshInfer.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation of top-k and top-p sampling."""
|
||||
logits = apply_top_k_top_p(logits, no_top_k, k, no_top_p, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if no_top_k and no_top_p:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
return random_sample(probs, generators)
|
||||
return flashinfer_sample(probs, no_top_k, k, no_top_p, p, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
This function sorts the logits tensor, which can be slow for large batches.
|
||||
"""
|
||||
if no_top_k and no_top_p:
|
||||
return logits
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if not no_top_k:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if not no_top_p:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
probs: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the probabilities using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (no_top_k and no_top_p)
|
||||
max_top_k_round = 32
|
||||
batch_size = probs.shape[0]
|
||||
uniform_samples = torch.empty((max_top_k_round, batch_size),
|
||||
device=probs.device)
|
||||
if len(generators) != batch_size:
|
||||
uniform_samples.uniform_()
|
||||
if generators:
|
||||
for i, generator in generators.items():
|
||||
uniform_samples[:, i].uniform_(generator=generator)
|
||||
|
||||
if no_top_k:
|
||||
# Top-p only.
|
||||
next_token_ids, success = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, uniform_samples, p, deterministic=True)
|
||||
elif no_top_p:
|
||||
# Top-k only.
|
||||
next_token_ids, success = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, uniform_samples, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids, success = (
|
||||
flashinfer.sampling.top_k_top_p_sampling_from_probs(
|
||||
probs, uniform_samples, k, p, deterministic=True))
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
if not success.all():
|
||||
if not no_top_k:
|
||||
probs = flashinfer.sampling.top_k_renorm_prob(probs, k)
|
||||
if not no_top_p:
|
||||
probs = flashinfer.sampling.top_p_renorm_prob(probs, p)
|
||||
next_token_ids = flashinfer.sampling.sampling_from_probs(
|
||||
probs, uniform_samples[0], deterministic=True)
|
||||
return next_token_ids.view(-1)
|
||||
@ -1,53 +1,55 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.penalties import (apply_min_token_penalties,
|
||||
apply_penalties)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
_apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||
sampling_metadata.stop_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
_apply_penalties(logits, sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
||||
probs = self.get_probs(logits)
|
||||
sampled = self.sample(probs, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
if sampling_metadata.max_num_logprobs > 0:
|
||||
logprobs = self.get_logprobs(logits)
|
||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||
# and concatenate the topk with the sampled token_id.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||
# Use int32 to reduce the tensor size.
|
||||
topk_indices = topk_indices.to(torch.int32)
|
||||
needs_logprobs = sampling_metadata.max_num_logprobs > 0
|
||||
if needs_logprobs:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# NOTE: We compute logprobs first because the below ops may
|
||||
# modify the logits tensor in-place (and we don't want to clone
|
||||
# the logits tensor for memory efficiency).
|
||||
topk_logprobs, topk_indices = self.get_topk_logprobs(
|
||||
logits, sampling_metadata)
|
||||
else:
|
||||
topk_logprobs = None
|
||||
topk_indices = None
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# NOTE: CPU-GPU synchronization happens here.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled.tolist(),
|
||||
@ -63,71 +65,37 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use float32 to apply temperature scaling.
|
||||
logits = logits.to(torch.float32)
|
||||
# Avoid division by zero.
|
||||
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(temp.unsqueeze(dim=1))
|
||||
return logits
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
return _apply_top_k_top_p(
|
||||
logits,
|
||||
sampling_metadata.no_top_k,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.no_top_p,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
|
||||
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
|
||||
return probs.argmax(dim=-1).view(-1)
|
||||
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
# This might still be done here unnecessarily if there are greedies
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert not (sampling_metadata.all_greedy
|
||||
and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_greedy:
|
||||
return self.greedy_sample(probs)
|
||||
if sampling_metadata.all_random:
|
||||
return self.random_sample(probs, sampling_metadata.generators)
|
||||
return self.greedy_sample(logits)
|
||||
|
||||
greedy_sampled = self.greedy_sample(probs)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators)
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.no_top_k,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.no_top_p,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
if sampling_metadata.all_random:
|
||||
return random_sampled
|
||||
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
@ -135,86 +103,34 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampled
|
||||
|
||||
|
||||
# TODO(woosuk): Optimize this with a custom kernel.
|
||||
def _apply_top_k_top_p(
|
||||
def get_topk_logprobs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
no_top_k: bool,
|
||||
k: torch.Tensor,
|
||||
no_top_p: bool,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if no_top_k and no_top_p:
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||
# and concatenate the topk with the sampled token_id.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||
# Use int32 to reduce the tensor size.
|
||||
topk_indices = topk_indices.to(torch.int32)
|
||||
return topk_logprobs, topk_indices
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
|
||||
sampling_metadata.stop_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_penalties(logits,
|
||||
sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.presence_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.output_token_ids)
|
||||
return logits
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if not no_top_k:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if not no_top_p:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = probs_sort.cumsum(dim=-1)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_min_token_penalties(logits: torch.Tensor,
|
||||
output_token_ids: List[List[int]],
|
||||
stop_token_ids: List[Set[int]],
|
||||
min_tokens: List[int]):
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: List[Tuple[int, int]] = []
|
||||
for index, min_token in enumerate(min_tokens):
|
||||
if (len(output_token_ids[index]) < min_token):
|
||||
for stop_token_id in stop_token_ids[index]:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def _apply_penalties(logits: torch.Tensor, prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: List[List[int]]):
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user