diff --git a/tests/v1/sample/test_logits_processors.py b/tests/v1/sample/test_logits_processors.py new file mode 100644 index 0000000000000..a8e230a97ed54 --- /dev/null +++ b/tests/v1/sample/test_logits_processors.py @@ -0,0 +1,626 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random +from collections.abc import Callable +from typing import NamedTuple, Optional, Union + +import numpy as np +import pytest +import torch + +from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, + create_penalty_tensor, + create_prompt_tokens_tensor, + fake_apply_logitsprocs, + fake_update_logitsprocs_state) +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams +from vllm.utils import is_pin_memory_available +# yapf: disable +from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + MoveDirectionality, + init_builtin_logitsprocs) +# yapf: enable +from vllm.v1.sample.metadata import SamplingMetadata + +PIN_MEMORY_AVAILABLE = is_pin_memory_available() +MAX_NUM_REQS = 256 +VOCAB_SIZE = 1024 +NUM_OUTPUT_TOKENS = 20 +CUDA_DEVICES = [ + f"{current_platform.device_type}:{i}" + for i in range(1 if current_platform.device_count() == 1 else 2) +] +MAX_NUM_PROMPT_TOKENS = 64 +MIN_TOKENS_LEN_THRESHOLD = 5 +REQS_PER_LOGITPROC = 50 +STR_NO_LOGITPROC = "none" + +# LogitsProcessor subclass or "none" +LogitprocType = Union[type[LogitsProcessor], str] + + +class LogitsProcsRequestParams: + """Encapsulates key params for a single request in a batch. + + Params can be customized based on the enabled logitproc + """ + workload_index: int + logitproc_type: LogitprocType # Logitproc enabled, specified by str id + out_tokens: list[int] # Output tokens required for min tokens test + params: SamplingParams # Settings customized for logitproc + + def __init__(self, workload_index: int, logitproc_type: LogitprocType): + self.workload_index = workload_index + self.logitproc_type = logitproc_type + # Number of output tokens is randomly 0 or twice the min-tokens + # threshold which will be used in testing. Output token values + # don't matter *for these tests* so use 0 as a dummy value + self.out_tokens = ([0] * + (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + self.params = _sampling_params_from_logitproc(logitproc_type) + + def __str__(self): + """For debugging""" + summ = ', '.join(f'{k}={v}' for k, v in vars(self).items()) + return f"MyClass({summ})" + + +def _generate_fake_sampling_metadata( + num_output_tokens: int, + batch_size: int, + vocab_size: int, + device: torch.device, +) -> SamplingMetadata: + """Generate fake sampling metadata with fake logitsprocs""" + output_token_ids: list[list[int]] = [] + prompt_token_ids: list[list[int]] = [] + for _ in range(batch_size): + output_token_ids.append( + np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + prompt_token_ids.append( + np.random.randint(0, + vocab_size, + size=np.random.randint( + 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + logitsprocs = init_builtin_logitsprocs( + pin_memory_available=PIN_MEMORY_AVAILABLE, + max_num_reqs=MAX_NUM_REQS + 1, + device=device) + + fake_sampling_metadata = SamplingMetadata( + temperature=torch.full((batch_size, ), 0.0), + all_greedy=True, + all_random=False, + top_p=None, + top_k=None, + generators={}, + max_num_logprobs=0, + prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids, + vocab_size, device), + output_token_ids=output_token_ids, + frequency_penalties=create_penalty_tensor(batch_size, 0.0, device), + presence_penalties=create_penalty_tensor(batch_size, 0.0, device), + repetition_penalties=create_penalty_tensor(batch_size, 1.0, device), + no_penalties=True, + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=logitsprocs) + return fake_sampling_metadata + + +def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes: + """Generate fake logits and sampling metadata""" + fake_logits = create_fake_logits(batch_size, VOCAB_SIZE) + # Create one dominant token per batch, to support min-p test + for i in range(batch_size): + fake_logits[i, 0] = 10.0 # High logit for first token + fake_logits[i, 1:] = 1e-2 # Others remain low + sampling_metadata = _generate_fake_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + return LogitsprocsTestFakes( + logits=fake_logits, + sampling_metadata=sampling_metadata, + ) + + +def _sampling_params_from_logitproc( + logitproc_type: LogitprocType) -> SamplingParams: + """Customize request SamplingParams for a specified logitproc""" + # SamplingParams for req with no logitproc + kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0} + if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn: + fxn(kwargs) + return SamplingParams(**kwargs) + + +def _generate_mixed_logitsprocs_batch_params( + reqs_per_logitproc: int, + logitsprocs_types: list[str], +) -> list[LogitsProcsRequestParams]: + """Define key params for a batch of requests with a different + logitproc enabled per request. + + The batch will have `reqs_per_logitproc` repeats for all + `logitsprocs_types` under test, including the case where + no logitsproc is enabled. The batch is randomly shuffled. The + size of the batch is `reqs_per_logitproc` times + `n = len(logitsprocs_types)` + + Args: + reqs_per_logitproc: number of requests using each logitproc + logitsprocs_types: logitsprocs under test + + Returns: + List of per-request params which configure the engine for that request's + enabled logitproc + """ + batch_size = len(logitsprocs_types) * reqs_per_logitproc + # Generate multiple repeats of key params for each logitproc; + # apply random inverse permutation to the iteration + # over logitsprocs, such that logitsprocs are shuffled. + batch_perm = random.sample(range(batch_size), k=batch_size) + return [ + LogitsProcsRequestParams( + workload_index=idx, + logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc]) + for idx, pdx in enumerate(batch_perm) + ] + + +def _raise_error_invalid( + msg_suffix: str, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, + err_cls: type[Exception] = ValueError, +) -> None: + raise err_cls(f"Validation failed for step={step_idx}, " + f"batch_index={batch_index}, " + f"workload_index={request_params.workload_index}, " + f"req_params={request_params}. Reason: {msg_suffix}") + + +def _logit_bias_params(kwargs: dict) -> None: + """Logit bias config""" + kwargs["logit_bias"] = { + random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2]) + } + + +def _logit_bias_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate logit bias logitproc applied correctly""" + logit_bias = request_params.params.logit_bias + logits_old = ( + test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits_new = logits_new[batch_index].cpu() + for token_id in range(VOCAB_SIZE): + logit_old_value = logits_old[token_id] + logit_new_value = logits_new[token_id] + if token_id in logit_bias: + bias_value = logit_bias[token_id] + exp_value = bias_value + logit_old_value + if logit_new_value != pytest.approx(exp_value): + _raise_error_invalid(msg_suffix=( + f"Biased token {token_id} logit value {logit_new_value} " + f"does not match expected value {exp_value} " + f"given bias {bias_value}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + else: + if logit_new_value != pytest.approx(logit_old_value): + _raise_error_invalid(msg_suffix=( + f"Unbiased token {token_id} logit value {logit_new_value} " + f"does not match expected value {logit_old_value}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + +def _min_p_params(kwargs: dict) -> None: + """Min-p logitproc config""" + kwargs["min_p"] = 0.1 + + +def _min_p_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate min-p logitproc applied correctly""" + for token_id in range(VOCAB_SIZE): + logits_for_token = logits_new[batch_index][token_id] + if token_id == 0: + # Dominant token should always be unmasked + if logits_for_token == -float("inf"): + _raise_error_invalid( + msg_suffix="Invalid: dominant token 0 masked (-inf)", + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + else: + if request_params.params.min_p > 0.0: + # Non-dominant tokens should be masked when min_p > 0 + if logits_for_token != -float("inf"): + _raise_error_invalid( + msg_suffix= + f"Invalid: non-dominant token {token_id} not masked", + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + else: + # No masking when min_p is 0 + if logits_for_token == -float("inf"): + _raise_error_invalid( + msg_suffix= + f"Invalid: token {token_id} masked when min_p=0.0", + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + +def _min_tokens_params(kwargs: dict) -> None: + """Min-tokens logitproc config""" + kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD + kwargs["stop_token_ids"] = [ + np.random.randint(0, VOCAB_SIZE - 1) + for _ in range(np.random.randint(0, VOCAB_SIZE)) + ] + + +def _min_tokens_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate min-tokens logitsproc applied correctly""" + ref_num_out_tokens = len(request_params.out_tokens) + min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD + ref_all_stop_token_ids = request_params.params.all_stop_token_ids + mt_lp: MinTokensLogitsProcessor = next( + test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)) + assert isinstance(mt_lp, MinTokensLogitsProcessor) + min_tok = mt_lp.min_toks.get(batch_index, None) + + # Validate min-token logits processor state + if min_tok: + (_, out_tok, all_stop_token_ids) = min_tok + num_out_tokens = len(out_tok) + if num_out_tokens != ref_num_out_tokens: + _raise_error_invalid(msg_suffix=( + "Number of output tokens in min-token logit processor " + f"request metadata ({num_out_tokens}) does not match " + f"reference ({ref_num_out_tokens})."), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + if ref_all_stop_token_ids != all_stop_token_ids: + _raise_error_invalid(msg_suffix=( + "Stop token ids do not match reference; all_stop_token_ids: " + f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " + f"{sorted(ref_all_stop_token_ids)}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + if min_reached: + _raise_error_invalid(msg_suffix=( + "Expected min-tokens request with min reached, but batch " + "index is recognized by min-tokens logits processor."), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError) + + elif not min_reached: + _raise_error_invalid(msg_suffix=( + "Expected min-tokens request with min not reached, but batch " + "index is not recognized by min-tokens logits processor."), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError) + + # Validate min-token logits + for token_id in range(VOCAB_SIZE): + logits_for_token = logits_new[batch_index][token_id] + if token_id in ref_all_stop_token_ids and not min_reached: + if logits_for_token != -float("inf"): + _raise_error_invalid( + msg_suffix=(f"Token {token_id} is a stop token and " + "the sequence has not reached min length, " + "but the token is not masked " + f"(logit={logits_for_token})"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + else: + if logits_for_token == -float("inf"): + _raise_error_invalid( + msg_suffix=(f"Token {token_id} should not be masked but " + f"is (output len={ref_num_out_tokens})"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + +def _none_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate that no logits processors are applied""" + logits = ( + test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + ref_logits = logits_new[batch_index] + if not torch.all(ref_logits == logits): + mismatch_toks = (ref_logits + != logits).nonzero(as_tuple=True)[0].tolist() + mismatch_strs = [] + for token in mismatch_toks: + val = float(logits[token]) + ref_val = float(ref_logits[token]) + mismatch_strs.append(f"({token=},{val=},{ref_val=})") + _raise_error_invalid(msg_suffix=( + f"Unexpected modification of logits: {','.join(mismatch_strs)}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + +class LogitsprocTestHelpers(NamedTuple): + """Supports setting up and validating logitsprocs unit tests.""" + eval_fxn: Callable + gen_request_fxn: Optional[Callable] = None + + +logitsprocs_test_mapping = { + STR_NO_LOGITPROC: + LogitsprocTestHelpers(eval_fxn=_none_validate), + LogitBiasLogitsProcessor: + LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params, + eval_fxn=_logit_bias_validate), + MinPLogitsProcessor: + LogitsprocTestHelpers(gen_request_fxn=_min_p_params, + eval_fxn=_min_p_validate), + MinTokensLogitsProcessor: + LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params, + eval_fxn=_min_tokens_validate), +} + + +def _get_test_cases() -> list[list[str]]: + """Each test case is a set of logitsprocs""" + logitsprocs_types = list(logitsprocs_test_mapping.keys()) + return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] + for logitproc_type in logitsprocs_types + if logitproc_type != STR_NO_LOGITPROC + ] + [logitsprocs_types] + + +def _generate_fake_step_update( + persistent_batch: list[LogitsProcsRequestParams], + workload_params: list[LogitsProcsRequestParams], + wdx: int, + batch_update_builder: BatchUpdateBuilder, +) -> tuple[Optional[BatchUpdate], int, int]: + batch_size = len(persistent_batch) + workload_size = len(workload_params) + workload_reqs_remaining = workload_size - wdx + max_add_remove_per_step = max(1, int(0.2 * workload_size)) + + # 50% of steps: add no reqs + # Other 50%: add a limited number of reqs (less than the number + # of workload reqs remaining, less than an arbitrary max) + # If no workload reqs remain: 100% of steps have 0 adds + num_step_add = random.choice([ + 0, + random.randint(1, min(max_add_remove_per_step, + workload_reqs_remaining)) + ]) if workload_reqs_remaining else 0 + + # 50% of steps: remove no requests + # Other 50%: remove a limited number of reqs (less than the number + # persistent batch reqs remaining, less than an arbitrary max) + # If persistent batch is empty: 100% of steps have 0 removals until + # more requests are added. Assume that removed requests are always + # drawn from the current batch, before new adds + num_step_remove = random.choice([ + 0, random.randint(1, min(max_add_remove_per_step, batch_size)) + ]) if batch_size else 0 + + num_step_add_replace = min(num_step_add, num_step_remove) + + # Generate fake removed request indices drawn from persistent batch indices + for removal in random.sample(range(batch_size), num_step_remove): + batch_update_builder.removed_append(removal) + + # Get added requests from workload + for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]: + # Replace as many removed requests as possible with added requests + add_remove_idx = batch_update_builder.pop_removed() + batch_update_builder.added.append( + (add_remove_idx, add_req_params.params, add_req_params.out_tokens)) + persistent_batch[add_remove_idx] = add_req_params + + # Append remaining added requests to end of batch + add_reqs_append = workload_params[(wdx + + num_step_add_replace):(wdx + + num_step_add)] + batch_update_builder.added.extend([ + (adx + batch_size, add_req_params.params, add_req_params.out_tokens) + for adx, add_req_params in enumerate(add_reqs_append) + ]) + persistent_batch.extend(add_reqs_append) + pre_condense_batch_size = len(persistent_batch) + wdx += num_step_add # Update workload offset + + # Simulate condensing persistent batch + last_nonempty_index = pre_condense_batch_size - 1 + condensed_to_idxs = set() + while batch_update_builder.removed: + if (last_nonempty_index in batch_update_builder.removed + or last_nonempty_index in condensed_to_idxs): + last_nonempty_index -= 1 + continue + # last_nonempty_index is the highest persistent batch index that was + # not removed + first_empty_index = batch_update_builder.peek_removed() + assert first_empty_index is not None + if first_empty_index > last_nonempty_index: + break + # first_empty_index is the lowest removed persistent batch index + # that is less than last_nonempty_index + # + # move last_nonempty_index -> first_empty_index + batch_update_builder.pop_removed() + condensed_to_idxs.add(first_empty_index) + persistent_batch[first_empty_index] = persistent_batch[ + last_nonempty_index] + batch_update_builder.moved.append( + (last_nonempty_index, first_empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + + last_nonempty_index -= 1 + + # Now removed requests & gaps left by non-removed requests that got + # moved downward are grouped consecutively in the upper indices of + # the persistent batch. Truncate them to get condensed persistent batch + condensed_batch_size = batch_size + num_step_add - num_step_remove + persistent_batch[:] = persistent_batch[0:condensed_batch_size] + + if condensed_batch_size > 1: + # Simulate arbitrary reorder_batch() in the kernel backend + # Generate a random number k of non-overlapping swap tuples + k = random.randint(0, condensed_batch_size // 2) + idxs = list(range(condensed_batch_size)) + random.shuffle(idxs) + swaps = [ + tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k) + ] + batch_update_builder.moved.extend([ + (sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps + ]) + for adx, bdx in swaps: + persistent_batch[adx], persistent_batch[bdx] = persistent_batch[ + bdx], persistent_batch[adx] + + return (batch_update_builder.get_and_reset(condensed_batch_size), wdx, + workload_size - wdx) + + +def _assert_valid( + batch_size: int, + persistent_batch: list[LogitsProcsRequestParams], + test_fakes: LogitsprocsTestFakes, + slice_idxs: list[int], + logits_w_lp: torch.Tensor, + step_idx: int, +) -> None: + if not slice_idxs: + # Trivial case of empty persistent batch + assert len(persistent_batch) == 0 + if logits_w_lp.shape[0] != 0: + raise ValueError("Fake persistent batch is empty but logitsprocs " + f"output batch has shape {logits_w_lp.shape}") + return + + # Validate logits for each fake request + for batch_index in range(batch_size): + request_params = persistent_batch[batch_index] + # Invoke the appropriate validation function for + # the logitproc employed by this request + fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn + fxn(test_fakes=test_fakes, + persistent_batch=persistent_batch, + logits_new=logits_w_lp, + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) +@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) +def test_logitsprocs(device: str, reqs_per_logitproc: int, + logitsprocs_under_test: list[str]): + random.seed(40) + torch.set_default_device(device) + + # Define a shuffled batch of requests which individually use a different + # logitproc, or no logitproc at all + workload_params = _generate_mixed_logitsprocs_batch_params( + reqs_per_logitproc=reqs_per_logitproc, + logitsprocs_types=logitsprocs_under_test) + workload_size = len(workload_params) + + # Create fake test data structures for testing. + test_fakes = _generate_test_fakes(workload_size, device) + + wdx = 0 # Next request index in workload to add + persistent_batch: list[LogitsProcsRequestParams] = [ + ] # Persistent batch state, as list of workload indices + + # Generate fake removed request indices from current persistent + # batch before adds + batch_update_builder = BatchUpdateBuilder() + + # Break when entire workload has been added previously and persistent + # batch is empty + workload_reqs_remaining = workload_size + batch_size = 0 + step_idx = 0 + while True: + if not (workload_reqs_remaining or batch_size): + break + + ( + batch_update, + wdx, + workload_reqs_remaining, + ) = _generate_fake_step_update( + persistent_batch=persistent_batch, + workload_params=workload_params, + wdx=wdx, + batch_update_builder=batch_update_builder, + ) + batch_size = len(persistent_batch) + + # Apply fake batch update to logitsprocs + fake_update_logitsprocs_state(test_fakes, batch_update) + + # Emulate application of logits processors in engine + slice_idxs = [req.workload_index for req in persistent_batch] + logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu() + + _assert_valid( + batch_size=batch_size, + persistent_batch=persistent_batch, + test_fakes=test_fakes, + slice_idxs=slice_idxs, + logits_w_lp=logits_w_lp, + step_idx=step_idx, + ) + + step_idx += 1 diff --git a/tests/v1/sample/test_logprobs_e2e.py b/tests/v1/sample/test_logprobs_e2e.py index 0b135613ff6bd..50b14a15dc164 100644 --- a/tests/v1/sample/test_logprobs_e2e.py +++ b/tests/v1/sample/test_logprobs_e2e.py @@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62 # FIXME(rob): enable prefix caching once supported. MODEL = "meta-llama/Llama-3.2-1B-Instruct" -MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 +MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501 SERVER_ARGS = [ - "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" + "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests", + "--gpu-memory-utilization=0.8" ] NUM_CONCURRENT = 100 diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 1f2bdb3c5ff62..3a4d48afc9d77 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from vllm.platforms import current_platform +from vllm.v1.sample.logits_processor import LogitsProcessorManager from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) @@ -58,7 +59,6 @@ def create_sampling_metadata( all_random=not all_greedy, top_p=top_p, top_k=top_k, - min_p=torch.empty(1, ), generators=generators, max_num_logprobs=0, no_penalties=False, @@ -67,10 +67,9 @@ def create_sampling_metadata( presence_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]), output_token_ids=[], - min_tokens={}, - logit_bias=[None], allowed_token_ids_mask=None, bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager(), ) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index a2beb5ad71dbb..ea10661ea1137 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -8,10 +8,13 @@ import pytest import torch from vllm.platforms import current_platform -from vllm.utils import make_tensor_with_pad +from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.v1.sample.logits_processor import LogitsProcessorManager from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import Sampler +PIN_MEMORY_AVAILABLE = is_pin_memory_available() +MAX_NUM_REQS = 256 VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 CUDA_DEVICES = [ @@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor( ) -def _create_logit_bias( - batch_size: int, - vocab_size: int, - bias_value: float, -) -> list[Optional[dict[int, float]]]: - res: list[Optional[dict[int, float]]] = [] - for i in range(batch_size): - logit_bias = {min(i, vocab_size - 1): bias_value} - res.append(logit_bias) - return res - - def _create_allowed_token_ids( batch_size: int, vocab_size: int, @@ -145,7 +136,6 @@ def _create_default_sampling_metadata( all_random=False, top_p=None, top_k=None, - min_p=None, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, @@ -155,43 +145,13 @@ def _create_default_sampling_metadata( presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), no_penalties=True, - min_tokens={}, - logit_bias=[None] * batch_size, allowed_token_ids_mask=None, bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager(), ) return fake_sampling_metadata -def _generate_min_token_penalties_and_stop_tokens( - num_output_tokens: int, batch_size: int, vocab_size: int, - batch_indices_for_min_token_penalty: list[int] -) -> dict[int, tuple[int, set[int]]]: - """ - Generates and returns a dict of minimum token penalties and - corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each - batch. - - If a batch index is included in `batch_indices_for_min_token_penalty`, - a higher `min_tokens` value is assigned (within a randomized range), - and a random set of stop token IDs is created. Otherwise, a lower - `min_tokens` value is assigned, and the stop token IDs set is empty. - """ - min_tokens: dict[int, tuple[int, set[int]]] = {} - for index in range(batch_size): - if index in batch_indices_for_min_token_penalty: - min_tokens[index] = ( - np.random.randint(num_output_tokens + 1, - 2 * num_output_tokens), - set( - np.random.randint(0, vocab_size - 1) - for _ in range(np.random.randint(0, vocab_size)))) - else: - min_tokens[index] = (np.random.randint(0, - num_output_tokens), set()) - return min_tokens - - def _create_weighted_output_token_list( batch_size: int, vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: @@ -227,36 +187,6 @@ def _create_weighted_output_token_list( return output_token_ids, sorted_token_ids_in_output -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("batch_size", [1, 2, 32]) -def test_sampler_min_tokens_penalty(device: str, batch_size: int): - """ - Tests that if the number of output tokens is less than - SamplingParams.min_tokens then we will set the logits for - the stop token ids to -inf. - """ - torch.set_default_device(device) - fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) - batch_indices_for_min_token_penalty = np.random.randint( - 0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist() - min_tokens = _generate_min_token_penalties_and_stop_tokens( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, - batch_indices_for_min_token_penalty) - sampling_metadata.min_tokens = min_tokens - sampler = Sampler() - logits = sampler.apply_penalties(fake_logits, sampling_metadata) - logits = logits.cpu() - for batch_idx in range(batch_size): - for token_id in range(VOCAB_SIZE): - _, stop_token_ids = min_tokens.get(batch_idx, (0, set())) - if token_id in stop_token_ids: - assert logits[batch_idx][token_id] == -float("inf") - else: - assert logits[batch_idx][token_id] != -float("inf") - - @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) @@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, or non_penalized_token_id in output_tokens) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("batch_size", [1, 2, 32]) -@pytest.mark.parametrize("min_p", [0.0, 0.1]) -def test_sampler_min_p(device: str, batch_size: int, min_p: float): - """ - Tests that when min_p is applied, tokens with probability below - min_p * max_prob are masked with -inf. - """ - torch.set_default_device(device) - fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - - # Create one dominant token per batch - for i in range(batch_size): - fake_logits[i, 0] = 10.0 # High logit for first token - fake_logits[i, 1:] = 1e-2 # Others remain low - - sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) - - # Configure min_p parameters - sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device) - - sampler = Sampler() - logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p) - logits = logits.cpu() - - for batch_idx in range(batch_size): - for token_id in range(VOCAB_SIZE): - if token_id == 0: - # Dominant token should always be unmasked - assert logits[batch_idx][token_id] != -float("inf") - else: - if min_p > 0.0: - # Non-dominant tokens should be masked when min_p > 0 - assert logits[batch_idx][token_id] == -float("inf") - else: - # No masking when min_p is 0 - assert logits[batch_idx][token_id] != -float("inf") - - -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("batch_size", [1, 2, 32]) -@pytest.mark.parametrize("bias_value", [-0.1, 1.2]) -def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): - """ - Test to verify that when the repetition penalty is enabled, tokens - are penalized based on their presence in the prompt or the existing - output. - """ - torch.set_default_device(device) - # Create fake logits where each token is assigned the same - # logit value. - fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) - sampling_metadata = _create_default_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) - sampling_metadata.logit_bias = _create_logit_bias( - batch_size=batch_size, - vocab_size=VOCAB_SIZE, - bias_value=bias_value, - ) - sampler = Sampler() - logits = sampler.apply_logits_bias(fake_logits, sampling_metadata) - logits = logits.cpu() - for batch_idx in range(batch_size): - logits_for_req = logits[batch_idx] - biased_index = min(batch_idx, VOCAB_SIZE - 1) - for token_id in range(VOCAB_SIZE): - if biased_index == token_id: - assert logits_for_req[token_id] == pytest.approx(bias_value + - 1e-2) - else: - assert logits_for_req[token_id] == pytest.approx(1e-2) - - @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) diff --git a/tests/v1/sample/utils.py b/tests/v1/sample/utils.py index 8c111f846b47e..e33efb413d026 100644 --- a/tests/v1/sample/utils.py +++ b/tests/v1/sample/utils.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator from enum import Enum -from typing import Optional +from typing import NamedTuple, Optional import regex as re +import torch from vllm import CompletionOutput +from vllm.utils import make_tensor_with_pad +from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor +from vllm.v1.sample.metadata import SamplingMetadata class BatchLogprobsComposition(Enum): @@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob( logprobs = completion_output.logprobs assert logprobs is not None return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)]) + + +def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor: + fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float) + return fake_logits + + +def create_penalty_tensor(batch_size: int, penalty_value: float, + device: torch.device) -> torch.Tensor: + return torch.full((batch_size, ), + fill_value=penalty_value, + dtype=torch.float, + device=device) + + +def create_prompt_tokens_tensor( + prompt_token_ids: list[list[int]], + vocab_size: int, + device: torch.device, +) -> torch.Tensor: + return make_tensor_with_pad( + prompt_token_ids, + pad=vocab_size, + device=device, + dtype=torch.int64, + pin_memory=False, + ) + + +class LogitsprocsTestFakes(NamedTuple): + """Wraps fake data structures to support testing""" + logits: torch.Tensor + sampling_metadata: SamplingMetadata + + def get_logitsprocs_by_cls( + self, + cls: type[LogitsProcessor], + ) -> Iterator[LogitsProcessor]: + """Yield logits processors of a specific class. + + Args: + cls: :class:`LogitsProcessor` subclass + + Returns: + Iterator over logits processors + """ + return (lp for lp in self.sampling_metadata.logitsprocs.all + if isinstance(lp, cls)) + + def get_logitsprocs(self) -> Iterator[LogitsProcessor]: + """Iterator over all logits processors.""" + return self.sampling_metadata.logitsprocs.all + + +def fake_update_logitsprocs_state( + test_fakes: LogitsprocsTestFakes, + batch_update: BatchUpdate, +) -> None: + """Imitate logits processors persistent batch state update + in engine core""" + for logitproc in test_fakes.get_logitsprocs(): + logitproc.update_state(batch_update) + + +def fake_apply_logitsprocs( + test_fakes: LogitsprocsTestFakes, + slice_indices: list[int], +) -> torch.Tensor: + """Imitate application of logits processors in engine core""" + logits = test_fakes.logits[torch.tensor(slice_indices, + dtype=torch.long)].clone() + for processor in test_fakes.get_logitsprocs(): + logits = processor.apply(logits) + return logits diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 59b28e675c252..943a13debada2 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import inspect +from collections.abc import Sequence from typing import Optional import numpy as np @@ -12,6 +13,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import LogitsProcessorManager from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -26,13 +28,18 @@ CUDA_DEVICES = [ MAX_NUM_PROMPT_TOKENS = 64 -def _compare_objs(obj1, obj2): +def _compare_objs(obj1, + obj2, + skip: Sequence = ("logitsprocs", "batch_update_builder")): attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) attr_names = set([ a[0] for a in attrs if not (a[0].startswith('__') and a[0].endswith('__')) ]) for attr_name in attr_names: + if attr_name in skip: + continue + a = getattr(obj1, attr_name) b = getattr(obj2, attr_name) @@ -58,13 +65,11 @@ def _compare_objs(obj1, obj2): f" in {obj1} and {obj2}: {a} != {b}" -def _remove_requests( - input_batch: InputBatch, batch_size: int, - reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]: +def _remove_requests(input_batch: InputBatch, batch_size: int, + reqs: list[CachedRequestState]) -> set[str]: """ - Remove some requests randomly from the batch and returns a tuple - of 1) set of request removed 2) indices of the requests removed - ordered in descending order + Remove some requests randomly from the batch and returns + set of request removed """ num_reqs_to_remove = np.random.randint(0, batch_size) @@ -73,13 +78,11 @@ def _remove_requests( req_index_to_remove = np.random.randint(0, batch_size) req_indices_to_remove.add(req_index_to_remove) - req_indices_to_remove_list = list(req_indices_to_remove) - req_indices_to_remove_list.sort(reverse=True) req_ids_to_remove: set[str] = set() for index in req_indices_to_remove: input_batch.remove_request(reqs[index].req_id) req_ids_to_remove.add(reqs[index].req_id) - return req_ids_to_remove, req_indices_to_remove_list + return req_ids_to_remove def _construct_expected_sampling_metadata( @@ -100,7 +103,6 @@ def _construct_expected_sampling_metadata( repetition_penalties = [1.0 for _ in range(num_reqs)] top_k = [0 for _ in range(num_reqs)] top_p = [0.0 for _ in range(num_reqs)] - min_p = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)] min_tokens = {} logit_bias = [None] * num_reqs @@ -123,7 +125,6 @@ def _construct_expected_sampling_metadata( req.sampling_params.repetition_penalty) top_k[index_in_input_batch] = req.sampling_params.top_k top_p[index_in_input_batch] = req.sampling_params.top_p - min_p[index_in_input_batch] = req.sampling_params.min_p temperature[index_in_input_batch] = req.sampling_params.temperature min_tokens[index_in_input_batch] = ( req.sampling_params.min_tokens, @@ -145,8 +146,6 @@ def _construct_expected_sampling_metadata( top_p, dtype=torch.float, device=device), top_k=None if all(x == 0 for x in top_k) else torch.tensor( top_k, dtype=torch.int, device=device), - min_p=None if all(x == 0.0 for x in min_p) else torch.tensor( - min_p, dtype=torch.float, device=device), generators={}, max_num_logprobs=0, prompt_token_ids=make_tensor_with_pad( @@ -165,13 +164,12 @@ def _construct_expected_sampling_metadata( dtype=torch.float, device=device), output_token_ids=output_token_ids, - min_tokens=min_tokens, no_penalties=(all(x == 0 for x in presence_penalties) and all(x == 0 for x in frequency_penalties) and all(x == 1 for x in repetition_penalties)), - logit_bias=logit_bias, allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=bad_words_token_ids, + logitsprocs=LogitsProcessorManager(), ) @@ -225,6 +223,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): and the `make_sampling_metadata` method is invoked on the batch. The output of `make_sampling_metadata` is then compared against the expected results to ensure correctness. + + Note: Ignore logits processor logic, which is tested separately """ input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -238,21 +238,22 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): reqs: list[CachedRequestState] = [] req_id_reqs = {} req_id_output_token_ids = {} + # Add requests for req_index in range(batch_size): req: CachedRequestState = _construct_cached_request_state(req_index) - input_batch.add_request(req, req_index) + assigned_req_index = input_batch.add_request(req) + assert req_index == assigned_req_index reqs.append(req) req_id_reqs[req.req_id] = req req_id_output_token_ids[req.req_id] = req.output_token_ids # Remove some requests - req_ids_to_remove, req_indices_to_remove = _remove_requests( - input_batch, batch_size, reqs) + req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs) req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove # Compact the input batch - input_batch.condense(req_indices_to_remove) + input_batch.condense() # Generate the sampling metadata sampling_metadata = input_batch._make_sampling_metadata() @@ -290,10 +291,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): sampling_metadata.prompt_token_ids) assert (expected_sampling_metadata.output_token_ids == sampling_metadata.output_token_ids) - assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens assert expected_sampling_metadata.no_penalties == \ sampling_metadata.no_penalties - assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias if sampling_metadata.allowed_token_ids_mask: assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, @@ -315,6 +314,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, and the `make_sampling_metadata` method is invoked on the batch. The output of `make_sampling_metadata` is then compared against the expected results to ensure correctness. + + Note: Ignore logits processor logic, which is tested separately """ input_batch: InputBatch = InputBatch( max_num_reqs=batch_size, @@ -341,7 +342,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, # Add requests for req_index in range(batch_size): req: CachedRequestState = _construct_cached_request_state(req_index) - input_batch.add_request(req, req_index) + assigned_req_index = input_batch.add_request(req) + assert assigned_req_index == req_index reqs.append(req) req_id_reqs[req.req_id] = req req_id_output_token_ids[req.req_id] = req.output_token_ids @@ -354,9 +356,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, for req_index in range(batch_size): req = reordered_reqs[req_index] - ref_input_batch.add_request(req, req_index) + assigned_req_index = ref_input_batch.add_request(req) + assert assigned_req_index == req_index - input_batch.refresh_sampling_metadata() - ref_input_batch.refresh_sampling_metadata() + input_batch.refresh_metadata() + ref_input_batch.refresh_metadata() _compare_objs(input_batch, ref_input_batch) diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py new file mode 100644 index 0000000000000..9aa560d30eee6 --- /dev/null +++ b/vllm/v1/sample/logits_processor.py @@ -0,0 +1,516 @@ +# SPDX-License-Identifier: Apache-2.0 +import dataclasses +from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from enum import Enum +from itertools import chain +from typing import Optional, Union + +import torch +from torch._prims_common import DeviceLikeType + +from vllm import PoolingParams, SamplingParams +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class MoveDirectionality(Enum): + # One-way i1->i2 req move within batch + UNIDIRECTIONAL = 0 + # Two-way i1<->i2 req swap within batch + SWAP = 1 + + +# (index, params, output_tok_ids) tuples for new +# requests added to the batch. +AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] +# (index 1, index 2, directionality) tuples representing +# one-way moves or two-way swaps of requests in batch +MovedRequest = tuple[int, int, MoveDirectionality] +# Batch indices of any removed requests. +RemovedRequest = int + + +@dataclasses.dataclass(frozen=True) +class BatchUpdate: + """Persistent batch state change info for logitsprocs""" + batch_size: int # Current num reqs in batch + + # Metadata for requests added to, removed from, and moved + # within the persistent batch. + # + # Note: each added request is represented as + # (index, params, output_tok_ids) + # Key assumption: output_tok_ids is a reference to the + # request's running output tokens list; in this way + # the logits processors always see the latest list of + # generated tokens + removed: Sequence[RemovedRequest] + moved: Sequence[MovedRequest] + added: Sequence[AddedRequest] + + +class BatchUpdateBuilder: + """Helps track persistent batch state changes and build + a batch update data structure for logitsprocs + + Assumptions: + * All information about requests removed from persistent batch + during a step is aggregated in self._removed through calls to + self.removed_append() at the beginning of a step. This must happen + before the first time that self.removed, self.pop_removed() + or self.peek_removed() are invoked in a given step + * After the first time that self.removed, self.pop_removed() + or self.peek_removed() are read in a step, no new removals + are registered using self.removed_append() + * Elements of self._removed are never directly modified, added or + removed (i.e. modification is only via self.removed_append() and + self.pop_removed()) + + Guarantees under above assumptions: + * self.removed is always sorted in descending order + * self.pop_removed() and self.peek_removed() both return + the lowest removed request index in the current step + """ + + _removed: list[RemovedRequest] + _is_removed_sorted: bool + moved: list[MovedRequest] + added: list[AddedRequest] + + def __init__( + self, + removed: Optional[list[RemovedRequest]] = None, + moved: Optional[list[MovedRequest]] = None, + added: Optional[list[AddedRequest]] = None, + ) -> None: + self._removed = removed or [] + self.moved = moved or [] + self.added = added or [] + self._is_removed_sorted = False + + def _ensure_removed_sorted(self) -> None: + """Sort removed request indices in + descending order. + + Idempotent after first call in a + given step, until reset. + """ + if not self._is_removed_sorted: + self._removed.sort(reverse=True) + self._is_removed_sorted = True + + @property + def removed(self) -> list[RemovedRequest]: + """Removed request indices sorted in + descending order""" + self._ensure_removed_sorted() + return self._removed + + def removed_append(self, index: int) -> None: + """Register the removal of a request from + the persistent batch. + + Must not be called after the first time + self.removed, self.pop_removed() or + self.peek_removed() are invoked. + + Args: + index: request index + """ + if self._is_removed_sorted: + raise RuntimeError("Cannot register new removed request after" + " self.removed has been read.") + self._removed.append(index) + + def has_removed(self) -> bool: + return bool(self._removed) + + def peek_removed(self) -> Optional[int]: + """Return lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed[-1] + return None + + def pop_removed(self) -> Optional[int]: + """Pop lowest removed request index""" + if self.has_removed(): + self._ensure_removed_sorted() + return self._removed.pop() + return None + + def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: + """Generate a logitsprocs batch update data structure + and reset internal batch update builder state. + + Args: + batch_size: current persistent batch size + + Returns: + Frozen logitsprocs batch update instance; `None` if no updates + """ + # Reset removal-sorting logic + self._is_removed_sorted = False + if not any((self._removed, self.moved, self.added)): + # No update; short-circuit + return None + # Build batch state update + batch_update = BatchUpdate( + batch_size=batch_size, + removed=self._removed, + moved=self.moved, + added=self.added, + ) + # Reset removed/moved/added update lists + self._removed = [] + self.moved = [] + self.added = [] + return batch_update + + +class LogitsProcessor(ABC): + + @abstractmethod + def apply(self, logits: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abstractmethod + def is_argmax_invariant(self) -> bool: + """True if logits processor has no impact on the + argmax computation in greedy sampling. + NOTE: may or may not have the same value for all + instances of a given LogitsProcessor subclass, + depending on subclass implementation. + TODO(andy): won't be utilized until logits + processors are user-extensible + """ + raise NotImplementedError + + @abstractmethod + def update_state( + self, + batch_update: Optional[BatchUpdate], + ) -> None: + """Called when there are new output tokens, prior + to each forward pass. + + Args: + batch_update is non-None iff there have been + changes to the batch makeup. + """ + raise NotImplementedError + + +@dataclass +class LogitsProcessorManager: + """Encapsulates initialized logitsproc objects.""" + argmax_invariant: list[LogitsProcessor] = field( + default_factory=list) # argmax-invariant logitsprocs + non_argmax_invariant: list[LogitsProcessor] = field( + default_factory=list) # non-argmax-invariant logitsprocs + + @property + def all(self) -> Iterator[LogitsProcessor]: + """Iterator over all logits processors.""" + return chain(self.argmax_invariant, self.non_argmax_invariant) + + +###### ----- Built-in LogitsProcessor impls below here + + +class MinPLogitsProcessor(LogitsProcessor): + + def __init__(self, max_num_reqs: int, pin_memory: bool, + device: DeviceLikeType): + super().__init__() + self.min_p_count: int = 0 + + self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + # Current slice of the device tensor + self.min_p: torch.Tensor = self.min_p_device[:0] + + def is_argmax_invariant(self) -> bool: + """Min-p never impacts greedy sampling""" + return True + + def get_min_p_by_index(self, index: int) -> float: + return float(self.min_p_cpu[index]) + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + needs_update = False + # Process added requests. + for index, params, _ in batch_update.added: + min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 + if self.min_p_cpu[index] != min_p: + needs_update = True + self.min_p_cpu[index] = min_p + if min_p: + self.min_p_count += 1 + + if self.min_p_count: + # Process removed requests. + needs_update |= bool(batch_update.removed) + for index in batch_update.removed: + if self.min_p_cpu[index]: + self.min_p_count -= 1 + + # Process moved requests, unidirectional (a->b) and swap (a<->b) + for adx, bdx, direct in batch_update.moved: + change = (min_p_a := + self.min_p_cpu[adx]) != (min_p_b := + self.min_p_cpu[bdx]) + needs_update |= change + if change: + self.min_p_cpu[bdx] = min_p_a + if direct == MoveDirectionality.SWAP: + self.min_p_cpu[adx] = min_p_b + + # Update tensors if needed. + size = batch_update.batch_size + if self.min_p_count and (needs_update or self.min_p.shape[0] != size): + self.min_p = self.min_p_device[:size] + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) + self.min_p.unsqueeze_(1) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if not self.min_p_count: + return logits + + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Adjust min_p + adjusted_min_p = max_probabilities.mul_(self.min_p) + # Identify valid tokens using threshold comparison + invalid_token_mask = probability_values < adjusted_min_p + # Apply mask using boolean indexing + logits[invalid_token_mask] = -float('inf') + return logits + + +class LogitBiasLogitsProcessor(LogitsProcessor): + + def __init__(self, pin_memory: bool, device: torch.device): + super().__init__() + self.biases: dict[int, dict[int, float]] = {} + self.device = device + self.pin_memory = pin_memory + + self.bias_tensor: torch.Tensor = torch.tensor(()) + self.logits_slice = (self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32)) + + def is_argmax_invariant(self) -> bool: + """Logit bias can rebalance token probabilities and change the + outcome of argmax in greedy sampling.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if not batch_update: + return + + # Process added requests. + needs_update = bool(batch_update.added) + for index, params, _ in batch_update.added: + if isinstance(params, SamplingParams) and (lb := + params.logit_bias): + self.biases[index] = lb + else: + self.biases.pop(index, None) + + if self.biases: + # Process removed requests. + for index in batch_update.removed: + if self.biases.pop(index, None): + needs_update = True + + # Process moved requests, unidirectional (a->b) and swap (a<->b) + for a_index, b_index, direct in batch_update.moved: + if direct == MoveDirectionality.UNIDIRECTIONAL: + if (a_entry := self.biases.pop(a_index, None)) is None: + if self.biases.pop(b_index, None) is not None: + needs_update = True + else: + self.biases[b_index] = a_entry + needs_update = True + else: + a_entry = self.biases.pop(a_index, None) + if (b_entry := self.biases.pop(b_index, None)) is not None: + self.biases[a_index] = b_entry + needs_update = True + if a_entry is not None: + self.biases[b_index] = a_entry + needs_update = True + + # Update tensors if needed. + if needs_update: + reqs, tok_ids, biases = [], [], [] + for req, lb in self.biases.items(): + reqs.extend([req] * len(lb)) + tok_ids.extend(lb.keys()) + biases.extend(lb.values()) + + self.bias_tensor = self._device_tensor(biases, torch.float32) + self.logits_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + + def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: + return (torch.tensor(data, + device="cpu", + dtype=dtype, + pin_memory=self.pin_memory).to(device=self.device, + non_blocking=True)) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.biases: + logits[self.logits_slice] += self.bias_tensor + return logits + + +class MinTokensLogitsProcessor(LogitsProcessor): + + def __init__(self, pin_memory: bool, device: torch.device): + # index -> (min_toks, output_token_ids, stop_token_ids) + super().__init__() + self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} + self.device = device + self.pin_memory = pin_memory + + # (req_idx_tensor,eos_tok_id_tensor) + self.logits_slice: tuple[torch.Tensor, + torch.Tensor] = (self._device_tensor( + [], torch.int32), + self._device_tensor( + [], torch.int32)) + + def is_argmax_invariant(self) -> bool: + """By censoring stop tokens, min-tokens can change the outcome + of the argmax operation in greedy sampling.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + needs_update = False + + if batch_update: + # Process added requests. + needs_update |= bool(batch_update.added) + for index, params, output_tok_ids in batch_update.added: + if (isinstance(params, SamplingParams) + and (min_tokens := params.min_tokens) + and len(output_tok_ids) < min_tokens): + # Replace request metadata at batch index + self.min_toks[index] = (min_tokens, output_tok_ids, + params.all_stop_token_ids) + else: + # Drop request metadata at batch index + self.min_toks.pop(index, None) + + if self.min_toks: + # Process removed requests. + for index in batch_update.removed: + if self.min_toks.pop(index, None): + needs_update = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + if direct == MoveDirectionality.UNIDIRECTIONAL: + if (a_entry := self.min_toks.pop(a_index, + None)) is None: + if self.min_toks.pop(b_index, None) is not None: + needs_update = True + else: + self.min_toks[b_index] = a_entry + needs_update = True + else: + a_entry = self.min_toks.pop(a_index, None) + if (b_entry := self.min_toks.pop(b_index, + None)) is not None: + self.min_toks[a_index] = b_entry + needs_update = True + if a_entry is not None: + self.min_toks[b_index] = a_entry + needs_update = True + + if self.min_toks: + # Check for any requests that have attained their min tokens. + to_remove = tuple(index for index, (min_toks, out_tok_ids, + _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks) + if to_remove: + needs_update = True + for index in to_remove: + del self.min_toks[index] + + # Update tensors if needed. + if needs_update: + reqs: list[int] = [] + tok_ids: list[int] = [] + for req, (_, _, stop_tok_ids) in self.min_toks.items(): + reqs.extend([req] * len(stop_tok_ids)) + tok_ids.extend(stop_tok_ids) + + self.logits_slice = (self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32)) + + def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: + return (torch.tensor(data, + device="cpu", + dtype=dtype, + pin_memory=self.pin_memory).to(device=self.device, + non_blocking=True)) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.min_toks: + # Inhibit EOS token for requests which have not reached min length + logits[self.logits_slice] = -float("inf") + return logits + + +def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, + device: torch.device) -> LogitsProcessorManager: + """Construct 'builtin' vLLM logitsprocs which the engine + loads by default. + + Args: + pin_memory_available: pinned memory is available for use + for use by logitsproc + max_num_reqs: ceiling on request count in persistent batch + device: inference device + + Returns: + Data structure encapsulating loaded logitsprocs + """ + min_tokens_logitproc = MinTokensLogitsProcessor( + pin_memory=pin_memory_available, device=device) + logit_bias_logitproc = LogitBiasLogitsProcessor( + pin_memory=pin_memory_available, device=device) + min_p_logitproc = MinPLogitsProcessor( + pin_memory=pin_memory_available, + device=device, + # +1 for temporary swap space + max_num_reqs=max_num_reqs + 1) + return LogitsProcessorManager( + non_argmax_invariant=[ + min_tokens_logitproc, + logit_bias_logitproc, + ], + argmax_invariant=[min_p_logitproc], + ) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index ab13b288a5a9b..1189b12f30776 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,6 +6,8 @@ from typing import Optional import torch +from vllm.v1.sample.logits_processor import LogitsProcessorManager + @dataclass class SamplingMetadata: @@ -16,7 +18,6 @@ class SamplingMetadata: top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] - min_p: Optional[torch.Tensor] generators: dict[int, torch.Generator] @@ -31,14 +32,12 @@ class SamplingMetadata: output_token_ids: list[list[int]] - # req_index -> (min_tokens, stop_token_ids) - min_tokens: dict[int, tuple[int, set[int]]] - - logit_bias: list[Optional[dict[int, float]]] - # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). allowed_token_ids_mask: Optional[torch.Tensor] # req_index -> bad_words_token_ids bad_words_token_ids: dict[int, list[list[int]]] + + # Loaded logits processors + logitsprocs: LogitsProcessorManager diff --git a/vllm/v1/sample/ops/penalties.py b/vllm/v1/sample/ops/penalties.py index 48423b9b424dd..5d54f6679a1a9 100644 --- a/vllm/v1/sample/ops/penalties.py +++ b/vllm/v1/sample/ops/penalties.py @@ -7,22 +7,6 @@ from vllm.model_executor.layers.utils import apply_penalties from vllm.utils import is_pin_memory_available, make_tensor_with_pad -def apply_min_token_penalties( - logits: torch.Tensor, output_token_ids: list[list[int]], - min_tokens: dict[int, tuple[int, set[int]]]) -> None: - """ - Applies minimum token penalty by setting the logits of the stop tokens - to -inf. - """ - min_tokens_logits_to_penalize: list[tuple[int, int]] = [] - for index, (min_token, stop_token_ids) in min_tokens.items(): - if len(output_token_ids[index]) < min_token: - for stop_token_id in stop_token_ids: - min_tokens_logits_to_penalize.append((index, stop_token_id)) - if min_tokens_logits_to_penalize: - logits[tuple(zip(*min_tokens_logits_to_penalize))] = -float("inf") - - def apply_all_penalties( logits: torch.Tensor, prompt_token_ids: torch.Tensor, diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 6bc0cecdd4940..e79e4451a3a3f 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -5,12 +5,11 @@ import torch import torch.nn as nn -from vllm.utils import async_tensor_h2d, is_pin_memory_available +from vllm.utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words -from vllm.v1.sample.ops.penalties import (apply_all_penalties, - apply_min_token_penalties) +from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler _SAMPLING_EPS = 1e-5 @@ -44,8 +43,11 @@ class Sampler(nn.Module): logits = self.apply_allowed_token_ids(logits, sampling_metadata) # Apply bad words exclusion. logits = self.apply_bad_words(logits, sampling_metadata) - # Apply logits bias. - logits = self.apply_logits_bias(logits, sampling_metadata) + + # Apply logits processors which can impact greedy sampling + for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): + logits = processor.apply(logits) + # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) # Sample the next token. @@ -110,9 +112,10 @@ class Sampler(nn.Module): # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) - # Apply min_p. - if sampling_metadata.min_p is not None: - logits = self.apply_min_p(logits, sampling_metadata.min_p) + # Apply logits processors that only apply to random sampling + # (argmax invariant) + for processor in sampling_metadata.logitsprocs.argmax_invariant: + logits = processor.apply(logits) # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( @@ -187,10 +190,6 @@ class Sampler(nn.Module): logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - if sampling_metadata.min_tokens: - apply_min_token_penalties(logits, - sampling_metadata.output_token_ids, - sampling_metadata.min_tokens) if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( @@ -203,65 +202,6 @@ class Sampler(nn.Module): ) return logits - def apply_min_p( - self, - logits: torch.Tensor, - min_p: torch.Tensor, - ) -> torch.Tensor: - """ - Filters logits using adaptive probability thresholding. - """ - # Convert logits to probability distribution - probability_values = torch.nn.functional.softmax(logits, dim=-1) - # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) - # Reshape min_p for broadcasting - adjusted_min_p = min_p.unsqueeze(1) * max_probabilities - # Identify valid tokens using threshold comparison - valid_token_mask = probability_values >= adjusted_min_p - # Apply mask using boolean indexing - logits[~valid_token_mask] = -float('inf') - return logits - - def apply_logits_bias( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: - # TODO(houseroad): this implementation is extremely inefficient. - # One idea is implement this as a PyTorch C++ op, and we may - # even optimize the logit_bias layout. - - rows: list[int] = [] - cols: list[int] = [] - vals: list[float] = [] - - # Get vocabulary size from logits - vocab_size = logits.shape[-1] - - for i, logit_bias in enumerate(sampling_metadata.logit_bias): - if logit_bias: - for token_id, bias in logit_bias.items(): - # Check token_id bounds to ensure within vocabulary - if token_id < 0 or token_id >= vocab_size: - raise ValueError( - f"token_id {token_id} in logit_bias contains " - f"out-of-vocab token id. Vocabulary size: " - f"{vocab_size}") - rows.append(i) - cols.append(token_id) - vals.append(bias) - - if rows: - indices = async_tensor_h2d([rows, cols], torch.int64, - logits.device, self.pin_memory) - values = async_tensor_h2d(vals, torch.float, logits.device, - self.pin_memory) - logits.index_put_(tuple(indices), values=values, accumulate=True) - return logits - def apply_allowed_token_ids( self, logits: torch.Tensor, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 5c37333cebc7a..3a86fea146f33 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,23 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu_input_batch import InputBatch + +_SAMPLING_EPS = 1e-5 -def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: - if req_id in input_batch.min_p_reqs: - # Spec decode doesn't support min_p sampling. - return False - elif (req_id in input_batch.frequency_penalties_reqs - or req_id in input_batch.presence_penalties_reqs - or req_id in input_batch.repetition_penalties_reqs): - # Spec decode doesn't support penalties. - return False - elif req_id in input_batch.num_logprobs: - # Spec decode doesn't support logprobs. - return False - - return True +def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: + """True if request is incompatible with speculative decoding""" + return (sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + or sampling_params.repetition_penalty != 1.0 + or sampling_params.min_p > _SAMPLING_EPS + or sampling_params.logprobs is not None) @triton.jit diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ca2bfe8317468..1a79d72be0a9b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,12 +15,14 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, + MoveDirectionality, + init_builtin_logitsprocs) from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable -_SAMPLING_EPS = 1e-5 - @dataclass class CachedRequestState: @@ -67,8 +69,10 @@ class InputBatch: pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group + is_spec_decode: bool = False, logits_processing_needs_token_ids: bool = False, ): + self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens @@ -146,15 +150,8 @@ class InputBatch: self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: set[str] = set() - self.min_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.min_p_cpu = self.min_p_cpu_tensor.numpy() - self.min_p_reqs: set[str] = set() + # IDs of requests which do not support spec decoding + self.spec_decode_unsupported_reqs: set[str] = set() # Frequency penalty related data structures self.frequency_penalties = torch.empty((max_num_reqs, ), @@ -194,9 +191,6 @@ class InputBatch: self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_reqs: set[str] = set() - # req_index -> (min_tokens, stop_token_ids) - self.min_tokens: dict[int, tuple[int, set[int]]] = {} - # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), dtype=np.int32) @@ -216,8 +210,20 @@ class InputBatch: # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - self.logit_bias: list[Optional[dict[int, - float]]] = [None] * max_num_reqs + # Internal representation of per-step batch state changes, used for + # reordering persistent batch and generating logitsprocs batch state + # updates. Should reset each step. + self.batch_update_builder = BatchUpdateBuilder() + + # Define logits processors. + # TODO(andy): logits processor list should be extensible via engine + # constructor argument; for now the list is fixed. + self.logitsprocs = init_builtin_logitsprocs( + pin_memory_available=pin_memory, + max_num_reqs=max_num_reqs + 1, + device=device) + + # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. @@ -240,14 +246,28 @@ class InputBatch: # while performing state updates to the batch. return cast(list[str], self._req_ids) + def _get_next_add_index(self) -> int: + if (req_index := self.batch_update_builder.pop_removed()) is not None: + # Fill the empty index. + return req_index + # Append to end + return self.num_reqs + + def _register_add_request(self, request: "CachedRequestState") -> int: + """Track add-request operations""" + req_index = self._get_next_add_index() + assert req_index < self.max_num_reqs + params = (request.sampling_params + if request.sampling_params else request.pooling_params) + self.batch_update_builder.added.append( + (req_index, params, request.output_token_ids)) + return req_index + def add_request( self, request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs + ) -> int: + req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): @@ -278,6 +298,9 @@ class InputBatch: self.block_table.add_row(request.block_ids, req_index) if sampling_params := request.sampling_params: + if (self.is_spec_decode + and is_spec_decode_unsupported(sampling_params)): + self.spec_decode_unsupported_reqs.add(req_id) if sampling_params.sampling_type == SamplingType.GREEDY: # Avoid later division by zero. self.temperature_cpu[req_index] = -1.0 @@ -295,11 +318,8 @@ class InputBatch: else: top_k = self.vocab_size self.top_k_cpu[req_index] = top_k - self.min_p_cpu[req_index] = sampling_params.min_p self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty - if sampling_params.min_p > _SAMPLING_EPS: - self.min_p_reqs.add(req_id) if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) self.presence_penalties_cpu[ @@ -310,10 +330,6 @@ class InputBatch: req_index] = sampling_params.repetition_penalty if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) - if sampling_params.min_tokens: - self.min_tokens[req_index] = ( - sampling_params.min_tokens, - sampling_params.all_stop_token_ids) # NOTE(woosuk): self.generators should not include the requests that # do not have their own generator. @@ -325,8 +341,6 @@ class InputBatch: if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[ req_id] = sampling_params.prompt_logprobs - if sampling_params.logit_bias is not None: - self.logit_bias[req_index] = sampling_params.logit_bias if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) @@ -368,12 +382,22 @@ class InputBatch: # No LoRA self.request_lora_mapping[req_index] = 0 + return req_index + def remove_request(self, req_id: str) -> Optional[int]: - """This method must always be followed by a call to condense().""" + """This method must always be followed by a call to condense(). + + Args: + req_id: request to remove + + Returns: + Removed request index, or `None` if `req_id` not recognized + """ req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None @@ -381,8 +405,7 @@ class InputBatch: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) - self.min_p_reqs.discard(req_id) - self.min_tokens.pop(req_index, None) + self.spec_decode_unsupported_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -400,7 +423,6 @@ class InputBatch: self.lora_id_to_lora_request.pop(lora_id) self.request_lora_mapping[req_index] = 0 - self.logit_bias[req_index] = None self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. @@ -410,6 +432,8 @@ class InputBatch: return req_index def swap_states(self, i1: int, i2: int) -> None: + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -439,8 +463,6 @@ class InputBatch: self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.min_p_cpu[i1], self.min_p_cpu[i2] =\ - self.min_p_cpu[i2], self.min_p_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -452,13 +474,10 @@ class InputBatch: self.token_ids_cpu[i2, ...] = tmp swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.min_tokens, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i2], self.request_lora_mapping[i1] - self.logit_bias[i1], self.logit_bias[i2] =\ - self.logit_bias[i2], self.logit_bias[i1] if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ @@ -467,12 +486,23 @@ class InputBatch: self.allowed_token_ids_mask_cpu_tensor[i1] self.block_table.swap_row(i1, i2) - def condense(self, empty_req_indices: list[int]) -> None: - """Move non-empty requests down into lower, empty indices. - + def condense(self) -> None: + """Slide non-empty requests down into lower, empty indices. + + Any consecutive empty indices at the very end of the list are not + filled. + Args: - empty_req_indices: empty batch indices, sorted descending. + empty_req_indices: empty indices which may be filled. + + Returns: + swaps: list of (from,to) swap tuples for moved requests + empty_req_indices: indices not filled by condensation """ + if not (empty_req_indices := self.batch_update_builder.removed): + # All removed requests were replaced by added requests, or else no + # requests were removed at all. No condense() needed + return num_reqs = self.num_reqs if num_reqs == 0: # The batched states are empty. @@ -489,11 +519,17 @@ class InputBatch: last_req_index -= 1 # Find the smallest empty index. - empty_index = empty_req_indices.pop() + empty_index = self.batch_update_builder.peek_removed() + assert empty_index is not None if empty_index >= last_req_index: break - # Swap the states. + # Move active request down into empty request + # index. + self.batch_update_builder.pop_removed() + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -524,20 +560,14 @@ class InputBatch: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] - self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] generator = self.generators.pop(last_req_index, None) if generator is not None: self.generators[empty_index] = generator - min_token = self.min_tokens.pop(last_req_index, None) - if min_token is not None: - self.min_tokens[empty_index] = min_token - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ last_req_index] - self.logit_bias[empty_index] = self.logit_bias[last_req_index] - + # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ empty_index] = self.allowed_token_ids_mask_cpu_tensor[ @@ -547,6 +577,7 @@ class InputBatch: last_req_index, None) if bad_words_token_ids is not None: self.bad_words_token_ids[empty_index] = bad_words_token_ids + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -554,8 +585,17 @@ class InputBatch: del self._req_ids[self.num_reqs:] del self.req_output_token_ids[self.num_reqs:] - def refresh_sampling_metadata(self): - self.sampling_metadata = self._make_sampling_metadata() + def refresh_metadata(self): + """Apply batch updates, reset input batch at end of step + + * Apply batch add/remove/permute to logits procs' states + * If batch state is modified, update sampling metadata + """ + batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) + for logit_proc in self.logitsprocs.all: + logit_proc.update_state(batch_update) + if batch_update: + self.sampling_metadata = self._make_sampling_metadata() def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs @@ -568,8 +608,6 @@ class InputBatch: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) - if not self.no_min_p: - copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs) if not self.no_penalties: # Since syncing these tensors is expensive only copy them @@ -607,7 +645,6 @@ class InputBatch: all_random=self.all_random, top_p=None if self.no_top_p else self.top_p[:num_reqs], top_k=None if self.no_top_k else self.top_k[:num_reqs], - min_p=None if self.no_min_p else self.min_p[:num_reqs], generators=self.generators, max_num_logprobs=self.max_num_logprobs, prompt_token_ids=prompt_token_ids, @@ -615,11 +652,10 @@ class InputBatch: presence_penalties=self.presence_penalties[:num_reqs], repetition_penalties=self.repetition_penalties[:num_reqs], output_token_ids=cast(list[list[int]], self.req_output_token_ids), - min_tokens=self.min_tokens, no_penalties=self.no_penalties, - logit_bias=self.logit_bias[:num_reqs], allowed_token_ids_mask=allowed_token_ids_mask, bad_words_token_ids=self.bad_words_token_ids, + logitsprocs=self.logitsprocs, ) @property @@ -702,10 +738,6 @@ class InputBatch: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 - @property - def no_min_p(self) -> bool: - return len(self.min_p_reqs) == 0 - @property def no_penalties(self) -> bool: return (len(self.presence_penalties_reqs) == 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index df9d69006fc57..4786d047acb5a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -63,12 +63,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from ..sample.logits_processor import LogitsProcessorManager from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -212,6 +212,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], + is_spec_decode=bool(self.vllm_config.speculative_config), ) self.use_cuda_graph = ( @@ -316,7 +317,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -325,21 +326,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): Args: scheduler_output: The scheduler output. - - Returns: - True if the batch was reordered, False otherwise. """ - batch_reordered = self.attn_metadata_builders[0].reorder_batch( - self.input_batch, scheduler_output) + self.attn_metadata_builders[0].reorder_batch(self.input_batch, + scheduler_output) # For models with multiple KV cache groups, the groups should agree on # the same order of requests. We ensure this by only allowing the first # group to reorder the batch and asserting that all other groups do not # reorder the batch. for i in range(1, len(self.kv_cache_config.kv_cache_groups)): - assert not self.attn_metadata_builders[i].reorder_batch( + batch_reordered = self.attn_metadata_builders[i].reorder_batch( self.input_batch, scheduler_output) - return batch_reordered + assert not batch_reordered # Note: used for model runner override. def _init_device_properties(self) -> None: @@ -372,11 +370,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # then resubmitted with the same ID. In this case, we treat them as two # distinct requests - clearing the cached states for the first request # and handling the second as a new request. - removed_req_indices: list[int] = [] for req_id in scheduler_output.finished_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) + self.input_batch.remove_request(req_id) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: @@ -399,9 +394,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # have low request overlap (e.g., alternating between two distinct # sets of requests), this optimization becomes very inefficient. for req_id in unscheduled_req_ids: - req_index = self.input_batch.remove_request(req_id) - assert req_index is not None - removed_req_indices.append(req_index) + self.input_batch.remove_request(req_id) req_ids_to_add: list[str] = [] # Add new requests to the cached states. @@ -545,31 +538,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] = end_token_index - # Check if the batch has changed. If not, we can skip copying the - # sampling metadata from CPU to GPU. - batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0 - # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. - removed_req_indices.sort(reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) + self.input_batch.add_request(req_state) - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - batch_reordered = self._may_reorder_batch(scheduler_output) - - if batch_changed or batch_reordered: - self.input_batch.refresh_sampling_metadata() + # Condense the batched states if there are gaps left by removed requests + self.input_batch.condense() + # Allow attention backend to reorder the batch, potentially + self._may_reorder_batch(scheduler_output) + # Refresh batch metadata with any pending updates. + self.input_batch.refresh_metadata() def _get_cumsum_and_arange( self, @@ -1296,7 +1276,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, IntermediateTensors]: - self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group(): @@ -1765,7 +1744,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. req_id = self.input_batch.req_ids[i] - if not is_spec_decode_supported(req_id, self.input_batch): + if req_id in self.input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue @@ -2121,7 +2100,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): all_random=False, top_p=dummy_tensors(0.9), top_k=dummy_tensors(logits.size(1) - 1), - min_p=None, generators={}, max_num_logprobs=None, no_penalties=True, @@ -2130,10 +2108,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), output_token_ids=[[] for _ in range(num_reqs)], - min_tokens={}, - logit_bias=[None for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, + logitsprocs=LogitsProcessorManager(), ) try: sampler_output = self.sampler(logits=logits, @@ -2425,6 +2402,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, + is_spec_decode=bool(self.vllm_config.speculative_config), ) def _allocate_kv_cache_tensors(