mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:55:01 +08:00
[Bugfix] Make spec. decode respect per-request seed. (#6034)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
b5672a112c
commit
d4201e06d5
@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
|||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
size=(batch_size, k),
|
size=(batch_size, k),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
generators = [None] * batch_size
|
||||||
|
|
||||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||||
draft_token_ids)
|
draft_token_ids, generators)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||||
|
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||||
|
@pytest.mark.parametrize("n_rep", [100])
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||||
|
frac_seeded: float, n_rep: int,
|
||||||
|
device: str):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
rejection_sampler = RejectionSampler()
|
||||||
|
rejection_sampler.init_gpu_tensors(rank=0)
|
||||||
|
|
||||||
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
|
bonus_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, 1),
|
||||||
|
dtype=torch.int64)
|
||||||
|
draft_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
|
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for _ in range(n_rep):
|
||||||
|
generators = [
|
||||||
|
torch.Generator(
|
||||||
|
device=device).manual_seed(i) if seeded_mask[i] else None
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
results.append(
|
||||||
|
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||||
|
draft_token_ids, generators))
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if seeded_mask[i]:
|
||||||
|
for j in range(1, n_rep):
|
||||||
|
assert torch.equal(results[j][i], results[0][i])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||||
@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
raise AssertionError()
|
raise AssertionError()
|
||||||
|
|
||||||
oob_token_ids[0][0] = rogue_token_id
|
oob_token_ids[0][0] = rogue_token_id
|
||||||
|
generators = [None] * batch_size
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||||
draft_token_ids)
|
draft_token_ids, generators)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||||
@ -371,11 +417,15 @@ class _CorrectnessTestHelper:
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cuda").repeat(num_samples, 1)
|
device="cuda").repeat(num_samples, 1)
|
||||||
|
|
||||||
|
# unseeded
|
||||||
|
generators = [None]
|
||||||
|
|
||||||
# Get output tokens via rejection sampling.
|
# Get output tokens via rejection sampling.
|
||||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||||
bonus_token_ids.to("cuda"),
|
bonus_token_ids.to("cuda"),
|
||||||
draft_probs.to("cuda"),
|
draft_probs.to("cuda"),
|
||||||
draft_token_ids.to("cuda"))
|
draft_token_ids.to("cuda"),
|
||||||
|
generators)
|
||||||
|
|
||||||
# Remove bonus tokens
|
# Remove bonus tokens
|
||||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
import ray
|
||||||
@ -128,7 +128,9 @@ class AsyncLLM:
|
|||||||
try:
|
try:
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
prompt = prompts[i] if prompts is not None else None
|
prompt = prompts[i] if prompts is not None else None
|
||||||
res = asyncio.run(get_output(prompt, sampling_params))
|
params = sampling_params[i] if isinstance(
|
||||||
|
sampling_params, Sequence) else sampling_params
|
||||||
|
res = asyncio.run(get_output(prompt, params))
|
||||||
outputs.append(res)
|
outputs.append(res)
|
||||||
finally:
|
finally:
|
||||||
ray.shutdown()
|
ray.shutdown()
|
||||||
@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
|||||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
the same when temperature is zero.
|
the same when temperature is zero.
|
||||||
"""
|
"""
|
||||||
temperature = 0.0
|
|
||||||
|
run_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len,
|
||||||
|
force_output_len,
|
||||||
|
temperature=0.0,
|
||||||
|
seeded=False,
|
||||||
|
print_tokens=print_tokens,
|
||||||
|
ensure_all_accepted=ensure_all_accepted)
|
||||||
|
|
||||||
|
|
||||||
|
def run_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len,
|
||||||
|
force_output_len: bool,
|
||||||
|
temperature: float,
|
||||||
|
seeded: bool,
|
||||||
|
print_tokens: bool = False,
|
||||||
|
ensure_all_accepted: bool = False):
|
||||||
|
"""Helper method that compares the outputs of both the baseline LLM and
|
||||||
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
|
the same when temperature is zero (or when temperature is > 0 and seeded).
|
||||||
|
"""
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@ -286,6 +312,16 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
|||||||
# sampling params to ignore eos token.
|
# sampling params to ignore eos token.
|
||||||
ignore_eos = force_output_len
|
ignore_eos = force_output_len
|
||||||
|
|
||||||
|
if seeded:
|
||||||
|
sampling_params = [
|
||||||
|
SamplingParams(
|
||||||
|
max_tokens=max_output_len,
|
||||||
|
ignore_eos=ignore_eos,
|
||||||
|
temperature=temperature,
|
||||||
|
seed=i,
|
||||||
|
) for i in range(len(prompts))
|
||||||
|
]
|
||||||
|
else:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
max_tokens=max_output_len,
|
max_tokens=max_output_len,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
|||||||
44
tests/spec_decode/e2e/test_seed.py
Normal file
44
tests/spec_decode/e2e/test_seed.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-68m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# speculative model
|
||||||
|
"speculative_model": "JackFram/llama-160m",
|
||||||
|
|
||||||
|
# num speculative tokens
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 8, 32])
|
||||||
|
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
10,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
|
||||||
|
temperature: float, output_len: int):
|
||||||
|
"""Verify outputs are consistent across multiple runs with same seed
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(baseline_llm_generator,
|
||||||
|
baseline_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
temperature=temperature,
|
||||||
|
seeded=True,
|
||||||
|
force_output_len=True)
|
||||||
@ -1,14 +1,14 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
|
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeStochasticBaseSampler)
|
||||||
|
|
||||||
|
|
||||||
class RejectionSampler(SpecDecodeBaseSampler):
|
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||||
Language Model Decoding with Speculative Sampling"
|
Language Model Decoding with Speculative Sampling"
|
||||||
https://arxiv.org/pdf/2302.01318.pdf.
|
https://arxiv.org/pdf/2302.01318.pdf.
|
||||||
@ -36,6 +36,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
bonus_token_ids: torch.Tensor,
|
bonus_token_ids: torch.Tensor,
|
||||||
draft_probs: torch.Tensor,
|
draft_probs: torch.Tensor,
|
||||||
draft_token_ids: torch.Tensor,
|
draft_token_ids: torch.Tensor,
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||||
tokens proposed by the draft model using the probability of each token
|
tokens proposed by the draft model using the probability of each token
|
||||||
@ -82,6 +83,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
target_probs,
|
target_probs,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
|
generators,
|
||||||
))
|
))
|
||||||
|
|
||||||
output_token_ids = self._create_output(
|
output_token_ids = self._create_output(
|
||||||
@ -98,6 +100,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Perform modified rejection sampling on each sequence.
|
"""Perform modified rejection sampling on each sequence.
|
||||||
|
|
||||||
@ -114,15 +117,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
|
|
||||||
# shape [batch_size, k]
|
# shape [batch_size, k]
|
||||||
accepted = self._get_accepted(target_probs, draft_probs,
|
accepted = self._get_accepted(target_probs, draft_probs,
|
||||||
draft_token_ids)
|
draft_token_ids, generators)
|
||||||
|
|
||||||
recovered_probs = self._get_recovered_probs(
|
recovered_probs = self._get_recovered_probs(
|
||||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||||
|
|
||||||
|
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||||
|
generators, k=k)
|
||||||
|
|
||||||
# NOTE: the recovered_probs are overwritten by this method.
|
# NOTE: the recovered_probs are overwritten by this method.
|
||||||
recovered_token_ids = _multinomial(recovered_probs,
|
recovered_token_ids = _multinomial(
|
||||||
num_samples=1).reshape(
|
recovered_probs,
|
||||||
batch_size, k)
|
num_samples=1,
|
||||||
|
k=k,
|
||||||
|
generators=generators,
|
||||||
|
seed_indices=seed_indices,
|
||||||
|
# this arg is unused when None but torch.jit requires a list
|
||||||
|
non_seed_indices=non_seed_indices or [],
|
||||||
|
).reshape(batch_size, k)
|
||||||
|
|
||||||
return accepted, recovered_token_ids
|
return accepted, recovered_token_ids
|
||||||
|
|
||||||
def _get_accepted(
|
def _get_accepted(
|
||||||
@ -130,6 +143,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""Create bool matrix over the proposed draft tokens. If
|
r"""Create bool matrix over the proposed draft tokens. If
|
||||||
True, then a token can be accepted, else it should be
|
True, then a token can be accepted, else it should be
|
||||||
@ -164,10 +178,28 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
||||||
draft_token_ids]
|
draft_token_ids]
|
||||||
|
|
||||||
uniform_rand = torch.rand(batch_size,
|
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||||
|
generators)
|
||||||
|
|
||||||
|
if len(seed_indices) == 0:
|
||||||
|
uniform_rand = torch.rand_like(selected_target_probs)
|
||||||
|
else:
|
||||||
|
uniform_rand = torch.empty_like(selected_target_probs)
|
||||||
|
|
||||||
|
for idx in seed_indices:
|
||||||
|
uniform_rand[idx, :] = torch.rand(1,
|
||||||
|
k,
|
||||||
|
dtype=self.probs_dtype,
|
||||||
|
device=target_probs.device,
|
||||||
|
generator=generators[idx])
|
||||||
|
|
||||||
|
if non_seed_indices:
|
||||||
|
uniform_rand[non_seed_indices, :] = torch.rand(
|
||||||
|
len(non_seed_indices),
|
||||||
k,
|
k,
|
||||||
dtype=self.probs_dtype,
|
dtype=self.probs_dtype,
|
||||||
device=target_probs.device)
|
device=target_probs.device)
|
||||||
|
|
||||||
capped_ratio = torch.minimum(
|
capped_ratio = torch.minimum(
|
||||||
selected_target_probs / selected_draft_probs,
|
selected_target_probs / selected_draft_probs,
|
||||||
torch.full((1, ), 1, device=target_probs.device))
|
torch.full((1, ), 1, device=target_probs.device))
|
||||||
@ -240,6 +272,27 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
"""
|
"""
|
||||||
return torch.finfo(self.probs_dtype).tiny
|
return torch.finfo(self.probs_dtype).tiny
|
||||||
|
|
||||||
|
# partition batch into indices for which a generator is provided
|
||||||
|
# and indicies for which no generator is provided
|
||||||
|
@staticmethod
|
||||||
|
def _split_batch_by_seeded(
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
|
k: int = 1,
|
||||||
|
) -> Tuple[List[int], Optional[List[int]]]:
|
||||||
|
|
||||||
|
if all(generator is None for generator in generators):
|
||||||
|
seed_indices: List[int] = []
|
||||||
|
non_seed_indices: Optional[List[int]] = None
|
||||||
|
else:
|
||||||
|
seed_indices, non_seed_indices = [], []
|
||||||
|
for i, generator in enumerate(generators):
|
||||||
|
if generator is None:
|
||||||
|
non_seed_indices.extend(range(k * i, k * (i + 1)))
|
||||||
|
else:
|
||||||
|
seed_indices.extend(range(k * i, k * (i + 1)))
|
||||||
|
|
||||||
|
return seed_indices, non_seed_indices
|
||||||
|
|
||||||
|
|
||||||
# torch.multinomial forces a GPU<->CPU sync.
|
# torch.multinomial forces a GPU<->CPU sync.
|
||||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||||
@ -250,12 +303,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
|||||||
def _multinomial(
|
def _multinomial(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
num_samples: int,
|
num_samples: int,
|
||||||
|
k: int,
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
|
seed_indices: List[int],
|
||||||
|
non_seed_indices: List[int],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if num_samples > 1:
|
if num_samples > 1:
|
||||||
# This is equivalent to torch.repeat_interleaved (which also
|
# This is equivalent to torch.repeat_interleaved (which also
|
||||||
# forces a GPU<->CPU sync).
|
# forces a GPU<->CPU sync).
|
||||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||||
probs.shape[1]).contiguous().view(
|
probs.shape[1]).contiguous().view(
|
||||||
-1, probs.shape[1])
|
-1, probs.shape[1])
|
||||||
q = torch.empty_like(probs).exponential_(1.0)
|
|
||||||
|
q = torch.empty_like(probs)
|
||||||
|
if len(seed_indices) == 0:
|
||||||
|
q.exponential_(1.0)
|
||||||
|
else:
|
||||||
|
q[non_seed_indices].exponential_(1.0)
|
||||||
|
for idx in seed_indices:
|
||||||
|
q[idx].exponential_(1.0, generator=generators[idx // k])
|
||||||
|
|
||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
def token_id_dtype(self):
|
def token_id_dtype(self):
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
target_probs: torch.Tensor,
|
|
||||||
bonus_token_ids: torch.Tensor,
|
|
||||||
draft_probs: torch.Tensor,
|
|
||||||
draft_token_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _create_output(
|
def _create_output(
|
||||||
self,
|
self,
|
||||||
accepted: torch.Tensor, # [batch_size, k]
|
accepted: torch.Tensor, # [batch_size, k]
|
||||||
@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
assert torch.all(bonus_token_ids >= 0)
|
assert torch.all(bonus_token_ids >= 0)
|
||||||
assert torch.all(draft_token_ids < vocab_size)
|
assert torch.all(draft_token_ids < vocab_size)
|
||||||
assert torch.all(draft_token_ids >= 0)
|
assert torch.all(draft_token_ids >= 0)
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||||
|
"""Base class for samplers used for Speculative Decoding verification
|
||||||
|
step which are deterministic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
target_probs: torch.Tensor,
|
||||||
|
bonus_token_ids: torch.Tensor,
|
||||||
|
draft_probs: torch.Tensor,
|
||||||
|
draft_token_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||||
|
"""Base class for samplers used for Speculative Decoding verification
|
||||||
|
step which are stochastic
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
target_probs: torch.Tensor,
|
||||||
|
bonus_token_ids: torch.Tensor,
|
||||||
|
draft_probs: torch.Tensor,
|
||||||
|
draft_token_ids: torch.Tensor,
|
||||||
|
generators: List[Optional[torch.Generator]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -2,10 +2,10 @@ import torch
|
|||||||
import torch.jit
|
import torch.jit
|
||||||
|
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeDeterministicBaseSampler)
|
||||||
|
|
||||||
|
|
||||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
|
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||||
Multiple Decoding Heads"
|
Multiple Decoding Heads"
|
||||||
|
|||||||
@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||||
SequenceGroupMetadata, get_all_seq_ids)
|
SequenceGroupMetadata, SequenceGroupState,
|
||||||
|
get_all_seq_ids)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||||
@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
for data in new_seq_data_dict.values():
|
for data in new_seq_data_dict.values():
|
||||||
data.update_num_computed_tokens(data.get_len() - 1)
|
data.update_num_computed_tokens(data.get_len() - 1)
|
||||||
|
|
||||||
|
if (seq_group_metadata.state is not None
|
||||||
|
and seq_group_metadata.state.generator is not None):
|
||||||
|
generator = torch.Generator(
|
||||||
|
device=seq_group_metadata.state.generator.device)
|
||||||
|
generator.set_state(seq_group_metadata.state.generator.get_state())
|
||||||
|
state = SequenceGroupState(generator=generator)
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
return SequenceGroupMetadata(
|
return SequenceGroupMetadata(
|
||||||
request_id=seq_group_metadata.request_id,
|
request_id=seq_group_metadata.request_id,
|
||||||
is_prompt=seq_group_metadata.is_prompt,
|
is_prompt=seq_group_metadata.is_prompt,
|
||||||
@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
},
|
},
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
token_chunk_size=1,
|
token_chunk_size=1,
|
||||||
|
state=state,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _split_scoring_output(
|
def _split_scoring_output(
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||||
SpecDecodeBaseSampler)
|
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||||
TypicalAcceptanceSampler)
|
TypicalAcceptanceSampler)
|
||||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||||
@ -521,11 +521,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# Get proposed tokens.
|
# Get proposed tokens.
|
||||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||||
|
|
||||||
|
# Sampler arguments
|
||||||
|
sampler_extra_kwargs = {}
|
||||||
|
if isinstance(self.spec_decode_sampler,
|
||||||
|
SpecDecodeStochasticBaseSampler):
|
||||||
|
|
||||||
|
# Get sequence group state
|
||||||
|
generators = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
if (seq_group_metadata.state is not None
|
||||||
|
and seq_group_metadata.state.generator is not None):
|
||||||
|
generators.append(seq_group_metadata.state.generator)
|
||||||
|
else:
|
||||||
|
generators.append(None)
|
||||||
|
|
||||||
|
sampler_extra_kwargs["generators"] = generators
|
||||||
|
|
||||||
accepted_token_ids = self.spec_decode_sampler(
|
accepted_token_ids = self.spec_decode_sampler(
|
||||||
target_probs=proposal_verifier_probs,
|
target_probs=proposal_verifier_probs,
|
||||||
bonus_token_ids=bonus_token_ids,
|
bonus_token_ids=bonus_token_ids,
|
||||||
draft_probs=proposal_probs,
|
draft_probs=proposal_probs,
|
||||||
draft_token_ids=proposal_token_ids,
|
draft_token_ids=proposal_token_ids,
|
||||||
|
**sampler_extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append output tokens from non-speculative sequences to
|
# Append output tokens from non-speculative sequences to
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user