[Perf] Introduce FlattenLogprobs to store logprobs results to reduce GC overhead (#28171)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-11-07 00:27:12 -08:00 committed by GitHub
parent 21b82f4ea2
commit ccd98b59c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 534 additions and 125 deletions

View File

@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs
MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
NUM_TOP_LOGPROBS = 5
NUM_PROMPT_LOGPROBS = 7
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flatten_logprobs", [True, False])
def test_ranks(
vllm_runner,
model,
dtype,
greedy,
flatten_logprobs,
example_prompts,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
sampling_params = SamplingParams(
temperature=0.0 if greedy else 1.0,
top_p=1.0,
max_tokens=MAX_TOKENS,
logprobs=NUM_TOP_LOGPROBS,
prompt_logprobs=NUM_PROMPT_LOGPROBS,
)
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
assert len(results) == len(example_prompt_tokens)
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
decode_tokens, _, decode_logprobs, prompt_logprobs = result
# Ensure the return type of logprobs is accurate
assert isinstance(
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
)
########################
# Check prompt logprobs
########################
assert len(prompt_tokens) == len(prompt_logprobs)
# No logprob for first prompt token
assert not prompt_logprobs[0]
for position, (token, logprobs) in enumerate(
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
):
# Ensure logprobs of prompt token is always returned
logprob = logprobs.get(token)
assert logprob is not None
assert logprob.rank >= 1
# Ensure # of returned logprobs should be
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
{logprob.rank for logprob in logprobs.values()}
)
########################
# Check sample logprobs
########################
assert len(decode_tokens) == len(decode_logprobs)
for position, (token, logprobs) in enumerate(
zip(decode_tokens, decode_logprobs)
):
# Ensure logprobs of chosen token is always returned
logprob = logprobs.get(token)
assert logprob is not None
if greedy:
# For greedy sampling, all chosen logprob should be top ranked
assert logprob.rank == 1
else:
assert logprob.rank >= 1
# Ensure # of returned logprobs should be
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
{logprob.rank for logprob in logprobs.values()}
)

View File

@ -1,59 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import SamplingParams
MODELS = ["distilbert/distilgpt2"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_ranks(
vllm_runner,
model,
dtype,
example_prompts,
):
max_tokens = 5
num_top_logprobs = 5
num_prompt_logprobs = 5
with vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) as vllm_model:
## Test greedy logprobs ranks
vllm_sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs,
)
vllm_results = vllm_model.generate_w_logprobs(
example_prompts, vllm_sampling_params
)
## Test non-greedy logprobs ranks
sampling_params = SamplingParams(
temperature=1.0,
top_p=1.0,
max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_prompt_logprobs,
)
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
for result in vllm_results:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks = 1
for token, logprobs in zip(result[0], result[2]):
assert token in logprobs
assert logprobs[token].rank == 1
for result in res:
assert result[2] is not None
assert len(result[2]) == len(result[0])
# check whether all chosen tokens have ranks
for token, logprobs in zip(result[0], result[2]):
assert logprobs[token].rank >= 1

222
tests/test_logprobs.py Normal file
View File

