From 9d9072a069202e7892a40ef94e9085019e73f370 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 16 Oct 2023 10:56:50 -0700 Subject: [PATCH] Implement prompt logprobs & Batched topk for computing logprobs (#1328) Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com> --- examples/llm_engine_example.py | 2 +- tests/async_engine/test_request_tracker.py | 2 +- tests/conftest.py | 33 ++ tests/samplers/test_logprobs.py | 55 ++++ vllm/config.py | 2 +- vllm/engine/llm_engine.py | 20 +- vllm/model_executor/layers/attention.py | 2 +- vllm/model_executor/layers/sampler.py | 306 ++++++++++++------ .../parallel_utils/communication_op.py | 2 +- vllm/model_executor/parallel_utils/layers.py | 2 +- vllm/model_executor/parallel_utils/utils.py | 2 +- vllm/outputs.py | 21 +- vllm/sampling_params.py | 12 + vllm/sequence.py | 38 ++- 14 files changed, 369 insertions(+), 130 deletions(-) create mode 100644 tests/samplers/test_logprobs.py diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index cf86a4737727..df97ae5d6a8f 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -11,7 +11,7 @@ def main(args: argparse.Namespace): # Test the following prompts. test_prompts = [ ("A robot may not injure a human being", - SamplingParams(temperature=0.0)), + SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)), ("To be or not to be,", SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ("What is the meaning of life?", diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 83b306e75950..3e4d53c5cbe2 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -64,7 +64,7 @@ def test_request_tracker(): stream_5 = tracker.add_request("5") assert tracker.new_requests_event.flag tracker.process_request_output( - RequestOutput("2", "output", [], [], finished=True)) + RequestOutput("2", "output", [], [], [], finished=True)) new, finished = tracker.get_new_and_finished_requests() assert not tracker.new_requests_event.flag assert len(finished) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 24b99a27929d..cc4339849f55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,6 +107,39 @@ class HfRunner: outputs[i] = (output_ids, output_str) return outputs + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + ) -> List[List[torch.Tensor]]: + all_logprobs = [] + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + seq_logprobs = [] + for hidden_states in output.hidden_states: + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + all_logprobs.append(seq_logprobs) + return all_logprobs + @pytest.fixture def hf_runner(): diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py new file mode 100644 index 000000000000..1c67cc5bd739 --- /dev/null +++ b/tests/samplers/test_logprobs.py @@ -0,0 +1,55 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_get_prompt_logprobs( + hf_runner, + vllm_runner, + model, + dtype, + example_prompts, +): + max_tokens = 5 + hf_model = hf_runner(model, dtype=dtype) + hf_logprobs = hf_model.generate_greedy_logprobs( + example_prompts, + max_tokens=max_tokens, + ) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_sampling_params = SamplingParams(max_tokens=max_tokens, + logprobs=5, + prompt_logprobs=5, + temperature=0.0) + vllm_results = vllm_model.model.generate( + example_prompts, sampling_params=vllm_sampling_params) + + # Test whether logprobs are included in the results. + for result in vllm_results: + assert result.prompt_logprobs is not None + assert result.outputs[0].logprobs is not None + + # Test whether prompt logprobs are consistent with HF + for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): + # Check prompt logprobs + vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] + for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): + for token_id, logprob in vllm_prompt_logprob_dict.items(): + torch.testing.assert_close(logprob, + hf_logprob[0][i][token_id].item(), + atol=1e-2, + rtol=1e-2) + vllm_sample_logprobs = vllm_result.outputs[0].logprobs + for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs): + for token_id, logprob in vllm_sample_logprob_dict.items(): + torch.testing.assert_close(logprob, + hf_logprob[i][-1][token_id].item(), + atol=1e-2, + rtol=1e-2) diff --git a/vllm/config.py b/vllm/config.py index f66fb291c994..90ffe82235a4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -143,7 +143,7 @@ class ModelConfig: def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU worker.""" # For GPTBigCode & Falcon: - # Note: for falcon, when new_decoder_architecture is True, the + # NOTE: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74a8905a916d..f0d868d3afec 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -12,8 +12,8 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceOutputs, - SequenceStatus) + SequenceGroupMetadata, SequenceGroupOutputs, + SequenceOutputs, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -350,9 +350,15 @@ class LLMEngine: eos_token_id=self.tokenizer.eos_token_id)) return current_worst_score >= highest_attainable_score - def _process_sequence_group_samples( - self, seq_group: SequenceGroup, - samples: List[SequenceOutputs]) -> None: + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutputs) -> None: + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None: + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() parent_child_dict = { @@ -520,8 +526,8 @@ class LLMEngine: scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for seq_group, samples in zip(scheduled_seq_groups, output): - self._process_sequence_group_samples(seq_group, samples) + for seq_group, outputs in zip(scheduled_seq_groups, output): + self._process_sequence_group_outputs(seq_group, outputs) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 0677ebbae792..084f4d98270e 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -420,7 +420,7 @@ class PagedAttentionWithALiBi(PagedAttention): # Generates ALiBi mask for each prompt. for prompt_len in input_metadata.prompt_lens: bias = torch.arange(prompt_len, dtype=dtype) - # Note(zhuohan): HF uses + # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a9f036fa2f24..2c8652ee6c3a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -8,7 +8,8 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs +from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceData, SequenceGroupOutputs, SequenceOutputs) _SAMPLING_EPS = 1e-5 @@ -82,7 +83,12 @@ class Sampler(nn.Module): logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - return _sample(probs, logprobs, input_metadata) + sample_results = _sample(probs, logprobs, input_metadata) + # Get the logprobs query results. + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, input_metadata, sample_results) + return _build_sampler_output(sample_results, input_metadata, + prompt_logprobs, sample_logprobs) def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, @@ -102,24 +108,28 @@ def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - last_token_indices = [] + selected_token_indices: List[int] = [] start_idx = 0 for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, _ = seq_group + seq_ids, sampling_params = seq_group if i < input_metadata.num_prompts: assert len(seq_ids) == 1, "Prompt input should have only one seq." prompt_len = input_metadata.prompt_lens[i] - last_token_indices.append(start_idx + prompt_len - 1) + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(start_idx, start_idx + prompt_len - 1)) + selected_token_indices.append(start_idx + prompt_len - 1) start_idx += prompt_len else: num_seqs = len(seq_ids) - last_token_indices.extend(range(start_idx, start_idx + num_seqs)) + selected_token_indices.extend( + range(start_idx, start_idx + num_seqs)) start_idx += num_seqs - last_token_indices = torch.tensor(last_token_indices, - dtype=torch.long, - device=hidden_states.device) - return hidden_states.index_select(0, last_token_indices) + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device=hidden_states.device) + return hidden_states.index_select(0, selected_token_indices) def _get_penalties( @@ -127,10 +137,17 @@ def _get_penalties( # Collect the presence and frequency penalties. presence_penalties: List[float] = [] frequency_penalties: List[float] = [] - for seq_group in input_metadata.seq_groups: + for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group p = sampling_params.presence_penalty f = sampling_params.frequency_penalty + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # NOTE: We do not apply presence and frequency penalties for the + # prompt token positions where we don't sample new tokens. + prompt_len = input_metadata.prompt_lens[i] + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) presence_penalties += [p] * len(seq_ids) frequency_penalties += [f] * len(seq_ids) return presence_penalties, frequency_penalties @@ -138,8 +155,14 @@ def _get_penalties( def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: output_tokens: List[List[int]] = [] - for seq_group in input_metadata.seq_groups: - seq_ids, _ = seq_group + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # NOTE: prompt token positions do not need output tokens to + # compute penalties. + prompt_len = input_metadata.prompt_lens[i] + output_tokens.extend([] for _ in range(prompt_len - 1)) for seq_id in seq_ids: seq_data = input_metadata.seq_data[seq_id] output_tokens.append(seq_data.output_token_ids) @@ -200,7 +223,7 @@ def _apply_penalties( def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # Collect the temperatures for the logits. temperatures: List[float] = [] - for seq_group in input_metadata.seq_groups: + for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature if temperature < _SAMPLING_EPS: @@ -208,6 +231,10 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. temperature = 1.0 + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + prompt_len = input_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * len(seq_ids) return temperatures @@ -218,13 +245,18 @@ def _get_top_p_top_k( ) -> Tuple[List[float], List[int]]: top_ps: List[float] = [] top_ks: List[int] = [] - for seq_group in input_metadata.seq_groups: + for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group top_p = sampling_params.top_p # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) # k=-1 means no truncation. top_k = vocab_size if top_k == -1 else top_k + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + prompt_len = input_metadata.prompt_lens[i] + top_ps += [top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) return top_ps, top_ks @@ -259,49 +291,6 @@ def _apply_top_p_top_k( return logits -def _get_topk_logprobs( - logprobs: torch.Tensor, - num_logprobs: Optional[int], -) -> List[Dict[int, float]]: - num_seqs = logprobs.size(0) - if num_logprobs is None or num_logprobs == 0: - return [{} for _ in range(num_seqs)] - - all_topk_logprobs, all_topk_ids = torch.topk(logprobs, - num_logprobs, - dim=-1) - all_topk_logprobs = all_topk_logprobs.cpu() - all_topk_ids = all_topk_ids.cpu() - all_token_to_logprob = [] - for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids): - token_to_logprob: Dict[int, float] = {} - for token_id, logprob in zip(topk_ids, topk_logprobs): - token_to_logprob[token_id.item()] = logprob.item() - all_token_to_logprob.append(token_to_logprob) - return all_token_to_logprob - - -def _build_sequence_outputs( - parent_ids: List[int], - next_token_ids: List[int], - selected_token_logprobs: List[float], - parent_seq_ids: List[int], - parent_logprobs: torch.Tensor, - num_output_logprobs: Optional[int], -) -> List[SequenceOutputs]: - # Get top-k log probabilities for the next tokens. - next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs) - seq_outputs: List[SequenceOutputs] = [] - for parent_id, next_token_id, token_logprob in zip( - parent_ids, next_token_ids, selected_token_logprobs): - output_logprobs = next_logprobs[parent_id].copy() - output_logprobs[next_token_id] = token_logprob - seq_outputs.append( - SequenceOutputs(parent_seq_ids[parent_id], next_token_id, - output_logprobs)) - return seq_outputs - - def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], logprobs: torch.Tensor, @@ -372,7 +361,7 @@ def _beam_search_sample( # for details. See also HF reference: # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 # - # Note: Beam search is not vectorized, so its speed can be slower than + # NOTE: Beam search is not vectorized, so its speed can be slower than # other sampling methods. sample_idx = 0 results = [] @@ -416,79 +405,186 @@ def _sample( probs: torch.Tensor, logprobs: torch.Tensor, input_metadata: InputMetadata, -) -> SamplerOutput: +) -> List[Tuple[List[int], List[int]]]: categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = {t: [] for t in SamplingType} start_idx = 0 - categorized_seq_ids = {t: [] for t in SamplingType} for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group sampling_type = sampling_params.sampling_type + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # NOTE: prompt token positions do not need sample, skip + prompt_len = input_metadata.prompt_lens[i] + start_idx += prompt_len - 1 categorized_seq_group_ids[sampling_type].append(i) num_seqs = len(seq_ids) - categorized_seq_ids[sampling_type].extend( + categorized_sample_indices[sampling_type].extend( range(start_idx, start_idx + num_seqs)) start_idx += num_seqs - seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {} + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} for sampling_type in SamplingType: seq_group_ids = categorized_seq_group_ids[sampling_type] seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] - num_tokens = len(categorized_seq_ids[sampling_type]) + sample_indices = categorized_sample_indices[sampling_type] + num_tokens = len(sample_indices) if num_tokens == 0: continue - category_logprobs = logprobs[categorized_seq_ids[sampling_type]] - category_probs = probs[categorized_seq_ids[sampling_type]] if sampling_type == SamplingType.GREEDY: + category_logprobs = logprobs[sample_indices] sample_results = _greedy_sample(seq_groups, category_logprobs) elif sampling_type == SamplingType.RANDOM: + category_probs = probs[sample_indices] sample_results = _random_sample(seq_groups, is_prompts, category_probs) elif sampling_type == SamplingType.BEAM: + category_logprobs = logprobs[sample_indices] sample_results = _beam_search_sample(seq_groups, is_prompts, input_metadata.seq_data, category_logprobs) else: raise ValueError(f"Unsupported sampling type: {sampling_type}") + sample_results_dict.update(zip(seq_group_ids, sample_results)) - # Batched query for logprobs of selected token - batched_logprobs_query_seq_indices: List[int] = [] - batched_logprobs_query_token_indices: List[int] = [] - sample_idx = 0 - for seq_group_id, seq_group, sample_result in zip( - seq_group_ids, seq_groups, sample_results): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_parent_seqs = len(seq_ids) + sample_results = [ + sample_results_dict[i] for i in range(len(input_metadata.seq_groups)) + ] + return sample_results + + +def _get_logprobs( + logprobs: torch.Tensor, + input_metadata: InputMetadata, + sample_results: List[Tuple[List[int], List[int]]], +) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ + int, float]]]]: + # Prepare query indices + batched_logprobs_query_seq_indices: List[int] = [] + batched_logprobs_query_token_indices: List[int] = [] + largest_num_logprobs = 0 + sample_idx = 0 + for i, (seq_group, sample_result) in enumerate( + zip(input_metadata.seq_groups, sample_results)): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result + num_parent_seqs = len(seq_ids) + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + prompt_len = input_metadata.prompt_lens[i] + prompt_tokens = input_metadata.seq_data[ + seq_ids[0]].prompt_token_ids batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) - sample_idx += num_parent_seqs - assert sample_idx == num_tokens - batched_logprobs_query_result = category_logprobs[[ - batched_logprobs_query_seq_indices, - batched_logprobs_query_token_indices - ]].tolist() + sample_idx + j for j in range(prompt_len - 1)) + batched_logprobs_query_token_indices.extend( + token_id for token_id in prompt_tokens[1:]) + sample_idx += prompt_len - 1 + batched_logprobs_query_seq_indices.extend( + [sample_idx + parent_id for parent_id in parent_ids]) + batched_logprobs_query_token_indices.extend(next_token_ids) + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) - # Build the sequence outputs. - sample_idx = 0 - result_idx = 0 - for seq_group_id, seq_group, sample_result in zip( - seq_group_ids, seq_groups, sample_results): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_results = len(next_token_ids) - num_parent_seqs = len(seq_ids) - parent_logprobs = category_logprobs[sample_idx:sample_idx + - num_parent_seqs] - selected_token_logprobs = batched_logprobs_query_result[ - result_idx:result_idx + num_results] - seq_output = _build_sequence_outputs(parent_ids, next_token_ids, - selected_token_logprobs, - seq_ids, parent_logprobs, - sampling_params.logprobs) - seq_outputs_dict[seq_group_id] = seq_output - sample_idx += num_parent_seqs - result_idx += num_results - assert sample_idx == num_tokens + # Batched query for logprobs of selected token + batched_logprobs_query_result = logprobs[[ + batched_logprobs_query_seq_indices, + batched_logprobs_query_token_indices + ]].cpu() - return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))] + # Batched query for logprobs of topk tokens + if largest_num_logprobs > 0: + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.cpu() + top_token_ids = top_token_ids.cpu() + else: + top_logprobs, top_token_ids = None, None + + # Gather results + result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] + result_sample_logprobs: List[SampleLogprobs] = [] + sample_idx = 0 + query_result_idx = 0 + for i, (seq_group, sample_result) in enumerate( + zip(input_metadata.seq_groups, sample_results)): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result + + # Prompt logprobs + if (i < input_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + num_logprobs = sampling_params.prompt_logprobs + prompt_len = input_metadata.prompt_lens[i] + prompt_tokens = input_metadata.seq_data[ + seq_ids[0]].prompt_token_ids + group_prompt_logprobs: PromptLogprobs = [None] + for token_id in prompt_tokens[1:]: + prompt_logprobs_dict = { + token_id: + batched_logprobs_query_result[query_result_idx].item() + } + if num_logprobs > 0: + prompt_logprobs_dict.update( + zip(top_token_ids[sample_idx, :num_logprobs].tolist(), + top_logprobs[sample_idx, :num_logprobs].tolist())) + group_prompt_logprobs.append(prompt_logprobs_dict) + sample_idx += 1 + query_result_idx += 1 + result_prompt_logprobs.append(group_prompt_logprobs) + else: + result_prompt_logprobs.append(None) + + # Sample logprobs + num_logprobs = sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + group_sample_logprobs: SampleLogprobs = [] + for next_token_id, parent_id in zip(next_token_ids, parent_ids): + sample_logprobs_dict = { + next_token_id: + batched_logprobs_query_result[query_result_idx].item() + } + query_result_idx += 1 + if num_logprobs > 0: + sample_logprobs_dict.update( + zip( + top_token_ids[sample_idx + + parent_id, :num_logprobs].tolist(), + top_logprobs[sample_idx + + parent_id, :num_logprobs].tolist())) + group_sample_logprobs.append(sample_logprobs_dict) + result_sample_logprobs.append(group_sample_logprobs) + sample_idx += len(seq_ids) + + return result_prompt_logprobs, result_sample_logprobs + + +def _build_sampler_output( + sample_results: List[Tuple[List[int], List[int]]], + input_metadata: InputMetadata, + prompt_logprobs: List[Optional[PromptLogprobs]], + sample_logprobs: List[SampleLogprobs], +) -> SamplerOutput: + sampler_output = [] + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(input_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids, _ = seq_group + next_token_ids, parent_ids = sample_result + seq_outputs = [] + for parent_id, next_token_id, logprobs in zip(parent_ids, + next_token_ids, + group_sample_logprobs): + seq_outputs.append( + SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs)) + sampler_output.append( + SequenceGroupOutputs(seq_outputs, group_prompt_logprobs)) + return sampler_output diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index f977397f9d54..b1d5f5b9fb88 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.parallel_state import ( def tensor_model_parallel_all_reduce(input_): """All-reduce the input tensor across model parallel group. - Note: This operation is applied in-place on the input tensor. + NOTE: This operation is applied in-place on the input tensor. """ # Bypass the function if we are using only 1 GPU. if get_tensor_model_parallel_world_size() == 1: diff --git a/vllm/model_executor/parallel_utils/layers.py b/vllm/model_executor/parallel_utils/layers.py index 6b5ecc4c6a92..c1aea2c1d554 100644 --- a/vllm/model_executor/parallel_utils/layers.py +++ b/vllm/model_executor/parallel_utils/layers.py @@ -133,7 +133,7 @@ class ColumnParallelLinear(torch.nn.Module): params_dtype = torch.get_default_dtype() # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result + # NOTE: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. self.create_weights(params_dtype) diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 5d0a7595a4a2..004a81f130a7 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -41,7 +41,7 @@ def split_tensor_along_last_dim( last_dim_size = divide(tensor.size()[last_dim], num_partitions) # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. + # NOTE: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) diff --git a/vllm/outputs.py b/vllm/outputs.py index 64ba8440e3ef..ad6733ff5723 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Optional +from typing import List, Optional -from vllm.sequence import SequenceGroup, SequenceStatus +from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, + SequenceStatus) class CompletionOutput: @@ -23,7 +24,7 @@ class CompletionOutput: text: str, token_ids: List[int], cumulative_logprob: float, - logprobs: Optional[List[Dict[int, float]]], + logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, ) -> None: self.index = index @@ -61,12 +62,14 @@ class RequestOutput: request_id: str, prompt: str, prompt_token_ids: List[int], + prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, ) -> None: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished @@ -91,7 +94,7 @@ class RequestOutput: # NOTE: We need to take care of this case because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. - logprobs = {} + logprobs = None finshed_reason = SequenceStatus.get_finished_reason(seq.status) output = CompletionOutput(seqs.index(seq), seq.output_text, seq.get_output_token_ids(), @@ -100,15 +103,17 @@ class RequestOutput: outputs.append(output) # Every sequence in the sequence group should have the same prompt. - prompt = top_n_seqs[0].prompt - prompt_token_ids = top_n_seqs[0].data.prompt_token_ids + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, - finished) + return cls(seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " f"finished={self.finished})") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9155c48df067..10e97d1fcb19 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -60,6 +60,12 @@ class SamplingParams: tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. + Note that the implementation follows the OpenAI API: The return + result includes the log probabilities on the `logprobs` most likely + tokens, as well the chosen tokens. The API will always return the + log probability of the sampled token, so there may be up to + `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. skip_special_tokens: Whether to skip special tokens in the output. """ @@ -80,6 +86,7 @@ class SamplingParams: ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, ) -> None: self.n = n @@ -105,6 +112,7 @@ class SamplingParams: self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs + self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens self._verify_args() @@ -142,6 +150,9 @@ class SamplingParams: if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError(f"prompt_logprobs must be non-negative, got " + f"{self.prompt_logprobs}.") def _verify_beam_search(self) -> None: if self.best_of == 1: @@ -200,4 +211,5 @@ class SamplingParams: f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens})") diff --git a/vllm/sequence.py b/vllm/sequence.py index c197e3fd9d80..5847626b0306 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,9 @@ from typing import Dict, List, Optional, Union from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams +PromptLogprobs = List[Optional[Dict[int, float]]] +SampleLogprobs = List[Dict[int, float]] + class SequenceStatus(enum.Enum): """Status of a sequence.""" @@ -116,7 +119,7 @@ class Sequence: self.block_size = block_size self.data = SequenceData(prompt_token_ids) - self.output_logprobs: List[Dict[int, float]] = [] + self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] @@ -196,7 +199,7 @@ class Sequence: """ if seq_len is None: seq_len = self.get_len() - # Note: HF implementation does not count the EOS token + # NOTE: HF implementation does not count the EOS token # towards the length, we align with that here for testing. if (eos_token_id is not None and self.get_last_token_id() == eos_token_id): @@ -238,6 +241,19 @@ class SequenceGroup: self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.prompt_logprobs: Optional[PromptLogprobs] = None + + @property + def prompt(self) -> str: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).prompt + + @property + def prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).data.prompt_token_ids def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining @@ -370,6 +386,22 @@ class SequenceOutputs: and self.logprobs == other.logprobs) +class SequenceGroupOutputs: + """The model outputs associated with a sequence group.""" + + def __init__( + self, + samples: List[SequenceOutputs], + prompt_logprobs: Optional[PromptLogprobs], + ) -> None: + self.samples = samples + self.prompt_logprobs = prompt_logprobs + + def __repr__(self) -> str: + return (f"SequenceGroupOutputs(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})") + + # For each sequence group, we generate a list of SequenceOutputs object, # each of which contains one possible candidate for the next token. -SamplerOutput = List[List[SequenceOutputs]] +SamplerOutput = List[SequenceGroupOutputs]