mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 02:07:03 +08:00
[Perf] Introduce FlattenLogprobs to store logprobs results to reduce GC overhead (#28171)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
21b82f4ea2
commit
ccd98b59c1
96
tests/samplers/test_logprobs.py
Normal file
96
tests/samplers/test_logprobs.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.logprobs import FlattenLogprobs
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
MAX_TOKENS = 5
|
||||
NUM_TOP_LOGPROBS = 5
|
||||
NUM_PROMPT_LOGPROBS = 7
|
||||
MAX_LOGPROBS = max(NUM_TOP_LOGPROBS, NUM_PROMPT_LOGPROBS)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("greedy", [True, False])
|
||||
@pytest.mark.parametrize("flatten_logprobs", [True, False])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
greedy,
|
||||
flatten_logprobs,
|
||||
example_prompts,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1" if flatten_logprobs else "0")
|
||||
with vllm_runner(model, dtype=dtype, max_logprobs=MAX_LOGPROBS) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
example_prompt_tokens = [tokenizer.encode(prompt) for prompt in example_prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0 if greedy else 1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
logprobs=NUM_TOP_LOGPROBS,
|
||||
prompt_logprobs=NUM_PROMPT_LOGPROBS,
|
||||
)
|
||||
results = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
|
||||
assert len(results) == len(example_prompt_tokens)
|
||||
for i, (result, prompt_tokens) in enumerate(zip(results, example_prompt_tokens)):
|
||||
decode_tokens, _, decode_logprobs, prompt_logprobs = result
|
||||
|
||||
# Ensure the return type of logprobs is accurate
|
||||
assert isinstance(
|
||||
prompt_logprobs, FlattenLogprobs if flatten_logprobs else list
|
||||
)
|
||||
assert isinstance(
|
||||
decode_logprobs, FlattenLogprobs if flatten_logprobs else list
|
||||
)
|
||||
|
||||
########################
|
||||
# Check prompt logprobs
|
||||
########################
|
||||
assert len(prompt_tokens) == len(prompt_logprobs)
|
||||
# No logprob for first prompt token
|
||||
assert not prompt_logprobs[0]
|
||||
for position, (token, logprobs) in enumerate(
|
||||
zip(prompt_tokens[1:], prompt_logprobs[1:]), start=1
|
||||
):
|
||||
# Ensure logprobs of prompt token is always returned
|
||||
logprob = logprobs.get(token)
|
||||
assert logprob is not None
|
||||
assert logprob.rank >= 1
|
||||
# Ensure # of returned logprobs should be
|
||||
# either NUM_PROMPT_LOGPROBS or NUM_PROMPT_LOGPROBS+1
|
||||
assert NUM_PROMPT_LOGPROBS <= len(logprobs) <= NUM_PROMPT_LOGPROBS + 1
|
||||
# Ensure top NUM_PROMPT_LOGPROBS is always extracted
|
||||
assert set(range(1, NUM_PROMPT_LOGPROBS + 1)).issubset(
|
||||
{logprob.rank for logprob in logprobs.values()}
|
||||
)
|
||||
|
||||
########################
|
||||
# Check sample logprobs
|
||||
########################
|
||||
assert len(decode_tokens) == len(decode_logprobs)
|
||||
for position, (token, logprobs) in enumerate(
|
||||
zip(decode_tokens, decode_logprobs)
|
||||
):
|
||||
# Ensure logprobs of chosen token is always returned
|
||||
logprob = logprobs.get(token)
|
||||
assert logprob is not None
|
||||
if greedy:
|
||||
# For greedy sampling, all chosen logprob should be top ranked
|
||||
assert logprob.rank == 1
|
||||
else:
|
||||
assert logprob.rank >= 1
|
||||
# Ensure # of returned logprobs should be
|
||||
# either NUM_TOP_LOGPROBS or NUM_TOP_LOGPROBS+1
|
||||
assert NUM_TOP_LOGPROBS <= len(logprobs) <= NUM_TOP_LOGPROBS + 1
|
||||
# Ensure top NUM_TOP_LOGPROBS logprobs is always extracted
|
||||
assert set(range(1, NUM_TOP_LOGPROBS + 1)).issubset(
|
||||
{logprob.rank for logprob in logprobs.values()}
|
||||
)
|
||||
@ -1,59 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
example_prompts,
|
||||
):
|
||||
max_tokens = 5
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
with vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) as vllm_model:
|
||||
## Test greedy logprobs ranks
|
||||
vllm_sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
)
|
||||
vllm_results = vllm_model.generate_w_logprobs(
|
||||
example_prompts, vllm_sampling_params
|
||||
)
|
||||
|
||||
## Test non-greedy logprobs ranks
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
)
|
||||
res = vllm_model.generate_w_logprobs(example_prompts, sampling_params)
|
||||
|
||||
for result in vllm_results:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks = 1
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert token in logprobs
|
||||
assert logprobs[token].rank == 1
|
||||
|
||||
for result in res:
|
||||
assert result[2] is not None
|
||||
assert len(result[2]) == len(result[0])
|
||||
# check whether all chosen tokens have ranks
|
||||
for token, logprobs in zip(result[0], result[2]):
|
||||
assert logprobs[token].rank >= 1
|
||||
222
tests/test_logprobs.py
Normal file
222
tests/test_logprobs.py
Normal file
@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.logprobs import (
|
||||
FlattenLogprobs,
|
||||
Logprob,
|
||||
LogprobsOnePosition,
|
||||
append_logprobs_for_next_position,
|
||||
create_prompt_logprobs,
|
||||
create_sample_logprobs,
|
||||
)
|
||||
|
||||
|
||||
def test_create_logprobs_non_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
assert isinstance(prompt_logprobs, list)
|
||||
# Ensure first prompt position logprobs is None
|
||||
assert len(prompt_logprobs) == 1
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
sample_logprobs = create_sample_logprobs()
|
||||
assert isinstance(sample_logprobs, list)
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_create_logprobs_flatten(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
||||
|
||||
prompt_logprobs = create_prompt_logprobs()
|
||||
assert isinstance(prompt_logprobs, FlattenLogprobs)
|
||||
assert prompt_logprobs.start_indices == [0]
|
||||
assert prompt_logprobs.end_indices == [0]
|
||||
assert len(prompt_logprobs.token_ids) == 0
|
||||
assert len(prompt_logprobs.logprobs) == 0
|
||||
assert len(prompt_logprobs.ranks) == 0
|
||||
assert len(prompt_logprobs.decoded_tokens) == 0
|
||||
# Ensure first prompt position logprobs is empty
|
||||
assert len(prompt_logprobs) == 1
|
||||
assert prompt_logprobs[0] == dict()
|
||||
|
||||
sample_logprobs = create_sample_logprobs()
|
||||
assert isinstance(sample_logprobs, FlattenLogprobs)
|
||||
assert len(sample_logprobs.start_indices) == 0
|
||||
assert len(sample_logprobs.end_indices) == 0
|
||||
assert len(sample_logprobs.token_ids) == 0
|
||||
assert len(sample_logprobs.logprobs) == 0
|
||||
assert len(sample_logprobs.ranks) == 0
|
||||
assert len(sample_logprobs.decoded_tokens) == 0
|
||||
assert len(sample_logprobs) == 0
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_none_flatten(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "0")
|
||||
logprobs = create_sample_logprobs()
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[1],
|
||||
logprobs=[0.1],
|
||||
decoded_tokens=["1"],
|
||||
rank=10,
|
||||
num_logprobs=-1,
|
||||
)
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[2, 3],
|
||||
logprobs=[0.2, 0.3],
|
||||
decoded_tokens=["2", "3"],
|
||||
rank=11,
|
||||
num_logprobs=-1,
|
||||
)
|
||||
assert isinstance(logprobs, list)
|
||||
assert logprobs == [
|
||||
{1: Logprob(logprob=0.1, rank=10, decoded_token="1")},
|
||||
{
|
||||
2: Logprob(logprob=0.2, rank=11, decoded_token="2"),
|
||||
3: Logprob(logprob=0.3, rank=1, decoded_token="3"),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_append_logprobs_for_next_position_flatten(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("VLLM_FLATTEN_LOGPROBS", "1")
|
||||
logprobs = create_sample_logprobs()
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[1],
|
||||
logprobs=[0.1],
|
||||
decoded_tokens=["1"],
|
||||
rank=10,
|
||||
num_logprobs=-1,
|
||||
)
|
||||
append_logprobs_for_next_position(
|
||||
logprobs,
|
||||
token_ids=[2, 3],
|
||||
logprobs=[0.2, 0.3],
|
||||
decoded_tokens=["2", "3"],
|
||||
rank=11,
|
||||
num_logprobs=-1,
|
||||
)
|
||||
assert isinstance(logprobs, FlattenLogprobs)
|
||||
assert logprobs.start_indices == [0, 1]
|
||||
assert logprobs.end_indices == [1, 3]
|
||||
assert logprobs.token_ids == [1, 2, 3]
|
||||
assert logprobs.logprobs == [0.1, 0.2, 0.3]
|
||||
assert logprobs.ranks == [10, 11, 1]
|
||||
assert logprobs.decoded_tokens == ["1", "2", "3"]
|
||||
|
||||
|
||||
LOGPROBS_ONE_POSITION_0: LogprobsOnePosition = {
|
||||
1: Logprob(logprob=0.1, rank=10, decoded_token="10")
|
||||
}
|
||||
LOGPROBS_ONE_POSITION_1: LogprobsOnePosition = {
|
||||
2: Logprob(logprob=0.2, rank=20, decoded_token="20"),
|
||||
3: Logprob(logprob=0.3, rank=30, decoded_token="30"),
|
||||
}
|
||||
LOGPROBS_ONE_POSITION_2: LogprobsOnePosition = {
|
||||
4: Logprob(logprob=0.4, rank=40, decoded_token="40"),
|
||||
5: Logprob(logprob=0.5, rank=50, decoded_token="50"),
|
||||
6: Logprob(logprob=0.6, rank=60, decoded_token="60"),
|
||||
}
|
||||
|
||||
|
||||
def test_flatten_logprobs_append() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
logprobs.append(LOGPROBS_ONE_POSITION_0)
|
||||
logprobs.append(LOGPROBS_ONE_POSITION_1)
|
||||
assert logprobs.start_indices == [0, 1]
|
||||
assert logprobs.end_indices == [1, 3]
|
||||
assert logprobs.token_ids == [1, 2, 3]
|
||||
assert logprobs.logprobs == [0.1, 0.2, 0.3]
|
||||
assert logprobs.ranks == [10, 20, 30]
|
||||
assert logprobs.decoded_tokens == ["10", "20", "30"]
|
||||
|
||||
logprobs.append(LOGPROBS_ONE_POSITION_2)
|
||||
assert logprobs.start_indices == [0, 1, 3]
|
||||
assert logprobs.end_indices == [1, 3, 6]
|
||||
assert logprobs.token_ids == [1, 2, 3, 4, 5, 6]
|
||||
assert logprobs.logprobs == [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
assert logprobs.ranks == [10, 20, 30, 40, 50, 60]
|
||||
assert logprobs.decoded_tokens == ["10", "20", "30", "40", "50", "60"]
|
||||
|
||||
|
||||
def test_flatten_logprobs_extend() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
# Extend with list[LogprobsOnePosition]
|
||||
logprobs.extend([LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0])
|
||||
assert logprobs.start_indices == [0, 3]
|
||||
assert logprobs.end_indices == [3, 4]
|
||||
assert logprobs.token_ids == [4, 5, 6, 1]
|
||||
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1]
|
||||
assert logprobs.ranks == [40, 50, 60, 10]
|
||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10"]
|
||||
|
||||
other_logprobs = FlattenLogprobs()
|
||||
other_logprobs.extend([LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_0])
|
||||
# Extend with another FlattenLogprobs
|
||||
logprobs.extend(other_logprobs)
|
||||
assert logprobs.start_indices == [0, 3, 4, 6]
|
||||
assert logprobs.end_indices == [3, 4, 6, 7]
|
||||
assert logprobs.token_ids == [4, 5, 6, 1, 2, 3, 1]
|
||||
assert logprobs.logprobs == [0.4, 0.5, 0.6, 0.1, 0.2, 0.3, 0.1]
|
||||
assert logprobs.ranks == [40, 50, 60, 10, 20, 30, 10]
|
||||
assert logprobs.decoded_tokens == ["40", "50", "60", "10", "20", "30", "10"]
|
||||
|
||||
|
||||
def test_flatten_logprobs_access() -> None:
|
||||
logprobs = FlattenLogprobs()
|
||||
logprobs.extend(
|
||||
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0]
|
||||
)
|
||||
assert logprobs.start_indices == [0, 2, 5]
|
||||
assert logprobs.end_indices == [2, 5, 6]
|
||||
assert logprobs.token_ids == [2, 3, 4, 5, 6, 1]
|
||||
assert logprobs.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6, 0.1]
|
||||
assert logprobs.ranks == [20, 30, 40, 50, 60, 10]
|
||||
assert logprobs.decoded_tokens == ["20", "30", "40", "50", "60", "10"]
|
||||
|
||||
# Test __len__
|
||||
assert len(logprobs) == 3
|
||||
|
||||
# Test __iter__
|
||||
for actual_logprobs, expected_logprobs in zip(
|
||||
logprobs,
|
||||
[LOGPROBS_ONE_POSITION_1, LOGPROBS_ONE_POSITION_2, LOGPROBS_ONE_POSITION_0],
|
||||
):
|
||||
assert actual_logprobs == expected_logprobs
|
||||
|
||||
# Test __getitem__ : single item
|
||||
assert logprobs[0] == LOGPROBS_ONE_POSITION_1
|
||||
assert logprobs[1] == LOGPROBS_ONE_POSITION_2
|
||||
assert logprobs[2] == LOGPROBS_ONE_POSITION_0
|
||||
|
||||
# Test __getitem__ : slice
|
||||
logprobs02 = logprobs[:2]
|
||||
assert len(logprobs02) == 2
|
||||
assert logprobs02[0] == LOGPROBS_ONE_POSITION_1
|
||||
assert logprobs02[1] == LOGPROBS_ONE_POSITION_2
|
||||
assert logprobs02.start_indices == [0, 2]
|
||||
assert logprobs02.end_indices == [2, 5]
|
||||
assert logprobs02.token_ids == [2, 3, 4, 5, 6]
|
||||
assert logprobs02.logprobs == [0.2, 0.3, 0.4, 0.5, 0.6]
|
||||
assert logprobs02.ranks == [20, 30, 40, 50, 60]
|
||||
assert logprobs02.decoded_tokens == ["20", "30", "40", "50", "60"]
|
||||
logprobs_last2 = logprobs[-2:]
|
||||
assert len(logprobs_last2) == 2
|
||||
assert logprobs_last2[0] == LOGPROBS_ONE_POSITION_2
|
||||
assert logprobs_last2[1] == LOGPROBS_ONE_POSITION_0
|
||||
assert logprobs_last2.start_indices == [0, 3]
|
||||
assert logprobs_last2.end_indices == [3, 4]
|
||||
assert logprobs_last2.token_ids == [4, 5, 6, 1]
|
||||
assert logprobs_last2.logprobs == [0.4, 0.5, 0.6, 0.1]
|
||||
assert logprobs_last2.ranks == [40, 50, 60, 10]
|
||||
assert logprobs_last2.decoded_tokens == ["40", "50", "60", "10"]
|
||||
@ -220,6 +220,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GC_DEBUG: str = ""
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_FLATTEN_LOGPROBS: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1463,6 +1464,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||
),
|
||||
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
|
||||
# the original list[dict[int, Logprob]] approach.
|
||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
||||
# FlattenLogprobs.
|
||||
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
186
vllm/logprobs.py
186
vllm/logprobs.py
@ -1,6 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
import itertools
|
||||
from collections.abc import Iterable, Iterator, MutableSequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import overload
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
@ -21,8 +26,183 @@ class Logprob:
|
||||
decoded_token: str | None = None
|
||||
|
||||
|
||||
LogprobsOnePosition = dict[int, Logprob]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
"""
|
||||
Flatten logprobs of a request into multiple primitive type lists.
|
||||
|
||||
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
||||
overhead significantly. As it flattened logprob information for
|
||||
all positions and ranks in to multiple primitive type lists (i.e.
|
||||
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
||||
So regardless of the sequence length and top_logprobs setup,
|
||||
FlattenLogprobs would only introduce a constant amount of objects.
|
||||
|
||||
As each position might contains different amount of ranks,
|
||||
start_indices_per_position would be used to access the logprob ranges
|
||||
for different positions.
|
||||
|
||||
NOTE: To reduce the migration overhead and improve backward compatibility,
|
||||
we support the key Sequence APIs of list, so it could act as
|
||||
list[LogprobsOnePosition]
|
||||
"""
|
||||
|
||||
# Start / end indices to indicate the range of logprobs for each position.
|
||||
start_indices: list[int] = field(default_factory=list)
|
||||
end_indices: list[int] = field(default_factory=list)
|
||||
|
||||
# Flatten Logprob information for (each position, rank).
|
||||
# For position <i>, the logprobs are ranged
|
||||
# from self.start_indices[i] to self.end_indices[i] (exclusive).
|
||||
token_ids: list[int] = field(default_factory=list)
|
||||
logprobs: list[float] = field(default_factory=list)
|
||||
ranks: list[int | None] = field(default_factory=list)
|
||||
decoded_tokens: list[str | None] = field(default_factory=list)
|
||||
|
||||
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
|
||||
"""Appends the container with logprobs for the next position"""
|
||||
self.start_indices.append(len(self.logprobs))
|
||||
if logprobs_one_position:
|
||||
for token_id, logprob in logprobs_one_position.items():
|
||||
self.token_ids.append(token_id)
|
||||
self.logprobs.append(logprob.logprob)
|
||||
self.ranks.append(logprob.rank)
|
||||
self.decoded_tokens.append(logprob.decoded_token)
|
||||
self.end_indices.append(len(self.logprobs))
|
||||
|
||||
def append_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
logprobs: list[float],
|
||||
ranks: itertools.chain[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
) -> None:
|
||||
"""
|
||||
Appends logprobs for the next position without creating
|
||||
the intermediate logprob dictionary.
|
||||
"""
|
||||
self.start_indices.append(len(self.logprobs))
|
||||
for token_id, logprob, rank, decoded_token in zip(
|
||||
token_ids, logprobs, ranks, decoded_tokens
|
||||
):
|
||||
self.token_ids.append(token_id)
|
||||
self.logprobs.append(logprob)
|
||||
self.ranks.append(rank)
|
||||
self.decoded_tokens.append(decoded_token)
|
||||
self.end_indices.append(len(self.logprobs))
|
||||
|
||||
def extend(self, logprobs_multi_positions) -> None:
|
||||
"""Extends the container with logprobs for the next multiple positions"""
|
||||
for logprobs_one_position in logprobs_multi_positions:
|
||||
self.append(logprobs_one_position)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Gets number of positions stored in the container"""
|
||||
return len(self.start_indices)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
|
||||
|
||||
def __getitem__(self, index: int | slice):
|
||||
"""Extracts logprobs of a given position or slice"""
|
||||
if isinstance(index, int):
|
||||
return {
|
||||
self.token_ids[i]: Logprob(
|
||||
logprob=self.logprobs[i],
|
||||
rank=self.ranks[i],
|
||||
decoded_token=self.decoded_tokens[i],
|
||||
)
|
||||
for i in range(self.start_indices[index], self.end_indices[index])
|
||||
}
|
||||
elif isinstance(index, slice):
|
||||
min_index = self.start_indices[index][0]
|
||||
max_index = self.end_indices[index][-1]
|
||||
return FlattenLogprobs(
|
||||
# Shift updated start_indices and end_indices to
|
||||
# be 0-indexed
|
||||
start_indices=[i - min_index for i in self.start_indices[index]],
|
||||
end_indices=[i - min_index for i in self.end_indices[index]],
|
||||
token_ids=self.token_ids[min_index:max_index],
|
||||
logprobs=self.logprobs[min_index:max_index],
|
||||
ranks=self.ranks[min_index:max_index],
|
||||
decoded_tokens=self.decoded_tokens[min_index:max_index],
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid index type: {type(index)}")
|
||||
|
||||
def __setitem__(self, item, value) -> None:
|
||||
raise TypeError("Cannot set logprobs in FlattenLogprobs")
|
||||
|
||||
def __delitem__(self, item) -> None:
|
||||
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
|
||||
|
||||
def insert(self, item) -> None:
|
||||
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
|
||||
|
||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||
"""
|
||||
Iterates the container and yields LogprobsOnePosition for
|
||||
each position.
|
||||
"""
|
||||
for i in range(0, len(self.start_indices)):
|
||||
yield self.__getitem__(i)
|
||||
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = list[dict[int, Logprob] | None]
|
||||
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = list[dict[int, Logprob]]
|
||||
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
|
||||
def create_prompt_logprobs() -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sample_logprobs() -> SampleLogprobs:
|
||||
"""Creates a container to store decode logprobs for a request"""
|
||||
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
|
||||
|
||||
def append_logprobs_for_next_position(
|
||||
request_logprobs: PromptLogprobs | SampleLogprobs,
|
||||
token_ids: list[int],
|
||||
logprobs: list[float],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
"""Appends logprobs for the next position"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
if isinstance(request_logprobs, FlattenLogprobs):
|
||||
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
||||
else:
|
||||
request_logprobs.append(
|
||||
{
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
token_ids, logprobs, ranks, decoded_tokens
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
@ -2,11 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.logprobs import (
|
||||
PromptLogprobs,
|
||||
SampleLogprobs,
|
||||
append_logprobs_for_next_position,
|
||||
create_prompt_logprobs,
|
||||
create_sample_logprobs,
|
||||
)
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer,
|
||||
convert_ids_list_to_tokens,
|
||||
@ -44,9 +49,10 @@ class LogprobsProcessor:
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
||||
logprobs=(None if num_logprobs is None else []),
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
|
||||
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
|
||||
prompt_logprobs=(
|
||||
None if num_prompt_logprobs is None else create_prompt_logprobs()
|
||||
),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
@ -80,15 +86,14 @@ class LogprobsProcessor:
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.logprobs.append(
|
||||
self._make_logprob_dict(
|
||||
logprobs,
|
||||
token_ids,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
)
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.logprobs,
|
||||
token_ids,
|
||||
logprobs,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
)
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
@ -136,15 +141,14 @@ class LogprobsProcessor:
|
||||
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
)
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
self._make_logprob_dict(
|
||||
prompt_logprobs[pos],
|
||||
token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs,
|
||||
)
|
||||
# Update with the Logprob container for this pos.
|
||||
append_logprobs_for_next_position(
|
||||
self.prompt_logprobs,
|
||||
token_ids[pos],
|
||||
prompt_logprobs[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs,
|
||||
)
|
||||
|
||||
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
|
||||
@ -166,46 +170,6 @@ class LogprobsProcessor:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
logprob_token_ids, logprobs, ranks, decoded_tokens
|
||||
)
|
||||
}
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user