vllm/vllm/logprobs.py
Jialin Ouyang 40b6b38f2c
[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
2025-11-19 02:10:02 +00:00

207 lines
7.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable, Iterator, MutableSequence
from dataclasses import dataclass, field
from typing import overload
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
Attributes:
logprob: The logprob of chosen token
rank: The vocab rank of chosen token (>=1)
decoded_token: The decoded chosen token index
"""
logprob: float
rank: int | None = None
decoded_token: str | None = None
LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
for different positions.
NOTE: To reduce the migration overhead and improve backward compatibility,
we support the key Sequence APIs of list, so it could act as
list[LogprobsOnePosition]
"""
# Start / end indices to indicate the range of logprobs for each position.
start_indices: list[int] = field(default_factory=list)
end_indices: list[int] = field(default_factory=list)
# Flatten Logprob information for (each position, rank).
# For position <i>, the logprobs are ranged
# from self.start_indices[i] to self.end_indices[i] (exclusive).
token_ids: list[int] = field(default_factory=list)
logprobs: list[float] = field(default_factory=list)
ranks: list[int | None] = field(default_factory=list)
decoded_tokens: list[str | None] = field(default_factory=list)
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
"""Appends the container with logprobs for the next position"""
self.start_indices.append(len(self.logprobs))
if logprobs_one_position:
for token_id, logprob in logprobs_one_position.items():
self.token_ids.append(token_id)
self.logprobs.append(logprob.logprob)
self.ranks.append(logprob.rank)
self.decoded_tokens.append(logprob.decoded_token)
self.end_indices.append(len(self.logprobs))
def append_fast(
self,
token_ids: list[int],
logprobs: list[float],
ranks: itertools.chain[int],
decoded_tokens: Iterable[str | None],
) -> None:
"""
Appends logprobs for the next position without creating
the intermediate logprob dictionary.
"""
self.start_indices.append(len(self.logprobs))
for token_id, logprob, rank, decoded_token in zip(
token_ids, logprobs, ranks, decoded_tokens
):
self.token_ids.append(token_id)
self.logprobs.append(logprob)
self.ranks.append(rank)
self.decoded_tokens.append(decoded_token)
self.end_indices.append(len(self.logprobs))
def extend(self, logprobs_multi_positions) -> None:
"""Extends the container with logprobs for the next multiple positions"""
for logprobs_one_position in logprobs_multi_positions:
self.append(logprobs_one_position)
def __len__(self) -> int:
"""Gets number of positions stored in the container"""
return len(self.start_indices)
@overload
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
@overload
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
if isinstance(index, int):
return {
self.token_ids[i]: Logprob(
logprob=self.logprobs[i],
rank=self.ranks[i],
decoded_token=self.decoded_tokens[i],
)
for i in range(self.start_indices[index], self.end_indices[index])
}
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlatLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
end_indices=[i - min_index for i in self.end_indices[index]],
token_ids=self.token_ids[min_index:max_index],
logprobs=self.logprobs[min_index:max_index],
ranks=self.ranks[min_index:max_index],
decoded_tokens=self.decoded_tokens[min_index:max_index],
)
else:
raise TypeError(f"Invalid index type: {type(index)}")
def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlatLogprobs")
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
Iterates the container and yields LogprobsOnePosition for
each position.
"""
for i in range(0, len(self.start_indices)):
yield self.__getitem__(i)
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlatLogprobs() if flat_logprobs else []
def append_logprobs_for_next_position(
request_logprobs: PromptLogprobs | SampleLogprobs,
token_ids: list[int],
logprobs: list[float],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> None:
"""Appends logprobs for the next position"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(
{
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
token_ids, logprobs, ranks, decoded_tokens
)
}
)