mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 14:35:21 +08:00
526 lines
20 KiB
Python
526 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import dataclasses
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Iterator, Sequence
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from itertools import chain
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
from torch._prims_common import DeviceLikeType
|
|
|
|
from vllm import PoolingParams, SamplingParams
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MoveDirectionality(Enum):
|
|
# One-way i1->i2 req move within batch
|
|
UNIDIRECTIONAL = 0
|
|
# Two-way i1<->i2 req swap within batch
|
|
SWAP = 1
|
|
|
|
|
|
# (index, params, output_tok_ids) tuples for new
|
|
# requests added to the batch.
|
|
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
|
|
# (index 1, index 2, directionality) tuples representing
|
|
# one-way moves or two-way swaps of requests in batch
|
|
MovedRequest = tuple[int, int, MoveDirectionality]
|
|
# Batch indices of any removed requests.
|
|
RemovedRequest = int
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class BatchUpdate:
|
|
"""Persistent batch state change info for logitsprocs"""
|
|
batch_size: int # Current num reqs in batch
|
|
|
|
# Metadata for requests added to, removed from, and moved
|
|
# within the persistent batch.
|
|
#
|
|
# Note: each added request is represented as
|
|
# (index, params, output_tok_ids)
|
|
# Key assumption: output_tok_ids is a reference to the
|
|
# request's running output tokens list; in this way
|
|
# the logits processors always see the latest list of
|
|
# generated tokens
|
|
removed: Sequence[RemovedRequest]
|
|
moved: Sequence[MovedRequest]
|
|
added: Sequence[AddedRequest]
|
|
|
|
|
|
class BatchUpdateBuilder:
|
|
"""Helps track persistent batch state changes and build
|
|
a batch update data structure for logitsprocs
|
|
|
|
Assumptions:
|
|
* All information about requests removed from persistent batch
|
|
during a step is aggregated in self._removed through calls to
|
|
self.removed_append() at the beginning of a step. This must happen
|
|
before the first time that self.removed, self.pop_removed()
|
|
or self.peek_removed() are invoked in a given step
|
|
* After the first time that self.removed, self.pop_removed()
|
|
or self.peek_removed() are read in a step, no new removals
|
|
are registered using self.removed_append()
|
|
* Elements of self._removed are never directly modified, added or
|
|
removed (i.e. modification is only via self.removed_append() and
|
|
self.pop_removed())
|
|
|
|
Guarantees under above assumptions:
|
|
* self.removed is always sorted in descending order
|
|
* self.pop_removed() and self.peek_removed() both return
|
|
the lowest removed request index in the current step
|
|
"""
|
|
|
|
_removed: list[RemovedRequest]
|
|
_is_removed_sorted: bool
|
|
moved: list[MovedRequest]
|
|
added: list[AddedRequest]
|
|
|
|
def __init__(
|
|
self,
|
|
removed: Optional[list[RemovedRequest]] = None,
|
|
moved: Optional[list[MovedRequest]] = None,
|
|
added: Optional[list[AddedRequest]] = None,
|
|
) -> None:
|
|
self._removed = removed or []
|
|
self.moved = moved or []
|
|
self.added = added or []
|
|
self._is_removed_sorted = False
|
|
|
|
def _ensure_removed_sorted(self) -> None:
|
|
"""Sort removed request indices in
|
|
descending order.
|
|
|
|
Idempotent after first call in a
|
|
given step, until reset.
|
|
"""
|
|
if not self._is_removed_sorted:
|
|
self._removed.sort(reverse=True)
|
|
self._is_removed_sorted = True
|
|
|
|
@property
|
|
def removed(self) -> list[RemovedRequest]:
|
|
"""Removed request indices sorted in
|
|
descending order"""
|
|
self._ensure_removed_sorted()
|
|
return self._removed
|
|
|
|
def removed_append(self, index: int) -> None:
|
|
"""Register the removal of a request from
|
|
the persistent batch.
|
|
|
|
Must not be called after the first time
|
|
self.removed, self.pop_removed() or
|
|
self.peek_removed() are invoked.
|
|
|
|
Args:
|
|
index: request index
|
|
"""
|
|
if self._is_removed_sorted:
|
|
raise RuntimeError("Cannot register new removed request after"
|
|
" self.removed has been read.")
|
|
self._removed.append(index)
|
|
|
|
def has_removed(self) -> bool:
|
|
return bool(self._removed)
|
|
|
|
def peek_removed(self) -> Optional[int]:
|
|
"""Return lowest removed request index"""
|
|
if self.has_removed():
|
|
self._ensure_removed_sorted()
|
|
return self._removed[-1]
|
|
return None
|
|
|
|
def pop_removed(self) -> Optional[int]:
|
|
"""Pop lowest removed request index"""
|
|
if self.has_removed():
|
|
self._ensure_removed_sorted()
|
|
return self._removed.pop()
|
|
return None
|
|
|
|
def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]:
|
|
"""Generate a logitsprocs batch update data structure
|
|
and reset internal batch update builder state.
|
|
|
|
Args:
|
|
batch_size: current persistent batch size
|
|
|
|
Returns:
|
|
Frozen logitsprocs batch update instance; `None` if no updates
|
|
"""
|
|
# Reset removal-sorting logic
|
|
self._is_removed_sorted = False
|
|
if not any((self._removed, self.moved, self.added)):
|
|
# No update; short-circuit
|
|
return None
|
|
# Build batch state update
|
|
batch_update = BatchUpdate(
|
|
batch_size=batch_size,
|
|
removed=self._removed,
|
|
moved=self.moved,
|
|
added=self.added,
|
|
)
|
|
# Reset removed/moved/added update lists
|
|
self._removed = []
|
|
self.moved = []
|
|
self.added = []
|
|
return batch_update
|
|
|
|
|
|
class LogitsProcessor(ABC):
|
|
|
|
@abstractmethod
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def is_argmax_invariant(self) -> bool:
|
|
"""True if logits processor has no impact on the
|
|
argmax computation in greedy sampling.
|
|
NOTE: may or may not have the same value for all
|
|
instances of a given LogitsProcessor subclass,
|
|
depending on subclass implementation.
|
|
TODO(andy): won't be utilized until logits
|
|
processors are user-extensible
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def update_state(
|
|
self,
|
|
batch_update: Optional[BatchUpdate],
|
|
) -> None:
|
|
"""Called when there are new output tokens, prior
|
|
to each forward pass.
|
|
|
|
Args:
|
|
batch_update is non-None iff there have been
|
|
changes to the batch makeup.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclass
|
|
class LogitsProcessorManager:
|
|
"""Encapsulates initialized logitsproc objects."""
|
|
argmax_invariant: list[LogitsProcessor] = field(
|
|
default_factory=list) # argmax-invariant logitsprocs
|
|
non_argmax_invariant: list[LogitsProcessor] = field(
|
|
default_factory=list) # non-argmax-invariant logitsprocs
|
|
|
|
@property
|
|
def all(self) -> Iterator[LogitsProcessor]:
|
|
"""Iterator over all logits processors."""
|
|
return chain(self.argmax_invariant, self.non_argmax_invariant)
|
|
|
|
|
|
###### ----- Built-in LogitsProcessor impls below here
|
|
|
|
|
|
class MinPLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, max_num_reqs: int, pin_memory: bool,
|
|
device: DeviceLikeType):
|
|
super().__init__()
|
|
self.min_p_count: int = 0
|
|
|
|
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
|
dtype=torch.float32,
|
|
device="cpu",
|
|
pin_memory=pin_memory)
|
|
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
|
|
|
self.use_double_tensor = torch.device("cpu") != torch.device(device)
|
|
|
|
if self.use_double_tensor:
|
|
# Pre-allocated device tensor
|
|
self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ),
|
|
dtype=torch.float32,
|
|
device=device)
|
|
else:
|
|
self.min_p_device = self.min_p_cpu_tensor
|
|
# Current slice of the device tensor
|
|
self.min_p: torch.Tensor = self.min_p_device[:0]
|
|
|
|
def is_argmax_invariant(self) -> bool:
|
|
"""Min-p never impacts greedy sampling"""
|
|
return True
|
|
|
|
def get_min_p_by_index(self, index: int) -> float:
|
|
return float(self.min_p_cpu[index])
|
|
|
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
|
if not batch_update:
|
|
return
|
|
|
|
needs_update = False
|
|
# Process added requests.
|
|
for index, params, _ in batch_update.added:
|
|
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
|
|
if self.min_p_cpu[index] != min_p:
|
|
needs_update = True
|
|
self.min_p_cpu[index] = min_p
|
|
if min_p:
|
|
self.min_p_count += 1
|
|
|
|
if self.min_p_count:
|
|
# Process removed requests.
|
|
needs_update |= bool(batch_update.removed)
|
|
for index in batch_update.removed:
|
|
if self.min_p_cpu[index]:
|
|
self.min_p_count -= 1
|
|
|
|
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
|
for adx, bdx, direct in batch_update.moved:
|
|
change = (min_p_a :=
|
|
self.min_p_cpu[adx]) != (min_p_b :=
|
|
self.min_p_cpu[bdx])
|
|
needs_update |= change
|
|
if change:
|
|
self.min_p_cpu[bdx] = min_p_a
|
|
if direct == MoveDirectionality.SWAP:
|
|
self.min_p_cpu[adx] = min_p_b
|
|
|
|
# Update tensors if needed.
|
|
size = batch_update.batch_size
|
|
if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
|
|
self.min_p = self.min_p_device[:size]
|
|
if self.use_double_tensor:
|
|
self.min_p.copy_(self.min_p_cpu_tensor[:size],
|
|
non_blocking=True)
|
|
self.min_p.unsqueeze_(1)
|
|
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
if not self.min_p_count:
|
|
return logits
|
|
|
|
# Convert logits to probability distribution
|
|
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
|
# Calculate maximum probabilities per sequence
|
|
max_probabilities = torch.amax(probability_values,
|
|
dim=-1,
|
|
keepdim=True)
|
|
# Adjust min_p
|
|
adjusted_min_p = max_probabilities.mul_(self.min_p)
|
|
# Identify valid tokens using threshold comparison
|
|
invalid_token_mask = probability_values < adjusted_min_p
|
|
# Apply mask using boolean indexing
|
|
logits[invalid_token_mask] = -float('inf')
|
|
return logits
|
|
|
|
|
|
class LogitBiasLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, pin_memory: bool, device: torch.device):
|
|
super().__init__()
|
|
self.biases: dict[int, dict[int, float]] = {}
|
|
self.device = device
|
|
self.pin_memory = pin_memory
|
|
|
|
self.bias_tensor: torch.Tensor = torch.tensor(())
|
|
self.logits_slice = (self._device_tensor([], torch.int32),
|
|
self._device_tensor([], torch.int32))
|
|
|
|
def is_argmax_invariant(self) -> bool:
|
|
"""Logit bias can rebalance token probabilities and change the
|
|
outcome of argmax in greedy sampling."""
|
|
return False
|
|
|
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
|
if not batch_update:
|
|
return
|
|
|
|
# Process added requests.
|
|
needs_update = bool(batch_update.added)
|
|
for index, params, _ in batch_update.added:
|
|
if isinstance(params, SamplingParams) and (lb :=
|
|
params.logit_bias):
|
|
self.biases[index] = lb
|
|
else:
|
|
self.biases.pop(index, None)
|
|
|
|
if self.biases:
|
|
# Process removed requests.
|
|
for index in batch_update.removed:
|
|
if self.biases.pop(index, None):
|
|
needs_update = True
|
|
|
|
# Process moved requests, unidirectional (a->b) and swap (a<->b)
|
|
for a_index, b_index, direct in batch_update.moved:
|
|
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
|
if (a_entry := self.biases.pop(a_index, None)) is None:
|
|
if self.biases.pop(b_index, None) is not None:
|
|
needs_update = True
|
|
else:
|
|
self.biases[b_index] = a_entry
|
|
needs_update = True
|
|
else:
|
|
a_entry = self.biases.pop(a_index, None)
|
|
if (b_entry := self.biases.pop(b_index, None)) is not None:
|
|
self.biases[a_index] = b_entry
|
|
needs_update = True
|
|
if a_entry is not None:
|
|
self.biases[b_index] = a_entry
|
|
needs_update = True
|
|
|
|
# Update tensors if needed.
|
|
if needs_update:
|
|
reqs, tok_ids, biases = [], [], []
|
|
for req, lb in self.biases.items():
|
|
reqs.extend([req] * len(lb))
|
|
tok_ids.extend(lb.keys())
|
|
biases.extend(lb.values())
|
|
|
|
self.bias_tensor = self._device_tensor(biases, torch.float32)
|
|
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
|
self._device_tensor(tok_ids, torch.int32))
|
|
|
|
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
|
return (torch.tensor(data,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=self.pin_memory).to(device=self.device,
|
|
non_blocking=True))
|
|
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
if self.biases:
|
|
logits[self.logits_slice] += self.bias_tensor
|
|
return logits
|
|
|
|
|
|
class MinTokensLogitsProcessor(LogitsProcessor):
|
|
|
|
def __init__(self, pin_memory: bool, device: torch.device):
|
|
# index -> (min_toks, output_token_ids, stop_token_ids)
|
|
super().__init__()
|
|
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
|
|
self.device = device
|
|
self.pin_memory = pin_memory
|
|
|
|
# (req_idx_tensor,eos_tok_id_tensor)
|
|
self.logits_slice: tuple[torch.Tensor,
|
|
torch.Tensor] = (self._device_tensor(
|
|
[], torch.int32),
|
|
self._device_tensor(
|
|
[], torch.int32))
|
|
|
|
def is_argmax_invariant(self) -> bool:
|
|
"""By censoring stop tokens, min-tokens can change the outcome
|
|
of the argmax operation in greedy sampling."""
|
|
return False
|
|
|
|
def update_state(self, batch_update: Optional[BatchUpdate]):
|
|
needs_update = False
|
|
|
|
if batch_update:
|
|
# Process added requests.
|
|
needs_update |= bool(batch_update.added)
|
|
for index, params, output_tok_ids in batch_update.added:
|
|
if (isinstance(params, SamplingParams)
|
|
and (min_tokens := params.min_tokens)
|
|
and len(output_tok_ids) < min_tokens):
|
|
# Replace request metadata at batch index
|
|
self.min_toks[index] = (min_tokens, output_tok_ids,
|
|
params.all_stop_token_ids)
|
|
else:
|
|
# Drop request metadata at batch index
|
|
self.min_toks.pop(index, None)
|
|
|
|
if self.min_toks:
|
|
# Process removed requests.
|
|
for index in batch_update.removed:
|
|
if self.min_toks.pop(index, None):
|
|
needs_update = True
|
|
|
|
# Process moved requests, unidirectional (a->b) and
|
|
# swapped (a<->b)
|
|
for a_index, b_index, direct in batch_update.moved:
|
|
if direct == MoveDirectionality.UNIDIRECTIONAL:
|
|
if (a_entry := self.min_toks.pop(a_index,
|
|
None)) is None:
|
|
if self.min_toks.pop(b_index, None) is not None:
|
|
needs_update = True
|
|
else:
|
|
self.min_toks[b_index] = a_entry
|
|
needs_update = True
|
|
else:
|
|
a_entry = self.min_toks.pop(a_index, None)
|
|
if (b_entry := self.min_toks.pop(b_index,
|
|
None)) is not None:
|
|
self.min_toks[a_index] = b_entry
|
|
needs_update = True
|
|
if a_entry is not None:
|
|
self.min_toks[b_index] = a_entry
|
|
needs_update = True
|
|
|
|
if self.min_toks:
|
|
# Check for any requests that have attained their min tokens.
|
|
to_remove = tuple(index for index, (min_toks, out_tok_ids,
|
|
_) in self.min_toks.items()
|
|
if len(out_tok_ids) >= min_toks)
|
|
if to_remove:
|
|
needs_update = True
|
|
for index in to_remove:
|
|
del self.min_toks[index]
|
|
|
|
# Update tensors if needed.
|
|
if needs_update:
|
|
reqs: list[int] = []
|
|
tok_ids: list[int] = []
|
|
for req, (_, _, stop_tok_ids) in self.min_toks.items():
|
|
reqs.extend([req] * len(stop_tok_ids))
|
|
tok_ids.extend(stop_tok_ids)
|
|
|
|
self.logits_slice = (self._device_tensor(reqs, torch.int32),
|
|
self._device_tensor(tok_ids, torch.int32))
|
|
|
|
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
|
|
return (torch.tensor(data,
|
|
device="cpu",
|
|
dtype=dtype,
|
|
pin_memory=self.pin_memory).to(device=self.device,
|
|
non_blocking=True))
|
|
|
|
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
|
if self.min_toks:
|
|
# Inhibit EOS token for requests which have not reached min length
|
|
logits[self.logits_slice] = -float("inf")
|
|
return logits
|
|
|
|
|
|
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
|
|
device: torch.device) -> LogitsProcessorManager:
|
|
"""Construct 'builtin' vLLM logitsprocs which the engine
|
|
loads by default.
|
|
|
|
Args:
|
|
pin_memory_available: pinned memory is available for use
|
|
for use by logitsproc
|
|
max_num_reqs: ceiling on request count in persistent batch
|
|
device: inference device
|
|
|
|
Returns:
|
|
Data structure encapsulating loaded logitsprocs
|
|
"""
|
|
min_tokens_logitproc = MinTokensLogitsProcessor(
|
|
pin_memory=pin_memory_available, device=device)
|
|
logit_bias_logitproc = LogitBiasLogitsProcessor(
|
|
pin_memory=pin_memory_available, device=device)
|
|
min_p_logitproc = MinPLogitsProcessor(
|
|
pin_memory=pin_memory_available,
|
|
device=device,
|
|
# +1 for temporary swap space
|
|
max_num_reqs=max_num_reqs + 1)
|
|
return LogitsProcessorManager(
|
|
non_argmax_invariant=[
|
|
min_tokens_logitproc,
|
|
logit_bias_logitproc,
|
|
],
|
|
argmax_invariant=[min_p_logitproc],
|
|
)
|