[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)

Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
Cody Yu 2024-05-08 14:44:00 -07:00 committed by GitHub
parent 89579a201f
commit f942efb5a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 227 additions and 39 deletions

View File

@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int,
def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str):
"""Verify the output has correct format given predetermined accepted matrix.
"""
@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1),
dtype=torch.int64)
rejection_sampler = RejectionSampler()
rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
rejection_sampler.init_gpu_tensors(rank=0)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids,
)
# Bonus tokens are currently disabled. Verify they're set to -1.
expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.

View File

@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,

View File

@ -0,0 +1,77 @@
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@ -692,6 +692,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]:
@ -720,6 +721,9 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
@ -730,7 +734,7 @@ class SpeculativeConfig:
the necessary conditions are met, else None.
"""
if (speculative_model is None and num_speculative_tokens is None):
if speculative_model is None and num_speculative_tokens is None:
return None
if speculative_model is not None and num_speculative_tokens is None:
@ -739,6 +743,12 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
assert (speculative_model is not None
and num_speculative_tokens is not None)
@ -807,6 +817,7 @@ class SpeculativeConfig:
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
)
@ -876,8 +887,9 @@ class SpeculativeConfig:
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
ngram_prompt_lookup_max: int,
ngram_prompt_lookup_min: int,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
):
"""Create a SpeculativeConfig object.
@ -886,12 +898,19 @@ class SpeculativeConfig:
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self._verify_args()

View File

@ -83,6 +83,7 @@ class EngineArgs:
speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
@ -467,6 +468,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
@ -547,6 +555,8 @@ class EngineArgs:
target_dtype=self.dtype,
speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,

View File

@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)
assert self.parallel_config.world_size == 1, (

View File

@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf.
"""
def __init__(self, strict_mode: bool = False):
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
@ -312,7 +318,8 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(

View File

@ -612,6 +612,12 @@ class SequenceGroupMetadata:
self._token_chunk_size = token_chunk_size
self.do_sample = do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None
if self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()

View File

@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def create_worker(
cls,
scorer_worker: WorkerBase,
draft_worker_kwargs,
draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int],
) -> "SpecDecodeWorker":
ngram_prompt_lookup_max = (
@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
disable_bonus_tokens = True
if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max)
@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return SpecDecodeWorker(
proposer_worker,
scorer_worker,
# TODO(cade) disable strict mode for speedup.
rejection_sampler=RejectionSampler(strict_mode=True),
)
disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens, ))
def __init__(
self,
@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
):
"""
Create a SpecDecodeWorker.
@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Worker.
rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
for testing purposes.
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.rejection_sampler = rejection_sampler
self._metrics = AsyncMetricsCollector(
@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding "
"requires non-None seq_group_metadata_list")
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req)
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the
two.
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to
the proposer and scorer model so that the KV cache is consistent
between the two. When skip_proposer is True, the proposer model is
not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
"""
#logger.info("run proposer worker no spec")
if not skip_proposer:
self.proposer_worker.execute_model(execute_model_req)
self.proposer_worker.execute_model(execute_model_req)
#logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1
sampler_output = sampler_output[0]
@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence.
"""
#logger.info("get spec proposals")
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
#logger.info("score proposals")
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
#logger.info("verify proposals")
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
#logger.info("create output list")
return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,

View File

@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens,
nonzero_proposal_len_seqs,
nonzero_proposal_len_indices,
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative
@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
return proposals
def _split_by_max_model_len(
def _split_by_proposal_len(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length."""
"""Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if seq_group_metadata.num_speculative_tokens == 0:
proposal_lens.append(0)
continue
seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len()
@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
# are supported.
# If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len):
proposal_lens.append(proposal_len)
new_k = proposal_len
nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i)
else:
proposal_lens.append(0)
proposal_lens.append(new_k)
seq_group_metadata.num_speculative_tokens = new_k
return (
proposal_lens,