mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 02:37:07 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
257 lines
9.9 KiB
Python
257 lines
9.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from abc import abstractmethod
|
|
from typing import Dict, Optional, Union
|
|
|
|
import torch
|
|
import torch.jit
|
|
import torch.nn as nn
|
|
|
|
|
|
class SpecDecodeBaseSampler(nn.Module):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step.
|
|
"""
|
|
|
|
def __init__(self, strict_mode: bool = False):
|
|
"""Base class constructor.
|
|
Args:
|
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
|
during sampling. This catches correctness issues but adds
|
|
nontrivial latency.
|
|
"""
|
|
super().__init__()
|
|
self._strict_mode = strict_mode
|
|
|
|
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
|
# accepted. There is always only one possible bonus token. We store this
|
|
# value in a variable for readability.
|
|
self._num_bonus_tokens = 1
|
|
|
|
self.num_accepted_tokens: Optional[torch.Tensor] = None
|
|
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
|
self.num_draft_tokens: int = 0
|
|
|
|
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
|
assert self.num_accepted_tokens is None
|
|
if isinstance(device, int):
|
|
device = f"cuda:{device}"
|
|
elif not isinstance(device, str):
|
|
raise ValueError(f"Device must be int or str, get {type(device)}")
|
|
self.num_accepted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
self.num_emitted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
|
|
def init_tensors(self,
|
|
device: Union[int, str],
|
|
device_type: Union[torch.device, str] = 'cuda') -> None:
|
|
assert self.num_accepted_tokens is None
|
|
if isinstance(device_type, torch.device):
|
|
device_type = device_type.type
|
|
if isinstance(device, int):
|
|
device = f"{device_type}:{device}"
|
|
self.num_accepted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
self.num_emitted_tokens = torch.tensor(0,
|
|
dtype=torch.long,
|
|
device=device)
|
|
|
|
@property
|
|
def probs_dtype(self):
|
|
return torch.float32
|
|
|
|
@property
|
|
def token_id_dtype(self):
|
|
return torch.int64
|
|
|
|
def _create_output(
|
|
self,
|
|
accepted: torch.Tensor, # [batch_size, k]
|
|
substitute_token_ids: torch.Tensor, # [batch_size, k]
|
|
draft_token_ids: torch.Tensor, # [batch_size, k]
|
|
bonus_token_ids: torch.Tensor, # [batch_size]
|
|
) -> torch.Tensor:
|
|
"""Format output. Returns a matrix of token ids. When
|
|
a token is rejected via sampling, all subsequent token ids are
|
|
set to -1 for the sequence.
|
|
|
|
Args:
|
|
accepted: A boolean tensor indicating if the corresponding
|
|
draft token in draft_token_ids should be accepted or not.
|
|
substitute_token_ids: A tensor of token_ids that can be used
|
|
as substitutes for the draft token ids if the proposed token
|
|
is rejected.
|
|
draft_token_ids: A tensor of token ids speculated by the
|
|
draft model.
|
|
bonus_token_ids: Token ids to use as the bonus token if
|
|
all the draft tokens are accepted.
|
|
Returns:
|
|
A tensor containing the accepted token ids. The shape of the
|
|
tensor is [batch_size, k + num_bonus_tokens]
|
|
"""
|
|
batch_size, k = substitute_token_ids.shape
|
|
bonus_token_ids = bonus_token_ids.squeeze(-1)
|
|
# Determine the index of the first False value for each row.
|
|
limits = (accepted == 0).max(1).indices
|
|
limits[~(accepted == 0).any(1)] = k
|
|
|
|
# Create masks using the indices.
|
|
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
|
|
accepted_mask = indices < limits.unsqueeze(1)
|
|
after_false_mask = indices == limits.unsqueeze(1)
|
|
|
|
# Create an extended output tensor
|
|
output_with_bonus_tokens = -torch.ones(
|
|
(batch_size, k + self._num_bonus_tokens),
|
|
dtype=self.token_id_dtype,
|
|
device=accepted.device)
|
|
output = output_with_bonus_tokens[:, :k]
|
|
|
|
# Fill in the first k columns of the output tensor using masks and data
|
|
# tensors.
|
|
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
|
|
-torch.ones_like(draft_token_ids))
|
|
|
|
# Fill the last column.
|
|
# We check output directly as accepted may have True values inconsistent
|
|
# with causal acceptance.
|
|
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
|
|
bonus_token_ids, -1)
|
|
|
|
# Fill the recovered token ids.
|
|
output.mul_(~after_false_mask).add_(
|
|
substitute_token_ids.mul(after_false_mask))
|
|
|
|
self.num_accepted_tokens += accepted.sum()
|
|
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
|
|
self.num_draft_tokens += batch_size * k
|
|
|
|
return output_with_bonus_tokens
|
|
|
|
def _raise_if_incorrect_input(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
self._raise_if_incorrect_shape(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_incorrect_dtype(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_inconsistent_device(target_with_bonus_probs,
|
|
draft_token_ids, bonus_token_ids,
|
|
draft_probs)
|
|
self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1],
|
|
draft_token_ids, bonus_token_ids)
|
|
|
|
def _raise_if_incorrect_shape(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
(target_batch_size, num_target_probs,
|
|
target_vocab_size) = target_with_bonus_probs.shape
|
|
|
|
# Does not count the extra token
|
|
num_target_probs -= 1
|
|
|
|
# validate the shape of draft token ids.
|
|
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
|
|
assert draft_token_ids_batch_size == target_batch_size
|
|
assert num_draft_token_ids == num_target_probs
|
|
|
|
# validate the shape of bonus token ids
|
|
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
|
|
assert bonus_batch_size == target_batch_size
|
|
assert num_bonus_tokens == self._num_bonus_tokens
|
|
|
|
# validate the shape of draft probs if it is set
|
|
if draft_probs is not None:
|
|
(draft_batch_size, num_draft_probs,
|
|
draft_vocab_size) = draft_probs.shape
|
|
assert draft_batch_size == target_batch_size
|
|
assert num_draft_probs == num_target_probs
|
|
assert (draft_vocab_size == target_vocab_size
|
|
), f"{draft_vocab_size=} {target_vocab_size=}"
|
|
|
|
def _raise_if_incorrect_dtype(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
assert target_with_bonus_probs.dtype == self.probs_dtype
|
|
assert draft_token_ids.dtype == self.token_id_dtype
|
|
assert bonus_token_ids.dtype == self.token_id_dtype
|
|
if draft_probs is not None:
|
|
assert draft_probs.dtype == self.probs_dtype
|
|
|
|
def _raise_if_inconsistent_device(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
devices = [
|
|
t.device for t in [
|
|
target_with_bonus_probs, bonus_token_ids, draft_probs,
|
|
draft_token_ids
|
|
] if t is not None
|
|
]
|
|
assert all([devices[0] == device for device in devices])
|
|
|
|
def _raise_if_out_of_bounds_vocab(
|
|
self,
|
|
vocab_size: int,
|
|
draft_token_ids: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
) -> None:
|
|
assert torch.all(bonus_token_ids < vocab_size)
|
|
assert torch.all(bonus_token_ids >= 0)
|
|
assert torch.all(draft_token_ids < vocab_size)
|
|
assert torch.all(draft_token_ids >= 0)
|
|
|
|
|
|
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step which are deterministic.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
|
"""Base class for samplers used for Speculative Decoding verification
|
|
step which are stochastic
|
|
"""
|
|
|
|
@abstractmethod
|
|
def forward(
|
|
self,
|
|
target_with_bonus_probs: torch.Tensor,
|
|
bonus_token_ids: torch.Tensor,
|
|
draft_probs: torch.Tensor,
|
|
draft_token_ids: torch.Tensor,
|
|
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|