mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 07:54:27 +08:00
[V1] Logits processors extensibility (#19912)
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: Andrew Feldman <afeld2012@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
4fc722eca4
commit
bf7f470b22
@ -253,6 +253,7 @@ steps:
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
|
||||
147
examples/offline_inference/logits_processor.py
Normal file
147
examples/offline_inference/logits_processor.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""This example demonstrates instantiating vLLM with a custom logits processor
|
||||
class object.
|
||||
|
||||
For a basic example of implementing a custom logits processor, see
|
||||
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
|
||||
|
||||
For testing purposes, a dummy logits processor is employed which, if
|
||||
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
||||
will mask out all tokens except `target_token`.
|
||||
|
||||
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||
`target_token`, and for these requests - and *only* these requests - we
|
||||
expect the `target_token` to be decoded in each step, yielding an output
|
||||
similar to that shown below:
|
||||
|
||||
Generated Outputs:
|
||||
------------------------------------------------------------
|
||||
Prompt: 'Hello, my name is'
|
||||
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The president of the United States is'
|
||||
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The capital of France is'
|
||||
Output: ' also also also also also also also also also also also also also
|
||||
also also also'
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The future of AI is'
|
||||
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
|
||||
|
||||
# Hypothetical custom logits processor
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, SamplingParams] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
if params.extra_args and (
|
||||
target_token := params.extra_args.get("target_token")
|
||||
):
|
||||
self.req_info[index] = target_token
|
||||
|
||||
if self.req_info:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
self.req_info.pop(index, None)
|
||||
|
||||
# Process moved requests, unidirectional move (a->b) and swap
|
||||
# (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
a_val = self.req_info.pop(adx, None)
|
||||
b_val = self.req_info.pop(bdx, None)
|
||||
if a_val is not None:
|
||||
self.req_info[bdx] = a_val
|
||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
||||
self.req_info[adx] = b_val
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
rows_list = list(self.req_info.keys())
|
||||
cols = torch.tensor(
|
||||
[self.req_info[i] for i in rows_list],
|
||||
dtype=torch.long,
|
||||
device=logits.device,
|
||||
)
|
||||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||
SamplingParams(temperature=0.0),
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||
SamplingParams(temperature=0.0),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
logits_processors=[DummyLogitsProcessor],
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params_list)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -13,6 +13,7 @@ import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager, suppress
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
@ -76,6 +77,23 @@ VLLM_PATH = Path(__file__).parent.parent
|
||||
class RemoteOpenAIServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||
|
||||
def _start_server(self, model: str, vllm_serve_args: list[str],
|
||||
env_dict: Optional[dict[str, str]]) -> None:
|
||||
"""Subclasses override this method to customize server process launch
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc: subprocess.Popen = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
@ -128,18 +146,7 @@ class RemoteOpenAIServer:
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
self._start_server(model, vllm_serve_args, env_dict)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=max_wait_seconds)
|
||||
@ -155,6 +162,10 @@ class RemoteOpenAIServer:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _poll(self) -> Optional[int]:
|
||||
"""Subclasses override this method to customize process polling"""
|
||||
return self.proc.poll()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
@ -169,7 +180,7 @@ class RemoteOpenAIServer:
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self.proc.poll()
|
||||
result = self._poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
@ -205,6 +216,48 @@ class RemoteOpenAIServer:
|
||||
**kwargs)
|
||||
|
||||
|
||||
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
"""Launch test server with custom child process"""
|
||||
|
||||
def _start_server(self, model: str, vllm_serve_args: list[str],
|
||||
env_dict: Optional[dict[str, str]]) -> None:
|
||||
self.proc: Process = Process(
|
||||
target=self.child_process_fxn,
|
||||
args=(env_dict, model,
|
||||
vllm_serve_args)) # type: ignore[assignment]
|
||||
self.proc.start()
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
child_process_fxn: Callable[
|
||||
[Optional[dict[str, str]], str, list[str]], None],
|
||||
*,
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
"""Store custom child process function then invoke superclass
|
||||
constructor which will indirectly launch it."""
|
||||
self.child_process_fxn = child_process_fxn
|
||||
super().__init__(model=model,
|
||||
vllm_serve_args=vllm_serve_args,
|
||||
env_dict=env_dict,
|
||||
seed=seed,
|
||||
auto_port=auto_port,
|
||||
max_wait_seconds=max_wait_seconds)
|
||||
|
||||
def _poll(self) -> Optional[int]:
|
||||
return self.proc.exitcode
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
self.proc.join(8)
|
||||
if self.proc.is_alive():
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
|
||||
def _test_completion(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
|
||||
0
tests/v1/logits_processors/__init__.py
Normal file
0
tests/v1/logits_processors/__init__.py
Normal file
@ -9,11 +9,13 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
|
||||
create_penalty_tensor,
|
||||
create_prompt_tokens_tensor,
|
||||
fake_apply_logitsprocs,
|
||||
fake_update_logitsprocs_state)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
@ -24,7 +26,7 @@ from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor,
|
||||
MoveDirectionality,
|
||||
init_builtin_logitsprocs)
|
||||
build_logitsprocs)
|
||||
# yapf: enable
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
@ -53,6 +55,7 @@ class LogitsProcsRequestParams:
|
||||
workload_index: int
|
||||
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
|
||||
out_tokens: list[int] # Output tokens required for min tokens test
|
||||
prompt_tokens: list[int] # Dummy prompt tokens placeholder
|
||||
params: SamplingParams # Settings customized for logitproc
|
||||
|
||||
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
|
||||
@ -63,6 +66,7 @@ class LogitsProcsRequestParams:
|
||||
# don't matter *for these tests* so use 0 as a dummy value
|
||||
self.out_tokens = ([0] *
|
||||
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
|
||||
self.prompt_tokens = []
|
||||
self.params = _sampling_params_from_logitproc(logitproc_type)
|
||||
|
||||
def __str__(self):
|
||||
@ -88,11 +92,12 @@ def _generate_fake_sampling_metadata(
|
||||
vocab_size,
|
||||
size=np.random.randint(
|
||||
1, MAX_NUM_PROMPT_TOKENS)).tolist())
|
||||
logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=PIN_MEMORY_AVAILABLE,
|
||||
max_num_reqs=MAX_NUM_REQS + 1,
|
||||
device=device)
|
||||
|
||||
logitsprocs = build_logitsprocs(
|
||||
vllm_config=VllmConfig(),
|
||||
device=device,
|
||||
is_pin_memory=PIN_MEMORY_AVAILABLE,
|
||||
is_pooling_model=False,
|
||||
)
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
@ -462,7 +467,8 @@ def _generate_fake_step_update(
|
||||
# Replace as many removed requests as possible with added requests
|
||||
add_remove_idx = batch_update_builder.pop_removed()
|
||||
batch_update_builder.added.append(
|
||||
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
|
||||
(add_remove_idx, add_req_params.params,
|
||||
add_req_params.prompt_tokens, add_req_params.out_tokens))
|
||||
persistent_batch[add_remove_idx] = add_req_params
|
||||
|
||||
# Append remaining added requests to end of batch
|
||||
@ -470,7 +476,8 @@ def _generate_fake_step_update(
|
||||
num_step_add_replace):(wdx +
|
||||
num_step_add)]
|
||||
batch_update_builder.added.extend([
|
||||
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
|
||||
(adx + batch_size, add_req_params.params, add_req_params.prompt_tokens,
|
||||
add_req_params.out_tokens)
|
||||
for adx, add_req_params in enumerate(add_reqs_append)
|
||||
])
|
||||
persistent_batch.extend(add_reqs_append)
|
||||
@ -561,6 +568,7 @@ def _assert_valid(
|
||||
step_idx=step_idx)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
|
||||
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
|
||||
237
tests/v1/logits_processors/test_custom_offline.py
Normal file
237
tests/v1/logits_processors/test_custom_offline.py
Normal file
@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
# yapf: disable
|
||||
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS, MODEL_NAME,
|
||||
POOLING_MODEL_NAME, TEMP_GREEDY,
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
dummy_module)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
# yapf: enable
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
LogitsProcessor)
|
||||
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 128}),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
SamplingParams(temperature=TEMP_GREEDY,
|
||||
max_tokens=MAX_TOKENS,
|
||||
extra_args={DUMMY_LOGITPROC_ARG: 67}),
|
||||
SamplingParams(temperature=TEMP_GREEDY, max_tokens=MAX_TOKENS),
|
||||
]
|
||||
|
||||
|
||||
def _run_test(kwargs: dict, logitproc_loaded: bool) -> None:
|
||||
"""Compare `LLM` instance initialized with specified `kwargs` against
|
||||
reference `LLM` instance.
|
||||
|
||||
Two scenarios:
|
||||
1. Server has loaded dummy logitproc; test that requests which specify
|
||||
dummy logitproc arg value behave as if logitproc is operating (output
|
||||
token value should repeat), while requests that don't specify dummy
|
||||
logitproc arg value should match reference `LLM` output.
|
||||
2. Server has *not* loaded dummy logitproc; test that all requests
|
||||
behave as if logitproc is *not* operating (output matches reference
|
||||
`LLM` output.)
|
||||
|
||||
Args:
|
||||
kwargs: `LLM` constructor kwargs
|
||||
logitproc_loaded: server has loaded dummy logitproc if True
|
||||
"""
|
||||
|
||||
# Create a vLLM instance and load custom logitproc
|
||||
llm_logitproc = LLM(
|
||||
model=MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create a reference vLLM instance without custom logitproc
|
||||
llm_ref = LLM(model=MODEL_NAME, gpu_memory_utilization=0.1)
|
||||
|
||||
# Run inference with logitproc loaded
|
||||
outputs_logitproc = llm_logitproc.generate(prompts, sampling_params_list)
|
||||
|
||||
# Reference run
|
||||
outputs_ref = llm_ref.generate(prompts, sampling_params_list)
|
||||
|
||||
# Validate outputs
|
||||
for bdx, (out_lp, out_ref, params) in enumerate(
|
||||
zip(outputs_logitproc, outputs_ref, sampling_params_list)):
|
||||
lp_toks = out_lp.outputs[0].token_ids
|
||||
if logitproc_loaded and params.extra_args:
|
||||
# This request exercises custom logitproc; validate that logitproc
|
||||
# forces `target_token` to be decoded in each step
|
||||
target_token = params.extra_args[DUMMY_LOGITPROC_ARG]
|
||||
if not all(x == target_token for x in lp_toks):
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, shoud all be "
|
||||
f"{target_token}")
|
||||
else:
|
||||
# This request does not exercise custom logitproc (or custom
|
||||
# logitproc is not enabled on this server); validate against
|
||||
# reference result
|
||||
ref_toks = out_ref.outputs[0].token_ids
|
||||
if lp_toks != ref_toks:
|
||||
raise AssertionError(
|
||||
f"Request {bdx} generated {lp_toks}, should match "
|
||||
f"{ref_toks}")
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", list(CustomLogitprocSource))
|
||||
def test_custom_logitsprocs(monkeypatch,
|
||||
logitproc_source: CustomLogitprocSource):
|
||||
"""Test offline Python interface for passing custom logitsprocs
|
||||
|
||||
Construct an `LLM` instance which loads a custom logitproc that has a
|
||||
well-defined behavior (mask out all tokens except one `target_token`)
|
||||
|
||||
Construct a reference `LLM` instance with no custom logitproc
|
||||
|
||||
Pass in a batch of requests, 50% of which pass a `target_token` value
|
||||
in through `SamplingParams.extra_args`, 50% of which do not.
|
||||
|
||||
Validate that
|
||||
* Requests which do not activate the custom logitproc, yield the same
|
||||
results for both `LLM` instances
|
||||
* Requests which activate the custom logitproc, only output `target_token`
|
||||
|
||||
Test four scenarios, corresponding to `logitproc_source` value
|
||||
* No logitsprocs loaded - test that generated tokens match reference `LLM`
|
||||
instance output
|
||||
* Logitproc passed in via {entrypoint, class object, fully-qualified class
|
||||
name (FQCN)} - test that dummy logitproc is utilized correctly when
|
||||
provided via any of these three possible sources
|
||||
|
||||
Args:
|
||||
monkeypatch: for setting env vars
|
||||
logitproc_source: what source (entrypoint, fully-qualified class name
|
||||
(FQCN), class object, or None) the user pulls the
|
||||
logitproc from
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
|
||||
# Choose LLM args based on logitproc source
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_NONE:
|
||||
# Scenario: the server does not load any custom logitproc
|
||||
# Every other scenario is a different way of loading a custom logitproc
|
||||
_run_test({}, logitproc_loaded=False)
|
||||
return
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||
# Scenario: vLLM loads a logitproc from a preconfigured entrypoint
|
||||
# To that end, mock a dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||
_run_test({}, logitproc_loaded=True)
|
||||
return
|
||||
|
||||
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
# Inject dummy module which defines logitproc
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
|
||||
_run_test(kwargs, logitproc_loaded=True)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", [
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_FQCN,
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||
])
|
||||
def test_pooling_rejects_custom_logitsprocs(
|
||||
monkeypatch, logitproc_source: CustomLogitprocSource):
|
||||
"""Validate that vLLM engine initialization properly rejects custom
|
||||
logitsprocs when the model is a pooling model.
|
||||
|
||||
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
|
||||
logitproc is actually loaded.
|
||||
|
||||
Scenario 1:
|
||||
* Mock a logitproc entrypoint
|
||||
* Validate that `LLM` does not load the logitproc
|
||||
|
||||
Scenario 2:
|
||||
* Pass custom logitproc to `LLM` constructor
|
||||
* Scenario 2a: via FQCN
|
||||
* Scenario 2b: via class object
|
||||
* Validate that initialization fails with appropriate exception
|
||||
|
||||
Args:
|
||||
monkeypatch: used to set environment variables
|
||||
logitproc_source: what source (entrypoint, fully-qualified class name
|
||||
(FQCN), or class object) the user pulls the
|
||||
logitproc from
|
||||
"""
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
random.seed(40)
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
|
||||
# available at a preconfigured entrypoint
|
||||
|
||||
# Patch in dummy logitproc entrypoint
|
||||
import importlib.metadata
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
|
||||
# fork is required for entrypoint patch to be visible to workers,
|
||||
# although they should ignore the entrypoint patch anyway
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||
|
||||
llm = LLM(
|
||||
runner="pooling",
|
||||
model=POOLING_MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
)
|
||||
# Require that no logitsprocs have been loaded
|
||||
assert sum([
|
||||
1 for _ in llm.llm_engine.model_executor.driver_worker.worker.
|
||||
model_runner.input_batch.logitsprocs.all
|
||||
]) == 0
|
||||
return
|
||||
|
||||
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
|
||||
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
|
||||
# Require that loading a pooling model alongside the logitproc raises
|
||||
# the appropriate exception.
|
||||
LLM(
|
||||
runner="pooling",
|
||||
model=POOLING_MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
**kwargs,
|
||||
)
|
||||
180
tests/v1/logits_processors/test_custom_online.py
Normal file
180
tests/v1/logits_processors/test_custom_online.py
Normal file
@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import (RemoteOpenAIServerCustom,
|
||||
create_new_process_for_each_test)
|
||||
# yapf: disable
|
||||
from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
DUMMY_LOGITPROC_FQCN,
|
||||
DUMMY_LOGITPROC_MODULE,
|
||||
MAX_TOKENS, MODEL_NAME,
|
||||
TEMP_GREEDY, dummy_module)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def _server_with_logitproc_entrypoint(
|
||||
env_dict: Optional[dict[str, str]],
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
) -> None:
|
||||
"""Start vLLM server, inject dummy logitproc entrypoint"""
|
||||
|
||||
# Patch `entry_points` to inject logitproc entrypoint
|
||||
import importlib.metadata
|
||||
importlib.metadata.entry_points = fake_entry_points # type: ignore
|
||||
from vllm.entrypoints.cli import main
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
# Emulate `vllm serve <model> <CLI args>`
|
||||
sys.argv = ["vllm", "serve", model] + vllm_serve_args
|
||||
main.main()
|
||||
|
||||
|
||||
def _server_with_logitproc_module(
|
||||
env_dict: Optional[dict[str, str]],
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
) -> None:
|
||||
"""Start vLLM server, inject module with dummy logitproc"""
|
||||
|
||||
# Patch `modules` to inject dummy logitproc module
|
||||
from vllm.entrypoints.cli import main
|
||||
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
|
||||
|
||||
# fork is required for workers to see entrypoint patch
|
||||
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = "fork"
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
# Emulate `vllm serve <model> <CLI args>`
|
||||
sys.argv = ["vllm", "serve", model] + vllm_serve_args
|
||||
main.main()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function",
|
||||
params=[[], ["--logits-processors", DUMMY_LOGITPROC_FQCN]])
|
||||
def server(default_server_args, request, monkeypatch):
|
||||
"""Consider two server configurations:
|
||||
(1) --logits-processors cli arg specifies dummy logits processor via fully-
|
||||
qualified class name (FQCN); patch in a dummy logits processor module
|
||||
(2) No --logits-processors cli arg; patch in a dummy logits processor
|
||||
entrypoint
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
|
||||
if request.param:
|
||||
# Launch server, append FQCN argument, inject dummy logitproc module
|
||||
args = default_server_args + request.param
|
||||
_server_fxn = _server_with_logitproc_module
|
||||
else:
|
||||
# Launch server, inject dummy logitproc entrypoint
|
||||
args = default_server_args
|
||||
_server_fxn = _server_with_logitproc_entrypoint
|
||||
|
||||
with RemoteOpenAIServerCustom(MODEL_NAME, args,
|
||||
_server_fxn) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
# General request argument values for these tests
|
||||
api_keyword_args = {
|
||||
# Greedy sampling ensures that requests which receive the `target_token`
|
||||
# arg will decode it in every step
|
||||
"temperature": TEMP_GREEDY,
|
||||
# Since EOS will never be decoded (unless `target_token` is EOS)
|
||||
"max_tokens": MAX_TOKENS,
|
||||
# Return decoded token logprobs (as a way of getting token id)
|
||||
"logprobs": 0,
|
||||
}
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test custom logitsprocs when starting OpenAI server from CLI
|
||||
|
||||
Launch vLLM OpenAI-compatible server, configured to load a custom logitproc
|
||||
that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`).
|
||||
|
||||
Pass in requests, 50% of which pass a `target_token` value
|
||||
in through `extra_body["vllm_xargs"]`, 50% of which do not.
|
||||
|
||||
Validate that requests which activate the custom logitproc, repeat the same
|
||||
token
|
||||
"""
|
||||
|
||||
use_dummy_logitproc = True
|
||||
for prompt in prompts:
|
||||
# Build request arguments
|
||||
request_keyword_args: dict[str, Any] = {
|
||||
**api_keyword_args,
|
||||
}
|
||||
if use_dummy_logitproc:
|
||||
# 50% of requests pass target_token custom arg
|
||||
target_token = random.choice([128, 67])
|
||||
# For requests which activate the dummy logitproc, choose one of
|
||||
# two `target_token` values which are known not to be EOS tokens
|
||||
request_keyword_args["extra_body"] = {
|
||||
"vllm_xargs": {
|
||||
DUMMY_LOGITPROC_ARG: target_token
|
||||
}
|
||||
}
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
**request_keyword_args,
|
||||
)
|
||||
|
||||
if use_dummy_logitproc:
|
||||
# Only for requests which activate dummy logitproc - validate that
|
||||
# output token is repeated
|
||||
choices: openai.types.CompletionChoice = batch.choices
|
||||
toks = choices[0].logprobs.tokens
|
||||
if not all([x == toks[0] for x in toks]):
|
||||
raise AssertionError(
|
||||
f"Generated {toks} should all be {toks[0]}")
|
||||
|
||||
# Alternate whether to activate dummy logitproc for each request
|
||||
use_dummy_logitproc = not use_dummy_logitproc
|
||||
127
tests/v1/logits_processors/utils.py
Normal file
127
tests/v1/logits_processors/utils.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DUMMY_LOGITPROC_ARG = "target_token"
|
||||
TEMP_GREEDY = 0.0
|
||||
MAX_TOKENS = 20
|
||||
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
|
||||
DUMMY_LOGITPROC_MODULE = "DummyModule"
|
||||
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"
|
||||
|
||||
|
||||
class CustomLogitprocSource(Enum):
|
||||
"""How to source a logitproc for testing purposes"""
|
||||
LOGITPROC_SOURCE_NONE = auto() # No custom logitproc
|
||||
LOGITPROC_SOURCE_ENTRYPOINT = auto() # Via entrypoint
|
||||
LOGITPROC_SOURCE_FQCN = auto() # Via fully-qualified class name (FQCN)
|
||||
LOGITPROC_SOURCE_CLASS = auto() # Via provided class object
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
self.req_info: dict[int, SamplingParams] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
if params.extra_args and (target_token :=
|
||||
params.extra_args.get("target_token")):
|
||||
self.req_info[index] = target_token
|
||||
|
||||
if self.req_info:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
self.req_info.pop(index, None)
|
||||
|
||||
# Process moved requests, unidirectional move (a->b) and swap
|
||||
# (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
a_val = self.req_info.pop(adx, None)
|
||||
b_val = self.req_info.pop(bdx, None)
|
||||
if a_val is not None:
|
||||
self.req_info[bdx] = a_val
|
||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
||||
self.req_info[adx] = b_val
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
rows_list = list(self.req_info.keys())
|
||||
cols = torch.tensor([self.req_info[i] for i in rows_list],
|
||||
dtype=torch.long,
|
||||
device=logits.device)
|
||||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float('-inf')
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
"""Dummy module with dummy logitproc class"""
|
||||
dummy_module = types.ModuleType(DUMMY_LOGITPROC_MODULE)
|
||||
dummy_module.DummyLogitsProcessor = DummyLogitsProcessor # type: ignore
|
||||
|
||||
|
||||
class EntryPoint:
|
||||
"""Dummy entrypoint class for logitsprocs testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.name = DUMMY_LOGITPROC_ENTRYPOINT
|
||||
self.value = DUMMY_LOGITPROC_FQCN
|
||||
|
||||
def load(self):
|
||||
return DummyLogitsProcessor
|
||||
|
||||
|
||||
class EntryPoints(list):
|
||||
"""Dummy EntryPoints class for logitsprocs testing"""
|
||||
|
||||
def __init__(self, group: str):
|
||||
# Emulate list-like functionality
|
||||
eps = [EntryPoint()] if group == LOGITSPROCS_GROUP else []
|
||||
super().__init__(eps)
|
||||
# Extra attributes
|
||||
self.names = [ep.name for ep in eps]
|
||||
|
||||
|
||||
"""Fake version of importlib.metadata.entry_points"""
|
||||
entry_points = lambda group: EntryPoints(group)
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
|
||||
RejectionSampler)
|
||||
@ -69,7 +69,7 @@ def create_sampling_metadata(
|
||||
output_token_ids=[],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
@ -173,7 +173,7 @@ def _create_default_sampling_metadata(
|
||||
no_penalties=True,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -169,7 +169,7 @@ def _construct_expected_sampling_metadata(
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -62,6 +62,7 @@ if TYPE_CHECKING:
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
HfOverrides = Union[dict, Callable[[type], type]]
|
||||
else:
|
||||
@ -72,6 +73,7 @@ else:
|
||||
BaseModelLoader = Any
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
LogitsProcessor = Any
|
||||
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
||||
|
||||
me_quant = LazyLoader("model_executor", globals(),
|
||||
@ -465,6 +467,9 @@ class ModelConfig:
|
||||
- "transformers" will use the Transformers model implementation."""
|
||||
override_attention_dtype: Optional[str] = None
|
||||
"""Override dtype for attention"""
|
||||
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
||||
"""One or more logits processors' fully-qualified class names or class
|
||||
definitions"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
@ -43,6 +43,7 @@ from vllm.transformers_utils.config import is_interleaved
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -435,6 +436,10 @@ class EngineArgs:
|
||||
enable_multimodal_encoder_data_parallel: bool = \
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
|
||||
logits_processors: Optional[list[Union[
|
||||
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
|
||||
"""Custom logitproc types"""
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
# DEPRECATED
|
||||
enable_prompt_adapter: bool = False
|
||||
@ -549,6 +554,8 @@ class EngineArgs:
|
||||
**model_kwargs["model_impl"])
|
||||
model_group.add_argument("--override-attention-dtype",
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
model_group.add_argument("--logits-processors",
|
||||
**model_kwargs["logits_processors"])
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
@ -940,6 +947,7 @@ class EngineArgs:
|
||||
enable_sleep_mode=self.enable_sleep_mode,
|
||||
model_impl=self.model_impl,
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
logits_processors=self.logits_processors,
|
||||
)
|
||||
|
||||
def validate_tensorizer_args(self):
|
||||
|
||||
@ -55,6 +55,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.metrics.reader import Metric
|
||||
@ -198,6 +199,8 @@ class LLM:
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any],
|
||||
CompilationConfig]] = None,
|
||||
logits_processors: Optional[list[Union[str,
|
||||
type[LogitsProcessor]]]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""LLM constructor."""
|
||||
@ -272,6 +275,7 @@ class LLM:
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
override_pooler_config=override_pooler_config,
|
||||
compilation_config=compilation_config_instance,
|
||||
logits_processors=logits_processors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -2562,7 +2562,7 @@ def direct_register_custom_op(
|
||||
|
||||
def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
"""
|
||||
Resolve an object by its fully qualified name.
|
||||
Resolve an object by its fully-qualified class name.
|
||||
"""
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
185
vllm/v1/sample/logits_processor/__init__.py
Normal file
185
vllm/v1/sample/logits_processor/__init__.py
Normal file
@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
MinTokensLogitsProcessor)
|
||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder,
|
||||
LogitsProcessors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Error message when the user tries to initialize vLLM with a pooling model
|
||||
# and custom logitsproces
|
||||
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
|
||||
" logits processors.")
|
||||
|
||||
LOGITSPROCS_GROUP = 'vllm.logits_processors'
|
||||
|
||||
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||
MinTokensLogitsProcessor,
|
||||
LogitBiasLogitsProcessor,
|
||||
MinPLogitsProcessor,
|
||||
]
|
||||
|
||||
|
||||
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
|
||||
"""Load all installed logit processor plugins"""
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
|
||||
if len(installed_logitsprocs_plugins) == 0:
|
||||
logger.debug("No logitsprocs plugins installed (group %s).",
|
||||
LOGITSPROCS_GROUP)
|
||||
return []
|
||||
|
||||
# Load logitsprocs plugins
|
||||
logger.debug("Loading installed logitsprocs plugins (group %s):",
|
||||
LOGITSPROCS_GROUP)
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for entrypoint in installed_logitsprocs_plugins:
|
||||
try:
|
||||
logger.debug("- Loading logitproc plugin entrypoint=%s target=%s",
|
||||
entrypoint.name, entrypoint.value)
|
||||
classes.append(entrypoint.load())
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load LogitsProcessor plugin {entrypoint}") from e
|
||||
return classes
|
||||
|
||||
|
||||
def _load_logitsprocs_by_fqcns(
|
||||
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]]
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load logit processor types, identifying them by fully-qualified class
|
||||
names (FQCNs).
|
||||
|
||||
Effectively, a mixed list of logitproc types and FQCN strings is converted
|
||||
into a list of entirely logitproc types, by loading from the FQCNs.
|
||||
|
||||
FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc
|
||||
|
||||
Already-loaded logitproc types must be subclasses of LogitsProcessor
|
||||
|
||||
Args:
|
||||
logits_processors: Potentially mixed list of logitsprocs types and FQCN
|
||||
strings for logitproc types
|
||||
|
||||
Returns:
|
||||
List of logitproc types
|
||||
|
||||
"""
|
||||
if not logits_processors:
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
"%s additional custom logits processors specified, checking whether "
|
||||
"they need to be loaded.", len(logits_processors))
|
||||
|
||||
classes: list[type[LogitsProcessor]] = []
|
||||
for ldx, logitproc in enumerate(logits_processors):
|
||||
if isinstance(logitproc, type):
|
||||
logger.debug(" - Already-loaded logit processor: %s",
|
||||
logitproc.__name__)
|
||||
if not issubclass(logitproc, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{logitproc.__name__} is not a subclass of LogitsProcessor"
|
||||
)
|
||||
classes.append(logitproc)
|
||||
continue
|
||||
|
||||
logger.debug("- Loading logits processor %s", logitproc)
|
||||
module_path, qualname = logitproc.split(":")
|
||||
|
||||
try:
|
||||
# Load module
|
||||
module = importlib.import_module(module_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
|
||||
) from e
|
||||
|
||||
# Walk down dotted name to get logitproc class
|
||||
obj = module
|
||||
for attr in qualname.split("."):
|
||||
obj = getattr(obj, attr)
|
||||
if not isinstance(obj, type):
|
||||
raise ValueError("Loaded logit processor must be a type.")
|
||||
if not issubclass(obj, LogitsProcessor):
|
||||
raise ValueError(
|
||||
f"{obj.__name__} must be a subclass of LogitsProcessor")
|
||||
classes.append(obj)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def _load_custom_logitsprocs(
|
||||
logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]],
|
||||
) -> list[type[LogitsProcessor]]:
|
||||
"""Load all custom logits processors.
|
||||
|
||||
* First load all installed logitproc plugins
|
||||
* Second load custom logitsprocs pass by the user at initialization time
|
||||
|
||||
Args:
|
||||
logits_processors: potentially mixed list of logitproc types and
|
||||
logitproc type fully-qualified names (FQCNs)
|
||||
which need to be loaded
|
||||
|
||||
Returns:
|
||||
A list of all loaded logitproc types
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu():
|
||||
# No logitsprocs specified by caller
|
||||
# TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
|
||||
return []
|
||||
|
||||
return (_load_logitsprocs_plugins() +
|
||||
_load_logitsprocs_by_fqcns(logits_processors))
|
||||
|
||||
|
||||
def build_logitsprocs(
|
||||
vllm_config: "VllmConfig",
|
||||
device: torch.device,
|
||||
is_pin_memory: bool,
|
||||
is_pooling_model: bool,
|
||||
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
|
||||
) -> LogitsProcessors:
|
||||
if is_pooling_model:
|
||||
if custom_logitsprocs:
|
||||
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||
logger.debug("Skipping logits processor loading because pooling models"
|
||||
" do not support logits processors.")
|
||||
return LogitsProcessors()
|
||||
custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs)
|
||||
return LogitsProcessors(
|
||||
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
|
||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
|
||||
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
|
||||
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
|
||||
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP"
|
||||
]
|
||||
@ -1,241 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Optional, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from torch._prims_common import DeviceLikeType
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = 0
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = 1
|
||||
|
||||
|
||||
# (index, params, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Note: each added request is represented as
|
||||
# (index, params, output_tok_ids)
|
||||
# Key assumption: output_tok_ids is a reference to the
|
||||
# request's running output tokens list; in this way
|
||||
# the logits processors always see the latest list of
|
||||
# generated tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from
|
||||
the persistent batch.
|
||||
|
||||
Must not be called after the first time
|
||||
self.removed, self.pop_removed() or
|
||||
self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure
|
||||
and reset internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not any((self._removed, self.moved, self.added)):
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
# Reset removed/moved/added update lists
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
TODO(andy): won't be utilized until logits
|
||||
processors are user-extensible
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional[BatchUpdate],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogitsProcessorManager:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # argmax-invariant logitsprocs
|
||||
non_argmax_invariant: list[LogitsProcessor] = field(
|
||||
default_factory=list) # non-argmax-invariant logitsprocs
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator[LogitsProcessor]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
|
||||
|
||||
###### ----- Built-in LogitsProcessor impls below here
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, max_num_reqs: int, pin_memory: bool,
|
||||
device: DeviceLikeType):
|
||||
super().__init__()
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.min_p_count: int = 0
|
||||
|
||||
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
pin_memory=is_pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
|
||||
self.use_double_tensor = torch.device("cpu") != torch.device(device)
|
||||
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||
|
||||
if self.use_double_tensor:
|
||||
# Pre-allocated device tensor
|
||||
@ -260,8 +51,8 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
needs_update = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
|
||||
for index, params, _, _ in batch_update.added:
|
||||
min_p = params.min_p
|
||||
if self.min_p_cpu[index] != min_p:
|
||||
needs_update = True
|
||||
self.min_p_cpu[index] = min_p
|
||||
@ -316,11 +107,10 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
super().__init__()
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
def __init__(self, _, device: torch.device, is_pin_memory: bool):
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.pin_memory = is_pin_memory
|
||||
self.biases: dict[int, dict[int, float]] = {}
|
||||
|
||||
self.bias_tensor: torch.Tensor = torch.tensor(())
|
||||
self.logits_slice = (self._device_tensor([], torch.int32),
|
||||
@ -337,9 +127,8 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
needs_update: bool = False
|
||||
# Process added requests.
|
||||
for index, params, _ in batch_update.added:
|
||||
if isinstance(params, SamplingParams) and (lb :=
|
||||
params.logit_bias):
|
||||
for index, params, _, _ in batch_update.added:
|
||||
if lb := params.logit_bias:
|
||||
self.biases[index] = lb
|
||||
needs_update = True
|
||||
else:
|
||||
@ -400,12 +189,12 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
|
||||
|
||||
class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, pin_memory: bool, device: torch.device):
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
# index -> (min_toks, output_token_ids, stop_token_ids)
|
||||
super().__init__()
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.pin_memory = is_pin_memory
|
||||
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
||||
|
||||
# (req_idx_tensor,eos_tok_id_tensor)
|
||||
self.logits_slice: tuple[torch.Tensor,
|
||||
@ -424,9 +213,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
if batch_update:
|
||||
# Process added requests.
|
||||
for index, params, output_tok_ids in batch_update.added:
|
||||
if (isinstance(params, SamplingParams)
|
||||
and (min_tokens := params.min_tokens)
|
||||
for index, params, _, output_tok_ids in batch_update.added:
|
||||
if ((min_tokens := params.min_tokens)
|
||||
and len(output_tok_ids) < min_tokens):
|
||||
# Replace request metadata at batch index
|
||||
self.min_toks[index] = (min_tokens, output_tok_ids,
|
||||
@ -499,35 +287,3 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits[self.logits_slice] = -float("inf")
|
||||
return logits
|
||||
|
||||
|
||||
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
|
||||
device: torch.device) -> LogitsProcessorManager:
|
||||
"""Construct 'builtin' vLLM logitsprocs which the engine
|
||||
loads by default.
|
||||
|
||||
Args:
|
||||
pin_memory_available: pinned memory is available for use
|
||||
for use by logitsproc
|
||||
max_num_reqs: ceiling on request count in persistent batch
|
||||
device: inference device
|
||||
|
||||
Returns:
|
||||
Data structure encapsulating loaded logitsprocs
|
||||
"""
|
||||
min_tokens_logitproc = MinTokensLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
logit_bias_logitproc = LogitBiasLogitsProcessor(
|
||||
pin_memory=pin_memory_available, device=device)
|
||||
min_p_logitproc = MinPLogitsProcessor(
|
||||
pin_memory=pin_memory_available,
|
||||
device=device,
|
||||
# +1 for temporary swap space
|
||||
max_num_reqs=max_num_reqs + 1)
|
||||
return LogitsProcessorManager(
|
||||
non_argmax_invariant=[
|
||||
min_tokens_logitproc,
|
||||
logit_bias_logitproc,
|
||||
],
|
||||
argmax_invariant=[min_p_logitproc],
|
||||
)
|
||||
86
vllm/v1/sample/logits_processor/interface.py
Normal file
86
vllm/v1/sample/logits_processor/interface.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class MoveDirectionality(Enum):
|
||||
# One-way i1->i2 req move within batch
|
||||
UNIDIRECTIONAL = auto()
|
||||
# Two-way i1<->i2 req swap within batch
|
||||
SWAP = auto()
|
||||
|
||||
|
||||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
|
||||
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
MovedRequest = tuple[int, int, MoveDirectionality]
|
||||
|
||||
# Batch indices of any removed requests.
|
||||
RemovedRequest = int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchUpdate:
|
||||
"""Persistent batch state change info for logitsprocs"""
|
||||
batch_size: int # Current num reqs in batch
|
||||
|
||||
# Metadata for requests added to, removed from, and moved
|
||||
# within the persistent batch.
|
||||
#
|
||||
# Key assumption: the `output_tok_ids` list (which is an element of each
|
||||
# tuple in `added`) is a reference to the request's running output tokens
|
||||
# list; via this reference, the logits processors always see the latest
|
||||
# list of generated output tokens
|
||||
removed: Sequence[RemovedRequest]
|
||||
moved: Sequence[MovedRequest]
|
||||
added: Sequence[AddedRequest]
|
||||
|
||||
|
||||
class LogitsProcessor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""True if logits processor has no impact on the
|
||||
argmax computation in greedy sampling.
|
||||
NOTE: may or may not have the same value for all
|
||||
instances of a given LogitsProcessor subclass,
|
||||
depending on subclass implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_state(
|
||||
self,
|
||||
batch_update: Optional["BatchUpdate"],
|
||||
) -> None:
|
||||
"""Called when there are new output tokens, prior
|
||||
to each forward pass.
|
||||
|
||||
Args:
|
||||
batch_update is non-None iff there have been
|
||||
changes to the batch makeup.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
149
vllm/v1/sample/logits_processor/state.py
Normal file
149
vllm/v1/sample/logits_processor/state.py
Normal file
@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.v1.sample.logits_processor.interface import (AddedRequest,
|
||||
BatchUpdate,
|
||||
MovedRequest,
|
||||
RemovedRequest)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||
|
||||
|
||||
class BatchUpdateBuilder:
|
||||
"""Helps track persistent batch state changes and build
|
||||
a batch update data structure for logitsprocs
|
||||
Assumptions:
|
||||
* All information about requests removed from persistent batch
|
||||
during a step is aggregated in self._removed through calls to
|
||||
self.removed_append() at the beginning of a step. This must happen
|
||||
before the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are invoked in a given step
|
||||
* After the first time that self.removed, self.pop_removed()
|
||||
or self.peek_removed() are read in a step, no new removals
|
||||
are registered using self.removed_append()
|
||||
* Elements of self._removed are never directly modified, added or
|
||||
removed (i.e. modification is only via self.removed_append() and
|
||||
self.pop_removed())
|
||||
Guarantees under above assumptions:
|
||||
* self.removed is always sorted in descending order
|
||||
* self.pop_removed() and self.peek_removed() both return
|
||||
the lowest removed request index in the current step
|
||||
"""
|
||||
|
||||
_removed: list[RemovedRequest]
|
||||
_is_removed_sorted: bool
|
||||
moved: list[MovedRequest]
|
||||
added: list[AddedRequest]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
removed: Optional[list[RemovedRequest]] = None,
|
||||
moved: Optional[list[MovedRequest]] = None,
|
||||
added: Optional[list[AddedRequest]] = None,
|
||||
) -> None:
|
||||
self._removed = removed or []
|
||||
self.moved = moved or []
|
||||
self.added = added or []
|
||||
self._is_removed_sorted = False
|
||||
|
||||
def _ensure_removed_sorted(self) -> None:
|
||||
"""Sort removed request indices in
|
||||
descending order.
|
||||
Idempotent after first call in a
|
||||
given step, until reset.
|
||||
"""
|
||||
if not self._is_removed_sorted:
|
||||
self._removed.sort(reverse=True)
|
||||
self._is_removed_sorted = True
|
||||
|
||||
@property
|
||||
def removed(self) -> list[RemovedRequest]:
|
||||
"""Removed request indices sorted in
|
||||
descending order"""
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed
|
||||
|
||||
def removed_append(self, index: int) -> None:
|
||||
"""Register the removal of a request from the persistent batch.
|
||||
|
||||
Must not be called after the first time self.removed,
|
||||
self.pop_removed() or self.peek_removed() are invoked.
|
||||
|
||||
Args:
|
||||
index: request index
|
||||
"""
|
||||
if self._is_removed_sorted:
|
||||
raise RuntimeError("Cannot register new removed request after"
|
||||
" self.removed has been read.")
|
||||
self._removed.append(index)
|
||||
|
||||
def has_removed(self) -> bool:
|
||||
return bool(self._removed)
|
||||
|
||||
def peek_removed(self) -> Optional[int]:
|
||||
"""Return lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed[-1]
|
||||
return None
|
||||
|
||||
def pop_removed(self) -> Optional[int]:
|
||||
"""Pop lowest removed request index"""
|
||||
if self.has_removed():
|
||||
self._ensure_removed_sorted()
|
||||
return self._removed.pop()
|
||||
return None
|
||||
|
||||
def _is_update(self) -> bool:
|
||||
"""True if there is a batch state change"""
|
||||
return any((self._removed, self.moved, self.added))
|
||||
|
||||
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
||||
"""Generate a logitsprocs batch update data structure and reset
|
||||
internal batch update builder state.
|
||||
|
||||
Args:
|
||||
batch_size: current persistent batch size
|
||||
|
||||
Returns:
|
||||
Frozen logitsprocs batch update instance; `None` if no updates
|
||||
"""
|
||||
# Reset removal-sorting logic
|
||||
self._is_removed_sorted = False
|
||||
if not self._is_update():
|
||||
# No update; short-circuit
|
||||
return None
|
||||
# Build batch state update
|
||||
batch_update = BatchUpdate(
|
||||
batch_size=batch_size,
|
||||
removed=self._removed,
|
||||
moved=self.moved,
|
||||
added=self.added,
|
||||
)
|
||||
self._removed = []
|
||||
self.moved = []
|
||||
self.added = []
|
||||
return batch_update
|
||||
|
||||
|
||||
class LogitsProcessors:
|
||||
"""Encapsulates initialized logitsproc objects."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logitsprocs: Optional[Iterator["LogitsProcessor"]] = None) -> None:
|
||||
self.argmax_invariant: list[LogitsProcessor] = []
|
||||
self.non_argmax_invariant: list[LogitsProcessor] = []
|
||||
if logitsprocs:
|
||||
for logitproc in logitsprocs:
|
||||
(self.argmax_invariant if logitproc.is_argmax_invariant() else
|
||||
self.non_argmax_invariant).append(logitproc)
|
||||
|
||||
@property
|
||||
def all(self) -> Iterator["LogitsProcessor"]:
|
||||
"""Iterator over all logits processors."""
|
||||
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
||||
@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessorManager
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -40,4 +40,4 @@ class SamplingMetadata:
|
||||
bad_words_token_ids: dict[int, list[list[int]]]
|
||||
|
||||
# Loaded logits processors
|
||||
logitsprocs: LogitsProcessorManager
|
||||
logitsprocs: LogitsProcessors
|
||||
|
||||
@ -18,8 +18,8 @@ from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
MoveDirectionality,
|
||||
init_builtin_logitsprocs)
|
||||
LogitsProcessors,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
@ -78,8 +78,11 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@ -221,14 +224,6 @@ class InputBatch:
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
@ -244,6 +239,10 @@ class InputBatch:
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
@ -255,28 +254,35 @@ class InputBatch:
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def _get_next_add_index(self) -> int:
|
||||
if (req_index := self.batch_update_builder.pop_removed()) is not None:
|
||||
# Fill the empty index.
|
||||
return req_index
|
||||
# Append to end
|
||||
return self.num_reqs
|
||||
|
||||
def _register_add_request(self, request: "CachedRequestState") -> int:
|
||||
"""Track add-request operations"""
|
||||
req_index = self._get_next_add_index()
|
||||
assert req_index < self.max_num_reqs
|
||||
params = (request.sampling_params
|
||||
if request.sampling_params else request.pooling_params)
|
||||
"""Track add-request operations for logits processors.
|
||||
Not applicable to pooling models.
|
||||
"""
|
||||
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs
|
||||
assert request.sampling_params
|
||||
|
||||
# Fill the next empty index if there is one.
|
||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||
# Append to end otherwise.
|
||||
new_req_index = self.num_reqs
|
||||
|
||||
assert new_req_index < self.max_num_reqs
|
||||
self.batch_update_builder.added.append(
|
||||
(req_index, params, request.output_token_ids))
|
||||
return req_index
|
||||
(new_req_index, request.sampling_params, request.prompt_token_ids,
|
||||
request.output_token_ids))
|
||||
return new_req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
) -> int:
|
||||
req_index = self._register_add_request(request)
|
||||
if not self.is_pooling_model:
|
||||
# New request index bookkeeping for autoregressive models.
|
||||
req_index = self._register_add_request(request)
|
||||
else:
|
||||
req_index = self.num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
@ -411,7 +417,10 @@ class InputBatch:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
if not self.is_pooling_model:
|
||||
# Autoregressive models require bookkeeping of removed requests to
|
||||
# support logitsprocs.
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
@ -446,6 +455,8 @@ class InputBatch:
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
@ -513,11 +524,18 @@ class InputBatch:
|
||||
swaps: list of (from,to) swap tuples for moved requests
|
||||
empty_req_indices: indices not filled by condensation
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Will be contiguous in pooling case, just trim the lists.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
return
|
||||
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
return
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
@ -541,6 +559,8 @@ class InputBatch:
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
@ -596,15 +616,20 @@ class InputBatch:
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
|
||||
def refresh_metadata(self):
|
||||
"""Apply batch updates, reset input batch at end of step
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
|
||||
* Apply batch add/remove/permute to logits procs' states
|
||||
* If batch state is modified, update sampling metadata
|
||||
"""
|
||||
if self.is_pooling_model:
|
||||
# Batch changes every step for pooling models.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
return
|
||||
|
||||
# For non-pooling models - generate and apply logitsprocs update;
|
||||
# reset batch update tracking.
|
||||
# Update sampling metadata if batch state is changed.
|
||||
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
||||
for logit_proc in self.logitsprocs.all:
|
||||
logit_proc.update_state(batch_update)
|
||||
|
||||
@ -68,6 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
@ -80,7 +81,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
from ..sample.logits_processor import LogitsProcessorManager
|
||||
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
|
||||
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
@ -221,6 +221,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config, self.device, self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
@ -2447,7 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
output_token_ids=[[] for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=LogitsProcessorManager(),
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
try:
|
||||
sampler_output = self.sampler(logits=logits,
|
||||
@ -2968,6 +2973,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user