vllm/vllm/outputs.py
Woosuk Kwon 20235c1822 [V0 Deprecation] Remove from_seq_group methods (#25330)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-03 13:35:53 -07:00

325 lines
12 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
from typing import Any, Generic, Optional, Union
import torch
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sequence import RequestMetrics
logger = init_logger(__name__)
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.
Args:
index: The index of the output in the request.
text: The generated output text.
token_ids: The token IDs of the generated output text.
cumulative_logprob: The cumulative log probability of the generated
output text.
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
stop_reason: The stop string or token id that caused the completion
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
"""
index: int
text: str
token_ids: GenericSequence[int]
cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
lora_request: Optional[LoRARequest] = None
def finished(self) -> bool:
return self.finish_reason is not None
def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason}, "
f"stop_reason={self.stop_reason})")
@dataclass
class PoolingOutput:
"""The output data of one pooling output of a request.
Args:
data: The extracted hidden states.
"""
data: torch.Tensor
def __repr__(self) -> str:
return (f"PoolingOutput(data={self.data})")
def __eq__(self, other: object) -> bool:
return (isinstance(other, self.__class__) and bool(
(self.data == other.data).all()))
class RequestOutput:
"""The output data of a completion request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
For encoder/decoder models, this is the
decoder input prompt.
prompt_token_ids: The token IDs of the prompt.
For encoder/decoder models, this is the
decoder input prompt token ids.
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
metrics: Metrics associated with the request.
lora_request: The LoRA request that was used to generate the output.
encoder_prompt: The encoder prompt string of the request.
None if decoder-only.
encoder_prompt_token_ids: The token IDs of the encoder prompt.
None if decoder-only.
num_cached_tokens: The number of tokens with prefix cache hit.
kv_transfer_params: The params for remote K/V transfer.
"""
def __init__(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[list[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: list[CompletionOutput],
finished: bool,
metrics: Optional[RequestMetrics] = None,
lora_request: Optional[LoRARequest] = None,
encoder_prompt: Optional[str] = None,
encoder_prompt_token_ids: Optional[list[int]] = None,
num_cached_tokens: Optional[int] = None,
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
) -> None:
if kwargs:
logger.warning_once("RequestOutput: Ignoring extra arguments: %s",
str(kwargs))
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.multi_modal_placeholders = multi_modal_placeholders or {}
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
self.metrics = metrics
self.lora_request = lora_request
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
self.finished |= next_output.finished
self.kv_transfer_params = next_output.kv_transfer_params
for next_completion in next_output.outputs:
for i, completion in enumerate(self.outputs):
if completion.index == next_completion.index:
if aggregate:
# Merge outputs with same index
completion.text += next_completion.text
if not isinstance(completion.token_ids,
MutableSequence):
completion.token_ids = list(completion.token_ids)
completion.token_ids.extend(next_completion.token_ids)
if next_completion.logprobs:
assert completion.logprobs is not None
completion.logprobs.extend(
next_completion.logprobs)
completion.cumulative_logprob = (
next_completion.cumulative_logprob)
completion.finish_reason = next_completion.finish_reason
completion.stop_reason = next_completion.stop_reason
else:
# Replace the output with the new one
self.outputs[i] = next_completion
break
else:
self.outputs.append(next_completion)
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"encoder_prompt={self.encoder_prompt!r}, "
f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"lora_request={self.lora_request}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"multi_modal_placeholders={self.multi_modal_placeholders})")
_O = TypeVar("_O", default=PoolingOutput)
class PoolingRequestOutput(Generic[_O]):
"""
The output data of a pooling request to the LLM.
Args:
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
finished (bool): A flag indicating whether the pooling is completed.
"""
def __init__(self, request_id: str, outputs: _O,
prompt_token_ids: list[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.finished = finished
self.outputs = outputs
def __repr__(self):
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})")
@dataclass
class EmbeddingOutput:
"""The output data of one embedding output of a request.
Args:
embedding: The embedding vector, which is a list of floats.
Its length depends on the hidden dimension of the model.
"""
embedding: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D embedding vector")
return EmbeddingOutput(pooled_data.tolist())
@property
def hidden_size(self) -> int:
return len(self.embedding)
def __repr__(self) -> str:
return f"EmbeddingOutput(hidden_size={self.hidden_size})"
class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return EmbeddingRequestOutput(
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ClassificationOutput:
"""The output data of one classification output of a request.
Args:
probs: The probability vector, which is a list of floats.
Its length depends on the number of classes.
"""
probs: list[float]
@staticmethod
def from_base(pooling_output: PoolingOutput):
# pooling_output shape: (num_classes)
pooled_data = pooling_output.data
if pooled_data.ndim != 1:
raise ValueError("pooled_data should be a 1-D probability vector")
return ClassificationOutput(pooled_data.tolist())
@property
def num_classes(self) -> int:
return len(self.probs)
def __repr__(self) -> str:
return f"ClassificationOutput(num_classes={self.num_classes})"
class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ClassificationRequestOutput(
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)
@dataclass
class ScoringOutput:
"""The output data of one scoring output of a request.
Args:
score: The similarity score, which is a scalar value.
"""
score: float
@staticmethod
def from_base(pooling_output: PoolingOutput):
# pooling_output shape:
# classify task: (num_classes) num_classes == 1
# embed task: a scalar value
pooled_data = pooling_output.data.squeeze()
if pooled_data.ndim != 0:
raise ValueError("pooled_data should be a scalar score")
return ScoringOutput(pooled_data.item())
def __repr__(self) -> str:
return f"ScoringOutput(score={self.score})"
class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]):
@staticmethod
def from_base(request_output: PoolingRequestOutput):
return ScoringRequestOutput(
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
finished=request_output.finished,
)