mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 07:29:08 +08:00
[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
89579a201f
commit
f942efb5a3
@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"which_tokens_accepted",
|
"which_tokens_accepted",
|
||||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||||
|
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
def test_correct_output_format(which_tokens_accepted: str,
|
||||||
|
disable_bonus_tokens: bool, seed: int,
|
||||||
device: str):
|
device: str):
|
||||||
"""Verify the output has correct format given predetermined accepted matrix.
|
"""Verify the output has correct format given predetermined accepted matrix.
|
||||||
"""
|
"""
|
||||||
@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
size=(batch_size, 1),
|
size=(batch_size, 1),
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
rejection_sampler = RejectionSampler()
|
rejection_sampler = RejectionSampler(
|
||||||
|
disable_bonus_tokens=disable_bonus_tokens)
|
||||||
rejection_sampler.init_gpu_tensors(rank=0)
|
rejection_sampler.init_gpu_tensors(rank=0)
|
||||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||||
accepted,
|
accepted,
|
||||||
@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
|||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bonus tokens are currently disabled. Verify they're set to -1.
|
expected_bonus_token_ids = bonus_token_ids.clone()
|
||||||
|
# If bonus tokens disabled. Verify they are set to -1.
|
||||||
# See https://github.com/vllm-project/vllm/issues/4212
|
# See https://github.com/vllm-project/vllm/issues/4212
|
||||||
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
|
if disable_bonus_tokens:
|
||||||
|
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
|
||||||
|
|
||||||
if which_tokens_accepted == "all_tokens_accepted":
|
if which_tokens_accepted == "all_tokens_accepted":
|
||||||
# Expect all tokens to be equal to draft tokens.
|
# Expect all tokens to be equal to draft tokens.
|
||||||
|
|||||||
@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
|||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"speculative_disable_by_batch_size": 2,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("output_len", [10])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify greedy equality when all sequences disable speculation.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
|||||||
@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
@pytest.mark.parametrize("output_len", [
|
@pytest.mark.parametrize("output_len", [
|
||||||
256,
|
256,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||||
test_llm_generator, batch_size: int,
|
test_llm_generator, batch_size: int,
|
||||||
|
|||||||
77
tests/spec_decode/test_dynamic_spec_decode.py
Normal file
77
tests/spec_decode/test_dynamic_spec_decode.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
|
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||||
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
|
||||||
|
from .utils import create_batch, mock_worker
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('queue_size', [2, 4])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
|
||||||
|
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
|
||||||
|
"""Verify that speculative tokens are disabled when the batch size
|
||||||
|
exceeds the threshold.
|
||||||
|
"""
|
||||||
|
disable_by_batch_size = 3
|
||||||
|
|
||||||
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
|
target_worker = mock_worker()
|
||||||
|
rejection_sampler = MagicMock(spec=RejectionSampler)
|
||||||
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||||
|
scorer_worker=target_worker,
|
||||||
|
rejection_sampler=rejection_sampler,
|
||||||
|
metrics_collector=metrics_collector,
|
||||||
|
disable_by_batch_size=disable_by_batch_size)
|
||||||
|
|
||||||
|
exception_secret = 'artificial stop'
|
||||||
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k,
|
||||||
|
running_queue_size=queue_size)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
|
||||||
|
# When the batch size is larger than the threshold,
|
||||||
|
# we expect no speculative tokens (0).
|
||||||
|
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
|
||||||
|
assert seq_group_metadata_list[
|
||||||
|
0].num_speculative_tokens == expected_num_spec_tokens
|
||||||
|
|
||||||
|
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
|
proposer = Top1Proposer(
|
||||||
|
worker=draft_worker,
|
||||||
|
device='cpu', # not used
|
||||||
|
vocab_size=100, # not used
|
||||||
|
# Must be long enough to avoid being skipped due to length.
|
||||||
|
max_proposal_len=1024,
|
||||||
|
)
|
||||||
|
|
||||||
|
if queue_size < disable_by_batch_size:
|
||||||
|
# Should raise exception when executing the mocked draft model.
|
||||||
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
|
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k), )
|
||||||
|
else:
|
||||||
|
# Should not execute the draft model because spec decode is disabled
|
||||||
|
# for all requests. Accordingly, the proposal length should be 0.
|
||||||
|
proposals = proposer.get_proposals(
|
||||||
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k), )
|
||||||
|
assert proposals.proposal_lens.tolist() == [0] * batch_size
|
||||||
@ -692,6 +692,7 @@ class SpeculativeConfig:
|
|||||||
speculative_max_model_len: Optional[int],
|
speculative_max_model_len: Optional[int],
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
use_v2_block_manager: bool,
|
use_v2_block_manager: bool,
|
||||||
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
) -> Optional["SpeculativeConfig"]:
|
) -> Optional["SpeculativeConfig"]:
|
||||||
@ -720,6 +721,9 @@ class SpeculativeConfig:
|
|||||||
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
use_v2_block_manager (bool): Whether vLLM is configured to use the
|
||||||
v2 block manager or not. Used for raising an error since the v2
|
v2 block manager or not. Used for raising an error since the v2
|
||||||
block manager is required with spec decode.
|
block manager is required with spec decode.
|
||||||
|
speculative_disable_by_batch_size (Optional[int]): Disable
|
||||||
|
speculative decoding for new incoming requests when the number
|
||||||
|
of enqueue requests is larger than this value, if provided.
|
||||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
||||||
window, if provided.
|
window, if provided.
|
||||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||||
@ -730,7 +734,7 @@ class SpeculativeConfig:
|
|||||||
the necessary conditions are met, else None.
|
the necessary conditions are met, else None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (speculative_model is None and num_speculative_tokens is None):
|
if speculative_model is None and num_speculative_tokens is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if speculative_model is not None and num_speculative_tokens is None:
|
if speculative_model is not None and num_speculative_tokens is None:
|
||||||
@ -739,6 +743,12 @@ class SpeculativeConfig:
|
|||||||
"num_speculative_tokens to be provided, but found "
|
"num_speculative_tokens to be provided, but found "
|
||||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||||
|
|
||||||
|
if (speculative_disable_by_batch_size is not None
|
||||||
|
and speculative_disable_by_batch_size < 2):
|
||||||
|
raise ValueError("Expect the batch size threshold of disabling "
|
||||||
|
"speculative decoding is > 1, but got "
|
||||||
|
f"{speculative_disable_by_batch_size=}")
|
||||||
|
|
||||||
assert (speculative_model is not None
|
assert (speculative_model is not None
|
||||||
and num_speculative_tokens is not None)
|
and num_speculative_tokens is not None)
|
||||||
|
|
||||||
@ -807,6 +817,7 @@ class SpeculativeConfig:
|
|||||||
draft_model_config,
|
draft_model_config,
|
||||||
draft_parallel_config,
|
draft_parallel_config,
|
||||||
num_speculative_tokens,
|
num_speculative_tokens,
|
||||||
|
speculative_disable_by_batch_size,
|
||||||
ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min,
|
||||||
)
|
)
|
||||||
@ -876,8 +887,9 @@ class SpeculativeConfig:
|
|||||||
draft_model_config: ModelConfig,
|
draft_model_config: ModelConfig,
|
||||||
draft_parallel_config: ParallelConfig,
|
draft_parallel_config: ParallelConfig,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
ngram_prompt_lookup_max: int,
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_min: int,
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
):
|
):
|
||||||
"""Create a SpeculativeConfig object.
|
"""Create a SpeculativeConfig object.
|
||||||
|
|
||||||
@ -886,12 +898,19 @@ class SpeculativeConfig:
|
|||||||
draft_parallel_config: ParallelConfig for the draft model.
|
draft_parallel_config: ParallelConfig for the draft model.
|
||||||
num_speculative_tokens: The number of tokens to sample from the
|
num_speculative_tokens: The number of tokens to sample from the
|
||||||
draft model before scoring with the target model.
|
draft model before scoring with the target model.
|
||||||
|
speculative_disable_by_batch_size: Disable speculative
|
||||||
|
decoding for new incoming requests when the number of
|
||||||
|
enqueue requests is larger than this value.
|
||||||
|
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||||
|
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||||
"""
|
"""
|
||||||
self.draft_model_config = draft_model_config
|
self.draft_model_config = draft_model_config
|
||||||
self.draft_parallel_config = draft_parallel_config
|
self.draft_parallel_config = draft_parallel_config
|
||||||
self.num_speculative_tokens = num_speculative_tokens
|
self.num_speculative_tokens = num_speculative_tokens
|
||||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max
|
self.speculative_disable_by_batch_size = \
|
||||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min
|
speculative_disable_by_batch_size
|
||||||
|
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||||
|
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
|||||||
@ -83,6 +83,7 @@ class EngineArgs:
|
|||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
speculative_max_model_len: Optional[int] = None
|
speculative_max_model_len: Optional[int] = None
|
||||||
|
speculative_disable_by_batch_size: Optional[int] = None
|
||||||
ngram_prompt_lookup_max: Optional[int] = None
|
ngram_prompt_lookup_max: Optional[int] = None
|
||||||
ngram_prompt_lookup_min: Optional[int] = None
|
ngram_prompt_lookup_min: Optional[int] = None
|
||||||
|
|
||||||
@ -467,6 +468,13 @@ class EngineArgs:
|
|||||||
'draft model. Sequences over this length will skip '
|
'draft model. Sequences over this length will skip '
|
||||||
'speculation.')
|
'speculation.')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--speculative-disable-by-batch-size',
|
||||||
|
type=int,
|
||||||
|
default=EngineArgs.speculative_disable_by_batch_size,
|
||||||
|
help='Disable speculative decoding for new incoming requests '
|
||||||
|
'if the number of enqueue requests is larger than this value.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--ngram-prompt-lookup-max',
|
'--ngram-prompt-lookup-max',
|
||||||
type=int,
|
type=int,
|
||||||
@ -547,6 +555,8 @@ class EngineArgs:
|
|||||||
target_dtype=self.dtype,
|
target_dtype=self.dtype,
|
||||||
speculative_model=self.speculative_model,
|
speculative_model=self.speculative_model,
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
speculative_disable_by_batch_size=self.
|
||||||
|
speculative_disable_by_batch_size,
|
||||||
speculative_max_model_len=self.speculative_max_model_len,
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
use_v2_block_manager=self.use_v2_block_manager,
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
|
|||||||
@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
|
|||||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||||
scorer_worker=target_worker,
|
scorer_worker=target_worker,
|
||||||
draft_worker_kwargs=draft_worker_kwargs,
|
draft_worker_kwargs=draft_worker_kwargs,
|
||||||
|
disable_by_batch_size=self.speculative_config.
|
||||||
|
speculative_disable_by_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert self.parallel_config.world_size == 1, (
|
assert self.parallel_config.world_size == 1, (
|
||||||
|
|||||||
@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
|
|||||||
https://arxiv.org/pdf/2302.01318.pdf.
|
https://arxiv.org/pdf/2302.01318.pdf.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, strict_mode: bool = False):
|
def __init__(self,
|
||||||
|
disable_bonus_tokens: bool = True,
|
||||||
|
strict_mode: bool = False):
|
||||||
"""Create a rejection sampler.
|
"""Create a rejection sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
disable_bonus_tokens: Whether or not to disable the bonus token.
|
||||||
|
Require when bonus tokens will cause corrupt KV cache for
|
||||||
|
proposal methods that require KV cache.
|
||||||
strict_mode: Whether or not to perform shape/device/dtype checks
|
strict_mode: Whether or not to perform shape/device/dtype checks
|
||||||
during sampling. This catches correctness issues but adds
|
during sampling. This catches correctness issues but adds
|
||||||
nontrivial latency.
|
nontrivial latency.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self._disable_bonus_tokens = disable_bonus_tokens
|
||||||
self._strict_mode = strict_mode
|
self._strict_mode = strict_mode
|
||||||
|
|
||||||
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
# NOTE: A "bonus token" is accepted iff all proposal tokens are
|
||||||
@ -312,7 +318,8 @@ class RejectionSampler(nn.Module):
|
|||||||
# proposal methods that require KV cache. We can fix it by "prefilling"
|
# proposal methods that require KV cache. We can fix it by "prefilling"
|
||||||
# the bonus token in the proposer. The following issue tracks the fix.
|
# the bonus token in the proposer. The following issue tracks the fix.
|
||||||
# https://github.com/vllm-project/vllm/issues/4212
|
# https://github.com/vllm-project/vllm/issues/4212
|
||||||
output_with_bonus_tokens[:, -1] = -1
|
if self._disable_bonus_tokens:
|
||||||
|
output_with_bonus_tokens[:, -1] = -1
|
||||||
|
|
||||||
# Fill the recovered token ids.
|
# Fill the recovered token ids.
|
||||||
output.mul_(~after_false_mask).add_(
|
output.mul_(~after_false_mask).add_(
|
||||||
|
|||||||
@ -612,6 +612,12 @@ class SequenceGroupMetadata:
|
|||||||
self._token_chunk_size = token_chunk_size
|
self._token_chunk_size = token_chunk_size
|
||||||
self.do_sample = do_sample
|
self.do_sample = do_sample
|
||||||
|
|
||||||
|
# The number of speculative tokens adopted in this request.
|
||||||
|
# None means specuative decoding is not used.
|
||||||
|
# Zero means speculative decoding is disabled for some reasons.
|
||||||
|
# TODO: We should maintain this states out of the sequence group.
|
||||||
|
self.num_speculative_tokens = None
|
||||||
|
|
||||||
if self._token_chunk_size is None:
|
if self._token_chunk_size is None:
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
def create_worker(
|
def create_worker(
|
||||||
cls,
|
cls,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
draft_worker_kwargs,
|
draft_worker_kwargs: Dict[str, Any],
|
||||||
|
disable_by_batch_size: Optional[int],
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
|
|
||||||
ngram_prompt_lookup_max = (
|
ngram_prompt_lookup_max = (
|
||||||
@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
ngram_prompt_lookup_min = (
|
ngram_prompt_lookup_min = (
|
||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||||
|
|
||||||
|
disable_bonus_tokens = True
|
||||||
if ngram_prompt_lookup_max > 0:
|
if ngram_prompt_lookup_max > 0:
|
||||||
|
disable_bonus_tokens = False
|
||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
|
||||||
ngram_prompt_lookup_max)
|
ngram_prompt_lookup_max)
|
||||||
@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
return SpecDecodeWorker(
|
return SpecDecodeWorker(
|
||||||
proposer_worker,
|
proposer_worker,
|
||||||
scorer_worker,
|
scorer_worker,
|
||||||
# TODO(cade) disable strict mode for speedup.
|
disable_by_batch_size=disable_by_batch_size,
|
||||||
rejection_sampler=RejectionSampler(strict_mode=True),
|
rejection_sampler=RejectionSampler(
|
||||||
)
|
disable_bonus_tokens=disable_bonus_tokens, ))
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
rejection_sampler: RejectionSampler,
|
rejection_sampler: RejectionSampler,
|
||||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
|
disable_by_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a SpecDecodeWorker.
|
Create a SpecDecodeWorker.
|
||||||
@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
Worker.
|
Worker.
|
||||||
rejection_sampler: A Torch module used to perform modified rejection
|
rejection_sampler: A Torch module used to perform modified rejection
|
||||||
sampling for speculative decoding.
|
sampling for speculative decoding.
|
||||||
|
disable_by_batch_size: If the batch size is larger than this,
|
||||||
|
disable speculative decoding for new incoming requests.
|
||||||
metrics_collector: Helper class for collecting metrics; can be set
|
metrics_collector: Helper class for collecting metrics; can be set
|
||||||
for testing purposes.
|
for testing purposes.
|
||||||
"""
|
"""
|
||||||
self.proposer_worker = proposer_worker
|
self.proposer_worker = proposer_worker
|
||||||
self.scorer_worker = scorer_worker
|
self.scorer_worker = scorer_worker
|
||||||
|
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||||
self.rejection_sampler = rejection_sampler
|
self.rejection_sampler = rejection_sampler
|
||||||
|
|
||||||
self._metrics = AsyncMetricsCollector(
|
self._metrics = AsyncMetricsCollector(
|
||||||
@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
"speculative decoding "
|
"speculative decoding "
|
||||||
"requires non-None seq_group_metadata_list")
|
"requires non-None seq_group_metadata_list")
|
||||||
|
|
||||||
|
# When the batch size is too large, disable speculative decoding
|
||||||
|
# to stop trading off throughput for latency.
|
||||||
|
disable_all = (execute_model_req.running_queue_size >=
|
||||||
|
self.disable_by_batch_size)
|
||||||
|
if disable_all:
|
||||||
|
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||||
|
# Once num_speculative_tokens is set to 0, the spec decode
|
||||||
|
# of this request will be disabled forever.
|
||||||
|
# TODO(comaniac): We currently store spec decoding specific
|
||||||
|
# state in the global data structure, but we should maintain
|
||||||
|
# this state within spec decode worker.
|
||||||
|
seq_group_metadata.num_speculative_tokens = 0
|
||||||
|
|
||||||
# If no spec tokens, call the proposer and scorer workers normally.
|
# If no spec tokens, call the proposer and scorer workers normally.
|
||||||
# Used for prefill.
|
# This happens for prefill, or when the spec decode is disabled
|
||||||
|
# for this batch.
|
||||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||||
execute_model_req.seq_group_metadata_list) == 0:
|
execute_model_req.seq_group_metadata_list) == 0:
|
||||||
return self._run_no_spec(execute_model_req)
|
return self._run_no_spec(execute_model_req,
|
||||||
|
skip_proposer=disable_all)
|
||||||
|
|
||||||
return self._run_speculative_decoding_step(execute_model_req)
|
return self._run_speculative_decoding_step(execute_model_req)
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||||
def _run_no_spec(
|
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||||
self,
|
skip_proposer: bool) -> List[SamplerOutput]:
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
"""Run a prefill step, without any speculation. The input is sent to
|
||||||
"""Run a prefill step, without any speculation. The input is sent to the
|
the proposer and scorer model so that the KV cache is consistent
|
||||||
proposer and scorer model so that the KV cache is consistent between the
|
between the two. When skip_proposer is True, the proposer model is
|
||||||
two.
|
not called, meaning that the kv-cache in proposer for requests is not
|
||||||
|
updated, so they cannot enable spec decode in the rest decoding.
|
||||||
"""
|
"""
|
||||||
#logger.info("run proposer worker no spec")
|
if not skip_proposer:
|
||||||
|
self.proposer_worker.execute_model(execute_model_req)
|
||||||
|
|
||||||
self.proposer_worker.execute_model(execute_model_req)
|
|
||||||
|
|
||||||
#logger.info("run target worker no spec")
|
|
||||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||||
assert len(sampler_output) == 1
|
assert len(sampler_output) == 1
|
||||||
sampler_output = sampler_output[0]
|
sampler_output = sampler_output[0]
|
||||||
@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
sequence.
|
sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#logger.info("get spec proposals")
|
|
||||||
# Generate proposals using draft worker.
|
# Generate proposals using draft worker.
|
||||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
#logger.info("score proposals")
|
|
||||||
proposal_scores = self.scorer.score_proposals(
|
proposal_scores = self.scorer.score_proposals(
|
||||||
execute_model_req,
|
execute_model_req,
|
||||||
proposals,
|
proposals,
|
||||||
)
|
)
|
||||||
|
|
||||||
#logger.info("verify proposals")
|
|
||||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||||
proposals, execute_model_req.num_lookahead_slots)
|
proposals, execute_model_req.num_lookahead_slots)
|
||||||
|
|
||||||
#logger.info("create output list")
|
|
||||||
return self._create_output_sampler_list(
|
return self._create_output_sampler_list(
|
||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
accepted_token_ids,
|
accepted_token_ids,
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
proposal_lens,
|
proposal_lens,
|
||||||
nonzero_proposal_len_seqs,
|
nonzero_proposal_len_seqs,
|
||||||
nonzero_proposal_len_indices,
|
nonzero_proposal_len_indices,
|
||||||
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len)
|
) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
|
||||||
|
|
||||||
if nonzero_proposal_len_seqs:
|
if nonzero_proposal_len_seqs:
|
||||||
# Speculate tokens using the draft worker for the speculative
|
# Speculate tokens using the draft worker for the speculative
|
||||||
@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
|
|
||||||
return proposals
|
return proposals
|
||||||
|
|
||||||
def _split_by_max_model_len(
|
def _split_by_proposal_len(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
proposal_len: int,
|
proposal_len: int,
|
||||||
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
|
||||||
"""Determine which sequences would exceed the max model length."""
|
"""Split sequences by two groups:
|
||||||
|
1. Sequences with non-zero proposal length.
|
||||||
|
2. Sequences with zero proposal length (due to disabled speculation
|
||||||
|
or exceed the maximum model length).
|
||||||
|
"""
|
||||||
|
|
||||||
proposal_lens: List[int] = []
|
proposal_lens: List[int] = []
|
||||||
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
|
||||||
nonzero_proposal_len_indices: List[int] = []
|
nonzero_proposal_len_indices: List[int] = []
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
# The speculative decoding for this request has been disabled
|
||||||
|
# (e.g. due to high traffic).
|
||||||
|
if seq_group_metadata.num_speculative_tokens == 0:
|
||||||
|
proposal_lens.append(0)
|
||||||
|
continue
|
||||||
|
|
||||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||||
seq_len = seq_data.get_len()
|
seq_len = seq_data.get_len()
|
||||||
|
|
||||||
@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
# are supported.
|
# are supported.
|
||||||
# If max_proposal_len is defined, then we shall no exccess this
|
# If max_proposal_len is defined, then we shall no exccess this
|
||||||
# quota for nonzero_proposal
|
# quota for nonzero_proposal
|
||||||
|
new_k = 0
|
||||||
if (self.max_proposal_len is None
|
if (self.max_proposal_len is None
|
||||||
or seq_len + proposal_len < self.max_proposal_len):
|
or seq_len + proposal_len < self.max_proposal_len):
|
||||||
proposal_lens.append(proposal_len)
|
new_k = proposal_len
|
||||||
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
nonzero_proposal_len_seqs.append(seq_group_metadata)
|
||||||
nonzero_proposal_len_indices.append(i)
|
nonzero_proposal_len_indices.append(i)
|
||||||
else:
|
proposal_lens.append(new_k)
|
||||||
proposal_lens.append(0)
|
seq_group_metadata.num_speculative_tokens = new_k
|
||||||
|
|
||||||
return (
|
return (
|
||||||
proposal_lens,
|
proposal_lens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user