[Sampler] Vectorized sampling (simplified) (#1048)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Zhuohan Li 2023-09-22 17:48:04 -07:00 committed by GitHub
parent 8d926e91f1
commit 947b794146
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 475 additions and 174 deletions

View File

@ -0,0 +1,184 @@
import pytest
import random
from typing import Tuple
from unittest.mock import patch
import torch
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
class MockLogitsSampler(Sampler):
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
super().__init__(vocab_size=vocab_size)
self.fake_logits = fake_logits
def forward(self, *args, **kwargs):
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
lambda x, y: x):
with patch("vllm.model_executor.layers.sampler._get_logits",
lambda *args, **kwargs: self.fake_logits):
return super().forward(*args, **kwargs)
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
worker = Worker(None, None, None)
worker.block_size = 16
return input_tensor, fake_logits, sampler, worker
RANDOM_SEEDS = list(range(128))
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == expected[i].item()
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
for i in range(batch_size):
fake_logits[i, i] = 1e2
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
assert nth_output.output_token == i
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
expected_tokens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
top_p=min(random.random() + 0.1, 1),
top_k=random.randint(0, 10) or -1,
n=n,
presence_penalty=random.randint(0, 1),
)
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=sampling_params,
block_tables={0: [1]},
))
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
assert nth_output.output_token in expected_tokens

View File

@ -1,15 +1,14 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region)
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceOutputs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
_SAMPLING_EPS = 1e-5
@ -44,12 +43,8 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :self.vocab_size]
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
@ -59,7 +54,7 @@ class Sampler(nn.Module):
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
frequency_penalties)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
@ -90,19 +85,47 @@ class Sampler(nn.Module):
return _sample(probs, logprobs, input_metadata)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states.index_select(
0, torch.tensor(last_token_indicies, device=hidden_states.device))
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
last_token_indices[sampling_type].append(start_idx + prompt_len -
1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
def _get_penalties(
@ -149,11 +172,8 @@ def _apply_penalties(
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
vocab_size: int,
) -> torch.Tensor:
num_seqs = logits.shape[0]
# Collect the indices of sequences that have non-zero penalties.
indices = []
num_seqs, vocab_size = logits.shape
for i in range(num_seqs):
if not output_tokens[i]:
continue
@ -161,33 +181,40 @@ def _apply_penalties(
f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue
indices.append(i)
# Return early if all sequences have zero penalties.
if not indices:
break
else:
# Return early if all sequences have zero penalties.
return logits
bin_counts = []
for i in indices:
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)
max_output_len = max(len(tokens) for tokens in output_tokens)
padded_output_tokens = [
tokens + [vocab_size] * (max_output_len - len(tokens))
for tokens in output_tokens
]
output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
return logits
@ -268,95 +295,154 @@ def _apply_top_p_top_k(
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: Optional[int],
) -> Dict[int, float]:
) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0:
return {}
return [{} for _ in range(num_seqs)]
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
if num_logprobs == 1:
topk_logprobs = [topk_logprobs.item()]
topk_ids = [topk_ids.item()]
else:
topk_logprobs = topk_logprobs.tolist()
topk_ids = topk_ids.tolist()
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id] = logprob
return token_to_logprob
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
num_logprobs,
dim=-1)
all_topk_logprobs = all_topk_logprobs.cpu()
all_topk_ids = all_topk_ids.cpu()
all_token_to_logprob = []
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id.item()] = logprob.item()
all_token_to_logprob.append(token_to_logprob)
return all_token_to_logprob
def _sample_from_prompt(
prob: torch.Tensor,
sampling_params: SamplingParams,
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.best_of
# Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_, next_token_ids = torch.topk(prob, 2 * beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling.
assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _build_sequence_outputs(
parent_ids: List[int],
next_token_ids: List[int],
selected_token_logprobs: torch.Tensor,
parent_seq_ids: List[int],
parent_logprobs: torch.Tensor,
num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
seq_outputs: List[SequenceOutputs] = []
for parent_id, next_token_id, token_logprob in zip(
parent_ids, next_token_ids, selected_token_logprobs):
output_logprobs = next_logprobs[parent_id].copy()
output_logprobs[next_token_id] = token_logprob
seq_outputs.append(
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
output_logprobs))
return seq_outputs
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor,
def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(seq_logprobs,
dtype=torch.float,
device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
) -> List[Tuple[List[int], List[int]]]:
samples = torch.argmax(logprobs, dim=-1).cpu()
sample_idx = 0
results = []
for seq_group in selected_seq_groups:
seq_ids, _ = seq_group
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx].item()]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
parent_seq_ids = [seq_ids[i] for i in seq_idx]
next_token_ids = [i % vocab_size for i in topk_ids]
elif sampling_params.temperature < _SAMPLING_EPS:
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [int(next_token_id.item())]
parent_seq_ids = seq_ids
else:
# Random sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
probs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# Find the maximum best_of value of the prompt phase requests.
max_best_of = 1
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
if is_prompt:
seq_ids, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
random_samples = torch.multinomial(probs,
num_samples=max_best_of,
replacement=True).cpu()
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == probs.size(0)
return results
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# Note: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
assert num_parent_seqs == 1, (
"Prompt input should have only one seq.")
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
return results
def _sample(
@ -364,65 +450,80 @@ def _sample(
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> SamplerOutput:
seq_outputs: SamplerOutput = []
# TODO(woosuk): Optimize.
idx = 0
categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_group_outputs: List[SequenceOutputs] = []
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == 1, "Prompt input should have only one seq."
parent_seq_id = seq_ids[0]
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
category_num_tokens[sampling_type] += num_seqs
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(logprob,
sampling_params.logprobs)
# Build the output.
for next_token_id in next_token_ids:
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
category_start_idx = 0
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
num_tokens = category_num_tokens[sampling_type]
if num_tokens == 0:
continue
category_logprobs = logprobs[category_start_idx:category_start_idx +
num_tokens]
category_probs = probs[category_start_idx:category_start_idx +
num_tokens]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
else:
# Generate the next tokens for generation tokens.
raise ValueError(f"Unsupported sampling type: {sampling_type}")
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
prob = probs[idx:idx + num_parent_seqs]
logprob = logprobs[idx:idx + num_parent_seqs]
idx += num_parent_seqs
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
assert sample_idx == num_tokens
batched_logprobs_query_result = category_logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].tolist()
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Build the sequence outputs.
sample_idx = 0
result_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
num_parent_seqs = len(seq_ids)
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs]
selected_token_logprobs = batched_logprobs_query_result[
result_idx:result_idx + num_results]
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
selected_token_logprobs,
seq_ids, parent_logprobs,
sampling_params.logprobs)
seq_outputs_dict[seq_group_id] = seq_output
sample_idx += num_parent_seqs
result_idx += num_results
assert sample_idx == num_tokens
category_start_idx += num_tokens
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for j, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[j], sampling_params.logprobs)
# Build the output.
for parent_seq_id, next_token_id in zip(parent_seq_ids,
next_token_ids):
j = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[j,
next_token_id].item()
seq_group_outputs.append(
SequenceOutputs(parent_seq_id, next_token_id,
output_logprobs))
seq_outputs.append(seq_group_outputs)
return seq_outputs
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]

View File

@ -1,9 +1,17 @@
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams:
"""Sampling parameters for text generation.
@ -166,6 +174,14 @@ class SamplingParams:
if self.top_k != -1:
raise ValueError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "