# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from collections.abc import Iterable, Iterator, MutableSequence from dataclasses import dataclass, field from typing import overload # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. # TODO(sang): Fix it. @dataclass class Logprob: """Infos for supporting OpenAI compatible logprobs and token ranks. Attributes: logprob: The logprob of chosen token rank: The vocab rank of chosen token (>=1) decoded_token: The decoded chosen token index """ logprob: float rank: int | None = None decoded_token: str | None = None LogprobsOnePosition = dict[int, Logprob] @dataclass class FlatLogprobs(MutableSequence[LogprobsOnePosition]): """ Flat logprobs of a request into multiple primitive type lists. Compared to list[dict[int, Logprob]], this data structure reduced GC overhead significantly. As it flattened logprob information for all positions and ranks in to multiple primitive type lists (i.e. logprobs, token_ids, ranks per token_ids, decoded_tokens). So regardless of the sequence length and top_logprobs setup, FlatLogprobs would only introduce a constant amount of objects. As each position might contains different amount of ranks, start_indices_per_position would be used to access the logprob ranges for different positions. NOTE: To reduce the migration overhead and improve backward compatibility, we support the key Sequence APIs of list, so it could act as list[LogprobsOnePosition] """ # Start / end indices to indicate the range of logprobs for each position. start_indices: list[int] = field(default_factory=list) end_indices: list[int] = field(default_factory=list) # Flatten Logprob information for (each position, rank). # For position , the logprobs are ranged # from self.start_indices[i] to self.end_indices[i] (exclusive). token_ids: list[int] = field(default_factory=list) logprobs: list[float] = field(default_factory=list) ranks: list[int | None] = field(default_factory=list) decoded_tokens: list[str | None] = field(default_factory=list) def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None: """Appends the container with logprobs for the next position""" self.start_indices.append(len(self.logprobs)) if logprobs_one_position: for token_id, logprob in logprobs_one_position.items(): self.token_ids.append(token_id) self.logprobs.append(logprob.logprob) self.ranks.append(logprob.rank) self.decoded_tokens.append(logprob.decoded_token) self.end_indices.append(len(self.logprobs)) def append_fast( self, token_ids: list[int], logprobs: list[float], ranks: itertools.chain[int], decoded_tokens: Iterable[str | None], ) -> None: """ Appends logprobs for the next position without creating the intermediate logprob dictionary. """ self.start_indices.append(len(self.logprobs)) for token_id, logprob, rank, decoded_token in zip( token_ids, logprobs, ranks, decoded_tokens ): self.token_ids.append(token_id) self.logprobs.append(logprob) self.ranks.append(rank) self.decoded_tokens.append(decoded_token) self.end_indices.append(len(self.logprobs)) def extend(self, logprobs_multi_positions) -> None: """Extends the container with logprobs for the next multiple positions""" for logprobs_one_position in logprobs_multi_positions: self.append(logprobs_one_position) def __len__(self) -> int: """Gets number of positions stored in the container""" return len(self.start_indices) @overload def __getitem__(self, position: int) -> LogprobsOnePosition: ... @overload def __getitem__(self, s: slice, /) -> "FlatLogprobs": ... def __getitem__(self, index: int | slice): """Extracts logprobs of a given position or slice""" if isinstance(index, int): return { self.token_ids[i]: Logprob( logprob=self.logprobs[i], rank=self.ranks[i], decoded_token=self.decoded_tokens[i], ) for i in range(self.start_indices[index], self.end_indices[index]) } elif isinstance(index, slice): min_index = self.start_indices[index][0] max_index = self.end_indices[index][-1] return FlatLogprobs( # Shift updated start_indices and end_indices to # be 0-indexed start_indices=[i - min_index for i in self.start_indices[index]], end_indices=[i - min_index for i in self.end_indices[index]], token_ids=self.token_ids[min_index:max_index], logprobs=self.logprobs[min_index:max_index], ranks=self.ranks[min_index:max_index], decoded_tokens=self.decoded_tokens[min_index:max_index], ) else: raise TypeError(f"Invalid index type: {type(index)}") def __setitem__(self, item, value) -> None: raise TypeError("Cannot set logprobs in FlatLogprobs") def __delitem__(self, item) -> None: raise TypeError("Cannot delete logprobs from FlatLogprobs") def insert(self, item) -> None: raise TypeError("Cannot insert logprobs to FlatLogprobs") def __iter__(self) -> Iterator[LogprobsOnePosition]: """ Iterates the container and yields LogprobsOnePosition for each position. """ for i in range(0, len(self.start_indices)): yield self.__getitem__(i) # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" logprobs = FlatLogprobs() if flat_logprobs else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" return FlatLogprobs() if flat_logprobs else [] def append_logprobs_for_next_position( request_logprobs: PromptLogprobs | SampleLogprobs, token_ids: list[int], logprobs: list[float], decoded_tokens: Iterable[str | None], rank: int, num_logprobs: int, ) -> None: """Appends logprobs for the next position""" if num_logprobs == -1: num_logprobs = len(logprobs) # We do not need a special case for the sampled token # being in the topk, since inserting duplicated data # into a dictionary twice is the same as doing it once. topk_ranks = range(1, num_logprobs + 1) ranks = itertools.chain((rank,), topk_ranks) if isinstance(request_logprobs, FlatLogprobs): request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) else: request_logprobs.append( { token_id: Logprob( logprob=logprob, rank=rank, decoded_token=token, ) for token_id, logprob, rank, token in zip( token_ids, logprobs, ranks, decoded_tokens ) } )