mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
Support top-k sampling (#94)
This commit is contained in:
parent
ae356774ab
commit
9f88db35da
@ -46,12 +46,13 @@ class Sampler(nn.Module):
|
||||
# Compute the log probabilities (before applying top-p).
|
||||
logprobs = torch.log(probs)
|
||||
|
||||
# Apply top-p truncation.
|
||||
top_ps = _get_top_ps(input_metadata)
|
||||
assert len(top_ps) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps):
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
|
||||
probs = _apply_top_p(probs, p)
|
||||
k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
|
||||
probs = _apply_top_p_top_k(probs, p, k)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
@ -94,31 +95,51 @@ def _get_temperatures(
|
||||
return temperatures
|
||||
|
||||
|
||||
def _get_top_ps(
|
||||
def _get_top_p_top_k(
|
||||
input_metadata: InputMetadata,
|
||||
) -> List[float]:
|
||||
vocab_size: int,
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if i < input_metadata.num_prompts:
|
||||
# A prompt input.
|
||||
top_ps.append(sampling_params.top_p)
|
||||
top_ps.append(top_p)
|
||||
top_ks.append(top_k)
|
||||
else:
|
||||
# A generation token.
|
||||
top_ps += [sampling_params.top_p] * len(seq_ids)
|
||||
return top_ps
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
|
||||
|
||||
def _apply_top_p(
|
||||
def _apply_top_p_top_k(
|
||||
probs: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): Optimize.
|
||||
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
||||
|
||||
# Apply top-p.
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
||||
probs_sort[top_p_mask] = 0.0
|
||||
|
||||
# Apply top-k.
|
||||
# Create a mask for the top-k elements.
|
||||
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
|
||||
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
|
||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
||||
probs_sort[top_k_mask] = 0.0
|
||||
|
||||
# Re-sort the probabilities.
|
||||
probs = torch.gather(
|
||||
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
|
||||
return probs
|
||||
@ -160,7 +181,7 @@ def _sample_from_prompt(
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Random sampling.
|
||||
# Sample n tokens for the prompt.
|
||||
n = sampling_params.n
|
||||
next_token_ids = torch.multinomial(
|
||||
@ -218,7 +239,7 @@ def _sample_from_generation_tokens(
|
||||
next_token_ids = [next_token_id.item()]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Neucleus sampling.
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(
|
||||
probs, num_samples=1, replacement=True)
|
||||
|
||||
@ -8,69 +8,82 @@ class SamplingParams:
|
||||
n: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
use_beam_search: bool,
|
||||
stop_token_ids: Set[int],
|
||||
max_num_steps: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if n < 1:
|
||||
raise ValueError(f'n must be at least 1, got {n}.')
|
||||
raise ValueError(f"n must be at least 1, got {n}.")
|
||||
if temperature < 0.0:
|
||||
raise ValueError(
|
||||
f'temperature must be non-negative, got {temperature}.')
|
||||
f"temperature must be non-negative, got {temperature}.")
|
||||
if not 0.0 < top_p <= 1.0:
|
||||
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
|
||||
raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
|
||||
if top_k < -1 or top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
||||
f"got {top_k}.")
|
||||
if max_num_steps < 1:
|
||||
raise ValueError(
|
||||
f'max_num_steps must be at least 1, got {max_num_steps}.')
|
||||
f"max_num_steps must be at least 1, got {max_num_steps}.")
|
||||
if num_logprobs < 0:
|
||||
raise ValueError(
|
||||
f'num_logprobs must be non-negative, got {num_logprobs}.')
|
||||
f"num_logprobs must be non-negative, got {num_logprobs}.")
|
||||
|
||||
if use_beam_search:
|
||||
if n == 1:
|
||||
raise ValueError(
|
||||
'n must be greater than 1 when using beam search.')
|
||||
"n must be greater than 1 when using beam search.")
|
||||
if temperature > 0.0:
|
||||
raise ValueError(
|
||||
'temperature must be 0 when using beam search.')
|
||||
"temperature must be 0 when using beam search.")
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
'top_p must be 1 when using beam search.')
|
||||
"top_p must be 1 when using beam search.")
|
||||
if top_k != -1:
|
||||
raise ValueError(
|
||||
"top_k must be -1 when using beam search.")
|
||||
elif temperature == 0.0:
|
||||
# Zero temperature means greedy sampling.
|
||||
if n > 1:
|
||||
raise ValueError(
|
||||
'n must be 1 when using greedy sampling.')
|
||||
"n must be 1 when using greedy sampling.")
|
||||
if top_p < 1.0:
|
||||
raise ValueError(
|
||||
'top_p must be 1 when using greedy sampling.')
|
||||
"top_p must be 1 when using greedy sampling.")
|
||||
if top_k != -1:
|
||||
raise ValueError(
|
||||
"top_k must be -1 when using greedy sampling.")
|
||||
|
||||
self.n = n
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.use_beam_search = use_beam_search
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.max_num_steps = max_num_steps
|
||||
self.num_logprobs = num_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'SamplingParams(n={self.n}, '
|
||||
f'temperature={self.temperature}, '
|
||||
f'top_p={self.top_p}, '
|
||||
f'use_beam_search={self.use_beam_search}, '
|
||||
f'stop_token_ids={self.stop_token_ids}, '
|
||||
f'max_num_steps={self.max_num_steps}, '
|
||||
f'num_logprobs={self.num_logprobs}')
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k},"
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"max_num_steps={self.max_num_steps}, "
|
||||
f"num_logprobs={self.num_logprobs}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: Dict) -> 'SamplingParams':
|
||||
def from_dict(cls, d: Dict) -> "SamplingParams":
|
||||
return cls(
|
||||
n=d.get('n', 1),
|
||||
temperature=d.get('temperature', 1.0),
|
||||
top_p=d.get('top_p', 1.0),
|
||||
use_beam_search=d.get('use_beam_search', False),
|
||||
stop_token_ids=set(d.get('stop_token_ids', set())),
|
||||
max_num_steps=d.get('max_num_steps', 16),
|
||||
num_logprobs=d.get('num_logprobs', 0),
|
||||
n=d.get("n", 1),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
top_p=d.get("top_p", 1.0),
|
||||
top_k=d.get("top_k", -1),
|
||||
use_beam_search=d.get("use_beam_search", False),
|
||||
stop_token_ids=set(d.get("stop_token_ids", set())),
|
||||
max_num_steps=d.get("max_num_steps", 16),
|
||||
num_logprobs=d.get("num_logprobs", 0),
|
||||
)
|
||||
|
||||
@ -11,8 +11,9 @@ def main(args: argparse.Namespace):
|
||||
# Test the following inputs.
|
||||
test_inputs = [
|
||||
("A robot may not injure a human being", {}), # Use default parameters.
|
||||
("What is the meaning of life?", {"n": 3, "temperature": 0.8, "top_p": 0.99}),
|
||||
("It is only with the heart that one can see rightly", {"n": 4, "use_beam_search": True, "temperature": 0.0}),
|
||||
("To be or not to be,", {"temperature": 0.8, "top_k": 5}),
|
||||
("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}),
|
||||
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
|
||||
]
|
||||
while True:
|
||||
if test_inputs:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user