[V1] Logits processors extensibility (#19912)

Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Signed-off-by: Andrew Feldman <afeld2012@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
afeldman-nm 2025-08-16 15:59:17 -04:00 committed by GitHub
parent 4fc722eca4
commit bf7f470b22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1312 additions and 334 deletions

View File

@ -253,6 +253,7 @@ steps:
- pytest -v -s v1/engine
- pytest -v -s v1/entrypoints
- pytest -v -s v1/sample
- pytest -v -s v1/logits_processors
- pytest -v -s v1/worker
- pytest -v -s v1/structured_output
- pytest -v -s v1/spec_decode

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This example demonstrates instantiating vLLM with a custom logits processor
class object.
For a basic example of implementing a custom logits processor, see
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
For testing purposes, a dummy logits processor is employed which, if
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
will mask out all tokens except `target_token`.
A batch is constructed with `temperature=0.0` and 50% of requests specifying
`target_token`, and for these requests - and *only* these requests - we
expect the `target_token` to be decoded in each step, yielding an output
similar to that shown below:
Generated Outputs:
------------------------------------------------------------
Prompt: 'Hello, my name is'
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
------------------------------------------------------------
Prompt: 'The president of the United States is'
Output: " not a racist. He is a racist.\nHe's a racist because he"
------------------------------------------------------------
Prompt: 'The capital of France is'
Output: ' also also also also also also also also also also also also also
also also also'
------------------------------------------------------------
Prompt: 'The future of AI is'
Output: ' in the hands of the people.\n\nThe future of AI is in the'
------------------------------------------------------------
"""
from typing import Optional
import torch
from vllm import LLM, SamplingParams
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import (
BatchUpdate,
LogitsProcessor,
MoveDirectionality,
)
# Hypothetical custom logits processor
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
def __init__(
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
):
self.req_info: dict[int, SamplingParams] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (
target_token := params.extra_args.get("target_token")
):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits
# Save target values before modification
rows_list = list(self.req_info.keys())
cols = torch.tensor(
[self.req_info[i] for i in rows_list],
dtype=torch.long,
device=logits.device,
)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
logits[rows] = float("-inf")
logits[rows, cols] = values_to_keep
return logits
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
SamplingParams(temperature=0.0),
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
SamplingParams(temperature=0.0),
]
def main():
# Create an LLM.
llm = LLM(
model="facebook/opt-125m",
logits_processors=[DummyLogitsProcessor],
)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params_list)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Output: {generated_text!r}")
print("-" * 60)
if __name__ == "__main__":
main()

View File

@ -13,6 +13,7 @@ import tempfile
import time
import warnings
from contextlib import contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
@ -76,6 +77,23 @@ VLLM_PATH = Path(__file__).parent.parent
class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
def _start_server(self, model: str, vllm_serve_args: list[str],
env_dict: Optional[dict[str, str]]) -> None:
"""Subclasses override this method to customize server process launch
"""
env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc: subprocess.Popen = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
def __init__(self,
model: str,
vllm_serve_args: list[str],
@ -128,18 +146,7 @@ class RemoteOpenAIServer:
model_loader = get_model_loader(load_config)
model_loader.download_model(model_config)
env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(
["vllm", "serve", model, *vllm_serve_args],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._start_server(model, vllm_serve_args, env_dict)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds)
@ -155,6 +162,10 @@ class RemoteOpenAIServer:
# force kill if needed
self.proc.kill()
def _poll(self) -> Optional[int]:
"""Subclasses override this method to customize process polling"""
return self.proc.poll()
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
@ -169,7 +180,7 @@ class RemoteOpenAIServer:
# which means the server is not ready yet.
# the stack trace is not useful, so we suppress it
# by using `raise from None`.
result = self.proc.poll()
result = self._poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
@ -205,6 +216,48 @@ class RemoteOpenAIServer:
**kwargs)
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
"""Launch test server with custom child process"""
def _start_server(self, model: str, vllm_serve_args: list[str],
env_dict: Optional[dict[str, str]]) -> None:
self.proc: Process = Process(
target=self.child_process_fxn,
args=(env_dict, model,
vllm_serve_args)) # type: ignore[assignment]
self.proc.start()
def __init__(self,
model: str,
vllm_serve_args: list[str],
child_process_fxn: Callable[
[Optional[dict[str, str]], str, list[str]], None],
*,
env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
"""Store custom child process function then invoke superclass
constructor which will indirectly launch it."""
self.child_process_fxn = child_process_fxn
super().__init__(model=model,
vllm_serve_args=vllm_serve_args,
env_dict=env_dict,
seed=seed,
auto_port=auto_port,
max_wait_seconds=max_wait_seconds)
def _poll(self) -> Optional[int]:
return self.proc.exitcode
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
self.proc.join(8)
if self.proc.is_alive():
# force kill if needed
self.proc.kill()
def _test_completion(
client: openai.OpenAI,
model: str,

View File

View File

@ -9,11 +9,13 @@ import numpy as np
import pytest
import torch
from tests.utils import create_new_process_for_each_test
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from vllm.config import VllmConfig
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
@ -24,7 +26,7 @@ from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
build_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
@ -53,6 +55,7 @@ class LogitsProcsRequestParams:
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
prompt_tokens: list[int] # Dummy prompt tokens placeholder
params: SamplingParams # Settings customized for logitproc
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
@ -63,6 +66,7 @@ class LogitsProcsRequestParams:
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type)
def __str__(self):
@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata(
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
logitsprocs = init_builtin_logitsprocs(
pin_memory_available=PIN_MEMORY_AVAILABLE,
max_num_reqs=MAX_NUM_REQS + 1,
device=device)
logitsprocs = build_logitsprocs(
vllm_config=VllmConfig(),
device=device,
is_pin_memory=PIN_MEMORY_AVAILABLE,
is_pooling_model=False,
)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
@ -462,7 +467,8 @@ def _generate_fake_step_update(
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
(add_remove_idx, add_req_params.params,
add_req_params.prompt_tokens, add_req_params.out_tokens))
persistent_batch[add_remove_idx] = add_req_params
# Append remaining added requests to end of batch
@ -470,7 +476,8 @@ def _generate_fake_step_update(
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
persistent_batch.extend(add_reqs_append)
@ -561,6 +568,7 @@ def _assert_valid(
step_idx=step_idx)
@create_new_process_for_each_test()
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())

View File

@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import sys
from typing import Union
import pytest
from tests.utils import create_new_process_for_each_test
# yapf: disable
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MODEL_NAME,
POOLING_MODEL_NAME, TEMP_GREEDY,
CustomLogitprocSource,
DummyLogitsProcessor,
dummy_module)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from tests.v1.logits_processors.utils import prompts
# yapf: enable
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS,
LogitsProcessor)
# Create a mixture of requests which do and don't utilize the dummy logitproc
sampling_params_list = [
SamplingParams(temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 128}),
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
SamplingParams(temperature=TEMP_GREEDY,
max_tokens=MAX_TOKENS,
extra_args={DUMMY_LOGITPROC_ARG: 67}),
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
]
def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
"""Compare `LLM` instance initialized with specified `kwargs` against
reference `LLM` instance.
Two scenarios:
1. Server has loaded dummy logitproc; test that requests which specify
dummy logitproc arg value behave as if logitproc is operating (output
token value should repeat), while requests that don't specify dummy
logitproc arg value should match reference `LLM` output.
2. Server has *not* loaded dummy logitproc; test that all requests
behave as if logitproc is *not* operating (output matches reference
`LLM` output.)
Args:
kwargs: `LLM` constructor kwargs
logitproc_loaded: server has loaded dummy logitproc if True
"""
# Create a vLLM instance and load custom logitproc
llm_logitproc = LLM(
model=MODEL_NAME,
gpu_memory_utilization=0.1,
**kwargs,
)
# Create a reference vLLM instance without custom logitproc
llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1)
# Run inference with logitproc loaded
outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list)
# Reference run
outputs_ref = llm_ref.generate(prompts, sampling_params_list)
# Validate outputs
for bdx, (out_lp, out_ref, params) in enumerate(
zip(outputs_logitproc, outputs_ref, sampling_params_list)):
lp_toks = out_lp.outputs[0].token_ids
if logitproc_loaded and params.extra_args:
# This request exercises custom logitproc; validate that logitproc
# forces `target_token` to be decoded in each step
target_token = params.extra_args[DUMMY_LOGITPROC_ARG]
if not all(x == target_token for x in lp_toks):
raise AssertionError(
f"Request {bdx} generated {lp_toks}, shoud all be "
f"{target_token}")
else:
# This request does not exercise custom logitproc (or custom
# logitproc is not enabled on this server); validate against
# reference result
ref_toks = out_ref.outputs[0].token_ids
if lp_toks != ref_toks:
raise AssertionError(
f"Request {bdx} generated {lp_toks}, should match "
f"{ref_toks}")
@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
def test_custom_logitsprocs(monkeypatch,
logitproc_source: CustomLogitprocSource):
"""Test offline Python interface for passing custom logitsprocs
Construct an `LLM` instance which loads a custom logitproc that has a
well-defined behavior (mask out all tokens except one `target_token`)
Construct a reference `LLM` instance with no custom logitproc
Pass in a batch of requests, 50% of which pass a `target_token` value
in through `SamplingParams.extra_args`, 50% of which do not.
Validate that
* Requests which do not activate the custom logitproc, yield the same
results for both `LLM` instances
* Requests which activate the custom logitproc, only output `target_token`
Test four scenarios, corresponding to `logitproc_source` value
* No logitsprocs loaded - test that generated tokens match reference `LLM`
instance output
* Logitproc passed in via {entrypoint, class object, fully-qualified class
name (FQCN)} - test that dummy logitproc is utilized correctly when
provided via any of these three possible sources
Args:
monkeypatch: for setting env vars
logitproc_source: what source (entrypoint, fully-qualified class name
(FQCN), class object, or None) the user pulls the
logitproc from
"""
# Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
random.seed(40)
# Choose LLM args based on logitproc source
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE:
# Scenario: the server does not load any custom logitproc
# Every other scenario is a different way of loading a custom logitproc
_run_test({}, logitproc_loaded=False)
return
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
# To that end, mock a dummy logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
# fork is required for workers to see entrypoint patch
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
_run_test({}, logitproc_loaded=True)
return
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object
kwargs["logits_processors"] = [DummyLogitsProcessor]
_run_test(kwargs, logitproc_loaded=True)
@create_new_process_for_each_test()
@pytest.mark.parametrize("logitproc_source", [
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
])
def test_pooling_rejects_custom_logitsprocs(
monkeypatch, logitproc_source: CustomLogitprocSource):
"""Validate that vLLM engine initialization properly rejects custom
logitsprocs when the model is a pooling model.
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
logitproc is actually loaded.
Scenario 1:
* Mock a logitproc entrypoint
* Validate that `LLM` does not load the logitproc
Scenario 2:
* Pass custom logitproc to `LLM` constructor
* Scenario 2a: via FQCN
* Scenario 2b: via class object
* Validate that initialization fails with appropriate exception
Args:
monkeypatch: used to set environment variables
logitproc_source: what source (entrypoint, fully-qualified class name
(FQCN), or class object) the user pulls the
logitproc from
"""
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
random.seed(40)
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
# available at a preconfigured entrypoint
# Patch in dummy logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
# fork is required for entrypoint patch to be visible to workers,
# although they should ignore the entrypoint patch anyway
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
llm = LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
)
# Require that no logitsprocs have been loaded
assert sum([
1 for _ in llm.llm_engine.model_executor.driver_worker.worker.
model_runner.input_batch.logitsprocs.all
]) == 0
return
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object
kwargs["logits_processors"] = [DummyLogitsProcessor]
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
# Require that loading a pooling model alongside the logitproc raises
# the appropriate exception.
LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
**kwargs,
)

View File

@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import random
import sys
from typing import Any, Optional
import openai
import pytest
import pytest_asyncio
from tests.utils import (RemoteOpenAIServerCustom,
create_new_process_for_each_test)
# yapf: disable
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS, MODEL_NAME,
TEMP_GREEDY, dummy_module)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from tests.v1.logits_processors.utils import prompts
# yapf: enable
def _server_with_logitproc_entrypoint(
env_dict: Optional[dict[str, str]],
model: str,
vllm_serve_args: list[str],
) -> None:
"""Start vLLM server, inject dummy logitproc entrypoint"""
# Patch `entry_points` to inject logitproc entrypoint
import importlib.metadata
importlib.metadata.entry_points = fake_entry_points # type: ignore
from vllm.entrypoints.cli import main
# fork is required for workers to see entrypoint patch
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
# Emulate `vllm serve <model> <CLI args>`
sys.argv = ["vllm", "serve", model] + vllm_serve_args
main.main()
def _server_with_logitproc_module(
env_dict: Optional[dict[str, str]],
model: str,
vllm_serve_args: list[str],
) -> None:
"""Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
# Emulate `vllm serve <model> <CLI args>`
sys.argv = ["vllm", "serve", model] + vllm_serve_args
main.main()
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
]
@pytest.fixture(scope="function",
params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]])
def server(default_server_args, request, monkeypatch):
"""Consider two server configurations:
(1) --logits-processors cli arg specifies dummy logits processor via fully-
qualified class name (FQCN); patch in a dummy logits processor module
(2) No --logits-processors cli arg; patch in a dummy logits processor
entrypoint
"""
# Test that logitproc info is passed to workers
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
if request.param:
# Launch server, append FQCN argument, inject dummy logitproc module
args = default_server_args + request.param
_server_fxn = _server_with_logitproc_module
else:
# Launch server, inject dummy logitproc entrypoint
args = default_server_args
_server_fxn = _server_with_logitproc_entrypoint
with RemoteOpenAIServerCustom(MODEL_NAME, args,
_server_fxn) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
# General request argument values for these tests
api_keyword_args = {
# Greedy sampling ensures that requests which receive the `target_token`
# arg will decode it in every step
"temperature": TEMP_GREEDY,
# Since EOS will never be decoded (unless `target_token` is EOS)
"max_tokens": MAX_TOKENS,
# Return decoded token logprobs (as a way of getting token id)
"logprobs": 0,
}
@create_new_process_for_each_test()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
"""Test custom logitsprocs when starting OpenAI server from CLI
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
that has a well-defined behavior (mask out all tokens except one
`target_token`).
Pass in requests, 50% of which pass a `target_token` value
in through `extra_body["vllm_xargs"]`, 50% of which do not.
Validate that requests which activate the custom logitproc, repeat the same
token
"""
use_dummy_logitproc = True
for prompt in prompts:
# Build request arguments
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
}
if use_dummy_logitproc:
# 50% of requests pass target_token custom arg
target_token = random.choice([128, 67])
# For requests which activate the dummy logitproc, choose one of
# two `target_token` values which are known not to be EOS tokens
request_keyword_args["extra_body"] = {
"vllm_xargs": {
DUMMY_LOGITPROC_ARG: target_token
}
}
batch = await client.completions.create(
model=model_name,
prompt=prompt,
**request_keyword_args,
)
if use_dummy_logitproc:
# Only for requests which activate dummy logitproc - validate that
# output token is repeated
choices: openai.types.CompletionChoice = batch.choices
toks = choices[0].logprobs.tokens
if not all([x == toks[0] for x in toks]):
raise AssertionError(
f"Generated {toks} should all be {toks[0]}")
# Alternate whether to activate dummy logitproc for each request
use_dummy_logitproc = not use_dummy_logitproc

View File

@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
from enum import Enum, auto
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
LogitsProcessor,
MoveDirectionality)
MODEL_NAME = "facebook/opt-125m"
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0
MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
class CustomLogitprocSource(Enum):
"""How to source a logitproc for testing purposes"""
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
LOGITPROC_SOURCE_CLASS = auto() # Via provided class object
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
self.req_info: dict[int, SamplingParams] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if not batch_update:
return
# Process added requests.
for index, params, _, _ in batch_update.added:
assert params is not None
if params.extra_args and (target_token :=
params.extra_args.get("target_token")):
self.req_info[index] = target_token
if self.req_info:
# Process removed requests.
for index in batch_update.removed:
self.req_info.pop(index, None)
# Process moved requests, unidirectional move (a->b) and swap
# (a<->b)
for adx, bdx, direct in batch_update.moved:
a_val = self.req_info.pop(adx, None)
b_val = self.req_info.pop(bdx, None)
if a_val is not None:
self.req_info[bdx] = a_val
if direct == MoveDirectionality.SWAP and b_val is not None:
self.req_info[adx] = b_val
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.req_info:
return logits
# Save target values before modification
rows_list = list(self.req_info.keys())
cols = torch.tensor([self.req_info[i] for i in rows_list],
dtype=torch.long,
device=logits.device)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
logits[rows] = float('-inf')
logits[rows, cols] = values_to_keep
return logits
"""Dummy module with dummy logitproc class"""
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore
class EntryPoint:
"""Dummy entrypoint class for logitsprocs testing"""
def __init__(self):
self.name = DUMMY_LOGITPROC_ENTRYPOINT
self.value = DUMMY_LOGITPROC_FQCN
def load(self):
return DummyLogitsProcessor
class EntryPoints(list):
"""Dummy EntryPoints class for logitsprocs testing"""
def __init__(self, group: str):
# Emulate list-like functionality
eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else []
super().__init__(eps)
# Extra attributes
self.names = [ep.name for ep in eps]
"""Fake version of importlib.metadata.entry_points"""
entry_points = lambda group: EntryPoints(group)