@ -0,0 +1,222 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.logprobs import (
FlattenLogprobs,
Logprob,
LogprobsOnePosition,
append_logprobs_for_next_position,
create_prompt_logprobs,
create_sample_logprobs,
)
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, list)
# Ensure first prompt position logprobs is None
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] is None
sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, list)
assert len(sample_logprobs) == 0
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlattenLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
assert len(prompt_logprobs.token_ids) == 0
assert len(prompt_logprobs.logprobs) == 0
assert len(prompt_logprobs.ranks) == 0
assert len(prompt_logprobs.decoded_tokens) == 0
# Ensure first prompt position logprobs is empty
assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] == dict()
sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, FlattenLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
assert len(sample_logprobs.token_ids) == 0
assert len(sample_logprobs.logprobs) == 0
assert len(sample_logprobs.ranks) == 0
assert len(sample_logprobs.decoded_tokens) == 0
assert len(sample_logprobs) == 0
def test_append_logprobs_for_next_position_none_flatten(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
token_ids=[1],
logprobs=[0.1],
decoded_tokens=["1"],
rank=10,
num_logprobs=-1,
)
append_logprobs_for_next_position(
logprobs,
token_ids=[2, 3],
logprobs=[0.2, 0.3],
decoded_tokens=["2", "3"],
rank=11,
num_logprobs=-1,
)
assert isinstance(logprobs, list)
assert logprobs == [
{1: Logprob(logprob=0.1, rank=10, decoded_token="1")},
{
2: Logprob(logprob=0.2, rank=11, decoded_token="2"),
3: Logprob(logprob=0.3, rank=1, decoded_token="3"),
},
]
def test_append_logprobs_for_next_position_flatten(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
token_ids=[1],
logprobs=[0.1],
decoded_tokens=["1"],
rank=10,
num_logprobs=-1,
)
append_logprobs_for_next_position(
logprobs,
token_ids=[2, 3],
logprobs=[0.2, 0.3],
decoded_tokens=["2", "3"],
rank=11,
num_logprobs=-1,
)
assert isinstance(logprobs, FlattenLogprobs)
assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3]
assert logprobs.logprobs == [0.1, 0.2, 0.3]
assert logprobs.ranks == [10, 11, 1]
assert logprobs.decoded_tokens == ["1", "2", "3"]
LOGPROBS_ONE_POSITION_0: LogprobsOnePosition = {
1: Logprob(logprob=0.1, rank=10, decoded_token="10")
}
LOGPROBS_ONE_POSITION_1: LogprobsOnePosition = {
2: Logprob(logprob=0.2, rank=20, decoded_token="20"),
3: Logprob(logprob=0.3, rank=30, decoded_token="30"),
}
LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
4: Logprob(logprob=0.4, rank=40, decoded_token="40"),
5: Logprob(logprob=0.5, rank=50, decoded_token="50"),
6: Logprob(logprob=0.6, rank=60, decoded_token="60"),
}
def test_flatten_logprobs_append() -> None:
logprobs = FlattenLogprobs()
logprobs.append(LOGPROBS_ONE_POSITION_0)
logprobs.append(LOGPROBS_ONE_POSITION_1)
assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3]
assert logprobs.logprobs == [0.1, 0.2, 0.3]
assert logprobs.ranks == [10, 20, 30]
assert logprobs.decoded_tokens == ["10", "20", "30"]
logprobs.append(LOGPROBS_ONE_POSITION_2)
assert logprobs.start_indices == [0, 1, 3]
assert logprobs.end_indices == [1, 3, 6]
assert logprobs.token_ids == [1, 2, 3, 4, 5, 6]
assert logprobs.logprobs == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
assert logprobs.ranks == [10, 20, 30, 40, 50, 60]
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
def test_flatten_logprobs_extend() -> None:
logprobs = FlattenLogprobs()
# Extend with list[LogprobsOnePosition]
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
assert logprobs.start_indices == [0, 3]
assert logprobs.end_indices == [3, 4]
assert logprobs.token_ids == [4, 5, 6, 1]
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1]
assert logprobs.ranks == [40, 50, 60, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
other_logprobs = FlattenLogprobs()
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
# Extend with another FlattenLogprobs
logprobs.extend(other_logprobs)
assert logprobs.start_indices == [0, 3, 4, 6]
assert logprobs.end_indices == [3, 4, 6, 7]
assert logprobs.token_ids == [4, 5, 6, 1, 2, 3, 1]
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.1]
assert logprobs.ranks == [40, 50, 60, 10, 20, 30, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
def test_flatten_logprobs_access() -> None:
logprobs = FlattenLogprobs()
logprobs.extend(
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
)
assert logprobs.start_indices == [0, 2, 5]
assert logprobs.end_indices == [2, 5, 6]
assert logprobs.token_ids == [2, 3, 4, 5, 6, 1]
assert logprobs.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6, 0.1]
assert logprobs.ranks == [20, 30, 40, 50, 60, 10]
assert logprobs.decoded_tokens == ["20", "30", "40", "50", "60", "10"]
# Test __len__
assert len(logprobs) == 3
# Test __iter__
for actual_logprobs, expected_logprobs in zip(
logprobs,
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0],
):
assert actual_logprobs == expected_logprobs
# Test __getitem__ : single item
assert logprobs[0] == LOGPROBS_ONE_POSITION_1
assert logprobs[1] == LOGPROBS_ONE_POSITION_2
assert logprobs[2] == LOGPROBS_ONE_POSITION_0
# Test __getitem__ : slice
logprobs02 = logprobs[:2]
assert len(logprobs02) == 2
assert logprobs02[0] == LOGPROBS_ONE_POSITION_1
assert logprobs02[1] == LOGPROBS_ONE_POSITION_2
assert logprobs02.start_indices == [0, 2]
assert logprobs02.end_indices == [2, 5]
assert logprobs02.token_ids == [2, 3, 4, 5, 6]
assert logprobs02.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6]
assert logprobs02.ranks == [20, 30, 40, 50, 60]
assert logprobs02.decoded_tokens == ["20", "30", "40", "50", "60"]
logprobs_last2 = logprobs[-2:]
assert len(logprobs_last2) == 2
assert logprobs_last2[0] == LOGPROBS_ONE_POSITION_2
assert logprobs_last2[1] == LOGPROBS_ONE_POSITION_0
assert logprobs_last2.start_indices == [0, 3]
assert logprobs_last2.end_indices == [3, 4]
assert logprobs_last2.token_ids == [4, 5, 6, 1]
assert logprobs_last2.logprobs == [0.4, 0.5, 0.6, 0.1]
assert logprobs_last2.ranks == [40, 50, 60, 10]
assert logprobs_last2.decoded_tokens == ["40", "50", "60", "10"]

