# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This example demonstrates a special case of wrapping a request-level logits processor, namely the case where it is necessary to utilize engine config or environment info passed to the constructor. The subclass must override the wrapper base class `__init__()` method to access the engine config, the device identifier, or the flag which indicates whether pinned memory is available. For demo purposes, a request-level dummy logits processor is employed which causes the same token (`target_token`) to be decoded in each step. The request-level dummy logits processor is wrapped to create a batch-level logits processor, which can apply the logits processor to output logits from all requests in the persistent batch in a given decode step. The wrapped dummy logits processor below models a scenario where we must disable the logits processor on non-"cuda" platforms. The wrapper base class `__init__()` is overridden in order to check this condition and set a flag. A batch is constructed with `temperature=0.0` and 50% of requests specifying `target_token`, and for these requests - and *only* these requests - we expect that on a "cuda" device the output will look something like: Generated Outputs: ------------------------------------------------------------ Prompt: 'Hello, my name is' Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '" ------------------------------------------------------------ Prompt: 'The president of the United States is' Output: " not a racist. He is a racist.\nHe's a racist because he" ------------------------------------------------------------ Prompt: 'The capital of France is' Output: ' also also also also also also also also also also also also also also also also' ------------------------------------------------------------ Prompt: 'The future of AI is' Output: ' in the hands of the people.\n\nThe future of AI is in the' ------------------------------------------------------------ which indicates that the logits processor is running. However, on a non-"cuda" device, the first and third requests would not repeat the same token. """ import torch from vllm import LLM, SamplingParams from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.sample.logits_processor import ( AdapterLogitsProcessor, RequestLogitsProcessor, ) logger = init_logger(__name__) class DummyPerReqLogitsProcessor: """The request-level logits processor masks out all logits except the token id identified by `target_token`""" def __init__(self, target_token: int) -> None: """Specify `target_token`""" self.target_token = target_token def __call__( self, output_ids: list[int], logits: torch.Tensor, ) -> torch.Tensor: val_to_keep = logits[self.target_token].item() logits[:] = float("-inf") logits[self.target_token] = val_to_keep return logits class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): """Example of overriding the wrapper class `__init__()` in order to utilize info about the device type""" @classmethod def validate_params(cls, params: SamplingParams): target_token = params.extra_args and params.extra_args.get("target_token") if target_token is not None and not isinstance(target_token, int): raise ValueError( f"`target_token` has to be an integer, got {target_token}." ) def __init__( self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool ): super().__init__(vllm_config, device, is_pin_memory) self.is_cuda = device.type == "cuda" def is_argmax_invariant(self) -> bool: return False def new_req_logits_processor( self, params: SamplingParams, ) -> RequestLogitsProcessor | None: """This method returns a new request-level logits processor, customized to the `target_token` value associated with a particular request. Returns None if the logits processor should not be applied to the particular request. To use the logits processor the request must have a "target_token" custom argument with an integer value, and the device must be "cuda"-type Args: params: per-request sampling params Returns: `Callable` request logits processor, or None """ if ( not self.is_cuda or ( target_token := params.extra_args and params.extra_args.get("target_token") ) is None ): return None return DummyPerReqLogitsProcessor(target_token) # Sample prompts. prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Create a mixture of requests which do and don't utilize the dummy logitproc sampling_params_list = [ SamplingParams(temperature=0.0, extra_args={"target_token": 128}), SamplingParams(temperature=0.0), SamplingParams(temperature=0.0, extra_args={"target_token": 67}), SamplingParams(temperature=0.0), ] def main(): # Create an LLM. llm = LLM( model="facebook/opt-125m", logits_processors=[WrappedPerReqLogitsProcessor], ) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params_list) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}") print(f"Output: {generated_text!r}") print("-" * 60) if __name__ == "__main__": main()