View File

@ -7,7 +7,7 @@ import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
@ -69,7 +69,7 @@ def create_sampling_metadata(
output_token_ids=[],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
logitsprocs=LogitsProcessors(),
)

View File

@ -9,7 +9,7 @@ import torch
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
@ -173,7 +173,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
logitsprocs=LogitsProcessors(),
)
return fake_sampling_metadata

View File

@ -13,7 +13,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata(
and all(x == 1 for x in repetition_penalties)),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessorManager(),
logitsprocs=LogitsProcessors(),
)

View File

@ -62,6 +62,7 @@ if TYPE_CHECKING:
QuantizationConfig)
from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.sample.logits_processor import LogitsProcessor
HfOverrides = Union[dict, Callable[[type], type]]
else:
@ -72,6 +73,7 @@ else:
BaseModelLoader = Any
LoadFormats = Any
TensorizerConfig = Any
LogitsProcessor = Any
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
me_quant = LazyLoader("model_executor", globals(),
@ -465,6 +467,9 @@ class ModelConfig:
- "transformers" will use the Transformers model implementation."""
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
def compute_hash(self) -> str:
"""

View File

@ -43,6 +43,7 @@ from vllm.transformers_utils.config import is_interleaved
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
from vllm.v1.sample.logits_processor import LogitsProcessor
# yapf: enable
@ -435,6 +436,10 @@ class EngineArgs:
enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
logits_processors: Optional[list[Union[
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
"""Custom logitproc types"""
async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
@ -549,6 +554,8 @@ class EngineArgs:
**model_kwargs["model_impl"])
model_group.add_argument("--override-attention-dtype",
**model_kwargs["override_attention_dtype"])
model_group.add_argument("--logits-processors",
**model_kwargs["logits_processors"])
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
@ -940,6 +947,7 @@ class EngineArgs:
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
)
def validate_tensorizer_args(self):

View File

@ -55,6 +55,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
from vllm.v1.metrics.reader import Metric
@ -198,6 +199,8 @@ class LLM:
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None,
logits_processors: Optional[list[Union[str,
type[LogitsProcessor]]]] = None,
**kwargs,
) -> None:
"""LLM constructor."""
@ -272,6 +275,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,
)

