mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:15:44 +08:00
253 lines
9.8 KiB
Python
253 lines
9.8 KiB
Python
from abc import abstractmethod
|
|
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
import torch.jit
|
|
import torch.nn as nn
|
|
|
|
|
|
class SpecDecodeBaseSampler(nn.Module):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step.
|
|
"""
|
|
|
|
def __init__(self,
|
|
disable_bonus_tokens: bool = True,
|
|
strict_mode: bool = False):
|
|
"""Base class constructor.
|
|
Args:
|
|
disable_bonus_tokens: Whether or not to disable the bonus token.
|
|
Require when bonus tokens will cause corrupt KV cache for
|
|
proposal methods that require KV cache.
|
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
|
during sampling. This catches correctness issues but adds
|
|
nontrivial latency.
|
|
"""
|
|
super().__init__()
|
|
self._disable_bonus_tokens = disable_bonus_tokens
|
|
self._strict_mode = strict_mode
|
|
|
|
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
|
# accepted. There is always only one possible bonus token. We store this
|
|
# value in a variable for readability.
|
|
self._num_bonus_tokens = 1
|
|
|
|
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
|
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
|
self.num_draft_tokens: int = 0
|
|
|
|
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
|
assert self.num_accepted_tokens is None
|
|
if isinstance(device, int):
|
|
device = f"cuda:{device}"
|
|
elif not isinstance(device, str):
|
|
raise ValueError(f"Device must be int or str, get {type(device)}")
|
|
self.num_accepted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
self.num_emitted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
|
|
@property
|
|
def probs_dtype(self):
|
|
return torch.float32
|
|
|
|
@property
|
|
def token_id_dtype(self):
|
|
return torch.int64
|
|
|
|
def _create_output(
|
|
self,
|
|
accepted: torch.Tensor, # [batch_size, k]
|
|
substitute_token_ids: torch.Tensor, # [batch_size, k]
|
|
draft_token_ids: torch.Tensor, # [batch_size, k]
|
|
bonus_token_ids: torch.Tensor, # [batch_size]
|
|
) -> torch.Tensor:
|
|
"""Format output. Returns a matrix of token ids. When
|
|
a token is rejected via sampling, all subsequent token ids are
|
|
set to -1 for the sequence.
|
|
|
|
Args:
|
|
accepted: A boolean tensor indicating if the corresponding
|
|
draft token in draft_token_ids should be accepted or not.
|
|
substitute_token_ids: A tensor of token_ids that can be used
|
|
as substitutes for the draft token ids if the proposed token
|
|
is rejected.
|
|
draft_token_ids: A tensor of token ids speculated by the
|
|
draft model.
|
|
bonus_token_ids: Token ids to use as the bonus token if
|
|
all the draft tokens are accepted.
|
|
Returns:
|
|
A tensor containing the accepted token ids. The shape of the
|
|
tensor is [batch_size, k + num_bonus_tokens]
|
|
"""
|
|
batch_size, k = substitute_token_ids.shape
|
|
bonus_token_ids = bonus_token_ids.squeeze()
|
|
# Determine the index of the first False value for each row.
|
|
limits = (accepted == 0).max(1).indices
|
|
limits[~(accepted == 0).any(1)] = k
|
|
|
|
# Create masks using the indices.
|
|
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
|
accepted_mask = indices < limits.unsqueeze(1)
|
|
after_false_mask = indices == limits.unsqueeze(1)
|
|
|
|
# Create an extended output tensor
|
|
output_with_bonus_tokens = -torch.ones(
|
|
(batch_size, k + self._num_bonus_tokens),
|
|
dtype=self.token_id_dtype,
|
|
device=accepted.device)
|
|
output = output_with_bonus_tokens[:, :k]
|
|
|
|
# Fill in the first k columns of the output tensor using masks and data
|
|
# tensors.
|
|
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
|
-torch.ones_like(draft_token_ids))
|
|
|
|
# Fill the last column.
|
|
# We check output directly as accepted may have True values inconsistent
|
|
# with causal acceptance.
|
|
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
|
bonus_token_ids, -1)
|
|
|
|
# We disable bonus tokens because it causes corrupt KV cache for
|
|
# proposal methods that require KV cache. We can fix it by "prefilling"
|
|
# the bonus token in the proposer. The following issue tracks the fix.
|
|
# https://github.com/vllm-project/vllm/issues/4212
|
|
if self._disable_bonus_tokens:
|
|
output_with_bonus_tokens[:, -1] = -1
|
|
|
|
# Fill the recovered token ids.
|
|
output.mul_(~after_false_mask).add_(
|
|
substitute_token_ids.mul(after_false_mask))
|
|
|
|
self.num_accepted_tokens += accepted.sum()
|
|
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
|
self.num_draft_tokens += batch_size * k
|
|
|
|
return output_with_bonus_tokens
|
|
|
|
def _raise_if_incorrect_input(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
|
draft_token_ids, bonus_token_ids)
|
|
|
|
def _raise_if_incorrect_shape(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
(target_batch_size, num_target_probs,
|
|
target_vocab_size) = target_with_bonus_probs.shape
|
|
|
|
# Does not count the extra token
|
|
num_target_probs -= 1
|
|
|
|
# validate the shape of draft token ids.
|
|
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
|
assert draft_token_ids_batch_size == target_batch_size
|
|
assert num_draft_token_ids == num_target_probs
|
|
|
|
# validate the shape of bonus token ids
|
|
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
|
assert bonus_batch_size == target_batch_size
|
|
assert num_bonus_tokens == self._num_bonus_tokens
|
|
|
|
# validate the shape of draft probs if it is set
|
|
if draft_probs is not None:
|
|
(draft_batch_size, num_draft_probs,
|
|
draft_vocab_size) = draft_probs.shape
|
|
assert draft_batch_size == target_batch_size
|
|
assert num_draft_probs == num_target_probs
|
|
assert (draft_vocab_size == target_vocab_size
|
|
), f"{draft_vocab_size=} {target_vocab_size=}"
|
|
|
|
def _raise_if_incorrect_dtype(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
assert target_with_bonus_probs.dtype == self.probs_dtype
|
|
assert draft_token_ids.dtype == self.token_id_dtype
|
|
assert bonus_token_ids.dtype == self.token_id_dtype
|
|
if draft_probs is not None:
|
|
assert draft_probs.dtype == self.probs_dtype
|
|
|
|
def _raise_if_inconsistent_device(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
devices = [
|
|
t.device for t in [
|
|
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
|
draft_token_ids
|
|
] if t is not None
|
|
]
|
|
assert all([devices[0] == device for device in devices])
|
|
|
|
def _raise_if_out_of_bounds_vocab(
|
|
self,
|
|
vocab_size: int,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
) -> None:
|
|
assert torch.all(bonus_token_ids < vocab_size)
|
|
assert torch.all(bonus_token_ids >= 0)
|
|
assert torch.all(draft_token_ids < vocab_size)
|
|
assert torch.all(draft_token_ids >= 0)
|
|
|
|
|
|
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step which are deterministic.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step which are stochastic
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|