mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 13:06:11 +08:00
[V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (#23656)
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
This commit is contained in:
parent
e32a0e8678
commit
136d853e65
151
examples/offline_inference/logits_processor/custom_req.py
Normal file
151
examples/offline_inference/logits_processor/custom_req.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""This example demonstrates wrapping a request-level logits processor to be
|
||||||
|
compatible with vLLM's batch-level logits processing
|
||||||
|
|
||||||
|
For demo purposes, a dummy logits processor is employed which, if
|
||||||
|
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
||||||
|
will mask out all tokens except `target_token`. This logits processor can be
|
||||||
|
applied to a vector of logits associated with a single decode step for a single
|
||||||
|
request. The logits processor cannot be applied to a request which does not
|
||||||
|
pass in a `target_token` custom argument.
|
||||||
|
|
||||||
|
The request-level dummy logits processor is wrapped to create a batch-level
|
||||||
|
logits processor, which can apply the logits processor to output logits from
|
||||||
|
all requests in the persistent batch in a given decode step. For requests which
|
||||||
|
do not provide a `target_token` argument, the corresponding row of `logits`
|
||||||
|
will not be modified.
|
||||||
|
|
||||||
|
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||||
|
`target_token`, and for these requests - and *only* these requests - we
|
||||||
|
expect the `target_token` to be decoded in each step, yielding an output
|
||||||
|
similar to that shown below:
|
||||||
|
|
||||||
|
Generated Outputs:
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'Hello, my name is'
|
||||||
|
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The president of the United States is'
|
||||||
|
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The capital of France is'
|
||||||
|
Output: ' also also also also also also also also also also also also also
|
||||||
|
also also also'
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The future of AI is'
|
||||||
|
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||||
|
------------------------------------------------------------
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.sample.logits_processor import (
|
||||||
|
AdapterLogitsProcessor,
|
||||||
|
RequestLogitsProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPerReqLogitsProcessor:
|
||||||
|
"""The request-level logits processor masks out all logits except the
|
||||||
|
token id identified by `target_token`"""
|
||||||
|
|
||||||
|
def __init__(self, target_token: int) -> None:
|
||||||
|
"""Specify `target_token`"""
|
||||||
|
self.target_token = target_token
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
output_ids: list[int],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
val_to_keep = logits[self.target_token].item()
|
||||||
|
logits[:] = float("-inf")
|
||||||
|
logits[self.target_token] = val_to_keep
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||||
|
"""Example of wrapping a fake request-level logit processor to create a
|
||||||
|
batch-level logits processor"""
|
||||||
|
|
||||||
|
def is_argmax_invariant(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def new_req_logits_processor(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
) -> Optional[RequestLogitsProcessor]:
|
||||||
|
"""This method returns a new request-level logits processor, customized
|
||||||
|
to the `target_token` value associated with a particular request.
|
||||||
|
|
||||||
|
Returns None if the logits processor should not be applied to the
|
||||||
|
particular request. To use the logits processor the request must have
|
||||||
|
a "target_token" custom argument with an integer value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: per-request sampling params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Callable` request logits processor, or None
|
||||||
|
"""
|
||||||
|
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
|
||||||
|
"target_token"
|
||||||
|
)
|
||||||
|
if target_token is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(target_token, int):
|
||||||
|
logger.warning(
|
||||||
|
"target_token value %s is not int; not applying logits"
|
||||||
|
" processor to request.",
|
||||||
|
target_token,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return DummyPerReqLogitsProcessor(target_token)
|
||||||
|
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||||
|
sampling_params_list = [
|
||||||
|
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||||
|
SamplingParams(temperature=0.0),
|
||||||
|
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||||
|
SamplingParams(temperature=0.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
logits_processors=[WrappedPerReqLogitsProcessor],
|
||||||
|
)
|
||||||
|
# Generate texts from the prompts.
|
||||||
|
# The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params_list)
|
||||||
|
# Print the outputs.
|
||||||
|
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}")
|
||||||
|
print(f"Output: {generated_text!r}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
165
examples/offline_inference/logits_processor/custom_req_init.py
Normal file
165
examples/offline_inference/logits_processor/custom_req_init.py
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""This example demonstrates a special case of wrapping a request-level logits
|
||||||
|
processor, namely the case where it is necessary to utilize engine config or
|
||||||
|
environment info passed to the constructor. The subclass must override the
|
||||||
|
wrapper base class `__init__()` method to access the engine config, the device
|
||||||
|
identifier, or the flag which indicates whether pinned memory is available.
|
||||||
|
|
||||||
|
For demo purposes, a request-level dummy logits processor is employed which
|
||||||
|
causes the same token (`target_token`) to be decoded in each step. The
|
||||||
|
request-level dummy logits processor is wrapped to create a batch-level logits
|
||||||
|
processor, which can apply the logits processor to output logits from all
|
||||||
|
requests in the persistent batch in a given decode step.
|
||||||
|
|
||||||
|
The wrapped dummy logits processor below models a scenario where we must
|
||||||
|
disable the logits processor on non-"cuda" platforms. The wrapper base class
|
||||||
|
`__init__()` is overridden in order to check this condition and set a flag.
|
||||||
|
|
||||||
|
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||||
|
`target_token`, and for these requests - and *only* these requests - we
|
||||||
|
expect that on a "cuda" device the output will look something like:
|
||||||
|
|
||||||
|
Generated Outputs:
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'Hello, my name is'
|
||||||
|
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The president of the United States is'
|
||||||
|
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The capital of France is'
|
||||||
|
Output: ' also also also also also also also also also also also also also
|
||||||
|
also also also'
|
||||||
|
------------------------------------------------------------
|
||||||
|
Prompt: 'The future of AI is'
|
||||||
|
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||||
|
------------------------------------------------------------
|
||||||
|
|
||||||
|
which indicates that the logits processor is running. However, on a non-"cuda"
|
||||||
|
device, the first and third requests would not repeat the same token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.sample.logits_processor import (
|
||||||
|
AdapterLogitsProcessor,
|
||||||
|
RequestLogitsProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPerReqLogitsProcessor:
|
||||||
|
"""The request-level logits processor masks out all logits except the
|
||||||
|
token id identified by `target_token`"""
|
||||||
|
|
||||||
|
def __init__(self, target_token: int) -> None:
|
||||||
|
"""Specify `target_token`"""
|
||||||
|
self.target_token = target_token
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
output_ids: list[int],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
val_to_keep = logits[self.target_token].item()
|
||||||
|
logits[:] = float("-inf")
|
||||||
|
logits[self.target_token] = val_to_keep
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||||
|
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||||
|
info about the device type"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, device, is_pin_memory)
|
||||||
|
self.is_cuda = device.type == "cuda"
|
||||||
|
|
||||||
|
def is_argmax_invariant(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def new_req_logits_processor(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
) -> Optional[RequestLogitsProcessor]:
|
||||||
|
"""This method returns a new request-level logits processor, customized
|
||||||
|
to the `target_token` value associated with a particular request.
|
||||||
|
|
||||||
|
Returns None if the logits processor should not be applied to the
|
||||||
|
particular request. To use the logits processor the request must have
|
||||||
|
a "target_token" custom argument with an integer value, and the device
|
||||||
|
must be "cuda"-type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: per-request sampling params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Callable` request logits processor, or None
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
not self.is_cuda
|
||||||
|
or (
|
||||||
|
target_token := params.extra_args
|
||||||
|
and params.extra_args.get("target_token")
|
||||||
|
)
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
if not isinstance(target_token, int):
|
||||||
|
logger.warning(
|
||||||
|
"target_token value %s is not int; not applying logits"
|
||||||
|
" processor to request.",
|
||||||
|
target_token,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return DummyPerReqLogitsProcessor(target_token)
|
||||||
|
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||||
|
sampling_params_list = [
|
||||||
|
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||||
|
SamplingParams(temperature=0.0),
|
||||||
|
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||||
|
SamplingParams(temperature=0.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
logits_processors=[WrappedPerReqLogitsProcessor],
|
||||||
|
)
|
||||||
|
# Generate texts from the prompts.
|
||||||
|
# The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params_list)
|
||||||
|
# Print the outputs.
|
||||||
|
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}")
|
||||||
|
print(f"Output: {generated_text!r}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -15,6 +15,7 @@ from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
|||||||
POOLING_MODEL_NAME, TEMP_GREEDY,
|
POOLING_MODEL_NAME, TEMP_GREEDY,
|
||||||
CustomLogitprocSource,
|
CustomLogitprocSource,
|
||||||
DummyLogitsProcessor,
|
DummyLogitsProcessor,
|
||||||
|
WrappedPerReqLogitsProcessor,
|
||||||
dummy_module)
|
dummy_module)
|
||||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||||
from tests.v1.logits_processors.utils import prompts
|
from tests.v1.logits_processors.utils import prompts
|
||||||
@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch,
|
|||||||
_run_test(kwargs, logitproc_loaded=True)
|
_run_test(kwargs, logitproc_loaded=True)
|
||||||
|
|
||||||
|
|
||||||
|
@create_new_process_for_each_test()
|
||||||
|
def test_custom_logitsprocs_req(monkeypatch):
|
||||||
|
"""Test passing request-level logits processor to offline Python interface
|
||||||
|
|
||||||
|
Wrap a request-level logits processor to create a batch level logits
|
||||||
|
processor that has a well-defined behavior (mask out all tokens except one
|
||||||
|
`target_token`)
|
||||||
|
|
||||||
|
Construct an `LLM` instance which loads the wrapped logits processor. Pass
|
||||||
|
the custom logitproc as a class object.
|
||||||
|
|
||||||
|
Construct a reference `LLM` instance with no custom logitproc
|
||||||
|
|
||||||
|
Pass in a batch of requests, 50% of which pass a `target_token` value
|
||||||
|
in through `SamplingParams.extra_args`, 50% of which do not.
|
||||||
|
|
||||||
|
Validate that
|
||||||
|
* Requests which do not activate the custom logitproc, yield the same
|
||||||
|
results for both `LLM` instances
|
||||||
|
* Requests which activate the custom logitproc, only output `target_token`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
monkeypatch: for setting env vars
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Test that logitproc info is passed to workers
|
||||||
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||||
|
random.seed(40)
|
||||||
|
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
|
||||||
|
logitproc_loaded=True)
|
||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
@pytest.mark.parametrize("logitproc_source", [
|
@pytest.mark.parametrize("logitproc_source", [
|
||||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||||
|
|||||||
@ -3,15 +3,21 @@
|
|||||||
|
|
||||||
import types
|
import types
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
from vllm.logger import init_logger
|
||||||
LogitsProcessor)
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP,
|
||||||
|
AdapterLogitsProcessor,
|
||||||
|
BatchUpdate, LogitsProcessor,
|
||||||
|
RequestLogitsProcessor)
|
||||||
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||||
DUMMY_LOGITPROC_ARG = "target_token"
|
DUMMY_LOGITPROC_ARG = "target_token"
|
||||||
@ -104,5 +110,60 @@ class EntryPoints(list):
|
|||||||
self.names = [ep.name for ep in eps]
|
self.names = [ep.name for ep in eps]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPerReqLogitsProcessor:
|
||||||
|
"""The request-level logits processor masks out all logits except the
|
||||||
|
token id identified by `target_token`"""
|
||||||
|
|
||||||
|
def __init__(self, target_token: int) -> None:
|
||||||
|
"""Specify `target_token`"""
|
||||||
|
self.target_token = target_token
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
output_ids: list[int],
|
||||||
|
logits: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
val_to_keep = logits[self.target_token].item()
|
||||||
|
logits[:] = float("-inf")
|
||||||
|
logits[self.target_token] = val_to_keep
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||||
|
"""Example of wrapping a fake request-level logit processor to create a
|
||||||
|
batch-level logits processor"""
|
||||||
|
|
||||||
|
def is_argmax_invariant(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def new_req_logits_processor(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
) -> Optional[RequestLogitsProcessor]:
|
||||||
|
"""This method returns a new request-level logits processor, customized
|
||||||
|
to the `target_token` value associated with a particular request.
|
||||||
|
|
||||||
|
Returns None if the logits processor should not be applied to the
|
||||||
|
particular request. To use the logits processor the request must have
|
||||||
|
a "target_token" custom argument with an integer value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: per-request sampling params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Callable` request logits processor, or None
|
||||||
|
"""
|
||||||
|
target_token: Optional[
|
||||||
|
Any] = params.extra_args and params.extra_args.get("target_token")
|
||||||
|
if target_token is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(target_token, int):
|
||||||
|
logger.warning(
|
||||||
|
"target_token value %s is not int; not applying logits"
|
||||||
|
" processor to request.", target_token)
|
||||||
|
return None
|
||||||
|
return DummyPerReqLogitsProcessor(target_token)
|
||||||
|
|
||||||
|
|
||||||
"""Fake version of importlib.metadata.entry_points"""
|
"""Fake version of importlib.metadata.entry_points"""
|
||||||
entry_points = lambda group: EntryPoints(group)
|
entry_points = lambda group: EntryPoints(group)
|
||||||
|
|||||||
@ -1,16 +1,22 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
|
from abc import abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
||||||
MinPLogitsProcessor,
|
MinPLogitsProcessor,
|
||||||
MinTokensLogitsProcessor)
|
MinTokensLogitsProcessor,
|
||||||
|
process_dict_updates)
|
||||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
MoveDirectionality)
|
MoveDirectionality)
|
||||||
@ -177,9 +183,112 @@ def build_logitsprocs(
|
|||||||
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterLogitsProcessor(LogitsProcessor):
|
||||||
|
"""Wrapper for per-request logits processors
|
||||||
|
|
||||||
|
To wrap a specific per-request logits processor,
|
||||||
|
* Subclass `AdapterLogitsProcessor`
|
||||||
|
* Implement `self.is_argmax_invariant()` base-class method
|
||||||
|
* Implement `self.new_req_logits_processor(params)`
|
||||||
|
|
||||||
|
`self.__init__(vllm_config, device, is_pin_memory)` does not need to be
|
||||||
|
overridden in general. However, to implement custom constructor behavior -
|
||||||
|
especially any logic which operates on or stores `vllm_config`, `device`,
|
||||||
|
or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
|
||||||
|
must be overriden and the override must call
|
||||||
|
`super().__init__(vllm_config, device, is_pin_memory)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||||
|
is_pin_memory: bool):
|
||||||
|
"""Subclass must invoke
|
||||||
|
`super().__init__(vllm_config, device, is_pin_memory)`.
|
||||||
|
|
||||||
|
Subclass constructor may find it useful to utilize the `vllm_config`,
|
||||||
|
`device` and `is_pin_memory` argument. However regardless of whether
|
||||||
|
these arguments are used, the vLLM logits processor interface requires
|
||||||
|
all three arguments to be present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Map req index -> logits processor state
|
||||||
|
#
|
||||||
|
# State representation is a partial[Tensor] comprising a request-level
|
||||||
|
# logits processor with the output token ids argument and (if required)
|
||||||
|
# the prompt token ids argument pre-populated
|
||||||
|
#
|
||||||
|
# Note that the partial carries a *reference* to output token ids, and
|
||||||
|
# will thus always operate on the list as it is currently, not as it
|
||||||
|
# was when the partial was created.
|
||||||
|
self.req_info: dict[int, partial[torch.Tensor]] = {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def new_req_logits_processor(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
) -> Optional[RequestLogitsProcessor]:
|
||||||
|
"""Consume request info; return a per-request logits processor.
|
||||||
|
|
||||||
|
Return None if logits processor does not need to be applied to request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: request sampling params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if logits processor should not be applied to request; otherwise
|
||||||
|
returns a `RequestLogitsProcessor` instance
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _new_state(
|
||||||
|
self,
|
||||||
|
params: SamplingParams,
|
||||||
|
prompt_ids: list[int],
|
||||||
|
output_ids: list[int],
|
||||||
|
) -> Optional[partial[torch.Tensor]]:
|
||||||
|
"""Return state representation for new request
|
||||||
|
|
||||||
|
Returns None if logits processor is not applicable to request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: request sampling params
|
||||||
|
prompt_ids: request prompt token ids
|
||||||
|
output_ids: decoded tokens so far for this request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits processor partial[Tensor] or None
|
||||||
|
|
||||||
|
"""
|
||||||
|
if req_lp := self.new_req_logits_processor(params):
|
||||||
|
args = [prompt_ids, output_ids] if (len(
|
||||||
|
inspect.signature(req_lp).parameters) == 3) else [output_ids]
|
||||||
|
return partial(req_lp, *args)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||||
|
process_dict_updates(
|
||||||
|
self.req_info,
|
||||||
|
batch_update,
|
||||||
|
self._new_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.req_info:
|
||||||
|
# Apply per-request logits processors to corresponding rows of
|
||||||
|
# logits tensor
|
||||||
|
for req_idx, req_lp in self.req_info.items():
|
||||||
|
req_logits = logits[req_idx]
|
||||||
|
new_logits = req_lp(req_logits)
|
||||||
|
if new_logits is not req_logits:
|
||||||
|
# Modify logits tensor row in-place if necessary
|
||||||
|
logits[req_idx] = new_logits
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
|
"LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor",
|
||||||
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
|
"MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder",
|
||||||
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
|
"MoveDirectionality", "LogitsProcessors", "build_logitsprocs",
|
||||||
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP"
|
"STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP",
|
||||||
|
"AdapterLogitsProcessor"
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user