mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
Implement prompt logprobs & Batched topk for computing logprobs (#1328)
Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
parent
928de46888
commit
9d9072a069
@ -11,7 +11,7 @@ def main(args: argparse.Namespace):
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0)),
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
|
||||
@ -64,7 +64,7 @@ def test_request_tracker():
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
|
||||
@ -107,6 +107,39 @@ class HfRunner:
|
||||
outputs[i] = (output_ids, output_str)
|
||||
return outputs
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
) -> List[List[torch.Tensor]]:
|
||||
all_logprobs = []
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
output = self.model.generate(
|
||||
input_ids.cuda(),
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
seq_logprobs = []
|
||||
for hidden_states in output.hidden_states:
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
last_hidden_states,
|
||||
self.model.get_output_embeddings().weight.t(),
|
||||
)
|
||||
if self.model.get_output_embeddings().bias is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = torch.nn.functional.log_softmax(logits,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hf_runner():
|
||||
|
||||
55
tests/samplers/test_logprobs.py
Normal file
55
tests/samplers/test_logprobs.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["facebook/opt-125m"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=5,
|
||||
prompt_logprobs=5,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
# Test whether logprobs are included in the results.
|
||||
for result in vllm_results:
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
# Check prompt logprobs
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
|
||||
for token_id, logprob in vllm_sample_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
@ -143,7 +143,7 @@ class ModelConfig:
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU worker."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# Note: for falcon, when new_decoder_architecture is True, the
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
|
||||
@ -12,8 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
SequenceGroupMetadata, SequenceGroupOutputs,
|
||||
SequenceOutputs, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
@ -350,9 +350,15 @@ class LLMEngine:
|
||||
eos_token_id=self.tokenizer.eos_token_id))
|
||||
return current_worst_score >= highest_attainable_score
|
||||
|
||||
def _process_sequence_group_samples(
|
||||
self, seq_group: SequenceGroup,
|
||||
samples: List[SequenceOutputs]) -> None:
|
||||
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
|
||||
outputs: SequenceGroupOutputs) -> None:
|
||||
# Process prompt logprobs
|
||||
prompt_logprobs = outputs.prompt_logprobs
|
||||
if prompt_logprobs is not None:
|
||||
seq_group.prompt_logprobs = prompt_logprobs
|
||||
|
||||
# Process samples
|
||||
samples = outputs.samples
|
||||
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
existing_finished_seqs = seq_group.get_finished_seqs()
|
||||
parent_child_dict = {
|
||||
@ -520,8 +526,8 @@ class LLMEngine:
|
||||
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
for seq_group, samples in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_samples(seq_group, samples)
|
||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
@ -420,7 +420,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
# Generates ALiBi mask for each prompt.
|
||||
for prompt_len in input_metadata.prompt_lens:
|
||||
bias = torch.arange(prompt_len, dtype=dtype)
|
||||
# Note(zhuohan): HF uses
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
|
||||
@ -8,7 +8,8 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||
SequenceData, SequenceGroupOutputs, SequenceOutputs)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -82,7 +83,12 @@ class Sampler(nn.Module):
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
sample_results = _sample(probs, logprobs, input_metadata)
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, input_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results, input_metadata,
|
||||
prompt_logprobs, sample_logprobs)
|
||||
|
||||
|
||||
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
@ -102,24 +108,28 @@ def _prune_hidden_states(
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
last_token_indices = []
|
||||
selected_token_indices: List[int] = []
|
||||
start_idx = 0
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, _ = seq_group
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
assert len(seq_ids) == 1, "Prompt input should have only one seq."
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
last_token_indices.append(start_idx + prompt_len - 1)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
range(start_idx, start_idx + prompt_len - 1))
|
||||
selected_token_indices.append(start_idx + prompt_len - 1)
|
||||
start_idx += prompt_len
|
||||
else:
|
||||
num_seqs = len(seq_ids)
|
||||
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
|
||||
selected_token_indices.extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
|
||||
last_token_indices = torch.tensor(last_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, last_token_indices)
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device=hidden_states.device)
|
||||
return hidden_states.index_select(0, selected_token_indices)
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
@ -127,10 +137,17 @@ def _get_penalties(
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: We do not apply presence and frequency penalties for the
|
||||
# prompt token positions where we don't sample new tokens.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
presence_penalties += [0] * (prompt_len - 1)
|
||||
frequency_penalties += [0] * (prompt_len - 1)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
@ -138,8 +155,14 @@ def _get_penalties(
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
output_tokens: List[List[int]] = []
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
seq_ids, _ = seq_group
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: prompt token positions do not need output tokens to
|
||||
# compute penalties.
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
output_tokens.extend([] for _ in range(prompt_len - 1))
|
||||
for seq_id in seq_ids:
|
||||
seq_data = input_metadata.seq_data[seq_id]
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
@ -200,7 +223,7 @@ def _apply_penalties(
|
||||
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# Collect the temperatures for the logits.
|
||||
temperatures: List[float] = []
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
if temperature < _SAMPLING_EPS:
|
||||
@ -208,6 +231,10 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
|
||||
# (i.e., greedy sampling or beam search).
|
||||
# Set the temperature to 1 to avoid division by zero.
|
||||
temperature = 1.0
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
temperatures += [temperature] * (prompt_len - 1)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
return temperatures
|
||||
|
||||
@ -218,13 +245,18 @@ def _get_top_p_top_k(
|
||||
) -> Tuple[List[float], List[int]]:
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
for seq_group in input_metadata.seq_groups:
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
top_p = sampling_params.top_p
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
# k=-1 means no truncation.
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
top_ps += [top_p] * (prompt_len - 1)
|
||||
top_ks += [top_k] * (prompt_len - 1)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
return top_ps, top_ks
|
||||
@ -259,49 +291,6 @@ def _apply_top_p_top_k(
|
||||
return logits
|
||||
|
||||
|
||||
def _get_topk_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: Optional[int],
|
||||
) -> List[Dict[int, float]]:
|
||||
num_seqs = logprobs.size(0)
|
||||
if num_logprobs is None or num_logprobs == 0:
|
||||
return [{} for _ in range(num_seqs)]
|
||||
|
||||
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
all_topk_logprobs = all_topk_logprobs.cpu()
|
||||
all_topk_ids = all_topk_ids.cpu()
|
||||
all_token_to_logprob = []
|
||||
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
|
||||
token_to_logprob: Dict[int, float] = {}
|
||||
for token_id, logprob in zip(topk_ids, topk_logprobs):
|
||||
token_to_logprob[token_id.item()] = logprob.item()
|
||||
all_token_to_logprob.append(token_to_logprob)
|
||||
return all_token_to_logprob
|
||||
|
||||
|
||||
def _build_sequence_outputs(
|
||||
parent_ids: List[int],
|
||||
next_token_ids: List[int],
|
||||
selected_token_logprobs: List[float],
|
||||
parent_seq_ids: List[int],
|
||||
parent_logprobs: torch.Tensor,
|
||||
num_output_logprobs: Optional[int],
|
||||
) -> List[SequenceOutputs]:
|
||||
# Get top-k log probabilities for the next tokens.
|
||||
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
|
||||
seq_outputs: List[SequenceOutputs] = []
|
||||
for parent_id, next_token_id, token_logprob in zip(
|
||||
parent_ids, next_token_ids, selected_token_logprobs):
|
||||
output_logprobs = next_logprobs[parent_id].copy()
|
||||
output_logprobs[next_token_id] = token_logprob
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
|
||||
output_logprobs))
|
||||
return seq_outputs
|
||||
|
||||
|
||||
def _greedy_sample(
|
||||
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||
logprobs: torch.Tensor,
|
||||
@ -372,7 +361,7 @@ def _beam_search_sample(
|
||||
# for details. See also HF reference:
|
||||
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
||||
#
|
||||
# Note: Beam search is not vectorized, so its speed can be slower than
|
||||
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
||||
# other sampling methods.
|
||||
sample_idx = 0
|
||||
results = []
|
||||
@ -416,79 +405,186 @@ def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> SamplerOutput:
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
start_idx = 0
|
||||
categorized_seq_ids = {t: [] for t in SamplingType}
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: prompt token positions do not need sample, skip
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
start_idx += prompt_len - 1
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
num_seqs = len(seq_ids)
|
||||
categorized_seq_ids[sampling_type].extend(
|
||||
categorized_sample_indices[sampling_type].extend(
|
||||
range(start_idx, start_idx + num_seqs))
|
||||
start_idx += num_seqs
|
||||
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
|
||||
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
for sampling_type in SamplingType:
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
|
||||
num_tokens = len(categorized_seq_ids[sampling_type])
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
|
||||
category_probs = probs[categorized_seq_ids[sampling_type]]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
category_logprobs = logprobs[sample_indices]
|
||||
sample_results = _greedy_sample(seq_groups, category_logprobs)
|
||||
elif sampling_type == SamplingType.RANDOM:
|
||||
category_probs = probs[sample_indices]
|
||||
sample_results = _random_sample(seq_groups, is_prompts,
|
||||
category_probs)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
category_logprobs = logprobs[sample_indices]
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
input_metadata.seq_data,
|
||||
category_logprobs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
sample_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
sample_results = [
|
||||
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
||||
int, float]]]]:
|
||||
# Prepare query indices
|
||||
batched_logprobs_query_seq_indices: List[int] = []
|
||||
batched_logprobs_query_token_indices: List[int] = []
|
||||
largest_num_logprobs = 0
|
||||
sample_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_parent_seqs = len(seq_ids)
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.prompt_logprobs)
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == num_tokens
|
||||
batched_logprobs_query_result = category_logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
]].tolist()
|
||||
sample_idx + j for j in range(prompt_len - 1))
|
||||
batched_logprobs_query_token_indices.extend(
|
||||
token_id for token_id in prompt_tokens[1:])
|
||||
sample_idx += prompt_len - 1
|
||||
batched_logprobs_query_seq_indices.extend(
|
||||
[sample_idx + parent_id for parent_id in parent_ids])
|
||||
batched_logprobs_query_token_indices.extend(next_token_ids)
|
||||
if sampling_params.logprobs is not None:
|
||||
largest_num_logprobs = max(largest_num_logprobs,
|
||||
sampling_params.logprobs)
|
||||
sample_idx += num_parent_seqs
|
||||
assert sample_idx == logprobs.size(0)
|
||||
|
||||
# Build the sequence outputs.
|
||||
sample_idx = 0
|
||||
result_idx = 0
|
||||
for seq_group_id, seq_group, sample_result in zip(
|
||||
seq_group_ids, seq_groups, sample_results):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
num_results = len(next_token_ids)
|
||||
num_parent_seqs = len(seq_ids)
|
||||
parent_logprobs = category_logprobs[sample_idx:sample_idx +
|
||||
num_parent_seqs]
|
||||
selected_token_logprobs = batched_logprobs_query_result[
|
||||
result_idx:result_idx + num_results]
|
||||
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
|
||||
selected_token_logprobs,
|
||||
seq_ids, parent_logprobs,
|
||||
sampling_params.logprobs)
|
||||
seq_outputs_dict[seq_group_id] = seq_output
|
||||
sample_idx += num_parent_seqs
|
||||
result_idx += num_results
|
||||
assert sample_idx == num_tokens
|
||||
# Batched query for logprobs of selected token
|
||||
batched_logprobs_query_result = logprobs[[
|
||||
batched_logprobs_query_seq_indices,
|
||||
batched_logprobs_query_token_indices
|
||||
]].cpu()
|
||||
|
||||
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
|
||||
# Batched query for logprobs of topk tokens
|
||||
if largest_num_logprobs > 0:
|
||||
top_logprobs, top_token_ids = torch.topk(logprobs,
|
||||
largest_num_logprobs,
|
||||
dim=-1)
|
||||
top_logprobs = top_logprobs.cpu()
|
||||
top_token_ids = top_token_ids.cpu()
|
||||
else:
|
||||
top_logprobs, top_token_ids = None, None
|
||||
|
||||
# Gather results
|
||||
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
||||
result_sample_logprobs: List[SampleLogprobs] = []
|
||||
sample_idx = 0
|
||||
query_result_idx = 0
|
||||
for i, (seq_group, sample_result) in enumerate(
|
||||
zip(input_metadata.seq_groups, sample_results)):
|
||||
seq_ids, sampling_params = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
|
||||
# Prompt logprobs
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
num_logprobs = sampling_params.prompt_logprobs
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
prompt_tokens = input_metadata.seq_data[
|
||||
seq_ids[0]].prompt_token_ids
|
||||
group_prompt_logprobs: PromptLogprobs = [None]
|
||||
for token_id in prompt_tokens[1:]:
|
||||
prompt_logprobs_dict = {
|
||||
token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
}
|
||||
if num_logprobs > 0:
|
||||
prompt_logprobs_dict.update(
|
||||
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx, :num_logprobs].tolist()))
|
||||
group_prompt_logprobs.append(prompt_logprobs_dict)
|
||||
sample_idx += 1
|
||||
query_result_idx += 1
|
||||
result_prompt_logprobs.append(group_prompt_logprobs)
|
||||
else:
|
||||
result_prompt_logprobs.append(None)
|
||||
|
||||
# Sample logprobs
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is None:
|
||||
num_logprobs = 0
|
||||
group_sample_logprobs: SampleLogprobs = []
|
||||
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
||||
sample_logprobs_dict = {
|
||||
next_token_id:
|
||||
batched_logprobs_query_result[query_result_idx].item()
|
||||
}
|
||||
query_result_idx += 1
|
||||
if num_logprobs > 0:
|
||||
sample_logprobs_dict.update(
|
||||
zip(
|
||||
top_token_ids[sample_idx +
|
||||
parent_id, :num_logprobs].tolist(),
|
||||
top_logprobs[sample_idx +
|
||||
parent_id, :num_logprobs].tolist()))
|
||||
group_sample_logprobs.append(sample_logprobs_dict)
|
||||
result_sample_logprobs.append(group_sample_logprobs)
|
||||
sample_idx += len(seq_ids)
|
||||
|
||||
return result_prompt_logprobs, result_sample_logprobs
|
||||
|
||||
|
||||
def _build_sampler_output(
|
||||
sample_results: List[Tuple[List[int], List[int]]],
|
||||
input_metadata: InputMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
) -> SamplerOutput:
|
||||
sampler_output = []
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(input_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids, _ = seq_group
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs = []
|
||||
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
||||
next_token_ids,
|
||||
group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
|
||||
return sampler_output
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
def tensor_model_parallel_all_reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel group.
|
||||
|
||||
Note: This operation is applied in-place on the input tensor.
|
||||
NOTE: This operation is applied in-place on the input tensor.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
|
||||
@ -133,7 +133,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# NOTE: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ def split_tensor_along_last_dim(
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
|
||||
SequenceStatus)
|
||||
|
||||
|
||||
class CompletionOutput:
|
||||
@ -23,7 +24,7 @@ class CompletionOutput:
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
cumulative_logprob: float,
|
||||
logprobs: Optional[List[Dict[int, float]]],
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
self.index = index
|
||||
@ -61,12 +62,14 @@ class RequestOutput:
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
finished: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
|
||||
@ -91,7 +94,7 @@ class RequestOutput:
|
||||
# NOTE: We need to take care of this case because the sequence
|
||||
# always has the logprobs of the sampled tokens even if the
|
||||
# logprobs are not requested.
|
||||
logprobs = {}
|
||||
logprobs = None
|
||||
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
|
||||
output = CompletionOutput(seqs.index(seq), seq.output_text,
|
||||
seq.get_output_token_ids(),
|
||||
@ -100,15 +103,17 @@ class RequestOutput:
|
||||
outputs.append(output)
|
||||
|
||||
# Every sequence in the sequence group should have the same prompt.
|
||||
prompt = top_n_seqs[0].prompt
|
||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||
prompt = seq_group.prompt
|
||||
prompt_token_ids = seq_group.prompt_token_ids
|
||||
prompt_logprobs = seq_group.prompt_logprobs
|
||||
finished = seq_group.is_finished()
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||
finished)
|
||||
return cls(seq_group.request_id, prompt, prompt_token_ids,
|
||||
prompt_logprobs, outputs, finished)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished})")
|
||||
|
||||
@ -60,6 +60,12 @@ class SamplingParams:
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
tokens, as well the chosen tokens. The API will always return the
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
"""
|
||||
|
||||
@ -80,6 +86,7 @@ class SamplingParams:
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> None:
|
||||
self.n = n
|
||||
@ -105,6 +112,7 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
self._verify_args()
|
||||
@ -142,6 +150,9 @@ class SamplingParams:
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
@ -200,4 +211,5 @@ class SamplingParams:
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens})")
|
||||
|
||||
@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
PromptLogprobs = List[Optional[Dict[int, float]]]
|
||||
SampleLogprobs = List[Dict[int, float]]
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
||||
"""Status of a sequence."""
|
||||
@ -116,7 +119,7 @@ class Sequence:
|
||||
self.block_size = block_size
|
||||
|
||||
self.data = SequenceData(prompt_token_ids)
|
||||
self.output_logprobs: List[Dict[int, float]] = []
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
@ -196,7 +199,7 @@ class Sequence:
|
||||
"""
|
||||
if seq_len is None:
|
||||
seq_len = self.get_len()
|
||||
# Note: HF implementation does not count the EOS token
|
||||
# NOTE: HF implementation does not count the EOS token
|
||||
# towards the length, we align with that here for testing.
|
||||
if (eos_token_id is not None
|
||||
and self.get_last_token_id() == eos_token_id):
|
||||
@ -238,6 +241,19 @@ class SequenceGroup:
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
self.sampling_params = sampling_params
|
||||
self.arrival_time = arrival_time
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
|
||||
@property
|
||||
def prompt(self) -> str:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||
|
||||
def get_max_num_running_seqs(self) -> int:
|
||||
"""The maximum number of sequences running in parallel in the remaining
|
||||
@ -370,6 +386,22 @@ class SequenceOutputs:
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
|
||||
class SequenceGroupOutputs:
|
||||
"""The model outputs associated with a sequence group."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: List[SequenceOutputs],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
) -> None:
|
||||
self.samples = samples
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroupOutputs(samples={self.samples}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs})")
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
SamplerOutput = List[List[SequenceOutputs]]
|
||||
SamplerOutput = List[SequenceGroupOutputs]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user