mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 06:37:04 +08:00
[Misc] FlattenLogprobs -> FlatLogprobs (#28335)
This commit is contained in:
parent
57201a6a4c
commit
8d706cca90
@ -4,7 +4,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.logprobs import FlattenLogprobs
|
from vllm.logprobs import FlatLogprobs
|
||||||
|
|
||||||
MODELS = ["distilbert/distilgpt2"]
|
MODELS = ["distilbert/distilgpt2"]
|
||||||
MAX_TOKENS = 5
|
MAX_TOKENS = 5
|
||||||
@ -16,17 +16,17 @@ MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
|
|||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("greedy", [True, False])
|
@pytest.mark.parametrize("greedy", [True, False])
|
||||||
@pytest.mark.parametrize("flatten_logprobs", [True, False])
|
@pytest.mark.parametrize("flat_logprobs", [True, False])
|
||||||
def test_ranks(
|
def test_ranks(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
model,
|
model,
|
||||||
dtype,
|
dtype,
|
||||||
greedy,
|
greedy,
|
||||||
flatten_logprobs,
|
flat_logprobs,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
|
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1" if flat_logprobs else "0")
|
||||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||||
tokenizer = vllm_model.llm.get_tokenizer()
|
tokenizer = vllm_model.llm.get_tokenizer()
|
||||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||||
@ -44,12 +44,8 @@ def test_ranks(
|
|||||||
decode_tokens, _, decode_logprobs, prompt_logprobs = result
|
decode_tokens, _, decode_logprobs, prompt_logprobs = result
|
||||||
|
|
||||||
# Ensure the return type of logprobs is accurate
|
# Ensure the return type of logprobs is accurate
|
||||||
assert isinstance(
|
assert isinstance(prompt_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||||
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
|
assert isinstance(decode_logprobs, FlatLogprobs if flat_logprobs else list)
|
||||||
)
|
|
||||||
assert isinstance(
|
|
||||||
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
|
|
||||||
)
|
|
||||||
|
|
||||||
########################
|
########################
|
||||||
# Check prompt logprobs
|
# Check prompt logprobs
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.logprobs import (
|
from vllm.logprobs import (
|
||||||
FlattenLogprobs,
|
FlatLogprobs,
|
||||||
Logprob,
|
Logprob,
|
||||||
LogprobsOnePosition,
|
LogprobsOnePosition,
|
||||||
append_logprobs_for_next_position,
|
append_logprobs_for_next_position,
|
||||||
@ -14,8 +14,8 @@ from vllm.logprobs import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_create_logprobs_non_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||||
|
|
||||||
prompt_logprobs = create_prompt_logprobs()
|
prompt_logprobs = create_prompt_logprobs()
|
||||||
assert isinstance(prompt_logprobs, list)
|
assert isinstance(prompt_logprobs, list)
|
||||||
@ -28,11 +28,11 @@ def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert len(sample_logprobs) == 0
|
assert len(sample_logprobs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_create_logprobs_flat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||||
|
|
||||||
prompt_logprobs = create_prompt_logprobs()
|
prompt_logprobs = create_prompt_logprobs()
|
||||||
assert isinstance(prompt_logprobs, FlattenLogprobs)
|
assert isinstance(prompt_logprobs, FlatLogprobs)
|
||||||
assert prompt_logprobs.start_indices == [0]
|
assert prompt_logprobs.start_indices == [0]
|
||||||
assert prompt_logprobs.end_indices == [0]
|
assert prompt_logprobs.end_indices == [0]
|
||||||
assert len(prompt_logprobs.token_ids) == 0
|
assert len(prompt_logprobs.token_ids) == 0
|
||||||
@ -44,7 +44,7 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert prompt_logprobs[0] == dict()
|
assert prompt_logprobs[0] == dict()
|
||||||
|
|
||||||
sample_logprobs = create_sample_logprobs()
|
sample_logprobs = create_sample_logprobs()
|
||||||
assert isinstance(sample_logprobs, FlattenLogprobs)
|
assert isinstance(sample_logprobs, FlatLogprobs)
|
||||||
assert len(sample_logprobs.start_indices) == 0
|
assert len(sample_logprobs.start_indices) == 0
|
||||||
assert len(sample_logprobs.end_indices) == 0
|
assert len(sample_logprobs.end_indices) == 0
|
||||||
assert len(sample_logprobs.token_ids) == 0
|
assert len(sample_logprobs.token_ids) == 0
|
||||||
@ -54,10 +54,10 @@ def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert len(sample_logprobs) == 0
|
assert len(sample_logprobs) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_append_logprobs_for_next_position_none_flatten(
|
def test_append_logprobs_for_next_position_none_flat(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "0")
|
||||||
logprobs = create_sample_logprobs()
|
logprobs = create_sample_logprobs()
|
||||||
append_logprobs_for_next_position(
|
append_logprobs_for_next_position(
|
||||||
logprobs,
|
logprobs,
|
||||||
@ -85,10 +85,10 @@ def test_append_logprobs_for_next_position_none_flatten(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_append_logprobs_for_next_position_flatten(
|
def test_append_logprobs_for_next_position_flat(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
monkeypatch.setenv("VLLM_FLAT_LOGPROBS", "1")
|
||||||
logprobs = create_sample_logprobs()
|
logprobs = create_sample_logprobs()
|
||||||
append_logprobs_for_next_position(
|
append_logprobs_for_next_position(
|
||||||
logprobs,
|
logprobs,
|
||||||
@ -106,7 +106,7 @@ def test_append_logprobs_for_next_position_flatten(
|
|||||||
rank=11,
|
rank=11,
|
||||||
num_logprobs=-1,
|
num_logprobs=-1,
|
||||||
)
|
)
|
||||||
assert isinstance(logprobs, FlattenLogprobs)
|
assert isinstance(logprobs, FlatLogprobs)
|
||||||
assert logprobs.start_indices == [0, 1]
|
assert logprobs.start_indices == [0, 1]
|
||||||
assert logprobs.end_indices == [1, 3]
|
assert logprobs.end_indices == [1, 3]
|
||||||
assert logprobs.token_ids == [1, 2, 3]
|
assert logprobs.token_ids == [1, 2, 3]
|
||||||
@ -129,8 +129,8 @@ LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_flatten_logprobs_append() -> None:
|
def test_flat_logprobs_append() -> None:
|
||||||
logprobs = FlattenLogprobs()
|
logprobs = FlatLogprobs()
|
||||||
logprobs.append(LOGPROBS_ONE_POSITION_0)
|
logprobs.append(LOGPROBS_ONE_POSITION_0)
|
||||||
logprobs.append(LOGPROBS_ONE_POSITION_1)
|
logprobs.append(LOGPROBS_ONE_POSITION_1)
|
||||||
assert logprobs.start_indices == [0, 1]
|
assert logprobs.start_indices == [0, 1]
|
||||||
@ -149,8 +149,8 @@ def test_flatten_logprobs_append() -> None:
|
|||||||
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
|
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
|
||||||
|
|
||||||
|
|
||||||
def test_flatten_logprobs_extend() -> None:
|
def test_flat_logprobs_extend() -> None:
|
||||||
logprobs = FlattenLogprobs()
|
logprobs = FlatLogprobs()
|
||||||
# Extend with list[LogprobsOnePosition]
|
# Extend with list[LogprobsOnePosition]
|
||||||
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
|
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
|
||||||
assert logprobs.start_indices == [0, 3]
|
assert logprobs.start_indices == [0, 3]
|
||||||
@ -160,9 +160,9 @@ def test_flatten_logprobs_extend() -> None:
|
|||||||
assert logprobs.ranks == [40, 50, 60, 10]
|
assert logprobs.ranks == [40, 50, 60, 10]
|
||||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
|
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
|
||||||
|
|
||||||
other_logprobs = FlattenLogprobs()
|
other_logprobs = FlatLogprobs()
|
||||||
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
|
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
|
||||||
# Extend with another FlattenLogprobs
|
# Extend with another FlatLogprobs
|
||||||
logprobs.extend(other_logprobs)
|
logprobs.extend(other_logprobs)
|
||||||
assert logprobs.start_indices == [0, 3, 4, 6]
|
assert logprobs.start_indices == [0, 3, 4, 6]
|
||||||
assert logprobs.end_indices == [3, 4, 6, 7]
|
assert logprobs.end_indices == [3, 4, 6, 7]
|
||||||
@ -172,8 +172,8 @@ def test_flatten_logprobs_extend() -> None:
|
|||||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
|
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
|
||||||
|
|
||||||
|
|
||||||
def test_flatten_logprobs_access() -> None:
|
def test_flat_logprobs_access() -> None:
|
||||||
logprobs = FlattenLogprobs()
|
logprobs = FlatLogprobs()
|
||||||
logprobs.extend(
|
logprobs.extend(
|
||||||
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
|
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -223,7 +223,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_GC_DEBUG: str = ""
|
VLLM_GC_DEBUG: str = ""
|
||||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||||
VLLM_FLATTEN_LOGPROBS: bool = False
|
VLLM_FLAT_LOGPROBS: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||||
),
|
),
|
||||||
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
|
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
|
||||||
# the original list[dict[int, Logprob]] approach.
|
# the original list[dict[int, Logprob]] approach.
|
||||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
||||||
# FlattenLogprobs.
|
# FlatLogprobs.
|
||||||
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
|
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# --8<-- [end:env-vars-definition]
|
# --8<-- [end:env-vars-definition]
|
||||||
|
|||||||
@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||||
"""
|
"""
|
||||||
Flatten logprobs of a request into multiple primitive type lists.
|
Flat logprobs of a request into multiple primitive type lists.
|
||||||
|
|
||||||
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
||||||
overhead significantly. As it flattened logprob information for
|
overhead significantly. As it flattened logprob information for
|
||||||
all positions and ranks in to multiple primitive type lists (i.e.
|
all positions and ranks in to multiple primitive type lists (i.e.
|
||||||
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
||||||
So regardless of the sequence length and top_logprobs setup,
|
So regardless of the sequence length and top_logprobs setup,
|
||||||
FlattenLogprobs would only introduce a constant amount of objects.
|
FlatLogprobs would only introduce a constant amount of objects.
|
||||||
|
|
||||||
As each position might contains different amount of ranks,
|
As each position might contains different amount of ranks,
|
||||||
start_indices_per_position would be used to access the logprob ranges
|
start_indices_per_position would be used to access the logprob ranges
|
||||||
@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
|||||||
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
|
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
|
||||||
|
|
||||||
def __getitem__(self, index: int | slice):
|
def __getitem__(self, index: int | slice):
|
||||||
"""Extracts logprobs of a given position or slice"""
|
"""Extracts logprobs of a given position or slice"""
|
||||||
@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
|||||||
elif isinstance(index, slice):
|
elif isinstance(index, slice):
|
||||||
min_index = self.start_indices[index][0]
|
min_index = self.start_indices[index][0]
|
||||||
max_index = self.end_indices[index][-1]
|
max_index = self.end_indices[index][-1]
|
||||||
return FlattenLogprobs(
|
return FlatLogprobs(
|
||||||
# Shift updated start_indices and end_indices to
|
# Shift updated start_indices and end_indices to
|
||||||
# be 0-indexed
|
# be 0-indexed
|
||||||
start_indices=[i - min_index for i in self.start_indices[index]],
|
start_indices=[i - min_index for i in self.start_indices[index]],
|
||||||
@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
|||||||
raise TypeError(f"Invalid index type: {type(index)}")
|
raise TypeError(f"Invalid index type: {type(index)}")
|
||||||
|
|
||||||
def __setitem__(self, item, value) -> None:
|
def __setitem__(self, item, value) -> None:
|
||||||
raise TypeError("Cannot set logprobs in FlattenLogprobs")
|
raise TypeError("Cannot set logprobs in FlatLogprobs")
|
||||||
|
|
||||||
def __delitem__(self, item) -> None:
|
def __delitem__(self, item) -> None:
|
||||||
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
|
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
||||||
|
|
||||||
def insert(self, item) -> None:
|
def insert(self, item) -> None:
|
||||||
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
|
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||||
"""
|
"""
|
||||||
@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
|||||||
|
|
||||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||||
# sequence group doesn't require prompt logprob.
|
# sequence group doesn't require prompt logprob.
|
||||||
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
|
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
||||||
# {token_id -> logprob} for each sequence group.
|
# {token_id -> logprob} for each sequence group.
|
||||||
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
|
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||||
|
|
||||||
|
|
||||||
def create_prompt_logprobs() -> PromptLogprobs:
|
def create_prompt_logprobs() -> PromptLogprobs:
|
||||||
"""Creates a container to store prompt logprobs for a request"""
|
"""Creates a container to store prompt logprobs for a request"""
|
||||||
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||||
# NOTE: logprob of first prompt token is None.
|
# NOTE: logprob of first prompt token is None.
|
||||||
logprobs.append(None)
|
logprobs.append(None)
|
||||||
return logprobs
|
return logprobs
|
||||||
@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
|
|||||||
|
|
||||||
def create_sample_logprobs() -> SampleLogprobs:
|
def create_sample_logprobs() -> SampleLogprobs:
|
||||||
"""Creates a container to store decode logprobs for a request"""
|
"""Creates a container to store decode logprobs for a request"""
|
||||||
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||||
|
|
||||||
|
|
||||||
def append_logprobs_for_next_position(
|
def append_logprobs_for_next_position(
|
||||||
@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
|
|||||||
topk_ranks = range(1, num_logprobs + 1)
|
topk_ranks = range(1, num_logprobs + 1)
|
||||||
ranks = itertools.chain((rank,), topk_ranks)
|
ranks = itertools.chain((rank,), topk_ranks)
|
||||||
|
|
||||||
if isinstance(request_logprobs, FlattenLogprobs):
|
if isinstance(request_logprobs, FlatLogprobs):
|
||||||
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
||||||
else:
|
else:
|
||||||
request_logprobs.append(
|
request_logprobs.append(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user