diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 41b7f3da1e83..57d6d2a410ee 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -9,15 +9,26 @@ MODELS = ["facebook/opt-125m"] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1]) +@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size def test_get_prompt_logprobs( hf_runner, vllm_runner, model, dtype, + chunked_prefill_token_size: int, + num_top_logprobs: int, example_prompts, ): + max_num_seqs = 256 + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_seqs = min(chunked_prefill_token_size, max_num_seqs) + max_num_batched_tokens = chunked_prefill_token_size + max_tokens = 5 - num_top_logprobs = 6 hf_model = hf_runner(model, dtype=dtype) hf_logprobs = hf_model.generate_greedy_logprobs( example_prompts, @@ -25,10 +36,17 @@ def test_get_prompt_logprobs( ) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) + vllm_model = vllm_runner( + model, + dtype=dtype, + max_logprobs=num_top_logprobs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) vllm_sampling_params = SamplingParams(max_tokens=max_tokens, logprobs=num_top_logprobs, - prompt_logprobs=5, + prompt_logprobs=num_top_logprobs, temperature=0.0) vllm_results = vllm_model.model.generate( example_prompts, sampling_params=vllm_sampling_params) @@ -52,9 +70,18 @@ def test_get_prompt_logprobs( "The output text from the top logprob for each token position " "should be the same as the output text in the result.") + # The first prompt logprob is always None + assert result.prompt_logprobs[0] is None + for prompt_logprobs in result.prompt_logprobs[1:]: + # If the prompt token is not included in the top X + # logprob, it can return 1 more data + assert (len(prompt_logprobs) == num_top_logprobs + or len(prompt_logprobs) == num_top_logprobs + 1) + # Test whether prompt logprobs are consistent with HF for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): # Check prompt logprobs + # The first prompt logprob is always None, so we compare it from 1:. vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for token_id, logprob in vllm_prompt_logprob_dict.items(): @@ -74,6 +101,17 @@ def test_get_prompt_logprobs( "The token should be decoded by the time it is returned " " to the user.") + # Test if prompt logprobs are correctly set. + for vllm_result in vllm_results: + token_ids = vllm_result.prompt_token_ids + prompt_logprobs = vllm_result.prompt_logprobs + + # The first token doesn't have logprob. + assert prompt_logprobs[0] is None + + for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]): + assert token_id in logprob_dict + def test_max_logprobs(): runner = VllmRunner("facebook/opt-125m", max_logprobs=1) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 52a2b0ca52aa..6f2145f8cdcf 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -8,6 +8,7 @@ import torch from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import Counter @@ -54,6 +55,7 @@ def _do_sample( sampler: MockLogitsSampler, model_runner: ModelRunner, sampling_params: SamplingParams, + device: str, ): seq_group_metadata_list = [] prompt_lens = [] @@ -68,9 +70,12 @@ def _do_sample( )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -85,7 +90,7 @@ def test_sampler_all_greedy(seed: int, device: str): sampling_params = SamplingParams(temperature=0) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -111,7 +116,7 @@ def test_sampler_all_random(seed: int, device: str): n=random.randint(1, 10), ) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -137,7 +142,7 @@ def test_sampler_all_random_seed(seed: int, device: str): seed=random.randint(0, 10000), ) sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, - sampling_params) + sampling_params, device) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -160,10 +165,10 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + model_runner, sampling_params, device) second_sampler_output = _do_sample(batch_size, fake_logits, sampler, - model_runner, sampling_params) + model_runner, sampling_params, device) assert first_sampler_output == second_sampler_output @@ -183,7 +188,8 @@ def test_sampler_all_beam(seed: int, device: str): best_of=2, use_beam_search=True, ) - _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params) + _do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, + device) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler @@ -443,10 +449,12 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): "batch size") _, fake_logits, sampler, model_runner = _prepare_test(batch_size) - sampling_metadata = model_runner._prepare_sample( + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None) + subquery_lens=prompt_lens if prompt_lens else None, + device=device, + pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -530,8 +538,12 @@ def test_sampler_mixed(seed: int, device: str): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): - sampling_metadata = model_runner._prepare_sample( - seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) @@ -627,9 +639,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=device, + pin_memory=model_runner.pin_memory) sample_probs = None diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 5bb93ca74855..dbaeb4de1825 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -6,6 +6,7 @@ import pytest import torch from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner @@ -82,9 +83,12 @@ def test_logits_processors(seed: int, device: str): )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( embedding=None, hidden_states=input_tensor, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 59bed2ce0dad..abb401f25c10 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,6 +2,7 @@ import pytest import torch from vllm.config import ModelConfig, SchedulerConfig +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -97,9 +98,12 @@ def test_prepare_prompt(batch_size): assert len(input_positions) == sum(prompt_lens) torch.testing.assert_close(input_tokens, input_positions) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) assert len(input_tokens) == sum(prompt_lens) assert len(input_positions) == sum(prompt_lens) actual = sampling_metadata.selected_token_indices @@ -195,9 +199,12 @@ def test_prepare_decode_cuda_graph(batch_size): for prompt_len in prompt_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + subquery_lens=prompt_lens, + device=model_runner.device, + pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ac3bd7d228e9..7439f7dc33e8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -915,6 +915,20 @@ class Scheduler: self.block_manager.get_common_computed_block_ids( seq_group.get_seqs(status=SequenceStatus.RUNNING))) + do_sample = True + if seq_group.is_prefill(): + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + seqs[0].data.get_num_computed_tokens() < + seqs[0].data.get_len()): + do_sample = False + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. is_prompt = seq_group.is_prefill() @@ -924,6 +938,7 @@ class Scheduler: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + do_sample=do_sample, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 518532e4a280..89ee3f0db491 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -219,7 +219,7 @@ class _AsyncLLMEngine(LLMEngine): request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups) + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. if self.log_stats: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d2f5379e621c..741d3bcd8089 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceStage) + SequenceGroup, SequenceGroupMetadata) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -476,9 +476,12 @@ class LLMEngine: return self.scheduler.has_unfinished_seqs() def _process_model_outputs( - self, output: List[SamplerOutput], - scheduled_seq_groups: List[SequenceGroup], - ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: + self, + output: List[SamplerOutput], + scheduled_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup], + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> List[RequestOutput]: """Apply the model output to the sequences in the scheduled seq groups. Returns RequestOutputs that can be returned to the client. @@ -492,17 +495,15 @@ class LLMEngine: sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs in zip(scheduled_seq_groups, - output_by_sequence_group): + for scheduled_seq_group, outputs, seq_group_meta in zip( + scheduled_seq_groups, output_by_sequence_group, + seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) - # If all sequences in the sequence group are in DECODE, then we can - # process the output tokens. Otherwise, they are (chunked) prefill - # samples and should not be processed. - stages = [seq.data._stage for seq in seq_group.seqs_dict.values()] - if all(stage == SequenceStage.DECODE for stage in stages): + self.output_processor.process_prompt_logprob(seq_group, outputs) + if seq_group_meta.do_sample: self.output_processor.process_outputs(seq_group, outputs) # Free the finished sequence groups. @@ -585,7 +586,7 @@ class LLMEngine: request_outputs = self._process_model_outputs( output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups) + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. if self.log_stats: diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index f307ea4da301..9ddb6a3648b8 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC): scheduler. """ pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 39e99d06ed87..9abd87a4d5a9 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -44,6 +44,15 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): self.get_tokenizer_for_seq = get_tokenizer_for_seq self.stop_checker = stop_checker + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + # TODO(sang): Prompt logprob currently not implemented in multi step + # workers. + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + pass + def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: """Append new tokens in the outputs to sequences in the sequence group. diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 7e9d65244670..07b140584bbe 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ), f"{type(self)} does not support multiple outputs per step" return self._process_sequence_group_outputs(sequence_group, outputs[0]) - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: - - # Process prompt logprobs - prompt_logprobs = outputs.prompt_logprobs - if prompt_logprobs is not None and \ - seq_group.sampling_params.detokenize and self.detokenizer: + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + assert len(outputs) == 1, ("Single step should only has 1 output.") + output = outputs[0] + prompt_logprobs = output.prompt_logprobs + if (prompt_logprobs is not None + and seq_group.sampling_params.detokenize and self.detokenizer): self.detokenizer.decode_prompt_logprobs_inplace( seq_group, prompt_logprobs) - seq_group.prompt_logprobs = prompt_logprobs + if not seq_group.prompt_logprobs: + # The first prompt token's logprob is None because it doesn't + # have tokens that are precedent. + seq_group.prompt_logprobs = [None] + seq_group.prompt_logprobs.extend(prompt_logprobs) + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index d076fee8c2a3..9816e966c1e3 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,10 +1,11 @@ from typing import List -from vllm.sequence import SamplerOutput +from vllm.sequence import SamplerOutput, SequenceGroupOutput -def create_output_by_sequence_group(sampler_outputs: List[SamplerOutput], - num_seq_groups: int): +def create_output_by_sequence_group( + sampler_outputs: List[SamplerOutput], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index e556e31f9937..22620d9fc86d 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -83,30 +83,27 @@ def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - logits_row_idx = 0 found_logits_processors = False - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + logits_processed = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params logits_processors = sampling_params.logits_processors - # handle prompt_logprobs by skipping rows in logits added for - # the prompt tokens (prompt logprobs are not processed) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - assert len(seq_ids) == 1 - logits_row_idx += sampling_metadata.prompt_lens[i] - 1 if logits_processors: found_logits_processors = True - for seq_id in seq_ids: + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + token_ids = seq_group.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + if found_logits_processors: # verifies that no rows in logits were missed unexpectedly - assert logits_row_idx == logits.shape[0] + assert logits_processed == logits.shape[0] return logits diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c4b11cb33a67..2ffa8227cc4e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -7,11 +7,11 @@ import torch.nn as nn from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors) -from vllm.sampling_params import SamplingParams, SamplingType + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, - SamplerOutput, SequenceData, SequenceGroupOutput, - SequenceOutput) + SamplerOutput, SequenceGroupOutput, SequenceOutput) class Sampler(nn.Module): @@ -48,11 +48,14 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: + """ + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ assert logits is not None _, vocab_size = logits.shape - # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - # have not been generated yet logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Prepare sampling tensors with pinned memory to avoid blocking. @@ -83,7 +86,6 @@ class Sampler(nn.Module): # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities. - # Use log_softmax to ensure numerical stability. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. @@ -149,24 +151,28 @@ def _apply_min_tokens_penalty( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ # list of indices in logits that will be set to -inf logits_to_penalize = [] - start_idx = 0 - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params - # handle prompt_logprobs by skipping rows in logits added for the prompt - # tokens (prompt logprobs are not penalized) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - assert len(seq_ids) == 1 - start_idx += sampling_metadata.prompt_lens[i] - 1 + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens if min_tokens > 0: seqs_to_penalize = [] for i, seq_id in enumerate(seq_ids): - seq_data = sampling_metadata.seq_data[seq_id] + seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: seqs_to_penalize.append(i) @@ -180,15 +186,13 @@ def _apply_min_tokens_penalty( logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) - start_idx += len(seq_ids) - if logits_to_penalize: # use zip and * to group indices along each dimension # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) logits[tuple(zip(*logits_to_penalize))] = -float("inf") # verifies that no rows in logits were missed unexpectedly - assert start_idx == logits.shape[0] + assert logits_applied == logits.shape[0] return logits @@ -265,14 +269,30 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], + selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ samples = samples.tolist() sample_idx = 0 results = [] for seq_group in selected_seq_groups: - seq_ids, _ = seq_group + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids num_parent_seqs = len(seq_ids) assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") @@ -284,16 +304,33 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], + selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. @@ -311,11 +348,20 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], + selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: + """Run beam sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + logprobs: (num_selected_samples, vocab_size,) A tensor of logprob + on selected sample indices. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ # We sample 2 * beam_width candidates to make sure that with high # probability we can get `beam_width` candidates in addition to # the finished sequences for the next iteration. See @@ -327,8 +373,13 @@ def _beam_search_sample( # other sampling methods. sample_idx = 0 results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + is_prompt = seq_group.is_prompt + seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params num_parent_seqs = len(seq_ids) beam_width = sampling_params.best_of seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] @@ -343,7 +394,8 @@ def _beam_search_sample( else: # Generation phase. cumulative_logprobs = [ - seq_data[seq_id].cumulative_logprob for seq_id in seq_ids + seq_group.seq_data[seq_id].cumulative_logprob + for seq_id in seq_ids ] cumulative_logprobs = torch.tensor( cumulative_logprobs, @@ -371,8 +423,7 @@ def _beam_search_sample( def _multinomial( probs: torch.Tensor, num_samples: int, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, - generators: Optional[List[torch.Generator]] = None, + seq_groups: Optional[List[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also @@ -388,9 +439,11 @@ def _multinomial( q.exponential_() else: sample_idx = 0 - for (seq_ids, _), generator in zip(seq_groups, generators): + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids next_sample_idx = sample_idx + len(seq_ids) * num_samples - q[sample_idx:next_sample_idx].exponential_(generator=generator) + q[sample_idx:next_sample_idx].exponential_( + generator=seq_group.generator) sample_idx = next_sample_idx return probs.div_(q).argmax(dim=1).view(-1, num_samples) @@ -405,7 +458,7 @@ def _sample_with_torch( categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -429,13 +482,11 @@ def _sample_with_torch( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices) - long_sample_indices = sample_indices.long() + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() if sampling_type == SamplingType.GREEDY: greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) @@ -455,14 +506,13 @@ def _sample_with_torch( elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of_in_batch = 1 - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, - "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( @@ -481,25 +531,22 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ - sampling_type] + (seq_group_id, seq_groups) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample(seq_groups, greedy_samples) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, is_prompts, + sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] return sample_results, sampled_token_ids_tensor @@ -514,7 +561,7 @@ def _sample_with_triton_kernel( categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group + sampling_params = seq_group.sampling_params sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) @@ -530,17 +577,16 @@ def _sample_with_triton_kernel( num_tokens = len(sample_indices) if num_tokens == 0: continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices, + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups, + sample_indices, sampled_token_indices) if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, SamplingType.RANDOM_SEED): - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) elif sampling_type == SamplingType.BEAM: @@ -564,22 +610,21 @@ def _sample_with_triton_kernel( for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - (seq_group_ids, seq_groups, is_prompts, sample_indices, + (seq_group_id, seq_groups, sample_indices, sampled_token_indices) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample( seq_groups, sampled_tokens[sampled_token_indices][:, 0]) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample( - seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + seq_groups, sampled_tokens[sampled_token_indices]) elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, + sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) + sample_results_dict.update(zip(seq_group_id, sample_results)) sample_results = [ - sample_results_dict[i] + sample_results_dict.get(i, ([], [])) for i in range(len(sampling_metadata.seq_groups)) ] return sample_results @@ -590,6 +635,18 @@ def _sample( sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ return _sample_with_torch( probs, logprobs, @@ -626,56 +683,97 @@ def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: List[Tuple[List[int], List[int]]], -) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ - int, float]]]]: - # Prepare query indices - batched_logprobs_query_seq_indices: List[int] = [] - batched_logprobs_query_token_indices: List[int] = [] - # at least get one logprob for each token +) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: + """Return sample lobprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: List[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: List[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. largest_num_logprobs = 1 - sample_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_parent_seqs = len(seq_ids) - if (i < sampling_metadata.num_prompts + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): largest_num_logprobs = max(largest_num_logprobs, sampling_params.prompt_logprobs) - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - batched_logprobs_query_seq_indices.extend( - sample_idx + j for j in range(prompt_len - 1)) - batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) - sample_idx += prompt_len - 1 - batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) - batched_logprobs_query_seq_indices_gpu = torch.tensor( - batched_logprobs_query_seq_indices, device=logprobs.device) - batched_logprobs_query_token_indices_gpu = torch.tensor( - batched_logprobs_query_token_indices, device=logprobs.device) + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) - # Batched query for logprobs of selected token - batched_logprobs_query_result = logprobs[[ - batched_logprobs_query_seq_indices_gpu, - batched_logprobs_query_token_indices_gpu + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob = [] + empty_prompt_logprob = None + return [empty_prompt_logprob], [empty_sampled_logprob] + + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] - batched_ranks_query_result = _get_ranks( - logprobs[batched_logprobs_query_seq_indices_gpu], - batched_logprobs_query_token_indices_gpu) - - # Batched query for logprobs of topk tokens + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). if largest_num_logprobs > 0: top_logprobs, top_token_ids = torch.topk(logprobs, largest_num_logprobs, @@ -685,79 +783,136 @@ def _get_logprobs( else: top_logprobs, top_token_ids = None, None - batched_logprobs_query_result = batched_logprobs_query_result.cpu() - batched_ranks_query_result = batched_ranks_query_result.cpu() + selected_logprobs = selected_logprobs.cpu() + ranks = ranks.cpu() - # Gather results - result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] - result_sample_logprobs: List[SampleLogprobs] = [] - sample_idx = 0 - query_result_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 - # Prompt logprobs - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - num_logprobs = sampling_params.prompt_logprobs - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - group_prompt_logprobs: PromptLogprobs = [None] - for token_id in prompt_tokens[1:]: - prompt_logprobs_dict = { - token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) - } - if num_logprobs > 0: - prompt_logprobs_dict.update( - zip( - top_token_ids[sample_idx, :num_logprobs].tolist(), - zip( - top_logprobs[ - sample_idx, :num_logprobs].tolist(), - range(1, num_logprobs + 1)))) - group_prompt_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in prompt_logprobs_dict.items() - }) - sample_idx += 1 - query_result_idx += 1 - result_prompt_logprobs.append(group_prompt_logprobs) - else: - result_prompt_logprobs.append(None) + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) - # Sample logprobs - num_logprobs = sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 - group_sample_logprobs: SampleLogprobs = [] - for next_token_id, parent_id in zip(next_token_ids, parent_ids): - sample_logprobs_dict = { - next_token_id: - (batched_logprobs_query_result[query_result_idx].item(), - batched_ranks_query_result[query_result_idx].item()) + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if (is_prompt and sampling_params.prompt_logprobs is not None): + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + for token_id in next_prompt_tokens: + # Calculate the prompt logprob of the real prompt tokens. + # Use tuple here for performance (to use to_list()). + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + token_id: (selected_logprobs[selected_logprobs_idx].item(), + ranks[selected_logprobs_idx].item()) } - query_result_idx += 1 - if num_logprobs >= 0: - sample_logprobs_dict.update( + + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + prompt_logprobs_dict.update( zip( - top_token_ids[sample_idx + + top_token_ids[top_logprob_idx, :num_logprobs].tolist(), + zip( + top_logprobs[ + top_logprob_idx, :num_logprobs].tolist(), + # This is ranks. Since top_logprob is sorted, + # we can just use a range here. + range(1, num_logprobs + 1)))) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + selected_logprobs_idx += 1 + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids): + # Calculate the sample logprob of the real sampled tokens. + # Use tuple here for performance (to use to_list()). + # token_id: (logprob, rank_from_vocab) + sampled_logprobs_dict: Dict[int, Tuple[float, int]] = { + next_token_id: + (selected_logprobs[selected_logprobs_idx].item(), + ranks[selected_logprobs_idx].item()) + } + # +1 to go to the next sampled token. Note that + # selected_logprobs can contain duplicates unlike top_logprobs + # when beam search is enabled. + selected_logprobs_idx += 1 + + # Second, add top K logprobs along with its rank. + if num_logprobs >= 0: + sampled_logprobs_dict.update( + zip( + top_token_ids[top_logprob_idx + parent_id, :num_logprobs].tolist(), zip( - top_logprobs[sample_idx + + top_logprobs[top_logprob_idx + parent_id, :num_logprobs].tolist(), + # This is rank. Since top_logprob is sorted, we + # can just use a range here. range(1, num_logprobs + 1)))) - group_sample_logprobs.append({ - token_id: Logprob(*logprob_rank) - for token_id, logprob_rank in sample_logprobs_dict.items() + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() }) - result_sample_logprobs.append(group_sample_logprobs) - sample_idx += len(seq_ids) - - return result_prompt_logprobs, result_sample_logprobs + # There are len(seq_ids) number of sampled tokens for the current + # sequence group in top_logprobs. Jump to the next seq_group. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, @@ -832,7 +987,7 @@ def _build_sampler_output( group_sample_logprobs) in zip(sampling_metadata.seq_groups, sample_results, prompt_logprobs, sample_logprobs): - seq_ids, _ = seq_group + seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs = [] for parent_id, next_token_id, logprobs in zip(parent_ids, @@ -854,3 +1009,36 @@ def _build_sampler_output( sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, ) + + +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + subquery_len = seq_group.subquery_len + assert subquery_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + subquery_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 31032c4cead2..12156b2ba1aa 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -6,57 +6,275 @@ import torch from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SequenceData -from vllm.utils import is_pin_memory_available +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.utils import (async_tensor_h2d, is_pin_memory_available, + maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 +@dataclass +class SequenceGroupToSample: + # Sequence ids for the sequence group in a previous step. + seq_ids: List[int] + sampling_params: SamplingParams + # seq_id -> sequence data. + seq_data: Dict[int, SequenceData] + # The length of the prompt of the sequence group. None if it is in a decode + # stage. + prompt_len: Optional[int] + # The length of the query tokens to compute in the current step. None if it + # is in a decode stage. The length of subquery_len <= prompt_len. + subquery_len: Optional[int] + # A random number generator for sampling. + generator: Optional[torch.Generator] + # True if the sequence group is in prefill stage. False if it is in a + # decode stage. + is_prompt: bool + # Query token indices from logits. to compute prompt logprob. Empty if + # prompt logprob is not required. + prompt_logprob_indices: List[int] + # Sample token indices from logits. Empty if sampling is not required. + sample_indices: List[int] + + @property + def do_sample(self): + return len(self.sample_indices) > 0 + + def __post_init__(self): + if len(self.prompt_logprob_indices) > 0: + assert self.sampling_params.prompt_logprobs is not None + if self.is_prompt: + assert self.prompt_len is not None + assert self.subquery_len is not None + + class SamplingMetadata: """Metadata for input sequences. Used in sampler. + The usage is as follow; + ``` + hidden_states = execute_model(...) + logits = hidden_states[sampling_metadata.selected_token_indices] + sample(logits) + + def sample(logits): + # Use categorized_sample_indices for sampling.... + ``` + Args: - seq_groups: List of (seq_ids, sampling_params). - seq_data: Seq_id -> SequenceData. - prompt_lens: Lengths of prompts. - selected_token_indices: Token indices selected for sampling. + seq_groups: List of batched sequence groups. + selected_token_indices: (num_query_tokens_to_logprob). Indices to find + logits from the initial model output hidden states. categorized_sample_indices: SamplingType -> token indices to sample. - generators: List of torch.Generators to use for seeded sampling - perform_sampling: Whether to perform sampling. This option is used to - make the sampling only happens in the driver worker, and disable - sampling in other worker processes. + Each token indices is 2D tensor of (num_indices, num_indices) where + the first item means the sample index within the returned logit + (before pruning padding), and the second item means the sample + index after pruning using selected_token_indices. + For example, if the returned logit is [1, 2, 3], and we select + [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, + The first tuple is [1, 2] (sampled index within original logit), + and the second tuple is [0, 1] (sampled index within pruned logit). + num_prompts: Number of prompt sequence groups in seq_groups. """ def __init__( self, - seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], - seq_data: Optional[Dict[int, SequenceData]], - prompt_lens: Optional[List[int]], + seq_groups: List[SequenceGroupToSample], selected_token_indices: torch.Tensor, - categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], - generators: Optional[List[torch.Generator]] = None, - perform_sampling: bool = True, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + num_prompts: int, ) -> None: self.seq_groups = seq_groups - self.seq_data = seq_data - self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices - self.generators = generators - self.perform_sampling = perform_sampling + self.num_prompts = num_prompts - self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 + @staticmethod + def prepare( + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + subquery_lens: Optional[List[int]], + device: str, + pin_memory: bool, + ) -> "SamplingMetadata": + ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, + subquery_lens, device) + selected_token_indices = async_tensor_h2d(selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + categorized_sample_indices = { + t: maybe_expand_dim( + async_tensor_h2d(seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory), 2, 2) + for t, seq_ids in categorized_sample_indices.items() + } + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + num_prompts=num_prompts, + ) + return sampling_metadata def __repr__(self) -> str: return ( "SamplingMetadata(" f"seq_groups={self.seq_groups}, " - f"seq_data={self.seq_data}, " - f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices}), " - f"perform_sampling={self.perform_sampling})") + f"categorized_sample_indices={self.categorized_sample_indices}), ") + + +def _prepare_seq_groups( + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + subquery_lens: Optional[List[int]], + device: str, +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ + SamplingType, List[Tuple[int, int]]], int]: + """Prepare sequence groups and indices for sampling. + + Args: + seq_group_metadata_list: A list of sequence group to batch. + prompt_lens: A list of prompt lens per sequence group. + Index of prompt len should match with seq_group_metadata_list. + subquery_lens: A list of query lengths. Prompt lens include the length + of entire prompt tokens, and it could be shorter. + device: A device to use for random number generator, + `SequenceGroupToSample.generator`. + + Returns: + seq_groups: A list of sequence group to sample. + selected_token_indices: See the definition from `SamplingMetadata`. + categorized_sample_indices: See the definition from `SamplingMetadata`. + num_prompts: Total number of prompts from `seq_group_metadata_list`. + """ + # Batched sequence groups for the current model forward stsep. + seq_groups: List[SequenceGroupToSample] = [] + # A list of token indices to sample/compute logprob. It is used to + # prune the outcome logits from the model for the performance. + selected_token_indices: List[int] = [] + # Used for selected_token_indices. + model_output_idx = 0 + + # Sampling type -> ( + # indices to sample/prompt logprob within pruned output logits, + # indices to sample within pruned logits) + categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } + # Index of logits to compute logprob. Logits include both prompt logprob + # and sample logprob indices. + logit_idx = 0 + # Index to sample from a sample tensor. It is used by triton sample kernel. + # See `_sample_with_triton_kernel` for more details. + sample_idx = 0 + # Total number of prompts from given sequence groups. + num_prompts = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + is_prompt = seq_group_metadata.is_prompt + generator: Optional[torch.Generator] = None + # If the current seq group is in decode stage, it is None. + prompt_len: Optional[int] = None + subquery_len: Optional[int] = None + prompt_logprob_indices: List[int] = [] + sample_indices: List[int] = [] + do_sample = seq_group_metadata.do_sample + + if seq_group_metadata.is_prompt: + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device=device).manual_seed(sampling_params.seed) + + num_prompts += 1 + num_prefill_sample = len(seq_ids) + assert num_prefill_sample == 1 + assert subquery_lens is not None and prompt_lens is not None + subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + # If we need sampling, exclude num_prefill_sample tokens from + # prompt logprob. + prompt_logprob_len = (subquery_len - num_prefill_sample + if do_sample else subquery_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + # Decode + prompt_logprob_len = 0 + sample_len = len(seq_ids) if do_sample else 0 + + # Update indices to select from the model output. + """ + This blocks computes selected_token_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + """ + + if sampling_params.prompt_logprobs: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + prompt_logprob_len)) + model_output_idx += prompt_logprob_len + if do_sample: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + sample_len)) + model_output_idx += sample_len + + # We now find indices for logprob computation and sampling. + """ + This block computes categorized_sample_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + def sample(logits): + # Use categorized_sample_indices for sampling. + # prompt_logprob_indices to find prompt logprob indices. + # sample_indices to find sample indices. + """ + + if sampling_params.prompt_logprobs is not None: + prompt_logprob_indices.extend( + range(logit_idx, logit_idx + prompt_logprob_len)) + logit_idx += prompt_logprob_len + if do_sample: + sample_indices.extend(range(logit_idx, logit_idx + sample_len)) + categorized_sample_indices[sampling_params.sampling_type].extend( + list( + zip(range(logit_idx, logit_idx + sample_len), + range(sample_idx, sample_idx + sample_len)))) + logit_idx += sample_len + sample_idx += sample_len + + if sampling_params.seed is not None: + generator = seq_group_metadata.state.generator + + seq_groups.append( + SequenceGroupToSample( + seq_ids=seq_ids, + sampling_params=sampling_params, + seq_data=seq_group_metadata.seq_data, + prompt_len=prompt_len, + subquery_len=subquery_len, + generator=generator, + is_prompt=is_prompt, + prompt_logprob_indices=list(prompt_logprob_indices), + sample_indices=list(sample_indices))) + return (seq_groups, selected_token_indices, categorized_sample_indices, + num_prompts) @dataclass @@ -112,11 +330,10 @@ class SamplingTensors: seeds_to_generate = (extra_seeds_to_generate + get_num_triton_sampler_splits(vocab_size)) - sample_indices_start_idx = 0 assert sampling_metadata.seq_groups is not None - assert sampling_metadata.seq_data is not None - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params temperature = sampling_params.temperature p = sampling_params.presence_penalty f = sampling_params.frequency_penalty @@ -145,45 +362,46 @@ class SamplingTensors: or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True - if (i < sampling_metadata.num_prompts + is_prompt = seq_group.is_prompt + if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - assert sampling_metadata.prompt_lens is not None - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - prompt_tokens.extend([] for _ in range(prompt_len - 1)) - output_tokens.extend([] for _ in range(prompt_len - 1)) - for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) + subquery_len = seq_group.subquery_len + assert subquery_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + min_ps += [min_p] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens == len(seq_ids) + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) - is_prompt = i < sampling_metadata.num_prompts if is_prompt: prompt_best_of.append(sampling_params.best_of) - assert sampling_metadata.prompt_lens is not None - prompt_len = sampling_metadata.prompt_lens[i] + subquery_len = seq_group.subquery_len + assert subquery_len is not None - if sampling_params.prompt_logprobs is not None: - # NOTE: the sampling position is the last token - # in the prompt - sample_indices_start_idx += prompt_len - 1 for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] + seq_data = seq_group.seq_data[seq_id] extra_entropy = extra_entropy or () seq_seeds = cls._get_sequence_seeds( seed, @@ -193,8 +411,7 @@ class SamplingTensors: seeds_to_generate=seeds_to_generate, is_greedy=is_greedy) sampling_seeds.append(seq_seeds) - sample_indices.append(sample_indices_start_idx) - sample_indices_start_idx += 1 + sample_indices.extend(seq_group.sample_indices) sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, @@ -217,12 +434,14 @@ class SamplingTensors: # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() - prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_max_len = max([len(tokens) for tokens in prompt_tokens], + default=0) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) for tokens in prompt_tokens ] - output_max_len = max(len(tokens) for tokens in output_tokens) + output_max_len = max([len(tokens) for tokens in output_tokens], + default=0) output_padded_tokens = [ tokens + [vocab_size] * (output_max_len - len(tokens)) for tokens in output_tokens diff --git a/vllm/sequence.py b/vllm/sequence.py index b296b37a84f1..567fca570951 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -28,7 +28,10 @@ class Logprob: decoded_token: Optional[str] = None +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. PromptLogprobs = List[Optional[Dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. SampleLogprobs = List[Dict[int, Logprob]] @@ -215,7 +218,7 @@ class Sequence: self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data: SequenceData = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -559,6 +562,9 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + do_sample: True if sampling is required. Sampling is not required when + e.g., prefill is chunked, and the current iteration only computes + query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. state: Internal state tied to this sequence group. @@ -573,6 +579,7 @@ class SequenceGroupMetadata: seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + do_sample: bool = True, token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, @@ -589,6 +596,7 @@ class SequenceGroupMetadata: self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state self._token_chunk_size = token_chunk_size + self.do_sample = do_sample if self._token_chunk_size is None: if is_prompt: @@ -650,6 +658,7 @@ class SequenceGroupOutput: prompt_logprobs: Optional[PromptLogprobs], ) -> None: self.samples = samples + # Prompt logprob for each prompt query token. self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index bf0a6c84e6f0..34d7d3dffea1 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -10,9 +10,8 @@ from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad, maybe_expand_dim +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad logger = init_logger(__name__) @@ -38,6 +37,8 @@ class CPUModelRunner: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False self.lora_config = lora_config self.vision_language_config = vision_language_config self.load_config = load_config @@ -252,99 +253,6 @@ class CPUModelRunner: attn_metadata, ) - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - subquery_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += subquery_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + subquery_len - 1)) - selected_token_indices.append(selected_token_start_idx + - subquery_len - 1) - selected_token_start_idx += subquery_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long) - - categorized_sample_indices = { - t: maybe_expand_dim(torch.tensor(seq_ids, dtype=torch.int), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -364,8 +272,15 @@ class CPUModelRunner: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + # subquery_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use prompt_lens instead. + prompt_lens, + self.device, + pin_memory=False) # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens, @@ -389,7 +304,6 @@ class CPUModelRunner: selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, - perform_sampling=False, ) return (input_tokens, input_positions, attn_metadata, @@ -421,7 +335,7 @@ class CPUModelRunner: logits = self.model.compute_logits(hidden_states, sampling_metadata) # Only perform sampling in the driver worker. - if not sampling_metadata.perform_sampling: + if not self.is_driver_worker: return None # Sample the next token. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c6da28f11032..0704f5fec54d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -20,12 +20,11 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) +from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available, + make_tensor_with_pad) logger = init_logger(__name__) @@ -547,108 +546,6 @@ class ModelRunner: slot_mapping=slot_mapping, ) - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert subquery_lens is not None - subquery_len = subquery_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += subquery_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + subquery_len - 1)) - selected_token_indices.append(selected_token_start_idx + - subquery_len - 1) - selected_token_start_idx += subquery_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - list( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx - + num_seqs)))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -685,9 +582,9 @@ class ModelRunner: decode_lora_requests, decode_slot_mapping, ) = self._prepare_decode(decode_reqs) - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens, - subquery_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, prompt_lens, subquery_lens, + self.device, self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 @@ -788,12 +685,9 @@ class ModelRunner: **metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, - seq_data=None, - prompt_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, - generators=None, - perform_sampling=False, + num_prompts=0, ) # if it is a mixed batch, decode attn_metadata is broadcasted @@ -852,7 +746,7 @@ class ModelRunner: logits = self.model.compute_logits(hidden_states, sampling_metadata) # Only perform sampling in the driver worker. - if not sampling_metadata.perform_sampling: + if not self.is_driver_worker: return None # Sample the next token. @@ -860,6 +754,7 @@ class ModelRunner: logits=logits, sampling_metadata=sampling_metadata, ) + return output @torch.inference_mode() diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 487df334d73e..a974e85c22f4 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn @@ -8,10 +8,8 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.utils import (async_tensor_h2d, is_pin_memory_available, - make_tensor_with_pad, maybe_expand_dim) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad logger = init_logger(__name__) @@ -141,106 +139,6 @@ class NeuronModelRunner: return input_tokens, input_positions, input_block_ids - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> SamplingMetadata: - seq_groups: List[Tuple[List[int], SamplingParams]] = [] - selected_token_indices: List[int] = [] - generators: List[torch.Generator] = [] - selected_token_start_idx = 0 - categorized_sample_indices: Dict[SamplingType, - List[Tuple[int, int]]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices_start_idx = 0 - categorized_sampled_token_indices_start_idx = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - sampling_params = seq_group_metadata.sampling_params - seq_groups.append((seq_ids, sampling_params)) - - if seq_group_metadata.is_prompt: - assert len(seq_ids) == 1 - assert prompt_lens is not None - prompt_len = prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 - - categorized_sample_indices[ - sampling_params.sampling_type].append( - (categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx)) - categorized_sample_indices_start_idx += 1 - categorized_sampled_token_indices_start_idx += 1 - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - - if sampling_params.seed is not None: - seq_group_metadata.state.generator = torch.Generator( - device=self.device).manual_seed(sampling_params.seed) - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) - selected_token_start_idx += num_seqs - - categorized_sample_indices[ - sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) - categorized_sample_indices_start_idx += num_seqs - categorized_sampled_token_indices_start_idx += num_seqs - - if sampling_params.seed is not None: - generators.append(seq_group_metadata.state.generator) - - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=self.device, - pin_memory=self.pin_memory) - - categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=self.pin_memory), 2, 2) - for t, seq_ids in categorized_sample_indices.items() - } - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - generators=generators, - ) - return sampling_metadata - def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -256,8 +154,15 @@ class NeuronModelRunner: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + prompt_lens, + # subquery_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use prompt_lens instead. + prompt_lens, + self.device, + self.pin_memory) return (input_tokens, input_positions, input_block_ids, sampling_metadata)