From 136d853e65a91a21b08227217b51daaba2d5cc71 Mon Sep 17 00:00:00 2001 From: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Date: Tue, 2 Sep 2025 22:52:51 -0400 Subject: [PATCH] [V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (#23656) Signed-off-by: Andrew Feldman --- .../custom.py} | 0 .../logits_processor/custom_req.py | 151 ++++++++++++++++ .../logits_processor/custom_req_init.py | 165 ++++++++++++++++++ .../logits_processors/test_custom_offline.py | 33 ++++ tests/v1/logits_processors/utils.py | 67 ++++++- vllm/v1/sample/logits_processor/__init__.py | 113 +++++++++++- 6 files changed, 524 insertions(+), 5 deletions(-) rename examples/offline_inference/{logits_processor.py => logits_processor/custom.py} (100%) create mode 100644 examples/offline_inference/logits_processor/custom_req.py create mode 100644 examples/offline_inference/logits_processor/custom_req_init.py diff --git a/examples/offline_inference/logits_processor.py b/examples/offline_inference/logits_processor/custom.py similarity index 100% rename from examples/offline_inference/logits_processor.py rename to examples/offline_inference/logits_processor/custom.py diff --git a/examples/offline_inference/logits_processor/custom_req.py b/examples/offline_inference/logits_processor/custom_req.py new file mode 100644 index 000000000000..4c19bb4ce2ba --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates wrapping a request-level logits processor to be +compatible with vLLM's batch-level logits processing + +For demo purposes, a dummy logits processor is employed which, if +`target_token` is passed as a keyword argument to `SamplingParams.extra_args`, +will mask out all tokens except `target_token`. This logits processor can be +applied to a vector of logits associated with a single decode step for a single +request. The logits processor cannot be applied to a request which does not +pass in a `target_token` custom argument. + +The request-level dummy logits processor is wrapped to create a batch-level +logits processor, which can apply the logits processor to output logits from +all requests in the persistent batch in a given decode step. For requests which +do not provide a `target_token` argument, the corresponding row of `logits` +will not be modified. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect the `target_token` to be decoded in each step, yielding an output +similar to that shown below: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ +""" + +from typing import Any, Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[Any] = params.extra_args and params.extra_args.get( + "target_token" + ) + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/logits_processor/custom_req_init.py b/examples/offline_inference/logits_processor/custom_req_init.py new file mode 100644 index 000000000000..62947d122e01 --- /dev/null +++ b/examples/offline_inference/logits_processor/custom_req_init.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""This example demonstrates a special case of wrapping a request-level logits +processor, namely the case where it is necessary to utilize engine config or +environment info passed to the constructor. The subclass must override the +wrapper base class `__init__()` method to access the engine config, the device +identifier, or the flag which indicates whether pinned memory is available. + +For demo purposes, a request-level dummy logits processor is employed which +causes the same token (`target_token`) to be decoded in each step. The +request-level dummy logits processor is wrapped to create a batch-level logits +processor, which can apply the logits processor to output logits from all +requests in the persistent batch in a given decode step. + +The wrapped dummy logits processor below models a scenario where we must +disable the logits processor on non-"cuda" platforms. The wrapper base class +`__init__()` is overridden in order to check this condition and set a flag. + +A batch is constructed with `temperature=0.0` and 50% of requests specifying +`target_token`, and for these requests - and *only* these requests - we +expect that on a "cuda" device the output will look something like: + +Generated Outputs: +------------------------------------------------------------ +Prompt: 'Hello, my name is' +Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" +------------------------------------------------------------ +Prompt: 'The president of the United States is' +Output: " not a racist. He is a racist.\nHe's a racist because he" +------------------------------------------------------------ +Prompt: 'The capital of France is' +Output: ' also also also also also also also also also also also also also + also also also' +------------------------------------------------------------ +Prompt: 'The future of AI is' +Output: ' in the hands of the people.\n\nThe future of AI is in the' +------------------------------------------------------------ + +which indicates that the logits processor is running. However, on a non-"cuda" +device, the first and third requests would not repeat the same token. +""" + +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.sample.logits_processor import ( + AdapterLogitsProcessor, + RequestLogitsProcessor, +) + +logger = init_logger(__name__) + + +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of overriding the wrapper class `__init__()` in order to utilize + info about the device type""" + + def __init__( + self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool + ): + super().__init__(vllm_config, device, is_pin_memory) + self.is_cuda = device.type == "cuda" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value, and the device + must be "cuda"-type + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + if ( + not self.is_cuda + or ( + target_token := params.extra_args + and params.extra_args.get("target_token") + ) + is None + ): + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", + target_token, + ) + return None + return DummyPerReqLogitsProcessor(target_token) + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a mixture of requests which do and don't utilize the dummy logitproc +sampling_params_list = [ + SamplingParams(temperature=0.0, extra_args={"target_token": 128}), + SamplingParams(temperature=0.0), + SamplingParams(temperature=0.0, extra_args={"target_token": 67}), + SamplingParams(temperature=0.0), +] + + +def main(): + # Create an LLM. + llm = LLM( + model="facebook/opt-125m", + logits_processors=[WrappedPerReqLogitsProcessor], + ) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params_list) + # Print the outputs. + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +if __name__ == "__main__": + main() diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index a7fde1990f7e..97d96b129ae9 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -15,6 +15,7 @@ from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG, POOLING_MODEL_NAME, TEMP_GREEDY, CustomLogitprocSource, DummyLogitsProcessor, + WrappedPerReqLogitsProcessor, dummy_module) from tests.v1.logits_processors.utils import entry_points as fake_entry_points from tests.v1.logits_processors.utils import prompts @@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch, _run_test(kwargs, logitproc_loaded=True) +@create_new_process_for_each_test() +def test_custom_logitsprocs_req(monkeypatch): + """Test passing request-level logits processor to offline Python interface + + Wrap a request-level logits processor to create a batch level logits + processor that has a well-defined behavior (mask out all tokens except one + `target_token`) + + Construct an `LLM` instance which loads the wrapped logits processor. Pass + the custom logitproc as a class object. + + Construct a reference `LLM` instance with no custom logitproc + + Pass in a batch of requests, 50% of which pass a `target_token` value + in through `SamplingParams.extra_args`, 50% of which do not. + + Validate that + * Requests which do not activate the custom logitproc, yield the same + results for both `LLM` instances + * Requests which activate the custom logitproc, only output `target_token` + + Args: + monkeypatch: for setting env vars + """ + + # Test that logitproc info is passed to workers + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1") + random.seed(40) + _run_test({"logits_processors": [WrappedPerReqLogitsProcessor]}, + logitproc_loaded=True) + + @create_new_process_for_each_test() @pytest.mark.parametrize("logitproc_source", [ CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT, diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index c36f1bd021c7..7ec35bd3eb63 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -3,15 +3,21 @@ import types from enum import Enum, auto -from typing import Optional +from typing import Any, Optional import torch from vllm.config import VllmConfig -from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate, - LogitsProcessor) +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, + AdapterLogitsProcessor, + BatchUpdate, LogitsProcessor, + RequestLogitsProcessor) from vllm.v1.sample.logits_processor.builtin import process_dict_updates +logger = init_logger(__name__) + MODEL_NAME = "facebook/opt-125m" POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5" DUMMY_LOGITPROC_ARG = "target_token" @@ -104,5 +110,60 @@ class EntryPoints(list): self.names = [ep.name for ep in eps] +class DummyPerReqLogitsProcessor: + """The request-level logits processor masks out all logits except the + token id identified by `target_token`""" + + def __init__(self, target_token: int) -> None: + """Specify `target_token`""" + self.target_token = target_token + + def __call__( + self, + output_ids: list[int], + logits: torch.Tensor, + ) -> torch.Tensor: + val_to_keep = logits[self.target_token].item() + logits[:] = float("-inf") + logits[self.target_token] = val_to_keep + return logits + + +class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): + """Example of wrapping a fake request-level logit processor to create a + batch-level logits processor""" + + def is_argmax_invariant(self) -> bool: + return False + + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """This method returns a new request-level logits processor, customized + to the `target_token` value associated with a particular request. + + Returns None if the logits processor should not be applied to the + particular request. To use the logits processor the request must have + a "target_token" custom argument with an integer value. + + Args: + params: per-request sampling params + + Returns: + `Callable` request logits processor, or None + """ + target_token: Optional[ + Any] = params.extra_args and params.extra_args.get("target_token") + if target_token is None: + return None + if not isinstance(target_token, int): + logger.warning( + "target_token value %s is not int; not applying logits" + " processor to request.", target_token) + return None + return DummyPerReqLogitsProcessor(target_token) + + """Fake version of importlib.metadata.entry_points""" entry_points = lambda group: EntryPoints(group) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 822026916295..a5f1cadd8524 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -1,16 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +import inspect import itertools +from abc import abstractmethod from collections.abc import Sequence +from functools import partial from typing import TYPE_CHECKING, Optional, Union import torch from vllm.logger import init_logger +from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor +from vllm.sampling_params import SamplingParams from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, MinPLogitsProcessor, - MinTokensLogitsProcessor) + MinTokensLogitsProcessor, + process_dict_updates) from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, MoveDirectionality) @@ -177,9 +183,112 @@ def build_logitsprocs( BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) +class AdapterLogitsProcessor(LogitsProcessor): + """Wrapper for per-request logits processors + + To wrap a specific per-request logits processor, + * Subclass `AdapterLogitsProcessor` + * Implement `self.is_argmax_invariant()` base-class method + * Implement `self.new_req_logits_processor(params)` + + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be + overridden in general. However, to implement custom constructor behavior - + especially any logic which operates on or stores `vllm_config`, `device`, + or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)` + must be overriden and the override must call + `super().__init__(vllm_config, device, is_pin_memory)` + """ + + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): + """Subclass must invoke + `super().__init__(vllm_config, device, is_pin_memory)`. + + Subclass constructor may find it useful to utilize the `vllm_config`, + `device` and `is_pin_memory` argument. However regardless of whether + these arguments are used, the vLLM logits processor interface requires + all three arguments to be present. + """ + + # Map req index -> logits processor state + # + # State representation is a partial[Tensor] comprising a request-level + # logits processor with the output token ids argument and (if required) + # the prompt token ids argument pre-populated + # + # Note that the partial carries a *reference* to output token ids, and + # will thus always operate on the list as it is currently, not as it + # was when the partial was created. + self.req_info: dict[int, partial[torch.Tensor]] = {} + + @abstractmethod + def new_req_logits_processor( + self, + params: SamplingParams, + ) -> Optional[RequestLogitsProcessor]: + """Consume request info; return a per-request logits processor. + + Return None if logits processor does not need to be applied to request + + Args: + params: request sampling params + + Returns: + None if logits processor should not be applied to request; otherwise + returns a `RequestLogitsProcessor` instance + + """ + raise NotImplementedError + + def _new_state( + self, + params: SamplingParams, + prompt_ids: list[int], + output_ids: list[int], + ) -> Optional[partial[torch.Tensor]]: + """Return state representation for new request + + Returns None if logits processor is not applicable to request + + Args: + params: request sampling params + prompt_ids: request prompt token ids + output_ids: decoded tokens so far for this request + + Returns: + logits processor partial[Tensor] or None + + """ + if req_lp := self.new_req_logits_processor(params): + args = [prompt_ids, output_ids] if (len( + inspect.signature(req_lp).parameters) == 3) else [output_ids] + return partial(req_lp, *args) + return None + + def update_state(self, batch_update: Optional[BatchUpdate]): + process_dict_updates( + self.req_info, + batch_update, + self._new_state, + ) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + if self.req_info: + # Apply per-request logits processors to corresponding rows of + # logits tensor + for req_idx, req_lp in self.req_info.items(): + req_logits = logits[req_idx] + new_logits = req_lp(req_logits) + if new_logits is not req_logits: + # Modify logits tensor row in-place if necessary + logits[req_idx] = new_logits + return logits + + __all__ = [ "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP" + "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor" ]