[Misc] FlattenLogprobs -> FlatLogprobs (#28335)

This commit is contained in:
Zhuohan Li 2025-11-10 19:41:23 -08:00 committed by GitHub
parent 57201a6a4c
commit 8d706cca90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 47 deletions

View File

@ -4,7 +4,7 @@
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs from vllm.logprobs import FlatLogprobs
MODELS = ["distilbert/distilgpt2"] MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5 MAX_TOKENS = 5
@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("greedy", [True, False]) @pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flatten_logprobs", [True, False]) @pytest.mark.parametrize("flat_logprobs", [True, False])
def test_ranks( def test_ranks(
vllm_runner, vllm_runner,
model, model,
dtype, dtype,
greedy, greedy,
flatten_logprobs, flat_logprobs,
example_prompts, example_prompts,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
): ):
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0") monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer() tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@ -44,12 +44,8 @@ def test_ranks(
decode_tokens, _, decode_logprobs, prompt_logprobs = result decode_tokens, _, decode_logprobs, prompt_logprobs = result
# Ensure the return type of logprobs is accurate # Ensure the return type of logprobs is accurate
assert isinstance( assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
)
assert isinstance(
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
)
######################## ########################
# Check prompt logprobs # Check prompt logprobs

View File

@ -5,7 +5,7 @@
import pytest import pytest
from vllm.logprobs import ( from vllm.logprobs import (
FlattenLogprobs, FlatLogprobs,
Logprob, Logprob,
LogprobsOnePosition, LogprobsOnePosition,
append_logprobs_for_next_position, append_logprobs_for_next_position,
@ -14,8 +14,8 @@ from vllm.logprobs import (
) )
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None: def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
prompt_logprobs = create_prompt_logprobs() prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, list) assert isinstance(prompt_logprobs, list)
@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0 assert len(sample_logprobs) == 0
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None: def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
prompt_logprobs = create_prompt_logprobs() prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlattenLogprobs) assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0] assert prompt_logprobs.end_indices == [0]
assert len(prompt_logprobs.token_ids) == 0 assert len(prompt_logprobs.token_ids) == 0
@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert prompt_logprobs[0] == dict() assert prompt_logprobs[0] == dict()
sample_logprobs = create_sample_logprobs() sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, FlattenLogprobs) assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.end_indices) == 0
assert len(sample_logprobs.token_ids) == 0 assert len(sample_logprobs.token_ids) == 0
@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0 assert len(sample_logprobs) == 0
def test_append_logprobs_for_next_position_none_flatten( def test_append_logprobs_for_next_position_none_flat(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0") monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
logprobs = create_sample_logprobs() logprobs = create_sample_logprobs()
append_logprobs_for_next_position( append_logprobs_for_next_position(
logprobs, logprobs,
@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
] ]
def test_append_logprobs_for_next_position_flatten( def test_append_logprobs_for_next_position_flat(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1") monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
logprobs = create_sample_logprobs() logprobs = create_sample_logprobs()
append_logprobs_for_next_position( append_logprobs_for_next_position(
logprobs, logprobs,
@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank=11, rank=11,
num_logprobs=-1, num_logprobs=-1,
) )
assert isinstance(logprobs, FlattenLogprobs) assert isinstance(logprobs, FlatLogprobs)
assert logprobs.start_indices == [0, 1] assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3] assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3] assert logprobs.token_ids == [1, 2, 3]
@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
} }
def test_flatten_logprobs_append() -> None: def test_flat_logprobs_append() -> None:
logprobs = FlattenLogprobs() logprobs = FlatLogprobs()
logprobs.append(LOGPROBS_ONE_POSITION_0) logprobs.append(LOGPROBS_ONE_POSITION_0)
logprobs.append(LOGPROBS_ONE_POSITION_1) logprobs.append(LOGPROBS_ONE_POSITION_1)
assert logprobs.start_indices == [0, 1] assert logprobs.start_indices == [0, 1]
@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"] assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
def test_flatten_logprobs_extend() -> None: def test_flat_logprobs_extend() -> None:
logprobs = FlattenLogprobs() logprobs = FlatLogprobs()
# Extend with list[LogprobsOnePosition] # Extend with list[LogprobsOnePosition]
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]) logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
assert logprobs.start_indices == [0, 3] assert logprobs.start_indices == [0, 3]
@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.ranks == [40, 50, 60, 10] assert logprobs.ranks == [40, 50, 60, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10"] assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
other_logprobs = FlattenLogprobs() other_logprobs = FlatLogprobs()
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0]) other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
# Extend with another FlattenLogprobs # Extend with another FlatLogprobs
logprobs.extend(other_logprobs) logprobs.extend(other_logprobs)
assert logprobs.start_indices == [0, 3, 4, 6] assert logprobs.start_indices == [0, 3, 4, 6]
assert logprobs.end_indices == [3, 4, 6, 7] assert logprobs.end_indices == [3, 4, 6, 7]
@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"] assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
def test_flatten_logprobs_access() -> None: def test_flat_logprobs_access() -> None:
logprobs = FlattenLogprobs() logprobs = FlatLogprobs()
logprobs.extend( logprobs.extend(
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0] [LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
) )

View File

@ -223,7 +223,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG: str = "" VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLATTEN_LOGPROBS: bool = False VLLM_FLAT_LOGPROBS: bool = False
def get_default_cache_root(): def get_default_cache_root():
@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
), ),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach. # the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as # After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs. # FlatLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))), "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]

View File

@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass @dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]): class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
""" """
Flatten logprobs of a request into multiple primitive type lists. Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e. all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens). logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup, So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects. FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks, As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges start_indices_per_position would be used to access the logprob ranges
@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
def __getitem__(self, position: int) -> LogprobsOnePosition: ... def __getitem__(self, position: int) -> LogprobsOnePosition: ...
@overload @overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ... def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
def __getitem__(self, index: int | slice): def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice""" """Extracts logprobs of a given position or slice"""
@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
elif isinstance(index, slice): elif isinstance(index, slice):
min_index = self.start_indices[index][0] min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1] max_index = self.end_indices[index][-1]
return FlattenLogprobs( return FlatLogprobs(
# Shift updated start_indices and end_indices to # Shift updated start_indices and end_indices to
# be 0-indexed # be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]], start_indices=[i - min_index for i in self.start_indices[index]],
@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
raise TypeError(f"Invalid index type: {type(index)}") raise TypeError(f"Invalid index type: {type(index)}")
def __setitem__(self, item, value) -> None: def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs") raise TypeError("Cannot set logprobs in FlatLogprobs")
def __delitem__(self, item) -> None: def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs") raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None: def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs") raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]: def __iter__(self) -> Iterator[LogprobsOnePosition]:
""" """
@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
# {token_id -> logprob} per each sequence group. None if the corresponding # {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob. # sequence group doesn't require prompt logprob.
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None] PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group. # {token_id -> logprob} for each sequence group.
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition] SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs() -> PromptLogprobs: def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request""" """Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
# NOTE: logprob of first prompt token is None. # NOTE: logprob of first prompt token is None.
logprobs.append(None) logprobs.append(None)
return logprobs return logprobs
@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
def create_sample_logprobs() -> SampleLogprobs: def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request""" """Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else [] return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
def append_logprobs_for_next_position( def append_logprobs_for_next_position(
@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks = range(1, num_logprobs + 1) topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks) ranks = itertools.chain((rank,), topk_ranks)
if isinstance(request_logprobs, FlattenLogprobs): if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens) request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else: else:
request_logprobs.append( request_logprobs.append(