diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor.py index 7ef20efa7d28c..3e122319169eb 100644 --- a/examples/offline_inference/logits_processor.py +++ b/examples/offline_inference/logits_processor.py @@ -42,8 +42,8 @@ from vllm.config import VllmConfig from vllm.v1.sample.logits_processor import ( BatchUpdate, LogitsProcessor, - MoveDirectionality, ) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates # Hypothetical custom logits processor @@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and ( - target_token := params.extra_args.get("target_token") - ): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + # This function returns the LP's per-request state based on the + # request details, or None if this LP does not apply to the + # request. + lambda params, _, __: params.extra_args + and (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index c0bfc1a18feca..c36f1bd021c70 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -8,10 +8,9 @@ from typing import Optional import torch from vllm.config import VllmConfig -from vllm.sampling_params import SamplingParams from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, - LogitsProcessor, - MoveDirectionality) + LogitsProcessor) +from vllm.v1.sample.logits_processor.builtin import process_dict_updates MODEL_NAME = "facebook/opt-125m" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" @@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor): def __init__(self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool): - self.req_info: dict[int, SamplingParams] = {} + self.req_info: dict[int, int] = {} def is_argmax_invariant(self) -> bool: """Never impacts greedy sampling""" return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - # Process added requests. - for index, params, _, _ in batch_update.added: - assert params is not None - if params.extra_args and (target_token := - params.extra_args.get("target_token")): - self.req_info[index] = target_token - - if self.req_info: - # Process removed requests. - for index in batch_update.removed: - self.req_info.pop(index, None) - - # Process moved requests, unidirectional move (a->b) and swap - # (a<->b) - for adx, bdx, direct in batch_update.moved: - a_val = self.req_info.pop(adx, None) - b_val = self.req_info.pop(bdx, None) - if a_val is not None: - self.req_info[bdx] = a_val - if direct == MoveDirectionality.SWAP and b_val is not None: - self.req_info[adx] = b_val + process_dict_updates( + self.req_info, + batch_update, + lambda params, _, __: params.extra_args and + (params.extra_args.get("target_token")), + ) def apply(self, logits: torch.Tensor) -> torch.Tensor: if not self.req_info: diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 00dd757489ca0..60f9c0bdb6313 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar import torch +from vllm import SamplingParams from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, MoveDirectionality) @@ -12,6 +13,8 @@ from vllm.v1.sample.logits_processor.interface import (BatchUpdate, if TYPE_CHECKING: from vllm.config import VllmConfig +T = TypeVar("T") + class MinPLogitsProcessor(LogitsProcessor): @@ -130,49 +133,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor): return False def update_state(self, batch_update: Optional[BatchUpdate]): - if not batch_update: - return - - needs_update: bool = False - # Process added requests. - for index, params, _, _ in batch_update.added: - if lb := params.logit_bias: - self.biases[index] = lb - needs_update = True - else: - # Drop biases metadata at batch index - if self.biases.pop(index, None) is not None: - # If a new request replaces an old request which - # specified biases, we should update processor tensors - needs_update = True - - if self.biases: - # Process removed requests. - for index in batch_update.removed: - if self.biases.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and swap (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.biases.pop(a_index, None)) is None: - if self.biases.pop(b_index, None) is not None: - needs_update = True - else: - self.biases[b_index] = a_entry - needs_update = True - else: - a_entry = self.biases.pop(a_index, None) - if (b_entry := self.biases.pop(b_index, None)) is not None: - self.biases[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.biases[b_index] = a_entry - needs_update = True + needs_update = process_dict_updates( + self.biases, batch_update, + lambda params, _, __: params.logit_bias or None) # Update tensors if needed. if needs_update: - reqs, tok_ids, biases = [], [], [] + reqs: list[int] = [] + tok_ids: list[int] = [] + biases: list[float] = [] for req, lb in self.biases.items(): reqs.extend([req] * len(lb)) tok_ids.extend(lb.keys()) @@ -216,52 +185,18 @@ class MinTokensLogitsProcessor(LogitsProcessor): of the argmax operation in greedy sampling.""" return False + @staticmethod + def add_request( + params: SamplingParams, _: list[int], output_tok_ids: list[int] + ) -> Optional[tuple[int, Sequence[int], set[int]]]: + min_tokens = params.min_tokens + if not min_tokens or len(output_tok_ids) >= min_tokens: + return None + return min_tokens, output_tok_ids, params.all_stop_token_ids + def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = False - - if batch_update: - # Process added requests. - for index, params, _, output_tok_ids in batch_update.added: - if ((min_tokens := params.min_tokens) - and len(output_tok_ids) < min_tokens): - # Replace request metadata at batch index - self.min_toks[index] = (min_tokens, output_tok_ids, - params.all_stop_token_ids) - needs_update = True - else: - # Drop min_toks metadata at batch index - if self.min_toks.pop(index, None) is not None: - # If a new request replaces an old request which - # specified min_toks, we should update processor tensors - needs_update = True - - if self.min_toks: - # Process removed requests. - for index in batch_update.removed: - if self.min_toks.pop(index, None): - needs_update = True - - # Process moved requests, unidirectional (a->b) and - # swapped (a<->b) - for a_index, b_index, direct in batch_update.moved: - if direct == MoveDirectionality.UNIDIRECTIONAL: - if (a_entry := self.min_toks.pop(a_index, - None)) is None: - if self.min_toks.pop(b_index, None) is not None: - needs_update = True - else: - self.min_toks[b_index] = a_entry - needs_update = True - else: - a_entry = self.min_toks.pop(a_index, None) - if (b_entry := self.min_toks.pop(b_index, - None)) is not None: - self.min_toks[a_index] = b_entry - needs_update = True - if a_entry is not None: - self.min_toks[b_index] = a_entry - needs_update = True - + needs_update = process_dict_updates(self.min_toks, batch_update, + self.add_request) if self.min_toks: # Check for any requests that have attained their min tokens. to_remove = tuple(index for index, (min_toks, out_tok_ids, @@ -295,3 +230,44 @@ class MinTokensLogitsProcessor(LogitsProcessor): # Inhibit EOS token for requests which have not reached min length logits[self.logits_slice] = -float("inf") return logits + + +def process_dict_updates( + req_entries: dict[int, T], batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] +) -> bool: + """Utility function to update dict state for sparse LogitsProcessors.""" + + if not batch_update: + # Nothing to do. + return False + + updated = False + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + if (state := new_state(params, prompt_tok_ids, + output_tok_ids)) is not None: + req_entries[index] = state + updated = True + elif req_entries.pop(index, None) is not None: + updated = True + + if req_entries: + # Process removed requests. + for index in batch_update.removed: + if req_entries.pop(index, None): + updated = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + a_entry = req_entries.pop(a_index, None) + b_entry = req_entries.pop(b_index, None) + if a_entry is not None: + req_entries[b_index] = a_entry + updated = True + if b_entry is not None: + updated = True + if direct == MoveDirectionality.SWAP: + req_entries[a_index] = b_entry + + return updated diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 12b4db24bff88..16cd00943db8d 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -44,10 +44,16 @@ class BatchUpdate: # Key assumption: the `output_tok_ids` list (which is an element of each # tuple in `added`) is a reference to the request's running output tokens # list; via this reference, the logits processors always see the latest - # list of generated output tokens + # list of generated output tokens. + # + # NOTE: + # * Added or moved requests may replace existing requests with the same + # index. + # * Operations should be processed in the following order: + # - removed, added, moved removed: Sequence[RemovedRequest] - moved: Sequence[MovedRequest] added: Sequence[AddedRequest] + moved: Sequence[MovedRequest] class LogitsProcessor(ABC): @@ -59,6 +65,11 @@ class LogitsProcessor(ABC): @abstractmethod def apply(self, logits: torch.Tensor) -> torch.Tensor: + """Apply LogitsProcessor to batch logits tensor. + + The updated tensor must be returned but may be + modified in-place. + """ raise NotImplementedError @abstractmethod