[Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)

Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
Cody Yu 2024-05-08 14:44:00 -07:00 committed by GitHub
parent 89579a201f
commit f942efb5a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 227 additions and 39 deletions

View File

@ -42,9 +42,11 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"which_tokens_accepted", "which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int, def test_correct_output_format(which_tokens_accepted: str,
disable_bonus_tokens: bool, seed: int,
device: str): device: str):
"""Verify the output has correct format given predetermined accepted matrix. """Verify the output has correct format given predetermined accepted matrix.
""" """
@ -82,7 +84,8 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64)
rejection_sampler = RejectionSampler() rejection_sampler = RejectionSampler(
disable_bonus_tokens=disable_bonus_tokens)
rejection_sampler.init_gpu_tensors(rank=0) rejection_sampler.init_gpu_tensors(rank=0)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted, accepted,
@ -91,9 +94,11 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
bonus_token_ids, bonus_token_ids,
) )
# Bonus tokens are currently disabled. Verify they're set to -1. expected_bonus_token_ids = bonus_token_ids.clone()
# If bonus tokens disabled. Verify they are set to -1.
# See https://github.com/vllm-project/vllm/issues/4212 # See https://github.com/vllm-project/vllm/issues/4212
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1 if disable_bonus_tokens:
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
if which_tokens_accepted == "all_tokens_accepted": if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens. # Expect all tokens to be equal to draft tokens.

View File

@ -536,6 +536,40 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"speculative_disable_by_batch_size": 2,
},
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when all sequences disable speculation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{

View File

@ -57,7 +57,7 @@ from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
256, 256,
]) ])
@pytest.mark.parametrize("batch_size", [1, 64]) @pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator, def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int, test_llm_generator, batch_size: int,

View File

@ -0,0 +1,77 @@
from unittest.mock import MagicMock
import pytest
import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import create_batch, mock_worker
@pytest.mark.parametrize('queue_size', [2, 4])
@pytest.mark.parametrize('batch_size', [1, 2, 3, 6])
@pytest.mark.parametrize('k', [1, 2, 5, 7, 10])
@torch.inference_mode()
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
"""Verify that speculative tokens are disabled when the batch size
exceeds the threshold.
"""
disable_by_batch_size = 3
draft_worker = mock_worker(cls=MultiStepWorker)
target_worker = mock_worker()
rejection_sampler = MagicMock(spec=RejectionSampler)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
rejection_sampler=rejection_sampler,
metrics_collector=metrics_collector,
disable_by_batch_size=disable_by_batch_size)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
running_queue_size=queue_size)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
# When the batch size is larger than the threshold,
# we expect no speculative tokens (0).
expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0
assert seq_group_metadata_list[
0].num_speculative_tokens == expected_num_spec_tokens
draft_worker.sampler_output.side_effect = ValueError(exception_secret)
proposer = Top1Proposer(
worker=draft_worker,
device='cpu', # not used
vocab_size=100, # not used
# Must be long enough to avoid being skipped due to length.
max_proposal_len=1024,
)
if queue_size < disable_by_batch_size:
# Should raise exception when executing the mocked draft model.
with pytest.raises(ValueError, match=exception_secret):
proposer.get_proposals(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
else:
# Should not execute the draft model because spec decode is disabled
# for all requests. Accordingly, the proposal length should be 0.
proposals = proposer.get_proposals(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k), )
assert proposals.proposal_lens.tolist() == [0] * batch_size

View File

@ -692,6 +692,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool, enable_chunked_prefill: bool,
use_v2_block_manager: bool, use_v2_block_manager: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int], ngram_prompt_lookup_min: Optional[int],
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
@ -720,6 +721,9 @@ class SpeculativeConfig:
use_v2_block_manager (bool): Whether vLLM is configured to use the use_v2_block_manager (bool): Whether vLLM is configured to use the
v2 block manager or not. Used for raising an error since the v2 v2 block manager or not. Used for raising an error since the v2
block manager is required with spec decode. block manager is required with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided. window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
@ -730,7 +734,7 @@ class SpeculativeConfig:
the necessary conditions are met, else None. the necessary conditions are met, else None.
""" """
if (speculative_model is None and num_speculative_tokens is None): if speculative_model is None and num_speculative_tokens is None:
return None return None
if speculative_model is not None and num_speculative_tokens is None: if speculative_model is not None and num_speculative_tokens is None:
@ -739,6 +743,12 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found " "num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.") f"{speculative_model=} and {num_speculative_tokens=}.")
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{speculative_disable_by_batch_size=}")
assert (speculative_model is not None assert (speculative_model is not None
and num_speculative_tokens is not None) and num_speculative_tokens is not None)
@ -807,6 +817,7 @@ class SpeculativeConfig:
draft_model_config, draft_model_config,
draft_parallel_config, draft_parallel_config,
num_speculative_tokens, num_speculative_tokens,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max, ngram_prompt_lookup_max,
ngram_prompt_lookup_min, ngram_prompt_lookup_min,
) )
@ -876,8 +887,9 @@ class SpeculativeConfig:
draft_model_config: ModelConfig, draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig, draft_parallel_config: ParallelConfig,
num_speculative_tokens: int, num_speculative_tokens: int,
ngram_prompt_lookup_max: int, speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_min: int, ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
@ -886,12 +898,19 @@ class SpeculativeConfig:
draft_parallel_config: ParallelConfig for the draft model. draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model. draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens self.num_speculative_tokens = num_speculative_tokens
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max self.speculative_disable_by_batch_size = \
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self._verify_args() self._verify_args()

