Implement prompt logprobs & Batched topk for computing logprobs (#1328)

Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
Zhuohan Li 2023-10-16 10:56:50 -07:00 committed by GitHub
parent 928de46888
commit 9d9072a069
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 369 additions and 130 deletions

View File

@ -11,7 +11,7 @@ def main(args: argparse.Namespace):
# Test the following prompts.
test_prompts = [
("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",

View File

@ -64,7 +64,7 @@ def test_request_tracker():
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event.flag
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
RequestOutput("2", "output", [], [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1

View File

@ -107,6 +107,39 @@ class HfRunner:
outputs[i] = (output_ids, output_str)
return outputs
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
) -> List[List[torch.Tensor]]:
all_logprobs = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)
seq_logprobs = []
for hidden_states in output.hidden_states:
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
@pytest.fixture
def hf_runner():

View File

@ -0,0 +1,55 @@
import pytest
import torch
from vllm import SamplingParams
MODELS = ["facebook/opt-125m"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts,
max_tokens=max_tokens,
)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5,
prompt_logprobs=5,
temperature=0.0)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)
# Test whether logprobs are included in the results.
for result in vllm_results:
assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None
# Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
# Check prompt logprobs
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob,
hf_logprob[0][i][token_id].item(),
atol=1e-2,
rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items():
torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)

View File

@ -143,7 +143,7 @@ class ModelConfig:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]

View File

