[Sampler] Vectorized sampling (simplified) (#1048)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Zhuohan Li 2023-09-22 17:48:04 -07:00 committed by GitHub
parent 8d926e91f1
commit 947b794146
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 475 additions and 174 deletions

View File

@ -0,0 +1,184 @@
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import torch
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x):
with patch("vllm.model_executor.layers.sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker
RANDOM_SEEDS = list(range(128))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == expected[i].item()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
expected_tokens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens

View File

@ -1,15 +1,14 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region) gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceOutputs from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
@ -44,12 +43,8 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = _get_logits(hidden_states, embedding, embedding_bias,
if embedding_bias is not None: self.vocab_size)
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
@ -59,7 +54,7 @@ class Sampler(nn.Module):
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size) frequency_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
@ -90,19 +85,47 @@ class Sampler(nn.Module):
return _sample(probs, logprobs, input_metadata) return _sample(probs, logprobs, input_metadata)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
start_idx = 0 start_idx = 0
last_token_indicies: List[int] = [] for i, seq_group in enumerate(input_metadata.seq_groups):
for prompt_len in input_metadata.prompt_lens: seq_ids, sampling_params = seq_group
last_token_indicies.append(start_idx + prompt_len - 1) sampling_type = sampling_params.sampling_type
start_idx += prompt_len if i < input_metadata.num_prompts:
last_token_indicies.extend( assert len(seq_ids) == 1, "Prompt input should have only one seq."
range(start_idx, start_idx + input_metadata.num_generation_tokens)) prompt_len = input_metadata.prompt_lens[i]
return hidden_states.index_select( last_token_indices[sampling_type].append(start_idx + prompt_len -
0, torch.tensor(last_token_indicies, device=hidden_states.device)) 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
def _get_penalties( def _get_penalties(
@ -149,11 +172,8 @@ def _apply_penalties(
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = logits.shape[0] num_seqs, vocab_size = logits.shape
# Collect the indices of sequences that have non-zero penalties.
indices = []
for i in range(num_seqs): for i in range(num_seqs):
if not output_tokens[i]: if not output_tokens[i]:
continue continue
@ -161,33 +181,40 @@ def _apply_penalties(
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue continue
indices.append(i) break
else:
# Return early if all sequences have zero penalties. # Return early if all sequences have zero penalties.
if not indices:
return logits return logits
bin_counts = [] max_output_len = max(len(tokens) for tokens in output_tokens)
for i in indices: padded_output_tokens = [
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size)) tokens + [vocab_size] * (max_output_len - len(tokens))
bin_counts = np.stack(bin_counts, axis=0) for tokens in output_tokens
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype, ]
device=logits.device) output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties, presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype) logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
return logits return logits
@ -268,95 +295,154 @@ def _apply_top_p_top_k(
def _get_topk_logprobs( def _get_topk_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
num_logprobs: Optional[int], num_logprobs: Optional[int],
) -> Dict[int, float]: ) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0: if num_logprobs is None or num_logprobs == 0:
return {} return [{} for _ in range(num_seqs)]
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
if num_logprobs == 1: num_logprobs,
topk_logprobs = [topk_logprobs.item()] dim=-1)
topk_ids = [topk_ids.item()] all_topk_logprobs = all_topk_logprobs.cpu()
else: all_topk_ids = all_topk_ids.cpu()
topk_logprobs = topk_logprobs.tolist() all_token_to_logprob = []
topk_ids = topk_ids.tolist() for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
token_to_logprob: Dict[int, float] = {} for token_id, logprob in zip(topk_ids, topk_logprobs):
for token_id, logprob in zip(topk_ids, topk_logprobs): token_to_logprob[token_id.item()] = logprob.item()
token_to_logprob[token_id] = logprob all_token_to_logprob.append(token_to_logprob)
return token_to_logprob return all_token_to_logprob
def _sample_from_prompt( def _build_sequence_outputs(
prob: torch.Tensor, parent_ids: List[int],
sampling_params: SamplingParams, next_token_ids: List[int],
) -> List[int]: selected_token_logprobs: torch.Tensor,
if sampling_params.use_beam_search: parent_seq_ids: List[int],
# Beam search. parent_logprobs: torch.Tensor,
beam_width = sampling_params.best_of num_output_logprobs: Optional[int],
# Sample 2 * beam_width candidates to make sure that with high ) -> List[SequenceOutputs]:
# probability we can get `beam_width` candidates in addition to # Get top-k log probabilities for the next tokens.
# the finished sequences for the next iteration. See next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 seq_outputs: List[SequenceOutputs] = []
# for details. See also HF reference: for parent_id, next_token_id, token_logprob in zip(
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 parent_ids, next_token_ids, selected_token_logprobs):
_, next_token_ids = torch.topk(prob, 2 * beam_width) output_logprobs = next_logprobs[parent_id].copy()
next_token_ids = next_token_ids.tolist() output_logprobs[next_token_id] = token_logprob
elif sampling_params.temperature < _SAMPLING_EPS: seq_outputs.append(
# Greedy sampling. SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
assert sampling_params.best_of == 1 output_logprobs))
next_token_id = torch.argmax(prob) return seq_outputs
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens( def _greedy_sample(
seq_ids: List[int], selected_seq_groups: List[Tuple[List[int], SamplingParams]],
probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
seq_logprobs: List[float], ) -> List[Tuple[List[int], List[int]]]:
sampling_params: SamplingParams, samples = torch.argmax(logprobs, dim=-1).cpu()
) -> Tuple[List[int], List[int]]: sample_idx = 0
# NOTE(woosuk): sampling_params.best_of can be greater than results = []
# len(seq_ids) because some sequences in the group might have for seq_group in selected_seq_groups:
# been already terminated. seq_ids, _ = seq_group
if sampling_params.use_beam_search: num_parent_seqs = len(seq_ids)
# Beam search. assert num_parent_seqs == 1, (
# Add cumulative logprobs for the sequences in the group. "Greedy sampling should have only one seq.")
seq_logprobs = torch.tensor(seq_logprobs, parent_ids = list(range(num_parent_seqs))
dtype=torch.float, next_token_ids = [samples[sample_idx].item()]
device=logprobs.device) results.append((next_token_ids, parent_ids))
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids) def _random_sample(
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) selected_seq_groups: List[Tuple[List[int], SamplingParams]],
topk_ids = topk_ids.tolist() is_prompts: List[bool],
seq_idx = [i // vocab_size for i in topk_ids] probs: torch.Tensor,
parent_seq_ids = [seq_ids[i] for i in seq_idx] ) -> List[Tuple[List[int], List[int]]]:
next_token_ids = [i % vocab_size for i in topk_ids] # Find the maximum best_of value of the prompt phase requests.
elif sampling_params.temperature < _SAMPLING_EPS: max_best_of = 1
# Greedy sampling. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
assert len(seq_ids) == 1 if is_prompt:
next_token_id = torch.argmax(probs, dim=-1) seq_ids, sampling_params = seq_group
next_token_ids = [int(next_token_id.item())] max_best_of = max(max_best_of, sampling_params.best_of)
parent_seq_ids = seq_ids random_samples = torch.multinomial(probs,
else: num_samples=max_best_of,
# Random sampling. replacement=True).cpu()
# Sample 1 token for each sequence in the group. sample_idx = 0
next_token_ids = torch.multinomial(probs, results = []
num_samples=1, for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
replacement=True) seq_ids, sampling_params = seq_group
next_token_ids = next_token_ids.squeeze(dim=-1).tolist() num_parent_seqs = len(seq_ids)
parent_seq_ids = seq_ids if is_prompt:
return parent_seq_ids, next_token_ids # Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == probs.size(0)
return results
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# Note: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
def _sample( def _sample(
@ -364,65 +450,80 @@ def _sample(
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> SamplerOutput:
seq_outputs: SamplerOutput = [] categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts: sampling_type = sampling_params.sampling_type
# Generate the next tokens for a prompt input. categorized_seq_group_ids[sampling_type].append(i)
assert len(seq_ids) == 1, "Prompt input should have only one seq." num_seqs = len(seq_ids)
parent_seq_id = seq_ids[0] category_num_tokens[sampling_type] += num_seqs
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens. seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
next_token_ids = _sample_from_prompt(prob, sampling_params) category_start_idx = 0
# Get top-k log probabilities for the next tokens. for sampling_type in SamplingType:
next_logprobs = _get_topk_logprobs(logprob, seq_group_ids = categorized_seq_group_ids[sampling_type]
sampling_params.logprobs) seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
# Build the output. num_tokens = category_num_tokens[sampling_type]
for next_token_id in next_token_ids: if num_tokens == 0:
output_logprobs = next_logprobs.copy() continue
output_logprobs[next_token_id] = logprob[next_token_id].item() category_logprobs = logprobs[category_start_idx:category_start_idx +
seq_group_outputs.append( num_tokens]
SequenceOutputs(parent_seq_id, next_token_id, category_probs = probs[category_start_idx:category_start_idx +
output_logprobs)) num_tokens]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
else: else:
# Generate the next tokens for generation tokens. raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
prob = probs[idx:idx + num_parent_seqs] batched_logprobs_query_seq_indices.extend(
logprob = logprobs[idx:idx + num_parent_seqs] [sample_idx + parent_id for parent_id in parent_ids])
idx += num_parent_seqs batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
assert sample_idx == num_tokens
batched_logprobs_query_result = category_logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].tolist()
# Sample the next tokens. # Build the sequence outputs.
seq_logprobs = [ sample_idx = 0
input_metadata.seq_data[seq_id].cumulative_logprob result_idx = 0
for seq_id in seq_ids for seq_group_id, seq_group, sample_result in zip(
] seq_group_ids, seq_groups, sample_results):
parent_seq_ids, next_token_ids = _sample_from_generation_tokens( seq_ids, sampling_params = seq_group
seq_ids, prob, logprob, seq_logprobs, sampling_params) next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
num_parent_seqs = len(seq_ids)
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs]
selected_token_logprobs = batched_logprobs_query_result[
result_idx:result_idx + num_results]
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
selected_token_logprobs,
seq_ids, parent_logprobs,
sampling_params.logprobs)
seq_outputs_dict[seq_group_id] = seq_output
sample_idx += num_parent_seqs
result_idx += num_results
assert sample_idx == num_tokens
category_start_idx += num_tokens
# Get top-k log probabilities for the next tokens. return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
next_logprobs: Dict[int, Dict[int, float]] = {}
for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[j], sampling_params.logprobs)
# Build the output.
for parent_seq_id, next_token_id in zip(parent_seq_ids,
next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
seq_outputs.append(seq_group_outputs)
return seq_outputs

View File

@ -1,9 +1,17 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5 _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams: class SamplingParams:
"""Sampling parameters for text generation. """Sampling parameters for text generation.
@ -166,6 +174,14 @@ class SamplingParams:
if self.top_k != -1: if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.") raise ValueError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, " return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, " f"best_of={self.best_of}, "