View File

@ -220,6 +220,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLATTEN_LOGPROBS: bool = False
def get_default_cache_root():
@ -1463,6 +1464,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
}
# --8<-- [end:env-vars-definition]

View File

@ -1,6 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import itertools
from collections.abc import Iterable, Iterator, MutableSequence
from dataclasses import dataclass, field
from typing import overload
import vllm.envs as envs
# We use dataclass for now because it is used for
@ -21,8 +26,183 @@ class Logprob:
decoded_token: str | None = None
LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flatten logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
for different positions.
NOTE: To reduce the migration overhead and improve backward compatibility,
we support the key Sequence APIs of list, so it could act as
list[LogprobsOnePosition]
"""
# Start / end indices to indicate the range of logprobs for each position.
start_indices: list[int] = field(default_factory=list)
end_indices: list[int] = field(default_factory=list)
# Flatten Logprob information for (each position, rank).
# For position <i>, the logprobs are ranged
# from self.start_indices[i] to self.end_indices[i] (exclusive).
token_ids: list[int] = field(default_factory=list)
logprobs: list[float] = field(default_factory=list)
ranks: list[int | None] = field(default_factory=list)
decoded_tokens: list[str | None] = field(default_factory=list)
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
"""Appends the container with logprobs for the next position"""
self.start_indices.append(len(self.logprobs))
if logprobs_one_position:
for token_id, logprob in logprobs_one_position.items():
self.token_ids.append(token_id)
self.logprobs.append(logprob.logprob)
self.ranks.append(logprob.rank)
self.decoded_tokens.append(logprob.decoded_token)
self.end_indices.append(len(self.logprobs))
def append_fast(
self,
token_ids: list[int],
logprobs: list[float],
ranks: itertools.chain[int],
decoded_tokens: Iterable[str | None],
) -> None:
"""
Appends logprobs for the next position without creating
the intermediate logprob dictionary.
"""
self.start_indices.append(len(self.logprobs))
for token_id, logprob, rank, decoded_token in zip(
token_ids, logprobs, ranks, decoded_tokens
):
self.token_ids.append(token_id)
self.logprobs.append(logprob)
self.ranks.append(rank)
self.decoded_tokens.append(decoded_token)
self.end_indices.append(len(self.logprobs))
def extend(self, logprobs_multi_positions) -> None:
"""Extends the container with logprobs for the next multiple positions"""
for logprobs_one_position in logprobs_multi_positions:
self.append(logprobs_one_position)
def __len__(self) -> int:
"""Gets number of positions stored in the container"""
return len(self.start_indices)
@overload
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
@overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
if isinstance(index, int):
return {
self.token_ids[i]: Logprob(
logprob=self.logprobs[i],
rank=self.ranks[i],
decoded_token=self.decoded_tokens[i],
)
for i in range(self.start_indices[index], self.end_indices[index])
}
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlattenLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
end_indices=[i - min_index for i in self.end_indices[index]],
token_ids=self.token_ids[min_index:max_index],
logprobs=self.logprobs[min_index:max_index],
ranks=self.ranks[min_index:max_index],
decoded_tokens=self.decoded_tokens[min_index:max_index],
)
else:
raise TypeError(f"Invalid index type: {type(index)}")
def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs")
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
Iterates the container and yields LogprobsOnePosition for
each position.
"""
for i in range(0, len(self.start_indices)):
yield self.__getitem__(i)
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = list[dict[int, Logprob] | None]
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = list[dict[int, Logprob]]
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
def append_logprobs_for_next_position(
request_logprobs: PromptLogprobs | SampleLogprobs,
token_ids: list[int],
logprobs: list[float],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> None:
"""Appends logprobs for the next position"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
if isinstance(request_logprobs, FlattenLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(
{
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
token_ids, logprobs, ranks, decoded_tokens
)
}
)

View File

@ -2,11 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.logprobs import (
PromptLogprobs,
SampleLogprobs,
append_logprobs_for_next_position,
create_prompt_logprobs,
create_sample_logprobs,
)
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
convert_ids_list_to_tokens,
@ -44,9 +49,10 @@ class LogprobsProcessor:
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.0),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
prompt_logprobs=(
None if num_prompt_logprobs is None else create_prompt_logprobs()
),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
@ -80,15 +86,14 @@ class LogprobsProcessor:
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
)
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.logprobs,
token_ids,
logprobs,
decoded_tokens,
rank,
self.num_logprobs,
)
def _update_prompt_logprobs(
@ -136,15 +141,14 @@ class LogprobsProcessor:
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
)
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(
prompt_logprobs[pos],
token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs,
)
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.prompt_logprobs,
token_ids[pos],
prompt_logprobs[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs,
)
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
@ -166,46 +170,6 @@ class LogprobsProcessor:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens
)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)