mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-08 02:31:52 +08:00
159 lines
5.4 KiB
Python
159 lines
5.4 KiB
Python
"""A layer that samples the next tokens from the model's outputs."""
|
|
from typing import Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.v1.outputs import SamplerOutput
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> SamplerOutput:
|
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
|
logits = self.apply_top_k_top_p(logits, sampling_metadata)
|
|
|
|
probs = self.get_probs(logits)
|
|
sampled = self.sample(probs, sampling_metadata)
|
|
# Use int32 to reduce the tensor size.
|
|
sampled = sampled.to(torch.int32)
|
|
|
|
if sampling_metadata.max_num_logprobs > 0:
|
|
logprobs = self.get_logprobs(logits)
|
|
# FIXME: Mask the sampled token_id, get topk logprobs,
|
|
# and concatenate the topk with the sampled token_id.
|
|
topk_logprobs, topk_indices = torch.topk(
|
|
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
|
# Use int32 to reduce the tensor size.
|
|
topk_indices = topk_indices.to(torch.int32)
|
|
else:
|
|
topk_logprobs = None
|
|
topk_indices = None
|
|
|
|
sampler_output = SamplerOutput(
|
|
sampled_token_ids=sampled,
|
|
logprob_token_ids=topk_indices,
|
|
logprobs=topk_logprobs,
|
|
prompt_logprob_token_ids=None,
|
|
prompt_logprobs=None,
|
|
)
|
|
return sampler_output
|
|
|
|
def apply_temperature(
|
|
self,
|
|
logits: torch.Tensor,
|
|
temp: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Use float32 to apply temperature scaling.
|
|
logits = logits.to(torch.float32)
|
|
# Avoid division by zero.
|
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
|
# Use in-place division to avoid creating a new tensor.
|
|
logits.div_(temp.unsqueeze(dim=1))
|
|
return logits
|
|
|
|
def apply_top_k_top_p(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
return _apply_top_k_top_p(
|
|
logits,
|
|
sampling_metadata.no_top_k,
|
|
sampling_metadata.top_k,
|
|
sampling_metadata.no_top_p,
|
|
sampling_metadata.top_p,
|
|
)
|
|
|
|
def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
return torch.softmax(logits, dim=-1, dtype=torch.float32)
|
|
|
|
def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
|
|
|
|
def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
|
|
return probs.argmax(dim=-1).view(-1)
|
|
|
|
def random_sample(
|
|
self,
|
|
probs: torch.Tensor,
|
|
generators: Dict[int, torch.Generator],
|
|
) -> torch.Tensor:
|
|
q = torch.empty_like(probs)
|
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
|
# which is the common case, we first assume that every request does
|
|
# not have its own seed. Then, we overwrite the values for the requests
|
|
# that have their own seeds.
|
|
if len(generators) != probs.shape[0]:
|
|
# This might still be done here unnecessarily if there are greedies
|
|
q.exponential_()
|
|
if generators:
|
|
# TODO(woosuk): This can be slow because we handle each request
|
|
# one by one. Optimize this.
|
|
for i, generator in generators.items():
|
|
q[i].exponential_(generator=generator)
|
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
|
|
|
def sample(
|
|
self,
|
|
probs: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
assert not (sampling_metadata.all_greedy
|
|
and sampling_metadata.all_random)
|
|
if sampling_metadata.all_greedy:
|
|
return self.greedy_sample(probs)
|
|
if sampling_metadata.all_random:
|
|
return self.random_sample(probs, sampling_metadata.generators)
|
|
|
|
greedy_sampled = self.greedy_sample(probs)
|
|
random_sampled = self.random_sample(probs,
|
|
sampling_metadata.generators)
|
|
sampled = torch.where(
|
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
|
greedy_sampled,
|
|
random_sampled,
|
|
)
|
|
return sampled
|
|
|
|
|
|
# TODO(woosuk): Optimize this with a custom kernel.
|
|
def _apply_top_k_top_p(
|
|
logits: torch.Tensor,
|
|
no_top_k: bool,
|
|
k: torch.Tensor,
|
|
no_top_p: bool,
|
|
p: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
if no_top_k and no_top_p:
|
|
return logits
|
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
|
|
|
if not no_top_k:
|
|
# Apply top-k.
|
|
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
|
# Get all the top_k values.
|
|
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
|
top_k_mask = logits_sort < top_k_mask
|
|
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
|
|
|
if not no_top_p:
|
|
# Apply top-p.
|
|
probs_sort = logits_sort.softmax(dim=-1)
|
|
probs_sum = probs_sort.cumsum(dim=-1)
|
|
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
|
# at least one
|
|
top_p_mask[:, -1] = False
|
|
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
|
|
|
# Re-sort the probabilities.
|
|
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
|
return logits
|