mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 22:05:48 +08:00
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
207 lines
7.7 KiB
Python
207 lines
7.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import itertools
|
|
from collections.abc import Iterable, Iterator, MutableSequence
|
|
from dataclasses import dataclass, field
|
|
from typing import overload
|
|
|
|
|
|
# We use dataclass for now because it is used for
|
|
# openai server output, and msgspec is not serializable.
|
|
# TODO(sang): Fix it.
|
|
@dataclass
|
|
class Logprob:
|
|
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
|
|
|
Attributes:
|
|
logprob: The logprob of chosen token
|
|
rank: The vocab rank of chosen token (>=1)
|
|
decoded_token: The decoded chosen token index
|
|
"""
|
|
|
|
logprob: float
|
|
rank: int | None = None
|
|
decoded_token: str | None = None
|
|
|
|
|
|
LogprobsOnePosition = dict[int, Logprob]
|
|
|
|
|
|
@dataclass
|
|
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
|
"""
|
|
Flat logprobs of a request into multiple primitive type lists.
|
|
|
|
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
|
overhead significantly. As it flattened logprob information for
|
|
all positions and ranks in to multiple primitive type lists (i.e.
|
|
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
|
So regardless of the sequence length and top_logprobs setup,
|
|
FlatLogprobs would only introduce a constant amount of objects.
|
|
|
|
As each position might contains different amount of ranks,
|
|
start_indices_per_position would be used to access the logprob ranges
|
|
for different positions.
|
|
|
|
NOTE: To reduce the migration overhead and improve backward compatibility,
|
|
we support the key Sequence APIs of list, so it could act as
|
|
list[LogprobsOnePosition]
|
|
"""
|
|
|
|
# Start / end indices to indicate the range of logprobs for each position.
|
|
start_indices: list[int] = field(default_factory=list)
|
|
end_indices: list[int] = field(default_factory=list)
|
|
|
|
# Flatten Logprob information for (each position, rank).
|
|
# For position <i>, the logprobs are ranged
|
|
# from self.start_indices[i] to self.end_indices[i] (exclusive).
|
|
token_ids: list[int] = field(default_factory=list)
|
|
logprobs: list[float] = field(default_factory=list)
|
|
ranks: list[int | None] = field(default_factory=list)
|
|
decoded_tokens: list[str | None] = field(default_factory=list)
|
|
|
|
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
|
|
"""Appends the container with logprobs for the next position"""
|
|
self.start_indices.append(len(self.logprobs))
|
|
if logprobs_one_position:
|
|
for token_id, logprob in logprobs_one_position.items():
|
|
self.token_ids.append(token_id)
|
|
self.logprobs.append(logprob.logprob)
|
|
self.ranks.append(logprob.rank)
|
|
self.decoded_tokens.append(logprob.decoded_token)
|
|
self.end_indices.append(len(self.logprobs))
|
|
|
|
def append_fast(
|
|
self,
|
|
token_ids: list[int],
|
|
logprobs: list[float],
|
|
ranks: itertools.chain[int],
|
|
decoded_tokens: Iterable[str | None],
|
|
) -> None:
|
|
"""
|
|
Appends logprobs for the next position without creating
|
|
the intermediate logprob dictionary.
|
|
"""
|
|
self.start_indices.append(len(self.logprobs))
|
|
for token_id, logprob, rank, decoded_token in zip(
|
|
token_ids, logprobs, ranks, decoded_tokens
|
|
):
|
|
self.token_ids.append(token_id)
|
|
self.logprobs.append(logprob)
|
|
self.ranks.append(rank)
|
|
self.decoded_tokens.append(decoded_token)
|
|
self.end_indices.append(len(self.logprobs))
|
|
|
|
def extend(self, logprobs_multi_positions) -> None:
|
|
"""Extends the container with logprobs for the next multiple positions"""
|
|
for logprobs_one_position in logprobs_multi_positions:
|
|
self.append(logprobs_one_position)
|
|
|
|
def __len__(self) -> int:
|
|
"""Gets number of positions stored in the container"""
|
|
return len(self.start_indices)
|
|
|
|
@overload
|
|
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
|
|
|
@overload
|
|
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
|
|
|
|
def __getitem__(self, index: int | slice):
|
|
"""Extracts logprobs of a given position or slice"""
|
|
if isinstance(index, int):
|
|
return {
|
|
self.token_ids[i]: Logprob(
|
|
logprob=self.logprobs[i],
|
|
rank=self.ranks[i],
|
|
decoded_token=self.decoded_tokens[i],
|
|
)
|
|
for i in range(self.start_indices[index], self.end_indices[index])
|
|
}
|
|
elif isinstance(index, slice):
|
|
min_index = self.start_indices[index][0]
|
|
max_index = self.end_indices[index][-1]
|
|
return FlatLogprobs(
|
|
# Shift updated start_indices and end_indices to
|
|
# be 0-indexed
|
|
start_indices=[i - min_index for i in self.start_indices[index]],
|
|
end_indices=[i - min_index for i in self.end_indices[index]],
|
|
token_ids=self.token_ids[min_index:max_index],
|
|
logprobs=self.logprobs[min_index:max_index],
|
|
ranks=self.ranks[min_index:max_index],
|
|
decoded_tokens=self.decoded_tokens[min_index:max_index],
|
|
)
|
|
else:
|
|
raise TypeError(f"Invalid index type: {type(index)}")
|
|
|
|
def __setitem__(self, item, value) -> None:
|
|
raise TypeError("Cannot set logprobs in FlatLogprobs")
|
|
|
|
def __delitem__(self, item) -> None:
|
|
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
|
|
|
def insert(self, item) -> None:
|
|
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
|
|
|
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
|
"""
|
|
Iterates the container and yields LogprobsOnePosition for
|
|
each position.
|
|
"""
|
|
for i in range(0, len(self.start_indices)):
|
|
yield self.__getitem__(i)
|
|
|
|
|
|
# {token_id -> logprob} per each sequence group. None if the corresponding
|
|
# sequence group doesn't require prompt logprob.
|
|
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
|
# {token_id -> logprob} for each sequence group.
|
|
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
|
|
|
|
|
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
|
|
"""Creates a container to store prompt logprobs for a request"""
|
|
logprobs = FlatLogprobs() if flat_logprobs else []
|
|
# NOTE: logprob of first prompt token is None.
|
|
logprobs.append(None)
|
|
return logprobs
|
|
|
|
|
|
def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
|
|
"""Creates a container to store decode logprobs for a request"""
|
|
return FlatLogprobs() if flat_logprobs else []
|
|
|
|
|
|
def append_logprobs_for_next_position(
|
|
request_logprobs: PromptLogprobs | SampleLogprobs,
|
|
token_ids: list[int],
|
|
logprobs: list[float],
|
|
decoded_tokens: Iterable[str | None],
|
|
rank: int,
|
|
num_logprobs: int,
|
|
) -> None:
|
|
"""Appends logprobs for the next position"""
|
|
if num_logprobs == -1:
|
|
num_logprobs = len(logprobs)
|
|
# We do not need a special case for the sampled token
|
|
# being in the topk, since inserting duplicated data
|
|
# into a dictionary twice is the same as doing it once.
|
|
topk_ranks = range(1, num_logprobs + 1)
|
|
ranks = itertools.chain((rank,), topk_ranks)
|
|
|
|
if isinstance(request_logprobs, FlatLogprobs):
|
|
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
|
else:
|
|
request_logprobs.append(
|
|
{
|
|
token_id: Logprob(
|
|
logprob=logprob,
|
|
rank=rank,
|
|
decoded_token=token,
|
|
)
|
|
for token_id, logprob, rank, token in zip(
|
|
token_ids, logprobs, ranks, decoded_tokens
|
|
)
|
|
}
|
|
)
|