mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 01:16:59 +08:00
[LogitsProcs] Deduplicate built-in LP implementation logic (#23362)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
83f555f637
commit
3ce8285d6d
@ -42,8 +42,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.v1.sample.logits_processor import (
|
from vllm.v1.sample.logits_processor import (
|
||||||
BatchUpdate,
|
BatchUpdate,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
MoveDirectionality,
|
|
||||||
)
|
)
|
||||||
|
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||||
|
|
||||||
|
|
||||||
# Hypothetical custom logits processor
|
# Hypothetical custom logits processor
|
||||||
@ -53,38 +53,22 @@ class DummyLogitsProcessor(LogitsProcessor):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||||
):
|
):
|
||||||
self.req_info: dict[int, SamplingParams] = {}
|
self.req_info: dict[int, int] = {}
|
||||||
|
|
||||||
def is_argmax_invariant(self) -> bool:
|
def is_argmax_invariant(self) -> bool:
|
||||||
"""Never impacts greedy sampling"""
|
"""Never impacts greedy sampling"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||||
if not batch_update:
|
process_dict_updates(
|
||||||
return
|
self.req_info,
|
||||||
|
batch_update,
|
||||||
# Process added requests.
|
# This function returns the LP's per-request state based on the
|
||||||
for index, params, _, _ in batch_update.added:
|
# request details, or None if this LP does not apply to the
|
||||||
assert params is not None
|
# request.
|
||||||
if params.extra_args and (
|
lambda params, _, __: params.extra_args
|
||||||
target_token := params.extra_args.get("target_token")
|
and (params.extra_args.get("target_token")),
|
||||||
):
|
)
|
||||||
self.req_info[index] = target_token
|
|
||||||
|
|
||||||
if self.req_info:
|
|
||||||
# Process removed requests.
|
|
||||||
for index in batch_update.removed:
|
|
||||||
self.req_info.pop(index, None)
|
|
||||||
|
|
||||||
# Process moved requests, unidirectional move (a->b) and swap
|
|
||||||
# (a<->b)
|
|
||||||
for adx, bdx, direct in batch_update.moved:
|
|
||||||
a_val = self.req_info.pop(adx, None)
|
|
||||||
b_val = self.req_info.pop(bdx, None)
|
|
||||||
if a_val is not None:
|
|
||||||
self.req_info[bdx] = a_val
|
|
||||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
|
||||||
self.req_info[adx] = b_val
|
|
||||||
|
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
if not self.req_info:
|
if not self.req_info:
|
||||||
|
|||||||
@ -8,10 +8,9 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
from vllm.v1.sample.logits_processor import (LOGITSPROCS_GROUP, BatchUpdate,
|
||||||
LogitsProcessor,
|
LogitsProcessor)
|
||||||
MoveDirectionality)
|
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||||
|
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
POOLING_MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||||
@ -45,37 +44,19 @@ class DummyLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||||
is_pin_memory: bool):
|
is_pin_memory: bool):
|
||||||
self.req_info: dict[int, SamplingParams] = {}
|
self.req_info: dict[int, int] = {}
|
||||||
|
|
||||||
def is_argmax_invariant(self) -> bool:
|
def is_argmax_invariant(self) -> bool:
|
||||||
"""Never impacts greedy sampling"""
|
"""Never impacts greedy sampling"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||||
if not batch_update:
|
process_dict_updates(
|
||||||
return
|
self.req_info,
|
||||||
|
batch_update,
|
||||||
# Process added requests.
|
lambda params, _, __: params.extra_args and
|
||||||
for index, params, _, _ in batch_update.added:
|
(params.extra_args.get("target_token")),
|
||||||
assert params is not None
|
)
|
||||||
if params.extra_args and (target_token :=
|
|
||||||
params.extra_args.get("target_token")):
|
|
||||||
self.req_info[index] = target_token
|
|
||||||
|
|
||||||
if self.req_info:
|
|
||||||
# Process removed requests.
|
|
||||||
for index in batch_update.removed:
|
|
||||||
self.req_info.pop(index, None)
|
|
||||||
|
|
||||||
# Process moved requests, unidirectional move (a->b) and swap
|
|
||||||
# (a<->b)
|
|
||||||
for adx, bdx, direct in batch_update.moved:
|
|
||||||
a_val = self.req_info.pop(adx, None)
|
|
||||||
b_val = self.req_info.pop(bdx, None)
|
|
||||||
if a_val is not None:
|
|
||||||
self.req_info[bdx] = a_val
|
|
||||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
|
||||||
self.req_info[adx] = b_val
|
|
||||||
|
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
if not self.req_info:
|
if not self.req_info:
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
MoveDirectionality)
|
MoveDirectionality)
|
||||||
@ -12,6 +13,8 @@ from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class MinPLogitsProcessor(LogitsProcessor):
|
class MinPLogitsProcessor(LogitsProcessor):
|
||||||
|
|
||||||
@ -130,49 +133,15 @@ class LogitBiasLogitsProcessor(LogitsProcessor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||||
if not batch_update:
|
needs_update = process_dict_updates(
|
||||||
return
|
self.biases, batch_update,
|
||||||
|
lambda params, _, __: params.logit_bias or None)
|
||||||
needs_update: bool = False
|
|
||||||
# Process added requests.
|
|
||||||
for index, params, _, _ in batch_update.added:
|
|
||||||
if lb := params.logit_bias:
|
|
||||||
self.biases[index] = lb
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
# Drop biases metadata at batch index
|
|
||||||
if self.biases.pop(index, None) is not None:
|
|
||||||
# If a new request replaces an old request which
|
|
||||||
# specified biases, we should update processor tensors
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
if self.biases:
|
|
||||||
# Process removed requests.
|
|
||||||
for index in batch_update.removed:
|
|
||||||
if self.biases.pop(index, None):
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
|
||||||
for a_index, b_index, direct in batch_update.moved:
|
|
||||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
|
||||||
if (a_entry := self.biases.pop(a_index, None)) is None:
|
|
||||||
if self.biases.pop(b_index, None) is not None:
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
self.biases[b_index] = a_entry
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
a_entry = self.biases.pop(a_index, None)
|
|
||||||
if (b_entry := self.biases.pop(b_index, None)) is not None:
|
|
||||||
self.biases[a_index] = b_entry
|
|
||||||
needs_update = True
|
|
||||||
if a_entry is not None:
|
|
||||||
self.biases[b_index] = a_entry
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
# Update tensors if needed.
|
# Update tensors if needed.
|
||||||
if needs_update:
|
if needs_update:
|
||||||
reqs, tok_ids, biases = [], [], []
|
reqs: list[int] = []
|
||||||
|
tok_ids: list[int] = []
|
||||||
|
biases: list[float] = []
|
||||||
for req, lb in self.biases.items():
|
for req, lb in self.biases.items():
|
||||||
reqs.extend([req] * len(lb))
|
reqs.extend([req] * len(lb))
|
||||||
tok_ids.extend(lb.keys())
|
tok_ids.extend(lb.keys())
|
||||||
@ -216,52 +185,18 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
|||||||
of the argmax operation in greedy sampling."""
|
of the argmax operation in greedy sampling."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_request(
|
||||||
|
params: SamplingParams, _: list[int], output_tok_ids: list[int]
|
||||||
|
) -> Optional[tuple[int, Sequence[int], set[int]]]:
|
||||||
|
min_tokens = params.min_tokens
|
||||||
|
if not min_tokens or len(output_tok_ids) >= min_tokens:
|
||||||
|
return None
|
||||||
|
return min_tokens, output_tok_ids, params.all_stop_token_ids
|
||||||
|
|
||||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||||
needs_update = False
|
needs_update = process_dict_updates(self.min_toks, batch_update,
|
||||||
|
self.add_request)
|
||||||
if batch_update:
|
|
||||||
# Process added requests.
|
|
||||||
for index, params, _, output_tok_ids in batch_update.added:
|
|
||||||
if ((min_tokens := params.min_tokens)
|
|
||||||
and len(output_tok_ids) < min_tokens):
|
|
||||||
# Replace request metadata at batch index
|
|
||||||
self.min_toks[index] = (min_tokens, output_tok_ids,
|
|
||||||
params.all_stop_token_ids)
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
# Drop min_toks metadata at batch index
|
|
||||||
if self.min_toks.pop(index, None) is not None:
|
|
||||||
# If a new request replaces an old request which
|
|
||||||
# specified min_toks, we should update processor tensors
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
if self.min_toks:
|
|
||||||
# Process removed requests.
|
|
||||||
for index in batch_update.removed:
|
|
||||||
if self.min_toks.pop(index, None):
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
# Process moved requests, unidirectional (a->b) and
|
|
||||||
# swapped (a<->b)
|
|
||||||
for a_index, b_index, direct in batch_update.moved:
|
|
||||||
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
|
||||||
if (a_entry := self.min_toks.pop(a_index,
|
|
||||||
None)) is None:
|
|
||||||
if self.min_toks.pop(b_index, None) is not None:
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
self.min_toks[b_index] = a_entry
|
|
||||||
needs_update = True
|
|
||||||
else:
|
|
||||||
a_entry = self.min_toks.pop(a_index, None)
|
|
||||||
if (b_entry := self.min_toks.pop(b_index,
|
|
||||||
None)) is not None:
|
|
||||||
self.min_toks[a_index] = b_entry
|
|
||||||
needs_update = True
|
|
||||||
if a_entry is not None:
|
|
||||||
self.min_toks[b_index] = a_entry
|
|
||||||
needs_update = True
|
|
||||||
|
|
||||||
if self.min_toks:
|
if self.min_toks:
|
||||||
# Check for any requests that have attained their min tokens.
|
# Check for any requests that have attained their min tokens.
|
||||||
to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
||||||
@ -295,3 +230,44 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
|||||||
# Inhibit EOS token for requests which have not reached min length
|
# Inhibit EOS token for requests which have not reached min length
|
||||||
logits[self.logits_slice] = -float("inf")
|
logits[self.logits_slice] = -float("inf")
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def process_dict_updates(
|
||||||
|
req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
|
||||||
|
new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]]
|
||||||
|
) -> bool:
|
||||||
|
"""Utility function to update dict state for sparse LogitsProcessors."""
|
||||||
|
|
||||||
|
if not batch_update:
|
||||||
|
# Nothing to do.
|
||||||
|
return False
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
|
||||||
|
if (state := new_state(params, prompt_tok_ids,
|
||||||
|
output_tok_ids)) is not None:
|
||||||
|
req_entries[index] = state
|
||||||
|
updated = True
|
||||||
|
elif req_entries.pop(index, None) is not None:
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if req_entries:
|
||||||
|
# Process removed requests.
|
||||||
|
for index in batch_update.removed:
|
||||||
|
if req_entries.pop(index, None):
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
# Process moved requests, unidirectional (a->b) and
|
||||||
|
# swapped (a<->b)
|
||||||
|
for a_index, b_index, direct in batch_update.moved:
|
||||||
|
a_entry = req_entries.pop(a_index, None)
|
||||||
|
b_entry = req_entries.pop(b_index, None)
|
||||||
|
if a_entry is not None:
|
||||||
|
req_entries[b_index] = a_entry
|
||||||
|
updated = True
|
||||||
|
if b_entry is not None:
|
||||||
|
updated = True
|
||||||
|
if direct == MoveDirectionality.SWAP:
|
||||||
|
req_entries[a_index] = b_entry
|
||||||
|
|
||||||
|
return updated
|
||||||
|
|||||||
@ -44,10 +44,16 @@ class BatchUpdate:
|
|||||||
# Key assumption: the `output_tok_ids` list (which is an element of each
|
# Key assumption: the `output_tok_ids` list (which is an element of each
|
||||||
# tuple in `added`) is a reference to the request's running output tokens
|
# tuple in `added`) is a reference to the request's running output tokens
|
||||||
# list; via this reference, the logits processors always see the latest
|
# list; via this reference, the logits processors always see the latest
|
||||||
# list of generated output tokens
|
# list of generated output tokens.
|
||||||
|
#
|
||||||
|
# NOTE:
|
||||||
|
# * Added or moved requests may replace existing requests with the same
|
||||||
|
# index.
|
||||||
|
# * Operations should be processed in the following order:
|
||||||
|
# - removed, added, moved
|
||||||
removed: Sequence[RemovedRequest]
|
removed: Sequence[RemovedRequest]
|
||||||
moved: Sequence[MovedRequest]
|
|
||||||
added: Sequence[AddedRequest]
|
added: Sequence[AddedRequest]
|
||||||
|
moved: Sequence[MovedRequest]
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(ABC):
|
class LogitsProcessor(ABC):
|
||||||
@ -59,6 +65,11 @@ class LogitsProcessor(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply LogitsProcessor to batch logits tensor.
|
||||||
|
|
||||||
|
The updated tensor must be returned but may be
|
||||||
|
modified in-place.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user