View File

@ -2562,7 +2562,7 @@ def direct_register_custom_op(
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
Resolve an object by its fully-qualified class name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)

View File

@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import itertools
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union
import torch
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor)
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder,
LogitsProcessors)
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
" logits processors.")
LOGITSPROCS_GROUP = 'vllm.logits_processors'
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
MinPLogitsProcessor,
]
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
"""Load all installed logit processor plugins"""
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
if len(installed_logitsprocs_plugins) == 0:
logger.debug("No logitsprocs plugins installed (group %s).",
LOGITSPROCS_GROUP)
return []
# Load logitsprocs plugins
logger.debug("Loading installed logitsprocs plugins (group %s):",
LOGITSPROCS_GROUP)
classes: list[type[LogitsProcessor]] = []
for entrypoint in installed_logitsprocs_plugins:
try:
logger.debug("- Loading logitproc plugin entrypoint=%s target=%s",
entrypoint.name, entrypoint.value)
classes.append(entrypoint.load())
except Exception as e:
raise RuntimeError(
f"Failed to load LogitsProcessor plugin {entrypoint}") from e
return classes
def _load_logitsprocs_by_fqcns(
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]]
) -> list[type[LogitsProcessor]]:
"""Load logit processor types, identifying them by fully-qualified class
names (FQCNs).
Effectively, a mixed list of logitproc types and FQCN strings is converted
into a list of entirely logitproc types, by loading from the FQCNs.
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
Already-loaded logitproc types must be subclasses of LogitsProcessor
Args:
logits_processors: Potentially mixed list of logitsprocs types and FQCN
strings for logitproc types
Returns:
List of logitproc types
"""
if not logits_processors:
return []
logger.debug(
"%s additional custom logits processors specified, checking whether "
"they need to be loaded.", len(logits_processors))
classes: list[type[LogitsProcessor]] = []
for ldx, logitproc in enumerate(logits_processors):
if isinstance(logitproc, type):
logger.debug(" - Already-loaded logit processor: %s",
logitproc.__name__)
if not issubclass(logitproc, LogitsProcessor):
raise ValueError(
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
)
classes.append(logitproc)
continue
logger.debug("- Loading logits processor %s", logitproc)
module_path, qualname = logitproc.split(":")
try:
# Load module
module = importlib.import_module(module_path)
except Exception as e:
raise RuntimeError(
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
) from e
# Walk down dotted name to get logitproc class
obj = module
for attr in qualname.split("."):
obj = getattr(obj, attr)
if not isinstance(obj, type):
raise ValueError("Loaded logit processor must be a type.")
if not issubclass(obj, LogitsProcessor):
raise ValueError(
f"{obj.__name__} must be a subclass of LogitsProcessor")
classes.append(obj)
return classes
def _load_custom_logitsprocs(
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]],
) -> list[type[LogitsProcessor]]:
"""Load all custom logits processors.
* First load all installed logitproc plugins
* Second load custom logitsprocs pass by the user at initialization time
Args:
logits_processors: potentially mixed list of logitproc types and
logitproc type fully-qualified names (FQCNs)
which need to be loaded
Returns:
A list of all loaded logitproc types
"""
from vllm.platforms import current_platform
if current_platform.is_tpu():
# No logitsprocs specified by caller
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
return []
return (_load_logitsprocs_plugins() +
_load_logitsprocs_by_fqcns(logits_processors))
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug("Skipping logits processor loading because pooling models"
" do not support logits processors.")
return LogitsProcessors()
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
__all__ = [
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP"
]

