[Misc] FlattenLogprobs -> FlatLogprobs (#28335)

This commit is contained in:
Zhuohan Li 2025-11-10 19:41:23 -08:00 committed by GitHub
parent 57201a6a4c
commit 8d706cca90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 47 deletions

View File

@ -4,7 +4,7 @@
import pytest
from vllm import SamplingParams
from vllm.logprobs import FlattenLogprobs
from vllm.logprobs import FlatLogprobs
MODELS = ["distilbert/distilgpt2"]
MAX_TOKENS = 5
@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("greedy", [True, False])
@pytest.mark.parametrize("flatten_logprobs", [True, False])
@pytest.mark.parametrize("flat_logprobs", [True, False])
def test_ranks(
vllm_runner,
model,
dtype,
greedy,
flatten_logprobs,
flat_logprobs,
example_prompts,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@ -44,12 +44,8 @@ def test_ranks(
decode_tokens, _, decode_logprobs, prompt_logprobs = result
# Ensure the return type of logprobs is accurate
assert isinstance(
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
)
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
########################
# Check prompt logprobs

View File

@ -5,7 +5,7 @@
import pytest
from vllm.logprobs import (
FlattenLogprobs,
FlatLogprobs,
Logprob,
LogprobsOnePosition,
append_logprobs_for_next_position,
@ -14,8 +14,8 @@ from vllm.logprobs import (
)
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, list)
@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlattenLogprobs)
assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0]
assert len(prompt_logprobs.token_ids) == 0
@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert prompt_logprobs[0] == dict()
sample_logprobs = create_sample_logprobs()
assert isinstance(sample_logprobs, FlattenLogprobs)
assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0
assert len(sample_logprobs.token_ids) == 0
@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0
def test_append_logprobs_for_next_position_none_flatten(
def test_append_logprobs_for_next_position_none_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
]
def test_append_logprobs_for_next_position_flatten(
def test_append_logprobs_for_next_position_flat(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position(
logprobs,
@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
rank=11,
num_logprobs=-1,
)
assert isinstance(logprobs, FlattenLogprobs)
assert isinstance(logprobs, FlatLogprobs)
assert logprobs.start_indices == [0, 1]
assert logprobs.end_indices == [1, 3]
assert logprobs.token_ids == [1, 2, 3]
@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
}
def test_flatten_logprobs_append() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_append() -> None:
logprobs = FlatLogprobs()
logprobs.append(LOGPROBS_ONE_POSITION_0)
logprobs.append(LOGPROBS_ONE_POSITION_1)
assert logprobs.start_indices == [0, 1]
@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
def test_flatten_logprobs_extend() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_extend() -> None:
logprobs = FlatLogprobs()
# Extend with list[LogprobsOnePosition]
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
assert logprobs.start_indices == [0, 3]
@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.ranks == [40, 50, 60, 10]
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
other_logprobs = FlattenLogprobs()
other_logprobs = FlatLogprobs()
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
# Extend with another FlattenLogprobs
# Extend with another FlatLogprobs
logprobs.extend(other_logprobs)
assert logprobs.start_indices == [0, 3, 4, 6]
assert logprobs.end_indices == [3, 4, 6, 7]
@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
def test_flatten_logprobs_access() -> None:
logprobs = FlattenLogprobs()
def test_flat_logprobs_access() -> None:
logprobs = FlatLogprobs()
logprobs.extend(
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
)

View File

@ -223,7 +223,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLATTEN_LOGPROBS: bool = False
VLLM_FLAT_LOGPROBS: bool = False
def get_default_cache_root():
@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
# FlatLogprobs.
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
}
# --8<-- [end:env-vars-definition]

View File

@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flatten logprobs of a request into multiple primitive type lists.
Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
@overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlattenLogprobs(
return FlatLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
raise TypeError(f"Invalid index type: {type(index)}")
def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs")
raise TypeError("Cannot set logprobs in FlatLogprobs")
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
def append_logprobs_for_next_position(
@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
if isinstance(request_logprobs, FlattenLogprobs):
if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(