[V1] LogitsProcessor programming model (#16728)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm 2025-07-02 12:10:42 -04:00 committed by GitHub
parent c1909e7e8c
commit 48fb076cbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1401 additions and 393 deletions

View File

@ -0,0 +1,626 @@
# SPDX-License-Identifier: Apache-2.0
import random
from collections.abc import Callable
from typing import NamedTuple, Optional, Union
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
MIN_TOKENS_LEN_THRESHOLD = 5
REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none"
# LogitsProcessor subclass or "none"
LogitprocType = Union[type[LogitsProcessor], str]
class LogitsProcsRequestParams:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
params: SamplingParams # Settings customized for logitproc
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
self.workload_index = workload_index
self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.params = _sampling_params_from_logitproc(logitproc_type)
def __str__(self):
"""For debugging"""
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
return f"MyClass({summ})"
def _generate_fake_sampling_metadata(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
"""Generate fake sampling metadata with fake logitsprocs"""
output_token_ids: list[list[int]] = []
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
logitsprocs = init_builtin_logitsprocs(
pin_memory_available=PIN_MEMORY_AVAILABLE,
max_num_reqs=MAX_NUM_REQS + 1,
device=device)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=logitsprocs)
return fake_sampling_metadata
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
"""Generate fake logits and sampling metadata"""
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch, to support min-p test
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _generate_fake_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
return LogitsprocsTestFakes(
logits=fake_logits,
sampling_metadata=sampling_metadata,
)
def _sampling_params_from_logitproc(
logitproc_type: LogitprocType) -> SamplingParams:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
fxn(kwargs)
return SamplingParams(**kwargs)
def _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc: int,
logitsprocs_types: list[str],
) -> list[LogitsProcsRequestParams]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
size of the batch is `reqs_per_logitproc` times
`n = len(logitsprocs_types)`
Args:
reqs_per_logitproc: number of requests using each logitproc
logitsprocs_types: logitsprocs under test
Returns:
List of per-request params which configure the engine for that request's
enabled logitproc
"""
batch_size = len(logitsprocs_types) * reqs_per_logitproc
# Generate multiple repeats of key params for each logitproc;
# apply random inverse permutation to the iteration
# over logitsprocs, such that logitsprocs are shuffled.
batch_perm = random.sample(range(batch_size), k=batch_size)
return [
LogitsProcsRequestParams(
workload_index=idx,
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
for idx, pdx in enumerate(batch_perm)
]
def _raise_error_invalid(
msg_suffix: str,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
err_cls: type[Exception] = ValueError,
) -> None:
raise err_cls(f"Validation failed for step={step_idx}, "
f"batch_index={batch_index}, "
f"workload_index={request_params.workload_index}, "
f"req_params={request_params}. Reason: {msg_suffix}")
def _logit_bias_params(kwargs: dict) -> None:
"""Logit bias config"""
kwargs["logit_bias"] = {
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
}
def _logit_bias_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate logit bias logitproc applied correctly"""
logit_bias = request_params.params.logit_bias
logits_old = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
logits_new = logits_new[batch_index].cpu()
for token_id in range(VOCAB_SIZE):
logit_old_value = logits_old[token_id]
logit_new_value = logits_new[token_id]
if token_id in logit_bias:
bias_value = logit_bias[token_id]
exp_value = bias_value + logit_old_value
if logit_new_value != pytest.approx(exp_value):
_raise_error_invalid(msg_suffix=(
f"Biased token {token_id} logit value {logit_new_value} "
f"does not match expected value {exp_value} "
f"given bias {bias_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logit_new_value != pytest.approx(logit_old_value):
_raise_error_invalid(msg_suffix=(
f"Unbiased token {token_id} logit value {logit_new_value} "
f"does not match expected value {logit_old_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_p_params(kwargs: dict) -> None:
"""Min-p logitproc config"""
kwargs["min_p"] = 0.1
def _min_p_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-p logitproc applied correctly"""
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id == 0:
# Dominant token should always be unmasked
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix="Invalid: dominant token 0 masked (-inf)",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if request_params.params.min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: non-dominant token {token_id} not masked",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
# No masking when min_p is 0
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: token {token_id} masked when min_p=0.0",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_tokens_params(kwargs: dict) -> None:
"""Min-tokens logitproc config"""
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
kwargs["stop_token_ids"] = [
np.random.randint(0, VOCAB_SIZE - 1)
for _ in range(np.random.randint(0, VOCAB_SIZE))
]
def _min_tokens_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-tokens logitsproc applied correctly"""
ref_num_out_tokens = len(request_params.out_tokens)
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
mt_lp: MinTokensLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
assert isinstance(mt_lp, MinTokensLogitsProcessor)
min_tok = mt_lp.min_toks.get(batch_index, None)
# Validate min-token logits processor state
if min_tok:
(_, out_tok, all_stop_token_ids) = min_tok
num_out_tokens = len(out_tok)
if num_out_tokens != ref_num_out_tokens:
_raise_error_invalid(msg_suffix=(
"Number of output tokens in min-token logit processor "
f"request metadata ({num_out_tokens}) does not match "
f"reference ({ref_num_out_tokens})."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if ref_all_stop_token_ids != all_stop_token_ids:
_raise_error_invalid(msg_suffix=(
"Stop token ids do not match reference; all_stop_token_ids: "
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
f"{sorted(ref_all_stop_token_ids)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
elif not min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
# Validate min-token logits
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id in ref_all_stop_token_ids and not min_reached:
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f"(logit={logits_for_token})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} should not be masked but "
f"is (output len={ref_num_out_tokens})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _none_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate that no logits processors are applied"""
logits = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
ref_logits = logits_new[batch_index]
if not torch.all(ref_logits == logits):
mismatch_toks = (ref_logits
!= logits).nonzero(as_tuple=True)[0].tolist()
mismatch_strs = []
for token in mismatch_toks:
val = float(logits[token])
ref_val = float(ref_logits[token])
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
_raise_error_invalid(msg_suffix=(
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
class LogitsprocTestHelpers(NamedTuple):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn: Callable
gen_request_fxn: Optional[Callable] = None
logitsprocs_test_mapping = {
STR_NO_LOGITPROC:
LogitsprocTestHelpers(eval_fxn=_none_validate),
LogitBiasLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
eval_fxn=_logit_bias_validate),
MinPLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
eval_fxn=_min_p_validate),
MinTokensLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
eval_fxn=_min_tokens_validate),
}
def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
] + [logitsprocs_types]
def _generate_fake_step_update(
persistent_batch: list[LogitsProcsRequestParams],
workload_params: list[LogitsProcsRequestParams],
wdx: int,
batch_update_builder: BatchUpdateBuilder,
) -> tuple[Optional[BatchUpdate], int, int]:
batch_size = len(persistent_batch)
workload_size = len(workload_params)
workload_reqs_remaining = workload_size - wdx
max_add_remove_per_step = max(1, int(0.2 * workload_size))
# 50% of steps: add no reqs
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add = random.choice([
0,
random.randint(1, min(max_add_remove_per_step,
workload_reqs_remaining))
]) if workload_reqs_remaining else 0
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
# persistent batch reqs remaining, less than an arbitrary max)
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove = random.choice([
0, random.randint(1, min(max_add_remove_per_step, batch_size))
]) if batch_size else 0
num_step_add_replace = min(num_step_add, num_step_remove)
# Generate fake removed request indices drawn from persistent batch indices
for removal in random.sample(range(batch_size), num_step_remove):
batch_update_builder.removed_append(removal)
# Get added requests from workload
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
persistent_batch[add_remove_idx] = add_req_params
# Append remaining added requests to end of batch
add_reqs_append = workload_params[(wdx +
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
persistent_batch.extend(add_reqs_append)
pre_condense_batch_size = len(persistent_batch)
wdx += num_step_add # Update workload offset
# Simulate condensing persistent batch
last_nonempty_index = pre_condense_batch_size - 1
condensed_to_idxs = set()
while batch_update_builder.removed:
if (last_nonempty_index in batch_update_builder.removed
or last_nonempty_index in condensed_to_idxs):
last_nonempty_index -= 1
continue
# last_nonempty_index is the highest persistent batch index that was
# not removed
first_empty_index = batch_update_builder.peek_removed()
assert first_empty_index is not None
if first_empty_index > last_nonempty_index:
break
# first_empty_index is the lowest removed persistent batch index
# that is less than last_nonempty_index
#
# move last_nonempty_index -> first_empty_index
batch_update_builder.pop_removed()
condensed_to_idxs.add(first_empty_index)
persistent_batch[first_empty_index] = persistent_batch[
last_nonempty_index]
batch_update_builder.moved.append(
(last_nonempty_index, first_empty_index,
MoveDirectionality.UNIDIRECTIONAL))
last_nonempty_index -= 1
# Now removed requests & gaps left by non-removed requests that got
# moved downward are grouped consecutively in the upper indices of
# the persistent batch. Truncate them to get condensed persistent batch
condensed_batch_size = batch_size + num_step_add - num_step_remove
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))
random.shuffle(idxs)
swaps = [
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
]
batch_update_builder.moved.extend([
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
])
for adx, bdx in swaps:
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
bdx], persistent_batch[adx]
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
workload_size - wdx)
def _assert_valid(
batch_size: int,
persistent_batch: list[LogitsProcsRequestParams],
test_fakes: LogitsprocsTestFakes,
slice_idxs: list[int],
logits_w_lp: torch.Tensor,
step_idx: int,
) -> None:
if not slice_idxs:
# Trivial case of empty persistent batch
assert len(persistent_batch) == 0
if logits_w_lp.shape[0] != 0:
raise ValueError("Fake persistent batch is empty but logitsprocs "
f"output batch has shape {logits_w_lp.shape}")
return
# Validate logits for each fake request
for batch_index in range(batch_size):
request_params = persistent_batch[batch_index]
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
fxn(test_fakes=test_fakes,
persistent_batch=persistent_batch,
logits_new=logits_w_lp,
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs(device: str, reqs_per_logitproc: int,
logitsprocs_under_test: list[str]):
random.seed(40)
torch.set_default_device(device)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params = _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc=reqs_per_logitproc,
logitsprocs_types=logitsprocs_under_test)
workload_size = len(workload_params)
# Create fake test data structures for testing.
test_fakes = _generate_test_fakes(workload_size, device)
wdx = 0 # Next request index in workload to add
persistent_batch: list[LogitsProcsRequestParams] = [
] # Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds
batch_update_builder = BatchUpdateBuilder()
# Break when entire workload has been added previously and persistent
# batch is empty
workload_reqs_remaining = workload_size
batch_size = 0
step_idx = 0
while True:
if not (workload_reqs_remaining or batch_size):
break
(
batch_update,
wdx,
workload_reqs_remaining,
) = _generate_fake_step_update(
persistent_batch=persistent_batch,
workload_params=workload_params,
wdx=wdx,
batch_update_builder=batch_update_builder,
)
batch_size = len(persistent_batch)
# Apply fake batch update to logitsprocs
fake_update_logitsprocs_state(test_fakes, batch_update)
# Emulate application of logits processors in engine
slice_idxs = [req.workload_index for req in persistent_batch]
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
_assert_valid(
batch_size=batch_size,
persistent_batch=persistent_batch,
test_fakes=test_fakes,
slice_idxs=slice_idxs,
logits_w_lp=logits_w_lp,
step_idx=step_idx,
)
step_idx += 1

View File

@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported.
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests",
"--gpu-memory-utilization=0.8"
]
NUM_CONCURRENT = 100

View File

@ -7,6 +7,7 @@ import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
@ -58,7 +59,6 @@ def create_sampling_metadata(
all_random=not all_greedy,
top_p=top_p,
top_k=top_k,
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
no_penalties=False,
@ -67,10 +67,9 @@ def create_sampling_metadata(
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens={},
logit_bias=[None],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
)

View File

@ -8,10 +8,13 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils import make_tensor_with_pad
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
)
def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> list[Optional[dict[int, float]]]:
res: list[Optional[dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res
def _create_allowed_token_ids(
batch_size: int,
vocab_size: int,
@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
all_random=False,
top_p=None,
top_k=None,
min_p=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
)
return fake_sampling_metadata
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: list[int]
) -> dict[int, tuple[int, set[int]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens: dict[int, tuple[int, set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))
else:
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens
def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
or non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("min_p", [0.0, 0.1])
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
# Configure min_p parameters
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
sampler = Sampler()
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id == 0:
# Dominant token should always be unmasked
assert logits[batch_idx][token_id] != -float("inf")
else:
if min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
assert logits[batch_idx][token_id] == -float("inf")
else:
# No masking when min_p is 0
assert logits[batch_idx][token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])

View File

@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from enum import Enum
from typing import Optional
from typing import NamedTuple, Optional
import regex as re
import torch
from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
class BatchLogprobsComposition(Enum):
@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs = completion_output.logprobs
assert logprobs is not None
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits
def create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def create_prompt_tokens_tensor(
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)
class LogitsprocsTestFakes(NamedTuple):
"""Wraps fake data structures to support testing"""
logits: torch.Tensor
sampling_metadata: SamplingMetadata
def get_logitsprocs_by_cls(
self,
cls: type[LogitsProcessor],
) -> Iterator[LogitsProcessor]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return (lp for lp in self.sampling_metadata.logitsprocs.all
if isinstance(lp, cls))
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return self.sampling_metadata.logitsprocs.all
def fake_update_logitsprocs_state(
test_fakes: LogitsprocsTestFakes,
batch_update: BatchUpdate,
) -> None:
"""Imitate logits processors persistent batch state update
in engine core"""
for logitproc in test_fakes.get_logitsprocs():
logitproc.update_state(batch_update)
def fake_apply_logitsprocs(
test_fakes: LogitsprocsTestFakes,
slice_indices: list[int],
) -> torch.Tensor:
"""Imitate application of logits processors in engine core"""
logits = test_fakes.logits[torch.tensor(slice_indices,
dtype=torch.long)].clone()
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
@ -12,6 +13,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -26,13 +28,18 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1, obj2):
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
@ -58,13 +65,11 @@ def _compare_objs(obj1, obj2):
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns a tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
@ -73,13 +78,11 @@ def _remove_requests(
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove, req_indices_to_remove_list
return req_ids_to_remove
def _construct_expected_sampling_metadata(
@ -100,7 +103,6 @@ def _construct_expected_sampling_metadata(
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
@ -123,7 +125,6 @@ def _construct_expected_sampling_metadata(
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
@ -145,8 +146,6 @@ def _construct_expected_sampling_metadata(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
@ -165,13 +164,12 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessorManager(),
)
@ -225,6 +223,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@ -238,21 +238,22 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove, req_indices_to_remove = _remove_requests(
input_batch, batch_size, reqs)
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense(req_indices_to_remove)
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
@ -290,10 +291,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
@ -315,6 +314,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@ -341,7 +342,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
@ -354,9 +356,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
for req_index in range(batch_size):
req = reordered_reqs[req_index]
ref_input_batch.add_request(req, req_index)
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_sampling_metadata()
ref_input_batch.refresh_sampling_metadata()
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)

View File

@ -0,0 +1,516 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
from typing import Optional, Union
import torch
from torch._prims_common import DeviceLikeType
from vllm import PoolingParams, SamplingParams
from vllm.logger import init_logger
logger = init_logger(__name__)
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = 0
# Two-way i1<->i2 req swap within batch
SWAP = 1
# (index, params, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclasses.dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Note: each added request is represented as
# (index, params, output_tok_ids)
# Key assumption: output_tok_ids is a reference to the
# request's running output tokens list; in this way
# the logits processors always see the latest list of
# generated tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self._is_removed_sorted = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from
the persistent batch.
Must not be called after the first time
self.removed, self.pop_removed() or
self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> Optional[int]:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> Optional[int]:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure
and reset internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
# Reset removed/moved/added update lists
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessor(ABC):
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
TODO(andy): won't be utilized until logits
processors are user-extensible
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional[BatchUpdate],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@dataclass
class LogitsProcessorManager:
"""Encapsulates initialized logitsproc objects."""
argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # argmax-invariant logitsprocs
non_argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # non-argmax-invariant logitsprocs
@property
def all(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)
###### ----- Built-in LogitsProcessor impls below here
class MinPLogitsProcessor(LogitsProcessor):
def __init__(self, max_num_reqs: int, pin_memory: bool,
device: DeviceLikeType):
super().__init__()
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]
def is_argmax_invariant(self) -> bool:
"""Min-p never impacts greedy sampling"""
return True
def get_min_p_by_index(self, index: int) -> float:
return float(self.min_p_cpu[index])
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
needs_update = False
# Process added requests.
for index, params, _ in batch_update.added:
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
if self.min_p_cpu[index] != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
if min_p:
self.min_p_count += 1
if self.min_p_count:
# Process removed requests.
needs_update |= bool(batch_update.removed)
for index in batch_update.removed:
if self.min_p_cpu[index]:
self.min_p_count -= 1
# Process moved requests, unidirectional (a->b) and swap (a<->b)
for adx, bdx, direct in batch_update.moved:
change = (min_p_a :=
self.min_p_cpu[adx]) != (min_p_b :=
self.min_p_cpu[bdx])
needs_update |= change
if change:
self.min_p_cpu[bdx] = min_p_a
if direct == MoveDirectionality.SWAP:
self.min_p_cpu[adx] = min_p_b
# Update tensors if needed.
size = batch_update.batch_size
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
self.min_p = self.min_p_device[:size]
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
self.min_p.unsqueeze_(1)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.min_p_count:
return logits
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Adjust min_p
adjusted_min_p = max_probabilities.mul_(self.min_p)
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float('inf')
return logits
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
super().__init__()
self.biases: dict[int, dict[int, float]] = {}
self.device = device
self.pin_memory = pin_memory
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32))
def is_argmax_invariant(self) -> bool:
"""Logit bias can rebalance token probabilities and change the
outcome of argmax in greedy sampling."""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
needs_update = bool(batch_update.added)
for index, params, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
self.biases[index] = lb
else:
self.biases.pop(index, None)
if self.biases:
# Process removed requests.
for index in batch_update.removed:
if self.biases.pop(index, None):
needs_update = True
# Process moved requests, unidirectional (a->b) and swap (a<->b)
for a_index, b_index, direct in batch_update.moved:
if direct == MoveDirectionality.UNIDIRECTIONAL:
if (a_entry := self.biases.pop(a_index, None)) is None:
if self.biases.pop(b_index, None) is not None:
needs_update = True
else:
self.biases[b_index] = a_entry
needs_update = True
else:
a_entry = self.biases.pop(a_index, None)
if (b_entry := self.biases.pop(b_index, None)) is not None:
self.biases[a_index] = b_entry
needs_update = True
if a_entry is not None:
self.biases[b_index] = a_entry
needs_update = True
# Update tensors if needed.
if needs_update:
reqs, tok_ids, biases = [], [], []
for req, lb in self.biases.items():
reqs.extend([req] * len(lb))
tok_ids.extend(lb.keys())
biases.extend(lb.values())
self.bias_tensor = self._device_tensor(biases, torch.float32)
self.logits_slice = (self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32))
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return (torch.tensor(data,
device="cpu",
dtype=dtype,
pin_memory=self.pin_memory).to(device=self.device,
non_blocking=True))
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.biases:
logits[self.logits_slice] += self.bias_tensor
return logits
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
# index -> (min_toks, output_token_ids, stop_token_ids)
super().__init__()
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
self.device = device
self.pin_memory = pin_memory
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor,
torch.Tensor] = (self._device_tensor(
[], torch.int32),
self._device_tensor(
[], torch.int32))
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
needs_update = False
if batch_update:
# Process added requests.
needs_update |= bool(batch_update.added)
for index, params, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
and len(output_tok_ids) < min_tokens):
# Replace request metadata at batch index
self.min_toks[index] = (min_tokens, output_tok_ids,
params.all_stop_token_ids)
else:
# Drop request metadata at batch index
self.min_toks.pop(index, None)
if self.min_toks:
# Process removed requests.
for index in batch_update.removed:
if self.min_toks.pop(index, None):
needs_update = True
# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
if direct == MoveDirectionality.UNIDIRECTIONAL:
if (a_entry := self.min_toks.pop(a_index,
None)) is None:
if self.min_toks.pop(b_index, None) is not None:
needs_update = True
else:
self.min_toks[b_index] = a_entry
needs_update = True
else:
a_entry = self.min_toks.pop(a_index, None)
if (b_entry := self.min_toks.pop(b_index,
None)) is not None:
self.min_toks[a_index] = b_entry
needs_update = True
if a_entry is not None:
self.min_toks[b_index] = a_entry
needs_update = True
if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(index for index, (min_toks, out_tok_ids,
_) in self.min_toks.items()
if len(out_tok_ids) >= min_toks)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32))
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return (torch.tensor(data,
device="cpu",
dtype=dtype,
pin_memory=self.pin_memory).to(device=self.device,
non_blocking=True))
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
return logits
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device: torch.device) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.
Args:
pin_memory_available: pinned memory is available for use
for use by logitsproc
max_num_reqs: ceiling on request count in persistent batch
device: inference device
Returns:
Data structure encapsulating loaded logitsprocs
"""
min_tokens_logitproc = MinTokensLogitsProcessor(
pin_memory=pin_memory_available, device=device)
logit_bias_logitproc = LogitBiasLogitsProcessor(
pin_memory=pin_memory_available, device=device)
min_p_logitproc = MinPLogitsProcessor(
pin_memory=pin_memory_available,
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)
return LogitsProcessorManager(
non_argmax_invariant=[
min_tokens_logitproc,
logit_bias_logitproc,
],
argmax_invariant=[min_p_logitproc],
)

View File

@ -6,6 +6,8 @@ from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessorManager
@dataclass
class SamplingMetadata:
@ -16,7 +18,6 @@ class SamplingMetadata:
top_p: Optional[torch.Tensor]
top_k: Optional[torch.Tensor]
min_p: Optional[torch.Tensor]
generators: dict[int, torch.Generator]
@ -31,14 +32,12 @@ class SamplingMetadata:
output_token_ids: list[list[int]]
# req_index -> (min_tokens, stop_token_ids)
min_tokens: dict[int, tuple[int, set[int]]]
logit_bias: list[Optional[dict[int, float]]]
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]
# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessorManager

View File

@ -7,22 +7,6 @@ from vllm.model_executor.layers.utils import apply_penalties
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
def apply_min_token_penalties(
logits: torch.Tensor, output_token_ids: list[list[int]],
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
"""
Applies minimum token penalty by setting the logits of the stop tokens
to -inf.
"""
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
for index, (min_token, stop_token_ids) in min_tokens.items():
if len(output_token_ids[index]) < min_token:
for stop_token_id in stop_token_ids:
min_tokens_logits_to_penalize.append((index, stop_token_id))
if min_tokens_logits_to_penalize:
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,

View File

@ -5,12 +5,11 @@
import torch
import torch.nn as nn
from vllm.utils import async_tensor_h2d, is_pin_memory_available
from vllm.utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
_SAMPLING_EPS = 1e-5
@ -44,8 +43,11 @@ class Sampler(nn.Module):
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Sample the next token.
@ -110,9 +112,10 @@ class Sampler(nn.Module):
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
@ -187,10 +190,6 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.min_tokens:
apply_min_token_penalties(logits,
sampling_metadata.output_token_ids,
sampling_metadata.min_tokens)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
@ -203,65 +202,6 @@ class Sampler(nn.Module):
)
return logits
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
return logits
def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
rows: list[int] = []
cols: list[int] = []
vals: list[float] = []
# Get vocabulary size from logits
vocab_size = logits.shape[-1]
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
# Check token_id bounds to ensure within vocabulary
if token_id < 0 or token_id >= vocab_size:
raise ValueError(
f"token_id {token_id} in logit_bias contains "
f"out-of-vocab token id. Vocabulary size: "
f"{vocab_size}")
rows.append(i)
cols.append(token_id)
vals.append(bias)
if rows:
indices = async_tensor_h2d([rows, cols], torch.int64,
logits.device, self.pin_memory)
values = async_tensor_h2d(vals, torch.float, logits.device,
self.pin_memory)
logits.index_put_(tuple(indices), values=values, accumulate=True)
return logits
def apply_allowed_token_ids(
self,
logits: torch.Tensor,

View File

@ -1,23 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch
_SAMPLING_EPS = 1e-5
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
or req_id in input_batch.presence_penalties_reqs
or req_id in input_batch.repetition_penalties_reqs):
# Spec decode doesn't support penalties.
return False
elif req_id in input_batch.num_logprobs:
# Spec decode doesn't support logprobs.
return False
return True
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
"""True if request is incompatible with speculative decoding"""
return (sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
or sampling_params.repetition_penalty != 1.0
or sampling_params.min_p > _SAMPLING_EPS
or sampling_params.logprobs is not None)
@triton.jit

View File

@ -15,12 +15,14 @@ from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
MoveDirectionality,
init_builtin_logitsprocs)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5
@dataclass
class CachedRequestState:
@ -67,8 +69,10 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False,
):
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
@ -146,15 +150,8 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
@ -194,9 +191,6 @@ class InputBatch:
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
@ -216,8 +210,20 @@ class InputBatch:
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# Define logits processors.
# TODO(andy): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
@ -240,14 +246,28 @@ class InputBatch:
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _get_next_add_index(self) -> int:
if (req_index := self.batch_update_builder.pop_removed()) is not None:
# Fill the empty index.
return req_index
# Append to end
return self.num_reqs
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations"""
req_index = self._get_next_add_index()
assert req_index < self.max_num_reqs
params = (request.sampling_params
if request.sampling_params else request.pooling_params)
self.batch_update_builder.added.append(
(req_index, params, request.output_token_ids))
return req_index
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
) -> int:
req_index = self._register_add_request(request)
req_id = request.req_id
if req_index == len(self._req_ids):
@ -278,6 +298,9 @@ class InputBatch:
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
@ -295,11 +318,8 @@ class InputBatch:
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
@ -310,10 +330,6 @@ class InputBatch:
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (
sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
@ -325,8 +341,6 @@ class InputBatch:
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
@ -368,12 +382,22 @@ class InputBatch:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense()."""
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
@ -381,8 +405,7 @@ class InputBatch:
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
@ -400,7 +423,6 @@ class InputBatch:
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
@ -410,6 +432,8 @@ class InputBatch:
return req_index
def swap_states(self, i1: int, i2: int) -> None:
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
@ -439,8 +463,6 @@ class InputBatch:
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
@ -452,13 +474,10 @@ class InputBatch:
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
@ -467,12 +486,23 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Args:
empty_req_indices: empty batch indices, sorted descending.
empty_req_indices: empty indices which may be filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
@ -489,11 +519,17 @@ class InputBatch:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Swap the states.
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
@ -524,20 +560,14 @@ class InputBatch:
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
@ -547,6 +577,7 @@ class InputBatch:
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@ -554,8 +585,17 @@ class InputBatch:
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
def refresh_sampling_metadata(self):
self.sampling_metadata = self._make_sampling_metadata()
def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
"""
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
@ -568,8 +608,6 @@ class InputBatch:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_min_p:
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
@ -607,7 +645,6 @@ class InputBatch:
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
min_p=None if self.no_min_p else self.min_p[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
@ -615,11 +652,10 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
min_tokens=self.min_tokens,
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property
@ -702,10 +738,6 @@ class InputBatch:
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0

View File

@ -63,12 +63,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
@ -212,6 +212,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
)
self.use_cuda_graph = (
@ -316,7 +317,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
self.shared_kv_cache_layers: dict[str, str] = {}
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
@ -325,21 +326,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Args:
scheduler_output: The scheduler output.
Returns:
True if the batch was reordered, False otherwise.
"""
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
self.input_batch, scheduler_output)
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
scheduler_output)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
assert not self.attn_metadata_builders[i].reorder_batch(
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
return batch_reordered
assert not batch_reordered
# Note: used for model runner override.
def _init_device_properties(self) -> None:
@ -372,11 +370,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# then resubmitted with the same ID. In this case, we treat them as two
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
removed_req_indices: list[int] = []
for req_id in scheduler_output.finished_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
self.input_batch.remove_request(req_id)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
@ -399,9 +394,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
req_index = self.input_batch.remove_request(req_id)
assert req_index is not None
removed_req_indices.append(req_index)
self.input_batch.remove_request(req_id)
req_ids_to_add: list[str] = []
# Add new requests to the cached states.
@ -545,31 +538,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices.sort(reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
self.input_batch.add_request(req_state)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
batch_reordered = self._may_reorder_batch(scheduler_output)
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
def _get_cumsum_and_arange(
self,
@ -1296,7 +1276,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
@ -1765,7 +1744,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
@ -2121,7 +2100,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
all_random=False,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
@ -2130,10 +2108,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
)
try:
sampler_output = self.sampler(logits=logits,
@ -2425,6 +2402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
)
def _allocate_kv_cache_tensors(