mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 05:44:59 +08:00
[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
da94c7c0eb
commit
40b6b38f2c
@ -24,9 +24,7 @@ def test_ranks(
|
|||||||
greedy,
|
greedy,
|
||||||
flat_logprobs,
|
flat_logprobs,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
|
|
||||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||||
tokenizer = vllm_model.llm.get_tokenizer()
|
tokenizer = vllm_model.llm.get_tokenizer()
|
||||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||||
@ -36,6 +34,7 @@ def test_ranks(
|
|||||||
max_tokens=MAX_TOKENS,
|
max_tokens=MAX_TOKENS,
|
||||||
logprobs=NUM_TOP_LOGPROBS,
|
logprobs=NUM_TOP_LOGPROBS,
|
||||||
prompt_logprobs=NUM_PROMPT_LOGPROBS,
|
prompt_logprobs=NUM_PROMPT_LOGPROBS,
|
||||||
|
flat_logprobs=flat_logprobs,
|
||||||
)
|
)
|
||||||
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.logprobs import (
|
from vllm.logprobs import (
|
||||||
FlatLogprobs,
|
FlatLogprobs,
|
||||||
Logprob,
|
Logprob,
|
||||||
@ -14,24 +12,20 @@ from vllm.logprobs import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_create_logprobs_non_flat() -> None:
|
||||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
prompt_logprobs = create_prompt_logprobs(flat_logprobs=False)
|
||||||
|
|
||||||
prompt_logprobs = create_prompt_logprobs()
|
|
||||||
assert isinstance(prompt_logprobs, list)
|
assert isinstance(prompt_logprobs, list)
|
||||||
# Ensure first prompt position logprobs is None
|
# Ensure first prompt position logprobs is None
|
||||||
assert len(prompt_logprobs) == 1
|
assert len(prompt_logprobs) == 1
|
||||||
assert prompt_logprobs[0] is None
|
assert prompt_logprobs[0] is None
|
||||||
|
|
||||||
sample_logprobs = create_sample_logprobs()
|
sample_logprobs = create_sample_logprobs(flat_logprobs=False)
|
||||||
assert isinstance(sample_logprobs, list)
|
assert isinstance(sample_logprobs, list)
|
||||||
assert len(sample_logprobs) == 0
|
assert len(sample_logprobs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_create_logprobs_flat() -> None:
|
||||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
prompt_logprobs = create_prompt_logprobs(flat_logprobs=True)
|
||||||
|
|
||||||
prompt_logprobs = create_prompt_logprobs()
|
|
||||||
assert isinstance(prompt_logprobs, FlatLogprobs)
|
assert isinstance(prompt_logprobs, FlatLogprobs)
|
||||||
assert prompt_logprobs.start_indices == [0]
|
assert prompt_logprobs.start_indices == [0]
|
||||||
assert prompt_logprobs.end_indices == [0]
|
assert prompt_logprobs.end_indices == [0]
|
||||||
@ -43,7 +37,7 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert len(prompt_logprobs) == 1
|
assert len(prompt_logprobs) == 1
|
||||||
assert prompt_logprobs[0] == dict()
|
assert prompt_logprobs[0] == dict()
|
||||||
|
|
||||||
sample_logprobs = create_sample_logprobs()
|
sample_logprobs = create_sample_logprobs(flat_logprobs=True)
|
||||||
assert isinstance(sample_logprobs, FlatLogprobs)
|
assert isinstance(sample_logprobs, FlatLogprobs)
|
||||||
assert len(sample_logprobs.start_indices) == 0
|
assert len(sample_logprobs.start_indices) == 0
|
||||||
assert len(sample_logprobs.end_indices) == 0
|
assert len(sample_logprobs.end_indices) == 0
|
||||||
@ -54,11 +48,8 @@ def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert len(sample_logprobs) == 0
|
assert len(sample_logprobs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_append_logprobs_for_next_position_none_flat(
|
def test_append_logprobs_for_next_position_none_flat() -> None:
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
logprobs = create_sample_logprobs(flat_logprobs=False)
|
||||||
) -> None:
|
|
||||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
|
||||||
logprobs = create_sample_logprobs()
|
|
||||||
append_logprobs_for_next_position(
|
append_logprobs_for_next_position(
|
||||||
logprobs,
|
logprobs,
|
||||||
token_ids=[1],
|
token_ids=[1],
|
||||||
@ -85,11 +76,8 @@ def test_append_logprobs_for_next_position_none_flat(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_append_logprobs_for_next_position_flat(
|
def test_append_logprobs_for_next_position_flat() -> None:
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
logprobs = create_sample_logprobs(flat_logprobs=True)
|
||||||
) -> None:
|
|
||||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
|
||||||
logprobs = create_sample_logprobs()
|
|
||||||
append_logprobs_for_next_position(
|
append_logprobs_for_next_position(
|
||||||
logprobs,
|
logprobs,
|
||||||
token_ids=[1],
|
token_ids=[1],
|
||||||
|
|||||||
@ -225,7 +225,6 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||||
VLLM_FLAT_LOGPROBS: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1499,11 +1498,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||||
),
|
),
|
||||||
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
|
|
||||||
# the original list[dict[int, Logprob]] approach.
|
|
||||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
|
||||||
# FlatLogprobs.
|
|
||||||
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -5,8 +5,6 @@ from collections.abc import Iterable, Iterator, MutableSequence
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import overload
|
from typing import overload
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
|
|
||||||
|
|
||||||
# We use dataclass for now because it is used for
|
# We use dataclass for now because it is used for
|
||||||
# openai server output, and msgspec is not serializable.
|
# openai server output, and msgspec is not serializable.
|
||||||
@ -161,17 +159,17 @@ PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
|||||||
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||||
|
|
||||||
|
|
||||||
def create_prompt_logprobs() -> PromptLogprobs:
|
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
|
||||||
"""Creates a container to store prompt logprobs for a request"""
|
"""Creates a container to store prompt logprobs for a request"""
|
||||||
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
logprobs = FlatLogprobs() if flat_logprobs else []
|
||||||
# NOTE: logprob of first prompt token is None.
|
# NOTE: logprob of first prompt token is None.
|
||||||
logprobs.append(None)
|
logprobs.append(None)
|
||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
def create_sample_logprobs() -> SampleLogprobs:
|
def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
|
||||||
"""Creates a container to store decode logprobs for a request"""
|
"""Creates a container to store decode logprobs for a request"""
|
||||||
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
return FlatLogprobs() if flat_logprobs else []
|
||||||
|
|
||||||
|
|
||||||
def append_logprobs_for_next_position(
|
def append_logprobs_for_next_position(
|
||||||
|
|||||||
@ -204,6 +204,12 @@ class SamplingParams(
|
|||||||
prompt_logprobs: int | None = None
|
prompt_logprobs: int | None = None
|
||||||
"""Number of log probabilities to return per prompt token.
|
"""Number of log probabilities to return per prompt token.
|
||||||
When set to -1, return all `vocab_size` log probabilities."""
|
When set to -1, return all `vocab_size` log probabilities."""
|
||||||
|
flat_logprobs: bool = False
|
||||||
|
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
|
||||||
|
for better performance.
|
||||||
|
NOTE: GC costs of FlatLogprobs is significantly smaller than
|
||||||
|
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
|
||||||
|
SampleLogprobs would populated as FlatLogprobs."""
|
||||||
# NOTE: This parameter is only exposed at the engine level for now.
|
# NOTE: This parameter is only exposed at the engine level for now.
|
||||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||||
# not support returning only a list of token IDs.
|
# not support returning only a list of token IDs.
|
||||||
|
|||||||
@ -43,15 +43,22 @@ class LogprobsProcessor:
|
|||||||
tokenizer: AnyTokenizer | None,
|
tokenizer: AnyTokenizer | None,
|
||||||
request: EngineCoreRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "LogprobsProcessor":
|
) -> "LogprobsProcessor":
|
||||||
assert request.sampling_params is not None
|
sampling_params = request.sampling_params
|
||||||
num_logprobs = request.sampling_params.logprobs
|
assert sampling_params is not None
|
||||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
num_logprobs = sampling_params.logprobs
|
||||||
|
num_prompt_logprobs = sampling_params.prompt_logprobs
|
||||||
return cls(
|
return cls(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
||||||
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
|
logprobs=(
|
||||||
|
None
|
||||||
|
if num_logprobs is None
|
||||||
|
else create_sample_logprobs(sampling_params.flat_logprobs)
|
||||||
|
),
|
||||||
prompt_logprobs=(
|
prompt_logprobs=(
|
||||||
None if num_prompt_logprobs is None else create_prompt_logprobs()
|
None
|
||||||
|
if num_prompt_logprobs is None
|
||||||
|
else create_prompt_logprobs(sampling_params.flat_logprobs)
|
||||||
),
|
),
|
||||||
num_prompt_logprobs=num_prompt_logprobs,
|
num_prompt_logprobs=num_prompt_logprobs,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user