mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 06:14:25 +08:00
Make sampler less blocking (#1889)
This commit is contained in:
parent
f8c688d746
commit
a7347d9a6d
@ -6,13 +6,11 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_all_gather)
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
SequenceData, SequenceGroupOutput, SequenceOutput)
|
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
"""Samples the next tokens from the model's outputs.
|
"""Samples the next tokens from the model's outputs.
|
||||||
@ -32,6 +30,7 @@ class Sampler(nn.Module):
|
|||||||
def __init__(self, vocab_size: int) -> None:
|
def __init__(self, vocab_size: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self._copy_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -47,40 +46,38 @@ class Sampler(nn.Module):
|
|||||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||||
self.vocab_size)
|
self.vocab_size)
|
||||||
|
|
||||||
|
_, vocab_size = logits.shape
|
||||||
|
|
||||||
# Apply logits processors (if any).
|
# Apply logits processors (if any).
|
||||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||||
|
|
||||||
|
# Prepare sampling tensors in another stream to overlap
|
||||||
|
# CPU<->GPU data transfer with GPU computation in forward pass.
|
||||||
|
with torch.cuda.stream(self._copy_stream):
|
||||||
|
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||||
|
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||||
|
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
||||||
|
|
||||||
|
torch.cuda.current_stream().wait_stream(self._copy_stream)
|
||||||
|
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
if do_penalties:
|
||||||
_get_penalties(sampling_metadata))
|
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
||||||
assert len(presence_penalties) == logits.shape[0]
|
sampling_tensors.output_tokens,
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
sampling_tensors.presence_penalties,
|
||||||
assert len(repetition_penalties) == logits.shape[0]
|
sampling_tensors.frequency_penalties,
|
||||||
logits = _apply_penalties(logits, sampling_metadata,
|
sampling_tensors.repetition_penalties)
|
||||||
presence_penalties, frequency_penalties,
|
|
||||||
repetition_penalties)
|
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(sampling_metadata)
|
# Use in-place division to avoid creating a new tensor.
|
||||||
assert len(temperatures) == logits.shape[0]
|
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
||||||
if any(t != 1.0 for t in temperatures):
|
|
||||||
t = torch.tensor(temperatures,
|
|
||||||
dtype=logits.dtype,
|
|
||||||
device=logits.device)
|
|
||||||
# Use in-place division to avoid creating a new tensor.
|
|
||||||
logits.div_(t.unsqueeze(dim=1))
|
|
||||||
|
|
||||||
# Apply top-p and top-k truncation.
|
if do_top_p_top_k:
|
||||||
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
|
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
|
||||||
sampling_metadata, self.vocab_size)
|
sampling_tensors.top_ks)
|
||||||
assert len(top_ps) == len(top_ks) == logits.shape[0]
|
|
||||||
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
|
|
||||||
do_top_k = any(k != self.vocab_size for k in top_ks)
|
|
||||||
if do_top_p or do_top_k:
|
|
||||||
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
|
|
||||||
|
|
||||||
do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
|
|
||||||
if do_min_p:
|
if do_min_p:
|
||||||
logits = _apply_min_p(logits, min_ps)
|
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
||||||
|
|
||||||
# We use float32 for probabilities and log probabilities.
|
# We use float32 for probabilities and log probabilities.
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
@ -120,32 +117,6 @@ def _prune_hidden_states(
|
|||||||
sampling_metadata.selected_token_indices)
|
sampling_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
|
||||||
sampling_metadata: SamplingMetadata
|
|
||||||
) -> Tuple[List[float], List[float], List[float]]:
|
|
||||||
# Collect the presence and frequency penalties.
|
|
||||||
presence_penalties: List[float] = []
|
|
||||||
frequency_penalties: List[float] = []
|
|
||||||
repetition_penalties: List[float] = []
|
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
p = sampling_params.presence_penalty
|
|
||||||
f = sampling_params.frequency_penalty
|
|
||||||
r = sampling_params.repetition_penalty
|
|
||||||
if (i < sampling_metadata.num_prompts
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
# NOTE: We do not apply presence and frequency penalties for the
|
|
||||||
# prompt token positions where we don't sample new tokens.
|
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
|
||||||
presence_penalties += [0] * (prompt_len - 1)
|
|
||||||
frequency_penalties += [0] * (prompt_len - 1)
|
|
||||||
repetition_penalties += [1] * (prompt_len - 1)
|
|
||||||
presence_penalties += [p] * len(seq_ids)
|
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
|
||||||
repetition_penalties += [r] * len(seq_ids)
|
|
||||||
return presence_penalties, frequency_penalties, repetition_penalties
|
|
||||||
|
|
||||||
|
|
||||||
def _get_prompt_and_output_tokens(
|
def _get_prompt_and_output_tokens(
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Tuple[List[List[int]], List[List[int]]]:
|
) -> Tuple[List[List[int]], List[List[int]]]:
|
||||||
@ -168,25 +139,16 @@ def _get_prompt_and_output_tokens(
|
|||||||
|
|
||||||
|
|
||||||
def _get_bin_counts_and_mask(
|
def _get_bin_counts_and_mask(
|
||||||
logits: torch.Tensor,
|
tokens: torch.Tensor,
|
||||||
tokens: List[List[int]],
|
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
num_seqs: int,
|
num_seqs: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_len = max(len(tokens) for tokens in tokens)
|
|
||||||
padded_tokens = [
|
|
||||||
tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
|
|
||||||
]
|
|
||||||
tokens_tensor = torch.tensor(padded_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=logits.device)
|
|
||||||
|
|
||||||
# Compute the bin counts for the tokens.
|
# Compute the bin counts for the tokens.
|
||||||
# vocab_size + 1 for padding.
|
# vocab_size + 1 for padding.
|
||||||
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=logits.device)
|
device=tokens.device)
|
||||||
bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
|
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
||||||
bin_counts = bin_counts[:, :vocab_size]
|
bin_counts = bin_counts[:, :vocab_size]
|
||||||
mask = bin_counts > 0
|
mask = bin_counts > 0
|
||||||
|
|
||||||
@ -217,45 +179,16 @@ def _apply_logits_processors(
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _apply_penalties(
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||||
logits: torch.Tensor,
|
output_tokens_tensor: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
presence_penalties: torch.Tensor,
|
||||||
presence_penalties: List[float],
|
frequency_penalties: torch.Tensor,
|
||||||
frequency_penalties: List[float],
|
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
||||||
repetition_penalties: List[float],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_seqs, vocab_size = logits.shape
|
num_seqs, vocab_size = logits.shape
|
||||||
for i in range(num_seqs):
|
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
||||||
p = presence_penalties[i]
|
num_seqs)
|
||||||
f = frequency_penalties[i]
|
|
||||||
r = repetition_penalties[i]
|
|
||||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
|
|
||||||
r - 1.0) < _SAMPLING_EPS:
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Return early if all sequences have zero penalties.
|
|
||||||
return logits
|
|
||||||
|
|
||||||
prompt_tokens, output_tokens = (
|
|
||||||
_get_prompt_and_output_tokens(sampling_metadata))
|
|
||||||
assert len(prompt_tokens) == logits.shape[0]
|
|
||||||
assert len(output_tokens) == logits.shape[0]
|
|
||||||
|
|
||||||
prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
|
|
||||||
logits, prompt_tokens, vocab_size, num_seqs)
|
|
||||||
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
||||||
logits, output_tokens, vocab_size, num_seqs)
|
output_tokens_tensor, vocab_size, num_seqs)
|
||||||
|
|
||||||
repetition_penalties = torch.tensor(repetition_penalties,
|
|
||||||
dtype=logits.dtype,
|
|
||||||
device=logits.device)
|
|
||||||
frequency_penalties = torch.tensor(frequency_penalties,
|
|
||||||
dtype=logits.dtype,
|
|
||||||
device=logits.device)
|
|
||||||
presence_penalties = torch.tensor(presence_penalties,
|
|
||||||
dtype=logits.dtype,
|
|
||||||
device=logits.device)
|
|
||||||
|
|
||||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||||
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
||||||
@ -264,109 +197,65 @@ def _apply_penalties(
|
|||||||
|
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
||||||
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
|
|
||||||
# Collect the temperatures for the logits.
|
|
||||||
temperatures: List[float] = []
|
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
temperature = sampling_params.temperature
|
|
||||||
if temperature < _SAMPLING_EPS:
|
|
||||||
# NOTE: Zero temperature means deterministic sampling
|
|
||||||
# (i.e., greedy sampling or beam search).
|
|
||||||
# Set the temperature to 1 to avoid division by zero.
|
|
||||||
temperature = 1.0
|
|
||||||
if (i < sampling_metadata.num_prompts
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
|
||||||
temperatures += [temperature] * (prompt_len - 1)
|
|
||||||
temperatures += [temperature] * len(seq_ids)
|
|
||||||
return temperatures
|
|
||||||
|
|
||||||
|
|
||||||
def _get_top_p_top_k_min_p(
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
vocab_size: int,
|
|
||||||
) -> Tuple[List[float], List[int], List[float]]:
|
|
||||||
top_ps: List[float] = []
|
|
||||||
top_ks: List[int] = []
|
|
||||||
min_ps: List[float] = []
|
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
top_p = sampling_params.top_p
|
|
||||||
min_p = sampling_params.min_p
|
|
||||||
# k should not be greater than the vocab size.
|
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
|
||||||
# k=-1 means no truncation.
|
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
|
||||||
if (i < sampling_metadata.num_prompts
|
|
||||||
and sampling_params.prompt_logprobs is not None):
|
|
||||||
prompt_len = sampling_metadata.prompt_lens[i]
|
|
||||||
top_ps += [top_p] * (prompt_len - 1)
|
|
||||||
top_ks += [top_k] * (prompt_len - 1)
|
|
||||||
min_ps += [min_p] * (prompt_len - 1)
|
|
||||||
top_ps += [top_p] * len(seq_ids)
|
|
||||||
top_ks += [top_k] * len(seq_ids)
|
|
||||||
min_ps += [min_p] * len(seq_ids)
|
|
||||||
return top_ps, top_ks, min_ps
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_top_p_top_k(
|
def _apply_top_p_top_k(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
top_ps: List[float],
|
p: torch.Tensor,
|
||||||
top_ks: List[int],
|
k: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
|
|
||||||
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
|
|
||||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
# Apply top-p.
|
# Apply top-p.
|
||||||
probs_sort = logits_sort.softmax(dim=-1)
|
probs_sort = logits_sort.softmax(dim=-1)
|
||||||
probs_sum = probs_sort.cumsum(dim=-1)
|
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
|
||||||
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
|
top_p_mask = probs_sum > p.unsqueeze_(dim=1)
|
||||||
logits_sort[top_p_mask] = -float("inf")
|
|
||||||
|
|
||||||
# Apply top-k.
|
# Apply top-k.
|
||||||
# Create a mask for the top-k elements.
|
# Create a mask for the top-k elements.
|
||||||
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
|
||||||
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
|
||||||
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
|
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
|
||||||
logits_sort[top_k_mask] = -float("inf")
|
|
||||||
|
# Final mask.
|
||||||
|
mask = (top_p_mask | top_k_mask)
|
||||||
|
logits_sort.masked_fill_(mask, -float("inf"))
|
||||||
|
|
||||||
# Re-sort the probabilities.
|
# Re-sort the probabilities.
|
||||||
logits = torch.gather(logits_sort,
|
src = torch.arange(logits_idx.shape[-1],
|
||||||
dim=-1,
|
device=logits_idx.device).expand_as(logits_idx)
|
||||||
index=torch.argsort(logits_idx, dim=-1))
|
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
|
||||||
|
index=logits_idx,
|
||||||
|
src=src)
|
||||||
|
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _apply_min_p(
|
def _apply_min_p(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
min_ps: List[float],
|
min_p: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Adapted from
|
Adapted from
|
||||||
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
||||||
"""
|
"""
|
||||||
min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
|
|
||||||
probs = torch.softmax(logits, dim=-1)
|
probs = torch.softmax(logits, dim=-1)
|
||||||
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
||||||
scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
|
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
||||||
tokens_to_remove = probs < scaled_min_p
|
tokens_to_remove = probs < scaled_min_p
|
||||||
logits = logits.masked_fill(tokens_to_remove, -float("inf"))
|
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _greedy_sample(
|
def _greedy_sample(
|
||||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
logprobs: torch.Tensor,
|
samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
samples = torch.argmax(logprobs, dim=-1).cpu()
|
samples = samples.tolist()
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results = []
|
||||||
for seq_group in selected_seq_groups:
|
for seq_group in selected_seq_groups:
|
||||||
@ -375,27 +264,19 @@ def _greedy_sample(
|
|||||||
assert num_parent_seqs == 1, (
|
assert num_parent_seqs == 1, (
|
||||||
"Greedy sampling should have only one seq.")
|
"Greedy sampling should have only one seq.")
|
||||||
parent_ids = list(range(num_parent_seqs))
|
parent_ids = list(range(num_parent_seqs))
|
||||||
next_token_ids = [samples[sample_idx].item()]
|
next_token_ids = [samples[sample_idx]]
|
||||||
results.append((next_token_ids, parent_ids))
|
results.append((next_token_ids, parent_ids))
|
||||||
sample_idx += num_parent_seqs
|
sample_idx += num_parent_seqs
|
||||||
assert sample_idx == logprobs.size(0)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _random_sample(
|
def _random_sample(
|
||||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
is_prompts: List[bool],
|
is_prompts: List[bool],
|
||||||
probs: torch.Tensor,
|
random_samples: torch.Tensor,
|
||||||
) -> List[Tuple[List[int], List[int]]]:
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
# Find the maximum best_of value of the prompt phase requests.
|
# Find the maximum best_of value of the prompt phase requests.
|
||||||
max_best_of = 1
|
random_samples = random_samples.cpu()
|
||||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
|
||||||
if is_prompt:
|
|
||||||
seq_ids, sampling_params = seq_group
|
|
||||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
|
||||||
random_samples = torch.multinomial(probs,
|
|
||||||
num_samples=max_best_of,
|
|
||||||
replacement=True).cpu()
|
|
||||||
sample_idx = 0
|
sample_idx = 0
|
||||||
results = []
|
results = []
|
||||||
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
||||||
@ -403,8 +284,6 @@ def _random_sample(
|
|||||||
num_parent_seqs = len(seq_ids)
|
num_parent_seqs = len(seq_ids)
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
# Prompt phase.
|
# Prompt phase.
|
||||||
assert num_parent_seqs == 1, (
|
|
||||||
"Prompt input should have only one seq.")
|
|
||||||
parent_ids = [0] * sampling_params.best_of
|
parent_ids = [0] * sampling_params.best_of
|
||||||
next_token_ids = random_samples[
|
next_token_ids = random_samples[
|
||||||
sample_idx, :sampling_params.best_of].tolist()
|
sample_idx, :sampling_params.best_of].tolist()
|
||||||
@ -415,7 +294,6 @@ def _random_sample(
|
|||||||
num_parent_seqs, 0].tolist()
|
num_parent_seqs, 0].tolist()
|
||||||
results.append((next_token_ids, parent_ids))
|
results.append((next_token_ids, parent_ids))
|
||||||
sample_idx += num_parent_seqs
|
sample_idx += num_parent_seqs
|
||||||
assert sample_idx == probs.size(0)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -472,6 +350,28 @@ def _beam_search_sample(
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# torch.multinomial forces a GPU<->CPU sync.
|
||||||
|
# Therefore, we use an optimized implementation instead.
|
||||||
|
# Note that we always sample with replacement.
|
||||||
|
# probs will be modified in place, but this is fine, as we pass
|
||||||
|
# in a copy already.
|
||||||
|
def _multinomial(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
num_samples: int,
|
||||||
|
):
|
||||||
|
if num_samples > 1:
|
||||||
|
# This is equivalent to torch.repeat_interleaved (which also
|
||||||
|
# forces a GPU<->CPU sync).
|
||||||
|
# This allows us to do sampling with replacement by creating
|
||||||
|
# num_samples copies of each row in the tensor, and then
|
||||||
|
# batch sampling the resulting tensor.
|
||||||
|
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||||
|
probs.shape[1]).contiguous().view(
|
||||||
|
-1, probs.shape[1])
|
||||||
|
q = torch.empty_like(probs).exponential_(1)
|
||||||
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
@ -485,28 +385,51 @@ def _sample(
|
|||||||
categorized_seq_group_ids[sampling_type].append(i)
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
|
|
||||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||||
|
sample_metadata = {}
|
||||||
|
|
||||||
|
# Counterintiutively, having two loops here is actually faster.
|
||||||
|
# The first loop can run without waiting on GPU<->CPU sync.
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
|
||||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
|
||||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
|
||||||
sample_indices = categorized_sample_indices[sampling_type]
|
sample_indices = categorized_sample_indices[sampling_type]
|
||||||
num_tokens = len(sample_indices)
|
num_tokens = len(sample_indices)
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
continue
|
continue
|
||||||
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||||
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||||
|
is_prompts, sample_indices)
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
category_logprobs = logprobs[sample_indices]
|
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
|
||||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
|
||||||
elif sampling_type == SamplingType.RANDOM:
|
elif sampling_type == SamplingType.RANDOM:
|
||||||
category_probs = probs[sample_indices]
|
max_best_of = 1
|
||||||
sample_results = _random_sample(seq_groups, is_prompts,
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
category_probs)
|
if is_prompt:
|
||||||
|
_, sampling_params = seq_group
|
||||||
|
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||||
|
multinomial_samples = _multinomial(probs[sample_indices],
|
||||||
|
max_best_of)
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
category_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
|
||||||
sampling_metadata.seq_data,
|
|
||||||
category_logprobs)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
|
|
||||||
|
# GPU<->CPU sync happens in the loop below.
|
||||||
|
|
||||||
|
for sampling_type in SamplingType:
|
||||||
|
if sampling_type not in sample_metadata:
|
||||||
|
continue
|
||||||
|
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
||||||
|
sampling_type]
|
||||||
|
if sampling_type == SamplingType.GREEDY:
|
||||||
|
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||||
|
elif sampling_type == SamplingType.RANDOM:
|
||||||
|
sample_results = _random_sample(seq_groups, is_prompts,
|
||||||
|
multinomial_samples)
|
||||||
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
|
sampling_metadata.seq_data,
|
||||||
|
beam_search_logprobs)
|
||||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||||
|
|
||||||
sample_results = [
|
sample_results = [
|
||||||
@ -557,7 +480,7 @@ def _get_logprobs(
|
|||||||
batched_logprobs_query_result = logprobs[[
|
batched_logprobs_query_result = logprobs[[
|
||||||
batched_logprobs_query_seq_indices,
|
batched_logprobs_query_seq_indices,
|
||||||
batched_logprobs_query_token_indices
|
batched_logprobs_query_token_indices
|
||||||
]].cpu()
|
]]
|
||||||
|
|
||||||
# Batched query for logprobs of topk tokens
|
# Batched query for logprobs of topk tokens
|
||||||
if largest_num_logprobs > 0:
|
if largest_num_logprobs > 0:
|
||||||
@ -569,6 +492,8 @@ def _get_logprobs(
|
|||||||
else:
|
else:
|
||||||
top_logprobs, top_token_ids = None, None
|
top_logprobs, top_token_ids = None, None
|
||||||
|
|
||||||
|
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
||||||
|
|
||||||
# Gather results
|
# Gather results
|
||||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||||
result_sample_logprobs: List[SampleLogprobs] = []
|
result_sample_logprobs: List[SampleLogprobs] = []
|
||||||
|
|||||||
@ -1,9 +1,13 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
|
from vllm.utils import in_wsl
|
||||||
|
|
||||||
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
@ -41,3 +45,186 @@ class SamplingMetadata:
|
|||||||
f"prompt_lens={self.prompt_lens}, "
|
f"prompt_lens={self.prompt_lens}, "
|
||||||
f"selected_token_indices={self.selected_token_indices}, "
|
f"selected_token_indices={self.selected_token_indices}, "
|
||||||
f"categorized_sample_indices={self.categorized_sample_indices})")
|
f"categorized_sample_indices={self.categorized_sample_indices})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingTensors:
|
||||||
|
"""Tensors for sampling."""
|
||||||
|
|
||||||
|
temperatures: torch.Tensor
|
||||||
|
top_ps: torch.Tensor
|
||||||
|
top_ks: torch.Tensor
|
||||||
|
min_ps: torch.Tensor
|
||||||
|
presence_penalties: torch.Tensor
|
||||||
|
frequency_penalties: torch.Tensor
|
||||||
|
repetition_penalties: torch.Tensor
|
||||||
|
prompt_tokens: torch.Tensor
|
||||||
|
output_tokens: torch.Tensor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_sampling_metadata(
|
||||||
|
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||||||
|
prompt_tokens: List[List[int]] = []
|
||||||
|
output_tokens: List[List[int]] = []
|
||||||
|
top_ks: List[int] = []
|
||||||
|
temperatures: List[float] = []
|
||||||
|
top_ps: List[float] = []
|
||||||
|
min_ps: List[float] = []
|
||||||
|
presence_penalties: List[float] = []
|
||||||
|
frequency_penalties: List[float] = []
|
||||||
|
repetition_penalties: List[float] = []
|
||||||
|
do_penalties = False
|
||||||
|
do_top_p_top_k = False
|
||||||
|
do_min_p = False
|
||||||
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
|
seq_ids, sampling_params = seq_group
|
||||||
|
temperature = sampling_params.temperature
|
||||||
|
p = sampling_params.presence_penalty
|
||||||
|
f = sampling_params.frequency_penalty
|
||||||
|
r = sampling_params.repetition_penalty
|
||||||
|
top_p = sampling_params.top_p
|
||||||
|
min_p = sampling_params.min_p
|
||||||
|
# k should not be greater than the vocab size.
|
||||||
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
|
if temperature < _SAMPLING_EPS:
|
||||||
|
# NOTE: Zero temperature means deterministic sampling
|
||||||
|
# (i.e., greedy sampling or beam search).
|
||||||
|
# Set the temperature to 1 to avoid division by zero.
|
||||||
|
temperature = 1.0
|
||||||
|
if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
|
||||||
|
or top_k != vocab_size):
|
||||||
|
do_top_p_top_k = True
|
||||||
|
if not do_min_p and min_p > _SAMPLING_EPS:
|
||||||
|
do_min_p = True
|
||||||
|
if not do_penalties and (abs(p) >= _SAMPLING_EPS
|
||||||
|
or abs(f) >= _SAMPLING_EPS
|
||||||
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||||
|
do_penalties = True
|
||||||
|
if (i < sampling_metadata.num_prompts
|
||||||
|
and sampling_params.prompt_logprobs is not None):
|
||||||
|
# For tokens in the prompt that we only need to get their logprobs
|
||||||
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
|
temperatures += [temperature] * (prompt_len - 1)
|
||||||
|
top_ps += [top_p] * (prompt_len - 1)
|
||||||
|
top_ks += [top_k] * (prompt_len - 1)
|
||||||
|
min_ps += [min_p] * (prompt_len - 1)
|
||||||
|
presence_penalties += [0] * (prompt_len - 1)
|
||||||
|
frequency_penalties += [0] * (prompt_len - 1)
|
||||||
|
repetition_penalties += [1] * (prompt_len - 1)
|
||||||
|
prompt_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
|
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = sampling_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||||
|
output_tokens.append(seq_data.output_token_ids)
|
||||||
|
temperatures += [temperature] * len(seq_ids)
|
||||||
|
top_ps += [top_p] * len(seq_ids)
|
||||||
|
top_ks += [top_k] * len(seq_ids)
|
||||||
|
min_ps += [min_p] * len(seq_ids)
|
||||||
|
presence_penalties += [p] * len(seq_ids)
|
||||||
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
|
repetition_penalties += [r] * len(seq_ids)
|
||||||
|
|
||||||
|
sampling_tensors = SamplingTensors.from_lists(
|
||||||
|
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||||
|
frequency_penalties, repetition_penalties, prompt_tokens,
|
||||||
|
output_tokens, vocab_size, device, dtype)
|
||||||
|
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lists(cls, temperatures: List[float], top_ps: List[float],
|
||||||
|
top_ks: List[int], min_ps: List[float],
|
||||||
|
presence_penalties: List[float],
|
||||||
|
frequency_penalties: List[float],
|
||||||
|
repetition_penalties: List[float],
|
||||||
|
prompt_tokens: List[List[int]],
|
||||||
|
output_tokens: List[List[int]], vocab_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype) -> "SamplingTensors":
|
||||||
|
# Note that the performance will be very bad without
|
||||||
|
# pinned memory.
|
||||||
|
pin_memory = not in_wsl()
|
||||||
|
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
||||||
|
prompt_padded_tokens = [
|
||||||
|
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||||
|
for tokens in prompt_tokens
|
||||||
|
]
|
||||||
|
output_max_len = max(len(tokens) for tokens in output_tokens)
|
||||||
|
output_padded_tokens = [
|
||||||
|
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||||
|
for tokens in output_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
temperatures_t = torch.tensor(
|
||||||
|
temperatures,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
top_ps_t = torch.tensor(
|
||||||
|
top_ps,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
min_ps_t = torch.tensor(
|
||||||
|
min_ps,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
presence_penalties_t = torch.tensor(
|
||||||
|
presence_penalties,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
frequency_penalties_t = torch.tensor(
|
||||||
|
frequency_penalties,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
repetition_penalties_t = torch.tensor(
|
||||||
|
repetition_penalties,
|
||||||
|
device="cpu",
|
||||||
|
dtype=dtype,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
top_ks_t = torch.tensor(
|
||||||
|
top_ks,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
prompt_tensor = torch.tensor(
|
||||||
|
prompt_padded_tokens,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.long,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
output_tensor = torch.tensor(
|
||||||
|
output_padded_tokens,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.long,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
|
# Because the memory is pinned, we can do non-blocking
|
||||||
|
# transfer to device.
|
||||||
|
return cls(
|
||||||
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||||
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||||
|
top_ks=top_ks_t.to(device=device, non_blocking=True),
|
||||||
|
min_ps=min_ps_t.to(device=device, non_blocking=True),
|
||||||
|
presence_penalties=presence_penalties_t.to(device=device,
|
||||||
|
non_blocking=True),
|
||||||
|
frequency_penalties=frequency_penalties_t.to(device=device,
|
||||||
|
non_blocking=True),
|
||||||
|
repetition_penalties=repetition_penalties_t.to(device=device,
|
||||||
|
non_blocking=True),
|
||||||
|
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
||||||
|
output_tokens=output_tensor.to(device=device, non_blocking=True),
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user