mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Sampler] Vectorized sampling (simplified) (#1048)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
8d926e91f1
commit
947b794146
184
tests/samplers/test_sampler.py
Normal file
184
tests/samplers/test_sampler.py
Normal file
@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
import random
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch("vllm.model_executor.layers.sampler._prune_hidden_states",
|
||||
lambda x, y: x):
|
||||
with patch("vllm.model_executor.layers.sampler._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
worker = Worker(None, None, None)
|
||||
worker.block_size = 16
|
||||
return input_tensor, fake_logits, sampler, worker
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_greedy(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0, ),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_random(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0,
|
||||
best_of=2,
|
||||
use_beam_search=True,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
# no assertion here as I am not sure how to determine whether
|
||||
# the outputs are expected - in other words, this just tests
|
||||
# whether there are no exceptions in the sampler
|
||||
# when handling an all-beam search case.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_mixed(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
expected_tokens = []
|
||||
for i in range(batch_size):
|
||||
n = 1
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
elif sampling_type == 1:
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10) or -1,
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
else:
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
use_beam_search=True,
|
||||
best_of=2)
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected_tokens.append(i + idx)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
if seq_group_metadata_list[i].sampling_params.use_beam_search:
|
||||
continue
|
||||
for nth_output in sequence_output:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
@ -1,15 +1,14 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceOutputs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -44,12 +43,8 @@ class Sampler(nn.Module):
|
||||
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.vocab_size]
|
||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
@ -59,7 +54,7 @@ class Sampler(nn.Module):
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties, self.vocab_size)
|
||||
frequency_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
@ -90,19 +85,47 @@ class Sampler(nn.Module):
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
|
||||
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
vocab_size: int) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :vocab_size]
|
||||
return logits
|
||||
|
||||
|
||||
def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
last_token_indices = {t: [] for t in SamplingType}
|
||||
start_idx = 0
|
||||
last_token_indicies: List[int] = []
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
last_token_indicies.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
last_token_indicies.extend(
|
||||
range(start_idx, start_idx + input_metadata.num_generation_tokens))
|
||||
return hidden_states.index_select(
|
||||
0, torch.tensor(last_token_indicies, device=hidden_states.device))
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
if i < input_metadata.num_prompts:
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
last_token_indices[sampling_type].append(start_idx + prompt_len -
|
||||
1)
|
||||
start_idx += prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
last_token_indices[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
all_last_token_indices = []
|
||||
for sampling_type in SamplingType:
|
||||
all_last_token_indices.extend(last_token_indices[sampling_type])
|
||||
all_last_token_indices = torch.tensor(all_last_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, all_last_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -149,11 +172,8 @@ def _apply_penalties(
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
vocab_size: int,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = logits.shape[0]
|
||||
# Collect the indices of sequences that have non-zero penalties.
|
||||
indices = []
|
||||
num_seqs, vocab_size = logits.shape
|
||||
for i in range(num_seqs):
|
||||
if not output_tokens[i]:
|
||||
continue
|
||||
@ -161,33 +181,40 @@ def _apply_penalties(
|
||||
f = frequency_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
continue
|
||||
indices.append(i)
|
||||
|
||||
# Return early if all sequences have zero penalties.
|
||||
if not indices:
|
||||
break
|
||||
else:
|
||||
# Return early if all sequences have zero penalties.
|
||||
return logits
|
||||
|
||||
bin_counts = []
|
||||
for i in indices:
|
||||
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
|
||||
bin_counts = np.stack(bin_counts, axis=0)
|
||||
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
max_output_len = max(len(tokens) for tokens in output_tokens)
|
||||
padded_output_tokens = [
|
||||
tokens + [vocab_size] * (max_output_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
output_tokens_tensor = torch.tensor(padded_output_tokens,
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
|
||||
# Compute the bin counts for the output tokens.
|
||||
# vocab_size + 1 for padding.
|
||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||
torch.ones_like(output_tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||
|
||||
frequency_penalties = [frequency_penalties[i] for i in indices]
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
presence_penalties = [presence_penalties[i] for i in indices]
|
||||
presence_penalties = torch.tensor(presence_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
|
||||
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
||||
return logits
|
||||
|
||||
|
||||
@ -268,95 +295,154 @@ def _apply_top_p_top_k(
|
||||
def _get_topk_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: Optional[int],
|
||||
) -> Dict[int, float]:
|
||||
) -> List[Dict[int, float]]:
|
||||
num_seqs = logprobs.size(0)
|
||||
if num_logprobs is None or num_logprobs == 0:
|
||||
return {}
|
||||
return [{} for _ in range(num_seqs)]
|
||||
|
||||
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
|
||||
if num_logprobs == 1:
|
||||
topk_logprobs = [topk_logprobs.item()]
|
||||
topk_ids = [topk_ids.item()]
|
||||
else:
|
||||
topk_logprobs = topk_logprobs.tolist()
|
||||
topk_ids = topk_ids.tolist()
|
||||
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id] = logprob
|
||||
return token_to_logprob
|
||||
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
all_topk_logprobs = all_topk_logprobs.cpu()
|
||||
all_topk_ids = all_topk_ids.cpu()
|
||||
all_token_to_logprob = []
|
||||
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id.item()] = logprob.item()
|
||||
all_token_to_logprob.append(token_to_logprob)
|
||||
return all_token_to_logprob
|
||||
|
||||
|
||||
def _sample_from_prompt(
|
||||
prob: torch.Tensor,
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.best_of
|
||||
# Sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
_, next_token_ids = torch.topk(prob, 2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(prob,
|
||||
num_samples=num_seqs,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
def _build_sequence_outputs(
|
||||
parent_ids: List[int],
|
||||
next_token_ids: List[int],
|
||||
selected_token_logprobs: torch.Tensor,
|
||||
parent_seq_ids: List[int],
|
||||
parent_logprobs: torch.Tensor,
|
||||
num_output_logprobs: Optional[int],
|
||||
) -> List[SequenceOutputs]:
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
|
||||
seq_outputs: List[SequenceOutputs] = []
|
||||
for parent_id, next_token_id, token_logprob in zip(
|
||||
parent_ids, next_token_ids, selected_token_logprobs):
|
||||
output_logprobs = next_logprobs[parent_id].copy()
|
||||
output_logprobs[next_token_id] = token_logprob
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
|
||||
output_logprobs))
|
||||
return seq_outputs
|
||||
|
||||
|
||||
def _sample_from_generation_tokens(
|
||||
seq_ids: List[int],
|
||||
probs: torch.Tensor,
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
logprobs: torch.Tensor,
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
# Add cumulative logprobs for the sequences in the group.
|
||||
seq_logprobs = torch.tensor(seq_logprobs,
|
||||
dtype=torch.float,
|
||||
device=logprobs.device)
|
||||
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
samples = torch.argmax(logprobs, dim=-1).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group in selected_seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
assert num_parent_seqs == 1, (
|
||||
"Greedy sampling should have only one seq.")
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = [samples[sample_idx].item()]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
vocab_size = logprobs.size(-1)
|
||||
beam_width = len(seq_ids)
|
||||
_, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
seq_idx = [i // vocab_size for i in topk_ids]
|
||||
parent_seq_ids = [seq_ids[i] for i in seq_idx]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
elif sampling_params.temperature < _SAMPLING_EPS:
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
next_token_ids = [int(next_token_id.item())]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample 1 token for each sequence in the group.
|
||||
next_token_ids = torch.multinomial(probs,
|
||||
num_samples=1,
|
||||
replacement=True)
|
||||
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
|
||||
parent_seq_ids = seq_ids
|
||||
return parent_seq_ids, next_token_ids
|
||||
|
||||
def _random_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
probs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# Find the maximum best_of value of the prompt phase requests.
|
||||
max_best_of = 1
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
seq_ids, sampling_params = seq_group
|
||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||
random_samples = torch.multinomial(probs,
|
||||
num_samples=max_best_of,
|
||||
replacement=True).cpu()
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * sampling_params.best_of
|
||||
next_token_ids = random_samples[
|
||||
sample_idx, :sampling_params.best_of].tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
parent_ids = list(range(num_parent_seqs))
|
||||
next_token_ids = random_samples[sample_idx:sample_idx +
|
||||
num_parent_seqs, 0].tolist()
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == probs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _beam_search_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
is_prompts: List[bool],
|
||||
seq_data: Dict[int, SequenceData],
|
||||
logprobs: torch.Tensor,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
# We sample 2 * beam_width candidates to make sure that with high
|
||||
# probability we can get `beam_width` candidates in addition to
|
||||
# the finished sequences for the next iteration. See
|
||||
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
#
|
||||
# Note: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||
seq_ids, sampling_params = seq_group
|
||||
num_parent_seqs = len(seq_ids)
|
||||
beam_width = sampling_params.best_of
|
||||
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
||||
if is_prompt:
|
||||
# Prompt phase.
|
||||
assert num_parent_seqs == 1, (
|
||||
"Prompt input should have only one seq.")
|
||||
parent_ids = [0] * (2 * beam_width)
|
||||
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
||||
2 * beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
# Generation phase.
|
||||
cumulative_logprobs = [
|
||||
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
||||
]
|
||||
cumulative_logprobs = torch.tensor(
|
||||
cumulative_logprobs,
|
||||
dtype=torch.float,
|
||||
device=seq_group_logprobs.device)
|
||||
seq_group_logprobs = (seq_group_logprobs +
|
||||
cumulative_logprobs.unsqueeze(dim=1))
|
||||
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
||||
2 * beam_width)
|
||||
topk_ids = topk_ids.tolist()
|
||||
vocab_size = seq_group_logprobs.size(-1)
|
||||
parent_ids = [i // vocab_size for i in topk_ids]
|
||||
next_token_ids = [i % vocab_size for i in topk_ids]
|
||||
results.append((next_token_ids, parent_ids))
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
return results
|
||||
|
||||
|
||||
def _sample(
|
||||
@ -364,65 +450,80 @@ def _sample(
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
seq_outputs: SamplerOutput = []
|
||||
|
||||
# TODO(woosuk): Optimize.
|
||||
idx = 0
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
category_num_tokens = {t: 0 for t in SamplingType}
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_group_outputs: List[SequenceOutputs] = []
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
parent_seq_id = seq_ids[0]
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
num_seqs = len(seq_ids)
|
||||
category_num_tokens[sampling_type] += num_seqs
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(logprob,
|
||||
sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for next_token_id in next_token_ids:
|
||||
output_logprobs = next_logprobs.copy()
|
||||
output_logprobs[next_token_id] = logprob[next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||
category_start_idx = 0
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
num_tokens = category_num_tokens[sampling_type]
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
category_logprobs = logprobs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
category_probs = probs[category_start_idx:category_start_idx +
|
||||
num_tokens]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
category_probs)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
input_metadata.seq_data,
|
||||
category_logprobs)
|
||||
else:
|
||||
# Generate the next tokens for generation tokens.
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
sample_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
prob = probs[idx:idx + num_parent_seqs]
|
||||
logprob = logprobs[idx:idx + num_parent_seqs]
|
||||
idx += num_parent_seqs
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == num_tokens
|
||||
batched_logprobs_query_result = category_logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
]].tolist()
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
# Build the sequence outputs.
|
||||
sample_idx = 0
|
||||
result_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_results = len(next_token_ids)
|
||||
num_parent_seqs = len(seq_ids)
|
||||
parent_logprobs = category_logprobs[sample_idx:sample_idx +
|
||||
num_parent_seqs]
|
||||
selected_token_logprobs = batched_logprobs_query_result[
|
||||
result_idx:result_idx + num_results]
|
||||
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
|
||||
selected_token_logprobs,
|
||||
seq_ids, parent_logprobs,
|
||||
sampling_params.logprobs)
|
||||
seq_outputs_dict[seq_group_id] = seq_output
|
||||
sample_idx += num_parent_seqs
|
||||
result_idx += num_results
|
||||
assert sample_idx == num_tokens
|
||||
category_start_idx += num_tokens
|
||||
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||
logprob[j], sampling_params.logprobs)
|
||||
|
||||
# Build the output.
|
||||
for parent_seq_id, next_token_id in zip(parent_seq_ids,
|
||||
next_token_ids):
|
||||
j = seq_ids.index(parent_seq_id)
|
||||
output_logprobs = next_logprobs[parent_seq_id].copy()
|
||||
output_logprobs[next_token_id] = logprob[j,
|
||||
next_token_id].item()
|
||||
seq_group_outputs.append(
|
||||
SequenceOutputs(parent_seq_id, next_token_id,
|
||||
output_logprobs))
|
||||
seq_outputs.append(seq_group_outputs)
|
||||
|
||||
return seq_outputs
|
||||
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@ -166,6 +174,14 @@ class SamplingParams:
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using greedy sampling.")
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user