diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index c9d227599cde..ea40c4802720 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -24,9 +24,7 @@ def test_ranks( greedy, flat_logprobs, example_prompts, - monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0") with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model: tokenizer = vllm_model.llm.get_tokenizer() example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts] @@ -36,6 +34,7 @@ def test_ranks( max_tokens=MAX_TOKENS, logprobs=NUM_TOP_LOGPROBS, prompt_logprobs=NUM_PROMPT_LOGPROBS, + flat_logprobs=flat_logprobs, ) results = vllm_model.generate_w_logprobs(example_prompts, sampling_params) diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py index d26a460d2bca..75e9d337aa24 100644 --- a/tests/test_logprobs.py +++ b/tests/test_logprobs.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - from vllm.logprobs import ( FlatLogprobs, Logprob, @@ -14,24 +12,20 @@ from vllm.logprobs import ( ) -def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") - - prompt_logprobs = create_prompt_logprobs() +def test_create_logprobs_non_flat() -> None: + prompt_logprobs = create_prompt_logprobs(flat_logprobs=False) assert isinstance(prompt_logprobs, list) # Ensure first prompt position logprobs is None assert len(prompt_logprobs) == 1 assert prompt_logprobs[0] is None - sample_logprobs = create_sample_logprobs() + sample_logprobs = create_sample_logprobs(flat_logprobs=False) assert isinstance(sample_logprobs, list) assert len(sample_logprobs) == 0 -def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") - - prompt_logprobs = create_prompt_logprobs() +def test_create_logprobs_flat() -> None: + prompt_logprobs = create_prompt_logprobs(flat_logprobs=True) assert isinstance(prompt_logprobs, FlatLogprobs) assert prompt_logprobs.start_indices == [0] assert prompt_logprobs.end_indices == [0] @@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: assert len(prompt_logprobs) == 1 assert prompt_logprobs[0] == dict() - sample_logprobs = create_sample_logprobs() + sample_logprobs = create_sample_logprobs(flat_logprobs=True) assert isinstance(sample_logprobs, FlatLogprobs) assert len(sample_logprobs.start_indices) == 0 assert len(sample_logprobs.end_indices) == 0 @@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None: assert len(sample_logprobs) == 0 -def test_append_logprobs_for_next_position_none_flat( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0") - logprobs = create_sample_logprobs() +def test_append_logprobs_for_next_position_none_flat() -> None: + logprobs = create_sample_logprobs(flat_logprobs=False) append_logprobs_for_next_position( logprobs, token_ids=[1], @@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat( ] -def test_append_logprobs_for_next_position_flat( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1") - logprobs = create_sample_logprobs() +def test_append_logprobs_for_next_position_flat() -> None: + logprobs = create_sample_logprobs(flat_logprobs=True) append_logprobs_for_next_position( logprobs, token_ids=[1], diff --git a/vllm/envs.py b/vllm/envs.py index 6bf05803e14e..62b3344ccd85 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -225,7 +225,6 @@ if TYPE_CHECKING: VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" - VLLM_FLAT_LOGPROBS: bool = False def get_default_cache_root(): @@ -1499,11 +1498,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices( "VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"] ), - # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than - # the original list[dict[int, Logprob]] approach. - # After enabled, PromptLogprobs and SampleLogprobs would populated as - # FlatLogprobs. - "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/logprobs.py b/vllm/logprobs.py index a34398db2c96..6a820308f523 100644 --- a/vllm/logprobs.py +++ b/vllm/logprobs.py @@ -5,8 +5,6 @@ from collections.abc import Iterable, Iterator, MutableSequence from dataclasses import dataclass, field from typing import overload -import vllm.envs as envs - # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. @@ -161,17 +159,17 @@ PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None] SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] -def create_prompt_logprobs() -> PromptLogprobs: +def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs: """Creates a container to store prompt logprobs for a request""" - logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] + logprobs = FlatLogprobs() if flat_logprobs else [] # NOTE: logprob of first prompt token is None. logprobs.append(None) return logprobs -def create_sample_logprobs() -> SampleLogprobs: +def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs: """Creates a container to store decode logprobs for a request""" - return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else [] + return FlatLogprobs() if flat_logprobs else [] def append_logprobs_for_next_position( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 901d66163452..0fb1d67687c8 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -204,6 +204,12 @@ class SamplingParams( prompt_logprobs: int | None = None """Number of log probabilities to return per prompt token. When set to -1, return all `vocab_size` log probabilities.""" + flat_logprobs: bool = False + """Whether to return logprobs in flatten format (i.e. FlatLogprob) + for better performance. + NOTE: GC costs of FlatLogprobs is significantly smaller than + list[dict[int, Logprob]]. After enabled, PromptLogprobs and + SampleLogprobs would populated as FlatLogprobs.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. diff --git a/vllm/v1/engine/logprobs.py b/vllm/v1/engine/logprobs.py index b618d2347265..63064a2c65d6 100644 --- a/vllm/v1/engine/logprobs.py +++ b/vllm/v1/engine/logprobs.py @@ -43,15 +43,22 @@ class LogprobsProcessor: tokenizer: AnyTokenizer | None, request: EngineCoreRequest, ) -> "LogprobsProcessor": - assert request.sampling_params is not None - num_logprobs = request.sampling_params.logprobs - num_prompt_logprobs = request.sampling_params.prompt_logprobs + sampling_params = request.sampling_params + assert sampling_params is not None + num_logprobs = sampling_params.logprobs + num_prompt_logprobs = sampling_params.prompt_logprobs return cls( tokenizer=tokenizer, cumulative_logprob=(None if num_logprobs is None else 0.0), - logprobs=(None if num_logprobs is None else create_sample_logprobs()), + logprobs=( + None + if num_logprobs is None + else create_sample_logprobs(sampling_params.flat_logprobs) + ), prompt_logprobs=( - None if num_prompt_logprobs is None else create_prompt_logprobs() + None + if num_prompt_logprobs is None + else create_prompt_logprobs(sampling_params.flat_logprobs) ), num_prompt_logprobs=num_prompt_logprobs, num_logprobs=num_logprobs,