From 5685370271d7f3e8222e26efb854e72e826b9af7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 1 Sep 2025 12:07:53 -0700 Subject: [PATCH] [Chore][V0 Deprecation] Move LogProb to a separate file (#24055) Signed-off-by: Woosuk Kwon --- vllm/beam_search.py | 2 +- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/entrypoints/openai/serving_responses.py | 4 +-- vllm/logprobs.py | 28 +++++++++++++++++++ vllm/model_executor/layers/sampler.py | 4 +-- vllm/model_executor/model_loader/neuron.py | 4 +-- .../model_loader/neuronx_distributed.py | 4 +-- vllm/outputs.py | 5 ++-- vllm/sequence.py | 25 +---------------- vllm/transformers_utils/detokenizer.py | 5 ++-- vllm/v1/engine/logprobs.py | 2 +- 14 files changed, 49 insertions(+), 42 deletions(-) create mode 100644 vllm/logprobs.py diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 5a2e79e1b5c7..01124872e98c 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -4,8 +4,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union +from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob if TYPE_CHECKING: from vllm.multimodal import MultiModalDataDict diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 30c3a8269615..488102232562 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -43,10 +43,10 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam, ScoreMultiModalParam) from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) -from vllm.sequence import Logprob from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6300d0758c3d..35edd2f85cd0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -43,10 +43,10 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, truncate_tool_call_ids, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 11effba8f9eb..b26140d4b9d7 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -38,9 +38,9 @@ from vllm.entrypoints.utils import get_max_tokens from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import as_list, merge_async_iterators diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b6a18760115a..796b8ab5fc2c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -67,13 +67,13 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger +from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin MultiModalDataDict, MultiModalUUIDDict) from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.sequence import Logprob, PromptLogprobs from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 899cb07b2b37..6a676cfe1b38 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -58,11 +58,11 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.tool_server import MCPToolServer, ToolServer from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob as SampleLogprob +from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import SamplingParams -from vllm.sequence import Logprob as SampleLogprob -from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid diff --git a/vllm/logprobs.py b/vllm/logprobs.py new file mode 100644 index 000000000000..e58ca142c00a --- /dev/null +++ b/vllm/logprobs.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ + logprob: float + rank: Optional[int] = None + decoded_token: Optional[str] = None + + +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. +PromptLogprobs = list[Optional[dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. +SampleLogprobs = list[dict[int, Logprob]] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e77eb637c894..829dd82b0bd4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -13,14 +13,14 @@ import torch import torch.nn as nn import vllm.envs as envs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.model_executor.layers.utils import apply_penalties from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SequenceOutput) + CompletionSequenceGroupOutput, SequenceOutput) if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): # yapf: disable diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index fad97aba84b6..ee484e9a7b0a 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -14,12 +14,12 @@ from transformers import PretrainedConfig from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig) +from vllm.logprobs import Logprob from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) +from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index f450961c64ff..34bf43fe7b57 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -27,11 +27,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig) from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceOutput) +from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput # yapf: enable logger = init_logger(__name__) diff --git a/vllm/outputs.py b/vllm/outputs.py index acdb2f89ce73..64bcfd472f2a 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -11,11 +11,12 @@ import torch from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) +from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase, + SequenceStatus) logger = init_logger(__name__) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7b48b7be9f51..4b8e1f4641f7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -16,6 +16,7 @@ import msgspec import torch from vllm.inputs import SingletonInputs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -38,30 +39,6 @@ def array_full(token_id: int, count: int): return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count -# We use dataclass for now because it is used for -# openai server output, and msgspec is not serializable. -# TODO(sang): Fix it. -@dataclass -class Logprob: - """Infos for supporting OpenAI compatible logprobs and token ranks. - - Attributes: - logprob: The logprob of chosen token - rank: The vocab rank of chosen token (>=1) - decoded_token: The decoded chosen token index - """ - logprob: float - rank: Optional[int] = None - decoded_token: Optional[str] = None - - -# {token_id -> logprob} per each sequence group. None if the corresponding -# sequence group doesn't require prompt logprob. -PromptLogprobs = list[Optional[dict[int, Logprob]]] -# {token_id -> logprob} for each sequence group. -SampleLogprobs = list[dict[int, Logprob]] - - class SequenceStatus(enum.IntEnum): """Status of a sequence.""" WAITING = 0 diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 380c62a141f0..56b01ecf78c4 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -3,8 +3,9 @@ from typing import Optional -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, - Sequence, SequenceGroup) +from vllm.logprobs import Logprob +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, + SequenceGroup) from .detokenizer_utils import (convert_prompt_ids_to_tokens, detokenize_incrementally) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 3de7fa6889e5..133122b6fcc0 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_ids_list_to_tokens) from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest