[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
Jialin Ouyang 2025-11-18 18:10:02 -08:00 committed by GitHub
parent da94c7c0eb
commit 40b6b38f2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 33 additions and 41 deletions

View File

@ -24,9 +24,7 @@ def test_ranks(
greedy, greedy,
flat_logprobs, flat_logprobs,
example_prompts, example_prompts,
monkeypatch: pytest.MonkeyPatch,
): ):
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
tokenizer = vllm_model.llm.get_tokenizer() tokenizer = vllm_model.llm.get_tokenizer()
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
@ -36,6 +34,7 @@ def test_ranks(
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
logprobs=NUM_TOP_LOGPROBS, logprobs=NUM_TOP_LOGPROBS,
prompt_logprobs=NUM_PROMPT_LOGPROBS, prompt_logprobs=NUM_PROMPT_LOGPROBS,
flat_logprobs=flat_logprobs,
) )
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params) results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)

View File

@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.logprobs import ( from vllm.logprobs import (
FlatLogprobs, FlatLogprobs,
Logprob, Logprob,
@ -14,24 +12,20 @@ from vllm.logprobs import (
) )
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: def test_create_logprobs_non_flat() -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, list) assert isinstance(prompt_logprobs, list)
# Ensure first prompt position logprobs is None # Ensure first prompt position logprobs is None
assert len(prompt_logprobs) == 1 assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] is None assert prompt_logprobs[0] is None
sample_logprobs = create_sample_logprobs() sample_logprobs = create_sample_logprobs(flat_logprobs=False)
assert isinstance(sample_logprobs, list) assert isinstance(sample_logprobs, list)
assert len(sample_logprobs) == 0 assert len(sample_logprobs) == 0
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: def test_create_logprobs_flat() -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
prompt_logprobs = create_prompt_logprobs()
assert isinstance(prompt_logprobs, FlatLogprobs) assert isinstance(prompt_logprobs, FlatLogprobs)
assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.start_indices == [0]
assert prompt_logprobs.end_indices == [0] assert prompt_logprobs.end_indices == [0]
@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(prompt_logprobs) == 1 assert len(prompt_logprobs) == 1
assert prompt_logprobs[0] == dict() assert prompt_logprobs[0] == dict()
sample_logprobs = create_sample_logprobs() sample_logprobs = create_sample_logprobs(flat_logprobs=True)
assert isinstance(sample_logprobs, FlatLogprobs) assert isinstance(sample_logprobs, FlatLogprobs)
assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.start_indices) == 0
assert len(sample_logprobs.end_indices) == 0 assert len(sample_logprobs.end_indices) == 0
@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sample_logprobs) == 0 assert len(sample_logprobs) == 0
def test_append_logprobs_for_next_position_none_flat( def test_append_logprobs_for_next_position_none_flat() -> None:
monkeypatch: pytest.MonkeyPatch, logprobs = create_sample_logprobs(flat_logprobs=False)
) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position( append_logprobs_for_next_position(
logprobs, logprobs,
token_ids=[1], token_ids=[1],
@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
] ]
def test_append_logprobs_for_next_position_flat( def test_append_logprobs_for_next_position_flat() -> None:
monkeypatch: pytest.MonkeyPatch, logprobs = create_sample_logprobs(flat_logprobs=True)
) -> None:
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
logprobs = create_sample_logprobs()
append_logprobs_for_next_position( append_logprobs_for_next_position(
logprobs, logprobs,
token_ids=[1], token_ids=[1],

View File

@ -225,7 +225,6 @@ if TYPE_CHECKING:
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLAT_LOGPROBS: bool = False
def get_default_cache_root(): def get_default_cache_root():
@ -1499,11 +1498,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
), ),
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlatLogprobs.
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]

View File

@ -5,8 +5,6 @@ from collections.abc import Iterable, Iterator, MutableSequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import overload from typing import overload
import vllm.envs as envs
# We use dataclass for now because it is used for # We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable. # openai server output, and msgspec is not serializable.
@ -161,17 +159,17 @@ PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs() -> PromptLogprobs: def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request""" """Creates a container to store prompt logprobs for a request"""
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] logprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None. # NOTE: logprob of first prompt token is None.
logprobs.append(None) logprobs.append(None)
return logprobs return logprobs
def create_sample_logprobs() -> SampleLogprobs: def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request""" """Creates a container to store decode logprobs for a request"""
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] return FlatLogprobs() if flat_logprobs else []
def append_logprobs_for_next_position( def append_logprobs_for_next_position(

View File

@ -204,6 +204,12 @@ class SamplingParams(
prompt_logprobs: int | None = None prompt_logprobs: int | None = None
"""Number of log probabilities to return per prompt token. """Number of log probabilities to return per prompt token.
When set to -1, return all `vocab_size` log probabilities.""" When set to -1, return all `vocab_size` log probabilities."""
flat_logprobs: bool = False
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
for better performance.
NOTE: GC costs of FlatLogprobs is significantly smaller than
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
SampleLogprobs would populated as FlatLogprobs."""
# NOTE: This parameter is only exposed at the engine level for now. # NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does # It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs. # not support returning only a list of token IDs.

View File

@ -43,15 +43,22 @@ class LogprobsProcessor:
tokenizer: AnyTokenizer | None, tokenizer: AnyTokenizer | None,
request: EngineCoreRequest, request: EngineCoreRequest,
) -> "LogprobsProcessor": ) -> "LogprobsProcessor":
assert request.sampling_params is not None sampling_params = request.sampling_params
num_logprobs = request.sampling_params.logprobs assert sampling_params is not None
num_prompt_logprobs = request.sampling_params.prompt_logprobs num_logprobs = sampling_params.logprobs
num_prompt_logprobs = sampling_params.prompt_logprobs
return cls( return cls(
tokenizer=tokenizer, tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.0), cumulative_logprob=(None if num_logprobs is None else 0.0),
logprobs=(None if num_logprobs is None else create_sample_logprobs()), logprobs=(
None
if num_logprobs is None
else create_sample_logprobs(sampling_params.flat_logprobs)
),
prompt_logprobs=( prompt_logprobs=(
None if num_prompt_logprobs is None else create_prompt_logprobs() None
if num_prompt_logprobs is None
else create_prompt_logprobs(sampling_params.flat_logprobs)
), ),
num_prompt_logprobs=num_prompt_logprobs, num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,