@ -12,8 +12,8 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceOutputs,
SequenceStatus)
SequenceGroupMetadata, SequenceGroupOutputs,
SequenceOutputs, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
@ -350,9 +350,15 @@ class LLMEngine:
eos_token_id=self.tokenizer.eos_token_id))
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup,
samples: List[SequenceOutputs]) -> None:
def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutputs) -> None:
# Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None:
seq_group.prompt_logprobs = prompt_logprobs
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {
@ -520,8 +526,8 @@ class LLMEngine:
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

View File

@ -420,7 +420,7 @@ class PagedAttentionWithALiBi(PagedAttention):
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# Note(zhuohan): HF uses
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi

View File

@ -8,7 +8,8 @@ from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutputs, SequenceOutputs)
_SAMPLING_EPS = 1e-5
@ -82,7 +83,12 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
sample_results = _sample(probs, logprobs, input_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, input_metadata, sample_results)
return _build_sampler_output(sample_results, input_metadata,
prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
@ -102,24 +108,28 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
last_token_indices = []
selected_token_indices: List[int] = []
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, _ = seq_group
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
last_token_indices.append(start_idx + prompt_len - 1)
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices.extend(range(start_idx, start_idx + num_seqs))
selected_token_indices.extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
last_token_indices = torch.tensor(last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, last_token_indices)
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, selected_token_indices)
def _get_penalties(
@ -127,10 +137,17 @@ def _get_penalties(
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
for seq_group in input_metadata.seq_groups:
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len = input_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties
@ -138,8 +155,14 @@ def _get_penalties(
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = []
for seq_group in input_metadata.seq_groups:
seq_ids, _ = seq_group
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = input_metadata.prompt_lens[i]
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
@ -200,7 +223,7 @@ def _apply_penalties(
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for seq_group in input_metadata.seq_groups:
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS:
@ -208,6 +231,10 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids)
return temperatures
@ -218,13 +245,18 @@ def _get_top_p_top_k(
) -> Tuple[List[float], List[int]]:
top_ps: List[float] = []
top_ks: List[int] = []
for seq_group in input_metadata.seq_groups:
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks
@ -259,49 +291,6 @@ def _apply_top_p_top_k(
return logits
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: Optional[int],
) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0:
return [{} for _ in range(num_seqs)]
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
num_logprobs,
dim=-1)
all_topk_logprobs = all_topk_logprobs.cpu()
all_topk_ids = all_topk_ids.cpu()
all_token_to_logprob = []
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id.item()] = logprob.item()
all_token_to_logprob.append(token_to_logprob)
return all_token_to_logprob
def _build_sequence_outputs(
parent_ids: List[int],
next_token_ids: List[int],
selected_token_logprobs: List[float],
parent_seq_ids: List[int],
parent_logprobs: torch.Tensor,
num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
seq_outputs: List[SequenceOutputs] = []
for parent_id, next_token_id, token_logprob in zip(
parent_ids, next_token_ids, selected_token_logprobs):
output_logprobs = next_logprobs[parent_id].copy()
output_logprobs[next_token_id] = token_logprob
seq_outputs.append(
SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
output_logprobs))
return seq_outputs
def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
@ -372,7 +361,7 @@ def _beam_search_sample(
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# Note: Beam search is not vectorized, so its speed can be slower than
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
@ -416,79 +405,186 @@ def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> SamplerOutput:
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = {t: [] for t in SamplingType}
start_idx = 0
categorized_seq_ids = {t: [] for t in SamplingType}
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need sample, skip
prompt_len = input_metadata.prompt_lens[i]
start_idx += prompt_len - 1
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
categorized_seq_ids[sampling_type].extend(
categorized_sample_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
num_tokens = len(categorized_seq_ids[sampling_type])
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
category_probs = probs[categorized_seq_ids[sampling_type]]
if sampling_type == SamplingType.GREEDY:
category_logprobs = logprobs[sample_indices]
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
category_probs = probs[sample_indices]
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results))
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
sample_results = [
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
]
return sample_results
def _get_logprobs(
logprobs: torch.Tensor,
input_metadata: InputMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
largest_num_logprobs = 0
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
assert sample_idx == num_tokens
batched_logprobs_query_result = category_logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].tolist()
sample_idx + j for j in range(prompt_len - 1))
batched_logprobs_query_token_indices.extend(
token_id for token_id in prompt_tokens[1:])
sample_idx += prompt_len - 1
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
# Build the sequence outputs.
sample_idx = 0
result_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
num_parent_seqs = len(seq_ids)
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs]
selected_token_logprobs = batched_logprobs_query_result[
result_idx:result_idx + num_results]
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
selected_token_logprobs,
seq_ids, parent_logprobs,
sampling_params.logprobs)
seq_outputs_dict[seq_group_id] = seq_output
sample_idx += num_parent_seqs
result_idx += num_results
assert sample_idx == num_tokens
# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].cpu()
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]
# Batched query for logprobs of topk tokens
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.cpu()
top_token_ids = top_token_ids.cpu()
else:
top_logprobs, top_token_ids = None, None
# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = []
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
# Prompt logprobs
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
batched_logprobs_query_result[query_result_idx].item()
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist()))
group_prompt_logprobs.append(prompt_logprobs_dict)
sample_idx += 1
query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs)
else:
result_prompt_logprobs.append(None)
# Sample logprobs
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
group_sample_logprobs: SampleLogprobs = []
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
next_token_id:
batched_logprobs_query_result[query_result_idx].item()
}
query_result_idx += 1
if num_logprobs > 0:
sample_logprobs_dict.update(
zip(
top_token_ids[sample_idx +
parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist()))
group_sample_logprobs.append(sample_logprobs_dict)
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)
return result_prompt_logprobs, result_sample_logprobs
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
input_metadata: InputMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(input_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
return sampler_output

View File

@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
def tensor_model_parallel_all_reduce(input_):
"""All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor.
NOTE: This operation is applied in-place on the input tensor.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:

View File

@ -133,7 +133,7 @@ class ColumnParallelLinear(torch.nn.Module):
params_dtype = torch.get_default_dtype()
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# NOTE: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.create_weights(params_dtype)

View File

@ -41,7 +41,7 @@ def split_tensor_along_last_dim(
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)

View File

@ -1,6 +1,7 @@
from typing import Dict, List, Optional
from typing import List, Optional
from vllm.sequence import SequenceGroup, SequenceStatus
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus)
class CompletionOutput:
@ -23,7 +24,7 @@ class CompletionOutput:
text: str,
token_ids: List[int],
cumulative_logprob: float,
logprobs: Optional[List[Dict[int, float]]],
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
) -> None:
self.index = index
@ -61,12 +62,14 @@ class RequestOutput:
request_id: str,
prompt: str,
prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
@ -91,7 +94,7 @@ class RequestOutput:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
logprobs = None
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
@ -100,15 +103,17 @@ class RequestOutput:
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
finished)
return cls(seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished)
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, "
f"finished={self.finished})")

View File

@ -60,6 +60,12 @@ class SamplingParams:
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
tokens, as well the chosen tokens. The API will always return the
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
skip_special_tokens: Whether to skip special tokens in the output.
"""
@ -80,6 +86,7 @@ class SamplingParams:
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> None:
self.n = n
@ -105,6 +112,7 @@ class SamplingParams:
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self._verify_args()
@ -142,6 +150,9 @@ class SamplingParams:
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
@ -200,4 +211,5 @@ class SamplingParams:
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")

View File

@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]]
class SequenceStatus(enum.Enum):
"""Status of a sequence."""
@ -116,7 +119,7 @@ class Sequence:
self.block_size = block_size
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
@ -196,7 +199,7 @@ class Sequence:
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
@ -238,6 +241,19 @@ class SequenceGroup:
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
self.prompt_logprobs: Optional[PromptLogprobs] = None
@property
def prompt(self) -> str:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
@ -370,6 +386,22 @@ class SequenceOutputs:
and self.logprobs == other.logprobs)
class SequenceGroupOutputs:
"""The model outputs associated with a sequence group."""
def __init__(
self,
samples: List[SequenceOutputs],
prompt_logprobs: Optional[PromptLogprobs],
) -> None:
self.samples = samples
self.prompt_logprobs = prompt_logprobs
def __repr__(self) -> str:
return (f"SequenceGroupOutputs(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})")
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]
SamplerOutput = List[SequenceGroupOutputs]