diff --git a/tests/conftest.py b/tests/conftest.py index f14b1e8780ad9..dc70c98359598 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,10 +48,10 @@ from vllm.distributed import (cleanup_dist_env_and_memory, initialize_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import set_default_torch_num_threads diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index f2e6fbfad6e80..c1305e0ae31ce 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -7,8 +7,8 @@ from typing import Optional import pytest from transformers import AutoModelForSpeechSeq2Seq +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest -from vllm.sequence import SampleLogprobs from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 67d35213d6422..77e2b90dd5e96 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -12,10 +12,10 @@ from huggingface_hub import snapshot_download from transformers import AutoTokenizer from vllm.assets.image import ImageAsset +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, PromptImageInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index cb3cc1d3d330d..715b08ef90e54 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins -from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 8b7d051218f14..ba55450ec8a90 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature, GenerationConfig, GenerationMixin) from transformers.video_utils import VideoMetadata -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 9451131960885..e39ca40fbbf5e 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -12,7 +12,7 @@ from transformers import AutoModelForCausalLM from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import RunnerOption -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, diff --git a/tests/models/utils.py b/tests/models/utils.py index 76c6e4823a12c..5da2382cef814 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -12,7 +12,7 @@ from transformers import PretrainedConfig from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from .registry import HF_EXAMPLE_MODELS diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index bd2b91073d568..fe6c313d2966f 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -8,10 +8,7 @@ import pytest from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.inputs import token_inputs -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, @@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast): assert decoded_text == '' assert out_ids == [len(tokenizer)] - - -@pytest.fixture -def detokenizer(tokenizer_name: str) -> Detokenizer: - tokenizer = get_tokenizer( - tokenizer_name, - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", - trust_remote_code=False, - revision=None, - ) - - return Detokenizer(tokenizer) - - -@pytest.fixture(name="complete_sequence_token_ids") -def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer) -> list[int]: - return tokenizer(complete_sequence, add_special_tokens=False).input_ids - - -def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [] - return Sequence( - seq_id=0, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - -def create_dummy_logprobs( - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: - return [{ - token_id: Logprob(logprob=0.0), - token_id + 1: Logprob(logprob=0.1) - } for token_id in complete_sequence_token_ids] - - -def create_dummy_prompt_logprobs( - complete_sequence_token_ids: list[int] -) -> list[Optional[dict[int, Any]]]: - # logprob for the first prompt token is None. - logprobs: list[Optional[dict[int, Any]]] = [None] - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) - return logprobs - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) -def test_decode_sequence_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): - """Verify Detokenizer decodes logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - logprobs=2) - - # Run sequentially. - seq = create_sequence() - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - sequential_logprobs_text_chosen_token: list[str] = [] - sequential_logprobs_text_other_token: list[str] = [] - for new_token, logprobs in zip(complete_sequence_token_ids, - dummy_logprobs): - seq.append_token_id(new_token, logprobs) - detokenizer.decode_sequence_inplace(seq, sampling_params) - sequential_logprobs_text_chosen_token.append( - seq.output_logprobs[-1][new_token].decoded_token) - sequential_logprobs_text_other_token.append( - seq.output_logprobs[-1][new_token + 1].decoded_token) - sequential_result = seq.output_text - - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) - assert sequential_result != "".join(sequential_logprobs_text_other_token) - - if not skip_special_tokens: - # Text for logprobs for the chosen token should be the same as the - # generated text. Note that this will only be true if we skip - # special tokens. - assert sequential_result == complete_sequence - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer): - - # We want to use skip_special_tokens=False here but Mistral tokenizers - # don't support that. - if complete_sequence not in SPECIAL_TOKS_TRUTH: - skip_special_tokens = True - elif not isinstance(detokenizer.tokenizer, MistralTokenizer): - skip_special_tokens = False - else: - pytest.skip("MistralTokenizers don't support " - "skip_special_tokens=False") - return - """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - prompt_logprobs=1) - - # Run sequentially. - seq = create_sequence(complete_sequence_token_ids) - seq_group = SequenceGroup(request_id="1", - seqs=[seq], - sampling_params=sampling_params, - arrival_time=0.0) - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, - dummy_logprobs, - position_offset=0) - # First logprob is None. - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ - 1:] # type: ignore - - # decoded_prompt_logprobs doesn't contain the first token. - token_ids = complete_sequence_token_ids - tokenizer = detokenizer.tokenizer - text_full = tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) - text_first = tokenizer.decode(token_ids[0], - skip_special_tokens=skip_special_tokens) - text = text_full[len(text_first):] - - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that the first logprob is None. - assert text == "".join([ - logprobs[token_id].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - assert text != "".join([ - logprobs[token_id + 1].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 35153139350bf..57ace1fa22ac9 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -12,7 +12,7 @@ from partial_json_parser.core.options import Allow from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import JambaToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "ai21labs/Jamba-tiny-dev" diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index ccb2acf512caf..f06fb2b9f2f04 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ToolCall) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( Qwen3CoderToolParser) -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index c276a598aa68c..118c7534622e2 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 0bc22e4f1031c..c07ca0f56d6b0 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index a9632ce54eac8..bdb40be99aa3f 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -12,9 +12,9 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, STOP_STRINGS, DummyOutputProcessorTestVectors, MockEngineCore) +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import (OutputProcessor, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index b75b94ad0acc2..fd4b992c3821b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -15,10 +15,10 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.tasks import SupportedTask from vllm.utils import make_async +from vllm.v1.outputs import SamplerOutput from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 78d0ee6c1e3fc..84747575b4960 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -17,12 +17,12 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.v1.outputs import SamplerOutput if ray is not None: from ray.actor import ActorHandle diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e9db2a0dc13a8..46f49aaa013da 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -7,15 +7,7 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import (DummyData, InputContext, InputProcessingContext, - InputRegistry) - -INPUT_REGISTRY = InputRegistry() -""" -The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used -by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the -target model. -""" +from .registry import InputContext, InputProcessingContext __all__ = [ "DataPrompt", @@ -36,9 +28,6 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "INPUT_REGISTRY", - "DummyData", "InputContext", "InputProcessingContext", - "InputRegistry", ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f0b392e9767ae..b5316b6d0574c 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from transformers import BatchFeature, PretrainedConfig, ProcessorMixin @@ -15,16 +15,9 @@ from vllm.utils.jsontree import JSONTree, json_map_leaves if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, - MultiModalRegistry) - from vllm.sequence import SequenceData from vllm.transformers_utils.tokenizer import AnyTokenizer else: ModelConfig = Any - MultiModalDataDict = Any - MultiModalPlaceholderDict = Any - MultiModalRegistry = Any - SequenceData = Any AnyTokenizer = Any _T = TypeVar("_T") @@ -191,61 +184,3 @@ class InputProcessingContext(InputContext): f"on data={data} with kwargs={allowed_kwargs}") raise ValueError(msg) from exc - - -class DummyData(NamedTuple): - """ - Dummy data used for profiling. - - Note: This is only used in V0. - """ - - seq_data: SequenceData - multi_modal_data: Optional[MultiModalDataDict] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - - -class InputRegistry: - """ - Note: This is only used in V0. - """ - - def dummy_data_for_profiling( - self, - model_config: ModelConfig, - seq_len: int, - mm_registry: MultiModalRegistry, - is_encoder_data: bool = False, - ) -> DummyData: - """ - Create dummy data for profiling the memory usage of a model. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.multimodal.cache import processor_only_cache_from_config - from vllm.sequence import SequenceData - - if not model_config.is_multimodal_model: - seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - return DummyData(seq_data=seq_data) - - cache = processor_only_cache_from_config(model_config, mm_registry) - - # Encoder dummy data does not contain multi-modal data - if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data(model_config, - seq_len, - cache=cache) - seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) - return DummyData(seq_data=seq_data) - - dec_data = mm_registry.get_decoder_dummy_data(model_config, - seq_len, - cache=cache) - - return DummyData( - seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data.get_data(), - multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 55dfe8088c8f3..a59aebfac4ff9 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -3,13 +3,11 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingMetadataCache) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", - "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8a4ac214443eb..8226437cb1898 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" -import inspect -from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.custom_op import CustomOp @@ -16,11 +13,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None -if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: - _logits_processor_threadpool = ThreadPoolExecutor( - envs.VLLM_LOGITS_PROCESSOR_THREADS) - @CustomOp.register("logits_processor") class LogitsProcessor(CustomOp): @@ -60,15 +52,10 @@ class LogitsProcessor(CustomOp): hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, - prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None and prune_hidden_states: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: @@ -79,12 +66,6 @@ class LogitsProcessor(CustomOp): if self.scale != 1.0: logits *= self.scale - - # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: @@ -125,75 +106,3 @@ class LogitsProcessor(CustomOp): s += f", org_vocab_size={self.org_vocab_size}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" return s - - -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios - # (warmup, profile_run) we might not have selected_token_indices, - # so we skip pruning. - if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) - else: - return hidden_states - - -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - found_logits_processors = False - logits_processed = 0 - logits_row_ids_and_logits_row_futures = [] - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - if _logits_processor_threadpool is not None: - logits_row_ids_and_logits_row_futures.append( - (logits_row_idx, - _logits_processor_threadpool.submit( - _apply_logits_processors_single_seq, logits_row, - logits_processors, past_tokens_ids, - prompt_tokens_ids))) - else: - logits[logits_row_idx] = \ - _apply_logits_processors_single_seq( - logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - for logits_row_idx, future in logits_row_ids_and_logits_row_futures: - logits[logits_row_idx] = future.result() - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits - - -def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> torch.Tensor: - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py deleted file mode 100644 index 9d93cad2420ad..0000000000000 --- a/vllm/model_executor/layers/sampler.py +++ /dev/null @@ -1,1198 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A layer that samples the next tokens from the model's outputs.""" -import itertools -from collections.abc import Iterator -from dataclasses import dataclass -from importlib.util import find_spec -from math import inf -from typing import Optional, Union - -import msgspec -import torch -import torch.nn as nn - -import vllm.envs as envs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.model_executor.layers.utils import apply_penalties -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, SequenceOutput) - -if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - # yapf: disable - from flashinfer.sampling import ( - top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) - - # yapf: enable -else: - flashinfer_top_k_top_p_sampling = None - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -def get_sampler() -> torch.nn.Module: - if envs.VLLM_USE_V1: - # Lazy import: the v1 package isn't distributed - from vllm.v1.sample.sampler import Sampler as V1Sampler - return V1Sampler() - return Sampler() - - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = list[tuple[list[int], list[int]]] - -# Types of temporary data structures used for -# computing sample_result -SampleMetadataType = dict[SamplingType, tuple[list[int], - list[SequenceGroupToSample]]] -MultinomialSamplesType = dict[SamplingType, torch.Tensor] -SampleResultsDictType = dict[int, tuple[list[int], list[int]]] - - -# Encapsulates temporary data structures for computing -# sample_result. -# -# * For multi-step scheduling: must be returned -# by `Sampler.forward()` and used later to compute the pythonized -# sample_result -# -# * For single-step scheduling: consumed immediately -# inside `Sampler.forward()` to compute pythonized sample_result. -@dataclass -class SampleResultArgsType: - sample_metadata: SampleMetadataType - multinomial_samples: MultinomialSamplesType - sample_results_dict: SampleResultsDictType - sampling_metadata: SamplingMetadata - greedy_samples: Optional[torch.Tensor] - - -# Union of non-deferred (single-step scheduling) -# vs deferred (multi-step scheduling) -# sample result types -MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] - -# Abbreviation of the _sample() return type -SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] - - -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: list[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # Holds either (1) the pythonized sampler result (single-step scheduling) - # or (2) what will be arguments for later deferred pythonization of the - # sampler result (muliti-step scheduling) - deferred_sample_results_args: Optional[SampleResultArgsType] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # On-device tensor containing the sampled token embeddings (embeddings - # corresponding to the sampled token ids). Used when prompt embeddings are - # specified in lieu of prompt token ids or text. - sampled_token_embeds: Optional[torch.Tensor] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: - return iter(self.outputs) - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return (f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr})") - - -class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - - The structure of the logits tensor is coupled with the seq_groups in - sampling_metadata. Typically, each sequence in each seq_group has one row in - logits for the next token to be sampled; however, for a seq_group with a - prompt request with the prompt_logprobs sampling parameter, there are rows - in logits for each token in the input prompt. - """ - - def __init__(self): - super().__init__() - - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding and when prompt embeddings are specified. - self.include_gpu_probs_tensor = False - self.should_modify_greedy_probs_inplace = False - - def _init_sampling_tensors( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - """The goal here is to reuse sampling tensors between similar decode - runs. This is possible because sampling logic does not change between - decodes of the same sequences. - """ - _, vocab_size = logits.shape - - # First free any existing stored sampling tensors. - # This is necessary because some sampling tensors may - # have pinned memory. - self._sampling_tensors = None - - # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - self._sampling_tensors = sampling_tensors - self._do_penalties = do_penalties - self._do_top_p_top_k = do_top_p_top_k - self._do_min_p = do_min_p - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """ - Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor - - Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the - [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] - structure - - Args: - logits: (num_tokens, vocab_size). - sampling_metadata: Metadata for sampling. - """ - assert logits is not None - _, vocab_size = logits.shape - - # Prepare sampling tensors with pinned memory to avoid blocking. - if not sampling_metadata.reuse_sampling_tensors: - self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: - # In this case, the sampling tensors logic depends on - # "output_tokens" of a sequence. As a result, we cannot - # reuse sampling tensors, since "output_tokens" changes - # between decode runs. - self._init_sampling_tensors(logits, sampling_metadata) - - assert self._sampling_tensors is not None - sampling_tensors = self._sampling_tensors - do_penalties = self._do_penalties - do_top_p_top_k = self._do_top_p_top_k - do_min_p = self._do_min_p - - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Apply presence and frequency penalties. - if do_penalties: - logits = apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - # Use float32 to apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - - if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=self.include_gpu_probs_tensor, - modify_greedy_probs=self._should_modify_greedy_probs_inplace, - ) - - if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs - assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. - on_device_tensors = None - - # Get the logprobs query results. - prompt_logprobs = None - sample_logprobs = None - if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) - - return _build_sampler_output( - maybe_deferred_sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors, - skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) - - @property - def _should_modify_greedy_probs_inplace(self) -> bool: - """Whether or not the sampler should modify the probability distribution - of greedily-sampled tokens such that multinomial sampling would sample - the greedily-sampled token. - - In other words, if True then we set the probability of the greedily- - sampled token to 1. - - This is used by speculative decoding, which requires that the sampling - method be encoded into the probability distribution. - """ - return self.should_modify_greedy_probs_inplace - - -def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - have not been generated yet - """ - # list of indices in logits that will be set to -inf - logits_to_penalize: list[tuple[int, int]] = [] - logits_applied = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - - sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len( - seq_group.prompt_logprob_indices) - if not seq_group.do_sample: - continue - - start_idx = sample_indices[0] - min_tokens = sampling_params.min_tokens - token_ids_to_penalize = sampling_params.all_stop_token_ids - if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: list[int] = [] - for j, seq_id in enumerate(seq_ids): - seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids_array) < min_tokens: - seqs_to_penalize.append(j) - - if seqs_to_penalize: - # convert to the index into logits - seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] - # itertools.product pairs each seq index with every token id - logits_to_penalize.extend( - itertools.product(seqs_to_penalize, token_ids_to_penalize)) - - if logits_to_penalize: - # use zip and * to group indices along each dimension - # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) - logits[tuple(zip(*logits_to_penalize))] = -float("inf") - - # verifies that no rows in logits were missed unexpectedly - assert logits_applied == logits.shape[0] - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) - return logits - - -def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Adapted from - https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 - """ - probs = torch.softmax(logits, dim=-1) - top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill_(tokens_to_remove, -float("inf")) - - return logits - - -def _greedy_sample( - selected_seq_groups: list[SequenceGroupToSample], - samples: torch.Tensor, -) -> SampleResultType: - """Run greedy sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - samples: (num_selected_samples,) A tensor of samples. The length of - samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - samples_lst = samples.tolist() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: list[SequenceGroupToSample], - random_samples: torch.Tensor, -) -> SampleResultType: - """Run random sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - random_samples: (num_selected_samples,) A tensor of samples. The - length of samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # Find the maximum n value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.n - next_token_ids = random_samples[ - sample_idx, :sampling_params.n].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[list[SequenceGroupToSample]] = None, -) -> torch.Tensor: - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - q = torch.empty_like(probs) - if seq_groups is None: - q.exponential_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - q[sample_idx:sample_idx + - stride].exponential_(generator=seq_group.generator) - sample_idx += stride - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - -def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - top_ks = top_ks.repeat_interleave(num_samples) - top_ps = top_ps.repeat_interleave(num_samples) - batch_next_token_ids = flashinfer_top_k_top_p_sampling( - probs, - top_ks, - top_ps, - ) - return batch_next_token_ids.view(-1, num_samples) - - -def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: - '''This function consumes GPU-side sampler results and computes - Pythonized CPU-side sampler results (GPU -> CPU sync.) - - Single-step scheduling: this function is invoked at sampling-time - for immediate Pythonization. - - Multi-step scheduling: Pythonization is deferred until after multiple - GPU-side steps have been completed. - - Args: - sample_result_args: GPU-side inputs to the Pythonization process - - Returns: - Pythonized sampler results - ''' - - ( - sample_metadata, - sampling_metadata, - greedy_samples, - multinomial_samples, - sample_results_dict, - ) = ( - sample_result_args.sample_metadata, - sample_result_args.sampling_metadata, - sample_result_args.greedy_samples, - sample_result_args.multinomial_samples, - sample_result_args.sample_results_dict, - ) - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - return [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - - -def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - '''Torch-oriented _sample() implementation. - - Single-step scheduling: - * Perform GPU-side sampling computation - * Immediately Pythonize sampling result - - Multi-step scheduling: - * Perform GPU-side sampling computation - * Defer Pythonization & preserve GPU-side - tensors required for Pythonization - ''' - - categorized_seq_group_ids: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: SampleResultsDictType = {} - sample_metadata: SampleMetadataType = {} - multinomial_samples: MultinomialSamplesType = {} - greedy_samples: Optional[torch.Tensor] = None - - # Create output tensor for sampled token ids. - if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), - VLLM_INVALID_TOKEN_ID, - dtype=torch.long, - device=logprobs.device) - else: - sampled_token_ids_tensor = None - - # Counterintuitively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups) - long_sample_indices = sample_indices.long() - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = greedy_samples.unsqueeze(-1) - - if modify_greedy_probs: - # If required, modify the probabilities such that sampling from - # the modified distribution would always sample the argmax - # token id. - _modify_greedy_probs_inplace(logprobs, probs, - long_sample_indices, - greedy_samples) - - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_n_in_batch = 1 - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_n_in_batch = max(max_n_in_batch, sampling_params.n) - seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else - seq_groups) - - if flashinfer_top_k_top_p_sampling is not None: - logger.warning("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[long_sample_indices] = \ - multinomial_samples[sampling_type].to(torch.long) - - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # Encapsulate arguments for computing Pythonized sampler - # results, whether deferred or otherwise. - maybe_deferred_args = SampleResultArgsType( - sampling_metadata=sampling_metadata, - sample_metadata=sample_metadata, - multinomial_samples=multinomial_samples, - greedy_samples=greedy_samples, - sample_results_dict=sample_results_dict) - - if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. - # This also converts the sampler output to a Python object. - # Return Pythonized sampler result & sampled token ids - return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) - - -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """ - This function calculates the ranks of the chosen tokens in a logprob tensor. - - Args: - x (torch.Tensor): 2D logprob tensor of shape (N, M) - where N is the no. of tokens and M is the vocab dim. - indices (torch.Tensor): List of chosen token indices. - - Returns: - torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank - of the chosen token in the input logprob tensor. - """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) - - -def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, -) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: - """Return sample logprobs and prompt logprobs. - - The logic consists of 3 parts. - - Select indices to compute logprob from, ranks of token ids, and - the top k token ids from logprobs. - - Compute prompt logprobs if required. - - Compute sample logprobs if required. - - Args: - logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's - logprob per vocab. Sequence groups' query tokens are batched in a - single flattened tensor. For example, assuming there are N - seq groups, it is sorted by prefill tokens for seq_group_1 (if - prompt logprob is enabled), decode tokens for seq_group_1 (if - sampling is required), prefill tokens for seq_group_2, ... - sampling_metadata: The sampling metadata. - sample_results: (num_seq_groups) The tuple of (next_token_ids, - parent_ids) for each sequence group. When beam search is enabled, - sample_results can contain different number of seq_ids from - sampling_metadata.seq_groups. It is because beam search creates - 2 * BEAM_WIDTH number of samples (whereas there are only up to - BEAM_WIDTH number of seq_ids). - - Returns: - A tuple of prompt and sample logprobs per sequence group in a batch. - """ - # The index of query token to calculate logprobs. It includes both - # prompt and sample logprob indices. - query_indices: list[int] = [] - # The next token ids to get the logprob value from. - next_token_ids: list[int] = [] - # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. If every logprobs is None, it will be - # set to -1. - largest_num_logprobs = -1 - - # Select indices to compute logprob from, ranks of token ids, and the top - # k token ids from logprobs. - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - sample_results): - sampling_params = seq_group.sampling_params - - # Update indices and tokens for prompt logprobs. - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - query_indices.extend(seq_group.prompt_logprob_indices) - next_token_ids.extend(next_prompt_tokens) - - # Update indices and next tokenes for sample logprob. - if seq_group.do_sample: - token_ids, parent_seq_ids = sample_result - # NOTE: We cannot directly use sample_indices because - # sample_indices only contain parent seq_ids of a previous step. - # The current step may have different number of seq_ids, and - # we can obtain it from `sample_result[1]`. - query_idx = seq_group.sample_indices[0] - query_indices.extend( - [query_idx + parent_id for parent_id in parent_seq_ids]) - next_token_ids.extend(token_ids) - - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - - assert len(next_token_ids) == len(query_indices) - - if len(query_indices) == 0: - empty_sampled_logprob: SampleLogprobs = [] - empty_prompt_logprob: Optional[PromptLogprobs] = None - num_seq_groups = len(sampling_metadata.seq_groups) - return [empty_prompt_logprob - ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups - - selected_logprobs, ranks = None, None - top_logprobs, top_token_ids = None, None - - # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can - # skip the whole logprob calculation. - if largest_num_logprobs >= 0: - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, - device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, - ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # We need to compute top k only if there exists logprobs > 0. - if largest_num_logprobs > 0: - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') - - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - - # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: list[SampleLogprobs] = [] - top_logprob_idx = 0 - selected_logprobs_idx = 0 - - for seq_group, sample_result in zip(sampling_metadata.seq_groups, - sample_results): - (prompt_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_prompt_logprob_if_needed( - seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, - selected_logprobs_idx, top_logprob_idx) - prompt_logprobs_per_seq_group.append(prompt_logprobs) - - (sampled_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_sampled_logprob_if_needed( - seq_group, sample_result, selected_logprobs, ranks, top_token_ids, - top_logprobs, selected_logprobs_idx, top_logprob_idx) - sample_logprobs_per_seq_group.append(sampled_logprobs) - - return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group - - -def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the prompt logprob from a sequence group if needed.""" - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - - # Find prompt logprobs - prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] - num_logprobs = sampling_params.prompt_logprobs - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - # Pre-select indexes and create a list. It is faster than calling .item - # repetitively. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - - for idx, token_id in enumerate(next_prompt_tokens): - # Calculate the prompt logprob of the real prompt tokens. - # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: dict[int, tuple[float, int]] = { - token_id: (selected_logprob_items[idx], rank_items[idx]) - } - - # Add top K prompt logprobs along with its rank. - if num_logprobs > 0: - top_ids = top_token_ids[ - top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - prompt_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) - }) - prompt_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in prompt_logprobs_dict.items() - }) - # + 1 to go to the next prompt token. - top_logprob_idx += 1 - - # + len(next_prompt_tokens) to go to the next prompt. - selected_logprobs_idx += len(next_prompt_tokens) - return prompt_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: tuple[list[int], list[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the sample logprob if needed.""" - seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - sampled_logprobs: SampleLogprobs = [] - next_token_ids, parent_seq_ids = sample_result - - if seq_group.do_sample: - assert len(next_token_ids) > 0 - if num_logprobs is None: - for next_token_id in next_token_ids: - # Use a dummy logprob - sampled_logprobs.append({next_token_id: Logprob(inf)}) - else: - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: - (selected_logprob_items[idx], rank_items[idx]) - } - if num_logprobs is not None and num_logprobs > 0: - # Get top K logprobs. - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks) - }) - - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - - # NOTE: This part of code is not intuitive. `selected_logprobs` include - # logprobs for the current step, which has len(next_token_ids) tokens - # per sequence group. `logprobs` includes logprobs from the previous - # steps, which has len(seq_ids) tokens per sequence group. - - # Iterate to the next sequence group in a batch. - selected_logprobs_idx += len(next_token_ids) - # Iterate to the next sequence group in a batch. - top_logprob_idx += len(seq_ids) - return sampled_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor) -> None: - """Modify the probability distributions of the greedily-sampled tokens such - that each sampled token has a "probability" of 1.0. This is required by - speculative decoding, which depends on the sampling method being encoded - within the probability distribution for correctness. - - # Why do we only need to do this for greedy sampling? - - vLLM's sampler performs the following steps for greedy or multinomial - (random) sampling: - 1. Get logits from model. - 2. Modify logits according to per-sequence sampling parameters. - - Multiply by temperature, top-k and top-p masking, penalize tokens - according to their frequency, etc. - 3. Sample a token. - - Random sampling simply samples from the modified probability - distribution. - - Greedy sampling performs `argmax` to obtain the token with the - highest likelihood. - - Ignoring greedy sampling for a moment, we find that the computed probability - distribution has the following property: we can sample from it independently - and find that the token sampled by the Sampler has a frequency corresponding - to how often we see it in our sampling. In other words, for tokens sampled - with vLLM's random SamplingType, the computed probability distribution - encodes the sampling methodology completely. - - Greedy sampling does not normally have this property. vLLM modifies logits - according to sampling params, then performs `argmax`, then returns the - sampled token and the computed probability distribution. If we sample from - the distribution, we'll find the likelihood of the greedily-sampled token - is not always 1.0. - - Since lossless speculative decoding requires that the sampling methodology - be encoded within the probability distribution, we are motivated to modify - the probability distribution such that the sampled token has probability 1 - when speculative decoding is used. - - NOTE: Alternatively, we could use an extremely low temperature to achieve - greedy sampling using multinomial computation and unite the codepaths. This - has implications on the overall design of the sampler, e.g. how to record - accurate logprobs for the user, so this improvement is deferred to later. - """ - # NOTE: logprobs are not modified so they can be returned to the user. - probs[sample_indices, :] = 0 - probs[sample_indices, greedy_samples] = 1.0 - - -def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], - sample_logprobs: Optional[list[SampleLogprobs]], - on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]], - skip_sampler_cpu_output: bool = False, -) -> SamplerOutput: - """Construct Python objects with the output of sampling. - - Args: - on_device_tensors: Tuple containing on-device tensors with the - probabilities used in sampling and the sampled token ids. This - allows post-processing without copies to CPU/serialization, e.g. in - speculative decoding rejection sampling. - """ - sampler_output: list[CompletionSequenceGroupOutput] = [] - - if skip_sampler_cpu_output: - assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) - deferred_sample_results_args = maybe_deferred_sample_results - else: - assert prompt_logprobs is not None - assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - assert len(sampling_metadata.seq_groups) \ - == len(maybe_deferred_sample_results) \ - == len(prompt_logprobs) \ - == len(sample_logprobs) - deferred_sample_results_args = None - - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - maybe_deferred_sample_results, - prompt_logprobs, sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: list[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, - group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) - - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs_tensor, - deferred_sample_results_args=deferred_sample_results_args) - - -def _get_next_prompt_tokens( - seq_group: SequenceGroupToSample) -> tuple[int, ...]: - """Get a list of next prompt tokens to compute logprob from a - given sequence group. - - It is used to compute prompt logprob. Imagine you have logprob for each - query token. Query token needs to know the next prompt token id to compute - prompt logprob. This is a helper to obtain next prompt token ids. - - This API has to be used only when the caller knows seq_group is in prefill - stage. - - Returns: - A list of next prompt tokens to compute logprob. - """ - assert seq_group.is_prompt, ( - "Caller should ensure the sequence group is in a prefill stage.") - seq_ids = seq_group.seq_ids - query_len = seq_group.query_len - assert query_len is not None - # prompt has only 1 seq id. - assert len(seq_ids) == 1 - seq_data = seq_group.seq_data[seq_ids[0]] - computed_len = seq_data.get_num_computed_tokens() - prompt_tokens = seq_data.prompt_token_ids - # +1 because we are looking for a next prompt token. - next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, - len(prompt_tokens)) - next_prompt_tokens = prompt_tokens[ - next_token_index_start:next_token_index_end] - return next_prompt_tokens diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 6ba8ad372c95a..b0a96fca2ff8a 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -2,18 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import maybe_prefix @@ -105,8 +102,10 @@ class Medusa(nn.Module): return [block(hidden_states) for block in self.blocks] def compute_logits( - self, hidden_states: list[torch.Tensor], - sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: + self, + hidden_states: list[torch.Tensor], + sampling_metadata, + ) -> list[torch.Tensor]: logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): @@ -130,57 +129,6 @@ class Medusa(nn.Module): return logits_lst - def sample( - self, - logits: list[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - logits = torch.stack(logits, dim=0).float() - logprobs = torch.log_softmax(logits, dim=-1) - token_ids = logits.argmax(-1) # support only top-1 for now - probs = torch.softmax(logits, dim=-1) - - token_id_list = [] - token_prob_list = [] - token_logprob_list = [] - - for idx, seq_group in enumerate(sampling_metadata.seq_groups): - token_id_list.append(token_ids[:, seq_group.sample_indices]) - token_prob_list.append(probs[:, seq_group.sample_indices]) - token_logprob_list.append(logprobs[:, seq_group.sample_indices]) - - outputs: list[Optional[SamplerOutput]] = [] - for idx in range(len(sampling_metadata.seq_groups)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_prob_list[idx].squeeze(1), - logprobs=token_logprob_list[idx].squeeze(1), - sampled_token_ids=token_id_list[idx].squeeze(1), - )) - - return outputs - - def generate_proposals( - self, - previous_hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - # During preemption, we may receive an empty tensor (batch_size=0) - if previous_hidden_states.size(0) == 0: - # Return None to signal the Top1Proposer that no proposals - # were generated for this batch, allowing it to handle this - # special case appropriately - return None - - return self.sample( - logits=self.compute_logits( - hidden_states=self.forward(previous_hidden_states), - sampling_metadata=sampling_metadata, - ), - sampling_metadata=sampling_metadata, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc188..d057eb49a62d1 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,9 +8,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -141,55 +139,57 @@ class MLPSpeculator(nn.Module): self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) - self.sampler = get_sampler() - def generate_proposals( - self, - input_ids: torch.Tensor, - previous_hidden_states: torch.Tensor, - num_predict_tokens: int, - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - if num_predict_tokens > self.max_speculative_tokens: - raise ValueError(f"Max speculative tokens for model is " - f"{self.max_speculative_tokens}, but " - f"{num_predict_tokens} were requested") + # NOTE(woosuk): This method is commented out because it is old code + # using V0. We should either port it to V1 or remove it. - # b x 1 x d - previous_hidden_states = previous_hidden_states.unsqueeze(1) + # def generate_proposals( + # self, + # input_ids: torch.Tensor, + # previous_hidden_states: torch.Tensor, + # num_predict_tokens: int, + # sampling_metadata: SamplingMetadata, + # ) -> list[SamplerOutput]: + # if num_predict_tokens > self.max_speculative_tokens: + # raise ValueError(f"Max speculative tokens for model is " + # f"{self.max_speculative_tokens}, but " + # f"{num_predict_tokens} were requested") - if self.scale_input: - previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + # # b x 1 x d + # previous_hidden_states = previous_hidden_states.unsqueeze(1) - # b x 1 - last_tokens = input_ids.unsqueeze(1) + # if self.scale_input: + # previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 - next_tokens = [] + # # b x 1 + # last_tokens = input_ids.unsqueeze(1) - for head_index in range(num_predict_tokens): + # next_tokens = [] - # Project and predict - z = self.emb[head_index](last_tokens) # b k d - states = self.proj[head_index](previous_hidden_states) + # for head_index in range(num_predict_tokens): - # Weighted add of state_weight*state and emb_weight*z - # Let subsequent LN take care of denominator - # state_weight is close to 1, so shouldn't be any precision issues - states.add_(z, alpha=self.emb_weight / self.state_weight) + # # Project and predict + # z = self.emb[head_index](last_tokens) # b k d + # states = self.proj[head_index](previous_hidden_states) - states = self.activation(self.ln[head_index](states)) # b k d - previous_hidden_states = states - # TODO: not yet supporting top_k_tokens_per_head - states = states.flatten(0, 1) + # # Weighted add of state_weight*state and emb_weight*z + # # Let subsequent LN take care of denominator + # # state_weight is close to 1, so shouldn't be any precision issues + # states.add_(z, alpha=self.emb_weight / self.state_weight) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + # states = self.activation(self.ln[head_index](states)) # b k d + # previous_hidden_states = states + # # TODO: not yet supporting top_k_tokens_per_head + # states = states.flatten(0, 1) - output = self.sampler(logits, sampling_metadata) - last_tokens = output.sampled_token_ids - next_tokens.append(output) + # logits = self.logits_processor(self.head[head_index], states, + # sampling_metadata) - return next_tokens + # output = self.sampler(logits, sampling_metadata) + # last_tokens = output.sampled_token_ids + # next_tokens.append(output) + + # return next_tokens def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index c4548ee168bd7..aa7c434a44aeb 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -697,16 +697,12 @@ class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - # If the shape is the same, it means that we have already - # prune hidden states manually. - prune_hidden_states = hidden_states.size( - 0) != sampling_metadata.selected_token_indices.size(0) processed_logits = self.logits_processor( self.lm_head, hidden_states, sampling_metadata, self.embedding_bias, - prune_hidden_states=prune_hidden_states) + ) return processed_logits def load_weights( diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 2315f9dad5a5a..8c4548ff7f7dc 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,597 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from array import array -from dataclasses import dataclass -from typing import Optional - -import torch - -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) - -_SAMPLING_EPS = 1e-5 - - -@dataclass -class SequenceGroupToSample: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Sequence ids for the sequence group in a previous step. - seq_ids: list[int] - sampling_params: SamplingParams - # seq_id -> sequence data. - seq_data: dict[int, SequenceData] - # The length of the sequence (all tokens seen in the past + new token to - # compute attention) of the sequence group. None if it is in a decode - # stage. - seq_len: Optional[int] - # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked - # prefill is enabled. - query_len: Optional[int] - # A random number generator for sampling. - generator: Optional[torch.Generator] - # True if the sequence group is in prefill stage. False if it is in a - # decode stage. - is_prompt: bool - # Query token indices from logits. to compute prompt logprob. Empty if - # prompt logprob is not required. - prompt_logprob_indices: list[int] - # Sample token indices from logits. Empty if sampling is not required. - sample_indices: list[int] - - @property - def do_sample(self): - return len(self.sample_indices) > 0 - - def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.prompt_logprobs is not None - if self.is_prompt: - assert self.seq_len is not None - assert self.query_len is not None - - -def gen_seq_group_to_sample_builder(num_seqs: int): - return lambda: SequenceGroupToSample( - seq_ids=[0] * num_seqs, - sampling_params=None, - seq_data=None, # type: ignore - seq_len=0, - query_len=0, - generator=None, - is_prompt=True, - prompt_logprob_indices=[], - sample_indices=[], - ) - - -class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations""" - - def __init__(self): - self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {} - - def get_cached_seq_group_to_sample(self, num_seqs): - if num_seqs not in self._seq_group_to_sample_cache: - self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( - gen_seq_group_to_sample_builder(num_seqs)) - - obj = self._seq_group_to_sample_cache[num_seqs].get_object() - return obj - - def reset(self): - for cache in self._seq_group_to_sample_cache.values(): - cache.reset() - class SamplingMetadata: - """Metadata for input sequences. Used in sampler. - - The usage is as follows; - ``` - hidden_states = execute_model(...) - logits = hidden_states[sampling_metadata.selected_token_indices] - sample(logits) - - def sample(logits): - # Use categorized_sample_indices for sampling.... - ``` - - Args: - seq_groups: List of batched sequence groups. - selected_token_indices: (num_query_tokens_to_logprob). Indices to find - logits from the initial model output hidden states. - categorized_sample_indices: SamplingType -> token indices to sample. - Each token indices is 2D tensor of (num_indices, num_indices) where - the first item means the sample index within the returned logit - (before pruning padding), and the second item means the sample - index after pruning using selected_token_indices. - For example, if the returned logit is [1, 2, 3], and we select - [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, - The first tuple is [1, 2] (sampled index within original logit), - and the second tuple is [0, 1] (sampled index within pruned logit). - num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU - serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling - tensors that are part of the sampler forward pass. Currently, - it is mainly used for multi-step decode. - - """ - - def __init__( - self, - seq_groups: list[SequenceGroupToSample], - selected_token_indices: torch.Tensor, - categorized_sample_indices: dict[SamplingType, torch.Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, - ) -> None: - self.seq_groups = seq_groups - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - self.num_prompts = num_prompts - self.skip_sampler_cpu_output = skip_sampler_cpu_output - self.reuse_sampling_tensors = reuse_sampling_tensors - - @staticmethod - def prepare( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - pin_memory: bool, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, - ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators, cache) - selected_token_indices = async_tensor_h2d( - selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory, - ) - categorized_sample_indices = { - t: - async_tensor_h2d( - seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory, - ) - for t, seq_ids in categorized_sample_indices.items() - } - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - num_prompts=num_prompts, - ) - return sampling_metadata - - def __repr__(self) -> str: - return ( - "SamplingMetadata(" - f"seq_groups={self.seq_groups}, " - f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices})") - - -def _prepare_seq_groups( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, -) -> tuple[ - list[SequenceGroupToSample], - list[int], - dict[SamplingType, list[int]], - int, -]: - """Prepare sequence groups and indices for sampling. - - Args: - seq_group_metadata_list: A list of sequence group to batch. - seq_lens: A list of sequence lens per sequence group. - Index of prompt len should match with seq_group_metadata_list. - query_lens: A list of query lengths. Prompt lens include the length - of entire prompt tokens, and it could be shorter. - device: A device to use for random number generators, - `SequenceGroupToSample.generator`. - generators: A store of per-request random number generators used - for seeded requests. - - Returns: - seq_groups: A list of sequence group to sample. - selected_token_indices: See the definition from `SamplingMetadata`. - categorized_sample_indices: See the definition from `SamplingMetadata`. - num_prompts: Total number of prompts from `seq_group_metadata_list`. - """ - # Batched sequence groups for the current model forward stsep. - seq_groups: list[SequenceGroupToSample] = [] - # A list of token indices to sample/compute logprob. It is used to - # prune the outcome logits from the model for the performance. - selected_token_indices: list[int] = [] - # Used for selected_token_indices. - model_output_idx = 0 - - # Sampling type -> ( - # indices to sample/prompt logprob within pruned output logits, - # indices to sample within pruned logits) - categorized_sample_indices: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - # Index of logits to compute logprob. Logits include both prompt logprob - # and sample logprob indices. - logit_idx = 0 - # Total number of prompts from given sequence groups. - num_prompts = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = seq_group_metadata.seq_data.keys() - - if cache is not None: - sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) - - for j, seq_id in enumerate(seq_ids): - sample_obj.seq_ids[j] = seq_id - - sample_obj.prompt_logprob_indices.clear() - sample_obj.sample_indices.clear() - - sampling_params = seq_group_metadata.sampling_params - is_prompt = seq_group_metadata.is_prompt - generator: Optional[torch.Generator] = None - # If the current seq group is in decode stage, it is None. - seq_len: Optional[int] = None - query_len: Optional[int] = None - prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices - if cache is not None else []) - sample_indices: list[int] = (sample_obj.sample_indices - if cache is not None else []) - do_sample = seq_group_metadata.do_sample - - if seq_group_metadata.is_prompt: - if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) - if generators is not None: - generators[seq_group_metadata.request_id] = generator - - num_prompts += 1 - num_prefill_sample = len(seq_ids) - assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 - else: - # Decode - prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None and len( - query_lens) > 0 else 1 - sample_len = len(seq_ids) * query_len if do_sample else 0 - - if sampling_params.seed is not None and generators is not None: - generator = generators.get(seq_group_metadata.request_id) - - # Update indices to select from the model output. - """ - This blocks computes selected_token_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - """ - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + prompt_logprob_len)) - model_output_idx += prompt_logprob_len - if do_sample: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + sample_len)) - model_output_idx += sample_len - - # We now find indices for logprob computation and sampling. - """ - This block computes categorized_sample_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - def sample(logits): - # Use categorized_sample_indices for sampling. - # prompt_logprob_indices to find prompt logprob indices. - # sample_indices to find sample indices. - """ - - if sampling_params.prompt_logprobs is not None: - prompt_logprob_indices.extend( - range(logit_idx, logit_idx + prompt_logprob_len)) - logit_idx += prompt_logprob_len - if do_sample: - sample_indices.extend(range(logit_idx, logit_idx + sample_len)) - categorized_sample_indices[sampling_params.sampling_type].extend( - list(range(logit_idx, logit_idx + sample_len))) - logit_idx += sample_len - - if cache is not None: - sample_obj.sampling_params = sampling_params - sample_obj.seq_data = seq_group_metadata.seq_data - sample_obj.seq_len = seq_len - sample_obj.query_len = query_len - sample_obj.generator = generator - sample_obj.is_prompt = is_prompt - else: - sample_obj = SequenceGroupToSample( - seq_ids=list(seq_ids), - sampling_params=sampling_params, - seq_data=seq_group_metadata.seq_data, - seq_len=seq_len, - query_len=query_len, - generator=generator, - is_prompt=is_prompt, - prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices), - ) - - seq_groups.append(sample_obj) - - if cache is not None: - cache.reset() - - return (seq_groups, selected_token_indices, categorized_sample_indices, - num_prompts) - - -@dataclass -class SamplingTensors: - """Tensors for sampling.""" - - temperatures: torch.Tensor - top_ps: torch.Tensor - top_ks: torch.Tensor - min_ps: torch.Tensor - presence_penalties: torch.Tensor - frequency_penalties: torch.Tensor - repetition_penalties: torch.Tensor - prompt_tokens: torch.Tensor - output_tokens: torch.Tensor - - @classmethod - def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: list[array] = [] - output_tokens: list[array] = [] - top_ks: list[int] = [] - temperatures: list[float] = [] - top_ps: list[float] = [] - min_ps: list[float] = [] - presence_penalties: list[float] = [] - frequency_penalties: list[float] = [] - repetition_penalties: list[float] = [] - do_penalties = False - do_top_p_top_k = False - do_min_p = False - - assert sampling_metadata.seq_groups is not None - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - temperature = sampling_params.temperature - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - top_p = sampling_params.top_p - min_p = sampling_params.min_p - - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k < 1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS - or top_k != vocab_size): - do_top_p_top_k = True - if not do_min_p and min_p > _SAMPLING_EPS: - do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - - is_prompt = seq_group.is_prompt - if is_prompt and sampling_params.prompt_logprobs is not None: - # For tokens in the prompt that we only need to get - # their logprobs - query_len = seq_group.query_len - assert query_len is not None - prefill_len = len(seq_group.prompt_logprob_indices) - temperatures += [temperature] * prefill_len - top_ps += [top_p] * prefill_len - top_ks += [top_k] * prefill_len - min_ps += [min_p] * prefill_len - presence_penalties += [0] * prefill_len - frequency_penalties += [0] * prefill_len - repetition_penalties += [1] * prefill_len - - if seq_group.do_sample: - sample_lens = len(seq_group.sample_indices) - assert sample_lens >= len(seq_ids) - temperatures += [temperature] * sample_lens - top_ps += [top_p] * sample_lens - top_ks += [top_k] * sample_lens - min_ps += [min_p] * sample_lens - presence_penalties += [p] * sample_lens - frequency_penalties += [f] * sample_lens - repetition_penalties += [r] * sample_lens - - if do_penalties: - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - if seq_group.do_sample: - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) - - sampling_tensors = SamplingTensors.from_lists( - temperatures, - top_ps, - top_ks, - min_ps, - presence_penalties, - frequency_penalties, - repetition_penalties, - prompt_tokens, - output_tokens, - vocab_size, - device, - dtype, - ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) - - @classmethod - def from_lists( - cls, - temperatures: list[float], - top_ps: list[float], - top_ks: list[int], - min_ps: list[float], - presence_penalties: list[float], - frequency_penalties: list[float], - repetition_penalties: list[float], - prompt_tokens: list[array], - output_tokens: list[array], - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> "SamplingTensors": - # Note that the performance will be very bad without - # pinned memory. - pin_memory = is_pin_memory_available() - - do_penalties = prompt_tokens or output_tokens - - if do_penalties: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_t = empty_tensor - output_t = empty_tensor - - temperatures_t = torch.tensor( - temperatures, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ps_t = torch.tensor( - top_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - min_ps_t = torch.tensor( - min_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - presence_penalties_t = torch.tensor( - presence_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - frequency_penalties_t = torch.tensor( - frequency_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - repetition_penalties_t = torch.tensor( - repetition_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ks_t = torch.tensor( - top_ks, - device="cpu", - dtype=torch.int, - pin_memory=pin_memory, - ) - # Because the memory is pinned, we can do non-blocking - # transfer to device. - - return cls( - temperatures=temperatures_t.to(device=device, non_blocking=True), - top_ps=top_ps_t.to(device=device, non_blocking=True), - top_ks=top_ks_t.to(device=device, non_blocking=True), - min_ps=min_ps_t.to(device=device, non_blocking=True), - presence_penalties=presence_penalties_t.to(device=device, - non_blocking=True), - frequency_penalties=frequency_penalties_t.to(device=device, - non_blocking=True), - repetition_penalties=repetition_penalties_t.to(device=device, - non_blocking=True), - prompt_tokens=prompt_t.to(device=device, non_blocking=True), - output_tokens=output_t.to(device=device, non_blocking=True), - ) + # Placeholder until it can be safely removed. + pass diff --git a/vllm/sequence.py b/vllm/sequence.py index 24114c0bb792e..a6c194fbac0b2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,28 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sequence and its related classes.""" -import copy -import enum -from abc import ABC, abstractmethod -from array import array -from collections import defaultdict -from collections.abc import Mapping -from collections.abc import Sequence as GenericSequence -from dataclasses import dataclass, field -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec import torch -from vllm.inputs import SingletonInputs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import RequestOutputKind, SamplingParams - if TYPE_CHECKING: - from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) else: @@ -34,50 +19,6 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" VLLM_INVALID_TOKEN_ID = -1 -def array_full(token_id: int, count: int): - """[`array`][] equivalent of [numpy.full][].""" - return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - - -class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered - # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: - return status > SequenceStatus.SWAPPED - - @staticmethod - def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: - if status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - elif status == SequenceStatus.FINISHED_ABORTED: - finish_reason = "abort" - elif status == SequenceStatus.FINISHED_IGNORED: - # The ignored sequences are the sequences whose prompt lengths - # are longer than the model's length cap. Therefore, the stop - # reason should also be "length" as in OpenAI API. - finish_reason = "length" - else: - finish_reason = None - return finish_reason - - -class SequenceStage(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - - @dataclass class RequestMetrics: """Metrics associated with a request. @@ -107,971 +48,12 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta SequenceData to send to workers per step.""" - # A new token to be appended to existing SequenceData. - new_output_token_ids: list[int] - # Overwriting existing `cumulative_logprob` - new_cumulative_logprob: float - # Overwriting existing `num_computed_tokens`. - new_num_computed_tokens: int - # Overwriting existing `stage`. - new_stage: SequenceStage - - -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence.""" - # NOTE: we cannot use Union[list, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - - _prompt_embeds: Optional[torch.Tensor] = None - _output_embeds: Optional[torch.Tensor] = None - - ### The below fields should not be passed as an argument ### - _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: tuple[int, - ...] = msgspec.field(default_factory=tuple) - # The number of tokens that are computed (that run against the model). - _num_computed_tokens: int = 0 - # The number of tokens with prefix cache hit. - _num_cached_tokens: int = 0 - _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) - _cached_all_token_embeds: Optional[torch.Tensor] = None - - # It is used to get delta input. It is reset when `get_delta_and_reset` - # is called. - _new_appended_tokens: list[int] = msgspec.field(default_factory=list) - - # It is used to compute mrope_position_ids. - _mrope_position_delta: Optional[int] = None - - @staticmethod - def from_prompt_token_counts( - *token_counts: tuple[int, int]) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - by concatenating prompt token sequences. - - Each tuple represents one token sequence, expressed in the form - `(token_id, count)`. - """ - if len(token_counts) == 0: - return SequenceData.from_seqs([]) - - prompt_token_ids_arr = reduce( - array.__iadd__, - (array_full(token_id, count) for token_id, count in token_counts), - ) - - return SequenceData(prompt_token_ids_arr) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - *, - prompt_embeds: Optional[torch.Tensor] = None, - ) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - from prompt and output token sequences. - """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceData(prompt_token_ids_arr, - _prompt_embeds=prompt_embeds) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr, - _prompt_embeds=prompt_embeds) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: tuple[int, ...] = tuple( - self._prompt_token_ids) - self._update_cached_all_tokens() - if self._prompt_embeds is not None: - self._update_cached_all_token_embeds() - - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + - self._output_token_ids) - - def _update_cached_all_token_embeds(self): - assert isinstance(self._prompt_embeds, torch.Tensor) - self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds - if self._output_embeds is not None: - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, self._output_embeds), dim=0) - - @property - def cumulative_logprob(self) -> float: - """The cumulative log probability of the output.""" - return self._cumulative_logprob - - @property - def prompt_token_ids(self) -> tuple[int, ...]: - """The token IDs of the prompt.""" - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - @property - def output_token_ids(self) -> tuple[int, ...]: - """The token IDs of the output.""" - return tuple(self._output_token_ids) - - @output_token_ids.setter - def output_token_ids(self, - new_output_token_ids: GenericSequence[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) - self._update_cached_all_tokens() - - @property - def output_embeds(self) -> Optional[torch.Tensor]: - return self._output_embeds - - @output_embeds.setter - def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: - self._output_token_embeds = new_output_token_embeds - self._update_cached_all_token_embeds() - - @property - def output_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - assert isinstance(self._output_token_ids, array) - return self._output_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: - self._prompt_embeds = prompt_embeds - self._update_cached_all_token_embeds() - - @property - def mrope_position_delta(self) -> Optional[int]: - return self._mrope_position_delta - - @mrope_position_delta.setter - def mrope_position_delta(self, new_mrope_position_delta): - self._mrope_position_delta = new_mrope_position_delta - - def append_token_id(self, - token_id: int, - logprob: float, - token_embed: Optional[torch.Tensor] = None) -> None: - self._output_token_ids.append(token_id) - self._new_appended_tokens.append(token_id) - self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob - if token_embed is not None: - # Do not pass in with batch or sequence dimensions - assert token_embed.ndim == 1 - token_embed = token_embed.detach().cpu().unsqueeze(0) - if self._output_embeds is None: - self._output_embeds = token_embed - else: - self._output_embeds = torch.cat( - (self._output_embeds, token_embed), dim=0) - assert self._cached_all_token_embeds is not None - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, - token_embed.to(device=self._cached_all_token_embeds.device)), - dim=0) - - def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_output_len(self) -> int: - return len(self._output_token_ids) - - def get_token_ids(self) -> list[int]: - return self._cached_all_token_ids - - def get_token_embeddings(self) -> Optional[torch.Tensor]: - return self._cached_all_token_embeds - - def get_prefix_token_ids( - self, num_tokens: int - ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - - def get_num_computed_tokens(self) -> int: - """Return the number of prefill tokens that are already computed.""" - return self._num_computed_tokens - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) - # If all tokens are computed, it means it is in decoding phase. - if self.get_num_uncomputed_tokens() == 0: - self._stage = SequenceStage.DECODE - - def get_num_cached_tokens(self) -> int: - """Return the number of tokens with prefix cache hit.""" - return self._num_cached_tokens - - def update_num_cached_tokens(self, num_cached_tokens: int): - """Update the number of tokens with prefix cache hit.""" - self._num_cached_tokens = num_cached_tokens - - def reset_state_for_recompute(self) -> None: - """Reset the number of computed tokens from this sequence. It is - supposed to be called when a sequence needs to be started from - the beginning again (e.g., sequence is preempted). - """ - self._num_computed_tokens = 0 - self._stage = SequenceStage.PREFILL - self._new_appended_tokens = [] - - def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefill tokens that are not computed.""" - # we use `get_len()` which includes prompt_len + output_len instead - # of prompt_len here. This is because during recompute we need to - # prefill for both prompt and output. - return self.get_len() - self.get_num_computed_tokens() - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.prompt_token_ids - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.output_token_ids - - def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) - # Reset delta state. - self._new_appended_tokens = [] - return delta - - def apply_delta(self, delta: SequenceDataDelta): - self._num_computed_tokens = delta.new_num_computed_tokens - self._cumulative_logprob = delta.new_cumulative_logprob - self._stage = delta.new_stage - self._output_token_ids.extend(delta.new_output_token_ids) - self._cached_all_token_ids.extend(delta.new_output_token_ids) - - @property - def stage(self) -> SequenceStage: - return self._stage - - def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds.shape=" - f"{getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") - - -class Sequence: - """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the - [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only) - or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - (for encoder-decoder) instance passed in through the `inputs` - constructor argument. - - Args: - seq_id: The ID of the sequence. - inputs: The inputs of the sequence. - block_size: The block size of the sequence. Should be the same as the - block size used by the block manager and cache engine. - eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. - lora_request: LoRA request. - """ - - def __init__( - self, - seq_id: int, - inputs: SingletonInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = inputs - self.block_size = block_size - self.eos_token_id = eos_token_id - self.lora_request = lora_request - - self.data = SequenceData.from_seqs( - self.prompt_token_ids, - prompt_embeds=self.inputs["prompt_embeds"] - if self.inputs["type"] == "embeds" else None) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" - - self.status = SequenceStatus.WAITING - self.stop_reason: Union[int, str, None] = None - - # These are used to keep track of delta outputs - self._last_output_token_ids_offset: int = 0 - self._last_output_text_offset: int = 0 - - # Used for incremental detokenization - self.prefix_offset = 0 - self.read_offset = 0 - # Input + output tokens - self.tokens: Optional[list[str]] = None - - @property - def n_blocks(self) -> int: - return (self.get_len() + self.block_size - 1) // self.block_size - - @property - def prompt(self) -> Optional[str]: - if self.inputs["type"] == "embeds": - return None - return self.inputs.get("prompt") - - @property - def prompt_token_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [0] * len(self.inputs["prompt_embeds"]) - return self.inputs["prompt_token_ids"] - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"].get_data() - - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_placeholders"] - - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: - """If delta is True, only new text since the last call to - this method is returned""" - - # We return the full output text if the sequence is finished. - truncate = buffer_length and not self.is_finished() - if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) - length = len(self.output_text) - if truncate: - length -= buffer_length - last_offset = self._last_output_text_offset - if last_offset < length: - self._last_output_text_offset = length - return self.output_text[last_offset:length] - return "" - - def get_output_token_ids_to_return( - self, delta: bool) -> Union[GenericSequence[int], int]: - """If delta is True, only new tokens since the last call to - this method are returned""" - if not delta: - return self.get_output_token_ids() - - output_len = self.get_output_len() - - # Get the number of new tokens - num_new_tokens = output_len - self._last_output_token_ids_offset - self._last_output_token_ids_offset = output_len - - # Return new tokens - if num_new_tokens == 1: - # Optimization for single decode token case - # (which is what we have most of the time) - return self.data._cached_all_token_ids[-1] - - if num_new_tokens == 0: - return [] - - return self.data._cached_all_token_ids[-num_new_tokens:] - - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) - - def extra_hash(self) -> Optional[int]: - """ - This function computes an extra hash for a sequence, specifically - designed for prefix caching mode. The final sequence hash is determined - by applying token_ids from the sequence's blocks. - """ - if self.lora_int_id == 0: - return None - - # NOTE: If there are additional factors influencing the block aside from - # token_ids, include them as input parameters to the hash. - return hash(self.lora_int_id) - - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size - - def reset_state_for_recompute(self): - """Reset the sequence states for recomputation.""" - self.data.reset_state_for_recompute() - - def append_token_id(self, - token_id: int, - logprobs: dict[int, Logprob], - token_embed: Optional[torch.Tensor] = None) -> None: - assert token_id in logprobs - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, - token_embed) - - def get_len(self) -> int: - return self.data.get_len() - - def get_prompt_len(self) -> int: - return self.data.get_prompt_len() - - def get_output_len(self) -> int: - return self.data.get_output_len() - - def get_token_ids(self) -> list[int]: - return self.data.get_token_ids() - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.data.get_prompt_token_ids() - - def get_last_token_id(self) -> int: - return self.data.get_last_token_id() - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.data.get_output_token_ids() - - def get_cumulative_logprob(self) -> float: - return self.data.cumulative_logprob - - def is_finished(self) -> bool: - return SequenceStatus.is_finished(self.status) - - def fork(self, new_seq_id: int) -> "Sequence": - new_seq = copy.deepcopy(self) - new_seq.seq_id = new_seq_id - return new_seq - - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. - - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 - return self.data.get_num_uncomputed_tokens() - - def get_num_computed_tokens(self) -> int: - return self.data.get_num_computed_tokens() - - def is_prefill(self) -> bool: - return self.data.stage == SequenceStage.PREFILL - - def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks})") - - -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Mutable state tied to a specific sequence group""" - - # for multi-step decoding - num_steps: int = 1 - current_step: int = 0 - - @property - def remaining_steps(self) -> int: - return self.num_steps - self.current_step - - -class SequenceGroup: - """A group of sequences that are generated from the same prompt. - - Args: - request_id: The ID of the request. - seqs: The list of sequences. - sampling_params: The sampling parameters used to generate the outputs. - arrival_time: The arrival time of the request. - lora_request: LoRA request. - pooling_params: The parameters used to generate the pooler - for a pooling model. - pooled_data: The extracted hidden states from a pooling model. - encoder_seq: Optional, the single encoder sequence. Should be None - unless you are working with an encoder/decoder model. - trace_headers: OpenTelemetry trace headers. - priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target - model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than - that for multi-draft SD (currently not supported). - """ - - def __init__(self, - request_id: str, - seqs: list[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - draft_size: int = 1) -> None: - self.request_id = request_id - self.seqs = seqs - self.first_seq = seqs[0] - self.arrival_time = arrival_time - self.is_single_seq = len(seqs) == 1 - self.seqs_dict = {seq.seq_id: seq for seq in seqs} - - self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) - self.last_token_latency = 0.0 - self.lora_request = lora_request - self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() - self.pooling_params = pooling_params - self.pooled_data = pooled_data - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority - - self.cached_request_output = None - - @property - def prompt(self) -> Optional[str]: - return self.first_seq.prompt - - @property - def prompt_token_ids(self) -> list[int]: - return self.first_seq.prompt_token_ids - - @property - def encoder_prompt(self) -> Optional[str]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt is distinct - # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) - - @property - def encoder_prompt_token_ids(self) -> Optional[list[int]]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt token ids are - # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_data - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_data - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_placeholders - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_placeholders - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def set_last_token_time(self, now: float) -> None: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, assertion fails. - assert not self.is_prefill(), ( - "seq_group.set_last_token_time() should not be called " - "if the seq_group is in prefill phase.") - self.last_token_latency = now - self.metrics.last_token_time - self.metrics.last_token_time = now - - def get_last_token_latency(self) -> float: - """Returns the latency of the last token.""" - assert not self.is_prefill(), ( - "seq_group.get_last_token_latency() should not be called " - "if the seq_group is in prefill phase.") - return self.last_token_latency - - def maybe_set_first_token_time(self, time: float) -> None: - """Sets the first token time for Request level timings.""" - # Note: in a case where a sequence_group is swapped and - # recomputed, the time between iterations is counted - # in TPOT, rather than recalculating TTFT (since from the ) - # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.first_seq.get_output_len() == 1): - self.metrics.first_token_time = time - - def maybe_set_first_scheduled_time(self, time: float) -> None: - """Sets the first scheduled time and time in queue for Request - level timings.""" - if self.metrics.first_scheduled_time is None: - self.metrics.first_scheduled_time = time - self.metrics.time_in_queue = time - self.metrics.arrival_time - - def set_finished_time(self, time: Optional[float]) -> None: - """Sets the finished time for Request level timings.""" - self.metrics.finished_time = time - - def get_max_num_running_seqs(self) -> int: - """The maximum number of sequences running in parallel in the remaining - lifetime of the request.""" - if self.is_single_seq: - return 0 if self.first_seq.is_finished() else 1 - return self.num_seqs() - self.num_finished_seqs() - - def get_seqs( - self, - status: Optional[SequenceStatus] = None, - ) -> list[Sequence]: - if status is None: - return self.seqs - - if self.is_single_seq: - return self.seqs if self.first_seq.status == status else [] - - return [seq for seq in self.seqs if seq.status == status] - - def is_encoder_decoder(self) -> bool: - return self.encoder_seq is not None - - def get_encoder_seq(self) -> Optional[Sequence]: - return self.encoder_seq - - def get_finished_seqs(self) -> list[Sequence]: - if self.is_single_seq: - return self.seqs if self.first_seq.is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) - - def get_num_uncomputed_tokens(self) -> int: - num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() - return num_uncomputed_tokens - - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: - # Optimization. We don't need to call get_seqs if we don't need to - # filter by states. - if status is None: - return len(self.seqs) - - if self.is_single_seq: - return 1 if self.seqs[0].status == status else 0 - - return len(self.get_seqs(status)) - - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - return len(self.get_finished_seqs()) - - def is_finished(self) -> bool: - if self.is_single_seq: - return self.first_seq.is_finished() - return all(seq.is_finished() for seq in self.seqs) - - def is_prefill(self) -> bool: - return self.first_seq.is_prefill() - - def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") - - def uses_prompt_embeds(self) -> bool: - """Returns True if the sequence group uses input embeds.""" - return any(seq.data.prompt_embeds is not None for seq in self.seqs) - - -class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta of SequenceGroupMetadata. - - After sending the first SequenceGroupMetadata, vLLM scheduler - only sends delta to reduce the data payload size. - """ - seq_data_delta: dict[int, SequenceDataDelta] - request_id: str - block_tables: dict[int, list[int]] - is_prompt: bool - do_sample: bool = True - token_chunk_size: Optional[int] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - - -class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Metadata for a sequence group. Used to create `AttentionMetadata`. - - Attributes: - request_id: The ID of the request. - is_prompt: Whether the request is at prompt stage. - seq_data: The sequence data. (Seq id -> sequence data) - sampling_params: The sampling parameters used to generate the outputs. - block_tables: The block tables. (Seq id -> list of physical block - numbers) - do_sample: True if sampling is required. Sampling is not required when - e.g., prefill is chunked, and the current iteration only computes - query tokens for prefill, we don't need sampling. - pooling_params: Pooling parameters. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. - state: Internal state tied to this sequence group. - token_type_ids: Token type IDs. - multi_modal_data: Multi modal data. - multi_modal_placeholders: Multi modal placeholders. - encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - cross_block_table: Optional cross-attention block table associated - with the encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - """ - - request_id: str - is_prompt: bool - seq_data: dict[int, SequenceData] - sampling_params: Optional[SamplingParams] - block_tables: dict[int, list[int]] - do_sample: bool = True - pooling_params: Optional[PoolingParams] = None - lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - multi_modal_data: Optional[MultiModalKwargs] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[list[int]] = None - token_chunk_size: Optional[int] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - num_speculative_tokens: Optional[int] = None - - def __post_init__(self): - if self.seq_data is not None and self.token_chunk_size is None: - if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() - else: - self.token_chunk_size = 1 - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - # Multi-Step Chunked-Prefill property - @property - def is_single_step_prompt(self) -> bool: - # do_sample is true, only when the token_chunk_size matches the - # num_uncomputed_tokens of the sequence. This indicates that - # the prompt will finish processing in a single `execute_model` - # step. - return self.is_prompt and self.do_sample - - def get_first_seq_id(self) -> int: - # This is an efficient way of fetching the seq_id when - # we know this SequenceGroup has only one sequence. - return next(iter(self.seq_data)) - - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) - assert self.request_id == sequence_group_metadata_delta.request_id - self.block_tables = sequence_group_metadata_delta.block_tables - self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size - self.do_sample = sequence_group_metadata_delta.do_sample - self.is_prompt = sequence_group_metadata_delta.is_prompt - - def finish_step(self) -> None: - assert self.state is not None - assert self.state.current_step < self.state.num_steps, \ - f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa - self.state.current_step += 1 - - -class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a sequence. - - Attributes: - parent_seq_id: The ID of the parent sequence (for forking in beam - search). - output_token: The output token ID. - logprobs: The logprobs of the output token. - (Token id -> logP(x_i+1 | x_0, ..., x_i)) - output_embed: Optional output embedding tensor. - """ - parent_seq_id: int - output_token: int - logprobs: dict[int, Logprob] - output_embed: Optional[torch.Tensor] = None - - def __repr__(self) -> str: - output_embed_shape = \ - self.output_embed.shape if self.output_embed is not None else None - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}, " - f"logprobs={self.logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): - raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) - log_probs_equal = other.logprobs == self.logprobs - return equal and log_probs_equal - - -class SequenceGroupOutput(ABC): - """The base class for model outputs associated with a sequence group.""" - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - -class CompletionSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a completion sequence group.""" - __metaclass__ = SequenceGroupOutput - samples: list[SequenceOutput] - # Prompt logprob for each prompt query token. - prompt_logprobs: Optional[PromptLogprobs] - step_index: Optional[int] = 0 - - def __repr__(self) -> str: - return (f"CompletionSequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CompletionSequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) - - class PoolingSequenceGroupOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg] ): """The model output associated with a pooling sequence group.""" - __metaclass__ = SequenceGroupOutput # Annotated as Any to be compatible with msgspec # The actual type is in SequenceGroup.pooled_data data: Any @@ -1161,305 +143,9 @@ class PoolerOutput( self.__class__) and self.outputs == other.outputs -def get_all_seq_ids( - seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - -def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: list[SequenceGroupMetadata] -) -> tuple[list[int], dict[str, set[int]]]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - seq_ids: list[int] = [] - request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) - for sg in seq_group_metadata_list: - for seq_id in sg.seq_data: - seq_ids.append(seq_id) - request_id_seq_ids_mapping[sg.request_id].add(seq_id) - return seq_ids, request_id_seq_ids_mapping - - -class HiddenStates(msgspec.Struct, array_like=True, - omit_defaults=True): # type: ignore[call-arg] - """Hidden states corresponding to in-progress sequences. - Used in speculative decoding to pass hidden states from - the target model to the proposer model. - - seq_ids are the sequence ids of each entry of the batch - dimension of the hidden_states tensor""" - # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it is used for last accepted tokens. - hidden_states: torch.Tensor - # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None - # Scorer hidden states of the 2nd last token proposed by the proposer ( - # irrespective of whether it was accepted or not). Only used for cases when - # last proposed token is accepted (i.e., in case of bonus tokens). For the - # case of no bonus tokens, these are ignored. - second_last_token_hidden_states: Optional[torch.Tensor] = None - - _seq_ids: list[int] = msgspec.field(default_factory=list) - - def __post_init__(self): - if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - - @property - def seq_ids(self) -> list[int]: - return self._seq_ids - - def update(self, - hidden_states: torch.Tensor, - seq_group_metadata_list: list[SequenceGroupMetadata], - second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation. Only used for - decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) - self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - - if self.second_last_token_hidden_states is not None: - # Adding dummy hidden_states to this to maintain same shape - self.second_last_token_hidden_states = torch.cat([ - self.second_last_token_hidden_states, - torch.zeros_like(hidden_states) - if second_last_token_hidden_states is None else - second_last_token_hidden_states - ]) - - def prune(self, - seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids. Only used for decode steps. - """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence - # may be "paused" then "resumed" later. This should only prune sequences - # which are confirmed to be aborted. - seq_ids = get_all_seq_ids(seq_group_metadata_list) - # Only keep sequence IDs that exist in self._seq_ids - seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] - if seq_ids != self._seq_ids: - # Batch contents changed - prune removed sequences. - index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] - self._seq_ids = seq_ids - - def expand_with_bonus_tokens( - self, seq_with_bonus_token_in_last_step: set) -> None: - """Expand hidden states for sequences with bonus tokens. This is in - alignment with `MultiStepWorker._expand_execute_model_request`.""" - if self.second_last_token_hidden_states is None \ - or not seq_with_bonus_token_in_last_step: - return - - index = [] - for seq_id in self._seq_ids: - i = self._seq_ids.index(seq_id) - if seq_id in seq_with_bonus_token_in_last_step: - index.append(i + len(self._seq_ids)) - index.append(i) - - self.hidden_states = torch.cat( - [self.hidden_states, self.second_last_token_hidden_states])[index] - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] omit_defaults=True): # type: ignore[call-arg] - """The model execution request, containing CPU metadata only. The LLM - engine should create an instance of this class for each request batch.""" - # The sequence group metadata list. - seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to copy. Source to dest block. - blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) - # Virtual engine ID for pipeline parallel. - virtual_engine: int = 0 - # The number of slots for lookahead decoding. - num_lookahead_slots: int = 0 - # The number of requests in the running queue. - running_queue_size: int = 0 - # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None - # The number of forward steps to run. - num_steps: int = 1 - # Finished request ids since last step. - finished_requests_ids: list[str] = msgspec.field(default_factory=list) - # The last sampled token ids for multi step decoding. - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None - - @property - def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 - - @property - def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state - assert state is not None - return state.current_step - - def clone( - self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - ) -> "ExecuteModelRequest": - """Clone the request with a new sequence group metadata list.""" - return ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=self.blocks_to_swap_in.copy(), - blocks_to_swap_out=self.blocks_to_swap_out.copy(), - blocks_to_copy=self.blocks_to_copy.copy(), - virtual_engine=self.virtual_engine, - num_lookahead_slots=self.num_lookahead_slots, - running_queue_size=self.running_queue_size, - previous_hidden_states=self.previous_hidden_states, - num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids, - last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) - - -@dataclass -class SequenceGroupBase: - group_id: str # the original request id before splitting - - assembled_seq_group: Optional[SequenceGroup] = None - - # seq id to a unique index inside this group - seq_id_to_index: dict[str, int] = field(default_factory=dict) - - # seq ids to be finished - to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) - - # seq id to finished sequences - finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) - - streaming: bool = False - - output_produced: bool = False - - @staticmethod - def add_request(request_id: str, engine, params, *args, **kwargs): - """When we are ready to add a request with request_id and params - into the engine, we can split the request into multiple requests. - """ - raise NotImplementedError - - def finish_seq(self, seq: SequenceGroup): - """The sequence `seq` finishes, we should record the information. - """ - del self.to_be_finished[seq.request_id] - self.finished_reqs[seq.request_id] = seq - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - """Assemble the sequence group, for producing the final - output, or adding request in the engine again. - """ - raise NotImplementedError - - -class ParallelSampleSequenceGroup(SequenceGroupBase): - - @staticmethod - def add_request(request_id: str, engine, params, **kwargs): - original_params = params - group = ParallelSampleSequenceGroup(request_id) - seqs = [] - for i in range(original_params.n): - request_id_i = f"{request_id}_parallel_sample_{i}" - group.seq_id_to_index[request_id_i] = i - params = original_params.clone() - params.n = 1 - if params.seed is not None: - params.seed += i - seq_group = engine._add_processed_request( - request_id_i, - params=params, - **kwargs, - ) # type: ignore - assert seq_group is not None - engine.seq_id_to_seq_group[request_id_i] = group - group.to_be_finished[request_id_i] = seq_group - seqs.append(seq_group.seqs[0]) - - # for parallel sampling, the `assembled_seq_group` is always - # available, since we have all the sequences ready, and they - # will not change. - group.assembled_seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - arrival_time=seq_group.arrival_time, - sampling_params=original_params, - lora_request=seq_group.lora_request, - pooling_params=seq_group.pooling_params, - pooled_data=seq_group.pooled_data, - encoder_seq=seq_group.encoder_seq, - trace_headers=seq_group.trace_headers, - priority=seq_group.priority, - ) - - group.streaming = params.output_kind == RequestOutputKind.DELTA - group.output_produced = False - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - - # in the streaming mode, we will return the assembled sequence - # for the first remaining sequence, and then return None for the - # rest of sequences - if self.streaming: - first_remaining_id = next(iter(self.to_be_finished)) - if seq_group.request_id == first_remaining_id: - return self.assembled_seq_group - return None - - # in the non-streaming mode, we will return the assembled sequence - # when the last sequences finishes, and then return None for the - # rest of the time - if (len(self.to_be_finished) == 1 - and seq_group.request_id in self.to_be_finished - and seq_group.is_finished()): - assert self.assembled_seq_group is not None - params = self.assembled_seq_group.sampling_params - assert isinstance(params, SamplingParams) - if not self.output_produced: - self.output_produced = True - if params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group - if self.output_produced: - return None - return None + # Placeholder. Remove. + pass diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py deleted file mode 100644 index e2d2846a28073..0000000000000 --- a/vllm/transformers_utils/detokenizer.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from vllm.logprobs import Logprob -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, - SequenceGroup) - -from .detokenizer_utils import (convert_prompt_ids_to_tokens, - detokenize_incrementally) -from .tokenizer import AnyTokenizer - - -class Detokenizer: - """Provides methods to decode the output of a model into text.""" - - def __init__(self, tokenizer: AnyTokenizer): - self.tokenizer = tokenizer - - def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: list[Optional[dict[ - int, Logprob]]], - position_offset: int) -> None: - """Decodes the logprobs for the prompt of a sequence group. - - Args: - seq_group: The sequence group to decode. - prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs - relative to the start of the sequence (for chunked prefill). - - Returns: - The prompt logprobs with the decoded tokens. - """ - prms = seq_group.sampling_params - assert prms is not None - - # We can pick any sequence for the prompt. - seq = seq_group.get_seqs()[0] - # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() - prompt_token_ids = all_token_ids[:-1] - prefix_offset = 0 - read_offset = 0 - next_iter_prefix_offset = 0 - next_iter_read_offset = 0 - next_iter_tokens: list[str] = [] - prev_tokens = None - - for token_position_in_logprob, prompt_logprobs_for_token in enumerate( - prompt_logprobs): - - # Absolute token position equals the index in the logprobs - # list plus the offset of the entire logprobs list relative - # to the start of the sequence. - token_position = token_position_in_logprob + position_offset - if not prompt_logprobs_for_token: - continue - for token_id, sample_logprob in prompt_logprobs_for_token.items(): - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - prompt_token_ids_with_token = ( - prompt_token_ids[:token_position] + [token_id]) - (new_tokens, new_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=prompt_token_ids_with_token, - prev_tokens=prev_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - - sample_logprob.decoded_token = new_text - - # Use the offsets & prev tokens corresponding to - # real tokens to ensure detokenization is consistent - # actual with prompt. - if token_id == all_token_ids[token_position]: - next_iter_prefix_offset = new_prefix_offset - next_iter_read_offset = new_read_offset - next_iter_tokens = new_tokens - - # Advance to the next token position. - prefix_offset = next_iter_prefix_offset - read_offset = next_iter_read_offset - if prev_tokens is None: - prev_tokens = next_iter_tokens.copy() - else: - prev_tokens.extend(next_iter_tokens) - - def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> int: - """Decodes the new token for a sequence. In-place operation. - - Args: - seq: The sequence to decode. - prms: The sampling parameters used to generate the sequence. - - Returns: - The number of characters added to the output text. - """ - all_input_ids = seq.get_token_ids() - token_id_generated_this_iteration = all_input_ids[-1] - - # Convert prompt token IDs to tokens if necessary. - # Do it here so that we don't have to repeat this - # computation for each logprob. - if seq.tokens is None: - (seq.tokens, seq.prefix_offset, - seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=self.tokenizer, - prompt_ids=all_input_ids[:-1], - skip_special_tokens=prms.skip_special_tokens, - ) - - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - - # Decode logprobs - logprobs = seq.output_logprobs[-1] - if logprobs: - previous_tokens = all_input_ids[:-1] - for token_id, sample_logprob in logprobs.items(): - # If the token was generated this iteration, - # use the provided text. - if token_id == token_id_generated_this_iteration: - sample_logprob.decoded_token = new_decoded_token_text - continue - - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - all_input_ids_with_logprob = previous_tokens + [token_id] - (_, new_text, _, _) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_decoded_token_text - - return len(new_decoded_token_text) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index eaab976bf7f75..20fabef4f19b9 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,12 +11,12 @@ import torch.nn as nn from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, update_environment_variables, warn_for_unimplemented_methods) +from vllm.v1.outputs import SamplerOutput logger = init_logger(__name__)