mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:40:44 +08:00
[V1] LogitsProcessor programming model (#16728)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Andrew Feldman <afeldman@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
c1909e7e8c
commit
48fb076cbc
626
tests/v1/sample/test_logits_processors.py
Normal file
626
tests/v1/sample/test_logits_processors.py
Normal file
@ -0,0 +1,626 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
|
||||
create_penalty_tensor,
|
||||
create_prompt_tokens_tensor,
|
||||
fake_apply_logitsprocs,
|
||||
fake_update_logitsprocs_state)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
# yapf: disable
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
|
||||
LogitBiasLogitsProcessor,
|
||||
LogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
MoveDirectionality,
|
||||
init_builtin_logitsprocs)
|
||||
# yapf: enable
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
|
||||
MAX_NUM_REQS = 256
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
CUDA_DEVICES = [
|
||||
f"{current_platform.device_type}:{i}"
|
||||
for i in range(1 if current_platform.device_count() == 1 else 2)
|
||||
]
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
MIN_TOKENS_LEN_THRESHOLD = 5
|
||||
REQS_PER_LOGITPROC = 50
|
||||
STR_NO_LOGITPROC = "none"
|
||||
|
||||
# LogitsProcessor subclass or "none"
|
||||
LogitprocType = Union[type[LogitsProcessor], str]
|
||||
|
||||
|
||||
class LogitsProcsRequestParams:
|
||||
"""Encapsulates key params for a single request in a batch.
|
||||
|
||||
Params can be customized based on the enabled logitproc
|
||||
"""
|
||||
workload_index: int
|
||||
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
|
||||
out_tokens: list[int] # Output tokens required for min tokens test
|
||||
params: SamplingParams # Settings customized for logitproc
|
||||
|
||||
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
|
||||
self.workload_index = workload_index
|
||||
self.logitproc_type = logitproc_type
|
||||
# Number of output tokens is randomly 0 or twice the min-tokens
|
||||
# threshold which will be used in testing. Output token values
|
||||
# don't matter *for these tests* so use 0 as a dummy value
|
||||
self.out_tokens = ([0] *
|
||||
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
|
||||
self.params = _sampling_params_from_logitproc(logitproc_type)
|
||||
|
||||
def __str__(self):
|
||||
"""For debugging"""
|
||||
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
|
||||
return f"MyClass({summ})"
|
||||
|
||||
|
||||
def _generate_fake_sampling_metadata(
|
||||
num_output_tokens: int,
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
"""Generate fake sampling metadata with fake logitsprocs"""
|
||||
output_token_ids: list[list[int]] = []
|
||||
prompt_token_ids: list[list[int]] = []
|
||||
for _ in range(batch_size):
|
||||
output_token_ids.append(
|
||||
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
|
||||
prompt_token_ids.append(
|
||||
np.random.randint(0,
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=PIN_MEMORY_AVAILABLE,
|
||||
max_num_reqs=MAX_NUM_REQS + 1,
|
||||
device=device)
|
||||
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=logitsprocs)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
|
||||
"""Generate fake logits and sampling metadata"""
|
||||
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
# Create one dominant token per batch, to support min-p test
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, 0] = 10.0 # High logit for first token
|
||||
fake_logits[i, 1:] = 1e-2 # Others remain low
|
||||
sampling_metadata = _generate_fake_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
return LogitsprocsTestFakes(
|
||||
logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _sampling_params_from_logitproc(
|
||||
logitproc_type: LogitprocType) -> SamplingParams:
|
||||
"""Customize request SamplingParams for a specified logitproc"""
|
||||
# SamplingParams for req with no logitproc
|
||||
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
|
||||
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
|
||||
fxn(kwargs)
|
||||
return SamplingParams(**kwargs)
|
||||
|
||||
|
||||
def _generate_mixed_logitsprocs_batch_params(
|
||||
reqs_per_logitproc: int,
|
||||
logitsprocs_types: list[str],
|
||||
) -> list[LogitsProcsRequestParams]:
|
||||
"""Define key params for a batch of requests with a different
|
||||
logitproc enabled per request.
|
||||
|
||||
The batch will have `reqs_per_logitproc` repeats for all
|
||||
`logitsprocs_types` under test, including the case where
|
||||
no logitsproc is enabled. The batch is randomly shuffled. The
|
||||
size of the batch is `reqs_per_logitproc` times
|
||||
`n = len(logitsprocs_types)`
|
||||
|
||||
Args:
|
||||
reqs_per_logitproc: number of requests using each logitproc
|
||||
logitsprocs_types: logitsprocs under test
|
||||
|
||||
Returns:
|
||||
List of per-request params which configure the engine for that request's
|
||||
enabled logitproc
|
||||
"""
|
||||
batch_size = len(logitsprocs_types) * reqs_per_logitproc
|
||||
# Generate multiple repeats of key params for each logitproc;
|
||||
# apply random inverse permutation to the iteration
|
||||
# over logitsprocs, such that logitsprocs are shuffled.
|
||||
batch_perm = random.sample(range(batch_size), k=batch_size)
|
||||
return [
|
||||
LogitsProcsRequestParams(
|
||||
workload_index=idx,
|
||||
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
|
||||
for idx, pdx in enumerate(batch_perm)
|
||||
]
|
||||
|
||||
|
||||
def _raise_error_invalid(
|
||||
msg_suffix: str,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
err_cls: type[Exception] = ValueError,
|
||||
) -> None:
|
||||
raise err_cls(f"Validation failed for step={step_idx}, "
|
||||
f"batch_index={batch_index}, "
|
||||
f"workload_index={request_params.workload_index}, "
|
||||
f"req_params={request_params}. Reason: {msg_suffix}")
|
||||
|
||||
|
||||
def _logit_bias_params(kwargs: dict) -> None:
|
||||
"""Logit bias config"""
|
||||
kwargs["logit_bias"] = {
|
||||
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
|
||||
}
|
||||
|
||||
|
||||
def _logit_bias_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate logit bias logitproc applied correctly"""
|
||||
logit_bias = request_params.params.logit_bias
|
||||
logits_old = (
|
||||
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
||||
logits_new = logits_new[batch_index].cpu()
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logit_old_value = logits_old[token_id]
|
||||
logit_new_value = logits_new[token_id]
|
||||
if token_id in logit_bias:
|
||||
bias_value = logit_bias[token_id]
|
||||
exp_value = bias_value + logit_old_value
|
||||
if logit_new_value != pytest.approx(exp_value):
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Biased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {exp_value} "
|
||||
f"given bias {bias_value}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
else:
|
||||
if logit_new_value != pytest.approx(logit_old_value):
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Unbiased token {token_id} logit value {logit_new_value} "
|
||||
f"does not match expected value {logit_old_value}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
def _min_p_params(kwargs: dict) -> None:
|
||||
"""Min-p logitproc config"""
|
||||
kwargs["min_p"] = 0.1
|
||||
|
||||
|
||||
def _min_p_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate min-p logitproc applied correctly"""
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logits_for_token = logits_new[batch_index][token_id]
|
||||
if token_id == 0:
|
||||
# Dominant token should always be unmasked
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix="Invalid: dominant token 0 masked (-inf)",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
else:
|
||||
if request_params.params.min_p > 0.0:
|
||||
# Non-dominant tokens should be masked when min_p > 0
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=
|
||||
f"Invalid: non-dominant token {token_id} not masked",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
else:
|
||||
# No masking when min_p is 0
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=
|
||||
f"Invalid: token {token_id} masked when min_p=0.0",
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
def _min_tokens_params(kwargs: dict) -> None:
|
||||
"""Min-tokens logitproc config"""
|
||||
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
|
||||
kwargs["stop_token_ids"] = [
|
||||
np.random.randint(0, VOCAB_SIZE - 1)
|
||||
for _ in range(np.random.randint(0, VOCAB_SIZE))
|
||||
]
|
||||
|
||||
|
||||
def _min_tokens_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate min-tokens logitsproc applied correctly"""
|
||||
ref_num_out_tokens = len(request_params.out_tokens)
|
||||
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
|
||||
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
|
||||
mt_lp: MinTokensLogitsProcessor = next(
|
||||
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
|
||||
assert isinstance(mt_lp, MinTokensLogitsProcessor)
|
||||
min_tok = mt_lp.min_toks.get(batch_index, None)
|
||||
|
||||
# Validate min-token logits processor state
|
||||
if min_tok:
|
||||
(_, out_tok, all_stop_token_ids) = min_tok
|
||||
num_out_tokens = len(out_tok)
|
||||
if num_out_tokens != ref_num_out_tokens:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Number of output tokens in min-token logit processor "
|
||||
f"request metadata ({num_out_tokens}) does not match "
|
||||
f"reference ({ref_num_out_tokens})."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
if ref_all_stop_token_ids != all_stop_token_ids:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Stop token ids do not match reference; all_stop_token_ids: "
|
||||
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
|
||||
f"{sorted(ref_all_stop_token_ids)}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
if min_reached:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Expected min-tokens request with min reached, but batch "
|
||||
"index is recognized by min-tokens logits processor."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError)
|
||||
|
||||
elif not min_reached:
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
"Expected min-tokens request with min not reached, but batch "
|
||||
"index is not recognized by min-tokens logits processor."),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx,
|
||||
err_cls=RuntimeError)
|
||||
|
||||
# Validate min-token logits
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
logits_for_token = logits_new[batch_index][token_id]
|
||||
if token_id in ref_all_stop_token_ids and not min_reached:
|
||||
if logits_for_token != -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(f"Token {token_id} is a stop token and "
|
||||
"the sequence has not reached min length, "
|
||||
"but the token is not masked "
|
||||
f"(logit={logits_for_token})"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
else:
|
||||
if logits_for_token == -float("inf"):
|
||||
_raise_error_invalid(
|
||||
msg_suffix=(f"Token {token_id} should not be masked but "
|
||||
f"is (output len={ref_num_out_tokens})"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
def _none_validate(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
logits_new: torch.Tensor,
|
||||
batch_index: int,
|
||||
request_params: LogitsProcsRequestParams,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
"""Validate that no logits processors are applied"""
|
||||
logits = (
|
||||
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
|
||||
ref_logits = logits_new[batch_index]
|
||||
if not torch.all(ref_logits == logits):
|
||||
mismatch_toks = (ref_logits
|
||||
!= logits).nonzero(as_tuple=True)[0].tolist()
|
||||
mismatch_strs = []
|
||||
for token in mismatch_toks:
|
||||
val = float(logits[token])
|
||||
ref_val = float(ref_logits[token])
|
||||
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
|
||||
_raise_error_invalid(msg_suffix=(
|
||||
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
class LogitsprocTestHelpers(NamedTuple):
|
||||
"""Supports setting up and validating logitsprocs unit tests."""
|
||||
eval_fxn: Callable
|
||||
gen_request_fxn: Optional[Callable] = None
|
||||
|
||||
|
||||
logitsprocs_test_mapping = {
|
||||
STR_NO_LOGITPROC:
|
||||
LogitsprocTestHelpers(eval_fxn=_none_validate),
|
||||
LogitBiasLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
|
||||
eval_fxn=_logit_bias_validate),
|
||||
MinPLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
|
||||
eval_fxn=_min_p_validate),
|
||||
MinTokensLogitsProcessor:
|
||||
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
|
||||
eval_fxn=_min_tokens_validate),
|
||||
}
|
||||
|
||||
|
||||
def _get_test_cases() -> list[list[str]]:
|
||||
"""Each test case is a set of logitsprocs"""
|
||||
logitsprocs_types = list(logitsprocs_test_mapping.keys())
|
||||
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
|
||||
for logitproc_type in logitsprocs_types
|
||||
if logitproc_type != STR_NO_LOGITPROC
|
||||
] + [logitsprocs_types]
|
||||
|
||||
|
||||
def _generate_fake_step_update(
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
workload_params: list[LogitsProcsRequestParams],
|
||||
wdx: int,
|
||||
batch_update_builder: BatchUpdateBuilder,
|
||||
) -> tuple[Optional[BatchUpdate], int, int]:
|
||||
batch_size = len(persistent_batch)
|
||||
workload_size = len(workload_params)
|
||||
workload_reqs_remaining = workload_size - wdx
|
||||
max_add_remove_per_step = max(1, int(0.2 * workload_size))
|
||||
|
||||
# 50% of steps: add no reqs
|
||||
# Other 50%: add a limited number of reqs (less than the number
|
||||
# of workload reqs remaining, less than an arbitrary max)
|
||||
# If no workload reqs remain: 100% of steps have 0 adds
|
||||
num_step_add = random.choice([
|
||||
0,
|
||||
random.randint(1, min(max_add_remove_per_step,
|
||||
workload_reqs_remaining))
|
||||
]) if workload_reqs_remaining else 0
|
||||
|
||||
# 50% of steps: remove no requests
|
||||
# Other 50%: remove a limited number of reqs (less than the number
|
||||
# persistent batch reqs remaining, less than an arbitrary max)
|
||||
# If persistent batch is empty: 100% of steps have 0 removals until
|
||||
# more requests are added. Assume that removed requests are always
|
||||
# drawn from the current batch, before new adds
|
||||
num_step_remove = random.choice([
|
||||
0, random.randint(1, min(max_add_remove_per_step, batch_size))
|
||||
]) if batch_size else 0
|
||||
|
||||
num_step_add_replace = min(num_step_add, num_step_remove)
|
||||
|
||||
# Generate fake removed request indices drawn from persistent batch indices
|
||||
for removal in random.sample(range(batch_size), num_step_remove):
|
||||
batch_update_builder.removed_append(removal)
|
||||
|
||||
# Get added requests from workload
|
||||
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
|
||||
# Replace as many removed requests as possible with added requests
|
||||
add_remove_idx = batch_update_builder.pop_removed()
|
||||
batch_update_builder.added.append(
|
||||
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
|
||||
persistent_batch[add_remove_idx] = add_req_params
|
||||
|
||||
# Append remaining added requests to end of batch
|
||||
add_reqs_append = workload_params[(wdx +
|
||||
num_step_add_replace):(wdx +
|
||||
num_step_add)]
|
||||
batch_update_builder.added.extend([
|
||||
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
|
||||
for adx, add_req_params in enumerate(add_reqs_append)
|
||||
])
|
||||
persistent_batch.extend(add_reqs_append)
|
||||
pre_condense_batch_size = len(persistent_batch)
|
||||
wdx += num_step_add # Update workload offset
|
||||
|
||||
# Simulate condensing persistent batch
|
||||
last_nonempty_index = pre_condense_batch_size - 1
|
||||
condensed_to_idxs = set()
|
||||
while batch_update_builder.removed:
|
||||
if (last_nonempty_index in batch_update_builder.removed
|
||||
or last_nonempty_index in condensed_to_idxs):
|
||||
last_nonempty_index -= 1
|
||||
continue
|
||||
# last_nonempty_index is the highest persistent batch index that was
|
||||
# not removed
|
||||
first_empty_index = batch_update_builder.peek_removed()
|
||||
assert first_empty_index is not None
|
||||
if first_empty_index > last_nonempty_index:
|
||||
break
|
||||
# first_empty_index is the lowest removed persistent batch index
|
||||
# that is less than last_nonempty_index
|
||||
#
|
||||
# move last_nonempty_index -> first_empty_index
|
||||
batch_update_builder.pop_removed()
|
||||
condensed_to_idxs.add(first_empty_index)
|
||||
persistent_batch[first_empty_index] = persistent_batch[
|
||||
last_nonempty_index]
|
||||
batch_update_builder.moved.append(
|
||||
(last_nonempty_index, first_empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
|
||||
last_nonempty_index -= 1
|
||||
|
||||
# Now removed requests & gaps left by non-removed requests that got
|
||||
# moved downward are grouped consecutively in the upper indices of
|
||||
# the persistent batch. Truncate them to get condensed persistent batch
|
||||
condensed_batch_size = batch_size + num_step_add - num_step_remove
|
||||
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
||||
|
||||
if condensed_batch_size > 1:
|
||||
# Simulate arbitrary reorder_batch() in the kernel backend
|
||||
# Generate a random number k of non-overlapping swap tuples
|
||||
k = random.randint(0, condensed_batch_size // 2)
|
||||
idxs = list(range(condensed_batch_size))
|
||||
random.shuffle(idxs)
|
||||
swaps = [
|
||||
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
|
||||
]
|
||||
batch_update_builder.moved.extend([
|
||||
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
|
||||
])
|
||||
for adx, bdx in swaps:
|
||||
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
|
||||
bdx], persistent_batch[adx]
|
||||
|
||||
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
|
||||
workload_size - wdx)
|
||||
|
||||
|
||||
def _assert_valid(
|
||||
batch_size: int,
|
||||
persistent_batch: list[LogitsProcsRequestParams],
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
slice_idxs: list[int],
|
||||
logits_w_lp: torch.Tensor,
|
||||
step_idx: int,
|
||||
) -> None:
|
||||
if not slice_idxs:
|
||||
# Trivial case of empty persistent batch
|
||||
assert len(persistent_batch) == 0
|
||||
if logits_w_lp.shape[0] != 0:
|
||||
raise ValueError("Fake persistent batch is empty but logitsprocs "
|
||||
f"output batch has shape {logits_w_lp.shape}")
|
||||
return
|
||||
|
||||
# Validate logits for each fake request
|
||||
for batch_index in range(batch_size):
|
||||
request_params = persistent_batch[batch_index]
|
||||
# Invoke the appropriate validation function for
|
||||
# the logitproc employed by this request
|
||||
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
|
||||
fxn(test_fakes=test_fakes,
|
||||
persistent_batch=persistent_batch,
|
||||
logits_new=logits_w_lp,
|
||||
batch_index=batch_index,
|
||||
request_params=request_params,
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
|
||||
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
|
||||
def test_logitsprocs(device: str, reqs_per_logitproc: int,
|
||||
logitsprocs_under_test: list[str]):
|
||||
random.seed(40)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# Define a shuffled batch of requests which individually use a different
|
||||
# logitproc, or no logitproc at all
|
||||
workload_params = _generate_mixed_logitsprocs_batch_params(
|
||||
reqs_per_logitproc=reqs_per_logitproc,
|
||||
logitsprocs_types=logitsprocs_under_test)
|
||||
workload_size = len(workload_params)
|
||||
|
||||
# Create fake test data structures for testing.
|
||||
test_fakes = _generate_test_fakes(workload_size, device)
|
||||
|
||||
wdx = 0 # Next request index in workload to add
|
||||
persistent_batch: list[LogitsProcsRequestParams] = [
|
||||
] # Persistent batch state, as list of workload indices
|
||||
|
||||
# Generate fake removed request indices from current persistent
|
||||
# batch before adds
|
||||
batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# Break when entire workload has been added previously and persistent
|
||||
# batch is empty
|
||||
workload_reqs_remaining = workload_size
|
||||
batch_size = 0
|
||||
step_idx = 0
|
||||
while True:
|
||||
if not (workload_reqs_remaining or batch_size):
|
||||
break
|
||||
|
||||
(
|
||||
batch_update,
|
||||
wdx,
|
||||
workload_reqs_remaining,
|
||||
) = _generate_fake_step_update(
|
||||
persistent_batch=persistent_batch,
|
||||
workload_params=workload_params,
|
||||
wdx=wdx,
|
||||
batch_update_builder=batch_update_builder,
|
||||
)
|
||||
batch_size = len(persistent_batch)
|
||||
|
||||
# Apply fake batch update to logitsprocs
|
||||
fake_update_logitsprocs_state(test_fakes, batch_update)
|
||||
|
||||
# Emulate application of logits processors in engine
|
||||
slice_idxs = [req.workload_index for req in persistent_batch]
|
||||
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
|
||||
|
||||
_assert_valid(
|
||||
batch_size=batch_size,
|
||||
persistent_batch=persistent_batch,
|
||||
test_fakes=test_fakes,
|
||||
slice_idxs=slice_idxs,
|
||||
logits_w_lp=logits_w_lp,
|
||||
step_idx=step_idx,
|
||||
)
|
||||
|
||||
step_idx += 1
|
||||
@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
|
||||
|
||||
# FIXME(rob): enable prefix caching once supported.
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
|
||||
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
|
||||
SERVER_ARGS = [
|
||||
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
|
||||
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests",
|
||||
"--gpu-memory-utilization=0.8"
|
||||
]
|
||||
NUM_CONCURRENT = 100
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler)
|
||||
@ -58,7 +59,6 @@ def create_sampling_metadata(
|
||||
all_random=not all_greedy,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=torch.empty(1, ),
|
||||
generators=generators,
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
@ -67,10 +67,9 @@ def create_sampling_metadata(
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
min_tokens={},
|
||||
logit_bias=[None],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -8,10 +8,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
|
||||
MAX_NUM_REQS = 256
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
CUDA_DEVICES = [
|
||||
@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
|
||||
)
|
||||
|
||||
|
||||
def _create_logit_bias(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
bias_value: float,
|
||||
) -> list[Optional[dict[int, float]]]:
|
||||
res: list[Optional[dict[int, float]]] = []
|
||||
for i in range(batch_size):
|
||||
logit_bias = {min(i, vocab_size - 1): bias_value}
|
||||
res.append(logit_bias)
|
||||
return res
|
||||
|
||||
|
||||
def _create_allowed_token_ids(
|
||||
batch_size: int,
|
||||
vocab_size: int,
|
||||
@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
|
||||
all_random=False,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
no_penalties=True,
|
||||
min_tokens={},
|
||||
logit_bias=[None] * batch_size,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _generate_min_token_penalties_and_stop_tokens(
|
||||
num_output_tokens: int, batch_size: int, vocab_size: int,
|
||||
batch_indices_for_min_token_penalty: list[int]
|
||||
) -> dict[int, tuple[int, set[int]]]:
|
||||
"""
|
||||
Generates and returns a dict of minimum token penalties and
|
||||
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
|
||||
batch.
|
||||
|
||||
If a batch index is included in `batch_indices_for_min_token_penalty`,
|
||||
a higher `min_tokens` value is assigned (within a randomized range),
|
||||
and a random set of stop token IDs is created. Otherwise, a lower
|
||||
`min_tokens` value is assigned, and the stop token IDs set is empty.
|
||||
"""
|
||||
min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
for index in range(batch_size):
|
||||
if index in batch_indices_for_min_token_penalty:
|
||||
min_tokens[index] = (
|
||||
np.random.randint(num_output_tokens + 1,
|
||||
2 * num_output_tokens),
|
||||
set(
|
||||
np.random.randint(0, vocab_size - 1)
|
||||
for _ in range(np.random.randint(0, vocab_size))))
|
||||
else:
|
||||
min_tokens[index] = (np.random.randint(0,
|
||||
num_output_tokens), set())
|
||||
return min_tokens
|
||||
|
||||
|
||||
def _create_weighted_output_token_list(
|
||||
batch_size: int,
|
||||
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
|
||||
@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
|
||||
return output_token_ids, sorted_token_ids_in_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
||||
"""
|
||||
Tests that if the number of output tokens is less than
|
||||
SamplingParams.min_tokens then we will set the logits for
|
||||
the stop token ids to -inf.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
batch_indices_for_min_token_penalty = np.random.randint(
|
||||
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
|
||||
min_tokens = _generate_min_token_penalties_and_stop_tokens(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
|
||||
batch_indices_for_min_token_penalty)
|
||||
sampling_metadata.min_tokens = min_tokens
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
|
||||
if token_id in stop_token_ids:
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
|
||||
@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
||||
or non_penalized_token_id in output_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("min_p", [0.0, 0.1])
|
||||
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
|
||||
"""
|
||||
Tests that when min_p is applied, tokens with probability below
|
||||
min_p * max_prob are masked with -inf.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
|
||||
# Create one dominant token per batch
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, 0] = 10.0 # High logit for first token
|
||||
fake_logits[i, 1:] = 1e-2 # Others remain low
|
||||
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
|
||||
# Configure min_p parameters
|
||||
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
|
||||
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
|
||||
logits = logits.cpu()
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if token_id == 0:
|
||||
# Dominant token should always be unmasked
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
else:
|
||||
if min_p > 0.0:
|
||||
# Non-dominant tokens should be masked when min_p > 0
|
||||
assert logits[batch_idx][token_id] == -float("inf")
|
||||
else:
|
||||
# No masking when min_p is 0
|
||||
assert logits[batch_idx][token_id] != -float("inf")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
|
||||
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
||||
"""
|
||||
Test to verify that when the repetition penalty is enabled, tokens
|
||||
are penalized based on their presence in the prompt or the existing
|
||||
output.
|
||||
"""
|
||||
torch.set_default_device(device)
|
||||
# Create fake logits where each token is assigned the same
|
||||
# logit value.
|
||||
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
|
||||
sampling_metadata = _create_default_sampling_metadata(
|
||||
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
|
||||
sampling_metadata.logit_bias = _create_logit_bias(
|
||||
batch_size=batch_size,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
bias_value=bias_value,
|
||||
)
|
||||
sampler = Sampler()
|
||||
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
|
||||
logits = logits.cpu()
|
||||
for batch_idx in range(batch_size):
|
||||
logits_for_req = logits[batch_idx]
|
||||
biased_index = min(batch_idx, VOCAB_SIZE - 1)
|
||||
for token_id in range(VOCAB_SIZE):
|
||||
if biased_index == token_id:
|
||||
assert logits_for_req[token_id] == pytest.approx(bias_value +
|
||||
1e-2)
|
||||
else:
|
||||
assert logits_for_req[token_id] == pytest.approx(1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32])
|
||||
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterator
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import CompletionOutput
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
class BatchLogprobsComposition(Enum):
|
||||
@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
|
||||
logprobs = completion_output.logprobs
|
||||
assert logprobs is not None
|
||||
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
|
||||
|
||||
|
||||
def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
|
||||
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
|
||||
return fake_logits
|
||||
|
||||
|
||||
def create_penalty_tensor(batch_size: int, penalty_value: float,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
return torch.full((batch_size, ),
|
||||
fill_value=penalty_value,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
|
||||
def create_prompt_tokens_tensor(
|
||||
prompt_token_ids: list[list[int]],
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
return make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=vocab_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
|
||||
class LogitsprocsTestFakes(NamedTuple):
|
||||
"""Wraps fake data structures to support testing"""
|
||||
logits: torch.Tensor
|
||||
sampling_metadata: SamplingMetadata
|
||||
|
||||
def get_logitsprocs_by_cls(
|
||||
self,
|
||||
cls: type[LogitsProcessor],
|
||||
) -> Iterator[LogitsProcessor]:
|
||||
"""Yield logits processors of a specific class.
|
||||
|
||||
Args:
|
||||
cls: :class:`LogitsProcessor` subclass
|
||||
|
||||
Returns:
|
||||
Iterator over logits processors
|
||||
"""
|
||||
return (lp for lp in self.sampling_metadata.logitsprocs.all
|
||||
if isinstance(lp, cls))
|
||||
|
||||
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return self.sampling_metadata.logitsprocs.all
|
||||
|
||||
|
||||
def fake_update_logitsprocs_state(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
batch_update: BatchUpdate,
|
||||
) -> None:
|
||||
"""Imitate logits processors persistent batch state update
|
||||
in engine core"""
|
||||
for logitproc in test_fakes.get_logitsprocs():
|
||||
logitproc.update_state(batch_update)
|
||||
|
||||
|
||||
def fake_apply_logitsprocs(
|
||||
test_fakes: LogitsprocsTestFakes,
|
||||
slice_indices: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""Imitate application of logits processors in engine core"""
|
||||
logits = test_fakes.logits[torch.tensor(slice_indices,
|
||||
dtype=torch.long)].clone()
|
||||
for processor in test_fakes.get_logitsprocs():
|
||||
logits = processor.apply(logits)
|
||||
return logits
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@ -12,6 +13,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -26,13 +28,18 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
def _compare_objs(obj1,
|
||||
obj2,
|
||||
skip: Sequence = ("logitsprocs", "batch_update_builder")):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
a[0] for a in attrs
|
||||
if not (a[0].startswith('__') and a[0].endswith('__'))
|
||||
])
|
||||
for attr_name in attr_names:
|
||||
if attr_name in skip:
|
||||
continue
|
||||
|
||||
a = getattr(obj1, attr_name)
|
||||
b = getattr(obj2, attr_name)
|
||||
|
||||
@ -58,13 +65,11 @@ def _compare_objs(obj1, obj2):
|
||||
f" in {obj1} and {obj2}: {a} != {b}"
|
||||
|
||||
|
||||
def _remove_requests(
|
||||
input_batch: InputBatch, batch_size: int,
|
||||
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
|
||||
def _remove_requests(input_batch: InputBatch, batch_size: int,
|
||||
reqs: list[CachedRequestState]) -> set[str]:
|
||||
"""
|
||||
Remove some requests randomly from the batch and returns a tuple
|
||||
of 1) set of request removed 2) indices of the requests removed
|
||||
ordered in descending order
|
||||
Remove some requests randomly from the batch and returns
|
||||
set of request removed
|
||||
"""
|
||||
|
||||
num_reqs_to_remove = np.random.randint(0, batch_size)
|
||||
@ -73,13 +78,11 @@ def _remove_requests(
|
||||
req_index_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove.add(req_index_to_remove)
|
||||
|
||||
req_indices_to_remove_list = list(req_indices_to_remove)
|
||||
req_indices_to_remove_list.sort(reverse=True)
|
||||
req_ids_to_remove: set[str] = set()
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
return req_ids_to_remove, req_indices_to_remove_list
|
||||
return req_ids_to_remove
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
@ -100,7 +103,6 @@ def _construct_expected_sampling_metadata(
|
||||
repetition_penalties = [1.0 for _ in range(num_reqs)]
|
||||
top_k = [0 for _ in range(num_reqs)]
|
||||
top_p = [0.0 for _ in range(num_reqs)]
|
||||
min_p = [0.0 for _ in range(num_reqs)]
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
@ -123,7 +125,6 @@ def _construct_expected_sampling_metadata(
|
||||
req.sampling_params.repetition_penalty)
|
||||
top_k[index_in_input_batch] = req.sampling_params.top_k
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
min_p[index_in_input_batch] = req.sampling_params.min_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
min_tokens[index_in_input_batch] = (
|
||||
req.sampling_params.min_tokens,
|
||||
@ -145,8 +146,6 @@ def _construct_expected_sampling_metadata(
|
||||
top_p, dtype=torch.float, device=device),
|
||||
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||
top_k, dtype=torch.int, device=device),
|
||||
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
|
||||
min_p, dtype=torch.float, device=device),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
@ -165,13 +164,12 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
min_tokens=min_tokens,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
)
|
||||
|
||||
|
||||
@ -225,6 +223,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
|
||||
Note: Ignore logits processor logic, which is tested separately
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
@ -238,21 +238,22 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
input_batch.add_request(req, req_index)
|
||||
assigned_req_index = input_batch.add_request(req)
|
||||
assert req_index == assigned_req_index
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
|
||||
# Remove some requests
|
||||
req_ids_to_remove, req_indices_to_remove = _remove_requests(
|
||||
input_batch, batch_size, reqs)
|
||||
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
|
||||
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
|
||||
|
||||
# Compact the input batch
|
||||
input_batch.condense(req_indices_to_remove)
|
||||
input_batch.condense()
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch._make_sampling_metadata()
|
||||
@ -290,10 +291,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
sampling_metadata.prompt_token_ids)
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
|
||||
if sampling_metadata.allowed_token_ids_mask:
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
@ -315,6 +314,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
|
||||
Note: Ignore logits processor logic, which is tested separately
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
@ -341,7 +342,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
input_batch.add_request(req, req_index)
|
||||
assigned_req_index = input_batch.add_request(req)
|
||||
assert assigned_req_index == req_index
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
@ -354,9 +356,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
|
||||
for req_index in range(batch_size):
|
||||
req = reordered_reqs[req_index]
|
||||
ref_input_batch.add_request(req, req_index)
|
||||
assigned_req_index = ref_input_batch.add_request(req)
|
||||
assert assigned_req_index == req_index
|
||||
|
||||
input_batch.refresh_sampling_metadata()
|
||||
ref_input_batch.refresh_sampling_metadata()
|
||||
input_batch.refresh_metadata()
|
||||
ref_input_batch.refresh_metadata()
|
||||
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
|
||||
516
vllm/v1/sample/logits_processor.py
Normal file
516
vllm/v1/sample/logits_processor.py
Normal file
@ -0,0 +1,516 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = 0
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = 1
|
||||
|
||||
|
||||
# (index, params, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Note: each added request is represented as
|
||||
# (index, params, output_tok_ids)
|
||||
# Key assumption: output_tok_ids is a reference to the
|
||||
# request's running output tokens list; in this way
|
||||
# the logits processors always see the latest list of
|
||||
# generated tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from
|
||||
the persistent batch.
|
||||
|
||||
Must not be called after the first time
|
||||
self.removed, self.pop_removed() or
|
||||
self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure
|
||||
and reset internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
# Reset removed/moved/added update lists
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
TODO(andy): won't be utilized until logits
|
||||
processors are user-extensible
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional[BatchUpdate],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogitsProcessorManager:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # argmax-invariant logitsprocs
|
||||
non_argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # non-argmax-invariant logitsprocs
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
|
||||
|
||||
###### ----- Built-in LogitsProcessor impls below here
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, max_num_reqs: int, pin_memory: bool,
|
||||
device: DeviceLikeType):
|
||||
super().__init__()
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
# Pre-allocated device tensor
|
||||
self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
# Current slice of the device tensor
|
||||
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Min-p never impacts greedy sampling"""
|
||||
return True
|
||||
|
||||
def get_min_p_by_index(self, index: int) -> float:
|
||||
return float(self.min_p_cpu[index])
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
|
||||
if self.min_p_cpu[index] != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
if min_p:
|
||||
self.min_p_count += 1
|
||||
|
||||
if self.min_p_count:
|
||||
# Process removed requests.
|
||||
needs_update |= bool(batch_update.removed)
|
||||
for index in batch_update.removed:
|
||||
if self.min_p_cpu[index]:
|
||||
self.min_p_count -= 1
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
change = (min_p_a :=
|
||||
self.min_p_cpu[adx]) != (min_p_b :=
|
||||
self.min_p_cpu[bdx])
|
||||
needs_update |= change
|
||||
if change:
|
||||
self.min_p_cpu[bdx] = min_p_a
|
||||
if direct == MoveDirectionality.SWAP:
|
||||
self.min_p_cpu[adx] = min_p_b
|
||||
|
||||
# Update tensors if needed.
|
||||
size = batch_update.batch_size
|
||||
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
|
||||
self.min_p = self.min_p_device[:size]
|
||||
self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
|
||||
self.min_p.unsqueeze_(1)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.min_p_count:
|
||||
return logits
|
||||
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Adjust min_p
|
||||
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
||||
# Identify valid tokens using threshold comparison
|
||||
invalid_token_mask = probability_values < adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits[invalid_token_mask] = -float('inf')
|
||||
return logits
|
||||
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
super().__init__()
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (self._device_tensor([], torch.int32),
|
||||
self._device_tensor([], torch.int32))
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Logit bias can rebalance token probabilities and change the
|
||||
outcome of argmax in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
needs_update = bool(batch_update.added)
|
||||
for index, params, _ in batch_update.added:
|
||||
if isinstance(params, SamplingParams) and (lb :=
|
||||
params.logit_bias):
|
||||
self.biases[index] = lb
|
||||
else:
|
||||
self.biases.pop(index, None)
|
||||
|
||||
if self.biases:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if self.biases.pop(index, None):
|
||||
needs_update = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if (a_entry := self.biases.pop(a_index, None)) is None:
|
||||
if self.biases.pop(b_index, None) is not None:
|
||||
needs_update = True
|
||||
else:
|
||||
self.biases[b_index] = a_entry
|
||||
needs_update = True
|
||||
else:
|
||||
a_entry = self.biases.pop(a_index, None)
|
||||
if (b_entry := self.biases.pop(b_index, None)) is not None:
|
||||
self.biases[a_index] = b_entry
|
||||
needs_update = True
|
||||
if a_entry is not None:
|
||||
self.biases[b_index] = a_entry
|
||||
needs_update = True
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs, tok_ids, biases = [], [], []
|
||||
for req, lb in self.biases.items():
|
||||
reqs.extend([req] * len(lb))
|
||||
tok_ids.extend(lb.keys())
|
||||
biases.extend(lb.values())
|
||||
|
||||
self.bias_tensor = self._device_tensor(biases, torch.float32)
|
||||
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32))
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (torch.tensor(data,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory).to(device=self.device,
|
||||
non_blocking=True))
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.biases:
|
||||
logits[self.logits_slice] += self.bias_tensor
|
||||
return logits
|
||||
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
super().__init__()
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor,
|
||||
torch.Tensor] = (self._device_tensor(
|
||||
[], torch.int32),
|
||||
self._device_tensor(
|
||||
[], torch.int32))
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""By censoring stop tokens, min-tokens can change the outcome
|
||||
of the argmax operation in greedy sampling."""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
needs_update = False
|
||||
|
||||
if batch_update:
|
||||
# Process added requests.
|
||||
needs_update |= bool(batch_update.added)
|
||||
for index, params, output_tok_ids in batch_update.added:
|
||||
if (isinstance(params, SamplingParams)
|
||||
and (min_tokens := params.min_tokens)
|
||||
and len(output_tok_ids) < min_tokens):
|
||||
# Replace request metadata at batch index
|
||||
self.min_toks[index] = (min_tokens, output_tok_ids,
|
||||
params.all_stop_token_ids)
|
||||
else:
|
||||
# Drop request metadata at batch index
|
||||
self.min_toks.pop(index, None)
|
||||
|
||||
if self.min_toks:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
if self.min_toks.pop(index, None):
|
||||
needs_update = True
|
||||
|
||||
# Process moved requests, unidirectional (a->b) and
|
||||
# swapped (a<->b)
|
||||
for a_index, b_index, direct in batch_update.moved:
|
||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
||||
if (a_entry := self.min_toks.pop(a_index,
|
||||
None)) is None:
|
||||
if self.min_toks.pop(b_index, None) is not None:
|
||||
needs_update = True
|
||||
else:
|
||||
self.min_toks[b_index] = a_entry
|
||||
needs_update = True
|
||||
else:
|
||||
a_entry = self.min_toks.pop(a_index, None)
|
||||
if (b_entry := self.min_toks.pop(b_index,
|
||||
None)) is not None:
|
||||
self.min_toks[a_index] = b_entry
|
||||
needs_update = True
|
||||
if a_entry is not None:
|
||||
self.min_toks[b_index] = a_entry
|
||||
needs_update = True
|
||||
|
||||
if self.min_toks:
|
||||
# Check for any requests that have attained their min tokens.
|
||||
to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
||||
_) in self.min_toks.items()
|
||||
if len(out_tok_ids) >= min_toks)
|
||||
if to_remove:
|
||||
needs_update = True
|
||||
for index in to_remove:
|
||||
del self.min_toks[index]
|
||||
|
||||
# Update tensors if needed.
|
||||
if needs_update:
|
||||
reqs: list[int] = []
|
||||
tok_ids: list[int] = []
|
||||
for req, (_, _, stop_tok_ids) in self.min_toks.items():
|
||||
reqs.extend([req] * len(stop_tok_ids))
|
||||
tok_ids.extend(stop_tok_ids)
|
||||
|
||||
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
||||
self._device_tensor(tok_ids, torch.int32))
|
||||
|
||||
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (torch.tensor(data,
|
||||
device="cpu",
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory).to(device=self.device,
|
||||
non_blocking=True))
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.min_toks:
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits[self.logits_slice] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
|
||||
device: torch.device) -> LogitsProcessorManager:
|
||||
"""Construct 'builtin' vLLM logitsprocs which the engine
|
||||
loads by default.
|
||||
|
||||
Args:
|
||||
pin_memory_available: pinned memory is available for use
|
||||
for use by logitsproc
|
||||
max_num_reqs: ceiling on request count in persistent batch
|
||||
device: inference device
|
||||
|
||||
Returns:
|
||||
Data structure encapsulating loaded logitsprocs
|
||||
"""
|
||||
min_tokens_logitproc = MinTokensLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
logit_bias_logitproc = LogitBiasLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
min_p_logitproc = MinPLogitsProcessor(
|
||||
pin_memory=pin_memory_available,
|
||||
device=device,
|
||||
# +1 for temporary swap space
|
||||
max_num_reqs=max_num_reqs + 1)
|
||||
return LogitsProcessorManager(
|
||||
non_argmax_invariant=[
|
||||
min_tokens_logitproc,
|
||||
logit_bias_logitproc,
|
||||
],
|
||||
argmax_invariant=[min_p_logitproc],
|
||||
)
|
||||
@ -6,6 +6,8 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplingMetadata:
|
||||
@ -16,7 +18,6 @@ class SamplingMetadata:
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
min_p: Optional[torch.Tensor]
|
||||
|
||||
generators: dict[int, torch.Generator]
|
||||
|
||||
@ -31,14 +32,12 @@ class SamplingMetadata:
|
||||
|
||||
output_token_ids: list[list[int]]
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
min_tokens: dict[int, tuple[int, set[int]]]
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]]
|
||||
|
||||
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
|
||||
# vocab size).
|
||||
allowed_token_ids_mask: Optional[torch.Tensor]
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessorManager
|
||||
|
||||
@ -7,22 +7,6 @@ from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_min_token_penalties(
|
||||
logits: torch.Tensor, output_token_ids: list[list[int]],
|
||||
min_tokens: dict[int, tuple[int, set[int]]]) -> None:
|
||||
"""
|
||||
Applies minimum token penalty by setting the logits of the stop tokens
|
||||
to -inf.
|
||||
"""
|
||||
min_tokens_logits_to_penalize: list[tuple[int, int]] = []
|
||||
for index, (min_token, stop_token_ids) in min_tokens.items():
|
||||
if len(output_token_ids[index]) < min_token:
|
||||
for stop_token_id in stop_token_ids:
|
||||
min_tokens_logits_to_penalize.append((index, stop_token_id))
|
||||
if min_tokens_logits_to_penalize:
|
||||
logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf")
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
|
||||
@ -5,12 +5,11 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.utils import async_tensor_h2d, is_pin_memory_available
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
@ -44,8 +43,11 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
# Apply logits bias.
|
||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
# Sample the next token.
|
||||
@ -110,9 +112,10 @@ class Sampler(nn.Module):
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
@ -187,10 +190,6 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
if sampling_metadata.min_tokens:
|
||||
apply_min_token_penalties(logits,
|
||||
sampling_metadata.output_token_ids,
|
||||
sampling_metadata.min_tokens)
|
||||
if not sampling_metadata.no_penalties:
|
||||
assert sampling_metadata.prompt_token_ids is not None
|
||||
logits = apply_all_penalties(
|
||||
@ -203,65 +202,6 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return logits
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits[~valid_token_mask] = -float('inf')
|
||||
return logits
|
||||
|
||||
def apply_logits_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
# TODO(houseroad): this implementation is extremely inefficient.
|
||||
# One idea is implement this as a PyTorch C++ op, and we may
|
||||
# even optimize the logit_bias layout.
|
||||
|
||||
rows: list[int] = []
|
||||
cols: list[int] = []
|
||||
vals: list[float] = []
|
||||
|
||||
# Get vocabulary size from logits
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
|
||||
if logit_bias:
|
||||
for token_id, bias in logit_bias.items():
|
||||
# Check token_id bounds to ensure within vocabulary
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
raise ValueError(
|
||||
f"token_id {token_id} in logit_bias contains "
|
||||
f"out-of-vocab token id. Vocabulary size: "
|
||||
f"{vocab_size}")
|
||||
rows.append(i)
|
||||
cols.append(token_id)
|
||||
vals.append(bias)
|
||||
|
||||
if rows:
|
||||
indices = async_tensor_h2d([rows, cols], torch.int64,
|
||||
logits.device, self.pin_memory)
|
||||
values = async_tensor_h2d(vals, torch.float, logits.device,
|
||||
self.pin_memory)
|
||||
logits.index_put_(tuple(indices), values=values, accumulate=True)
|
||||
return logits
|
||||
|
||||
def apply_allowed_token_ids(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -1,23 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
|
||||
if req_id in input_batch.min_p_reqs:
|
||||
# Spec decode doesn't support min_p sampling.
|
||||
return False
|
||||
elif (req_id in input_batch.frequency_penalties_reqs
|
||||
or req_id in input_batch.presence_penalties_reqs
|
||||
or req_id in input_batch.repetition_penalties_reqs):
|
||||
# Spec decode doesn't support penalties.
|
||||
return False
|
||||
elif req_id in input_batch.num_logprobs:
|
||||
# Spec decode doesn't support logprobs.
|
||||
return False
|
||||
|
||||
return True
|
||||
def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
|
||||
"""True if request is incompatible with speculative decoding"""
|
||||
return (sampling_params.frequency_penalty != 0.0
|
||||
or sampling_params.presence_penalty != 0.0
|
||||
or sampling_params.repetition_penalty != 1.0
|
||||
or sampling_params.min_p > _SAMPLING_EPS
|
||||
or sampling_params.logprobs is not None)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@ -15,12 +15,14 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
MoveDirectionality,
|
||||
init_builtin_logitsprocs)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
@ -67,8 +69,10 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
is_spec_decode: bool = False,
|
||||
logits_processing_needs_token_ids: bool = False,
|
||||
):
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
@ -146,15 +150,8 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
# IDs of requests which do not support spec decoding
|
||||
self.spec_decode_unsupported_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
@ -194,9 +191,6 @@ class InputBatch:
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
@ -216,8 +210,20 @@ class InputBatch:
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
# Internal representation of per-step batch state changes, used for
|
||||
# reordering persistent batch and generating logitsprocs batch state
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
@ -240,14 +246,28 @@ class InputBatch:
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def _get_next_add_index(self) -> int:
|
||||
if (req_index := self.batch_update_builder.pop_removed()) is not None:
|
||||
# Fill the empty index.
|
||||
return req_index
|
||||
# Append to end
|
||||
return self.num_reqs
|
||||
|
||||
def _register_add_request(self, request: "CachedRequestState") -> int:
|
||||
"""Track add-request operations"""
|
||||
req_index = self._get_next_add_index()
|
||||
assert req_index < self.max_num_reqs
|
||||
params = (request.sampling_params
|
||||
if request.sampling_params else request.pooling_params)
|
||||
self.batch_update_builder.added.append(
|
||||
(req_index, params, request.output_token_ids))
|
||||
return req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
) -> int:
|
||||
req_index = self._register_add_request(request)
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
@ -278,6 +298,9 @@ class InputBatch:
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
@ -295,11 +318,8 @@ class InputBatch:
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
@ -310,10 +330,6 @@ class InputBatch:
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (
|
||||
sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
@ -325,8 +341,6 @@ class InputBatch:
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[
|
||||
req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
@ -368,12 +382,22 @@ class InputBatch:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
"""This method must always be followed by a call to condense().
|
||||
|
||||
Args:
|
||||
req_id: request to remove
|
||||
|
||||
Returns:
|
||||
Removed request index, or `None` if `req_id` not recognized
|
||||
"""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
@ -381,8 +405,7 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.spec_decode_unsupported_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
@ -400,7 +423,6 @@ class InputBatch:
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
@ -410,6 +432,8 @@ class InputBatch:
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
@ -439,8 +463,6 @@ class InputBatch:
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
@ -452,13 +474,10 @@ class InputBatch:
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
@ -467,12 +486,23 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
Any consecutive empty indices at the very end of the list are not
|
||||
filled.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
empty_req_indices: empty indices which may be filled.
|
||||
|
||||
Returns:
|
||||
swaps: list of (from,to) swap tuples for moved requests
|
||||
empty_req_indices: indices not filled by condensation
|
||||
"""
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
return
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
@ -489,11 +519,17 @@ class InputBatch:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
empty_index = self.batch_update_builder.peek_removed()
|
||||
assert empty_index is not None
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
@ -524,20 +560,14 @@ class InputBatch:
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
# TODO convert these to LogitsProcessors
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
@ -547,6 +577,7 @@ class InputBatch:
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@ -554,8 +585,17 @@ class InputBatch:
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
def refresh_sampling_metadata(self):
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
def refresh_metadata(self):
|
||||
"""Apply batch updates, reset input batch at end of step
|
||||
|
||||
* Apply batch add/remove/permute to logits procs' states
|
||||
* If batch state is modified, update sampling metadata
|
||||
"""
|
||||
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
||||
for logit_proc in self.logitsprocs.all:
|
||||
logit_proc.update_state(batch_update)
|
||||
if batch_update:
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
@ -568,8 +608,6 @@ class InputBatch:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
if not self.no_min_p:
|
||||
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
@ -607,7 +645,6 @@ class InputBatch:
|
||||
all_random=self.all_random,
|
||||
top_p=None if self.no_top_p else self.top_p[:num_reqs],
|
||||
top_k=None if self.no_top_k else self.top_k[:num_reqs],
|
||||
min_p=None if self.no_min_p else self.min_p[:num_reqs],
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
@ -615,11 +652,10 @@ class InputBatch:
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
logitsprocs=self.logitsprocs,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -702,10 +738,6 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
|
||||
@ -63,12 +63,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
@ -212,6 +212,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (
|
||||
@ -316,7 +317,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
@ -325,21 +326,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
|
||||
Returns:
|
||||
True if the batch was reordered, False otherwise.
|
||||
"""
|
||||
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
# For models with multiple KV cache groups, the groups should agree on
|
||||
# the same order of requests. We ensure this by only allowing the first
|
||||
# group to reorder the batch and asserting that all other groups do not
|
||||
# reorder the batch.
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
assert not self.attn_metadata_builders[i].reorder_batch(
|
||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
return batch_reordered
|
||||
assert not batch_reordered
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _init_device_properties(self) -> None:
|
||||
@ -372,11 +370,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# then resubmitted with the same ID. In this case, we treat them as two
|
||||
# distinct requests - clearing the cached states for the first request
|
||||
# and handling the second as a new request.
|
||||
removed_req_indices: list[int] = []
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
if req_index is not None:
|
||||
removed_req_indices.append(req_index)
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
@ -399,9 +394,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# have low request overlap (e.g., alternating between two distinct
|
||||
# sets of requests), this optimization becomes very inefficient.
|
||||
for req_id in unscheduled_req_ids:
|
||||
req_index = self.input_batch.remove_request(req_id)
|
||||
assert req_index is not None
|
||||
removed_req_indices.append(req_index)
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
req_ids_to_add: list[str] = []
|
||||
# Add new requests to the cached states.
|
||||
@ -545,31 +538,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): `num_tokens` here may include spec tokens.
|
||||
self.input_batch.num_tokens[req_index] = end_token_index
|
||||
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
removed_req_indices.sort(reverse=True)
|
||||
for req_id in req_ids_to_add:
|
||||
req_state = self.requests[req_id]
|
||||
if removed_req_indices:
|
||||
# Fill the empty index.
|
||||
req_index = removed_req_indices.pop()
|
||||
else:
|
||||
# Append to the end.
|
||||
req_index = None
|
||||
self.input_batch.add_request(req_state, req_index)
|
||||
self.input_batch.add_request(req_state)
|
||||
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
batch_reordered = self._may_reorder_batch(scheduler_output)
|
||||
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
# Condense the batched states if there are gaps left by removed requests
|
||||
self.input_batch.condense()
|
||||
# Allow attention backend to reorder the batch, potentially
|
||||
self._may_reorder_batch(scheduler_output)
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@ -1296,7 +1276,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, IntermediateTensors]:
|
||||
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
@ -1765,7 +1744,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if not is_spec_decode_supported(req_id, self.input_batch):
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
@ -2121,7 +2100,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
all_random=False,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
generators={},
|
||||
max_num_logprobs=None,
|
||||
no_penalties=True,
|
||||
@ -2130,10 +2108,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
presence_penalties=dummy_tensors(0.1),
|
||||
repetition_penalties=dummy_tensors(0.1),
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
)
|
||||
try:
|
||||
sampler_output = self.sampler(logits=logits,
|
||||
@ -2425,6 +2402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user