View File

@ -83,6 +83,7 @@ class EngineArgs:
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None
@ -467,6 +468,13 @@ class EngineArgs:
'draft model. Sequences over this length will skip ' 'draft model. Sequences over this length will skip '
'speculation.') 'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument( parser.add_argument(
'--ngram-prompt-lookup-max', '--ngram-prompt-lookup-max',
type=int, type=int,
@ -547,6 +555,8 @@ class EngineArgs:
target_dtype=self.dtype, target_dtype=self.dtype,
speculative_model=self.speculative_model, speculative_model=self.speculative_model,
num_speculative_tokens=self.num_speculative_tokens, num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len, speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager, use_v2_block_manager=self.use_v2_block_manager,

View File

@ -93,6 +93,8 @@ class GPUExecutor(ExecutorBase):
spec_decode_worker = SpecDecodeWorker.create_worker( spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker, scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs, draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
) )
assert self.parallel_config.world_size == 1, ( assert self.parallel_config.world_size == 1, (

View File

@ -12,15 +12,21 @@ class RejectionSampler(nn.Module):
https://arxiv.org/pdf/2302.01318.pdf. https://arxiv.org/pdf/2302.01318.pdf.
""" """
def __init__(self, strict_mode: bool = False): def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Create a rejection sampler. """Create a rejection sampler.
Args: Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are # NOTE: A "bonus token" is accepted iff all proposal tokens are
@ -312,7 +318,8 @@ class RejectionSampler(nn.Module):
# proposal methods that require KV cache. We can fix it by "prefilling" # proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix. # the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212 # https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1 if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids. # Fill the recovered token ids.
output.mul_(~after_false_mask).add_( output.mul_(~after_false_mask).add_(

View File

@ -612,6 +612,12 @@ class SequenceGroupMetadata:
self._token_chunk_size = token_chunk_size self._token_chunk_size = token_chunk_size
self.do_sample = do_sample self.do_sample = do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None
if self._token_chunk_size is None: if self._token_chunk_size is None:
if is_prompt: if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len() self._token_chunk_size = list(seq_data.values())[0].get_len()

View File

@ -1,5 +1,5 @@
from functools import cached_property from functools import cached_property
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
@ -54,7 +54,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def create_worker( def create_worker(
cls, cls,
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
draft_worker_kwargs, draft_worker_kwargs: Dict[str, Any],
disable_by_batch_size: Optional[int],
) -> "SpecDecodeWorker": ) -> "SpecDecodeWorker":
ngram_prompt_lookup_max = ( ngram_prompt_lookup_max = (
@ -62,7 +63,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
ngram_prompt_lookup_min = ( ngram_prompt_lookup_min = (
draft_worker_kwargs.pop("ngram_prompt_lookup_min")) draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
disable_bonus_tokens = True
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
disable_bonus_tokens = False
proposer_worker = NGramWorker(**draft_worker_kwargs) proposer_worker = NGramWorker(**draft_worker_kwargs)
proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min,
ngram_prompt_lookup_max) ngram_prompt_lookup_max)
@ -75,9 +78,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
return SpecDecodeWorker( return SpecDecodeWorker(
proposer_worker, proposer_worker,
scorer_worker, scorer_worker,
# TODO(cade) disable strict mode for speedup. disable_by_batch_size=disable_by_batch_size,
rejection_sampler=RejectionSampler(strict_mode=True), rejection_sampler=RejectionSampler(
) disable_bonus_tokens=disable_bonus_tokens, ))
def __init__( def __init__(
self, self,
@ -85,6 +88,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
scorer_worker: WorkerBase, scorer_worker: WorkerBase,
rejection_sampler: RejectionSampler, rejection_sampler: RejectionSampler,
metrics_collector: Optional[AsyncMetricsCollector] = None, metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
): ):
""" """
Create a SpecDecodeWorker. Create a SpecDecodeWorker.
@ -97,11 +101,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Worker. Worker.
rejection_sampler: A Torch module used to perform modified rejection rejection_sampler: A Torch module used to perform modified rejection
sampling for speculative decoding. sampling for speculative decoding.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set metrics_collector: Helper class for collecting metrics; can be set
for testing purposes. for testing purposes.
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.rejection_sampler = rejection_sampler self.rejection_sampler = rejection_sampler
self._metrics = AsyncMetricsCollector( self._metrics = AsyncMetricsCollector(
@ -199,27 +206,41 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len( if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0: execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req) return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return self._run_speculative_decoding_step(execute_model_req) return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec( def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
self, skip_proposer: bool) -> List[SamplerOutput]:
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Run a prefill step, without any speculation. The input is sent to
"""Run a prefill step, without any speculation. The input is sent to the the proposer and scorer model so that the KV cache is consistent
proposer and scorer model so that the KV cache is consistent between the between the two. When skip_proposer is True, the proposer model is
two. not called, meaning that the kv-cache in proposer for requests is not
updated, so they cannot enable spec decode in the rest decoding.
""" """
#logger.info("run proposer worker no spec") if not skip_proposer:
self.proposer_worker.execute_model(execute_model_req)
self.proposer_worker.execute_model(execute_model_req)
#logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model(execute_model_req) sampler_output = self.scorer_worker.execute_model(execute_model_req)
assert len(sampler_output) == 1 assert len(sampler_output) == 1
sampler_output = sampler_output[0] sampler_output = sampler_output[0]
@ -244,22 +265,18 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sequence. sequence.
""" """
#logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
#logger.info("score proposals")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
execute_model_req, execute_model_req,
proposals, proposals,
) )
#logger.info("verify proposals")
accepted_token_ids, target_logprobs = self._verify_tokens( accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores, execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots) proposals, execute_model_req.num_lookahead_slots)
#logger.info("create output list")
return self._create_output_sampler_list( return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
accepted_token_ids, accepted_token_ids,

View File

@ -56,7 +56,7 @@ class Top1Proposer(SpeculativeProposer):
proposal_lens, proposal_lens,
nonzero_proposal_len_seqs, nonzero_proposal_len_seqs,
nonzero_proposal_len_indices, nonzero_proposal_len_indices,
) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len)
if nonzero_proposal_len_seqs: if nonzero_proposal_len_seqs:
# Speculate tokens using the draft worker for the speculative # Speculate tokens using the draft worker for the speculative
@ -97,17 +97,27 @@ class Top1Proposer(SpeculativeProposer):
return proposals return proposals
def _split_by_max_model_len( def _split_by_proposal_len(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_len: int, proposal_len: int,
) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
"""Determine which sequences would exceed the max model length.""" """Split sequences by two groups:
1. Sequences with non-zero proposal length.
2. Sequences with zero proposal length (due to disabled speculation
or exceed the maximum model length).
"""
proposal_lens: List[int] = [] proposal_lens: List[int] = []
nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
nonzero_proposal_len_indices: List[int] = [] nonzero_proposal_len_indices: List[int] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
# The speculative decoding for this request has been disabled
# (e.g. due to high traffic).
if seq_group_metadata.num_speculative_tokens == 0:
proposal_lens.append(0)
continue
seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_data = next(iter(seq_group_metadata.seq_data.values()))
seq_len = seq_data.get_len() seq_len = seq_data.get_len()
@ -115,13 +125,14 @@ class Top1Proposer(SpeculativeProposer):
# are supported. # are supported.
# If max_proposal_len is defined, then we shall no exccess this # If max_proposal_len is defined, then we shall no exccess this
# quota for nonzero_proposal # quota for nonzero_proposal
new_k = 0
if (self.max_proposal_len is None if (self.max_proposal_len is None
or seq_len + proposal_len < self.max_proposal_len): or seq_len + proposal_len < self.max_proposal_len):
proposal_lens.append(proposal_len) new_k = proposal_len
nonzero_proposal_len_seqs.append(seq_group_metadata) nonzero_proposal_len_seqs.append(seq_group_metadata)
nonzero_proposal_len_indices.append(i) nonzero_proposal_len_indices.append(i)
else: proposal_lens.append(new_k)
proposal_lens.append(0) seq_group_metadata.num_speculative_tokens = new_k
return ( return (
proposal_lens, proposal_lens,