View File

@ -1,241 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
from typing import Optional, Union
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
import torch
from torch._prims_common import DeviceLikeType
from vllm import PoolingParams, SamplingParams
from vllm.logger import init_logger
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
logger = init_logger(__name__)
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = 0
# Two-way i1<->i2 req swap within batch
SWAP = 1
# (index, params, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclasses.dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Note: each added request is represented as
# (index, params, output_tok_ids)
# Key assumption: output_tok_ids is a reference to the
# request's running output tokens list; in this way
# the logits processors always see the latest list of
# generated tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self._is_removed_sorted = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from
the persistent batch.
Must not be called after the first time
self.removed, self.pop_removed() or
self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> Optional[int]:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> Optional[int]:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure
and reset internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not any((self._removed, self.moved, self.added)):
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
# Reset removed/moved/added update lists
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessor(ABC):
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
TODO(andy): won't be utilized until logits
processors are user-extensible
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional[BatchUpdate],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError
@dataclass
class LogitsProcessorManager:
"""Encapsulates initialized logitsproc objects."""
argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # argmax-invariant logitsprocs
non_argmax_invariant: list[LogitsProcessor] = field(
default_factory=list) # non-argmax-invariant logitsprocs
@property
def all(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)
###### ----- Built-in LogitsProcessor impls below here
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MinPLogitsProcessor(LogitsProcessor):
def __init__(self, max_num_reqs: int, pin_memory: bool,
device: DeviceLikeType):
super().__init__()
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
pin_memory=is_pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device("cpu") != torch.device(device)
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
@ -260,8 +51,8 @@ class MinPLogitsProcessor(LogitsProcessor):
needs_update = False
# Process added requests.
for index, params, _ in batch_update.added:
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
for index, params, _, _ in batch_update.added:
min_p = params.min_p
if self.min_p_cpu[index] != min_p:
needs_update = True
self.min_p_cpu[index] = min_p
@ -316,11 +107,10 @@ class MinPLogitsProcessor(LogitsProcessor):
class LogitBiasLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
super().__init__()
self.biases: dict[int, dict[int, float]] = {}
def __init__(self, _, device: torch.device, is_pin_memory: bool):
self.device = device
self.pin_memory = pin_memory
self.pin_memory = is_pin_memory
self.biases: dict[int, dict[int, float]] = {}
self.bias_tensor: torch.Tensor = torch.tensor(())
self.logits_slice = (self._device_tensor([], torch.int32),
@ -337,9 +127,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
needs_update: bool = False
# Process added requests.
for index, params, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
for index, params, _, _ in batch_update.added:
if lb := params.logit_bias:
self.biases[index] = lb
needs_update = True
else:
@ -400,12 +189,12 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(self, pin_memory: bool, device: torch.device):
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
# index -> (min_toks, output_token_ids, stop_token_ids)
super().__init__()
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
self.device = device
self.pin_memory = pin_memory
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor,
@ -424,9 +213,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
if batch_update:
# Process added requests.
for index, params, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
for index, params, _, output_tok_ids in batch_update.added:
if ((min_tokens := params.min_tokens)
and len(output_tok_ids) < min_tokens):
# Replace request metadata at batch index
self.min_toks[index] = (min_tokens, output_tok_ids,
@ -499,35 +287,3 @@ class MinTokensLogitsProcessor(LogitsProcessor):
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
return logits
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device: torch.device) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.
Args:
pin_memory_available: pinned memory is available for use
for use by logitsproc
max_num_reqs: ceiling on request count in persistent batch
device: inference device
Returns:
Data structure encapsulating loaded logitsprocs
"""
min_tokens_logitproc = MinTokensLogitsProcessor(
pin_memory=pin_memory_available, device=device)
logit_bias_logitproc = LogitBiasLogitsProcessor(
pin_memory=pin_memory_available, device=device)
min_p_logitproc = MinPLogitsProcessor(
pin_memory=pin_memory_available,
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)
return LogitsProcessorManager(
non_argmax_invariant=[
min_tokens_logitproc,
logit_bias_logitproc,
],
argmax_invariant=[min_p_logitproc],
)

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
import torch
from vllm import SamplingParams
if TYPE_CHECKING:
from vllm.config import VllmConfig
class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
UNIDIRECTIONAL = auto()
# Two-way i1<->i2 req swap within batch
SWAP = auto()
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclass(frozen=True)
class BatchUpdate:
"""Persistent batch state change info for logitsprocs"""
batch_size: int # Current num reqs in batch
# Metadata for requests added to, removed from, and moved
# within the persistent batch.
#
# Key assumption: the `output_tok_ids` list (which is an element of each
# tuple in `added`) is a reference to the request's running output tokens
# list; via this reference, the logits processors always see the latest
# list of generated output tokens
removed: Sequence[RemovedRequest]
moved: Sequence[MovedRequest]
added: Sequence[AddedRequest]
class LogitsProcessor(ABC):
@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool) -> None:
raise NotImplementedError
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
@abstractmethod
def update_state(
self,
batch_update: Optional["BatchUpdate"],
) -> None:
"""Called when there are new output tokens, prior
to each forward pass.
Args:
batch_update is non-None iff there have been
changes to the batch makeup.
"""
raise NotImplementedError

View File

@ -0,0 +1,149 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm.v1.sample.logits_processor.interface import (AddedRequest,
BatchUpdate,
MovedRequest,
RemovedRequest)
if TYPE_CHECKING:
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
class BatchUpdateBuilder:
"""Helps track persistent batch state changes and build
a batch update data structure for logitsprocs
Assumptions:
* All information about requests removed from persistent batch
during a step is aggregated in self._removed through calls to
self.removed_append() at the beginning of a step. This must happen
before the first time that self.removed, self.pop_removed()
or self.peek_removed() are invoked in a given step
* After the first time that self.removed, self.pop_removed()
or self.peek_removed() are read in a step, no new removals
are registered using self.removed_append()
* Elements of self._removed are never directly modified, added or
removed (i.e. modification is only via self.removed_append() and
self.pop_removed())
Guarantees under above assumptions:
* self.removed is always sorted in descending order
* self.pop_removed() and self.peek_removed() both return
the lowest removed request index in the current step
"""
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self._is_removed_sorted = False
def _ensure_removed_sorted(self) -> None:
"""Sort removed request indices in
descending order.
Idempotent after first call in a
given step, until reset.
"""
if not self._is_removed_sorted:
self._removed.sort(reverse=True)
self._is_removed_sorted = True
@property
def removed(self) -> list[RemovedRequest]:
"""Removed request indices sorted in
descending order"""
self._ensure_removed_sorted()
return self._removed
def removed_append(self, index: int) -> None:
"""Register the removal of a request from the persistent batch.
Must not be called after the first time self.removed,
self.pop_removed() or self.peek_removed() are invoked.
Args:
index: request index
"""
if self._is_removed_sorted:
raise RuntimeError("Cannot register new removed request after"
" self.removed has been read.")
self._removed.append(index)
def has_removed(self) -> bool:
return bool(self._removed)
def peek_removed(self) -> Optional[int]:
"""Return lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed[-1]
return None
def pop_removed(self) -> Optional[int]:
"""Pop lowest removed request index"""
if self.has_removed():
self._ensure_removed_sorted()
return self._removed.pop()
return None
def _is_update(self) -> bool:
"""True if there is a batch state change"""
return any((self._removed, self.moved, self.added))
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
"""Generate a logitsprocs batch update data structure and reset
internal batch update builder state.
Args:
batch_size: current persistent batch size
Returns:
Frozen logitsprocs batch update instance; `None` if no updates
"""
# Reset removal-sorting logic
self._is_removed_sorted = False
if not self._is_update():
# No update; short-circuit
return None
# Build batch state update
batch_update = BatchUpdate(
batch_size=batch_size,
removed=self._removed,
moved=self.moved,
added=self.added,
)
self._removed = []
self.moved = []
self.added = []
return batch_update
class LogitsProcessors:
"""Encapsulates initialized logitsproc objects."""
def __init__(
self,
logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
self.argmax_invariant: list[LogitsProcessor] = []
self.non_argmax_invariant: list[LogitsProcessor] = []
if logitsprocs:
for logitproc in logitsprocs:
(self.argmax_invariant if logitproc.is_argmax_invariant() else
self.non_argmax_invariant).append(logitproc)
@property
def all(self) -> Iterator["LogitsProcessor"]:
"""Iterator over all logits processors."""
return chain(self.argmax_invariant, self.non_argmax_invariant)

View File

@ -6,7 +6,7 @@ from typing import Optional
import torch
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
@dataclass
@ -40,4 +40,4 @@ class SamplingMetadata:
bad_words_token_ids: dict[int, list[list[int]]]
# Loaded logits processors
logitsprocs: LogitsProcessorManager
logitsprocs: LogitsProcessors

View File

@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
MoveDirectionality,
init_builtin_logitsprocs)
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
@ -78,8 +78,11 @@ class InputBatch:
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@ -221,14 +224,6 @@ class InputBatch:
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# Define logits processors.
# TODO(andy): logits processor list should be extensible via engine
# constructor argument; for now the list is fixed.
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
device=device)
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
@ -244,6 +239,10 @@ class InputBatch:
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@ -255,28 +254,35 @@ class InputBatch:
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _get_next_add_index(self) -> int:
if (req_index := self.batch_update_builder.pop_removed()) is not None:
# Fill the empty index.
return req_index
# Append to end
return self.num_reqs
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations"""
req_index = self._get_next_add_index()
assert req_index < self.max_num_reqs
params = (request.sampling_params
if request.sampling_params else request.pooling_params)
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(req_index, params, request.output_token_ids))
return req_index
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
req_index = self._register_add_request(request)
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
@ -411,7 +417,10 @@ class InputBatch:
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self.batch_update_builder.removed_append(req_index)
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
@ -446,6 +455,8 @@ class InputBatch:
return req_index
def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
@ -513,11 +524,18 @@ class InputBatch:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
@ -541,6 +559,8 @@ class InputBatch:
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
@ -596,15 +616,20 @@ class InputBatch:
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply batch updates, reset input batch at end of step
"""Apply any batch updates to sampling metadata."""
* Apply batch add/remove/permute to logits procs' states
* If batch state is modified, update sampling metadata
"""
if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)

View File

@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
output_token_ids=[[] for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
logitsprocs=LogitsProcessors(),
)
try:
sampler_output = self.sampler(logits=logits,
@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
)
def _allocate_kv_cache_tensors(