mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:47:18 +08:00
[Misc] FlattenLogprobs -> FlatLogprobs (#28335)
This commit is contained in:
parent
57201a6a4c
commit
8d706cca90
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import FlattenLogprobs
|
||||
from vllm.logprobs import FlatLogprobs
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
MAX_TOKENS = 5
|
||||
@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("greedy", [True, False])
|
||||
@pytest.mark.parametrize("flatten_logprobs", [True, False])
|
||||
@pytest.mark.parametrize("flat_logprobs", [True, False])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
greedy,
|
||||
flatten_logprobs,
|
||||
flat_logprobs,
|
||||
example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
|
||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||
@ -44,12 +44,8 @@ def test_ranks(
|
||||
decode_tokens, _, decode_logprobs, prompt_logprobs = result
|
||||
|
||||
# Ensure the return type of logprobs is accurate
|
||||
assert isinstance(
|
||||
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
|
||||
)
|
||||
assert isinstance(
|
||||
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
|
||||
)
|
||||
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||
|
||||
########################
|
||||
# Check prompt logprobs
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.logprobs import (
|
||||
FlattenLogprobs,
|
||||
FlatLogprobs,
|
||||
Logprob,
|
||||
LogprobsOnePosition,
|
||||
append_logprobs_for_next_position,
|
||||
@ -14,8 +14,8 @@ from vllm.logprobs import (
|
||||
)
|
||||
|
||||
|
||||
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
||||
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
assert isinstance(prompt_logprobs, list)
|
||||
@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
||||
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
assert isinstance(prompt_logprobs, FlattenLogprobs)
|
||||
assert isinstance(prompt_logprobs, FlatLogprobs)
|
||||
assert prompt_logprobs.start_indices == [0]
|
||||
assert prompt_logprobs.end_indices == [0]
|
||||
assert len(prompt_logprobs.token_ids) == 0
|
||||
@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert prompt_logprobs[0] == dict()
|
||||
|
||||
sample_logprobs = create_sample_logprobs()
|
||||
assert isinstance(sample_logprobs, FlattenLogprobs)
|
||||
assert isinstance(sample_logprobs, FlatLogprobs)
|
||||
assert len(sample_logprobs.start_indices) == 0
|
||||
assert len(sample_logprobs.end_indices) == 0
|
||||
assert len(sample_logprobs.token_ids) == 0
|
||||
@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_none_flatten(
|
||||
def test_append_logprobs_for_next_position_none_flat(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||
logprobs = create_sample_logprobs()
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
|
||||
]
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_flatten(
|
||||
def test_append_logprobs_for_next_position_flat(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
||||
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||
logprobs = create_sample_logprobs()
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
|
||||
rank=11,
|
||||
num_logprobs=-1,
|
||||
)
|
||||
assert isinstance(logprobs, FlattenLogprobs)
|
||||
assert isinstance(logprobs, FlatLogprobs)
|
||||
assert logprobs.start_indices == [0, 1]
|
||||
assert logprobs.end_indices == [1, 3]
|
||||
assert logprobs.token_ids == [1, 2, 3]
|
||||
@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
|
||||
}
|
||||
|
||||
|
||||
def test_flatten_logprobs_append() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
def test_flat_logprobs_append() -> None:
|
||||
logprobs = FlatLogprobs()
|
||||
logprobs.append(LOGPROBS_ONE_POSITION_0)
|
||||
logprobs.append(LOGPROBS_ONE_POSITION_1)
|
||||
assert logprobs.start_indices == [0, 1]
|
||||
@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
|
||||
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
|
||||
|
||||
|
||||
def test_flatten_logprobs_extend() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
def test_flat_logprobs_extend() -> None:
|
||||
logprobs = FlatLogprobs()
|
||||
# Extend with list[LogprobsOnePosition]
|
||||
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
|
||||
assert logprobs.start_indices == [0, 3]
|
||||
@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
|
||||
assert logprobs.ranks == [40, 50, 60, 10]
|
||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
|
||||
|
||||
other_logprobs = FlattenLogprobs()
|
||||
other_logprobs = FlatLogprobs()
|
||||
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
|
||||
# Extend with another FlattenLogprobs
|
||||
# Extend with another FlatLogprobs
|
||||
logprobs.extend(other_logprobs)
|
||||
assert logprobs.start_indices == [0, 3, 4, 6]
|
||||
assert logprobs.end_indices == [3, 4, 6, 7]
|
||||
@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
|
||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
|
||||
|
||||
|
||||
def test_flatten_logprobs_access() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
def test_flat_logprobs_access() -> None:
|
||||
logprobs = FlatLogprobs()
|
||||
logprobs.extend(
|
||||
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
|
||||
)
|
||||
|
||||
@ -223,7 +223,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GC_DEBUG: str = ""
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_FLATTEN_LOGPROBS: bool = False
|
||||
VLLM_FLAT_LOGPROBS: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||
),
|
||||
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
|
||||
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
|
||||
# the original list[dict[int, Logprob]] approach.
|
||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
||||
# FlattenLogprobs.
|
||||
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
|
||||
# FlatLogprobs.
|
||||
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
"""
|
||||
Flatten logprobs of a request into multiple primitive type lists.
|
||||
Flat logprobs of a request into multiple primitive type lists.
|
||||
|
||||
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
||||
overhead significantly. As it flattened logprob information for
|
||||
all positions and ranks in to multiple primitive type lists (i.e.
|
||||
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
||||
So regardless of the sequence length and top_logprobs setup,
|
||||
FlattenLogprobs would only introduce a constant amount of objects.
|
||||
FlatLogprobs would only introduce a constant amount of objects.
|
||||
|
||||
As each position might contains different amount of ranks,
|
||||
start_indices_per_position would be used to access the logprob ranges
|
||||
@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
|
||||
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
|
||||
|
||||
def __getitem__(self, index: int | slice):
|
||||
"""Extracts logprobs of a given position or slice"""
|
||||
@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
elif isinstance(index, slice):
|
||||
min_index = self.start_indices[index][0]
|
||||
max_index = self.end_indices[index][-1]
|
||||
return FlattenLogprobs(
|
||||
return FlatLogprobs(
|
||||
# Shift updated start_indices and end_indices to
|
||||
# be 0-indexed
|
||||
start_indices=[i - min_index for i in self.start_indices[index]],
|
||||
@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
raise TypeError(f"Invalid index type: {type(index)}")
|
||||
|
||||
def __setitem__(self, item, value) -> None:
|
||||
raise TypeError("Cannot set logprobs in FlattenLogprobs")
|
||||
raise TypeError("Cannot set logprobs in FlatLogprobs")
|
||||
|
||||
def __delitem__(self, item) -> None:
|
||||
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
|
||||
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
||||
|
||||
def insert(self, item) -> None:
|
||||
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
|
||||
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
||||
|
||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||
"""
|
||||
@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
|
||||
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
|
||||
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
|
||||
def create_prompt_logprobs() -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
|
||||
|
||||
def create_sample_logprobs() -> SampleLogprobs:
|
||||
"""Creates a container to store decode logprobs for a request"""
|
||||
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
|
||||
|
||||
def append_logprobs_for_next_position(
|
||||
@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
if isinstance(request_logprobs, FlattenLogprobs):
|
||||
if isinstance(request_logprobs, FlatLogprobs):
|
||||
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
||||
else:
|
||||
request_logprobs.append(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user