Support top-k sampling (#94)

This commit is contained in:
Woosuk Kwon 2023-05-10 12:51:36 -07:00 committed by GitHub
parent ae356774ab
commit 9f88db35da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 43 deletions

View File

@ -46,12 +46,13 @@ class Sampler(nn.Module):
# Compute the log probabilities (before applying top-p). # Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs) logprobs = torch.log(probs)
# Apply top-p truncation. # Apply top-p and top-k truncation.
top_ps = _get_top_ps(input_metadata) top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == probs.shape[0] assert len(top_ps) == len(top_ks) == probs.shape[0]
if any(p < 1.0 for p in top_ps): if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
probs = _apply_top_p(probs, p) k = torch.tensor(top_ks, dtype=torch.int, device=probs.device)
probs = _apply_top_p_top_k(probs, p, k)
# Sample the next tokens. # Sample the next tokens.
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
@ -94,31 +95,51 @@ def _get_temperatures(
return temperatures return temperatures
def _get_top_ps( def _get_top_p_top_k(
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> List[float]: vocab_size: int,
) -> Tuple[List[float], List[int]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
if i < input_metadata.num_prompts: if i < input_metadata.num_prompts:
# A prompt input. # A prompt input.
top_ps.append(sampling_params.top_p) top_ps.append(top_p)
top_ks.append(top_k)
else: else:
# A generation token. # A generation token.
top_ps += [sampling_params.top_p] * len(seq_ids) top_ps += [top_p] * len(seq_ids)
return top_ps top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks
def _apply_top_p( def _apply_top_p_top_k(
probs: torch.Tensor, probs: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(woosuk): Optimize. # TODO(woosuk): Optimize.
probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
# Apply top-p.
probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[mask] = 0.0 probs_sort[top_p_mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
probs_sort[top_k_mask] = 0.0
# Re-sort the probabilities.
probs = torch.gather( probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
return probs return probs
@ -160,7 +181,7 @@ def _sample_from_prompt(
next_token_id = torch.argmax(prob) next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()] next_token_ids = [next_token_id.item()]
else: else:
# Neucleus sampling. # Random sampling.
# Sample n tokens for the prompt. # Sample n tokens for the prompt.
n = sampling_params.n n = sampling_params.n
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(
@ -218,7 +239,7 @@ def _sample_from_generation_tokens(
next_token_ids = [next_token_id.item()] next_token_ids = [next_token_id.item()]
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
else: else:
# Neucleus sampling. # Random sampling.
# Sample 1 token for each sequence in the group. # Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial( next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True) probs, num_samples=1, replacement=True)

View File

@ -8,69 +8,82 @@ class SamplingParams:
n: int, n: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int,
use_beam_search: bool, use_beam_search: bool,
stop_token_ids: Set[int], stop_token_ids: Set[int],
max_num_steps: int, max_num_steps: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
if n < 1: if n < 1:
raise ValueError(f'n must be at least 1, got {n}.') raise ValueError(f"n must be at least 1, got {n}.")
if temperature < 0.0: if temperature < 0.0:
raise ValueError( raise ValueError(
f'temperature must be non-negative, got {temperature}.') f"temperature must be non-negative, got {temperature}.")
if not 0.0 < top_p <= 1.0: if not 0.0 < top_p <= 1.0:
raise ValueError(f'top_p must be in (0, 1], got {top_p}.') raise ValueError(f"top_p must be in (0, 1], got {top_p}.")
if top_k < -1 or top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {top_k}.")
if max_num_steps < 1: if max_num_steps < 1:
raise ValueError( raise ValueError(
f'max_num_steps must be at least 1, got {max_num_steps}.') f"max_num_steps must be at least 1, got {max_num_steps}.")
if num_logprobs < 0: if num_logprobs < 0:
raise ValueError( raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.') f"num_logprobs must be non-negative, got {num_logprobs}.")
if use_beam_search: if use_beam_search:
if n == 1: if n == 1:
raise ValueError( raise ValueError(
'n must be greater than 1 when using beam search.') "n must be greater than 1 when using beam search.")
if temperature > 0.0: if temperature > 0.0:
raise ValueError( raise ValueError(
'temperature must be 0 when using beam search.') "temperature must be 0 when using beam search.")
if top_p < 1.0: if top_p < 1.0:
raise ValueError( raise ValueError(
'top_p must be 1 when using beam search.') "top_p must be 1 when using beam search.")
if top_k != -1:
raise ValueError(
"top_k must be -1 when using beam search.")
elif temperature == 0.0: elif temperature == 0.0:
# Zero temperature means greedy sampling. # Zero temperature means greedy sampling.
if n > 1: if n > 1:
raise ValueError( raise ValueError(
'n must be 1 when using greedy sampling.') "n must be 1 when using greedy sampling.")
if top_p < 1.0: if top_p < 1.0:
raise ValueError( raise ValueError(
'top_p must be 1 when using greedy sampling.') "top_p must be 1 when using greedy sampling.")
if top_k != -1:
raise ValueError(
"top_k must be -1 when using greedy sampling.")
self.n = n self.n = n
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs self.num_logprobs = num_logprobs
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, ' return (f"SamplingParams(n={self.n}, "
f'temperature={self.temperature}, ' f"temperature={self.temperature}, "
f'top_p={self.top_p}, ' f"top_p={self.top_p}, "
f'use_beam_search={self.use_beam_search}, ' f"top_k={self.top_k},"
f'stop_token_ids={self.stop_token_ids}, ' f"use_beam_search={self.use_beam_search}, "
f'max_num_steps={self.max_num_steps}, ' f"stop_token_ids={self.stop_token_ids}, "
f'num_logprobs={self.num_logprobs}') f"max_num_steps={self.max_num_steps}, "
f"num_logprobs={self.num_logprobs}")
@classmethod @classmethod
def from_dict(cls, d: Dict) -> 'SamplingParams': def from_dict(cls, d: Dict) -> "SamplingParams":
return cls( return cls(
n=d.get('n', 1), n=d.get("n", 1),
temperature=d.get('temperature', 1.0), temperature=d.get("temperature", 1.0),
top_p=d.get('top_p', 1.0), top_p=d.get("top_p", 1.0),
use_beam_search=d.get('use_beam_search', False), top_k=d.get("top_k", -1),
stop_token_ids=set(d.get('stop_token_ids', set())), use_beam_search=d.get("use_beam_search", False),
max_num_steps=d.get('max_num_steps', 16), stop_token_ids=set(d.get("stop_token_ids", set())),
num_logprobs=d.get('num_logprobs', 0), max_num_steps=d.get("max_num_steps", 16),
num_logprobs=d.get("num_logprobs", 0),
) )

View File

@ -11,8 +11,9 @@ def main(args: argparse.Namespace):
# Test the following inputs. # Test the following inputs.
test_inputs = [ test_inputs = [
("A robot may not injure a human being", {}), # Use default parameters. ("A robot may not injure a human being", {}), # Use default parameters.
("What is the meaning of life?", {"n": 3, "temperature": 0.8, "top_p": 0.99}), ("To be or not to be,", {"temperature": 0.8, "top_k": 5}),
("It is only with the heart that one can see rightly", {"n": 4, "use_beam_search": True, "temperature": 0.0}), ("What is the meaning of life?", {"n": 2, "temperature": 0.8, "top_p": 0.95}),
("It is only with the heart that one can see rightly", {"n": 3, "use_beam_search": True, "temperature": 0.0}),
] ]
while True: while True:
if test_inputs: if test_inputs: