From ccd98b59c15b50478bb214d6824889570c6f4b8c Mon Sep 17 00:00:00 2001 From: Jialin Ouyang Date: Fri, 7 Nov 2025 00:27:12 -0800 Subject: [PATCH] [Perf] Introduce FlattenLogprobs to store logprobs results to reduce GC overhead (#28171) Signed-off-by: Jialin Ouyang --- tests/samplers/test_logprobs.py | 96 ++++++++++++++ tests/samplers/test_ranks.py | 59 --------- tests/test_logprobs.py | 222 ++++++++++++++++++++++++++++++++ vllm/envs.py | 6 + vllm/logprobs.py | 186 +++++++++++++++++++++++++- vllm/v1/engine/logprobs.py | 90 ++++--------- 6 files changed, 534 insertions(+), 125 deletions(-) create mode 100644 tests/samplers/test_logprobs.py delete mode 100644 tests/samplers/test_ranks.py create mode 100644 tests/test_logprobs.py diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py new file mode 100644 index 0000000000000..87f5d40ac1da7 --- /dev/null +++ b/tests/samplers/test_logprobs.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm import SamplingParams +from vllm.logprobs import FlattenLogprobs + +MODELS = ["distilbert/distilgpt2"] +MAX_TOKENS = 5 +NUM_TOP_LOGPROBS = 5 +NUM_PROMPT_LOGPROBS = 7 +MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("greedy", [True, False]) +@pytest.mark.parametrize("flatten_logprobs", [True, False]) +def test_ranks( + vllm_runner, + model, + dtype, + greedy, + flatten_logprobs, + example_prompts, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") + with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: + tokenizer = vllm_model.llm.get_tokenizer() + example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] + sampling_params = SamplingParams( + temperature=0.0 if greedy else 1.0, + top_p=1.0, + max_tokens=MAX_TOKENS, + logprobs=NUM_TOP_LOGPROBS, + prompt_logprobs=NUM_PROMPT_LOGPROBS, + ) + results = vllm_model.generate_w_logprobs(example_prompts, sampling_params) + + assert len(results) == len(example_prompt_tokens) + for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)): + decode_tokens, _, decode_logprobs, prompt_logprobs = result + + # Ensure the return type of logprobs is accurate + assert isinstance( + prompt_logprobs, FlattenLogprobs if flatten_logprobs else list + ) + assert isinstance( + decode_logprobs, FlattenLogprobs if flatten_logprobs else list + ) + + ######################## + # Check prompt logprobs + ######################## + assert len(prompt_tokens) == len(prompt_logprobs) + # No logprob for first prompt token + assert not prompt_logprobs[0] + for position, (token, logprobs) in enumerate( + zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1 + ): + # Ensure logprobs of prompt token is always returned + logprob = logprobs.get(token) + assert logprob is not None + assert logprob.rank >= 1 + # Ensure # of returned logprobs should be + # either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1 + assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1 + # Ensure top NUM_PROMPT_LOGPROBS is always extracted + assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset( + {logprob.rank for logprob in logprobs.values()} + ) + + ######################## + # Check sample logprobs + ######################## + assert len(decode_tokens) == len(decode_logprobs) + for position, (token, logprobs) in enumerate( + zip(decode_tokens, decode_logprobs) + ): + # Ensure logprobs of chosen token is always returned + logprob = logprobs.get(token) + assert logprob is not None + if greedy: + # For greedy sampling, all chosen logprob should be top ranked + assert logprob.rank == 1 + else: + assert logprob.rank >= 1 + # Ensure # of returned logprobs should be + # either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1 + assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1 + # Ensure top NUM_TOP_LOGPROBS logprobs is always extracted + assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset( + {logprob.rank for logprob in logprobs.values()} + ) diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py deleted file mode 100644 index 1359e6403e4c3..0000000000000 --- a/tests/samplers/test_ranks.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm import SamplingParams - -MODELS = ["distilbert/distilgpt2"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_ranks( - vllm_runner, - model, - dtype, - example_prompts, -): - max_tokens = 5 - num_top_logprobs = 5 - num_prompt_logprobs = 5 - - with vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) as vllm_model: - ## Test greedy logprobs ranks - vllm_sampling_params = SamplingParams( - temperature=0.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs, - ) - vllm_results = vllm_model.generate_w_logprobs( - example_prompts, vllm_sampling_params - ) - - ## Test non-greedy logprobs ranks - sampling_params = SamplingParams( - temperature=1.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs, - ) - res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) - - for result in vllm_results: - assert result[2] is not None - assert len(result[2]) == len(result[0]) - # check whether all chosen tokens have ranks = 1 - for token, logprobs in zip(result[0], result[2]): - assert token in logprobs - assert logprobs[token].rank == 1 - - for result in res: - assert result[2] is not None - assert len(result[2]) == len(result[0]) - # check whether all chosen tokens have ranks - for token, logprobs in zip(result[0], result[2]): - assert logprobs[token].rank >= 1 diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py new file mode 100644 index 0000000000000..1799d36381786 --- /dev/null +++ b/tests/test_logprobs.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import pytest + +from vllm.logprobs import ( + FlattenLogprobs, + Logprob, + LogprobsOnePosition, + append_logprobs_for_next_position, + create_prompt_logprobs, + create_sample_logprobs, +) + + +def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + + prompt_logprobs = create_prompt_logprobs() + assert isinstance(prompt_logprobs, list) + # Ensure first prompt position logprobs is None + assert len(prompt_logprobs) == 1 + assert prompt_logprobs[0] is None + + sample_logprobs = create_sample_logprobs() + assert isinstance(sample_logprobs, list) + assert len(sample_logprobs) == 0 + + +def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + + prompt_logprobs = create_prompt_logprobs() + assert isinstance(prompt_logprobs, FlattenLogprobs) + assert prompt_logprobs.start_indices == [0] + assert prompt_logprobs.end_indices == [0] + assert len(prompt_logprobs.token_ids) == 0 + assert len(prompt_logprobs.logprobs) == 0 + assert len(prompt_logprobs.ranks) == 0 + assert len(prompt_logprobs.decoded_tokens) == 0 + # Ensure first prompt position logprobs is empty + assert len(prompt_logprobs) == 1 + assert prompt_logprobs[0] == dict() + + sample_logprobs = create_sample_logprobs() + assert isinstance(sample_logprobs, FlattenLogprobs) + assert len(sample_logprobs.start_indices) == 0 + assert len(sample_logprobs.end_indices) == 0 + assert len(sample_logprobs.token_ids) == 0 + assert len(sample_logprobs.logprobs) == 0 + assert len(sample_logprobs.ranks) == 0 + assert len(sample_logprobs.decoded_tokens) == 0 + assert len(sample_logprobs) == 0 + + +def test_append_logprobs_for_next_position_none_flatten( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") + logprobs = create_sample_logprobs() + append_logprobs_for_next_position( + logprobs, + token_ids=[1], + logprobs=[0.1], + decoded_tokens=["1"], + rank=10, + num_logprobs=-1, + ) + append_logprobs_for_next_position( + logprobs, + token_ids=[2, 3], + logprobs=[0.2, 0.3], + decoded_tokens=["2", "3"], + rank=11, + num_logprobs=-1, + ) + assert isinstance(logprobs, list) + assert logprobs == [ + {1: Logprob(logprob=0.1, rank=10, decoded_token="1")}, + { + 2: Logprob(logprob=0.2, rank=11, decoded_token="2"), + 3: Logprob(logprob=0.3, rank=1, decoded_token="3"), + }, + ] + + +def test_append_logprobs_for_next_position_flatten( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") + logprobs = create_sample_logprobs() + append_logprobs_for_next_position( + logprobs, + token_ids=[1], + logprobs=[0.1], + decoded_tokens=["1"], + rank=10, + num_logprobs=-1, + ) + append_logprobs_for_next_position( + logprobs, + token_ids=[2, 3], + logprobs=[0.2, 0.3], + decoded_tokens=["2", "3"], + rank=11, + num_logprobs=-1, + ) + assert isinstance(logprobs, FlattenLogprobs) + assert logprobs.start_indices == [0, 1] + assert logprobs.end_indices == [1, 3] + assert logprobs.token_ids == [1, 2, 3] + assert logprobs.logprobs == [0.1, 0.2, 0.3] + assert logprobs.ranks == [10, 11, 1] + assert logprobs.decoded_tokens == ["1", "2", "3"] + + +LOGPROBS_ONE_POSITION_0: LogprobsOnePosition = { + 1: Logprob(logprob=0.1, rank=10, decoded_token="10") +} +LOGPROBS_ONE_POSITION_1: LogprobsOnePosition = { + 2: Logprob(logprob=0.2, rank=20, decoded_token="20"), + 3: Logprob(logprob=0.3, rank=30, decoded_token="30"), +} +LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = { + 4: Logprob(logprob=0.4, rank=40, decoded_token="40"), + 5: Logprob(logprob=0.5, rank=50, decoded_token="50"), + 6: Logprob(logprob=0.6, rank=60, decoded_token="60"), +} + + +def test_flatten_logprobs_append() -> None: + logprobs = FlattenLogprobs() + logprobs.append(LOGPROBS_ONE_POSITION_0) + logprobs.append(LOGPROBS_ONE_POSITION_1) + assert logprobs.start_indices == [0, 1] + assert logprobs.end_indices == [1, 3] + assert logprobs.token_ids == [1, 2, 3] + assert logprobs.logprobs == [0.1, 0.2, 0.3] + assert logprobs.ranks == [10, 20, 30] + assert logprobs.decoded_tokens == ["10", "20", "30"] + + logprobs.append(LOGPROBS_ONE_POSITION_2) + assert logprobs.start_indices == [0, 1, 3] + assert logprobs.end_indices == [1, 3, 6] + assert logprobs.token_ids == [1, 2, 3, 4, 5, 6] + assert logprobs.logprobs == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + assert logprobs.ranks == [10, 20, 30, 40, 50, 60] + assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] + + +def test_flatten_logprobs_extend() -> None: + logprobs = FlattenLogprobs() + # Extend with list[LogprobsOnePosition] + logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) + assert logprobs.start_indices == [0, 3] + assert logprobs.end_indices == [3, 4] + assert logprobs.token_ids == [4, 5, 6, 1] + assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1] + assert logprobs.ranks == [40, 50, 60, 10] + assert logprobs.decoded_tokens == ["40", "50", "60", "10"] + + other_logprobs = FlattenLogprobs() + other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) + # Extend with another FlattenLogprobs + logprobs.extend(other_logprobs) + assert logprobs.start_indices == [0, 3, 4, 6] + assert logprobs.end_indices == [3, 4, 6, 7] + assert logprobs.token_ids == [4, 5, 6, 1, 2, 3, 1] + assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.1] + assert logprobs.ranks == [40, 50, 60, 10, 20, 30, 10] + assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] + + +def test_flatten_logprobs_access() -> None: + logprobs = FlattenLogprobs() + logprobs.extend( + [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] + ) + assert logprobs.start_indices == [0, 2, 5] + assert logprobs.end_indices == [2, 5, 6] + assert logprobs.token_ids == [2, 3, 4, 5, 6, 1] + assert logprobs.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6, 0.1] + assert logprobs.ranks == [20, 30, 40, 50, 60, 10] + assert logprobs.decoded_tokens == ["20", "30", "40", "50", "60", "10"] + + # Test __len__ + assert len(logprobs) == 3 + + # Test __iter__ + for actual_logprobs, expected_logprobs in zip( + logprobs, + [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0], + ): + assert actual_logprobs == expected_logprobs + + # Test __getitem__ : single item + assert logprobs[0] == LOGPROBS_ONE_POSITION_1 + assert logprobs[1] == LOGPROBS_ONE_POSITION_2 + assert logprobs[2] == LOGPROBS_ONE_POSITION_0 + + # Test __getitem__ : slice + logprobs02 = logprobs[:2] + assert len(logprobs02) == 2 + assert logprobs02[0] == LOGPROBS_ONE_POSITION_1 + assert logprobs02[1] == LOGPROBS_ONE_POSITION_2 + assert logprobs02.start_indices == [0, 2] + assert logprobs02.end_indices == [2, 5] + assert logprobs02.token_ids == [2, 3, 4, 5, 6] + assert logprobs02.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6] + assert logprobs02.ranks == [20, 30, 40, 50, 60] + assert logprobs02.decoded_tokens == ["20", "30", "40", "50", "60"] + logprobs_last2 = logprobs[-2:] + assert len(logprobs_last2) == 2 + assert logprobs_last2[0] == LOGPROBS_ONE_POSITION_2 + assert logprobs_last2[1] == LOGPROBS_ONE_POSITION_0 + assert logprobs_last2.start_indices == [0, 3] + assert logprobs_last2.end_indices == [3, 4] + assert logprobs_last2.token_ids == [4, 5, 6, 1] + assert logprobs_last2.logprobs == [0.4, 0.5, 0.6, 0.1] + assert logprobs_last2.ranks == [40, 50, 60, 10] + assert logprobs_last2.decoded_tokens == ["40", "50", "60", "10"] diff --git a/vllm/envs.py b/vllm/envs.py index 99f2ad2bc3d00..eb50ea6e5dbe5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -220,6 +220,7 @@ if TYPE_CHECKING: VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" + VLLM_FLATTEN_LOGPROBS: bool = False def get_default_cache_root(): @@ -1463,6 +1464,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), + # Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than + # the original list[dict[int, Logprob]] approach. + # After enabled, PromptLogprobs and SampleLogprobs would populated as + # FlattenLogprobs. + "VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/logprobs.py b/vllm/logprobs.py index 21c886e0ad5eb..bf66e5f75c795 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +import itertools +from collections.abc import Iterable, Iterator, MutableSequence +from dataclasses import dataclass, field +from typing import overload + +import vllm.envs as envs # We use dataclass for now because it is used for @@ -21,8 +26,183 @@ class Logprob: decoded_token: str | None = None +LogprobsOnePosition = dict[int, Logprob] + + +@dataclass +class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): + """ + Flatten logprobs of a request into multiple primitive type lists. + + Compared to list[dict[int, Logprob]], this data structure reduced GC + overhead significantly. As it flattened logprob information for + all positions and ranks in to multiple primitive type lists (i.e. + logprobs, token_ids, ranks per token_ids, decoded_tokens). + So regardless of the sequence length and top_logprobs setup, + FlattenLogprobs would only introduce a constant amount of objects. + + As each position might contains different amount of ranks, + start_indices_per_position would be used to access the logprob ranges + for different positions. + + NOTE: To reduce the migration overhead and improve backward compatibility, + we support the key Sequence APIs of list, so it could act as + list[LogprobsOnePosition] + """ + + # Start / end indices to indicate the range of logprobs for each position. + start_indices: list[int] = field(default_factory=list) + end_indices: list[int] = field(default_factory=list) + + # Flatten Logprob information for (each position, rank). + # For position , the logprobs are ranged + # from self.start_indices[i] to self.end_indices[i] (exclusive). + token_ids: list[int] = field(default_factory=list) + logprobs: list[float] = field(default_factory=list) + ranks: list[int | None] = field(default_factory=list) + decoded_tokens: list[str | None] = field(default_factory=list) + + def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None: + """Appends the container with logprobs for the next position""" + self.start_indices.append(len(self.logprobs)) + if logprobs_one_position: + for token_id, logprob in logprobs_one_position.items(): + self.token_ids.append(token_id) + self.logprobs.append(logprob.logprob) + self.ranks.append(logprob.rank) + self.decoded_tokens.append(logprob.decoded_token) + self.end_indices.append(len(self.logprobs)) + + def append_fast( + self, + token_ids: list[int], + logprobs: list[float], + ranks: itertools.chain[int], + decoded_tokens: Iterable[str | None], + ) -> None: + """ + Appends logprobs for the next position without creating + the intermediate logprob dictionary. + """ + self.start_indices.append(len(self.logprobs)) + for token_id, logprob, rank, decoded_token in zip( + token_ids, logprobs, ranks, decoded_tokens + ): + self.token_ids.append(token_id) + self.logprobs.append(logprob) + self.ranks.append(rank) + self.decoded_tokens.append(decoded_token) + self.end_indices.append(len(self.logprobs)) + + def extend(self, logprobs_multi_positions) -> None: + """Extends the container with logprobs for the next multiple positions""" + for logprobs_one_position in logprobs_multi_positions: + self.append(logprobs_one_position) + + def __len__(self) -> int: + """Gets number of positions stored in the container""" + return len(self.start_indices) + + @overload + def __getitem__(self, position: int) -> LogprobsOnePosition: ... + + @overload + def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... + + def __getitem__(self, index: int | slice): + """Extracts logprobs of a given position or slice""" + if isinstance(index, int): + return { + self.token_ids[i]: Logprob( + logprob=self.logprobs[i], + rank=self.ranks[i], + decoded_token=self.decoded_tokens[i], + ) + for i in range(self.start_indices[index], self.end_indices[index]) + } + elif isinstance(index, slice): + min_index = self.start_indices[index][0] + max_index = self.end_indices[index][-1] + return FlattenLogprobs( + # Shift updated start_indices and end_indices to + # be 0-indexed + start_indices=[i - min_index for i in self.start_indices[index]], + end_indices=[i - min_index for i in self.end_indices[index]], + token_ids=self.token_ids[min_index:max_index], + logprobs=self.logprobs[min_index:max_index], + ranks=self.ranks[min_index:max_index], + decoded_tokens=self.decoded_tokens[min_index:max_index], + ) + else: + raise TypeError(f"Invalid index type: {type(index)}") + + def __setitem__(self, item, value) -> None: + raise TypeError("Cannot set logprobs in FlattenLogprobs") + + def __delitem__(self, item) -> None: + raise TypeError("Cannot delete logprobs from FlattenLogprobs") + + def insert(self, item) -> None: + raise TypeError("Cannot insert logprobs to FlattenLogprobs") + + def __iter__(self) -> Iterator[LogprobsOnePosition]: + """ + Iterates the container and yields LogprobsOnePosition for + each position. + """ + for i in range(0, len(self.start_indices)): + yield self.__getitem__(i) + + # {token_id -> logprob} per each sequence group. None if the corresponding # sequence group doesn't require prompt logprob. -PromptLogprobs = list[dict[int, Logprob] | None] +PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] # {token_id -> logprob} for each sequence group. -SampleLogprobs = list[dict[int, Logprob]] +SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] + + +def create_prompt_logprobs() -> PromptLogprobs: + """Creates a container to store prompt logprobs for a request""" + logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + # NOTE: logprob of first prompt token is None. + logprobs.append(None) + return logprobs + + +def create_sample_logprobs() -> SampleLogprobs: + """Creates a container to store decode logprobs for a request""" + return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] + + +def append_logprobs_for_next_position( + request_logprobs: PromptLogprobs | SampleLogprobs, + token_ids: list[int], + logprobs: list[float], + decoded_tokens: Iterable[str | None], + rank: int, + num_logprobs: int, +) -> None: + """Appends logprobs for the next position""" + if num_logprobs == -1: + num_logprobs = len(logprobs) + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank,), topk_ranks) + + if isinstance(request_logprobs, FlattenLogprobs): + request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) + else: + request_logprobs.append( + { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip( + token_ids, logprobs, ranks, decoded_tokens + ) + } + ) diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index 48bb5312f5d94..4c5955d7ee2e5 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -2,11 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -from collections.abc import Iterable from dataclasses import dataclass from vllm.logger import init_logger -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import ( + PromptLogprobs, + SampleLogprobs, + append_logprobs_for_next_position, + create_prompt_logprobs, + create_sample_logprobs, +) from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_ids_list_to_tokens, @@ -44,9 +49,10 @@ class LogprobsProcessor: return cls( tokenizer=tokenizer, cumulative_logprob=(None if num_logprobs is None else 0.0), - logprobs=(None if num_logprobs is None else []), - # NOTE: logprob of first prompt token is None. - prompt_logprobs=(None if num_prompt_logprobs is None else [None]), + logprobs=(None if num_logprobs is None else create_sample_logprobs()), + prompt_logprobs=( + None if num_prompt_logprobs is None else create_prompt_logprobs() + ), num_prompt_logprobs=num_prompt_logprobs, num_logprobs=num_logprobs, ) @@ -80,15 +86,14 @@ class LogprobsProcessor: sampled_token_logprob = logprobs[0] self.cumulative_logprob += sampled_token_logprob - # Update with the Logprob dictionary for this pos. - self.logprobs.append( - self._make_logprob_dict( - logprobs, - token_ids, - decoded_tokens, - rank, - self.num_logprobs, - ) + # Update with the Logprob container for this pos. + append_logprobs_for_next_position( + self.logprobs, + token_ids, + logprobs, + decoded_tokens, + rank, + self.num_logprobs, ) def _update_prompt_logprobs( @@ -136,15 +141,14 @@ class LogprobsProcessor: NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] ) - # Update with the Logprob dictionary for this pos. - self.prompt_logprobs.append( - self._make_logprob_dict( - prompt_logprobs[pos], - token_ids[pos], - decoded_tokens_for_pos, - prompt_token_ranks[pos], - self.num_prompt_logprobs, - ) + # Update with the Logprob container for this pos. + append_logprobs_for_next_position( + self.prompt_logprobs, + token_ids[pos], + prompt_logprobs[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + self.num_prompt_logprobs, ) def pop_prompt_logprobs(self) -> PromptLogprobs | None: @@ -166,46 +170,6 @@ class LogprobsProcessor: self.prompt_logprobs = [] return plp - @staticmethod - def _make_logprob_dict( - logprobs: list[float], - logprob_token_ids: list[int], - decoded_tokens: Iterable[str | None], - rank: int, - num_logprobs: int, - ) -> dict[int, Logprob]: - """Make a Logprob dictionary for a position. - - Args: - logprobs: list of log probabilities - logprob_token_ids: list of top token ids - decoded_tokens: list of decoded top tokens - rank: rank of the sampled token - num_logprobs: number of logprobs requested - by the user (in addition to sampled logprob) - - Returns: - dict[token id, Logprob] - """ - if num_logprobs == -1: - num_logprobs = len(logprobs) - # We do not need a special case for the sampled token - # being in the topk, since inserting duplicated data - # into a dictionary twice is the same as doing it once. - topk_ranks = range(1, num_logprobs + 1) - ranks = itertools.chain((rank,), topk_ranks) - - return { - token_id: Logprob( - logprob=logprob, - rank=rank, - decoded_token=token, - ) - for token_id, logprob, rank, token in zip( - logprob_token_ids, logprobs, ranks, decoded_tokens - ) - } - def update_from_output(self, output: EngineCoreOutput) -> None: if output.new_logprobs is not None: self._update_sample_logprobs(output.new_logprobs)